@@ -1918,22 +1918,29 @@ func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) ->
1918
1918
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32>
1919
1919
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
1920
1920
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<2xi64>) -> tensor<2xi32>
1921
- // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: 1, 1, 2>} : (tensor<2xi32>) -> tensor<1x1x2xi32>
1922
- // CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array<i64: 4, 5, 1>} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32>
1923
- // CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array<i64: 4, 5, 2, 1>} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32>
1924
- // CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
1925
- // CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
1926
- // CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32>
1927
- // CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 120, 1>} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32>
1928
- // CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 40, 3>} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32>
1929
- // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
1930
- // CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32>
1931
- // CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32>
1932
- // CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 1, 40>} : (tensor<40x1xi32>) -> tensor<1x40xi32>
1933
- // CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32>
1934
- // CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 4, 5, 2>} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32>
1935
- // CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32>
1936
- // CHECK: return %[[VAL_20]] : !torch.vtensor<[4,5,2],f32>
1921
+ // CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<120> : tensor<i32>}> : () -> tensor<i32>
1922
+ // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<119> : tensor<i32>}> : () -> tensor<i32>
1923
+ // CHECK: %[[VAL_8:.*]] = tosa.greater %[[VAL_5]], %[[VAL_7]] : (tensor<2xi32>, tensor<i32>) -> tensor<2xi1>
1924
+ // CHECK: %[[VAL_9:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor<2xi32>, tensor<i32>) -> tensor<2xi32>
1925
+ // CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_9]], %[[VAL_6]] {shift = 0 : i8} : (tensor<2xi32>, tensor<i32>) -> tensor<2xi32>
1926
+ // CHECK: %[[VAL_11:.*]] = tosa.sub %[[VAL_5]], %[[VAL_10]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
1927
+ // CHECK: %[[VAL_12:.*]] = tosa.select %[[VAL_8]], %[[VAL_11]], %[[VAL_5]] : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
1928
+ // CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array<i64: 1, 1, 2>} : (tensor<2xi32>) -> tensor<1x1x2xi32>
1929
+ // CHECK: %[[VAL_14:.*]] = tosa.tile %[[VAL_13]] {multiples = array<i64: 4, 5, 1>} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32>
1930
+ // CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array<i64: 4, 5, 2, 1>} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32>
1931
+ // CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
1932
+ // CHECK: %[[VAL_17:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32>
1933
+ // CHECK: %[[VAL_18:.*]] = tosa.concat %[[VAL_16]], %[[VAL_17]], %[[VAL_15]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32>
1934
+ // CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array<i64: 1, 120, 1>} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32>
1935
+ // CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 40, 3>} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32>
1936
+ // CHECK: %[[VAL_21:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
1937
+ // CHECK: %[[VAL_22:.*]] = tosa.mul %[[VAL_20]], %[[VAL_21]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32>
1938
+ // CHECK: %[[VAL_23:.*]] = tosa.reduce_sum %[[VAL_22]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32>
1939
+ // CHECK: %[[VAL_24:.*]] = tosa.reshape %[[VAL_23]] {new_shape = array<i64: 1, 40>} : (tensor<40x1xi32>) -> tensor<1x40xi32>
1940
+ // CHECK: %[[VAL_25:.*]] = tosa.gather %[[VAL_19]], %[[VAL_24]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32>
1941
+ // CHECK: %[[VAL_26:.*]] = tosa.reshape %[[VAL_25]] {new_shape = array<i64: 4, 5, 2>} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32>
1942
+ // CHECK: %[[VAL_27:.*]] = torch_c.from_builtin_tensor %[[VAL_26]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32>
1943
+ // CHECK: return %[[VAL_27]] : !torch.vtensor<[4,5,2],f32>
1937
1944
// CHECK: }
1938
1945
func.func @torch.aten.index_select (%arg0: !torch.vtensor <[4 ,5 ,6 ],f32 >, %arg1: !torch.vtensor <[2 ],si64 >) -> !torch.vtensor <[4 ,5 ,2 ],f32 > {
1939
1946
%int2 = torch.constant.int 2
@@ -2331,6 +2338,42 @@ func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !tor
2331
2338
return %2 : !torch.vtensor <[3 ,3 ],f32 >
2332
2339
}
2333
2340
2341
+ // -----
2342
+ // CHECK-LABEL: func.func @torch.aten.as_strided$offset(
2343
+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> {
2344
+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
2345
+ // CHECK: %[[VAL_2:.*]] = torch.constant.int 30
2346
+ // CHECK: %[[VAL_3:.*]] = torch.constant.int 1
2347
+ // CHECK: %[[VAL_4:.*]] = torch.constant.int 2
2348
+ // CHECK: %[[VAL_5:.*]] = torch.constant.int 3
2349
+ // CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_5]] : (!torch.int, !torch.int) -> !torch.list<int>
2350
+ // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
2351
+ // CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 25>} : (tensor<5x5xf32>) -> tensor<25xf32>
2352
+ // CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[5, 6, 7, 7, 8, 9, 9, 10, 11]> : tensor<9xi32>}> : () -> tensor<9xi32>
2353
+ // CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array<i64: 9, 1>} : (tensor<9xi32>) -> tensor<9x1xi32>
2354
+ // CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_10]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32>
2355
+ // CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1, 25, 1>} : (tensor<25xf32>) -> tensor<1x25x1xf32>
2356
+ // CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array<i64: 9, 1>} : (tensor<9x1xi32>) -> tensor<9x1xi32>
2357
+ // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
2358
+ // CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<9x1xi32>, tensor<1xi32>) -> tensor<9x1xi32>
2359
+ // CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32>
2360
+ // CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 1, 9>} : (tensor<9x1xi32>) -> tensor<1x9xi32>
2361
+ // CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x25x1xf32>, tensor<1x9xi32>) -> tensor<1x9x1xf32>
2362
+ // CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array<i64: 9>} : (tensor<1x9x1xf32>) -> tensor<9xf32>
2363
+ // CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array<i64: 3, 3>} : (tensor<9xf32>) -> tensor<3x3xf32>
2364
+ // CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<3x3xf32> -> !torch.vtensor<[3,3],f32>
2365
+ // CHECK: return %[[VAL_21]] : !torch.vtensor<[3,3],f32>
2366
+ func.func @torch.aten.as_strided$offset (%arg0: !torch.vtensor <[5 ,5 ],f32 >) -> !torch.vtensor <[3 ,3 ],f32 > {
2367
+ %int30 = torch.constant.int 30
2368
+ %int1 = torch.constant.int 1
2369
+ %int2 = torch.constant.int 2
2370
+ %int3 = torch.constant.int 3
2371
+ %0 = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
2372
+ %1 = torch.prim.ListConstruct %int2 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
2373
+ %2 = torch.aten.as_strided %arg0 , %0 , %1 , %int30 : !torch.vtensor <[5 ,5 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.int -> !torch.vtensor <[3 ,3 ],f32 >
2374
+ return %2 : !torch.vtensor <[3 ,3 ],f32 >
2375
+ }
2376
+
2334
2377
// -----
2335
2378
2336
2379
// CHECK-LABEL: func.func @torch.aten.max_pool1d$basic(
0 commit comments