Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from typing import List, Dict
- import itertools
- import numpy as np
- import pandas as pd
- from sklearn import metrics
- from sklearn.feature_extraction.text import CountVectorizer
- import torch
- import torch.nn as nn
- from torch.utils.data import Dataset, DataLoader, TensorDataset
- import torch.nn.functional as F
- class xDeepFM(nn.Module):
- def __init__(self,
- n_outputs:int=None,
- cat_cols: List[str]=None,
- cat_nuniques: List[int]=None,
- ordinal_cols:List[str]=None,
- use_raw:bool=True,
- use_bias:bool=True,
- D:int=4,
- H_sizes:List[int]=None,
- hidden_dims:List[int]=None,
- dropouts:List[float]=None,
- device:str='cpu') -> None:
- super().__init__()
- self.n_outputs = n_outputs
- self.device = device
- self.use_raw = use_raw
- self.use_bias = use_bias
- if use_bias:
- self.bias = nn.Parameter(torch.rand(1))
- if hidden_dims and (len(dropouts) != len(hidden_dims)):
- raise ValueError
- if use_raw:
- if cat_nuniques:
- raw_size = len(cat_nuniques)
- else:
- raw_size = 0
- if H_sizes and cat_nuniques:
- self.H_sizes = [len(cat_nuniques)] + H_sizes
- ordinal_size = 0
- if ordinal_cols:
- ordinal_size = len(ordinal_cols)
- if use_raw:
- raw_size += ordinal_size
- output_embed_size = 0
- self.W_l = []
- if cat_nuniques:
- if use_raw:
- self.cat_raw_inputs = nn.ModuleList([nn.Embedding(cat_nunique, 1) for cat_nunique in cat_nuniques])
- self.cat_embeddings = nn.ModuleList([nn.Embedding(cat_nunique, D) for cat_nunique in cat_nuniques])
- output_embed_size = D * len(cat_nuniques)
- ## CIN
- if H_sizes:
- for H_k_1, H_k in zip(self.H_sizes[:-1], self.H_sizes[1:]):
- self.W_l += [nn.Parameter(torch.randn((H_k,H_k_1,len(cat_nuniques))))]
- self.W_l = nn.ParameterList(self.W_l)
- self.cin_weights = nn.Linear(sum(H_sizes), n_outputs, bias=False)
- ## Deep
- self.dnn = []
- if hidden_dims:
- output_size = output_embed_size + ordinal_size
- dims = [output_size] + hidden_dims
- for k in range(1, len(hidden_dims) + 1):
- input_dim = dims[k-1]
- output_dim = dims[k]
- dropout = dropouts[k-1]
- self.dnn += [nn.Sequential(nn.Linear(input_dim,output_dim),
- nn.ReLU(),
- #nn.BatchNorm1d(output_dim),
- nn.Dropout(dropout)
- )]
- self.dnn = nn.ModuleList(self.dnn)
- self.dnn_weights = nn.Linear(hidden_dims[-1], n_outputs, bias=False)
- if use_raw:
- self.raw_weights = nn.Linear(raw_size, n_outputs, bias=False)
- def forward(self, inputs: Dict) -> torch.Tensor:
- batch_size = len(inputs['index'])
- # cat columns
- if 'cat_X' in inputs.keys():
- X_0_batch = torch.stack([emb(inputs['cat_X'][:,k]) for k, emb in enumerate(self.cat_embeddings)],1)
- if self.use_raw:
- raw_inputs = torch.cat([emb(inputs['cat_X'][:,k]) for k, emb in enumerate(self.cat_raw_inputs)],1)
- if self.dnn:
- dnn_inputs = X_0_batch.reshape(X_0_batch.size(0), -1)
- else:
- raw_inputs = dnn_inputs = None
- # ordinals
- if 'ord_X' in inputs.keys():
- ordinal_dense = inputs['ord_X']
- if raw_inputs is not None:
- raw_inputs = torch.cat([raw_inputs, ordinal_dense], dim=1)
- else:
- raw_inputs = ordinal_dense
- if dnn_inputs is not None:
- dnn_inputs = torch.cat([dnn_inputs, ordinal_dense], dim=1)
- else:
- dnn_inputs = ordinal_dense
- out = torch.zeros((batch_size,self.n_outputs), dtype=torch.float, device=self.device)
- if self.use_raw:
- out += self.raw_weights(raw_inputs)
- # CIN
- if self.W_l:
- X_batch = [X_0_batch]
- for k, W in enumerate(self.W_l):
- X_batch += [torch.einsum('him,nimd->nhd', W, torch.einsum('nhd, nmd->nhmd', X_batch[k], X_0_batch))]
- p_plus = torch.cat([torch.sum(X, dim=2) for X in X_batch[1:]],dim=1)
- out += (self.cin_weights(p_plus))
- # DNN
- if self.dnn:
- for k, layer in enumerate(self.dnn):
- dnn_inputs = layer(dnn_inputs)
- out += self.dnn_weights(dnn_inputs)
- if self.use_bias:
- out += self.bias
- return out.cpu().squeeze()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement