def convert_conv_type()

in pretrain/pointcontrast/model/modules/common.py [0:0]


def convert_conv_type(conv_type, kernel_size, D):
  assert isinstance(conv_type, ConvType), "conv_type must be of ConvType"
  region_type = conv_to_region_type[conv_type]
  axis_types = None
  if conv_type == ConvType.SPATIAL_HYPERCUBE:
    # No temporal convolution
    if isinstance(kernel_size, collections.Sequence):
      kernel_size = kernel_size[:3]
    else:
      kernel_size = [
          kernel_size,
      ] * 3
    if D == 4:
      kernel_size.append(1)
  elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCUBE:
    # conv_type conversion already handled
    assert D == 4
  elif conv_type == ConvType.HYPERCUBE:
    # conv_type conversion already handled
    pass
  elif conv_type == ConvType.SPATIAL_HYPERCROSS:
    if isinstance(kernel_size, collections.Sequence):
      kernel_size = kernel_size[:3]
    else:
      kernel_size = [
          kernel_size,
      ] * 3
    if D == 4:
      kernel_size.append(1)
  elif conv_type == ConvType.HYPERCROSS:
    # conv_type conversion already handled
    pass
  elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCROSS:
    # conv_type conversion already handled
    assert D == 4
  elif conv_type == ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS:
    # Define the CUBIC conv kernel for spatial dims and CROSS conv for temp dim
    axis_types = [
        ME.RegionType.HYPERCUBE,
    ] * 3
    if D == 4:
      axis_types.append(ME.RegionType.HYPERCROSS)
  return region_type, axis_types, kernel_size