Advertisement
coloriot

Mamba_main_function

Apr 12th, 2025 (edited)
19
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.02 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torch.nn.functional as F
  5. import numpy as np
  6. import pandas as pd
  7. from sklearn.preprocessing import StandardScaler
  8.  
  9. device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cuda" if torch.cuda.is_available() else "cpu")
  10.  
  11. class TrueMambaBlock(nn.Module):
  12.     def __init__(self, d_model, rank=1):
  13.         super().__init__()
  14.         self.d_model = d_model
  15.  
  16.         # Structured A: low-rank + diagonal (HiPPO-inspired initialization)
  17.         self.D = nn.Parameter(-torch.abs(torch.randn(d_model)))
  18.         self.low_rank_U = nn.Parameter(torch.randn(d_model, rank))
  19.         self.low_rank_V = nn.Parameter(torch.randn(rank, d_model))
  20.  
  21.         # Input and output projections (B, C)
  22.         self.B = nn.Parameter(torch.randn(d_model))
  23.         self.C = nn.Parameter(torch.randn(d_model))
  24.  
  25.         # Gating mechanism
  26.         self.input_proj = nn.Linear(d_model, d_model)
  27.         self.gate_proj = nn.Linear(d_model, d_model)
  28.         self.activation = nn.Sigmoid()
  29.  
  30.     def compute_A(self):
  31.         return torch.diag(self.D) + self.low_rank_U @ self.low_rank_V
  32.  
  33.     def compute_kernel(self, A, L):
  34.         dt = 1.0
  35.         kernel = []
  36.         Ak = torch.eye(self.d_model, device=A.device)
  37.         for _ in range(L):
  38.             kernel.append((Ak @ self.B).unsqueeze(0))
  39.             Ak = Ak @ (torch.eye(self.d_model, device=A.device) + dt * A)
  40.         kernel = torch.cat(kernel, dim=0)
  41.         return kernel
  42.  
  43.     def forward(self, x):
  44.         batch_size, seq_len, d_model = x.shape
  45.  
  46.         u = self.input_proj(x)
  47.         gate = self.activation(self.gate_proj(x))
  48.  
  49.         A = self.compute_A()
  50.         kernel = self.compute_kernel(A, seq_len)
  51.  
  52.         u = u.permute(0, 2, 1)
  53.         k = kernel.permute(1, 0).unsqueeze(1)
  54.         y = F.conv1d(u, k, padding=seq_len - 1, groups=self.d_model)
  55.         y = y[:, :, :seq_len]
  56.         y = y.permute(0, 2, 1)
  57.  
  58.         y = gate * y
  59.         y = self.C * y
  60.         return y
  61.  
  62. class MambaTabularModel(nn.Module):
  63.     def __init__(self, input_dim, hidden_dim=64, output_dim=1):
  64.         super().__init__()
  65.         self.input_layer = nn.Linear(input_dim, hidden_dim)
  66.         self.mamba_block = TrueMambaBlock(d_model=hidden_dim)
  67.         self.output_layer = nn.Linear(hidden_dim, output_dim)
  68.  
  69.     def forward(self, x):
  70.         x = F.relu(self.input_layer(x))
  71.         x = x.unsqueeze(1)
  72.         x = self.mamba_block(x)
  73.         x = x.squeeze(1)
  74.         return self.output_layer(x)
  75.  
  76. def analyze_portfolio_mamba_backtest(df, window_years=5, test_years=1):
  77.     independent_vars = ["MARKET_RETURN_ADJ", "SMB", "HML", "MOM",
  78.                         'MARKET_RETURN_ADJ_lag1', 'SMB_lag1', 'HML_lag1', 'MOM_lag1',
  79.                         'MARKET_RETURN_ADJ_lag2', 'SMB_lag2', 'HML_lag2', 'MOM_lag2',
  80.                         'MARKET_RETURN_ADJ_rollmean_3', 'MARKET_RETURN_ADJ_rollstd_3',
  81.                         'SMB_rollmean_3', 'SMB_rollstd_3', 'HML_rollmean_3', 'HML_rollstd_3',
  82.                         'MOM_rollmean_3', 'MOM_rollstd_3', 'MARKET_RETURN_ADJ_rollmean_6',
  83.                         'MARKET_RETURN_ADJ_rollstd_6', 'SMB_rollmean_6', 'SMB_rollstd_6',
  84.                         'HML_rollmean_6', 'HML_rollstd_6', 'MOM_rollmean_6', 'MOM_rollstd_6',
  85.                         'MARKET_RETURN_ADJ_x_SMB', 'MARKET_RETURN_ADJ_x_HML',
  86.                         'MARKET_RETURN_ADJ_x_MOM', 'SMB_x_HML', 'SMB_x_MOM', 'HML_x_MOM',
  87.                         'MARKET_RETURN_ADJ_squared', 'MARKET_RETURN_ADJ_cubed',
  88.                         'SMB_squared', 'SMB_cubed', 'HML_squared', 'HML_cubed',
  89.                         'MOM_squared', 'MOM_cubed', 'Month_sin', 'Month_cos', 'Quarter_sin', 'Quarter_cos']
  90.  
  91.     exclude = set(independent_vars + ["TRADEDATE", 'Year', 'Month', 'Quarter'])
  92.  
  93.     df = df.copy()
  94.     df['TRADEDATE'] = pd.to_datetime(df['TRADEDATE'])
  95.     df = df.sort_values('TRADEDATE')
  96.  
  97.     tradedates = df['TRADEDATE'].drop_duplicates().sort_values().tolist()
  98.     window_months = window_years * 12
  99.     test_months = test_years * 12
  100.  
  101.     results = []
  102.  
  103.     for start_idx in range(0, len(tradedates) - window_months - test_months + 1):
  104.         train_start = tradedates[start_idx]
  105.         train_end = tradedates[start_idx + window_months - 1]
  106.         test_start = tradedates[start_idx + window_months]
  107.         test_end = tradedates[start_idx + window_months + test_months - 1]
  108.  
  109.         train_df = df[(df['TRADEDATE'] >= train_start) & (df['TRADEDATE'] <= train_end)]
  110.         test_df = df[(df['TRADEDATE'] >= test_start) & (df['TRADEDATE'] <= test_end)]
  111.  
  112.         target_cols = [col for col in df.columns if col not in exclude]
  113.         expected_returns = {}
  114.         trained_models = {}
  115.  
  116.         for target in target_cols:
  117.             df_temp = pd.concat([train_df[independent_vars], train_df[target]], axis=1).dropna()
  118.             if df_temp.empty:
  119.                 continue
  120.  
  121.             X = df_temp[independent_vars]
  122.             y = df_temp[target]
  123.  
  124.             scaler = StandardScaler()
  125.             X_scaled = scaler.fit_transform(X)
  126.             X_tensor = torch.tensor(X_scaled, dtype=torch.float32, device=device)
  127.             y_tensor = torch.tensor(y.values.reshape(-1, 1), dtype=torch.float32, device=device)
  128.  
  129.             model = MambaTabularModel(input_dim=X_tensor.shape[1]).to(device)
  130.             if hasattr(torch, "compile") and device.type != "mps":
  131.                 model = torch.compile(model)
  132.  
  133.             optimizer = optim.Adam(model.parameters(), lr=0.05)
  134.             loss_fn = nn.MSELoss()
  135.  
  136.             for _ in range(5):
  137.                 model.train()
  138.                 optimizer.zero_grad()
  139.                 loss = loss_fn(model(X_tensor), y_tensor)
  140.                 loss.backward()
  141.                 optimizer.step()
  142.  
  143.             model.eval()
  144.             with torch.no_grad():
  145.                 predictions = model(X_tensor).cpu().numpy()
  146.             expected_returns[target] = predictions.mean()
  147.             trained_models[target] = model
  148.  
  149.         if not expected_returns:
  150.             continue
  151.  
  152.         all_stocks = list(expected_returns.keys())
  153.         mu = np.array([expected_returns[s] for s in all_stocks])
  154.         train_returns_df = train_df[all_stocks].dropna()
  155.         cov_matrix = train_returns_df.cov().values
  156.  
  157.         optimal_weights = optimize_portfolio(mu, cov_matrix)
  158.  
  159.         test_returns_df = test_df[all_stocks].dropna()
  160.         if test_returns_df.empty:
  161.             continue
  162.  
  163.         portfolio_returns = backtest_portfolio(test_returns_df, optimal_weights, all_stocks)
  164.         portfolio_beta = 1.0
  165.         risk_metrics = compute_risk_metrics(portfolio_returns, portfolio_beta)
  166.  
  167.         result_row = {
  168.             'Train Start': train_start,
  169.             'Train End': train_end,
  170.             'Test Start': test_start,
  171.             'Test End': test_end,
  172.         }
  173.         result_row.update(risk_metrics)
  174.  
  175.         for stock, weight in zip(all_stocks, optimal_weights):
  176.             result_row[f'Weight_{stock}'] = weight
  177.  
  178.         results.append(result_row)
  179.  
  180.     return pd.DataFrame(results)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement