Skip to content

Commit c4a3b3c

Browse files
committed
highlight anomalies
1 parent 77426a9 commit c4a3b3c

File tree

5 files changed

+346
-126
lines changed

5 files changed

+346
-126
lines changed

notebooks/highlight.ipynb

+189
Large diffs are not rendered by default.

python_autocomplete/dataset/bpe.py

+3-89
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from functools import lru_cache
22
from heapq import heappush, heappop
3-
from typing import List, Tuple
3+
from typing import List
44

55
from labml import lab, monit
66
from labml.utils.cache import cache_set
7-
from python_autocomplete.dataset import Tokenizer, ID_CHARS
7+
from python_autocomplete.dataset import Tokenizer
8+
from python_autocomplete.dataset.break_words import SourceCodeTokenizer
89

910

1011
class BPE(Tokenizer):
@@ -142,93 +143,6 @@ def encode(self, word: str):
142143
return self.encoder.encode([self.char_stoi[c] for c in word if c in self.char_stoi])
143144

144145

145-
class WordTokenizer:
146-
def collect_words(self, data: str):
147-
raise NotImplementedError
148-
149-
def get_words(self) -> Tuple[List[str], List[int]]:
150-
raise NotImplementedError
151-
152-
def tokenize(self, data: str, *, is_silent: bool = False) -> List[str]:
153-
raise NotImplementedError
154-
155-
156-
class SourceCodeTokenizer(WordTokenizer):
157-
def __init__(self):
158-
self.words = {}
159-
160-
def add_word(self, word):
161-
if not word:
162-
return
163-
164-
if word not in self.words:
165-
self.words[word] = 1
166-
else:
167-
self.words[word] += 1
168-
169-
def tokenize(self, data: str, *, is_silent: bool = False) -> List[str]:
170-
last_idx = 0
171-
is_id = False
172-
res = []
173-
174-
for i, c in monit.enum('Collect words', data, is_silent=is_silent):
175-
if c in ID_CHARS:
176-
if not is_id:
177-
if last_idx < i:
178-
res.append(data[last_idx:i])
179-
last_idx = i
180-
is_id = True
181-
else:
182-
if is_id:
183-
if last_idx < i:
184-
res.append(data[last_idx:i])
185-
last_idx = i
186-
is_id = False
187-
188-
if last_idx < len(data):
189-
res.append(data[last_idx:])
190-
191-
return res
192-
193-
def collect_words(self, data: str):
194-
last_idx = 0
195-
is_id = False
196-
197-
for i, c in monit.enum('Collect words', data):
198-
if c in ID_CHARS:
199-
if not is_id:
200-
self.add_word(data[last_idx:i])
201-
last_idx = i
202-
is_id = True
203-
else:
204-
if is_id:
205-
self.add_word(data[last_idx:i])
206-
last_idx = i
207-
is_id = False
208-
209-
self.add_word(data[last_idx:])
210-
211-
def get_words(self):
212-
words_list = [(f, w) for w, f in self.words.items()]
213-
words_list.sort(key=lambda x: -x[0])
214-
215-
return [w for _, w in words_list], [f for f, _ in words_list]
216-
217-
218-
class NoTokenizer(WordTokenizer):
219-
def __init__(self):
220-
self.data = ''
221-
222-
def collect_words(self, data):
223-
self.data += data
224-
225-
def get_words(self):
226-
return [self.data], [1]
227-
228-
def tokenize(self, data: str, *, is_silent: bool = False) -> List[str]:
229-
return [data]
230-
231-
232146
class BPELearner:
233147
def __init__(self, words_list: List[str], word_freq: List[int]):
234148
self.words_list = words_list
+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from typing import List, Tuple
2+
3+
from labml import monit
4+
from python_autocomplete.dataset import ID_CHARS
5+
6+
7+
class WordTokenizer:
8+
def collect_words(self, data: str):
9+
raise NotImplementedError
10+
11+
def get_words(self) -> Tuple[List[str], List[int]]:
12+
raise NotImplementedError
13+
14+
def tokenize(self, data: str, *, is_silent: bool = False) -> List[str]:
15+
raise NotImplementedError
16+
17+
18+
class SourceCodeTokenizer(WordTokenizer):
19+
def __init__(self):
20+
self.words = {}
21+
22+
def add_word(self, word):
23+
if not word:
24+
return
25+
26+
if word not in self.words:
27+
self.words[word] = 1
28+
else:
29+
self.words[word] += 1
30+
31+
def tokenize(self, data: str, *, is_silent: bool = False) -> List[str]:
32+
last_idx = 0
33+
is_id = False
34+
res = []
35+
36+
for i, c in monit.enum('Collect words', data, is_silent=is_silent):
37+
if c in ID_CHARS:
38+
if not is_id:
39+
if last_idx < i:
40+
res.append(data[last_idx:i])
41+
last_idx = i
42+
is_id = True
43+
else:
44+
if is_id:
45+
if last_idx < i:
46+
res.append(data[last_idx:i])
47+
last_idx = i
48+
is_id = False
49+
50+
if last_idx < len(data):
51+
res.append(data[last_idx:])
52+
53+
return res
54+
55+
def collect_words(self, data: str):
56+
last_idx = 0
57+
is_id = False
58+
59+
for i, c in monit.enum('Collect words', data):
60+
if c in ID_CHARS:
61+
if not is_id:
62+
self.add_word(data[last_idx:i])
63+
last_idx = i
64+
is_id = True
65+
else:
66+
if is_id:
67+
self.add_word(data[last_idx:i])
68+
last_idx = i
69+
is_id = False
70+
71+
self.add_word(data[last_idx:])
72+
73+
def get_words(self):
74+
words_list = [(f, w) for w, f in self.words.items()]
75+
words_list.sort(key=lambda x: -x[0])
76+
77+
return [w for _, w in words_list], [f for f, _ in words_list]
78+
79+
80+
class NoTokenizer(WordTokenizer):
81+
def __init__(self):
82+
self.data = ''
83+
84+
def collect_words(self, data):
85+
self.data += data
86+
87+
def get_words(self):
88+
return [self.data], [1]
89+
90+
def tokenize(self, data: str, *, is_silent: bool = False) -> List[str]:
91+
return [data]

