Advertisement
Guest User

Untitled

a guest
Oct 23rd, 2019
137
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.64 KB | None | 0 0
  1. def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
  2. # maximum of 6 value cols -> 6 colors
  3. colorPalette = ['#ff0000','#00ff00','#0000ff','#ffff00','#00ffff']
  4. labelList = []
  5. colorNumList = []
  6. for catCol in cat_cols:
  7. labelListTemp = list(set(df[catCol].values))
  8. colorNumList.append(len(labelListTemp))
  9. labelList = labelList + labelListTemp
  10.  
  11. # remove duplicates from labelList
  12. labelList = list(dict.fromkeys(labelList))
  13.  
  14. # define colors based on number of levels
  15. colorList = []
  16. for idx, colorNum in enumerate(colorNumList):
  17. colorList = colorList + [colorPalette[idx]]*colorNum
  18.  
  19. # transform df into a source-target pair
  20. for i in range(len(cat_cols)-1):
  21. if i==0:
  22. sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
  23. sourceTargetDf.columns = ['source','target','count']
  24. else:
  25. tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
  26. tempDf.columns = ['source','target','count']
  27. sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
  28. sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
  29.  
  30. # add index for source-target pair
  31. sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
  32. sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
  33.  
  34. #Source Color
  35. source_color = sourceTargetDf.source.unique()
  36. diccionario = {}
  37. for j in range(len(source_color)):
  38. diccionario.update({source_color[j]:'rgb({}, {}, {}, 0.05)'.format(np.random.randint(low = 0, high = 255),
  39. np.random.randint(low = 0, high = 255),
  40. np.random.randint(low = 0, high = 255))})
  41.  
  42. sourceTargetDf['colors'] = sourceTargetDf['source'].map(diccionario)
  43.  
  44. # creating the sankey diagram
  45. data = dict(
  46. type='sankey',
  47. node = dict(
  48. pad = 15,
  49. thickness = 20,
  50. line = dict(
  51. color = "black",
  52. width = 0.5
  53. ),
  54. label = labelList,
  55. color = colorList
  56. ),
  57. link = dict(
  58. source = sourceTargetDf['sourceID'],
  59. target = sourceTargetDf['targetID'],
  60. value = sourceTargetDf['count'],
  61. color = sourceTargetDf['colors'].dropna(axis=0, how='any')
  62. )
  63. )
  64.  
  65. layout = dict(
  66. title = title,
  67. font = dict(
  68. size = 10
  69. )
  70. )
  71.  
  72. fig = dict(data=[data], layout=layout)
  73.  
  74.  
  75. return fig
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement