lucid/misc/redirected_relu_grad.py [90:145]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def redirected_relu_grad(op, grad):
  assert op.type == "Relu"
  x = op.inputs[0]

  # Compute ReLu gradient
  relu_grad = tf.where(x < 0., tf.zeros_like(grad), grad)

  # Compute redirected gradient: where do we need to zero out incoming gradient
  # to prevent input going lower if its already negative
  neg_pushing_lower = tf.logical_and(x < 0., grad > 0.)
  redirected_grad = tf.where(neg_pushing_lower, tf.zeros_like(grad), grad)

  # Ensure we have at least a rank 2 tensor, as we expect a batch dimension
  assert_op = tf.Assert(tf.greater(tf.rank(relu_grad), 1), [tf.rank(relu_grad)])
  with tf.control_dependencies([assert_op]):
    # only use redirected gradient where nothing got through original gradient
    batch = tf.shape(relu_grad)[0]
    reshaped_relu_grad = tf.reshape(relu_grad, [batch, -1])
    relu_grad_mag = tf.norm(reshaped_relu_grad, axis=1)
  result_grad = tf.where(relu_grad_mag > 0., relu_grad, redirected_grad)

  global_step_t =tf.train.get_or_create_global_step()
  return_relu_grad = tf.greater(global_step_t, tf.constant(16, tf.int64))

  return tf.where(return_relu_grad, relu_grad, result_grad)


def redirected_relu6_grad(op, grad):
  assert op.type == "Relu6"
  x = op.inputs[0]

  # Compute ReLu gradient
  relu6_cond = tf.logical_or(x < 0., x > 6.)
  relu_grad = tf.where(relu6_cond, tf.zeros_like(grad), grad)

  # Compute redirected gradient: where do we need to zero out incoming gradient
  # to prevent input going lower if its already negative, or going higher if
  # already bigger than 6?
  neg_pushing_lower = tf.logical_and(x < 0., grad > 0.)
  pos_pushing_higher = tf.logical_and(x > 6., grad < 0.)
  dir_filter = tf.logical_or(neg_pushing_lower, pos_pushing_higher)
  redirected_grad = tf.where(dir_filter, tf.zeros_like(grad), grad)

  # Ensure we have at least a rank 2 tensor, as we expect a batch dimension
  assert_op = tf.Assert(tf.greater(tf.rank(relu_grad), 1), [tf.rank(relu_grad)])
  with tf.control_dependencies([assert_op]):
    # only use redirected gradient where nothing got through original gradient
    batch = tf.shape(relu_grad)[0]
    reshaped_relu_grad = tf.reshape(relu_grad, [batch, -1])
    relu_grad_mag = tf.norm(reshaped_relu_grad, axis=1)
  result_grad =  tf.where(relu_grad_mag > 0., relu_grad, redirected_grad)

  global_step_t = tf.train.get_or_create_global_step()
  return_relu_grad = tf.greater(global_step_t, tf.constant(16, tf.int64))

  return tf.where(return_relu_grad, relu_grad, result_grad)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



lucid/optvis/overrides/redirected_relu_grad.py [90:145]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def redirected_relu_grad(op, grad):
  assert op.type == "Relu"
  x = op.inputs[0]

  # Compute ReLu gradient
  relu_grad = tf.where(x < 0., tf.zeros_like(grad), grad)

  # Compute redirected gradient: where do we need to zero out incoming gradient
  # to prevent input going lower if its already negative
  neg_pushing_lower = tf.logical_and(x < 0., grad > 0.)
  redirected_grad = tf.where(neg_pushing_lower, tf.zeros_like(grad), grad)

  # Ensure we have at least a rank 2 tensor, as we expect a batch dimension
  assert_op = tf.Assert(tf.greater(tf.rank(relu_grad), 1), [tf.rank(relu_grad)])
  with tf.control_dependencies([assert_op]):
    # only use redirected gradient where nothing got through original gradient
    batch = tf.shape(relu_grad)[0]
    reshaped_relu_grad = tf.reshape(relu_grad, [batch, -1])
    relu_grad_mag = tf.norm(reshaped_relu_grad, axis=1)
  result_grad = tf.where(relu_grad_mag > 0., relu_grad, redirected_grad)

  global_step_t =tf.train.get_or_create_global_step()
  return_relu_grad = tf.greater(global_step_t, tf.constant(16, tf.int64))

  return tf.where(return_relu_grad, relu_grad, result_grad)


def redirected_relu6_grad(op, grad):
  assert op.type == "Relu6"
  x = op.inputs[0]

  # Compute ReLu gradient
  relu6_cond = tf.logical_or(x < 0., x > 6.)
  relu_grad = tf.where(relu6_cond, tf.zeros_like(grad), grad)

  # Compute redirected gradient: where do we need to zero out incoming gradient
  # to prevent input going lower if its already negative, or going higher if
  # already bigger than 6?
  neg_pushing_lower = tf.logical_and(x < 0., grad > 0.)
  pos_pushing_higher = tf.logical_and(x > 6., grad < 0.)
  dir_filter = tf.logical_or(neg_pushing_lower, pos_pushing_higher)
  redirected_grad = tf.where(dir_filter, tf.zeros_like(grad), grad)

  # Ensure we have at least a rank 2 tensor, as we expect a batch dimension
  assert_op = tf.Assert(tf.greater(tf.rank(relu_grad), 1), [tf.rank(relu_grad)])
  with tf.control_dependencies([assert_op]):
    # only use redirected gradient where nothing got through original gradient
    batch = tf.shape(relu_grad)[0]
    reshaped_relu_grad = tf.reshape(relu_grad, [batch, -1])
    relu_grad_mag = tf.norm(reshaped_relu_grad, axis=1)
  result_grad =  tf.where(relu_grad_mag > 0., relu_grad, redirected_grad)

  global_step_t = tf.train.get_or_create_global_step()
  return_relu_grad = tf.greater(global_step_t, tf.constant(16, tf.int64))

  return tf.where(return_relu_grad, relu_grad, result_grad)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



