in tensorflow-examples-legacy/training/src/main/java/Train.java [29:89]
public static void main(String[] args) throws Exception {
if (args.length != 2) {
System.err.println("Require two arguments: The GraphDef file and checkpoint directory");
System.exit(1);
}
final byte[] graphDef = Files.readAllBytes(Paths.get(args[0]));
final String checkpointDir = args[1];
final boolean checkpointExists = Files.exists(Paths.get(checkpointDir));
try (Graph graph = new Graph();
Session sess = new Session(graph);
Tensor<String> checkpointPrefix =
Tensors.create(Paths.get(checkpointDir, "ckpt").toString())) {
graph.importGraphDef(graphDef);
// Initialize or restore.
// The names of the tensors in the graph are printed out by the program
// that created the graph:
// https://github.com/tensorflow/models/blob/master/samples/languages/java/training/model/create_graph.py
if (checkpointExists) {
sess.runner().feed("save/Const", checkpointPrefix).addTarget("save/restore_all").run();
} else {
sess.runner().addTarget("init").run();
}
System.out.print("Starting from : ");
printVariables(sess);
// Train a bunch of times.
// (Will be much more efficient if we sent batches instead of individual values).
final Random r = new Random();
final int NUM_EXAMPLES = 500;
for (int i = 1; i <= 5; i++) {
for (int n = 0; n < NUM_EXAMPLES; n++) {
float in = r.nextFloat();
try (Tensor<Float> input = Tensors.create(in);
Tensor<Float> target = Tensors.create(3 * in + 2)) {
// Again the tensor names are from the program that created the graph.
// https://github.com/tensorflow/models/blob/master/samples/languages/java/training/model/create_graph.py
sess.runner().feed("input", input).feed("target", target).addTarget("train").run();
}
}
System.out.printf("After %5d examples: ", i*NUM_EXAMPLES);
printVariables(sess);
}
// Checkpoint.
// The feed and target name are from the program that created the graph.
// https://github.com/tensorflow/models/blob/master/samples/languages/java/training/model/create_graph.py.
sess.runner().feed("save/Const", checkpointPrefix).addTarget("save/control_dependency").run();
// Example of "inference" in the same graph:
try (Tensor<Float> input = Tensors.create(1.0f);
Tensor<Float> output =
sess.runner().feed("input", input).fetch("output").run().get(0).expect(Float.class)) {
System.out.printf(
"For input %f, produced %f (ideally would produce 3*%f + 2)\n",
input.floatValue(), output.floatValue(), input.floatValue());
}
}
}