Advertisement
Guest User

Untitled

a guest
Jun 24th, 2020
58
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import datetime
  2. import pprint
  3.  
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import pandas as pd
  7. import pandas_datareader as pdr
  8. import pymc3 as pm
  9. from pymc3.distributions.timeseries import GaussianRandomWalk
  10. import seaborn as sns
  11.  
  12. def obtain_amazon_prices(start_date, end_date):
  13.     print("Downloading and plotting AWZN log returns...")
  14.     amzn = pdr.get_data_yahoo("AMZN", start_date, end_date)
  15.     amzn["returns"] = amzn["Adj Close"]/amzn["Adj Close"].shift(1)
  16.     amzn.dropna(inplace=True);
  17.     amzn["log_returns"] = np.log(amzn["returns"])
  18.     amzn["log_returns"].plot(linewidth=0.5)
  19.     plt.ylabel("AMZN daily percentage returns");
  20.     plt.show();
  21.     return amzn
  22.  
  23. def stoch_vol_model(log_returns, samples):
  24.     print("Configuring stochastic volatility with PyMC3...")
  25.     model = pm.Model()
  26.     with model:
  27.         sigma = pm.Exponential('sigma', 10.0, testval=0.1)
  28.         nu = pm.Exponential('nu', 0.1)
  29.         s = GaussianRandomWalk('s', sigma**-2, shape=len(log_returns))
  30.         logrets = pm.StudentT('logrets', nu, lam=pm.math.exp(-2.0*s), observed=log_returns)
  31.     print("Fitting the stochastic volatility model...")
  32.     with model:
  33.         trace = pm.sample(samples);
  34.     pm.traceplot(trace) #pm.traceplot(trace, model.vars[:-1])
  35.     plt.show()
  36.  
  37.     print("Plotting the log volatility...")
  38.     k = 10 #step size
  39.     opacity = 0.03
  40.     plt.plot(trace[s][::k].T, 'b', alpha=opacity)
  41.     plt.xlabel('Time')
  42.     plt.ylabel('Log volatility')
  43.     plt.show()
  44.  
  45.     print("Plotting the absolute returns overlaid with vol...")
  46.     plt.plot(np.abs(np.exp(log_returns))-1.0, linewidth=0.5)
  47.     plt.plot(np.exp(trace[s][::k].T), 'r', alpha=opacity)
  48.     plt.xlabel("Trading days")
  49.     plt.ylabel("Absolute returns/volatility")
  50.     plt.show()
  51.  
  52. if __name__ == '__main__':
  53.     print(pm.__version__)
  54.     start = datetime.datetime(2006, 1, 1)
  55.     end = datetime.datetime(2015, 12, 31)
  56.     amzn = obtain_amazon_prices(start, end)
  57.     log_returns = np.array(amzn["log_returns"])
  58.     samples = 2000;
  59.     stoch_vol_model(log_returns, samples)
Advertisement
RAW Paste Data Copied
Advertisement