Advertisement
Guest User

Untitled

a guest
Sep 21st, 2019
119
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 14.43 KB | None | 0 0
  1. import seaborn as sns
  2. import pandas as pd
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5.  
  6. # agents_name = ["ours", "DDPG"]
  7. x_name = "uniqueness"
  8. y_name = "performance"
  9. dpi = 300
  10. figsize = (12, 8)
  11. xlim = None # (1, 4.5)
  12. ylim = None # (-100, 1100)
  13.  
  14. def build_dataframe(matrix, performance, name):
  15. # print(name)
  16. df = []
  17. for agent_id, (row, p) in enumerate(zip(matrix, performance)):
  18. for other_id, other_val in enumerate(row):
  19. if other_id == agent_id:
  20. continue
  21. df.append({
  22. "me": agent_id,
  23. "other": other_id,
  24. x_name: other_val,
  25. y_name: p,
  26. "agent": name
  27. })
  28. # print(df)
  29. return pd.DataFrame(df)
  30.  
  31.  
  32. def build_merge_dataframe(name_matrix_perf_dict):
  33. df = None
  34. for name, (matrix, perf) in name_matrix_perf_dict.items():
  35. if df is None:
  36. df = build_dataframe(matrix, perf, name)
  37. else:
  38. df = df.append(build_dataframe(matrix, perf, name))
  39. return df
  40.  
  41. def draw(df):
  42. num_agents = len(df.me.unique())
  43. num_groups = len(df.agent.unique())
  44. agents_name = list(df.agent.unique())
  45. sns.set(style="white", rc={'figure.figsize':figsize, 'figure.dpi': dpi})
  46. g = sns.JointGrid(x=x_name, y=y_name, data=df, xlim=xlim, ylim=ylim)
  47.  
  48. def _draw(x, y, *args, **kw):
  49. x = x.reshape(num_groups, num_agents, num_agents-1)
  50. y = y.reshape(num_groups, num_agents, num_agents-1)
  51. reconstruct = []
  52. for agent, group_x, group_y in zip(agents_name, x, y):
  53. for me, (agent_x, agent_y) in enumerate(zip(group_x, group_y)):
  54. reconstruct.append({
  55. x_name: agent_x.min(), # HERE! @SUNHAO
  56. y_name: agent_y.min(), # HERE! @SUNHAO
  57. "me": me,
  58. "agent": agent
  59. })
  60. reconstruct = pd.DataFrame(reconstruct)
  61. sns.scatterplot(data=reconstruct, x=x_name, y=y_name, hue="agent")# style='agent',
  62.  
  63. g = g.plot_joint(_draw)
  64.  
  65. def _draw_marginal(data, vertical):
  66. data = data.reshape(num_groups, num_agents, num_agents-1)
  67. if vertical:
  68. plot_data = data.mean(axis=2)
  69. else:
  70. plot_data = data.min(axis=2)
  71. for d in plot_data:
  72. print(d.mean())
  73. sns.kdeplot(d, vertical=vertical)
  74. g.plot_marginals(_draw_marginal)
  75.  
  76.  
  77.  
  78. # Draw DDPG
  79.  
  80. import json
  81. data = data = json.loads('{"walker2d_max_number_ours": [[0.0, 2.88953722490357, 3.022784289161591, 3.840421080213634, 3.09722991600208, 2.873850139834197, 2.853839487844594, 1.8387867344135893, 2.762383673340883, 3.4146540462963304], [2.458295183400117, 0.0, 3.129153521917165, 2.9393374679101933, 2.727387307650072, 2.7541793317957812, 3.3297877239920126, 2.6531661895986085, 2.4947579428891267, 3.2814910413399034], [2.342549697912158, 3.084715377267008, 0.0, 3.029600557284626, 2.834681980776867, 3.163460951827083, 3.00056136041045, 1.7253983044328365, 2.496222965085463, 3.2313406544690935], [2.699587345479628, 2.898538320908109, 3.034193626605158, 0.0, 2.795924686777543, 2.3922472367554346, 2.0117530561619126, 2.687938691534147, 3.21254662030195, 2.5681170972226206], [2.617154017230794, 2.385957566826345, 2.7686206110456304, 2.9999514238186356, 0.0, 2.802059495732546, 2.419418149409109, 2.452975929284793, 2.942619506673026, 2.5272353649812396], [2.421765742338517, 2.534225489963491, 3.405945216252822, 2.484921653137105, 2.7883689361010737, 0.0, 2.5050103228556835, 2.6275862985817326, 3.324875952922372, 2.8033926411224437], [2.47401905794973, 2.7225017515241166, 2.9495417549049514, 2.390772211335728, 2.7714218706338167, 2.716286539139286, 0.0, 2.375898443611422, 2.849160649489158, 2.7597931345585303], [2.136356668412712, 2.972189261636036, 2.932498573666196, 4.288715460766748, 2.7224164055691897, 3.4468895929186885, 2.7972489455369036, 0.0, 2.5779986577595895, 3.9199779616675485], [2.1949050797475715, 3.0156242411440735, 2.2084298916667544, 3.1021347155323786, 3.262171866606589, 2.8691368861011033, 3.325638577282577, 2.407346270030818, 0.0, 3.5023687506672205], [3.206195642639988, 3.4754804900122886, 2.4456670557815574, 2.7119826115021213, 2.9457660228115787, 2.7034322252568344, 2.9137183321424622, 2.6446686976333775, 3.0369404023155004, 0.0]], "walker2d_max_number_baseline": [[0.0, 2.29576748, 2.52579896, 2.88362751, 2.08214255, 2.64223037, 2.08666234, 2.43250145, 2.65218304, 2.45631894], [1.993507, 0.0, 2.5575764, 2.80604826, 2.13127137, 1.88994288, 2.11734079, 2.74962078, 1.34913155, 2.34690375], [2.16798781, 2.46710363, 0.0, 2.51245988, 2.05241123, 1.98162312, 1.89625356, 2.2304532, 2.38115364, 2.08036499], [2.35368704, 2.10524029, 2.28512216, 0.0, 2.78916009, 2.2073295, 2.20244733, 2.49064934, 2.45962174, 1.99736344], [2.20549185, 1.96470518, 2.32841908, 3.0773535, 0.0, 1.97001718, 2.33341498, 2.42369041, 2.02042056, 2.1379389], [2.13478754, 1.78614258, 1.95971379, 2.95214276, 1.74279205, 0.0, 2.11534533, 2.89788501, 1.16709214, 2.1787284], [2.26084103, 1.89432125, 2.48378627, 2.30890816, 1.94309234, 2.09896865, 0.0, 2.19293059, 1.93878875, 2.52055762], [2.51508277, 2.58875395, 2.45424625, 2.70138026, 2.05176432, 2.22905602, 2.4253792, 0.0, 1.86562291, 2.65833271], [2.37073205, 2.06878469, 2.94265596, 2.76383859, 1.95884215, 2.38344007, 2.24114604, 2.42976342, 0.0, 2.79148603], [2.35306189, 2.32941189, 2.01619864, 2.80295485, 2.04776463, 2.16971383, 2.65009399, 2.08498885, 2.09189488, 0.0]], "halfcheetah_max_number_ours": [[0.0, 2.8071758561845064, 3.1080679100655613, 2.375455593257357, 3.1511008156897633, 2.7989483142839737, 2.8551141039083365, 2.737426404697361, 3.470051251513629, 3.5033105921286762], [2.869897237262939, 0.0, 3.416831683904215, 2.978604047585906, 2.8194653906120632, 2.6463849766733896, 3.1986719218788684, 2.633681435624167, 3.175908455209049, 3.430492386257123], [3.035501598993617, 3.305309391797198, 0.0, 2.9419621864538428, 2.756344134873879, 3.2525088414699668, 2.7157455137538937, 3.2496443873030114, 2.7798088445682754, 3.283096803039421], [2.4101875382125395, 3.0544852860414005, 2.89242985684197, 0.0, 2.9973913664944036, 2.4770837357504463, 2.8819066893493863, 2.723098419284253, 2.799188867111816, 3.6045490943020733], [2.535644197744192, 2.8243058416257796, 3.164864352090903, 2.4719239629450795, 0.0, 2.2624804451347273, 2.755781520543618, 2.5251185029669267, 3.2137257085509083, 3.385939302571523], [2.481994994005666, 2.92978393781533, 3.163880972160604, 2.285488071987585, 2.7254592889270826, 0.0, 2.69201777198226, 2.580512016897533, 3.320058926059918, 3.4180481397053084], [2.785788394519207, 3.1785039381099756, 3.0071676496708792, 2.689090777652821, 2.9945517972466593, 2.445088416240049, 0.0, 3.0360942479600666, 3.036104488147577, 2.6534024827940446], [2.7341215003910535, 2.762292858523859, 3.0275086804166405, 2.691443870890807, 2.6310517890743634, 2.4242804144322694, 2.9269646506568794, 0.0, 2.751129658393651, 3.313920643253297], [3.143761803105406, 3.202881987745064, 2.7033984872115395, 2.7742044277358278, 3.1690406155377855, 2.8366159152793466, 2.7059889509300974, 2.704211978825707, 0.0, 3.085900698572481], [2.799286554012256, 3.2892659442652366, 3.0089406091775985, 3.463984403603907, 3.49096528446259, 3.485483943740849, 3.396108693629085, 3.111451550680112, 3.077128187233566, 0.0]], "halfcheetah_max_number_baseline": [[0.0, 1.7483053266994144, 2.894907343347984, 2.255704648541671, 2.1991540770562747, 1.937740631344138, 2.0017303705349394, 2.8049173445775657, 2.0177797247054423, 2.006515488309106], [1.359436046510323, 0.0, 2.9540176424311477, 2.1643299451251092, 1.8700850479549254, 1.0562655200954432, 1.4759903696652736, 2.575798607842702, 1.4591658126363973, 1.0971954031343245], [2.800382693223351, 3.0903680099301276, 0.0, 2.873489211875157, 3.1738110793753673, 2.733209961297599, 3.0832761092774645, 2.779985010903522, 3.1572751076060515, 3.4702220704030244], [2.12293021079454, 2.5285957836631106, 2.8428564907976854, 0.0, 2.5553093852264763, 2.5026791712508905, 2.3797309541897085, 2.0280032586913577, 2.6477992840409486, 3.5502767337354992], [1.9449482158654345, 1.8140329483743587, 3.2029499244669646, 2.1437843355457096, 0.0, 2.26745384789745, 1.818075753761372, 2.2563079275377578, 2.033386244112059, 2.0651671542834107], [1.2973285672029784, 2.9140286282642247, 2.8805756090994334, 2.362296789240141, 2.1174524614211383, 0.0, 2.629891212193543, 2.5248340852058164, 1.6850395240575227, 1.6739597199945866], [1.9525160091815394, 1.5650729544879134, 2.911456320356061, 2.3638629841740575, 2.180783122640742, 1.6711744342680945, 0.0, 2.781470346856108, 1.9931396373948602, 1.8988405420080108], [2.411483767831222, 2.786155971620397, 2.6688540019963107, 1.9771408696755002, 2.860282328438058, 2.7313102534426212, 2.893303506890569, 0.0, 3.0706017145150755, 3.132312877229809], [1.6690957931779822, 1.4291549963314722, 2.9081424500836066, 2.277371835020888, 1.8081355774522263, 1.6194100545076997, 1.7940030170245849, 2.4488186810063195, 0.0, 1.7259118408540377], [1.4616850346369077, 1.13061491960142, 3.0905549908440944, 2.0828644606041022, 1.7524922932218239, 1.4050233891833706, 1.472653235476726, 2.539628995651544, 1.6492439640032825, 0.0]], "hopper_max_number_ours": [[0.0, 2.02475261823282, 2.105581265506597, 2.5400889885349103, 2.1739192989614784, 1.8191642515511988, 2.172356726702115, 1.9106871915028707, 2.4196739249821677, 1.7009495552126677], [2.062172905193527, 0.0, 1.6330785141700073, 1.993356055607567, 2.001038086661073, 1.3695566299438848, 1.8810288398036559, 2.0142157979149915, 1.94987126898818, 1.7522579055632488], [1.8822150603712888, 1.8470517108884899, 0.0, 2.1566285738626365, 1.9394144276537626, 1.7438540082936964, 1.8891563355005938, 2.2140949908186096, 2.167089753097826, 2.0746794919607097], [1.6453234971027622, 2.0412607915510117, 1.4323703270997812, 0.0, 1.7230641990546114, 1.7802862240288904, 1.6300773999770621, 1.499829991197122, 1.6734707744617765, 1.657075821421672], [1.7188758400492075, 1.7803743632074587, 1.7686248155047906, 1.8857614017699407, 0.0, 2.2371799658602147, 1.9639782044822123, 2.253121167149921, 1.996374566447358, 1.8310060263755878], [2.0688137822730264, 1.641272557577225, 1.6397718334574276, 1.7325427512840228, 1.8359369736855333, 0.0, 1.7609052586400948, 1.1912116468977973, 1.8088861513849033, 1.7241281960198678], [1.8655361515605786, 2.011761259395777, 1.7130823568539149, 2.190904954746187, 2.1059964543716396, 1.780255719376753, 0.0, 2.10419034493609, 2.0299236183267984, 1.9681354441090058], [1.5185975102067024, 1.9556452319491215, 1.7188189486305514, 1.9522953442561668, 1.7261724690290332, 1.4748917595599191, 1.812120476073891, 0.0, 1.9019578837722881, 1.649292763680439], [1.563980107482434, 2.2127609810642133, 1.8157810659505644, 1.774292985565909, 1.7613767066918338, 1.932771410304421, 1.8526652999677427, 1.8013172608725794, 0.0, 1.955754387145012], [1.788783803897514, 1.5851022040053906, 1.6953962995056076, 1.7241847156627903, 1.613746460813442, 1.7163728959711377, 1.634866737366901, 2.0521306486767954, 1.722488447823253, 0.0]], "hopper_max_number_baseline": [[0.0, 1.053798788267681, 1.8256712713128194, 1.5448159943588824, 1.453366891575919, 1.562587145353531, 1.42850556104579, 1.8307098457153885, 1.5034847951826953, 1.7609133407706716], [1.2546232562430881, 0.0, 1.9105504008101866, 1.576732557058491, 1.554215942807333, 1.762122033886072, 1.5367824267336265, 1.9546385521876233, 1.7291893406460057, 1.708080204792175], [2.4414290330815187, 2.2378348462097737, 0.0, 2.1178758137974514, 2.2541187147800312, 2.225615106667455, 2.3474885186155436, 2.1001823643879343, 2.3518960035642587, 1.8557683275710926], [1.4181817678120254, 1.3574095772367774, 1.6330377361438364, 0.0, 1.271792056114869, 1.1452214849388929, 1.2120189524735898, 2.1465382085565996, 1.3240164252985325, 1.8786421021810187], [1.329122878722604, 1.1902496677239713, 1.582924122436589, 1.323147329278146, 0.0, 1.2709015425314558, 1.1826324218179551, 1.8878925255019026, 1.2866711655907437, 1.6542456975935333], [1.3381938803636066, 1.2469160223918967, 1.4701129178045889, 1.127494668919077, 1.167526708096727, 0.0, 1.2177388088700778, 2.1074853125285182, 1.3361027477043879, 1.9030243409797571], [1.0294341479246338, 0.9839778057287712, 1.4896143513388904, 1.2621253489549695, 1.4082671689779616, 1.4598265647378827, 0.0, 1.759142166324117, 1.3533645471864397, 1.7486146229352513], [1.4306982062814446, 1.6496832444679597, 1.8119722008198296, 1.6332372368612245, 2.057964108559832, 1.9350188309355678, 1.848287538936522, 0.0, 1.743562606201064, 1.5461156065616564], [1.4791286051103032, 1.4854942108798241, 2.0821904949198973, 1.5914081714548691, 1.6815917215354164, 1.5959461774794208, 1.5396342323546444, 2.3450520937399926, 0.0, 2.452537345360658], [1.8953381954145776, 1.781168020918467, 1.98622951449728, 1.8422291369908377, 2.0062106805032363, 2.039555445353528, 1.8357246991778153, 1.7494396083150314, 1.939194731066756, 0.0]], "novel_walker_ours": [1032.5883179695581, 945.9293835542866, 382.8126511782732, 912.7684734993741, 863.8945221460688, 670.1998112932936, 115.49154600940514, 695.7679810275692, 137.6842650368118, 68.82742067080497], "novel_walker_baseline": [157.06887681750106, 287.1869723054354, 316.80287193318406, 16.91832823778375, 340.7819051371001, 415.96532403746016, 644.7012640607605, 870.3393495594191, 328.16271237351424, 113.84822613853055], "novel_cheetah_ours": [7316.168653105899, 4366.572190170884, 3446.487155046419, 5452.319851852781, 3627.7778821781817, 4655.128109240334, 2836.2332640627114, 3459.5723405868093, 2309.642705044114, 3304.8533708683585], "novel_cheetah_baseline": [4864.696956089928, 5610.000069577717, 4993.078750895224, 4526.692358805851, 4052.1704059637555, 5540.689724635217, 5515.161949362478, 4823.379978195207, 5549.10911668712, 5420.443219718892], "novel_hopper_ours": [276.50186737328005, 892.3392366213386, 1175.182220404519, 1028.4590340647019, 785.3882626805608, 59.09378762992622, 703.3066796446836, 1179.1295854731543, 527.3837036247781, 250.04034829354177, 812.4770727119684, 1887.8268789151323], "novel_hopper_baseline": [143.87811352294398, 197.88779149208224, 232.91468664935007, 236.77523679179725, 239.07342762389834, 241.07290814550998, 260.8757582876272, 349.91058747824684, 792.0795706605692, 2403.805055898077]}')
  82.  
  83.  
  84.  
  85. # num_groups = 2
  86. matrix = np.array(data['walker2d_max_number_ours'])
  87. performance = np.array(data['novel_walker_ours'])
  88. matrix_org = np.array(data['walker2d_max_number_baseline'])
  89. performance_org = np.array(data['novel_walker_baseline'])
  90.  
  91.  
  92. name_matrix_perf_dict = {
  93. "ours": [matrix, performance],
  94. "DDPG": [matrix_org, performance_org]
  95. }
  96.  
  97. df = build_merge_dataframe(name_matrix_perf_dict)
  98. draw(df)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement