meehai

Untitled

Sep 14th, 2020
738
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import numpy as np
  2. import torch as tr
  3. import torch.nn as nn
  4.  
  5. device = tr.device("cuda") if tr.cuda.is_available() else tr.device("cpu")
  6.  
  7. def printDataStats(data):
  8.     nUsers = len(data["users"])
  9.     print("User features: %s" % data["users"][0].shape)
  10.     print("Session features: %d" % (data["sessions"][0][0].shape))
  11.     print("Hits features: %d" % (data["hits"][0][0][0].shape))
  12.     print("Targets features: %d" % (data["targets"][0][0].shape))
  13.  
  14.     for i in range(nUsers):
  15.         nSessions = len(data["sessions"][i])
  16.         print("User %d. Num sessions: %d" % (i, nSessions))
  17.         for j in range(nSessions):
  18.             nHits = len(data["hits"][j])
  19.             print("  - session %d. Num hits: %d" % (j, nHits))
  20.  
  21. def getData(nUsers, NU, NS, NH, NT):
  22.     dataUsers = tr.randn(nUsers, NU).to(device)
  23.     dataSessions = []
  24.     dataHits = []
  25.     dataTargets = []
  26.     # For each user, generate some number of sessions between 1 and 10
  27.     for i in range(nUsers):
  28.         sessCount = np.random.randint(1, 10)
  29.         dataSessions.append(tr.randn(sessCount, NS).to(device))
  30.         dataTargets.append(tr.randn(sessCount, NT).to(device))
  31.         # For each session, generate some number of hits between 1 and 10
  32.         sessionHits = []
  33.         for j in range(sessCount):
  34.             hitsCount = np.random.randint(1, 10)
  35.             sessionHits.append(tr.randn(hitsCount, NH).to(device))
  36.         dataHits.append(sessionHits)
  37.    
  38.     data = {"users" : dataUsers, "sessions" : dataSessions, "hits" : dataHits, "targets" : dataTargets}
  39.     return data
  40.  
  41. class Model(nn.Module):
  42.     def __init__(self, NU, NS, NH, NT):
  43.         super(Model, self).__init__()
  44.         self.rnnHits = nn.RNN(input_size=NH, hidden_size=10)
  45.         self.rnnSessions = nn.RNN(input_size=NS, hidden_size=10)
  46.         self.linearOut = nn.Linear(10 + 10 + NU, NT)
  47.  
  48.     def forward(self, data):
  49.         # Hidden state for each session of each user:
  50.         #  hiddenSessions :: [numSessions x 1 (batch) x 10] for each user (total nUsers)
  51.         hiddenSessions = [self.rnnSessions(data["sessions"][i].unsqueeze(dim=1))[0][:, 0] \
  52.             for i in range(len(data["sessions"]))]
  53.  
  54.         # Hidden state for each hits of each session of each user:
  55.         hiddenHits = []
  56.         for i in range(len(data["sessions"])):
  57.             userNumSessions = len(data["sessions"][i])
  58.             # Hidden state for each session of ith user user:
  59.             #  userHits :: [numHits x 1 (batch) x 10] for this user (total userNumSessions)
  60.             userHits =  [self.rnnHits(data["hits"][i][j].unsqueeze(dim=1))[0] \
  61.                 for j in range(len(data["hits"][i]))]
  62.             # Get only the last hidden state (basically the cummulation of all hidden states), so we transform the
  63.             #  hits into session features.
  64.             lastUserHit = tr.cat([x[0] for x in userHits]).to(device)
  65.             hiddenHits.append(lastUserHit)
  66.  
  67.         # Now, concatenate everything, resulting a list of nUsers x [nSessionsEachUser, NU + 10 + 10]
  68.         allFeatures = []
  69.         for i in range(len(data["users"])):
  70.             nSessions = len(data["sessions"][i])
  71.             user = data["users"][i].unsqueeze(dim=0).repeat(nSessions, 1)
  72.             session = hiddenSessions[i]
  73.             hit = hiddenHits[i]
  74.             concatenated = tr.cat([user, session, hit], dim=1)
  75.             allFeatures.append(concatenated)
  76.  
  77.         # Finally, project this output to the targets space
  78.         results = [self.linearOut(x) for x in allFeatures]
  79.         return results
  80.  
  81. def main():
  82.     nUsers = 10
  83.     # number of features per users, sessions and hits and target valiues (5 shopping stages)
  84.     NU, NS, NH, NT = 5, 10, 15, 5
  85.     data = getData(nUsers, NU, NS, NH, NT)
  86.  
  87.     # Print for each user the generated data
  88.     printDataStats(data)
  89.  
  90.     # Pass the data through the users>sessions>hits model
  91.     model = Model(NU, NS, NH, NT).to(device)
  92.     res = model.forward(data)
  93.     print([x.shape for x in res])
  94.     print([x.shape for x in data["targets"]])
  95.  
  96.     # Compute the loss (note, in reality we use one hot encoding, not [0-4] values for targets and cross entropy)
  97.     loss = [((x - y)**2).mean() for x, y in zip(res, data["targets"])]
  98.     loss = sum(loss)
  99.     print("Loss: %2.5f" % loss)
  100.  
  101.     # Optimize the loss, use train-val split etc.
  102.     # ...
  103.  
  104. if __name__ == "__main__":
  105.     main()
RAW Paste Data