fn operation_attributes()

in src/graph.rs [2504:2642]


    fn operation_attributes() {
        let mut g = Graph::new();

        let shape = Shape(Some(vec![None, Some(3)]));
        let variable_op = {
            let mut nd = g.new_operation("Variable", "Variable").unwrap();
            nd.set_attr_type("dtype", DataType::Int32).unwrap();
            nd.set_attr_shape("shape", &shape).unwrap();
            nd.set_attr_string("shared_name", "bar").unwrap();
            nd.finish().unwrap()
        };
        assert_eq!("bar", variable_op.get_attr_string("shared_name").unwrap());
        assert_eq!(DataType::Int32, variable_op.get_attr_type("dtype").unwrap());
        assert_eq!(shape, variable_op.get_attr_shape("shape").unwrap());

        let op = {
            let mut nd = g
                .new_operation("Variable", "Variable_unknown_rank")
                .unwrap();
            nd.set_attr_type("dtype", DataType::Int32).unwrap();
            nd.set_attr_shape("shape", &Shape(None)).unwrap();
            nd.finish().unwrap()
        };
        assert_eq!(Shape(None), op.get_attr_shape("shape").unwrap());

        let value = Tensor::<i32>::new(&[1, 3]).with_values(&[1, 2, 3]).unwrap();
        let const_op = {
            let mut nd = g.new_operation("Const", "Const").unwrap();
            nd.set_attr_tensor("value", value.clone()).unwrap();
            nd.set_attr_type("dtype", DataType::Int32).unwrap();
            nd.finish().unwrap()
        };
        assert_eq!(value, const_op.get_attr_tensor("value").unwrap());

        let op = {
            let mut nd = g.new_operation("Assign", "Assign").unwrap();
            nd.add_input(variable_op.clone());
            nd.add_input(variable_op.clone());
            nd.set_attr_bool("validate_shape", true).unwrap();
            nd.set_attr_bool("use_locking", false).unwrap();
            nd.finish().unwrap()
        };
        assert_eq!(true, op.get_attr_bool("validate_shape").unwrap());
        assert_eq!(false, op.get_attr_bool("use_locking").unwrap());

        let op = {
            let variable_op = {
                let mut nd = g.new_operation("Variable", "MaxPool_in1").unwrap();
                nd.set_attr_type("dtype", DataType::Int32).unwrap();
                nd.set_attr_shape(
                    "shape",
                    &Shape(Some(vec![Some(5), Some(5), Some(5), Some(5)])),
                )
                .unwrap();
                nd.finish().unwrap()
            };
            let mut nd = g.new_operation("MaxPool", "MaxPool").unwrap();
            nd.add_input(variable_op);
            nd.set_attr_int_list("ksize", &[1, 2, 3, 4]).unwrap();
            nd.set_attr_int_list("strides", &[1, 1, 1, 1]).unwrap();
            nd.set_attr_string("padding", "VALID").unwrap();
            nd.finish().unwrap()
        };
        assert_eq!(
            &[1, 2, 3, 4],
            &op.get_attr_int_list("ksize").unwrap() as &[i64]
        );

        let op = {
            let mut nd = g.new_operation("TensorSummary", "TensorSummary").unwrap();
            nd.add_input(variable_op.clone());
            nd.set_attr_string_list("labels", &["foo", "bar"]).unwrap();
            nd.finish().unwrap()
        };
        assert_eq!(
            &["foo".to_string(), "bar".to_string()],
            &op.get_attr_string_list("labels").unwrap() as &[_]
        );

        let op = {
            let mut nd = g
                .new_operation("ApproximateEqual", "ApproximateEqual")
                .unwrap();
            nd.add_input(variable_op.clone());
            nd.add_input(variable_op.clone());
            nd.set_attr_float("tolerance", 3.14).unwrap();
            nd.finish().unwrap()
        };
        assert_eq!(3.14, op.get_attr_float("tolerance").unwrap());

        let op = {
            let mut nd = g.new_operation("Bucketize", "Bucketize").unwrap();
            nd.add_input(variable_op.clone());
            nd.set_attr_float_list("boundaries", &[0.1, 2.3]).unwrap();
            nd.finish().unwrap()
        };
        assert_eq!(
            &[0.1f32, 2.3],
            &op.get_attr_float_list("boundaries").unwrap() as &[_]
        );

        let shape_list = &[
            Shape(None),
            Shape(Some(vec![])),
            Shape(Some(vec![None])),
            Shape(Some(vec![Some(1)])),
        ];
        let op = {
            let mut nd = g
                .new_operation("RandomShuffleQueue", "RandomShuffleQueue")
                .unwrap();
            nd.set_attr_shape_list("shapes", shape_list).unwrap();
            nd.set_attr_type_list("component_types", &[DataType::Float, DataType::Int32])
                .unwrap();
            nd.set_attr_int("seed", 42).unwrap();
            nd.finish().unwrap()
        };
        assert_eq!(
            shape_list,
            &op.get_attr_shape_list("shapes").unwrap() as &[_]
        );
        assert_eq!(
            &[DataType::Float, DataType::Int32],
            &op.get_attr_type_list("component_types").unwrap() as &[_]
        );
        assert_eq!(42, op.get_attr_int("seed").unwrap());

        // TODO: Support get_attr_*/set_attr_*:
        // - bool_list
        // - tensor_list
        // - tensor_shape_proto
        // - tensor_shape_proto_list
        // - value_proto
        // - func_name
        // The protos are tricky because we don't currently support proto
        // serialization/deserialization, and bool_list and tensor_list (a.k.a.
        // list(bool) and list(tensor)) don't seem to be used for any standard
        // ops. TF_GetAttrFuncName doesn't exist yet.
    }