Skip to content

Eric-Fithian/java-cnn

Repository files navigation

Neural Network Framework

Overview

This project is a comprehensive neural network framework implemented entirely in Java, built from the ground up without relying on external deep learning libraries. It demonstrates a profound understanding of deep learning architectures and algorithms, showcasing the ability to create complex neural network models with various layer types and optimization techniques.

Features

  • Supports arbitrary network topologies including:
    • Convolutional layers
    • Pooling layers
    • Flattening layers
    • Dense (fully connected) layers
  • Implements core CNN operations:
    • Convolution
    • Pooling
    • Backpropagation
  • Includes a variety of activation functions:
    • ReLU
    • Sigmoid
    • Tanh
  • Offers multiple optimization methods:
    • Stochastic Gradient Descent (SGD)
    • Momentum-based Gradient Descent
  • Provides various loss functions:
    • Softmax Cross-Entropy
    • Binary Cross-Entropy
  • Utilizes vectorization and efficient matrix operations for optimized performance
  • Offers an intuitive API for model definition, training, and evaluation

Usage

The framework allows for easy creation and training of neural networks. Here's a basic example of creating a simple feedforward neural network:

NeuralNetwork nn = NeuralNetwork.newBuilder()
    .setInputShape(2)
    .addLayer("Dense", 10)
    .addLayer("Sigmoid")
    .addLayer("Dense", 2)
    .addLayer("SoftmaxCrossEntropy")
    .addOptimizer(new VanillaGradientDescent())
    .build();

// Train the network
nn.train(inputs, targets, 1000, 2, 0.05);

// Make predictions
INDArray predictions = nn.predict(inputs);

Testing

The project includes a comprehensive test suite (NeuralNetworkTest.java) that demonstrates various use cases:

  1. TestFFNN: Tests a simple feedforward neural network on a binary classification task.
  2. TestCNN: Evaluates a basic convolutional neural network on a small dataset.
  3. CnnMNistTest: Tests a more complex CNN on the MNIST dataset.
  4. FFNNMNistTest: Evaluates a feedforward neural network on the MNIST dataset.

Performance

The framework has been validated against established benchmarks, achieving comparable results to popular deep learning libraries on standard datasets like MNIST.

Dependencies

The project uses a basic matrix math library (nd4j) for efficient matrix operations. All neural network components and algorithms are implemented from scratch.

Conclusion

This neural network framework demonstrates a deep understanding of the mathematical foundations of neural networks and showcases strong software engineering skills in designing and implementing a complex system from the ground up.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages