in src/bert_train.py [0:0]
def __init__(self, model_dir, device=None, epochs=10, early_stopping_patience=20, checkpoint_frequency=1,
checkpoint_dir=None,
accumulation_steps=1):
self.model_dir = model_dir
self.accumulation_steps = accumulation_steps
self.checkpoint_dir = checkpoint_dir
self.checkpoint_frequency = checkpoint_frequency
self.early_stopping_patience = early_stopping_patience
self.epochs = epochs
self.snapshotter = None
# Set up device is not set
available_device = "cuda:0" if torch.cuda.is_available() else "cpu"
if torch.cuda.device_count() > 1:
available_device = [f"cuda:{i}" for i in range(torch.cuda.device_count())]
self.device = device or available_device
# Assume multi gpu if device passed is a list and not a string
self._is_multigpu = not isinstance(self.device, str)
self._default_device = self.device[0] if self._is_multigpu else self.device