Skip to content

Commit 391b47c

Browse files
committed
Apply substitution to the rule predicate
1 parent 5e03d41 commit 391b47c

File tree

1 file changed

+55
-39
lines changed

1 file changed

+55
-39
lines changed

booster/library/Booster/Pattern/Rewrite.hs

+55-39
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ import Booster.Pattern.Match (
6565
MatchResult (MatchFailed, MatchIndeterminate, MatchSuccess),
6666
MatchType (Rewrite),
6767
SortError,
68+
Substitution,
6869
matchTerms,
6970
)
7071
import Booster.Pattern.Pretty
@@ -153,7 +154,7 @@ data RewriteStepResult a = OnlyTrivial | AppliedRules a deriving (Eq, Show, Func
153154
rewriteStep ::
154155
LoggerMIO io =>
155156
Pattern ->
156-
RewriteT io (RewriteStepResult [(RewriteRule "Rewrite", Pattern)])
157+
RewriteT io (RewriteStepResult [(RewriteRule "Rewrite", Pattern, Substitution)])
157158
rewriteStep pat = do
158159
def <- getDefinition
159160
let getIndex =
@@ -175,18 +176,18 @@ rewriteStep pat = do
175176
-- return `OnlyTrivial` if all elements of a list are `(r, Nothing)`. If the list is empty or contains at least one `(r, Just p)`,
176177
-- return an `AppliedRules` list of `(r, p)` pairs.
177178
filterOutTrivial ::
178-
[(RewriteRule "Rewrite", Maybe Pattern)] ->
179-
RewriteStepResult [(RewriteRule "Rewrite", Pattern)]
179+
[(RewriteRule "Rewrite", Maybe (Pattern, Substitution))] ->
180+
RewriteStepResult [(RewriteRule "Rewrite", Pattern, Substitution)]
180181
filterOutTrivial = \case
181182
[] -> AppliedRules []
182183
[(_, Nothing)] -> OnlyTrivial
183184
(_, Nothing) : xs -> filterOutTrivial xs
184-
(rule, Just p) : xs -> AppliedRules $ (rule, p) : mapMaybe (\(r, mp) -> (r,) <$> mp) xs
185+
(rule, Just (p, subst)) : xs -> AppliedRules $ (rule, p, subst) : mapMaybe (\(r, mp) -> (\(x, y) -> (r, x, y)) <$> mp) xs
185186

186187
processGroups ::
187188
LoggerMIO io =>
188189
[[RewriteRule "Rewrite"]] ->
189-
RewriteT io [(RewriteRule "Rewrite", Maybe Pattern)]
190+
RewriteT io [(RewriteRule "Rewrite", Maybe (Pattern, Substitution))]
190191
processGroups [] = pure []
191192
processGroups (rules : lowerPriorityRules) = do
192193
-- try all rules of the priority group. This will immediately
@@ -209,8 +210,10 @@ rewriteStep pat = do
209210
results
210211
-- compute remainder condition here from @nonTrivialResults@ and the remainder up to now.
211212
-- If the new remainder is bottom, then no lower priority rules apply
212-
newRemainder = currentRemainder <> Set.fromList (mapMaybe (snd . snd) nonTrivialResultsWithPartialRemainders)
213-
resultsWithoutRemainders = map (fmap (fmap fst)) results
213+
newRemainder =
214+
currentRemainder
215+
<> Set.fromList (mapMaybe ((\(_, r, _) -> r) . snd) nonTrivialResultsWithPartialRemainders)
216+
resultsWithoutRemainders = map (fmap (fmap (\(p, _, s) -> (p, s)))) results
214217
setRemainder newRemainder
215218
ModifiersRep (_ :: FromModifiersT mods => Proxy mods) <- getPrettyModifiers
216219
withContext CtxRemainder $ logPretty' @mods (collapseAndBools . Set.toList $ newRemainder)
@@ -272,7 +275,7 @@ applyRule ::
272275
LoggerMIO io =>
273276
Pattern ->
274277
RewriteRule "Rewrite" ->
275-
RewriteT io (Maybe (Maybe (Pattern, Maybe Predicate)))
278+
RewriteT io (Maybe (Maybe (Pattern, Maybe Predicate, Substitution)))
276279
applyRule pat@Pattern{ceilConditions} rule =
277280
withRuleContext rule $
278281
runRewriteRuleAppT $
@@ -436,11 +439,12 @@ applyRule pat@Pattern{ceilConditions} rule =
436439
ceilConditions
437440
withContext CtxSuccess $ do
438441
case unclearRequiresAfterSmt of
439-
[] -> withPatternContext rewritten $ pure (rewritten, Nothing)
442+
[] -> withPatternContext rewritten $ pure (rewritten, Nothing, subst)
440443
_ ->
441444
let rewritten' = rewritten{constraints = rewritten.constraints <> Set.fromList unclearRequiresAfterSmt}
442445
in withPatternContext rewritten' $
443-
pure (rewritten', Just $ Predicate $ NotBool $ coerce $ collapseAndBools unclearRequiresAfterSmt)
446+
pure
447+
(rewritten', Just $ Predicate $ NotBool $ coerce $ collapseAndBools unclearRequiresAfterSmt, subst)
444448
where
445449
failRewrite :: RewriteFailed "Rewrite" -> RewriteRuleAppT (RewriteT io) a
446450
failRewrite = lift . (throw)
@@ -841,7 +845,7 @@ performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalL
841845
-- We are stuck here not trivial because we didn't apply a single rule
842846
logMessage ("Rewrite stuck after simplification." :: Text) >> pure (RewriteStuck pat')
843847
pat'@Simplified{} -> logMessage ("Retrying with simplified pattern" :: Text) >> doSteps pat'
844-
AppliedRules [(rule, nextPat)] -- applied single rule
848+
AppliedRules [(rule, nextPat, _subst)] -- applied single rule
845849
-- cut-point rule, stop
846850
| labelOf rule `elem` cutLabels -> do
847851
simplify pat >>= \case
@@ -880,34 +884,46 @@ performRewrite doTracing def mLlvmLibrary mSolver mbMaxDepth cutLabels terminalL
880884
logMessage $ "Previous state found to be bottom after " <> showCounter counter
881885
pure $ RewriteTrivial pat'
882886
Simplified pat' ->
883-
(catSimplified <$> mapM (\(r, nextPat) -> fmap (r,) <$> simplify (Unsimplified nextPat)) nextPats) >>= \case
884-
[] -> withPatternContext pat' $ do
885-
logMessage ("Rewrite trivial after pruning all branches" :: Text)
886-
pure $ RewriteTrivial pat'
887-
[(rule, nextPat')] -> withPatternContext pat' $ do
888-
logMessage ("All but one branch pruned, continuing" :: Text)
889-
emitRewriteTrace $ RewriteSingleStep (labelOf rule) (uniqueId rule) pat' nextPat'
890-
incrementCounter
891-
doSteps (Simplified nextPat')
892-
nextPats' -> do
893-
emitRewriteTrace $
894-
RewriteBranchingStep pat' $
895-
NE.fromList $
896-
map (\(rule, _) -> (ruleLabelOrLocT rule, uniqueId rule)) nextPats'
897-
unless (Set.null remainderPredicates) $ do
898-
ModifiersRep (_ :: FromModifiersT mods => Proxy mods) <- getPrettyModifiers
899-
withContext CtxRemainder . withContext CtxDetail $
900-
logMessage
901-
( ("Uncovered remainder branch after rewriting with rules " :: Text)
902-
<> ( Text.intercalate ", " $ map (\(r, _) -> getUniqueId $ uniqueId r) nextPats'
903-
)
904-
)
905-
pure $
906-
RewriteBranch pat' $
907-
NE.fromList $
908-
map
909-
(\(r, n) -> (ruleLabelOrLocT r, uniqueId r, n, Just (collapseAndBools . Set.toList $ r.requires)))
910-
nextPats'
887+
( catSimplified
888+
<$> mapM (\(r, nextPat, subst) -> fmap (r,,subst) <$> simplify (Unsimplified nextPat)) nextPats
889+
)
890+
>>= \case
891+
[] -> withPatternContext pat' $ do
892+
logMessage ("Rewrite trivial after pruning all branches" :: Text)
893+
pure $ RewriteTrivial pat'
894+
[(rule, nextPat', _subst)] -> withPatternContext pat' $ do
895+
logMessage ("All but one branch pruned, continuing" :: Text)
896+
emitRewriteTrace $ RewriteSingleStep (labelOf rule) (uniqueId rule) pat' nextPat'
897+
incrementCounter
898+
doSteps (Simplified nextPat')
899+
nextPats' -> do
900+
emitRewriteTrace $
901+
RewriteBranchingStep pat' $
902+
NE.fromList $
903+
map (\(rule, _, _subst) -> (ruleLabelOrLocT rule, uniqueId rule)) nextPats'
904+
unless (Set.null remainderPredicates) $ do
905+
ModifiersRep (_ :: FromModifiersT mods => Proxy mods) <- getPrettyModifiers
906+
withContext CtxRemainder . withContext CtxDetail $
907+
logMessage
908+
( ("Uncovered remainder branch after rewriting with rules " :: Text)
909+
<> ( Text.intercalate ", " $ map (\(r, _, _subst) -> getUniqueId $ uniqueId r) nextPats'
910+
)
911+
)
912+
pure $
913+
RewriteBranch pat' $
914+
NE.fromList $
915+
map
916+
( \(r, n, subst) ->
917+
( ruleLabelOrLocT r
918+
, uniqueId r
919+
, n
920+
, Just
921+
( collapseAndBools $
922+
concatMap (splitBoolPredicates . coerce . substituteInTerm subst . coerce) r.requires
923+
)
924+
)
925+
)
926+
nextPats'
911927

912928
data RewriteStepsState = RewriteStepsState
913929
{ counter :: !Natural

0 commit comments

Comments
 (0)