# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import atexit
import linecache
import os
import uuid
from typing import List, Callable, Union

from pyfury.resolver import NULL_FLAG, NOT_NULL_VALUE_FLAG
from pyfury.error import CompileError


_type_mapping = {
    bool: ("write_bool", "read_bool", "write_nullable_pybool", "read_nullable_pybool"),
    int: (
        "write_varint64",
        "read_varint64",
        "write_nullable_pyint64",
        "read_nullable_pyint64",
    ),
    float: (
        "write_double",
        "read_double",
        "write_nullable_pyfloat64",
        "read_nullable_pyfloat64",
    ),
    str: ("write_string", "read_string", "write_nullable_pystr", "read_nullable_pystr"),
}


def gen_write_nullable_basic_stmts(
    buffer: str,
    value: str,
    type_: type,
) -> List[str]:
    methods = _type_mapping[type_]
    from pyfury import ENABLE_FURY_CYTHON_SERIALIZATION

    if ENABLE_FURY_CYTHON_SERIALIZATION:
        return [f"{methods[2]}({buffer}, {value})"]
    return [
        f"if {value} is None:",
        f"    {buffer}.write_int8({NULL_FLAG})",
        "else: ",
        f"    {buffer}.write_int8({NOT_NULL_VALUE_FLAG})",
        f"    {buffer}.{methods[0]}({value})",
    ]


def gen_read_nullable_basic_stmts(
    buffer: str,
    type_: type,
    set_action: Callable[[str], str],
) -> List[str]:
    methods = _type_mapping[type_]
    from pyfury import ENABLE_FURY_CYTHON_SERIALIZATION

    if ENABLE_FURY_CYTHON_SERIALIZATION:
        return [set_action(f"{methods[3]}({buffer})")]

    read_value = f"{buffer}.{methods[1]}()"
    return [
        f"if {buffer}.read_int8() == {NULL_FLAG}:",
        f"    {set_action('None')}",
        "else: ",
        f"    {set_action(read_value)}",
    ]


def compile_function(
    function_name: str,
    params: List[str],
    stmts: List[str],
    context: dict,
):
    from pyfury import ENABLE_FURY_CYTHON_SERIALIZATION

    if ENABLE_FURY_CYTHON_SERIALIZATION:
        from pyfury import _serialization

        context["write_nullable_pybool"] = _serialization.write_nullable_pybool
        context["read_nullable_pybool"] = _serialization.read_nullable_pybool
        context["write_nullable_pyint64"] = _serialization.write_nullable_pyint64
        context["read_nullable_pyint64"] = _serialization.read_nullable_pyint64
        context["write_nullable_pyfloat64"] = _serialization.write_nullable_pyfloat64
        context["read_nullable_pyfloat64"] = _serialization.read_nullable_pyfloat64
        context["write_nullable_pystr"] = _serialization.write_nullable_pystr
        context["read_nullable_pystr"] = _serialization.read_nullable_pystr
    stmts = [f"{ident(statement)}" for statement in stmts]
    stmts.insert(0, f"def {function_name}({', '.join(params)}):")
    stmts = [f"{statement}  # line {idx + 1}" for idx, statement in enumerate(stmts)]
    code = "\n".join(stmts)
    filename = _generate_filename(function_name)
    code_dir = _get_code_dir()
    if code_dir:
        filename = os.path.join(code_dir, filename)
        with open(filename, "w") as f:
            f.write(code)
            f.flush()
        if _delete_code_on_exit():
            atexit.register(os.remove, filename)
    try:
        compiled = compile(code, filename, "exec")
    except Exception as e:
        raise CompileError(f"Failed to compile code:\n{code}") from e
    exec(compiled, context, context)
    # See https://stackoverflow.com/questions/64879414/how-does-attrs-fool-the-debugger-to-step-into-auto-generated-code # noqa: E501
    # In order of debuggers like PDB being able to step through the code,
    # we add a fake linecache entry.
    linecache.cache[filename] = (
        len(code),
        None,
        code.splitlines(True),
        filename,
    )
    return code, context[function_name]


# Based on https://github.com/python-attrs/attrs/blob/32fb12789e5cba4b2e71c09e47196b10763ddd7d/src/attr/_make.py#L1863 # noqa: E501
def _generate_filename(func_name):
    """
    Create a "filename" suitable for a function being generated.
    """
    unique_id = uuid.uuid4()
    extra = "0"
    count = 1

    while True:
        filename = f"fury_generated_{func_name}_{extra}.py"
        # To handle concurrency we essentially "reserve" our spot in
        # the linecache with a dummy line.  The caller can then
        # set this value correctly.
        cache_line = (1, None, [str(unique_id)], filename)
        if linecache.cache.setdefault(filename, cache_line) == cache_line:
            return filename

        # Looks like this spot is taken. Try again.
        count += 1
        extra = "{0}".format(count)


def _get_code_dir():
    code_dir = os.environ.get("FURY_CODE_DIR")
    if code_dir is not None and not os.path.exists(code_dir):
        os.makedirs(code_dir)
    return code_dir


def _delete_code_on_exit():
    return os.environ.get("DELETE_CODE_ON_EXIT", "True").lower() in ("true", "1")


def ident_lines(lines: Union[List[str], str]):
    is_str = type(lines) is str
    if is_str:
        lines = lines.split("\n")
    lines = [ident(line) for line in lines]
    return lines if not is_str else "\n".join(lines)


def ident(line: str):
    assert type(line) is str, type(line)
    return " " * 4 + line
