-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTrainTest.java
More file actions
75 lines (72 loc) · 2.93 KB
/
TrainTest.java
File metadata and controls
75 lines (72 loc) · 2.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import java.util.Arrays;
import java.util.Scanner;
import java.io.File;
public class TrainTest {
Network network;
String TRAIN_DATA = "mnist_train.csv";
String TEST_DATA = "mnist_test.csv";
double accuracy;
File train, test;
public TrainTest(){
network = new Network(new int[]{784, 72, 10}, new double[784]);
train = new File(TRAIN_DATA);
test = new File(TEST_DATA);
//System.out.println(Arrays.deepToString(network.network.get(0).weights));
}
public void Train() {
try (Scanner s = new Scanner(train)) {
int j = 1;
//s.nextLine();
while (s.hasNextLine()) {
System.out.println(j);
double[] line = java.util.Arrays.stream(s.nextLine().split(",")).mapToDouble(Double::parseDouble).toArray();
double[] image_data = Arrays.copyOfRange(line, 1,line.length);
for(int i = 0; i < image_data.length; i++){
image_data[i] /= 255;
}
network.setInitialNodes(image_data);
network.forwardProp();
double[] expected = new double[10];
for(int i = 0; i < expected.length; i++){
if(line[0] == i){
expected[i] = 1;
}
else{
expected[i] = 0;
}
}
network.backProp(expected, 0.01);
System.out.println("Train number: " + j);
j++;
}
//System.out.println(Arrays.deepToString(network.network.get(0).weights));
} catch (java.io.FileNotFoundException e) {
System.err.println("Training data file not found: " + e.getMessage());
}
}
public void Test(){
try (Scanner s = new Scanner(test)){
int j = 1;
int correct = 0;
//s.nextLine();
while(s.hasNextLine()){
double[] line = java.util.Arrays.stream(s.nextLine().split(",")).mapToDouble(Double::parseDouble).toArray();
double[] image_data = Arrays.copyOfRange(line, 1,line.length);
for(int i = 0; i < image_data.length; i++){
image_data[i] /= 255;
}
network.setInitialNodes(image_data);
network.forwardProp();
double output = network.classify(new double[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
if(output == line[0]){
correct++;
}
accuracy = correct / (double)j;
System.out.println("Accuracy for image " + j + ":" + accuracy);
j++;
}
} catch (java.io.FileNotFoundException e) {
System.err.println("Training data file not found: " + e.getMessage());
}
}
}