# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#    http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import itertools
import json
import os
import shutil
from collections import OrderedDict
from queue import Queue
from threading import Thread
from typing import Any, Dict, List, Optional, Tuple, Union

import pyarrow as pa
import torch
from torch import distributed as dist
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# NOQA
from torchrec.distributed.train_pipeline import TrainPipelineSparseDist
from torchrec.inference.modules import quantize_embeddings
from torchrec.optim.apply_optimizer_in_backward import (
    apply_optimizer_in_backward,  # NOQA
)
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter

from tzrec.acc.aot_utils import export_model_aot
from tzrec.acc.trt_utils import export_model_trt, get_trt_max_batch_size
from tzrec.acc.utils import (
    export_acc_config,
    is_aot,
    is_cuda_export,
    is_input_tile_emb,
    is_quant,
    is_trt,
    is_trt_predict,
    quant_dtype,
    write_mapping_file_for_input_tile,
)
from tzrec.constant import PREDICT_QUEUE_TIMEOUT, Mode
from tzrec.datasets.dataset import BaseDataset, BaseWriter, create_writer
from tzrec.datasets.utils import Batch, RecordBatchTensor
from tzrec.features.feature import (
    BaseFeature,
    create_feature_configs,
    create_features,
    create_fg_json,
)
from tzrec.models.match_model import (
    MatchModel,
    MatchTower,
    MatchTowerWoEG,
    TowerWoEGWrapper,
    TowerWrapper,
)
from tzrec.models.model import BaseModel, CudaExportWrapper, ScriptWrapper, TrainWrapper
from tzrec.models.tdm import TDM, TDMEmbedding
from tzrec.modules.embedding import EmbeddingGroup
from tzrec.modules.utils import BaseModule
from tzrec.ops import Kernel
from tzrec.optim import optimizer_builder
from tzrec.optim.lr_scheduler import BaseLR
from tzrec.protos.data_pb2 import DataConfig, DatasetType
from tzrec.protos.eval_pb2 import EvalConfig
from tzrec.protos.feature_pb2 import FeatureConfig
from tzrec.protos.model_pb2 import Kernel as KernelProto
from tzrec.protos.model_pb2 import ModelConfig
from tzrec.protos.pipeline_pb2 import EasyRecConfig
from tzrec.protos.train_pb2 import TrainConfig
from tzrec.utils import checkpoint_util, config_util
from tzrec.utils.dist_util import DistributedModelParallel
from tzrec.utils.fx_util import symbolic_trace
from tzrec.utils.logging_util import ProgressLogger, logger
from tzrec.utils.plan_util import create_planner, get_default_sharders
from tzrec.utils.state_dict_util import fix_mch_state, init_parameters
from tzrec.version import __version__ as tzrec_version


def init_process_group() -> Tuple[torch.device, str]:
    """Init process_group, device, rank, backend."""
    rank = int(os.environ.get("LOCAL_RANK", 0))
    if torch.cuda.is_available():
        device: torch.device = torch.device(f"cuda:{rank}")
        backend = "nccl"
        torch.cuda.set_device(device)
    else:
        device: torch.device = torch.device("cpu")
        backend = "gloo"
    dist.init_process_group(backend=backend)
    return device, backend


def _create_features(
    feature_configs: List[FeatureConfig], data_config: DataConfig
) -> List[BaseFeature]:
    neg_fields = None
    if data_config.HasField("sampler"):
        sampler_type = data_config.WhichOneof("sampler")
        if sampler_type != "tdm_sampler":
            neg_fields = list(
                getattr(data_config, data_config.WhichOneof("sampler")).attr_fields
            )

    features = create_features(
        feature_configs,
        fg_mode=data_config.fg_mode,
        neg_fields=neg_fields,
        fg_encoded_multival_sep=data_config.fg_encoded_multival_sep,
        force_base_data_group=data_config.force_base_data_group,
    )
    return features


def _get_dataloader(
    data_config: DataConfig,
    features: List[BaseFeature],
    input_path: str,
    reserved_columns: Optional[List[str]] = None,
    mode: Mode = Mode.TRAIN,
    gl_cluster: Optional[Dict[str, Union[int, str]]] = None,
    debug_level: int = 0,
) -> DataLoader:
    """Build dataloader.

    Args:
        data_config (DataConfig): dataloader config.
        features (list): list of feature.
        input_path (str): input data path.
        reserved_columns (list): reserved columns in predict mode.
        mode (Mode): train or eval or predict.
        gl_cluster (dict, bool): if set, reuse the graphlearn cluster.
        debug_level (int): dataset debug level, when mode=predict and
            debug_level > 0, will dump fg encoded data to debug_str

    Return:
        dataloader (dataloader): a DataLoader.
    """
    dataset_name = DatasetType.Name(data_config.dataset_type)
    # pyre-ignore [16]
    dataset_cls = BaseDataset.create_class(dataset_name)
    dataset = dataset_cls(
        data_config=data_config,
        features=features,
        input_path=input_path,
        reserved_columns=reserved_columns,
        mode=mode,
        debug_level=debug_level,
    )

    kwargs = {}
    if data_config.num_workers < 1:
        num_workers = 1
    else:
        num_workers = data_config.num_workers
        # check number of files is valid or not for file based dataset.
        if hasattr(dataset._reader, "num_files"):
            num_files = dataset._reader.num_files()
            world_size = int(os.environ.get("WORLD_SIZE", 1))
            if num_files >= world_size:
                num_files_per_worker = num_files // world_size
                if num_files_per_worker < num_workers:
                    logger.info(
                        f"data_config.num_workers reset to {num_files_per_worker}"
                    )
                    num_workers = num_files_per_worker
            else:
                raise ValueError(
                    f"Number of files in the dataset[{input_path}] must greater "
                    f"than world_size: {world_size}, but got {num_files}"
                )

        kwargs["num_workers"] = num_workers
        kwargs["persistent_workers"] = True

    if mode == Mode.TRAIN:
        # When in train_and_eval mode, use 2x worker in gl cluster
        # for train_dataloader and eval_dataloader
        dataset.launch_sampler_cluster(num_client_per_rank=num_workers * 2)
    else:
        if gl_cluster:
            # Reuse the gl cluster for eval_dataloader
            dataset.launch_sampler_cluster(
                num_client_per_rank=num_workers * 2,
                client_id_bias=num_workers,
                cluster=gl_cluster,
            )
        else:
            dataset.launch_sampler_cluster(num_client_per_rank=num_workers)

    dataloader = DataLoader(
        dataset=dataset,
        batch_size=None,
        pin_memory=data_config.pin_memory if mode != Mode.PREDICT else False,
        collate_fn=lambda x: x,
        **kwargs,
    )
    # For PyTorch versions 2.6 and above, we initialize the data iterator before
    # beginning the training process to avoid potential CUDA-related issues following
    # model saving.
    iter(dataloader)
    return dataloader


def _create_model(
    model_config: ModelConfig,
    features: List[BaseFeature],
    labels: List[str],
    sample_weights: Optional[List[str]] = None,
) -> BaseModel:
    """Build model.

    Args:
        model_config (ModelConfig): easyrec model config.
        features (list): list of features.
        labels (list): list of label names.
        sample_weights (list): list of sample weight names.

    Return:
        model: a EasyRec Model.
    """
    model_cls_name = config_util.which_msg(model_config, "model")
    # pyre-ignore [16]
    model_cls = BaseModel.create_class(model_cls_name)

    model: BaseModel = model_cls(
        model_config, features, labels, sample_weights=sample_weights
    )

    kernel = Kernel[KernelProto.Name(model_config.kernel)]
    model.set_kernel(kernel)
    return model


def _evaluate(
    model: nn.Module,
    eval_dataloader: DataLoader,
    eval_config: EvalConfig,
    eval_result_filename: Optional[str] = None,
    global_step: Optional[int] = None,
    eval_summary_writer: Optional[SummaryWriter] = None,
    global_epoch: Optional[int] = None,
) -> Dict[str, torch.Tensor]:
    """Evaluate the model."""
    is_rank_zero = int(os.environ.get("RANK", 0)) == 0
    is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0
    model.eval()
    pipeline = TrainPipelineSparseDist(
        model,
        # pyre-fixme [6]
        None,
        model.device,
        execute_all_batches=True,
    )

    use_step = eval_config.num_steps and eval_config.num_steps > 0
    iterator = iter(eval_dataloader)
    step_iter = range(eval_config.num_steps) if use_step else itertools.count(0)

    desc_suffix = ""
    if global_epoch:
        desc_suffix += f" Epoch-{global_epoch}"
    if global_step:
        desc_suffix += f" model-{global_step}"
    _model = model.module.model

    plogger = None
    if is_local_rank_zero:
        plogger = ProgressLogger(desc=f"Evaluating{desc_suffix}")

    with torch.no_grad():
        i_step = 0
        for i_step in step_iter:
            try:
                losses, predictions, batch = pipeline.progress(iterator)
                _model.update_metric(predictions, batch, losses)
                if (
                    plogger is not None
                    and i_step % eval_config.log_step_count_steps == 0
                ):
                    plogger.log(i_step)
            except StopIteration:
                break
        if plogger is not None:
            plogger.log(i_step)

    metric_result = _model.compute_metric()

    if is_rank_zero:
        metric_str = " ".join([f"{k}:{v:0.6f}" for k, v in metric_result.items()])
        logger.info(f"Eval Result{desc_suffix}: {metric_str}")
        metric_result = {k: v.item() for k, v in metric_result.items()}
        if eval_result_filename:
            with open(eval_result_filename, "w") as f:
                json.dump(metric_result, f, indent=4)
        if eval_summary_writer:
            for k, v in metric_result.items():
                eval_summary_writer.add_scalar(f"metric/{k}", v, global_step or 0)
    return metric_result


