1
- from loader import get_input_shape , mnist_train_data , mnist_test_data
2
-
3
1
import numpy as np
4
2
import tensorflow as tf
5
- from os import mkdir
6
- from os .path import join
7
3
from time import time
8
4
9
- from keras .models import Sequential , load_model
5
+ from keras .models import Sequential
10
6
from keras .layers import Dense
11
7
from keras .activations import relu , softmax
12
8
from keras .losses import categorical_crossentropy
13
9
from keras .optimizers import Adam
14
10
from keras .metrics import categorical_accuracy
15
11
12
+ # NOTE: change import to use different dataset
13
+ from datasets import mnist_1d as dataset
14
+
16
15
# removes pestering tensorflow warnings
17
16
tf .logging .set_verbosity (tf .logging .ERROR )
18
17
19
- MODELS_FOLDER = 'models'
20
- MODEL_EXT = '.h5'
21
-
22
18
23
19
def main ():
24
20
start_time = time ()
@@ -33,10 +29,8 @@ def main():
33
29
34
30
35
31
def create ():
36
- input_shape = get_input_shape ()
37
-
38
32
model = Sequential ([
39
- Dense (units = 1 , activation = relu , input_shape = input_shape ),
33
+ Dense (units = 1 , activation = relu , input_shape = dataset . input_shape ),
40
34
Dense (units = 10 , activation = softmax )
41
35
])
42
36
@@ -53,7 +47,7 @@ def create():
53
47
def train (model , epochs = 10 ):
54
48
print ('\n Commence model training\n ' )
55
49
56
- train_images , train_targets = mnist_train_data ()
50
+ train_images , train_targets = dataset . load_train_data ()
57
51
58
52
# https://keras.io/models/sequential/#fit
59
53
model .fit (
@@ -74,7 +68,7 @@ def train(model, epochs=10):
74
68
def test (model , verbose = False ):
75
69
print ('\n Commence model testing\n ' )
76
70
77
- test_images , test_targets = mnist_test_data ()
71
+ test_images , test_targets = dataset . load_test_data ()
78
72
79
73
# test for all indices and count correctly classified
80
74
correctly_classified = 0
@@ -104,7 +98,7 @@ def test(model, verbose=False):
104
98
def test_with_evaluate (model , verbose = True ):
105
99
print ('\n Commence model testing\n ' )
106
100
107
- test_images , test_targets = mnist_test_data ()
101
+ test_images , test_targets = dataset . load_test_data ()
108
102
109
103
# https://keras.io/models/sequential/#evaluate
110
104
loss_and_metrics = model .evaluate (
@@ -117,21 +111,4 @@ def test_with_evaluate(model, verbose=True):
117
111
print ('\n Completed model testing' )
118
112
119
113
120
- def save (model , filename ):
121
- # make path, if not exists
122
- try :
123
- mkdir (MODELS_FOLDER )
124
- except FileExistsError :
125
- # file exists, which is what we want
126
- pass
127
-
128
- path = join (MODELS_FOLDER , filename + MODEL_EXT )
129
- model .save (path )
130
-
131
-
132
- def load (filename ):
133
- path = join (MODELS_FOLDER , filename + MODEL_EXT )
134
- return load_model (path )
135
-
136
-
137
114
main ()
0 commit comments