Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # custom GNN
- import dgl.nn as dglnn
- import torch.nn as nn
- import torch.nn.functional as F
- class CustomGNNLayer(nn.Module):
- def __init__(self, user_in_feats, product_in_feats, image_in_feats, hidden_feats):
- super(CustomGNNLayer, self).__init__()
- # Define weight matrices for each node type
- self.weight_user = nn.Linear(user_in_feats, hidden_feats)
- self.weight_product = nn.Linear(product_in_feats, hidden_feats)
- self.weight_image = nn.Linear(image_in_feats, hidden_feats)
- self.weight_self = nn.Linear(hidden_feats, hidden_feats)
- def forward(self, g, h):
- with g.local_scope():
- # Extract features from the dictionaries
- user_feats = h['user']['features']
- product_feats = h['product']['features']
- image_feats = h['image']['features']
- # Assign features to each node type
- g.nodes['user'].data['h'] = self.weight_user(user_feats)
- g.nodes['product'].data['h'] = self.weight_product(product_feats)
- g.nodes['image'].data['h'] = self.weight_image(image_feats)
- # Message function to fetch incoming messages
- def message_func(edges):
- return {'msg': edges.src['h']}
- # Reduce function to aggregate messages
- def reduce_func(nodes):
- neigh_msg = nodes.mailbox['msg'].mean(dim=1)
- self_msg = self.weight_self(nodes.data['h'])
- return {'h': torch.relu(neigh_msg + self_msg)}
- # Update all node types
- g.update_all(message_func, reduce_func, etype=('user', 'rates', 'product'))
- g.update_all(message_func, reduce_func, etype=('user', 'has', 'image'))
- # Extract updated features for each node type
- user_feats_out = g.nodes['user'].data['h']
- product_feats_out = g.nodes['product'].data['h']
- image_feats_out = g.nodes['image'].data['h']
- return {'user': user_feats_out, 'product': product_feats_out, 'image': image_feats_out}
- #embedding generation
- class EmbeddingGenerationModel(nn.Module):
- def __init__(self, user_in_feats, product_in_feats, image_in_feats, hidden_feats):
- super(EmbeddingGenerationModel, self).__init__()
- self.layers = CustomGNNLayer(user_in_feats, product_in_feats, image_in_feats, hidden_feats)
- self.user_final_layer = nn.Linear(hidden_feats, hidden_feats)
- self.product_final_layer = nn.Linear(hidden_feats, hidden_feats)
- self.image_final_layer = nn.Linear(hidden_feats, hidden_feats)
- def forward(self, g, h):
- h = self.layers(g, h)
- user_out = self.user_final_layer(h['user'])
- product_out = self.product_final_layer(h['product'])
- image_out = self.image_final_layer(h['image'])
- return user_out, product_out, image_out
- model#
- from torch.nn.functional import cosine_similarity
- class LinkPredictionModel(nn.Module):
- def __init__(self, user_in_feats, product_in_feats, image_in_feats, hidden_feats):
- super().__init__()
- self.embedding_model = EmbeddingGenerationModel(
- user_in_feats, product_in_feats, image_in_feats, hidden_feats)
- self.fc = nn.Linear(2, 1) # 2 similarity scores: user-image and user-product
- def forward(self, g, user_feats, product_feats, image_feats, edges):
- # Generate embeddings
- user_embeddings, product_embeddings, image_embeddings = self.embedding_model(g, {'user': user_feats, 'product': product_feats, 'image': image_feats})
- # Select relevant embeddings based on edges
- user_embed_selected = user_embeddings[edges[0]]
- product_embed_selected = product_embeddings[edges[1]]
- image_embed_selected = image_embeddings[edges[0]] # Assuming image embeddings correspond to users
- # Check if selected embeddings match edge sizes
- assert user_embed_selected.size(0) == edges[0].size(0), "Mismatch between user embeddings and edges"
- assert product_embed_selected.size(0) == edges[1].size(0), "Mismatch between product embeddings and edges"
- assert image_embed_selected.size(0) == edges[0].size(0), "Mismatch between image embeddings and edges"
- # Calculate user-image similarity (cosine similarity)
- user_image_similarity = cosine_similarity(user_embed_selected, image_embed_selected, dim=1).unsqueeze(1)
- # Calculate user-product similarity (cosine similarity)
- user_product_similarity = cosine_similarity(user_embed_selected, product_embed_selected, dim=1).unsqueeze(1)
- # Concatenate user_image_similarity and user_product_similarity
- similarities = torch.cat([user_image_similarity, user_product_similarity], dim=1)
- # Prediction using similarities
- interaction_probabilities = torch.sigmoid(self.fc(similarities))
- return interaction_probabilities
- user_in_feats = 512
- product_in_feats = 512
- image_in_feats= 512
- link_prediction_model = LinkPredictionModel(
- user_in_feats=user_in_feats, product_in_feats=product_in_feats, image_in_feats=image_in_feats, hidden_feats=512
- )
- embedding_model = EmbeddingGenerationModel(
- user_in_feats=user_in_feats, product_in_feats=product_in_feats, image_in_feats=image_in_feats, hidden_feats=512
- )
- #training loop
- learning_rate = 0.010
- num_epochs = 100
- #loss_fn = nn.BCEWithLogitsLoss() # Binary Cross Entropy with Logits
- optimizer = torch.optim.Adam(link_prediction_model.parameters(), lr=learning_rate)
- train_edges = train_g.edges(etype='rates')
- for epoch in range(num_epochs):
- epoch_loss = 0.0 # Accumulate loss for the epoch
- # DataLoader for positive and negative edge samples
- dataloader = DataLoader(train_edges, batch_size=64, shuffle=True)
- for i, batch in enumerate(dataloader):
- optimizer.zero_grad() # Reset gradients
- pos_u, pos_v = batch[:, 0], batch[:, 1]
- neg_u = pos_u # Negative samples have the same users as positive samples
- neg_v = torch.randint(0, train_num_products, (len(pos_v),)) # Random negative products
- # Forward pass for positive edges
- pos_scores = link_prediction_model(train_g, train_user_feats, train_product_feats, train_image_feats, (pos_u, pos_v))
- pos_scores = pos_scores.squeeze(-1)
- # Forward pass for negative edges
- neg_scores = link_prediction_model(train_g, train_user_feats, train_product_feats, train_image_feats, (neg_u, neg_v))
- neg_scores = neg_scores.squeeze(-1)
- # Max-margin loss
- loss = torch.sum(torch.clamp(1 - pos_scores + neg_scores, min=0))
- # Backward pass
- loss.backward() # Compute gradients
- # Perform parameter update
- optimizer.step() # Update parameters
- epoch_loss += loss.item()
- print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement