Advertisement
Guest User

Untitled

a guest
May 24th, 2019
99
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.59 KB | None | 0 0
  1. from typing import Any, Callable, Optional, Type, TypeVar
  2. from mypy.plugin import Plugin, FunctionContext # pylint: disable=no-name-in-module
  3. from mypy.types import TypedDictType # pylint: disable=no-name-in-module
  4.  
  5. A = TypeVar('A')
  6.  
  7. def safe_typeddict_cast(_: Type[A], data: Any) -> A:
  8. # This function does nothing on its own, but the types are checked in mypy_plugin.py.
  9. # The goal is to safely convert the data argument to the type of the first argument.
  10. # The type will only be converted if it passes type checks. If it doesn't pass,
  11. # error messages will tell you about missing keys or incompatible key types.
  12. return data
  13.  
  14.  
  15. class TypedDictCastPlugin(Plugin):
  16. @staticmethod
  17. def get_function_hook(fullname: str) -> Optional[Callable[[FunctionContext], TypedDictType]]:
  18. def convert_data_type(context: FunctionContext) -> TypedDictType:
  19. # Compare the types
  20. target_type = context.arg_types[0][0].ret_type.type.typeddict_type # type: ignore
  21. target_type_name = context.args[0][0].name # type: ignore
  22. target_type_items = target_type.items
  23. input_type = context.arg_types[1][0]
  24. input_type_items = input_type.items # type: ignore
  25.  
  26. # All keys in the input type must be present and subtyped on the target type
  27. missing_keys_in_target_type = []
  28. mistyped_keys_in_target_type = []
  29. for key, value in input_type_items.items():
  30. if key not in target_type_items:
  31. missing_keys_in_target_type.append(key)
  32. elif value != target_type_items[key]:
  33. # TODO: Recursively support checking nested typed dicts for better errors?
  34. mistyped_keys_in_target_type.append(key)
  35.  
  36. if missing_keys_in_target_type:
  37. if len(missing_keys_in_target_type) == 1:
  38. context.api.fail(
  39. f'Input data type has extra key "{missing_keys_in_target_type[0]}" '
  40. f'which is missing on type "{target_type_name}"', context.context
  41. )
  42. else:
  43. joined_keys = '"' + '", "'.join(missing_keys_in_target_type) + '"'
  44. context.api.fail(
  45. f'Input data type has extra keys {joined_keys} which are missing on type "{target_type_name}"',
  46. context.context
  47. )
  48.  
  49. if mistyped_keys_in_target_type:
  50. if len(mistyped_keys_in_target_type) == 1:
  51. context.api.fail(
  52. f'Input data type has key "{mistyped_keys_in_target_type[0]}" '
  53. f'which has incompatible type on type {target_type_name}', context.context
  54. )
  55. else:
  56. joined_keys = '"' + '", "'.join(mistyped_keys_in_target_type) + '"'
  57. context.api.fail(
  58. f'Input data type has keys {joined_keys} '
  59. f'which have incompatible types on type "{target_type_name}"', context.context
  60. )
  61.  
  62. # TODO: All required keys in the target type must be present and subtyped on the input type
  63.  
  64. # If any type checks failed, don't convert the type
  65. if missing_keys_in_target_type or mistyped_keys_in_target_type:
  66. return input_type # type: ignore
  67.  
  68. return target_type
  69.  
  70. if '.safe_typeddict_cast' in fullname:
  71. return convert_data_type
  72.  
  73. return None
  74.  
  75.  
  76. def plugin(_: str) -> Type[TypedDictCastPlugin]:
  77. return TypedDictCastPlugin
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement