Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from enum import Enum
- from functools import lru_cache
- from typing import Dict, Sequence, Any, Callable, Iterable, Tuple, Type, Generic, TypeVar, Union
- E = TypeVar("E", bound=Enum)
- Transformer = Callable[[Any], Any]
- def name_of(enum_or_name: Union[E, str]) -> str:
- if isinstance(enum_or_name, str):
- return enum_or_name
- else:
- return enum_or_name.name
- class EnumCallable(Generic[E]):
- def __init__(
- self,
- enum_type: Type[E],
- enum_func_registry: Union[
- Iterable[Tuple[Union[E, str], Transformer]],
- Dict[Union[E, str], Transformer],
- ],
- validate: bool = True,
- ) -> None:
- self._enum_type = enum_type
- if validate and not issubclass(self._enum_type, Enum):
- raise TypeError(f"Expecting an enum type, not '{self._enum_type}'")
- enum2func = (
- enum_func_registry.items()
- if isinstance(enum_func_registry, dict)
- else enum_func_registry
- )
- self._enum2func = {name_of(enum_val): func for enum_val, func in enum2func}
- if validate:
- vals = set(map(lambda e: e.name, self.values))
- for enum_name in self._enum2func.keys():
- if enum_name not in vals:
- raise ValueError(f"No function for enum value named: '{enum_name}'")
- @property
- @lru_cache(1)
- def values(self) -> tuple:
- """Obtain all distinct enum values.
- """
- return tuple(self._enum_type._member_map_.values())
- def get_transformer(self, e: E) -> Callable[[Any], Any]:
- """Obtain the registered transformation function for the specific enum value.
- """
- return self._enum2func[e.name]
- def transform(self, enum_val_or_name: Union[E, str], value: Any) -> Any:
- """Applies the function registered for the enum's specific value.
- """
- return self._enum2func[name_of(enum_val_or_name)](value)
- def __call__(self, *args, **kwargs):
- """Alis for :func:`transform`.
- """
- return self.transform(*args, **kwargs)
- def test_enum_callable():
- ET = Enum("ET", ["a", "b", "c"])
- ec = EnumCallable(
- enum_type=ET,
- enum_func_registry={
- ET["a"]: lambda x: x + 1,
- "b": lambda x: x % 2 == 0,
- ET["c"]: lambda x: x * 10,
- },
- validate=True,
- )
- assert ec.transform(ET["a"], 0) == 1
- assert ec.transform(ET["b"], 6)
- assert ec.transform(ET["c"], -10) == -100
- assert ec.transform("a", -1) == 0
- assert ec.transform("b", 2)
- assert ec.transform("c", 99) == 990
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement