in shap_e/models/transmitter/channels_encoder.py [0:0]
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
param_shapes: Dict[str, Tuple[int]],
params_proj: Dict[str, Any],
min_unrolls: int,
max_unrolls: int,
d_latent: int = 512,
latent_bottleneck: Optional[Dict[str, Any]] = None,
latent_warp: Optional[Dict[str, Any]] = None,
width: int = 512,
layers: int = 12,
xattn_layers: int = 1,
heads: int = 8,
init_scale: float = 0.25,
# Training hparams
inner_batch_size: Union[int, List[int]] = 1,
data_ctx: int = 1,