python_autocomplete/evaluate/anomalies.py

+55-35
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,79 @@
1+
import torch
2+
from torch import nn
3+
14
from labml import logger, lab, monit
25
from labml.logger import Text, Style
3-
from python_autocomplete.evaluate import Predictor
4-
from python_autocomplete.evaluate.factory import get_predictor
6+
from labml_helpers.module import Module
7+
from python_autocomplete.dataset import Tokenizer
8+
from python_autocomplete.evaluate.factory import load_experiment
9+
from python_autocomplete.train import StateUpdater
10+
511

12+
def anomalies(tokenizer: Tokenizer, text: str, model: Module, state_updater: StateUpdater, is_token_by_token: bool):
13+
tokens = tokenizer.encode(text)
614

7-
def anomalies(predictor: Predictor, text: str):
815
line_no = 1
9-
logs = [(f"{line_no: 4d}: ", Text.meta), (text[0], Text.subtle)]
16+
logs = [(f"{line_no: 4d}: ", Text.meta), (tokenizer.itos[tokens[0]], Style.bold)]
17+
18+
text = torch.tensor(tokens, dtype=torch.long, device=model.device)
19+
prompt = text[:1].unsqueeze(-1)
20+
21+
state = None
22+
softmax = nn.Softmax(-1)
1023

11-
i = 0
24+
i = 1
1225

