07_training/serverlessml/flowers/utils/augment.py (19 lines of code) (raw):
#!/usr/bin/env python
# Copyright 2020 Google Inc. Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
import numpy as np
import tensorflow as tf
class RandomColorDistortion(tf.keras.layers.Layer):
def __init__(self, contrast_range=[0.5, 1.5],
brightness_delta=[-0.2, 0.2], **kwargs):
super(RandomColorDistortion, self).__init__(**kwargs)
self.contrast_range = contrast_range
self.brightness_delta = brightness_delta
def call(self, images, training=None):
if not training:
return images
contrast = np.random.uniform(
self.contrast_range[0], self.contrast_range[1])
brightness = np.random.uniform(
self.brightness_delta[0], self.brightness_delta[1])
images = tf.image.adjust_contrast(images, contrast)
images = tf.image.adjust_brightness(images, brightness)
images = tf.clip_by_value(images, 0, 1)
return images