in horovod/spark/torch/remote.py [0:0]
def RemoteTrainer(estimator, metadata, last_checkpoint_state, run_id, dataset_idx):
# Estimator parameters
gradient_compression = estimator.getGradientCompression()
input_shapes = estimator.getInputShapes()
label_shapes = estimator.getLabelShapes()
feature_columns = estimator.getFeatureCols()
label_columns = estimator.getLabelCols()
num_labels = len(label_columns)
should_validate = estimator.getValidation()
batch_size = estimator.getBatchSize()
epochs = estimator.getEpochs()
train_steps_per_epoch = estimator.getTrainStepsPerEpoch()
validation_steps_per_epoch = estimator.getValidationStepsPerEpoch()
sample_weight_col = estimator.getSampleWeightCol()
metric_fn_groups = estimator.getMetrics()
user_shuffle_buffer_size = estimator.getShufflingBufferSize()
user_verbose = estimator.getVerbose()
train_minibatch_fn = estimator.getTrainMinibatchFn()
train_minibatch = train_minibatch_fn if train_minibatch_fn else _train_minibatch_fn()
loss_fns_pre_train = to_list(estimator.getLoss(), num_labels)
loss_constructors = to_list(estimator.getLossConstructors(), num_labels)
transformation_fn = estimator.getTransformationFn()
transformation = transformation_fn if transformation_fn else None
# If loss weight is not provided, use equal loss for all the labels
loss_weights = estimator.getLossWeights()
if not loss_weights:
loss_weights = [float(1) / num_labels for _ in range(num_labels)]
else:
if not isinstance(loss_weights, list) or \
len(loss_weights) != len(label_columns):
raise ValueError('loss_weights needs to be a list with the same '
'length as the label_columns.')
# Data reader parameters
train_reader_worker_count = estimator.getTrainReaderNumWorker()
val_reader_worker_count = estimator.getValReaderNumWorker()
# Utility functions
deserialize = deserialize_fn()
get_optimizer_with_unscaled_lr = _get_optimizer_with_unscaled_lr_fn()
calculate_shuffle_buffer_size = _calculate_shuffle_buffer_size_fn()
construct_metric_value_holders = _construct_metric_value_holders_fn()
metric_cls = _metric_cls()
prepare_np_data = _prepare_np_data_fn()
get_metric_avgs = _get_metric_avgs_fn()
update_metrics = _update_metrics_fn(metric_fn_groups)
write_metrics_summary = _write_metrics_summary_fn()
calculate_loss = _calculate_loss_fn()
# Storage
store = estimator.getStore()
remote_store = store.to_remote(run_id, dataset_idx)
@contextlib.contextmanager
def empty_batch_reader():
yield None
def train(serialized_model, optimizer_cls, model_opt_state_serialized,
train_rows, val_rows, avg_row_size):
from petastorm import TransformSpec, make_reader, make_batch_reader
from petastorm.pytorch import BatchedDataLoader
import torch
import horovod.torch as hvd
# Deserializing objects
model_opt_state = torch.load(model_opt_state_serialized)
model = deserialize(serialized_model)
if loss_fns_pre_train:
loss_fns = loss_fns_pre_train
if loss_constructors:
local_vars = locals()
loss_fns = [loss_constructor(**local_vars) for loss_constructor in loss_constructors]
# Horovod: initialize library.
hvd.init()
if not user_shuffle_buffer_size:
shuffle_buffer_size = \
calculate_shuffle_buffer_size(hvd, avg_row_size, train_rows / hvd.size())
else:
shuffle_buffer_size = user_shuffle_buffer_size
cuda_available = torch.cuda.is_available()
if cuda_available:
# Horovod: pin GPU to local rank.
torch.cuda.set_device(hvd.local_rank())
# Move model to GPU.
model.cuda()
# Optimizer object needs to be re-instantiated. Internally, it uses memory addresses of
# objects as their identity and therefore it cannot be serialized and then
# deserialized. The deserialized optimizer object stores the names of the parameters
# with their old memory addresses but in reality those are different than the
# reconstructed deserialized object and that creates problem.
# Learning rate is a required parameters in SGD optimizer. It will be overridden with
# load_state_dict.
optimizer = optimizer_cls(model.parameters(), lr=1)
optimizer_state = model_opt_state['optimizer']
if last_checkpoint_state is not None:
model.load_state_dict(last_checkpoint_state['model'])
optimizer.load_state_dict(last_checkpoint_state['optimizer'])
else:
# scale the learning rate with the number of horovod workers
for i in range(len(optimizer_state['param_groups'])):
optimizer_state['param_groups'][i]['lr'] = \
optimizer_state['param_groups'][i]['lr'] * hvd.size()
optimizer.load_state_dict(optimizer_state)
# Horovod: broadcast parameters & optimizer state.
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
for group in optimizer.param_groups:
for p in group['params']:
if id(p) not in optimizer.state_dict()['state']:
p.grad = p.data.new(p.size()).zero_()
optimizer.step()
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
dist_optimizer_args = dict(optimizer=optimizer,
named_parameters=model.named_parameters())
if gradient_compression:
# Pass the compression arg only if it is specified by the user.
dist_optimizer_args['compression'] = gradient_compression
# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = hvd.DistributedOptimizer(**dist_optimizer_args)
# This function takes the current optimizer and constructs a new optimizer with the
# same state except with learning rate scaled down with the number of horovod workers.
# This is important the retraining of the model. User may retrain the model with
# different number of workers and we need the raw learning rate to adjust with the
# new number of workers.
transform_spec = None
if transformation:
transform_spec = TransformSpec(transformation)
schema_fields = feature_columns + label_columns
if sample_weight_col:
schema_fields.append(sample_weight_col)
if train_steps_per_epoch is None:
steps_per_epoch = int(math.ceil(float(train_rows) / batch_size / hvd.size()))
else:
steps_per_epoch = train_steps_per_epoch
with remote_store.get_local_output_dir() as run_output_dir:
logs_dir = os.path.join(run_output_dir, remote_store.logs_subdir)
log_writer = SummaryWriter(logs_dir) if hvd.rank() == 0 else None
ckpt_file = os.path.join(run_output_dir, remote_store.checkpoint_filename)
def save_checkpoint():
model.cpu()
optimizer_with_scaled_down_lr = \
get_optimizer_with_unscaled_lr(hvd, optimizer, optimizer_cls, model)
state = {
'model': model.state_dict(),
'optimizer': optimizer_with_scaled_down_lr.state_dict(),
}
torch.save(state, ckpt_file)
if cuda_available:
model.cuda()
# In general, make_batch_reader is faster than make_reader for reading the dataset.
# However, we found out that make_reader performs data transformations much faster than
# make_batch_reader with parallel worker processes. Therefore, the default reader
# we choose is make_batch_reader unless there are data transformations.
reader_factory = None
reader_factory_kwargs = dict()
if transform_spec:
reader_factory = make_reader
reader_factory_kwargs['pyarrow_serialize'] = True
else:
reader_factory = make_batch_reader
# Petastorm: read data from the store with the correct shard for this rank
# setting num_epochs=None will cause an infinite iterator
# and enables ranks to perform training and validation with
# unequal number of samples
with reader_factory(remote_store.train_data_path,
num_epochs=None,
cur_shard=hvd.rank(),
reader_pool_type='process',
workers_count=train_reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=schema_fields,
transform_spec=transform_spec,
**reader_factory_kwargs) as train_reader:
with reader_factory(remote_store.val_data_path,
num_epochs=None,
cur_shard=hvd.rank(),
reader_pool_type='process',
workers_count=val_reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=schema_fields,
transform_spec=transform_spec,
**reader_factory_kwargs) \
if should_validate else empty_batch_reader() as val_reader:
train_loader = BatchedDataLoader(train_reader,
batch_size=batch_size,
shuffling_queue_capacity=shuffle_buffer_size)
train_loader_iter = iter(train_loader)
def prepare_batch(row):
inputs = [
prepare_np_data(
row[col].float(), col, metadata).reshape(shape)
for col, shape in zip(feature_columns, input_shapes)]
labels = [
prepare_np_data(
row[col].float(), col, metadata)
for col in label_columns]
sample_weights = row.get(sample_weight_col, None)
if sample_weights is not None:
sample_weights = sample_weights.float()
if cuda_available:
inputs = [input.cuda() for input in inputs]
labels = [label.cuda() for label in labels]
if sample_weights:
sample_weights = sample_weights.cuda()
return inputs, labels, sample_weights
def transform_outputs(outputs, labels):
if type(outputs) != tuple and type(outputs) != list:
outputs = [outputs]
# reshape labels to match the output shape of the model
if hasattr(outputs[0], 'shape'):
if label_shapes:
labels = [label.reshape(label_shape)
for label, label_shape in zip(labels, label_shapes)]
else:
# If label_shapes parameter is not provided, reshape the label
# columns data to match the shape of the model output
labels = [label.reshape(output.shape) if
output.shape.numel() == label.shape.numel() else label
for label, output in zip(labels, outputs)]
return outputs, labels
def aggregate_metrics(stage, epoch, loss, metric_value_groups):
all_metric_groups_values = get_metric_avgs(metric_value_groups)
if remote_store.saving_runs:
write_metrics_summary(
stage, epoch, loss, all_metric_groups_values, log_writer)
return {
loss.name: loss.avg.item(),
'all_metrics': all_metric_groups_values
}
def loss_fn(outputs, labels, sample_weights):
loss = calculate_loss(outputs, labels, loss_weights, loss_fns, sample_weights)
return loss
def print_metrics(batch_idx, loss, metric_value_groups, phase):
if user_verbose > 0 and hvd.rank() == 0 and \
batch_idx % METRIC_PRINT_FREQUENCY == 0:
print("epoch:\t{epoch}\tstep\t{batch_idx}:\t{metrics}".
format(epoch=epoch,
batch_idx=batch_idx,
metrics=aggregate_metrics(phase, epoch, loss,
metric_value_groups)))
def _train(epoch):
model.train()
train_loss = metric_cls('loss', hvd)
metric_value_groups = construct_metric_value_holders(
metric_cls, metric_fn_groups, label_columns, hvd)
# iterate on one epoch
for batch_idx in range(steps_per_epoch):
row = next(train_loader_iter)
inputs, labels, sample_weights = prepare_batch(row)
outputs, loss = train_minibatch(model, optimizer, transform_outputs,
loss_fn, inputs, labels, sample_weights)
update_metrics(metric_value_groups, outputs, labels)
train_loss.update(loss)
print_metrics(batch_idx, train_loss, metric_value_groups, 'train')
return aggregate_metrics('train', epoch, train_loss, metric_value_groups)
if should_validate:
val_loader = BatchedDataLoader(val_reader, batch_size=batch_size)
val_loader_iter = iter(val_loader)
if validation_steps_per_epoch is None:
validation_steps = int(math.ceil(float(val_rows) / batch_size / hvd.size()))
else:
validation_steps = validation_steps_per_epoch
def _validate(epoch):
model.eval()
val_loss = metric_cls('loss', hvd)
metric_value_groups = construct_metric_value_holders(
metric_cls, metric_fn_groups, label_columns, hvd)
# iterate on one epoch
for batch_idx in range(validation_steps):
row = next(val_loader_iter)
inputs, labels, sample_weights = prepare_batch(row)
outputs = model(*inputs)
outputs, labels = transform_outputs(outputs, labels)
loss = calculate_loss(
outputs, labels, loss_weights, loss_fns, sample_weights)
val_loss.update(loss)
update_metrics(metric_value_groups, outputs, labels)
print_metrics(batch_idx, val_loss, metric_value_groups, 'val')
return aggregate_metrics('val', epoch, val_loss, metric_value_groups)
history = []
for epoch in range(epochs):
epoch_metrics = {
'epoch': epoch,
'train': _train(epoch)
}
if should_validate:
epoch_metrics['validation'] = _validate(epoch)
if user_verbose > 0:
print(epoch_metrics)
history.append(epoch_metrics)
if hvd.rank() == 0:
# Save model after every epoch
save_checkpoint()
if remote_store.saving_runs:
remote_store.sync(run_output_dir)
if hvd.rank() == 0:
best_checkpoint = torch.load(ckpt_file)
serialized_checkpoint = io.BytesIO()
torch.save(best_checkpoint, serialized_checkpoint)
serialized_checkpoint.seek(0)
return history, serialized_checkpoint
return train