def _log_train(
    step: int,
    losses: Dict[str, torch.Tensor],
    param_groups: List[Dict[str, Any]],
    plogger: Optional[ProgressLogger] = None,
    summary_writer: Optional[SummaryWriter] = None,
) -> None:
    """Logging current training step."""
    if plogger is not None:
        loss_strs = []
        lr_strs = []
        for k, v in losses.items():
            loss_strs.append(f"{k}:{v:.5f}")
        for i, g in enumerate(param_groups):
            lr_strs.append(f"lr_g{i}:{g['lr']:.5f}")
        total_loss = sum(losses.values())
        plogger.log(
            step,
            f"{' '.join(lr_strs)} {' '.join(loss_strs)} total_loss: {total_loss:.5f}",
        )
    if summary_writer is not None:
        total_loss = sum(losses.values())
        for k, v in losses.items():
            summary_writer.add_scalar(f"loss/{k}", v, step)
        summary_writer.add_scalar("loss/total", total_loss, step)
        for i, g in enumerate(param_groups):
            summary_writer.add_scalar(f"lr/g{i}", g["lr"], step)


def _train_and_evaluate(
    model: nn.Module,
    optimizer: optim.Optimizer,
    train_dataloader: DataLoader,
    eval_dataloader: Optional[DataLoader],
    lr_scheduler: List[BaseLR],
    model_dir: str,
    train_config: TrainConfig,
    eval_config: EvalConfig,
    skip_steps: int = -1,
    ckpt_path: Optional[str] = None,
    eval_result_filename: str = "train_eval_result.txt",
) -> None:
    """Train and evaluate the model."""
    is_rank_zero = int(os.environ.get("RANK", 0)) == 0
    is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0
    model.train()

    assert train_config.num_steps ^ train_config.num_epochs, (
        "train_config.num_epochs or train_config.num_steps should be set, "
        "and can not be set at the same time."
    )
    use_epoch = train_config.num_epochs and train_config.num_epochs > 0
    use_step = train_config.num_steps and train_config.num_steps > 0
    epoch_iter = range(train_config.num_epochs) if use_epoch else itertools.count(0, 0)
    step_iter = range(train_config.num_steps) if use_step else itertools.count(0)

    save_checkpoints_steps, save_checkpoints_epochs = 0, 0
    if train_config.save_checkpoints_epochs > 0:
        save_checkpoints_epochs = train_config.save_checkpoints_epochs
    else:
        save_checkpoints_steps = train_config.save_checkpoints_steps

    plogger = None
    summary_writer = None
    eval_summary_writer = None
    if is_local_rank_zero:
        plogger = ProgressLogger(desc="Training Epoch 0", start_n=skip_steps)
    if is_rank_zero and train_config.use_tensorboard:
        summary_writer = SummaryWriter(model_dir)
        eval_summary_writer = SummaryWriter(os.path.join(model_dir, "eval_val"))
    eval_result_filename = os.path.join(model_dir, eval_result_filename)

    if train_config.is_profiling:
        if is_rank_zero:
            logger.info(str(model))
        prof = torch.profiler.profile(
            schedule=torch.profiler.schedule(wait=1, warmup=1, active=10, repeat=1),
            on_trace_ready=torch.profiler.tensorboard_trace_handler(
                os.path.join(model_dir, "train_eval_trace")
            ),
            record_shapes=True,
            profile_memory=True,
            with_stack=True,
        )
        prof.start()

    last_ckpt_step = -1
    i_step = 0
    i_epoch = 0
    losses = {}
    for i_epoch in epoch_iter:
        pipeline = TrainPipelineSparseDist(
            model, optimizer, model.device, execute_all_batches=True
        )
        if plogger is not None:
            plogger.set_description(f"Training Epoch {i_epoch}")

        train_iterator = iter(train_dataloader)

        # Restore model and optimizer checkpoint, because optimizer's state
        # is lazy init, we should do a dummy step before restore.
        if i_step == 0 and ckpt_path is not None:
            peek_batch = next(train_iterator)
            pipeline.progress(iter([peek_batch]))
            train_iterator = itertools.chain([peek_batch], train_iterator)
            checkpoint_util.restore_model(
                ckpt_path, model, optimizer, train_config.fine_tune_ckpt_param_map
            )

        for i_step in step_iter:
            if i_step <= skip_steps:
                continue
            try:
                losses, _, _ = pipeline.progress(train_iterator)

                if i_step % train_config.log_step_count_steps == 0:
                    _log_train(
                        i_step,
                        losses,
                        param_groups=optimizer.param_groups,
                        plogger=plogger,
                        summary_writer=summary_writer,
                    )

                for lr in lr_scheduler:
                    if not lr.by_epoch:
                        lr.step()
            except StopIteration:
                step_iter = itertools.chain([i_step], step_iter)
                i_step -= 1
                break

            if save_checkpoints_steps > 0 and i_step > 0:
                if i_step % save_checkpoints_steps == 0:
                    last_ckpt_step = i_step
                    checkpoint_util.save_model(
                        os.path.join(model_dir, f"model.ckpt-{i_step}"),
                        model,
                        optimizer,
                    )
                    if eval_dataloader is not None:
                        _evaluate(
                            model,
                            eval_dataloader,
                            eval_config,
                            eval_result_filename=eval_result_filename,
                            global_step=i_step,
                            eval_summary_writer=eval_summary_writer,
                            global_epoch=i_epoch,
                        )
                        model.train()
            if train_config.is_profiling:
                prof.step()

        if save_checkpoints_epochs > 0 and i_step > 0:
            if i_epoch % save_checkpoints_epochs == 0:
                last_ckpt_step = i_step
                checkpoint_util.save_model(
                    os.path.join(model_dir, f"model.ckpt-{i_step}"),
                    model,
                    optimizer,
                )
                if eval_dataloader is not None:
                    _evaluate(
                        model,
                        eval_dataloader,
                        eval_config,
                        eval_result_filename=eval_result_filename,
                        global_step=i_step,
                        eval_summary_writer=eval_summary_writer,
                        global_epoch=i_epoch,
                    )
                    model.train()

        if use_step and i_step >= train_config.num_steps - 1:
            break

        for lr in lr_scheduler:
            if lr.by_epoch:
                lr.step()

    _log_train(
        i_step,
        losses,
        param_groups=optimizer.param_groups,
        plogger=plogger,
        summary_writer=summary_writer,
    )
    if summary_writer is not None:
        summary_writer.close()
    if train_config.is_profiling:
        prof.stop()
    if last_ckpt_step != i_step:
        checkpoint_util.save_model(
            os.path.join(model_dir, f"model.ckpt-{i_step}"),
            model,
            optimizer,
        )
        if eval_dataloader is not None:
            _evaluate(
                model,
                eval_dataloader,
                eval_config,
                eval_result_filename=eval_result_filename,
                global_step=i_step,
                eval_summary_writer=eval_summary_writer,
                global_epoch=i_epoch,
            )
            model.train()


