in compiler_gym/views/observation_space_spec.py [0:0]
def from_proto(cls, index: int, proto: ObservationSpace):
"""Construct a space from an ObservationSpace message."""
shape_type = proto.WhichOneof("shape")
def make_box(scalar_range_list, dtype, defaults):
bounds = [_scalar_range2tuple(r, defaults) for r in scalar_range_list]
return Box(
name=proto.name,
low=np.array([b[0] for b in bounds], dtype=dtype),
high=np.array([b[1] for b in bounds], dtype=dtype),
dtype=dtype,
)
def make_scalar(scalar_range, dtype, defaults):
scalar_range_tuple = _scalar_range2tuple(scalar_range, defaults)
return Scalar(
name=proto.name,
min=dtype(scalar_range_tuple[0]),
max=dtype(scalar_range_tuple[1]),
dtype=dtype,
)
def make_seq(size_range, dtype, defaults, scalar_range=None):
return Sequence(
name=proto.name,
size_range=_scalar_range2tuple(size_range, defaults),
dtype=dtype,
opaque_data_format=proto.opaque_data_format,
scalar_range=scalar_range,
)
# Translate from protocol buffer specification to python. There are
# three variables to derive:
# (1) space: the gym.Space instance describing the space.
# (2) translate: is a callback that translates from an Observation
# message to a python type.
# (3) to_string: is a callback that translates from a python type to a
# string for printing.
if proto.opaque_data_format == "json://networkx/MultiDiGraph":
# TODO(cummins): Add a Graph space.
space = make_seq(proto.string_size_range, str, (0, None))
def translate(observation):
return nx.readwrite.json_graph.node_link_graph(
json.loads(observation.string_value), multigraph=True, directed=True
)
def to_string(observation):
return json.dumps(
nx.readwrite.json_graph.node_link_data(observation), indent=2
)
elif proto.opaque_data_format == "json://":
space = make_seq(proto.string_size_range, str, (0, None))
def translate(observation):
return json.loads(observation.string_value)
def to_string(observation):
return json.dumps(observation, indent=2)
elif shape_type == "int64_range_list":
space = make_box(
proto.int64_range_list.range,
np.int64,
(np.iinfo(np.int64).min, np.iinfo(np.int64).max),
)
def translate(observation):
return np.array(observation.int64_list.value, dtype=np.int64)
to_string = str
elif shape_type == "double_range_list":
space = make_box(
proto.double_range_list.range, np.float64, (-np.inf, np.inf)
)
def translate(observation):
return np.array(observation.double_list.value, dtype=np.float64)
to_string = str
elif shape_type == "string_size_range":
space = make_seq(proto.string_size_range, str, (0, None))
def translate(observation):
return observation.string_value
to_string = str
elif shape_type == "binary_size_range":
space = make_seq(proto.binary_size_range, bytes, (0, None))
def translate(observation):
return observation.binary_value
to_string = str
elif shape_type == "scalar_int64_range":
space = make_scalar(
proto.scalar_int64_range,
int,
(np.iinfo(np.int64).min, np.iinfo(np.int64).max),
)
def translate(observation):
return int(observation.scalar_int64)
to_string = str
elif shape_type == "scalar_double_range":
space = make_scalar(proto.scalar_double_range, float, (-np.inf, np.inf))
def translate(observation):
return float(observation.scalar_double)
to_string = str
elif shape_type == "double_sequence":
space = make_seq(
proto.double_sequence.length_range,
np.float64,
(-np.inf, np.inf),
make_scalar(
proto.double_sequence.scalar_range, np.float64, (-np.inf, np.inf)
),
)
def translate(observation):
return np.array(observation.double_list.value, dtype=np.float64)
to_string = str
else:
raise TypeError(
f"Unknown shape '{shape_type}' for ObservationSpace:\n{proto}"
)
return cls(
id=proto.name,
index=index,
space=space,
translate=translate,
to_string=to_string,
deterministic=proto.deterministic,
platform_dependent=proto.platform_dependent,
default_value=translate(proto.default_value),
)