tzrec/main.py (948 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import itertools
import json
import os
import shutil
from collections import OrderedDict
from queue import Queue
from threading import Thread
from typing import Any, Dict, List, Optional, Tuple, Union
import pyarrow as pa
import torch
from torch import distributed as dist
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# NOQA
from torchrec.distributed.train_pipeline import TrainPipelineSparseDist
from torchrec.inference.modules import quantize_embeddings
from torchrec.optim.apply_optimizer_in_backward import (
apply_optimizer_in_backward, # NOQA
)
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
from tzrec.acc.aot_utils import export_model_aot
from tzrec.acc.trt_utils import export_model_trt, get_trt_max_batch_size
from tzrec.acc.utils import (
export_acc_config,
is_aot,
is_cuda_export,
is_input_tile_emb,
is_quant,
is_trt,
is_trt_predict,
quant_dtype,
write_mapping_file_for_input_tile,
)
from tzrec.constant import PREDICT_QUEUE_TIMEOUT, Mode
from tzrec.datasets.dataset import BaseDataset, BaseWriter, create_writer
from tzrec.datasets.utils import Batch, RecordBatchTensor
from tzrec.features.feature import (
BaseFeature,
create_feature_configs,
create_features,
create_fg_json,
)
from tzrec.models.match_model import (
MatchModel,
MatchTower,
MatchTowerWoEG,
TowerWoEGWrapper,
TowerWrapper,
)
from tzrec.models.model import BaseModel, CudaExportWrapper, ScriptWrapper, TrainWrapper
from tzrec.models.tdm import TDM, TDMEmbedding
from tzrec.modules.embedding import EmbeddingGroup
from tzrec.modules.utils import BaseModule
from tzrec.ops import Kernel
from tzrec.optim import optimizer_builder
from tzrec.optim.lr_scheduler import BaseLR
from tzrec.protos.data_pb2 import DataConfig, DatasetType
from tzrec.protos.eval_pb2 import EvalConfig
from tzrec.protos.feature_pb2 import FeatureConfig
from tzrec.protos.model_pb2 import Kernel as KernelProto
from tzrec.protos.model_pb2 import ModelConfig
from tzrec.protos.pipeline_pb2 import EasyRecConfig
from tzrec.protos.train_pb2 import TrainConfig
from tzrec.utils import checkpoint_util, config_util
from tzrec.utils.dist_util import DistributedModelParallel
from tzrec.utils.fx_util import symbolic_trace
from tzrec.utils.logging_util import ProgressLogger, logger
from tzrec.utils.plan_util import create_planner, get_default_sharders
from tzrec.utils.state_dict_util import fix_mch_state, init_parameters
from tzrec.version import __version__ as tzrec_version
def init_process_group() -> Tuple[torch.device, str]:
"""Init process_group, device, rank, backend."""
rank = int(os.environ.get("LOCAL_RANK", 0))
if torch.cuda.is_available():
device: torch.device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
else:
device: torch.device = torch.device("cpu")
backend = "gloo"
dist.init_process_group(backend=backend)
return device, backend
def _create_features(
feature_configs: List[FeatureConfig], data_config: DataConfig
) -> List[BaseFeature]:
neg_fields = None
if data_config.HasField("sampler"):
sampler_type = data_config.WhichOneof("sampler")
if sampler_type != "tdm_sampler":
neg_fields = list(
getattr(data_config, data_config.WhichOneof("sampler")).attr_fields
)
features = create_features(
feature_configs,
fg_mode=data_config.fg_mode,
neg_fields=neg_fields,
fg_encoded_multival_sep=data_config.fg_encoded_multival_sep,
force_base_data_group=data_config.force_base_data_group,
)
return features
def _get_dataloader(
data_config: DataConfig,
features: List[BaseFeature],
input_path: str,
reserved_columns: Optional[List[str]] = None,
mode: Mode = Mode.TRAIN,
gl_cluster: Optional[Dict[str, Union[int, str]]] = None,
debug_level: int = 0,
) -> DataLoader:
"""Build dataloader.
Args:
data_config (DataConfig): dataloader config.
features (list): list of feature.
input_path (str): input data path.
reserved_columns (list): reserved columns in predict mode.
mode (Mode): train or eval or predict.
gl_cluster (dict, bool): if set, reuse the graphlearn cluster.
debug_level (int): dataset debug level, when mode=predict and
debug_level > 0, will dump fg encoded data to debug_str
Return:
dataloader (dataloader): a DataLoader.
"""
dataset_name = DatasetType.Name(data_config.dataset_type)
# pyre-ignore [16]
dataset_cls = BaseDataset.create_class(dataset_name)
dataset = dataset_cls(
data_config=data_config,
features=features,
input_path=input_path,
reserved_columns=reserved_columns,
mode=mode,
debug_level=debug_level,
)
kwargs = {}
if data_config.num_workers < 1:
num_workers = 1
else:
num_workers = data_config.num_workers
# check number of files is valid or not for file based dataset.
if hasattr(dataset._reader, "num_files"):
num_files = dataset._reader.num_files()
world_size = int(os.environ.get("WORLD_SIZE", 1))
if num_files >= world_size:
num_files_per_worker = num_files // world_size
if num_files_per_worker < num_workers:
logger.info(
f"data_config.num_workers reset to {num_files_per_worker}"
)
num_workers = num_files_per_worker
else:
raise ValueError(
f"Number of files in the dataset[{input_path}] must greater "
f"than world_size: {world_size}, but got {num_files}"
)
kwargs["num_workers"] = num_workers
kwargs["persistent_workers"] = True
if mode == Mode.TRAIN:
# When in train_and_eval mode, use 2x worker in gl cluster
# for train_dataloader and eval_dataloader
dataset.launch_sampler_cluster(num_client_per_rank=num_workers * 2)
else:
if gl_cluster:
# Reuse the gl cluster for eval_dataloader
dataset.launch_sampler_cluster(
num_client_per_rank=num_workers * 2,
client_id_bias=num_workers,
cluster=gl_cluster,
)
else:
dataset.launch_sampler_cluster(num_client_per_rank=num_workers)
dataloader = DataLoader(
dataset=dataset,
batch_size=None,
pin_memory=data_config.pin_memory if mode != Mode.PREDICT else False,
collate_fn=lambda x: x,
**kwargs,
)
# For PyTorch versions 2.6 and above, we initialize the data iterator before
# beginning the training process to avoid potential CUDA-related issues following
# model saving.
iter(dataloader)
return dataloader
def _create_model(
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
) -> BaseModel:
"""Build model.
Args:
model_config (ModelConfig): easyrec model config.
features (list): list of features.
labels (list): list of label names.
sample_weights (list): list of sample weight names.
Return:
model: a EasyRec Model.
"""
model_cls_name = config_util.which_msg(model_config, "model")
# pyre-ignore [16]
model_cls = BaseModel.create_class(model_cls_name)
model: BaseModel = model_cls(
model_config, features, labels, sample_weights=sample_weights
)
kernel = Kernel[KernelProto.Name(model_config.kernel)]
model.set_kernel(kernel)
return model
def _evaluate(
model: nn.Module,
eval_dataloader: DataLoader,
eval_config: EvalConfig,
eval_result_filename: Optional[str] = None,
global_step: Optional[int] = None,
eval_summary_writer: Optional[SummaryWriter] = None,
global_epoch: Optional[int] = None,
) -> Dict[str, torch.Tensor]:
"""Evaluate the model."""
is_rank_zero = int(os.environ.get("RANK", 0)) == 0
is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0
model.eval()
pipeline = TrainPipelineSparseDist(
model,
# pyre-fixme [6]
None,
model.device,
execute_all_batches=True,
)
use_step = eval_config.num_steps and eval_config.num_steps > 0
iterator = iter(eval_dataloader)
step_iter = range(eval_config.num_steps) if use_step else itertools.count(0)
desc_suffix = ""
if global_epoch:
desc_suffix += f" Epoch-{global_epoch}"
if global_step:
desc_suffix += f" model-{global_step}"
_model = model.module.model
plogger = None
if is_local_rank_zero:
plogger = ProgressLogger(desc=f"Evaluating{desc_suffix}")
with torch.no_grad():
i_step = 0
for i_step in step_iter:
try:
losses, predictions, batch = pipeline.progress(iterator)
_model.update_metric(predictions, batch, losses)
if (
plogger is not None
and i_step % eval_config.log_step_count_steps == 0
):
plogger.log(i_step)
except StopIteration:
break
if plogger is not None:
plogger.log(i_step)
metric_result = _model.compute_metric()
if is_rank_zero:
metric_str = " ".join([f"{k}:{v:0.6f}" for k, v in metric_result.items()])
logger.info(f"Eval Result{desc_suffix}: {metric_str}")
metric_result = {k: v.item() for k, v in metric_result.items()}
if eval_result_filename:
with open(eval_result_filename, "w") as f:
json.dump(metric_result, f, indent=4)
if eval_summary_writer:
for k, v in metric_result.items():
eval_summary_writer.add_scalar(f"metric/{k}", v, global_step or 0)
return metric_result
def _log_train(
step: int,
losses: Dict[str, torch.Tensor],
param_groups: List[Dict[str, Any]],
plogger: Optional[ProgressLogger] = None,
summary_writer: Optional[SummaryWriter] = None,
) -> None:
"""Logging current training step."""
if plogger is not None:
loss_strs = []
lr_strs = []
for k, v in losses.items():
loss_strs.append(f"{k}:{v:.5f}")
for i, g in enumerate(param_groups):
lr_strs.append(f"lr_g{i}:{g['lr']:.5f}")
total_loss = sum(losses.values())
plogger.log(
step,
f"{' '.join(lr_strs)} {' '.join(loss_strs)} total_loss: {total_loss:.5f}",
)
if summary_writer is not None:
total_loss = sum(losses.values())
for k, v in losses.items():
summary_writer.add_scalar(f"loss/{k}", v, step)
summary_writer.add_scalar("loss/total", total_loss, step)
for i, g in enumerate(param_groups):
summary_writer.add_scalar(f"lr/g{i}", g["lr"], step)
def _train_and_evaluate(
model: nn.Module,
optimizer: optim.Optimizer,
train_dataloader: DataLoader,
eval_dataloader: Optional[DataLoader],
lr_scheduler: List[BaseLR],
model_dir: str,
train_config: TrainConfig,
eval_config: EvalConfig,
skip_steps: int = -1,
ckpt_path: Optional[str] = None,
eval_result_filename: str = "train_eval_result.txt",
) -> None:
"""Train and evaluate the model."""
is_rank_zero = int(os.environ.get("RANK", 0)) == 0
is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0
model.train()
assert train_config.num_steps ^ train_config.num_epochs, (
"train_config.num_epochs or train_config.num_steps should be set, "
"and can not be set at the same time."
)
use_epoch = train_config.num_epochs and train_config.num_epochs > 0
use_step = train_config.num_steps and train_config.num_steps > 0
epoch_iter = range(train_config.num_epochs) if use_epoch else itertools.count(0, 0)
step_iter = range(train_config.num_steps) if use_step else itertools.count(0)
save_checkpoints_steps, save_checkpoints_epochs = 0, 0
if train_config.save_checkpoints_epochs > 0:
save_checkpoints_epochs = train_config.save_checkpoints_epochs
else:
save_checkpoints_steps = train_config.save_checkpoints_steps
plogger = None
summary_writer = None
eval_summary_writer = None
if is_local_rank_zero:
plogger = ProgressLogger(desc="Training Epoch 0", start_n=skip_steps)
if is_rank_zero and train_config.use_tensorboard:
summary_writer = SummaryWriter(model_dir)
eval_summary_writer = SummaryWriter(os.path.join(model_dir, "eval_val"))
eval_result_filename = os.path.join(model_dir, eval_result_filename)
if train_config.is_profiling:
if is_rank_zero:
logger.info(str(model))
prof = torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1, warmup=1, active=10, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
os.path.join(model_dir, "train_eval_trace")
),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
prof.start()
last_ckpt_step = -1
i_step = 0
i_epoch = 0
losses = {}
for i_epoch in epoch_iter:
pipeline = TrainPipelineSparseDist(
model, optimizer, model.device, execute_all_batches=True
)
if plogger is not None:
plogger.set_description(f"Training Epoch {i_epoch}")
train_iterator = iter(train_dataloader)
# Restore model and optimizer checkpoint, because optimizer's state
# is lazy init, we should do a dummy step before restore.
if i_step == 0 and ckpt_path is not None:
peek_batch = next(train_iterator)
pipeline.progress(iter([peek_batch]))
train_iterator = itertools.chain([peek_batch], train_iterator)
checkpoint_util.restore_model(
ckpt_path, model, optimizer, train_config.fine_tune_ckpt_param_map
)
for i_step in step_iter:
if i_step <= skip_steps:
continue
try:
losses, _, _ = pipeline.progress(train_iterator)
if i_step % train_config.log_step_count_steps == 0:
_log_train(
i_step,
losses,
param_groups=optimizer.param_groups,
plogger=plogger,
summary_writer=summary_writer,
)
for lr in lr_scheduler:
if not lr.by_epoch:
lr.step()
except StopIteration:
step_iter = itertools.chain([i_step], step_iter)
i_step -= 1
break
if save_checkpoints_steps > 0 and i_step > 0:
if i_step % save_checkpoints_steps == 0:
last_ckpt_step = i_step
checkpoint_util.save_model(
os.path.join(model_dir, f"model.ckpt-{i_step}"),
model,
optimizer,
)
if eval_dataloader is not None:
_evaluate(
model,
eval_dataloader,
eval_config,
eval_result_filename=eval_result_filename,
global_step=i_step,
eval_summary_writer=eval_summary_writer,
global_epoch=i_epoch,
)
model.train()
if train_config.is_profiling:
prof.step()
if save_checkpoints_epochs > 0 and i_step > 0:
if i_epoch % save_checkpoints_epochs == 0:
last_ckpt_step = i_step
checkpoint_util.save_model(
os.path.join(model_dir, f"model.ckpt-{i_step}"),
model,
optimizer,
)
if eval_dataloader is not None:
_evaluate(
model,
eval_dataloader,
eval_config,
eval_result_filename=eval_result_filename,
global_step=i_step,
eval_summary_writer=eval_summary_writer,
global_epoch=i_epoch,
)
model.train()
if use_step and i_step >= train_config.num_steps - 1:
break
for lr in lr_scheduler:
if lr.by_epoch:
lr.step()
_log_train(
i_step,
losses,
param_groups=optimizer.param_groups,
plogger=plogger,
summary_writer=summary_writer,
)
if summary_writer is not None:
summary_writer.close()
if train_config.is_profiling:
prof.stop()
if last_ckpt_step != i_step:
checkpoint_util.save_model(
os.path.join(model_dir, f"model.ckpt-{i_step}"),
model,
optimizer,
)
if eval_dataloader is not None:
_evaluate(
model,
eval_dataloader,
eval_config,
eval_result_filename=eval_result_filename,
global_step=i_step,
eval_summary_writer=eval_summary_writer,
global_epoch=i_epoch,
)
model.train()
def train_and_evaluate(
pipeline_config_path: str,
train_input_path: Optional[str] = None,
eval_input_path: Optional[str] = None,
model_dir: Optional[str] = None,
continue_train: Optional[bool] = True,
fine_tune_checkpoint: Optional[str] = None,
edit_config_json: Optional[str] = None,
) -> None:
"""Train and evaluate a EasyRec model.
Args:
pipeline_config_path (str): path to EasyRecConfig object.
train_input_path (str, optional): train data path.
eval_input_path (str, optional): eval data path.
model_dir (str, optionl): model directory.
continue_train (bool, optional): whether to restart train from
an existing checkpoint.
fine_tune_checkpoint (str, optional): path to an existing
finetune checkpoint.
edit_config_json (str, optional): edit pipeline config json str.
"""
pipeline_config = config_util.load_pipeline_config(pipeline_config_path)
if fine_tune_checkpoint:
pipeline_config.train_config.fine_tune_checkpoint = fine_tune_checkpoint
if train_input_path:
pipeline_config.train_input_path = train_input_path
if eval_input_path:
pipeline_config.eval_input_path = eval_input_path
if model_dir:
pipeline_config.model_dir = model_dir
if edit_config_json:
edit_config_json = json.loads(edit_config_json)
config_util.edit_config(pipeline_config, edit_config_json)
device, _ = init_process_group()
is_rank_zero = int(os.environ.get("RANK", 0)) == 0
is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0
data_config = pipeline_config.data_config
# Build feature
features = _create_features(list(pipeline_config.feature_configs), data_config)
# Build dataloader
train_dataloader = _get_dataloader(
data_config, features, pipeline_config.train_input_path, mode=Mode.TRAIN
)
eval_dataloader = None
if pipeline_config.eval_input_path:
# pyre-ignore [16]
gl_cluster = train_dataloader.dataset.get_sampler_cluster()
eval_dataloader = _get_dataloader(
data_config,
features,
pipeline_config.eval_input_path,
mode=Mode.EVAL,
gl_cluster=gl_cluster,
)
# Build model
model = _create_model(
pipeline_config.model_config,
features,
list(data_config.label_fields),
sample_weights=list(data_config.sample_weight_fields),
)
model = TrainWrapper(model)
sparse_optim_cls, sparse_optim_kwargs = optimizer_builder.create_sparse_optimizer(
pipeline_config.train_config.sparse_optimizer
)
apply_optimizer_in_backward(
sparse_optim_cls, model.model.sparse_parameters(), sparse_optim_kwargs
)
planner = create_planner(
device=device,
# pyre-ignore [16]
batch_size=train_dataloader.dataset.sampled_batch_size,
)
plan = planner.collective_plan(
model, get_default_sharders(), dist.GroupMember.WORLD
)
if is_rank_zero:
logger.info(str(plan))
model = DistributedModelParallel(
module=model,
device=device,
plan=plan,
)
dense_optim_cls, dense_optim_kwargs = optimizer_builder.create_dense_optimizer(
pipeline_config.train_config.dense_optimizer
)
dense_optimizer = KeyedOptimizerWrapper(
dict(in_backward_optimizer_filter(model.named_parameters())),
lambda params: dense_optim_cls(params, **dense_optim_kwargs),
)
optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer])
sparse_lr = optimizer_builder.create_scheduler(
model.fused_optimizer, pipeline_config.train_config.sparse_optimizer
)
dense_lr = optimizer_builder.create_scheduler(
dense_optimizer, pipeline_config.train_config.dense_optimizer
)
ckpt_path = None
skip_steps = -1
if pipeline_config.train_config.fine_tune_checkpoint:
ckpt_path, _ = checkpoint_util.latest_checkpoint(
pipeline_config.train_config.fine_tune_checkpoint
)
if ckpt_path is None or not os.path.exists(ckpt_path):
raise RuntimeError(
"fine_tune_checkpoint"
"[{pipeline_config.train_config.fine_tune_checkpoint}] not exists."
)
if os.path.exists(pipeline_config.model_dir):
# TODO(hongsheng.jhs): save and restore dataloader state.
latest_ckpt_path, skip_steps = checkpoint_util.latest_checkpoint(
pipeline_config.model_dir
)
if latest_ckpt_path:
if continue_train:
ckpt_path = latest_ckpt_path
else:
raise RuntimeError(
f"model_dir[{pipeline_config.model_dir}] already exists "
"and not empty(if you want to continue train on current "
"model_dir please delete dir model_dir or specify "
"--continue_train)"
)
# use barrier to sync all workers, prevent rank zero save_message and create
# model_dir first, other slow rank find model_dir already exists and
# do continue train improperly.
dist.barrier()
if is_rank_zero:
config_util.save_message(
pipeline_config, os.path.join(pipeline_config.model_dir, "pipeline.config")
)
with open(os.path.join(pipeline_config.model_dir, "version"), "w") as f:
f.write(tzrec_version + "\n")
_train_and_evaluate(
model,
optimizer,
train_dataloader,
eval_dataloader,
[sparse_lr, dense_lr],
pipeline_config.model_dir,
train_config=pipeline_config.train_config,
eval_config=pipeline_config.eval_config,
skip_steps=skip_steps,
ckpt_path=ckpt_path,
)
if is_local_rank_zero:
logger.info("Train and Evaluate Finished.")
def evaluate(
pipeline_config_path: str,
checkpoint_path: Optional[str] = None,
eval_input_path: Optional[str] = None,
eval_result_filename: str = "eval_result.txt",
) -> None:
"""Evaluate a EasyRec model.
Args:
pipeline_config_path (str): path to EasyRecConfig object.
checkpoint_path (str, optional): if specified, will use this model instead of
model specified by model_dir in pipeline_config_path
eval_input_path (str, optional): eval data path, default use eval data in
pipeline_config, could be a path or a list of paths
eval_result_filename (str): evaluation result metrics save path.
"""
pipeline_config = config_util.load_pipeline_config(pipeline_config_path)
device, _ = init_process_group()
is_rank_zero = int(os.environ.get("RANK", 0)) == 0
is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0
data_config = pipeline_config.data_config
# Build feature
features = _create_features(list(pipeline_config.feature_configs), data_config)
eval_dataloader = _get_dataloader(
data_config,
features,
eval_input_path or pipeline_config.eval_input_path,
mode=Mode.EVAL,
)
# Build model
model = _create_model(
pipeline_config.model_config,
features,
list(data_config.label_fields),
sample_weights=list(data_config.sample_weight_fields),
)
model = TrainWrapper(model)
planner = create_planner(
device=device,
# pyre-ignore [16]
batch_size=eval_dataloader.dataset.sampled_batch_size,
)
plan = planner.collective_plan(
model, get_default_sharders(), dist.GroupMember.WORLD
)
if is_rank_zero:
logger.info(str(plan))
model = DistributedModelParallel(module=model, device=device, plan=plan)
global_step = None
if not checkpoint_path:
checkpoint_path, global_step = checkpoint_util.latest_checkpoint(
pipeline_config.model_dir
)
if checkpoint_path:
checkpoint_util.restore_model(checkpoint_path, model)
else:
raise ValueError("Eval checkpoint path should be specified.")
summary_writer = None
if is_rank_zero:
summary_writer = SummaryWriter(os.path.join(pipeline_config.model_dir, "eval"))
_evaluate(
model,
eval_dataloader,
eval_config=pipeline_config.eval_config,
eval_result_filename=os.path.join(
pipeline_config.model_dir, eval_result_filename
),
global_step=global_step,
eval_summary_writer=summary_writer,
)
if is_local_rank_zero:
logger.info("Evaluate Finished.")
def _script_model(
pipeline_config: EasyRecConfig,
model: BaseModule,
state_dict: Optional[Dict[str, Any]],
dataloader: DataLoader,
save_dir: str,
) -> None:
is_rank_zero = int(os.environ.get("RANK", 0)) == 0
is_trt_convert = is_trt()
if is_rank_zero:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
model.set_is_inference(True)
if state_dict is not None:
model.to_empty(device="cpu")
model.load_state_dict(state_dict, strict=False)
# for mc modules, fix output_segments_tensor is a meta tensor.
fix_mch_state(model)
batch = next(iter(dataloader))
if is_cuda_export():
model = model.cuda()
if is_quant():
logger.info("quantize embeddings...")
quantize_embeddings(model, dtype=quant_dtype(), inplace=True)
logger.info("finish quantize embeddings...")
model.eval()
if is_trt_convert:
data_cuda = batch.to_dict(sparse_dtype=torch.int64)
result = model(data_cuda, "cuda:0")
result_info = {k: (v.size(), v.dtype) for k, v in result.items()}
logger.info(f"Model Outputs: {result_info}")
export_model_trt(model, data_cuda, save_dir)
elif is_aot():
data_cuda = batch.to_dict(sparse_dtype=torch.int64)
result = model(data_cuda)
export_model_aot(model, data_cuda, save_dir)
else:
data = batch.to_dict(sparse_dtype=torch.int64)
result = model(data)
result_info = {k: (v.size(), v.dtype) for k, v in result.items()}
logger.info(f"Model Outputs: {result_info}")
gm = symbolic_trace(model)
with open(os.path.join(save_dir, "gm.code"), "w") as f:
f.write(gm.code)
scripted_model = torch.jit.script(gm)
scripted_model.save(os.path.join(save_dir, "scripted_model.pt"))
features = model._features
feature_configs = create_feature_configs(features, asset_dir=save_dir)
pipeline_config = copy.copy(pipeline_config)
pipeline_config.ClearField("feature_configs")
pipeline_config.feature_configs.extend(feature_configs)
config_util.save_message(
pipeline_config, os.path.join(save_dir, "pipeline.config")
)
logger.info("saving fg json...")
fg_json = create_fg_json(features, asset_dir=save_dir)
with open(os.path.join(save_dir, "fg.json"), "w") as f:
json.dump(fg_json, f, indent=4)
with open(os.path.join(save_dir, "model_acc.json"), "w") as f:
json.dump(export_acc_config(), f, indent=4)
def export(
pipeline_config_path: str,
export_dir: str,
checkpoint_path: Optional[str] = None,
asset_files: Optional[str] = None,
) -> None:
"""Export a EasyRec model.
Args:
pipeline_config_path (str): path to EasyRecConfig object.
export_dir (str): base directory where the model should be exported.
checkpoint_path (str, optional): if specified, will use this model instead of
model specified by model_dir in pipeline_config_path.
asset_files (str, optional): more files will be copied to export_dir.
"""
is_rank_zero = int(os.environ.get("RANK", 0)) == 0
if not is_rank_zero:
logger.warning("Only first rank will be used for export now.")
return
else:
if os.environ.get("WORLD_SIZE") != "1":
logger.warning(
"export only support WORLD_SIZE=1 now, we set WORLD_SIZE to 1."
)
os.environ["WORLD_SIZE"] = "1"
pipeline_config = config_util.load_pipeline_config(pipeline_config_path)
ori_pipeline_config = copy.copy(pipeline_config)
dist.init_process_group("gloo")
if is_rank_zero:
if os.path.exists(export_dir):
raise RuntimeError(f"directory {export_dir} already exist.")
assets = []
if asset_files:
assets = asset_files.split(",")
data_config = pipeline_config.data_config
is_trt_convert = is_trt()
if is_trt_convert:
# export batch_size too large may OOM in trt convert phase
max_batch_size = get_trt_max_batch_size()
data_config.batch_size = min(data_config.batch_size, max_batch_size)
logger.info("using new batch_size: %s in trt export", data_config.batch_size)
# Build feature
features = _create_features(list(pipeline_config.feature_configs), data_config)
# make dataparser to get user feats before create model
data_config.num_workers = 1
dataloader = _get_dataloader(
data_config, features, pipeline_config.train_input_path, mode=Mode.PREDICT
)
# Build model
model = _create_model(
pipeline_config.model_config,
features,
list(data_config.label_fields),
)
InferWrapper = CudaExportWrapper if is_aot() else ScriptWrapper
model = InferWrapper(model)
init_parameters(model, torch.device("cpu"))
if not checkpoint_path:
checkpoint_path, _ = checkpoint_util.latest_checkpoint(
pipeline_config.model_dir
)
if checkpoint_path:
if is_input_tile_emb():
remap_file_path = os.path.join(export_dir, "emb_ckpt_mapping.txt")
if is_rank_zero:
if not os.path.exists(export_dir):
os.makedirs(export_dir)
write_mapping_file_for_input_tile(model.state_dict(), remap_file_path)
dist.barrier()
checkpoint_util.restore_model(
checkpoint_path, model, ckpt_param_map_path=remap_file_path
)
else:
checkpoint_util.restore_model(checkpoint_path, model)
else:
raise ValueError("checkpoint path should be specified.")
if isinstance(model.model, MatchModel):
for name, module in model.model.named_children():
if isinstance(module, MatchTower) or isinstance(module, MatchTowerWoEG):
wrapper = (
TowerWrapper if isinstance(module, MatchTower) else TowerWoEGWrapper
)
tower = InferWrapper(wrapper(module, name))
tower_export_dir = os.path.join(export_dir, name.replace("_tower", ""))
_script_model(
ori_pipeline_config,
tower,
model.state_dict(),
dataloader,
tower_export_dir,
)
for asset in assets:
shutil.copy(asset, tower_export_dir)
elif isinstance(model.model, TDM):
for name, module in model.model.named_children():
if isinstance(module, EmbeddingGroup):
emb_module = InferWrapper(TDMEmbedding(module, name))
_script_model(
ori_pipeline_config,
emb_module,
model.state_dict(),
dataloader,
os.path.join(export_dir, "embedding"),
)
break
_script_model(
ori_pipeline_config,
model,
None,
dataloader,
os.path.join(export_dir, "model"),
)
for asset in assets:
shutil.copy(asset, os.path.join(export_dir, "model"))
else:
_script_model(
ori_pipeline_config,
model,
None,
dataloader,
export_dir,
)
for asset in assets:
shutil.copy(asset, export_dir)
def predict(
predict_input_path: str,
predict_output_path: str,
scripted_model_path: str,
reserved_columns: Optional[str] = None,
output_columns: Optional[str] = None,
batch_size: Optional[int] = None,
is_profiling: bool = False,
debug_level: int = 0,
dataset_type: Optional[str] = None,
predict_threads: Optional[int] = None,
writer_type: Optional[str] = None,
edit_config_json: Optional[str] = None,
) -> None:
"""Evaluate a EasyRec model.
Args:
predict_input_path (str): inference input data path.
predict_output_path (str): inference output data path.
scripted_model_path (str): path to scripted model.
reserved_columns (str, optional): columns to reserved in output.
output_columns (str, optional): columns of model output.
batch_size (int, optional): predict batch_size.
is_profiling (bool): profiling predict process or not.
debug_level (int, optional): debug level for debug parsed inputs etc.
dataset_type (str, optional): dataset type, default use the type in pipeline.
predict_threads (int, optional): predict threads num, default will
use num_workers in data_config.
writer_type (int, optional): data writer type, default will be same as
dataset_type in data_config.
edit_config_json (str, optional): edit pipeline config json str.
"""
reserved_cols: Optional[List[str]] = None
output_cols: Optional[List[str]] = None
if reserved_columns is not None:
reserved_cols = [x.strip() for x in reserved_columns.split(",")]
if output_columns is not None:
output_cols = [x.strip() for x in output_columns.split(",")]
pipeline_config = config_util.load_pipeline_config(
os.path.join(scripted_model_path, "pipeline.config"), allow_unknown_field=True
)
if batch_size:
pipeline_config.data_config.batch_size = batch_size
is_trt_convert: bool = is_trt_predict(scripted_model_path)
if is_trt_convert:
# predict batch_size too large may out of range
max_batch_size = get_trt_max_batch_size()
pipeline_config.data_config.batch_size = min(
pipeline_config.data_config.batch_size, max_batch_size
)
logger.info(
"using new batch_size: %s in trt predict",
pipeline_config.data_config.batch_size,
)
if dataset_type:
pipeline_config.data_config.dataset_type = getattr(DatasetType, dataset_type)
if edit_config_json:
edit_config_json = json.loads(edit_config_json)
config_util.edit_config(pipeline_config, edit_config_json)
device_and_backend = init_process_group()
device: torch.device = device_and_backend[0]
is_rank_zero = int(os.environ.get("RANK", 0)) == 0
is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0
data_config: DataConfig = pipeline_config.data_config
data_config.drop_remainder = False
# Build feature
features = _create_features(list(pipeline_config.feature_configs), data_config)
infer_dataloader = _get_dataloader(
data_config,
features,
predict_input_path,
reserved_columns=reserved_cols,
mode=Mode.PREDICT,
debug_level=debug_level,
)
infer_iterator = iter(infer_dataloader)
if writer_type is None:
writer_type = DatasetType.Name(data_config.dataset_type).replace(
"Dataset", "Writer"
)
writer: BaseWriter = create_writer(
predict_output_path,
writer_type,
quota_name=data_config.odps_data_quota_name,
)
# disable jit compileļ¼ as it compile too slow now.
if "PYTORCH_TENSOREXPR_FALLBACK" not in os.environ:
os.environ["PYTORCH_TENSOREXPR_FALLBACK"] = "2"
model: torch.jit.ScriptModule = torch.jit.load(
os.path.join(scripted_model_path, "scripted_model.pt"), map_location=device
)
model.eval()
if is_local_rank_zero:
plogger = ProgressLogger(desc="Predicting", miniters=10)
if is_profiling:
if is_rank_zero:
logger.info(str(model))
prof = torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1, warmup=1, active=10, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
os.path.join(scripted_model_path, "predict_trace")
),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
prof.start()
if predict_threads is None:
predict_threads = max(data_config.num_workers, 1)
data_queue: Queue[Optional[Batch]] = Queue(maxsize=predict_threads * 2)
pred_queue: Queue[
Tuple[Optional[Dict[str, torch.Tensor]], Optional[RecordBatchTensor]]
] = Queue(maxsize=predict_threads * 2)
def _forward(batch: Batch) -> Tuple[Dict[str, torch.Tensor], RecordBatchTensor]:
with torch.no_grad():
parsed_inputs = batch.to_dict(sparse_dtype=torch.int64)
# when predicting with a model exported using INPUT_TILE,
# we set the batch size tensor to 1 to disable tiling.
parsed_inputs["batch_size"] = torch.tensor(1, dtype=torch.int64)
if is_trt_convert:
predictions = model(parsed_inputs)
else:
predictions = model(parsed_inputs, device)
predictions = {k: v.to("cpu") for k, v in predictions.items()}
return predictions, batch.reserves
def _write(
predictions: Dict[str, torch.Tensor],
reserves: RecordBatchTensor,
output_cols: List[str],
) -> None:
output_dict = OrderedDict()
for c in output_cols:
v = predictions[c]
v = v.tolist() if v.ndim > 1 else v.numpy()
output_dict[c] = pa.array(v)
reserve_batch_record = reserves.get()
if reserve_batch_record is not None:
for k, v in zip(
reserve_batch_record.column_names, reserve_batch_record.columns
):
output_dict[k] = v
writer.write(output_dict)
def _write_loop(output_cols: List[str]) -> None:
while True:
predictions, reserves = pred_queue.get(timeout=PREDICT_QUEUE_TIMEOUT)
if predictions is None:
break
assert predictions is not None and reserves is not None
_write(predictions, reserves, output_cols)
def _forward_loop() -> None:
while True:
batch = data_queue.get(timeout=PREDICT_QUEUE_TIMEOUT)
if batch is None:
break
assert batch is not None
pred = _forward(batch)
pred_queue.put(pred, timeout=PREDICT_QUEUE_TIMEOUT)
forward_t_list = []
write_t = None
i_step = 0
while True:
try:
batch = next(infer_iterator)
if i_step == 0:
# lazy init writer and create write and forward thread
predictions, reserves = _forward(batch)
if output_cols is None:
output_cols = sorted(predictions.keys())
_write(predictions, reserves, output_cols)
for _ in range(predict_threads):
t = Thread(target=_forward_loop)
t.start()
forward_t_list.append(t)
write_t = Thread(target=_write_loop, args=(output_cols,))
write_t.start()
else:
data_queue.put(batch, timeout=PREDICT_QUEUE_TIMEOUT)
if is_local_rank_zero:
plogger.log(i_step)
if is_profiling:
prof.step()
i_step += 1
except StopIteration:
break
for _ in range(predict_threads):
data_queue.put(None, timeout=PREDICT_QUEUE_TIMEOUT)
for t in forward_t_list:
t.join()
pred_queue.put((None, None), timeout=PREDICT_QUEUE_TIMEOUT)
assert write_t is not None
write_t.join()
writer.close()
if is_profiling:
prof.stop()
if is_local_rank_zero:
logger.info("Predict Finished.")