def calculate_class_weights()

in tensorflow_privacy/privacy/bolt_on/models.py [0:0]


  def calculate_class_weights(self,
                              class_weights=None,
                              class_counts=None,
                              num_classes=None):
    """Calculates class weighting to be used in training.

    Args:
      class_weights: str specifying type, array giving weights, or None.
      class_counts: If class_weights is not None, then an array of the number of
        samples for each class
      num_classes: If class_weights is not None, then the number of classes.

    Returns:
      class_weights as 1D tensor, to be passed to model's fit method.
    """
    # Value checking
    class_keys = ['balanced']
    is_string = False
    if isinstance(class_weights, str):
      is_string = True
      if class_weights not in class_keys:
        raise ValueError('Detected string class_weights with '
                         'value: {0}, which is not one of {1}.'
                         'Please select a valid class_weight type'
                         'or pass an array'.format(class_weights, class_keys))
      if class_counts is None:
        raise ValueError('Class counts must be provided if using '
                         'class_weights=%s' % class_weights)
      class_counts_shape = tf.Variable(
          class_counts, trainable=False, dtype=self._dtype).shape
      if len(class_counts_shape) != 1:
        raise ValueError('class counts must be a 1D array.'
                         'Detected: {0}'.format(class_counts_shape))
      if num_classes is None:
        raise ValueError('num_classes must be provided if using '
                         'class_weights=%s' % class_weights)
    elif class_weights is not None:
      if num_classes is None:
        raise ValueError('You must pass a value for num_classes if '
                         'creating an array of class_weights')
    # performing class weight calculation
    if class_weights is None:
      class_weights = 1
    elif is_string and class_weights == 'balanced':
      num_samples = sum(class_counts)
      weighted_counts = tf.dtypes.cast(
          tf.math.multiply(num_classes, class_counts), self._dtype)
      class_weights = (
          tf.Variable(num_samples, dtype=self._dtype) /
          tf.Variable(weighted_counts, dtype=self._dtype))
    else:
      class_weights = _ops.convert_to_tensor_v2(class_weights)
      if len(class_weights.shape) != 1:
        raise ValueError('Detected class_weights shape: {0} instead of '
                         '1D array'.format(class_weights.shape))
      if class_weights.shape[0] != num_classes:
        raise ValueError('Detected array length: {0} instead of: {1}'.format(
            class_weights.shape[0], num_classes))
    return class_weights