Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
- # maximum of 6 value cols -> 6 colors
- colorPalette = ['#ff0000','#00ff00','#0000ff','#ffff00','#00ffff']
- labelList = []
- colorNumList = []
- for catCol in cat_cols:
- labelListTemp = list(set(df[catCol].values))
- colorNumList.append(len(labelListTemp))
- labelList = labelList + labelListTemp
- # remove duplicates from labelList
- labelList = list(dict.fromkeys(labelList))
- # define colors based on number of levels
- colorList = []
- for idx, colorNum in enumerate(colorNumList):
- colorList = colorList + [colorPalette[idx]]*colorNum
- # transform df into a source-target pair
- for i in range(len(cat_cols)-1):
- if i==0:
- sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
- sourceTargetDf.columns = ['source','target','count']
- else:
- tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
- tempDf.columns = ['source','target','count']
- sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
- sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
- # add index for source-target pair
- sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
- sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
- #Source Color
- source_color = sourceTargetDf.source.unique()
- diccionario = {}
- for j in range(len(source_color)):
- diccionario.update({source_color[j]:'rgb({}, {}, {}, 0.05)'.format(np.random.randint(low = 0, high = 255),
- np.random.randint(low = 0, high = 255),
- np.random.randint(low = 0, high = 255))})
- sourceTargetDf['colors'] = sourceTargetDf['source'].map(diccionario)
- # creating the sankey diagram
- data = dict(
- type='sankey',
- node = dict(
- pad = 15,
- thickness = 20,
- line = dict(
- color = "black",
- width = 0.5
- ),
- label = labelList,
- color = colorList
- ),
- link = dict(
- source = sourceTargetDf['sourceID'],
- target = sourceTargetDf['targetID'],
- value = sourceTargetDf['count'],
- color = sourceTargetDf['colors'].dropna(axis=0, how='any')
- )
- )
- layout = dict(
- title = title,
- font = dict(
- size = 10
- )
- )
- fig = dict(data=[data], layout=layout)
- return fig
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement