fbgemm_gpu/codegen/lookup_args.py (39 lines of code) (raw):

#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from typing import NamedTuple, Optional import torch class CommonArgs(NamedTuple): placeholder_autograd_tensor: torch.Tensor dev_weights: torch.Tensor host_weights: torch.Tensor uvm_weights: torch.Tensor lxu_cache_weights: torch.Tensor weights_placements: torch.Tensor weights_offsets: torch.Tensor D_offsets: torch.Tensor total_D: int max_D: int hash_size_cumsum: torch.Tensor total_hash_size_bits: int indices: torch.Tensor offsets: torch.Tensor pooling_mode: int indice_weights: Optional[torch.Tensor] feature_requires_grad: Optional[torch.Tensor] lxu_cache_locations: torch.Tensor output_dtype: int class OptimizerArgs(NamedTuple): stochastic_rounding: bool gradient_clipping: bool max_gradient: float learning_rate: float eps: float beta1: float beta2: float weight_decay: float eta: float momentum: float class Momentum(NamedTuple): dev: torch.Tensor host: torch.Tensor uvm: torch.Tensor offsets: torch.Tensor placements: torch.Tensor