in src/rime/models/graph_conv.py [0:0]
def __init__(self, n_users=None, n_items=None,
user_rec=True, item_rec=True, no_components=32,
n_negatives=10, lr=1, weight_decay=1e-5,
user_conv_model='GCN', # plain_average
user_embeddings=None, item_embeddings=None, item_zero_bias=False,
recency_boundary_multipliers=[0.1, 0.3, 1, 3, 10], horizon=float("inf")):
super().__init__(user_rec, item_rec, n_negatives, lr, weight_decay)
if item_embeddings is not None:
warnings.warn("setting no_components according to provided embeddings")
no_components = item_embeddings.shape[-1]
self.item_encoder = torch.nn.Embedding(n_items, no_components)
if item_embeddings is not None:
self.item_encoder.weight.requires_grad = False
self.item_encoder.weight.copy_(torch.as_tensor(item_embeddings))
self.item_bias_vec = torch.nn.Embedding(n_items, 1)
if item_zero_bias:
self.item_bias_vec.weight.requires_grad = False
self.item_bias_vec.weight.copy_(torch.zeros_like(self.item_bias_vec.weight))
if user_conv_model == 'GCN':
self.user_conv = dgl.nn.pytorch.conv.GraphConv(no_components, no_components, "none")
elif user_conv_model == 'plain_average':
self.user_conv = _plain_average
self.user_layer_norm = torch.nn.LayerNorm(no_components)
if user_embeddings is not None:
self.user_ext_layer_norm = torch.nn.LayerNorm(user_embeddings.shape[1])
self.register_buffer("recency_boundaries",
torch.as_tensor(recency_boundary_multipliers) * horizon)
self.recency_encoder = torch.nn.Embedding(len(recency_boundary_multipliers) + 1, 1)
self.init_weights()