Skip to content

Commit 1868861

Browse files
committed
Improve ease of use of datasets, and remove some unfinished code
1 parent 810cd15 commit 1868861

File tree

3 files changed

+91
-107
lines changed

3 files changed

+91
-107
lines changed

datasets.py

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import keras.datasets as _kds
2+
from keras.utils.np_utils import to_categorical
3+
4+
5+
class _Dataset:
6+
def __init__(self, name, input_shape, load_fn, use_1d=False):
7+
if use_1d:
8+
name += '_1d'
9+
10+
input_shape_1d = 1
11+
for dim in input_shape:
12+
input_shape_1d *= dim
13+
input_shape = input_shape_1d,
14+
15+
self.name = name
16+
self.input_shape = input_shape
17+
self._load_fn = load_fn
18+
19+
def load_train_data(self):
20+
(train_data, train_targets), _ = self._load_fn()
21+
return self._process(train_data, train_targets)
22+
23+
def load_test_data(self):
24+
_, (test_data, test_targets) = self._load_fn()
25+
return self._process(test_data, test_targets)
26+
27+
def _process(self, data, targets):
28+
# reshape data (to fit input layer)
29+
data = data.reshape((data.shape[0],) + self.input_shape)
30+
31+
# normalize data to [0.0, 1.0]
32+
data = data.astype('float32')
33+
data /= 255.0
34+
35+
# make targets categorical (to fit output layer)
36+
targets = to_categorical(targets, 10)
37+
38+
return data, targets
39+
40+
41+
class _CIFAR10(_Dataset):
42+
def __init__(self, use_1d=False):
43+
name = 'cifar10'
44+
input_shape = (32, 32, 3)
45+
load_fn = _kds.cifar10.load_data
46+
super(_CIFAR10, self).__init__(name, input_shape, load_fn, use_1d)
47+
48+
49+
class _CIFAR100(_Dataset):
50+
def __init__(self, use_1d=False):
51+
name = 'cifar100'
52+
input_shape = (32, 32, 3)
53+
load_fn = _kds.cifar100.load_data
54+
super(_CIFAR100, self).__init__(name, input_shape, load_fn, use_1d)
55+
56+
57+
class _FASHION_MNIST(_Dataset):
58+
def __init__(self, use_1d=False):
59+
name = 'fashion_mnist'
60+
input_shape = (28, 28, 1)
61+
load_fn = _kds.fashion_mnist.load_data
62+
super(_FASHION_MNIST, self).__init__(name, input_shape, load_fn, use_1d)
63+
64+
65+
class _MNIST(_Dataset):
66+
def __init__(self, use_1d=False):
67+
name = 'mnist'
68+
input_shape = (28, 28, 1)
69+
load_fn = _kds.mnist.load_data
70+
super(_MNIST, self).__init__(name, input_shape, load_fn, use_1d)
71+
72+
73+
cifar10 = _CIFAR10()
74+
cifar10_1d = _CIFAR10(use_1d=True)
75+
76+
cifar100 = _CIFAR100()
77+
cifar100_1d = _CIFAR100(use_1d=True)
78+
79+
fashion_mnist = _FASHION_MNIST()
80+
fashion_mnist_1d = _FASHION_MNIST(use_1d=True)
81+
82+
mnist = _MNIST()
83+
mnist_1d = _MNIST(use_1d=True)

loader.py

-76
This file was deleted.

main.py

+8-31
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,20 @@
1-
from loader import get_input_shape, mnist_train_data, mnist_test_data
2-
31
import numpy as np
42
import tensorflow as tf
5-
from os import mkdir
6-
from os.path import join
73
from time import time
84

9-
from keras.models import Sequential, load_model
5+
from keras.models import Sequential
106
from keras.layers import Dense
117
from keras.activations import relu, softmax
128
from keras.losses import categorical_crossentropy
139
from keras.optimizers import Adam
1410
from keras.metrics import categorical_accuracy
1511

12+
# NOTE: change import to use different dataset
13+
from datasets import mnist_1d as dataset
14+
1615
# removes pestering tensorflow warnings
1716
tf.logging.set_verbosity(tf.logging.ERROR)
1817

19-
MODELS_FOLDER = 'models'
20-
MODEL_EXT = '.h5'
21-
2218

2319
def main():
2420
start_time = time()
@@ -33,10 +29,8 @@ def main():
3329

3430

3531
def create():
36-
input_shape = get_input_shape()
37-
3832
model = Sequential([
39-
Dense(units=1, activation=relu, input_shape=input_shape),
33+
Dense(units=1, activation=relu, input_shape=dataset.input_shape),
4034
Dense(units=10, activation=softmax)
4135
])
4236

@@ -53,7 +47,7 @@ def create():
5347
def train(model, epochs=10):
5448
print('\nCommence model training\n')
5549

56-
train_images, train_targets = mnist_train_data()
50+
train_images, train_targets = dataset.load_train_data()
5751

5852
# https://keras.io/models/sequential/#fit
5953
model.fit(
@@ -74,7 +68,7 @@ def train(model, epochs=10):
7468
def test(model, verbose=False):
7569
print('\nCommence model testing\n')
7670

77-
test_images, test_targets = mnist_test_data()
71+
test_images, test_targets = dataset.load_test_data()
7872

7973
# test for all indices and count correctly classified
8074
correctly_classified = 0
@@ -104,7 +98,7 @@ def test(model, verbose=False):
10498
def test_with_evaluate(model, verbose=True):
10599
print('\nCommence model testing\n')
106100

107-
test_images, test_targets = mnist_test_data()
101+
test_images, test_targets = dataset.load_test_data()
108102

109103
# https://keras.io/models/sequential/#evaluate
110104
loss_and_metrics = model.evaluate(
@@ -117,21 +111,4 @@ def test_with_evaluate(model, verbose=True):
117111
print('\nCompleted model testing')
118112

119113

120-
def save(model, filename):
121-
# make path, if not exists
122-
try:
123-
mkdir(MODELS_FOLDER)
124-
except FileExistsError:
125-
# file exists, which is what we want
126-
pass
127-
128-
path = join(MODELS_FOLDER, filename + MODEL_EXT)
129-
model.save(path)
130-
131-
132-
def load(filename):
133-
path = join(MODELS_FOLDER, filename + MODEL_EXT)
134-
return load_model(path)
135-
136-
137114
main()

0 commit comments

Comments
 (0)