torchbenchmark/util/framework/timm/timm_config.py (19 lines of code) (raw):

import torch.nn as nn import dataclasses from timm.optim import create_optimizer @dataclasses.dataclass class OptimizerOption: lr: float opt: str weight_decay: float momentum: float class TimmConfig: def __init__(self, model, device): self.model = model self.device = device # Configurations self.num_classes = self.model.num_classes self.loss = nn.CrossEntropyLoss().to(self.device) self.target_shape = tuple() self.input_size = self.model.default_cfg["input_size"] # Default optimizer configurations borrowed from: # https://github.com/rwightman/pytorch-image-models/blob/779107b693010934ac87c8cecbeb65796e218488/timm/optim/optim_factory.py#L78 opt_args = OptimizerOption(lr=1e-4, opt="sgd", weight_decay = 0.0001, momentum = 0.9) self.optimizer = create_optimizer(opt_args, self.model)