Advertisement
tarstars

causal impact

Apr 7th, 2022
905
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.69 KB | None | 0 0
  1. import pandas as pd
  2. import tensorflow_probability as tfp
  3.  
  4. from causalimpact import CausalImpact
  5. from pandas import Timestamp
  6. from pandas.tseries.offsets import DateOffset
  7.  
  8.  
  9. def generate_data():
  10.     f = DateOffset(months=1)
  11.     d = {
  12.             Timestamp("2012-01-01 00:00:00", freq=f): 20660.0,
  13.             Timestamp("2012-02-01 00:00:00", freq=f): 21652.0,
  14.             Timestamp("2012-03-01 00:00:00", freq=f): 28327.0,
  15.             Timestamp("2012-04-01 00:00:00", freq=f): 20055.0,
  16.             Timestamp("2012-05-01 00:00:00", freq=f): 8987.0,
  17.             Timestamp("2012-06-01 00:00:00", freq=f): 8706.0,
  18.             Timestamp("2012-07-01 00:00:00", freq=f): 8888.0,
  19.             Timestamp("2012-08-01 00:00:00", freq=f): 8312.0,
  20.             Timestamp("2012-09-01 00:00:00", freq=f): 6249.0,
  21.             Timestamp("2012-10-01 00:00:00", freq=f): 7475.0,
  22.             Timestamp("2012-11-01 00:00:00", freq=f): 13007.0,
  23.             Timestamp("2012-12-01 00:00:00", freq=f): 18157.0,
  24.             Timestamp("2013-01-01 00:00:00", freq=f): 21455.0,
  25.             Timestamp("2013-02-01 00:00:00", freq=f): 23113.0,
  26.             Timestamp("2013-03-01 00:00:00", freq=f): 29942.0,
  27.             Timestamp("2013-04-01 00:00:00", freq=f): 20183.0,
  28.             Timestamp("2013-05-01 00:00:00", freq=f): 9433.0,
  29.             Timestamp("2013-06-01 00:00:00", freq=f): 8872.0,
  30.             Timestamp("2013-07-01 00:00:00", freq=f): 9352.0,
  31.             Timestamp("2013-08-01 00:00:00", freq=f): 7806.0,
  32.             Timestamp("2013-09-01 00:00:00", freq=f): 6699.0,
  33.             Timestamp("2013-10-01 00:00:00", freq=f): 7404.0,
  34.             Timestamp("2013-11-01 00:00:00", freq=f): 12524.0,
  35.             Timestamp("2013-12-01 00:00:00", freq=f): 17350.0,
  36.             Timestamp("2014-01-01 00:00:00", freq=f): 21417.0,
  37.             Timestamp("2014-02-01 00:00:00", freq=f): 23143.0,
  38.             Timestamp("2014-03-01 00:00:00", freq=f): 31041.0,
  39.             Timestamp("2014-04-01 00:00:00", freq=f): 22819.0,
  40.             Timestamp("2014-05-01 00:00:00", freq=f): 9254.0,
  41.             Timestamp("2014-06-01 00:00:00", freq=f): 9087.0,
  42.             Timestamp("2014-07-01 00:00:00", freq=f): 9135.0,
  43.             Timestamp("2014-08-01 00:00:00", freq=f): 8846.0,
  44.             Timestamp("2014-09-01 00:00:00", freq=f): 5620.0,
  45.             Timestamp("2014-10-01 00:00:00", freq=f): 9693.0,
  46.             Timestamp("2014-11-01 00:00:00", freq=f): 13065.0,
  47.             Timestamp("2014-12-01 00:00:00", freq=f): 19775.0,
  48.             Timestamp("2015-01-01 00:00:00", freq=f): 24978.0,
  49.             Timestamp("2015-02-01 00:00:00", freq=f): 26439.0,
  50.             Timestamp("2015-03-01 00:00:00", freq=f): 34992.0,
  51.             Timestamp("2015-04-01 00:00:00", freq=f): 21477.0,
  52.             Timestamp("2015-05-01 00:00:00", freq=f): 10316.0,
  53.             Timestamp("2015-06-01 00:00:00", freq=f): 9352.0,
  54.             Timestamp("2015-07-01 00:00:00", freq=f): 9601.0,
  55.             Timestamp("2015-08-01 00:00:00", freq=f): 9407.0,
  56.             Timestamp("2015-09-01 00:00:00", freq=f): 6994.0,
  57.             Timestamp("2015-10-01 00:00:00", freq=f): 10123.0,
  58.             Timestamp("2015-11-01 00:00:00", freq=f): 13429.0,
  59.             Timestamp("2015-12-01 00:00:00", freq=f): 20732.0,
  60.             Timestamp("2016-01-01 00:00:00", freq=f): 25286.0,
  61.             Timestamp("2016-02-01 00:00:00", freq=f): 27630.0,
  62.             Timestamp("2016-03-01 00:00:00", freq=f): 42085.0,
  63.             Timestamp("2016-04-01 00:00:00", freq=f): 32153.0,
  64.             Timestamp("2016-05-01 00:00:00", freq=f): 11538.0,
  65.             Timestamp("2016-06-01 00:00:00", freq=f): 9637.0,
  66.             Timestamp("2016-07-01 00:00:00", freq=f): 9553.0,
  67.             Timestamp("2016-08-01 00:00:00", freq=f): 9682.0,
  68.             Timestamp("2016-09-01 00:00:00", freq=f): 8135.0,
  69.             Timestamp("2016-10-01 00:00:00", freq=f): 13733.0,
  70.             Timestamp("2016-11-01 00:00:00", freq=f): 15500.0,
  71.         }
  72.     return pd.DataFrame.from_dict(d, orient="index")
  73.  
  74.  
  75. def main():
  76.     data = generate_data()
  77.     obs_series = data.iloc[:, 0].astype("float32")
  78.     regular_data = tfp.sts.regularize_series(series=obs_series)
  79.     local_linear = tfp.sts.LocalLinearTrend(observed_time_series=regular_data)
  80.     seasonal = tfp.sts.Seasonal(num_seasons=12, observed_time_series=regular_data)
  81.     model = tfp.sts.Sum([local_linear, seasonal], observed_time_series=regular_data)
  82.  
  83.     pre_period = ["2012-01-01", "2014-12-01"]
  84.     post_period = ["2015-01-01", "2016-11-01"]
  85.     ci = CausalImpact(regular_data, pre_period, post_period) # works
  86.     ci = CausalImpact(regular_data, pre_period, post_period, model=model) # doesn't work
  87.  
  88.     print(ci.summary())
  89.  
  90.  
  91. if __name__ == "__main__":
  92.     main()
  93.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement