pyspark/graphar_pyspark/graph.py (145 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Bidnings to org.apache.graphar.graph.""" from __future__ import annotations from collections.abc import Mapping from dataclasses import dataclass from typing import Optional, Union from py4j.java_gateway import JavaObject from pyspark.sql import DataFrame from graphar_pyspark import GraphArSession, _check_session from graphar_pyspark.enums import FileType from graphar_pyspark.errors import InvalidGraphFormatError from graphar_pyspark.info import GraphInfo @dataclass(frozen=True) class EdgeTypes: """A triplet that describe edge. Contains source, edge and dest types. Immutable.""" src_type: str edge_type: str dst_type: str @dataclass(frozen=True) class GraphReaderResult: """A simple immutable class, that represent results of reading a graph with GraphReader.""" vertex_dataframes: Mapping[str, DataFrame] edge_dataframes: Mapping[EdgeTypes, Mapping[str, DataFrame]] @staticmethod def from_scala( jvm_result: tuple[ dict[str, JavaObject], dict[tuple[str, str, str], dict[str, JavaObject]], ], ) -> "GraphReaderResult": """Create an instance of the Class from JVM method output. :param jvm_result: structure, returned from JVM. :returns: instance of Python Class. """ first_dict = {} first_scala_map = jvm_result._1() first_scala_map_iter = first_scala_map.keySet().iterator() while first_scala_map_iter.hasNext(): k = first_scala_map_iter.next() first_dict[k] = DataFrame(first_scala_map.get(k).get(), GraphArSession.ss) second_dict = {} second_scala_map = jvm_result._2() second_scala_map_iter = second_scala_map.keySet().iterator() while second_scala_map_iter.hasNext(): k = second_scala_map_iter.next() nested_scala_map = second_scala_map.get(k).get() nested_scala_map_iter = nested_scala_map.keySet().iterator() inner_dict = {} while nested_scala_map_iter.hasNext(): kk = nested_scala_map_iter.next() inner_dict[kk] = DataFrame( nested_scala_map.get(kk).get(), GraphArSession.ss, ) second_dict[EdgeTypes(k._1(), k._2(), k._3())] = inner_dict return GraphReaderResult( vertex_dataframes=first_dict, edge_dataframes=second_dict, ) class GraphReader: """The helper object for reading graph through the definitions of graph info.""" @staticmethod def read( graph_info: Union[GraphInfo, str], ) -> GraphReaderResult: """Read the graph as vertex and edge DataFrames with the graph info yaml file or GraphInfo object. :param graph_info: The path of the graph info yaml or GraphInfo instance. :returns: GraphReaderResults, that contains vertex and edge dataframes. """ _check_session() if isinstance(graph_info, str): graph_info = GraphInfo.load_graph_info(graph_info) jvm_result = GraphArSession.graphar.graph.GraphReader.readWithGraphInfo( graph_info.to_scala(), GraphArSession.jss, ) return GraphReaderResult.from_scala(jvm_result) class GraphWriter: """The helper class for writing graph.""" def __init__(self, jvm_obj: JavaObject) -> None: """One should not use this constructor directly, please use `from_scala` or `from_python`.""" _check_session() self._jvm_graph_writer_obj = jvm_obj def to_scala(self) -> JavaObject: """Transform object to JVM representation. :returns: JavaObject """ return self._jvm_graph_writer_obj @staticmethod def from_scala(jvm_obj: JavaObject) -> "GraphWriter": """Create an instance of the Class from the corresponding JVM object. :param jvm_obj: scala object in JVM. :returns: instance of Python Class. """ return GraphWriter(jvm_obj) @staticmethod def from_python() -> "GraphWriter": """Create an instance of the Class from Python arguments.""" return GraphWriter(GraphArSession.graphar.graph.GraphWriter()) def put_vertex_data(self, vertex_type: str, df: DataFrame, primary_key: str) -> None: """Put the vertex DataFrame into writer. :param type: type of vertex. :param df: DataFrame of the vertex type. :param primary_key: primary key of the vertex type, default is empty, which take the first property column as primary key. """ self._jvm_graph_writer_obj.PutVertexData(vertex_type, df._jdf, primary_key) def put_edge_data(self, relation: tuple[str, str, str], df: DataFrame) -> None: """Put the egde datafrme into writer. :param relation: 3-Tuple (source type, edge type, target type) to indicate edge relation. :param df: data frame of edge relation. """ relation_jvm = GraphArSession.jvm.scala.Tuple3( relation[0], relation[1], relation[2], ) self._jvm_graph_writer_obj.PutEdgeData(relation_jvm, df._jdf) def write_with_graph_info(self, graph_info: Union[GraphInfo, str]) -> None: """Write the graph data in GraphAr format with graph info. Note: original method is `write` but there is not directly overloading in Python. :param graph_info: the graph info object for the graph or the path to graph info object. """ if isinstance(graph_info, str): self._jvm_graph_writer_obj.write(graph_info, GraphArSession.jss) else: self._jvm_graph_writer_obj.write(graph_info.to_scala(), GraphArSession.jss) def write( self, path: str, name: str = "graph", vertex_chunk_size: Optional[int] = None, edge_chunk_size: Optional[int] = None, file_type: Optional[FileType] = None, version: Optional[str] = None, ) -> None: """Write graph data in GraphAr format. Note: for default parameters check org.apache.graphar.GeneralParams; For this method None for any of arguments means that the default value will be used. :param path: the directory to write. :param name: the name of graph, default is 'grpah' :param vertex_chunk_size: the chunk size for vertices, default is 2^18 :param edge_chunk_size: the chunk size for edges, default is 2^22 :param file_type: the file type for data payload file, support [parquet, orc, csv, json], default is parquet. :param version: version of GraphAr format, default is v1. """ if vertex_chunk_size is None: vertex_chunk_size = ( GraphArSession.graphar.GeneralParams.defaultVertexChunkSize ) if edge_chunk_size is None: edge_chunk_size = GraphArSession.graphar.GeneralParams.defaultEdgeChunkSize file_type = ( GraphArSession.graphar.GeneralParams.defaultFileType if file_type is None else file_type.value ) if version is None: version = GraphArSession.graphar.GeneralParams.defaultVersion self._jvm_graph_writer_obj.write( path, GraphArSession.jss, name, vertex_chunk_size, edge_chunk_size, file_type, version, ) class GraphTransformer: """The helper object for transforming graphs through the definitions of their infos.""" @staticmethod def transform( source_graph_info: Union[str, GraphInfo], dest_graph_info: Union[str, GraphInfo], ) -> None: """Transform the graphs following the meta data provided or defined in info files. Note: both arguments should be strings or GrapInfo instances! Mixed arguments type is not supported. :param source_graph_info: The path of the graph info yaml file for the source graph OR the info object for the source graph. :param dest_graph_info: The path of the graph info yaml file for the destination graph OR the info object for the destination graph. :raise InvalidGraphFormatException: if you pass mixed format of source and dest graph info. """ _check_session() if isinstance(source_graph_info, str) and isinstance(dest_graph_info, str): GraphArSession.graphar.graph.GraphTransformer.transform( source_graph_info, dest_graph_info, GraphArSession.jss, ) elif isinstance(source_graph_info, GraphInfo) and isinstance( dest_graph_info, GraphInfo, ): GraphArSession.graphar.graph.GraphTransformer.transform( source_graph_info.to_scala(), dest_graph_info.to_scala(), GraphArSession.jss, ) else: msg = "Both src and dst graph info objects should be of the same type. " msg += f"But {type(source_graph_info)} and {type(dest_graph_info)} were provided!" raise InvalidGraphFormatError(msg)