Advertisement
VisualPaul

Untitled

Apr 8th, 2016
330
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.28 KB | None | 0 0
  1. from sklearn.cluster import KMeans
  2. from sklearn.preprocessing import OneHotEncoder
  3. from geopy.distance import vincenty
  4. from sklearn.neighbors import BallTree
  5.  
  6. smallKMeans = KMeans(n_clusters=3).fit(array([[ 55.73008165, 37.59531199], [ 59.91301691, 30.31944249],[ 55.67814337, 46.11249841]]))
  7. holidays = "1.01,2.01,3.01,4.01,5.01,6.01,7.01,8.01,23.02,8.03,9.03,10.03,1.05,2.05,3.05,4.05,9.05,10.05,11.05,12.06,13.06,14.06,15.06".split(',')
  8. holidays = set(tuple(map(int, x.split('.'))) for x in holidays)
  9.  
  10. def get_features(data):
  11. dist = data.dist.values
  12. lat, lon = data.lat.values, data.lon.values
  13. weekday, month = data.day_of_week.values, data.month.values
  14. hourx, houry = cos(data.hour / 23), sin(data.hour / 23)
  15. hour = data.hour
  16. hota, hotb, hotc = zeros_like(hour, dtype=float32), zeros_like(hour, dtype=float32), zeros_like(hour, dtype=float32)
  17.  
  18. hota[(data.f_class == 'econom').values] += 1.00
  19. hota[(data.s_class == 'econom').values] += 0.50
  20. hota[(data.t_class == 'econom').values] += 0.25
  21.  
  22. hotb[(data.f_class == 'business').values] += 1.00
  23. hotb[(data.s_class == 'business').values] += 0.50
  24. hotb[(data.t_class == 'business').values] += 0.25
  25.  
  26. hotc[(data.f_class == 'vip').values] += 1.00
  27. hotc[(data.s_class == 'vip').values] += 0.50
  28. hotc[(data.t_class == 'vip').values] += 0.25
  29.  
  30. isHoliday = array([[1 if x in holidays else 0] for x in zip(data['day'], data['month'])])
  31. city = smallKMeans.predict(ds[['lat', 'lon']])
  32. smallClusters = OneHotEncoder().fit_transform(city.reshape(-1, 1)).toarray()
  33. cityDistance = array([vincenty(smallKMeans.cluster_centers_[city[i]], (data['lat'][i], data['lon'][i])).meters for i in range(len(data))])
  34. ballTree = BallTree(data[['lat', 'lon']])
  35. coord = array(list(zip(data['lat'], data['lon'])))
  36. sumDist = array([sum(vincenty(coord[i], coord[x]).meters for x in ballTree.query(coord[i].reshape(1, -1), 5)[0]) for i in range(len(coord))])
  37.  
  38. weekday = OneHotEncoder().fit_transform(ds['day_of_week'].reshape(-1, 1)).toarray()
  39. features = array(list(zip(dist, lat, lon, month, hourx, houry, hour, hota, hotb, hotc)))
  40. features = hstack((features, smallClusters, isHoliday, weekday, cityDistance.reshape(-1, 1), sumDist.reshape(-1, 1)))
  41. return features
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement