Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from typing import Any, Callable, Optional, Type, TypeVar
- from mypy.plugin import Plugin, FunctionContext # pylint: disable=no-name-in-module
- from mypy.types import TypedDictType # pylint: disable=no-name-in-module
- A = TypeVar('A')
- def safe_typeddict_cast(_: Type[A], data: Any) -> A:
- # This function does nothing on its own, but the types are checked in mypy_plugin.py.
- # The goal is to safely convert the data argument to the type of the first argument.
- # The type will only be converted if it passes type checks. If it doesn't pass,
- # error messages will tell you about missing keys or incompatible key types.
- return data
- class TypedDictCastPlugin(Plugin):
- @staticmethod
- def get_function_hook(fullname: str) -> Optional[Callable[[FunctionContext], TypedDictType]]:
- def convert_data_type(context: FunctionContext) -> TypedDictType:
- # Compare the types
- target_type = context.arg_types[0][0].ret_type.type.typeddict_type # type: ignore
- target_type_name = context.args[0][0].name # type: ignore
- target_type_items = target_type.items
- input_type = context.arg_types[1][0]
- input_type_items = input_type.items # type: ignore
- # All keys in the input type must be present and subtyped on the target type
- missing_keys_in_target_type = []
- mistyped_keys_in_target_type = []
- for key, value in input_type_items.items():
- if key not in target_type_items:
- missing_keys_in_target_type.append(key)
- elif value != target_type_items[key]:
- # TODO: Recursively support checking nested typed dicts for better errors?
- mistyped_keys_in_target_type.append(key)
- if missing_keys_in_target_type:
- if len(missing_keys_in_target_type) == 1:
- context.api.fail(
- f'Input data type has extra key "{missing_keys_in_target_type[0]}" '
- f'which is missing on type "{target_type_name}"', context.context
- )
- else:
- joined_keys = '"' + '", "'.join(missing_keys_in_target_type) + '"'
- context.api.fail(
- f'Input data type has extra keys {joined_keys} which are missing on type "{target_type_name}"',
- context.context
- )
- if mistyped_keys_in_target_type:
- if len(mistyped_keys_in_target_type) == 1:
- context.api.fail(
- f'Input data type has key "{mistyped_keys_in_target_type[0]}" '
- f'which has incompatible type on type {target_type_name}', context.context
- )
- else:
- joined_keys = '"' + '", "'.join(mistyped_keys_in_target_type) + '"'
- context.api.fail(
- f'Input data type has keys {joined_keys} '
- f'which have incompatible types on type "{target_type_name}"', context.context
- )
- # TODO: All required keys in the target type must be present and subtyped on the input type
- # If any type checks failed, don't convert the type
- if missing_keys_in_target_type or mistyped_keys_in_target_type:
- return input_type # type: ignore
- return target_type
- if '.safe_typeddict_cast' in fullname:
- return convert_data_type
- return None
- def plugin(_: str) -> Type[TypedDictCastPlugin]:
- return TypedDictCastPlugin
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement