# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""Bindings to org.apache.graphar.writer."""


from __future__ import annotations

import os
from typing import Optional

from py4j.java_gateway import JavaObject
from pyspark.sql import DataFrame

from graphar_pyspark import GraphArSession, _check_session
from graphar_pyspark.enums import AdjListType
from graphar_pyspark.info import EdgeInfo, PropertyGroup, VertexInfo


class VertexWriter:
    """Writer for vertex DataFrame."""

    def __init__(
        self,
        prefix: Optional[str],
        vertex_info: Optional[VertexInfo],
        vertex_df: Optional[DataFrame],
        num_vertices: Optional[int],
        jvm_obj: Optional[JavaObject],
    ) -> None:
        """One should not use this constructor directly, please use `from_scala` or `from_python`."""
        _check_session()
        if jvm_obj is not None:
            self._jvm_vertex_writer_obj = jvm_obj
        else:
            num_vertices = -1 if num_vertices is None else num_vertices
            self._jvm_vertex_writer_obj = GraphArSession.graphar.writer.VertexWriter(
                prefix,
                vertex_info.to_scala(),
                vertex_df._jdf,
                num_vertices,
            )

    def to_scala(self) -> JavaObject:
        """Transform object to JVM representation.

        :returns: JavaObject
        """
        return self._jvm_vertex_writer_obj

    @staticmethod
    def from_scala(jvm_obj: JavaObject) -> "VertexWriter":
        """Create an instance of the Class from the corresponding JVM object.

        :param jvm_obj: scala object in JVM.
        :returns: instance of Python Class.
        """
        return VertexWriter(None, None, None, None, jvm_obj)

    @staticmethod
    def from_python(
        prefix: str,
        vertex_info: VertexInfo,
        vertex_df: DataFrame,
        num_vertices: Optional[int],
    ) -> "VertexWriter":
        """Create an instance of the Class from Python arguments.

        :param prefix: the absolute prefix.
        :param vertex_info: the vertex info that describes the vertex type.
        :param vertex_df: the input vertex DataFrame.
        :param num_vertices: the number of vertices, optional
        """
        if not prefix.endswith(os.sep):
            prefix += os.sep
        return VertexWriter(prefix, vertex_info, vertex_df, num_vertices, None)

    def write_vertex_properties(
        self,
        property_group: Optional[PropertyGroup] = None,
    ) -> None:
        """Generate chunks of the property group (or all property groups) for vertex DataFrame.

        :param property_group: property group (optional, default is None)
        if provided, generate chunks of the property group, otherwise generate for all property groups.
        """
        if property_group is not None:
            self._jvm_vertex_writer_obj.writeVertexProperties(property_group.to_scala())
        else:
            self._jvm_vertex_writer_obj.writeVertexProperties()


class EdgeWriter:
    """Writer for edge DataFrame."""

    def __init__(
        self,
        prefix: Optional[str],
        edge_info: Optional[EdgeInfo],
        adj_list_type: Optional[AdjListType],
        vertex_num: Optional[int],
        edge_df: Optional[DataFrame],
        jvm_obj: Optional[JavaObject],
    ) -> None:
        """One should not use this constructor directly, please use `from_scala` or `from_python`."""
        _check_session()
        if jvm_obj is not None:
            self._jvm_edge_writer_obj = jvm_obj
        else:
            self._jvm_edge_writer_obj = GraphArSession.graphar.writer.EdgeWriter(
                prefix,
                edge_info.to_scala(),
                adj_list_type.to_scala(),
                vertex_num,
                edge_df._jdf,
            )

    def to_scala(self) -> JavaObject:
        """Transform object to JVM representation.

        :returns: JavaObject
        """
        return self._jvm_edge_writer_obj

    @staticmethod
    def from_scala(jvm_obj: JavaObject) -> "EdgeWriter":
        """Create an instance of the Class from the corresponding JVM object.

        :param jvm_obj: scala object in JVM.
        :returns: instance of Python Class.
        """
        return EdgeWriter(None, None, None, None, None, jvm_obj)

    @staticmethod
    def from_python(
        prefix: str,
        edge_info: EdgeInfo,
        adj_list_type: AdjListType,
        vertex_num: int,
        edge_df: DataFrame,
    ) -> "EdgeWriter":
        """Create an instance of the Class from Python arguments.

        :param prefix: the absolute prefix.
        :param edge_info: the edge info that describes the ede type.
        :param adj_list_type: the adj list type for the edge.
        :param vertex_num: vertex number of the primary vertex type
        :param edge_df: the input edge DataFrame.
        """
        if not prefix.endswith(os.sep):
            prefix += os.sep
        return EdgeWriter(prefix, edge_info, adj_list_type, vertex_num, edge_df, None)

    def write_adj_list(self) -> None:
        """Generate the chunks of AdjList from edge DataFrame for this edge type."""
        self._jvm_edge_writer_obj.writeAdjList()

    def write_edge_properties(
        self,
        property_group: Optional[PropertyGroup] = None,
    ) -> None:
        """Generate the chunks of all or selected property groups from edge DataFrame.

        :param property_group: property group (optional, default is None)
        if provided, generate the chunks of selected property group, otherwise generate for all groups.
        """
        if property_group is not None:
            self._jvm_edge_writer_obj.writeEdgeProperties(property_group.to_scala())
        else:
            self._jvm_edge_writer_obj.writeEdgeProperties()

    def write_edges(self) -> None:
        """Generate the chunks for the AdjList and all property groups from edge."""
        self._jvm_edge_writer_obj.writeEdges()
