@@ -65,6 +65,7 @@ import Booster.Pattern.Match (
65
65
MatchResult (MatchFailed , MatchIndeterminate , MatchSuccess ),
66
66
MatchType (Rewrite ),
67
67
SortError ,
68
+ Substitution ,
68
69
matchTerms ,
69
70
)
70
71
import Booster.Pattern.Pretty
@@ -153,7 +154,7 @@ data RewriteStepResult a = OnlyTrivial | AppliedRules a deriving (Eq, Show, Func
153
154
rewriteStep ::
154
155
LoggerMIO io =>
155
156
Pattern ->
156
- RewriteT io (RewriteStepResult [(RewriteRule " Rewrite" , Pattern )])
157
+ RewriteT io (RewriteStepResult [(RewriteRule " Rewrite" , Pattern , Substitution )])
157
158
rewriteStep pat = do
158
159
def <- getDefinition
159
160
let getIndex =
@@ -175,18 +176,18 @@ rewriteStep pat = do
175
176
-- return `OnlyTrivial` if all elements of a list are `(r, Nothing)`. If the list is empty or contains at least one `(r, Just p)`,
176
177
-- return an `AppliedRules` list of `(r, p)` pairs.
177
178
filterOutTrivial ::
178
- [(RewriteRule " Rewrite" , Maybe Pattern )] ->
179
- RewriteStepResult [(RewriteRule " Rewrite" , Pattern )]
179
+ [(RewriteRule " Rewrite" , Maybe ( Pattern , Substitution ) )] ->
180
+ RewriteStepResult [(RewriteRule " Rewrite" , Pattern , Substitution )]
180
181
filterOutTrivial = \ case
181
182
[] -> AppliedRules []
182
183
[(_, Nothing )] -> OnlyTrivial
183
184
(_, Nothing ) : xs -> filterOutTrivial xs
184
- (rule, Just p) : xs -> AppliedRules $ (rule, p) : mapMaybe (\ (r, mp) -> (r, ) <$> mp) xs
185
+ (rule, Just (p, subst)) : xs -> AppliedRules $ (rule, p, subst ) : mapMaybe (\ (r, mp) -> (\ (x, y) -> (r, x, y) ) <$> mp) xs
185
186
186
187
processGroups ::
187
188
LoggerMIO io =>
188
189
[[RewriteRule " Rewrite" ]] ->
189
- RewriteT io [(RewriteRule " Rewrite" , Maybe Pattern )]
190
+ RewriteT io [(RewriteRule " Rewrite" , Maybe ( Pattern , Substitution ) )]
190
191
processGroups [] = pure []
191
192
processGroups (rules : lowerPriorityRules) = do
192
193
-- try all rules of the priority group. This will immediately
@@ -209,8 +210,10 @@ rewriteStep pat = do
209
210
results
210
211
-- compute remainder condition here from @nonTrivialResults@ and the remainder up to now.
211
212
-- If the new remainder is bottom, then no lower priority rules apply
212
- newRemainder = currentRemainder <> Set. fromList (mapMaybe (snd . snd ) nonTrivialResultsWithPartialRemainders)
213
- resultsWithoutRemainders = map (fmap (fmap fst )) results
213
+ newRemainder =
214
+ currentRemainder
215
+ <> Set. fromList (mapMaybe ((\ (_, r, _) -> r) . snd ) nonTrivialResultsWithPartialRemainders)
216
+ resultsWithoutRemainders = map (fmap (fmap (\ (p, _, s) -> (p, s)))) results
214
217
setRemainder newRemainder
215
218
ModifiersRep (_ :: FromModifiersT mods => Proxy mods ) <- getPrettyModifiers
216
219
withContext CtxRemainder $ logPretty' @ mods (collapseAndBools . Set. toList $ newRemainder)
@@ -272,7 +275,7 @@ applyRule ::
272
275
LoggerMIO io =>
273
276
Pattern ->
274
277
RewriteRule " Rewrite" ->
275
- RewriteT io (Maybe (Maybe (Pattern , Maybe Predicate )))
278
+ RewriteT io (Maybe (Maybe (Pattern , Maybe Predicate , Substitution )))
276
279
applyRule pat@ Pattern {ceilConditions} rule =
277
280
withRuleContext rule $
278
281
runRewriteRuleAppT $
@@ -436,11 +439,12 @@ applyRule pat@Pattern{ceilConditions} rule =
436
439
ceilConditions
437
440
withContext CtxSuccess $ do
438
441
case unclearRequiresAfterSmt of
439
- [] -> withPatternContext rewritten $ pure (rewritten, Nothing )
442
+ [] -> withPatternContext rewritten $ pure (rewritten, Nothing , subst )
440
443
_ ->
441
444
let rewritten' = rewritten{constraints = rewritten. constraints <> Set. fromList unclearRequiresAfterSmt}
442
445
in withPatternContext rewritten' $
443
- pure (rewritten', Just $ Predicate $ NotBool $ coerce $ collapseAndBools unclearRequiresAfterSmt)
446
+ pure
447
+ (rewritten', Just $ Predicate $ NotBool $ coerce $ collapseAndBools unclearRequiresAfterSmt, subst)
444
448
where
445
449
failRewrite :: RewriteFailed " Rewrite" -> RewriteRuleAppT (RewriteT io ) a
446
450
failRewrite = lift . (throw)
@@ -841,7 +845,7 @@ performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalL
841
845
-- We are stuck here not trivial because we didn't apply a single rule
842
846
logMessage (" Rewrite stuck after simplification." :: Text ) >> pure (RewriteStuck pat')
843
847
pat'@ Simplified {} -> logMessage (" Retrying with simplified pattern" :: Text ) >> doSteps pat'
844
- AppliedRules [(rule, nextPat)] -- applied single rule
848
+ AppliedRules [(rule, nextPat, _subst )] -- applied single rule
845
849
-- cut-point rule, stop
846
850
| labelOf rule `elem` cutLabels -> do
847
851
simplify pat >>= \ case
@@ -880,34 +884,46 @@ performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalL
880
884
logMessage $ " Previous state found to be bottom after " <> showCounter counter
881
885
pure $ RewriteTrivial pat'
882
886
Simplified pat' ->
883
- (catSimplified <$> mapM (\ (r, nextPat) -> fmap (r,) <$> simplify (Unsimplified nextPat)) nextPats) >>= \ case
884
- [] -> withPatternContext pat' $ do
885
- logMessage (" Rewrite trivial after pruning all branches" :: Text )
886
- pure $ RewriteTrivial pat'
887
- [(rule, nextPat')] -> withPatternContext pat' $ do
888
- logMessage (" All but one branch pruned, continuing" :: Text )
889
- emitRewriteTrace $ RewriteSingleStep (labelOf rule) (uniqueId rule) pat' nextPat'
890
- incrementCounter
891
- doSteps (Simplified nextPat')
892
- nextPats' -> do
893
- emitRewriteTrace $
894
- RewriteBranchingStep pat' $
895
- NE. fromList $
896
- map (\ (rule, _) -> (ruleLabelOrLocT rule, uniqueId rule)) nextPats'
897
- unless (Set. null remainderPredicates) $ do
898
- ModifiersRep (_ :: FromModifiersT mods => Proxy mods ) <- getPrettyModifiers
899
- withContext CtxRemainder . withContext CtxDetail $
900
- logMessage
901
- ( (" Uncovered remainder branch after rewriting with rules " :: Text )
902
- <> ( Text. intercalate " , " $ map (\ (r, _) -> getUniqueId $ uniqueId r) nextPats'
903
- )
904
- )
905
- pure $
906
- RewriteBranch pat' $
907
- NE. fromList $
908
- map
909
- (\ (r, n) -> (ruleLabelOrLocT r, uniqueId r, n, Just (collapseAndBools . Set. toList $ r. requires)))
910
- nextPats'
887
+ ( catSimplified
888
+ <$> mapM (\ (r, nextPat, subst) -> fmap (r,,subst) <$> simplify (Unsimplified nextPat)) nextPats
889
+ )
890
+ >>= \ case
891
+ [] -> withPatternContext pat' $ do
892
+ logMessage (" Rewrite trivial after pruning all branches" :: Text )
893
+ pure $ RewriteTrivial pat'
894
+ [(rule, nextPat', _subst)] -> withPatternContext pat' $ do
895
+ logMessage (" All but one branch pruned, continuing" :: Text )
896
+ emitRewriteTrace $ RewriteSingleStep (labelOf rule) (uniqueId rule) pat' nextPat'
897
+ incrementCounter
898
+ doSteps (Simplified nextPat')
899
+ nextPats' -> do
900
+ emitRewriteTrace $
901
+ RewriteBranchingStep pat' $
902
+ NE. fromList $
903
+ map (\ (rule, _, _subst) -> (ruleLabelOrLocT rule, uniqueId rule)) nextPats'
904
+ unless (Set. null remainderPredicates) $ do
905
+ ModifiersRep (_ :: FromModifiersT mods => Proxy mods ) <- getPrettyModifiers
906
+ withContext CtxRemainder . withContext CtxDetail $
907
+ logMessage
908
+ ( (" Uncovered remainder branch after rewriting with rules " :: Text )
909
+ <> ( Text. intercalate " , " $ map (\ (r, _, _subst) -> getUniqueId $ uniqueId r) nextPats'
910
+ )
911
+ )
912
+ pure $
913
+ RewriteBranch pat' $
914
+ NE. fromList $
915
+ map
916
+ ( \ (r, n, subst) ->
917
+ ( ruleLabelOrLocT r
918
+ , uniqueId r
919
+ , n
920
+ , Just
921
+ ( collapseAndBools $
922
+ concatMap (splitBoolPredicates . coerce . substituteInTerm subst . coerce) r. requires
923
+ )
924
+ )
925
+ )
926
+ nextPats'
911
927
912
928
data RewriteStepsState = RewriteStepsState
913
929
{ counter :: ! Natural
0 commit comments