SHARE
TWEET

Untitled

a guest Jul 21st, 2019 71 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import dgl
  2. import dgl.function as fn
  3.  
  4. import numpy as np
  5. import scipy.sparse as sp
  6.  
  7. import mxnet as mx
  8. from mxnet import gluon, autograd
  9. from mxnet.gluon import nn
  10. import mxnet.ndarray as nd
  11.  
  12.  
  13. def load_data():
  14.     print('Features...')
  15.     n = 0
  16.     m = 0
  17.     row = []
  18.     col = []
  19.     val = []
  20.     with open('./reddit-top50/post-feat.txt', 'r') as f:
  21.         for i, l in enumerate(f):
  22.             parts = l.strip().split('\t')
  23.             if i == 0:
  24.                 n = int(parts[0])
  25.                 m = int(parts[1])
  26.             else:
  27.                 row.append(int(parts[0]))
  28.                 col.append(int(parts[1]))
  29.                 val.append(int(parts[2]))
  30.     feat = sp.coo_matrix((val, (row, col)), shape=(n, m), dtype=np.float32)
  31.     feat = nd.sparse.csr_matrix(feat.tocsr(), shape=(n, m), dtype=np.float32)
  32.     feat = feat / feat.sum(1).reshape(-1,1)
  33.  
  34.     print('Generating graph')
  35.     u = np.load('./reddit-top50/u.npy')
  36.     v = np.load('./reddit-top50/v.npy')
  37.     adj = sp.coo_matrix((np.ones((len(u),)), (u, v)), shape=(n, n))
  38.     adj += sp.eye(n, n)
  39.     g = dgl.DGLGraph(adj)
  40.     print('#Nodes:', g.number_of_nodes())
  41.     print('#Edges:', g.number_of_edges())
  42.  
  43.     print('Labels...')
  44.     label = []
  45.     with open('./reddit-top50/post-labels.txt', 'r') as f:
  46.         for i, l in enumerate(f):
  47.             label.append(int(l.strip()))
  48.     label = np.array(label, dtype=np.int64)
  49.  
  50.     print('Making training/testing masks')
  51.     train_mask = np.arange(0, n//2)
  52.     test_mask = np.arange(n - 1000, n)
  53.     return g, feat, label, train_mask, test_mask
  54.  
  55. def evaluate(model, g, feats, labels, mask):
  56.     logits = model(g, feats)
  57.     logits = logits[mask]
  58.     labels = labels[mask]
  59.     indices = logits.argmax(axis=1)
  60.     accuracy = (indices == labels).sum() / labels.shape[0]
  61.     return accuracy.asscalar()
  62.  
  63. g, feat, label, train_mask, test_mask = load_data()
  64.  
  65. ctx = mx.gpu(0)
  66. feat = nd.array(feat, ctx=ctx)
  67. label = nd.array(label, ctx=ctx)
  68. train_mask = nd.array(train_mask, ctx=ctx)
  69. test_mask = nd.array(test_mask, ctx=ctx)
  70. n_train_samples = train_mask.shape[0]
  71.  
  72. # calculate normalization
  73. #degs = g.in_degrees().astype('float32').asnumpy()
  74. #norm = np.power(degs, -0.5).reshape(-1, 1)
  75. #norm[np.isinf(norm)] = 0.
  76. #norm = nd.array(norm, ctx=ctx)
  77. #g.ndata['norm'] = norm
  78.  
  79. class GraphConv(gluon.Block):
  80.     def __init__(self, n_in, n_out):
  81.         super(GraphConv, self).__init__()
  82.         self.fc = nn.Dense(n_out)
  83.         self.fc2 = nn.Dense(n_out)
  84.  
  85.     def forward(self, g, feats):
  86.         h = self.fc(feats)
  87.         g.ndata['h'] = h #* g.ndata['norm']
  88.         g.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h'))
  89.         hh = g.ndata.pop('h') #* g.ndata['norm']
  90.         h = nd.concat(h, hh, dim=1)
  91.         return self.fc2(h)
  92.  
  93. class GCN(gluon.Block):
  94.     def __init__(self):
  95.         super(GCN, self).__init__()
  96.         self.gc1 = GraphConv(feat.shape[1], 64)
  97.         self.gc2 = GraphConv(64, 50)
  98.  
  99.     def forward(self, g, feats):
  100.         h = self.gc1(g, feats)
  101.         h = nd.relu(h)
  102.         h = self.gc2(g, h)
  103.         return h
  104.  
  105. model = GCN()
  106. model.initialize(ctx=ctx)
  107. trainer = gluon.Trainer(model.collect_params(), 'adam',
  108.                         {'learning_rate': 0.01, 'wd': 5e-4})
  109. loss_fcn = gluon.loss.SoftmaxCELoss()
  110.  
  111. feat = feat.as_in_context(ctx)
  112. label = label.as_in_context(ctx)
  113.  
  114. for epoch in range(200):
  115.     with autograd.record():
  116.         logits = model(g, feat)
  117.         loss = loss_fcn(logits[train_mask], label[train_mask]).sum() / n_train_samples
  118.  
  119.     loss.backward()
  120.     trainer.step(batch_size=1)
  121.  
  122.     train_acc = evaluate(model, g, feat, label, train_mask)
  123.     test_acc = evaluate(model, g, feat, label, test_mask)
  124.     print('Epoch %d, Loss %f, Train acc %f, Test acc %f' % (epoch, loss.asscalar(), train_acc, test_acc))
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Not a member of Pastebin yet?
Sign Up, it unlocks many cool features!
 
Top