Advertisement
Guest User

Untitled

a guest
Jul 16th, 2019
99
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.89 KB | None | 0 0
  1. from typing import List, Dict
  2. import itertools
  3. import numpy as np
  4. import pandas as pd
  5. from sklearn import metrics
  6. from sklearn.feature_extraction.text import CountVectorizer
  7. import torch
  8. import torch.nn as nn
  9. from torch.utils.data import Dataset, DataLoader, TensorDataset
  10. import torch.nn.functional as F
  11.  
  12. class xDeepFM(nn.Module):
  13.  
  14. def __init__(self,
  15. n_outputs:int=None,
  16. cat_cols: List[str]=None,
  17. cat_nuniques: List[int]=None,
  18. ordinal_cols:List[str]=None,
  19. use_raw:bool=True,
  20. use_bias:bool=True,
  21. D:int=4,
  22. H_sizes:List[int]=None,
  23. hidden_dims:List[int]=None,
  24. dropouts:List[float]=None,
  25. device:str='cpu') -> None:
  26. super().__init__()
  27. self.n_outputs = n_outputs
  28. self.device = device
  29. self.use_raw = use_raw
  30. self.use_bias = use_bias
  31. if use_bias:
  32. self.bias = nn.Parameter(torch.rand(1))
  33. if hidden_dims and (len(dropouts) != len(hidden_dims)):
  34. raise ValueError
  35. if use_raw:
  36. if cat_nuniques:
  37. raw_size = len(cat_nuniques)
  38. else:
  39. raw_size = 0
  40. if H_sizes and cat_nuniques:
  41. self.H_sizes = [len(cat_nuniques)] + H_sizes
  42. ordinal_size = 0
  43. if ordinal_cols:
  44. ordinal_size = len(ordinal_cols)
  45. if use_raw:
  46. raw_size += ordinal_size
  47. output_embed_size = 0
  48. self.W_l = []
  49. if cat_nuniques:
  50. if use_raw:
  51. self.cat_raw_inputs = nn.ModuleList([nn.Embedding(cat_nunique, 1) for cat_nunique in cat_nuniques])
  52. self.cat_embeddings = nn.ModuleList([nn.Embedding(cat_nunique, D) for cat_nunique in cat_nuniques])
  53. output_embed_size = D * len(cat_nuniques)
  54. ## CIN
  55. if H_sizes:
  56. for H_k_1, H_k in zip(self.H_sizes[:-1], self.H_sizes[1:]):
  57. self.W_l += [nn.Parameter(torch.randn((H_k,H_k_1,len(cat_nuniques))))]
  58. self.W_l = nn.ParameterList(self.W_l)
  59. self.cin_weights = nn.Linear(sum(H_sizes), n_outputs, bias=False)
  60. ## Deep
  61. self.dnn = []
  62. if hidden_dims:
  63. output_size = output_embed_size + ordinal_size
  64. dims = [output_size] + hidden_dims
  65. for k in range(1, len(hidden_dims) + 1):
  66. input_dim = dims[k-1]
  67. output_dim = dims[k]
  68. dropout = dropouts[k-1]
  69. self.dnn += [nn.Sequential(nn.Linear(input_dim,output_dim),
  70. nn.ReLU(),
  71. #nn.BatchNorm1d(output_dim),
  72. nn.Dropout(dropout)
  73. )]
  74. self.dnn = nn.ModuleList(self.dnn)
  75. self.dnn_weights = nn.Linear(hidden_dims[-1], n_outputs, bias=False)
  76.  
  77. if use_raw:
  78. self.raw_weights = nn.Linear(raw_size, n_outputs, bias=False)
  79.  
  80. def forward(self, inputs: Dict) -> torch.Tensor:
  81. batch_size = len(inputs['index'])
  82. # cat columns
  83. if 'cat_X' in inputs.keys():
  84. X_0_batch = torch.stack([emb(inputs['cat_X'][:,k]) for k, emb in enumerate(self.cat_embeddings)],1)
  85. if self.use_raw:
  86. raw_inputs = torch.cat([emb(inputs['cat_X'][:,k]) for k, emb in enumerate(self.cat_raw_inputs)],1)
  87. if self.dnn:
  88. dnn_inputs = X_0_batch.reshape(X_0_batch.size(0), -1)
  89. else:
  90. raw_inputs = dnn_inputs = None
  91. # ordinals
  92. if 'ord_X' in inputs.keys():
  93. ordinal_dense = inputs['ord_X']
  94. if raw_inputs is not None:
  95. raw_inputs = torch.cat([raw_inputs, ordinal_dense], dim=1)
  96. else:
  97. raw_inputs = ordinal_dense
  98. if dnn_inputs is not None:
  99. dnn_inputs = torch.cat([dnn_inputs, ordinal_dense], dim=1)
  100. else:
  101. dnn_inputs = ordinal_dense
  102. out = torch.zeros((batch_size,self.n_outputs), dtype=torch.float, device=self.device)
  103. if self.use_raw:
  104. out += self.raw_weights(raw_inputs)
  105. # CIN
  106. if self.W_l:
  107. X_batch = [X_0_batch]
  108. for k, W in enumerate(self.W_l):
  109. X_batch += [torch.einsum('him,nimd->nhd', W, torch.einsum('nhd, nmd->nhmd', X_batch[k], X_0_batch))]
  110. p_plus = torch.cat([torch.sum(X, dim=2) for X in X_batch[1:]],dim=1)
  111. out += (self.cin_weights(p_plus))
  112. # DNN
  113. if self.dnn:
  114. for k, layer in enumerate(self.dnn):
  115. dnn_inputs = layer(dnn_inputs)
  116. out += self.dnn_weights(dnn_inputs)
  117. if self.use_bias:
  118. out += self.bias
  119.  
  120. return out.cpu().squeeze()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement