in archs/models.py [0:0]
def forward(self, x, sampled_g=None, t=None, return_feat=False):
"""TODO: Docstring for forward.
:x: Input data
:g: Gating tensor (#Task x )#num_layer x #num_mods x #num_mods
:t: task ID
:returns: TODO
"""
if t is None:
t_not_set = True
t = torch.tensor([0] * x.shape[0], dtype=x.dtype).long()
else:
t_not_set = False
t = t.squeeze()
if self._start_modules is not None:
prev_out = [mod(x) for mod in self._start_modules]
else:
prev_out = [x]
if sampled_g is None:
# NON-Gated Module network
prev_out = sum(prev_out) / float(len(prev_out))
#prev_out = torch.mean(prev_out, 0)
for li in range(len(self._module_list)):
prev_out = sum([
mod(prev_out) for mod in self._module_list[li]
]) / float(len(self._module_list[li]))
features = prev_out
if self._end_modules is not None:
if t_not_set or self.single_head:
prev_out = self._end_modules[0](prev_out)
else:
prev_out = torch.cat([
self._end_modules[tid](prev_out[bi:bi + 1])
for bi, tid in enumerate(t)
], 0)
if return_feat:
return prev_out, features
return prev_out
else:
# Forward prop with sampled Gs
for li in range(len(self._module_list)):
curr_out = []
for j in range(len(self._module_list[li])):
gind = j if not self._chain else 0
# Dim: #Batch x C
module_in_wt = sampled_g[li + 1][gind]
# Module input weights rearranged to match inputs
module_in_wt = module_in_wt.transpose(0, 1)
add_dims = prev_out[0].dim() + 1 - module_in_wt.dim()
module_in_wt = module_in_wt.view(*module_in_wt.shape,
*([1] * add_dims))
module_in_wt = module_in_wt.expand(
len(prev_out), *prev_out[0].shape)
module_in = sum([
module_in_wt[i] * prev_out[i]
for i in range(len(prev_out))
])
mod = self._module_list[li][j]
curr_out.append(mod(module_in))
prev_out = curr_out
# Output modules (with sampled Gs)
if self._end_modules is not None:
li = self.num_layers - 1
if t_not_set or self.single_head:
# Dim: #Batch x C
module_in_wt = sampled_g[li + 1][0]
# Module input weights rearranged to match inputs
module_in_wt = module_in_wt.transpose(0, 1)
add_dims = prev_out[0].dim() + 1 - module_in_wt.dim()
module_in_wt = module_in_wt.view(*module_in_wt.shape,
*([1] * add_dims))
module_in_wt = module_in_wt.expand(
len(prev_out), *prev_out[0].shape)
module_in = sum([
module_in_wt[i] * prev_out[i]
for i in range(len(prev_out))
])
features = module_in
prev_out = self._end_modules[0](module_in)
else:
curr_out = []
for bi, tid in enumerate(t):
# Dim: #Batch x C
gind = tid if not self._chain else 0
module_in_wt = sampled_g[li + 1][gind]
# Module input weights rearranged to match inputs
module_in_wt = module_in_wt.transpose(0, 1)
add_dims = prev_out[0].dim() + 1 - module_in_wt.dim()
module_in_wt = module_in_wt.view(
*module_in_wt.shape, *([1] * add_dims))
module_in_wt = module_in_wt.expand(
len(prev_out), *prev_out[0].shape)
module_in = sum([
module_in_wt[i] * prev_out[i]
for i in range(len(prev_out))
])
features = module_in
mod = self._end_modules[tid]
curr_out.append(mod(module_in[bi:bi + 1]))
prev_out = curr_out
prev_out = torch.cat(prev_out, 0)
if return_feat:
return prev_out, features
return prev_out