@@ -456,107 +456,113 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
456
456
patterns.onOp (
457
457
" AveragePool" , 1 ,
458
458
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
459
- std::string autoPad;
460
- SmallVector<int64_t > dilations;
461
- if (binder.customOpNameStringAttr (autoPad, " auto_pad" , " NOTSET" ))
462
- return failure ();
463
- if (autoPad != " NOTSET" ) {
464
- // TODO: Add support for `auto_pad` != "NOTSET"
465
- return rewriter.notifyMatchFailure (
466
- binder.op , " unsupported conversion: auto_pad != NOTSET" );
467
- }
468
-
469
459
Torch::ValueTensorType resultType;
470
460
Value operand;
471
461
bool ceilMode, countIncludePad;
462
+ std::string autoPad;
472
463
if (binder.tensorOperand (operand) ||
473
464
binder.s64BoolAttr (ceilMode, " ceil_mode" , false ) ||
474
465
binder.s64BoolAttr (countIncludePad, " count_include_pad" , false ) ||
466
+ binder.customOpNameStringAttr (autoPad, " auto_pad" , " NOTSET" ) ||
475
467
binder.tensorResultType (resultType))
476
- return failure ();
468
+ return rewriter.notifyMatchFailure (
469
+ binder.op , " operand/ceil_mode/count_include_pad/auto_pad/"
470
+ " resultType bind failure" );
471
+
477
472
// Determine the rank of input tensor.
478
473
std::optional<unsigned > maybeRank = Torch::getTensorRank (operand);
479
474
if (!maybeRank)
480
475
return rewriter.notifyMatchFailure (binder.op ,
481
476
" Unimplemented: unranked tensor" );
482
477
unsigned rank = *maybeRank;
483
478
484
- SmallVector<int64_t > kernel, padding, strides;
485
- if (binder.s64IntegerArrayAttr (kernel, " kernel_shape" , {})) {
486
- return failure ();
487
- }
488
- if (kernel.size () != rank - 2 ) {
479
+ int64_t spatialRank = rank - 2 ;
480
+ SmallVector<int64_t > kernel, padding, strides, dilations;
481
+
482
+ if (binder.s64IntegerArrayAttr (kernel, " kernel_shape" , {}))
483
+ return rewriter.notifyMatchFailure (binder.op ,
484
+ " kernel_shape bind failure" );
485
+ if (kernel.size () != static_cast <size_t >(spatialRank))
489
486
return rewriter.notifyMatchFailure (
490
487
binder.op , " kernel list size does not match the number of axes" );
491
- }
492
- SmallVector<int64_t > defaultPadding (2 * (rank - 2 ), 0 );
493
- if (binder.s64IntegerArrayAttr (padding, " pads" , defaultPadding)) {
494
- return failure ();
495
- }
496
- if (padding.size () != 2 * (rank - 2 )) {
488
+
489
+ if (binder.s64IntegerArrayAttr (padding, " pads" , {}))
490
+ return rewriter.notifyMatchFailure (binder.op , " pads bind failure" );
491
+ if (!padding.empty () &&
492
+ padding.size () != static_cast <size_t >(2 * spatialRank))
497
493
return rewriter.notifyMatchFailure (
498
- binder.op ,
499
- " padding list size does not match twice the number of axes" );
500
- }
501
- if (binder.s64IntegerArrayAttr (
502
- strides, " strides" , llvm::SmallVector<int64_t >(rank - 2 , 1 ))) {
503
- return failure ();
504
- }
505
- if (strides.size () != 1 && strides.size () != rank - 2 ) {
494
+ binder.op , " padding list must contain (begin,end) pair for each "
495
+ " spatial axis" );
496
+
497
+ if (binder.s64IntegerArrayAttr (strides, " strides" , {}))
498
+ return rewriter.notifyMatchFailure (binder.op , " strides bind failure" );
499
+ if (!strides.empty () &&
500
+ strides.size () != static_cast <size_t >(spatialRank))
506
501
return rewriter.notifyMatchFailure (
507
502
binder.op , " strides list size does not match the number of axes" );
508
- }
509
503
510
- SmallVector<Value> cstKernel, cstPadding, cstStridesDilations;
511
- for (int64_t i : kernel) {
512
- cstKernel.push_back (rewriter.create <Torch::ConstantIntOp>(
513
- binder.getLoc (), rewriter.getI64IntegerAttr (i)));
514
- }
515
- // Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…]
516
- // Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all
517
- // axes x.
518
- int64_t paddingSizeHalf = padding.size () / 2 ;
519
- for (int64_t i = 0 ; i < paddingSizeHalf; ++i) {
520
- // Check if onnx padding attribute is symmetric.
521
- if (padding[i] != padding[i + paddingSizeHalf])
522
- return rewriter.notifyMatchFailure (
523
- binder.op , " onnx padding attribute is not symmetric" );
524
- cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
525
- binder.getLoc (), rewriter.getI64IntegerAttr (padding[i])));
504
+ if (binder.s64IntegerArrayAttr (dilations, " dilations" , {}))
505
+ return rewriter.notifyMatchFailure (binder.op ,
506
+ " dilations bind failure" );
507
+
508
+ // set default values for padding, strides, and dilations.
509
+ if (padding.empty ())
510
+ padding.resize (spatialRank, 0 );
511
+ if (strides.empty ())
512
+ strides.resize (spatialRank, 1 );
513
+ if (dilations.empty ())
514
+ dilations.resize (spatialRank, 1 );
515
+
516
+ // Padding for the beginning and ending along each spatial axis, it can
517
+ // take any value greater than or equal to 0. The value represent the
518
+ // number of pixels added to the beginning and end part of the
519
+ // corresponding axis. pads format should be as follow [x1_begin,
520
+ // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
521
+ // at the beginning of axis i and xi_end, the number of pixels added at
522
+ // the end of axis i.
523
+ auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType ());
524
+ if (autoPad != " NOTSET" && autoPad != " VALID" ) {
525
+ const bool isSameLower = autoPad == " SAME_LOWER" ;
526
+ ArrayRef<int64_t > inputShape = inputTensorType.getSizes ();
527
+ padding.resize_for_overwrite (2 * spatialRank);
528
+ for (unsigned dimIdx = 0 ; dimIdx < spatialRank; dimIdx++) {
529
+ const int64_t dilatedKernelSize =
530
+ dilations[dimIdx] * (kernel[dimIdx] - 1 ) + 1 ;
531
+ int64_t totalPad = ((inputShape[dimIdx + 2 ] + strides[dimIdx] - 1 ) /
532
+ strides[dimIdx] -
533
+ 1 ) *
534
+ strides[dimIdx] +
535
+ dilatedKernelSize - inputShape[dimIdx + 2 ];
536
+ totalPad = totalPad >= 0 ? totalPad : 0 ;
537
+ padding[dimIdx] =
538
+ isSameLower ? ((totalPad + 1 ) / 2 ) : (totalPad / 2 );
539
+ padding[spatialRank + dimIdx] = totalPad - padding[dimIdx];
540
+ }
526
541
}
527
- for (int64_t i : strides) {
528
- cstStridesDilations.push_back (rewriter.create <Torch::ConstantIntOp>(
529
- binder.getLoc (), rewriter.getI64IntegerAttr (i)));
542
+
543
+ // If the padding is symmetric then we don't need seperate low/high
544
+ // padding values.
545
+ if (padding.size () == static_cast <size_t >(2 * spatialRank)) {
546
+ bool equal = true ;
547
+ for (int i = 0 ; i < spatialRank; ++i) {
548
+ equal = equal && (padding[i] == padding[i + spatialRank]);
549
+ }
550
+ if (equal)
551
+ padding.resize (spatialRank);
530
552
}
531
553
532
- // No dilations attribute in pytorch avgpool op, so use this trick to
533
- // encode dilation into strides. Then in the following torchtolinalg
534
- // lowering, decode strides into strides + dilation.
554
+ // Since the PyTorch AvgPool op does not contain the `dilation` arg,
555
+ // hence we use the trick of encoding dilation into strides. Then,
556
+ // during the torch->linalg lowering of the `AvgPool` op we decode the
557
+ // `strides` arg into strides values followed by dilation like:
535
558
// [strideDim1,strideDim2,...,dilationDim1,dilationDim2,...]
536
- if (binder.s64IntegerArrayAttr (
537
- dilations, " dilations" ,
538
- llvm::SmallVector<int64_t >(rank - 2 , 1 ))) {
539
- return failure ();
540
- }
541
- for (auto dilation : dilations) {
542
- cstStridesDilations.push_back (rewriter.create <Torch::ConstantIntOp>(
543
- binder.getLoc (), rewriter.getI64IntegerAttr (dilation)));
544
- }
559
+ SmallVector<int64_t > stridesDilations = strides;
560
+ stridesDilations.append (dilations);
545
561
546
- Value kernelSizeList = rewriter.create <Torch::PrimListConstructOp>(
547
- binder.getLoc (),
548
- Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
549
- cstKernel);
550
- Value paddingList = rewriter.create <Torch::PrimListConstructOp>(
551
- binder.getLoc (),
552
- Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
553
- cstPadding);
562
+ Value kernelSizeList = createConstantIntList (binder, rewriter, kernel);
563
+ Value paddingList = createConstantIntList (binder, rewriter, padding);
554
564
Value stridesDilationsList =
555
- rewriter.create <Torch::PrimListConstructOp>(
556
- binder.getLoc (),
557
- Torch::ListType::get (
558
- Torch::IntType::get (binder.op ->getContext ())),
559
- cstStridesDilations);
565
+ createConstantIntList (binder, rewriter, stridesDilations);
560
566
Value cstCeilMode =
561
567
rewriter.create <Torch::ConstantBoolOp>(binder.getLoc (), ceilMode);
562
568
Value cstCountIncludePad = rewriter.create <Torch::ConstantBoolOp>(
0 commit comments