|
69 | 69 | "#export\n",
|
70 | 70 | "_tfm_methods = 'encodes','decodes','setups'\n",
|
71 | 71 | "\n",
|
| 72 | + "def _is_tfm_method(n, f): return n in _tfm_methods and callable(f)\n", |
| 73 | + "\n", |
72 | 74 | "class _TfmDict(dict):\n",
|
73 |
| - " def __setitem__(self,k,v):\n", |
74 |
| - " if k not in _tfm_methods or not callable(v): return super().__setitem__(k,v)\n", |
| 75 | + " def __setitem__(self, k, v):\n", |
| 76 | + " if not _is_tfm_method(k, v): return super().__setitem__(k,v)\n", |
75 | 77 | " if k not in self: super().__setitem__(k,TypeDispatch())\n",
|
76 | 78 | " self[k].add(v)"
|
77 | 79 | ]
|
|
90 | 92 | " base_td = [getattr(b,nm,None) for b in bases]\n",
|
91 | 93 | " if nm in res.__dict__: getattr(res,nm).bases = base_td\n",
|
92 | 94 | " else: setattr(res, nm, TypeDispatch(bases=base_td))\n",
|
| 95 | + " # _TfmMeta.__call__ shadows the signature of inheriting classes, set it back\n", |
93 | 96 | " res.__signature__ = inspect.signature(res.__init__)\n",
|
94 | 97 | " return res\n",
|
95 | 98 | "\n",
|
96 | 99 | " def __call__(cls, *args, **kwargs):\n",
|
97 |
| - " f = args[0] if args else None\n", |
98 |
| - " n = getattr(f,'__name__',None)\n", |
99 |
| - " if callable(f) and n in _tfm_methods:\n", |
| 100 | + " f = first(args)\n", |
| 101 | + " n = getattr(f, '__name__', None)\n", |
| 102 | + " if _is_tfm_method(n, f):\n", |
100 | 103 | " getattr(cls,n).add(f)\n",
|
101 | 104 | " return f\n",
|
102 |
| - " return super().__call__(*args, **kwargs)\n", |
| 105 | + " obj = super().__call__(*args, **kwargs)\n", |
| 106 | + " # _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable\n", |
| 107 | + " # instances of cls, fix it\n", |
| 108 | + " if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)\n", |
| 109 | + " return obj\n", |
103 | 110 | "\n",
|
104 | 111 | " @classmethod\n",
|
105 | 112 | " def __prepare__(cls, name, bases): return _TfmDict()"
|
|
368 | 375 | "test_eq_type(f3(2), 2)"
|
369 | 376 | ]
|
370 | 377 | },
|
| 378 | + { |
| 379 | + "cell_type": "markdown", |
| 380 | + "metadata": {}, |
| 381 | + "source": [ |
| 382 | + "Transforms can be created from class methods too:" |
| 383 | + ] |
| 384 | + }, |
| 385 | + { |
| 386 | + "cell_type": "code", |
| 387 | + "execution_count": null, |
| 388 | + "metadata": {}, |
| 389 | + "outputs": [], |
| 390 | + "source": [ |
| 391 | + "class A:\n", |
| 392 | + " @classmethod\n", |
| 393 | + " def create(cls, x:int): return x+1\n", |
| 394 | + "test_eq(Transform(A.create)(1), 2)" |
| 395 | + ] |
| 396 | + }, |
| 397 | + { |
| 398 | + "cell_type": "code", |
| 399 | + "execution_count": null, |
| 400 | + "metadata": {}, |
| 401 | + "outputs": [], |
| 402 | + "source": [ |
| 403 | + "#hide\n", |
| 404 | + "# Test extension of a tfm method defined in the class\n", |
| 405 | + "class A(Transform):\n", |
| 406 | + " def encodes(self, x): return 'obj'\n", |
| 407 | + "\n", |
| 408 | + "@A\n", |
| 409 | + "def encodes(self, x:int): return 'int'\n", |
| 410 | + "\n", |
| 411 | + "a = A()\n", |
| 412 | + "test_eq(a.encodes(0), 'int')\n", |
| 413 | + "test_eq(a.encodes(0.0), 'obj')" |
| 414 | + ] |
| 415 | + }, |
371 | 416 | {
|
372 | 417 | "cell_type": "markdown",
|
373 | 418 | "metadata": {},
|
|
845 | 890 | "def encodes(self, x:str): return x+'hello'\n",
|
846 | 891 | "\n",
|
847 | 892 | "@B\n",
|
848 |
| - "def encodes(self, x)->None: return str(x)+'!'" |
| 893 | + "def encodes(self, x): return str(x)+'!'" |
849 | 894 | ]
|
850 | 895 | },
|
851 | 896 | {
|
|
0 commit comments