in ppuda/deepnets1m/net.py [0:0]
def forward(self, s0, s1, drop_path_prob=0):
s0 = None if (s0 is None or _is_none(self.preprocess0)) else self.preprocess0(s0)
s1 = None if (s1 is None or _is_none(self.preprocess1)) else self.preprocess1(s1)
states = [s0, s1]
for i in range(self._steps):
h1 = states[self._indices[2 * i]]
h2 = states[self._indices[2 * i + 1]]
op1 = self._ops[2 * i]
op2 = self._ops[2 * i + 1]
s = None
if not (isinstance(op1, Zero) or _is_none(op1) or h1 is None):
h1 = op1(h1)
if self.training and drop_path_prob > 0 and not isinstance(op1, nn.Identity):
h1 = drop_path(h1, drop_path_prob)
s = h1
if not (isinstance(op2, Zero) or _is_none(op2) or h2 is None):
h2 = op2(h2)
if self.training and drop_path_prob > 0 and not isinstance(op2, nn.Identity):
h2 = drop_path(h2, drop_path_prob)
try:
s = h2 if s is None else (h1 + h2)
except:
print(h1.shape, h2.shape, self.genotype)
raise
states.append(s)
if sum([states[i] is None for i in self._concat]) > 0:
# Replace None states with Zeros to match feature dimensionalities and enable forward pass
assert self._has_none, self.genotype
s_dummy = None
for i in self._concat:
if states[i] is not None:
s_dummy = states[i] * 0
break
if s_dummy is None:
return None
else:
for i in self._concat:
if states[i] is None:
states[i] = s_dummy
return torch.cat([states[i] for i in self._concat], dim=1)