def create_optimizer_for_sync_aggregator()

in flsim/optimizers/sync_aggregators.py [0:0]


def create_optimizer_for_sync_aggregator(config: SyncAggregatorConfig, model: Model):
    if config._target_ == FedAvgWithLRSyncAggregatorConfig._target_:
        return torch.optim.SGD(
            model.parameters(),
            # pyre-fixme[16]: `SyncAggregatorConfig` has no attribute `lr`.
            lr=config.lr,
            # pyre-fixme[16]: `SyncAggregatorConfig` has no attribute `momentum`.
            momentum=config.momentum,
        )
    elif config._target_ == FedAdamSyncAggregatorConfig._target_:
        return torch.optim.Adam(
            model.parameters(),
            lr=config.lr,
            # pyre-fixme[16]: `SyncAggregatorConfig` has no attribute `weight_decay`.
            weight_decay=config.weight_decay,
            # pyre-fixme[16]: `SyncAggregatorConfig` has no attribute `beta1`.
            # pyre-fixme[16]: `SyncAggregatorConfig` has no attribute `beta2`.
            betas=(config.beta1, config.beta2),
            # pyre-fixme[16]: `SyncAggregatorConfig` has no attribute `eps`.
            eps=config.eps,
        )
    elif config._target_ == FedLARSSyncAggregatorConfig._target_:
        return LARS(
            model.parameters(),
            lr=config.lr,
            # pyre-fixme[16]: `SyncAggregatorConfig` has no attribute `beta`.
            beta=config.beta,
            weight_decay=config.weight_decay,
        )

    elif config._target_ == FedLAMBSyncAggregatorConfig._target_:
        return LAMB(
            model.parameters(),
            lr=config.lr,
            beta1=config.beta1,
            beta2=config.beta2,
            weight_decay=config.weight_decay,
            eps=config.eps,
        )