def check_shape_compatibility()

in horovod/spark/common/util.py [0:0]


def check_shape_compatibility(metadata, feature_columns, label_columns,
                              input_shapes=None, output_shapes=None, label_shapes=None):
    # Check for model and input type incompatibility. Columns must have the same size
    # (total number of elements) of the corresponding inputs.
    feature_count = len(feature_columns)
    if input_shapes is not None:
        if feature_count != len(input_shapes):
            raise ValueError('Feature column count {features} must equal '
                             'model inputs count {inputs}'
                             .format(features=feature_count, inputs=len(input_shapes)))

        for idx, col, input_shape in zip(range(feature_count), feature_columns, input_shapes):
            col_size = metadata[col]['shape']
            if col_size is None:
                # When training directly on Parquet, we do not compute shape metadata
                continue

            input_size = abs(np.prod(input_shape))
            if col_size != input_size:
                raise ValueError(
                    'Feature column \'{col}\' with size {feature} must equal that of the '
                    'model input at index {idx} with size {input}'
                    .format(col=col, feature=col_size, idx=idx, input=input_size))

    label_count = len(label_columns)
    if label_shapes is not None and label_count != len(label_shapes):
        raise ValueError('Label column count {labels} must equal '
                         'provided label shapes count {outputs}'
                         .format(labels=label_count, outputs=len(label_shapes)))

    if output_shapes is not None and label_count != len(output_shapes):
        raise ValueError('Label column count {labels} must equal '
                         'model outputs count {outputs}'
                         .format(labels=label_count, outputs=len(output_shapes)))

    def _check_label_cols_size(target_shapes, target_name):
        for idx, col, target_shape in zip(range(label_count), label_columns, target_shapes):
            col_size = metadata[col]['shape']
            if col_size is None:
                # When training directly on Parquet, we do not compute shape metadata
                continue

            target_size = abs(np.prod(target_shape))
            if col_size != target_size:
                raise ValueError('Label column \'{col}\' with size {label} must equal that of the '
                                 '{target_name} shape at index {idx} with size {output}'
                                 .format(col=col, label=col_size, idx=idx, output=target_size,
                                         target_name=target_name))

    if label_shapes is not None:
        _check_label_cols_size(label_shapes, 'label')
    elif output_shapes is not None:
        # Check the label size against the model output shapes only if label_shapes is not provided.
        _check_label_cols_size(output_shapes, 'model output')