Skip to content

Commit 8bf5c66

Browse files
authored
submit the llvm#3902 to local repo (#5)
* Decompose lstm and gru. * Add tests and update xfail_sets.py * Rebase main * Fix casting for arith.cmpi operands to be of same type.
1 parent 6a85b98 commit 8bf5c66

File tree

5 files changed

+151
-19
lines changed

5 files changed

+151
-19
lines changed

lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp

+19-3
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,21 @@ class ConvertAtenEmbeddingBagPaddingIdxOp
417417
};
418418
} // namespace
419419

420+
static Value wrapIndicesAroundMax(OpBuilder &b, Location loc, Value index,
421+
Value input, int64_t dim) {
422+
// performs the operation : index = index % maxIndex to wrap index around
423+
// maxIndex
424+
Value maxIndexValue = getDimOp(b, loc, input, dim);
425+
maxIndexValue =
426+
b.createOrFold<arith::IndexCastOp>(loc, index.getType(), maxIndexValue);
427+
Value isBeyondMaxIndices = b.createOrFold<arith::CmpIOp>(
428+
loc, arith::CmpIPredicate::sge, index, maxIndexValue);
429+
Value wrappedIndices =
430+
b.createOrFold<arith::RemSIOp>(loc, index, maxIndexValue);
431+
return b.createOrFold<arith::SelectOp>(loc, isBeyondMaxIndices,
432+
wrappedIndices, index);
433+
}
434+
420435
namespace {
421436
// Let's say we have an input tensor: initialized with some random values of
422437
// size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an
@@ -478,16 +493,17 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
478493

479494
auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr},
480495
rewriter.getContext());
481-
482496
Value finalRes =
483497
rewriter
484498
.create<linalg::GenericOp>(
485499
loc, initTensor.getType(), ValueRange{indices}, initTensor,
486500
/*indexingMaps=*/indexingMaps,
487501
/*iteratorTypes=*/iteratorTypes,
488502
[&](OpBuilder &b, Location loc, ValueRange args) {
489-
Value index = rewriter.create<arith::IndexCastOp>(
490-
loc, rewriter.getIndexType(), args[0]);
503+
Value index =
504+
wrapIndicesAroundMax(b, loc, args[0], input, dimInt);
505+
index = rewriter.create<arith::IndexCastOp>(
506+
loc, rewriter.getIndexType(), index);
491507
SmallVector<Value> indexTarget;
492508
for (unsigned i = 0; i < inputRank; i++)
493509
indexTarget.push_back(b.create<linalg::IndexOp>(loc, i));

lib/Conversion/TorchToTosa/TorchToTosa.cpp

+41
Original file line numberDiff line numberDiff line change
@@ -4023,6 +4023,41 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
40234023
return success();
40244024
}
40254025

4026+
Value wrapIndicesAroundMax(Value index, int maxIndex, Operation *op,
4027+
ConversionPatternRewriter &rewriter) {
4028+
// performs the operation : index = index % maxIndex to wrap index around
4029+
// maxIndex
4030+
4031+
auto maxIndexValue =
4032+
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex, {}).value();
4033+
auto maxIndexValueMinusOne =
4034+
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex - 1, {}).value();
4035+
4036+
auto indexType = dyn_cast<RankedTensorType>(index.getType());
4037+
auto boolType = indexType.clone(rewriter.getIntegerType(1));
4038+
4039+
auto isBeyondMaxIndices = tosa::CreateOpAndInfer<tosa::GreaterOp>(
4040+
rewriter, op->getLoc(), boolType, index, maxIndexValueMinusOne);
4041+
auto wrappedBeyondMaxIndicesQuotient =
4042+
tosa::CreateOpAndInfer<tosa::IntDivOp>(rewriter, op->getLoc(), indexType,
4043+
index, maxIndexValue)
4044+
.getResult();
4045+
auto wrappedBeyondMaxIndicesQuotientTimesIndices =
4046+
tosa::CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(), indexType,
4047+
wrappedBeyondMaxIndicesQuotient,
4048+
maxIndexValue, /*shift=*/0)
4049+
.getResult();
4050+
auto wrappedBeyondMaxIndices =
4051+
tosa::CreateOpAndInfer<tosa::SubOp>(
4052+
rewriter, op->getLoc(), indexType, index,
4053+
wrappedBeyondMaxIndicesQuotientTimesIndices)
4054+
.getResult();
4055+
4056+
return tosa::CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(),
4057+
indexType, isBeyondMaxIndices,
4058+
wrappedBeyondMaxIndices, index);
4059+
}
4060+
40264061
template <>
40274062
LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
40284063
AtenIndexSelectOp op, OpAdaptor adaptor,
@@ -4066,6 +4101,10 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
40664101
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
40674102
}
40684103

