-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathstatic_model.py
43 lines (32 loc) · 1.11 KB
/
static_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from arithmetic_compressor.models.base_adaptive_model import BaseFrequencyTable
from arithmetic_compressor.util import *
SCALE_FACTOR = 4096
class StaticModel:
"""A static model, which does not adapt to input data or statistics.
"""
def __init__(self, probability):
symbols = list(probability.keys())
self.name = "Static"
self.symbols = symbols
self.__prob = dict(probability)
# compute cdf from given probability
cdf = {}
prev_freq = 0
self.freq = freq = {sym: round(SCALE_FACTOR * prob)
for sym, prob in probability.items()}
for sym, freq in freq.items():
cdf[sym] = Range(prev_freq, prev_freq + freq)
prev_freq += freq
self.cdf_object = cdf
def cdf(self):
return self.cdf_object
def probability(self):
return self.__prob
def predict(self, symbol):
assert symbol in self.symbols
return self.probability()[symbol]
def update(self, symbol):
pass
def test_model(self, gen_random=True, N=10000, custom_data=None):
self.name = "Static Model"
return BaseFrequencyTable.test_model(self, gen_random, N, custom_data)