def train_and_evaluate(
    pipeline_config_path: str,
    train_input_path: Optional[str] = None,
    eval_input_path: Optional[str] = None,
    model_dir: Optional[str] = None,
    continue_train: Optional[bool] = True,
    fine_tune_checkpoint: Optional[str] = None,
    edit_config_json: Optional[str] = None,
) -> None:
    """Train and evaluate a EasyRec model.

    Args:
        pipeline_config_path (str): path to EasyRecConfig object.
        train_input_path (str, optional): train data path.
        eval_input_path (str, optional): eval data path.
        model_dir (str, optionl): model directory.
        continue_train (bool, optional): whether to restart train from
            an existing checkpoint.
        fine_tune_checkpoint (str, optional): path to an existing
            finetune checkpoint.
        edit_config_json (str, optional): edit pipeline config json str.
    """
    pipeline_config = config_util.load_pipeline_config(pipeline_config_path)
    if fine_tune_checkpoint:
        pipeline_config.train_config.fine_tune_checkpoint = fine_tune_checkpoint
    if train_input_path:
        pipeline_config.train_input_path = train_input_path
    if eval_input_path:
        pipeline_config.eval_input_path = eval_input_path
    if model_dir:
        pipeline_config.model_dir = model_dir
    if edit_config_json:
        edit_config_json = json.loads(edit_config_json)
        config_util.edit_config(pipeline_config, edit_config_json)

    device, _ = init_process_group()
    is_rank_zero = int(os.environ.get("RANK", 0)) == 0
    is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0

    data_config = pipeline_config.data_config
    # Build feature
    features = _create_features(list(pipeline_config.feature_configs), data_config)

    # Build dataloader
    train_dataloader = _get_dataloader(
        data_config, features, pipeline_config.train_input_path, mode=Mode.TRAIN
    )
    eval_dataloader = None
    if pipeline_config.eval_input_path:
        # pyre-ignore [16]
        gl_cluster = train_dataloader.dataset.get_sampler_cluster()
        eval_dataloader = _get_dataloader(
            data_config,
            features,
            pipeline_config.eval_input_path,
            mode=Mode.EVAL,
            gl_cluster=gl_cluster,
        )

    # Build model
    model = _create_model(
        pipeline_config.model_config,
        features,
        list(data_config.label_fields),
        sample_weights=list(data_config.sample_weight_fields),
    )
    model = TrainWrapper(model)

    sparse_optim_cls, sparse_optim_kwargs = optimizer_builder.create_sparse_optimizer(
        pipeline_config.train_config.sparse_optimizer
    )
    apply_optimizer_in_backward(
        sparse_optim_cls, model.model.sparse_parameters(), sparse_optim_kwargs
    )

    planner = create_planner(
        device=device,
        # pyre-ignore [16]
        batch_size=train_dataloader.dataset.sampled_batch_size,
    )

    plan = planner.collective_plan(
        model, get_default_sharders(), dist.GroupMember.WORLD
    )
    if is_rank_zero:
        logger.info(str(plan))

    model = DistributedModelParallel(
        module=model,
        device=device,
        plan=plan,
    )

    dense_optim_cls, dense_optim_kwargs = optimizer_builder.create_dense_optimizer(
        pipeline_config.train_config.dense_optimizer
    )
    dense_optimizer = KeyedOptimizerWrapper(
        dict(in_backward_optimizer_filter(model.named_parameters())),
        lambda params: dense_optim_cls(params, **dense_optim_kwargs),
    )
    optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer])
    sparse_lr = optimizer_builder.create_scheduler(
        model.fused_optimizer, pipeline_config.train_config.sparse_optimizer
    )
    dense_lr = optimizer_builder.create_scheduler(
        dense_optimizer, pipeline_config.train_config.dense_optimizer
    )

    ckpt_path = None
    skip_steps = -1
    if pipeline_config.train_config.fine_tune_checkpoint:
        ckpt_path, _ = checkpoint_util.latest_checkpoint(
            pipeline_config.train_config.fine_tune_checkpoint
        )
        if ckpt_path is None or not os.path.exists(ckpt_path):
            raise RuntimeError(
                "fine_tune_checkpoint"
                "[{pipeline_config.train_config.fine_tune_checkpoint}] not exists."
            )
    if os.path.exists(pipeline_config.model_dir):
        # TODO(hongsheng.jhs): save and restore dataloader state.
        latest_ckpt_path, skip_steps = checkpoint_util.latest_checkpoint(
            pipeline_config.model_dir
        )
        if latest_ckpt_path:
            if continue_train:
                ckpt_path = latest_ckpt_path
            else:
                raise RuntimeError(
                    f"model_dir[{pipeline_config.model_dir}] already exists "
                    "and not empty(if you want to continue train on current "
                    "model_dir please delete dir model_dir or specify "
                    "--continue_train)"
                )

    # use barrier to sync all workers, prevent rank zero save_message and create
    # model_dir first, other slow rank find model_dir already exists and
    # do continue train improperly.
    dist.barrier()

    if is_rank_zero:
        config_util.save_message(
            pipeline_config, os.path.join(pipeline_config.model_dir, "pipeline.config")
        )
        with open(os.path.join(pipeline_config.model_dir, "version"), "w") as f:
            f.write(tzrec_version + "\n")

    _train_and_evaluate(
        model,
        optimizer,
        train_dataloader,
        eval_dataloader,
        [sparse_lr, dense_lr],
        pipeline_config.model_dir,
        train_config=pipeline_config.train_config,
        eval_config=pipeline_config.eval_config,
        skip_steps=skip_steps,
        ckpt_path=ckpt_path,
    )
    if is_local_rank_zero:
        logger.info("Train and Evaluate Finished.")


