Skip to content

Commit 5d1a117

Browse files
committed
works with tupple annotations; working autoreload; repr
1 parent 7c5bb3c commit 5d1a117

File tree

3 files changed

+121
-68
lines changed

3 files changed

+121
-68
lines changed

fastcore/imports.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@
1414
MethodDescriptorType = type(str.join)
1515
from types import BuiltinFunctionType,BuiltinMethodType,MethodType,FunctionType,SimpleNamespace
1616

17+
#Patch autoreload (if its loaded) to work with plum
18+
try: from IPython import get_ipython
19+
except ImportError: pass
20+
else:
21+
ip = get_ipython()
22+
if ip is not None and 'IPython.extensions.storemagic' in ip.extension_manager.loaded:
23+
from plum.autoreload import activate
24+
activate()
25+
1726
NoneType = type(None)
1827
string_classes = (str,bytes)
1928

fastcore/transform.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,33 +16,42 @@
1616
import inspect
1717
from copy import copy
1818
from plum import add_conversion_method, dispatch, Function
19+
from typing import get_args, get_origin
1920

2021
# Cell
21-
_tfm_methods = 'encodes','decodes','setups'
22-
23-
def _is_tfm_method(f,n):
24-
return n in _tfm_methods and callable(f)
22+
# Convert tuple annotations to unions to work with plum
23+
def _annot_tuple_to_union(f):
24+
for k, v in type_hints(f).items():
25+
if isinstance(v, tuple): f.__annotations__[k] = Union[v]
26+
return f
2527

26-
class _TfmDict(dict):
27-
def __setitem__(self,k,v):
28-
if _is_tfm_method(v,k): v = dispatch(v)
29-
super().__setitem__(k,v)
28+
def _dispatch(f): return dispatch(_annot_tuple_to_union(f))
3029

31-
# Cell
32-
def _dispatch(f, cls):
33-
"Dispatch and set a function as an instance method"
30+
def _dispatch_method(f, cls):
3431
f = copy(f)
3532
n = f.__name__
3633
# plum uses __qualname__ to infer f's owner
3734
f.__qualname__ = f'{cls.__name__}.{n}'
38-
pf = dispatch(f)
35+
pf = _dispatch(f)
3936
setattr(cls, n, pf)
4037
# plum uses __set_name__ to resolve a plum.Function's owner.
4138
# since we assign after class creation, __set_name__ must be called directly
4239
# source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__
4340
pf.__set_name__(cls, n)
4441
return pf
4542

43+
def _pf_dispatch(pf, f): return pf.dispatch(_annot_tuple_to_union(f))
44+
45+
# Cell
46+
_tfm_methods = 'encodes','decodes','setups'
47+
48+
def _is_tfm_method(f, n): return n in _tfm_methods and callable(f)
49+
50+
class _TfmDict(dict):
51+
def __setitem__(self, k, v):
52+
if _is_tfm_method(v, k): v = _dispatch(v)
53+
super().__setitem__(k, v)
54+
4655
# Cell
4756
class _TfmMeta(type):
4857
def __new__(cls, name, bases, dict):
@@ -54,8 +63,8 @@ def __new__(cls, name, bases, dict):
5463
def __call__(cls, *args, **kwargs):
5564
f = first(args)
5665
n = getattr(f, '__name__', None)
57-
if _is_tfm_method(f, n): return _dispatch(f, cls)
58-
obj = super().__call__(*args,**kwargs)
66+
if _is_tfm_method(f, n): return _dispatch_method(f, cls)
67+
obj = super().__call__(*args, **kwargs)
5968
# _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable
6069
# instances of cls, fix it
6170
if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)
@@ -74,15 +83,16 @@ def _get_name(o):
7483
def _is_tuple(o): return isinstance(o, tuple) and not hasattr(o, '_fields')
7584

