Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import asyncio
- import logging
- from enum import Enum
- from typing import Any, AsyncIterable, AsyncIterator, Callable, Generator, Generic, TypeVar
- T = TypeVar("T")
- logging.basicConfig(level=logging.DEBUG)
- logger = logging.getLogger(__name__)
- class OperationState(Enum):
- CANCELED = -1
- IDLE = 0
- RUNNING = 1
- SUSPENDED = 2
- DONE = 3
- class InvalidOperationStateError(RuntimeError):
- """Raised when an `Operation` enters an invalid state."""
- pass
- class Operation(Generic[T]):
- def __init__(
- self,
- async_iterator: AsyncIterable[T],
- *,
- progress_callback: Callable[[T], Any] | None = None,
- done_callback: Callable[[], Any] | None = None,
- ):
- self._async_iterator = async_iterator
- self._progress_callback = progress_callback
- self._done_callback = done_callback
- self._state = OperationState.IDLE
- self._resume_event = asyncio.Event()
- self._runner_task: asyncio.Task[None] | None = None
- @property
- def state(self) -> OperationState:
- return self._state
- def run(self) -> None:
- if self._state == OperationState.IDLE:
- self._runner_task = asyncio.create_task(self._runner())
- self._state = OperationState.RUNNING
- logger.debug("Operation is runned.")
- else:
- raise InvalidOperationStateError("Operation is already started")
- def suspend(self) -> None:
- if self._state == OperationState.RUNNING:
- self._state = OperationState.SUSPENDED
- self._resume_event.clear()
- logger.debug("Operation is suspended.")
- else:
- raise InvalidOperationStateError("Operation is not running")
- def resume(self) -> None:
- if self._state == OperationState.SUSPENDED:
- self._state = OperationState.RUNNING
- self._resume_event.set()
- logger.debug("Operation is resumed.")
- else:
- raise InvalidOperationStateError("Operation is not suspended")
- def cancel(self) -> None:
- if self._state in (OperationState.RUNNING, OperationState.SUSPENDED):
- if self._runner_task is not None:
- self._runner_task.cancel()
- self._state = OperationState.CANCELED
- logger.debug("Operation is canceled.")
- else:
- raise InvalidOperationStateError("Operation is not running")
- async def wait(self) -> None:
- if self._runner_task is not None:
- await asyncio.wait({self._runner_task})
- else:
- raise InvalidOperationStateError("Operation is not started")
- async def _runner(self) -> None:
- async for target in self._async_iterator:
- if self._progress_callback is not None:
- self._progress_callback(target)
- if self._state == OperationState.SUSPENDED:
- await self._resume_event.wait()
- self._resume_event.clear()
- self._state = OperationState.DONE
- logger.debug("Operation is done.")
- if self._done_callback is not None:
- self._done_callback()
- def __await__(self) -> Generator[Any, None, None]:
- return self.wait().__await__()
- ######
- class ConcreteAsyncGenerator:
- def __init__(self, start: int, stop: int):
- self.current = start
- self.start = start
- self.stop = stop
- def __aiter__(self) -> AsyncIterator[int]:
- return self
- async def __anext__(self) -> int:
- if self.current < self.stop:
- self.current += 1
- await asyncio.sleep(0.1)
- return self.current
- raise StopAsyncIteration
- def concrete_progress_callback(process: int) -> None:
- # UI, log or something update
- logger.info(f"Iteration complete with: {process}...")
- def concrete_done_callback() -> None:
- # result and errors handling
- logger.info("All iterations complete!")
- async def main() -> None:
- operation = Operation(
- async_iterator=ConcreteAsyncGenerator(0, 100),
- progress_callback=concrete_progress_callback,
- done_callback=concrete_done_callback,
- )
- operation.run()
- await asyncio.sleep(1)
- operation.suspend()
- await asyncio.sleep(2)
- operation.resume()
- await asyncio.sleep(1)
- # operation.cancel()
- await operation
- print("done")
- if __name__ == "__main__":
- asyncio.run(main())
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement