in optimum/tpu/xla_model_parallel.py [0:0]
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
keep_master_weight_for_test: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
groups: Optional[List] = None,
quant: bool = False,