diff --git a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h index b59d183b4084..2069bf381ab3 100644 --- a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h +++ b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h @@ -47,6 +47,16 @@ Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in, Value kernelSizeInt, Value strideInt, bool ceilMode = false); +// Helper function to caculate the output tensor dims for pooling-like ops. +// Along each dim: +// dim_out = +// floor((dim_in + totalPadding - dilation * (kernelSize - 1) - 1) / stride) + +// 1 +Value getOutputDimForPoolOps(OpBuilder &b, Location loc, Value in, + int64_t totalPadding, int64_t leftPadding, + Value dilationInt, Value kernelSizeInt, + Value strideInt, bool ceilMode); + // As above but for transposed convolution ops // Along each dim: // dim_out = diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 95ae068369c1..8f14515e425c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -456,24 +456,19 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( patterns.onOp( "AveragePool", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - std::string autoPad; - SmallVector dilations; - if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) - return failure(); - if (autoPad != "NOTSET") { - // TODO: Add support for `auto_pad` != "NOTSET" - return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: auto_pad != NOTSET"); - } - Torch::ValueTensorType resultType; Value operand; bool ceilMode, countIncludePad; + std::string autoPad; if (binder.tensorOperand(operand) || binder.s64BoolAttr(ceilMode, "ceil_mode", false) || binder.s64BoolAttr(countIncludePad, "count_include_pad", false) || + binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET") || binder.tensorResultType(resultType)) - return failure(); + return rewriter.notifyMatchFailure( + binder.op, "operand/ceil_mode/count_include_pad/auto_pad/" + "resultType bind failure"); + // Determine the rank of input tensor. std::optional maybeRank = Torch::getTensorRank(operand); if (!maybeRank) @@ -481,82 +476,93 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( "Unimplemented: unranked tensor"); unsigned rank = *maybeRank; - SmallVector kernel, padding, strides; - if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) { - return failure(); - } - if (kernel.size() != rank - 2) { + int64_t spatialRank = rank - 2; + SmallVector kernel, padding, strides, dilations; + + if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) + return rewriter.notifyMatchFailure(binder.op, + "kernel_shape bind failure"); + if (kernel.size() != static_cast(spatialRank)) return rewriter.notifyMatchFailure( binder.op, "kernel list size does not match the number of axes"); - } - SmallVector defaultPadding(2 * (rank - 2), 0); - if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { - return failure(); - } - if (padding.size() != 2 * (rank - 2)) { + + if (binder.s64IntegerArrayAttr(padding, "pads", {})) + return rewriter.notifyMatchFailure(binder.op, "pads bind failure"); + if (!padding.empty() && + padding.size() != static_cast(2 * spatialRank)) return rewriter.notifyMatchFailure( - binder.op, - "padding list size does not match twice the number of axes"); - } - if (binder.s64IntegerArrayAttr( - strides, "strides", llvm::SmallVector(rank - 2, 1))) { - return failure(); - } - if (strides.size() != 1 && strides.size() != rank - 2) { + binder.op, "padding list must contain (begin,end) pair for each " + "spatial axis"); + + if (binder.s64IntegerArrayAttr(strides, "strides", {})) + return rewriter.notifyMatchFailure(binder.op, "strides bind failure"); + if (!strides.empty() && + strides.size() != static_cast(spatialRank)) return rewriter.notifyMatchFailure( binder.op, "strides list size does not match the number of axes"); - } - SmallVector cstKernel, cstPadding, cstStridesDilations; - for (int64_t i : kernel) { - cstKernel.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); - } - // Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…] - // Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all - // axes x. - int64_t paddingSizeHalf = padding.size() / 2; - for (int64_t i = 0; i < paddingSizeHalf; ++i) { - // Check if onnx padding attribute is symmetric. - if (padding[i] != padding[i + paddingSizeHalf]) - return rewriter.notifyMatchFailure( - binder.op, "onnx padding attribute is not symmetric"); - cstPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + if (binder.s64IntegerArrayAttr(dilations, "dilations", {})) + return rewriter.notifyMatchFailure(binder.op, + "dilations bind failure"); + + // set default values for padding, strides, and dilations. + if (padding.empty()) + padding.resize(spatialRank, 0); + if (strides.empty()) + strides.resize(spatialRank, 1); + if (dilations.empty()) + dilations.resize(spatialRank, 1); + + // Padding for the beginning and ending along each spatial axis, it can + // take any value greater than or equal to 0. The value represent the + // number of pixels added to the beginning and end part of the + // corresponding axis. pads format should be as follow [x1_begin, + // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added + // at the beginning of axis i and xi_end, the number of pixels added at + // the end of axis i. + auto inputTensorType = cast(operand.getType()); + if (autoPad != "NOTSET" && autoPad != "VALID") { + const bool isSameLower = autoPad == "SAME_LOWER"; + ArrayRef inputShape = inputTensorType.getSizes(); + padding.resize_for_overwrite(2 * spatialRank); + for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) { + const int64_t dilatedKernelSize = + dilations[dimIdx] * (kernel[dimIdx] - 1) + 1; + int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) / + strides[dimIdx] - + 1) * + strides[dimIdx] + + dilatedKernelSize - inputShape[dimIdx + 2]; + totalPad = totalPad >= 0 ? totalPad : 0; + padding[dimIdx] = + isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2); + padding[spatialRank + dimIdx] = totalPad - padding[dimIdx]; + } } - for (int64_t i : strides) { - cstStridesDilations.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + + // If the padding is symmetric then we don't need seperate low/high + // padding values. + if (padding.size() == static_cast(2 * spatialRank)) { + bool equal = true; + for (int i = 0; i < spatialRank; ++i) { + equal = equal && (padding[i] == padding[i + spatialRank]); + } + if (equal) + padding.resize(spatialRank); } - // No dilations attribute in pytorch avgpool op, so use this trick to - // encode dilation into strides. Then in the following torchtolinalg - // lowering, decode strides into strides + dilation. + // Since the PyTorch AvgPool op does not contain the `dilation` arg, + // hence we use the trick of encoding dilation into strides. Then, + // during the torch->linalg lowering of the `AvgPool` op we decode the + // `strides` arg into strides values followed by dilation like: // [strideDim1,strideDim2,...,dilationDim1,dilationDim2,...] - if (binder.s64IntegerArrayAttr( - dilations, "dilations", - llvm::SmallVector(rank - 2, 1))) { - return failure(); - } - for (auto dilation : dilations) { - cstStridesDilations.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(dilation))); - } + SmallVector stridesDilations = strides; + stridesDilations.append(dilations); - Value kernelSizeList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - cstKernel); - Value paddingList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - cstPadding); + Value kernelSizeList = createConstantIntList(binder, rewriter, kernel); + Value paddingList = createConstantIntList(binder, rewriter, padding); Value stridesDilationsList = - rewriter.create( - binder.getLoc(), - Torch::ListType::get( - Torch::IntType::get(binder.op->getContext())), - cstStridesDilations); + createConstantIntList(binder, rewriter, stridesDilations); Value cstCeilMode = rewriter.create(binder.getLoc(), ceilMode); Value cstCountIncludePad = rewriter.create( diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 45268452a992..e307c39973c8 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -83,8 +83,6 @@ computeOutputTensor(Operation *op, ConversionPatternRewriter &rewriter, Value N = getDimOp(rewriter, loc, self, 0); Value C = getDimOp(rewriter, loc, self, 1); - SmallVector paddingIntValues = - getAsConstantIntValues(rewriter, loc, paddingInts); SmallVector dilationIntValues = getAsConstantIntValues(rewriter, loc, dilationInts); SmallVector strideIntValues = @@ -92,9 +90,17 @@ computeOutputTensor(Operation *op, ConversionPatternRewriter &rewriter, // Get dimension size for each dimension and calculate output size for (int64_t i = dimensionality - 1; i > -1; --i) { + // In case of asymmetric padding the total padding value would be the sum of + // low and high padding. And, in case of symmetric padding it would just be + // the double of padding value for the corresponding dimension. + int64_t totalPadding = paddingInts[i] * 2; + if ((int64_t)paddingInts.size() == 2 * dimensionality) + totalPadding = paddingInts[i] + paddingInts[i + dimensionality]; + Value dimSize = getDimOp(rewriter, loc, self, i + 2); - Value outDim = torch_to_linalg::getOutputDimForConvOps( - rewriter, loc, dimSize, paddingIntValues[i], dilationIntValues[i], + Value outDim = torch_to_linalg::getOutputDimForPoolOps( + rewriter, loc, dimSize, /*totalPadding=*/totalPadding, + /*leftPadding=*/paddingInts[i], dilationIntValues[i], kernelSizeIntValues[i], strideIntValues[i], ceilMode); outTensorShape.insert(outTensorShape.begin(), {outDim}); } diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index aa4ae60b76e1..e98ad5dca084 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -135,6 +135,53 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, return castIntToIndex(b, loc, out); } +Value torch_to_linalg::getOutputDimForPoolOps(OpBuilder &b, Location loc, + Value in, int64_t totalPadding, + int64_t leftPadding, + Value dilationInt, + Value kernelSizeInt, + Value strideInt, bool ceilMode) { + Value c1 = b.create(loc, b.getI64IntegerAttr(1)); + Value totalPaddingIntCst = + b.create(loc, b.getI64IntegerAttr(totalPadding)); + + // in + totalPadding + Value inAddTotalPadding = b.createOrFold( + loc, castIndexToInt64(b, loc, in), totalPaddingIntCst); + + // dilation * (kernelSize - 1) + Value kernelSizeSub1 = b.createOrFold(loc, kernelSizeInt, c1); + Value dilationTimesKernelSize = + b.createOrFold(loc, dilationInt, kernelSizeSub1); + + Value temp = b.createOrFold(loc, inAddTotalPadding, + dilationTimesKernelSize); + Value dividend = b.createOrFold(loc, temp, c1); + Value division; + if (ceilMode) + division = b.createOrFold(loc, dividend, strideInt); + else + division = b.createOrFold(loc, dividend, strideInt); + Value out = b.createOrFold(loc, division, c1); + + if (!ceilMode) + return castIntToIndex(b, loc, out); + + Value outMinusOneTimesStride = + b.createOrFold(loc, division, strideInt); + Value leftPaddingIntCst = + b.create(loc, b.getI64IntegerAttr(leftPadding)); + Value inAddLeftPadding = b.createOrFold( + loc, castIndexToInt64(b, loc, in), leftPaddingIntCst); + + auto reduceOutputDimCond = b.createOrFold( + loc, arith::CmpIPredicate::uge, outMinusOneTimesStride, inAddLeftPadding); + + auto reducedDim = + b.createOrFold(loc, reduceOutputDimCond, division, out); + return castIntToIndex(b, loc, reducedDim); +} + Value torch_to_linalg::getOutputDimForConvTransposeOps( OpBuilder &b, Location loc, Value in, Value paddingInt, Value dilationInt, Value kernelSizeInt, Value strideInt, Value outputPaddingInt) { diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 48a1547a740f..8d83dc181987 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -996,6 +996,16 @@ func.func @test_averagepool_with_padding(%arg0: !torch.vtensor<[1,20,64,48],f32> // ----- +// CHECK-LABEL: @test_averagepool_with_asymmetric_padding +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,1024,6,6],f32> +func.func @test_averagepool_with_asymmetric_padding(%arg1: !torch.vtensor<[1,1024,6,6],f32>) -> !torch.vtensor<[1,1024,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {ai.onnx.contrib = 1000 : si64, ai.onnx.ml = 3 : si64, ai.onnx.preview.training = 1 : si64, ai.onnx.training = 1 : si64, com.microsoft = 1 : si64, com.microsoft.experimental = 1 : si64, com.microsoft.nchwc = 1 : si64, com.ms.internal.nhwc = 1 : si64, org.pytorch.aten = 1 : si64}, torch.onnx_meta.producer_name = "onnx.quantize", torch.onnx_meta.producer_version = "0.1.0"} { + %1 = torch.operator "onnx.AveragePool"(%arg1) {torch.onnx.auto_pad = "NOTSET", torch.onnx.ceil_mode = 0 : si64, torch.onnx.count_include_pad = 0 : si64, torch.onnx.kernel_shape = [7 : si64, 7 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[1,1024,6,6],f32>) -> !torch.vtensor<[1,1024,1,1],f32> + // CHECK: torch.aten.avg_pool2d %[[ARG]], {{.*}} : !torch.vtensor<[1,1024,6,6],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1024,1,1],f32> + return %1 : !torch.vtensor<[1,1024,1,1],f32> +} + +// ----- + // CHECK-LABEL: @test_conv_with_strides_no_padding func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0