| """ |
| Custom handler for Vietnamese POS Tagger inference on Hugging Face. |
| |
| Supports two model formats: |
| - CRFsuite format (.crfsuite) - loaded with pycrfsuite |
| - underthesea-core format (.crf) - loaded with underthesea_core |
| """ |
|
|
| import os |
| import re |
| from typing import Dict, List, Any |
|
|
| |
| try: |
| import pycrfsuite |
| HAS_PYCRFSUITE = True |
| except ImportError: |
| HAS_PYCRFSUITE = False |
|
|
| try: |
| from underthesea_core import CRFModel, CRFTagger |
| HAS_UNDERTHESEA_CORE = True |
| except ImportError: |
| try: |
| from underthesea_core.underthesea_core import CRFModel, CRFTagger |
| HAS_UNDERTHESEA_CORE = True |
| except ImportError: |
| HAS_UNDERTHESEA_CORE = False |
|
|
|
|
| class PythonCRFFeaturizer: |
| """ |
| Python implementation of CRFFeaturizer compatible with underthesea_core API. |
| """ |
|
|
| def __init__(self, feature_templates, dictionary=None): |
| self.feature_templates = feature_templates |
| self.dictionary = dictionary or set() |
|
|
| def _parse_template(self, template): |
| match = re.match(r'T\[([^\]]+)\](?:\.(\w+))?', template) |
| if not match: |
| return None, None, None |
| indices_str = match.group(1) |
| attribute = match.group(2) |
| indices = [int(i.strip()) for i in indices_str.split(',')] |
| return indices, attribute, template |
|
|
| def _get_token_value(self, tokens, position, index): |
| actual_pos = position + index |
| if actual_pos < 0: |
| return '__BOS__' |
| elif actual_pos >= len(tokens): |
| return '__EOS__' |
| return tokens[actual_pos] |
|
|
| def _apply_attribute(self, value, attribute): |
| if value in ('__BOS__', '__EOS__'): |
| return value |
| if attribute is None: |
| return value |
| elif attribute == 'lower': |
| return value.lower() |
| elif attribute == 'upper': |
| return value.upper() |
| elif attribute == 'istitle': |
| return str(value.istitle()) |
| elif attribute == 'isupper': |
| return str(value.isupper()) |
| elif attribute == 'islower': |
| return str(value.islower()) |
| elif attribute == 'isdigit': |
| return str(value.isdigit()) |
| elif attribute == 'isalpha': |
| return str(value.isalpha()) |
| elif attribute == 'is_in_dict': |
| return str(value in self.dictionary) |
| elif attribute.startswith('prefix'): |
| n = int(attribute[6:]) if len(attribute) > 6 else 2 |
| return value[:n] if len(value) >= n else value |
| elif attribute.startswith('suffix'): |
| n = int(attribute[6:]) if len(attribute) > 6 else 2 |
| return value[-n:] if len(value) >= n else value |
| else: |
| return value |
|
|
| def extract_features(self, tokens, position): |
| features = {} |
| for template in self.feature_templates: |
| indices, attribute, template_str = self._parse_template(template) |
| if indices is None: |
| continue |
| if len(indices) == 1: |
| value = self._get_token_value(tokens, position, indices[0]) |
| value = self._apply_attribute(value, attribute) |
| features[template_str] = value |
| else: |
| values = [self._get_token_value(tokens, position, idx) for idx in indices] |
| if attribute == 'is_in_dict': |
| combined = ' '.join(values) |
| features[template_str] = str(combined in self.dictionary) |
| else: |
| combined = '|'.join(values) |
| features[template_str] = combined |
| return features |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| import os |
|
|
| |
| self.feature_templates = [ |
| "T[0]", "T[0].lower", "T[0].istitle", "T[0].isupper", |
| "T[0].isdigit", "T[0].isalpha", "T[0].prefix2", "T[0].prefix3", |
| "T[0].suffix2", "T[0].suffix3", "T[-1]", "T[-1].lower", |
| "T[-1].istitle", "T[-1].isupper", "T[-2]", "T[-2].lower", |
| "T[1]", "T[1].lower", "T[1].istitle", "T[1].isupper", |
| "T[2]", "T[2].lower", "T[-1,0]", "T[0,1]", |
| "T[0].is_in_dict", "T[-1,0].is_in_dict", "T[0,1].is_in_dict", |
| ] |
|
|
| self.featurizer = PythonCRFFeaturizer(self.feature_templates) |
|
|
| |
| |
| model_candidates = [ |
| (os.path.join(path, "model.crfsuite"), "pycrfsuite"), |
| (os.path.join(path, "pos_tagger.crfsuite"), "pycrfsuite"), |
| (os.path.join(path, "model.crf"), "underthesea-core"), |
| ] |
|
|
| model_path = None |
| model_format = None |
| for candidate, fmt in model_candidates: |
| if os.path.exists(candidate): |
| model_path = candidate |
| model_format = fmt |
| break |
|
|
| if model_path is None: |
| raise FileNotFoundError( |
| f"No model found. Checked: {[c for c, _ in model_candidates]}" |
| ) |
|
|
| |
| self.model_format = model_format |
| if model_format == "pycrfsuite": |
| if not HAS_PYCRFSUITE: |
| raise ImportError("pycrfsuite not installed. Install with: pip install python-crfsuite") |
| self.tagger = pycrfsuite.Tagger() |
| self.tagger.open(model_path) |
| elif model_format == "underthesea-core": |
| if not HAS_UNDERTHESEA_CORE: |
| raise ImportError("underthesea-core not installed") |
| model = CRFModel.load(model_path) |
| self.tagger = CRFTagger.from_model(model) |
|
|
| def _tokenize(self, text: str) -> List[str]: |
| """Simple whitespace tokenization.""" |
| return text.strip().split() |
|
|
| def _extract_features(self, tokens: List[str]) -> List[List[str]]: |
| """Extract features for all tokens in a sentence.""" |
| features = [] |
| for i in range(len(tokens)): |
| feat_dict = self.featurizer.extract_features(tokens, i) |
| feature_list = [f"{k}={v}" for k, v in feat_dict.items()] |
| features.append(feature_list) |
| return features |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| Handle inference requests. |
| |
| Args: |
| data: Dict with "inputs" key containing text or list of texts |
| |
| Returns: |
| List of dicts with token and POS tag pairs |
| """ |
| inputs = data.get("inputs", data.get("text", "")) |
|
|
| |
| if isinstance(inputs, str): |
| inputs = [inputs] |
|
|
| results = [] |
| for text in inputs: |
| tokens = self._tokenize(text) |
| if not tokens: |
| results.append([]) |
| continue |
|
|
| features = self._extract_features(tokens) |
| tags = self.tagger.tag(features) |
|
|
| result = [{"token": token, "tag": tag} for token, tag in zip(tokens, tags)] |
| results.append(result) |
|
|
| return results if len(results) > 1 else results[0] |
|
|