Skip to content

Commit 959a166

Browse files
authored
Add get_words_spell_suggestion (#1157)
1 parent a14b4b5 commit 959a166

File tree

4 files changed

+261
-0
lines changed

4 files changed

+261
-0
lines changed

docs/api/spell.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ correct_sent
1919

2020
The `correct_sent` function is an extension of the `correct` function and is used to correct an entire sentence. It tokenizes the input sentence, corrects each word, and returns the corrected sentence. This is beneficial for proofreading and improving the readability of Thai text.
2121

22+
get_words_spell_suggestion
23+
~~~~~~~~~~~~~~~~~~~~~~~~~~
24+
.. autofunction:: get_words_spell_suggestion
25+
26+
The `get_words_spell_suggestion` function is designed to retrieve spelling suggestions for one or more input Thai words.
27+
2228
spell
2329
~~~~~
2430
.. autofunction:: spell

pythainlp/spell/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"correct_sent",
1414
"spell",
1515
"spell_sent",
16+
"get_words_spell_suggestion",
1617
]
1718

1819
from pythainlp.spell.pn import NorvigSpellChecker
@@ -21,3 +22,4 @@
2122

2223
# these imports are placed here to avoid circular imports
2324
from pythainlp.spell.core import correct, correct_sent, spell, spell_sent
25+
from pythainlp.spell.words_spelling_correction import get_words_spell_suggestion
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# -*- coding: utf-8 -*-
2+
# SPDX-FileCopyrightText: 2016-2025 PyThaiNLP Project
3+
# SPDX-FileType: SOURCE
4+
# SPDX-License-Identifier: Apache-2.0
5+
import os
6+
from pythainlp.corpus import get_hf_hub
7+
from typing import List, Union
8+
9+
10+
class FastTextEncoder:
11+
"""
12+
A class to load pre-trained FastText-like word embeddings,
13+
compute word and sentence vectors, and interact with an ONNX
14+
model for nearest neighbor suggestions.
15+
"""
16+
17+
# --- Initialization and Data Loading ---
18+
19+
def __init__(self, model_dir, nn_model_path, words_list, bucket=2000000, nb_words=2000000, minn=5, maxn=5):
20+
"""
21+
Initializes the FastTextEncoder, loading embeddings, vocabulary,
22+
nearest neighbor model, and suggestion words list.
23+
24+
Args:
25+
model_dir (str): Directory containing 'embeddings.npy' and 'vocabulary.txt'.
26+
nn_model_path (str): Path to the ONNX nearest neighbors model.
27+
words_list (str): the list of words for suggestions.
28+
bucket (int): The size of the hash bucket for subword hashing.
29+
nb_words (int): The number of words in the vocabulary (used as an offset for subword indices).
30+
minn (int): Minimum character length for subwords.
31+
maxn (int): Maximum character length for subwords.
32+
"""
33+
try:
34+
import numpy as np # reduce load
35+
import onnxruntime
36+
self.np = np
37+
except ModuleNotFoundError:
38+
raise ModuleNotFoundError("""
39+
Please installing the package via 'pip install numpy onnxruntime'.
40+
""")
41+
except Exception as e:
42+
raise Exception(f"An unexpected error occurred: {e}")
43+
self.model_dir = model_dir
44+
self.nn_model_path = nn_model_path
45+
self.bucket = bucket
46+
self.nb_words = nb_words
47+
self.minn = minn
48+
self.maxn = maxn
49+
50+
# Load data and models
51+
self.vocabulary, self.embeddings = self._load_embeddings()
52+
self.words_for_suggestion = self._load_suggestion_words(words_list)
53+
self.nn_session = self._load_onnx_session(nn_model_path)
54+
self.embedding_dim = self.embeddings.shape[1]
55+
56+
def _load_embeddings(self):
57+
"""Loads embeddings matrix and vocabulary list."""
58+
input_matrix = self.np.load(os.path.join(self.model_dir, "embeddings.npy"))
59+
words = []
60+
vocab_path = os.path.join(self.model_dir, "vocabulary.txt")
61+
with open(vocab_path, "r", encoding='utf-8') as f:
62+
for line in f.readlines():
63+
words.append(line.rstrip())
64+
return words, input_matrix
65+
66+
def _load_suggestion_words(self, words_list):
67+
"""Loads the list of words used for suggestions."""
68+
words = self.np.array(words_list)
69+
return words
70+
71+
def _load_onnx_session(self, onnx_path):
72+
"""Loads the ONNX inference session."""
73+
# Note: Using providers=["CPUExecutionProvider"] for platform independence
74+
import onnxruntime as rt
75+
sess = rt.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
76+
return sess
77+
78+
# --- Helper Methods for Encoding ---
79+
80+
def _get_hash(self, subword):
81+
"""Computes the FastText-like hash for a subword."""
82+
h = 2166136261 # FNV-1a basis
83+
for c in subword:
84+
c_ord = ord(c) % 2**8
85+
h = (h ^ c_ord) % 2**32
86+
h = (h * 16777619) % 2**32 # FNV-1a prime
87+
return h % self.bucket + self.nb_words
88+
89+
def _get_subwords(self, word):
90+
"""Extracts subwords and their corresponding indices for a given word."""
91+
_word = "<" + word + ">"
92+
_subwords = []
93+
_subword_ids = []
94+
95+
# 1. Check for the word in vocabulary (full word is the first subword)
96+
if word in self.vocabulary:
97+
_subwords.append(word)
98+
_subword_ids.append(self.vocabulary.index(word))
99+
if word == "</s>":
100+
return _subwords, self.np.array(_subword_ids)
101+
102+
# 2. Extract n-grams (subwords) and get their hash indices
103+
for ngram_start in range(0, len(_word)):
104+
for ngram_length in range(self.minn, self.maxn + 1):
105+
if ngram_start + ngram_length <= len(_word):
106+
_candidate_subword = _word[ngram_start:ngram_start + ngram_length]
107+
# Only append if not already included (e.g., as the full word)
108+
if _candidate_subword not in _subwords:
109+
_subwords.append(_candidate_subword)
110+
_subword_ids.append(self._get_hash(_candidate_subword))
111+
112+
return _subwords, self.np.array(_subword_ids)
113+
114+
def get_word_vector(self, word):
115+
"""Computes the normalized vector for a single word."""
116+
# subword_ids[1] contains the array of indices for the word and its subwords
117+
subword_ids = self._get_subwords(word)[1]
118+
119+
# Check if the array of subword indices is empty
120+
if subword_ids.size == 0:
121+
# Return a 300-dimensional zero vector if no word/subword is found.
122+
return self.np.zeros(self.embedding_dim)
123+
124+
# Compute the mean of the embeddings for all subword indices
125+
vector = self.np.mean([self.embeddings[s] for s in subword_ids], axis=0)
126+
127+
# Normalize the vector
128+
norm = self.np.linalg.norm(vector)
129+
if norm > 0:
130+
vector /= norm
131+
132+
return vector
133+
134+
def _tokenize(self, sentence):
135+
"""Tokenizes a sentence based on whitespace."""
136+
tokens = []
137+
word = ""
138+
for c in sentence:
139+
if c in [' ', '\n', '\r', '\t', '\v', '\f', '\0']:
140+
if word:
141+
tokens.append(word)
142+
word = ""
143+
if c == '\n':
144+
tokens.append("</s>")
145+
else:
146+
word += c
147+
if word:
148+
tokens.append(word)
149+
return tokens
150+
151+
def get_sentence_vector(self, line):
152+
"""Computes the mean vector for a sentence."""
153+
tokens = self._tokenize(line)
154+
vectors = []
155+
for t in tokens:
156+
# get_word_vector already handles normalization, so no need to do it again here
157+
vec = self.get_word_vector(t)
158+
vectors.append(vec)
159+
160+
# If the sentence was empty and resulted in no vectors, return a zero vector
161+
if not vectors:
162+
return self.np.zeros(self.embedding_dim)
163+
164+
return self.np.mean(vectors, axis=0)
165+
166+
# --- Nearest Neighbor Method ---
167+
168+
def get_word_suggestion(self, list_word):
169+
"""
170+
Queries the ONNX model to find the nearest neighbor word(s)
171+
for the given word or list of words.
172+
173+
Args:
174+
list_word (str or list of str): A single word or a list of words
175+
to get suggestions for.
176+
177+
Returns:
178+
str or list of str: The nearest neighbor word(s) from the
179+
pre-loaded suggestion list.
180+
"""
181+
if isinstance(list_word, str):
182+
input_words = [list_word]
183+
return_single = True
184+
else:
185+
input_words = list_word
186+
return_single = False
187+
188+
# Compute sentence vector for each input word/phrase
189+
# The original code's `get_sentence_vector(' '.join(list(word)))` seems
190+
# intended to treat a list of characters/tokens as a sentence.
191+
# I'll stick to a more standard usage: treat each item in `input_words`
192+
# as a separate phrase/word to encode.
193+
word_input_vecs = [self.get_sentence_vector(' '.join(list(word))) for word in input_words]
194+
195+
# Convert to numpy array for ONNX input (ensure float32)
196+
input_data = self.np.array(word_input_vecs, dtype=self.np.float32)
197+
198+
# Run ONNX inference
199+
indices = self.nn_session.run(None, {"X": input_data})[0]
200+
201+
# Look up suggestions
202+
suggestions = [self.words_for_suggestion[i].tolist() for i in indices]
203+
204+
return suggestions[0] if return_single else suggestions
205+
206+
207+
class Words_Spelling_Correction(FastTextEncoder):
208+
def __init__(self):
209+
self.model_name = "pythainlp/word-spelling-correction-char2vec"
210+
self.model_path = get_hf_hub(self.model_name)
211+
self.model_onnx = get_hf_hub(self.model_name, "nearest_neighbors.onnx")
212+
with open(get_hf_hub(self.model_name, "list_word-spelling-correction-char2vec.txt")) as f:
213+
self.list_word = [i.strip() for i in f.readlines()]
214+
super().__init__(self.model_path, self.model_onnx, self.list_word)
215+
216+
217+
_WSC = None
218+
219+
220+
def get_words_spell_suggestion(list_words: Union[str, List[str]]) -> Union[List[str], List[List[str]]]:
221+
"""
222+
Get words spell suggestion
223+
224+
The function is designed to retrieve spelling suggestions \
225+
for one or more input Thai words.
226+
227+
Requirements: numpy and onnxruntime (Install before use this function)
228+
229+
:param Union[str, List[str]] list_word: list words or a word.
230+
:return: List words spell suggestion (max 5 items per word)
231+
:rtype: Union[List[str], List[List[str]]]
232+
233+
:Example:
234+
::
235+
236+
from pythainlp.spell import get_words_spell_suggestion
237+
238+
print(get_words_spell_suggestion("คมดี"))
239+
# output: ['คนดีผีคุ้ม', 'มีดคอม้า', 'คดี', 'มีดสองคม', 'มูลคดี']
240+
241+
print(get_words_spell_suggestion(["คมดี","กระเพาะ"]))
242+
# output: [['คนดีผีคุ้ม', 'มีดคอม้า', 'คดี', 'มีดสองคม', 'มูลคดี'],
243+
# ['กระเพาะ', 'กระพา', 'กะเพรา', 'กระเพาะปลา', 'พระประธาน']]
244+
"""
245+
global _WSC
246+
if _WSC==None:
247+
_WSC = Words_Spelling_Correction()
248+
return _WSC.get_word_suggestion(list_words)

tests/extra/testx_spell.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
spell,
1212
spell_sent,
1313
symspellpy,
14+
get_words_spell_suggestion,
1415
)
1516

1617
from ..core.test_spell import SENT_TOKS
@@ -66,3 +67,7 @@ def test_correct_sent(self):
6667
correct_sent(SENT_TOKS, engine="wanchanberta_thai_grammarly")
6768
)
6869
self.assertIsNotNone(symspellpy.correct_sent(SENT_TOKS))
70+
71+
def test_get_words_spell_suggestion(self):
72+
self.assertIsNotNone(get_words_spell_suggestion("คมดี"))
73+
self.assertIsNotNone(get_words_spell_suggestion(["คมดี","มะนา"]))

0 commit comments

Comments
 (0)