Skip to content

Commit bf1b7bc

Browse files
[ONNX] Add support for asymmetric padding for Onnx.AveragePool op
This commit also refactors the code for the Onnx's AveragePool and MaxPool op by creating a common utility for both the op lowerings to get the pooling op parameters. Signed-off-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
1 parent 71cb942 commit bf1b7bc

File tree

5 files changed

+175
-200
lines changed

5 files changed

+175
-200
lines changed

include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,15 @@ LogicalResult createTorchPermuteOp(OpBinder binder,
119119
SmallVector<int64_t> permuteDims,
120120
Value &permuted);
121121

122+
// Checks the validity of pooling parameters and stores them in the respective
123+
// vector.
124+
LogicalResult checkAndGetOnnxPoolingOpParameters(
125+
OpBinder binder, ConversionPatternRewriter &rewriter, Type resultDtype,
126+
std::string autoPad, int64_t spatialRank, Value &input,
127+
SmallVectorImpl<int64_t> &kernelSizeInts,
128+
SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &paddingInts,
129+
SmallVectorImpl<int64_t> &dilationInts);
130+
122131
} // namespace mlir::torch::onnx_c
123132

124133
#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H

lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

Lines changed: 26 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -456,107 +456,47 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
456456
patterns.onOp(
457457
"AveragePool", 11,
458458
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
459-
std::string autoPad;
460-
SmallVector<int64_t> dilations;
461-
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
462-
return failure();
463-
if (autoPad != "NOTSET") {
464-
// TODO: Add support for `auto_pad` != "NOTSET"
465-
return rewriter.notifyMatchFailure(
466-
binder.op, "unsupported conversion: auto_pad != NOTSET");
467-
}
468-
469459
Torch::ValueTensorType resultType;
470460
Value operand;
471-
bool ceilMode, countIncludePad;
461+
int64_t ceilMode, countIncludePad;
462+
std::string autoPad;
472463
if (binder.tensorOperand(operand) ||
473-
binder.s64BoolAttr(ceilMode, "ceil_mode", false) ||
474-
binder.s64BoolAttr(countIncludePad, "count_include_pad", false) ||
464+
binder.s64IntegerAttr(ceilMode, "ceil_mode", 0) ||
465+
binder.s64IntegerAttr(countIncludePad, "count_include_pad", 0) ||
466+
binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET") ||
475467
binder.tensorResultType(resultType))
476-
return failure();
468+
return rewriter.notifyMatchFailure(
469+
binder.op, "operand/ceil_mode/count_include_pad/auto_pad/"
470+
"resultType bind failure");
471+
477472
// Determine the rank of input tensor.
478473
std::optional<unsigned> maybeRank = Torch::getTensorRank(operand);
479474
if (!maybeRank)
480475
return rewriter.notifyMatchFailure(binder.op,
481476
"Unimplemented: unranked tensor");
482477
unsigned rank = *maybeRank;
483478

484-
SmallVector<int64_t> kernel, padding, strides;
485-
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) {
486-
return failure();
487-
}
488-
if (kernel.size() != rank - 2) {
489-
return rewriter.notifyMatchFailure(
490-
binder.op, "kernel list size does not match the number of axes");
491-
}
492-
SmallVector<int64_t> defaultPadding(2 * (rank - 2), 0);
493-
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) {
494-
return failure();
495-
}
496-
if (padding.size() != 2 * (rank - 2)) {
497-
return rewriter.notifyMatchFailure(
498-
binder.op,
499-
"padding list size does not match twice the number of axes");
500-
}
501-
if (binder.s64IntegerArrayAttr(
502-
strides, "strides", llvm::SmallVector<int64_t>(rank - 2, 1))) {
503-
return failure();
504-
}
505-
if (strides.size() != 1 && strides.size() != rank - 2) {
506-
return rewriter.notifyMatchFailure(
507-
binder.op, "strides list size does not match the number of axes");
508-
}
509-
510-
SmallVector<Value> cstKernel, cstPadding, cstStridesDilations;
511-
for (int64_t i : kernel) {
512-
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
513-
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
514-
}
515-
// Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…]
516-
// Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all
517-
// axes x.
518-
int64_t paddingSizeHalf = padding.size() / 2;
519-
for (int64_t i = 0; i < paddingSizeHalf; ++i) {
520-
// Check if onnx padding attribute is symmetric.
521-
if (padding[i] != padding[i + paddingSizeHalf])
522-
return rewriter.notifyMatchFailure(
523-
binder.op, "onnx padding attribute is not symmetric");
524-
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
525-
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
526-
}
527-
for (int64_t i : strides) {
528-
cstStridesDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
529-
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
530-
}
479+
SmallVector<int64_t> kernel, padding, strides, dilations,
480+
stridesDilations;
481+
if (failed(checkAndGetOnnxPoolingOpParameters(
482+
binder, rewriter, resultType.getDtype(), autoPad,
483+
/*spatialRank=*/rank - 2,
484+
/*input=*/operand, kernel, strides, padding, dilations)))
485+
return rewriter.notifyMatchFailure(binder.op,
486+
"invalid pooling parameters");
531487

