Skip to content

Commit 4e2d0fd

Browse files
[ONNX] Add support for asymmetric padding for Onnx.AveragePool op (#3923)
This commit adds support for the asymmetric padding for Onnx's AveragePool op. This commit also extends the Torch->Linalg lowering of the pooling ops to consider asymmetric padding during the output dim computation. --------- Signed-off-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
1 parent 389541f commit 4e2d0fd

File tree

5 files changed

+158
-79
lines changed

5 files changed

+158
-79
lines changed

include/torch-mlir/Conversion/TorchToLinalg/Utils.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@ Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in,
4747
Value kernelSizeInt, Value strideInt,
4848
bool ceilMode = false);
4949

50+
// Helper function to caculate the output tensor dims for pooling-like ops.
51+
// Along each dim:
52+
// dim_out =
53+
// floor((dim_in + totalPadding - dilation * (kernelSize - 1) - 1) / stride) +
54+
// 1
55+
Value getOutputDimForPoolOps(OpBuilder &b, Location loc, Value in,
56+
int64_t totalPadding, int64_t leftPadding,
57+
Value dilationInt, Value kernelSizeInt,
58+
Value strideInt, bool ceilMode);
59+
5060
// As above but for transposed convolution ops
5161
// Along each dim:
5262
// dim_out =

lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

Lines changed: 81 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -456,107 +456,113 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
456456
patterns.onOp(
457457
"AveragePool", 1,
458458
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
459-
std::string autoPad;
460-
SmallVector<int64_t> dilations;
461-
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
462-
return failure();
463-
if (autoPad != "NOTSET") {
464-
// TODO: Add support for `auto_pad` != "NOTSET"
465-
return rewriter.notifyMatchFailure(
466-
binder.op, "unsupported conversion: auto_pad != NOTSET");
467-
}
468-
469459
Torch::ValueTensorType resultType;
470460
Value operand;
471461
bool ceilMode, countIncludePad;
462+
std::string autoPad;
472463
if (binder.tensorOperand(operand) ||
473464
binder.s64BoolAttr(ceilMode, "ceil_mode", false) ||
474465
binder.s64BoolAttr(countIncludePad, "count_include_pad", false) ||
466+
binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET") ||
475467
binder.tensorResultType(resultType))
476-
return failure();
468+
return rewriter.notifyMatchFailure(
469+
binder.op, "operand/ceil_mode/count_include_pad/auto_pad/"
470+
"resultType bind failure");
471+
477472
// Determine the rank of input tensor.
478473
std::optional<unsigned> maybeRank = Torch::getTensorRank(operand);
479474
if (!maybeRank)
480475
return rewriter.notifyMatchFailure(binder.op,
481476
"Unimplemented: unranked tensor");
482477
unsigned rank = *maybeRank;
483478

484-
SmallVector<int64_t> kernel, padding, strides;
485-
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) {
486-
return failure();
487-
}
488-
if (kernel.size() != rank - 2) {
479+
int64_t spatialRank = rank - 2;
480+
SmallVector<int64_t> kernel, padding, strides, dilations;
481+
482+
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {}))
483+
return rewriter.notifyMatchFailure(binder.op,
484+
"kernel_shape bind failure");
485+
if (kernel.size() != static_cast<size_t>(spatialRank))
489486
return rewriter.notifyMatchFailure(
490487
binder.op, "kernel list size does not match the number of axes");
491-
}
492-
SmallVector<int64_t> defaultPadding(2 * (rank - 2), 0);
493-
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) {
494-
return failure();
495-
}
496-
if (padding.size() != 2 * (rank - 2)) {
488+
489+
if (binder.s64IntegerArrayAttr(padding, "pads", {}))
490+
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
491+
if (!padding.empty() &&
492+
padding.size() != static_cast<size_t>(2 * spatialRank))
497493
return rewriter.notifyMatchFailure(
498-
binder.op,
499-
"padding list size does not match twice the number of axes");
500-
}
501-
if (binder.s64IntegerArrayAttr(
502-
strides, "strides", llvm::SmallVector<int64_t>(rank - 2, 1))) {
503-
return failure();
504-
}
505-
if (strides.size() != 1 && strides.size() != rank - 2) {
494+
binder.op, "padding list must contain (begin,end) pair for each "
495+
"spatial axis");
496+
497+
if (binder.s64IntegerArrayAttr(strides, "strides", {}))
498+
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
499+
if (!strides.empty() &&
500+
strides.size() != static_cast<size_t>(spatialRank))
506501
return rewriter.notifyMatchFailure(
507502
binder.op, "strides list size does not match the number of axes");
508-
}
509503

510-
SmallVector<Value> cstKernel, cstPadding, cstStridesDilations;
511-
for (int64_t i : kernel) {
512-
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
513-
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
514-
}
515-
// Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…]
516-
// Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all
517-
// axes x.
518-
int64_t paddingSizeHalf = padding.size() / 2;
519-
for (int64_t i = 0; i < paddingSizeHalf; ++i) {
520-
// Check if onnx padding attribute is symmetric.
521-
if (padding[i] != padding[i + paddingSizeHalf])
522-
return rewriter.notifyMatchFailure(
523-
binder.op, "onnx padding attribute is not symmetric");
524-
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
525-
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
504+
if (binder.s64IntegerArrayAttr(dilations, "dilations", {}))
505+
return rewriter.notifyMatchFailure(binder.op,
506+
"dilations bind failure");
507+
508+
// set default values for padding, strides, and dilations.
509+
if (padding.empty())
510+
padding.resize(spatialRank, 0);
511+
if (strides.empty())
512+
strides.resize(spatialRank, 1);
513+
if (dilations.empty())
514+
dilations.resize(spatialRank, 1);
515+
516+
// Padding for the beginning and ending along each spatial axis, it can
517+
// take any value greater than or equal to 0. The value represent the
518+
// number of pixels added to the beginning and end part of the
519+
// corresponding axis. pads format should be as follow [x1_begin,
520+
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
521+
// at the beginning of axis i and xi_end, the number of pixels added at
522+
// the end of axis i.
523+
auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
524+
if (autoPad != "NOTSET" && autoPad != "VALID") {
525+
const bool isSameLower = autoPad == "SAME_LOWER";
526+
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
527+
padding.resize_for_overwrite(2 * spatialRank);
528+
for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) {
529+
const int64_t dilatedKernelSize =
530+
dilations[dimIdx] * (kernel[dimIdx] - 1) + 1;
531+
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
532+
strides[dimIdx] -
533+
1) *
534+
strides[dimIdx] +
535+
dilatedKernelSize - inputShape[dimIdx + 2];
536+
totalPad = totalPad >= 0 ? totalPad : 0;
537+
padding[dimIdx] =
538+
isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2);
539+
padding[spatialRank + dimIdx] = totalPad - padding[dimIdx];
540+
}
526541
}
527-
for (int64_t i : strides) {
528-
cstStridesDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
529-
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
542+
543+
// If the padding is symmetric then we don't need seperate low/high
544+
// padding values.
545+
if (padding.size() == static_cast<size_t>(2 * spatialRank)) {
546+
bool equal = true;
547+
for (int i = 0; i < spatialRank; ++i) {
548+
equal = equal && (padding[i] == padding[i + spatialRank]);
549+
}
550+
if (equal)
551+
padding.resize(spatialRank);
530552
}
531553

532-
// No dilations attribute in pytorch avgpool op, so use this trick to
533-
// encode dilation into strides. Then in the following torchtolinalg
534-
// lowering, decode strides into strides + dilation.
554+
// Since the PyTorch AvgPool op does not contain the `dilation` arg,
555+
// hence we use the trick of encoding dilation into strides. Then,
556+
// during the torch->linalg lowering of the `AvgPool` op we decode the
557+
// `strides` arg into strides values followed by dilation like:
535558
// [strideDim1,strideDim2,...,dilationDim1,dilationDim2,...]
536-
if (binder.s64IntegerArrayAttr(
537-
dilations, "dilations",
538-
llvm::SmallVector<int64_t>(rank - 2, 1))) {
539-
return failure();
540-
}
541-
for (auto dilation : dilations) {
542-
cstStridesDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
543-
binder.getLoc(), rewriter.getI64IntegerAttr(dilation)));
544-
}
559+
SmallVector<int64_t> stridesDilations = strides;
560+
stridesDilations.append(dilations);
545561

