in benchmarks/dlrm/ootb/dlrm_s_pytorch.py [0:0]
def parallel_forward(self, dense_x, lS_o, lS_i):
### prepare model (overwrite) ###
# WARNING: # of devices must be >= batch size in parallel_forward call
batch_size = dense_x.size()[0]
ndevices = min(self.ndevices_available, batch_size, self.ntables)
device_ids = range(ndevices)
# WARNING: must redistribute the model if mini-batch size changes(this is common
# for last mini-batch, when # of elements in the dataset/batch size is not even
if self.ndevices_in_use != ndevices:
self.ndevices_in_use = ndevices
self.prepare_parallel_model(ndevices)
elif self.sync_dense_params:
# When training, replicate the new/updated mlp weights each iteration.
# For inference-only, this code should never run.
self.bot_l_replicas = replicate(self.bot_l, device_ids)
self.top_l_replicas = replicate(self.top_l, device_ids)
### prepare input (overwrite) ###
# scatter dense features (data parallelism)
# print(dense_x.device)
dense_x = scatter(dense_x, device_ids, dim=0)
# distribute sparse features (model parallelism)
if (self.ntables != len(lS_o)) or (self.ntables != len(lS_i)):
sys.exit("ERROR: corrupted model input detected in parallel_forward call")
lS_o = [
lS_o[k].to(torch.device("cuda:" + str(k % ndevices)))
for k in range(self.ntables)
]
lS_i = [
lS_i[k].to(torch.device("cuda:" + str(k % ndevices)))
for k in range(self.ntables)
]
### compute results in parallel ###
# bottom mlp
# WARNING: Note that the self.bot_l is a list of bottom mlp modules
# that have been replicated across devices, while dense_x is a tuple of dense
# inputs that has been scattered across devices on the first (batch) dimension.
# The output is a list of tensors scattered across devices according to the
# distribution of dense_x.
x = parallel_apply(self.bot_l_replicas, dense_x, None, device_ids)
# debug prints
# print(x)
# embeddings
ly = self.apply_emb(lS_o, lS_i)
# debug prints
# print(ly)
# butterfly shuffle (implemented inefficiently for now)
# WARNING: Note that at this point we have the result of the embedding lookup
# for the entire batch on each device. We would like to obtain partial results
# corresponding to all embedding lookups, but part of the batch on each device.
# Therefore, matching the distribution of output of bottom mlp, so that both
# could be used for subsequent interactions on each device.
if self.ntables != len(ly):
sys.exit("ERROR: corrupted intermediate result in parallel_forward call")
t_list = [scatter(ly[k], device_ids, dim=0) for k in range(self.ntables)]
# adjust the list to be ordered per device
ly = list(map(lambda y: list(y), zip(*t_list)))
# debug prints
# print(ly)
# interactions
z = parallel_apply(self.interact_features_l, list(zip(itertools.repeat(self.interact_features),x,ly)))
# debug prints
# print(z)
if self.quantize_mlp_input_with_half_call:
z = [tens.half() for tens in z]
# top mlp
# WARNING: Note that the self.top_l is a list of top mlp modules that
# have been replicated across devices, while z is a list of interaction results
# that by construction are scattered across devices on the first (batch) dim.
# The output is a list of tensors scattered across devices according to the
# distribution of z.
p = parallel_apply(self.top_l_replicas, z, None, device_ids)
### gather the distributed results ###
p0 = gather(p, self.output_d, dim=0)
# clamp output if needed
if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
z0 = torch.clamp(
p0, min=self.loss_threshold, max=(1.0 - self.loss_threshold)
)
else:
z0 = p0
return z0