SHARE
TWEET

Untitled

a guest Jul 16th, 2019 64 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. from typing import List, Dict
  2. import itertools
  3. import numpy as np
  4. import pandas as pd
  5. from sklearn import metrics
  6. from sklearn.feature_extraction.text import CountVectorizer
  7. import torch
  8. import torch.nn as nn
  9. from torch.utils.data import Dataset, DataLoader, TensorDataset
  10. import torch.nn.functional as F
  11.    
  12. class xDeepFM(nn.Module):
  13.  
  14.     def __init__(self,
  15.                  n_outputs:int=None,
  16.                  cat_cols: List[str]=None,
  17.                  cat_nuniques: List[int]=None,
  18.                  ordinal_cols:List[str]=None,
  19.                  use_raw:bool=True,
  20.                  use_bias:bool=True,
  21.                  D:int=4,
  22.                  H_sizes:List[int]=None,
  23.                  hidden_dims:List[int]=None,
  24.                  dropouts:List[float]=None,
  25.                  device:str='cpu') -> None:
  26.         super().__init__()
  27.         self.n_outputs = n_outputs
  28.         self.device = device
  29.         self.use_raw = use_raw
  30.         self.use_bias = use_bias
  31.         if use_bias:
  32.             self.bias = nn.Parameter(torch.rand(1))
  33.         if hidden_dims and (len(dropouts) != len(hidden_dims)):
  34.             raise ValueError
  35.         if use_raw:          
  36.             if cat_nuniques:
  37.                 raw_size = len(cat_nuniques)
  38.             else:
  39.                 raw_size = 0
  40.         if H_sizes and cat_nuniques:
  41.             self.H_sizes = [len(cat_nuniques)] + H_sizes
  42.         ordinal_size = 0
  43.         if ordinal_cols:
  44.             ordinal_size = len(ordinal_cols)
  45.             if use_raw:
  46.                 raw_size += ordinal_size
  47.         output_embed_size = 0
  48.         self.W_l = []
  49.         if cat_nuniques:
  50.             if use_raw:
  51.                 self.cat_raw_inputs = nn.ModuleList([nn.Embedding(cat_nunique, 1) for cat_nunique in cat_nuniques])
  52.             self.cat_embeddings = nn.ModuleList([nn.Embedding(cat_nunique, D) for cat_nunique in cat_nuniques])
  53.             output_embed_size = D * len(cat_nuniques)
  54.             ## CIN
  55.             if H_sizes:
  56.                 for H_k_1, H_k in zip(self.H_sizes[:-1], self.H_sizes[1:]):
  57.                     self.W_l += [nn.Parameter(torch.randn((H_k,H_k_1,len(cat_nuniques))))]
  58.                 self.W_l = nn.ParameterList(self.W_l)
  59.                 self.cin_weights = nn.Linear(sum(H_sizes), n_outputs, bias=False)
  60.             ## Deep  
  61.         self.dnn = []
  62.         if hidden_dims:
  63.             output_size = output_embed_size + ordinal_size
  64.             dims = [output_size] + hidden_dims
  65.             for k in range(1, len(hidden_dims) + 1):
  66.                 input_dim = dims[k-1]
  67.                 output_dim = dims[k]
  68.                 dropout = dropouts[k-1]
  69.                 self.dnn += [nn.Sequential(nn.Linear(input_dim,output_dim),
  70.                                                nn.ReLU(),
  71.                                                #nn.BatchNorm1d(output_dim),
  72.                                                nn.Dropout(dropout)
  73.                                                )]
  74.             self.dnn = nn.ModuleList(self.dnn)
  75.             self.dnn_weights = nn.Linear(hidden_dims[-1], n_outputs, bias=False)
  76.  
  77.         if use_raw:
  78.             self.raw_weights = nn.Linear(raw_size, n_outputs, bias=False)
  79.        
  80.     def forward(self, inputs: Dict) -> torch.Tensor:  
  81.         batch_size = len(inputs['index'])
  82.         # cat columns
  83.         if 'cat_X' in inputs.keys():
  84.             X_0_batch = torch.stack([emb(inputs['cat_X'][:,k]) for k, emb in enumerate(self.cat_embeddings)],1)
  85.             if self.use_raw:
  86.                 raw_inputs = torch.cat([emb(inputs['cat_X'][:,k]) for k, emb in enumerate(self.cat_raw_inputs)],1)  
  87.             if self.dnn:
  88.                 dnn_inputs = X_0_batch.reshape(X_0_batch.size(0), -1)
  89.         else:
  90.             raw_inputs = dnn_inputs = None
  91.         # ordinals
  92.         if 'ord_X' in inputs.keys():
  93.             ordinal_dense = inputs['ord_X']
  94.             if raw_inputs is not None:
  95.                 raw_inputs = torch.cat([raw_inputs, ordinal_dense], dim=1)
  96.             else:
  97.                 raw_inputs = ordinal_dense
  98.             if dnn_inputs is not None:
  99.                 dnn_inputs = torch.cat([dnn_inputs, ordinal_dense], dim=1)
  100.             else:
  101.                 dnn_inputs = ordinal_dense                
  102.         out = torch.zeros((batch_size,self.n_outputs), dtype=torch.float, device=self.device)
  103.         if self.use_raw:
  104.             out += self.raw_weights(raw_inputs)
  105.         # CIN        
  106.         if self.W_l:
  107.             X_batch = [X_0_batch]
  108.             for k, W in enumerate(self.W_l):
  109.                 X_batch += [torch.einsum('him,nimd->nhd', W, torch.einsum('nhd, nmd->nhmd', X_batch[k], X_0_batch))]
  110.             p_plus = torch.cat([torch.sum(X, dim=2) for X in X_batch[1:]],dim=1)
  111.             out += (self.cin_weights(p_plus))
  112.         # DNN
  113.         if self.dnn:
  114.             for k, layer in enumerate(self.dnn):
  115.                 dnn_inputs = layer(dnn_inputs)
  116.             out += self.dnn_weights(dnn_inputs)
  117.         if self.use_bias:
  118.             out += self.bias
  119.        
  120.         return out.cpu().squeeze()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top