546-
Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
547-
binder.getLoc(),
548-
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
549-
cstKernel);
550-
Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
551-
binder.getLoc(),
552-
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
553-
cstPadding);
562+
Value kernelSizeList = createConstantIntList(binder, rewriter, kernel);
563+
Value paddingList = createConstantIntList(binder, rewriter, padding);
554564
Value stridesDilationsList =
555-
rewriter.create<Torch::PrimListConstructOp>(
556-
binder.getLoc(),
557-
Torch::ListType::get(
558-
Torch::IntType::get(binder.op->getContext())),
559-
cstStridesDilations);
565+
createConstantIntList(binder, rewriter, stridesDilations);
560566
Value cstCeilMode =
561567
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);
562568
Value cstCountIncludePad = rewriter.create<Torch::ConstantBoolOp>(

lib/Conversion/TorchToLinalg/Pooling.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,24 @@ computeOutputTensor(Operation *op, ConversionPatternRewriter &rewriter,
8383
Value N = getDimOp(rewriter, loc, self, 0);
8484
Value C = getDimOp(rewriter, loc, self, 1);
8585

86-
SmallVector<Value> paddingIntValues =
87-
getAsConstantIntValues(rewriter, loc, paddingInts);
8886
SmallVector<Value> dilationIntValues =
8987
getAsConstantIntValues(rewriter, loc, dilationInts);
9088
SmallVector<Value> strideIntValues =
9189
getAsConstantIntValues(rewriter, loc, strideInts);
9290

9391
// Get dimension size for each dimension and calculate output size
9492
for (int64_t i = dimensionality - 1; i > -1; --i) {
93+
// In case of asymmetric padding the total padding value would be the sum of
94+
// low and high padding. And, in case of symmetric padding it would just be
95+
// the double of padding value for the corresponding dimension.
96+
int64_t totalPadding = paddingInts[i] * 2;
97+
if ((int64_t)paddingInts.size() == 2 * dimensionality)
98+
totalPadding = paddingInts[i] + paddingInts[i + dimensionality];
99+
95100
Value dimSize = getDimOp(rewriter, loc, self, i + 2);
96-
Value outDim = torch_to_linalg::getOutputDimForConvOps(
97-
rewriter, loc, dimSize, paddingIntValues[i], dilationIntValues[i],
101+
Value outDim = torch_to_linalg::getOutputDimForPoolOps(
102+
rewriter, loc, dimSize, /*totalPadding=*/totalPadding,
103+
/*leftPadding=*/paddingInts[i], dilationIntValues[i],
98104
kernelSizeIntValues[i], strideIntValues[i], ceilMode);
99105
outTensorShape.insert(outTensorShape.begin(), {outDim});
100106
}

lib/Conversion/TorchToLinalg/Utils.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,53 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc,
135135
return castIntToIndex(b, loc, out);
136136
}
137137

138+
Value torch_to_linalg::getOutputDimForPoolOps(OpBuilder &b, Location loc,
139+
Value in, int64_t totalPadding,
140+
int64_t leftPadding,
141+
Value dilationInt,
142+
Value kernelSizeInt,
143+
Value strideInt, bool ceilMode) {
144+
Value c1 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
145+
Value totalPaddingIntCst =
146+
b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(totalPadding));
147+
148+
// in + totalPadding
149+
Value inAddTotalPadding = b.createOrFold<arith::AddIOp>(
150+
loc, castIndexToInt64(b, loc, in), totalPaddingIntCst);
151+
152+
// dilation * (kernelSize - 1)
153+
Value kernelSizeSub1 = b.createOrFold<arith::SubIOp>(loc, kernelSizeInt, c1);
154+
Value dilationTimesKernelSize =
155+
b.createOrFold<arith::MulIOp>(loc, dilationInt, kernelSizeSub1);
156+
157+
Value temp = b.createOrFold<arith::SubIOp>(loc, inAddTotalPadding,
158+
dilationTimesKernelSize);
159+
Value dividend = b.createOrFold<arith::SubIOp>(loc, temp, c1);
160+
Value division;
161+
if (ceilMode)
162+
division = b.createOrFold<arith::CeilDivSIOp>(loc, dividend, strideInt);
163+
else
164+
division = b.createOrFold<arith::FloorDivSIOp>(loc, dividend, strideInt);
165+
Value out = b.createOrFold<arith::AddIOp>(loc, division, c1);
166+
167+
if (!ceilMode)
168+
return castIntToIndex(b, loc, out);
169+
170+
Value outMinusOneTimesStride =
171+
b.createOrFold<arith::MulIOp>(loc, division, strideInt);
172+
Value leftPaddingIntCst =
173+
b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(leftPadding));
174+
Value inAddLeftPadding = b.createOrFold<arith::AddIOp>(
175+
loc, castIndexToInt64(b, loc, in), leftPaddingIntCst);
176+
177+
auto reduceOutputDimCond = b.createOrFold<arith::CmpIOp>(
178+
loc, arith::CmpIPredicate::uge, outMinusOneTimesStride, inAddLeftPadding);
179+
180+
auto reducedDim =
181+
b.createOrFold<arith::SelectOp>(loc, reduceOutputDimCond, division, out);
182+
return castIntToIndex(b, loc, reducedDim);
183+
}
184+
138185
Value torch_to_linalg::getOutputDimForConvTransposeOps(
139186
OpBuilder &b, Location loc, Value in, Value paddingInt, Value dilationInt,
140187
Value kernelSizeInt, Value strideInt, Value outputPaddingInt) {

test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,16 @@ func.func @test_averagepool_with_padding(%arg0: !torch.vtensor<[1,20,64,48],f32>
996996

997997
// -----
998998

999+
// CHECK-LABEL: @test_averagepool_with_asymmetric_padding
1000+
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,1024,6,6],f32>
1001+
func.func @test_averagepool_with_asymmetric_padding(%arg1: !torch.vtensor<[1,1024,6,6],f32>) -> !torch.vtensor<[1,1024,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {ai.onnx.contrib = 1000 : si64, ai.onnx.ml = 3 : si64, ai.onnx.preview.training = 1 : si64, ai.onnx.training = 1 : si64, com.microsoft = 1 : si64, com.microsoft.experimental = 1 : si64, com.microsoft.nchwc = 1 : si64, com.ms.internal.nhwc = 1 : si64, org.pytorch.aten = 1 : si64}, torch.onnx_meta.producer_name = "onnx.quantize", torch.onnx_meta.producer_version = "0.1.0"} {
1002+
%1 = torch.operator "onnx.AveragePool"(%arg1) {torch.onnx.auto_pad = "NOTSET", torch.onnx.ceil_mode = 0 : si64, torch.onnx.count_include_pad = 0 : si64, torch.onnx.kernel_shape = [7 : si64, 7 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[1,1024,6,6],f32>) -> !torch.vtensor<[1,1024,1,1],f32>
1003+
// CHECK: torch.aten.avg_pool2d %[[ARG]], {{.*}} : !torch.vtensor<[1,1024,6,6],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1024,1,1],f32>
1004+
return %1 : !torch.vtensor<[1,1024,1,1],f32>
1005+
}
1006+
1007+
// -----
1008+
9991009
// CHECK-LABEL: @test_conv_with_strides_no_padding
10001010
func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
10011011
// CHECK: %[[C0:.*]] = torch.constant.int 0

0 commit comments

Comments
 (0)