def evaluate(
    pipeline_config_path: str,
    checkpoint_path: Optional[str] = None,
    eval_input_path: Optional[str] = None,
    eval_result_filename: str = "eval_result.txt",
) -> None:
    """Evaluate a EasyRec model.

    Args:
        pipeline_config_path (str): path to EasyRecConfig object.
        checkpoint_path (str, optional): if specified, will use this model instead of
            model specified by model_dir in pipeline_config_path
        eval_input_path (str, optional): eval data path, default use eval data in
            pipeline_config, could be a path or a list of paths
        eval_result_filename (str): evaluation result metrics save path.
    """
    pipeline_config = config_util.load_pipeline_config(pipeline_config_path)

    device, _ = init_process_group()
    is_rank_zero = int(os.environ.get("RANK", 0)) == 0
    is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0

    data_config = pipeline_config.data_config
    # Build feature
    features = _create_features(list(pipeline_config.feature_configs), data_config)

    eval_dataloader = _get_dataloader(
        data_config,
        features,
        eval_input_path or pipeline_config.eval_input_path,
        mode=Mode.EVAL,
    )

    # Build model
    model = _create_model(
        pipeline_config.model_config,
        features,
        list(data_config.label_fields),
        sample_weights=list(data_config.sample_weight_fields),
    )
    model = TrainWrapper(model)

    planner = create_planner(
        device=device,
        # pyre-ignore [16]
        batch_size=eval_dataloader.dataset.sampled_batch_size,
    )
    plan = planner.collective_plan(
        model, get_default_sharders(), dist.GroupMember.WORLD
    )
    if is_rank_zero:
        logger.info(str(plan))

    model = DistributedModelParallel(module=model, device=device, plan=plan)

    global_step = None
    if not checkpoint_path:
        checkpoint_path, global_step = checkpoint_util.latest_checkpoint(
            pipeline_config.model_dir
        )
    if checkpoint_path:
        checkpoint_util.restore_model(checkpoint_path, model)
    else:
        raise ValueError("Eval checkpoint path should be specified.")

    summary_writer = None
    if is_rank_zero:
        summary_writer = SummaryWriter(os.path.join(pipeline_config.model_dir, "eval"))
    _evaluate(
        model,
        eval_dataloader,
        eval_config=pipeline_config.eval_config,
        eval_result_filename=os.path.join(
            pipeline_config.model_dir, eval_result_filename
        ),
        global_step=global_step,
        eval_summary_writer=summary_writer,
    )
    if is_local_rank_zero:
        logger.info("Evaluate Finished.")


