in src/train.rs [406:510]
fn xor_nn() {
let mut scope = Scope::new_root_scope();
let scope = &mut scope;
let hidden_size: u64 = 4;
let input = ops::Placeholder::new()
.dtype(DataType::Float)
.shape([1u64, 2])
.build(&mut scope.with_op_name("input"))
.unwrap();
let label = ops::Placeholder::new()
.dtype(DataType::Float)
.shape([1u64])
.build(&mut scope.with_op_name("label"))
.unwrap();
let w_shape = ops::constant(&[2, hidden_size as i64][..], scope).unwrap();
let w_init = ops::RandomStandardNormal::new()
.dtype(DataType::Float)
.build(w_shape, scope)
.unwrap();
let w = Variable::builder()
.initial_value(w_init)
.data_type(DataType::Float)
.shape([2, hidden_size])
.build(&mut scope.with_op_name("w"))
.unwrap();
let b = Variable::builder()
.const_initial_value(Tensor::<f32>::new(&[hidden_size]))
.build(&mut scope.with_op_name("b"))
.unwrap();
let layer1a = ops::MatMul::new()
.build(input.clone(), w.output.clone(), scope)
.unwrap();
let layer1b = ops::Add::new()
.build(layer1a, b.output.clone(), scope)
.unwrap();
let layer1 = ops::Tanh::new().build(layer1b, scope).unwrap();
let w2_shape = ops::constant(&[hidden_size as i64, 1][..], scope).unwrap();
let w2_init = ops::RandomStandardNormal::new()
.dtype(DataType::Float)
.build(w2_shape, scope)
.unwrap();
let w2 = Variable::builder()
.initial_value(w2_init)
.data_type(DataType::Float)
.shape([hidden_size, 1])
.build(&mut scope.with_op_name("w2"))
.unwrap();
let b2 = Variable::builder()
.const_initial_value(Tensor::<f32>::new(&[1]))
.build(&mut scope.with_op_name("b2"))
.unwrap();
let layer2a = ops::mat_mul(layer1, w2.output.clone(), scope).unwrap();
let layer2b = ops::add(layer2a, b2.output.clone(), scope).unwrap();
let layer2 = layer2b;
let error = ops::sub(layer2.clone(), label.clone(), scope).unwrap();
let error_squared = ops::mul(error.clone(), error, scope).unwrap();
let sgd = GradientDescentOptimizer {
learning_rate: Output {
operation: ops::constant(0.1f32, scope).unwrap(),
index: 0,
},
};
let variables = vec![w.clone(), b.clone(), w2.clone(), b2.clone()];
let (minimizer_vars, minimize) = sgd
.minimize(
scope,
error_squared.clone().into(),
MinimizeOptions::default().with_variables(&variables),
)
.unwrap();
let options = SessionOptions::new();
let g = scope.graph_mut();
let session = Session::new(&options, &g).unwrap();
let mut run_args = SessionRunArgs::new();
for var in &variables {
run_args.add_target(&var.initializer);
}
for var in &minimizer_vars {
run_args.add_target(&var.initializer);
}
session.run(&mut run_args).unwrap();
let mut input_tensor = Tensor::<f32>::new(&[1, 2]);
let mut label_tensor = Tensor::<f32>::new(&[1]);
let mut train = |i| {
input_tensor[0] = (i & 1) as f32;
input_tensor[1] = ((i >> 1) & 1) as f32;
label_tensor[0] = ((i & 1) ^ ((i >> 1) & 1)) as f32;
let mut run_args = SessionRunArgs::new();
run_args.add_target(&minimize);
let error_squared_fetch = run_args.request_fetch(&error_squared, 0);
run_args.add_feed(&input, 0, &input_tensor);
run_args.add_feed(&label, 0, &label_tensor);
session.run(&mut run_args).unwrap();
run_args.fetch::<f32>(error_squared_fetch).unwrap()[0]
};
for i in 0..1000 {
train(i);
}
for i in 0..4 {
let error = train(i);
assert!(error < 0.01, "error = {}", error);
}
}