Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/active_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ impl ActiveQuery {
.mark_all_active(active_tracked_ids.iter().copied());
}

pub(super) fn take_cycle_heads(&mut self) -> CycleHeads {
std::mem::take(&mut self.cycle_heads)
}

pub(super) fn add_read(
&mut self,
input: DatabaseKeyIndex,
Expand Down
4 changes: 4 additions & 0 deletions src/cycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,4 +490,8 @@ impl<'db> ProvisionalStatus<'db> {
_ => empty_cycle_heads(),
}
}

pub(crate) const fn is_provisional(&self) -> bool {
matches!(self, ProvisionalStatus::Provisional { .. })
}
}
61 changes: 43 additions & 18 deletions src/function/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,25 @@ where
});

let (new_value, mut completed_query) = match C::CYCLE_STRATEGY {
CycleRecoveryStrategy::Panic => Self::execute_query(
db,
zalsa,
zalsa_local.push_query(database_key_index, IterationCount::initial()),
opt_old_memo,
),
CycleRecoveryStrategy::Panic => {
let (new_value, active_query) = Self::execute_query(
db,
zalsa,
zalsa_local.push_query(database_key_index, IterationCount::initial()),
opt_old_memo,
);
(new_value, active_query.pop())
}
CycleRecoveryStrategy::FallbackImmediate => {
let (mut new_value, mut completed_query) = Self::execute_query(
let (mut new_value, active_query) = Self::execute_query(
db,
zalsa,
zalsa_local.push_query(database_key_index, IterationCount::initial()),
opt_old_memo,
);

let mut completed_query = active_query.pop();

if let Some(cycle_heads) = completed_query.revisions.cycle_heads_mut() {
// Did the new result we got depend on our own provisional value, in a cycle?
if cycle_heads.contains(&database_key_index) {
Expand Down Expand Up @@ -198,9 +203,10 @@ where

let _poison_guard =
PoisonProvisionalIfPanicking::new(self, zalsa, id, memo_ingredient_index);
let mut active_query = zalsa_local.push_query(database_key_index, iteration_count);

let (new_value, completed_query) = loop {
let active_query = zalsa_local.push_query(database_key_index, iteration_count);

// Tracked struct ids that existed in the previous revision
// but weren't recreated in the last iteration. It's important that we seed the next
// query with these ids because the query might re-create them as part of the next iteration.
Expand All @@ -209,29 +215,32 @@ where
// if they aren't recreated when reaching the final iteration.
active_query.seed_tracked_struct_ids(&last_stale_tracked_ids);

let (mut new_value, mut completed_query) = Self::execute_query(
let (mut new_value, mut active_query) = Self::execute_query(
db,
zalsa,
active_query,
last_provisional_memo.or(opt_old_memo),
);

// If there are no cycle heads, break out of the loop (`cycle_heads_mut` returns `None` if the cycle head list is empty)
let Some(cycle_heads) = completed_query.revisions.cycle_heads_mut() else {
// Take the cycle heads to not-fight-rust's-borrow-checker.
let mut cycle_heads = active_query.take_cycle_heads();

// If there are no cycle heads, break out of the loop.
if cycle_heads.is_empty() {
iteration_count = iteration_count.increment().unwrap_or_else(|| {
tracing::warn!("{database_key_index:?}: execute: too many cycle iterations");
panic!("{database_key_index:?}: execute: too many cycle iterations")
});

let mut completed_query = active_query.pop();
completed_query
.revisions
.update_iteration_count_mut(database_key_index, iteration_count);

claim_guard.set_release_mode(ReleaseMode::SelfOnly);
break (new_value, completed_query);
};
}

// Take the cycle heads to not-fight-rust's-borrow-checker.
let mut cycle_heads = std::mem::take(cycle_heads);
let mut missing_heads: SmallVec<[(DatabaseKeyIndex, IterationCount); 1]> =
SmallVec::new_const();
let mut max_iteration_count = iteration_count;
Expand Down Expand Up @@ -262,6 +271,11 @@ where
.provisional_status(zalsa, head.database_key_index.key_index())
.expect("cycle head memo must have been created during the execution");

// A query should only ever depend on other heads that are provisional.
// If this invariant is violated, it means that this query participates in a cycle,
// but it wasn't executed in the last iteration of said cycle.
assert!(provisional_status.is_provisional());

for nested_head in provisional_status.cycle_heads() {
let nested_as_tuple = (
nested_head.database_key_index,
Expand Down Expand Up @@ -298,6 +312,8 @@ where
claim_guard.set_release_mode(ReleaseMode::SelfOnly);
}

let mut completed_query = active_query.pop();
*completed_query.revisions.verified_final.get_mut() = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand in what case this would be true and we need to reset it to false.

Copy link
Contributor Author

@MichaReiser MichaReiser Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QueryRevisions::new defaults to true if cycle heads are empty. The cycle heads are always empty when calling pop because of the preceding query_guard.take_cycle_heads

completed_query.revisions.set_cycle_heads(cycle_heads);

iteration_count = iteration_count.increment().unwrap_or_else(|| {
Expand Down Expand Up @@ -378,8 +394,17 @@ where
this_converged = C::values_equal(&new_value, last_provisional_value);
}
}

let new_cycle_heads = active_query.take_cycle_heads();
for head in new_cycle_heads {
if !cycle_heads.contains(&head.database_key_index) {
panic!("Cycle recovery function for {database_key_index:?} introduced a cycle, depending on {:?}. This is not allowed.", head.database_key_index);
}
}
}

let mut completed_query = active_query.pop();

if let Some(outer_cycle) = outer_cycle {
tracing::info!(
"Detected nested cycle {database_key_index:?}, iterate it as part of the outer cycle {outer_cycle:?}"
Expand All @@ -390,6 +415,7 @@ where
completed_query
.revisions
.set_cycle_converged(this_converged);
*completed_query.revisions.verified_final.get_mut() = false;

// Transfer ownership of this query to the outer cycle, so that it can claim it
// and other threads don't compete for the same lock.
Expand Down Expand Up @@ -428,9 +454,9 @@ where
}

*completed_query.revisions.verified_final.get_mut() = true;

break (new_value, completed_query);
}
*completed_query.revisions.verified_final.get_mut() = false;

// The fixpoint iteration hasn't converged. Iterate again...
iteration_count = iteration_count.increment().unwrap_or_else(|| {
Expand Down Expand Up @@ -484,7 +510,6 @@ where
last_provisional_memo = Some(new_memo);

last_stale_tracked_ids = completed_query.stale_tracked_structs;
active_query = zalsa_local.push_query(database_key_index, iteration_count);

continue;
};
Expand All @@ -503,7 +528,7 @@ where
zalsa: &'db Zalsa,
active_query: ActiveQueryGuard<'db>,
opt_old_memo: Option<&Memo<'db, C>>,
) -> (C::Output<'db>, CompletedQuery) {
) -> (C::Output<'db>, ActiveQueryGuard<'db>) {
if let Some(old_memo) = opt_old_memo {
// If we already executed this query once, then use the tracked-struct ids from the
// previous execution as the starting point for the new one.
Expand All @@ -528,7 +553,7 @@ where
C::id_to_input(zalsa, active_query.database_key_index.key_index()),
);

(new_value, active_query.pop())
(new_value, active_query)
}
}

Expand Down
5 changes: 4 additions & 1 deletion src/function/maybe_changed_after.rs
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,10 @@ where
cycle_heads.append_heads(&mut child_cycle_heads);

match input_result {
VerifyResult::Changed => return VerifyResult::changed(),
VerifyResult::Changed => {
cycle_heads.remove_head(database_key_index);
return VerifyResult::changed();
}
#[cfg(feature = "accumulator")]
VerifyResult::Unchanged { accumulated } => {
inputs |= accumulated;
Expand Down
12 changes: 12 additions & 0 deletions src/zalsa_local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,18 @@ impl ActiveQueryGuard<'_> {
}
}

pub(crate) fn take_cycle_heads(&mut self) -> CycleHeads {
// SAFETY: We do not access the query stack reentrantly.
unsafe {
self.local_state.with_query_stack_unchecked_mut(|stack| {
#[cfg(debug_assertions)]
assert_eq!(stack.len(), self.push_len);
let frame = stack.last_mut().unwrap();
frame.take_cycle_heads()
})
}
}

/// Invoked when the query has successfully completed execution.
fn complete(self) -> CompletedQuery {
// SAFETY: We do not access the query stack reentrantly.
Expand Down
82 changes: 82 additions & 0 deletions tests/cycle_recovery_dependencies.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#![cfg(feature = "inventory")]

//! Queries or inputs read within the cycle recovery function
//! are tracked on the cycle function and don't "leak" into the
//! function calling the query with cycle handling.
use expect_test::expect;
use salsa::Setter as _;

use crate::common::LogDatabase;

mod common;

#[salsa::input]
struct Input {
value: u32,
}

#[salsa::tracked]
fn entry(db: &dyn salsa::Database, input: Input) -> u32 {
query(db, input)
}

#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)]
fn query(db: &dyn salsa::Database, input: Input) -> u32 {
let val = query(db, input);
if val < 5 {
val + 1
} else {
val
}
}

fn cycle_initial(_db: &dyn salsa::Database, _id: salsa::Id, _input: Input) -> u32 {
0
}

fn cycle_fn(
db: &dyn salsa::Database,
_id: salsa::Id,
_last_provisional_value: &u32,
_value: &u32,
_count: u32,
input: Input,
) -> salsa::CycleRecoveryAction<u32> {
let _input = input.value(db);
salsa::CycleRecoveryAction::Iterate
}

#[test_log::test]
fn the_test() {
let mut db = common::EventLoggerDatabase::default();

let input = Input::new(&db, 1);
assert_eq!(entry(&db, input), 5);

db.assert_logs_len(15);

input.set_value(&mut db).to(2);

assert_eq!(entry(&db, input), 5);
db.assert_logs(expect![[r#"
[
"DidSetCancellationFlag",
"WillCheckCancellation",
"WillCheckCancellation",
"WillCheckCancellation",
"WillExecute { database_key: query(Id(0)) }",
"WillCheckCancellation",
"WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(1) }",
"WillCheckCancellation",
"WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(2) }",
"WillCheckCancellation",
"WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(3) }",
"WillCheckCancellation",
"WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(4) }",
"WillCheckCancellation",
"WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(5) }",
"WillCheckCancellation",
"DidValidateMemoizedValue { database_key: entry(Id(0)) }",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before the fix, this incorrectly re-executed entry but not the cycle (which is what depends on the value).

With the fix in execute_maybe_iterate, it did re-execute the cycle, but it also re-executed entry, which is incorrect because query returns the same value (it can be backdated). Now, the behavior is what we want. entry does not get re-executed but the cycle is

]"#]]);
}