Advertisement
Amaboh

Pytorch run time error calling backward()

Aug 6th, 2023
1,407
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.91 KB | Help | 0 0
  1. # custom GNN
  2. import dgl.nn as dglnn
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5.  
  6. class CustomGNNLayer(nn.Module):
  7.     def __init__(self, user_in_feats, product_in_feats, image_in_feats, hidden_feats):
  8.         super(CustomGNNLayer, self).__init__()
  9.         # Define weight matrices for each node type
  10.         self.weight_user = nn.Linear(user_in_feats, hidden_feats)
  11.         self.weight_product = nn.Linear(product_in_feats, hidden_feats)
  12.         self.weight_image = nn.Linear(image_in_feats, hidden_feats)
  13.         self.weight_self = nn.Linear(hidden_feats, hidden_feats)
  14.  
  15.     def forward(self, g, h):
  16.         with g.local_scope():
  17.             # Extract features from the dictionaries
  18.             user_feats = h['user']['features']
  19.             product_feats = h['product']['features']
  20.             image_feats = h['image']['features']
  21.  
  22.             # Assign features to each node type
  23.             g.nodes['user'].data['h'] = self.weight_user(user_feats)
  24.             g.nodes['product'].data['h'] = self.weight_product(product_feats)
  25.             g.nodes['image'].data['h'] = self.weight_image(image_feats)
  26.            
  27.             # Message function to fetch incoming messages
  28.             def message_func(edges):
  29.                 return {'msg': edges.src['h']}
  30.            
  31.             # Reduce function to aggregate messages
  32.             def reduce_func(nodes):
  33.                 neigh_msg = nodes.mailbox['msg'].mean(dim=1)
  34.                 self_msg = self.weight_self(nodes.data['h'])
  35.                 return {'h': torch.relu(neigh_msg + self_msg)}
  36.  
  37.             # Update all node types
  38.             g.update_all(message_func, reduce_func, etype=('user', 'rates', 'product'))
  39.             g.update_all(message_func, reduce_func, etype=('user', 'has', 'image'))
  40.  
  41.             # Extract updated features for each node type
  42.             user_feats_out = g.nodes['user'].data['h']
  43.             product_feats_out = g.nodes['product'].data['h']
  44.             image_feats_out = g.nodes['image'].data['h']
  45.  
  46.             return {'user': user_feats_out, 'product': product_feats_out, 'image': image_feats_out}
  47.  
  48.  
  49. #embedding generation
  50. class EmbeddingGenerationModel(nn.Module):
  51.     def __init__(self, user_in_feats, product_in_feats, image_in_feats, hidden_feats):
  52.         super(EmbeddingGenerationModel, self).__init__()
  53.         self.layers = CustomGNNLayer(user_in_feats, product_in_feats, image_in_feats, hidden_feats)
  54.         self.user_final_layer = nn.Linear(hidden_feats, hidden_feats)
  55.         self.product_final_layer = nn.Linear(hidden_feats, hidden_feats)
  56.         self.image_final_layer = nn.Linear(hidden_feats, hidden_feats)
  57.  
  58.     def forward(self, g, h):
  59.         h = self.layers(g, h)
  60.         user_out = self.user_final_layer(h['user'])
  61.         product_out = self.product_final_layer(h['product'])
  62.         image_out = self.image_final_layer(h['image'])
  63.         return user_out, product_out, image_out
  64. model#
  65. from torch.nn.functional import cosine_similarity
  66.  
  67. class LinkPredictionModel(nn.Module):
  68.     def __init__(self, user_in_feats, product_in_feats, image_in_feats, hidden_feats):
  69.         super().__init__()
  70.         self.embedding_model = EmbeddingGenerationModel(
  71.             user_in_feats, product_in_feats, image_in_feats, hidden_feats)
  72.         self.fc = nn.Linear(2, 1)  # 2 similarity scores: user-image and user-product
  73.        
  74.     def forward(self, g, user_feats, product_feats, image_feats, edges):
  75.         # Generate embeddings
  76.         user_embeddings, product_embeddings, image_embeddings = self.embedding_model(g, {'user': user_feats, 'product': product_feats, 'image': image_feats})
  77.        
  78.         # Select relevant embeddings based on edges
  79.         user_embed_selected = user_embeddings[edges[0]]
  80.         product_embed_selected = product_embeddings[edges[1]]
  81.         image_embed_selected = image_embeddings[edges[0]]  # Assuming image embeddings correspond to users
  82.  
  83.         # Check if selected embeddings match edge sizes
  84.         assert user_embed_selected.size(0) == edges[0].size(0), "Mismatch between user embeddings and edges"
  85.         assert product_embed_selected.size(0) == edges[1].size(0), "Mismatch between product embeddings and edges"
  86.         assert image_embed_selected.size(0) == edges[0].size(0), "Mismatch between image embeddings and edges"
  87.  
  88.         # Calculate user-image similarity (cosine similarity)
  89.         user_image_similarity = cosine_similarity(user_embed_selected, image_embed_selected, dim=1).unsqueeze(1)
  90.  
  91.         # Calculate user-product similarity (cosine similarity)
  92.         user_product_similarity = cosine_similarity(user_embed_selected, product_embed_selected, dim=1).unsqueeze(1)
  93.  
  94.         # Concatenate user_image_similarity and user_product_similarity
  95.         similarities = torch.cat([user_image_similarity, user_product_similarity], dim=1)
  96.  
  97.         # Prediction using similarities
  98.         interaction_probabilities = torch.sigmoid(self.fc(similarities))
  99.  
  100.         return interaction_probabilities
  101.  
  102.  
  103. user_in_feats = 512
  104. product_in_feats = 512
  105. image_in_feats= 512
  106.  
  107. link_prediction_model = LinkPredictionModel(
  108.     user_in_feats=user_in_feats, product_in_feats=product_in_feats, image_in_feats=image_in_feats, hidden_feats=512
  109. )
  110. embedding_model = EmbeddingGenerationModel(
  111.     user_in_feats=user_in_feats, product_in_feats=product_in_feats, image_in_feats=image_in_feats, hidden_feats=512
  112. )
  113.  
  114.  
  115.  
  116.  
  117.  
  118. #training loop
  119. learning_rate = 0.010
  120. num_epochs = 100
  121. #loss_fn = nn.BCEWithLogitsLoss() # Binary Cross Entropy with Logits
  122. optimizer = torch.optim.Adam(link_prediction_model.parameters(), lr=learning_rate)
  123.  
  124. train_edges = train_g.edges(etype='rates')
  125.  
  126. for epoch in range(num_epochs):
  127.     epoch_loss = 0.0  # Accumulate loss for the epoch
  128.    
  129.     # DataLoader for positive and negative edge samples
  130.     dataloader = DataLoader(train_edges, batch_size=64, shuffle=True)
  131.     for i, batch in enumerate(dataloader):
  132.         optimizer.zero_grad()  # Reset gradients
  133.         pos_u, pos_v = batch[:, 0], batch[:, 1]
  134.         neg_u = pos_u  # Negative samples have the same users as positive samples
  135.         neg_v = torch.randint(0, train_num_products, (len(pos_v),))  # Random negative products
  136.  
  137.         # Forward pass for positive edges
  138.         pos_scores = link_prediction_model(train_g, train_user_feats, train_product_feats, train_image_feats, (pos_u, pos_v))
  139.         pos_scores = pos_scores.squeeze(-1)
  140.  
  141.         # Forward pass for negative edges
  142.         neg_scores = link_prediction_model(train_g, train_user_feats, train_product_feats, train_image_feats, (neg_u, neg_v))
  143.         neg_scores = neg_scores.squeeze(-1)
  144.  
  145.         # Max-margin loss
  146.         loss = torch.sum(torch.clamp(1 - pos_scores + neg_scores, min=0))
  147.        
  148.         # Backward pass
  149.         loss.backward()  # Compute gradients
  150.  
  151.         # Perform parameter update
  152.         optimizer.step()  # Update parameters
  153.  
  154.         epoch_loss += loss.item()
  155.  
  156.     print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}")
  157.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement