|
1 | 1 | from typing import Callable
|
2 | 2 | from typing import Optional
|
3 | 3 |
|
| 4 | +import mypy.nodes |
| 5 | +import mypy.plugin |
4 | 6 | import mypy.types
|
5 |
| -from mypy.nodes import NameExpr |
6 |
| -from mypy.nodes import TypeInfo |
7 |
| -from mypy.plugin import FunctionContext |
8 |
| -from mypy.plugin import Plugin |
9 | 7 |
|
10 | 8 | ATTR_FULL_NAME = 'pynamodb.attributes.Attribute'
|
11 |
| -NULL_ATTR_WRAPPER_FULL_NAME = 'pynamodb.attributes._NullableAttributeWrapper' |
12 | 9 |
|
13 | 10 |
|
14 |
| -class PynamodbPlugin(Plugin): |
15 |
| - def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], mypy.types.Type]]: |
| 11 | +class PynamodbPlugin(mypy.plugin.Plugin): |
| 12 | + def get_function_hook(self, fullname: str) -> Optional[Callable[[mypy.plugin.FunctionContext], mypy.types.Type]]: |
16 | 13 | sym = self.lookup_fully_qualified(fullname)
|
17 |
| - if sym and isinstance(sym.node, TypeInfo): |
18 |
| - attr_underlying_type = _get_attribute_underlying_type(sym.node) |
19 |
| - if attr_underlying_type: |
20 |
| - _underlying_type = attr_underlying_type # https://github.com/python/mypy/issues/4297 |
21 |
| - return lambda ctx: _attribute_instantiation_hook(ctx, _underlying_type) |
| 14 | + if sym and isinstance(sym.node, mypy.nodes.TypeInfo) and _is_attribute_type_node(sym.node): |
| 15 | + return _attribute_instantiation_hook |
| 16 | + return None |
22 | 17 |
|
| 18 | + def get_method_signature_hook(self, fullname: str |
| 19 | + ) -> Optional[Callable[[mypy.plugin.MethodSigContext], mypy.types.CallableType]]: |
| 20 | + class_name, method_name = fullname.rsplit('.', 1) |
| 21 | + sym = self.lookup_fully_qualified(class_name) |
| 22 | + if sym is not None and sym.node is not None and _is_attribute_type_node(sym.node): |
| 23 | + if method_name == '__get__': |
| 24 | + return _get_method_sig_hook |
| 25 | + elif method_name == '__set__': |
| 26 | + return _set_method_sig_hook |
23 | 27 | return None
|
24 | 28 |
|
25 | 29 |
|
26 |
| -def _get_attribute_underlying_type(attribute_class: TypeInfo) -> Optional[mypy.types.Type]: |
| 30 | +def _is_attribute_type_node(node: mypy.nodes.Node) -> bool: |
| 31 | + return ( |
| 32 | + isinstance(node, mypy.nodes.TypeInfo) and |
| 33 | + node.has_base(ATTR_FULL_NAME) |
| 34 | + ) |
| 35 | + |
| 36 | + |
| 37 | +def _attribute_marked_as_nullable(t: mypy.types.Instance) -> mypy.types.Instance: |
| 38 | + return t.copy_modified(args=t.args + [mypy.types.NoneType()]) |
| 39 | + |
| 40 | + |
| 41 | +def _is_attribute_marked_nullable(t: mypy.types.Type) -> bool: |
| 42 | + return ( |
| 43 | + isinstance(t, mypy.types.Instance) and |
| 44 | + _is_attribute_type_node(t.type) and |
| 45 | + # In lieu of being able to attach metadata to an instance, |
| 46 | + # having a None "fake" type argument is our way of marking the attribute as nullable |
| 47 | + bool(t.args) and isinstance(t.args[-1], mypy.types.NoneType) |
| 48 | + ) |
| 49 | + |
| 50 | + |
| 51 | +def _get_bool_literal(node: mypy.nodes.Node) -> Optional[bool]: |
| 52 | + return { |
| 53 | + 'builtins.False': False, |
| 54 | + 'builtins.True': True, |
| 55 | + }.get(node.fullname or '') if isinstance(node, mypy.nodes.NameExpr) else None |
| 56 | + |
| 57 | + |
| 58 | +def _make_optional(t: mypy.types.Type) -> mypy.types.UnionType: |
| 59 | + """Wraps a type in optionality""" |
| 60 | + return mypy.types.UnionType([t, mypy.types.NoneType()]) |
| 61 | + |
| 62 | + |
| 63 | +def _unwrap_optional(t: mypy.types.Type) -> mypy.types.Type: |
| 64 | + """Unwraps a potentially optional type""" |
| 65 | + if not isinstance(t, mypy.types.UnionType): # pragma: no cover |
| 66 | + return t |
| 67 | + t = mypy.types.UnionType([item for item in t.items if not isinstance(item, mypy.types.NoneType)]) |
| 68 | + if len(t.items) == 0: # pragma: no cover |
| 69 | + return mypy.types.NoneType() |
| 70 | + elif len(t.items) == 1: |
| 71 | + return t.items[0] |
| 72 | + else: |
| 73 | + return t # pragma: no cover |
| 74 | + |
| 75 | + |
| 76 | +def _get_method_sig_hook(ctx: mypy.plugin.MethodSigContext) -> mypy.types.CallableType: |
27 | 77 | """
|
28 |
| - For attribute classes, will return the underlying type. |
29 |
| - e.g. for `class MyAttribute(Attribute[int])`, this will return `int`. |
| 78 | + Patches up the signature of Attribute.__get__ to respect attribute's nullability. |
30 | 79 | """
|
31 |
| - for base_instance in attribute_class.bases: |
32 |
| - if base_instance.type.fullname() == ATTR_FULL_NAME: |
33 |
| - return base_instance.args[0] |
34 |
| - return None |
| 80 | + sig = ctx.default_signature |
| 81 | + if not _is_attribute_marked_nullable(ctx.type): |
| 82 | + return sig |
| 83 | + try: |
| 84 | + (instance_type, owner_type) = sig.arg_types |
| 85 | + except ValueError: # pragma: no cover |
| 86 | + return sig |
| 87 | + if isinstance(instance_type, mypy.types.NoneType): # class attribute access |
| 88 | + return sig |
| 89 | + return sig.copy_modified(ret_type=_make_optional(sig.ret_type)) |
35 | 90 |
|
36 | 91 |
|
37 |
| -def _attribute_instantiation_hook(ctx: FunctionContext, |
38 |
| - underlying_type: mypy.types.Type) -> mypy.types.Type: |
| 92 | +def _set_method_sig_hook(ctx: mypy.plugin.MethodSigContext) -> mypy.types.CallableType: |
| 93 | + """ |
| 94 | + Patches up the signature of Attribute.__set__ to respect attribute's nullability. |
| 95 | + """ |
| 96 | + sig = ctx.default_signature |
| 97 | + if _is_attribute_marked_nullable(ctx.type): |
| 98 | + return sig |
| 99 | + try: |
| 100 | + (instance_type, value_type) = sig.arg_types |
| 101 | + except ValueError: # pragma: no cover |
| 102 | + return sig |
| 103 | + return sig.copy_modified(arg_types=[instance_type, _unwrap_optional(value_type)]) |
| 104 | + |
| 105 | + |
| 106 | +def _attribute_instantiation_hook(ctx: mypy.plugin.FunctionContext) -> mypy.types.Type: |
39 | 107 | """
|
40 | 108 | Handles attribute instantiation, e.g. MyAttribute(null=True)
|
41 | 109 | """
|
42 | 110 | args = dict(zip(ctx.callee_arg_names, ctx.args))
|
43 | 111 |
|
44 |
| - # If initializer is passed null=True, wrap in _NullableAttribute |
45 |
| - # to make the underlying type optional |
| 112 | + # If initializer is passed null=True, mark attribute type instance as nullable |
46 | 113 | null_arg_exprs = args.get('null')
|
| 114 | + nullable = False |
47 | 115 | if null_arg_exprs and len(null_arg_exprs) == 1:
|
48 |
| - (null_arg_expr,) = null_arg_exprs |
49 |
| - if ( |
50 |
| - not isinstance(null_arg_expr, NameExpr) or |
51 |
| - null_arg_expr.fullname not in ('builtins.False', 'builtins.True') |
52 |
| - ): |
53 |
| - ctx.api.fail("'null' argument is not constant False or True, " |
54 |
| - "cannot deduce optionality", ctx.context) |
55 |
| - return ctx.default_return_type |
56 |
| - |
57 |
| - if null_arg_expr.fullname == 'builtins.True': |
58 |
| - return ctx.api.named_generic_type(NULL_ATTR_WRAPPER_FULL_NAME, [ |
59 |
| - ctx.default_return_type, |
60 |
| - underlying_type, |
61 |
| - ]) |
62 |
| - |
63 |
| - return ctx.default_return_type |
| 116 | + null_literal = _get_bool_literal(null_arg_exprs[0]) |
| 117 | + if null_literal is not None: |
| 118 | + nullable = null_literal |
| 119 | + else: |
| 120 | + ctx.api.fail("'null' argument is not constant False or True, cannot deduce optionality", ctx.context) |
| 121 | + |
| 122 | + assert isinstance(ctx.default_return_type, mypy.types.Instance) |
| 123 | + return _attribute_marked_as_nullable(ctx.default_return_type) if nullable else ctx.default_return_type |
0 commit comments