Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import collections
- import typing
- T = typing.TypeVar('T')
- class ChainableIterator(typing.Iterator[T], collections.Iterator):
- def __init__(self, *iterators: typing.Iterator[T]):
- self._index = -1
- self._iterators = list((a for a in b) for b in iterators)
- def __next__(self):
- try:
- if self._index < 0:
- raise StopIteration
- return next(self._iterators[self._index])
- except StopIteration as e:
- if self._index < len(self._iterators) - 1:
- self._index += 1
- return next(self)
- raise e
- def chain(self, *iterators: typing.Iterator[T]) -> 'ChainableIterator[T]':
- self._iterators.extend((a for a in b) for b in iterators)
- return self
- if __name__ == '__main__':
- iterator = ChainableIterator((1, 2, 3), (4, 5, 6))
- iterator.chain((7, 8, 9), (10, 11, 12))
- for i in range(1, 13):
- assert i == next(iterator)
- try:
- next(iterator)
- assert False
- except StopIteration:
- pass
Add Comment
Please, Sign In to add comment