Skip to content

Commit 483fd92

Browse files
authored
Improve simplification of reshape/rearrange combinations (#2255)
This can be done with a fairly straightforward composition algebra. Currently restricted to rearranges that are (mapped) transpositions, but can be extended easily enough I think.
1 parent 35a42dd commit 483fd92

File tree

3 files changed

+107
-3
lines changed

3 files changed

+107
-3
lines changed

src/Futhark/IR/Prop/Reshape.hs

+54
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ module Futhark.IR.Prop.Reshape
99
reshapeInner,
1010

1111
-- * Simplification
12+
flipReshapeRearrange,
1213

1314
-- * Shape calculations
1415
reshapeIndex,
@@ -18,8 +19,11 @@ module Futhark.IR.Prop.Reshape
1819
)
1920
where
2021

22+
import Control.Monad (guard, mplus)
2123
import Data.Foldable
24+
import Futhark.IR.Prop.Rearrange (isMapTranspose)
2225
import Futhark.IR.Syntax
26+
import Futhark.Util (takeLast)
2327
import Futhark.Util.IntegralExp
2428
import Prelude hiding (product, quot, sum)
2529

@@ -101,3 +105,53 @@ sliceSizes (n : ns) =
101105
product (n : ns) : sliceSizes ns
102106

103107
{- HLINT ignore sliceSizes -}
108+
109+
-- | Interchange a reshape and rearrange. Essentially, rewrite composition
110+
--
111+
-- @
112+
-- let v1 = reshape(v1_shape, v0)
113+
-- let v2 = rearrange(perm, v1)
114+
-- @
115+
--
116+
-- into
117+
--
118+
-- @
119+
-- let v1' = rearrange(perm', v0)
120+
-- let v2' = reshape(v1_shape', v1')
121+
--
122+
-- The function is given the shape of @v0@, @v1@, and the @perm@, and returns
123+
-- @perm'@. This is a meaningful operation when @v2@ is itself reshaped, as the
124+
-- reshape-reshape can be fused. This can significantly simplify long chains of
125+
-- reshapes and rearranges.
126+
flipReshapeRearrange ::
127+
(Eq d) =>
128+
[d] ->
129+
[d] ->
130+
[Int] ->
131+
Maybe [Int]
132+
flipReshapeRearrange v0_shape v1_shape perm = do
133+
(num_map_dims, num_a_dims, num_b_dims) <- isMapTranspose perm
134+
guard $ num_a_dims == 1
135+
guard $ num_b_dims == 1
136+
let map_dims = take num_map_dims v0_shape
137+
num_b_dims_expanded = length v0_shape - num_map_dims - num_a_dims
138+
num_a_dims_expanded = length v0_shape - num_map_dims - num_b_dims
139+
caseA = do
140+
guard $ take num_a_dims v0_shape == take num_b_dims v1_shape
141+
let perm' =
142+
[0 .. num_map_dims - 1]
143+
++ map (+ num_map_dims) ([1 .. num_b_dims_expanded] ++ [0])
144+
Just perm'
145+
caseB = do
146+
guard $ takeLast num_b_dims v0_shape == takeLast num_b_dims v1_shape
147+
let perm' =
148+
[0 .. num_map_dims - 1]
149+
++ map
150+
(+ num_map_dims)
151+
(num_a_dims_expanded : [0 .. num_a_dims_expanded - 1])
152+
Just perm'
153+
154+
guard $ map_dims == take num_map_dims v1_shape
155+
156+
caseA `mplus` caseB
157+
{-# NOINLINE flipReshapeRearrange #-}

src/Futhark/Optimise/Simplify/Rules/BasicOp.hs

+10
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,16 @@ ruleBasicOp vtable pat aux (Manifest perm v1)
366366
ST.available v2 vtable =
367367
Simplify . auxing aux . certifying cs . letBind pat . BasicOp $
368368
Manifest perm v2
369+
ruleBasicOp vtable pat aux (Reshape ReshapeArbitrary v3_shape v2)
370+
| Just (Rearrange perm v1, v2_cs) <- ST.lookupBasicOp v2 vtable,
371+
Just (Reshape ReshapeArbitrary v1_shape v0, v1_cs) <- ST.lookupBasicOp v1 vtable,
372+
Just v0_shape <- arrayShape <$> ST.lookupType v0 vtable,
373+
Just perm' <-
374+
flipReshapeRearrange (shapeDims v0_shape) (shapeDims v1_shape) perm =
375+
Simplify $ do
376+
v1' <- letExp (baseString v1) $ BasicOp $ Rearrange perm' v0
377+
auxing aux . certifying (v1_cs <> v2_cs) . letBind pat $
378+
BasicOp (Reshape ReshapeArbitrary v3_shape v1')
369379
ruleBasicOp _ _ _ _ =
370380
Skip
371381

unittests/Futhark/IR/Prop/ReshapeTests.hs

+43-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ import Futhark.IR.Syntax
1111
import Test.Tasty
1212
import Test.Tasty.HUnit
1313

14+
intShape :: [Int] -> Shape
15+
intShape = Shape . map (intConst Int32 . toInteger)
16+
1417
reshapeOuterTests :: [TestTree]
1518
reshapeOuterTests =
1619
[ testCase (unwords ["reshapeOuter", show sc, show n, show shape, "==", show sc_res]) $
@@ -35,10 +38,47 @@ reshapeInnerTests =
3538
]
3639
]
3740

38-
intShape :: [Int] -> Shape
39-
intShape = Shape . map (intConst Int32 . toInteger)
41+
flipTests :: [TestTree]
42+
flipTests =
43+
[ testCase
44+
( unwords
45+
[ "flipReshapeRearrange",
46+
show v0_shape,
47+
show v1_shape,
48+
show perm
49+
]
50+
)
51+
$ flipReshapeRearrange v0_shape v1_shape perm @?= res
52+
| (v0_shape :: [String], v1_shape, perm, res) <-
53+
[ ( ["A", "B", "C"],
54+
["A", "BC"],
55+
[1, 0],
56+
Just [1, 2, 0]
57+
),
58+
( ["A", "B", "C", "D"],
59+
["A", "BCD"],
60+
[1, 0],
61+
Just [1, 2, 3, 0]
62+
),
63+
( ["A"],
64+
["B", "C"],
65+
[1, 0],
66+
Nothing
67+
),
68+
( ["A", "B", "C"],
69+
["AB", "C"],
70+
[1, 0],
71+
Just [2, 0, 1]
72+
),
73+
( ["A", "B", "C", "D"],
74+
["ABC", "D"],
75+
[1, 0],
76+
Just [3, 0, 1, 2]
77+
)
78+
]
79+
]
4080

4181
tests :: TestTree
4282
tests =
4383
testGroup "ReshapeTests" $
44-
reshapeOuterTests ++ reshapeInnerTests
84+
reshapeOuterTests ++ reshapeInnerTests ++ flipTests

0 commit comments

Comments
 (0)