def device_function()

in tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py [0:0]


def device_function(self, op):
  """Choose a device for `op`.

    Args:
      op: an `Operation`.

    Returns:
      The device to use for the `Operation`.
    """
  # If we don't return early here, either merge_devices is True, or op.device
  # is empty (in which case merging is a no-op). So we can always merge below.
  if not self._merge_devices and op.device:
    return op.device

  current_device = pydev.DeviceSpec.from_string(op.device or "")

  # The ps_device will be used for specified ops (ps_ops) whenever it is
  # present and ps_tasks is non-zero. However, its task number will only be
  # set (using ps_strategy) if there is a job field in ps_device that won't be
  # changed by the job field (if present) in current_device.
  node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def

  # TODO(rhdong): `TrainableWrapper` is not multi-threads safe so we try to
  #  prevent to place the `TrainableWrapper` on PS,
  #  a less bad way of avoiding handle of `TrainableWrapper` be
  #  placed on the PS devices for node_def carries too little information to
  #  know if it was created by `TrainableWrapper` or not.
  if ("TrainableWrapper" not in node_def.name and self._ps_tasks
      and self._ps_device and node_def.op in self._ps_ops):
    ps_device = pydev.DeviceSpec.from_string(self._ps_device)

    current_job, ps_job = current_device.job, ps_device.job
    if ps_job and (not current_job or current_job == ps_job):
      ps_device = ps_device.replace(task=self._ps_strategy(op))

    ps_device = ps_device.make_merged_spec(current_device)
    return ps_device.to_string()
  worker_device = pydev.DeviceSpec.from_string(self._worker_device or "")
  worker_device = worker_device.make_merged_spec(current_device)
  return worker_device.to_string()