Skip to content

Commit 52b86df

Browse files
authored
[PIR] Reconstruct the Verify system (PaddlePaddle#58052)
* refine verify of if op * fix * fix * fix * refine * fix * fix * fix * fix
1 parent e638c66 commit 52b86df

32 files changed

+205
-94
lines changed

paddle/cinn/hlir/dialect/operator/ir/manual_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ std::vector<pir::Operation *> GroupOp::ops() {
4444
inner_block->end());
4545
}
4646

47-
void GroupOp::Verify() {}
47+
void GroupOp::VerifySig() {}
4848

4949
void GroupOp::Print(pir::IrPrinter &printer) {
5050
auto &os = printer.os;

paddle/cinn/hlir/dialect/operator/ir/manual_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class GroupOp : public pir::Op<GroupOp> {
3636
pir::Block *block();
3737
std::vector<pir::Operation *> ops();
3838

39-
void Verify();
39+
void VerifySig();
4040
void Print(pir::IrPrinter &printer); // NOLINT
4141
};
4242

paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace dialect {
2222

2323
const char* JitKernelOp::attributes_name[attributes_num] = {kAttrName};
2424

25-
void JitKernelOp::Verify() {
25+
void JitKernelOp::VerifySig() {
2626
VLOG(4) << "Verifying inputs, outputs and attributes for: JitKernelOp.";
2727

2828
auto& attributes = this->attributes();

paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class JitKernelOp : public ::pir::Op<JitKernelOp> {
5151

5252
hlir::framework::Instruction* instruction();
5353

54-
void Verify();
54+
void VerifySig();
5555
};
5656

5757
} // namespace dialect

paddle/fluid/ir_adaptor/translator/program_translator.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,8 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation(
367367
true);
368368
}
369369
VLOG(4) << "[general op][conditional_block] IfOp false block translate end.";
370+
371+
operation->Verify();
370372
VLOG(4) << "[general op][conditional_block] IfOp translate end.";
371373
return operation;
372374
}

paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ const char* PhiKernelOp::attributes_name[attributes_num] = { // NOLINT
2525
"kernel_name",
2626
"kernel_key"};
2727

28-
void PhiKernelOp::Verify() {
28+
void PhiKernelOp::VerifySig() {
2929
VLOG(4) << "Verifying inputs, outputs and attributes for: PhiKernelOp.";
3030

3131
auto& attributes = this->attributes();
@@ -64,7 +64,7 @@ const char* LegacyKernelOp::attributes_name[attributes_num] = { // NOLINT
6464
"kernel_name",
6565
"kernel_key"};
6666

67-
void LegacyKernelOp::Verify() {
67+
void LegacyKernelOp::VerifySig() {
6868
VLOG(4) << "Verifying inputs, outputs and attributes for: LegacyKernelOp.";
6969

7070
auto& attributes = this->attributes();

paddle/fluid/pir/dialect/kernel/ir/kernel_op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class PhiKernelOp : public pir::Op<PhiKernelOp> {
2929
std::string op_name();
3030
std::string kernel_name();
3131
phi::KernelKey kernel_key();
32-
void Verify();
32+
void VerifySig();
3333
};
3434

3535
class LegacyKernelOp : public pir::Op<LegacyKernelOp> {
@@ -41,7 +41,7 @@ class LegacyKernelOp : public pir::Op<LegacyKernelOp> {
4141
std::string op_name();
4242
std::string kernel_name();
4343
phi::KernelKey kernel_key();
44-
void Verify();
44+
void VerifySig();
4545
};
4646

4747
} // namespace dialect

paddle/fluid/pir/dialect/op_generator/op_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
9999
{build_mutable_attr_is_input}
100100
{build_attr_num_over_1}
101101
{build_mutable_attr_is_input_attr_num_over_1}
102-
void Verify();
102+
void VerifySig();
103103
{get_inputs_and_outputs}
104104
{exclusive_interface}
105105
}};

paddle/fluid/pir/dialect/op_generator/op_verify_gen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
# verify
1616
OP_VERIFY_TEMPLATE = """
17-
void {op_name}::Verify() {{
17+
void {op_name}::VerifySig() {{
1818
VLOG(4) << "Start Verifying inputs, outputs and attributes for: {op_name}.";
1919
VLOG(4) << "Verifying inputs:";
2020
{{
@@ -36,7 +36,7 @@
3636
"""
3737

3838
GRAD_OP_VERIFY_TEMPLATE = """
39-
void {op_name}::Verify() {{}}
39+
void {op_name}::VerifySig() {{}}
4040
"""
4141

4242
INPUT_TYPE_CHECK_TEMPLATE = """

paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ paddle::dialect::IfOp, paddle::dialect::WhileOp
1919

2020
#include "paddle/phi/core/enforce.h"
2121
#include "paddle/pir/core/builder.h"
22+
#include "paddle/pir/core/builtin_type.h"
2223
#include "paddle/pir/core/ir_printer.h"
2324
#include "paddle/pir/core/operation_utils.h"
2425
#include "paddle/pir/dialect/control_flow/ir/cf_ops.h"
@@ -109,7 +110,74 @@ void IfOp::Print(pir::IrPrinter &printer) {
109110
}
110111
os << "\n }";
111112
}
112-
void IfOp::Verify() {}
113+
void IfOp::VerifySig() {
114+
VLOG(4) << "Start Verifying inputs, outputs and attributes for: IfOp.";
115+
auto input_size = num_operands();
116+
PADDLE_ENFORCE_EQ(
117+
input_size,
118+
1u,
119+
phi::errors::PreconditionNotMet(
120+
"The size %d of inputs must be equal to 1.", input_size));
121+
122+
if ((*this)->operand_source(0).type().isa<pir::DenseTensorType>()) {
123+
PADDLE_ENFORCE(
124+
(*this)
125+
->operand_source(0)
126+
.type()
127+
.dyn_cast<pir::DenseTensorType>()
128+
.dtype()
129+
.isa<pir::BoolType>(),
130+
phi::errors::PreconditionNotMet(
131+
"Type validation failed for the 1th input, it should be a "
132+
"bool DenseTensorType."));
133+
}
134+
135+
PADDLE_ENFORCE_EQ((*this)->num_regions(),
136+
2u,
137+
phi::errors::PreconditionNotMet(
138+
"The size %d of regions must be equal to 2.",
139+
(*this)->num_regions()));
140+
}
141+
142+
void IfOp::VerifyRegion() {
143+
VLOG(4) << "Start Verifying sub regions for: IfOp.";
144+
PADDLE_ENFORCE_EQ(
145+
(*this)->region(0).size(),
146+
1u,
147+
phi::errors::PreconditionNotMet("The size %d of true_region must be 1.",
148+
(*this)->region(0).size()));
149+
150+
if ((*this)->num_results() != 0) {
151+
PADDLE_ENFORCE_EQ(
152+
(*this)->region(0).size(),
153+
(*this)->region(1).size(),
154+
phi::errors::PreconditionNotMet("The size %d of true_region must be "
155+
"equal to the size %d of false_region.",
156+
(*this)->region(0).size(),
157+
(*this)->region(1).size()));
158+
159+
auto *true_last_op = (*this)->region(0).front()->back();
160+
auto *false_last_op = (*this)->region(1).front()->back();
161+
PADDLE_ENFORCE_EQ(true_last_op->isa<pir::YieldOp>(),
162+
true,
163+
phi::errors::PreconditionNotMet(
164+
"The last of true block must be YieldOp"));
165+
PADDLE_ENFORCE_EQ(true_last_op->num_operands(),
166+
(*this)->num_results(),
167+
phi::errors::PreconditionNotMet(
168+
"The size of last of true block op's input must be "
169+
"equal to IfOp's outputs num."));
170+
PADDLE_ENFORCE_EQ(false_last_op->isa<pir::YieldOp>(),
171+
true,
172+
phi::errors::PreconditionNotMet(
173+
"The last of false block must be YieldOp"));
174+
PADDLE_ENFORCE_EQ(false_last_op->num_operands(),
175+
(*this)->num_results(),
176+
phi::errors::PreconditionNotMet(
177+
"The size of last of false block op's input must be "
178+
"equal to IfOp's outputs num."));
179+
}
180+
}
113181

114182
void WhileOp::Build(pir::Builder &builder, // NOLINT
115183
pir::OperationArgument &argument, // NOLINT

paddle/fluid/pir/dialect/operator/ir/control_flow_op.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ class IfOp : public pir::Op<IfOp> {
4141
pir::Block *true_block();
4242
pir::Block *false_block();
4343
void Print(pir::IrPrinter &printer); // NOLINT
44-
void Verify();
44+
void VerifySig();
45+
void VerifyRegion();
4546
};
4647

4748
class WhileOp : public pir::Op<WhileOp> {
@@ -57,7 +58,8 @@ class WhileOp : public pir::Op<WhileOp> {
5758
pir::Block *cond_block();
5859
pir::Block *body_block();
5960
void Print(pir::IrPrinter &printer); // NOLINT
60-
void Verify() {}
61+
void VerifySig() {}
62+
void VerifyRegion() {}
6163
};
6264

6365
} // namespace dialect

paddle/fluid/pir/dialect/operator/ir/manual_op.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ OpInfoTuple AddNOp::GetOpInfo() {
5050
return std::make_tuple(inputs, attributes, outputs, run_time_info, "add_n");
5151
}
5252

53-
void AddNOp::Verify() {
53+
void AddNOp::VerifySig() {
5454
VLOG(4) << "Start Verifying inputs, outputs and attributes for: AddNOp.";
5555
VLOG(4) << "Verifying inputs:";
5656
{
@@ -222,7 +222,7 @@ void AddN_Op::Build(pir::Builder &builder,
222222
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
223223
}
224224

225-
void AddN_Op::Verify() {
225+
void AddN_Op::VerifySig() {
226226
VLOG(4) << "Start Verifying inputs, outputs and attributes for: AddN_Op.";
227227
VLOG(4) << "Verifying inputs:";
228228
{
@@ -345,7 +345,7 @@ void AddNWithKernelOp::Build(pir::Builder &builder,
345345
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
346346
}
347347

348-
void AddNWithKernelOp::Verify() {
348+
void AddNWithKernelOp::VerifySig() {
349349
VLOG(4) << "Start Verifying inputs, outputs and attributes for: "
350350
"AddNWithKernelOp.";
351351
VLOG(4) << "Verifying inputs:";
@@ -561,7 +561,7 @@ void FusedGemmEpilogueOp::Build(pir::Builder &builder,
561561
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
562562
}
563563

564-
void FusedGemmEpilogueOp::Verify() {
564+
void FusedGemmEpilogueOp::VerifySig() {
565565
VLOG(4) << "Start Verifying inputs, outputs and attributes for: "
566566
"FusedGemmEpilogueOp.";
567567
VLOG(4) << "Verifying inputs:";
@@ -833,7 +833,7 @@ void FusedGemmEpilogueGradOp::Build(pir::Builder &builder,
833833
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
834834
}
835835

836-
void FusedGemmEpilogueGradOp::Verify() {}
836+
void FusedGemmEpilogueGradOp::VerifySig() {}
837837

838838
void FusedGemmEpilogueGradOp::InferMeta(phi::InferMetaContext *infer_meta) {
839839
auto fn = PD_INFER_META(phi::FusedGemmEpilogueGradInferMeta);
@@ -983,7 +983,7 @@ void SplitGradOp::Build(pir::Builder &builder,
983983
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
984984
}
985985

986-
void SplitGradOp::Verify() {
986+
void SplitGradOp::VerifySig() {
987987
VLOG(4) << "Start Verifying inputs, outputs and attributes for: SplitGradOp.";
988988
VLOG(4) << "Verifying inputs:";
989989
{

paddle/fluid/pir/dialect/operator/ir/manual_op.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class AddNOp : public pir::Op<AddNOp,
4545
pir::OperationArgument &argument, // NOLINT
4646
pir::Value inputs);
4747

48-
void Verify();
48+
void VerifySig();
4949
pir::Value inputs() { return operand_source(0); }
5050
pir::OpResult out() { return result(0); }
5151
static void InferMeta(phi::InferMetaContext *infer_meta);
@@ -69,7 +69,7 @@ class AddN_Op : public pir::Op<AddN_Op,
6969
pir::OperationArgument &argument, // NOLINT
7070
pir::Value inputs_);
7171

72-
void Verify();
72+
void VerifySig();
7373
pir::Value inputs() { return operand_source(0); }
7474
pir::OpResult out() { return result(0); }
7575

@@ -89,7 +89,7 @@ class AddNWithKernelOp : public pir::Op<AddNWithKernelOp,
8989
pir::OperationArgument &argument, // NOLINT
9090
pir::Value inputs_);
9191

92-
void Verify();
92+
void VerifySig();
9393
pir::Value inputs() { return operand_source(0); }
9494
pir::OpResult out() { return result(0); }
9595

@@ -113,7 +113,7 @@ class FusedGemmEpilogueOp
113113
pir::Value y_,
114114
pir::Value bias_,
115115
pir::AttributeMap attributes);
116-
void Verify();
116+
void VerifySig();
117117
pir::Value x() { return operand_source(0); }
118118
pir::Value y() { return operand_source(1); }
119119
pir::Value bias() { return operand_source(2); }
@@ -141,7 +141,7 @@ class FusedGemmEpilogueGradOp
141141
pir::Value reserve_space_,
142142
pir::Value out_grad_,
143143
pir::AttributeMap attributes);
144-
void Verify();
144+
void VerifySig();
145145
pir::Value x() { return operand_source(0); }
146146
pir::Value y() { return operand_source(1); }
147147
pir::Value reserve_space() { return operand_source(2); }
@@ -169,7 +169,7 @@ class SplitGradOp : public pir::Op<SplitGradOp, OpYamlInfoInterface> {
169169
pir::Value out_grad_,
170170
pir::Value axis_);
171171

172-
void Verify();
172+
void VerifySig();
173173
pir::Value out_grad() { return operand_source(0); }
174174
pir::Value axis() { return operand_source(1); }
175175
pir::OpResult x_grad() { return result(0); }

paddle/pir/core/builtin_op.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ void ModuleOp::Destroy() {
8282
}
8383
}
8484

85-
void ModuleOp::Verify() const {
85+
void ModuleOp::VerifySig() const {
8686
VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
8787
// Verify inputs:
8888
IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0.");
@@ -118,7 +118,7 @@ void GetParameterOp::PassStopGradients(OperationArgument &argument) {
118118
pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient));
119119
}
120120

121-
void GetParameterOp::Verify() const {
121+
void GetParameterOp::VerifySig() const {
122122
VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp.";
123123
// Verify inputs:
124124
IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0.");
@@ -144,7 +144,7 @@ void SetParameterOp::Build(Builder &builder, // NOLINT
144144
argument.AddAttribute(attributes_name[0],
145145
pir::StrAttribute::get(builder.ir_context(), name));
146146
}
147-
void SetParameterOp::Verify() const {
147+
void SetParameterOp::VerifySig() const {
148148
VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp.";
149149
// Verify inputs:
150150
IR_ENFORCE(num_operands() == 1, "The size of outputs must be equal to 1.");
@@ -170,7 +170,7 @@ void ShadowOutputOp::Build(Builder &builder, // NOLINT
170170
argument.AddAttribute(attributes_name[0],
171171
pir::StrAttribute::get(builder.ir_context(), name));
172172
}
173-
void ShadowOutputOp::Verify() const {
173+
void ShadowOutputOp::VerifySig() const {
174174
VLOG(4) << "Verifying inputs, outputs and attributes for: ShadowOutputOp.";
175175
// Verify inputs:
176176
IR_ENFORCE(num_operands() == 1, "The size of outputs must be equal to 1.");
@@ -198,7 +198,7 @@ void CombineOp::Build(Builder &builder,
198198
PassStopGradientsDefaultly(argument);
199199
}
200200

201-
void CombineOp::Verify() const {
201+
void CombineOp::VerifySig() const {
202202
// outputs.size() == 1
203203
IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1.");
204204

@@ -260,7 +260,7 @@ void SliceOp::PassStopGradients(OperationArgument &argument, int index) {
260260
pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient));
261261
}
262262

263-
void SliceOp::Verify() const {
263+
void SliceOp::VerifySig() const {
264264
// inputs.size() == 1
265265
auto input_size = num_operands();
266266
IR_ENFORCE(
@@ -364,7 +364,7 @@ void SplitOp::PassStopGradients(OperationArgument &argument) {
364364
pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient));
365365
}
366366

367-
void SplitOp::Verify() const {
367+
void SplitOp::VerifySig() const {
368368
// inputs.size() == 1
369369
IR_ENFORCE(num_operands() == 1u, "The size of inputs must be equal to 1.");
370370

@@ -393,7 +393,7 @@ void ConstantOp::Build(Builder &builder,
393393
argument.output_types.push_back(output_type);
394394
}
395395

396-
void ConstantOp::Verify() const {
396+
void ConstantOp::VerifySig() const {
397397
IR_ENFORCE(num_operands() == 0, "The size of inputs must be equal to 0.");
398398
IR_ENFORCE(num_results() == 1, "The size of outputs must be equal to 1.");
399399
IR_ENFORCE(attributes().count("value") > 0, "must has value attribute");

0 commit comments

Comments
 (0)