|
23 | 23 | "from fastcore.dispatch import *\n",
|
24 | 24 | "import inspect\n",
|
25 | 25 | "from copy import copy\n",
|
26 |
| - "from plum import add_conversion_method, dispatch, Function" |
| 26 | + "from plum import add_conversion_method, dispatch, Function\n", |
| 27 | + "from typing import get_args, get_origin" |
27 | 28 | ]
|
28 | 29 | },
|
29 | 30 | {
|
|
69 | 70 | "outputs": [],
|
70 | 71 | "source": [
|
71 | 72 | "#export\n",
|
72 |
| - "_tfm_methods = 'encodes','decodes','setups'\n", |
| 73 | + "# Convert tuple annotations to unions to work with plum\n", |
| 74 | + "def _annot_tuple_to_union(f):\n", |
| 75 | + " for k, v in type_hints(f).items():\n", |
| 76 | + " if isinstance(v, tuple): f.__annotations__[k] = Union[v]\n", |
| 77 | + " return f\n", |
73 | 78 | "\n",
|
74 |
| - "def _is_tfm_method(f,n):\n", |
75 |
| - " return n in _tfm_methods and callable(f)\n", |
| 79 | + "def _dispatch(f): return dispatch(_annot_tuple_to_union(f))\n", |
76 | 80 | "\n",
|
77 |
| - "class _TfmDict(dict):\n", |
78 |
| - " def __setitem__(self,k,v):\n", |
79 |
| - " if _is_tfm_method(v,k): v = dispatch(v)\n", |
80 |
| - " super().__setitem__(k,v)" |
81 |
| - ] |
82 |
| - }, |
83 |
| - { |
84 |
| - "cell_type": "code", |
85 |
| - "execution_count": null, |
86 |
| - "metadata": {}, |
87 |
| - "outputs": [], |
88 |
| - "source": [ |
89 |
| - "#export\n", |
90 |
| - "def _dispatch(f, cls):\n", |
91 |
| - " \"Dispatch and set a function as an instance method\"\n", |
| 81 | + "def _dispatch_method(f, cls):\n", |
92 | 82 | " f = copy(f)\n",
|
93 | 83 | " n = f.__name__\n",
|
94 | 84 | " # plum uses __qualname__ to infer f's owner\n",
|
95 | 85 | " f.__qualname__ = f'{cls.__name__}.{n}'\n",
|
96 |
| - " pf = dispatch(f)\n", |
| 86 | + " pf = _dispatch(f)\n", |
97 | 87 | " setattr(cls, n, pf)\n",
|
98 | 88 | " # plum uses __set_name__ to resolve a plum.Function's owner.\n",
|
99 | 89 | " # since we assign after class creation, __set_name__ must be called directly\n",
|
100 | 90 | " # source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__\n",
|
101 | 91 | " pf.__set_name__(cls, n)\n",
|
102 |
| - " return pf" |
| 92 | + " return pf\n", |
| 93 | + "\n", |
| 94 | + "def _pf_dispatch(pf, f): return pf.dispatch(_annot_tuple_to_union(f))" |
| 95 | + ] |
| 96 | + }, |
| 97 | + { |
| 98 | + "cell_type": "code", |
| 99 | + "execution_count": null, |
| 100 | + "metadata": {}, |
| 101 | + "outputs": [], |
| 102 | + "source": [ |
| 103 | + "#export\n", |
| 104 | + "_tfm_methods = 'encodes','decodes','setups'\n", |
| 105 | + "\n", |
| 106 | + "def _is_tfm_method(f, n): return n in _tfm_methods and callable(f)\n", |
| 107 | + "\n", |
| 108 | + "class _TfmDict(dict):\n", |
| 109 | + " def __setitem__(self, k, v):\n", |
| 110 | + " if _is_tfm_method(v, k): v = _dispatch(v)\n", |
| 111 | + " super().__setitem__(k, v)" |
103 | 112 | ]
|
104 | 113 | },
|
105 | 114 | {
|
|
119 | 128 | " def __call__(cls, *args, **kwargs):\n",
|
120 | 129 | " f = first(args)\n",
|
121 | 130 | " n = getattr(f, '__name__', None)\n",
|
122 |
| - " if _is_tfm_method(f, n): return _dispatch(f, cls)\n", |
123 |
| - " obj = super().__call__(*args,**kwargs)\n", |
| 131 | + " if _is_tfm_method(f, n): return _dispatch_method(f, cls)\n", |
| 132 | + " obj = super().__call__(*args, **kwargs)\n", |
124 | 133 | " # _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable\n",
|
125 | 134 | " # instances of cls, fix it\n",
|
126 | 135 | " if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)\n",
|
|
160 | 169 | "outputs": [],
|
161 | 170 | "source": [
|
162 | 171 | "#export\n",
|
163 |
| - "@dispatch\n", |
164 | 172 | "def _pt_repr(o):\n",
|
165 | 173 | " n = type(o).__name__\n",
|
166 | 174 | " if n == 'Tuple': return f\"{n.lower()}[{','.join(_pt_repr(t) for t in o._el_types)}]\"\n",
|
167 | 175 | " if n == 'List': return f'{n.lower()}[{_pt_repr(o._el_type)}]'\n",
|
168 | 176 | " if n == 'Dict': return f'{n.lower()}[{_pt_repr(o._key_type)},{_pt_repr(o._value_type)}]'\n",
|
169 | 177 | " if n in ('Sequence','Iterable'): return f'{n}[{_pt_repr(o._el_type)}]'\n",
|
170 |
| - " if n == 'VarArgs': return f'VarArgs[{_pt_repr(o.type)}]'\n", |
171 |
| - " return '|'.join(t.__name__ for t in o.get_types())" |
| 178 | + " if n == 'VarArgs': return f'{n}[{_pt_repr(o.type)}]'\n", |
| 179 | + " if n == 'Union': return '|'.join(sorted(t.__name__ for t in (o.get_types())))\n", |
| 180 | + " assert len(o.get_types()) == 1\n", |
| 181 | + " return o.get_types()[0].__name__" |
172 | 182 | ]
|
173 | 183 | },
|
174 | 184 | {
|
|
188 | 198 | "test_eq(_pt_repr(ptype(Iterable[int])), 'Iterable[int]')\n",
|
189 | 199 | "test_eq(_pt_repr(ptype(Dict[str, int])), 'dict[str,int]')\n",
|
190 | 200 | "test_eq(_pt_repr(ptype(VarArgs[str])), 'VarArgs[str]')\n",
|
191 |
| - "test_eq(_pt_repr(ptype(Dict[Tuple[Union[int,str],float], List[Tuple[object]]])), 'dict[tuple[int|str,float],list[tuple[object]]]')" |
| 201 | + "test_eq(_pt_repr(ptype(Dict[Tuple[Union[int,str],float], List[Tuple[object]]])),\n", |
| 202 | + " 'dict[tuple[int|str,float],list[tuple[object]]]')" |
192 | 203 | ]
|
193 | 204 | },
|
194 | 205 | {
|
|
233 | 244 | " def identity(x): return x\n",
|
234 | 245 | " for n in _tfm_methods: setattr(self,n,Function(identity).dispatch(identity))\n",
|
235 | 246 | " if enc:\n",
|
236 |
| - " self.encodes.dispatch(enc)\n", |
| 247 | + " _pf_dispatch(self.encodes, enc)\n", |
237 | 248 | " self.order = getattr(enc,'order',self.order)\n",
|
238 |
| - " if len(type_hints(enc)) > 0: self.input_types = first(type_hints(enc).values())\n", |
| 249 | + " if len(type_hints(enc)) > 0:\n", |
| 250 | + " self.input_types = first(type_hints(enc).values())\n", |
| 251 | + " # Convert Union to tuple, remove once the rest of fastai supports Union\n", |
| 252 | + " if get_origin(self.input_types) is Union: self.input_types=get_args(self.input_types)\n", |
239 | 253 | " self._name = _get_name(enc)\n",
|
240 |
| - " if dec: self.decodes.dispatch(dec)\n", |
| 254 | + " if dec: _pf_dispatch(self.decodes, dec)\n", |
241 | 255 | "\n",
|
242 | 256 | " @property\n",
|
243 | 257 | " def name(self): return getattr(self, '_name', _get_name(self))\n",
|
|
292 | 306 | "text/markdown": [
|
293 | 307 | "<h2 id=\"Transform\" class=\"doc_header\"><code>class</code> <code>Transform</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h2>\n",
|
294 | 308 | "\n",
|
295 |
| - "> <code>Transform</code>(**`enc`**=*`None`*, **`dec`**=*`None`*, **`split_idx`**=*`None`*, **`order`**=*`None`*)\n", |
| 309 | + "> <code>Transform</code>(**`self`**, **`enc`**=*`None`*, **`dec`**=*`None`*, **`split_idx`**=*`None`*, **`order`**=*`None`*)\n", |
296 | 310 | "\n",
|
297 | 311 | "Delegates (`__call__`,`decode`,`setup`) to (<code>encodes</code>,<code>decodes</code>,<code>setups</code>) if `split_idx` matches"
|
298 | 312 | ],
|
|
473 | 487 | "test_eq_type(f3(2), 2)"
|
474 | 488 | ]
|
475 | 489 | },
|
| 490 | + { |
| 491 | + "cell_type": "markdown", |
| 492 | + "metadata": {}, |
| 493 | + "source": [ |
| 494 | + "Transforms can be created from class methods too:" |
| 495 | + ] |
| 496 | + }, |
| 497 | + { |
| 498 | + "cell_type": "code", |
| 499 | + "execution_count": null, |
| 500 | + "metadata": {}, |
| 501 | + "outputs": [], |
| 502 | + "source": [ |
| 503 | + "class A:\n", |
| 504 | + " @classmethod\n", |
| 505 | + " def create(cls, x:int): return x+1\n", |
| 506 | + "test_eq(Transform(A.create)(1), 2)" |
| 507 | + ] |
| 508 | + }, |
476 | 509 | {
|
477 | 510 | "cell_type": "markdown",
|
478 | 511 | "metadata": {},
|
|
607 | 640 | "class MyClass(int): pass\n",
|
608 | 641 | "\n",
|
609 | 642 | "class A(Transform):\n",
|
610 |
| - " def encodes(self, x:Union[MyClass,float]): return x/2\n", |
611 |
| - " def encodes(self, x:Union[str,list]): return str(x)+'_1'\n", |
| 643 | + " def encodes(self, x:(MyClass,float)): return x/2\n", |
| 644 | + " def encodes(self, x:(str,list)): return str(x)+'_1'\n", |
612 | 645 | "\n",
|
613 | 646 | "f = A()"
|
614 | 647 | ]
|
|
647 | 680 | "test_eq(f(['a','b','c']), \"['a', 'b', 'c']_1\") # input is of type list"
|
648 | 681 | ]
|
649 | 682 | },
|
| 683 | + { |
| 684 | + "cell_type": "code", |
| 685 | + "execution_count": null, |
| 686 | + "metadata": {}, |
| 687 | + "outputs": [], |
| 688 | + "source": [ |
| 689 | + "@Transform\n", |
| 690 | + "def f(x:(int,float)): return x+1\n", |
| 691 | + "test_eq(f(0), 1)\n", |
| 692 | + "test_eq(f('a'), 'a')" |
| 693 | + ] |
| 694 | + }, |
650 | 695 | {
|
651 | 696 | "cell_type": "markdown",
|
652 | 697 | "metadata": {},
|
|
901 | 946 | {
|
902 | 947 | "data": {
|
903 | 948 | "text/plain": [
|
904 |
| - "Promise(obj=<function <function AL.encodes at 0x1220569d0> with 2 method(s)>)" |
| 949 | + "Promise(obj=<function <function AL.encodes at 0x11e3fb670> with 2 method(s)>)" |
905 | 950 | ]
|
906 | 951 | },
|
907 | 952 | "execution_count": null,
|
|
1364 | 1409 | "test_eq(type(f.decode((1,1))), _T)"
|
1365 | 1410 | ]
|
1366 | 1411 | },
|
1367 |
| - { |
1368 |
| - "cell_type": "markdown", |
1369 |
| - "metadata": {}, |
1370 |
| - "source": [ |
1371 |
| - "#### Transform tests -" |
1372 |
| - ] |
1373 |
| - }, |
1374 |
| - { |
1375 |
| - "cell_type": "code", |
1376 |
| - "execution_count": null, |
1377 |
| - "metadata": {}, |
1378 |
| - "outputs": [], |
1379 |
| - "source": [] |
1380 |
| - }, |
1381 | 1412 | {
|
1382 | 1413 | "cell_type": "markdown",
|
1383 | 1414 | "metadata": {},
|
|
0 commit comments