in archs/models.py [0:0]
def forward(self, tdesc=None, return_additional=False, gating_wt=None):
if self.gater is None and return_additional:
return None, None
elif self.gater is None:
return None
if gating_wt is not None:
return_wts = gating_wt
gating_g = self.gater(tdesc, gating_wt=gating_wt)
else:
gating_g = self.gater(tdesc)
return_wts = None
if isinstance(gating_g, tuple):
return_wts = gating_g[1]
gating_g = gating_g[0]
if not self._stoch_sample:
sampled_g = gating_g
else:
raise (NotImplementedError)
if return_additional:
return sampled_g, return_wts
return sampled_g