1326
while i + 1 < len(text):
14-
# print(i, self.predictor.prompt)
15-
preds, _ = predictor.get_predictions(text[:i + 1], None, calc_probs=True)
16-
preds = preds[0, :]
17-
c = text[i + 1]
18-
19-
if c == '\n':
20-
logger.log(logs)
21-
line_no += 1
22-
logs = [(f"{line_no: 4d}: ", Text.meta)]
23-
elif c == '\r':
24-
continue
25-
elif c not in predictor.tokenizer.stoi:
26-
logs.append(c)
27+
with torch.no_grad():
28+
prediction, new_state = model(prompt, state)
29+
30+
state = state_updater(state, new_state)
31+
prediction = softmax(prediction[-1, 0])
32+
33+
if is_token_by_token:
34+
prompt = text[i: i + 1].unsqueeze(-1)
2735
else:
28-
next_id = predictor.tokenizer.stoi[c]
29-
prob = preds[next_id]
30-
if prob > 0.9:
31-
logs.append((c, [Style.bold, Text.success, Style.underline]))
32-
elif prob > 0.75:
33-
logs.append((c, [Text.success, Style.underline]))
34-
elif prob > 0.2:
35-
logs.append(c)
36-
elif prob > 0.1:
37-
logs.append((c, [Text.warning, Style.underline]))
38-
elif prob > 0.01:
39-
logs.append((c, [Style.bold, Text.warning, Style.underline]))
40-
elif prob > 0.001:
41-
logs.append((c, [Text.danger, Style.underline]))
36+
prompt = text[:i + 1]
37+
prompt = prompt[-512:].unsqueeze(-1)
38+
39+
token_str = tokenizer.itos[text[i]]
40+
prob = prediction[text[i]].item()
41+
42+
for c in token_str:
43+
if c == '\n':
44+
logger.log(logs)
45+
line_no += 1
46+
logs = [(f"{line_no: 4d}: ", Text.meta)]
47+
elif c == '\r':
48+
continue
4249
else:
43-
logs.append((c, [Style.bold, Text.danger, Style.underline]))
50+
if prob > 0.9:
51+
logs.append((c, [Text.subtle, Style.underline]))
52+
elif prob > 0.75:
53+
logs.append((c, [Text.success, Style.underline]))
54+
elif prob > 0.2:
55+
logs.append(c)
56+
elif prob > 0.1:
57+
logs.append((c, [Text.warning, Style.underline]))
58+
elif prob > 0.01:
59+
logs.append((c, [Style.bold, Text.warning, Style.underline]))
60+
elif prob > 0.001:
61+
logs.append((c, [Text.danger, Style.underline]))
62+
else:
63+
logs.append((c, [Style.bold, Text.danger, Style.underline]))
4464

4565
i += 1
4666

4767
logger.log(logs)
4868

4969

5070
def main():
51-
predictor = get_predictor()
71+
conf = load_experiment()
5272

5373
with open(str(lab.get_data_path() / 'sample.py'), 'r') as f:
5474
sample = f.read()
5575
with monit.section('Anomalies'):
56-
anomalies(predictor, sample)
76+
anomalies(conf.text.tokenizer, sample, conf.model, conf.state_updater, conf.is_token_by_token)
5777

5878

5979
if __name__ == '__main__':

python_autocomplete/evaluate/factory.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from python_autocomplete.train import Configs
55

66

7-
def get_predictor() -> Predictor:
7+
def load_experiment() -> Configs:
88
conf = Configs()
99
experiment.evaluate()
1010

@@ -29,7 +29,13 @@ def get_predictor() -> Predictor:
2929
experiment.load(run_uuid, checkpoint)
3030

3131
experiment.start()
32+
33+
return conf
34+
35+
36+
def get_predictor() -> Predictor:
37+
conf = load_experiment()
3238
conf.model.eval()
3339
return Predictor(conf.model, conf.text.tokenizer,
3440
state_updater=conf.state_updater,
35-
is_token_by_token=conf.is_token_by_token)
41+
is_token_by_token=conf.is_token_by_token)

0 commit comments

Comments
 (0)