@@ -1185,138 +1185,38 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
1185
1185
});
1186
1186
patterns.onOp (
1187
1187
" MaxPool" , 12 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1188
- std::string autoPad;
1189
- if (binder.customOpNameStringAttr (autoPad, " auto_pad" , " NOTSET" ))
1190
- return rewriter.notifyMatchFailure (binder.op ,
1191
- " auto_pad bind failure" );
1192
-
1193
1188
Torch::ValueTensorType resultTypeOut;
1194
1189
Value operand;
1195
1190
int64_t ceilMode, storageOrder;
1196
- // TODO: Add support for indices output and storage_order
1191
+ std::string autoPad;
1197
1192
if (binder.tensorOperand (operand) ||
1198
1193
binder.s64IntegerAttr (ceilMode, " ceil_mode" , 0 ) ||
1199
1194
binder.s64IntegerAttr (storageOrder, " storage_order" , 0 ) ||
1195
+ binder.customOpNameStringAttr (autoPad, " auto_pad" , " NOTSET" ) ||
1200
1196
binder.tensorResultTypeAtIndex (resultTypeOut, 0 ))
1201
1197
return rewriter.notifyMatchFailure (
1202
- binder.op ,
1203
- " operand/ceil_mode/storage_order/resultType bind failure" );
1198
+ binder.op , " operand/ceil_mode/storage_order/auto_pad/resultType "
1199
+ " bind failure" );
1200
+ // TODO: Add support for storage_order
1204
1201
if (storageOrder != 0 )
1205
1202
return rewriter.notifyMatchFailure (
1206
1203
binder.op , " storage_order setting is not supported." );
1204
+
1207
1205
// Determine the rank of input tensor.
1208
1206
std::optional<unsigned > maybeRank = Torch::getTensorRank (operand);
1209
1207
if (!maybeRank)
1210
1208
return rewriter.notifyMatchFailure (binder.op ,
1211
1209
" Unimplemented: unranked tensor" );
1212
- int64_t rank = *maybeRank;
1213
- int64_t spatial = rank - 2 ;
1210
+ unsigned rank = *maybeRank;
1214
1211
1215
- SmallVector<int64_t > kernel, padding, strides, dilations;
1216
- if (binder.s64IntegerArrayAttr (kernel, " kernel_shape" , {}))
1212
+ SmallVector<int64_t > kernel, padding, strides, dilations,
1213
+ stridesDilations;
1214
+ if (failed (checkAndGetOnnxPoolingOpParameters (
1215
+ binder, rewriter, resultTypeOut.getDtype (), autoPad,
1216
+ /* spatialRank=*/ rank - 2 ,
1217
+ /* input=*/ operand, kernel, strides, padding, dilations)))
1217
1218
return rewriter.notifyMatchFailure (binder.op ,
1218
- " kernel_shape bind failure" );
1219
- if (kernel.size () != static_cast <size_t >(spatial))
1220
- return rewriter.notifyMatchFailure (
1221
- binder.op , " kernel list size does not match the number of axes" );
1222
- if (binder.s64IntegerArrayAttr (padding, " pads" , {}))
1223
- return rewriter.notifyMatchFailure (binder.op , " pads bind failure" );
1224
- if (!padding.empty () &&
1225
- padding.size () != static_cast <size_t >(2 * spatial))
1226
- return rewriter.notifyMatchFailure (
1227
- binder.op , " padding list must contain (begin,end) pair for each "
1228
- " spatial axis" );
1229
- if (binder.s64IntegerArrayAttr (strides, " strides" , {}))
1230
- return rewriter.notifyMatchFailure (binder.op , " strides bind failure" );
1231
- if (!strides.empty () && strides.size () != static_cast <size_t >(spatial))
1232
- return rewriter.notifyMatchFailure (
1233
- binder.op , " strides list size does not match the number of axes" );
1234
- if (binder.s64IntegerArrayAttr (dilations, " dilations" , {}))
1235
- return rewriter.notifyMatchFailure (binder.op ,
1236
- " dilations bind failure" );
1237
-
1238
- // set default padding
1239
- if (padding.empty ())
1240
- padding.resize (spatial, 0 );
1241
- if (strides.empty ())
1242
- strides.resize (spatial, 1 );
1243
- if (dilations.empty ())
1244
- dilations.resize (spatial, 1 );
1245
-
1246
- auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType ());
1247
-
1248
- // Padding for the beginning and ending along each spatial axis, it can
1249
- // take any value greater than or equal to 0. The value represent the
1250
- // number of pixels added to the beginning and end part of the
1251
- // corresponding axis. pads format should be as follow [x1_begin,
1252
- // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
1253
- // at the beginning of axis i and xi_end, the number of pixels added at
1254
- // the end of axis i.
1255
- if (autoPad != " NOTSET" && autoPad != " VALID" ) {
1256
- const bool isSameLower = autoPad == " SAME_LOWER" ;
1257
- ArrayRef<int64_t > inputShape = inputTensorType.getSizes ();
1258
- padding.resize_for_overwrite (2 * spatial);
1259
- for (unsigned dimIdx = 0 ; dimIdx < spatial; dimIdx++) {
1260
- const int64_t dilatedKernelSize =
1261
- dilations[dimIdx] * (kernel[dimIdx] - 1 ) + 1 ;
1262
- int64_t totalPad = ((inputShape[dimIdx + 2 ] + strides[dimIdx] - 1 ) /
1263
- strides[dimIdx] -
1264
- 1 ) *
1265
- strides[dimIdx] +
1266
- dilatedKernelSize - inputShape[dimIdx + 2 ];
1267
- totalPad = totalPad >= 0 ? totalPad : 0 ;
1268
- padding[dimIdx] =
1269
- isSameLower ? ((totalPad + 1 ) / 2 ) : (totalPad / 2 );
1270
- padding[spatial + dimIdx] = totalPad - padding[dimIdx];
1271
- }
1272
- }
1273
-
1274
- // If the padding is symmetric we can push the padding operation to the
1275
- // torch operator.
1276
- if (padding.size () == static_cast <size_t >(2 * spatial)) {
1277
- bool equal = true ;
1278
- for (int i = 0 ; i < spatial; ++i) {
1279
- equal = equal && (padding[i] == padding[i + spatial]);
1280
- }
1281
- if (equal)
1282
- padding.resize (spatial);
1283
- }
1284
-
1285
- // Torch pool operators require equal padding on each size of each
1286
- // dimension so we materialize the padding behavior explicitly and set
1287
- // the padding to 0.
1288
- if (padding.size () == static_cast <size_t >(2 * spatial)) {
1289
- auto operandTy = cast<Torch::ValueTensorType>(operand.getType ());
1290
- llvm::SmallVector<int64_t > shuffledPadding (spatial * 2 );
1291
- llvm::SmallVector<int64_t > paddedShape (operandTy.getSizes ());
1292
- for (int i = 0 ; i < spatial; ++i) {
1293
- paddedShape[i + 2 ] += padding[i] + padding[i + spatial];
1294
- shuffledPadding[2 * i] = padding[spatial - i - 1 ];
1295
- shuffledPadding[2 * i + 1 ] = padding[2 * spatial - i - 1 ];
1296
- }
1297
-
1298
- Value shuffledPaddingList =
1299
- createConstantIntList (binder, rewriter, shuffledPadding);
1300
- Value zero;
1301
- if (isa<FloatType>(resultTypeOut.getDtype ())) {
1302
- zero = rewriter.create <Torch::ConstantFloatOp>(
1303
- binder.getLoc (), rewriter.getType <Torch::FloatType>(),
1304
- rewriter.getF64FloatAttr (
1305
- std::numeric_limits<double >::lowest ()));
1306
- } else if (isa<IntegerType>(resultTypeOut.getDtype ())) {
1307
- zero = rewriter.create <Torch::ConstantIntOp>(
1308
- binder.getLoc (), rewriter.getI64IntegerAttr (
1309
- std::numeric_limits<int64_t >::lowest ()));
1310
- }
1311
-
1312
- auto paddedInputTy = rewriter.getType <Torch::ValueTensorType>(
1313
- paddedShape, operandTy.getDtype ());
1314
- operand = rewriter.create <Torch::AtenConstantPadNdOp>(
1315
- binder.getLoc (), paddedInputTy, operand, shuffledPaddingList,
1316
- zero);
1317
- padding.clear ();
1318
- padding.resize (spatial, 0 );
1319
- }
1219
+ " invalid pooling parameters" );
1320
1220
1321
1221
Value kernelSizeList = createConstantIntList (binder, rewriter, kernel);
1322
1222
Value paddingList = createConstantIntList (binder, rewriter, padding);
0 commit comments