def _script_model(
    pipeline_config: EasyRecConfig,
    model: BaseModule,
    state_dict: Optional[Dict[str, Any]],
    dataloader: DataLoader,
    save_dir: str,
) -> None:
    is_rank_zero = int(os.environ.get("RANK", 0)) == 0
    is_trt_convert = is_trt()
    if is_rank_zero:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        model.set_is_inference(True)
        if state_dict is not None:
            model.to_empty(device="cpu")
            model.load_state_dict(state_dict, strict=False)

        # for mc modules, fix output_segments_tensor is a meta tensor.
        fix_mch_state(model)

        batch = next(iter(dataloader))

        if is_cuda_export():
            model = model.cuda()

        if is_quant():
            logger.info("quantize embeddings...")
            quantize_embeddings(model, dtype=quant_dtype(), inplace=True)
            logger.info("finish quantize embeddings...")

        model.eval()

        if is_trt_convert:
            data_cuda = batch.to_dict(sparse_dtype=torch.int64)
            result = model(data_cuda, "cuda:0")
            result_info = {k: (v.size(), v.dtype) for k, v in result.items()}
            logger.info(f"Model Outputs: {result_info}")
            export_model_trt(model, data_cuda, save_dir)

        elif is_aot():
            data_cuda = batch.to_dict(sparse_dtype=torch.int64)
            result = model(data_cuda)
            export_model_aot(model, data_cuda, save_dir)
        else:
            data = batch.to_dict(sparse_dtype=torch.int64)
            result = model(data)
            result_info = {k: (v.size(), v.dtype) for k, v in result.items()}
            logger.info(f"Model Outputs: {result_info}")

            gm = symbolic_trace(model)
            with open(os.path.join(save_dir, "gm.code"), "w") as f:
                f.write(gm.code)

            scripted_model = torch.jit.script(gm)
            scripted_model.save(os.path.join(save_dir, "scripted_model.pt"))

        features = model._features
        feature_configs = create_feature_configs(features, asset_dir=save_dir)
        pipeline_config = copy.copy(pipeline_config)
        pipeline_config.ClearField("feature_configs")
        pipeline_config.feature_configs.extend(feature_configs)
        config_util.save_message(
            pipeline_config, os.path.join(save_dir, "pipeline.config")
        )
        logger.info("saving fg json...")
        fg_json = create_fg_json(features, asset_dir=save_dir)
        with open(os.path.join(save_dir, "fg.json"), "w") as f:
            json.dump(fg_json, f, indent=4)
        with open(os.path.join(save_dir, "model_acc.json"), "w") as f:
            json.dump(export_acc_config(), f, indent=4)


def export(
    pipeline_config_path: str,
    export_dir: str,
    checkpoint_path: Optional[str] = None,
    asset_files: Optional[str] = None,
) -> None:
    """Export a EasyRec model.

    Args:
        pipeline_config_path (str): path to EasyRecConfig object.
        export_dir (str): base directory where the model should be exported.
        checkpoint_path (str, optional): if specified, will use this model instead of
            model specified by model_dir in pipeline_config_path.
        asset_files (str, optional): more files will be copied to export_dir.
    """
    is_rank_zero = int(os.environ.get("RANK", 0)) == 0
    if not is_rank_zero:
        logger.warning("Only first rank will be used for export now.")
        return
    else:
        if os.environ.get("WORLD_SIZE") != "1":
            logger.warning(
                "export only support WORLD_SIZE=1 now, we set WORLD_SIZE to 1."
            )
            os.environ["WORLD_SIZE"] = "1"

    pipeline_config = config_util.load_pipeline_config(pipeline_config_path)
    ori_pipeline_config = copy.copy(pipeline_config)

    dist.init_process_group("gloo")
    if is_rank_zero:
        if os.path.exists(export_dir):
            raise RuntimeError(f"directory {export_dir} already exist.")

    assets = []
    if asset_files:
        assets = asset_files.split(",")

    data_config = pipeline_config.data_config
    is_trt_convert = is_trt()
    if is_trt_convert:
        # export batch_size too large may OOM in trt convert phase
        max_batch_size = get_trt_max_batch_size()
        data_config.batch_size = min(data_config.batch_size, max_batch_size)
        logger.info("using new batch_size: %s in trt export", data_config.batch_size)

    # Build feature
    features = _create_features(list(pipeline_config.feature_configs), data_config)

    # make dataparser to get user feats before create model
    data_config.num_workers = 1
    dataloader = _get_dataloader(
        data_config, features, pipeline_config.train_input_path, mode=Mode.PREDICT
    )

    # Build model
    model = _create_model(
        pipeline_config.model_config,
        features,
        list(data_config.label_fields),
    )
    InferWrapper = CudaExportWrapper if is_aot() else ScriptWrapper
    model = InferWrapper(model)
    init_parameters(model, torch.device("cpu"))

    if not checkpoint_path:
        checkpoint_path, _ = checkpoint_util.latest_checkpoint(
            pipeline_config.model_dir
        )
    if checkpoint_path:
        if is_input_tile_emb():
            remap_file_path = os.path.join(export_dir, "emb_ckpt_mapping.txt")
            if is_rank_zero:
                if not os.path.exists(export_dir):
                    os.makedirs(export_dir)
                write_mapping_file_for_input_tile(model.state_dict(), remap_file_path)

            dist.barrier()
            checkpoint_util.restore_model(
                checkpoint_path, model, ckpt_param_map_path=remap_file_path
            )
        else:
            checkpoint_util.restore_model(checkpoint_path, model)
    else:
        raise ValueError("checkpoint path should be specified.")

    if isinstance(model.model, MatchModel):
        for name, module in model.model.named_children():
            if isinstance(module, MatchTower) or isinstance(module, MatchTowerWoEG):
                wrapper = (
                    TowerWrapper if isinstance(module, MatchTower) else TowerWoEGWrapper
                )
                tower = InferWrapper(wrapper(module, name))
                tower_export_dir = os.path.join(export_dir, name.replace("_tower", ""))
                _script_model(
                    ori_pipeline_config,
                    tower,
                    model.state_dict(),
                    dataloader,
                    tower_export_dir,
                )
                for asset in assets:
                    shutil.copy(asset, tower_export_dir)
    elif isinstance(model.model, TDM):
        for name, module in model.model.named_children():
            if isinstance(module, EmbeddingGroup):
                emb_module = InferWrapper(TDMEmbedding(module, name))
                _script_model(
                    ori_pipeline_config,
                    emb_module,
                    model.state_dict(),
                    dataloader,
                    os.path.join(export_dir, "embedding"),
                )
                break
        _script_model(
            ori_pipeline_config,
            model,
            None,
            dataloader,
            os.path.join(export_dir, "model"),
        )
        for asset in assets:
            shutil.copy(asset, os.path.join(export_dir, "model"))
    else:
        _script_model(
            ori_pipeline_config,
            model,
            None,
            dataloader,
            export_dir,
        )
        for asset in assets:
            shutil.copy(asset, export_dir)


