Main.java 1.4 KB

1234567891011121314151617181920212223242526272829303132
  1. package org.example;
  2. import org.deeplearning4j.nn.api.OptimizationAlgorithm;
  3. import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
  4. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  5. import org.deeplearning4j.nn.conf.layers.DenseLayer;
  6. import org.deeplearning4j.nn.conf.layers.OutputLayer;
  7. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  8. import org.deeplearning4j.nn.weights.WeightInit;
  9. import org.nd4j.linalg.activations.Activation;
  10. import org.nd4j.linalg.learning.config.Nesterovs;
  11. import org.nd4j.linalg.lossfunctions.LossFunctions;
  12. public class Main {
  13. public static void main(String[] args) {
  14. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
  15. .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
  16. .updater(new Nesterovs(0.01, 0.9))
  17. .weightInit(WeightInit.XAVIER)
  18. .list()
  19. .layer(new DenseLayer.Builder().nIn(784).nOut(250).activation(Activation.RELU).build())
  20. .layer(new OutputLayer.Builder().nIn(250).nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build())
  21. .build();
  22. // 生成MultiLayerNetwork实例
  23. MultiLayerNetwork model = new MultiLayerNetwork(conf);
  24. model.init();
  25. // 在此添加训练代码
  26. System.out.println("Hello world!");
  27. }
  28. }