Advertisement
Guest User

Untitled

a guest
Jul 21st, 2019
107
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.65 KB | None | 0 0
  1. import dgl
  2. import dgl.function as fn
  3.  
  4. import numpy as np
  5. import scipy.sparse as sp
  6.  
  7. import mxnet as mx
  8. from mxnet import gluon, autograd
  9. from mxnet.gluon import nn
  10. import mxnet.ndarray as nd
  11.  
  12.  
  13. def load_data():
  14. print('Features...')
  15. n = 0
  16. m = 0
  17. row = []
  18. col = []
  19. val = []
  20. with open('./reddit-top50/post-feat.txt', 'r') as f:
  21. for i, l in enumerate(f):
  22. parts = l.strip().split('\t')
  23. if i == 0:
  24. n = int(parts[0])
  25. m = int(parts[1])
  26. else:
  27. row.append(int(parts[0]))
  28. col.append(int(parts[1]))
  29. val.append(int(parts[2]))
  30. feat = sp.coo_matrix((val, (row, col)), shape=(n, m), dtype=np.float32)
  31. feat = nd.sparse.csr_matrix(feat.tocsr(), shape=(n, m), dtype=np.float32)
  32. feat = feat / feat.sum(1).reshape(-1,1)
  33.  
  34. print('Generating graph')
  35. u = np.load('./reddit-top50/u.npy')
  36. v = np.load('./reddit-top50/v.npy')
  37. adj = sp.coo_matrix((np.ones((len(u),)), (u, v)), shape=(n, n))
  38. adj += sp.eye(n, n)
  39. g = dgl.DGLGraph(adj)
  40. print('#Nodes:', g.number_of_nodes())
  41. print('#Edges:', g.number_of_edges())
  42.  
  43. print('Labels...')
  44. label = []
  45. with open('./reddit-top50/post-labels.txt', 'r') as f:
  46. for i, l in enumerate(f):
  47. label.append(int(l.strip()))
  48. label = np.array(label, dtype=np.int64)
  49.  
  50. print('Making training/testing masks')
  51. train_mask = np.arange(0, n//2)
  52. test_mask = np.arange(n - 1000, n)
  53. return g, feat, label, train_mask, test_mask
  54.  
  55. def evaluate(model, g, feats, labels, mask):
  56. logits = model(g, feats)
  57. logits = logits[mask]
  58. labels = labels[mask]
  59. indices = logits.argmax(axis=1)
  60. accuracy = (indices == labels).sum() / labels.shape[0]
  61. return accuracy.asscalar()
  62.  
  63. g, feat, label, train_mask, test_mask = load_data()
  64.  
  65. ctx = mx.gpu(0)
  66. feat = nd.array(feat, ctx=ctx)
  67. label = nd.array(label, ctx=ctx)
  68. train_mask = nd.array(train_mask, ctx=ctx)
  69. test_mask = nd.array(test_mask, ctx=ctx)
  70. n_train_samples = train_mask.shape[0]
  71.  
  72. # calculate normalization
  73. #degs = g.in_degrees().astype('float32').asnumpy()
  74. #norm = np.power(degs, -0.5).reshape(-1, 1)
  75. #norm[np.isinf(norm)] = 0.
  76. #norm = nd.array(norm, ctx=ctx)
  77. #g.ndata['norm'] = norm
  78.  
  79. class GraphConv(gluon.Block):
  80. def __init__(self, n_in, n_out):
  81. super(GraphConv, self).__init__()
  82. self.fc = nn.Dense(n_out)
  83. self.fc2 = nn.Dense(n_out)
  84.  
  85. def forward(self, g, feats):
  86. h = self.fc(feats)
  87. g.ndata['h'] = h #* g.ndata['norm']
  88. g.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h'))
  89. hh = g.ndata.pop('h') #* g.ndata['norm']
  90. h = nd.concat(h, hh, dim=1)
  91. return self.fc2(h)
  92.  
  93. class GCN(gluon.Block):
  94. def __init__(self):
  95. super(GCN, self).__init__()
  96. self.gc1 = GraphConv(feat.shape[1], 64)
  97. self.gc2 = GraphConv(64, 50)
  98.  
  99. def forward(self, g, feats):
  100. h = self.gc1(g, feats)
  101. h = nd.relu(h)
  102. h = self.gc2(g, h)
  103. return h
  104.  
  105. model = GCN()
  106. model.initialize(ctx=ctx)
  107. trainer = gluon.Trainer(model.collect_params(), 'adam',
  108. {'learning_rate': 0.01, 'wd': 5e-4})
  109. loss_fcn = gluon.loss.SoftmaxCELoss()
  110.  
  111. feat = feat.as_in_context(ctx)
  112. label = label.as_in_context(ctx)
  113.  
  114. for epoch in range(200):
  115. with autograd.record():
  116. logits = model(g, feat)
  117. loss = loss_fcn(logits[train_mask], label[train_mask]).sum() / n_train_samples
  118.  
  119. loss.backward()
  120. trainer.step(batch_size=1)
  121.  
  122. train_acc = evaluate(model, g, feat, label, train_mask)
  123. test_acc = evaluate(model, g, feat, label, test_mask)
  124. print('Epoch %d, Loss %f, Train acc %f, Test acc %f' % (epoch, loss.asscalar(), train_acc, test_acc))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement