Skip to content

Commit 7c5bb3c

Browse files
committed
cleanup
1 parent c1fc33b commit 7c5bb3c

File tree

2 files changed

+249
-127
lines changed

2 files changed

+249
-127
lines changed

fastcore/transform.py

Lines changed: 49 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -15,60 +15,51 @@
1515
from .dispatch import *
1616
import inspect
1717
from copy import copy
18-
from plum import add_conversion_method, dispatch, Dispatcher
19-
20-
# Cell
21-
# TODO: Shouldn't sig of method set first parameter type to self? how do i get that...
22-
# i.e. self assumes type(self) / self.__class__?
23-
def _mk_plum_func(d, n, f=None, cls=None):
24-
f = (lambda x: x) if f is None else copy(f)
25-
f.__name__ = n
26-
# plum uses __qualname__ to infer f's owner
27-
f.__qualname__ = n if cls is None else '.'.join([cls.__name__,n])
28-
# TODO: Shouldn't we create a Function here and dispatch from that?
29-
# We don't need this in the dispatch table i think...
30-
# TODO: Should Function take name qualname etc and have a .from_callable method?
31-
# since the func isn't even dispatched by default
32-
pf = d(f)
33-
if cls is not None:
34-
setattr(cls,n,pf)
35-
# plum uses __set_name__ to resolve a plum.Function's owner.
36-
# since we assign after class creation, __set_name__ must be called directly
37-
# source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__
38-
pf.__set_name__(cls,n)
39-
return pf
18+
from plum import add_conversion_method, dispatch, Function
4019

4120
# Cell
4221
_tfm_methods = 'encodes','decodes','setups'
4322

4423
def _is_tfm_method(f,n):
4524
return n in _tfm_methods and callable(f)
4625

47-
# TODO: Do we still need this given the fact that plum searches mro without needing them to be Functions?
4826
class _TfmDict(dict):
4927
def __setitem__(self,k,v):
5028
if _is_tfm_method(v,k): v = dispatch(v)
5129
super().__setitem__(k,v)
5230

31+
# Cell
32+
def _dispatch(f, cls):
33+
"Dispatch and set a function as an instance method"
34+
f = copy(f)
35+
n = f.__name__
36+
# plum uses __qualname__ to infer f's owner
37+
f.__qualname__ = f'{cls.__name__}.{n}'
38+
pf = dispatch(f)
39+
setattr(cls, n, pf)
40+
# plum uses __set_name__ to resolve a plum.Function's owner.
41+
# since we assign after class creation, __set_name__ must be called directly
42+
# source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__
43+
pf.__set_name__(cls, n)
44+
return pf
45+
5346
# Cell
5447
class _TfmMeta(type):
55-
# TODO: commenting since this breaks inspect.signature of Transform instances,
56-
# which then breaks inspect.signature of a partial of a Transform instance
57-
#def __new__(cls, name, bases, dict):
58-
# res = super().__new__(cls, name, bases, dict)
59-
# res.__signature__ = inspect.signature(res.__init__)
60-
# return res
61-
62-
# TODO: Can we move this to Transform.__init__? Then we don't reeeeeally need a metaclass anymore...?
63-
# Ohhhhh man, can we dispatch this? If called with callable, do this, and so on
48+
def __new__(cls, name, bases, dict):
49+
# _TfmMeta.__call__ shadows the signature of inheriting classes, set it back
50+
res = super().__new__(cls, name, bases, dict)
51+
res.__signature__ = inspect.signature(res.__init__)
52+
return res
53+
6454
def __call__(cls, *args, **kwargs):
6555
f = first(args)
66-
n = getattr(f,'__name__',None)
67-
if _is_tfm_method(f,n):
68-
# use __dict__ over hasattr since it excludes parent classes
69-
if n in cls.__dict__: return getattr(cls,n).dispatch(f)
70-
return _mk_plum_func(dispatch,n,f,cls)
71-
return super().__call__(*args,**kwargs)
56+
n = getattr(f, '__name__', None)
57+
if _is_tfm_method(f, n): return _dispatch(f, cls)
58+
obj = super().__call__(*args,**kwargs)
59+
# _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable
60+
# instances of cls, fix it
61+
if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)
62+
return obj
7263

7364
@classmethod
7465
def __prepare__(cls, name, bases): return _TfmDict()
@@ -82,6 +73,21 @@ def _get_name(o):
8273
# Cell
8374
def _is_tuple(o): return isinstance(o, tuple) and not hasattr(o, '_fields')
8475

76+
# Cell
77+
@dispatch
78+
def _pt_repr(o):
79+
n = type(o).__name__
80+
if n == 'Tuple': return f"{n.lower()}[{','.join(_pt_repr(t) for t in o._el_types)}]"
81+
if n == 'List': return f'{n.lower()}[{_pt_repr(o._el_type)}]'
82+
if n == 'Dict': return f'{n.lower()}[{_pt_repr(o._key_type)},{_pt_repr(o._value_type)}]'
83+
if n in ('Sequence','Iterable'): return f'{n}[{_pt_repr(o._el_type)}]'
84+
if n == 'VarArgs': return f'VarArgs[{_pt_repr(o.type)}]'
85+
return '|'.join(t.__name__ for t in o.get_types())
86+
87+
# Cell
88+
def _pf_repr(pf): return '\n'.join(f"{f.__name__}: ({','.join(_pt_repr(t) for t in s.types)}) -> {_pt_repr(r)}"
89+
for s, (f, r) in pf.methods.items())
90+
8591
# Cell
8692
class Transform(metaclass=_TfmMeta):
8793
"Delegates (`__call__`,`decode`,`setup`) to (<code>encodes</code>,<code>decodes</code>,<code>setups</code>) if `split_idx` matches"
@@ -92,10 +98,8 @@ def __init__(self, enc=None, dec=None, split_idx=None, order=None):
9298
self.init_enc = enc or dec
9399
if not self.init_enc: return
94100

95-
self._d = Dispatcher() # TODO: do we need to hold this reference?
96-
# TODO: do u have to set the name from the original func here? for pipelines to work
97-
# TODO: I don't think this is registering any methods! Why is a func needed then???
98-
for n in _tfm_methods: setattr(self,n,_mk_plum_func(self._d,n))
101+
def identity(x): return x
102+
for n in _tfm_methods: setattr(self,n,Function(identity).dispatch(identity))
99103
if enc:
100104
self.encodes.dispatch(enc)
101105
self.order = getattr(enc,'order',self.order)
@@ -107,18 +111,7 @@ def __init__(self, enc=None, dec=None, split_idx=None, order=None):
107111
def name(self): return getattr(self, '_name', _get_name(self))
108112
def __call__(self, x, **kwargs): return self._call('encodes', x, **kwargs)
109113
def decode (self, x, **kwargs): return self._call('decodes', x, **kwargs)
110-
def __repr__(self):
111-
def _pf_repr(pf):
112-
e = []
113-
for s, (f_, r) in self.encodes.methods.items():
114-
types = ','.join(str(o) for o in s.types)
115-
types = f'({types})'
116-
e.append(f'{f_.__name__}: {types} -> {r}')
117-
return e
118-
r = f'{self.name}:'
119-
r += '\n encodes:\n' + '\n'.join(' ' + o for o in _pf_repr(self.encodes))
120-
r += '\n decodes:\n' + '\n'.join(' ' + o for o in _pf_repr(self.decodes))
121-
return r
114+
def __repr__(self): return f'{self.name}:\nencodes: {_pf_repr(self.encodes)}\ndecodes: {_pf_repr(self.decodes)}'
122115

123116
def setup(self, items=None, train_setup=False):
124117
train_setup = train_setup if self.train_setup is None else self.train_setup
@@ -134,7 +127,7 @@ def _do_call(self, f, x, **kwargs):
134127
ts = [type(self),type(x)] if hasattr(f,'instance') else [type(x)]
135128
_, ret = f.resolve_method(*ts)
136129
ret = ret._type
137-
# plum reads empty return annot as object, fastcore reads as None
130+
# plum reads empty return annotation as object, retain_type expects it as None
138131
if ret is object: ret = None
139132
return retain_type(f(x,**kwargs), x, ret)
140133
res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
@@ -146,7 +139,7 @@ def setups(self, dl): return dl
146139
add_docs(Transform, decode="Delegate to <code>decodes</code> to undo transform", setup="Delegate to <code>setups</code> to set up transform")
147140

148141
# Cell
149-
#Transform interpret's None return type as no conversion
142+
#Implement the Transform convention that a None return annotation disables conversion
150143
add_conversion_method(object, NoneType, lambda x: x)
151144

152145
# Cell

0 commit comments

Comments
 (0)