@@ -1124,138 +1124,38 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
1124
1124
});
1125
1125
patterns.onOp (
1126
1126
" MaxPool" , 12 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1127
- std::string autoPad;
1128
- if (binder.customOpNameStringAttr (autoPad, " auto_pad" , " NOTSET" ))
1129
- return rewriter.notifyMatchFailure (binder.op ,
1130
- " auto_pad bind failure" );
1131
-
1132
1127
Torch::ValueTensorType resultTypeOut;
1133
1128
Value operand;
1134
1129
int64_t ceilMode, storageOrder;
1135
- // TODO: Add support for indices output and storage_order
1130
+ std::string autoPad;
1136
1131
if (binder.tensorOperand (operand) ||
1137
1132
binder.s64IntegerAttr (ceilMode, " ceil_mode" , 0 ) ||
1138
1133
binder.s64IntegerAttr (storageOrder, " storage_order" , 0 ) ||
1134
+ binder.customOpNameStringAttr (autoPad, " auto_pad" , " NOTSET" ) ||
1139
1135
binder.tensorResultTypeAtIndex (resultTypeOut, 0 ))
1140
1136
return rewriter.notifyMatchFailure (
1141
- binder.op ,
1142
- " operand/ceil_mode/storage_order/resultType bind failure" );
1137
+ binder.op , " operand/ceil_mode/storage_order/auto_pad/resultType "
1138
+ " bind failure" );
1139
+ // TODO: Add support for storage_order
1143
1140
if (storageOrder != 0 )
1144
1141
return rewriter.notifyMatchFailure (
1145
1142
binder.op , " storage_order setting is not supported." );
1143
+
1146
1144
// Determine the rank of input tensor.
1147
1145
std::optional<unsigned > maybeRank = Torch::getTensorRank (operand);
1148
1146
if (!maybeRank)
1149
1147
return rewriter.notifyMatchFailure (binder.op ,
1150
1148
" Unimplemented: unranked tensor" );
1151
- int64_t rank = *maybeRank;
1152
- int64_t spatial = rank - 2 ;
1149
+ unsigned rank = *maybeRank;
1153
1150
1154
- SmallVector<int64_t > kernel, padding, strides, dilations;
1155
- if (binder.s64IntegerArrayAttr (kernel, " kernel_shape" , {}))
1151
+ SmallVector<int64_t > kernel, padding, strides, dilations,
1152
+ stridesDilations;
1153
+ if (failed (checkAndGetOnnxPoolingOpParameters (
1154
+ binder, rewriter, resultTypeOut.getDtype (), autoPad,
1155
+ /* spatialRank=*/ rank - 2 ,
1156
+ /* input=*/ operand, kernel, strides, padding, dilations)))
1156
1157
return rewriter.notifyMatchFailure (binder.op ,
1157
- " kernel_shape bind failure" );
1158
- if (kernel.size () != static_cast <size_t >(spatial))
1159
- return rewriter.notifyMatchFailure (
1160
- binder.op , " kernel list size does not match the number of axes" );
1161
- if (binder.s64IntegerArrayAttr (padding, " pads" , {}))
1162
- return rewriter.notifyMatchFailure (binder.op , " pads bind failure" );
1163
- if (!padding.empty () &&
1164
- padding.size () != static_cast <size_t >(2 * spatial))
1165
- return rewriter.notifyMatchFailure (
1166
- binder.op , " padding list must contain (begin,end) pair for each "
1167
- " spatial axis" );
1168
- if (binder.s64IntegerArrayAttr (strides, " strides" , {}))
1169
- return rewriter.notifyMatchFailure (binder.op , " strides bind failure" );
1170
- if (!strides.empty () && strides.size () != static_cast <size_t >(spatial))
1171
- return rewriter.notifyMatchFailure (
1172
- binder.op , " strides list size does not match the number of axes" );
1173
- if (binder.s64IntegerArrayAttr (dilations, " dilations" , {}))
1174
- return rewriter.notifyMatchFailure (binder.op ,
1175
- " dilations bind failure" );
1176
-
1177
- // set default padding
1178
- if (padding.empty ())
1179
- padding.resize (spatial, 0 );
1180
- if (strides.empty ())
1181
- strides.resize (spatial, 1 );
1182
- if (dilations.empty ())
1183
- dilations.resize (spatial, 1 );
1184
-
1185
- auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType ());
1186
-
1187
- // Padding for the beginning and ending along each spatial axis, it can
1188
- // take any value greater than or equal to 0. The value represent the
1189
- // number of pixels added to the beginning and end part of the
1190
- // corresponding axis. pads format should be as follow [x1_begin,
1191
- // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
1192
- // at the beginning of axis i and xi_end, the number of pixels added at
1193
- // the end of axis i.
1194
- if (autoPad != " NOTSET" && autoPad != " VALID" ) {
1195
- const bool isSameLower = autoPad == " SAME_LOWER" ;
1196
- ArrayRef<int64_t > inputShape = inputTensorType.getSizes ();
1197
- padding.resize_for_overwrite (2 * spatial);
1198
- for (unsigned dimIdx = 0 ; dimIdx < spatial; dimIdx++) {
1199
- const int64_t dilatedKernelSize =
1200
- dilations[dimIdx] * (kernel[dimIdx] - 1 ) + 1 ;
1201
- int64_t totalPad = ((inputShape[dimIdx + 2 ] + strides[dimIdx] - 1 ) /
1202
- strides[dimIdx] -
1203
- 1 ) *
1204
- strides[dimIdx] +
1205
- dilatedKernelSize - inputShape[dimIdx + 2 ];
1206
- totalPad = totalPad >= 0 ? totalPad : 0 ;
1207
- padding[dimIdx] =
1208
- isSameLower ? ((totalPad + 1 ) / 2 ) : (totalPad / 2 );
1209
- padding[spatial + dimIdx] = totalPad - padding[dimIdx];
1210
- }
1211
- }
1212
-
1213
- // If the padding is symmetric we can push the padding operation to the
1214
- // torch operator.
1215
- if (padding.size () == static_cast <size_t >(2 * spatial)) {
1216
- bool equal = true ;
1217
- for (int i = 0 ; i < spatial; ++i) {
1218
- equal = equal && (padding[i] == padding[i + spatial]);
1219
- }
1220
- if (equal)
1221
- padding.resize (spatial);
1222
- }
1223
-
1224
- // Torch pool operators require equal padding on each size of each
1225
- // dimension so we materialize the padding behavior explicitly and set
1226
- // the padding to 0.
1227
- if (padding.size () == static_cast <size_t >(2 * spatial)) {
1228
- auto operandTy = cast<Torch::ValueTensorType>(operand.getType ());
1229
- llvm::SmallVector<int64_t > shuffledPadding (spatial * 2 );
1230
- llvm::SmallVector<int64_t > paddedShape (operandTy.getSizes ());
1231
- for (int i = 0 ; i < spatial; ++i) {
1232
- paddedShape[i + 2 ] += padding[i] + padding[i + spatial];
1233
- shuffledPadding[2 * i] = padding[spatial - i - 1 ];
1234
- shuffledPadding[2 * i + 1 ] = padding[2 * spatial - i - 1 ];
1235
- }
1236
-
1237
- Value shuffledPaddingList =
1238
- createConstantIntList (binder, rewriter, shuffledPadding);
1239
- Value zero;
1240
- if (isa<FloatType>(resultTypeOut.getDtype ())) {
1241
- zero = rewriter.create <Torch::ConstantFloatOp>(
1242
- binder.getLoc (), rewriter.getType <Torch::FloatType>(),
1243
- rewriter.getF64FloatAttr (
1244
- std::numeric_limits<double >::lowest ()));
1245
- } else if (isa<IntegerType>(resultTypeOut.getDtype ())) {
1246
- zero = rewriter.create <Torch::ConstantIntOp>(
1247
- binder.getLoc (), rewriter.getI64IntegerAttr (
1248
- std::numeric_limits<int64_t >::lowest ()));
1249
- }
1250
-
1251
- auto paddedInputTy = rewriter.getType <Torch::ValueTensorType>(
1252
- paddedShape, operandTy.getDtype ());
1253
- operand = rewriter.create <Torch::AtenConstantPadNdOp>(
1254
- binder.getLoc (), paddedInputTy, operand, shuffledPaddingList,
1255
- zero);
1256
- padding.clear ();
1257
- padding.resize (spatial, 0 );
1258
- }
1158
+ " invalid pooling parameters" );
1259
1159
1260
1160
Value kernelSizeList = createConstantIntList (binder, rewriter, kernel);
1261
1161
Value paddingList = createConstantIntList (binder, rewriter, padding);
0 commit comments