Skip to content

Commit e688390

Browse files
[ONNX] Add support for asymmetric padding for Onnx.AveragePool op
This commit also refactors the code for the Onnx's AveragePool and MaxPool op by creating a common utility for both the op lowerings to get the pooling op parameters. Signed-off-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
1 parent 80a3dfd commit e688390

File tree

5 files changed

+173
-200
lines changed

5 files changed

+173
-200
lines changed

include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,15 @@ LogicalResult createDequantizeTensor(ConversionPatternRewriter &rewriter,
130130
Location loc, Value input, Value scale,
131131
Value zeroPoint, Value &output);
132132

133+
// Checks the validity of pooling parameters and stores them in the respective
134+
// vector.
135+
LogicalResult checkAndGetOnnxPoolingOpParameters(
136+
OpBinder binder, ConversionPatternRewriter &rewriter, Type resultDtype,
137+
std::string autoPad, int64_t spatialRank, Value &input,
138+
SmallVectorImpl<int64_t> &kernelSizeInts,
139+
SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &paddingInts,
140+
SmallVectorImpl<int64_t> &dilationInts);
141+
133142
} // namespace mlir::torch::onnx_c
134143

135144
#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H

lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

Lines changed: 26 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -456,107 +456,47 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
456456
patterns.onOp(
457457
"AveragePool", 1,
458458
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
459-
std::string autoPad;
460-
SmallVector<int64_t> dilations;
461-
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
462-
return failure();
463-
if (autoPad != "NOTSET") {
464-
// TODO: Add support for `auto_pad` != "NOTSET"
465-
return rewriter.notifyMatchFailure(
466-
binder.op, "unsupported conversion: auto_pad != NOTSET");
467-
}
468-
469459
Torch::ValueTensorType resultType;
470460
Value operand;
471-
bool ceilMode, countIncludePad;
461+
int64_t ceilMode, countIncludePad;
462+
std::string autoPad;
472463
if (binder.tensorOperand(operand) ||
473-
binder.s64BoolAttr(ceilMode, "ceil_mode", false) ||
474-
binder.s64BoolAttr(countIncludePad, "count_include_pad", false) ||
464+
binder.s64IntegerAttr(ceilMode, "ceil_mode", 0) ||
465+
binder.s64IntegerAttr(countIncludePad, "count_include_pad", 0) ||
466+
binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET") ||
475467
binder.tensorResultType(resultType))
476-
return failure();
468+
return rewriter.notifyMatchFailure(
469+
binder.op, "operand/ceil_mode/count_include_pad/auto_pad/"
470+
"resultType bind failure");
471+
477472
// Determine the rank of input tensor.
478473
std::optional<unsigned> maybeRank = Torch::getTensorRank(operand);
479474
if (!maybeRank)
480475
return rewriter.notifyMatchFailure(binder.op,
481476
"Unimplemented: unranked tensor");
482477
unsigned rank = *maybeRank;
483478

484-
SmallVector<int64_t> kernel, padding, strides;
485-
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) {
486-
return failure();
487-
}
488-
if (kernel.size() != rank - 2) {
489-
return rewriter.notifyMatchFailure(
490-
binder.op, "kernel list size does not match the number of axes");
491-
}
492-
SmallVector<int64_t> defaultPadding(2 * (rank - 2), 0);
493-
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) {
494-
return failure();
495-
}
496-
if (padding.size() != 2 * (rank - 2)) {
497-
return rewriter.notifyMatchFailure(
498-
binder.op,
499-
"padding list size does not match twice the number of axes");
500-
}
501-
if (binder.s64IntegerArrayAttr(
502-
strides, "strides", llvm::SmallVector<int64_t>(rank - 2, 1))) {
503-
return failure();
504-
}
505-
if (strides.size() != 1 && strides.size() != rank - 2) {
506-
return rewriter.notifyMatchFailure(
507-
binder.op, "strides list size does not match the number of axes");
508-
}
509-
510-
SmallVector<Value> cstKernel, cstPadding, cstStridesDilations;
511-
for (int64_t i : kernel) {
512-
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
513-
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
514-
}
515-
// Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…]
516-
// Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all
517-
// axes x.
518-
int64_t paddingSizeHalf = padding.size() / 2;
519-
for (int64_t i = 0; i < paddingSizeHalf; ++i) {
520-
// Check if onnx padding attribute is symmetric.
521-
if (padding[i] != padding[i + paddingSizeHalf])
522-
return rewriter.notifyMatchFailure(
523-
binder.op, "onnx padding attribute is not symmetric");
524-
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
525-
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
526-
}
527-
for (int64_t i : strides) {
528-
cstStridesDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
529-
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
530-
}
479+
SmallVector<int64_t> kernel, padding, strides, dilations,
480+
stridesDilations;
481+
if (failed(checkAndGetOnnxPoolingOpParameters(
482+
binder, rewriter, resultType.getDtype(), autoPad,
483+
/*spatialRank=*/rank - 2,
484+
/*input=*/operand, kernel, strides, padding, dilations)))
485+
return rewriter.notifyMatchFailure(binder.op,
486+
"invalid pooling parameters");
531487