532-
// No dilations attribute in pytorch avgpool op, so use this trick to
533-
// encode dilation into strides. Then in the following torchtolinalg
534-
// lowering, decode strides into strides + dilation.
488+
// Since the PyTorch AvgPool op does not contain the `dilation` arg,
489+
// hence we use the trick of encoding dilation into strides. Then,
490+
// during the torch->linalg lowering of the `AvgPool` op we decode the
491+
// `strides` arg into strides values followed by dilation like:
535492
// [strideDim1,strideDim2,...,dilationDim1,dilationDim2,...]
536-
if (binder.s64IntegerArrayAttr(
537-
dilations, "dilations",
538-
llvm::SmallVector<int64_t>(rank - 2, 1))) {
539-
return failure();
540-
}
541-
for (auto dilation : dilations) {
542-
cstStridesDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
543-
binder.getLoc(), rewriter.getI64IntegerAttr(dilation)));
544-
}
493+
stridesDilations = strides;
494+
stridesDilations.append(dilations);
545495

546-
Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
547-
binder.getLoc(),
548-
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
549-
cstKernel);
550-
Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
551-
binder.getLoc(),
552-
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
553-
cstPadding);
496+
Value kernelSizeList = createConstantIntList(binder, rewriter, kernel);
497+
Value paddingList = createConstantIntList(binder, rewriter, padding);
554498
Value stridesDilationsList =
555-
rewriter.create<Torch::PrimListConstructOp>(
556-
binder.getLoc(),
557-
Torch::ListType::get(
558-
Torch::IntType::get(binder.op->getContext())),
559-
cstStridesDilations);
499+
createConstantIntList(binder, rewriter, stridesDilations);
560500
Value cstCeilMode =
561501
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);
562502
Value cstCountIncludePad = rewriter.create<Torch::ConstantBoolOp>(

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 14 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,138 +1124,38 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
11241124
});
11251125
patterns.onOp(
11261126
"MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1127-
std::string autoPad;
1128-
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
1129-
return rewriter.notifyMatchFailure(binder.op,
1130-
"auto_pad bind failure");
1131-
11321127
Torch::ValueTensorType resultTypeOut;
11331128
Value operand;
11341129
int64_t ceilMode, storageOrder;
1135-
// TODO: Add support for indices output and storage_order
1130+
std::string autoPad;
11361131
if (binder.tensorOperand(operand) ||
11371132
binder.s64IntegerAttr(ceilMode, "ceil_mode", 0) ||
11381133
binder.s64IntegerAttr(storageOrder, "storage_order", 0) ||
1134+
binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET") ||
11391135
binder.tensorResultTypeAtIndex(resultTypeOut, 0))
11401136
return rewriter.notifyMatchFailure(
1141-
binder.op,
1142-
"operand/ceil_mode/storage_order/resultType bind failure");
1137+
binder.op, "operand/ceil_mode/storage_order/auto_pad/resultType "
1138+
"bind failure");
1139+
// TODO: Add support for storage_order
11431140
if (storageOrder != 0)
11441141
return rewriter.notifyMatchFailure(
11451142
binder.op, "storage_order setting is not supported.");
1143+
11461144
// Determine the rank of input tensor.
11471145
std::optional<unsigned> maybeRank = Torch::getTensorRank(operand);
11481146
if (!maybeRank)
11491147
return rewriter.notifyMatchFailure(binder.op,
11501148
"Unimplemented: unranked tensor");
1151-
int64_t rank = *maybeRank;
1152-
int64_t spatial = rank - 2;
1149+
unsigned rank = *maybeRank;
11531150

1154-
SmallVector<int64_t> kernel, padding, strides, dilations;
1155-
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {}))
1151+
SmallVector<int64_t> kernel, padding, strides, dilations,
1152+
stridesDilations;
1153+
if (failed(checkAndGetOnnxPoolingOpParameters(
1154+
binder, rewriter, resultTypeOut.getDtype(), autoPad,
1155+
/*spatialRank=*/rank - 2,
1156+
/*input=*/operand, kernel, strides, padding, dilations)))
11561157
return rewriter.notifyMatchFailure(binder.op,
1157-
"kernel_shape bind failure");
1158-
if (kernel.size() != static_cast<size_t>(spatial))
1159-
return rewriter.notifyMatchFailure(
1160-
binder.op, "kernel list size does not match the number of axes");
1161-
if (binder.s64IntegerArrayAttr(padding, "pads", {}))
1162-
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
1163-
if (!padding.empty() &&
1164-
padding.size() != static_cast<size_t>(2 * spatial))
1165-
return rewriter.notifyMatchFailure(
1166-
binder.op, "padding list must contain (begin,end) pair for each "
1167-
"spatial axis");
1168-
if (binder.s64IntegerArrayAttr(strides, "strides", {}))
1169-
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
1170-
if (!strides.empty() && strides.size() != static_cast<size_t>(spatial))
1171-
return rewriter.notifyMatchFailure(
1172-
binder.op, "strides list size does not match the number of axes");
1173-
if (binder.s64IntegerArrayAttr(dilations, "dilations", {}))
1174-
return rewriter.notifyMatchFailure(binder.op,
1175-
"dilations bind failure");
1176-
1177-
// set default padding
1178-
if (padding.empty())
1179-
padding.resize(spatial, 0);
1180-
if (strides.empty())
1181-
strides.resize(spatial, 1);
1182-
if (dilations.empty())
1183-
dilations.resize(spatial, 1);
1184-
1185-
auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
1186-
1187-
// Padding for the beginning and ending along each spatial axis, it can
1188-
// take any value greater than or equal to 0. The value represent the
1189-
// number of pixels added to the beginning and end part of the
1190-
// corresponding axis. pads format should be as follow [x1_begin,
1191-
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
1192-
// at the beginning of axis i and xi_end, the number of pixels added at
1193-
// the end of axis i.
1194-
if (autoPad != "NOTSET" && autoPad != "VALID") {
1195-
const bool isSameLower = autoPad == "SAME_LOWER";
1196-
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
1197-
padding.resize_for_overwrite(2 * spatial);
1198-
for (unsigned dimIdx = 0; dimIdx < spatial; dimIdx++) {
1199-
const int64_t dilatedKernelSize =
1200-
dilations[dimIdx] * (kernel[dimIdx] - 1) + 1;
1201-
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
1202-
strides[dimIdx] -
1203-
1) *
1204-
strides[dimIdx] +
1205-
dilatedKernelSize - inputShape[dimIdx + 2];
1206-
totalPad = totalPad >= 0 ? totalPad : 0;
1207-
padding[dimIdx] =
1208-
isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2);
1209-
padding[spatial + dimIdx] = totalPad - padding[dimIdx];
1210-
}
1211-
}
1212-
1213-
// If the padding is symmetric we can push the padding operation to the
1214-
// torch operator.
1215-
if (padding.size() == static_cast<size_t>(2 * spatial)) {
1216-
bool equal = true;
1217-
for (int i = 0; i < spatial; ++i) {
1218-
equal = equal && (padding[i] == padding[i + spatial]);
1219-
}
1220-
if (equal)
1221-
padding.resize(spatial);
1222-
}
1223-
1224-
// Torch pool operators require equal padding on each size of each
1225-
// dimension so we materialize the padding behavior explicitly and set
1226-
// the padding to 0.
1227-
if (padding.size() == static_cast<size_t>(2 * spatial)) {
1228-
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
1229-
llvm::SmallVector<int64_t> shuffledPadding(spatial * 2);
1230-
llvm::SmallVector<int64_t> paddedShape(operandTy.getSizes());
1231-
for (int i = 0; i < spatial; ++i) {
1232-
paddedShape[i + 2] += padding[i] + padding[i + spatial];
1233-
shuffledPadding[2 * i] = padding[spatial - i - 1];
1234-
shuffledPadding[2 * i + 1] = padding[2 * spatial - i - 1];
1235-
}
1236-
1237-
Value shuffledPaddingList =
1238-
createConstantIntList(binder, rewriter, shuffledPadding);
1239-
Value zero;
1240-
if (isa<FloatType>(resultTypeOut.getDtype())) {
1241-
zero = rewriter.create<Torch::ConstantFloatOp>(
1242-
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
1243-
rewriter.getF64FloatAttr(
1244-
std::numeric_limits<double>::lowest()));
1245-
} else if (isa<IntegerType>(resultTypeOut.getDtype())) {
1246-
zero = rewriter.create<Torch::ConstantIntOp>(
1247-
binder.getLoc(), rewriter.getI64IntegerAttr(
1248-
std::numeric_limits<int64_t>::lowest()));
1249-
}
1250-
1251-
auto paddedInputTy = rewriter.getType<Torch::ValueTensorType>(
1252-
paddedShape, operandTy.getDtype());
1253-
operand = rewriter.create<Torch::AtenConstantPadNdOp>(
1254-
binder.getLoc(), paddedInputTy, operand, shuffledPaddingList,
1255-
zero);
1256-
padding.clear();
1257-
padding.resize(spatial, 0);
1258-
}
1158+
"invalid pooling parameters");
12591159

12601160
Value kernelSizeList = createConstantIntList(binder, rewriter, kernel);
12611161
Value paddingList = createConstantIntList(binder, rewriter, padding);

0 commit comments

Comments
 (0)