in dlrm_s_caffe2.py [0:0]
def create_parallel_forward_ops(self):
# distribute embeddings (model parallelism)
tag = (self.temb, self.tsin, self.tsout)
self.emb_l, self.emb_w, self.emb_vw = self.create_emb(
self.m_spa, self.ln_emb, self.model, tag
)
# replicate mlp (data parallelism)
tag = (self.tbot, self.tdin, self.tdout)
self.bot_l, self.bot_w = self.create_mlp(self.ln_bot, self.sigmoid_bot,
self.model, tag)
# add communication (butterfly shuffle)
t_list = []
for i, emb_output in enumerate(self.emb_l):
# split input
src_d = i % self.ndevices
lo = [emb_output + "_split_" + str(d) for d in range(self.ndevices)]
# approach 1: np and caffe2 operators assume the mini-batch size is
# divisible exactly by the number of available devices
with core.DeviceScope(core.DeviceOption(workspace.GpuDeviceType, src_d)):
self.model.net.Split(emb_output, lo, axis=0)
"""
# approach 2: np and caffe2 operators do not assume exact divisibility
ls = where_to_split(args.mini_batch_size, self.ndevices, _add_leftover=True)
with core.DeviceScope(core.DeviceOption(workspace.GpuDeviceType, src_d)):
emb_output_split = self.model.net.Split(
emb_output, lo, split=lp, axis=0
)
"""
# scatter
y = []
for dst_d in range(len(lo)):
src_blob = lo[dst_d]
dst_blob = str(src_blob).replace(
"gpu_" + str(src_d), "gpu_" + str(dst_d), 1
)
if src_blob != dst_blob:
with core.DeviceScope(
core.DeviceOption(workspace.GpuDeviceType, dst_d)
):
blob = self.model.Copy(src_blob, dst_blob)
else:
blob = dst_blob
y.append(blob)
t_list.append(y)
# adjust lists to be ordered per device
x = list(map(lambda x: list(x), zip(*self.bot_l)))
ly = list(map(lambda y: list(y), zip(*t_list)))
# interactions
for d in range(self.ndevices):
on_device = "gpu_" + str(d) + "/"
tag = (on_device + self.tdout, on_device + self.tsout, on_device + self.tint)
with core.DeviceScope(core.DeviceOption(workspace.GpuDeviceType, d)):
self.create_interactions([x[d][-1]], ly[d], self.model, tag)
# replicate mlp (data parallelism)
tag = (self.ttop, self.tint, self.tout)
self.top_l, self.top_w = self.create_mlp(self.ln_top, self.sigmoid_top,
self.model, tag)
# debug prints
# print(self.model.net.Proto(),end='\n')
# sys.exit("ERROR: debugging")
# setup the last output variable
self.last_output = self.top_l[-1]