4104+
int64_t selfNumElems = std::accumulate(inputShape.begin(), inputShape.end(),
4105+
1, std::multiplies<int64_t>());
4106+
index = wrapIndicesAroundMax(index, selfNumElems, op, rewriter);
4107+
40694108
// Get positive dim
40704109
int64_t dim;
40714110
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
@@ -7266,10 +7305,12 @@ LogicalResult ConvertAtenOp<AtenAsStridedOp>::matchAndRewrite(
72667305
// coord_i_n * stride[n]
72677306
int32_t index = offset;
72687307
int64_t coordFinder = i;
7308+
72697309
for (int64_t dim = 0; dim < outputRank; dim++) {
72707310
int64_t indexCoord = coordFinder % outputSize[outputRank - dim - 1];
72717311
index += indexCoord * stride[outputRank - dim - 1];
72727312
coordFinder /= outputSize[outputRank - dim - 1];
7313+
index = (index % selfNumElems);
72737314
}
72747315
targetIndicesVec.push_back(index);
72757316
}

projects/pt1/e2e_testing/xfail_sets.py

+3
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@
484484
"SplitTensorNegativeDimModule_basic",
485485
"SplitWithSizesListUnpackModule_basic",
486486
"SplitWithSizes_Module_basic",
487+
"AsStridedWithOffsetModule_basic",
487488
"AdaptiveAvgPool1dGeneralDynamic_basic",
488489
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
489490
"AdaptiveAvgPool1dStaticLargerOutput_basic",
@@ -895,6 +896,7 @@
895896
"SplitTensorNegativeDimModule_basic",
896897
"SplitWithSizesListUnpackModule_basic",
897898
"SplitWithSizes_Module_basic",
899+
"AsStridedWithOffsetModule_basic",
898900
"Unfold_Module_basic",
899901
"Unfold_Module_Rank_4",
900902
"Unfold_Module_Rank_Zero_basic",
@@ -1784,6 +1786,7 @@
17841786
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
17851787
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
17861788
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
1789+
"AsStridedWithOffsetModule_basic",
17871790
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
17881791
"ElementwiseCosIntModule_basic",
17891792
"ElementwiseReciprocalIntModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py

+29
Original file line numberDiff line numberDiff line change
@@ -1144,3 +1144,32 @@ def forward(self, x):
11441144
@register_test_case(module_factory=lambda: TensorSplitSections_ListUnpackModule())
11451145
def TensorSplitSections_ListUnpackModule_basic(module, tu: TestUtils):
11461146
module.forward(tu.rand(2, 5))
1147+
1148+
1149+
# ==============================================================================
1150+
1151+
1152+
class AsStridedWithOffsetModule(torch.nn.Module):
1153+
def __init__(self):
1154+
super().__init__()
1155+
1156+
@export
1157+
@annotate_args(
1158+
[
1159+
None,
1160+
([2, 6, 60], torch.float32, True),
1161+
]
1162+
)
1163+
def forward(self, x):
1164+
output_size = [6, 20]
1165+
stride = [60, 1]
1166+
slice = torch.ops.aten.slice.Tensor(x, 0, 1, 2)
1167+
squeeze = torch.ops.aten.squeeze.dim(slice, 0)
1168+
return torch.ops.aten.as_strided(
1169+
squeeze, size=output_size, stride=stride, storage_offset=360
1170+
)
1171+
1172+
1173+
@register_test_case(module_factory=lambda: AsStridedWithOffsetModule())
1174+
def AsStridedWithOffsetModule_basic(module, tu: TestUtils):
1175+
module.forward(torch.rand(2, 6, 60))

test/Conversion/TorchToTosa/basic.mlir

+59-16
Original file line numberDiff line numberDiff line change
@@ -1918,22 +1918,29 @@ func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) ->
19181918
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32>
19191919
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
19201920
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<2xi64>) -> tensor<2xi32>
1921-
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: 1, 1, 2>} : (tensor<2xi32>) -> tensor<1x1x2xi32>
1922-
// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array<i64: 4, 5, 1>} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32>
1923-
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 4, 5, 2, 1>} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32>
1924-
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
1925-
// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
1926-
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32>
1927-
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 120, 1>} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32>
1928-
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 40, 3>} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32>
1929-
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
1930-
// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32>
1931-
// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32>
1932-
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 1, 40>} : (tensor<40x1xi32>) -> tensor<1x40xi32>
1933-
// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32>
1934-
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 4, 5, 2>} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32>
1935-
// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32>
1936-
// CHECK: return %[[VAL_20]] : !torch.vtensor<[4,5,2],f32>
1921+
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<120> : tensor<i32>}> : () -> tensor<i32>
1922+
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<119> : tensor<i32>}> : () -> tensor<i32>
1923+
// CHECK: %[[VAL_8:.*]] = tosa.greater %[[VAL_5]], %[[VAL_7]] : (tensor<2xi32>, tensor<i32>) -> tensor<2xi1>
1924+
// CHECK: %[[VAL_9:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor<2xi32>, tensor<i32>) -> tensor<2xi32>
1925+
// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_9]], %[[VAL_6]] {shift = 0 : i8} : (tensor<2xi32>, tensor<i32>) -> tensor<2xi32>
1926+
// CHECK: %[[VAL_11:.*]] = tosa.sub %[[VAL_5]], %[[VAL_10]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
1927+
// CHECK: %[[VAL_12:.*]] = tosa.select %[[VAL_8]], %[[VAL_11]], %[[VAL_5]] : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
1928+
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array<i64: 1, 1, 2>} : (tensor<2xi32>) -> tensor<1x1x2xi32>
1929+
// CHECK: %[[VAL_14:.*]] = tosa.tile %[[VAL_13]] {multiples = array<i64: 4, 5, 1>} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32>
1930+
// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array<i64: 4, 5, 2, 1>} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32>
1931+
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
1932+
// CHECK: %[[VAL_17:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
1933+
// CHECK: %[[VAL_18:.*]] = tosa.concat %[[VAL_16]], %[[VAL_17]], %[[VAL_15]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32>
1934+
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 120, 1>} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32>
1935+
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 40, 3>} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32>
1936+
// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
1937+
// CHECK: %[[VAL_22:.*]] = tosa.mul %[[VAL_20]], %[[VAL_21]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32>
1938+
// CHECK: %[[VAL_23:.*]] = tosa.reduce_sum %[[VAL_22]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32>
1939+
// CHECK: %[[VAL_24:.*]] = tosa.reshape %[[VAL_23]] {new_shape = array<i64: 1, 40>} : (tensor<40x1xi32>) -> tensor<1x40xi32>
1940+
// CHECK: %[[VAL_25:.*]] = tosa.gather %[[VAL_19]], %[[VAL_24]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32>
1941+
// CHECK: %[[VAL_26:.*]] = tosa.reshape %[[VAL_25]] {new_shape = array<i64: 4, 5, 2>} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32>
1942+
// CHECK: %[[VAL_27:.*]] = torch_c.from_builtin_tensor %[[VAL_26]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32>
1943+
// CHECK: return %[[VAL_27]] : !torch.vtensor<[4,5,2],f32>
19371944
// CHECK: }
19381945
func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> {
19391946
%int2 = torch.constant.int 2
@@ -2331,6 +2338,42 @@ func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !tor
23312338
return %2 : !torch.vtensor<[3,3],f32>
23322339
}
23332340

2341+
// -----
2342+
// CHECK-LABEL: func.func @torch.aten.as_strided$offset(
2343+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> {
2344+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
2345+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 30
2346+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
2347+
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
2348+
// CHECK: %[[VAL_5:.*]] = torch.constant.int 3
2349+
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_5]] : (!torch.int, !torch.int) -> !torch.list<int>
2350+
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
2351+
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 25>} : (tensor<5x5xf32>) -> tensor<25xf32>
2352+
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[5, 6, 7, 7, 8, 9, 9, 10, 11]> : tensor<9xi32>}> : () -> tensor<9xi32>
2353+
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array<i64: 9, 1>} : (tensor<9xi32>) -> tensor<9x1xi32>
2354+
// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_10]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32>
2355+
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1, 25, 1>} : (tensor<25xf32>) -> tensor<1x25x1xf32>
2356+
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 9, 1>} : (tensor<9x1xi32>) -> tensor<9x1xi32>
2357+
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
2358+
// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<9x1xi32>, tensor<1xi32>) -> tensor<9x1xi32>
2359+
// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32>
2360+
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 1, 9>} : (tensor<9x1xi32>) -> tensor<1x9xi32>
2361+
// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x25x1xf32>, tensor<1x9xi32>) -> tensor<1x9x1xf32>
2362+
// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 9>} : (tensor<1x9x1xf32>) -> tensor<9xf32>
2363+
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array<i64: 3, 3>} : (tensor<9xf32>) -> tensor<3x3xf32>
2364+
// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<3x3xf32> -> !torch.vtensor<[3,3],f32>
2365+
// CHECK: return %[[VAL_21]] : !torch.vtensor<[3,3],f32>
2366+
func.func @torch.aten.as_strided$offset(%arg0: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> {
2367+
%int30 = torch.constant.int 30
2368+
%int1 = torch.constant.int 1
2369+
%int2 = torch.constant.int 2
2370+
%int3 = torch.constant.int 3
2371+
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
2372+
%1 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
2373+
%2 = torch.aten.as_strided %arg0, %0, %1, %int30 : !torch.vtensor<[5,5],f32>, !torch.list<int>, !torch.list<int>, !torch.int -> !torch.vtensor<[3,3],f32>
2374+
return %2 : !torch.vtensor<[3,3],f32>
2375+
}
2376+
23342377
// -----
23352378

23362379
// CHECK-LABEL: func.func @torch.aten.max_pool1d$basic(

0 commit comments

Comments
 (0)