-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNetwork.java
More file actions
97 lines (89 loc) · 3.11 KB
/
Network.java
File metadata and controls
97 lines (89 loc) · 3.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import java.util.LinkedList;
import java.util.Iterator;
import java.io.FileWriter;
import java.io.IOException;
public class Network {
LinkedList<Layer> network;
double[] output;
public Network(int[] numNodes, double[] initialNodes){
network = new LinkedList<>();
for(int i = 0; i < numNodes.length-1; i++){
network.add(new Layer(numNodes[i], numNodes[i+1],i));
}
network.get(0).nodes=initialNodes;
}
public void setInitialNodes(double[] vals){
network.get(0).nodes = vals;
}
public void forwardProp(){
Iterator<Layer> it = network.iterator();
Layer currLayer = it.next();
double[] newNodes = currLayer.feedForward();
while(it.hasNext()){
Layer nextLayer = it.next();
nextLayer.prev_unactivatedNodes = newNodes;
nextLayer.nodes = Layer.reLU(newNodes);
currLayer = nextLayer;
newNodes = currLayer.feedForward();
}
output = softMax(newNodes);
}
private double[] softMax(double[] raw){
double[] output = new double[raw.length];
double max = raw[0];
for (double d : raw) {
if (d > max) max = d;
}
double sum = 0;
for (int i = 0; i < raw.length; i++) {
output[i] = Math.exp(raw[i] - max);
sum += output[i];
}
for (int i = 0; i < raw.length; i++) {
output[i] /= sum;
}
return output;
}
public double classify(double[] map){
// Used during inference/testing (not during training)
int maxIndex = 0;
double maxVal = Double.MIN_VALUE;
for(int i = 0; i < output.length; i++){
if(output[i] > maxVal){
maxIndex = i;
maxVal = output[i];
}
}
return map[maxIndex];
}
public void error(double[] expected){
double sum = 0;
for(int i = 0; i < expected.length; i++){
sum+=0.5*Math.pow((expected[i]-output[i]),2);
}
// try (FileWriter writer = new FileWriter("output.txt", true)) { // true for append mode
// writer.write(""+sum+"\n");
// } catch (IOException e) {
// e.printStackTrace();
// }
}
public void backProp(double[] expected, double alpha){
// Calculate the error:
double[] dCdZ = Layer.addVectors(output, expected, -1.0); // Calculate the output layers's cost gradient
Iterator<Layer> it = network.descendingIterator();
Layer currLayer = it.next(); // Get output layer.
currLayer.dCdZ_before = dCdZ;
while(true){
System.out.println("Backpropagating Layer: " + currLayer.numLayer);
currLayer.backProp(alpha);
if(currLayer.numLayer == 0){
break;
}
if(it.hasNext()){
Layer nextLayer = it.next();
nextLayer.dCdZ_before = currLayer.dCdZ;
currLayer = nextLayer;
}
}
}
}