Advertisement
Guest User

Untitled

a guest
Aug 23rd, 2018
55
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 19.82 KB | None | 0 0
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "metadata": {},
  7. "outputs": [
  8. {
  9. "name": "stdout",
  10. "output_type": "stream",
  11. "text": [
  12. "Running on PyMC3 v3.5\n",
  13. "\n",
  14. " Patient_ID Family_ID Household_ID Age Sex Classroom Loc-x Loc-y \\\n",
  15. "0 184 51 6 13.0 0 2 73 80 \n",
  16. "1 173 9 35 6.0 1 0 85 43 \n",
  17. "2 177 17 6 8.0 2 1 73 80 \n",
  18. "3 178 17 6 4.0 1 0 73 80 \n",
  19. "4 174 9 35 3.0 2 0 85 43 \n",
  20. "5 45 48 51 7.0 1 1 30 22 \n",
  21. "6 183 4 6 4.0 1 0 73 80 \n",
  22. "7 175 9 35 2.0 1 0 85 43 \n",
  23. "8 181 4 6 10.0 2 2 73 80 \n",
  24. "9 42 33 45 4.0 2 0 68 7 \n",
  25. "\n",
  26. " Infector Prodromal_Onset Eruption_Onset Death_Day Died \\\n",
  27. "0 0 1 8 -1 0 \n",
  28. "1 0 3 5 -1 0 \n",
  29. "2 184 9 13 -1 0 \n",
  30. "3 184 9 13 -1 0 \n",
  31. "4 0 10 10 -1 0 \n",
  32. "5 184 13 15 -1 0 \n",
  33. "6 184 13 17 -1 0 \n",
  34. "7 173 14 17 -1 0 \n",
  35. "8 184 15 19 -1 0 \n",
  36. "9 178 17 21 -1 0 \n",
  37. "\n",
  38. " Infectious_Onset Infectious_End Infectious_Duration \n",
  39. "0 0 11 11 \n",
  40. "1 2 8 6 \n",
  41. "2 8 16 8 \n",
  42. "3 8 16 8 \n",
  43. "4 9 13 4 \n",
  44. "5 12 18 6 \n",
  45. "6 12 20 8 \n",
  46. "7 13 20 7 \n",
  47. "8 14 22 8 \n",
  48. "9 16 24 8 \n",
  49. "hi\n"
  50. ]
  51. }
  52. ],
  53. "source": [
  54. "import csv\n",
  55. "import numpy as np\n",
  56. "import pandas as pd\n",
  57. "import matplotlib.pyplot as plt\n",
  58. "import pymc3 as pm\n",
  59. "from jupyterthemes import jtplot\n",
  60. "from collections import namedtuple\n",
  61. "import os\n",
  62. "import time\n",
  63. "from datetime import datetime\n",
  64. "import pickle\n",
  65. "%matplotlib inline\n",
  66. "##plt.style.use(\"ggplot\")\n",
  67. "print(f\"Running on PyMC3 v{pm.__version__}\")\n",
  68. "print(\"\")\n",
  69. "jtplot.style()\n",
  70. "\n",
  71. "pd.options.mode.chained_assignment = None # default=\"warn\"\n",
  72. "\n",
  73. "measles_data_all = pd.read_csv(\"hagelloch_cleaned.csv\",\n",
  74. " sep=\" \")\n",
  75. "\n",
  76. "## remove useless columns\n",
  77. "patients_data = measles_data_all[[\"PN\", \"FN\", \"HN\", \"AGE\", \"SEX\", \"CL\",\n",
  78. " \"HNX\", \"HNY\", \"IFTO\", \"PRO_DAYS\",\n",
  79. " \"ERU_DAYS\", \"DEAD_DAYS\"]].copy()\n",
  80. "\n",
  81. "patients_data.columns = [\"Patient_ID\",\n",
  82. " \"Family_ID\",\n",
  83. " \"Household_ID\",\n",
  84. " \"Age\",\n",
  85. " \"Sex\",\n",
  86. " \"Classroom\",\n",
  87. " \"Loc-x\",\n",
  88. " \"Loc-y\",\n",
  89. " \"Infector\",\n",
  90. " \"Prodromal_Onset\",\n",
  91. " \"Eruption_Onset\",\n",
  92. " \"Death_Day\"]\n",
  93. "\n",
  94. "\n",
  95. "death_mask = ~(patients_data[\"Death_Day\"] == -1)\n",
  96. "\n",
  97. "d = 3 # time still infectious after eruption\n",
  98. "\n",
  99. "patients_data[\"Died\"] = death_mask\n",
  100. "patients_data[\"Died\"] = patients_data[\"Died\"].astype(np.int)\n",
  101. "patients_data[\"Infectious_Onset\"] = patients_data[\"Prodromal_Onset\"] -1\n",
  102. "patients_data[\"Infectious_End\"] = np.where(death_mask,\n",
  103. " patients_data[\"Death_Day\"],\n",
  104. " (patients_data[\"Eruption_Onset\"] + d))\n",
  105. "patients_data[\"Infectious_Duration\"] = (patients_data[\"Infectious_End\"] -\n",
  106. " patients_data[\"Infectious_Onset\"])\n",
  107. "\n",
  108. "print(patients_data.head(10))\n",
  109. "print(\"hi\")"
  110. ]
  111. },
  112. {
  113. "cell_type": "code",
  114. "execution_count": 2,
  115. "metadata": {},
  116. "outputs": [],
  117. "source": [
  118. "def lognormal_parameters(desired_mean, desired_variance):\n",
  119. " mu = np.log(desired_mean / np.sqrt(1 + desired_variance / (desired_mean**2)))\n",
  120. " sigma = np.sqrt(np.log(1 + desired_variance/(desired_mean**2)))\n",
  121. " return (mu, sigma)\n",
  122. "\n",
  123. "class Observed_Outbreak:\n",
  124. " def __init__(self, data_frame, duration=None):\n",
  125. " self.data_frame = data_frame.copy()\n",
  126. " self.n_patients = self.data_frame.shape[0]\n",
  127. " self.slabs = {}\n",
  128. " if not (duration is None):\n",
  129. " self.duration = duration\n",
  130. " self.duration_bc = self.duration[None, :]\n",
  131. " \n",
  132. " def impute_latencies(self, latency_params):\n",
  133. " latencies = np.random.lognormal(*latency_params,\n",
  134. " self.n_patients).astype(np.int)\n",
  135. " self.data_frame[\"Latency_Onset\"] = self.data_frame[\"Infectious_Onset\"] - latencies\n",
  136. " earliest_time = self.data_frame[\"Latency_Onset\"].min()\n",
  137. " latest_time = self.data_frame[\"Infectious_End\"].max()\n",
  138. " self.duration = np.arange(earliest_time, latest_time+1)\n",
  139. " self.duration_bc = self.duration[None, :]\n",
  140. " \n",
  141. " def get_global_infectious(self, valid_days=None):\n",
  142. " infectious_onsets_bc = self.data_frame[\"Infectious_Onset\"][:, None]\n",
  143. " infectious_ends_bc = self.data_frame[\"Infectious_End\"][:, None]\n",
  144. " self.global_infectious_array = ((infectious_onsets_bc <= self.duration_bc) &\n",
  145. " (self.duration_bc < infectious_ends_bc)).T\n",
  146. " self.global_infectious = self.global_infectious_array.sum(axis=1)\n",
  147. " if valid_days is None:\n",
  148. " self.valid_days = self.global_infectious.copy() != 0\n",
  149. " else:\n",
  150. " self.valid_days = valid_days\n",
  151. " self.n_days = sum(self.valid_days)\n",
  152. " self.global_infectious_array = self.global_infectious_array[self.valid_days, :]\n",
  153. " self.global_infectious = self.global_infectious[self.valid_days]\n",
  154. " self.slabs[\"Community\"] = np.broadcast_to(self.global_infectious[:, None],\n",
  155. " (self.n_days, self.n_patients))\n",
  156. " \n",
  157. " def get_susceptible_states(self):\n",
  158. " latency_onsets_bc = self.data_frame[\"Latency_Onset\"][:, None]\n",
  159. " self.states_before = ((self.duration_bc < latency_onsets_bc).T).astype(np.int)\n",
  160. " self.states_after = np.roll(self.states_before, -1, axis=0)\n",
  161. " self.states_after[-1, :] = 0 \n",
  162. " self.states_before = self.states_before[self.valid_days, :]\n",
  163. " self.states_after = self.states_after[self.valid_days, :]\n",
  164. " \n",
  165. " def get_patient_distances(self):\n",
  166. " self.distance_array = np.zeros((self.n_patients, self.n_patients))\n",
  167. " for first in range(self.n_patients):\n",
  168. " first_patient = self.data_frame.iloc[first]\n",
  169. " first_x = first_patient[\"Loc-x\"]\n",
  170. " first_y = first_patient[\"Loc-y\"]\n",
  171. " \n",
  172. " dx2 = (first_x - self.data_frame[\"Loc-x\"])**2\n",
  173. " dy2 = (first_y - self.data_frame[\"Loc-y\"])**2\n",
  174. " distance_1d = np.sqrt(dx2 + dy2)\n",
  175. " self.distance_array[first, :] = distance_1d\n",
  176. " \n",
  177. " def get_distance_infectious(self):\n",
  178. " distance_slab = np.zeros((self.n_days, self.n_patients))\n",
  179. " for today in range(self.n_days):\n",
  180. " global_infectious_today = self.global_infectious_array[today, :]\n",
  181. " for patient in range(self.n_patients):\n",
  182. " distances_from_patient = self.distance_array[patient, :]\n",
  183. " distances_exped = np.exp(-distances_from_patient)\n",
  184. " product = global_infectious_today * distances_exped\n",
  185. " distance_slab[today, patient] = product.sum()\n",
  186. " self.slabs[\"Distance\"] = distance_slab\n",
  187. " \n",
  188. " def get_group_infectious(self, group_name, saturated=False):\n",
  189. " output_slab = np.zeros((self.n_days, self.n_patients))\n",
  190. " for group_id, group in self.data_frame.groupby(group_name):\n",
  191. " ## create an Observed Outbreak class for this particular group,\n",
  192. " ## but using the duration from the entire outbreak\n",
  193. " group_suboutbreak = Observed_Outbreak(group, self.duration)\n",
  194. " group_suboutbreak.get_global_infectious(self.valid_days)\n",
  195. " ## get number of infectious in this particular group for each day\n",
  196. " group_infectious = group_suboutbreak.global_infectious\n",
  197. " group_infectious_bc = np.broadcast_to(group_infectious[:, None],\n",
  198. " (self.n_days, self.n_patients))\n",
  199. " ## nonzero entries when group_id matches the column's (i.e. patient's) group\n",
  200. " group_id_mask = (self.data_frame[group_name] == group_id).values\n",
  201. " subslab = group_infectious_bc * group_id_mask\n",
  202. " output_slab += subslab.astype(np.int)\n",
  203. " if saturated:\n",
  204. " output_slab = output_slab.astype(np.bool).astype(np.int)\n",
  205. " self.slabs[group_name] = output_slab\n",
  206. " \n",
  207. " def get_classroom_infectious(self, classroom, infectious_override=False):\n",
  208. " data_frame_new = self.data_frame.copy()\n",
  209. " if infectious_override:\n",
  210. " data_frame_new[\"Infectious_End\"] = (data_frame_new[\"Infectious_Onset\"] +\n",
  211. " infectious_override)\n",
  212. " \n",
  213. " classroom_frame = data_frame_new[data_frame_new[\"Classroom\"] == classroom]\n",
  214. " classroom_suboutbreak = Observed_Outbreak(classroom_frame, self.duration)\n",
  215. " classroom_suboutbreak.get_global_infectious(self.valid_days)\n",
  216. " classroom_infectious = classroom_suboutbreak.global_infectious\n",
  217. " classroom_infectious_bc = np.broadcast_to(classroom_infectious[:, None],\n",
  218. " (self.n_days, self.n_patients))\n",
  219. " classroom_id_mask = (self.data_frame[\"Classroom\"] == classroom).values\n",
  220. " classroom_slab = classroom_infectious_bc * classroom_id_mask\n",
  221. " self.slabs[f\"Classroom_{classroom}\"] = classroom_slab\n",
  222. " \n"
  223. ]
  224. },
  225. {
  226. "cell_type": "code",
  227. "execution_count": null,
  228. "metadata": {},
  229. "outputs": [
  230. {
  231. "name": "stdout",
  232. "output_type": "stream",
  233. "text": [
  234. "0\n"
  235. ]
  236. },
  237. {
  238. "name": "stderr",
  239. "output_type": "stream",
  240. "text": [
  241. "Auto-assigning NUTS sampler...\n",
  242. "Initializing NUTS using jitter+adapt_diag...\n",
  243. "Multiprocess sampling (4 chains in 4 jobs)\n",
  244. "NUTS: [rv_q_sex, rv_q_family, rv_q_household, rv_q_class_2, rv_q_class_1, rv_q_community]\n",
  245. "Sampling 4 chains: 100%|██████████| 8000/8000 [03:28<00:00, 15.39draws/s]\n",
  246. "D:\\Anaconda3\\lib\\site-packages\\mkl_fft\\_numpy_fft.py:1044: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.\n",
  247. " output = mkl_fft.rfftn_numpy(a, s, axes)\n",
  248. "There was 1 divergence after tuning. Increase `target_accept` or reparameterize.\n",
  249. "There were 2 divergences after tuning. Increase `target_accept` or reparameterize.\n",
  250. "The acceptance probability does not match the target. It is 0.6616778933467038, but should be close to 0.8. Try to increase the number of tuning steps.\n"
  251. ]
  252. },
  253. {
  254. "name": "stdout",
  255. "output_type": "stream",
  256. "text": [
  257. " mean sd mc_error hpd_2.5 hpd_97.5 n_eff \\\n",
  258. "rv_q_community 0.994322 0.000972 0.000038 0.992559 0.996216 505.819056 \n",
  259. "rv_q_class_1 0.882522 0.058837 0.001676 0.768720 0.984137 1196.684376 \n",
  260. "rv_q_class_2 0.946839 0.018331 0.000500 0.914268 0.982833 1425.523988 \n",
  261. "rv_q_household 0.892061 0.052584 0.001962 0.790368 0.984600 697.508660 \n",
  262. "rv_q_family 0.772249 0.105571 0.004480 0.601709 0.998689 516.690902 \n",
  263. "rv_q_sex 0.996833 0.002024 0.000079 0.993130 0.999992 544.947321 \n",
  264. "\n",
  265. " Rhat \n",
  266. "rv_q_community 1.004532 \n",
  267. "rv_q_class_1 1.000721 \n",
  268. "rv_q_class_2 1.002293 \n",
  269. "rv_q_household 1.002365 \n",
  270. "rv_q_family 1.003017 \n",
  271. "rv_q_sex 1.006037 \n",
  272. "Model successfully dumped to file: models/model_sah_1_sat_0_2018-08-23T15-15-18.pkl.\n",
  273. "1\n"
  274. ]
  275. },
  276. {
  277. "name": "stderr",
  278. "output_type": "stream",
  279. "text": [
  280. "Auto-assigning NUTS sampler...\n",
  281. "Initializing NUTS using jitter+adapt_diag...\n",
  282. "Multiprocess sampling (4 chains in 4 jobs)\n",
  283. "NUTS: [rv_q_sex, rv_q_family, rv_q_household, rv_q_class_2, rv_q_class_1, rv_q_community]\n",
  284. "Sampling 4 chains: 90%|█████████ | 7229/8000 [02:53<00:16, 46.84draws/s]"
  285. ]
  286. }
  287. ],
  288. "source": [
  289. "latency_mean = 12\n",
  290. "latency_var = 5.433\n",
  291. "latency_parameters = lognormal_parameters(latency_mean, latency_var)\n",
  292. "stay_at_home = 1\n",
  293. "saturated = False\n",
  294. "\n",
  295. "n_samples = 500\n",
  296. "n_tune = 1500\n",
  297. "n_chains = 4\n",
  298. "\n",
  299. "model_name = \"model_sah_\"+str(int(stay_at_home))+\"_sat_\"+str(int(saturated))+\"_\"\n",
  300. "marginalization_mc = 66\n",
  301. "for i in range(marginalization_mc):\n",
  302. " print(i)\n",
  303. " outbreak = Observed_Outbreak(patients_data)\n",
  304. " outbreak.impute_latencies(latency_parameters)\n",
  305. " outbreak.get_global_infectious()\n",
  306. " outbreak.get_susceptible_states()\n",
  307. "\n",
  308. " outbreak.get_group_infectious(\"Household_ID\")\n",
  309. " outbreak.get_group_infectious(\"Family_ID\")\n",
  310. " outbreak.get_group_infectious(\"Sex\")\n",
  311. " outbreak.get_classroom_infectious(1, infectious_override=stay_at_home)\n",
  312. " outbreak.get_classroom_infectious(2, infectious_override=stay_at_home)\n",
  313. "\n",
  314. " community_slab = outbreak.slabs[\"Community\"]\n",
  315. " c1_slab = outbreak.slabs[\"Classroom_1\"]\n",
  316. " c2_slab = outbreak.slabs[\"Classroom_2\"]\n",
  317. " household_slab = outbreak.slabs[\"Household_ID\"]\n",
  318. " family_slab = outbreak.slabs[\"Family_ID\"]\n",
  319. " sex_slab = outbreak.slabs[\"Sex\"]\n",
  320. " s_before = outbreak.states_before\n",
  321. " s_after = outbreak.states_after\n",
  322. "\n",
  323. " this_model = pm.Model()\n",
  324. " with this_model:\n",
  325. " rv_q_community = pm.Uniform(\"rv_q_community\")\n",
  326. " ##rv_q_class_0 = pm.Uniform(\"rv_q_class_0\")\n",
  327. " rv_q_class_1 = pm.Uniform(\"rv_q_class_1\")\n",
  328. " rv_q_class_2 = pm.Uniform(\"rv_q_class_2\")\n",
  329. " rv_q_household = pm.Uniform(\"rv_q_household\")\n",
  330. " rv_q_family = pm.Uniform(\"rv_q_family\")\n",
  331. " rv_q_sex = pm.Uniform(\"rv_q_sex\")\n",
  332. "\n",
  333. " probabilities = ((rv_q_community**community_slab) *\n",
  334. " (rv_q_class_1**c1_slab) *\n",
  335. " (rv_q_class_2**c2_slab) *\n",
  336. " (rv_q_household**household_slab) *\n",
  337. " (rv_q_family**family_slab) *\n",
  338. " (rv_q_sex**sex_slab))\n",
  339. "\n",
  340. " ## probability is zero if already not susceptible\n",
  341. " probabilities = s_before * probabilities\n",
  342. " s_observed = pm.Bernoulli(\"s_observed\", p=probabilities, observed=s_after)\n",
  343. " this_trace = pm.sample(draws=n_samples, cores=n_chains,\n",
  344. " chains=n_chains, tune=n_tune)\n",
  345. " pm.traceplot(this_trace)\n",
  346. " ##print(\"Likelihood sampled using MCMC NUTS:\")\n",
  347. " ##plt.show()\n",
  348. " print(pm.summary(this_trace))\n",
  349. "\n",
  350. "\n",
  351. " now = datetime.now().strftime(\"%Y-%m-%dT%H-%M-%S\")\n",
  352. " dump_filename = \"models/\"+model_name+now+\".pkl\"\n",
  353. " os.makedirs(os.path.dirname(dump_filename), exist_ok=True)\n",
  354. " model_parameters = {\"latency_parameters\": latency_parameters,\n",
  355. " \"stay_at_home\": stay_at_home,\n",
  356. " \"saturated\": saturated,\n",
  357. " \"n_samples\": n_samples,\n",
  358. " \"n_tunes\": n_tune,\n",
  359. " \"n_chains\": n_chains}\n",
  360. "\n",
  361. " with open(dump_filename, \"wb\") as the_file:\n",
  362. " pickle.dump({\"model_parameters\": model_parameters,\n",
  363. " \"latencies\": outbreak.data_frame[\"Latency_Onset\"],\n",
  364. " \"model\": this_model,\n",
  365. " \"trace\": this_trace}, the_file)\n",
  366. "\n",
  367. " print(f\"Model successfully dumped to file: {dump_filename}.\")\n",
  368. "\n",
  369. "print(\"Marginalization finished.\")"
  370. ]
  371. },
  372. {
  373. "cell_type": "code",
  374. "execution_count": null,
  375. "metadata": {},
  376. "outputs": [],
  377. "source": []
  378. },
  379. {
  380. "cell_type": "code",
  381. "execution_count": null,
  382. "metadata": {},
  383. "outputs": [],
  384. "source": []
  385. }
  386. ],
  387. "metadata": {
  388. "kernelspec": {
  389. "display_name": "Python 3",
  390. "language": "python",
  391. "name": "python3"
  392. },
  393. "language_info": {
  394. "codemirror_mode": {
  395. "name": "ipython",
  396. "version": 3
  397. },
  398. "file_extension": ".py",
  399. "mimetype": "text/x-python",
  400. "name": "python",
  401. "nbconvert_exporter": "python",
  402. "pygments_lexer": "ipython3",
  403. "version": "3.6.6"
  404. }
  405. },
  406. "nbformat": 4,
  407. "nbformat_minor": 2
  408. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement