################################################################################
#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
# limitations under the License.
################################################################################

import pyarrow as pa

from pypaimon import Schema
from pypaimon.py4j.java_gateway import get_gateway


def to_j_catalog_context(catalog_options: dict):
    gateway = get_gateway()
    j_options = gateway.jvm.Options(catalog_options)
    return gateway.jvm.CatalogContext.create(j_options)


def to_j_identifier(identifier: str):
    return get_gateway().jvm.Identifier.fromString(identifier)


def to_paimon_schema(schema: Schema):
    j_schema_builder = get_gateway().jvm.Schema.newBuilder()

    if schema.partition_keys is not None:
        j_schema_builder.partitionKeys(schema.partition_keys)

    if schema.primary_keys is not None:
        j_schema_builder.primaryKey(schema.primary_keys)

    if schema.options is not None:
        j_schema_builder.options(schema.options)

    j_schema_builder.comment(schema.comment)

    for field in schema.pa_schema:
        column_name = field.name
        column_type = _to_j_type(column_name, field.type)
        j_schema_builder.column(column_name, column_type)
    return j_schema_builder.build()


def check_batch_write(j_table):
    gateway = get_gateway()
    bucket_mode = j_table.bucketMode()
    if bucket_mode == gateway.jvm.BucketMode.HASH_DYNAMIC \
            or bucket_mode == gateway.jvm.BucketMode.CROSS_PARTITION:
        raise TypeError("Doesn't support writing dynamic bucket or cross partition table.")


def _to_j_type(name, pa_type):
    jvm = get_gateway().jvm
    # int
    if pa.types.is_int8(pa_type):
        return jvm.DataTypes.TINYINT()
    elif pa.types.is_int16(pa_type):
        return jvm.DataTypes.SMALLINT()
    elif pa.types.is_int32(pa_type):
        return jvm.DataTypes.INT()
    elif pa.types.is_int64(pa_type):
        return jvm.DataTypes.BIGINT()
    # float
    elif pa.types.is_float16(pa_type) or pa.types.is_float32(pa_type):
        return jvm.DataTypes.FLOAT()
    elif pa.types.is_float64(pa_type):
        return jvm.DataTypes.DOUBLE()
    # string
    elif pa.types.is_string(pa_type):
        return jvm.DataTypes.STRING()
    # bool
    elif pa.types.is_boolean(pa_type):
        return jvm.DataTypes.BOOLEAN()
    elif pa.types.is_null(pa_type):
        print(f"WARN: The type of column '{name}' is null, "
              "and it will be converted to string type by default. "
              "Please check if the original type is string. "
              f"If not, please manually specify the type of '{name}'.")
        return jvm.DataTypes.STRING()
    else:
        raise ValueError(f'Found unsupported data type {str(pa_type)} for field {name}.')


def to_arrow_schema(j_row_type):
    # init arrow schema
    schema_bytes = get_gateway().jvm.SchemaUtil.getArrowSchema(j_row_type)
    schema_reader = pa.RecordBatchStreamReader(pa.BufferReader(schema_bytes))
    arrow_schema = schema_reader.schema
    schema_reader.close()
    return arrow_schema


def serialize_java_object(java_obj) -> bytes:
    gateway = get_gateway()
    util = gateway.jvm.org.apache.paimon.utils.InstantiationUtil
    try:
        java_bytes = util.serializeObject(java_obj)
        return bytes(java_bytes)
    except Exception as e:
        raise RuntimeError(f"Java serialization failed: {e}")


def deserialize_java_object(bytes_data):
    gateway = get_gateway()
    cl = get_gateway().jvm.Thread.currentThread().getContextClassLoader()
    util = gateway.jvm.org.apache.paimon.utils.InstantiationUtil
    return util.deserializeObject(bytes_data, cl)
