Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from collections import defaultdict
- import unittest
- from unittest import TestCase
- import numpy as np
- from scipy import sparse
- from hypothesis import given
- from hypothesis.extra.numpy import arrays
- from hypothesis.extra.numpy import floating_dtypes
- from hypothesis.extra.numpy import array_shapes
- def duplicate_indices(X_sparse_r):
- ''' Function to determine duplicate row entries of a sparse matrix.
- Args:
- X_sparse_r (Scipy CSR matrix): Machine learning design matrix.
- Returns:
- list: Set of row indices which are another entries duplicates.
- See unit test for an example
- '''
- if not isinstance(X_sparse_r, sparse.csr.csr_matrix):
- raise ValueError('Must be a CSR Matrix!')
- drop_row_indices = []
- hash_map = defaultdict(bool)
- for i in range(X_sparse_r.shape[0]):
- row = X_sparse_r.getrow(i)
- hsh = hash(row.indices.tostring() + row.data.tostring())
- if not hash_map[hsh]:
- hash_map[hsh] = True
- else:
- drop_row_indices.append(i)
- return drop_row_indices
- class TestDuplicateIndices(TestCase):
- def test_duplicate_indices_example(self):
- data = [[1, 0], [1, 0]]
- X_sparse_r = sparse.csr_matrix(data)
- dupe_idx = duplicate_indices(X_sparse_r)
- self.assertListEqual(dupe_idx, [1])
- def test_duplicate_indices_nontrivial(self):
- data = [[1, 0, 0, 1],
- [0, 1, 1, 0],
- [1.1, 1.1, 1.2, 0.9],
- # Duplicates
- [1, 0, 0, 1],
- [0, 1, 1, 0],
- [1.1, 1.1, 1.2, 0.9]]
- X_sparse_r = sparse.csr_matrix(data)
- dupe_idx = duplicate_indices(X_sparse_r)
- self.assertListEqual(dupe_idx, [3, 4, 5])
- @given(arrays(floating_dtypes(), array_shapes(2, 2, 1, 10)))
- def test_duplicate_indices_random(self, data):
- np.random.shuffle(data)
- # Naive double for loop solution:
- X_sparse_r = sparse.csr_matrix(data)
- drop_row_indices_naive = []
- for i in range(X_sparse_r.shape[0]-1):
- row1 = X_sparse_r.getrow(i)
- for j in range(i+1, X_sparse_r.shape[0]):
- row2 = X_sparse_r.getrow(j)
- if not (row1.indices.shape[0] == row2.indices.shape[0]) \
- or not (row1.data.shape[0] == row2.data.shape[0]):
- continue
- if np.all(row1.indices == row2.indices) \
- and np.all(row1.data.tostring() == row2.data.tostring()):
- drop_row_indices_naive.append(j)
- # Expected set of indices
- drop_row_indices_naive = list(set(drop_row_indices_naive))
- # Actual
- drop_row_indices = duplicate_indices(X_sparse_r)
- self.assertEqual(len(drop_row_indices),
- len(drop_row_indices_naive))
- self.assertSetEqual(set(drop_row_indices),
- set(drop_row_indices_naive))
- def test_non_csr_raises(self):
- with self.assertRaises(ValueError):
- duplicate_indices('test')
- if __name__ == '__main__':
- unittest.main()
Add Comment
Please, Sign In to add comment