pyspark/graphar_pyspark/graph.py (145 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Bidnings to org.apache.graphar.graph."""
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Optional, Union
from py4j.java_gateway import JavaObject
from pyspark.sql import DataFrame
from graphar_pyspark import GraphArSession, _check_session
from graphar_pyspark.enums import FileType
from graphar_pyspark.errors import InvalidGraphFormatError
from graphar_pyspark.info import GraphInfo
@dataclass(frozen=True)
class EdgeTypes:
"""A triplet that describe edge. Contains source, edge and dest types. Immutable."""
src_type: str
edge_type: str
dst_type: str
@dataclass(frozen=True)
class GraphReaderResult:
"""A simple immutable class, that represent results of reading a graph with GraphReader."""
vertex_dataframes: Mapping[str, DataFrame]
edge_dataframes: Mapping[EdgeTypes, Mapping[str, DataFrame]]
@staticmethod
def from_scala(
jvm_result: tuple[
dict[str, JavaObject],
dict[tuple[str, str, str], dict[str, JavaObject]],
],
) -> "GraphReaderResult":
"""Create an instance of the Class from JVM method output.
:param jvm_result: structure, returned from JVM.
:returns: instance of Python Class.
"""
first_dict = {}
first_scala_map = jvm_result._1()
first_scala_map_iter = first_scala_map.keySet().iterator()
while first_scala_map_iter.hasNext():
k = first_scala_map_iter.next()
first_dict[k] = DataFrame(first_scala_map.get(k).get(), GraphArSession.ss)
second_dict = {}
second_scala_map = jvm_result._2()
second_scala_map_iter = second_scala_map.keySet().iterator()
while second_scala_map_iter.hasNext():
k = second_scala_map_iter.next()
nested_scala_map = second_scala_map.get(k).get()
nested_scala_map_iter = nested_scala_map.keySet().iterator()
inner_dict = {}
while nested_scala_map_iter.hasNext():
kk = nested_scala_map_iter.next()
inner_dict[kk] = DataFrame(
nested_scala_map.get(kk).get(),
GraphArSession.ss,
)
second_dict[EdgeTypes(k._1(), k._2(), k._3())] = inner_dict
return GraphReaderResult(
vertex_dataframes=first_dict,
edge_dataframes=second_dict,
)
class GraphReader:
"""The helper object for reading graph through the definitions of graph info."""
@staticmethod
def read(
graph_info: Union[GraphInfo, str],
) -> GraphReaderResult:
"""Read the graph as vertex and edge DataFrames with the graph info yaml file or GraphInfo object.
:param graph_info: The path of the graph info yaml or GraphInfo instance.
:returns: GraphReaderResults, that contains vertex and edge dataframes.
"""
_check_session()
if isinstance(graph_info, str):
graph_info = GraphInfo.load_graph_info(graph_info)
jvm_result = GraphArSession.graphar.graph.GraphReader.readWithGraphInfo(
graph_info.to_scala(),
GraphArSession.jss,
)
return GraphReaderResult.from_scala(jvm_result)
class GraphWriter:
"""The helper class for writing graph."""
def __init__(self, jvm_obj: JavaObject) -> None:
"""One should not use this constructor directly, please use `from_scala` or `from_python`."""
_check_session()
self._jvm_graph_writer_obj = jvm_obj
def to_scala(self) -> JavaObject:
"""Transform object to JVM representation.
:returns: JavaObject
"""
return self._jvm_graph_writer_obj
@staticmethod
def from_scala(jvm_obj: JavaObject) -> "GraphWriter":
"""Create an instance of the Class from the corresponding JVM object.
:param jvm_obj: scala object in JVM.
:returns: instance of Python Class.
"""
return GraphWriter(jvm_obj)
@staticmethod
def from_python() -> "GraphWriter":
"""Create an instance of the Class from Python arguments."""
return GraphWriter(GraphArSession.graphar.graph.GraphWriter())
def put_vertex_data(self, vertex_type: str, df: DataFrame, primary_key: str) -> None:
"""Put the vertex DataFrame into writer.
:param type: type of vertex.
:param df: DataFrame of the vertex type.
:param primary_key: primary key of the vertex type, default is empty, which take the first property column as primary key.
"""
self._jvm_graph_writer_obj.PutVertexData(vertex_type, df._jdf, primary_key)
def put_edge_data(self, relation: tuple[str, str, str], df: DataFrame) -> None:
"""Put the egde datafrme into writer.
:param relation: 3-Tuple (source type, edge type, target type) to indicate edge relation.
:param df: data frame of edge relation.
"""
relation_jvm = GraphArSession.jvm.scala.Tuple3(
relation[0], relation[1], relation[2],
)
self._jvm_graph_writer_obj.PutEdgeData(relation_jvm, df._jdf)
def write_with_graph_info(self, graph_info: Union[GraphInfo, str]) -> None:
"""Write the graph data in GraphAr format with graph info.
Note: original method is `write` but there is not directly overloading in Python.
:param graph_info: the graph info object for the graph or the path to graph info object.
"""
if isinstance(graph_info, str):
self._jvm_graph_writer_obj.write(graph_info, GraphArSession.jss)
else:
self._jvm_graph_writer_obj.write(graph_info.to_scala(), GraphArSession.jss)
def write(
self,
path: str,
name: str = "graph",
vertex_chunk_size: Optional[int] = None,
edge_chunk_size: Optional[int] = None,
file_type: Optional[FileType] = None,
version: Optional[str] = None,
) -> None:
"""Write graph data in GraphAr format.
Note: for default parameters check org.apache.graphar.GeneralParams;
For this method None for any of arguments means that the default value will be used.
:param path: the directory to write.
:param name: the name of graph, default is 'grpah'
:param vertex_chunk_size: the chunk size for vertices, default is 2^18
:param edge_chunk_size: the chunk size for edges, default is 2^22
:param file_type: the file type for data payload file, support [parquet, orc, csv, json], default is parquet.
:param version: version of GraphAr format, default is v1.
"""
if vertex_chunk_size is None:
vertex_chunk_size = (
GraphArSession.graphar.GeneralParams.defaultVertexChunkSize
)
if edge_chunk_size is None:
edge_chunk_size = GraphArSession.graphar.GeneralParams.defaultEdgeChunkSize
file_type = (
GraphArSession.graphar.GeneralParams.defaultFileType
if file_type is None
else file_type.value
)
if version is None:
version = GraphArSession.graphar.GeneralParams.defaultVersion
self._jvm_graph_writer_obj.write(
path,
GraphArSession.jss,
name,
vertex_chunk_size,
edge_chunk_size,
file_type,
version,
)
class GraphTransformer:
"""The helper object for transforming graphs through the definitions of their infos."""
@staticmethod
def transform(
source_graph_info: Union[str, GraphInfo],
dest_graph_info: Union[str, GraphInfo],
) -> None:
"""Transform the graphs following the meta data provided or defined in info files.
Note: both arguments should be strings or GrapInfo instances! Mixed arguments type is not supported.
:param source_graph_info: The path of the graph info yaml file for the source graph OR the info object for the source graph.
:param dest_graph_info: The path of the graph info yaml file for the destination graph OR the info object for the destination graph.
:raise InvalidGraphFormatException: if you pass mixed format of source and dest graph info.
"""
_check_session()
if isinstance(source_graph_info, str) and isinstance(dest_graph_info, str):
GraphArSession.graphar.graph.GraphTransformer.transform(
source_graph_info,
dest_graph_info,
GraphArSession.jss,
)
elif isinstance(source_graph_info, GraphInfo) and isinstance(
dest_graph_info,
GraphInfo,
):
GraphArSession.graphar.graph.GraphTransformer.transform(
source_graph_info.to_scala(),
dest_graph_info.to_scala(),
GraphArSession.jss,
)
else:
msg = "Both src and dst graph info objects should be of the same type. "
msg += f"But {type(source_graph_info)} and {type(dest_graph_info)} were provided!"
raise InvalidGraphFormatError(msg)