in python/hlo/optimize.py [0:0]
def get_rtr_rewriter(inst, id_to_inst, output_ids):
if inst.opcode != 'convolution':
return None
input_id, kernel_id = inst.operand_ids
if input_id in output_ids or kernel_id in output_ids:
return None
input_inst = id_to_inst[input_id]
kernel_inst = id_to_inst[kernel_id]
input_shape = list(input_inst.shape.dimensions)
window_dims = inst.window.dimensions
window_sizes = [dim.size for dim in window_dims]
strides = [dim.stride for dim in window_dims]
padding_lows = [dim.padding_low for dim in window_dims]
padding_highs = [dim.padding_high for dim in window_dims]
base_dilations = [dim.base_dilation for dim in window_dims]
dim_nums = inst.convolution_dimension_numbers
spatial_dimensions = dim_nums.input_spatial_dimensions
feature_dimension = dim_nums.input_feature_dimension
kernel_spatial_dimensions = dim_nums.kernel_spatial_dimensions
kernel_input_feature_dimension = dim_nums.kernel_input_feature_dimension
base_sizes = [input_shape[dim] for dim in spatial_dimensions]
# necessary conditions for enabling rtr shuffle
symm_strides = len(set(strides)) == 1
all_strided = all(strd > 1 for strd in strides)
spatial_divisible = all(bsz % strd == 0 for bsz, strd in zip(base_sizes, strides))
stride_condition = symm_strides and all_strided and spatial_divisible
input_is_param = input_inst.opcode == 'parameter'
no_padding = all(pad == 0 for pad in padding_lows + padding_highs)
few_channels = input_shape[feature_dimension] < 32
input_condition = input_is_param and no_padding and few_channels
window_condition = all(ws > 1 for ws in window_sizes) and kernel_inst.opcode == 'constant'
group_condition = inst.feature_group_count == 1
if not (stride_condition and input_condition and window_condition and group_condition):
return None
# pad and shuffle kernel
kernel_array = HloOp(kernel_inst).literal_value
kernel_spatial_paddings = [[0, 0] for _ in kernel_array.shape]
for dim, stride in zip(kernel_spatial_dimensions, strides):
kernel_spatial_paddings[dim][1] = kernel_array.shape[dim] % stride
kernel_array_p = np.pad(kernel_array, kernel_spatial_paddings)
kernel_array_prtr = _rtr_transform(kernel_array_p, kernel_spatial_dimensions,
kernel_input_feature_dimension, strides)
prtr_window_sizes = [kernel_array_prtr.shape[dim] for dim in kernel_spatial_dimensions]
# don't enable rtr shuffle if explicit dilation is still require after padding kernel
if not all(ws >= dil for ws, dil in zip(prtr_window_sizes, base_dilations)):
return None
# shuffle indices and new input shape
num_elements = np.prod([int(size) for size in input_shape])
indices = np.arange(num_elements).reshape(input_shape)
indices_rtr = _rtr_transform(indices, spatial_dimensions, feature_dimension, strides)
shuffle_indices = indices_rtr.reshape([-1])
input_r_shape = [int(size) for size in input_shape]
for dim, stride in zip(spatial_dimensions, strides):
input_r_shape[dim] //= stride
input_r_shape[feature_dimension] *= stride
# return rewriter
return RtrRewriter(inst, kernel_inst, input_r_shape, shuffle_indices, kernel_array_prtr)