def get_rtr_rewriter()

in python/hlo/optimize.py [0:0]


def get_rtr_rewriter(inst, id_to_inst, output_ids):
    if inst.opcode != 'convolution':
        return None
    input_id, kernel_id = inst.operand_ids
    if input_id in output_ids or kernel_id in output_ids:
        return None
    input_inst = id_to_inst[input_id]
    kernel_inst = id_to_inst[kernel_id]
    input_shape = list(input_inst.shape.dimensions)
    window_dims = inst.window.dimensions
    window_sizes = [dim.size for dim in window_dims]
    strides = [dim.stride for dim in window_dims]
    padding_lows = [dim.padding_low for dim in window_dims]
    padding_highs = [dim.padding_high for dim in window_dims]
    base_dilations = [dim.base_dilation for dim in window_dims]
    dim_nums = inst.convolution_dimension_numbers
    spatial_dimensions = dim_nums.input_spatial_dimensions
    feature_dimension = dim_nums.input_feature_dimension
    kernel_spatial_dimensions = dim_nums.kernel_spatial_dimensions
    kernel_input_feature_dimension = dim_nums.kernel_input_feature_dimension
    base_sizes = [input_shape[dim] for dim in spatial_dimensions]

    # necessary conditions for enabling rtr shuffle
    symm_strides = len(set(strides)) == 1
    all_strided = all(strd > 1 for strd in strides)
    spatial_divisible = all(bsz % strd == 0 for bsz, strd in zip(base_sizes, strides))
    stride_condition = symm_strides and all_strided and spatial_divisible
    input_is_param = input_inst.opcode == 'parameter'
    no_padding = all(pad == 0 for pad in padding_lows + padding_highs)
    few_channels = input_shape[feature_dimension] < 32
    input_condition = input_is_param and no_padding and few_channels
    window_condition = all(ws > 1 for ws in window_sizes) and kernel_inst.opcode == 'constant'
    group_condition = inst.feature_group_count == 1
    if not (stride_condition and input_condition and window_condition and group_condition):
        return None

    # pad and shuffle kernel
    kernel_array = HloOp(kernel_inst).literal_value
    kernel_spatial_paddings = [[0, 0] for _ in kernel_array.shape]
    for dim, stride in zip(kernel_spatial_dimensions, strides):
        kernel_spatial_paddings[dim][1] = kernel_array.shape[dim] % stride
    kernel_array_p = np.pad(kernel_array, kernel_spatial_paddings)
    kernel_array_prtr = _rtr_transform(kernel_array_p, kernel_spatial_dimensions,
                                       kernel_input_feature_dimension, strides)
    prtr_window_sizes = [kernel_array_prtr.shape[dim] for dim in kernel_spatial_dimensions]

    # don't enable rtr shuffle if explicit dilation is still require after padding kernel
    if not all(ws >= dil for ws, dil in zip(prtr_window_sizes, base_dilations)):
        return None

    # shuffle indices and new input shape
    num_elements = np.prod([int(size) for size in input_shape])
    indices = np.arange(num_elements).reshape(input_shape)
    indices_rtr = _rtr_transform(indices, spatial_dimensions, feature_dimension, strides)
    shuffle_indices = indices_rtr.reshape([-1])
    input_r_shape = [int(size) for size in input_shape]
    for dim, stride in zip(spatial_dimensions, strides):
        input_r_shape[dim] //= stride
        input_r_shape[feature_dimension] *= stride

    # return rewriter
    return RtrRewriter(inst, kernel_inst, input_r_shape, shuffle_indices, kernel_array_prtr)