def preprocess_args()

in python/graphscope/learning/graph.py [0:0]


    def preprocess_args(handle, nodes, edges, gen_labels):  # noqa: C901
        handle = json.loads(base64.b64decode(handle).decode("utf-8", errors="ignore"))
        node_names = []
        node_attributes = {}
        edge_names = []
        edge_attributes = {}

        def selected_property_schema(attr_types, attributes):
            prop_counts = collections.defaultdict(lambda: 0)
            for attr in attributes:
                prop_counts[attr_types[attr]] += 1
            return [prop_counts["i"], prop_counts["f"], prop_counts["s"]]

        if nodes is not None:
            for node in nodes:
                if isinstance(node, str):
                    if node in node_names:
                        raise InvalidArgumentError("Duplicate node type: %s" % node)
                    node_names.append(node)
                elif isinstance(node, tuple):
                    if node[0] in node_names:
                        raise InvalidArgumentError("Duplicate node type: %s" % node[0])
                    node_names.append(node[0])
                    attr_types = handle["node_attribute_types"][node[0]]
                    attr_schema = selected_property_schema(attr_types, node[1])
                    node_attributes[node[0]] = (node[1], attr_schema)
                else:
                    raise InvalidArgumentError(
                        "The node parameter is in bad format: %s" % node
                    )
        else:
            for node in handle["node_schema"]:
                node_names.append(node.split(":")[0])

        if edges is not None:
            for edge in edges:
                if isinstance(edge, str):
                    if len(node_names) > 1:
                        raise InvalidArgumentError(
                            "Cannot inference edge type when multiple kinds of nodes exists"
                        )
                    edge_names.append((node_names[0], edge, node_names[0]))
                elif (
                    isinstance(edge, tuple)
                    and isinstance(edge[0], str)
                    and isinstance(edge[1], str)
                ):
                    edge_names.append(edge)
                elif (
                    isinstance(edge, tuple)
                    and isinstance(edge[0], str)
                    and isinstance(edge[1], list)
                ):
                    if len(node_names) > 1:
                        raise InvalidArgumentError(
                            "Cannot inference edge type when multiple kinds of nodes exists"
                        )
                    edge_names.append((node_names[0], edge[0], node_names[0]))
                    attr_types = handle["edge_attribute_types"][edge[0]]
                    attr_schema = selected_property_schema(attr_types, edge[1])
                    edge_attributes[edge[0]] = (edge[1], attr_schema)
                elif (
                    isinstance(edge, tuple)
                    and isinstance(edge[0], (list, tuple))
                    and isinstance(edge[1], list)
                ):
                    edge_names.append(edge[0])
                    attr_types = handle["edge_attribute_types"][edge[0][1]]
                    attr_schema = selected_property_schema(attr_types, edge[1])
                    edge_attributes[edge[0][1]] = (edge[1], attr_schema)
                else:
                    raise InvalidArgumentError(
                        "The edge parameter is in bad format: %s" % edge
                    )

        split_groups = collections.defaultdict(list)
        if gen_labels is not None:
            for label in gen_labels:
                if len(label) == 3 or len(label) == 4:
                    split_groups[label[1]].append(label)
                else:
                    raise InvalidArgumentError(
                        "Bad gen_labels arguments: %s" % gen_labels
                    )

        split_labels = []
        for label, group in split_groups.items():
            lengths = [len(split) for split in group]
            check_argument(
                lengths[:-1] == lengths[1:], "Invalid gen labels: %s" % group
            )
            if len(group[0]) == 3:
                length_sum = sum(split[2] for split in group)
                s, ss = 0, []
                for split in group:
                    ss.append((s, s + split[2]))
                    s += split[2]
                group = [
                    (split[0], split[1], length_sum, s) for split, s in zip(group, ss)
                ]
            for split in group:
                split_labels.append(split)

        return {
            "nodes": node_names if node_names else None,
            "edges": edge_names if edge_names else None,
            "node_attributes": node_attributes,
            "edge_attributes": edge_attributes,
            "gen_labels": split_labels,
        }