in tensorflow_privacy/privacy/bolt_on/models.py [0:0]
def calculate_class_weights(self,
class_weights=None,
class_counts=None,
num_classes=None):
"""Calculates class weighting to be used in training.
Args:
class_weights: str specifying type, array giving weights, or None.
class_counts: If class_weights is not None, then an array of the number of
samples for each class
num_classes: If class_weights is not None, then the number of classes.
Returns:
class_weights as 1D tensor, to be passed to model's fit method.
"""
# Value checking
class_keys = ['balanced']
is_string = False
if isinstance(class_weights, str):
is_string = True
if class_weights not in class_keys:
raise ValueError('Detected string class_weights with '
'value: {0}, which is not one of {1}.'
'Please select a valid class_weight type'
'or pass an array'.format(class_weights, class_keys))
if class_counts is None:
raise ValueError('Class counts must be provided if using '
'class_weights=%s' % class_weights)
class_counts_shape = tf.Variable(
class_counts, trainable=False, dtype=self._dtype).shape
if len(class_counts_shape) != 1:
raise ValueError('class counts must be a 1D array.'
'Detected: {0}'.format(class_counts_shape))
if num_classes is None:
raise ValueError('num_classes must be provided if using '
'class_weights=%s' % class_weights)
elif class_weights is not None:
if num_classes is None:
raise ValueError('You must pass a value for num_classes if '
'creating an array of class_weights')
# performing class weight calculation
if class_weights is None:
class_weights = 1
elif is_string and class_weights == 'balanced':
num_samples = sum(class_counts)
weighted_counts = tf.dtypes.cast(
tf.math.multiply(num_classes, class_counts), self._dtype)
class_weights = (
tf.Variable(num_samples, dtype=self._dtype) /
tf.Variable(weighted_counts, dtype=self._dtype))
else:
class_weights = _ops.convert_to_tensor_v2(class_weights)
if len(class_weights.shape) != 1:
raise ValueError('Detected class_weights shape: {0} instead of '
'1D array'.format(class_weights.shape))
if class_weights.shape[0] != num_classes:
raise ValueError('Detected array length: {0} instead of: {1}'.format(
class_weights.shape[0], num_classes))
return class_weights