• API
• FAQ
• Tools
• Archive
SHARE
TWEET # Untitled a guest Oct 21st, 2019 101 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. import streamlit as st
2. from datetime import time
3. from datetime import date
4. import pandas as pd
5. import matplotlib.pyplot as plt
6. import seaborn as sns
7. import plotly.figure_factory as ff
8. import plotly.graph_objs  as go
9. from sklearn.linear_model import LinearRegression
10. from sklearn.model_selection import train_test_split
11. from sklearn.metrics import mean_squared_error, r2_score
12. from math import sqrt
13.
14. import numpy as np
15. sns.set_style("darkgrid")
16.
17.
19.
20. st.markdown("""
23.
24.     ## Problem Statement
25.     Sales (in thousands of units) for a particular product as a function of advertising budgets (in thousands of
26.     dollars) for TV, radio, and newspaper media. Suppose that in our role as statistical consultants we are
28.
29.     Here are a few important questions that you might seek to address:
30.     - Is there a relationship between advertising budget and sales?
31.     - How strong is the relationship between the advertising budget and sales?
32.     - Which media contribute to sales?
33.     - How accurately can we estimate the effect of each medium on sales?
34.     - How accurately can we predict future sales?
35.     - Is the relationship linear?
36.
37.     We want to find a function that given input budgets for TV, radio and newspaper predicts the output sales
38.     and visualize the relationship between the features and the response using scatter plots.
39.
40.     The objective is to use linear regression to understand how advertisement spending impacts sales.
41.
42.     ### Data Description
43.     TV
45.     Newspaper
46.     Sales
47.
48.
49. """)
50. st.sidebar.title("Operations on the Dataset")
51.
53. w1 = st.sidebar.checkbox("show table", False)
54. plot= st.sidebar.checkbox("show plots", False)
55. plothist= st.sidebar.checkbox("show hist plots", False)
56. trainmodel= st.sidebar.checkbox("Train model", False)
57. dokfold= st.sidebar.checkbox("DO KFold", False)
58. distView=st.sidebar.checkbox("Dist View", False)
59. _3dplot=st.sidebar.checkbox("3D plots", False)
60. linechart=st.sidebar.checkbox("Linechart",False)
61. #st.write(w1)
62.
63.
64. @st.cache
67.
69.
70. #st.write(df)
71. if w1:
72.     st.dataframe(df,width=2000,height=500)
73. if linechart:
75.     st.line_chart(df)
76. if plothist:
79.     sel_cols = st.selectbox("select columns", options,1)
80.     st.write(sel_cols)
81.     #f=plt.figure()
82.     fig = go.Histogram(x=df[sel_cols],nbinsx=50)
83.     st.plotly_chart([fig])
84.
85.
86. #    plt.hist(df[sel_cols])
87. #    plt.xlabel(sel_cols)
88. #    plt.ylabel("sales")
89. #    plt.title(f"{sel_cols} vs Sales")
90.     #plt.show()
91. #    st.plotly_chart(f)
92.
93. if plot:
96.     w7 = st.selectbox("Ad medium", options,1)
97.     st.write(w7)
98.     f=plt.figure()
99.     plt.scatter(df[w7],df["sales"])
100.     plt.xlabel(w7)
101.     plt.ylabel("sales")
102.     plt.title(f"{w7} vs Sales")
103.     #plt.show()
104.     st.plotly_chart(f)
105.
106.
107. if distView:
110.
111.     # Group data together
113.
114.     group_labels = ["TV", "Radio", "newspaper"]
115.
116.     # Create distplot with custom bin_size
117.     fig = ff.create_distplot(hist_data, group_labels, bin_size=[0.1, 0.25, 0.5])
118.
119.     # Plot!
120.     st.plotly_chart(fig)
121.
122. if _3dplot:
123.     options = st.multiselect(
125.     st.write('You selected:', options)
128.
129.     #x, y, z = np.random.multivariate_normal(np.array([0, 0, 0]), np.eye(3), 400).transpose()
130.     trace1 = go.Scatter3d(
131.         x=hist_data,
132.         y=hist_data,
133.         z=df["sales"].values,
134.         mode="markers",
135.         marker=dict(
136.             size=8,
137.             #color=df['sales'],  # set color to an array/list of desired values
138.             colorscale="Viridis",  # choose a colorscale
139.     #        opacity=0.,
140.         ),
141.     )
142.
143.     data = [trace1]
144.     layout = go.Layout(margin=dict(l=0, r=0, b=0, t=0))
145.     fig = go.Figure(data=data, layout=layout)
146.     st.write(fig)
147.
148.
149.
150. # trainmodel= st.checkbox("Train model", False)
151.
152. if trainmodel:
154.     y=df.sales
156.     X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
157.
158.     lrgr = LinearRegression()
159.     lrgr.fit(X_train,y_train)
160.     pred = lrgr.predict(X_test)
161.
162.     mse = mean_squared_error(y_test,pred)
163.     rmse = sqrt(mse)
164.
165.     st.markdown(f"""
166.
167.     Linear Regression model trained :
168.         - MSE:{mse}
169.         - RMSE:{rmse}
170.     """)
171.     st.success('Model trained successfully')
172.
173.
174. if dokfold:
176.     st.empty()
177.     my_bar = st.progress(0)
178.
179.     from sklearn.model_selection import KFold
180.
181.     X=df.values[:,-1].reshape(-1,1)
182.     y=df.values[:,-1]
183.     #st.progress()
184.     kf=KFold(n_splits=10)
185.     #X=X.reshape(-1,1)
186.     mse_list=[]
187.     rmse_list=[]
188.     r2_list=[]
189.     idx=1
190.     fig=plt.figure()
191.     i=0
192.     for train_index, test_index in kf.split(X):
193.     #   st.progress()
194.         my_bar.progress(idx*10)
195.         X_train, X_test = X[train_index], X[test_index]
196.         y_train, y_test = y[train_index], y[test_index]
197.         lrgr = LinearRegression()
198.         lrgr.fit(X_train,y_train)
199.         pred = lrgr.predict(X_test)
200.
201.         mse = mean_squared_error(y_test,pred)
202.         rmse = sqrt(mse)
203.         r2=r2_score(y_test,pred)
204.         mse_list.append(mse)
205.         rmse_list.append(rmse)
206.         r2_list.append(r2)
207.         plt.plot(pred,label=f"dataset-{idx}")
208.         idx+=1
209.     plt.legend()
210.     plt.xlabel("Data points")
211.     plt.ylabel("PRedictions")
212.     plt.show()
213.     st.plotly_chart(fig)
214.
215.     res=pd.DataFrame(columns=["MSE","RMSE","r2_SCORE"])
216.     res["MSE"]=mse_list
217.     res["RMSE"]=rmse_list
218.     res["r2_SCORE"]=r2_list
219.
220.     st.write(res)
221.     st.balloons()