Advertisement
Guest User

Untitled

a guest
Sep 20th, 2014
220
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.55 KB | None | 0 0
  1. DB = 'postgres'
  2. USER = 'postgres'
  3. HOST = 'localhost'
  4.  
  5. import psycopg2
  6.  
  7. import matplotlib.pyplot as plt
  8. import mpl_toolkits.mplot3d.axes3d as p3
  9. import time
  10. import numpy as np
  11. from sklearn.cluster import AgglomerativeClustering
  12. from sklearn import datasets
  13.  
  14. SCHEMA = 'scikit_learn'
  15. TABLE = SCHEMA + '.scikit_learn_table'
  16.  
  17. connectionString = "dbname=%s user=%s host=%s" % (DB, USER, HOST)
  18. conn = psycopg2.connect(connectionString)
  19.  
  20. create_schema = "CREATE SCHEMA "scikit_learn" ;"
  21.  
  22. drop_table = "
  23. DROP TABLE
  24. "scikit_learn"."scikit_learn_table" "
  25.  
  26. create_table = "
  27. CREATE TABLE IF NOT EXISTS
  28. "scikit_learn"."scikit_learn_table" (
  29. "CAL_ID" numeric not null,
  30. "IMPORTANCE" numeric default 0,
  31. "POINT_WGS" geometry not null,
  32. "CLUSTER_GEOM" geometry default null,
  33. "CLUSTER_ID" integer default null,
  34. "PARENT_ID" integer default null );"
  35.  
  36. def insert_california_housing(X):
  37. cur = conn.cursor()
  38. clearq = "DELETE FROM "scikit_learn"."scikit_learn_table" ;"
  39. cur.execute(clearq)
  40. for i in range (len (XData)):
  41. q = "INSERT INTO "scikit_learn"."scikit_learn_table" (
  42. "CAL_ID" ,
  43. "IMPORTANCE" ,
  44. "POINT_WGS"
  45. ) VALUES (
  46. {i} ,
  47. {importance} / {maximp} ,
  48. ST_GeomFromEWKT ( 'SRID=4326;POINT({lon} {lat})' ) ); ".format
  49. ( i=i, importance=XData[i,2], maximp = XData[:,2].max(), lat = XData[i,0], lon = XData[i,1] )
  50. try:
  51. cur.execute(q)
  52. except Exception, e:
  53. print e , q
  54. conn.commit()
  55. cur.close()
  56.  
  57. def update_california_housing_clusters(labels, keys):
  58. cur = conn.cursor()
  59. clearq = "UPDATE "scikit_learn"."scikit_learn_table" set "CLUSTER_GEOM" = null, "CLUSTER_ID" = 0 where "CLUSTER_ID" <> 0 OR "CLUSTER_GEOM" is NOT NULL;"
  60. try:
  61. cur.execute(clearq)
  62. except Exception, e:
  63. print e , clearq
  64.  
  65. for i in range (len (keys)):
  66. q = "UPDATE "scikit_learn"."scikit_learn_table" set "CLUSTER_ID" = {cluster_fkey_value}
  67. where "CAL_ID" = {key_value} ; ".format
  68. ( key_value=i, cluster_fkey_value = labels[i] )
  69. try:
  70. cur.execute(q)
  71. except Exception, e:
  72. print e , q
  73. conn.commit()
  74. create = "update "scikit_learn"."scikit_learn_table"
  75. set "CLUSTER_GEOM" = nn.the_geom from (
  76. SELECT ST_ConvexHull(ST_Collect("POINT_WGS")) As the_geom, "CLUSTER_ID"
  77. FROM "scikit_learn"."scikit_learn_table" GROUP BY "CLUSTER_ID"
  78. ) nn
  79. where "scikit_learn"."scikit_learn_table"."CLUSTER_ID" <> 0
  80. and nn."CLUSTER_ID" = "scikit_learn"."scikit_learn_table"."CLUSTER_ID" ; "
  81. try:
  82. cur.execute(create)
  83. except Exception, e:
  84. print 'Exception', e
  85. cur.close()
  86. cur = conn.cursor()
  87.  
  88. regGeom = "INSERT INTO geometry_columns(f_table_catalog, f_table_schema, f_table_name, f_geometry_column, coord_dimension, srid, "type")
  89. SELECT '', 'scikit_learn', 'scikit_learn_table', 'CLUSTER_GEOM', ST_CoordDim("CLUSTER_GEOM"), ST_SRID("CLUSTER_GEOM"), GeometryType("CLUSTER_GEOM")
  90. FROM "scikit_learn"."scikit_learn_table" LIMIT 1;"
  91. try:
  92. cur.execute(regGeom)
  93. except Exception, e:
  94. print 'Exception', e
  95. cur.close()
  96. cur = conn.cursor()
  97.  
  98. regGeomPoint = "INSERT INTO geometry_columns(f_table_catalog, f_table_schema, f_table_name, f_geometry_column, coord_dimension, srid, "type")
  99. SELECT '', 'scikit_learn', 'scikit_learn_table', 'POINT_WGS', ST_CoordDim("POINT_WGS"), ST_SRID("POINT_WGS"), GeometryType("POINT_WGS")
  100. FROM "scikit_learn"."scikit_learn_table" LIMIT 1;"
  101. try:
  102. cur.execute(regGeom)
  103. except Exception, e:
  104. print 'Exception', e
  105. cur.close()
  106. cur = conn.cursor()
  107. conn.commit()
  108. cur.close()
  109.  
  110. #########################
  111. main_cur = conn.cursor()
  112. try:
  113. main_cur.execute(create_schema)
  114. except Exception, e:
  115. print 'schema may already have been created ' , e
  116. main_cur.close()
  117. conn.commit()
  118. main_cur = conn.cursor()
  119.  
  120. try:
  121. main_cur.execute(drop_table)
  122. except Exception, e:
  123. print 'table may not already have been created ', e
  124. main_cur.close()
  125. conn.commit()
  126. main_cur = conn.cursor()
  127.  
  128. try:
  129. main_cur.execute(create_table)
  130. except Exception, e:
  131. print 'table may already have been created ', e
  132. main_cur.close()
  133. conn.commit()
  134. main_cur = conn.cursor()
  135.  
  136. conn.commit()
  137. main_cur.close()
  138.  
  139. ###################
  140. #note this downloads 500K of test data
  141. X = datasets.fetch_california_housing ()
  142.  
  143. n_samples = len (X.data)
  144. XData = X.data
  145.  
  146. #put x, y first
  147. rearrange = np.array([6, 7, 0, 1, 2, 3, 4, 5])
  148. XData = XData [:, rearrange]
  149. insert_california_housing ( XData )
  150. XData = XData [:,0:2]
  151. from sklearn.neighbors import kneighbors_graph
  152. connectivity = kneighbors_graph(XData, n_neighbors=10)
  153.  
  154. st = time.time()
  155. ward = AgglomerativeClustering(n_clusters=n_samples/10, connectivity=connectivity,
  156. linkage='ward').fit(XData)
  157. elapsed_time = time.time() - st
  158. labels = ward.labels_
  159. print("Elapsed time: %.2fs" % elapsed_time)
  160. print("Number of points: %i" % labels.size)
  161.  
  162. update_california_housing_clusters(labels, XData)
  163.  
  164. # ###############################################################################
  165. # # Plot result
  166. fig = plt.figure()
  167. ax = p3.Axes3D(fig)
  168. ax.view_init(7, -80)
  169. for l in np.unique(labels):
  170. ax.plot3D(XData[labels == l, 0], XData[labels== l, 1], XData[labels == l, 3],
  171. 'o', color=plt.cm.jet(float(l) / np.max(labels + 1)))
  172. plt.title('With connectivity constraints (time %.2fs)' % elapsed_time)
  173.  
  174. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement