|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +# SPDX-FileCopyrightText: 2016-2025 PyThaiNLP Project |
| 3 | +# SPDX-FileType: SOURCE |
| 4 | +# SPDX-License-Identifier: Apache-2.0 |
| 5 | +import os |
| 6 | +from pythainlp.corpus import get_hf_hub |
| 7 | +from typing import List, Union |
| 8 | + |
| 9 | + |
| 10 | +class FastTextEncoder: |
| 11 | + """ |
| 12 | + A class to load pre-trained FastText-like word embeddings, |
| 13 | + compute word and sentence vectors, and interact with an ONNX |
| 14 | + model for nearest neighbor suggestions. |
| 15 | + """ |
| 16 | + |
| 17 | + # --- Initialization and Data Loading --- |
| 18 | + |
| 19 | + def __init__(self, model_dir, nn_model_path, words_list, bucket=2000000, nb_words=2000000, minn=5, maxn=5): |
| 20 | + """ |
| 21 | + Initializes the FastTextEncoder, loading embeddings, vocabulary, |
| 22 | + nearest neighbor model, and suggestion words list. |
| 23 | +
|
| 24 | + Args: |
| 25 | + model_dir (str): Directory containing 'embeddings.npy' and 'vocabulary.txt'. |
| 26 | + nn_model_path (str): Path to the ONNX nearest neighbors model. |
| 27 | + words_list (str): the list of words for suggestions. |
| 28 | + bucket (int): The size of the hash bucket for subword hashing. |
| 29 | + nb_words (int): The number of words in the vocabulary (used as an offset for subword indices). |
| 30 | + minn (int): Minimum character length for subwords. |
| 31 | + maxn (int): Maximum character length for subwords. |
| 32 | + """ |
| 33 | + try: |
| 34 | + import numpy as np # reduce load |
| 35 | + import onnxruntime |
| 36 | + self.np = np |
| 37 | + except ModuleNotFoundError: |
| 38 | + raise ModuleNotFoundError(""" |
| 39 | + Please installing the package via 'pip install numpy onnxruntime'. |
| 40 | + """) |
| 41 | + except Exception as e: |
| 42 | + raise Exception(f"An unexpected error occurred: {e}") |
| 43 | + self.model_dir = model_dir |
| 44 | + self.nn_model_path = nn_model_path |
| 45 | + self.bucket = bucket |
| 46 | + self.nb_words = nb_words |
| 47 | + self.minn = minn |
| 48 | + self.maxn = maxn |
| 49 | + |
| 50 | + # Load data and models |
| 51 | + self.vocabulary, self.embeddings = self._load_embeddings() |
| 52 | + self.words_for_suggestion = self._load_suggestion_words(words_list) |
| 53 | + self.nn_session = self._load_onnx_session(nn_model_path) |
| 54 | + self.embedding_dim = self.embeddings.shape[1] |
| 55 | + |
| 56 | + def _load_embeddings(self): |
| 57 | + """Loads embeddings matrix and vocabulary list.""" |
| 58 | + input_matrix = self.np.load(os.path.join(self.model_dir, "embeddings.npy")) |
| 59 | + words = [] |
| 60 | + vocab_path = os.path.join(self.model_dir, "vocabulary.txt") |
| 61 | + with open(vocab_path, "r", encoding='utf-8') as f: |
| 62 | + for line in f.readlines(): |
| 63 | + words.append(line.rstrip()) |
| 64 | + return words, input_matrix |
| 65 | + |
| 66 | + def _load_suggestion_words(self, words_list): |
| 67 | + """Loads the list of words used for suggestions.""" |
| 68 | + words = self.np.array(words_list) |
| 69 | + return words |
| 70 | + |
| 71 | + def _load_onnx_session(self, onnx_path): |
| 72 | + """Loads the ONNX inference session.""" |
| 73 | + # Note: Using providers=["CPUExecutionProvider"] for platform independence |
| 74 | + import onnxruntime as rt |
| 75 | + sess = rt.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) |
| 76 | + return sess |
| 77 | + |
| 78 | + # --- Helper Methods for Encoding --- |
| 79 | + |
| 80 | + def _get_hash(self, subword): |
| 81 | + """Computes the FastText-like hash for a subword.""" |
| 82 | + h = 2166136261 # FNV-1a basis |
| 83 | + for c in subword: |
| 84 | + c_ord = ord(c) % 2**8 |
| 85 | + h = (h ^ c_ord) % 2**32 |
| 86 | + h = (h * 16777619) % 2**32 # FNV-1a prime |
| 87 | + return h % self.bucket + self.nb_words |
| 88 | + |
| 89 | + def _get_subwords(self, word): |
| 90 | + """Extracts subwords and their corresponding indices for a given word.""" |
| 91 | + _word = "<" + word + ">" |
| 92 | + _subwords = [] |
| 93 | + _subword_ids = [] |
| 94 | + |
| 95 | + # 1. Check for the word in vocabulary (full word is the first subword) |
| 96 | + if word in self.vocabulary: |
| 97 | + _subwords.append(word) |
| 98 | + _subword_ids.append(self.vocabulary.index(word)) |
| 99 | + if word == "</s>": |
| 100 | + return _subwords, self.np.array(_subword_ids) |
| 101 | + |
| 102 | + # 2. Extract n-grams (subwords) and get their hash indices |
| 103 | + for ngram_start in range(0, len(_word)): |
| 104 | + for ngram_length in range(self.minn, self.maxn + 1): |
| 105 | + if ngram_start + ngram_length <= len(_word): |
| 106 | + _candidate_subword = _word[ngram_start:ngram_start + ngram_length] |
| 107 | + # Only append if not already included (e.g., as the full word) |
| 108 | + if _candidate_subword not in _subwords: |
| 109 | + _subwords.append(_candidate_subword) |
| 110 | + _subword_ids.append(self._get_hash(_candidate_subword)) |
| 111 | + |
| 112 | + return _subwords, self.np.array(_subword_ids) |
| 113 | + |
| 114 | + def get_word_vector(self, word): |
| 115 | + """Computes the normalized vector for a single word.""" |
| 116 | + # subword_ids[1] contains the array of indices for the word and its subwords |
| 117 | + subword_ids = self._get_subwords(word)[1] |
| 118 | + |
| 119 | + # Check if the array of subword indices is empty |
| 120 | + if subword_ids.size == 0: |
| 121 | + # Return a 300-dimensional zero vector if no word/subword is found. |
| 122 | + return self.np.zeros(self.embedding_dim) |
| 123 | + |
| 124 | + # Compute the mean of the embeddings for all subword indices |
| 125 | + vector = self.np.mean([self.embeddings[s] for s in subword_ids], axis=0) |
| 126 | + |
| 127 | + # Normalize the vector |
| 128 | + norm = self.np.linalg.norm(vector) |
| 129 | + if norm > 0: |
| 130 | + vector /= norm |
| 131 | + |
| 132 | + return vector |
| 133 | + |
| 134 | + def _tokenize(self, sentence): |
| 135 | + """Tokenizes a sentence based on whitespace.""" |
| 136 | + tokens = [] |
| 137 | + word = "" |
| 138 | + for c in sentence: |
| 139 | + if c in [' ', '\n', '\r', '\t', '\v', '\f', '\0']: |
| 140 | + if word: |
| 141 | + tokens.append(word) |
| 142 | + word = "" |
| 143 | + if c == '\n': |
| 144 | + tokens.append("</s>") |
| 145 | + else: |
| 146 | + word += c |
| 147 | + if word: |
| 148 | + tokens.append(word) |
| 149 | + return tokens |
| 150 | + |
| 151 | + def get_sentence_vector(self, line): |
| 152 | + """Computes the mean vector for a sentence.""" |
| 153 | + tokens = self._tokenize(line) |
| 154 | + vectors = [] |
| 155 | + for t in tokens: |
| 156 | + # get_word_vector already handles normalization, so no need to do it again here |
| 157 | + vec = self.get_word_vector(t) |
| 158 | + vectors.append(vec) |
| 159 | + |
| 160 | + # If the sentence was empty and resulted in no vectors, return a zero vector |
| 161 | + if not vectors: |
| 162 | + return self.np.zeros(self.embedding_dim) |
| 163 | + |
| 164 | + return self.np.mean(vectors, axis=0) |
| 165 | + |
| 166 | + # --- Nearest Neighbor Method --- |
| 167 | + |
| 168 | + def get_word_suggestion(self, list_word): |
| 169 | + """ |
| 170 | + Queries the ONNX model to find the nearest neighbor word(s) |
| 171 | + for the given word or list of words. |
| 172 | +
|
| 173 | + Args: |
| 174 | + list_word (str or list of str): A single word or a list of words |
| 175 | + to get suggestions for. |
| 176 | +
|
| 177 | + Returns: |
| 178 | + str or list of str: The nearest neighbor word(s) from the |
| 179 | + pre-loaded suggestion list. |
| 180 | + """ |
| 181 | + if isinstance(list_word, str): |
| 182 | + input_words = [list_word] |
| 183 | + return_single = True |
| 184 | + else: |
| 185 | + input_words = list_word |
| 186 | + return_single = False |
| 187 | + |
| 188 | + # Compute sentence vector for each input word/phrase |
| 189 | + # The original code's `get_sentence_vector(' '.join(list(word)))` seems |
| 190 | + # intended to treat a list of characters/tokens as a sentence. |
| 191 | + # I'll stick to a more standard usage: treat each item in `input_words` |
| 192 | + # as a separate phrase/word to encode. |
| 193 | + word_input_vecs = [self.get_sentence_vector(' '.join(list(word))) for word in input_words] |
| 194 | + |
| 195 | + # Convert to numpy array for ONNX input (ensure float32) |
| 196 | + input_data = self.np.array(word_input_vecs, dtype=self.np.float32) |
| 197 | + |
| 198 | + # Run ONNX inference |
| 199 | + indices = self.nn_session.run(None, {"X": input_data})[0] |
| 200 | + |
| 201 | + # Look up suggestions |
| 202 | + suggestions = [self.words_for_suggestion[i].tolist() for i in indices] |
| 203 | + |
| 204 | + return suggestions[0] if return_single else suggestions |
| 205 | + |
| 206 | + |
| 207 | +class Words_Spelling_Correction(FastTextEncoder): |
| 208 | + def __init__(self): |
| 209 | + self.model_name = "pythainlp/word-spelling-correction-char2vec" |
| 210 | + self.model_path = get_hf_hub(self.model_name) |
| 211 | + self.model_onnx = get_hf_hub(self.model_name, "nearest_neighbors.onnx") |
| 212 | + with open(get_hf_hub(self.model_name, "list_word-spelling-correction-char2vec.txt")) as f: |
| 213 | + self.list_word = [i.strip() for i in f.readlines()] |
| 214 | + super().__init__(self.model_path, self.model_onnx, self.list_word) |
| 215 | + |
| 216 | + |
| 217 | +_WSC = None |
| 218 | + |
| 219 | + |
| 220 | +def get_words_spell_suggestion(list_words: Union[str, List[str]]) -> Union[List[str], List[List[str]]]: |
| 221 | + """ |
| 222 | + Get words spell suggestion |
| 223 | +
|
| 224 | + The function is designed to retrieve spelling suggestions \ |
| 225 | + for one or more input Thai words. |
| 226 | +
|
| 227 | + Requirements: numpy and onnxruntime (Install before use this function) |
| 228 | +
|
| 229 | + :param Union[str, List[str]] list_word: list words or a word. |
| 230 | + :return: List words spell suggestion (max 5 items per word) |
| 231 | + :rtype: Union[List[str], List[List[str]]] |
| 232 | +
|
| 233 | + :Example: |
| 234 | + :: |
| 235 | +
|
| 236 | + from pythainlp.spell import get_words_spell_suggestion |
| 237 | +
|
| 238 | + print(get_words_spell_suggestion("คมดี")) |
| 239 | + # output: ['คนดีผีคุ้ม', 'มีดคอม้า', 'คดี', 'มีดสองคม', 'มูลคดี'] |
| 240 | +
|
| 241 | + print(get_words_spell_suggestion(["คมดี","กระเพาะ"])) |
| 242 | + # output: [['คนดีผีคุ้ม', 'มีดคอม้า', 'คดี', 'มีดสองคม', 'มูลคดี'], |
| 243 | + # ['กระเพาะ', 'กระพา', 'กะเพรา', 'กระเพาะปลา', 'พระประธาน']] |
| 244 | + """ |
| 245 | + global _WSC |
| 246 | + if _WSC==None: |
| 247 | + _WSC = Words_Spelling_Correction() |
| 248 | + return _WSC.get_word_suggestion(list_words) |
0 commit comments