in theseus/core/objective.py [0:0]
def update(self, input_data: Optional[Dict[str, torch.Tensor]] = None):
self._batch_size = None
def _get_batch_size(batch_sizes: Sequence[int]) -> int:
unique_batch_sizes = set(batch_sizes)
if len(unique_batch_sizes) == 1:
return batch_sizes[0]
if len(unique_batch_sizes) == 2:
min_bs = min(unique_batch_sizes)
max_bs = max(unique_batch_sizes)
if min_bs == 1:
return max_bs
raise ValueError("Provided data tensors must be broadcastable.")
input_data = input_data or {}
for var_name, data in input_data.items():
if data.ndim < 2:
raise ValueError(
f"Input data tensors must have a batch dimension and "
f"one ore more data dimensions, but data.ndim={data.ndim} for "
f"tensor with name {var_name}."
)
if var_name in self.optim_vars:
self.optim_vars[var_name].update(data)
elif var_name in self.aux_vars:
self.aux_vars[var_name].update(data)
elif var_name in self.cost_weight_optim_vars:
self.cost_weight_optim_vars[var_name].update(data)
warnings.warn(
"Updated a variable declared as optimization, but it is "
"only associated to cost weights and not to any cost functions. "
"Theseus optimizers will only update optimization variables "
"that are associated to one or more cost functions."
)
else:
warnings.warn(
f"Attempted to update a tensor with name {var_name}, "
"which is not associated to any variable in the objective."
)
# Check that the batch size of all data is consistent after update
batch_sizes = [v.data.shape[0] for v in self.optim_vars.values()]
batch_sizes.extend([v.data.shape[0] for v in self.aux_vars.values()])
self._batch_size = _get_batch_size(batch_sizes)