in torchbiggraph/model.py [0:0]
def make_model(config: ConfigSchema) -> MultiRelationEmbedder:
if config.dynamic_relations:
if len(config.relations) != 1:
raise RuntimeError(
"Dynamic relations are enabled, so there should only be one "
"entry in config.relations with config for all relations."
)
try:
relation_type_storage = RELATION_TYPE_STORAGES.make_instance(
config.entity_path
)
num_dynamic_rels = relation_type_storage.load_count()
except CouldNotLoadData:
raise RuntimeError(
"Dynamic relations are enabled, so there should be a file called "
"dynamic_rel_count.txt in the entity path with their count."
)
else:
num_dynamic_rels = 0
if config.num_batch_negs > 0 and config.batch_size % config.num_batch_negs != 0:
raise RuntimeError(
"Batch size (%d) must be a multiple of num_batch_negs (%d)"
% (config.batch_size, config.num_batch_negs)
)
lhs_operators: List[Optional[Union[AbstractOperator, AbstractDynamicOperator]]] = []
rhs_operators: List[Optional[Union[AbstractOperator, AbstractDynamicOperator]]] = []
for r in config.relations:
lhs_operators.append(
instantiate_operator(
r.operator, Side.LHS, num_dynamic_rels, config.entity_dimension(r.lhs)
)
)
rhs_operators.append(
instantiate_operator(
r.operator, Side.RHS, num_dynamic_rels, config.entity_dimension(r.rhs)
)
)
comparator_class = COMPARATORS.get_class(config.comparator)
comparator = comparator_class()
if config.bias:
comparator = BiasedComparator(comparator)
if config.regularization_coef != 0:
regularizer_class = REGULARIZERS.get_class(config.regularizer)
regularizer = regularizer_class(config.regularization_coef)
else:
regularizer = None
return MultiRelationEmbedder(
config.dimension,
config.relations,
config.entities,
num_uniform_negs=config.num_uniform_negs,
num_batch_negs=config.num_batch_negs,
disable_lhs_negs=config.disable_lhs_negs,
disable_rhs_negs=config.disable_rhs_negs,
lhs_operators=lhs_operators,
rhs_operators=rhs_operators,
comparator=comparator,
regularizer=regularizer,
global_emb=config.global_emb,
max_norm=config.max_norm,
num_dynamic_rels=num_dynamic_rels,
half_precision=config.half_precision,
)