import asyncio import logging from enum import Enum from typing import Any, AsyncIterable, AsyncIterator, Callable, Generator, Generic, TypeVar T = TypeVar("T") logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) class OperationState(Enum): CANCELED = -1 IDLE = 0 RUNNING = 1 SUSPENDED = 2 DONE = 3 class InvalidOperationStateError(RuntimeError): """Raised when an `Operation` enters an invalid state.""" pass class Operation(Generic[T]): def __init__( self, async_iterator: AsyncIterable[T], *, progress_callback: Callable[[T], Any] | None = None, done_callback: Callable[[], Any] | None = None, ): self._async_iterator = async_iterator self._progress_callback = progress_callback self._done_callback = done_callback self._state = OperationState.IDLE self._resume_event = asyncio.Event() self._runner_task: asyncio.Task[None] | None = None @property def state(self) -> OperationState: return self._state def run(self) -> None: if self._state == OperationState.IDLE: self._runner_task = asyncio.create_task(self._runner()) self._state = OperationState.RUNNING logger.debug("Operation is runned.") else: raise InvalidOperationStateError("Operation is already started") def suspend(self) -> None: if self._state == OperationState.RUNNING: self._state = OperationState.SUSPENDED self._resume_event.clear() logger.debug("Operation is suspended.") else: raise InvalidOperationStateError("Operation is not running") def resume(self) -> None: if self._state == OperationState.SUSPENDED: self._state = OperationState.RUNNING self._resume_event.set() logger.debug("Operation is resumed.") else: raise InvalidOperationStateError("Operation is not suspended") def cancel(self) -> None: if self._state in (OperationState.RUNNING, OperationState.SUSPENDED): if self._runner_task is not None: self._runner_task.cancel() self._state = OperationState.CANCELED logger.debug("Operation is canceled.") else: raise InvalidOperationStateError("Operation is not running") async def wait(self) -> None: if self._runner_task is not None: await asyncio.wait({self._runner_task}) else: raise InvalidOperationStateError("Operation is not started") async def _runner(self) -> None: async for target in self._async_iterator: if self._progress_callback is not None: self._progress_callback(target) if self._state == OperationState.SUSPENDED: await self._resume_event.wait() self._resume_event.clear() self._state = OperationState.DONE logger.debug("Operation is done.") if self._done_callback is not None: self._done_callback() def __await__(self) -> Generator[Any, None, None]: return self.wait().__await__() ###### class ConcreteAsyncGenerator: def __init__(self, start: int, stop: int): self.current = start self.start = start self.stop = stop def __aiter__(self) -> AsyncIterator[int]: return self async def __anext__(self) -> int: if self.current < self.stop: self.current += 1 await asyncio.sleep(0.1) return self.current raise StopAsyncIteration def concrete_progress_callback(process: int) -> None: # UI, log or something update logger.info(f"Iteration complete with: {process}...") def concrete_done_callback() -> None: # result and errors handling logger.info("All iterations complete!") async def main() -> None: operation = Operation( async_iterator=ConcreteAsyncGenerator(0, 100), progress_callback=concrete_progress_callback, done_callback=concrete_done_callback, ) operation.run() await asyncio.sleep(1) operation.suspend() await asyncio.sleep(2) operation.resume() await asyncio.sleep(1) # operation.cancel() await operation print("done") if __name__ == "__main__": asyncio.run(main())