xiuwei před 6 měsíci
revize
f8bd8af788
2 změnil soubory, kde provedl 107 přidání a 0 odebrání
  1. 75 0
      pom.xml
  2. 32 0
      src/main/java/org/example/Main.java

+ 75 - 0
pom.xml

@@ -0,0 +1,75 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+    <modelVersion>4.0.0</modelVersion>
+
+    <groupId>org.example</groupId>
+    <artifactId>testDs</artifactId>
+    <version>1.0-SNAPSHOT</version>
+
+    <properties>
+        <maven.compiler.source>1.8</maven.compiler.source>
+        <maven.compiler.target>1.8</maven.compiler.target>
+        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
+    </properties>
+    <repositories>
+        <repository>
+            <id>syjy-repos</id>
+            <name>syjy Repository</name>
+            <url>http://49.4.68.219:8888/nexus/content/groups/public/</url>
+            <releases>
+                <enabled>true</enabled>
+            </releases>
+            <snapshots>
+                <enabled>true</enabled>
+            </snapshots>
+        </repository>
+    </repositories>
+    <dependencies>
+        <dependency>
+            <groupId>org.deeplearning4j</groupId>
+            <artifactId>deeplearning4j-core</artifactId>
+            <version>1.0.0-beta7</version>
+        </dependency>
+        <dependency>
+            <groupId>org.nd4j</groupId>
+            <artifactId>nd4j-native-platform</artifactId>
+            <version>1.0.0-beta7</version>
+        </dependency>
+    </dependencies>
+    <build>
+        <plugins>
+            <plugin>
+                <groupId>org.codehaus.mojo</groupId>
+                <artifactId>exec-maven-plugin</artifactId>
+                <version>3.0.0</version>
+                <configuration>
+                    <mainClass>org.example.Main</mainClass>
+                </configuration>
+            </plugin>
+            <plugin>
+                <groupId>org.apache.maven.plugins</groupId>
+                <artifactId>maven-shade-plugin</artifactId>
+                <version>2.4.3</version>
+                <executions>
+                    <execution>
+                        <phase>package</phase>
+                        <goals>
+                            <goal>shade</goal>
+                        </goals>
+                        <configuration>
+                            <createDependencyReducedPom>false</createDependencyReducedPom>
+                            <transformers>
+                                <transformer
+                                        implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
+                                    <mainClass>org.example.Main</mainClass>
+                                </transformer>
+                            </transformers>
+                        </configuration>
+                    </execution>
+                </executions>
+            </plugin>
+        </plugins>
+    </build>
+</project>

+ 32 - 0
src/main/java/org/example/Main.java

@@ -0,0 +1,32 @@
+package org.example;
+
+import org.deeplearning4j.nn.api.OptimizationAlgorithm;
+import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
+import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
+import org.deeplearning4j.nn.conf.layers.DenseLayer;
+import org.deeplearning4j.nn.conf.layers.OutputLayer;
+import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
+import org.deeplearning4j.nn.weights.WeightInit;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.learning.config.Nesterovs;
+import org.nd4j.linalg.lossfunctions.LossFunctions;
+
+public class Main {
+    public static void main(String[] args) {
+
+        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
+                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
+                .updater(new Nesterovs(0.01, 0.9))
+                .weightInit(WeightInit.XAVIER)
+                .list()
+                .layer(new DenseLayer.Builder().nIn(784).nOut(250).activation(Activation.RELU).build())
+                .layer(new OutputLayer.Builder().nIn(250).nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build())
+                .build();
+        // 生成MultiLayerNetwork实例
+        MultiLayerNetwork model = new MultiLayerNetwork(conf);
+        model.init();
+        // 在此添加训练代码
+
+        System.out.println("Hello world!");
+    }
+}