forked from Hannah-Richert/iannwtf_hw7
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassify.py
61 lines (48 loc) · 2.4 KB
/
classify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import tensorflow as tf
from util import train_step, test
def classify(model, optimizer, num_epochs, train_ds, valid_ds):
"""
Trains and tests our predefined model.
Args:
- model <tensorflow.keras.Model>: our untrained model
- optimizer <keras function>: optimizer for the model
- num_epochs <int>: number of training epochs
- train_ds <tensorflow.python.data.ops.dataset_ops.PrefetchDataset>: our training dataset
- valid_ds <tensorflow.python.data.ops.dataset_ops.PrefetchDataset>: our validation set for testing and optimizing hyperparameters
Returns:
- results <list<list<float>>>: list with losses and accuracies
- model <tensorflow.keras.Model>: our trained MLP model
"""
tf.keras.backend.clear_session()
# initialize the loss: categorical cross entropy
cross_entropy_loss = tf.keras.losses.BinaryCrossentropy()
# initialize lists for later visualization.
train_losses = []
valid_losses = []
valid_accuracies = []
# testing on our valid_ds once before we begin
valid_loss, valid_accuracy = test(model, valid_ds, cross_entropy_loss,False)
valid_losses.append(valid_loss)
valid_accuracies.append(valid_accuracy)
# Testing on our train_ds once before we begin
train_loss, _ = test(model, train_ds, cross_entropy_loss,False)
train_losses.append(train_loss)
# training our model for num_epochs epochs.
for epoch in range(num_epochs):
print(
f'Epoch: {str(epoch+1)} starting with (validation set) accuracy {valid_accuracies[-1]} and loss {valid_losses[-1]}')
# training (and calculating loss while training)
epoch_loss_agg = []
for input, target in train_ds:
train_loss = train_step(
model, input, target, cross_entropy_loss, optimizer,True)
epoch_loss_agg.append(train_loss)
# track training loss
train_losses.append(tf.reduce_mean(epoch_loss_agg))
print(f'Epoch: {str(epoch+1)} train loss: {train_losses[-1]}')
# testing our model in each epoch to track accuracy and loss on the validation set
valid_loss, valid_accuracy = test(model, valid_ds, cross_entropy_loss,False)
valid_losses.append(valid_loss)
valid_accuracies.append(valid_accuracy)
results = [train_losses, valid_losses, valid_accuracies]
return results, model