in vihds/encoders.py [0:0]
def forward(self, delta_obs, conds, devs):
x = torch.Tensor([])
if self.condition_data:
x = torch.cat((x, delta_obs), 1)
if self.condition_treatments:
x = torch.cat((x, conds), 1)
if self.condition_devices:
x = torch.cat((x, devs), 1)
params = OrderedDict()
for free_name, constrained_name, free_to_constrained in zip(
self.description.free_params, self.description.params, self.description.free_to_constrained,
):
free_param = self.layers[free_name](x)
# TODO: Torch equivalent of tf.stop_gradient
# if stop_grad:
# free_param = tf.stop_gradient(free_param) # eliminate score function term from autodiff
constrained = constrain_parameter(free_param, free_to_constrained)
params[free_name] = free_param
params[constrained_name] = constrained
for (other_param_name, other_param_value,) in self.description.other_params.items():
params[other_param_name] = other_param_value
new_distribution = self.description.class_type(wait_for_assigned=True, variable=True)
new_distribution.assign_free_and_constrained(**params)
return new_distribution