Guest User

Untitled

a guest
Jul 16th, 2018
73
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.07 KB | None | 0 0
  1. from collections import defaultdict
  2. import unittest
  3. from unittest import TestCase
  4.  
  5. import numpy as np
  6. from scipy import sparse
  7.  
  8. from hypothesis import given
  9. from hypothesis.extra.numpy import arrays
  10. from hypothesis.extra.numpy import floating_dtypes
  11. from hypothesis.extra.numpy import array_shapes
  12.  
  13.  
  14. def duplicate_indices(X_sparse_r):
  15. ''' Function to determine duplicate row entries of a sparse matrix.
  16.  
  17. Args:
  18. X_sparse_r (Scipy CSR matrix): Machine learning design matrix.
  19.  
  20. Returns:
  21. list: Set of row indices which are another entries duplicates.
  22.  
  23. See unit test for an example
  24. '''
  25. if not isinstance(X_sparse_r, sparse.csr.csr_matrix):
  26. raise ValueError('Must be a CSR Matrix!')
  27.  
  28. drop_row_indices = []
  29. hash_map = defaultdict(bool)
  30. for i in range(X_sparse_r.shape[0]):
  31. row = X_sparse_r.getrow(i)
  32. hsh = hash(row.indices.tostring() + row.data.tostring())
  33. if not hash_map[hsh]:
  34. hash_map[hsh] = True
  35. else:
  36. drop_row_indices.append(i)
  37.  
  38. return drop_row_indices
  39.  
  40.  
  41. class TestDuplicateIndices(TestCase):
  42. def test_duplicate_indices_example(self):
  43. data = [[1, 0], [1, 0]]
  44. X_sparse_r = sparse.csr_matrix(data)
  45. dupe_idx = duplicate_indices(X_sparse_r)
  46.  
  47. self.assertListEqual(dupe_idx, [1])
  48.  
  49. def test_duplicate_indices_nontrivial(self):
  50. data = [[1, 0, 0, 1],
  51. [0, 1, 1, 0],
  52. [1.1, 1.1, 1.2, 0.9],
  53. # Duplicates
  54. [1, 0, 0, 1],
  55. [0, 1, 1, 0],
  56. [1.1, 1.1, 1.2, 0.9]]
  57.  
  58. X_sparse_r = sparse.csr_matrix(data)
  59.  
  60. dupe_idx = duplicate_indices(X_sparse_r)
  61.  
  62. self.assertListEqual(dupe_idx, [3, 4, 5])
  63.  
  64. @given(arrays(floating_dtypes(), array_shapes(2, 2, 1, 10)))
  65. def test_duplicate_indices_random(self, data):
  66. np.random.shuffle(data)
  67.  
  68. # Naive double for loop solution:
  69. X_sparse_r = sparse.csr_matrix(data)
  70. drop_row_indices_naive = []
  71. for i in range(X_sparse_r.shape[0]-1):
  72. row1 = X_sparse_r.getrow(i)
  73. for j in range(i+1, X_sparse_r.shape[0]):
  74. row2 = X_sparse_r.getrow(j)
  75. if not (row1.indices.shape[0] == row2.indices.shape[0]) \
  76. or not (row1.data.shape[0] == row2.data.shape[0]):
  77. continue
  78. if np.all(row1.indices == row2.indices) \
  79. and np.all(row1.data.tostring() == row2.data.tostring()):
  80. drop_row_indices_naive.append(j)
  81. # Expected set of indices
  82. drop_row_indices_naive = list(set(drop_row_indices_naive))
  83.  
  84. # Actual
  85. drop_row_indices = duplicate_indices(X_sparse_r)
  86.  
  87. self.assertEqual(len(drop_row_indices),
  88. len(drop_row_indices_naive))
  89. self.assertSetEqual(set(drop_row_indices),
  90. set(drop_row_indices_naive))
  91.  
  92. def test_non_csr_raises(self):
  93. with self.assertRaises(ValueError):
  94. duplicate_indices('test')
  95.  
  96.  
  97. if __name__ == '__main__':
  98. unittest.main()
Add Comment
Please, Sign In to add comment