tzrec/utils/fx_util.py (108 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Callable, Dict, List, Optional, Union import torch from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( BoundsCheckMode, ) from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( IntNBitTableBatchedEmbeddingBagsCodegen, ) from torch import Tensor from torchrec.fx import symbolic_trace as _symbolic_trace def symbolic_trace( # pyre-ignore[24] root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None, leaf_modules: Optional[List[str]] = None, ) -> torch.fx.GraphModule: """Symbolic tracing API. Given an `nn.Module` or function instance `root`, this function will return a `GraphModule` constructed by recording operations seen while tracing through `root`. `concrete_args` allows you to partially specialize your function, whether it's to remove control flow or data structures. Args: root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted into a Graph representation. concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized leaf_modules (Optional[List[str]]): modules do not trace Returns: GraphModule: a Module created from the recorded operations from ``root``. """ # ComputeJTDictToKJT could not be traced _leaf_modules = ["ComputeJTDictToKJT"] if leaf_modules: _leaf_modules.extend(leaf_modules) return _symbolic_trace(root, concrete_args, _leaf_modules) @torch.fx.wrap def fx_arange(len: int, device: torch.device) -> torch.Tensor: """Fx trace wrapper for arange.""" return torch.arange(len, device=device) @torch.fx.wrap def fx_unwrap_optional_tensor(optional: Optional[torch.Tensor]) -> torch.Tensor: """Unwrap optional tensor for trace.""" assert optional is not None, "Expected optional to be non-None Tensor" return optional @torch.fx.wrap def fx_int_item(x: torch.Tensor) -> int: """Fx trace wrapper for `int(x.item())`.""" return int(x.item()) # We remove `inputs_to_device` to allow `IntNBitTableBatchedEmbeddingBagsCodegen` # temporarily to run on both CPU and GPU after applying `symbolic_trace`. Additionally, # we also can uncomment the following code to ensure it functions correctly, this may # introduce unnecessary to_device operations. # @torch.fx.wrap # def inputs_to_device( # indices: torch.Tensor, # offsets: torch.Tensor, # per_sample_weights: Optional[torch.Tensor], # bounds_check_warning: torch.Tensor, # ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: # if bounds_check_warning.device.type == "meta": # return indices, offsets, per_sample_weights # non_blocking = bounds_check_warning.device.type != "cpu" # if indices.device != bounds_check_warning.device: # indices = indices.to(bounds_check_warning.device, non_blocking=non_blocking) # if offsets.device != bounds_check_warning.device: # offsets = offsets.to(bounds_check_warning.device, non_blocking=non_blocking) # if ( # per_sample_weights is not None # and per_sample_weights.device != bounds_check_warning.device # ): # per_sample_weights = per_sample_weights.to( # bounds_check_warning.device, non_blocking=non_blocking # ) # return indices, offsets, per_sample_weights def _forward_impl( # pyre-ignore[2] self, indices: Tensor, offsets: Tensor, per_sample_weights: Optional[Tensor] = None, ) -> Tensor: assert self.weight_initialized, ( "weight needs to be initialized before forward function" ) # indices, offsets, per_sample_weights = inputs_to_device( # indices, offsets, per_sample_weights, self.bounds_check_warning # ) # First bound check: check if the indices/offsets are within the boundary # of the original embedding rows before pruning. # Note that this is only applied when we enable pruning (if the perf becomes # an issue, we can fuse it inside the remapping kernel). if ( self.index_remapping_hash_table_cpu is not None or self.index_remapping_hash_table.numel() > 0 or self.index_remappings_array.numel() > 0 ): if self.bounds_check_mode_int != BoundsCheckMode.NONE.value: torch.ops.fbgemm.bounds_check_indices( self.original_rows_per_table, indices, offsets, self.bounds_check_mode_int, self.bounds_check_warning, per_sample_weights, ) # Index remapping changes input indices, and some of them becomes -1 (prunned rows). # Hence, remapping should be done before prefetch and emb lookup # so that these operations are with the remapped indices. if self.index_remapping_hash_table_cpu is not None: indices = self.index_remapping_hash_table_cpu.lookup(indices, offsets) elif self.index_remapping_hash_table.numel() > 0: # Convert from raw indices to pruned indices indices = torch.ops.fbgemm.pruned_hashmap_lookup( indices, offsets, self.index_remapping_hash_table, self.index_remapping_hash_table_offsets, ) elif self.index_remappings_array.numel() > 0: indices = torch.ops.fbgemm.pruned_array_lookup( indices, offsets, self.index_remappings_array, self.index_remappings_array_offsets, ) if self.lxu_cache_weights.numel() > 0: if self.timestep_prefetch_size.get() <= 0: self.prefetch(indices, offsets) self.timestep_prefetch_size.decrement() lxu_cache_locations = self.lxu_cache_locations_list.pop() # Second bound check: check if the indices/offsets are within the boundary # of the pruned embedding rows after pruning. # Note: we cast to int as a TorchScript workaround. if self.bounds_check_mode_int != BoundsCheckMode.NONE.value: torch.ops.fbgemm.bounds_check_indices( self.rows_per_table, indices, offsets, self.bounds_check_mode_int, self.bounds_check_warning, per_sample_weights, ) # Note: CPU and CUDA ops use the same interface to facilitate JIT IR # generation for CUDA/CPU. For CPU op, we don't need weights_uvm and # weights_placements return torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function( dev_weights=self.weights_host if self.host_size > 0 else self.weights_dev, uvm_weights=self.weights_uvm, weights_placements=self.weights_placements, weights_offsets=self.weights_offsets, weights_tys=self.weights_tys, D_offsets=self.D_offsets, total_D=self.total_D, max_int2_D=self.max_int2_D, max_int4_D=self.max_int4_D, max_int8_D=self.max_int8_D, max_float16_D=self.max_float16_D, max_float32_D=self.max_float32_D, indices=indices, offsets=offsets, pooling_mode=int(self.pooling_mode), indice_weights=per_sample_weights, output_dtype=self.output_dtype, lxu_cache_weights=self.lxu_cache_weights, lxu_cache_locations=lxu_cache_locations, row_alignment=self.row_alignment, max_float8_D=self.max_float8_D, fp8_exponent_bits=self.fp8_exponent_bits, fp8_exponent_bias=self.fp8_exponent_bias, ) IntNBitTableBatchedEmbeddingBagsCodegen._forward_impl = _forward_impl