in dlrm_s_pytorch.py [0:0]
def parallel_forward(self, dense_x, lS_o, lS_i):
### prepare model (overwrite) ###
# WARNING: # of devices must be >= batch size in parallel_forward call
batch_size = dense_x.size()[0]
ndevices = min(self.ndevices, batch_size, len(self.emb_l))
device_ids = range(ndevices)
# WARNING: must redistribute the model if mini-batch size changes(this is common
# for last mini-batch, when # of elements in the dataset/batch size is not even
if self.parallel_model_batch_size != batch_size:
self.parallel_model_is_not_prepared = True
if self.parallel_model_is_not_prepared or self.sync_dense_params:
# replicate mlp (data parallelism)
self.bot_l_replicas = replicate(self.bot_l, device_ids)
self.top_l_replicas = replicate(self.top_l, device_ids)
self.parallel_model_batch_size = batch_size
if self.parallel_model_is_not_prepared:
# distribute embeddings (model parallelism)
t_list = []
w_list = []
for k, emb in enumerate(self.emb_l):
d = torch.device("cuda:" + str(k % ndevices))
t_list.append(emb.to(d))
if self.weighted_pooling == "learned":
w_list.append(Parameter(self.v_W_l[k].to(d)))
elif self.weighted_pooling == "fixed":
w_list.append(self.v_W_l[k].to(d))
else:
w_list.append(None)
self.emb_l = nn.ModuleList(t_list)
if self.weighted_pooling == "learned":
self.v_W_l = nn.ParameterList(w_list)
else:
self.v_W_l = w_list
self.parallel_model_is_not_prepared = False
### prepare input (overwrite) ###
# scatter dense features (data parallelism)
# print(dense_x.device)
dense_x = scatter(dense_x, device_ids, dim=0)
# distribute sparse features (model parallelism)
if (len(self.emb_l) != len(lS_o)) or (len(self.emb_l) != len(lS_i)):
sys.exit("ERROR: corrupted model input detected in parallel_forward call")
t_list = []
i_list = []
for k, _ in enumerate(self.emb_l):
d = torch.device("cuda:" + str(k % ndevices))
t_list.append(lS_o[k].to(d))
i_list.append(lS_i[k].to(d))
lS_o = t_list
lS_i = i_list
### compute results in parallel ###
# bottom mlp
# WARNING: Note that the self.bot_l is a list of bottom mlp modules
# that have been replicated across devices, while dense_x is a tuple of dense
# inputs that has been scattered across devices on the first (batch) dimension.
# The output is a list of tensors scattered across devices according to the
# distribution of dense_x.
x = parallel_apply(self.bot_l_replicas, dense_x, None, device_ids)
# debug prints
# print(x)
# embeddings
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)
# debug prints
# print(ly)
# butterfly shuffle (implemented inefficiently for now)
# WARNING: Note that at this point we have the result of the embedding lookup
# for the entire batch on each device. We would like to obtain partial results
# corresponding to all embedding lookups, but part of the batch on each device.
# Therefore, matching the distribution of output of bottom mlp, so that both
# could be used for subsequent interactions on each device.
if len(self.emb_l) != len(ly):
sys.exit("ERROR: corrupted intermediate result in parallel_forward call")
t_list = []
for k, _ in enumerate(self.emb_l):
d = torch.device("cuda:" + str(k % ndevices))
y = scatter(ly[k], device_ids, dim=0)
t_list.append(y)
# adjust the list to be ordered per device
ly = list(map(lambda y: list(y), zip(*t_list)))
# debug prints
# print(ly)
# interactions
z = []
for k in range(ndevices):
zk = self.interact_features(x[k], ly[k])
z.append(zk)
# debug prints
# print(z)
# top mlp
# WARNING: Note that the self.top_l is a list of top mlp modules that
# have been replicated across devices, while z is a list of interaction results
# that by construction are scattered across devices on the first (batch) dim.
# The output is a list of tensors scattered across devices according to the
# distribution of z.
p = parallel_apply(self.top_l_replicas, z, None, device_ids)
### gather the distributed results ###
p0 = gather(p, self.output_d, dim=0)
# clamp output if needed
if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
z0 = torch.clamp(
p0, min=self.loss_threshold, max=(1.0 - self.loss_threshold)
)
else:
z0 = p0
return z0