1234567891011121314151617181920212223242526272829303132 |
- package org.example;
- import org.deeplearning4j.nn.api.OptimizationAlgorithm;
- import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
- import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
- import org.deeplearning4j.nn.conf.layers.DenseLayer;
- import org.deeplearning4j.nn.conf.layers.OutputLayer;
- import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
- import org.deeplearning4j.nn.weights.WeightInit;
- import org.nd4j.linalg.activations.Activation;
- import org.nd4j.linalg.learning.config.Nesterovs;
- import org.nd4j.linalg.lossfunctions.LossFunctions;
- public class Main {
- public static void main(String[] args) {
- MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
- .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
- .updater(new Nesterovs(0.01, 0.9))
- .weightInit(WeightInit.XAVIER)
- .list()
- .layer(new DenseLayer.Builder().nIn(784).nOut(250).activation(Activation.RELU).build())
- .layer(new OutputLayer.Builder().nIn(250).nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build())
- .build();
- // 生成MultiLayerNetwork实例
- MultiLayerNetwork model = new MultiLayerNetwork(conf);
- model.init();
- // 在此添加训练代码
- System.out.println("Hello world!");
- }
- }
|