in tensorflow_recommenders/experimental/optimizers/composite_optimizer.py [0:0]
def apply_gradients(self, grads_and_vars: Sequence[Tuple[Tensor, Tensor]],
name: Optional[str] = None,
experimental_aggregate_gradients: bool = True) -> None:
"""See base class."""
var_optimizer_dict = {}
for optimizer, var_callable in self._optimizers_and_vars:
for v in var_callable():
if v.ref() in var_optimizer_dict:
raise ValueError(
f"The set of variables handled by each optimizer should be "
f"disjoint, but variable {v} is handled both "
f"by {var_optimizer_dict[v.ref()]} and {optimizer}.")
var_optimizer_dict[v.ref()] = optimizer
optimizer_grads_and_vars = collections.defaultdict(list)
for g, v in grads_and_vars:
if v.ref() in var_optimizer_dict:
optimizer = var_optimizer_dict[v.ref()]
optimizer_grads_and_vars[optimizer].append((g, v))
else:
raise ValueError(f"Variable {v} is not handled by any optimizer. "
f"This would cause it to be not trained.")
for optimizer, opt_grads_and_vars in optimizer_grads_and_vars.items():
optimizer.apply_gradients(
opt_grads_and_vars,
name=name,
experimental_aggregate_gradients=experimental_aggregate_gradients)