tzrec/optim/optimizer_builder.py (60 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Tuple, Type, Union import torch from torch.optim.optimizer import Optimizer from torchrec.optim import optimizers, rowwise_adagrad from tzrec.optim.lr_scheduler import BaseLR from tzrec.protos import optimizer_pb2 from tzrec.utils.config_util import config_to_kwargs def create_sparse_optimizer( optimizer_config: optimizer_pb2.SparseOptimizer, ) -> Tuple[Type[Optimizer], Dict[str, Any]]: """Create optimizer for embedding module. Args: optimizer_config (optimizer_pb2.SparseOptimizer): an instance of SparseOptimizer config. Returns: optimizer (Optimizer): an instance of Optimizer. optimizer_kwargs (dict): optimizer params. """ optimizer_type = optimizer_config.WhichOneof("optimizer") oneof_optim_config = getattr(optimizer_config, optimizer_type) optimizer_kwargs = config_to_kwargs(oneof_optim_config) if optimizer_type == "sgd_optimizer": return optimizers.SGD, optimizer_kwargs elif optimizer_type == "adagrad_optimizer": return optimizers.Adagrad, optimizer_kwargs elif optimizer_type == "adam_optimizer": return optimizers.Adam, optimizer_kwargs elif optimizer_type == "lars_sgd_optimizer": return optimizers.LarsSGD, optimizer_kwargs elif optimizer_type == "lamb_optimizer": return optimizers.LAMB, optimizer_kwargs elif optimizer_type == "partial_rowwise_lamb_optimizer": return optimizers.PartialRowWiseLAMB, optimizer_kwargs elif optimizer_type == "partial_rowwise_adam_optimizer": return optimizers.PartialRowWiseAdam, optimizer_kwargs elif optimizer_type == "rowwise_adagrad_optimizer": return rowwise_adagrad.RowWiseAdagrad, optimizer_kwargs else: raise ValueError(f"Unknown optimizer: {optimizer_type}") def create_dense_optimizer( optimizer_config: optimizer_pb2.DenseOptimizer, ) -> Tuple[Type[Optimizer], Dict[str, Any]]: """Create optimizer for dense module. Args: optimizer_config (optimizer_pb2.DenseOptimizer): an instance of DenseOptimizer config. Returns: optimizer (Optimizer): an instance of Optimizer. optimizer_kwargs (dict): optimizer params. """ optimizer_type = optimizer_config.WhichOneof("optimizer") oneof_optim_config = getattr(optimizer_config, optimizer_type) optimizer_kwargs = config_to_kwargs(oneof_optim_config) if optimizer_type == "sgd_optimizer": return torch.optim.SGD, optimizer_kwargs elif optimizer_type == "adagrad_optimizer": return torch.optim.Adagrad, optimizer_kwargs elif optimizer_type == "adam_optimizer": beta1 = optimizer_kwargs.pop("beta1") beta2 = optimizer_kwargs.pop("beta2") optimizer_kwargs["betas"] = (beta1, beta2) return torch.optim.Adam, optimizer_kwargs else: raise ValueError(f"Unknown optimizer: {optimizer_type}") def create_scheduler( optimizer: Optimizer, optimizer_config: Union[ optimizer_pb2.SparseOptimizer, optimizer_pb2.DenseOptimizer ], ) -> BaseLR: """Create optimizer for dense module. Args: optimizer (Optimizer): an instance of Optimizer. optimizer_config (optimizer_pb2.SparseOptimizer|optimizer_pb2.DenseOptimizer): an instance of Optimizer config. Returns: lr (BaseLR): a lr scheduler. """ lr_type = optimizer_config.WhichOneof("learning_rate") oneof_lr_config = getattr(optimizer_config, lr_type) lr_cls_name = oneof_lr_config.__class__.__name__ lr_kwargs = config_to_kwargs(oneof_lr_config) lr_kwargs["optimizer"] = optimizer # pyre-ignore [16] return BaseLR.create_class(lr_cls_name)(**lr_kwargs)