def read_dataset_video_and_frame()

in courses/machine_learning/asl/open_project/ASL_youtube8m_models/video_and_frame_using_datasets/trainer/model.py [0:0]


def read_dataset_video_and_frame(file_pattern, mode, batch_size):
    def _input_fn():
        print("\nread_dataset_video_and_frame: _input_fn: file_pattern = {}".format(file_pattern))
        print("read_dataset_video_and_frame: _input_fn: mode = {}".format(mode))
        print("read_dataset_video_and_frame: _input_fn: batch_size = {}".format(batch_size))

        # This function dequantizes our tensors to bring them back to full floating point precision
        def dequantize(feat_vector, max_quantized_value = 2, min_quantized_value = -2):
            assert max_quantized_value > min_quantized_value # ensure the max value is larger than the min value
            quantized_range = max_quantized_value - min_quantized_value # find the range between max and min
            scalar = quantized_range / 255.0 # create a scale factor where 0 is the min and 1 is the max
            bias = (quantized_range / 512.0) + min_quantized_value # create bias term to shift our scaled feature vector
            return feat_vector * scalar + bias # return the scaled and shifted feature vector

        # This function resizes our frames axis so that we only get a subset of frames
        def resize_axis(tensor, axis, new_size, fill_value = 0):
            tensor = tf.convert_to_tensor(value = tensor) # ensure tensor is a tensor
            shape = tf.unstack(value = tf.shape(input = tensor)) # create a list where each element is a 1-D tensor the size of each dimension in tensor

            pad_shape = shape[:] # create a copy of the shape list of 1-D tensors
            pad_shape[axis] = tf.maximum(x = 0, y = new_size - shape[axis]) # change the size of the axis dimension to the maximum of 0 and the new size of our padded shape

            shape[axis] = tf.minimum(x = shape[axis], y = new_size) # change the size of the axis dimension to the minimum of our original shape and the new size of our padded shape
            shape = tf.stack(values = shape) # stack the list of tensor sizes back into a larger tensor

            resized = tf.concat(values = [
                tf.slice(input_ = tensor, begin = tf.zeros_like(tensor = shape), size = shape), # slice the tensor starting at the 0th index in each dimension and going as far as our adjusted shape in each dimension
                tf.fill(dims = tf.stack(values = pad_shape), value = tf.cast(x = fill_value, dtype = tensor.dtype)) # fill the rest of the tensor with the fill value
            ], axis = axis) # concatenate our sliced tensor with our fill value tensor together

            new_shape = tensor.get_shape().as_list() # get the static shape of the tensor and output it to a list
            new_shape[axis] = new_size # change the static shape's axis to our new size
            resized.set_shape(shape = new_shape) # set the static shape of our resized tensor to our new shape
            return resized # return the resized tensor

        # This function will decode video examples from the video level TF Records
        def video_decode_example(serialized_examples):
            # Create feature map
            feature_map = {
                'video_id': tf.FixedLenFeature(shape = [], dtype = tf.string),
                'labels': tf.VarLenFeature(dtype = tf.int64),
                'mean_rgb': tf.FixedLenFeature(shape = [1024], dtype = tf.float32),
                'mean_audio': tf.FixedLenFeature(shape = [128], dtype = tf.float32)
            }

            # Parse TF Records into our features
            features = tf.parse_single_example(serialized = serialized_examples, features = feature_map)
            print("\nread_dataset_video_and_frame: _input_fn: video_decode_example: features = {}".format(features)) # shape = video_id = (), mean_rgb = (1024,), mean_audio = (128,), labels = SparseTensor object

            # Extract and format labels
            sparse_labels = features.pop("labels") # SparseTensor object
            print("read_dataset_video_and_frame: _input_fn: video_decode_example: sparse_labels = {}".format(sparse_labels))
            labels = tf.cast(x = tf.sparse_to_dense(sparse_indices = sparse_labels.values, output_shape = (NUM_CLASSES,), sparse_values = 1, validate_indices = False), dtype = tf.float32)
            print("read_dataset_video_and_frame: _input_fn: video_decode_example: labels = {}\n".format(labels)) # shape = (NUM_CLASSES,)

            return features, labels

        # This function will decode frame examples from the frame level TF Records
        def frame_decode_example(serialized_examples):
            # Create context and sequence feature map
            context_features = {
                "video_id": tf.FixedLenFeature(shape = [], dtype = tf.string),
                "labels": tf.VarLenFeature(dtype = tf.int64)
            }
            sequence_features = {
                "rgb": tf.FixedLenSequenceFeature(shape = [], dtype = tf.string),
                "audio": tf.FixedLenSequenceFeature(shape = [], dtype = tf.string)
            }

            # Parse TF Records into our features
            contexts, features = tf.parse_single_sequence_example(
                serialized = serialized_examples, 
                context_features = context_features,
                sequence_features = sequence_features)
            print("\nread_dataset_video_and_frame: _input_fn: frame_decode_example: contexts = {}".format(contexts)) # shape = video_id = (), labels = SparseTensor object
            print("read_dataset_video_and_frame: _input_fn: frame_decode_example: features = {}".format(features)) # shape = rgb = (frames_per_video,), audio = (frames_per_video,)

            # Create features
            # Pass video_id to features
            features['video_id'] = contexts['video_id'] # shape = video_id = (), rgb = (frames_per_video,), audio = (frames_per_video,)
            print("read_dataset_video_and_frame: _input_fn: frame_decode_example: features = {}".format(features))

            # Fix rgb data
            decoded_rgb = tf.reshape(tensor = tf.cast(x = tf.decode_raw(bytes = features["rgb"], out_type = tf.uint8), dtype = tf.float32), shape = [-1, 1024]) # shape = (frames_per_video, 1024)
            print("read_dataset_video_and_frame: _input_fn: frame_decode_example: decoded_rgb = {}".format(decoded_rgb))
            rgb_matrix = resize_axis(tensor = dequantize(decoded_rgb), axis = 0, new_size = MAX_FRAMES) # shape = (MAX_FRAMES, 1024)
            print("read_dataset_video_and_frame: _input_fn: frame_decode_example: rgb_matrix = {}".format(rgb_matrix))
            features['rgb'] = rgb_matrix
            print("read_dataset_video_and_frame: _input_fn: frame_decode_example: features = {}".format(features)) # shape = video_id = (), rgb = (MAX_FRAMES, 1024), audio = (frames_per_video,)

            # Fix audio data
            decoded_audio = tf.reshape(tensor = tf.cast(x = tf.decode_raw(bytes = features["audio"], out_type = tf.uint8), dtype = tf.float32), shape = [-1, 128]) # shape = (frames_per_video, 128)
            print("read_dataset_video_and_frame: _input_fn: frame_decode_example: decoded_audio = {}".format(decoded_audio))
            audio_matrix = resize_axis(tensor = dequantize(decoded_audio), axis = 0, new_size = MAX_FRAMES) # shape = (MAX_FRAMES, 128)
            print("read_dataset_video_and_frame: _input_fn: frame_decode_example: audio_matrix = {}".format(audio_matrix))
            features['audio'] = audio_matrix
            print("read_dataset_video_and_frame: _input_fn: frame_decode_example: features = {}".format(features)) # shape = video_id = (), rgb = (MAX_FRAMES, 1024), audio = (MAX_FRAMES, 128)

            # Extract and format labels
            labels = tf.cast(x = tf.sparse_to_dense(sparse_indices = contexts['labels'].values, output_shape = (NUM_CLASSES,), sparse_values = 1, validate_indices = False), dtype = tf.float32)
            print("read_dataset_video_and_frame: _input_fn: frame_decode_example: labels = {}\n".format(labels)) # shape = (NUM_CLASSES,)

            return features, labels

        # Create list of files from file pattern
        if mode == tf.estimator.ModeKeys.TRAIN:
            video_file_list = tf.gfile.Glob(filename = file_pattern + "/video_level/train/train*.tfrecord")
            frame_file_list = tf.gfile.Glob(filename = file_pattern + "/frame_level/train/train*.tfrecord")
        else:
            video_file_list = tf.gfile.Glob(filename = file_pattern + "/video_level/validate/validate-0.tfrecord")
            frame_file_list = tf.gfile.Glob(filename = file_pattern + "/frame_level/validate/validate-0.tfrecord")
        #print("read_dataset_video_and_frame: _input_fn: video_file_list = {}".format(video_file_list))
        #print("read_dataset_video_and_frame: _input_fn: frame_file_list = {}".format(frame_file_list))

        # Create dataset from file list
        video_dataset = tf.data.TFRecordDataset(filenames = video_file_list)
        print("read_dataset_video_and_frame: _input_fn: video_dataset.TFRecordDataset = {}".format(video_dataset))
        frame_dataset = tf.data.TFRecordDataset(filenames = frame_file_list)
        print("read_dataset_video_and_frame: _input_fn: frame_dataset.TFRecordDataset = {}".format(frame_dataset))

        # Decode TF Record dataset examples
        video_dataset = video_dataset.map(map_func = lambda x: video_decode_example(serialized_examples = x))
        print("read_dataset_video_and_frame: _input_fn: video_dataset.map = {}".format(video_dataset))
        frame_dataset = frame_dataset.map(map_func = lambda x: frame_decode_example(serialized_examples = x))
        print("read_dataset_video_and_frame: _input_fn: frame_dataset.map = {}".format(frame_dataset))

        # Zip together video and frame datasets
        combined_dataset = tf.data.Dataset.zip(datasets = (video_dataset, frame_dataset))
        print("read_dataset_video_and_frame: _input_fn: combined_dataset = {}".format(combined_dataset))

        # Determine amount of times to repeat file and if we should shuffle based on if we are training or evaluating
        if mode == tf.estimator.ModeKeys.TRAIN:
            num_epochs = None # read files forever

            # Shuffle the dataset within a buffer
            combined_dataset = combined_dataset.shuffle(buffer_size = batch_size * 10, seed = None)
            print("read_dataset_video_and_frame: _input_fn: combined_dataset.shuffle = {}".format(combined_dataset))
        else:
            num_epochs = 1 # read files only once

        # Repeat files num_epoch times
        combined_dataset = combined_dataset.repeat(count = num_epochs)
        print("read_dataset_video_and_frame: _input_fn: combined_dataset.repeat = {}".format(combined_dataset))

        # Group the data into batches
        combined_dataset = combined_dataset.batch(batch_size = batch_size)
        print("read_dataset_video_and_frame: _input_fn: combined_dataset.batch = {}".format(combined_dataset))

        # Create a iterator and then pull the next batch of features and labels from the example queue
        (video_batch_features, video_batch_labels), (frame_batch_features, frame_batch_labels) = combined_dataset.make_one_shot_iterator().get_next()
        print("read_dataset_video_and_frame: _input_fn: video_batch_features = {}".format(video_batch_features))
        print("read_dataset_video_and_frame: _input_fn: video_batch_labels = {}".format(video_batch_labels))
        print("read_dataset_video_and_frame: _input_fn: frame_batch_features = {}".format(frame_batch_features))
        print("read_dataset_video_and_frame: _input_fn: frame_batch_labels = {}\n".format(frame_batch_labels))

        # Combine features from the two datasets
        batch_features = video_batch_features
        batch_features["rgb"] = frame_batch_features["rgb"]
        batch_features["audio"] = frame_batch_features["audio"]
        print("read_dataset_video_and_frame: _input_fn: batch_features = {}".format(batch_features))

        # Only need one set of labels
        batch_labels = video_batch_labels
        print("read_dataset_video_and_frame: _input_fn: batch_labels = {}".format(batch_labels))

        return batch_features, batch_labels
    return _input_fn