in torchbiggraph/filtered_eval.py [0:0]
def __init__(self, config: ConfigSchema, filter_paths: List[str]) -> None:
loss_fn = LOSS_FUNCTIONS.get_class(config.loss_fn)(margin=config.margin)
relation_weights = [r.weight for r in config.relations]
super().__init__(loss_fn, relation_weights)
if len(config.relations) != 1 or len(config.entities) != 1:
raise RuntimeError(
"Filtered ranking evaluation should only be used "
"with dynamic relations and one entity type."
)
if not config.relations[0].all_negs:
raise RuntimeError("Filtered Eval can only be done with all negatives.")
(entity,) = config.entities.values()
if entity.featurized:
raise RuntimeError("Entity cannot be featurized for filtered eval.")
if entity.num_partitions > 1:
raise RuntimeError("Entity cannot be partitioned for filtered eval.")
self.lhs_map: Dict[Tuple[int, int], List[int]] = defaultdict(list)
self.rhs_map: Dict[Tuple[int, int], List[int]] = defaultdict(list)
for path in filter_paths:
logger.info(f"Building links map from path {path}")
e_storage = EDGE_STORAGES.make_instance(path)
# Assume unpartitioned.
edges = e_storage.load_edges(UNPARTITIONED, UNPARTITIONED)
for idx in range(len(edges)):
# Assume non-featurized.
cur_lhs = int(edges.lhs.to_tensor()[idx])
# Assume dynamic relations.
cur_rel = int(edges.rel[idx])
# Assume non-featurized.
cur_rhs = int(edges.rhs.to_tensor()[idx])
self.lhs_map[cur_lhs, cur_rel].append(cur_rhs)
self.rhs_map[cur_rhs, cur_rel].append(cur_lhs)
logger.info(f"Done building links map from path {path}")