in flsim/optimizers/sync_aggregators.py [0:0]
def create_optimizer_for_sync_aggregator(config: SyncAggregatorConfig, model: Model):
if config._target_ == FedAvgWithLRSyncAggregatorConfig._target_:
return torch.optim.SGD(
model.parameters(),
# pyre-fixme[16]: `SyncAggregatorConfig` has no attribute `lr`.
lr=config.lr,
# pyre-fixme[16]: `SyncAggregatorConfig` has no attribute `momentum`.
momentum=config.momentum,
)
elif config._target_ == FedAdamSyncAggregatorConfig._target_:
return torch.optim.Adam(
model.parameters(),
lr=config.lr,
# pyre-fixme[16]: `SyncAggregatorConfig` has no attribute `weight_decay`.
weight_decay=config.weight_decay,
# pyre-fixme[16]: `SyncAggregatorConfig` has no attribute `beta1`.
# pyre-fixme[16]: `SyncAggregatorConfig` has no attribute `beta2`.
betas=(config.beta1, config.beta2),
# pyre-fixme[16]: `SyncAggregatorConfig` has no attribute `eps`.
eps=config.eps,
)
elif config._target_ == FedLARSSyncAggregatorConfig._target_:
return LARS(
model.parameters(),
lr=config.lr,
# pyre-fixme[16]: `SyncAggregatorConfig` has no attribute `beta`.
beta=config.beta,
weight_decay=config.weight_decay,
)
elif config._target_ == FedLAMBSyncAggregatorConfig._target_:
return LAMB(
model.parameters(),
lr=config.lr,
beta1=config.beta1,
beta2=config.beta2,
weight_decay=config.weight_decay,
eps=config.eps,
)