in archs/models.py [0:0]
def __init__(self,
module_list,
start_modules=None,
end_modules=None,
single_head=False,
chain=False):
"""TODO: to be defined1.
:module_list: TODO
:g: TODO
"""
nn.Module.__init__(self)
self._module_list = nn.ModuleList(
[nn.ModuleList(m) for m in module_list])
self.num_layers = len(self._module_list)
if start_modules is not None:
self._start_modules = nn.ModuleList(start_modules)
else:
self._start_modules = None
if end_modules is not None:
self._end_modules = nn.ModuleList(end_modules)
self.num_layers += 1
else:
self._end_modules = None
self.sampled_g = None
self.single_head = single_head
self._chain = chain