def from_proto()

in compiler_gym/views/observation_space_spec.py [0:0]


    def from_proto(cls, index: int, proto: ObservationSpace):
        """Construct a space from an ObservationSpace message."""
        shape_type = proto.WhichOneof("shape")

        def make_box(scalar_range_list, dtype, defaults):
            bounds = [_scalar_range2tuple(r, defaults) for r in scalar_range_list]
            return Box(
                name=proto.name,
                low=np.array([b[0] for b in bounds], dtype=dtype),
                high=np.array([b[1] for b in bounds], dtype=dtype),
                dtype=dtype,
            )

        def make_scalar(scalar_range, dtype, defaults):
            scalar_range_tuple = _scalar_range2tuple(scalar_range, defaults)
            return Scalar(
                name=proto.name,
                min=dtype(scalar_range_tuple[0]),
                max=dtype(scalar_range_tuple[1]),
                dtype=dtype,
            )

        def make_seq(size_range, dtype, defaults, scalar_range=None):
            return Sequence(
                name=proto.name,
                size_range=_scalar_range2tuple(size_range, defaults),
                dtype=dtype,
                opaque_data_format=proto.opaque_data_format,
                scalar_range=scalar_range,
            )

        # Translate from protocol buffer specification to python. There are
        # three variables to derive:
        #   (1) space: the gym.Space instance describing the space.
        #   (2) translate: is a callback that translates from an Observation
        #           message to a python type.
        #   (3) to_string: is a callback that translates from a python type to a
        #           string for printing.
        if proto.opaque_data_format == "json://networkx/MultiDiGraph":
            # TODO(cummins): Add a Graph space.
            space = make_seq(proto.string_size_range, str, (0, None))

            def translate(observation):
                return nx.readwrite.json_graph.node_link_graph(
                    json.loads(observation.string_value), multigraph=True, directed=True
                )

            def to_string(observation):
                return json.dumps(
                    nx.readwrite.json_graph.node_link_data(observation), indent=2
                )

        elif proto.opaque_data_format == "json://":
            space = make_seq(proto.string_size_range, str, (0, None))

            def translate(observation):
                return json.loads(observation.string_value)

            def to_string(observation):
                return json.dumps(observation, indent=2)

        elif shape_type == "int64_range_list":
            space = make_box(
                proto.int64_range_list.range,
                np.int64,
                (np.iinfo(np.int64).min, np.iinfo(np.int64).max),
            )

            def translate(observation):
                return np.array(observation.int64_list.value, dtype=np.int64)

            to_string = str
        elif shape_type == "double_range_list":
            space = make_box(
                proto.double_range_list.range, np.float64, (-np.inf, np.inf)
            )

            def translate(observation):
                return np.array(observation.double_list.value, dtype=np.float64)

            to_string = str
        elif shape_type == "string_size_range":
            space = make_seq(proto.string_size_range, str, (0, None))

            def translate(observation):
                return observation.string_value

            to_string = str
        elif shape_type == "binary_size_range":
            space = make_seq(proto.binary_size_range, bytes, (0, None))

            def translate(observation):
                return observation.binary_value

            to_string = str
        elif shape_type == "scalar_int64_range":
            space = make_scalar(
                proto.scalar_int64_range,
                int,
                (np.iinfo(np.int64).min, np.iinfo(np.int64).max),
            )

            def translate(observation):
                return int(observation.scalar_int64)

            to_string = str
        elif shape_type == "scalar_double_range":
            space = make_scalar(proto.scalar_double_range, float, (-np.inf, np.inf))

            def translate(observation):
                return float(observation.scalar_double)

            to_string = str
        elif shape_type == "double_sequence":
            space = make_seq(
                proto.double_sequence.length_range,
                np.float64,
                (-np.inf, np.inf),
                make_scalar(
                    proto.double_sequence.scalar_range, np.float64, (-np.inf, np.inf)
                ),
            )

            def translate(observation):
                return np.array(observation.double_list.value, dtype=np.float64)

            to_string = str
        else:
            raise TypeError(
                f"Unknown shape '{shape_type}' for ObservationSpace:\n{proto}"
            )

        return cls(
            id=proto.name,
            index=index,
            space=space,
            translate=translate,
            to_string=to_string,
            deterministic=proto.deterministic,
            platform_dependent=proto.platform_dependent,
            default_value=translate(proto.default_value),
        )