in tzrec/features/feature.py [0:0]
def mc_module(self, device: torch.device) -> Optional[ManagedCollisionModule]:
"""Get ManagedCollisionModule."""
if self.is_sparse:
if hasattr(self.config, "zch") and self.config.HasField("zch"):
evict_type = self.config.zch.WhichOneof("eviction_policy")
evict_config = getattr(self.config.zch, evict_type)
threshold_filtering_func = None
if self.config.zch.HasField("threshold_filtering_func"):
threshold_filtering_func = eval(
self.config.zch.threshold_filtering_func
)
if evict_type == "lfu":
eviction_policy = LFU_EvictionPolicy(
threshold_filtering_func=threshold_filtering_func
)
elif evict_type == "lru":
eviction_policy = LRU_EvictionPolicy(
decay_exponent=evict_config.decay_exponent,
threshold_filtering_func=threshold_filtering_func,
)
elif evict_type == "distance_lfu":
eviction_policy = DistanceLFU_EvictionPolicy(
decay_exponent=evict_config.decay_exponent,
threshold_filtering_func=threshold_filtering_func,
)
else:
raise ValueError("Unknown evict policy type: {evict_type}")
return MCHManagedCollisionModule(
zch_size=self.config.zch.zch_size,
device=device,
eviction_interval=self.config.zch.eviction_interval,
eviction_policy=eviction_policy,
)
return None