def predict(
    predict_input_path: str,
    predict_output_path: str,
    scripted_model_path: str,
    reserved_columns: Optional[str] = None,
    output_columns: Optional[str] = None,
    batch_size: Optional[int] = None,
    is_profiling: bool = False,
    debug_level: int = 0,
    dataset_type: Optional[str] = None,
    predict_threads: Optional[int] = None,
    writer_type: Optional[str] = None,
    edit_config_json: Optional[str] = None,
) -> None:
    """Evaluate a EasyRec model.

    Args:
        predict_input_path (str): inference input data path.
        predict_output_path (str): inference output data path.
        scripted_model_path (str): path to scripted model.
        reserved_columns (str, optional): columns to reserved in output.
        output_columns (str, optional): columns of model output.
        batch_size (int, optional): predict batch_size.
        is_profiling (bool): profiling predict process or not.
        debug_level (int, optional): debug level for debug parsed inputs etc.
        dataset_type (str, optional): dataset type, default use the type in pipeline.
        predict_threads (int, optional): predict threads num, default will
            use num_workers in data_config.
        writer_type (int, optional): data writer type, default will be same as
            dataset_type in data_config.
        edit_config_json (str, optional): edit pipeline config json str.
    """
    reserved_cols: Optional[List[str]] = None
    output_cols: Optional[List[str]] = None
    if reserved_columns is not None:
        reserved_cols = [x.strip() for x in reserved_columns.split(",")]
    if output_columns is not None:
        output_cols = [x.strip() for x in output_columns.split(",")]

    pipeline_config = config_util.load_pipeline_config(
        os.path.join(scripted_model_path, "pipeline.config"), allow_unknown_field=True
    )
    if batch_size:
        pipeline_config.data_config.batch_size = batch_size

    is_trt_convert: bool = is_trt_predict(scripted_model_path)
    if is_trt_convert:
        # predict batch_size too large may out of range
        max_batch_size = get_trt_max_batch_size()
        pipeline_config.data_config.batch_size = min(
            pipeline_config.data_config.batch_size, max_batch_size
        )
        logger.info(
            "using new batch_size: %s in trt predict",
            pipeline_config.data_config.batch_size,
        )

    if dataset_type:
        pipeline_config.data_config.dataset_type = getattr(DatasetType, dataset_type)
    if edit_config_json:
        edit_config_json = json.loads(edit_config_json)
        config_util.edit_config(pipeline_config, edit_config_json)

    device_and_backend = init_process_group()
    device: torch.device = device_and_backend[0]

    is_rank_zero = int(os.environ.get("RANK", 0)) == 0
    is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0

    data_config: DataConfig = pipeline_config.data_config
    data_config.drop_remainder = False
    # Build feature
    features = _create_features(list(pipeline_config.feature_configs), data_config)

    infer_dataloader = _get_dataloader(
        data_config,
        features,
        predict_input_path,
        reserved_columns=reserved_cols,
        mode=Mode.PREDICT,
        debug_level=debug_level,
    )
    infer_iterator = iter(infer_dataloader)

    if writer_type is None:
        writer_type = DatasetType.Name(data_config.dataset_type).replace(
            "Dataset", "Writer"
        )
    writer: BaseWriter = create_writer(
        predict_output_path,
        writer_type,
        quota_name=data_config.odps_data_quota_name,
    )

    # disable jit compile， as it compile too slow now.
    if "PYTORCH_TENSOREXPR_FALLBACK" not in os.environ:
        os.environ["PYTORCH_TENSOREXPR_FALLBACK"] = "2"

    model: torch.jit.ScriptModule = torch.jit.load(
        os.path.join(scripted_model_path, "scripted_model.pt"), map_location=device
    )
    model.eval()

    if is_local_rank_zero:
        plogger = ProgressLogger(desc="Predicting", miniters=10)

    if is_profiling:
        if is_rank_zero:
            logger.info(str(model))
        prof = torch.profiler.profile(
            schedule=torch.profiler.schedule(wait=1, warmup=1, active=10, repeat=1),
            on_trace_ready=torch.profiler.tensorboard_trace_handler(
                os.path.join(scripted_model_path, "predict_trace")
            ),
            record_shapes=True,
            profile_memory=True,
            with_stack=True,
        )
        prof.start()

    if predict_threads is None:
        predict_threads = max(data_config.num_workers, 1)
    data_queue: Queue[Optional[Batch]] = Queue(maxsize=predict_threads * 2)
    pred_queue: Queue[
        Tuple[Optional[Dict[str, torch.Tensor]], Optional[RecordBatchTensor]]
    ] = Queue(maxsize=predict_threads * 2)

    def _forward(batch: Batch) -> Tuple[Dict[str, torch.Tensor], RecordBatchTensor]:
        with torch.no_grad():
            parsed_inputs = batch.to_dict(sparse_dtype=torch.int64)
            # when predicting with a model exported using INPUT_TILE,
            #  we set the batch size tensor to 1 to disable tiling.
            parsed_inputs["batch_size"] = torch.tensor(1, dtype=torch.int64)
            if is_trt_convert:
                predictions = model(parsed_inputs)
            else:
                predictions = model(parsed_inputs, device)
            predictions = {k: v.to("cpu") for k, v in predictions.items()}
            return predictions, batch.reserves

    def _write(
        predictions: Dict[str, torch.Tensor],
        reserves: RecordBatchTensor,
        output_cols: List[str],
    ) -> None:
        output_dict = OrderedDict()
        for c in output_cols:
            v = predictions[c]
            v = v.tolist() if v.ndim > 1 else v.numpy()
            output_dict[c] = pa.array(v)
        reserve_batch_record = reserves.get()
        if reserve_batch_record is not None:
            for k, v in zip(
                reserve_batch_record.column_names, reserve_batch_record.columns
            ):
                output_dict[k] = v
        writer.write(output_dict)

    def _write_loop(output_cols: List[str]) -> None:
        while True:
            predictions, reserves = pred_queue.get(timeout=PREDICT_QUEUE_TIMEOUT)
            if predictions is None:
                break
            assert predictions is not None and reserves is not None
            _write(predictions, reserves, output_cols)

    def _forward_loop() -> None:
        while True:
            batch = data_queue.get(timeout=PREDICT_QUEUE_TIMEOUT)
            if batch is None:
                break
            assert batch is not None
            pred = _forward(batch)
            pred_queue.put(pred, timeout=PREDICT_QUEUE_TIMEOUT)

    forward_t_list = []
    write_t = None
    i_step = 0
    while True:
        try:
            batch = next(infer_iterator)

            if i_step == 0:
                # lazy init writer and create write and forward thread
                predictions, reserves = _forward(batch)
                if output_cols is None:
                    output_cols = sorted(predictions.keys())
                _write(predictions, reserves, output_cols)
                for _ in range(predict_threads):
                    t = Thread(target=_forward_loop)
                    t.start()
                    forward_t_list.append(t)
                write_t = Thread(target=_write_loop, args=(output_cols,))
                write_t.start()
            else:
                data_queue.put(batch, timeout=PREDICT_QUEUE_TIMEOUT)

            if is_local_rank_zero:
                plogger.log(i_step)
            if is_profiling:
                prof.step()
            i_step += 1
        except StopIteration:
            break

    for _ in range(predict_threads):
        data_queue.put(None, timeout=PREDICT_QUEUE_TIMEOUT)
    for t in forward_t_list:
        t.join()
    pred_queue.put((None, None), timeout=PREDICT_QUEUE_TIMEOUT)
    assert write_t is not None
    write_t.join()
    writer.close()

    if is_profiling:
        prof.stop()
    if is_local_rank_zero:
        logger.info("Predict Finished.")
