Skip to content

Commit 71ec9e9

Browse files
authored
Merge pull request #424 from seeM/refactor-transform
refactor transform
2 parents 06922b7 + a52b382 commit 71ec9e9

File tree

2 files changed

+65
-13
lines changed

2 files changed

+65
-13
lines changed

fastcore/transform.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# Cell
1414
_tfm_methods = 'encodes','decodes','setups'
1515

16+
def _is_tfm_method(n, f): return n in _tfm_methods and callable(f)
17+
1618
class _TfmDict(dict):
17-
def __setitem__(self,k,v):
18-
if k not in _tfm_methods or not callable(v): return super().__setitem__(k,v)
19+
def __setitem__(self, k, v):
20+
if not _is_tfm_method(k, v): return super().__setitem__(k,v)
1921
if k not in self: super().__setitem__(k,TypeDispatch())
2022
self[k].add(v)
2123

@@ -27,16 +29,21 @@ def __new__(cls, name, bases, dict):
2729
base_td = [getattr(b,nm,None) for b in bases]
2830
if nm in res.__dict__: getattr(res,nm).bases = base_td
2931
else: setattr(res, nm, TypeDispatch(bases=base_td))
32+
# _TfmMeta.__call__ shadows the signature of inheriting classes, set it back
3033
res.__signature__ = inspect.signature(res.__init__)
3134
return res
3235

3336
def __call__(cls, *args, **kwargs):
34-
f = args[0] if args else None
35-
n = getattr(f,'__name__',None)
36-
if callable(f) and n in _tfm_methods:
37+
f = first(args)
38+
n = getattr(f, '__name__', None)
39+
if _is_tfm_method(n, f):
3740
getattr(cls,n).add(f)
3841
return f
39-
return super().__call__(*args, **kwargs)
42+
obj = super().__call__(*args, **kwargs)
43+
# _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable
44+
# instances of cls, fix it
45+
if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)
46+
return obj
4047

4148
@classmethod
4249
def __prepare__(cls, name, bases): return _TfmDict()

nbs/05_transform.ipynb

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,11 @@
6969
"#export\n",
7070
"_tfm_methods = 'encodes','decodes','setups'\n",
7171
"\n",
72+
"def _is_tfm_method(n, f): return n in _tfm_methods and callable(f)\n",
73+
"\n",
7274
"class _TfmDict(dict):\n",
73-
" def __setitem__(self,k,v):\n",
74-
" if k not in _tfm_methods or not callable(v): return super().__setitem__(k,v)\n",
75+
" def __setitem__(self, k, v):\n",
76+
" if not _is_tfm_method(k, v): return super().__setitem__(k,v)\n",
7577
" if k not in self: super().__setitem__(k,TypeDispatch())\n",
7678
" self[k].add(v)"
7779
]
@@ -90,16 +92,21 @@
9092
" base_td = [getattr(b,nm,None) for b in bases]\n",
9193
" if nm in res.__dict__: getattr(res,nm).bases = base_td\n",
9294
" else: setattr(res, nm, TypeDispatch(bases=base_td))\n",
95+
" # _TfmMeta.__call__ shadows the signature of inheriting classes, set it back\n",
9396
" res.__signature__ = inspect.signature(res.__init__)\n",
9497
" return res\n",
9598
"\n",
9699
" def __call__(cls, *args, **kwargs):\n",
97-
" f = args[0] if args else None\n",
98-
" n = getattr(f,'__name__',None)\n",
99-
" if callable(f) and n in _tfm_methods:\n",
100+
" f = first(args)\n",
101+
" n = getattr(f, '__name__', None)\n",
102+
" if _is_tfm_method(n, f):\n",
100103
" getattr(cls,n).add(f)\n",
101104
" return f\n",
102-
" return super().__call__(*args, **kwargs)\n",
105+
" obj = super().__call__(*args, **kwargs)\n",
106+
" # _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable\n",
107+
" # instances of cls, fix it\n",
108+
" if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)\n",
109+
" return obj\n",
103110
"\n",
104111
" @classmethod\n",
105112
" def __prepare__(cls, name, bases): return _TfmDict()"
@@ -368,6 +375,44 @@
368375
"test_eq_type(f3(2), 2)"
369376
]
370377
},
378+
{
379+
"cell_type": "markdown",
380+
"metadata": {},
381+
"source": [
382+
"Transforms can be created from class methods too:"
383+
]
384+
},
385+
{
386+
"cell_type": "code",
387+
"execution_count": null,
388+
"metadata": {},
389+
"outputs": [],
390+
"source": [
391+
"class A:\n",
392+
" @classmethod\n",
393+
" def create(cls, x:int): return x+1\n",
394+
"test_eq(Transform(A.create)(1), 2)"
395+
]
396+
},
397+
{
398+
"cell_type": "code",
399+
"execution_count": null,
400+
"metadata": {},
401+
"outputs": [],
402+
"source": [
403+
"#hide\n",
404+
"# Test extension of a tfm method defined in the class\n",
405+
"class A(Transform):\n",
406+
" def encodes(self, x): return 'obj'\n",
407+
"\n",
408+
"@A\n",
409+
"def encodes(self, x:int): return 'int'\n",
410+
"\n",
411+
"a = A()\n",
412+
"test_eq(a.encodes(0), 'int')\n",
413+
"test_eq(a.encodes(0.0), 'obj')"
414+
]
415+
},
371416
{
372417
"cell_type": "markdown",
373418
"metadata": {},
@@ -845,7 +890,7 @@
845890
"def encodes(self, x:str): return x+'hello'\n",
846891
"\n",
847892
"@B\n",
848-
"def encodes(self, x)->None: return str(x)+'!'"
893+
"def encodes(self, x): return str(x)+'!'"
849894
]
850895
},
851896
{

0 commit comments

Comments
 (0)