in parlai/core/torch_classifier_agent.py [0:0]
def __init__(self, opt: Opt, shared=None):
init_model, self.is_finetune = self._get_init_model(opt, shared)
super().__init__(opt, shared)
# set up classes
if opt.get('classes') is None and opt.get('classes_from_file') is None:
raise RuntimeError(
'Must specify --classes or --classes-from-file argument.'
)
if not shared:
if opt['classes_from_file'] is not None:
with PathManager.open(opt['classes_from_file']) as f:
self.class_list = f.read().splitlines()
else:
self.class_list = opt['classes']
self.class_dict = {val: i for i, val in enumerate(self.class_list)}
if opt.get('class_weights', None) is not None:
self.class_weights = opt['class_weights']
else:
self.class_weights = [1.0 for c in self.class_list]
self.reset_metrics()
else:
self.class_list = shared['class_list']
self.class_dict = shared['class_dict']
self.class_weights = shared['class_weights']
# in binary classfication, opt['threshold'] applies to ref class
if opt['ref_class'] is None or opt['ref_class'] not in self.class_dict:
self.ref_class = self.class_list[0]
else:
self.ref_class = opt['ref_class']
ref_class_id = self.class_list.index(self.ref_class)
if ref_class_id != 0:
# move to the front of the class list
self.class_list.insert(0, self.class_list.pop(ref_class_id))
# set up threshold, only used in binary classification
if len(self.class_list) == 2 and opt.get('threshold', 0.5) != 0.5:
self.threshold = opt['threshold']
else:
self.threshold = None
# set up calculating auc
self.calc_auc = opt.get('area_under_curve_digits', -1) > 0
if self.calc_auc:
self.auc_bucket_decimal_size = opt.get('area_under_curve_digits')
if opt.get('area_under_curve_class') is None:
# self.auc_class_ind
interested_classes = self.class_list
else:
interested_classes = opt.get('area_under_curve_class')
try:
self.auc_class_indices = [
self.class_dict[class_name] for class_name in interested_classes
]
except Exception:
raise RuntimeError(
f'The inputted classes for auc were probably invalid.\n Current class names: {self.class_list} \n Names of AUC classes passed in: {interested_classes}'
)
self.reset_auc()
# set up model and optimizers
states = {}
if shared:
self.model = shared['model']
else:
self.model = self.build_model()
# freeze the encoder and update the classifier only
if opt.get("update_classifier_head_only", False):
for _param_name, _param_value in self.model.named_parameters():
if not _param_name.startswith('additional_linear_layer'):
_param_value.requires_grad = False
self.criterion = self.build_criterion()
if self.model is None or self.criterion is None:
raise AttributeError(
'build_model() and build_criterion() need to return the model or criterion'
)
if init_model:
logging.info(f'Loading existing model parameters from {init_model}')
states = self.load(init_model)
if self.use_cuda:
if self.model_parallel:
ph = PipelineHelper()
ph.check_compatibility(self.opt)
self.model = ph.make_parallel(self.model)
else:
self.model.cuda()
if self.data_parallel:
self.model = torch.nn.DataParallel(self.model)
self.criterion.cuda()
train_params = trainable_parameters(self.model)
total_params = total_parameters(self.model)
logging.info(
f"Total parameters: {total_params:,d} ({train_params:,d} trainable)"
)
if shared:
# We don't use get here because hasattr is used on optimizer later.
if 'optimizer' in shared:
self.optimizer = shared['optimizer']
elif self._should_initialize_optimizer():
optim_params = [p for p in self.model.parameters() if p.requires_grad]
self.init_optim(optim_params)
self.build_lr_scheduler(states, hard_reset=self.is_finetune)