fn xor_nn()

in src/train.rs [406:510]


    fn xor_nn() {
        let mut scope = Scope::new_root_scope();
        let scope = &mut scope;
        let hidden_size: u64 = 4;
        let input = ops::Placeholder::new()
            .dtype(DataType::Float)
            .shape([1u64, 2])
            .build(&mut scope.with_op_name("input"))
            .unwrap();
        let label = ops::Placeholder::new()
            .dtype(DataType::Float)
            .shape([1u64])
            .build(&mut scope.with_op_name("label"))
            .unwrap();
        let w_shape = ops::constant(&[2, hidden_size as i64][..], scope).unwrap();
        let w_init = ops::RandomStandardNormal::new()
            .dtype(DataType::Float)
            .build(w_shape, scope)
            .unwrap();
        let w = Variable::builder()
            .initial_value(w_init)
            .data_type(DataType::Float)
            .shape([2, hidden_size])
            .build(&mut scope.with_op_name("w"))
            .unwrap();
        let b = Variable::builder()
            .const_initial_value(Tensor::<f32>::new(&[hidden_size]))
            .build(&mut scope.with_op_name("b"))
            .unwrap();
        let layer1a = ops::MatMul::new()
            .build(input.clone(), w.output.clone(), scope)
            .unwrap();
        let layer1b = ops::Add::new()
            .build(layer1a, b.output.clone(), scope)
            .unwrap();
        let layer1 = ops::Tanh::new().build(layer1b, scope).unwrap();
        let w2_shape = ops::constant(&[hidden_size as i64, 1][..], scope).unwrap();
        let w2_init = ops::RandomStandardNormal::new()
            .dtype(DataType::Float)
            .build(w2_shape, scope)
            .unwrap();
        let w2 = Variable::builder()
            .initial_value(w2_init)
            .data_type(DataType::Float)
            .shape([hidden_size, 1])
            .build(&mut scope.with_op_name("w2"))
            .unwrap();
        let b2 = Variable::builder()
            .const_initial_value(Tensor::<f32>::new(&[1]))
            .build(&mut scope.with_op_name("b2"))
            .unwrap();
        let layer2a = ops::mat_mul(layer1, w2.output.clone(), scope).unwrap();
        let layer2b = ops::add(layer2a, b2.output.clone(), scope).unwrap();
        let layer2 = layer2b;
        let error = ops::sub(layer2.clone(), label.clone(), scope).unwrap();
        let error_squared = ops::mul(error.clone(), error, scope).unwrap();
        let sgd = GradientDescentOptimizer {
            learning_rate: Output {
                operation: ops::constant(0.1f32, scope).unwrap(),
                index: 0,
            },
        };
        let variables = vec![w.clone(), b.clone(), w2.clone(), b2.clone()];
        let (minimizer_vars, minimize) = sgd
            .minimize(
                scope,
                error_squared.clone().into(),
                MinimizeOptions::default().with_variables(&variables),
            )
            .unwrap();
        let options = SessionOptions::new();
        let g = scope.graph_mut();
        let session = Session::new(&options, &g).unwrap();

        let mut run_args = SessionRunArgs::new();
        for var in &variables {
            run_args.add_target(&var.initializer);
        }
        for var in &minimizer_vars {
            run_args.add_target(&var.initializer);
        }
        session.run(&mut run_args).unwrap();

        let mut input_tensor = Tensor::<f32>::new(&[1, 2]);
        let mut label_tensor = Tensor::<f32>::new(&[1]);
        let mut train = |i| {
            input_tensor[0] = (i & 1) as f32;
            input_tensor[1] = ((i >> 1) & 1) as f32;
            label_tensor[0] = ((i & 1) ^ ((i >> 1) & 1)) as f32;
            let mut run_args = SessionRunArgs::new();
            run_args.add_target(&minimize);
            let error_squared_fetch = run_args.request_fetch(&error_squared, 0);
            run_args.add_feed(&input, 0, &input_tensor);
            run_args.add_feed(&label, 0, &label_tensor);
            session.run(&mut run_args).unwrap();
            run_args.fetch::<f32>(error_squared_fetch).unwrap()[0]
        };
        for i in 0..1000 {
            train(i);
        }
        for i in 0..4 {
            let error = train(i);
            assert!(error < 0.01, "error = {}", error);
        }
    }