Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- DB = 'postgres'
- USER = 'postgres'
- HOST = 'localhost'
- import psycopg2
- import matplotlib.pyplot as plt
- import mpl_toolkits.mplot3d.axes3d as p3
- import time
- import numpy as np
- from sklearn.cluster import AgglomerativeClustering
- from sklearn import datasets
- SCHEMA = 'scikit_learn'
- TABLE = SCHEMA + '.scikit_learn_table'
- connectionString = "dbname=%s user=%s host=%s" % (DB, USER, HOST)
- conn = psycopg2.connect(connectionString)
- create_schema = "CREATE SCHEMA "scikit_learn" ;"
- drop_table = "
- DROP TABLE
- "scikit_learn"."scikit_learn_table" "
- create_table = "
- CREATE TABLE IF NOT EXISTS
- "scikit_learn"."scikit_learn_table" (
- "CAL_ID" numeric not null,
- "IMPORTANCE" numeric default 0,
- "POINT_WGS" geometry not null,
- "CLUSTER_GEOM" geometry default null,
- "CLUSTER_ID" integer default null,
- "PARENT_ID" integer default null );"
- def insert_california_housing(X):
- cur = conn.cursor()
- clearq = "DELETE FROM "scikit_learn"."scikit_learn_table" ;"
- cur.execute(clearq)
- for i in range (len (XData)):
- q = "INSERT INTO "scikit_learn"."scikit_learn_table" (
- "CAL_ID" ,
- "IMPORTANCE" ,
- "POINT_WGS"
- ) VALUES (
- {i} ,
- {importance} / {maximp} ,
- ST_GeomFromEWKT ( 'SRID=4326;POINT({lon} {lat})' ) ); ".format
- ( i=i, importance=XData[i,2], maximp = XData[:,2].max(), lat = XData[i,0], lon = XData[i,1] )
- try:
- cur.execute(q)
- except Exception, e:
- print e , q
- conn.commit()
- cur.close()
- def update_california_housing_clusters(labels, keys):
- cur = conn.cursor()
- clearq = "UPDATE "scikit_learn"."scikit_learn_table" set "CLUSTER_GEOM" = null, "CLUSTER_ID" = 0 where "CLUSTER_ID" <> 0 OR "CLUSTER_GEOM" is NOT NULL;"
- try:
- cur.execute(clearq)
- except Exception, e:
- print e , clearq
- for i in range (len (keys)):
- q = "UPDATE "scikit_learn"."scikit_learn_table" set "CLUSTER_ID" = {cluster_fkey_value}
- where "CAL_ID" = {key_value} ; ".format
- ( key_value=i, cluster_fkey_value = labels[i] )
- try:
- cur.execute(q)
- except Exception, e:
- print e , q
- conn.commit()
- create = "update "scikit_learn"."scikit_learn_table"
- set "CLUSTER_GEOM" = nn.the_geom from (
- SELECT ST_ConvexHull(ST_Collect("POINT_WGS")) As the_geom, "CLUSTER_ID"
- FROM "scikit_learn"."scikit_learn_table" GROUP BY "CLUSTER_ID"
- ) nn
- where "scikit_learn"."scikit_learn_table"."CLUSTER_ID" <> 0
- and nn."CLUSTER_ID" = "scikit_learn"."scikit_learn_table"."CLUSTER_ID" ; "
- try:
- cur.execute(create)
- except Exception, e:
- print 'Exception', e
- cur.close()
- cur = conn.cursor()
- regGeom = "INSERT INTO geometry_columns(f_table_catalog, f_table_schema, f_table_name, f_geometry_column, coord_dimension, srid, "type")
- SELECT '', 'scikit_learn', 'scikit_learn_table', 'CLUSTER_GEOM', ST_CoordDim("CLUSTER_GEOM"), ST_SRID("CLUSTER_GEOM"), GeometryType("CLUSTER_GEOM")
- FROM "scikit_learn"."scikit_learn_table" LIMIT 1;"
- try:
- cur.execute(regGeom)
- except Exception, e:
- print 'Exception', e
- cur.close()
- cur = conn.cursor()
- regGeomPoint = "INSERT INTO geometry_columns(f_table_catalog, f_table_schema, f_table_name, f_geometry_column, coord_dimension, srid, "type")
- SELECT '', 'scikit_learn', 'scikit_learn_table', 'POINT_WGS', ST_CoordDim("POINT_WGS"), ST_SRID("POINT_WGS"), GeometryType("POINT_WGS")
- FROM "scikit_learn"."scikit_learn_table" LIMIT 1;"
- try:
- cur.execute(regGeom)
- except Exception, e:
- print 'Exception', e
- cur.close()
- cur = conn.cursor()
- conn.commit()
- cur.close()
- #########################
- main_cur = conn.cursor()
- try:
- main_cur.execute(create_schema)
- except Exception, e:
- print 'schema may already have been created ' , e
- main_cur.close()
- conn.commit()
- main_cur = conn.cursor()
- try:
- main_cur.execute(drop_table)
- except Exception, e:
- print 'table may not already have been created ', e
- main_cur.close()
- conn.commit()
- main_cur = conn.cursor()
- try:
- main_cur.execute(create_table)
- except Exception, e:
- print 'table may already have been created ', e
- main_cur.close()
- conn.commit()
- main_cur = conn.cursor()
- conn.commit()
- main_cur.close()
- ###################
- #note this downloads 500K of test data
- X = datasets.fetch_california_housing ()
- n_samples = len (X.data)
- XData = X.data
- #put x, y first
- rearrange = np.array([6, 7, 0, 1, 2, 3, 4, 5])
- XData = XData [:, rearrange]
- insert_california_housing ( XData )
- XData = XData [:,0:2]
- from sklearn.neighbors import kneighbors_graph
- connectivity = kneighbors_graph(XData, n_neighbors=10)
- st = time.time()
- ward = AgglomerativeClustering(n_clusters=n_samples/10, connectivity=connectivity,
- linkage='ward').fit(XData)
- elapsed_time = time.time() - st
- labels = ward.labels_
- print("Elapsed time: %.2fs" % elapsed_time)
- print("Number of points: %i" % labels.size)
- update_california_housing_clusters(labels, XData)
- # ###############################################################################
- # # Plot result
- fig = plt.figure()
- ax = p3.Axes3D(fig)
- ax.view_init(7, -80)
- for l in np.unique(labels):
- ax.plot3D(XData[labels == l, 0], XData[labels== l, 1], XData[labels == l, 3],
- 'o', color=plt.cm.jet(float(l) / np.max(labels + 1)))
- plt.title('With connectivity constraints (time %.2fs)' % elapsed_time)
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement