in src/graph.rs [2504:2642]
fn operation_attributes() {
let mut g = Graph::new();
let shape = Shape(Some(vec![None, Some(3)]));
let variable_op = {
let mut nd = g.new_operation("Variable", "Variable").unwrap();
nd.set_attr_type("dtype", DataType::Int32).unwrap();
nd.set_attr_shape("shape", &shape).unwrap();
nd.set_attr_string("shared_name", "bar").unwrap();
nd.finish().unwrap()
};
assert_eq!("bar", variable_op.get_attr_string("shared_name").unwrap());
assert_eq!(DataType::Int32, variable_op.get_attr_type("dtype").unwrap());
assert_eq!(shape, variable_op.get_attr_shape("shape").unwrap());
let op = {
let mut nd = g
.new_operation("Variable", "Variable_unknown_rank")
.unwrap();
nd.set_attr_type("dtype", DataType::Int32).unwrap();
nd.set_attr_shape("shape", &Shape(None)).unwrap();
nd.finish().unwrap()
};
assert_eq!(Shape(None), op.get_attr_shape("shape").unwrap());
let value = Tensor::<i32>::new(&[1, 3]).with_values(&[1, 2, 3]).unwrap();
let const_op = {
let mut nd = g.new_operation("Const", "Const").unwrap();
nd.set_attr_tensor("value", value.clone()).unwrap();
nd.set_attr_type("dtype", DataType::Int32).unwrap();
nd.finish().unwrap()
};
assert_eq!(value, const_op.get_attr_tensor("value").unwrap());
let op = {
let mut nd = g.new_operation("Assign", "Assign").unwrap();
nd.add_input(variable_op.clone());
nd.add_input(variable_op.clone());
nd.set_attr_bool("validate_shape", true).unwrap();
nd.set_attr_bool("use_locking", false).unwrap();
nd.finish().unwrap()
};
assert_eq!(true, op.get_attr_bool("validate_shape").unwrap());
assert_eq!(false, op.get_attr_bool("use_locking").unwrap());
let op = {
let variable_op = {
let mut nd = g.new_operation("Variable", "MaxPool_in1").unwrap();
nd.set_attr_type("dtype", DataType::Int32).unwrap();
nd.set_attr_shape(
"shape",
&Shape(Some(vec![Some(5), Some(5), Some(5), Some(5)])),
)
.unwrap();
nd.finish().unwrap()
};
let mut nd = g.new_operation("MaxPool", "MaxPool").unwrap();
nd.add_input(variable_op);
nd.set_attr_int_list("ksize", &[1, 2, 3, 4]).unwrap();
nd.set_attr_int_list("strides", &[1, 1, 1, 1]).unwrap();
nd.set_attr_string("padding", "VALID").unwrap();
nd.finish().unwrap()
};
assert_eq!(
&[1, 2, 3, 4],
&op.get_attr_int_list("ksize").unwrap() as &[i64]
);
let op = {
let mut nd = g.new_operation("TensorSummary", "TensorSummary").unwrap();
nd.add_input(variable_op.clone());
nd.set_attr_string_list("labels", &["foo", "bar"]).unwrap();
nd.finish().unwrap()
};
assert_eq!(
&["foo".to_string(), "bar".to_string()],
&op.get_attr_string_list("labels").unwrap() as &[_]
);
let op = {
let mut nd = g
.new_operation("ApproximateEqual", "ApproximateEqual")
.unwrap();
nd.add_input(variable_op.clone());
nd.add_input(variable_op.clone());
nd.set_attr_float("tolerance", 3.14).unwrap();
nd.finish().unwrap()
};
assert_eq!(3.14, op.get_attr_float("tolerance").unwrap());
let op = {
let mut nd = g.new_operation("Bucketize", "Bucketize").unwrap();
nd.add_input(variable_op.clone());
nd.set_attr_float_list("boundaries", &[0.1, 2.3]).unwrap();
nd.finish().unwrap()
};
assert_eq!(
&[0.1f32, 2.3],
&op.get_attr_float_list("boundaries").unwrap() as &[_]
);
let shape_list = &[
Shape(None),
Shape(Some(vec![])),
Shape(Some(vec![None])),
Shape(Some(vec![Some(1)])),
];
let op = {
let mut nd = g
.new_operation("RandomShuffleQueue", "RandomShuffleQueue")
.unwrap();
nd.set_attr_shape_list("shapes", shape_list).unwrap();
nd.set_attr_type_list("component_types", &[DataType::Float, DataType::Int32])
.unwrap();
nd.set_attr_int("seed", 42).unwrap();
nd.finish().unwrap()
};
assert_eq!(
shape_list,
&op.get_attr_shape_list("shapes").unwrap() as &[_]
);
assert_eq!(
&[DataType::Float, DataType::Int32],
&op.get_attr_type_list("component_types").unwrap() as &[_]
);
assert_eq!(42, op.get_attr_int("seed").unwrap());
// TODO: Support get_attr_*/set_attr_*:
// - bool_list
// - tensor_list
// - tensor_shape_proto
// - tensor_shape_proto_list
// - value_proto
// - func_name
// The protos are tricky because we don't currently support proto
// serialization/deserialization, and bool_list and tensor_list (a.k.a.
// list(bool) and list(tensor)) don't seem to be used for any standard
// ops. TF_GetAttrFuncName doesn't exist yet.
}