def __new__()

in tensorflow_model_optimization/python/core/sparsity/keras/estimator_utils.py [0:0]


  def __new__(cls, model, step=None, train_op=None, **kwargs):
    if "mode" not in kwargs:
      raise ValueError("Must provide a mode (TRAIN/EVAL/PREDICT) when "
                       "creating an EstimatorSpec")

    if train_op is None:
      raise ValueError(
          "Must provide train_op for creating a PruningEstimatorSpec")

    for layer in model.layers:
      # If the model is newly created/initialized, set the 'pruning_step' to 0.
      # Otherwise, do nothing.
      if isinstance(layer, PruneLowMagnitude) and layer.pruning_step == -1:
        tf.assign(layer.pruning_step, 0)

    def _get_step_increment_ops(model, step=None):
      """Returns ops to increment the pruning_step in the prunable layers."""
      increment_ops = []

      for layer in model.layers:
        if isinstance(layer, PruneLowMagnitude):
          if step is None:
            # Add ops to increment the pruning_step by 1
            increment_ops.append(tf.assign_add(layer.pruning_step, 1))
          else:
            increment_ops.append(
                tf.assign(layer.pruning_step, tf.cast(step, tf.int64)))

      return tf.group(increment_ops)

    pruning_ops = []
    # Grab the ops to update pruning step in every prunable layer
    step_increment_ops = _get_step_increment_ops(model, step)
    pruning_ops.append(step_increment_ops)
    # Grab the model updates.
    pruning_ops.append(model.updates)

    kwargs["train_op"] = tf.group(pruning_ops, train_op)

    def init_fn(scaffold, session):  # pylint: disable=unused-argument
      return session.run(step_increment_ops)

    def get_new_scaffold(old_scaffold):
      if old_scaffold.init_fn is None:
        return tf.compat.v1.train.Scaffold(
            init_fn=init_fn, copy_from_scaffold=old_scaffold)
      # TODO(suyoggupta): Figure out a way to merge the init_fn of the
      # original scaffold with the one defined above.
      raise ValueError("Scaffold provided to PruningEstimatorSpec must not "
                       "set an init_fn.")

    scaffold = tf.compat.v1.train.Scaffold(init_fn=init_fn)
    if "scaffold" in kwargs:
      scaffold = get_new_scaffold(kwargs["scaffold"])

    kwargs["scaffold"] = scaffold

    return super(PruningEstimatorSpec, cls).__new__(cls, **kwargs)