07_training/serverlessml/flowers/utils/augment.py (19 lines of code) (raw):

#!/usr/bin/env python # Copyright 2020 Google Inc. Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES # OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. import numpy as np import tensorflow as tf class RandomColorDistortion(tf.keras.layers.Layer): def __init__(self, contrast_range=[0.5, 1.5], brightness_delta=[-0.2, 0.2], **kwargs): super(RandomColorDistortion, self).__init__(**kwargs) self.contrast_range = contrast_range self.brightness_delta = brightness_delta def call(self, images, training=None): if not training: return images contrast = np.random.uniform( self.contrast_range[0], self.contrast_range[1]) brightness = np.random.uniform( self.brightness_delta[0], self.brightness_delta[1]) images = tf.image.adjust_contrast(images, contrast) images = tf.image.adjust_brightness(images, brightness) images = tf.clip_by_value(images, 0, 1) return images