Skip to content

Commit e2349e0

Browse files
committed
Misc update
Add code for reproducing experiments
1 parent eac494f commit e2349e0

4 files changed

+465
-5
lines changed

README.md

+31
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,37 @@ X_test_transform = apply_kernels(X_test, kernels)
6666
predictions = classifier.predict(X_test_transform)
6767
```
6868

69+
## Reproducing the Experiments
70+
71+
### [`reproduce_experiments_ucr.py`](code/reproduce_experiments_ucr.py)
72+
73+
```
74+
Arguments:
75+
-d --dataset_names : txt file of dataset names
76+
-i --input_path : parent directory for datasets
77+
-o --output_path : path for results
78+
-n --num_runs : number of runs (optional, default 10)
79+
-k --num_kernels : number of kernels (optional, default 10,000)
80+
81+
Examples:
82+
> python reproduce_experiments_ucr.py -d bakeoff.txt -i ./Univariate_arff -o ./
83+
> python reproduce_experiments_ucr.py -d additional.txt -i ./Univariate_arff -o ./ -n 1 -k 1000
84+
```
85+
86+
### [`reproduce_experiments_scalability.py`](code/reproduce_experiments_scalability.py)
87+
88+
```
89+
Arguments:
90+
-tr --training_path : training dataset (csv)
91+
-te --test_path : test dataset (csv)
92+
-o --output_path : path for results
93+
-k --num_kernels : number of kernels
94+
95+
Examples:
96+
> python reproduce_experiments_scalability.py -tr training.csv -te test.csv -o ./ -k 100
97+
> python reproduce_experiments_scalability.py -tr training.csv -te test.csv -o ./ -k 1000
98+
```
99+
69100
## Acknowledgements
70101

71102
We thank Professor Eamonn Keogh and all the people who have contributed to the UCR time series classification archive. Figures in our paper showing the ranking of different classifiers and variants of ROCKET were produced using code from [Ismail Fawaz et al. (2019)](https://github.com/hfawaz/cd-diagram).
+286
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
# Angus Dempster, Francois Petitjean, Geoff Webb
2+
#
3+
# @article{dempster_etal_2020,
4+
# author = {Dempster, Angus and Petitjean, Fran\c{c}ois and Webb, Geoffrey I},
5+
# title = {ROCKET: Exceptionally fast and accurate time classification using random convolutional kernels},
6+
# year = {2020},
7+
# journal = {Data Mining and Knowledge Discovery},
8+
# doi = {https://doi.org/10.1007/s10618-020-00701-z}
9+
# }
10+
#
11+
# https://arxiv.org/abs/1910.13051 (preprint)
12+
13+
import argparse
14+
import numpy as np
15+
import pandas as pd
16+
import time
17+
import torch, torch.nn as nn, torch.optim as optim
18+
19+
from rocket_functions import apply_kernels, generate_kernels
20+
21+
# == notes =====================================================================
22+
23+
# Reproduce the scalability experiments.
24+
#
25+
# Arguments:
26+
# -tr --training_path : training dataset (csv)
27+
# -te --test_path : test dataset (csv)
28+
# -o --output_path : path for results
29+
# -k --num_kernels : number of kernels
30+
31+
# == parse arguments ===========================================================
32+
33+
parser = argparse.ArgumentParser()
34+
35+
parser.add_argument("-tr", "--training_path", required = True)
36+
parser.add_argument("-te", "--test_path", required = True)
37+
parser.add_argument("-o", "--output_path", required = True)
38+
parser.add_argument("-k", "--num_kernels", type = int)
39+
40+
arguments = parser.parse_args()
41+
42+
# == training function =========================================================
43+
44+
def train(X,
45+
Y,
46+
X_validation,
47+
Y_validation,
48+
kernels,
49+
num_features,
50+
num_classes,
51+
minibatch_size = 256,
52+
max_epochs = 100,
53+
patience = 2, # x10 minibatches; reset if loss improves
54+
tranche_size = 2 ** 11,
55+
cache_size = 2 ** 14): # as much as possible
56+
57+
# -- init ------------------------------------------------------------------
58+
59+
def init(layer):
60+
if isinstance(layer, nn.Linear):
61+
nn.init.constant_(layer.weight.data, 0)
62+
nn.init.constant_(layer.bias.data, 0)
63+
64+
# -- model -----------------------------------------------------------------
65+
66+
model = nn.Sequential(nn.Linear(num_features, num_classes)) # logistic / softmax regression
67+
loss_function = nn.CrossEntropyLoss()
68+
optimizer = optim.Adam(model.parameters())
69+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.5, min_lr = 1e-8)
70+
model.apply(init)
71+
72+
# -- run -------------------------------------------------------------------
73+
74+
minibatch_count = 0
75+
best_validation_loss = np.inf
76+
stall_count = 0
77+
stop = False
78+
79+
num_examples = len(X)
80+
num_tranches = np.int(np.ceil(num_examples / tranche_size))
81+
82+
cache = np.zeros((min(cache_size, num_examples), num_features))
83+
cache_count = 0
84+
85+
for epoch in range(max_epochs):
86+
87+
if epoch > 0 and stop:
88+
break
89+
90+
for tranche_index in range(num_tranches):
91+
92+
if epoch > 0 and stop:
93+
break
94+
95+
a = tranche_size * tranche_index
96+
b = a + tranche_size
97+
98+
Y_tranche = Y[a:b]
99+
100+
# if cached, use cached transform; else transform and cache the result
101+
if b <= cache_count:
102+
103+
X_tranche_transform = cache[a:b]
104+
105+
else:
106+
107+
X_tranche = X[a:b]
108+
X_tranche = (X_tranche - X_tranche.mean(axis = 1, keepdims = True)) / X_tranche.std(axis = 1, keepdims = True) # normalise time series
109+
X_tranche_transform = apply_kernels(X_tranche, kernels)
110+
111+
if epoch == 0 and tranche_index == 0:
112+
113+
# per-feature mean and standard deviation (estimated on first tranche)
114+
f_mean = X_tranche_transform.mean(0)
115+
f_std = X_tranche_transform.std(0) + 1e-8
116+
117+
# normalise and transform validation data
118+
X_validation = (X_validation - X_validation.mean(axis = 1, keepdims = True)) / X_validation.std(axis = 1, keepdims = True) # normalise time series
119+
X_validation_transform = apply_kernels(X_validation, kernels)
120+
X_validation_transform = (X_validation_transform - f_mean) / f_std # normalise transformed features
121+
X_validation_transform = torch.FloatTensor(X_validation_transform)
122+
Y_validation = torch.LongTensor(Y_validation)
123+
124+
X_tranche_transform = (X_tranche_transform - f_mean) / f_std # normalise transformed features
125+
126+
if b <= cache_size:
127+
128+
cache[a:b] = X_tranche_transform
129+
cache_count = b
130+
131+
X_tranche_transform = torch.FloatTensor(X_tranche_transform)
132+
Y_tranche = torch.LongTensor(Y_tranche)
133+
134+
minibatches = torch.randperm(len(X_tranche_transform)).split(minibatch_size)
135+
136+
for minibatch_index, minibatch in enumerate(minibatches):
137+
138+
if epoch > 0 and stop:
139+
break
140+
141+
# abandon undersized minibatches
142+
if minibatch_index > 0 and len(minibatch) < minibatch_size:
143+
break
144+
145+
# -- (optional) minimal lr search ------------------------------
146+
147+
# default lr for Adam may cause training loss to diverge for a
148+
# large number of kernels; lr minimising training loss on first
149+
# update should ensure training loss converges
150+
151+
if epoch == 0 and tranche_index == 0 and minibatch_index == 0:
152+
153+
candidate_lr = 10 ** np.linspace(-1, -6, 6)
154+
155+
best_lr = None
156+
best_training_loss = np.inf
157+
158+
for lr in candidate_lr:
159+
160+
lr_model = nn.Sequential(nn.Linear(num_features, num_classes))
161+
lr_optimizer = optim.Adam(lr_model.parameters())
162+
lr_model.apply(init)
163+
164+
for param_group in lr_optimizer.param_groups:
165+
param_group["lr"] = lr
166+
167+
# perform a single update
168+
lr_optimizer.zero_grad()
169+
Y_tranche_predictions = lr_model(X_tranche_transform[minibatch])
170+
training_loss = loss_function(Y_tranche_predictions, Y_tranche[minibatch])
171+
training_loss.backward()
172+
lr_optimizer.step()
173+
174+
Y_tranche_predictions = lr_model(X_tranche_transform)
175+
training_loss = loss_function(Y_tranche_predictions, Y_tranche).item()
176+
177+
if training_loss < best_training_loss:
178+
best_training_loss = training_loss
179+
best_lr = lr
180+
181+
for param_group in optimizer.param_groups:
182+
param_group["lr"] = best_lr
183+
184+
# -- training --------------------------------------------------
185+
186+
optimizer.zero_grad()
187+
Y_tranche_predictions = model(X_tranche_transform[minibatch])
188+
training_loss = loss_function(Y_tranche_predictions, Y_tranche[minibatch])
189+
training_loss.backward()
190+
optimizer.step()
191+
192+
minibatch_count += 1
193+
194+
if minibatch_count % 10 == 0:
195+
196+
Y_validation_predictions = model(X_validation_transform)
197+
validation_loss = loss_function(Y_validation_predictions, Y_validation)
198+
199+
scheduler.step(validation_loss)
200+
201+
if validation_loss.item() >= best_validation_loss:
202+
stall_count += 1
203+
if stall_count >= patience:
204+
stop = True
205+
else:
206+
best_validation_loss = validation_loss.item()
207+
if not stop:
208+
stall_count = 0
209+
210+
return model, f_mean, f_std
211+
212+
# == run =======================================================================
213+
214+
# -- run through dataset sizes -------------------------------------------------
215+
216+
all_num_training_examples = 2 ** np.arange(8, 20 + 1)
217+
218+
results = pd.DataFrame(index = all_num_training_examples,
219+
columns = ["accuracy", "time_training_seconds"],
220+
data = 0)
221+
results.index.name = "num_training_examples"
222+
223+
print(f" {arguments.num_kernels:,} Kernels ".center(80, "="))
224+
225+
for num_training_examples in all_num_training_examples:
226+
227+
if num_training_examples == all_num_training_examples[0]:
228+
print("Number of training examples:" + f"{num_training_examples:,}".rjust(75 - 28 - 5, " ") + ".....", end = "", flush = True)
229+
else:
230+
print(f"{num_training_examples:,}".rjust(75 - 5, " ") + ".....", end = "", flush = True)
231+
232+
# -- read training and validation data -------------------------------------
233+
234+
# if training data does not fit in memory, it is possible to load the
235+
# training data inside the train(...) function, using the *chunksize*
236+
# argument for pandas.read_csv(...) (and roughly substituting chunks for
237+
# tranches); similarly, if the cache does not fit in memory, consider
238+
# caching the transformed features on disk
239+
240+
# here, validation data is always the first 2 ** 11 = 2,048 examples
241+
validation_data = pd.read_csv(arguments.training_path, header = None, nrows = 2 ** 11).values
242+
Y_validation, X_validation = validation_data[:, 0], validation_data[:, 1:]
243+
244+
training_data = pd.read_csv(arguments.training_path, header = None, skiprows = 2 ** 11, nrows = num_training_examples).values
245+
Y_training, X_training = training_data[:, 0], training_data[:, 1:]
246+
247+
# -- generate kernels ------------------------------------------------------
248+
249+
kernels = generate_kernels(X_training.shape[1], arguments.num_kernels)
250+
251+
# -- train -----------------------------------------------------------------
252+
253+
time_a = time.perf_counter()
254+
model, f_mean, f_std = train(X_training,
255+
Y_training,
256+
X_validation,
257+
Y_validation,
258+
kernels,
259+
arguments.num_kernels * 2,
260+
num_classes = 24)
261+
time_b = time.perf_counter()
262+
263+
results.loc[num_training_examples, "time_training_seconds"] = time_b - time_a
264+
265+
# -- test ------------------------------------------------------------------
266+
267+
# read test data (here, we test on a subset of the full test data)
268+
test_data = pd.read_csv(arguments.test_path, header = None, nrows = 2 ** 11).values
269+
Y_test, X_test = test_data[:, 0].astype(np.int), test_data[:, 1:]
270+
271+
# normalise and transform test data
272+
X_test = (X_test - X_test.mean(axis = 1, keepdims = True)) / X_test.std(axis = 1, keepdims = True) # normalise time series
273+
X_test_transform = apply_kernels(X_test, kernels)
274+
X_test_transform = (X_test_transform - f_mean) / f_std # normalise transformed features
275+
276+
# predict
277+
model.eval()
278+
Y_test_predictions = model(torch.FloatTensor(X_test_transform))
279+
280+
results.loc[num_training_examples, "accuracy"] = (Y_test_predictions.max(1)[1].numpy() == Y_test).mean()
281+
282+
print("Done.")
283+
284+
print(f" FINISHED ".center(80, "="))
285+
286+
results.to_csv(f"{arguments.output_path}/results_scalability_k={arguments.num_kernels}.csv")

0 commit comments

Comments
 (0)