532-
// No dilations attribute in pytorch avgpool op, so use this trick to
533-
// encode dilation into strides. Then in the following torchtolinalg
534-
// lowering, decode strides into strides + dilation.
488+
// Since the PyTorch AvgPool op does not contain the `dilation` arg,
489+
// hence we use the trick of encoding dilation into strides. Then,
490+
// during the torch->linalg lowering of the `AvgPool` op we decode the
491+
// `strides` arg into strides values followed by dilation like:
535492
// [strideDim1,strideDim2,...,dilationDim1,dilationDim2,...]
536-
if (binder.s64IntegerArrayAttr(
537-
dilations, "dilations",
538-
llvm::SmallVector<int64_t>(rank - 2, 1))) {
539-
return failure();
540-
}
541-
for (auto dilation : dilations) {
542-
cstStridesDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
543-
binder.getLoc(), rewriter.getI64IntegerAttr(dilation)));
544-
}
493+
stridesDilations = strides;
494+
stridesDilations.append(dilations);
545495

546-
Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
547-
binder.getLoc(),
548-
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
549-
cstKernel);
550-
Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
551-
binder.getLoc(),
552-
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
553-
cstPadding);
496+
Value kernelSizeList = createConstantIntList(binder, rewriter, kernel);
497+
Value paddingList = createConstantIntList(binder, rewriter, padding);
554498
Value stridesDilationsList =
555-
rewriter.create<Torch::PrimListConstructOp>(
556-
binder.getLoc(),
557-
Torch::ListType::get(
558-
Torch::IntType::get(binder.op->getContext())),
559-
cstStridesDilations);
499+
createConstantIntList(binder, rewriter, stridesDilations);
560500
Value cstCeilMode =
561501
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);
562502
Value cstCountIncludePad = rewriter.create<Torch::ConstantBoolOp>(

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 14 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,138 +1185,38 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
11851185
});
11861186
patterns.onOp(
11871187
"MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1188-
std::string autoPad;
1189-
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
1190-
return rewriter.notifyMatchFailure(binder.op,
1191-
"auto_pad bind failure");
1192-
11931188
Torch::ValueTensorType resultTypeOut;
11941189
Value operand;
11951190
int64_t ceilMode, storageOrder;
1196-
// TODO: Add support for indices output and storage_order
1191+
std::string autoPad;
11971192
if (binder.tensorOperand(operand) ||
11981193
binder.s64IntegerAttr(ceilMode, "ceil_mode", 0) ||
11991194
binder.s64IntegerAttr(storageOrder, "storage_order", 0) ||
1195+
binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET") ||
12001196
binder.tensorResultTypeAtIndex(resultTypeOut, 0))
12011197
return rewriter.notifyMatchFailure(
1202-
binder.op,
1203-
"operand/ceil_mode/storage_order/resultType bind failure");
1198+
binder.op, "operand/ceil_mode/storage_order/auto_pad/resultType "
1199+
"bind failure");
1200+
// TODO: Add support for storage_order
12041201
if (storageOrder != 0)
12051202
return rewriter.notifyMatchFailure(
12061203
binder.op, "storage_order setting is not supported.");
1204+
12071205
// Determine the rank of input tensor.
12081206
std::optional<unsigned> maybeRank = Torch::getTensorRank(operand);
12091207
if (!maybeRank)
12101208
return rewriter.notifyMatchFailure(binder.op,
12111209
"Unimplemented: unranked tensor");
1212-
int64_t rank = *maybeRank;
1213-
int64_t spatial = rank - 2;
1210+
unsigned rank = *maybeRank;
12141211