7685
# Cell
77-
@dispatch
7886
def _pt_repr(o):
7987
n = type(o).__name__
8088
if n == 'Tuple': return f"{n.lower()}[{','.join(_pt_repr(t) for t in o._el_types)}]"
8189
if n == 'List': return f'{n.lower()}[{_pt_repr(o._el_type)}]'
8290
if n == 'Dict': return f'{n.lower()}[{_pt_repr(o._key_type)},{_pt_repr(o._value_type)}]'
8391
if n in ('Sequence','Iterable'): return f'{n}[{_pt_repr(o._el_type)}]'
84-
if n == 'VarArgs': return f'VarArgs[{_pt_repr(o.type)}]'
85-
return '|'.join(t.__name__ for t in o.get_types())
92+
if n == 'VarArgs': return f'{n}[{_pt_repr(o.type)}]'
93+
if n == 'Union': return '|'.join(sorted(t.__name__ for t in (o.get_types())))
94+
assert len(o.get_types()) == 1
95+
return o.get_types()[0].__name__
8696

8797
# Cell
8898
def _pf_repr(pf): return '\n'.join(f"{f.__name__}: ({','.join(_pt_repr(t) for t in s.types)}) -> {_pt_repr(r)}"
@@ -101,11 +111,14 @@ def __init__(self, enc=None, dec=None, split_idx=None, order=None):
101111
def identity(x): return x
102112
for n in _tfm_methods: setattr(self,n,Function(identity).dispatch(identity))
103113
if enc:
104-
self.encodes.dispatch(enc)
114+
_pf_dispatch(self.encodes, enc)
105115
self.order = getattr(enc,'order',self.order)
106-
if len(type_hints(enc)) > 0: self.input_types = first(type_hints(enc).values())
116+
if len(type_hints(enc)) > 0:
117+
self.input_types = first(type_hints(enc).values())
118+
# Convert Union to tuple, remove once the rest of fastai supports Union
119+
if get_origin(self.input_types) is Union: self.input_types=get_args(self.input_types)
107120
self._name = _get_name(enc)
108-
if dec: self.decodes.dispatch(dec)
121+
if dec: _pf_dispatch(self.decodes, dec)
109122

110123
@property
111124
def name(self): return getattr(self, '_name', _get_name(self))

nbs/05_transform.ipynb

Lines changed: 79 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
"from fastcore.dispatch import *\n",
2424
"import inspect\n",
2525
"from copy import copy\n",
26-
"from plum import add_conversion_method, dispatch, Function"
26+
"from plum import add_conversion_method, dispatch, Function\n",
27+
"from typing import get_args, get_origin"
2728
]
2829
},
2930
{
@@ -69,37 +70,45 @@
6970
"outputs": [],
7071
"source": [
7172
"#export\n",
72-
"_tfm_methods = 'encodes','decodes','setups'\n",
73+
"# Convert tuple annotations to unions to work with plum\n",
74+
"def _annot_tuple_to_union(f):\n",
75+
" for k, v in type_hints(f).items():\n",
76+
" if isinstance(v, tuple): f.__annotations__[k] = Union[v]\n",
77+
" return f\n",
7378
"\n",
74-
"def _is_tfm_method(f,n):\n",
75-
" return n in _tfm_methods and callable(f)\n",
79+
"def _dispatch(f): return dispatch(_annot_tuple_to_union(f))\n",
7680
"\n",
77-
"class _TfmDict(dict):\n",
78-
" def __setitem__(self,k,v):\n",
79-
" if _is_tfm_method(v,k): v = dispatch(v)\n",
80-
" super().__setitem__(k,v)"
81-
]
82-
},
83-
{
84-
"cell_type": "code",
85-
"execution_count": null,
86-
"metadata": {},
87-
"outputs": [],
88-
"source": [
89-
"#export\n",
90-
"def _dispatch(f, cls):\n",
91-
" \"Dispatch and set a function as an instance method\"\n",
81+
"def _dispatch_method(f, cls):\n",
9282
" f = copy(f)\n",
9383
" n = f.__name__\n",
9484
" # plum uses __qualname__ to infer f's owner\n",
9585
" f.__qualname__ = f'{cls.__name__}.{n}'\n",
96-
" pf = dispatch(f)\n",
86+
" pf = _dispatch(f)\n",
9787
" setattr(cls, n, pf)\n",
9888
" # plum uses __set_name__ to resolve a plum.Function's owner.\n",
9989
" # since we assign after class creation, __set_name__ must be called directly\n",
10090
" # source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__\n",
10191
" pf.__set_name__(cls, n)\n",
102-
" return pf"
92+
" return pf\n",
93+
"\n",
94+
"def _pf_dispatch(pf, f): return pf.dispatch(_annot_tuple_to_union(f))"
95+
]
96+
},
97+
{
98+
"cell_type": "code",
99+
"execution_count": null,
100+
"metadata": {},
101+
"outputs": [],
102+
"source": [
103+
"#export\n",
104+
"_tfm_methods = 'encodes','decodes','setups'\n",
105+
"\n",
106+
"def _is_tfm_method(f, n): return n in _tfm_methods and callable(f)\n",
107+
"\n",
108+
"class _TfmDict(dict):\n",
109+
" def __setitem__(self, k, v):\n",
110+
" if _is_tfm_method(v, k): v = _dispatch(v)\n",
111+
" super().__setitem__(k, v)"
103112
]
104113
},
105114
{
@@ -119,8 +128,8 @@
119128
" def __call__(cls, *args, **kwargs):\n",
120129
" f = first(args)\n",
121130
" n = getattr(f, '__name__', None)\n",
122-
" if _is_tfm_method(f, n): return _dispatch(f, cls)\n",
123-
" obj = super().__call__(*args,**kwargs)\n",
131+
" if _is_tfm_method(f, n): return _dispatch_method(f, cls)\n",
132+
" obj = super().__call__(*args, **kwargs)\n",
124133
" # _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable\n",
125134
" # instances of cls, fix it\n",
126135
" if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)\n",
@@ -160,15 +169,16 @@
160169
"outputs": [],
161170
"source": [
162171
"#export\n",
163-
"@dispatch\n",
164172
"def _pt_repr(o):\n",
165173
" n = type(o).__name__\n",
166174
" if n == 'Tuple': return f\"{n.lower()}[{','.join(_pt_repr(t) for t in o._el_types)}]\"\n",
167175
" if n == 'List': return f'{n.lower()}[{_pt_repr(o._el_type)}]'\n",
168176
" if n == 'Dict': return f'{n.lower()}[{_pt_repr(o._key_type)},{_pt_repr(o._value_type)}]'\n",
169177
" if n in ('Sequence','Iterable'): return f'{n}[{_pt_repr(o._el_type)}]'\n",
170-
" if n == 'VarArgs': return f'VarArgs[{_pt_repr(o.type)}]'\n",
171-
" return '|'.join(t.__name__ for t in o.get_types())"
178+
" if n == 'VarArgs': return f'{n}[{_pt_repr(o.type)}]'\n",
179+
" if n == 'Union': return '|'.join(sorted(t.__name__ for t in (o.get_types())))\n",
180+
" assert len(o.get_types()) == 1\n",
181+
" return o.get_types()[0].__name__"
172182
]
173183
},
174184
{
@@ -188,7 +198,8 @@
188198
"test_eq(_pt_repr(ptype(Iterable[int])), 'Iterable[int]')\n",
189199
"test_eq(_pt_repr(ptype(Dict[str, int])), 'dict[str,int]')\n",
190200
"test_eq(_pt_repr(ptype(VarArgs[str])), 'VarArgs[str]')\n",
191-
"test_eq(_pt_repr(ptype(Dict[Tuple[Union[int,str],float], List[Tuple[object]]])), 'dict[tuple[int|str,float],list[tuple[object]]]')"
201+
"test_eq(_pt_repr(ptype(Dict[Tuple[Union[int,str],float], List[Tuple[object]]])),\n",
202+
" 'dict[tuple[int|str,float],list[tuple[object]]]')"
192203
]
193204
},
194205
{
@@ -233,11 +244,14 @@
233244
" def identity(x): return x\n",
234245
" for n in _tfm_methods: setattr(self,n,Function(identity).dispatch(identity))\n",
235246
" if enc:\n",
236-
" self.encodes.dispatch(enc)\n",
247+
" _pf_dispatch(self.encodes, enc)\n",
237248
" self.order = getattr(enc,'order',self.order)\n",
238-
" if len(type_hints(enc)) > 0: self.input_types = first(type_hints(enc).values())\n",
249+
" if len(type_hints(enc)) > 0:\n",
250+
" self.input_types = first(type_hints(enc).values())\n",
251+
" # Convert Union to tuple, remove once the rest of fastai supports Union\n",
252+
" if get_origin(self.input_types) is Union: self.input_types=get_args(self.input_types)\n",
239253
" self._name = _get_name(enc)\n",
240-
" if dec: self.decodes.dispatch(dec)\n",
254+
" if dec: _pf_dispatch(self.decodes, dec)\n",
241255
"\n",
242256
" @property\n",
243257
" def name(self): return getattr(self, '_name', _get_name(self))\n",
@@ -292,7 +306,7 @@
292306
"text/markdown": [
293307
"<h2 id=\"Transform\" class=\"doc_header\"><code>class</code> <code>Transform</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h2>\n",
294308
"\n",
295-
"> <code>Transform</code>(**`enc`**=*`None`*, **`dec`**=*`None`*, **`split_idx`**=*`None`*, **`order`**=*`None`*)\n",
309+
"> <code>Transform</code>(**`self`**, **`enc`**=*`None`*, **`dec`**=*`None`*, **`split_idx`**=*`None`*, **`order`**=*`None`*)\n",
296310
"\n",
297311
"Delegates (`__call__`,`decode`,`setup`) to (<code>encodes</code>,<code>decodes</code>,<code>setups</code>) if `split_idx` matches"
298312
],
@@ -473,6 +487,25 @@
473487
"test_eq_type(f3(2), 2)"
474488
]
475489
},
490+
{
491+
"cell_type": "markdown",
492+
"metadata": {},
493+
"source": [
494+
"Transforms can be created from class methods too:"
495+
]
496+
},
497+
{
498+
"cell_type": "code",
499+
"execution_count": null,
500+
"metadata": {},
501+
"outputs": [],
502+
"source": [
503+
"class A:\n",
504+
" @classmethod\n",
505+
" def create(cls, x:int): return x+1\n",
506+
"test_eq(Transform(A.create)(1), 2)"
507+
]
508+
},
476509
{
477510
"cell_type": "markdown",
478511
"metadata": {},
@@ -607,8 +640,8 @@
607640
"class MyClass(int): pass\n",
608641
"\n",
609642
"class A(Transform):\n",
610-
" def encodes(self, x:Union[MyClass,float]): return x/2\n",
611-
" def encodes(self, x:Union[str,list]): return str(x)+'_1'\n",
643+
" def encodes(self, x:(MyClass,float)): return x/2\n",
644+
" def encodes(self, x:(str,list)): return str(x)+'_1'\n",
612645
"\n",
613646
"f = A()"
614647
]
@@ -647,6 +680,18 @@
647680
"test_eq(f(['a','b','c']), \"['a', 'b', 'c']_1\") # input is of type list"
648681
]
649682
},
683+
{
684+
"cell_type": "code",
685+
"execution_count": null,
686+
"metadata": {},
687+
"outputs": [],
688+
"source": [
689+
"@Transform\n",
690+
"def f(x:(int,float)): return x+1\n",
691+
"test_eq(f(0), 1)\n",
692+
"test_eq(f('a'), 'a')"
693+
]
694+
},
650695
{
651696
"cell_type": "markdown",
652697
"metadata": {},
@@ -901,7 +946,7 @@
901946
{
902947
"data": {
903948
"text/plain": [
904-
"Promise(obj=<function <function AL.encodes at 0x1220569d0> with 2 method(s)>)"
949+
"Promise(obj=<function <function AL.encodes at 0x11e3fb670> with 2 method(s)>)"
905950
]
906951
},
907952
"execution_count": null,
@@ -1364,20 +1409,6 @@
13641409
"test_eq(type(f.decode((1,1))), _T)"
13651410
]
13661411
},
1367-
{
1368-
"cell_type": "markdown",
1369-
"metadata": {},
1370-
"source": [
1371-
"#### Transform tests -"
1372-
]
1373-
},
1374-
{
1375-
"cell_type": "code",
1376-
"execution_count": null,
1377-
"metadata": {},
1378-
"outputs": [],
1379-
"source": []
1380-
},
13811412
{
13821413
"cell_type": "markdown",
13831414
"metadata": {},

0 commit comments

Comments
 (0)