def update()

in theseus/core/objective.py [0:0]


    def update(self, input_data: Optional[Dict[str, torch.Tensor]] = None):
        self._batch_size = None

        def _get_batch_size(batch_sizes: Sequence[int]) -> int:
            unique_batch_sizes = set(batch_sizes)
            if len(unique_batch_sizes) == 1:
                return batch_sizes[0]
            if len(unique_batch_sizes) == 2:
                min_bs = min(unique_batch_sizes)
                max_bs = max(unique_batch_sizes)
                if min_bs == 1:
                    return max_bs
            raise ValueError("Provided data tensors must be broadcastable.")

        input_data = input_data or {}
        for var_name, data in input_data.items():
            if data.ndim < 2:
                raise ValueError(
                    f"Input data tensors must have a batch dimension and "
                    f"one ore more data dimensions, but data.ndim={data.ndim} for "
                    f"tensor with name {var_name}."
                )
            if var_name in self.optim_vars:
                self.optim_vars[var_name].update(data)
            elif var_name in self.aux_vars:
                self.aux_vars[var_name].update(data)
            elif var_name in self.cost_weight_optim_vars:
                self.cost_weight_optim_vars[var_name].update(data)
                warnings.warn(
                    "Updated a variable declared as optimization, but it is "
                    "only associated to cost weights and not to any cost functions. "
                    "Theseus optimizers will only update optimization variables "
                    "that are associated to one or more cost functions."
                )
            else:
                warnings.warn(
                    f"Attempted to update a tensor with name {var_name}, "
                    "which is not associated to any variable in the objective."
                )

        # Check that the batch size of all data is consistent after update
        batch_sizes = [v.data.shape[0] for v in self.optim_vars.values()]
        batch_sizes.extend([v.data.shape[0] for v in self.aux_vars.values()])
        self._batch_size = _get_batch_size(batch_sizes)