1215-
SmallVector<int64_t> kernel, padding, strides, dilations;
1216-
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {}))
1212+
SmallVector<int64_t> kernel, padding, strides, dilations,
1213+
stridesDilations;
1214+
if (failed(checkAndGetOnnxPoolingOpParameters(
1215+
binder, rewriter, resultTypeOut.getDtype(), autoPad,
1216+
/*spatialRank=*/rank - 2,
1217+
/*input=*/operand, kernel, strides, padding, dilations)))
12171218
return rewriter.notifyMatchFailure(binder.op,
1218-
"kernel_shape bind failure");
1219-
if (kernel.size() != static_cast<size_t>(spatial))
1220-
return rewriter.notifyMatchFailure(
1221-
binder.op, "kernel list size does not match the number of axes");
1222-
if (binder.s64IntegerArrayAttr(padding, "pads", {}))
1223-
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
1224-
if (!padding.empty() &&
1225-
padding.size() != static_cast<size_t>(2 * spatial))
1226-
return rewriter.notifyMatchFailure(
1227-
binder.op, "padding list must contain (begin,end) pair for each "
1228-
"spatial axis");
1229-
if (binder.s64IntegerArrayAttr(strides, "strides", {}))
1230-
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
1231-
if (!strides.empty() && strides.size() != static_cast<size_t>(spatial))
1232-
return rewriter.notifyMatchFailure(
1233-
binder.op, "strides list size does not match the number of axes");
1234-
if (binder.s64IntegerArrayAttr(dilations, "dilations", {}))
1235-
return rewriter.notifyMatchFailure(binder.op,
1236-
"dilations bind failure");
1237-
1238-
// set default padding
1239-
if (padding.empty())
1240-
padding.resize(spatial, 0);
1241-
if (strides.empty())
1242-
strides.resize(spatial, 1);
1243-
if (dilations.empty())
1244-
dilations.resize(spatial, 1);
1245-
1246-
auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
1247-
1248-
// Padding for the beginning and ending along each spatial axis, it can
1249-
// take any value greater than or equal to 0. The value represent the
1250-
// number of pixels added to the beginning and end part of the
1251-
// corresponding axis. pads format should be as follow [x1_begin,
1252-
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
1253-
// at the beginning of axis i and xi_end, the number of pixels added at
1254-
// the end of axis i.
1255-
if (autoPad != "NOTSET" && autoPad != "VALID") {
1256-
const bool isSameLower = autoPad == "SAME_LOWER";
1257-
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
1258-
padding.resize_for_overwrite(2 * spatial);
1259-
for (unsigned dimIdx = 0; dimIdx < spatial; dimIdx++) {
1260-
const int64_t dilatedKernelSize =
1261-
dilations[dimIdx] * (kernel[dimIdx] - 1) + 1;
1262-
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
1263-
strides[dimIdx] -
1264-
1) *
1265-
strides[dimIdx] +
1266-
dilatedKernelSize - inputShape[dimIdx + 2];
1267-
totalPad = totalPad >= 0 ? totalPad : 0;
1268-
padding[dimIdx] =
1269-
isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2);
1270-
padding[spatial + dimIdx] = totalPad - padding[dimIdx];
1271-
}
1272-
}
1273-
1274-
// If the padding is symmetric we can push the padding operation to the
1275-
// torch operator.
1276-
if (padding.size() == static_cast<size_t>(2 * spatial)) {
1277-
bool equal = true;
1278-
for (int i = 0; i < spatial; ++i) {
1279-
equal = equal && (padding[i] == padding[i + spatial]);
1280-
}
1281-
if (equal)
1282-
padding.resize(spatial);
1283-
}
1284-
1285-
// Torch pool operators require equal padding on each size of each
1286-
// dimension so we materialize the padding behavior explicitly and set
1287-
// the padding to 0.
1288-
if (padding.size() == static_cast<size_t>(2 * spatial)) {
1289-
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
1290-
llvm::SmallVector<int64_t> shuffledPadding(spatial * 2);
1291-
llvm::SmallVector<int64_t> paddedShape(operandTy.getSizes());
1292-
for (int i = 0; i < spatial; ++i) {
1293-
paddedShape[i + 2] += padding[i] + padding[i + spatial];
1294-
shuffledPadding[2 * i] = padding[spatial - i - 1];
1295-
shuffledPadding[2 * i + 1] = padding[2 * spatial - i - 1];
1296-
}
1297-
1298-
Value shuffledPaddingList =
1299-
createConstantIntList(binder, rewriter, shuffledPadding);
1300-
Value zero;
1301-
if (isa<FloatType>(resultTypeOut.getDtype())) {
1302-
zero = rewriter.create<Torch::ConstantFloatOp>(
1303-
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
1304-
rewriter.getF64FloatAttr(
1305-
std::numeric_limits<double>::lowest()));
1306-
} else if (isa<IntegerType>(resultTypeOut.getDtype())) {
1307-
zero = rewriter.create<Torch::ConstantIntOp>(
1308-
binder.getLoc(), rewriter.getI64IntegerAttr(
1309-
std::numeric_limits<int64_t>::lowest()));
1310-
}
1311-
1312-
auto paddedInputTy = rewriter.getType<Torch::ValueTensorType>(
1313-
paddedShape, operandTy.getDtype());
1314-
operand = rewriter.create<Torch::AtenConstantPadNdOp>(
1315-
binder.getLoc(), paddedInputTy, operand, shuffledPaddingList,
1316-
zero);
1317-
padding.clear();
1318-
padding.resize(spatial, 0);
1319-
}
1219+
"invalid pooling parameters");
13201220

13211221
Value kernelSizeList = createConstantIntList(binder, rewriter, kernel);
13221222
Value paddingList = createConstantIntList(binder, rewriter, padding);

0 commit comments

Comments
 (0)