Skip to content

Commit a49ee53

Browse files
SigureMo2742195759feifei-1110x45fgouzil
authored
[SOT] merge PaddleSOT into Paddle (PaddlePaddle#57824)
PaddleSOT is a Bytecode level Implementation of Symbolic OpCode Translator For PaddlePaddle. We originally developed in [PaddleSOT](https://github.com/PaddlePaddle/PaddleSOT), and to ensure consistency in Paddle versions, we are now merging PaddleSOT into Paddle. Thanks to all the contributors of this project! See more details in https://github.com/PaddlePaddle/PaddleSOT/graphs/contributors --------- Co-authored-by: xiongkun <xiongkun03@baidu.com> Co-authored-by: feifei-111 <2364819892@qq.com> Co-authored-by: 0x45f <23097963+0x45f@users.noreply.github.com> Co-authored-by: gouzil <66515297+gouzil@users.noreply.github.com> Co-authored-by: 六个骨头 <46243324+zrr1999@users.noreply.github.com> Co-authored-by: Aurelius84 <zhangliujie@baidu.com> Co-authored-by: Wang Xin <xinwang614@gmail.com> Co-authored-by: haozi <64006169+NotHaozi@users.noreply.github.com> Co-authored-by: RedContritio <RedContritio@qq.com> Co-authored-by: Sanbu <96160062+sanbuphy@users.noreply.github.com> Co-authored-by: Difer <c7070655110@gmail.com> Co-authored-by: cyberslack_lee <luhputu0815@gmail.com> Co-authored-by: jjyaoao <jjyaoao@126.com> Co-authored-by: PuQing <me@puqing.work> Co-authored-by: Ran chongzhi <57489288+ranchongzhi@users.noreply.github.com> Co-authored-by: Zhenghai Zhang <65210872+ccsuzzh@users.noreply.github.com>
1 parent 52b86df commit a49ee53

File tree

201 files changed

+22410
-196
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

201 files changed

+22410
-196
lines changed

.flake8

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ per-file-ignores =
2626
# These files need tabs for testing.
2727
test/dygraph_to_static/test_error.py:E101,W191
2828

29+
# Ignore compare with True in sot unittest
30+
test/sot/test_dup_top.py:E712
31+
2932
# temp ignore base directory
3033
python/paddle/base/*:
3134
E712,

paddle/scripts/paddle_build.sh

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -933,23 +933,21 @@ set -ex
933933
}
934934

935935
function run_sot_test() {
936-
PADDLE_SOT_ROOT=$1
937-
PY_VERSION=$2
936+
PY_VERSION=$1
938937
PYTHON_WITH_SPECIFY_VERSION=python$PY_VERSION
939938
PY_VERSION_NO_DOT=`echo $PY_VERSION | sed 's/\.//g'`
940939

941940
export STRICT_MODE=1
942941
export COST_MODEL=False
943942
export MIN_GRAPH_SIZE=0
943+
export SOT_LOG_LEVEL=0
944944

945945
# Install PaddlePaddle
946946
$PYTHON_WITH_SPECIFY_VERSION -m pip install ${PADDLE_ROOT}/dist/paddlepaddle-0.0.0-cp${PY_VERSION_NO_DOT}-cp${PY_VERSION_NO_DOT}-linux_x86_64.whl
947947
# Install PaddleSOT
948-
cd $PADDLE_SOT_ROOT
949-
$PYTHON_WITH_SPECIFY_VERSION -m pip install -e .
948+
cd $PADDLE_ROOT/test/sot/
950949

951950
# Run unittest
952-
cd tests
953951
failed_tests=()
954952

955953
for file in ./test_*.py; do
@@ -4128,14 +4126,12 @@ function main() {
41284126
;;
41294127
cicheck_sot)
41304128
export WITH_SHARED_PHI=ON
4131-
PADDLE_SOT_ROOT=${PADDLE_ROOT}/sot
4132-
git clone https://github.com/PaddlePaddle/PaddleSOT.git ${PADDLE_SOT_ROOT}
41334129
PYTHON_VERSIONS=(3.8 3.9 3.10 3.11)
41344130
for PY_VERSION in ${PYTHON_VERSIONS[@]}; do
41354131
ln -sf $(which python${PY_VERSION}) /usr/local/bin/python
41364132
ln -sf $(which pip${PY_VERSION}) /usr/local/bin/pip
41374133
run_setup ${PYTHON_ABI:-""} bdist_wheel ${parallel_number}
4138-
run_sot_test $PADDLE_SOT_ROOT $PY_VERSION
4134+
run_sot_test $PY_VERSION
41394135
rm -rf ${PADDLE_ROOT}/build/CMakeCache.txt
41404136
done
41414137
;;

python/paddle/jit/dy2static/program_translator.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -692,30 +692,22 @@ class SymbolicStaticFunction(StaticFunction):
692692
def __init__(self, function, input_spec=None, **kwargs):
693693
if input_spec is not None:
694694
warnings.warn(
695-
"\nSymbolic Trace don't support input_spec arguments. It will Will not produce any effect.\n"
695+
"\nSymbolic Trace don't support input_spec arguments. It will not produce any effect.\n"
696696
"1. You can disable fallback mode by `paddle.jit.to_static(enable_fallback=False)` to switch to AST to static, then you can assign input spec.\n"
697697
)
698698
super().__init__(function, input_spec, **kwargs)
699699
self.last_call_input_spec = None
700700

701701
def _perform_call(self, *args, **kwargs):
702+
from ..sot import symbolic_translate
703+
702704
args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs)
703705
(
704706
input_args_with_spec,
705707
input_kwargs_with_spec,
706708
) = self._function_spec.args_to_input_spec(args, kwargs)
707709
self.last_call_input_spec = input_args_with_spec
708710

709-
try:
710-
from sot import symbolic_translate
711-
except:
712-
import os
713-
714-
os.system(
715-
"pip install git+https://github.com/PaddlePaddle/PaddleSOT@develop"
716-
)
717-
from sot import symbolic_translate
718-
719711
build_strategy = self._kwargs.get("build_strategy", None)
720712
backend = self._kwargs.get("backend", None)
721713
traced_fun = symbolic_translate(

python/paddle/jit/sot/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from . import psdb # noqa: F401
16+
from .opcode_translator.breakpoint import ( # noqa: F401
17+
BM,
18+
add_breakpoint,
19+
add_event,
20+
)
21+
from .opcode_translator.skip_files import skip_function # noqa: F401
22+
from .translate import symbolic_translate # noqa: F401

python/paddle/jit/sot/infer_meta.py

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
from paddle.amp.auto_cast import amp_state
17+
from paddle.base.unique_name import UniqueNameGenerator
18+
from paddle.base.unique_name import guard as UniqueNameGuard
19+
from paddle.static import Program
20+
from paddle.utils import flatten, is_sequence
21+
22+
from .utils import Cache, Singleton, map_if_extend, meta_str
23+
24+
25+
class MetaInfo:
26+
def __init__(
27+
self, shape, dtype, stop_gradient, name, persistable, type, place
28+
):
29+
self.name = name
30+
self.persistable = persistable
31+
self.type = type
32+
self.place = place
33+
self.shape = shape
34+
self.dtype = dtype
35+
self.stop_gradient = stop_gradient
36+
37+
@staticmethod
38+
def from_tensor(tensor):
39+
# We always use float32 in simulation if AMP is enabled.
40+
dtype = tensor.dtype
41+
current_amp_state = amp_state()
42+
if (
43+
dtype == paddle.float16
44+
and current_amp_state is not None
45+
and current_amp_state["dtype"] == "float16"
46+
):
47+
dtype = paddle.float32
48+
return MetaInfo(
49+
list(tensor.shape),
50+
dtype,
51+
tensor.stop_gradient,
52+
tensor.name,
53+
tensor.persistable,
54+
tensor.type,
55+
tensor.place,
56+
)
57+
58+
def is_dynamic_shape(self):
59+
"""
60+
if -1 in shape, return True
61+
else: return False
62+
"""
63+
return -1 in self.shape
64+
65+
def to_input_spec(self):
66+
return paddle.static.InputSpec(
67+
self.shape, dtype=self.dtype, stop_gradient=self.stop_gradient
68+
)
69+
70+
def guard_str(self):
71+
return f"({self.shape}, {self.dtype}, {self.stop_gradient})"
72+
73+
def __repr__(self):
74+
return meta_str(self.shape, self.dtype, self.stop_gradient)
75+
76+
def __eq__(self, meta):
77+
return (
78+
self.shape == meta.shape
79+
and self.dtype == meta.dtype
80+
and self.stop_gradient == meta.stop_gradient
81+
)
82+
83+
def __hash__(self):
84+
return hash((tuple(self.shape), self.dtype, self.stop_gradient))
85+
86+
87+
@Singleton
88+
class VariableCreator:
89+
"""
90+
We use the static graph Variable to infer the meta information of Tensor.
91+
This singleton class is used to create Variable for infer meta.
92+
"""
93+
94+
def __init__(self):
95+
self.var_cache = {}
96+
self.main_program = Program()
97+
self.startup_program = Program()
98+
self.var_name_generator = UniqueNameGenerator("infer_meta_variable_")
99+
100+
def gen_name(self, meta):
101+
name = f"{meta.dtype}_{meta.stop_gradient}"
102+
for l in meta.shape:
103+
name += f"_{l}"
104+
return name
105+
106+
def create_var(self, meta):
107+
var = self.main_program.global_block().create_var(
108+
shape=meta.shape,
109+
dtype=meta.dtype,
110+
stop_gradient=meta.stop_gradient,
111+
)
112+
assert not isinstance(
113+
var, paddle.Tensor
114+
), "Expect a Variable, but got a Tensor."
115+
return var
116+
117+
def get_variable(self, meta):
118+
var_feature_name = self.gen_name(meta)
119+
if var_feature_name not in self.var_cache:
120+
self.var_cache[var_feature_name] = self.create_var(meta)
121+
return self.var_cache[var_feature_name]
122+
123+
def infer_meta(self, func, *args, **kwargs):
124+
with paddle.base.framework._dygraph_guard(None), UniqueNameGuard(
125+
self.var_name_generator
126+
):
127+
args, kwargs = convert_meta_to_variable(
128+
args
129+
), convert_meta_to_variable(kwargs)
130+
131+
with paddle.static.program_guard(
132+
self.main_program, self.startup_program
133+
):
134+
if isinstance(func, str):
135+
# TODO(Aurelius84): Is length of args always greater than 0?
136+
# Do we need add condition check here?
137+
out = getattr(args[0], func)(*args[1:], **kwargs)
138+
else:
139+
out = func(*args, **kwargs)
140+
141+
return convert_variable_to_meta_info(out)
142+
143+
144+
def convert_meta_to_variable(args):
145+
return map_if_extend(
146+
args,
147+
pred=lambda x: isinstance(x, MetaInfo),
148+
true_fn=lambda x: VariableCreator().get_variable(x),
149+
false_fn=lambda x: x,
150+
)
151+
152+
153+
def convert_meta_to_input_spec(args):
154+
return map_if_extend(
155+
args,
156+
pred=lambda x: isinstance(x, MetaInfo),
157+
true_fn=lambda x: x.to_input_spec(),
158+
# TODO(xiongkun): can x be tensor ?
159+
false_fn=lambda x: paddle.static.InputSpec.from_tensor(x)
160+
if isinstance(x, paddle.Tensor)
161+
else x,
162+
)
163+
164+
165+
def convert_variable_to_meta_info(args):
166+
return map_if_extend(
167+
args,
168+
pred=lambda x: isinstance(x, paddle.static.Variable),
169+
true_fn=lambda x: MetaInfo.from_tensor(x),
170+
false_fn=lambda x: x,
171+
)
172+
173+
174+
def infer_meta(func, *args, **kwargs):
175+
fn = SpecialInferMeta().get_infermeta_fn(func)
176+
if fn:
177+
return fn(*args, **kwargs)
178+
return VariableCreator().infer_meta(func, *args, **kwargs)
179+
180+
181+
def infer_meta_for_layer(layer, *args, **kwargs):
182+
assert isinstance(
183+
layer, paddle.nn.Layer
184+
), f"Expect a Layer, but got {layer}."
185+
layer = paddle.jit.to_static(layer, enable_fallback=False)
186+
187+
args_, kwargs_ = convert_meta_to_input_spec((args, kwargs))
188+
189+
(
190+
concrete_program,
191+
partial_program_layer,
192+
) = layer.forward.get_concrete_program(*args_, **kwargs_)
193+
194+
out = partial_program_layer._restore_out(
195+
paddle.utils.flatten(
196+
convert_variable_to_meta_info(concrete_program.outputs)
197+
)
198+
)
199+
layer.forward.rollback()
200+
return out
201+
202+
203+
@Singleton
204+
class SpecialInferMeta:
205+
"""
206+
There are some functions that cannot be inferred directly through static graph,
207+
and need to be implemented manually. This class is used to implement infer meta
208+
for these functions.
209+
"""
210+
211+
def __init__(self):
212+
pass
213+
214+
def get_infermeta_fn(self, fn):
215+
try:
216+
funcname = fn.__name__
217+
return getattr(self, f"infermeta_{funcname}")
218+
except:
219+
pass
220+
return None
221+
222+
def infermeta_grad(
223+
self,
224+
outputs,
225+
inputs,
226+
grad_outputs=None,
227+
retain_graph=None,
228+
create_graph=False,
229+
only_inputs=True,
230+
allow_unused=False,
231+
no_grad_vars=None,
232+
):
233+
if not is_sequence(inputs):
234+
inputs = [inputs]
235+
return inputs
236+
237+
238+
@Singleton
239+
class InferMetaCache(Cache):
240+
def key_fn(
241+
self, func, *args, **kwargs
242+
): # args & kwargs have transformed to MetaInfo
243+
try:
244+
retval = hash(
245+
(
246+
func,
247+
tuple(flatten(args)),
248+
tuple(kwargs.keys()),
249+
tuple(flatten(kwargs)),
250+
)
251+
)
252+
except Exception as e:
253+
return None
254+
return retval
255+
256+
def value_fn(self, func, *args, **kwargs):
257+
return infer_meta(func, *args, **kwargs)
258+
259+
260+
@Singleton
261+
class LayerInferMetaCache(Cache):
262+
def key_fn(self, layer, *args, **kwargs):
263+
params = [
264+
MetaInfo.from_tensor(x)
265+
for x in layer.parameters(include_sublayers=True)
266+
]
267+
try:
268+
retval = hash(
269+
(
270+
layer,
271+
tuple(params),
272+
tuple(flatten(args)),
273+
tuple(kwargs.keys()),
274+
tuple(flatten(kwargs)),
275+
)
276+
)
277+
except Exception as e:
278+
return None
279+
return retval
280+
281+
def value_fn(self, layer, *args, **kwargs):
282+
return infer_meta_for_layer(layer, *args, **kwargs)

0 commit comments

Comments
 (0)