diff --git a/Cargo.toml b/Cargo.toml index 9b21d254..968ba365 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,5 +25,7 @@ env_logger = "0.7" linked-hash-map = "0.5.2" rand = "0.7" rand_distr = "0.2.1" +test-env-log = "0.2.7" +insta = "1.8.0" [workspace] diff --git a/book/src/SUMMARY.md b/book/src/SUMMARY.md index af54da7b..6bc5c58b 100644 --- a/book/src/SUMMARY.md +++ b/book/src/SUMMARY.md @@ -1,16 +1,48 @@ # Summary - [About salsa](./about_salsa.md) + +# How to use Salsa + - [How to use Salsa](./how_to_use.md) -- [How Salsa works](./how_salsa_works.md) - [Common patterns](./common_patterns.md) - [Selection](./common_patterns/selection.md) - [On-demand (Lazy) inputs](./common_patterns/on_demand_inputs.md) +- [Cycle handling](./cycles.md) + - [Recovering via fallback](./cycles/fallback.md) + +# How Salsa works internally + +- [How Salsa works](./how_salsa_works.md) - [Videos](./videos.md) - [Plumbing](./plumbing.md) - - [Diagram](./plumbing/diagram.md) - - [Query groups](./plumbing/query_groups.md) - - [Database](./plumbing/database.md) + - [Generated code](./plumbing/generated_code.md) + - [Diagram](./plumbing/diagram.md) + - [Query groups](./plumbing/query_groups.md) + - [Database](./plumbing/database.md) + - [The `salsa` crate](./plumbing/salsa_crate.md) + - [Query operations](./plumbing/query_ops.md) + - [maybe changed after](./plumbing/maybe_changed_after.md) + - [Fetch](./plumbing/fetch.md) + - [Derived queries flowchart](./plumbing/derived_flowchart.md) + - [Cycle handling](./plumbing/cycles.md) + - [Terminology](./plumbing/terminology.md) + - [Backdate](./plumbing/terminology/backdate.md) + - [Changed at](./plumbing/terminology/changed_at.md) + - [Dependency](./plumbing/terminology/dependency.md) + - [Derived query](./plumbing/terminology/derived_query.md) + - [Durability](./plumbing/terminology/durability.md) + - [Input query](./plumbing/terminology/input_query.md) + - [LRU](./plumbing/terminology/LRU.md) + - [Memo](./plumbing/terminology/memo.md) + - [Query](./plumbing/terminology/query.md) + - [Query function](./plumbing/terminology/query_function.md) + - [Revision](./plumbing/terminology/revision.md) + - [Untracked dependency](./plumbing/terminology/untracked.md) + - [Verified](./plumbing/terminology/verified.md) + +# Salsa RFCs + - [RFCs](./rfcs.md) - [Template](./rfcs/template.md) - [RFC 0001: Query group traits](./rfcs/RFC0001-Query-Group-Traits.md) @@ -20,4 +52,10 @@ - [RFC 0005: Durability](./rfcs/RFC0005-Durability.md) - [RFC 0006: Dynamic database](./rfcs/RFC0006-Dynamic-Databases.md) - [RFC 0007: Opinionated cancelation](./rfcs/RFC0007-Opinionated-Cancelation.md) - - [RFC 0008: Remove garbage collection](./rfcs/RFC0008-Remove-Garbage-Collection.md) \ No newline at end of file + - [RFC 0008: Remove garbage collection](./rfcs/RFC0008-Remove-Garbage-Collection.md) + - [RFC 0009: Cycle recovery](./rfcs/RFC0009-Cycle-recovery.md) + +# Appendices + +- [Meta: about the book itself](./meta.md) + diff --git a/book/src/cycles.md b/book/src/cycles.md new file mode 100644 index 00000000..64ea5e73 --- /dev/null +++ b/book/src/cycles.md @@ -0,0 +1,3 @@ +# Cycle handling + +By default, when Salsa detects a cycle in the computation graph, Salsa will panic with a [`salsa::Cycle`] as the panic value. The [`salsa::Cycle`] structure that describes the cycle, which can be useful for diagnosing what went wrong. diff --git a/book/src/cycles/fallback.md b/book/src/cycles/fallback.md new file mode 100644 index 00000000..a69d9071 --- /dev/null +++ b/book/src/cycles/fallback.md @@ -0,0 +1,22 @@ +# Recovering via fallback + +Panicking when a cycle occurs is ok for situations where you believe a cycle is impossible. But sometimes cycles can result from illegal user input and cannot be statically prevented. In these cases, you might prefer to gracefully recover from a cycle rather than panicking the entire query. Salsa supports that with the idea of *cycle recovery*. + +To use cycle recovery, you annotate potential participants in the cycle with a `#[salsa::recover(my_recover_fn)]` attribute. When a cycle occurs, if any participant P has recovery information, then no panic occurs. Instead, the execution of P is aborted and P will execute the recovery function to generate its result. Participants in the cycle that do not have recovery information continue executing as normal, using this recovery result. + +The recovery function has a similar signature to a query function. It is given a reference to your database along with a `salsa::Cycle` describing the cycle that occurred; it returns the result of the query. Example: + +```rust +fn my_recover_fn( + db: &dyn MyDatabase, + cycle: &salsa::Cycle, +) -> MyResultValue +``` + +The `db` and `cycle` argument can be used to prepare a useful error message for your users. + +**Important:** Although the recovery function is given a `db` handle, you should be careful to avoid creating a cycle from within recovery or invoking queries that may be participating in the current cycle. Attempting to do so can result in inconsistent results. + +## Figuring out why recovery did not work + +If a cycle occurs and *some* of the participant queries have `#[salsa::recover]` annotations and others do not, then the query will be treated as irrecoverable and will simply panic. You can use the `Cycle::unexpected_participants` method to figure out why recovery did not succeed and add the appropriate `#[salsa::recover]` annotations. diff --git a/book/src/derived-query-read.drawio.svg b/book/src/derived-query-read.drawio.svg new file mode 100644 index 00000000..ccde18c8 --- /dev/null +++ b/book/src/derived-query-read.drawio.svg @@ -0,0 +1,4 @@ + + + +
Acquire read lock
on query Q
Acquire read lock...
Acquire read-upgrade lock on query Q
Acquire read-upgrade...
With read lock
In progress
by thread T
In progress...
Read value
of Q
Read value...
Up to date
Up to date
Block on T
Block on T
Return value
Return value
Thread T
completed
Thread T...
Execute Cycle Recovery
Execute Cycle...
Recoverable
cycle
detected
Recoverable...
Panic
Panic
Irrecoverable
cycle detected
Irrecoverable...
Other
value
Other...
With read-upgrade lock
In progress
by thread T
In progress...
Read value
of Q
Read value...
Up to date
Up to date
Block on T
Block on T
Return value
Return value
Thread T
completed
Thread T...
Execute Cycle Recovery
Execute Cycle...
Recoverable
cycle
detected
Recoverable...
Upgrade to
write lock on Q
Upgrade to...
Other
value
Other...
With write lock
Take memo M and set Q to InProgress
Take memo M and set...
Query is marked as in-progress
Pop query from
query stack
Pop query from...
Set Q to Not Computed
Set Q to Not Computed
Push panic guard
Push panic guard
Push query
on query stack
Push query...
Query has
memoized value?
Query has...
Salsa Event:
WillExecute
Salsa Event:...
No
No
Execute Query Function
Execute Query Functi...
Execute Cycle Recovery
Execute Cycle Recove...
Pop query from
query stack
Pop query from...
Unwound with
`CycleParticipant`
Unwound with...
Create memo M1
with recorded inputs
Create memo M1...
Returned
normally
Returned...
Backdate M1 to "changed at" from M
if value is the same
Backdate M1 to "chan...
Store `M` and
return value
Store `M` and...
For each input `Qin`...
For each input `Qin`...
Yes!
Yes!
No
No
...has `Qin` changed
since `M` was last verified?
...has `Qin` changed...
Yes
Yes
Mark `M` as verified
in current revision
Mark `M` as verified...
Salsa Event:
DidValidateMemoizedValue
Salsa Event:...
Store `M` and return value
Store `M` and return value
No more
inputs
No more...
Yes
Yes
Any inputs with durability D changed since M was verified?
Any inputs with durability...
No
No
Panic
Panic
Panic
Panic
Panic
Panic
Viewer does not support full SVG 1.1
\ No newline at end of file diff --git a/book/src/meta.md b/book/src/meta.md new file mode 100644 index 00000000..19b458b6 --- /dev/null +++ b/book/src/meta.md @@ -0,0 +1,17 @@ +# Meta: about the book itself + +## Linking policy + +We try to avoid links that easily become fragile. + +**Do:** + +* Link to `docs.rs` types to document the public API, but modify the link to use `latest` as the version. +* Link to modules in the source code. +* Create ["named anchors"] and embed source code directly. + +["named anchors"]: https://rust-lang.github.io/mdBook/format/mdbook.html?highlight=ANCHOR#including-portions-of-a-file + +**Don't:** + +* Link to direct lines on github, even within a specific commit, unless you are trying to reference a historical piece of code ("how things were at the time"). \ No newline at end of file diff --git a/book/src/plumbing.md b/book/src/plumbing.md index 5d4b1a72..197b2c6e 100644 --- a/book/src/plumbing.md +++ b/book/src/plumbing.md @@ -3,17 +3,6 @@ This chapter documents the code that salsa generates and its "inner workings". We refer to this as the "plumbing". -This page walks through the ["Hello, World!"] example and explains the code that -it generates. Please take it with a grain of salt: while we make an effort to -keep this documentation up to date, this sort of thing can fall out of date -easily. See the page history below for major updates. - -["Hello, World!"]: https://github.com/salsa-rs/salsa/blob/master/examples/hello_world/main.rs - -If you'd like to see for yourself, you can set the environment variable -`SALSA_DUMP` to 1 while the procedural macro runs, and it will dump the full -output to stdout. I recommend piping the output through rustfmt. - ## History * 2020-07-05: Updated to take [RFC 6](rfcs/RFC0006-Dynamic-Databases.md) into account. diff --git a/book/src/plumbing/cycles.md b/book/src/plumbing/cycles.md new file mode 100644 index 00000000..e98a8b66 --- /dev/null +++ b/book/src/plumbing/cycles.md @@ -0,0 +1,65 @@ +# Cycles + +## Cross-thread blocking + +The interface for blocking across threads now works as follows: + +* When one thread `T1` wishes to block on a query `Q` being executed by another thread `T2`, it invokes `Runtime::try_block_on`. This will check for cycles. Assuming no cycle is detected, it will block `T1` until `T2` has completed with `Q`. At that point, `T1` reawakens. However, we don't know the result of executing `Q`, so `T1` now has to "retry". Typically, this will result in successfully reading the cached value. +* While `T1` is blocking, the runtime moves its query stack (a `Vec`) into the shared dependency graph data structure. When `T1` reawakens, it recovers ownership of its query stack before returning from `try_block_on`. + +## Cycle detection + +When a thread `T1` attempts to execute a query `Q`, it will try to load the value for `Q` from the memoization tables. If it finds an `InProgress` marker, that indicates that `Q` is currently being computed. This indicates a potential cycle. `T1` will then try to block on the query `Q`: + +* If `Q` is also being computed by `T1`, then there is a cycle. +* Otherwise, if `Q` is being computed by some other thread `T2`, we have to check whether `T2` is (transitively) blocked on `T1`. If so, there is a cycle. + +These two cases are handled internally by the `Runtime::try_block_on` function. Detecting the intra-thread cycle case is easy; to detect cross-thread cycles, the runtime maintains a dependency DAG between threads (identified by `RuntimeId`). Before adding an edge `T1 -> T2` (i.e., `T1` is blocked waiting for `T2`) into the DAG, it checks whether a path exists from `T2` to `T1`. If so, we have a cycle and the edge cannot be added (then the DAG would not longer be acyclic). + +When a cycle is detected, the current thread `T1` has full access to the query stacks that are participating in the cycle. Consider: naturally, `T1` has access to its own stack. There is also a path `T2 -> ... -> Tn -> T1` of blocked threads. Each of the blocked threads `T2 ..= Tn` will have moved their query stacks into the dependency graph, so those query stacks are available for inspection. + +Using the available stacks, we can create a list of cycle participants `Q0 ... Qn` and store that into a `Cycle` struct. If none of the participants `Q0 ... Qn` have cycle recovery enabled, we panic with the `Cycle` struct, which will trigger all the queries on this thread to panic. + +## Cycle recovery via fallback + +If any of the cycle participants `Q0 ... Qn` has cycle recovery set, we recover from the cycle. To help explain how this works, we will use this example cycle which contains three threads. Beginning with the current query, the cycle participants are `QA3`, `QB2`, `QB3`, `QC2`, `QC3`, and `QA2`. + +``` + The cyclic + edge we have + failed to add. + : + A : B C + : + QA1 v QB1 QC1 +┌► QA2 ┌──► QB2 ┌─► QC2 +│ QA3 ───┘ QB3 ──┘ QC3 ───┐ +│ │ +└───────────────────────────────┘ +``` + +Recovery works in phases: + +* **Analyze:** As we enumerate the query participants, we collect their collective inputs (all queries invoked so far by any cycle participant) and the max changed-at and min duration. We then remove the cycle participants themselves from this list of inputs, leaving only the queries external to the cycle. +* **Mark**: For each query Q that is annotated with `#[salsa::recover]`, we mark it and all of its successors on the same thread by setting its `cycle` flag to the `c: Cycle` we constructed earlier; we also reset its inputs to the collective inputs gathering during analysis. If those queries resume execution later, those marks will trigger them to immediately unwind and use cycle recovery, and the inputs will be used as the inputs to the recovery value. + * Note that we mark *all* the successors of Q on the same thread, whether or not they have recovery set. We'll discuss later how this is important in the case where the active thread (A, here) doesn't have any recovery set. +* **Unblock**: Each blocked thread T that has a recovering query is forcibly reawoken; the outgoing edge from that thread to its successor in the cycle is removed. Its condvar is signalled with a `WaitResult::Cycle(c)`. When the thread reawakens, it will see that and start unwinding with the cycle `c`. +* **Handle the current thread:** Finally, we have to choose how to have the current thread proceed. If the current thread includes any cycle with recovery information, then we can begin unwinding. Otherwise, the current thread simply continues as if there had been no cycle, and so the cyclic edge is added to the graph and the current thread blocks. This is possible because some other thread had recovery information and therefore has been awoken. + +Let's walk through the process with a few examples. + +### Example 1: Recovery on the detecting thread + +Consider the case where only the query QA2 has recovery set. It and QA3 will be marked with their `cycle` flag set to `c: Cycle`. Threads B and C will not be unblocked, as they do not have any cycle recovery nodes. The current thread (Thread A) will initiate unwinding with the cycle `c` as the value. Unwinding will pass through QA3 and be caught by QA2. QA2 will substitute the recovery value and return normally. QA1 and QC3 will then complete normally and so forth, on up until all queries have completed. + +### Example 2: Recovery in two queries on the detecting thread + +Consider the case where both query QA2 and QA3 have recovery set. It proceeds the same Example 1 until the the current initiates unwinding, as described in Example 1. When QA3 receives the cycle, it stores its recovery value and completes normally. QA2 then adds QA3 as an input dependency: at that point, QA2 observes that it too has the cycle mark set, and so it initiates unwinding. The rest of QA2 therefore never executes. This unwinding is caught by QA2's entry point and it stores the recovery value and returns normally. QA1 and QC3 then continue normally, as they have not had their `cycle` flag set. + +### Example 3: Recovery on another thread + +Now consider the case where only the query QB2 has recovery set. It and QB3 will be marked with the cycle `c: Cycle` and thread B will be unblocked; the edge `QB3 -> QC2` will be removed from the dependency graph. Thread A will then add an edge `QA3 -> QB2` and block on thread B. At that point, thread A releases the lock on the dependency graph, and so thread B is re-awoken. It observes the `WaitResult::Cycle` and initiates unwinding. Unwinding proceeds through QB3 and into QB2, which recovers. QB1 is then able to execute normally, as is QA3, and execution proceeds from there. + +### Example 4: Recovery on all queries + +Now consider the case where all the queries have recovery set. In that case, they are all marked with the cycle, and all the cross-thread edges are removed from the graph. Each thread will independently awaken and initiate unwinding. Each query will recover. diff --git a/book/src/plumbing/derived_flowchart.md b/book/src/plumbing/derived_flowchart.md new file mode 100644 index 00000000..8423c173 --- /dev/null +++ b/book/src/plumbing/derived_flowchart.md @@ -0,0 +1,14 @@ +# Derived queries flowchart + +Derived queries are by far the most complex. This flowchart documents the flow of the [maybe changed after] and [fetch] operations. This flowchart can be edited on [draw.io]: + +[draw.io]: https://draw.io +[fetch]: ./fetch.md +[maybe changed after]: ./maybe_changed_after.md + + +
+ +![Flowchart](../derived-query-read.drawio.svg) + +
diff --git a/book/src/plumbing/diagram.md b/book/src/plumbing/diagram.md index 15050897..17ab7f90 100644 --- a/book/src/plumbing/diagram.md +++ b/book/src/plumbing/diagram.md @@ -1,21 +1,13 @@ # Diagram -Based on the hello world example: - -```rust,ignore -{{#include ../../../examples/hello_world/main.rs:trait}} -``` - -```rust,ignore -{{#include ../../../examples/hello_world/main.rs:database}} -``` +This diagram shows the items that get generated from the Hello World query group and database struct. You can click on each item to be taken to the explanation of its purpose. The diagram is wide so be sure to scroll over! ```mermaid graph LR classDef diagramNode text-align:left; subgraph query group HelloWorldTrait["trait HelloWorld: Database + HasQueryGroup(HelloWorldStroage)"] - HelloWorldImpl["impl(DB) HelloWorld for DB
where DB: HasQueryGroup(HelloWorldStorage)"] + HelloWorldImpl["impl<DB> HelloWorld for DB
where DB: HasQueryGroup(HelloWorldStorage)"] click HelloWorldImpl "http:query_groups.html#impl-of-the-hello-world-trait" "more info" HelloWorldStorage["struct HelloWorldStorage"] click HelloWorldStorage "http:query_groups.html#the-group-struct-and-querygroup-trait" "more info" @@ -53,6 +45,6 @@ graph LR class DerivedStorage diagramNode; end LengthQueryImpl --> DerivedStorage; - DatabaseStruct --> HelloWorldImpl - HasQueryGroup --> HelloWorldImpl + DatabaseStruct -- "used by" --> HelloWorldImpl + HasQueryGroup -- "used by" --> HelloWorldImpl ``` \ No newline at end of file diff --git a/book/src/plumbing/fetch.md b/book/src/plumbing/fetch.md new file mode 100644 index 00000000..e0319d1c --- /dev/null +++ b/book/src/plumbing/fetch.md @@ -0,0 +1,42 @@ +# Fetch + +```rust,no_run,noplayground +{{#include ../../../src/plumbing.rs:fetch}} +``` + +The `fetch` operation computes the value of a query. It prefers to reuse memoized values when it can. + +## Input queries + +Input queries simply load the result from the table. + +## Interned queries + +Interned queries map the input into a hashmap to find an existing integer. If none is present, a new value is created. + +## Derived queries + +The logic for derived queries is more complex. We summarize the high-level ideas here, but you may find the [flowchart](./derived_flowchart.md) useful to dig deeper. The [terminology](./terminology.md) section may also be useful; in some cases, we link to that section on the first usage of a word. + +* If an existing [memo] is found, then we check if the memo was [verified] in the current [revision]. If so, we can directly return the memoized value. +* Otherwise, if the memo contains a memoized value, we must check whether [dependencies] have been modified: + * Let R be the revision in which the memo was last verified; we wish to know if any of the dependencies have changed since revision R. + * First, we check the [durability]. For each memo, we track the minimum durability of the memo's dependencies. If the memo has durability D, and there have been no changes to an input with durability D since the last time the memo was verified, then we can consider the memo verified without any further work. + * If the durability check is not sufficient, then we must check the dependencies individually. For this, we iterate over each dependency D and invoke the [maybe changed after](./maybe_changed_after.md) operation to check whether D has changed since the revision R. + * If no dependency was modified: + * We can mark the memo as verified and return its memoized value. +* Assuming dependencies have been modified or the memo does not contain a memoized value: + * Then we execute the user's query function. + * Next, we compute the revision in which the memoized value last changed: + * *Backdate:* If there was a previous memoized value, and the new value is equal to that old value, then we can *backdate* the memo, which means to use the 'changed at' revision from before. + * Thanks to backdating, it is possible for a dependency of the query to have changed in some revision R1 but for the *output* of the query to have changed in some revision R2 where R2 predates R1. + * Otherwise, we use the current revision. + * Construct a memo for the new value and return it. + +[durability]: ./terminology/durability.md +[backdate]: ./terminology/backdate.md +[dependency]: ./terminology/dependency.md +[dependencies]: ./terminology/dependency.md +[memo]: ./terminology/memo.md +[revision]: ./terminology/revision.md +[verified]: ./terminology/verified.md \ No newline at end of file diff --git a/book/src/plumbing/generated_code.md b/book/src/plumbing/generated_code.md new file mode 100644 index 00000000..ae7742f5 --- /dev/null +++ b/book/src/plumbing/generated_code.md @@ -0,0 +1,28 @@ +# Generated code + +This page walks through the ["Hello, World!"] example and explains the code that +it generates. Please take it with a grain of salt: while we make an effort to +keep this documentation up to date, this sort of thing can fall out of date +easily. See the page history below for major updates. + +["Hello, World!"]: https://github.com/salsa-rs/salsa/blob/master/examples/hello_world/main.rs + +If you'd like to see for yourself, you can set the environment variable +`SALSA_DUMP` to 1 while the procedural macro runs, and it will dump the full +output to stdout. I recommend piping the output through rustfmt. + +## Sources + +The main parts of the source that we are focused on are as follows. + +### Query group + +```rust,ignore +{{#include ../../../examples/hello_world/main.rs:trait}} +``` + +### Database + +```rust,ignore +{{#include ../../../examples/hello_world/main.rs:database}} +``` diff --git a/book/src/plumbing/maybe_changed_after.md b/book/src/plumbing/maybe_changed_after.md new file mode 100644 index 00000000..754ebd26 --- /dev/null +++ b/book/src/plumbing/maybe_changed_after.md @@ -0,0 +1,40 @@ +# Maybe changed after + +```rust,no_run,noplayground +{{#include ../../../src/plumbing.rs:maybe_changed_after}} +``` + +The `maybe_changed_after` operation computes whether a query's value *may have changed* **after** the given revision. In other words, `Q.maybe_change_since(R)` is true if the value of the query `Q` may have changed in the revisions `(R+1)..R_now`, where `R_now` is the current revision. Note that it doesn't make sense to ask `maybe_changed_after(R_now)`. + +## Input queries + +Input queries are set explicitly by the user. `maybe_changed_after` can therefore just check when the value was last set and compare. + +## Interned queries + +## Derived queries + +The logic for derived queries is more complex. We summarize the high-level ideas here, but you may find the [flowchart](./derived_flowchart.md) useful to dig deeper. The [terminology](./terminology.md) section may also be useful; in some cases, we link to that section on the first usage of a word. + +* If an existing [memo] is found, then we check if the memo was [verified] in the current [revision]. If so, we can compare its [changed at] revision and return true or false appropriately. +* Otherwise, we must check whether [dependencies] have been modified: + * Let R be the revision in which the memo was last verified; we wish to know if any of the dependencies have changed since revision R. + * First, we check the [durability]. For each memo, we track the minimum durability of the memo's dependencies. If the memo has durability D, and there have been no changes to an input with durability D since the last time the memo was verified, then we can consider the memo verified without any further work. + * If the durability check is not sufficient, then we must check the dependencies individually. For this, we iterate over each dependency D and invoke the [maybe changed after](./maybe_changed_after.md) operation to check whether D has changed since the revision R. + * If no dependency was modified: + * We can mark the memo as verified and use its [changed at] revision to return true or false. +* Assuming dependencies have been modified: + * Then we execute the user's query function (same as in [fetch]), which potentially [backdates] the resulting value. + * Compare the [changed at] revision in the resulting memo and return true or false. + +[changed at]: ./terminology/changed_at.md +[durability]: ./terminology/durability.md +[backdate]: ./terminology/backdate.md +[backdates]: ./terminology/backdate.md +[dependency]: ./terminology/dependency.md +[dependencies]: ./terminology/dependency.md +[memo]: ./terminology/memo.md +[revision]: ./terminology/revision.md +[verified]: ./terminology/verified.md +[fetch]: ./fetch.md +[LRU]: ./terminology/LRU.md \ No newline at end of file diff --git a/book/src/plumbing/query_ops.md b/book/src/plumbing/query_ops.md new file mode 100644 index 00000000..425209ad --- /dev/null +++ b/book/src/plumbing/query_ops.md @@ -0,0 +1,14 @@ +# Query operations + +Each of the query storage struct implements the `QueryStorageOps` trait found in the [`plumbing`] module: + +```rust,no_run,noplayground +{{#include ../../../src/plumbing.rs:QueryStorageOps}} +``` + + which defines the basic operations that all queries support. The most important are these two: + +* [maybe changed after](./maybe_changed_after.md): Returns true if the value of the query (for the given key) may have changed since the given revision. +* [Fetch](./fetch.md): Returms the up-to-date value for the given K (or an error in the case of an "unrecovered" cycle). + +[`plumbing`]: https://github.com/salsa-rs/salsa/blob/master/src/plumbing.rs \ No newline at end of file diff --git a/book/src/plumbing/salsa_crate.md b/book/src/plumbing/salsa_crate.md new file mode 100644 index 00000000..373a098e --- /dev/null +++ b/book/src/plumbing/salsa_crate.md @@ -0,0 +1,45 @@ +# Runtime + +This section documents the contents of the salsa crate. The salsa crate contains code that interacts with the [generated code] to create the complete "salsa experience". + +[generated code]: ./generated_code.md + +## Major types + +The crate has a few major types. + +### The [`salsa::Storage`] struct + +The [`salsa::Storage`] struct is what users embed into their database. It consists of two main parts: + +* The "query store", which is the [generated storage struct](./database.md#the-database-storage-struct). +* The [`salsa::Runtime`]. + +### The [`salsa::Runtime`] struct + +The [`salsa::Runtime`] struct stores the data that is used to track which queries are being executed and to coordinate between them. The `Runtime` is embedded within the [`salsa::Storage`] struct. + +**Important**. The `Runtime` does **not** store the actual data from the queries; they live alongside it in the [`salsa::Storage`] struct. This ensures that the type of `Runtime` is not generic which is needed to ensure dyn safety. + +#### Threading + +There is one [`salsa::Runtime`] for each active thread, and each of them has a unique [`RuntimeId`]. The `Runtime` state itself is divided into; + +* `SharedState`, accessible from all runtimes; +* `LocalState`, accessible only from this runtime. + +[`salsa::Runtime`]: https://docs.rs/salsa/latest/salsa/struct.Runtime.html +[`salsa::Storage`]: https://docs.rs/salsa/latest/salsa/struct.Storage.html +[`RuntimeId`]: https://docs.rs/salsa/0.16.1/salsa/struct.RuntimeId.html + +### Query storage implementations and support code + +For each kind of query (input, derived, interned, etc) there is a corresponding "storage struct" that contains the code to implement it. For example, derived queries are implemented by the `DerivedStorage` struct found in the [`salsa::derived`] module. + +[`salsa::derived`]: https://github.com/salsa-rs/salsa/blob/master/src/derived.rs + +Storage structs like `DerivedStorage` are generic over a query type `Q`, which corresponds to the [query structs] in the generated code. The query structs implement the `Query` trait which gives basic info such as the key and value type of the query and its ability to recover from cycles. In some cases, the `Q` type is expected to implement additional traits: derived queries, for example, implement `QueryFunction`, which defines the code that will execute when the query is called. + +[query structs]: ./query_groups.md#for-each-query-a-query-struct + +The storage structs, in turn, implement key traits from the plumbing module. The most notable is the `QueryStorageOps`, which defines the [basic operations that can be done on a query](./query_ops.md). diff --git a/book/src/plumbing/terminology.md b/book/src/plumbing/terminology.md new file mode 100644 index 00000000..e2e2597f --- /dev/null +++ b/book/src/plumbing/terminology.md @@ -0,0 +1 @@ +# Terminology diff --git a/book/src/plumbing/terminology/LRU.md b/book/src/plumbing/terminology/LRU.md new file mode 100644 index 00000000..5e6dd319 --- /dev/null +++ b/book/src/plumbing/terminology/LRU.md @@ -0,0 +1,6 @@ +# LRU + +the [`set_lru_capacity`](https://docs.rs/salsa/0.16.1/salsa/struct.QueryTableMut.html#method.set_lru_capacity) method can be used to fix the maximum capacity for a query at a specific number of values. If more values are added after that point, then salsa will drop the values from older [memos] to conserve memory (we always retain the [dependency] information for those memos, however, so that we can still compute whether values may have changed, even if we don't know what that value is). The LRU mechanism was introduced in [RFC #4](../../rfcs/RFC0004-LRU.md). + +[memos]: ./memo.md +[dependency]: ./dependency.md \ No newline at end of file diff --git a/book/src/plumbing/terminology/backdate.md b/book/src/plumbing/terminology/backdate.md new file mode 100644 index 00000000..5d1a00c0 --- /dev/null +++ b/book/src/plumbing/terminology/backdate.md @@ -0,0 +1,7 @@ +# Backdate + +*Backdating* is when we mark a value that was computed in revision R as having last changed in some earlier revision. This is done when we have an older [memo] M and we can compare the two values to see that, while the [dependencies] to M may have changed, the result of the [query function] did not. + +[memo]: ./memo.md +[dependencies]: ./dependency.md +[query function]: ./query_function.md \ No newline at end of file diff --git a/book/src/plumbing/terminology/changed_at.md b/book/src/plumbing/terminology/changed_at.md new file mode 100644 index 00000000..e73fa626 --- /dev/null +++ b/book/src/plumbing/terminology/changed_at.md @@ -0,0 +1,8 @@ +# Changed at + +The *changed at* revision for a [memo] is the [revision] in which that memo's value last changed. Typically, this is the same as the revision in which the [query function] was last executed, but it may be an earlier revision if the memo was [backdated]. + +[query function]: ./query_function.md +[backdated]: ./backdate.md +[revision]: ./revision.md +[memo]: ./memo.md \ No newline at end of file diff --git a/book/src/plumbing/terminology/dependency.md b/book/src/plumbing/terminology/dependency.md new file mode 100644 index 00000000..2c921410 --- /dev/null +++ b/book/src/plumbing/terminology/dependency.md @@ -0,0 +1,6 @@ +# Dependency + +A *dependency* of a [query] Q is some other query Q1 that was invoked as part of computing the value for Q (typically, invoking by Q's [query function]). + +[query]: ./query.md +[query function]: ./query_function.md \ No newline at end of file diff --git a/book/src/plumbing/terminology/derived_query.md b/book/src/plumbing/terminology/derived_query.md new file mode 100644 index 00000000..158801f6 --- /dev/null +++ b/book/src/plumbing/terminology/derived_query.md @@ -0,0 +1,7 @@ +# Derived query + +A *derived query* is a [query] whose value is defined by the result of a user-provided [query function]. That function is executed to get the result of the query. Unlike [input queries], the result of a derived queries can always be recomputed whenever needed simply by re-executing the function. + +[query]: ./query.md +[query function]: ./query_function.md +[input queries]: ./input_query.md \ No newline at end of file diff --git a/book/src/plumbing/terminology/durability.md b/book/src/plumbing/terminology/durability.md new file mode 100644 index 00000000..4527c69a --- /dev/null +++ b/book/src/plumbing/terminology/durability.md @@ -0,0 +1,6 @@ +# Durability + +*Durability* is an optimization that we use to avoid checking the [dependencies] of a [query] individually. It was introduced in [RFC #5](../../rfcs/RFC0005-Durability.md). + +[dependencies]: ./dependency.md +[query]: ./query.md \ No newline at end of file diff --git a/book/src/plumbing/terminology/input_query.md b/book/src/plumbing/terminology/input_query.md new file mode 100644 index 00000000..27665e24 --- /dev/null +++ b/book/src/plumbing/terminology/input_query.md @@ -0,0 +1,6 @@ +# Input query + +An *input query* is a [query] whose value is explicitly set by the user. When that value is set, a [durability] can also be provided. + +[query]: ./query.md +[durability]: ./durability.md \ No newline at end of file diff --git a/book/src/plumbing/terminology/memo.md b/book/src/plumbing/terminology/memo.md new file mode 100644 index 00000000..dce9d917 --- /dev/null +++ b/book/src/plumbing/terminology/memo.md @@ -0,0 +1,21 @@ +# Memo + +A *memo* stores information about the last time that a [query function] for some [query] Q was executed: + +* Typically, it contains the value that was returned from that function, so that we don't have to execute it again. + * However, this is not always true: some queries don't cache their result values, and values can also be dropped as a result of [LRU] collection. In those cases, the memo just stores [dependency] information, which can still be useful to determine if other queries that have Q as a [dependency] may have changed. +* The revision in which the memo last [verified]. +* The [changed at] revision in which the memo's value last changed. (Note that it may be [backdated].) +* The minimum durability of the memo's [dependencies]. +* The complete set of [dependencies], if available, or a marker that the memo has an [untracked dependency]. + +[revision]: ./revision.md +[backdated]: ./backdate.md +[dependencies]: ./dependency.md +[dependency]: ./dependency.md +[untracked dependency]: ./untracked.md +[verified]: ./verified.md +[query]: ./query.md +[query function]: ./query_function.md +[changed at]: ./changed_at.md +[LRU]: ./LRU.md \ No newline at end of file diff --git a/book/src/plumbing/terminology/query.md b/book/src/plumbing/terminology/query.md new file mode 100644 index 00000000..f0c52077 --- /dev/null +++ b/book/src/plumbing/terminology/query.md @@ -0,0 +1 @@ +# Query diff --git a/book/src/plumbing/terminology/query_function.md b/book/src/plumbing/terminology/query_function.md new file mode 100644 index 00000000..d320b279 --- /dev/null +++ b/book/src/plumbing/terminology/query_function.md @@ -0,0 +1,7 @@ +# Query function + +The *query function* is the user-provided function that we execute to compute the value of a [derived query]. Salsa assumed that all query functions are a 'pure' function of their [dependencies] unless the user reports an [untracked read]. Salsa always assumes that functions have no important side-effects (i.e., that they don't send messages over the network whose results you wish to observe) and thus that it doesn't have to re-execute functions unless it needs their return value. + +[derived query]: ./derived_query.md +[dependencies]: ./dependency.md +[untracked read]: ./untracked.md \ No newline at end of file diff --git a/book/src/plumbing/terminology/revision.md b/book/src/plumbing/terminology/revision.md new file mode 100644 index 00000000..dc922271 --- /dev/null +++ b/book/src/plumbing/terminology/revision.md @@ -0,0 +1,5 @@ +# Revision + +A *revision* is a monotonically increasing integer that we use to track the "version" of the database. Each time the value of an [input query] is modified, we create a new revision. + +[input query]: ./input_query.md \ No newline at end of file diff --git a/book/src/plumbing/terminology/untracked.md b/book/src/plumbing/terminology/untracked.md new file mode 100644 index 00000000..15da6ae2 --- /dev/null +++ b/book/src/plumbing/terminology/untracked.md @@ -0,0 +1,7 @@ +# Untracked dependency + +An *untracked dependency* is an indication that the result of a [derived query] depends on something not visible to the salsa database. Untracked dependencies are created by invoking [`report_untracked_read`](https://docs.rs/salsa/0.16.1/salsa/struct.Runtime.html#method.report_untracked_read) or [`report_synthetic_read`](https://docs.rs/salsa/0.16.1/salsa/struct.Runtime.html#method.report_synthetic_read). When an untracked dependency is present, [derived queries] are always re-executed if the durability check fails (see the description of the [fetch operation] for more details). + +[derived query]: ./derived_query.md +[derived queries]: ./derived_query.md +[fetch operation]: ../fetch.md#derived-queries diff --git a/book/src/plumbing/terminology/verified.md b/book/src/plumbing/terminology/verified.md new file mode 100644 index 00000000..bd6103b2 --- /dev/null +++ b/book/src/plumbing/terminology/verified.md @@ -0,0 +1,8 @@ +# Verified + +A [memo] is *verified* in a revision R if we have checked that its value is still up-to-date (i.e., if we were to reexecute the [query function], we are guaranteed to get the same result). Each memo tracks the revision in which it was last verified to avoid repeatedly checking whether dependencies have changed during the [fetch] and [maybe changed after] operations. + +[query function]: ./query_function.md +[fetch]: ../fetch.md +[maybe changed after]: ../maybe_changed_after.md +[memo]: ./memo.md \ No newline at end of file diff --git a/book/src/rfcs/RFC0001-Query-Group-Traits.md b/book/src/rfcs/RFC0001-Query-Group-Traits.md index c28fac28..6db82eee 100644 --- a/book/src/rfcs/RFC0001-Query-Group-Traits.md +++ b/book/src/rfcs/RFC0001-Query-Group-Traits.md @@ -268,9 +268,9 @@ impl of `QueryGroup`. That involves generating the following things: but since we cannot, it is entitled `MyGroupGroupKey` - It is an enum which contains one variant per query with the value being the key: - `my_query(>::Key)` - - The group key enum offers a public, inherent method `maybe_changed_since`: - - `fn maybe_changed_since(db: &DB, db_descriptor: &DB::DatabaseKey, revision: Revision)` - - it is invoked when implementing `maybe_changed_since` for the database key + - The group key enum offers a public, inherent method `maybe_changed_after`: + - `fn maybe_changed_after(db: &DB, db_descriptor: &DB::DatabaseKey, revision: Revision)` + - it is invoked when implementing `maybe_changed_after` for the database key ### Lowering database storage @@ -298,7 +298,7 @@ It generates the following things: - This contains a `for_each_query` method, which is implemented by invoking, in turn, the inherent methods defined on each query group storage struct. - impl of `plumbing::DatabaseKey` for the database key enum - - This contains a method `maybe_changed_since`. We implement this by + - This contains a method `maybe_changed_after`. We implement this by matching to get a particular group key, and then invoking the inherent method on the group key struct. diff --git a/book/src/rfcs/RFC0006-Dynamic-Databases.md b/book/src/rfcs/RFC0006-Dynamic-Databases.md index 9b1fc2b5..2a3202d6 100644 --- a/book/src/rfcs/RFC0006-Dynamic-Databases.md +++ b/book/src/rfcs/RFC0006-Dynamic-Databases.md @@ -343,7 +343,7 @@ some type for dependencies that is independent of the dtabase type `DB`. There are a number of methods that can be dispatched through the database interface on a `DatabaseKeyIndex`. For example, we already mentioned `fmt_debug`, which emits a debug representation of the key, but there is also -`maybe_changed_since`, which checks whether the value for a given key may have +`maybe_changed_after`, which checks whether the value for a given key may have changed since the given revision. Each of these methods is a member of the `DatabaseOps` trait and they are dispatched as follows. diff --git a/book/src/rfcs/RFC0009-Cycle-recovery.md b/book/src/rfcs/RFC0009-Cycle-recovery.md new file mode 100644 index 00000000..d0fe7b73 --- /dev/null +++ b/book/src/rfcs/RFC0009-Cycle-recovery.md @@ -0,0 +1,157 @@ +# Description/title + +## Metadata + +* Author: nikomatsakis +* Date: 2021-10-31 +* Introduced in: https://github.com/salsa-rs/salsa/pull/285 + +## Summary + +* Permit cycle recovery as long as at least one participant has recovery enabled. +* Modify cycle recovery to take a `&Cycle`. +* Introduce `Cycle` type that carries information about a cycle and lists participants in a deterministic order. + +[RFC 7]: ./RFC0007-Opinionated-Cancelation.md + +## Motivation + +Cycle recovery has been found to have some subtle bugs that could lead to panics. Furthermore, the existing cycle recovery APIs require all participants in a cycle to have recovery enabled and give limited and non-deterministic information. This RFC tweaks the user exposed APIs to correct these shortcomings. It also describes a major overhaul of how cycles are handled internally. + +## User's guide + +By default, cycles in the computation graph are considered a "programmer bug" and result in a panic. Sometimes, though, cycles are outside of the programmer's control. Salsa provides mechanisms to recover from cycles that can help in those cases. + +### Default cycle handling: panic + +By default, when Salsa detects a cycle in the computation graph, Salsa will panic with a `salsa::Cycle` as the panic value. Your queries should not attempt to catch this value; rather, the `salsa::Cycle` is meant to be caught by the outermost thread, which can print out information from it to diagnose what went wrong. The `Cycle` type offers a few methods for inspecting the participants in the cycle: + +* `participant_keys` -- returns an iterator over the `DatabaseKeyIndex` for each participant in the cycle. +* `all_participants` -- returns an iterator over `String` values for each participant in the cycle (debug output). +* `unexpected_participants` -- returns an iterator over `String` values for each participant in the cycle that doesn't have recovery information (see next section). + +`Cycle` implements `Debug`, but because the standard trait doesn't provide access to the database, the output can be kind of inscrutable. To get more readable `Debug` values, use the method `cycle.debug(db)`, which returns an `impl Debug` that is more readable. + +### Cycle recovery + +Panicking when a cycle occurs is ok for situations where you believe a cycle is impossible. But sometimes cycles can result from illegal user input and cannot be statically prevented. In these cases, you might prefer to gracefully recover from a cycle rather than panicking the entire query. Salsa supports that with the idea of *cycle recovery*. + +To use cycle recovery, you annotate potential participants in the cycle with a `#[salsa::recover(my_recover_fn)]` attribute. When a cycle occurs, if any participant P has recovery information, then no panic occurs. Instead, the execution of P is aborted and P will execute the recovery function to generate its result. Participants in the cycle that do not have recovery information continue executing as normal, using this recovery result. + +The recovery function has a similar signature to a query function. It is given a reference to your database along with a `salsa::Cycle` describing the cycle that occurred; it returns the result of the query. Example: + +```rust +fn my_recover_fn( + db: &dyn MyDatabase, + cycle: &salsa::Cycle, +) -> MyResultValue +``` + +The `db` and `cycle` argument can be used to prepare a useful error message for your users. + +**Important:** Although the recovery function is given a `db` handle, you should be careful to avoid creating a cycle from within recovery or invoking queries that may be participating in the current cycle. Attempting to do so can result in inconsistent results. + +### Figuring out why recovery did not work + +If a cycle occurs and *some* of the participant queries have `#[salsa::recover]` annotations and others do not, then the query will be treated as irrecoverable and will simply panic. You can use the `Cycle::unexpected_participants` method to figure out why recovery did not succeed and add the appropriate `#[salsa::recover]` annotations. + +## Reference guide + +This RFC accompanies a rather long and complex PR with a number of changes to the implementation. We summarize the most important points here. +# Cycles + +## Cross-thread blocking + +The interface for blocking across threads now works as follows: + +* When one thread `T1` wishes to block on a query `Q` being executed by another thread `T2`, it invokes `Runtime::try_block_on`. This will check for cycles. Assuming no cycle is detected, it will block `T1` until `T2` has completed with `Q`. At that point, `T1` reawakens. However, we don't know the result of executing `Q`, so `T1` now has to "retry". Typically, this will result in successfully reading the cached value. +* While `T1` is blocking, the runtime moves its query stack (a `Vec`) into the shared dependency graph data structure. When `T1` reawakens, it recovers ownership of its query stack before returning from `try_block_on`. + +## Cycle detection + +When a thread `T1` attempts to execute a query `Q`, it will try to load the value for `Q` from the memoization tables. If it finds an `InProgress` marker, that indicates that `Q` is currently being computed. This indicates a potential cycle. `T1` will then try to block on the query `Q`: + +* If `Q` is also being computed by `T1`, then there is a cycle. +* Otherwise, if `Q` is being computed by some other thread `T2`, we have to check whether `T2` is (transitively) blocked on `T1`. If so, there is a cycle. + +These two cases are handled internally by the `Runtime::try_block_on` function. Detecting the intra-thread cycle case is easy; to detect cross-thread cycles, the runtime maintains a dependency DAG between threads (identified by `RuntimeId`). Before adding an edge `T1 -> T2` (i.e., `T1` is blocked waiting for `T2`) into the DAG, it checks whether a path exists from `T2` to `T1`. If so, we have a cycle and the edge cannot be added (then the DAG would not longer be acyclic). + +When a cycle is detected, the current thread `T1` has full access to the query stacks that are participating in the cycle. Consider: naturally, `T1` has access to its own stack. There is also a path `T2 -> ... -> Tn -> T1` of blocked threads. Each of the blocked threads `T2 ..= Tn` will have moved their query stacks into the dependency graph, so those query stacks are available for inspection. + +Using the available stacks, we can create a list of cycle participants `Q0 ... Qn` and store that into a `Cycle` struct. If none of the participants `Q0 ... Qn` have cycle recovery enabled, we panic with the `Cycle` struct, which will trigger all the queries on this thread to panic. + +## Cycle recovery via fallback + +If any of the cycle participants `Q0 ... Qn` has cycle recovery set, we recover from the cycle. To help explain how this works, we will use this example cycle which contains three threads. Beginning with the current query, the cycle participants are `QA3`, `QB2`, `QB3`, `QC2`, `QC3`, and `QA2`. + +``` + The cyclic + edge we have + failed to add. + : + A : B C + : + QA1 v QB1 QC1 +┌► QA2 ┌──► QB2 ┌─► QC2 +│ QA3 ───┘ QB3 ──┘ QC3 ───┐ +│ │ +└───────────────────────────────┘ +``` + +Recovery works in phases: + +* **Analyze:** As we enumerate the query participants, we collect their collective inputs (all queries invoked so far by any cycle participant) and the max changed-at and min duration. We then remove the cycle participants themselves from this list of inputs, leaving only the queries external to the cycle. +* **Mark**: For each query Q that is annotated with `#[salsa::recover]`, we mark it and all of its successors on the same thread by setting its `cycle` flag to the `c: Cycle` we constructed earlier; we also reset its inputs to the collective inputs gathering during analysis. If those queries resume execution later, those marks will trigger them to immediately unwind and use cycle recovery, and the inputs will be used as the inputs to the recovery value. + * Note that we mark *all* the successors of Q on the same thread, whether or not they have recovery set. We'll discuss later how this is important in the case where the active thread (A, here) doesn't have any recovery set. +* **Unblock**: Each blocked thread T that has a recovering query is forcibly reawoken; the outgoing edge from that thread to its successor in the cycle is removed. Its condvar is signalled with a `WaitResult::Cycle(c)`. When the thread reawakens, it will see that and start unwinding with the cycle `c`. +* **Handle the current thread:** Finally, we have to choose how to have the current thread proceed. If the current thread includes any cycle with recovery information, then we can begin unwinding. Otherwise, the current thread simply continues as if there had been no cycle, and so the cyclic edge is added to the graph and the current thread blocks. This is possible because some other thread had recovery information and therefore has been awoken. + +Let's walk through the process with a few examples. + +### Example 1: Recovery on the detecting thread + +Consider the case where only the query QA2 has recovery set. It and QA3 will be marked with their `cycle` flag set to `c: Cycle`. Threads B and C will not be unblocked, as they do not have any cycle recovery nodes. The current thread (Thread A) will initiate unwinding with the cycle `c` as the value. Unwinding will pass through QA3 and be caught by QA2. QA2 will substitute the recovery value and return normally. QA1 and QC3 will then complete normally and so forth, on up until all queries have completed. + +### Example 2: Recovery in two queries on the detecting thread + +Consider the case where both query QA2 and QA3 have recovery set. It proceeds the same Example 1 until the the current initiates unwinding, as described in Example 1. When QA3 receives the cycle, it stores its recovery value and completes normally. QA2 then adds QA3 as an input dependency: at that point, QA2 observes that it too has the cycle mark set, and so it initiates unwinding. The rest of QA2 therefore never executes. This unwinding is caught by QA2's entry point and it stores the recovery value and returns normally. QA1 and QC3 then continue normally, as they have not had their `cycle` flag set. + +### Example 3: Recovery on another thread + +Now consider the case where only the query QB2 has recovery set. It and QB3 will be marked with the cycle `c: Cycle` and thread B will be unblocked; the edge `QB3 -> QC2` will be removed from the dependency graph. Thread A will then add an edge `QA3 -> QB2` and block on thread B. At that point, thread A releases the lock on the dependency graph, and so thread B is re-awoken. It observes the `WaitResult::Cycle` and initiates unwinding. Unwinding proceeds through QB3 and into QB2, which recovers. QB1 is then able to execute normally, as is QA3, and execution proceeds from there. + +### Example 4: Recovery on all queries + +Now consider the case where all the queries have recovery set. In that case, they are all marked with the cycle, and all the cross-thread edges are removed from the graph. Each thread will independently awaken and initiate unwinding. Each query will recover. + +## Frequently asked questions + +### Why have other threads retry instead of giving them the value? + +In the past, when one thread T1 blocked on some query Q being executed by another thread T2, we would create a custom channel between the threads. T2 would then send the result of Q directly to T1, and T1 had no need to retry. This mechanism was simplified in this RFC because we don't always have a value available: sometimes the cycle results when T2 is just verifying whether a memoized value is still valid. In that case, the value may not have been computed, and so when T1 retries it will in fact go on to compute the value. (Previously, this case was overlooked by the cycle handling logic and resulted in a panic.) + +### Why do we use unwinding to manage cycle recovery? + +When a query Q participates in cycle recovery, we use unwinding to get from the point where the cycle is detected back to the query's execution function. This ensures that the rest of Q never runs. This is important because Q might otherwise go on to create new cycles even while recovery is proceeding. Consider an example like: + +```rust +#[salsa::recovery] +fn query_q1(db: &dyn Database) { + db.query_q2() + db.query_q3() // <-- this never runs, thanks to unwinding +} + +#[salsa::recovery] +fn query_q2(db: &dyn Database) { + db.query_q1() +} + +#[salsa::recovery] +fn query_q3(db: &dyn Database) { + db.query_q1() +} +``` + +### Why not invoke the recovery functions all at once? + +The code currently unwinds frame by frame and invokes recovery as it goes. Another option might be to invoke the recovery function for all participants in the cycle up-front. This would be fine, but it's a bit difficult to do, since the types for each cycle are different, and the `Runtime` code doesn't know what they are. We also don't have access to the memoization tables and so forth. \ No newline at end of file diff --git a/components/salsa-macros/src/database_storage.rs b/components/salsa-macros/src/database_storage.rs index e2909cb3..c7620a17 100644 --- a/components/salsa-macros/src/database_storage.rs +++ b/components/salsa-macros/src/database_storage.rs @@ -106,6 +106,7 @@ pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream { // ANCHOR:DatabaseOps let mut fmt_ops = proc_macro2::TokenStream::new(); let mut maybe_changed_ops = proc_macro2::TokenStream::new(); + let mut cycle_recovery_strategy_ops = proc_macro2::TokenStream::new(); let mut for_each_ops = proc_macro2::TokenStream::new(); for ((QueryGroup { group_path }, group_storage), group_index) in query_groups .iter() @@ -123,7 +124,14 @@ pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream { #group_index => { let storage: &#group_storage = >::group_storage(self); - storage.maybe_changed_since(self, input, revision) + storage.maybe_changed_after(self, input, revision) + } + }); + cycle_recovery_strategy_ops.extend(quote! { + #group_index => { + let storage: &#group_storage = + >::group_storage(self); + storage.cycle_recovery_strategy(self, input) } }); for_each_ops.extend(quote! { @@ -157,7 +165,7 @@ pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream { } } - fn maybe_changed_since( + fn maybe_changed_after( &self, input: salsa::DatabaseKeyIndex, revision: salsa::Revision @@ -168,6 +176,16 @@ pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream { } } + fn cycle_recovery_strategy( + &self, + input: salsa::DatabaseKeyIndex, + ) -> salsa::plumbing::CycleRecoveryStrategy { + match input.group_index() { + #cycle_recovery_strategy_ops + i => panic!("salsa: invalid group index {}", i) + } + } + fn for_each_query( &self, mut op: &mut dyn FnMut(&dyn salsa::plumbing::QueryStorageMassOps), diff --git a/components/salsa-macros/src/query_group.rs b/components/salsa-macros/src/query_group.rs index 42c729d8..d8c5e4d5 100644 --- a/components/salsa-macros/src/query_group.rs +++ b/components/salsa-macros/src/query_group.rs @@ -483,17 +483,22 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream let recover = if let Some(cycle_recovery_fn) = &query.cycle { quote! { - fn recover(db: &>::DynDb, cycle: &[salsa::DatabaseKeyIndex], #key_pattern: &::Key) - -> Option<::Value> { - Some(#cycle_recovery_fn( + const CYCLE_STRATEGY: salsa::plumbing::CycleRecoveryStrategy = + salsa::plumbing::CycleRecoveryStrategy::Fallback; + fn cycle_fallback(db: &>::DynDb, cycle: &salsa::Cycle, #key_pattern: &::Key) + -> ::Value { + #cycle_recovery_fn( db, - &cycle.iter().map(|k| format!("{:?}", k.debug(db))).collect::>(), + cycle, #(#key_names),* - )) + ) } } } else { - quote! {} + quote! { + const CYCLE_STRATEGY: salsa::plumbing::CycleRecoveryStrategy = + salsa::plumbing::CycleRecoveryStrategy::Panic; + } }; output.extend(quote_spanned! {span=> @@ -527,13 +532,24 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream for (Query { fn_name, .. }, query_index) in non_transparent_queries().zip(0_u16..) { maybe_changed_ops.extend(quote! { #query_index => { - salsa::plumbing::QueryStorageOps::maybe_changed_since( + salsa::plumbing::QueryStorageOps::maybe_changed_after( &*self.#fn_name, db, input, revision ) } }); } + let mut cycle_recovery_strategy_ops = proc_macro2::TokenStream::new(); + for (Query { fn_name, .. }, query_index) in non_transparent_queries().zip(0_u16..) { + cycle_recovery_strategy_ops.extend(quote! { + #query_index => { + salsa::plumbing::QueryStorageOps::cycle_recovery_strategy( + &*self.#fn_name + ) + } + }); + } + let mut for_each_ops = proc_macro2::TokenStream::new(); for Query { fn_name, .. } in non_transparent_queries() { for_each_ops.extend(quote! { @@ -574,7 +590,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream } } - #trait_vis fn maybe_changed_since( + #trait_vis fn maybe_changed_after( &self, db: &(#dyn_db + '_), input: salsa::DatabaseKeyIndex, @@ -586,6 +602,17 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream } } + #trait_vis fn cycle_recovery_strategy( + &self, + db: &(#dyn_db + '_), + input: salsa::DatabaseKeyIndex, + ) -> salsa::plumbing::CycleRecoveryStrategy { + match input.query_index() { + #cycle_recovery_strategy_ops + i => panic!("salsa: impossible query index {}", i), + } + } + #trait_vis fn for_each_query( &self, _runtime: &salsa::Runtime, diff --git a/src/blocking_future.rs b/src/blocking_future.rs deleted file mode 100644 index bbdde14e..00000000 --- a/src/blocking_future.rs +++ /dev/null @@ -1,84 +0,0 @@ -use parking_lot::{Condvar, Mutex}; -use std::mem; -use std::sync::Arc; - -pub(crate) struct BlockingFuture { - slot: Arc>, -} - -pub(crate) struct Promise { - fulfilled: bool, - slot: Arc>, -} - -impl BlockingFuture { - pub(crate) fn new() -> (BlockingFuture, Promise) { - let future = BlockingFuture { - slot: Default::default(), - }; - let promise = Promise { - fulfilled: false, - slot: Arc::clone(&future.slot), - }; - (future, promise) - } - - pub(crate) fn wait(self) -> Option { - let mut guard = self.slot.lock.lock(); - if guard.is_empty() { - // parking_lot guarantees absence of spurious wake ups - self.slot.cvar.wait(&mut guard); - } - match mem::replace(&mut *guard, State::Dead) { - State::Empty => unreachable!(), - State::Full(it) => Some(it), - State::Dead => None, - } - } -} - -impl Promise { - pub(crate) fn fulfil(mut self, value: T) { - self.fulfilled = true; - self.transition(State::Full(value)); - } - fn transition(&mut self, new_state: State) { - let mut guard = self.slot.lock.lock(); - *guard = new_state; - self.slot.cvar.notify_one(); - } -} - -impl Drop for Promise { - fn drop(&mut self) { - if !self.fulfilled { - self.transition(State::Dead); - } - } -} - -struct Slot { - lock: Mutex>, - cvar: Condvar, -} - -impl Default for Slot { - fn default() -> Slot { - Slot { - lock: Mutex::new(State::Empty), - cvar: Condvar::new(), - } - } -} - -enum State { - Empty, - Full(T), - Dead, -} - -impl State { - fn is_empty(&self) -> bool { - matches!(self, State::Empty) - } -} diff --git a/src/derived.rs b/src/derived.rs index 3d144fad..d06ab35d 100644 --- a/src/derived.rs +++ b/src/derived.rs @@ -7,7 +7,7 @@ use crate::plumbing::QueryFunction; use crate::plumbing::QueryStorageMassOps; use crate::plumbing::QueryStorageOps; use crate::runtime::{FxIndexMap, StampedValue}; -use crate::{CycleError, Database, DatabaseKeyIndex, QueryDb, Revision}; +use crate::{Database, DatabaseKeyIndex, QueryDb, Revision}; use parking_lot::RwLock; use std::borrow::Borrow; use std::convert::TryFrom; @@ -117,6 +117,8 @@ where Q: QueryFunction, MP: MemoizationPolicy, { + const CYCLE_STRATEGY: crate::plumbing::CycleRecoveryStrategy = Q::CYCLE_STRATEGY; + fn new(group_index: u16) -> Self { DerivedStorage { group_index, @@ -139,7 +141,7 @@ where write!(fmt, "{}({:?})", Q::QUERY_NAME, key) } - fn maybe_changed_since( + fn maybe_changed_after( &self, db: &>::DynDb, input: DatabaseKeyIndex, @@ -147,6 +149,7 @@ where ) -> bool { assert_eq!(input.group_index, self.group_index); assert_eq!(input.query_index, Q::QUERY_INDEX); + debug_assert!(revision < db.salsa_runtime().current_revision()); let slot = self .slot_map .read() @@ -154,14 +157,10 @@ where .unwrap() .1 .clone(); - slot.maybe_changed_since(db, revision) + slot.maybe_changed_after(db, revision) } - fn try_fetch( - &self, - db: &>::DynDb, - key: &Q::Key, - ) -> Result> { + fn fetch(&self, db: &>::DynDb, key: &Q::Key) -> Q::Value { db.unwind_if_cancelled(); let slot = self.slot(key); @@ -169,16 +168,20 @@ where value, durability, changed_at, - } = slot.read(db)?; + } = slot.read(db); if let Some(evicted) = self.lru_list.record_use(&slot) { evicted.evict(); } db.salsa_runtime() - .report_query_read(slot.database_key_index(), durability, changed_at); + .report_query_read_and_unwind_if_cycle_resulted( + slot.database_key_index(), + durability, + changed_at, + ); - Ok(value) + value } fn durability(&self, db: &>::DynDb, key: &Q::Key) -> Durability { diff --git a/src/derived/slot.rs b/src/derived/slot.rs index 95443c99..41e788bb 100644 --- a/src/derived/slot.rs +++ b/src/derived/slot.rs @@ -1,24 +1,24 @@ -use crate::blocking_future::{BlockingFuture, Promise}; use crate::debug::TableEntry; use crate::derived::MemoizationPolicy; use crate::durability::Durability; use crate::lru::LruIndex; use crate::lru::LruNode; -use crate::plumbing::CycleDetected; use crate::plumbing::{DatabaseOps, QueryFunction}; use crate::revision::Revision; +use crate::runtime::local_state::ActiveQueryGuard; +use crate::runtime::local_state::QueryInputs; +use crate::runtime::local_state::QueryRevisions; use crate::runtime::Runtime; use crate::runtime::RuntimeId; use crate::runtime::StampedValue; -use crate::Cancelled; -use crate::{CycleError, Database, DatabaseKeyIndex, Event, EventKind, QueryDb}; +use crate::runtime::WaitResult; +use crate::Cycle; +use crate::{Database, DatabaseKeyIndex, Event, EventKind, QueryDb}; use log::{debug, info}; -use parking_lot::Mutex; use parking_lot::{RawRwLock, RwLock}; -use smallvec::SmallVec; use std::marker::PhantomData; use std::ops::Deref; -use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; pub(super) struct Slot where @@ -32,12 +32,6 @@ where lru_index: LruIndex, } -#[derive(Clone)] -struct WaitResult { - value: StampedValue, - cycle: Vec, -} - /// Defines the "current state" of query's memoized results. enum QueryState where @@ -46,61 +40,68 @@ where NotComputed, /// The runtime with the given id is currently computing the - /// result of this query; if we see this value in the table, it - /// indeeds a cycle. + /// result of this query. InProgress { id: RuntimeId, - waiting: Mutex>; 2]>>, + + /// Set to true if any other queries are blocked, + /// waiting for this query to complete. + anyone_waiting: AtomicBool, }, /// We have computed the query already, and here is the result. - Memoized(Memo), + Memoized(Memo), } -struct Memo -where - Q: QueryFunction, -{ +struct Memo { /// The result of the query, if we decide to memoize it. - value: Option, + value: Option, + + /// Last revision when this memo was verified; this begins + /// as the current revision. + pub(crate) verified_at: Revision, /// Revision information - revisions: MemoRevisions, + revisions: QueryRevisions, } -struct MemoRevisions { - /// Last revision when this memo was verified (if there are - /// untracked inputs, this will also be when the memo was - /// created). - verified_at: Revision, - - /// Last revision when the memoized value was observed to change. - changed_at: Revision, - - /// Minimum durability of the inputs to this query. - durability: Durability, - - /// The inputs that went into our query, if we are tracking them. - inputs: MemoInputs, +/// Return value of `probe` helper. +enum ProbeState { + /// Another thread was active but has completed. + /// Try again! + Retry, + + /// No entry for this key at all. + NotComputed(G), + + /// There is an entry, but its contents have not been + /// verified in this revision. + Stale(G), + + /// There is an entry, and it has been verified + /// in this revision, but it has no cached + /// value. The `Revision` is the revision where the + /// value last changed (if we were to recompute it). + NoValue(G, Revision), + + /// There is an entry which has been verified, + /// and it has the following value-- or, we blocked + /// on another thread, and that resulted in a cycle. + UpToDate(V), } -/// An insertion-order-preserving set of queries. Used to track the -/// inputs accessed during query execution. -pub(super) enum MemoInputs { - /// Non-empty set of inputs, fully known - Tracked { inputs: Arc<[DatabaseKeyIndex]> }, +/// Return value of `maybe_changed_after_probe` helper. +enum MaybeChangedSinceProbeState { + /// Another thread was active but has completed. + /// Try again! + Retry, - /// Empty set of inputs, fully known. - NoInputs, + /// Value may have changed in the given revision. + ChangedAt(Revision), - /// Unknown quantity of inputs - Untracked, -} - -/// Return value of `probe` helper. -enum ProbeState { - UpToDate(Result>), - StaleOrAbsent(G), + /// There is a stale cache entry that has not been + /// verified in this revision, so we can't say. + Stale(G), } impl Slot @@ -122,10 +123,7 @@ where self.database_key_index } - pub(super) fn read( - &self, - db: &>::DynDb, - ) -> Result, CycleError> { + pub(super) fn read(&self, db: &>::DynDb) -> StampedValue { let runtime = db.salsa_runtime(); // NB: We don't need to worry about people modifying the @@ -138,9 +136,14 @@ where info!("{:?}: invoked at {:?}", self, revision_now,); // First, do a check with a read-lock. - match self.probe(db, self.state.read(), runtime, revision_now) { - ProbeState::UpToDate(v) => return v, - ProbeState::StaleOrAbsent(_guard) => (), + loop { + match self.probe(db, self.state.read(), runtime, revision_now) { + ProbeState::UpToDate(v) => return v, + ProbeState::Stale(..) | ProbeState::NoValue(..) | ProbeState::NotComputed(..) => { + break + } + ProbeState::Retry => continue, + } } self.read_upgrade(db, revision_now) @@ -154,7 +157,7 @@ where &self, db: &>::DynDb, revision_now: Revision, - ) -> Result, CycleError> { + ) -> StampedValue { let runtime = db.salsa_runtime(); debug!("{:?}: read_upgrade(revision_now={:?})", self, revision_now,); @@ -162,29 +165,35 @@ where // Check with an upgradable read to see if there is a value // already. (This permits other readers but prevents anyone // else from running `read_upgrade` at the same time.) - let old_memo = match self.probe(db, self.state.upgradable_read(), runtime, revision_now) { - ProbeState::UpToDate(v) => return v, - ProbeState::StaleOrAbsent(state) => { - type RwLockUpgradableReadGuard<'a, T> = - lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>; - - let mut state = RwLockUpgradableReadGuard::upgrade(state); - match std::mem::replace(&mut *state, QueryState::in_progress(runtime.id())) { - QueryState::Memoized(old_memo) => Some(old_memo), - QueryState::InProgress { .. } => unreachable!(), - QueryState::NotComputed => None, + let mut old_memo = loop { + match self.probe(db, self.state.upgradable_read(), runtime, revision_now) { + ProbeState::UpToDate(v) => return v, + ProbeState::Stale(state) + | ProbeState::NotComputed(state) + | ProbeState::NoValue(state, _) => { + type RwLockUpgradableReadGuard<'a, T> = + lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>; + + let mut state = RwLockUpgradableReadGuard::upgrade(state); + match std::mem::replace(&mut *state, QueryState::in_progress(runtime.id())) { + QueryState::Memoized(old_memo) => break Some(old_memo), + QueryState::InProgress { .. } => unreachable!(), + QueryState::NotComputed => break None, + } } + ProbeState::Retry => continue, } }; - let mut panic_guard = PanicGuard::new(self.database_key_index, self, old_memo, runtime); + let panic_guard = PanicGuard::new(self.database_key_index, self, runtime); + let active_query = runtime.push_query(self.database_key_index); // If we have an old-value, it *may* now be stale, since there // has been a new revision since the last time we checked. So, // first things first, let's walk over each of our previous // inputs and check whether they are out of date. - if let Some(memo) = &mut panic_guard.memo { - if let Some(value) = memo.validate_memoized_value(db, revision_now) { + if let Some(memo) = &mut old_memo { + if let Some(value) = memo.verify_value(db.ops_database(), revision_now, &active_query) { info!("{:?}: validated old memoized value", self,); db.salsa_event(Event { @@ -194,42 +203,73 @@ where }, }); - panic_guard.proceed( - &value, - // The returned value could have been produced as part of a cycle but since - // we returned the memoized value we know we short-circuited the execution - // just as we entered the cycle. Therefore there is no values to invalidate - // and no need to call a cycle handler so we do not need to return the - // actual cycle - Vec::new(), - ); + panic_guard.proceed(old_memo); - return Ok(value); + return value; } } - // Query was not previously executed, or value is potentially - // stale, or value is absent. Let's execute! - let mut result = runtime.execute_query_implementation(db, self.database_key_index, || { - info!("{:?}: executing query", self); + self.execute( + db, + runtime, + revision_now, + active_query, + panic_guard, + old_memo, + ) + } - Q::execute(db, self.key.clone()) + fn execute( + &self, + db: &>::DynDb, + runtime: &Runtime, + revision_now: Revision, + active_query: ActiveQueryGuard<'_>, + panic_guard: PanicGuard<'_, Q, MP>, + old_memo: Option>, + ) -> StampedValue { + log::info!("{:?}: executing query", self.database_key_index.debug(db)); + + db.salsa_event(Event { + runtime_id: db.salsa_runtime().id(), + kind: EventKind::WillExecute { + database_key: self.database_key_index, + }, }); - if !result.cycle.is_empty() { - result.value = match Q::recover(db, &result.cycle, &self.key) { - Some(v) => v, - None => { - let err = CycleError { - cycle: result.cycle, - durability: result.durability, - changed_at: result.changed_at, - }; - panic_guard.report_unexpected_cycle(); - return Err(err); + // Query was not previously executed, or value is potentially + // stale, or value is absent. Let's execute! + let value = match Cycle::catch(|| Q::execute(db, self.key.clone())) { + Ok(v) => v, + Err(cycle) => { + log::debug!( + "{:?}: caught cycle {:?}, have strategy {:?}", + self.database_key_index.debug(db), + cycle, + Q::CYCLE_STRATEGY, + ); + match Q::CYCLE_STRATEGY { + crate::plumbing::CycleRecoveryStrategy::Panic => { + panic_guard.proceed(None); + cycle.throw() + } + crate::plumbing::CycleRecoveryStrategy::Fallback => { + if let Some(c) = active_query.take_cycle() { + assert!(c.is(&cycle)); + Q::cycle_fallback(db, &cycle, &self.key) + } else { + // we are not a participant in this cycle + debug_assert!(!cycle + .participant_keys() + .any(|k| k == self.database_key_index)); + cycle.throw() + } + } } - }; - } + } + }; + + let mut revisions = active_query.pop(); // We assume that query is side-effect free -- that is, does // not mutate the "inputs" to the query system. Sanity check @@ -244,72 +284,50 @@ where // really change, even if some of its inputs have. So we can // "backdate" its `changed_at` revision to be the same as the // old value. - if let Some(old_memo) = &panic_guard.memo { + if let Some(old_memo) = &old_memo { if let Some(old_value) = &old_memo.value { // Careful: if the value became less durable than it // used to be, that is a "breaking change" that our // consumers must be aware of. Becoming *more* durable // is not. See the test `constant_to_non_constant`. - if result.durability >= old_memo.revisions.durability - && MP::memoized_value_eq(&old_value, &result.value) + if revisions.durability >= old_memo.revisions.durability + && MP::memoized_value_eq(old_value, &value) { debug!( "read_upgrade({:?}): value is equal, back-dating to {:?}", self, old_memo.revisions.changed_at, ); - assert!(old_memo.revisions.changed_at <= result.changed_at); - result.changed_at = old_memo.revisions.changed_at; + assert!(old_memo.revisions.changed_at <= revisions.changed_at); + revisions.changed_at = old_memo.revisions.changed_at; } } } let new_value = StampedValue { - value: result.value, - durability: result.durability, - changed_at: result.changed_at, + value, + durability: revisions.durability, + changed_at: revisions.changed_at, }; - let value = if self.should_memoize_value(&self.key) { + let memo_value = if self.should_memoize_value(&self.key) { Some(new_value.value.clone()) } else { None }; debug!( - "read_upgrade({:?}): result.changed_at={:?}, \ - result.durability={:?}, result.dependencies = {:?}", - self, result.changed_at, result.durability, result.dependencies, + "read_upgrade({:?}): result.revisions = {:#?}", + self, revisions, ); - let inputs = match result.dependencies { - None => MemoInputs::Untracked, - - Some(dependencies) => { - if dependencies.is_empty() { - MemoInputs::NoInputs - } else { - MemoInputs::Tracked { - inputs: dependencies.into_iter().collect(), - } - } - } - }; - debug!("read_upgrade({:?}): inputs={:#?}", self, inputs.debug(db)); - - panic_guard.memo = Some(Memo { - value, - revisions: MemoRevisions { - changed_at: result.changed_at, - verified_at: revision_now, - inputs, - durability: result.durability, - }, - }); - - panic_guard.proceed(&new_value, result.cycle); + panic_guard.proceed(Some(Memo { + value: memo_value, + verified_at: revision_now, + revisions, + })); - Ok(new_value) + new_value } /// Helper for `read` that does a shallow check (not recursive) if we have an up-to-date value. @@ -329,98 +347,59 @@ where state: StateGuard, runtime: &Runtime, revision_now: Revision, - ) -> ProbeState, DatabaseKeyIndex, StateGuard> + ) -> ProbeState, StateGuard> where StateGuard: Deref>, { match &*state { - QueryState::NotComputed => { /* fall through */ } + QueryState::NotComputed => ProbeState::NotComputed(state), - QueryState::InProgress { id, waiting } => { + QueryState::InProgress { id, anyone_waiting } => { let other_id = *id; - return match self.register_with_in_progress_thread(db, runtime, other_id, waiting) { - Ok(future) => { - // Release our lock on `self.state`, so other thread can complete. - std::mem::drop(state); - - db.salsa_event(Event { - runtime_id: runtime.id(), - kind: EventKind::WillBlockOn { - other_runtime_id: other_id, - database_key: self.database_key_index, - }, - }); - - let result = future.wait().unwrap_or_else(|| { - // If the other thread panics, we treat this as cancellation: there is no - // need to panic ourselves, since the original panic will already invoke - // the panic hook and bubble up to the thread boundary (or be caught). - Cancelled::throw() - }); - ProbeState::UpToDate(if result.cycle.is_empty() { - Ok(result.value) - } else { - let err = CycleError { - cycle: result.cycle, - changed_at: result.value.changed_at, - durability: result.value.durability, - }; - runtime.mark_cycle_participants(&err); - Q::recover(db, &err.cycle, &self.key) - .map(|value| StampedValue { - value, - durability: err.durability, - changed_at: err.changed_at, - }) - .ok_or(err) - }) - } - Err(err) => { - let err = runtime.report_unexpected_cycle( - self.database_key_index, - err, - revision_now, - ); - ProbeState::UpToDate( - Q::recover(db, &err.cycle, &self.key) - .map(|value| StampedValue { - value, - changed_at: err.changed_at, - durability: err.durability, - }) - .ok_or(err), - ) - } - }; + // NB: `Ordering::Relaxed` is sufficient here, + // as there are no loads that are "gated" on this + // value. Everything that is written is also protected + // by a lock that must be acquired. The role of this + // boolean is to decide *whether* to acquire the lock, + // not to gate future atomic reads. + anyone_waiting.store(true, Ordering::Relaxed); + + self.block_on_or_unwind(db, runtime, other_id, state); + + // Other thread completely normally, so our value may be available now. + ProbeState::Retry } QueryState::Memoized(memo) => { debug!( "{:?}: found memoized value, verified_at={:?}, changed_at={:?}", - self, memo.revisions.verified_at, memo.revisions.changed_at, + self, memo.verified_at, memo.revisions.changed_at, ); + if memo.verified_at < revision_now { + return ProbeState::Stale(state); + } + if let Some(value) = &memo.value { - if memo.revisions.verified_at == revision_now { - let value = StampedValue { - durability: memo.revisions.durability, - changed_at: memo.revisions.changed_at, - value: value.clone(), - }; - - info!( - "{:?}: returning memoized value changed at {:?}", - self, value.changed_at - ); - - return ProbeState::UpToDate(Ok(value)); - } + let value = StampedValue { + durability: memo.revisions.durability, + changed_at: memo.revisions.changed_at, + value: value.clone(), + }; + + info!( + "{:?}: returning memoized value changed at {:?}", + self, value.changed_at + ); + + ProbeState::UpToDate(value) + } else { + let changed_at = memo.revisions.changed_at; + ProbeState::NoValue(state, changed_at) } } } - - ProbeState::StaleOrAbsent(state) } pub(super) fn durability(&self, db: &>::DynDb) -> Durability { @@ -428,7 +407,7 @@ where QueryState::NotComputed => Durability::LOW, QueryState::InProgress { .. } => panic!("query in progress"), QueryState::Memoized(memo) => { - if memo.revisions.check_durability(db.salsa_runtime()) { + if memo.check_durability(db.salsa_runtime()) { memo.revisions.durability } else { Durability::LOW @@ -454,7 +433,7 @@ where // lead to inconsistencies. Note that we can't check // `has_untracked_input` when we add the value to the cache, // because inputs can become untracked in the next revision. - if memo.revisions.has_untracked_input() { + if memo.has_untracked_input() { return; } memo.value = None; @@ -465,7 +444,7 @@ where log::debug!("Slot::invalidate(new_revision = {:?})", new_revision); match &mut *self.state.write() { QueryState::Memoized(memo) => { - memo.revisions.inputs = MemoInputs::Untracked; + memo.revisions.inputs = QueryInputs::Untracked; memo.revisions.changed_at = new_revision; Some(memo.revisions.durability) } @@ -474,7 +453,7 @@ where } } - pub(super) fn maybe_changed_since( + pub(super) fn maybe_changed_after( &self, db: &>::DynDb, revision: Revision, @@ -485,214 +464,135 @@ where db.unwind_if_cancelled(); debug!( - "maybe_changed_since({:?}) called with revision={:?}, revision_now={:?}", + "maybe_changed_after({:?}) called with revision={:?}, revision_now={:?}", self, revision, revision_now, ); - // Acquire read lock to start. In some of the arms below, we - // drop this explicitly. - let state = self.state.read(); - - // Look for a memoized value. - let memo = match &*state { - // If somebody depends on us, but we have no map - // entry, that must mean that it was found to be out - // of date and removed. - QueryState::NotComputed => { - debug!("maybe_changed_since({:?}): no value", self); - return true; - } - - // This value is being actively recomputed. Wait for - // that thread to finish (assuming it's not dependent - // on us...) and check its associated revision. - QueryState::InProgress { id, waiting } => { - let other_id = *id; - debug!( - "maybe_changed_since({:?}): blocking on thread `{:?}`", - self, other_id, - ); - match self.register_with_in_progress_thread(db, runtime, other_id, waiting) { - Ok(future) => { - // Release our lock on `self.state`, so other thread can complete. - std::mem::drop(state); - - let result = future.wait().unwrap_or_else(|| Cancelled::throw()); - return !result.cycle.is_empty() || result.value.changed_at > revision; - } - - // Consider a cycle to have changed. - Err(_) => return true, + // Do an initial probe with just the read-lock. + // + // If we find that a cache entry for the value is present + // but hasn't been verified in this revision, we'll have to + // do more. + loop { + match self.maybe_changed_after_probe(db, self.state.read(), runtime, revision_now) { + MaybeChangedSinceProbeState::Retry => continue, + MaybeChangedSinceProbeState::ChangedAt(changed_at) => return changed_at > revision, + MaybeChangedSinceProbeState::Stale(state) => { + drop(state); + return self.maybe_changed_after_upgrade(db, revision); } } + } + } - QueryState::Memoized(memo) => memo, - }; - - if memo.revisions.verified_at == revision_now { - debug!( - "maybe_changed_since({:?}): {:?} since up-to-date memo that changed at {:?}", - self, - memo.revisions.changed_at > revision, - memo.revisions.changed_at, - ); - return memo.revisions.changed_at > revision; + fn maybe_changed_after_probe( + &self, + db: &>::DynDb, + state: StateGuard, + runtime: &Runtime, + revision_now: Revision, + ) -> MaybeChangedSinceProbeState + where + StateGuard: Deref>, + { + match self.probe(db, state, runtime, revision_now) { + ProbeState::Retry => MaybeChangedSinceProbeState::Retry, + + ProbeState::Stale(state) => MaybeChangedSinceProbeState::Stale(state), + + // If we know when value last changed, we can return right away. + // Note that we don't need the actual value to be available. + ProbeState::NoValue(_, changed_at) + | ProbeState::UpToDate(StampedValue { + value: _, + durability: _, + changed_at, + }) => MaybeChangedSinceProbeState::ChangedAt(changed_at), + + // If we have nothing cached, then value may have changed. + ProbeState::NotComputed(_) => MaybeChangedSinceProbeState::ChangedAt(revision_now), } + } - let maybe_changed; + fn maybe_changed_after_upgrade( + &self, + db: &>::DynDb, + revision: Revision, + ) -> bool { + let runtime = db.salsa_runtime(); + let revision_now = runtime.current_revision(); - // If we only depended on constants, and no constant has been - // modified since then, we cannot have changed; no need to - // trace our inputs. - if memo.revisions.check_durability(runtime) { - std::mem::drop(state); - maybe_changed = false; - } else { - match &memo.revisions.inputs { - MemoInputs::Untracked => { - // we don't know the full set of - // inputs, so if there is a new - // revision, we must assume it is - // dirty - debug!( - "maybe_changed_since({:?}: true since untracked inputs", - self, - ); - return true; - } + // Get an upgradable read lock, which permits other reads but no writers. + // Probe again. If the value is stale (needs to be verified), then upgrade + // to a write lock and swap it with InProgress while we work. + let mut old_memo = match self.maybe_changed_after_probe( + db, + self.state.upgradable_read(), + runtime, + revision_now, + ) { + MaybeChangedSinceProbeState::ChangedAt(changed_at) => return changed_at > revision, - MemoInputs::NoInputs => { - std::mem::drop(state); - maybe_changed = false; - } + // If another thread was active, then the cache line is going to be + // either verified or cleared out. Just recurse to figure out which. + // Note that we don't need an upgradable read. + MaybeChangedSinceProbeState::Retry => return self.maybe_changed_after(db, revision), - MemoInputs::Tracked { inputs } => { - // At this point, the value may be dirty (we have - // to check the database-keys). If we have a cached - // value, we'll just fall back to invoking `read`, - // which will do that checking (and a bit more) -- - // note that we skip the "pure read" part as we - // already know the result. - assert!(inputs.len() > 0); - if memo.value.is_some() { - std::mem::drop(state); - return match self.read_upgrade(db, revision_now) { - Ok(v) => { - debug!( - "maybe_changed_since({:?}: {:?} since (recomputed) value changed at {:?}", - self, - v.changed_at > revision, - v.changed_at, - ); - v.changed_at > revision - } - Err(_) => true, - }; - } + MaybeChangedSinceProbeState::Stale(state) => { + type RwLockUpgradableReadGuard<'a, T> = + lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>; - // We have a **tracked set of inputs** that need to be validated. - let inputs = inputs.clone(); - // We'll need to update the state anyway (see below), so release the read-lock. - std::mem::drop(state); - - // Iterate the inputs and see if any have maybe changed. - maybe_changed = inputs - .iter() - .filter(|&&input| db.maybe_changed_since(input, revision)) - .inspect(|input| debug!("{:?}: input `{:?}` may have changed", self, input)) - .next() - .is_some(); + let mut state = RwLockUpgradableReadGuard::upgrade(state); + match std::mem::replace(&mut *state, QueryState::in_progress(runtime.id())) { + QueryState::Memoized(old_memo) => old_memo, + QueryState::NotComputed | QueryState::InProgress { .. } => unreachable!(), } } - } - - // Either way, we have to update our entry. - // - // Keep in mind, though, that we released the lock before checking the ipnuts and a lot - // could have happened in the interim. =) Therefore, we have to probe the current - // `self.state` again and in some cases we ought to do nothing. - { - let mut state = self.state.write(); - match &mut *state { - QueryState::Memoized(memo) => { - if memo.revisions.verified_at == revision_now { - // Since we started verifying inputs, somebody - // else has come along and updated this value - // (they may even have recomputed - // it). Therefore, we should not touch this - // memo. - // - // FIXME: Should we still return whatever - // `maybe_changed` value we computed, - // however..? It seems .. harmless to indicate - // that the value has changed, but possibly - // less efficient? (It may cause some - // downstream value to be recomputed that - // wouldn't otherwise have to be?) - } else if maybe_changed { - // We found this entry is out of date and - // nobody touch it in the meantime. Just - // remove it. - *state = QueryState::NotComputed; - } else { - // We found this entry is valid. Update the - // `verified_at` to reflect the current - // revision. - memo.revisions.verified_at = revision_now; - } - } - - QueryState::InProgress { .. } => { - // Since we started verifying inputs, somebody - // else has come along and started updated this - // value. Just leave their marker alone and return - // whatever `maybe_changed` value we computed. - } + }; - QueryState::NotComputed => { - // Since we started verifying inputs, somebody - // else has come along and removed this value. The - // GC can do this, for example. That's fine. - } - } + let panic_guard = PanicGuard::new(self.database_key_index, self, runtime); + let active_query = runtime.push_query(self.database_key_index); + + if old_memo.verify_revisions(db.ops_database(), revision_now, &active_query) { + let maybe_changed = old_memo.revisions.changed_at > revision; + panic_guard.proceed(Some(old_memo)); + maybe_changed + } else if old_memo.value.is_some() { + // We found that this memoized value may have changed + // but we have an old value. We can re-run the code and + // actually *check* if it has changed. + let StampedValue { changed_at, .. } = self.execute( + db, + runtime, + revision_now, + active_query, + panic_guard, + Some(old_memo), + ); + changed_at > revision + } else { + // We found that inputs to this memoized value may have chanced + // but we don't have an old value to compare against or re-use. + // No choice but to drop the memo and say that its value may have changed. + panic_guard.proceed(None); + true } - - maybe_changed } - /// Helper: - /// - /// When we encounter an `InProgress` indicator, we need to either - /// report a cycle or else register ourselves to be notified when - /// that work completes. This helper does that; it returns a port - /// where you can wait for the final value that wound up being - /// computed (but first drop the lock on the map). - fn register_with_in_progress_thread( + /// Helper: see [`Runtime::try_block_on_or_unwind`]. + fn block_on_or_unwind( &self, - _db: &>::DynDb, + db: &>::DynDb, runtime: &Runtime, other_id: RuntimeId, - waiting: &Mutex>; 2]>>, - ) -> Result>, CycleDetected> { - let id = runtime.id(); - if other_id == id { - Err(CycleDetected { from: id, to: id }) - } else { - if !runtime.try_block_on(self.database_key_index, other_id) { - return Err(CycleDetected { - from: id, - to: other_id, - }); - } - - let (future, promise) = BlockingFuture::new(); - - // The reader of this will have to acquire map - // lock, we don't need any particular ordering. - waiting.lock().push(promise); - - Ok(future) - } + mutex_guard: MutexGuard, + ) { + runtime.block_on_or_unwind( + db.ops_database(), + self.database_key_index, + other_id, + mutex_guard, + ) } fn should_memoize_value(&self, key: &Q::Key) -> bool { @@ -707,7 +607,7 @@ where fn in_progress(id: RuntimeId) -> Self { QueryState::InProgress { id, - waiting: Default::default(), + anyone_waiting: Default::default(), } } } @@ -719,7 +619,6 @@ where { database_key_index: DatabaseKeyIndex, slot: &'me Slot, - memo: Option>, runtime: &'me Runtime, } @@ -731,40 +630,31 @@ where fn new( database_key_index: DatabaseKeyIndex, slot: &'me Slot, - memo: Option>, runtime: &'me Runtime, ) -> Self { Self { database_key_index, slot, - memo, runtime, } } - /// Proceed with our panic guard by overwriting the placeholder for `key`. - /// Once that completes, ensure that our deconstructor is not run once we - /// are out of scope. - fn proceed(mut self, new_value: &StampedValue, cycle: Vec) { - self.overwrite_placeholder(Some((new_value, cycle))); - std::mem::forget(self) - } - - fn report_unexpected_cycle(mut self) { - self.overwrite_placeholder(None); + /// Indicates that we have concluded normally (without panicking). + /// If `opt_memo` is some, then this memo is installed as the new + /// memoized value. If `opt_memo` is `None`, then the slot is cleared + /// and has no value. + fn proceed(mut self, opt_memo: Option>) { + self.overwrite_placeholder(WaitResult::Completed, opt_memo); std::mem::forget(self) } /// Overwrites the `InProgress` placeholder for `key` that we /// inserted; if others were blocked, waiting for us to finish, /// then notify them. - fn overwrite_placeholder( - &mut self, - new_value: Option<(&StampedValue, Vec)>, - ) { + fn overwrite_placeholder(&mut self, wait_result: WaitResult, opt_memo: Option>) { let mut write = self.slot.state.write(); - let old_value = match self.memo.take() { + let old_value = match opt_memo { // Replace the `InProgress` marker that we installed with the new // memo, thus releasing our unique access to this key. Some(memo) => std::mem::replace(&mut *write, QueryState::Memoized(memo)), @@ -776,29 +666,16 @@ where }; match old_value { - QueryState::InProgress { id, waiting } => { + QueryState::InProgress { id, anyone_waiting } => { assert_eq!(id, self.runtime.id()); - self.runtime - .unblock_queries_blocked_on_self(self.database_key_index); - - match new_value { - // If anybody has installed themselves in our "waiting" - // list, notify them that the value is available. - Some((new_value, ref cycle)) => { - for promise in waiting.into_inner() { - promise.fulfil(WaitResult { - value: new_value.clone(), - cycle: cycle.clone(), - }); - } - } - - // We have no value to send when we are panicking. - // Therefore, we need to drop the sending half of the - // channel so that our panic propagates to those waiting - // on the receiving half. - None => std::mem::drop(waiting), + // NB: As noted on the `store`, `Ordering::Relaxed` is + // sufficient here. This boolean signals us on whether to + // acquire a mutex; the mutex will guarantee that all writes + // we are interested in are visible. + if anyone_waiting.load(Ordering::Relaxed) { + self.runtime + .unblock_queries_blocked_on(self.database_key_index, wait_result); } } _ => panic!( @@ -819,7 +696,7 @@ where fn drop(&mut self) { if std::thread::panicking() { // We panicked before we could proceed and need to remove `key`. - self.overwrite_placeholder(None) + self.overwrite_placeholder(WaitResult::Panicked, None) } else { // If no panic occurred, then panic guard ought to be // "forgotten" and so this Drop code should never run. @@ -828,56 +705,74 @@ where } } -impl Memo +impl Memo where - Q: QueryFunction, + V: Clone, { - fn validate_memoized_value( + /// Determines whether the value stored in this memo (if any) is still + /// valid in the current revision. If so, returns a stamped value. + /// + /// If needed, this will walk each dependency and + /// recursively invoke `maybe_changed_after`, which may in turn + /// re-execute the dependency. This can cause cycles to occur, + /// so the current query must be pushed onto the + /// stack to permit cycle detection and recovery: therefore, + /// takes the `active_query` argument as evidence. + fn verify_value( &mut self, - db: &>::DynDb, + db: &dyn Database, revision_now: Revision, - ) -> Option> { + active_query: &ActiveQueryGuard<'_>, + ) -> Option> { // If we don't have a memoized value, nothing to validate. - let value = match &self.value { - None => return None, - Some(v) => v, - }; - - let dyn_db = db.ops_database(); - if self.revisions.validate_memoized_value(dyn_db, revision_now) { + if self.value.is_none() { + return None; + } + if self.verify_revisions(db, revision_now, active_query) { Some(StampedValue { durability: self.revisions.durability, changed_at: self.revisions.changed_at, - value: value.clone(), + value: self.value.as_ref().unwrap().clone(), }) } else { None } } -} -impl MemoRevisions { - fn validate_memoized_value(&mut self, db: &dyn Database, revision_now: Revision) -> bool { + /// Determines whether the value represented by this memo is still + /// valid in the current revision; note that the value itself is + /// not needed for this check. If needed, this will walk each + /// dependency and recursively invoke `maybe_changed_after`, which + /// may in turn re-execute the dependency. This can cause cycles to occur, + /// so the current query must be pushed onto the + /// stack to permit cycle detection and recovery: therefore, + /// takes the `active_query` argument as evidence. + fn verify_revisions( + &mut self, + db: &dyn Database, + revision_now: Revision, + _active_query: &ActiveQueryGuard<'_>, + ) -> bool { assert!(self.verified_at != revision_now); let verified_at = self.verified_at; debug!( - "validate_memoized_value: verified_at={:?}, revision_now={:?}, inputs={:#?}", - verified_at, revision_now, self.inputs + "verify_revisions: verified_at={:?}, revision_now={:?}, inputs={:#?}", + verified_at, revision_now, self.revisions.inputs ); if self.check_durability(db.salsa_runtime()) { return self.mark_value_as_verified(revision_now); } - match &self.inputs { + match &self.revisions.inputs { // We can't validate values that had untracked inputs; just have to // re-execute. - MemoInputs::Untracked => { + QueryInputs::Untracked => { return false; } - MemoInputs::NoInputs => {} + QueryInputs::NoInputs => {} // Check whether any of our inputs changed since the // **last point where we were verified** (not since we @@ -888,10 +783,10 @@ impl MemoRevisions { // R1. But our *verification* date will be R2, and we // are only interested in finding out whether the // input changed *again*. - MemoInputs::Tracked { inputs } => { + QueryInputs::Tracked { inputs } => { let changed_input = inputs .iter() - .find(|&&input| db.maybe_changed_since(input, verified_at)); + .find(|&&input| db.maybe_changed_after(input, verified_at)); if let Some(input) = changed_input { debug!("validate_memoized_value: `{:?}` may have changed", input); @@ -905,7 +800,7 @@ impl MemoRevisions { /// True if this memo is known not to have changed based on its durability. fn check_durability(&self, runtime: &Runtime) -> bool { - let last_changed = runtime.last_changed_revision(self.durability); + let last_changed = runtime.last_changed_revision(self.revisions.durability); debug!( "check_durability(last_changed={:?} <= verified_at={:?}) = {:?}", last_changed, @@ -921,7 +816,7 @@ impl MemoRevisions { } fn has_untracked_input(&self) -> bool { - matches!(self.inputs, MemoInputs::Untracked) + matches!(self.revisions.inputs, QueryInputs::Untracked) } } @@ -935,59 +830,6 @@ where } } -impl MemoInputs { - fn debug<'a, D: ?Sized>(&'a self, db: &'a D) -> impl std::fmt::Debug + 'a - where - D: DatabaseOps, - { - enum DebugMemoInputs<'a, D: ?Sized> { - Tracked { - inputs: &'a [DatabaseKeyIndex], - db: &'a D, - }, - NoInputs, - Untracked, - } - - impl std::fmt::Debug for DebugMemoInputs<'_, D> { - fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - DebugMemoInputs::Tracked { inputs, db } => fmt - .debug_struct("Tracked") - .field( - "inputs", - &inputs.iter().map(|key| key.debug(*db)).collect::>(), - ) - .finish(), - DebugMemoInputs::NoInputs => fmt.debug_struct("NoInputs").finish(), - DebugMemoInputs::Untracked => fmt.debug_struct("Untracked").finish(), - } - } - } - - match self { - MemoInputs::Tracked { inputs } => DebugMemoInputs::Tracked { - inputs: &inputs, - db, - }, - MemoInputs::NoInputs => DebugMemoInputs::NoInputs, - MemoInputs::Untracked => DebugMemoInputs::Untracked, - } - } -} - -impl std::fmt::Debug for MemoInputs { - fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - MemoInputs::Tracked { inputs } => { - fmt.debug_struct("Tracked").field("inputs", inputs).finish() - } - MemoInputs::NoInputs => fmt.debug_struct("NoInputs").finish(), - MemoInputs::Untracked => fmt.debug_struct("Untracked").finish(), - } - } -} - impl LruNode for Slot where Q: QueryFunction, diff --git a/src/input.rs b/src/input.rs index dea49843..191f44a6 100644 --- a/src/input.rs +++ b/src/input.rs @@ -1,11 +1,11 @@ use crate::debug::TableEntry; use crate::durability::Durability; +use crate::plumbing::CycleRecoveryStrategy; use crate::plumbing::InputQueryStorageOps; use crate::plumbing::QueryStorageMassOps; use crate::plumbing::QueryStorageOps; use crate::revision::Revision; use crate::runtime::{FxIndexMap, StampedValue}; -use crate::CycleError; use crate::Database; use crate::Query; use crate::{DatabaseKeyIndex, QueryDb}; @@ -56,6 +56,8 @@ impl QueryStorageOps for InputStorage where Q: Query, { + const CYCLE_STRATEGY: crate::plumbing::CycleRecoveryStrategy = CycleRecoveryStrategy::Panic; + fn new(group_index: u16) -> Self { InputStorage { group_index, @@ -76,7 +78,7 @@ where write!(fmt, "{}({:?})", Q::QUERY_NAME, key) } - fn maybe_changed_since( + fn maybe_changed_after( &self, db: &>::DynDb, input: DatabaseKeyIndex, @@ -84,6 +86,7 @@ where ) -> bool { assert_eq!(input.group_index, self.group_index); assert_eq!(input.query_index, Q::QUERY_INDEX); + debug_assert!(revision < db.salsa_runtime().current_revision()); let slot = self .slots .read() @@ -91,14 +94,10 @@ where .unwrap() .1 .clone(); - slot.maybe_changed_since(db, revision) + slot.maybe_changed_after(db, revision) } - fn try_fetch( - &self, - db: &>::DynDb, - key: &Q::Key, - ) -> Result> { + fn fetch(&self, db: &>::DynDb, key: &Q::Key) -> Q::Value { db.unwind_if_cancelled(); let slot = self @@ -112,9 +111,13 @@ where } = slot.stamped_value.read().clone(); db.salsa_runtime() - .report_query_read(slot.database_key_index, durability, changed_at); + .report_query_read_and_unwind_if_cycle_resulted( + slot.database_key_index, + durability, + changed_at, + ); - Ok(value) + value } fn durability(&self, _db: &>::DynDb, key: &Q::Key) -> Durability { @@ -145,15 +148,15 @@ impl Slot where Q: Query, { - fn maybe_changed_since(&self, _db: &>::DynDb, revision: Revision) -> bool { + fn maybe_changed_after(&self, _db: &>::DynDb, revision: Revision) -> bool { debug!( - "maybe_changed_since(slot={:?}, revision={:?})", + "maybe_changed_after(slot={:?}, revision={:?})", self, revision, ); let changed_at = self.stamped_value.read().changed_at; - debug!("maybe_changed_since: changed_at = {:?}", changed_at); + debug!("maybe_changed_after: changed_at = {:?}", changed_at); changed_at > revision } diff --git a/src/interned.rs b/src/interned.rs index 6ad59998..a7220251 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -1,12 +1,13 @@ use crate::debug::TableEntry; use crate::durability::Durability; use crate::intern_id::InternId; +use crate::plumbing::CycleRecoveryStrategy; use crate::plumbing::HasQueryGroup; use crate::plumbing::QueryStorageMassOps; use crate::plumbing::QueryStorageOps; use crate::revision::Revision; use crate::Query; -use crate::{CycleError, Database, DatabaseKeyIndex, QueryDb}; +use crate::{Database, DatabaseKeyIndex, QueryDb}; use parking_lot::RwLock; use rustc_hash::FxHashMap; use std::collections::hash_map::Entry; @@ -191,6 +192,8 @@ where Q: Query, Q::Value: InternKey, { + const CYCLE_STRATEGY: crate::plumbing::CycleRecoveryStrategy = CycleRecoveryStrategy::Panic; + fn new(group_index: u16) -> Self { InternedStorage { group_index, @@ -211,35 +214,32 @@ where write!(fmt, "{}({:?})", Q::QUERY_NAME, slot.value) } - fn maybe_changed_since( + fn maybe_changed_after( &self, - _db: &>::DynDb, + db: &>::DynDb, input: DatabaseKeyIndex, revision: Revision, ) -> bool { assert_eq!(input.group_index, self.group_index); assert_eq!(input.query_index, Q::QUERY_INDEX); + debug_assert!(revision < db.salsa_runtime().current_revision()); let intern_id = InternId::from(input.key_index); let slot = self.lookup_value(intern_id); - slot.maybe_changed_since(revision) + slot.maybe_changed_after(revision) } - fn try_fetch( - &self, - db: &>::DynDb, - key: &Q::Key, - ) -> Result> { + fn fetch(&self, db: &>::DynDb, key: &Q::Key) -> Q::Value { db.unwind_if_cancelled(); - let slot = self.intern_index(db, key); let changed_at = slot.interned_at; let index = slot.index; - db.salsa_runtime().report_query_read( - slot.database_key_index, - INTERN_DURABILITY, - changed_at, - ); - Ok(::from_intern_id(index)) + db.salsa_runtime() + .report_query_read_and_unwind_if_cycle_resulted( + slot.database_key_index, + INTERN_DURABILITY, + changed_at, + ); + ::from_intern_id(index) } fn durability(&self, _db: &>::DynDb, _key: &Q::Key) -> Durability { @@ -312,6 +312,8 @@ where IQ: Query>, for<'d> Q: EqualDynDb<'d, IQ>, { + const CYCLE_STRATEGY: CycleRecoveryStrategy = CycleRecoveryStrategy::Panic; + fn new(_group_index: u16) -> Self { LookupInternedStorage { phantom: std::marker::PhantomData, @@ -330,7 +332,7 @@ where interned_storage.fmt_index(Q::convert_db(db), index, fmt) } - fn maybe_changed_since( + fn maybe_changed_after( &self, db: &>::DynDb, input: DatabaseKeyIndex, @@ -339,14 +341,10 @@ where let group_storage = <>::DynDb as HasQueryGroup>::group_storage(db); let interned_storage = IQ::query_storage(Q::convert_group_storage(group_storage)); - interned_storage.maybe_changed_since(Q::convert_db(db), input, revision) + interned_storage.maybe_changed_after(Q::convert_db(db), input, revision) } - fn try_fetch( - &self, - db: &>::DynDb, - key: &Q::Key, - ) -> Result> { + fn fetch(&self, db: &>::DynDb, key: &Q::Key) -> Q::Value { let index = key.as_intern_id(); let group_storage = <>::DynDb as HasQueryGroup>::group_storage(db); @@ -354,12 +352,13 @@ where let slot = interned_storage.lookup_value(index); let value = slot.value.clone(); let interned_at = slot.interned_at; - db.salsa_runtime().report_query_read( - slot.database_key_index, - INTERN_DURABILITY, - interned_at, - ); - Ok(value) + db.salsa_runtime() + .report_query_read_and_unwind_if_cycle_resulted( + slot.database_key_index, + INTERN_DURABILITY, + interned_at, + ); + value } fn durability(&self, _db: &>::DynDb, _key: &Q::Key) -> Durability { @@ -395,7 +394,7 @@ where } impl Slot { - fn maybe_changed_since(&self, revision: Revision) -> bool { + fn maybe_changed_after(&self, revision: Revision) -> bool { self.interned_at > revision } } diff --git a/src/lib.rs b/src/lib.rs index 45246eb6..a34df5ac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ #![allow(clippy::type_complexity)] +#![allow(clippy::question_mark)] #![warn(rust_2018_idioms)] #![warn(missing_docs)] @@ -8,7 +9,6 @@ //! re-execute the derived queries and it will try to re-use results //! from previous invocations as appropriate. -mod blocking_future; mod derived; mod doctest; mod durability; @@ -26,7 +26,7 @@ pub mod debug; #[doc(hidden)] pub mod plumbing; -use crate::plumbing::DatabaseOps; +use crate::plumbing::CycleRecoveryStrategy; use crate::plumbing::DerivedQueryStorageOps; use crate::plumbing::InputQueryStorageOps; use crate::plumbing::LruQueryStorageOps; @@ -35,6 +35,7 @@ use crate::plumbing::QueryStorageOps; pub use crate::revision::Revision; use std::fmt::{self, Debug}; use std::hash::Hash; +use std::panic::AssertUnwindSafe; use std::panic::{self, UnwindSafe}; use std::sync::Arc; @@ -115,6 +116,17 @@ pub struct Event { pub kind: EventKind, } +impl Event { + /// Returns a type that gives a user-readable debug output. + /// Use like `println!("{:?}", index.debug(db))`. + pub fn debug<'me, D: ?Sized>(&'me self, db: &'me D) -> impl std::fmt::Debug + 'me + where + D: plumbing::DatabaseOps, + { + EventDebug { event: self, db } + } +} + impl fmt::Debug for Event { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Event") @@ -124,6 +136,26 @@ impl fmt::Debug for Event { } } +struct EventDebug<'me, D: ?Sized> +where + D: plumbing::DatabaseOps, +{ + event: &'me Event, + db: &'me D, +} + +impl<'me, D: ?Sized> fmt::Debug for EventDebug<'me, D> +where + D: plumbing::DatabaseOps, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Event") + .field("runtime_id", &self.event.runtime_id) + .field("kind", &self.event.kind.debug(self.db)) + .finish() + } +} + /// An enum identifying the various kinds of events that can occur. pub enum EventKind { /// Occurs when we found that all inputs to a memoized value are @@ -166,6 +198,17 @@ pub enum EventKind { WillCheckCancellation, } +impl EventKind { + /// Returns a type that gives a user-readable debug output. + /// Use like `println!("{:?}", index.debug(db))`. + pub fn debug<'me, D: ?Sized>(&'me self, db: &'me D) -> impl std::fmt::Debug + 'me + where + D: plumbing::DatabaseOps, + { + EventKindDebug { kind: self, db } + } +} + impl fmt::Debug for EventKind { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -190,6 +233,41 @@ impl fmt::Debug for EventKind { } } +struct EventKindDebug<'me, D: ?Sized> +where + D: plumbing::DatabaseOps, +{ + kind: &'me EventKind, + db: &'me D, +} + +impl<'me, D: ?Sized> fmt::Debug for EventKindDebug<'me, D> +where + D: plumbing::DatabaseOps, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.kind { + EventKind::DidValidateMemoizedValue { database_key } => fmt + .debug_struct("DidValidateMemoizedValue") + .field("database_key", &database_key.debug(self.db)) + .finish(), + EventKind::WillBlockOn { + other_runtime_id, + database_key, + } => fmt + .debug_struct("WillBlockOn") + .field("other_runtime_id", &other_runtime_id) + .field("database_key", &database_key.debug(self.db)) + .finish(), + EventKind::WillExecute { database_key } => fmt + .debug_struct("WillExecute") + .field("database_key", &database_key.debug(self.db)) + .finish(), + EventKind::WillCheckCancellation => fmt.debug_struct("WillCheckCancellation").finish(), + } + } +} + /// Indicates a database that also supports parallel query /// evaluation. All of Salsa's base query support is capable of /// parallel execution, but for it to work, your query key/value types @@ -418,12 +496,7 @@ where /// queries (those with no inputs, or those with more than one /// input) the key will be a tuple. pub fn get(&self, key: Q::Key) -> Q::Value { - self.try_get(key) - .unwrap_or_else(|err| panic!("{:?}", err.debug(self.db))) - } - - fn try_get(&self, key: Q::Key) -> Result> { - self.storage.try_fetch(self.db, &key) + self.storage.fetch(self.db, &key) } /// Completely clears the storage for this query. @@ -520,62 +593,29 @@ where } } -/// The error returned when a query could not be resolved due to a cycle -#[derive(Eq, PartialEq, Clone, Debug)] -pub struct CycleError { - /// The queries that were part of the cycle - cycle: Vec, - changed_at: Revision, - durability: Durability, -} - -impl CycleError { - fn debug<'a, D: ?Sized>(&'a self, db: &'a D) -> impl Debug + 'a - where - D: DatabaseOps, - { - struct CycleErrorDebug<'a, D: ?Sized> { - db: &'a D, - error: &'a CycleError, - } - - impl<'a, D: ?Sized + DatabaseOps> Debug for CycleErrorDebug<'a, D> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "Internal error, cycle detected:\n")?; - for i in &self.error.cycle { - writeln!(f, "{:?}", i.debug(self.db))?; - } - Ok(()) - } - } - - CycleErrorDebug { db, error: self } - } -} - -impl fmt::Display for CycleError -where - K: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "Internal error, cycle detected:\n")?; - for i in &self.cycle { - writeln!(f, "{:?}", i)?; - } - Ok(()) - } -} - -/// A panic payload indicating that a salsa revision was cancelled. +/// A panic payload indicating that execution of a salsa query was cancelled. +/// +/// This can occur for a few reasons: +/// * +/// * +/// * #[derive(Debug)] #[non_exhaustive] -pub struct Cancelled; +pub enum Cancelled { + /// The query was operating on revision R, but there is a pending write to move to revision R+1. + #[non_exhaustive] + PendingWrite, + + /// The query was blocked on another thread, and that thread panicked. + #[non_exhaustive] + PropagatedPanic, +} impl Cancelled { - fn throw() -> ! { + fn throw(self) -> ! { // We use resume and not panic here to avoid running the panic // hook (that is, to avoid collecting and printing backtrace). - std::panic::resume_unwind(Box::new(Self)); + std::panic::resume_unwind(Box::new(self)); } /// Runs `f`, and catches any salsa cancellation. @@ -595,12 +635,114 @@ impl Cancelled { impl std::fmt::Display for Cancelled { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("cancelled") + let why = match self { + Cancelled::PendingWrite => "pending write", + Cancelled::PropagatedPanic => "propagated panic", + }; + f.write_str("cancelled because of ")?; + f.write_str(why) } } impl std::error::Error for Cancelled {} +/// Captures the participants of a cycle that occurred when executing a query. +/// +/// This type is meant to be used to help give meaningful error messages to the +/// user or to help salsa developers figure out why their program is resulting +/// in a computation cycle. +/// +/// It is used in a few ways: +/// +/// * During [cycle recovery](https://https://salsa-rs.github.io/salsa/cycles/fallback.html), +/// where it is given to the fallback function. +/// * As the panic value when an unexpected cycle (i.e., a cycle where one or more participants +/// lacks cycle recovery information) occurs. +/// +/// You can read more about cycle handling in +/// the [salsa book](https://https://salsa-rs.github.io/salsa/cycles.html). +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Cycle { + participants: plumbing::CycleParticipants, +} + +impl Cycle { + pub(crate) fn new(participants: plumbing::CycleParticipants) -> Self { + Self { participants } + } + + /// True if two `Cycle` values represent the same cycle. + pub(crate) fn is(&self, cycle: &Cycle) -> bool { + Arc::ptr_eq(&self.participants, &cycle.participants) + } + + pub(crate) fn throw(self) -> ! { + log::debug!("throwing cycle {:?}", self); + std::panic::resume_unwind(Box::new(self)) + } + + pub(crate) fn catch(execute: impl FnOnce() -> T) -> Result { + match std::panic::catch_unwind(AssertUnwindSafe(execute)) { + Ok(v) => Ok(v), + Err(err) => match err.downcast::() { + Ok(cycle) => Err(*cycle), + Err(other) => std::panic::resume_unwind(other), + }, + } + } + + /// Iterate over the [`DatabaseKeyIndex`] for each query participating + /// in the cycle. The start point of this iteration within the cycle + /// is arbitrary but deterministic, but the ordering is otherwise determined + /// by the execution. + pub fn participant_keys(&self) -> impl Iterator + '_ { + self.participants.iter().copied() + } + + /// Returns a vector with the debug information for + /// all the participants in the cycle. + pub fn all_participants(&self, db: &DB) -> Vec { + self.participant_keys() + .map(|d| format!("{:?}", d.debug(db))) + .collect() + } + + /// Returns a vector with the debug information for + /// those participants in the cycle that lacked recovery + /// information. + pub fn unexpected_participants(&self, db: &DB) -> Vec { + self.participant_keys() + .filter(|&d| db.cycle_recovery_strategy(d) == CycleRecoveryStrategy::Panic) + .map(|d| format!("{:?}", d.debug(db))) + .collect() + } + + /// Returns a "debug" view onto this strict that can be used to print out information. + pub fn debug<'me, DB: ?Sized + Database>(&'me self, db: &'me DB) -> impl std::fmt::Debug + 'me { + struct UnexpectedCycleDebug<'me> { + c: &'me Cycle, + db: &'me dyn Database, + } + + impl<'me> std::fmt::Debug for UnexpectedCycleDebug<'me> { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt.debug_struct("UnexpectedCycle") + .field("all_participants", &self.c.all_participants(self.db)) + .field( + "unexpected_participants", + &self.c.unexpected_participants(self.db), + ) + .finish() + } + } + + UnexpectedCycleDebug { + c: self, + db: db.ops_database(), + } + } +} + // Re-export the procedural macros. #[allow(unused_imports)] #[macro_use] diff --git a/src/plumbing.rs b/src/plumbing.rs index 2c9dd95e..cc5d4e55 100644 --- a/src/plumbing.rs +++ b/src/plumbing.rs @@ -2,15 +2,15 @@ use crate::debug::TableEntry; use crate::durability::Durability; -use crate::CycleError; +use crate::Cycle; use crate::Database; use crate::Query; use crate::QueryTable; use crate::QueryTableMut; -use crate::RuntimeId; use std::borrow::Borrow; use std::fmt::Debug; use std::hash::Hash; +use std::sync::Arc; pub use crate::derived::DependencyStorage; pub use crate::derived::MemoizedStorage; @@ -19,12 +19,6 @@ pub use crate::interned::InternedStorage; pub use crate::interned::LookupInternedStorage; pub use crate::{revision::Revision, DatabaseKeyIndex, QueryDb, Runtime}; -#[derive(Clone, Debug)] -pub struct CycleDetected { - pub(crate) from: RuntimeId, - pub(crate) to: RuntimeId, -} - /// Defines various associated types. An impl of this /// should be generated for your query-context type automatically by /// the `database_storage` macro, so you shouldn't need to mess @@ -54,7 +48,10 @@ pub trait DatabaseOps { ) -> std::fmt::Result; /// True if the computed value for `input` may have changed since `revision`. - fn maybe_changed_since(&self, input: DatabaseKeyIndex, revision: Revision) -> bool; + fn maybe_changed_after(&self, input: DatabaseKeyIndex, revision: Revision) -> bool; + + /// Find the `CycleRecoveryStrategy` for a given input. + fn cycle_recovery_strategy(&self, input: DatabaseKeyIndex) -> CycleRecoveryStrategy; /// Executes the callback for each kind of query. fn for_each_query(&self, op: &mut dyn FnMut(&dyn QueryStorageMassOps)); @@ -70,18 +67,45 @@ pub trait QueryStorageMassOps { pub trait DatabaseKey: Clone + Debug + Eq + Hash {} pub trait QueryFunction: Query { + /// See `CycleRecoveryStrategy` + const CYCLE_STRATEGY: CycleRecoveryStrategy; + fn execute(db: &>::DynDb, key: Self::Key) -> Self::Value; - fn recover( + fn cycle_fallback( db: &>::DynDb, - cycle: &[DatabaseKeyIndex], + cycle: &Cycle, key: &Self::Key, - ) -> Option { + ) -> Self::Value { let _ = (db, cycle, key); - None + panic!( + "query `{:?}` doesn't support cycle fallback", + Self::default() + ) } } +/// Cycle recovery strategy: Is this query capable of recovering from +/// a cycle that results from executing the function? If so, how? +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum CycleRecoveryStrategy { + /// Cannot recover from cycles: panic. + /// + /// This is the default. It is also what happens if a cycle + /// occurs and the queries involved have different recovery + /// strategies. + /// + /// In the case of a failure due to a cycle, the panic + /// value will be XXX (FIXME). + Panic, + + /// Recovers from cycles by storing a sentinel value. + /// + /// This value is computed by the `QueryFunction::cycle_fallback` + /// function. + Fallback, +} + /// Create a query table, which has access to the storage for the query /// and offers methods like `get`. pub fn get_query_table<'me, Q>(db: &'me >::DynDb) -> QueryTable<'me, Q> @@ -122,11 +146,17 @@ where fn group_storage(&self) -> &G::GroupStorage; } +// ANCHOR:QueryStorageOps pub trait QueryStorageOps where Self: QueryStorageMassOps, Q: Query, { + // ANCHOR_END:QueryStorageOps + + /// See CycleRecoveryStrategy + const CYCLE_STRATEGY: CycleRecoveryStrategy; + fn new(group_index: u16) -> Self; /// Format a database key index in a suitable way. @@ -137,15 +167,25 @@ where fmt: &mut std::fmt::Formatter<'_>, ) -> std::fmt::Result; + // ANCHOR:maybe_changed_after /// True if the value of `input`, which must be from this query, may have - /// changed since the given revision. - fn maybe_changed_since( + /// changed after the given revision ended. + /// + /// This function should only be invoked with a revision less than the current + /// revision. + fn maybe_changed_after( &self, db: &>::DynDb, input: DatabaseKeyIndex, revision: Revision, ) -> bool; + // ANCHOR_END:maybe_changed_after + + fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { + Self::CYCLE_STRATEGY + } + // ANCHOR:fetch /// Execute the query, returning the result (often, the result /// will be memoized). This is the "main method" for /// queries. @@ -153,11 +193,8 @@ where /// Returns `Err` in the event of a cycle, meaning that computing /// the value for this `key` is recursively attempting to fetch /// itself. - fn try_fetch( - &self, - db: &>::DynDb, - key: &Q::Key, - ) -> Result>; + fn fetch(&self, db: &>::DynDb, key: &Q::Key) -> Q::Value; + // ANCHOR_END:fetch /// Returns the durability associated with a given key. fn durability(&self, db: &>::DynDb, key: &Q::Key) -> Durability; @@ -200,3 +237,5 @@ where S: Eq + Hash, Q::Key: Borrow; } + +pub type CycleParticipants = Arc>; diff --git a/src/runtime.rs b/src/runtime.rs index 5743271b..b6350020 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,22 +1,27 @@ -use crate::plumbing::CycleDetected; +use crate::durability::Durability; +use crate::plumbing::CycleRecoveryStrategy; use crate::revision::{AtomicRevision, Revision}; -use crate::{durability::Durability, Cancelled}; -use crate::{CycleError, Database, DatabaseKeyIndex, Event, EventKind}; +use crate::{Cancelled, Cycle, Database, DatabaseKeyIndex, Event, EventKind}; use log::debug; use parking_lot::lock_api::{RawRwLock, RawRwLockRecursive}; use parking_lot::{Mutex, RwLock}; -use rustc_hash::{FxHashMap, FxHasher}; -use smallvec::SmallVec; +use rustc_hash::FxHasher; use std::hash::{BuildHasherDefault, Hash}; +use std::panic::panic_any; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; pub(crate) type FxIndexSet = indexmap::IndexSet>; pub(crate) type FxIndexMap = indexmap::IndexMap>; -mod local_state; +mod dependency_graph; +use dependency_graph::DependencyGraph; + +pub(crate) mod local_state; use local_state::LocalState; +use self::local_state::{ActiveQueryGuard, QueryInputs, QueryRevisions}; + /// The salsa runtime stores the storage for all queries as well as /// tracking the query stack and dependencies between cycles. /// @@ -40,6 +45,13 @@ pub struct Runtime { shared_state: Arc, } +#[derive(Clone, Debug)] +pub(crate) enum WaitResult { + Completed, + Panicked, + Cycle(Cycle), +} + impl Default for Runtime { fn default() -> Self { Runtime { @@ -61,6 +73,16 @@ impl std::fmt::Debug for Runtime { } } +#[derive(Clone, Debug)] +struct CycleDetected { + /// Common recovery strategy to all participants in the cycle, + /// or [`CycleRecoveryStrategy::Panic`] otherwise. + pub(crate) recovery_strategy: CycleRecoveryStrategy, + + /// Cycle participants. + pub(crate) cycle: Cycle, +} + impl Runtime { /// Create a new runtime; equivalent to `Self::default`. This is /// used when creating a new database. @@ -140,7 +162,7 @@ impl Runtime { #[cold] pub(crate) fn unwind_cancelled(&self) { self.report_untracked_read(); - Cancelled::throw(); + Cancelled::PendingWrite.throw(); } /// Acquires the **global query write lock** (ensuring that no queries are @@ -201,70 +223,31 @@ impl Runtime { self.revision_guard.is_none() && !self.local_state.query_in_progress() } - pub(crate) fn execute_query_implementation( - &self, - db: &DB, - database_key_index: DatabaseKeyIndex, - execute: impl FnOnce() -> V, - ) -> ComputedQueryResult - where - DB: ?Sized + Database, - { - debug!( - "{:?}: execute_query_implementation invoked", - database_key_index - ); - - db.salsa_event(Event { - runtime_id: self.id(), - kind: EventKind::WillExecute { - database_key: database_key_index, - }, - }); - - // Push the active query onto the stack. - let max_durability = Durability::MAX; - let active_query = self - .local_state - .push_query(database_key_index, max_durability); - - // Execute user's code, accumulating inputs etc. - let value = execute(); - - // Extract accumulated inputs. - let ActiveQuery { - dependencies, - changed_at, - durability, - cycle, - .. - } = active_query.complete(); - - ComputedQueryResult { - value, - durability, - changed_at, - dependencies, - cycle, - } + #[inline] + pub(crate) fn push_query(&self, database_key_index: DatabaseKeyIndex) -> ActiveQueryGuard<'_> { + self.local_state.push_query(database_key_index) } /// Reports that the currently active query read the result from /// another query. /// + /// Also checks whether the "cycle participant" flag is set on + /// the current stack frame -- if so, panics with `CycleParticipant` + /// value, which should be caught by the code executing the query. + /// /// # Parameters /// /// - `database_key`: the query whose result was read /// - `changed_revision`: the last revision in which the result of that /// query had changed - pub(crate) fn report_query_read( + pub(crate) fn report_query_read_and_unwind_if_cycle_resulted( &self, input: DatabaseKeyIndex, durability: Durability, changed_at: Revision, ) { self.local_state - .report_query_read(input, durability, changed_at); + .report_query_read_and_unwind_if_cycle_resulted(input, durability, changed_at); } /// Reports that the query depends on some state unknown to salsa. @@ -280,113 +263,205 @@ impl Runtime { /// /// This is mostly useful to control the durability level for [on-demand inputs](https://salsa-rs.github.io/salsa/common_patterns/on_demand_inputs.html). pub fn report_synthetic_read(&self, durability: Durability) { + let changed_at = self.last_changed_revision(durability); self.local_state - .report_synthetic_read(durability, self.current_revision()); - } - - /// Obviously, this should be user configurable at some point. - pub(crate) fn report_unexpected_cycle( + .report_synthetic_read(durability, changed_at); + } + + /// Handles a cycle in the dependency graph that was detected when the + /// current thread tried to block on `database_key_index` which is being + /// executed by `to_id`. If this function returns, then `to_id` no longer + /// depends on the current thread, and so we should continue executing + /// as normal. Otherwise, the function will throw a `Cycle` which is expected + /// to be caught by some frame on our stack. This occurs either if there is + /// a frame on our stack with cycle recovery (possibly the top one!) or if there + /// is no cycle recovery at all. + fn unblock_cycle_and_maybe_throw( &self, + db: &dyn Database, + dg: &mut DependencyGraph, database_key_index: DatabaseKeyIndex, - error: CycleDetected, - changed_at: Revision, - ) -> crate::CycleError { + to_id: RuntimeId, + ) { debug!( - "report_unexpected_cycle(database_key={:?})", + "unblock_cycle_and_maybe_throw(database_key={:?})", database_key_index ); - let mut query_stack = self.local_state.borrow_query_stack_mut(); - - if error.from == error.to { - // All queries in the cycle is local - let start_index = query_stack - .iter() - .rposition(|active_query| active_query.database_key_index == database_key_index) - .unwrap(); - let mut cycle = Vec::new(); - let cycle_participants = &mut query_stack[start_index..]; - for active_query in &mut *cycle_participants { - cycle.push(active_query.database_key_index); - } + let mut from_stack = self.local_state.take_query_stack(); + let from_id = self.id(); - assert!(!cycle.is_empty()); + // Make a "dummy stack frame". As we iterate through the cycle, we will collect the + // inputs from each participant. Then, if we are participating in cycle recovery, we + // will propagate those results to all participants. + let mut cycle_query = ActiveQuery::new(database_key_index); - for active_query in cycle_participants { - active_query.cycle = cycle.clone(); - } - - crate::CycleError { - cycle, - changed_at, - durability: Durability::MAX, - } - } else { - // Part of the cycle is on another thread so we need to lock and inspect the shared - // state - let dependency_graph = self.shared_state.dependency_graph.lock(); - - let mut cycle = Vec::new(); - dependency_graph.push_cycle_path( + // Identify the cycle participants: + let cycle = { + let mut v = vec![]; + dg.for_each_cycle_participant( + from_id, + &mut from_stack, database_key_index, - error.to, - query_stack.iter().map(|query| query.database_key_index), - &mut cycle, + to_id, + |aqs| { + aqs.iter_mut().for_each(|aq| { + cycle_query.add_from(aq); + v.push(aq.database_key_index); + }); + }, ); - cycle.push(database_key_index); - assert!(!cycle.is_empty()); + // We want to give the participants in a deterministic order + // (at least for this execution, not necessarily across executions), + // no matter where it started on the stack. Find the minimum + // key and rotate it to the front. + let min = v.iter().min().unwrap(); + let index = v.iter().position(|p| p == min).unwrap(); + v.rotate_left(index); - for active_query in query_stack - .iter_mut() - .filter(|query| cycle.iter().any(|key| *key == query.database_key_index)) - { - active_query.cycle = cycle.clone(); - } + // No need to store extra memory. + v.shrink_to_fit(); - crate::CycleError { - cycle, - changed_at, - durability: Durability::MAX, - } + Cycle::new(Arc::new(v)) + }; + debug!( + "cycle {:?}, cycle_query {:#?}", + cycle.debug(db), + cycle_query, + ); + + // We can remove the cycle participants from the list of dependencies; + // they are a strongly connected component (SCC) and we only care about + // dependencies to things outside the SCC that control whether it will + // form again. + cycle_query.remove_cycle_participants(&cycle); + + // Mark each cycle participant that has recovery set, along with + // any frames that come after them on the same thread. Those frames + // are going to be unwound so that fallback can occur. + dg.for_each_cycle_participant(from_id, &mut from_stack, database_key_index, to_id, |aqs| { + aqs.iter_mut() + .skip_while( + |aq| match db.cycle_recovery_strategy(aq.database_key_index) { + CycleRecoveryStrategy::Panic => true, + CycleRecoveryStrategy::Fallback => false, + }, + ) + .for_each(|aq| { + debug!("marking {:?} for fallback", aq.database_key_index.debug(db)); + aq.take_inputs_from(&cycle_query); + assert!(aq.cycle.is_none()); + aq.cycle = Some(cycle.clone()); + }); + }); + + // Unblock every thread that has cycle recovery with a `WaitResult::Cycle`. + // They will throw the cycle, which will be caught by the frame that has + // cycle recovery so that it can execute that recovery. + let (me_recovered, others_recovered) = + dg.maybe_unblock_runtimes_in_cycle(from_id, &from_stack, database_key_index, to_id); + + self.local_state.restore_query_stack(from_stack); + + if me_recovered { + // If the current thread has recovery, we want to throw + // so that it can begin. + cycle.throw() + } else if others_recovered { + // If other threads have recovery but we didn't: return and we will block on them. + } else { + // if nobody has recover, then we panic + panic_any(cycle); } } - pub(crate) fn mark_cycle_participants(&self, err: &CycleError) { - for active_query in self - .local_state - .borrow_query_stack_mut() - .iter_mut() - .rev() - .take_while(|active_query| { - err.cycle - .iter() - .any(|e| *e == active_query.database_key_index) - }) - { - active_query.cycle = err.cycle.clone(); + /// Block until `other_id` completes executing `database_key`; + /// panic or unwind in the case of a cycle. + /// + /// `query_mutex_guard` is the guard for the current query's state; + /// it will be dropped after we have successfully registered the + /// dependency. + /// + /// # Propagating panics + /// + /// If the thread `other_id` panics, then our thread is considered + /// cancelled, so this function will panic with a `Cancelled` value. + /// + /// # Cycle handling + /// + /// If the thread `other_id` already depends on the current thread, + /// and hence there is a cycle in the query graph, then this function + /// will unwind instead of returning normally. The method of unwinding + /// depends on the [`Self::mutual_cycle_recovery_strategy`] + /// of the cycle participants: + /// + /// * [`CycleRecoveryStrategy::Panic`]: panic with the [`Cycle`] as the value. + /// * [`CycleRecoveryStrategy::Fallback`]: initiate unwinding with [`CycleParticipant::unwind`]. + pub(crate) fn block_on_or_unwind( + &self, + db: &dyn Database, + database_key: DatabaseKeyIndex, + other_id: RuntimeId, + query_mutex_guard: QueryMutexGuard, + ) { + let mut dg = self.shared_state.dependency_graph.lock(); + + if dg.depends_on(other_id, self.id()) { + self.unblock_cycle_and_maybe_throw(db, &mut dg, database_key, other_id); + + // If the above fn returns, then (via cycle recovery) it has unblocked the + // cycle, so we can continue. + assert!(!dg.depends_on(other_id, self.id())); } - } - /// Try to make this runtime blocked on `other_id`. Returns true - /// upon success or false if `other_id` is already blocked on us. - pub(crate) fn try_block_on(&self, database_key: DatabaseKeyIndex, other_id: RuntimeId) -> bool { - self.shared_state.dependency_graph.lock().add_edge( + db.salsa_event(Event { + runtime_id: self.id(), + kind: EventKind::WillBlockOn { + other_runtime_id: other_id, + database_key, + }, + }); + + let stack = self.local_state.take_query_stack(); + + let (stack, result) = DependencyGraph::block_on( + dg, self.id(), database_key, other_id, - self.local_state - .borrow_query_stack() - .iter() - .map(|query| query.database_key_index), - ) + stack, + query_mutex_guard, + ); + + self.local_state.restore_query_stack(stack); + + match result { + WaitResult::Completed => (), + + // If the other thread panicked, then we consider this thread + // cancelled. The assumption is that the panic will be detected + // by the other thread and responded to appropriately. + WaitResult::Panicked => Cancelled::PropagatedPanic.throw(), + + WaitResult::Cycle(c) => c.throw(), + } } - pub(crate) fn unblock_queries_blocked_on_self(&self, database_key_index: DatabaseKeyIndex) { + /// Invoked when this runtime completed computing `database_key` with + /// the given result `wait_result` (`wait_result` should be `None` if + /// computing `database_key` panicked and could not complete). + /// This function unblocks any dependent queries and allows them + /// to continue executing. + pub(crate) fn unblock_queries_blocked_on( + &self, + database_key: DatabaseKeyIndex, + wait_result: WaitResult, + ) { self.shared_state .dependency_graph .lock() - .remove_edge(database_key_index, self.id()) + .unblock_runtimes_blocked_on(database_key, wait_result); } } @@ -423,7 +498,7 @@ struct SharedState { /// The dependency graph tracks which runtimes are blocked on one /// another, waiting for queries to terminate. - dependency_graph: Mutex>, + dependency_graph: Mutex, } impl SharedState { @@ -463,6 +538,7 @@ impl std::fmt::Debug for SharedState { } } +#[derive(Debug)] struct ActiveQuery { /// What query is executing database_key_index: DatabaseKeyIndex, @@ -479,36 +555,17 @@ struct ActiveQuery { dependencies: Option>, /// Stores the entire cycle, if one is found and this query is part of it. - cycle: Vec, -} - -pub(crate) struct ComputedQueryResult { - /// Final value produced - pub(crate) value: V, - - /// Minimum durability of inputs observed so far. - pub(crate) durability: Durability, - - /// Maximum revision of all inputs observed. If we observe an - /// untracked read, this will be set to the most recent revision. - pub(crate) changed_at: Revision, - - /// Complete set of subqueries that were accessed, or `None` if - /// there was an untracked read. - pub(crate) dependencies: Option>, - - /// The cycle if one occured while computing this value - pub(crate) cycle: Vec, + cycle: Option, } impl ActiveQuery { - fn new(database_key_index: DatabaseKeyIndex, max_durability: Durability) -> Self { + fn new(database_key_index: DatabaseKeyIndex) -> Self { ActiveQuery { database_key_index, - durability: max_durability, + durability: Durability::MAX, changed_at: Revision::start(), dependencies: Some(FxIndexSet::default()), - cycle: Vec::new(), + cycle: None, } } @@ -527,9 +584,64 @@ impl ActiveQuery { self.changed_at = changed_at; } - fn add_synthetic_read(&mut self, durability: Durability, current_revision: Revision) { + fn add_synthetic_read(&mut self, durability: Durability, revision: Revision) { + self.dependencies = None; self.durability = self.durability.min(durability); - self.changed_at = current_revision; + self.changed_at = self.changed_at.max(revision); + } + + pub(crate) fn revisions(&self) -> QueryRevisions { + let inputs = match &self.dependencies { + None => QueryInputs::Untracked, + + Some(dependencies) => { + if dependencies.is_empty() { + QueryInputs::NoInputs + } else { + QueryInputs::Tracked { + inputs: dependencies.iter().copied().collect(), + } + } + } + }; + + QueryRevisions { + changed_at: self.changed_at, + inputs, + durability: self.durability, + } + } + + /// Adds any dependencies from `other` into `self`. + /// Used during cycle recovery, see [`Runtime::create_cycle_error`]. + fn add_from(&mut self, other: &ActiveQuery) { + self.changed_at = self.changed_at.max(other.changed_at); + self.durability = self.durability.min(other.durability); + if let Some(other_dependencies) = &other.dependencies { + if let Some(my_dependencies) = &mut self.dependencies { + my_dependencies.extend(other_dependencies.iter().copied()); + } + } else { + self.dependencies = None; + } + } + + /// Removes the participants in `cycle` from my dependencies. + /// Used during cycle recovery, see [`Runtime::create_cycle_error`]. + fn remove_cycle_participants(&mut self, cycle: &Cycle) { + if let Some(my_dependencies) = &mut self.dependencies { + for p in cycle.participant_keys() { + my_dependencies.remove(&p); + } + } + } + + /// Copy the changed-at, durability, and dependencies from `cycle_query`. + /// Used during cycle recovery, see [`Runtime::create_cycle_error`]. + pub(crate) fn take_inputs_from(&mut self, cycle_query: &ActiveQuery) { + self.changed_at = cycle_query.changed_at; + self.durability = cycle_query.durability; + self.dependencies = cycle_query.dependencies.clone(); } } @@ -548,118 +660,6 @@ pub(crate) struct StampedValue { pub(crate) changed_at: Revision, } -#[derive(Debug)] -struct Edge { - id: RuntimeId, - path: Vec, -} - -#[derive(Debug)] -struct DependencyGraph { - /// A `(K -> V)` pair in this map indicates that the the runtime - /// `K` is blocked on some query executing in the runtime `V`. - /// This encodes a graph that must be acyclic (or else deadlock - /// will result). - edges: FxHashMap>, - labels: FxHashMap>, -} - -impl Default for DependencyGraph -where - K: Hash + Eq, -{ - fn default() -> Self { - DependencyGraph { - edges: Default::default(), - labels: Default::default(), - } - } -} - -impl DependencyGraph -where - K: Hash + Eq + Clone, -{ - /// Attempt to add an edge `from_id -> to_id` into the result graph. - fn add_edge( - &mut self, - from_id: RuntimeId, - database_key: K, - to_id: RuntimeId, - path: impl IntoIterator, - ) -> bool { - assert_ne!(from_id, to_id); - debug_assert!(!self.edges.contains_key(&from_id)); - - // First: walk the chain of things that `to_id` depends on, - // looking for us. - let mut p = to_id; - while let Some(q) = self.edges.get(&p).map(|edge| edge.id) { - if q == from_id { - return false; - } - - p = q; - } - - self.edges.insert( - from_id, - Edge { - id: to_id, - path: path.into_iter().chain(Some(database_key.clone())).collect(), - }, - ); - self.labels.entry(database_key).or_default().push(from_id); - true - } - - fn remove_edge(&mut self, database_key: K, to_id: RuntimeId) { - let vec = self.labels.remove(&database_key).unwrap_or_default(); - - for from_id in &vec { - let to_id1 = self.edges.remove(from_id).map(|edge| edge.id); - assert_eq!(Some(to_id), to_id1); - } - } - - fn push_cycle_path( - &self, - database_key: K, - to: RuntimeId, - local_path: impl IntoIterator, - output: &mut Vec, - ) where - K: std::fmt::Debug, - { - let mut current = Some((to, std::slice::from_ref(&database_key))); - let mut last = None; - let mut local_path = Some(local_path); - - while let Some((id, path)) = current.take() { - let link_key = path.last().unwrap(); - output.extend(path.iter().cloned()); - - current = self.edges.get(&id).map(|edge| { - let i = edge.path.iter().rposition(|p| p == link_key).unwrap(); - (edge.id, &edge.path[i + 1..]) - }); - - if current.is_none() { - last = local_path.take().map(|local_path| { - local_path - .into_iter() - .skip_while(move |p| *p != *link_key) - .skip(1) - }); - } - } - - if let Some(iter) = &mut last { - output.extend(iter); - } - } -} - struct RevisionGuard { shared_state: Arc, } @@ -701,34 +701,3 @@ impl Drop for RevisionGuard { } } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn dependency_graph_path1() { - let mut graph = DependencyGraph::default(); - let a = RuntimeId { counter: 0 }; - let b = RuntimeId { counter: 1 }; - assert!(graph.add_edge(a, 2, b, vec![1])); - // assert!(graph.add_edge(b, &1, a, vec![3, 2])); - let mut v = vec![]; - graph.push_cycle_path(1, a, vec![3, 2], &mut v); - assert_eq!(v, vec![1, 2]); - } - - #[test] - fn dependency_graph_path2() { - let mut graph = DependencyGraph::default(); - let a = RuntimeId { counter: 0 }; - let b = RuntimeId { counter: 1 }; - let c = RuntimeId { counter: 2 }; - assert!(graph.add_edge(a, 3, b, vec![1])); - assert!(graph.add_edge(b, 4, c, vec![2, 3])); - // assert!(graph.add_edge(c, &1, a, vec![5, 6, 4, 7])); - let mut v = vec![]; - graph.push_cycle_path(1, a, vec![5, 6, 4, 7], &mut v); - assert_eq!(v, vec![1, 3, 4, 7]); - } -} diff --git a/src/runtime/dependency_graph.rs b/src/runtime/dependency_graph.rs new file mode 100644 index 00000000..b7fdf3ef --- /dev/null +++ b/src/runtime/dependency_graph.rs @@ -0,0 +1,277 @@ +use std::sync::Arc; + +use crate::{DatabaseKeyIndex, RuntimeId}; +use parking_lot::{Condvar, MutexGuard}; +use rustc_hash::FxHashMap; +use smallvec::SmallVec; + +use super::{ActiveQuery, WaitResult}; + +type QueryStack = Vec; + +#[derive(Debug, Default)] +pub(super) struct DependencyGraph { + /// A `(K -> V)` pair in this map indicates that the the runtime + /// `K` is blocked on some query executing in the runtime `V`. + /// This encodes a graph that must be acyclic (or else deadlock + /// will result). + edges: FxHashMap, + + /// Encodes the `RuntimeId` that are blocked waiting for the result + /// of a given query. + query_dependents: FxHashMap>, + + /// When a key K completes which had dependent queries Qs blocked on it, + /// it stores its `WaitResult` here. As they wake up, each query Q in Qs will + /// come here to fetch their results. + wait_results: FxHashMap, +} + +#[derive(Debug)] +struct Edge { + blocked_on_id: RuntimeId, + blocked_on_key: DatabaseKeyIndex, + stack: QueryStack, + + /// Signalled whenever a query with dependents completes. + /// Allows those dependents to check if they are ready to unblock. + condvar: Arc, +} + +impl DependencyGraph { + /// True if `from_id` depends on `to_id`. + /// + /// (i.e., there is a path from `from_id` to `to_id` in the graph.) + pub(super) fn depends_on(&mut self, from_id: RuntimeId, to_id: RuntimeId) -> bool { + let mut p = from_id; + while let Some(q) = self.edges.get(&p).map(|edge| edge.blocked_on_id) { + if q == to_id { + return true; + } + + p = q; + } + p == to_id + } + + /// Invokes `closure` with a `&mut ActiveQuery` for each query that participates in the cycle. + /// The cycle runs as follows: + /// + /// 1. The runtime `from_id`, which has the stack `from_stack`, would like to invoke `database_key`... + /// 2. ...but `database_key` is already being executed by `to_id`... + /// 3. ...and `to_id` is transitively dependent on something which is present on `from_stack`. + pub(super) fn for_each_cycle_participant( + &mut self, + from_id: RuntimeId, + from_stack: &mut QueryStack, + database_key: DatabaseKeyIndex, + to_id: RuntimeId, + mut closure: impl FnMut(&mut [ActiveQuery]), + ) { + debug_assert!(self.depends_on(to_id, from_id)); + + // To understand this algorithm, consider this [drawing](https://is.gd/TGLI9v): + // + // database_key = QB2 + // from_id = A + // to_id = B + // from_stack = [QA1, QA2, QA3] + // + // self.edges[B] = { C, QC2, [QB1..QB3] } + // self.edges[C] = { A, QA2, [QC1..QC3] } + // + // The cyclic + // edge we have + // failed to add. + // : + // A : B C + // : + // QA1 v QB1 QC1 + // ┌► QA2 ┌──► QB2 ┌─► QC2 + // │ QA3 ───┘ QB3 ──┘ QC3 ───┐ + // │ │ + // └───────────────────────────────┘ + // + // Final output: [QB2, QB3, QC2, QC3, QA2, QA3] + + let mut id = to_id; + let mut key = database_key; + while id != from_id { + // Looking at the diagram above, the idea is to + // take the edge from `to_id` starting at `key` + // (inclusive) and down to the end. We can then + // load up the next thread (i.e., we start at B/QB2, + // and then load up the dependency on C/QC2). + let edge = self.edges.get_mut(&id).unwrap(); + let prefix = edge + .stack + .iter_mut() + .take_while(|p| p.database_key_index != key) + .count(); + closure(&mut edge.stack[prefix..]); + id = edge.blocked_on_id; + key = edge.blocked_on_key; + } + + // Finally, we copy in the results from `from_stack`. + let prefix = from_stack + .iter_mut() + .take_while(|p| p.database_key_index != key) + .count(); + closure(&mut from_stack[prefix..]); + } + + /// Unblock each blocked runtime (excluding the current one) if some + /// query executing in that runtime is participating in cycle fallback. + /// + /// Returns a boolean (Current, Others) where: + /// * Current is true if the current runtime has cycle participants + /// with fallback; + /// * Others is true if other runtimes were unblocked. + pub(super) fn maybe_unblock_runtimes_in_cycle( + &mut self, + from_id: RuntimeId, + from_stack: &QueryStack, + database_key: DatabaseKeyIndex, + to_id: RuntimeId, + ) -> (bool, bool) { + // See diagram in `for_each_cycle_participant`. + let mut id = to_id; + let mut key = database_key; + let mut others_unblocked = false; + while id != from_id { + let edge = self.edges.get(&id).unwrap(); + let prefix = edge + .stack + .iter() + .take_while(|p| p.database_key_index != key) + .count(); + let next_id = edge.blocked_on_id; + let next_key = edge.blocked_on_key; + + if let Some(cycle) = edge.stack[prefix..] + .iter() + .rev() + .find_map(|aq| aq.cycle.clone()) + { + // Remove `id` from the list of runtimes blocked on `next_key`: + self.query_dependents + .get_mut(&next_key) + .unwrap() + .retain(|r| *r != id); + + // Unblock runtime so that it can resume execution once lock is released: + self.unblock_runtime(id, WaitResult::Cycle(cycle)); + + others_unblocked = true; + } + + id = next_id; + key = next_key; + } + + let prefix = from_stack + .iter() + .take_while(|p| p.database_key_index != key) + .count(); + let this_unblocked = from_stack[prefix..].iter().any(|aq| aq.cycle.is_some()); + + (this_unblocked, others_unblocked) + } + + /// Modifies the graph so that `from_id` is blocked + /// on `database_key`, which is being computed by + /// `to_id`. + /// + /// For this to be reasonable, the lock on the + /// results table for `database_key` must be held. + /// This ensures that computing `database_key` doesn't + /// complete before `block_on` executes. + /// + /// Preconditions: + /// * No path from `to_id` to `from_id` + /// (i.e., `me.depends_on(to_id, from_id)` is false) + /// * `held_mutex` is a read lock (or stronger) on `database_key` + pub(super) fn block_on( + mut me: MutexGuard<'_, Self>, + from_id: RuntimeId, + database_key: DatabaseKeyIndex, + to_id: RuntimeId, + from_stack: QueryStack, + query_mutex_guard: QueryMutexGuard, + ) -> (QueryStack, WaitResult) { + let condvar = me.add_edge(from_id, database_key, to_id, from_stack); + + // Release the mutex that prevents `database_key` + // from completing, now that the edge has been added. + drop(query_mutex_guard); + + loop { + if let Some(stack_and_result) = me.wait_results.remove(&from_id) { + debug_assert!(!me.edges.contains_key(&from_id)); + return stack_and_result; + } + condvar.wait(&mut me); + } + } + + /// Helper for `block_on`: performs actual graph modification + /// to add a dependency edge from `from_id` to `to_id`, which is + /// computing `database_key`. + fn add_edge( + &mut self, + from_id: RuntimeId, + database_key: DatabaseKeyIndex, + to_id: RuntimeId, + from_stack: QueryStack, + ) -> Arc { + assert_ne!(from_id, to_id); + debug_assert!(!self.edges.contains_key(&from_id)); + debug_assert!(!self.depends_on(to_id, from_id)); + + let condvar = Arc::new(Condvar::new()); + self.edges.insert( + from_id, + Edge { + blocked_on_id: to_id, + blocked_on_key: database_key, + stack: from_stack, + condvar: condvar.clone(), + }, + ); + self.query_dependents + .entry(database_key) + .or_default() + .push(from_id); + condvar + } + + /// Invoked when runtime `to_id` completes executing + /// `database_key`. + pub(super) fn unblock_runtimes_blocked_on( + &mut self, + database_key: DatabaseKeyIndex, + wait_result: WaitResult, + ) { + let dependents = self + .query_dependents + .remove(&database_key) + .unwrap_or_default(); + + for from_id in dependents { + self.unblock_runtime(from_id, wait_result.clone()); + } + } + + /// Unblock the runtime with the given id with the given wait-result. + /// This will cause it resume execution (though it will have to grab + /// the lock on this data structure first, to recover the wait result). + fn unblock_runtime(&mut self, id: RuntimeId, wait_result: WaitResult) { + let edge = self.edges.remove(&id).expect("not blocked"); + self.wait_results.insert(id, (edge.stack, wait_result)); + + // Now that we have inserted the `wait_results`, + // notify the thread. + edge.condvar.notify_one(); + } +} diff --git a/src/runtime/local_state.rs b/src/runtime/local_state.rs index 04fd6f8e..d84dafa1 100644 --- a/src/runtime/local_state.rs +++ b/src/runtime/local_state.rs @@ -1,8 +1,12 @@ +use log::debug; + use crate::durability::Durability; use crate::runtime::ActiveQuery; use crate::runtime::Revision; +use crate::Cycle; use crate::DatabaseKeyIndex; -use std::cell::{Ref, RefCell, RefMut}; +use std::cell::RefCell; +use std::sync::Arc; /// State that is specific to a single execution thread. /// @@ -13,78 +17,157 @@ use std::cell::{Ref, RefCell, RefMut}; pub(super) struct LocalState { /// Vector of active queries. /// + /// This is normally `Some`, but it is set to `None` + /// while the query is blocked waiting for a result. + /// /// Unwinding note: pushes onto this vector must be popped -- even /// during unwinding. - query_stack: RefCell>, + query_stack: RefCell>>, +} + +/// Summarizes "all the inputs that a query used" +#[derive(Debug, Clone)] +pub(crate) struct QueryRevisions { + /// The most revision in which some input changed. + pub(crate) changed_at: Revision, + + /// Minimum durability of the inputs to this query. + pub(crate) durability: Durability, + + /// The inputs that went into our query, if we are tracking them. + pub(crate) inputs: QueryInputs, +} + +/// Every input. +#[derive(Debug, Clone)] +pub(crate) enum QueryInputs { + /// Non-empty set of inputs, fully known + Tracked { inputs: Arc<[DatabaseKeyIndex]> }, + + /// Empty set of inputs, fully known. + NoInputs, + + /// Unknown quantity of inputs + Untracked, } impl Default for LocalState { fn default() -> Self { LocalState { - query_stack: Default::default(), + query_stack: RefCell::new(Some(Vec::new())), } } } impl LocalState { - pub(super) fn push_query( - &self, - database_key_index: DatabaseKeyIndex, - max_durability: Durability, - ) -> ActiveQueryGuard<'_> { + #[inline] + pub(super) fn push_query(&self, database_key_index: DatabaseKeyIndex) -> ActiveQueryGuard<'_> { let mut query_stack = self.query_stack.borrow_mut(); - query_stack.push(ActiveQuery::new(database_key_index, max_durability)); + let query_stack = query_stack.as_mut().expect("local stack taken"); + query_stack.push(ActiveQuery::new(database_key_index)); ActiveQueryGuard { local_state: self, + database_key_index, push_len: query_stack.len(), } } - /// Returns a reference to the active query stack. - /// - /// **Warning:** Because this reference holds the ref-cell lock, - /// you should not use any mutating methods of `LocalState` while - /// reading from it. - pub(super) fn borrow_query_stack(&self) -> Ref<'_, Vec> { - self.query_stack.borrow() - } - - pub(super) fn borrow_query_stack_mut(&self) -> RefMut<'_, Vec> { - self.query_stack.borrow_mut() + fn with_query_stack(&self, c: impl FnOnce(&mut Vec) -> R) -> R { + c(self + .query_stack + .borrow_mut() + .as_mut() + .expect("query stack taken")) } pub(super) fn query_in_progress(&self) -> bool { - !self.query_stack.borrow().is_empty() + self.with_query_stack(|stack| !stack.is_empty()) } pub(super) fn active_query(&self) -> Option { - self.query_stack - .borrow() - .last() - .map(|active_query| active_query.database_key_index) + self.with_query_stack(|stack| { + stack + .last() + .map(|active_query| active_query.database_key_index) + }) } - pub(super) fn report_query_read( + pub(super) fn report_query_read_and_unwind_if_cycle_resulted( &self, input: DatabaseKeyIndex, durability: Durability, changed_at: Revision, ) { - if let Some(top_query) = self.query_stack.borrow_mut().last_mut() { - top_query.add_read(input, durability, changed_at); - } + debug!( + "report_query_read_and_unwind_if_cycle_resulted(input={:?}, durability={:?}, changed_at={:?})", + input, durability, changed_at + ); + self.with_query_stack(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.add_read(input, durability, changed_at); + + // We are a cycle participant: + // + // C0 --> ... --> Ci --> Ci+1 -> ... -> Cn --> C0 + // ^ ^ + // : | + // This edge -----+ | + // | + // | + // N0 + // + // In this case, the value we have just read from `Ci+1` + // is actually the cycle fallback value and not especially + // interesting. We unwind now with `CycleParticipant` to avoid + // executing the rest of our query function. This unwinding + // will be caught and our own fallback value will be used. + // + // Note that `Ci+1` may` have *other* callers who are not + // participants in the cycle (e.g., N0 in the graph above). + // They will not have the `cycle` marker set in their + // stack frames, so they will just read the fallback value + // from `Ci+1` and continue on their merry way. + if let Some(cycle) = &top_query.cycle { + cycle.clone().throw() + } + } + }) } pub(super) fn report_untracked_read(&self, current_revision: Revision) { - if let Some(top_query) = self.query_stack.borrow_mut().last_mut() { - top_query.add_untracked_read(current_revision); - } + self.with_query_stack(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.add_untracked_read(current_revision); + } + }) } - pub(super) fn report_synthetic_read(&self, durability: Durability, current_revision: Revision) { - if let Some(top_query) = self.query_stack.borrow_mut().last_mut() { - top_query.add_synthetic_read(durability, current_revision); - } + /// Update the top query on the stack to act as though it read a value + /// of durability `durability` which changed in `revision`. + pub(super) fn report_synthetic_read(&self, durability: Durability, revision: Revision) { + self.with_query_stack(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.add_synthetic_read(durability, revision); + } + }) + } + + /// Takes the query stack and returns it. This is used when + /// the current thread is blocking. The stack must be restored + /// with [`Self::restore_query_stack`] when the thread unblocks. + pub(super) fn take_query_stack(&self) -> Vec { + assert!( + self.query_stack.borrow().is_some(), + "query stack already taken" + ); + self.query_stack.take().unwrap() + } + + /// Restores a query stack taken with [`Self::take_query_stack`] once + /// the thread unblocks. + pub(super) fn restore_query_stack(&self, stack: Vec) { + assert!(self.query_stack.borrow().is_none(), "query stack not taken"); + self.query_stack.replace(Some(stack)); } } @@ -94,19 +177,23 @@ impl std::panic::RefUnwindSafe for LocalState {} /// is returned to represent its slot. The guard can be used to pop /// the query from the stack -- in the case of unwinding, the guard's /// destructor will also remove the query. -pub(super) struct ActiveQueryGuard<'me> { +pub(crate) struct ActiveQueryGuard<'me> { local_state: &'me LocalState, push_len: usize, + database_key_index: DatabaseKeyIndex, } impl ActiveQueryGuard<'_> { fn pop_helper(&self) -> ActiveQuery { - let mut query_stack = self.local_state.query_stack.borrow_mut(); - - // Sanity check: pushes and pops should be balanced. - assert_eq!(query_stack.len(), self.push_len); - - query_stack.pop().unwrap() + self.local_state.with_query_stack(|stack| { + // Sanity check: pushes and pops should be balanced. + assert_eq!(stack.len(), self.push_len); + debug_assert_eq!( + stack.last().unwrap().database_key_index, + self.database_key_index + ); + stack.pop().unwrap() + }) } /// Invoked when the query has successfully completed execution. @@ -115,6 +202,27 @@ impl ActiveQueryGuard<'_> { std::mem::forget(self); query } + + /// Pops an active query from the stack. Returns the [`QueryRevisions`] + /// which summarizes the other queries that were accessed during this + /// query's execution. + #[inline] + pub(crate) fn pop(self) -> QueryRevisions { + // Extract accumulated inputs. + let popped_query = self.complete(); + + // If this frame were a cycle participant, it would have unwound. + assert!(popped_query.cycle.is_none()); + + popped_query.revisions() + } + + /// If the active query is registered as a cycle participant, remove and + /// return that cycle. + pub(crate) fn take_cycle(&self) -> Option { + self.local_state + .with_query_stack(|stack| stack.last_mut()?.cycle.take()) + } } impl Drop for ActiveQueryGuard<'_> { diff --git a/tests/cycles.rs b/tests/cycles.rs index 7e697fed..f3946630 100644 --- a/tests/cycles.rs +++ b/tests/cycles.rs @@ -1,4 +1,48 @@ -use salsa::{ParallelDatabase, Snapshot}; +use std::panic::UnwindSafe; + +use salsa::{Durability, ParallelDatabase, Snapshot}; +use test_env_log::test; + +// Axes: +// +// Threading +// * Intra-thread +// * Cross-thread -- part of cycle is on one thread, part on another +// +// Recovery strategies: +// * Panic +// * Fallback +// * Mixed -- multiple strategies within cycle participants +// +// Across revisions: +// * N/A -- only one revision +// * Present in new revision, not old +// * Present in old revision, not new +// * Present in both revisions +// +// Dependencies +// * Tracked +// * Untracked -- cycle participant(s) contain untracked reads +// +// Layers +// * Direct -- cycle participant is directly invoked from test +// * Indirect -- invoked a query that invokes the cycle +// +// +// | Thread | Recovery | Old, New | Dep style | Layers | Test Name | +// | ------ | -------- | -------- | --------- | ------ | --------- | +// | Intra | Panic | N/A | Tracked | direct | cycle_memoized | +// | Intra | Panic | N/A | Untracked | direct | cycle_volatile | +// | Intra | Fallback | N/A | Tracked | direct | cycle_cycle | +// | Intra | Fallback | N/A | Tracked | indirect | inner_cycle | +// | Intra | Fallback | Both | Tracked | direct | cycle_revalidate | +// | Intra | Fallback | New | Tracked | direct | cycle_appears | +// | Intra | Fallback | Old | Tracked | direct | cycle_disappears | +// | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability | +// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 | +// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 | +// | Cross | Fallback | N/A | Tracked | both | parallel/cycles.rs: recover_parallel_cycle | +// | Cross | Panic | N/A | Tracked | both | parallel/cycles.rs: panic_parallel_cycle | #[derive(PartialEq, Eq, Hash, Clone, Debug)] struct Error { @@ -6,7 +50,6 @@ struct Error { } #[salsa::database(GroupStruct)] -#[derive(Default)] struct DatabaseImpl { storage: salsa::Storage, } @@ -21,6 +64,28 @@ impl ParallelDatabase for DatabaseImpl { } } +impl Default for DatabaseImpl { + fn default() -> Self { + let res = DatabaseImpl { + storage: salsa::Storage::default(), + }; + + res + } +} + +/// The queries A, B, and C in `Database` can be configured +/// to invoke one another in arbitrary ways using this +/// enum. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum CycleQuery { + None, + A, + B, + C, + AthenC, +} + #[salsa::query_group(GroupStruct)] trait Database: salsa::Database { // `a` and `b` depend on each other and form a cycle @@ -29,25 +94,33 @@ trait Database: salsa::Database { fn volatile_a(&self) -> (); fn volatile_b(&self) -> (); - fn cycle_leaf(&self) -> (); + #[salsa::input] + fn a_invokes(&self) -> CycleQuery; + + #[salsa::input] + fn b_invokes(&self) -> CycleQuery; + + #[salsa::input] + fn c_invokes(&self) -> CycleQuery; #[salsa::cycle(recover_a)] fn cycle_a(&self) -> Result<(), Error>; + #[salsa::cycle(recover_b)] fn cycle_b(&self) -> Result<(), Error>; fn cycle_c(&self) -> Result<(), Error>; } -fn recover_a(_db: &dyn Database, cycle: &[String]) -> Result<(), Error> { +fn recover_a(db: &dyn Database, cycle: &salsa::Cycle) -> Result<(), Error> { Err(Error { - cycle: cycle.to_owned(), + cycle: cycle.all_participants(db), }) } -fn recover_b(_db: &dyn Database, cycle: &[String]) -> Result<(), Error> { +fn recover_b(db: &dyn Database, cycle: &salsa::Cycle) -> Result<(), Error> { Err(Error { - cycle: cycle.to_owned(), + cycle: cycle.all_participants(db), }) } @@ -69,94 +142,360 @@ fn volatile_b(db: &dyn Database) { db.volatile_a() } -fn cycle_leaf(_db: &dyn Database) {} +impl CycleQuery { + fn invoke(self, db: &dyn Database) -> Result<(), Error> { + match self { + CycleQuery::A => db.cycle_a(), + CycleQuery::B => db.cycle_b(), + CycleQuery::C => db.cycle_c(), + CycleQuery::AthenC => { + let _ = db.cycle_a(); + db.cycle_c() + } + CycleQuery::None => Ok(()), + } + } +} fn cycle_a(db: &dyn Database) -> Result<(), Error> { - let _ = db.cycle_b(); - Ok(()) + dbg!("cycle_a"); + db.a_invokes().invoke(db) } fn cycle_b(db: &dyn Database) -> Result<(), Error> { - db.cycle_leaf(); - let _ = db.cycle_a(); - Ok(()) + dbg!("cycle_b"); + db.b_invokes().invoke(db) } fn cycle_c(db: &dyn Database) -> Result<(), Error> { - db.cycle_b() + dbg!("cycle_c"); + db.c_invokes().invoke(db) +} + +#[track_caller] +fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle { + let v = std::panic::catch_unwind(f); + if let Err(d) = &v { + if let Some(cycle) = d.downcast_ref::() { + return cycle.clone(); + } + } + panic!("unexpected value: {:?}", v) } #[test] -#[should_panic(expected = "cycle detected")] fn cycle_memoized() { - let query = DatabaseImpl::default(); - query.memoized_a(); + let db = DatabaseImpl::default(); + let cycle = extract_cycle(|| db.memoized_a()); + insta::assert_debug_snapshot!(cycle.unexpected_participants(&db), @r###" + [ + "memoized_a(())", + "memoized_b(())", + ] + "###); } #[test] -#[should_panic(expected = "cycle detected")] fn cycle_volatile() { - let query = DatabaseImpl::default(); - query.volatile_a(); + let db = DatabaseImpl::default(); + let cycle = extract_cycle(|| db.volatile_a()); + insta::assert_debug_snapshot!(cycle.unexpected_participants(&db), @r###" + [ + "volatile_a(())", + "volatile_b(())", + ] + "###); } #[test] fn cycle_cycle() { - let query = DatabaseImpl::default(); + let mut query = DatabaseImpl::default(); + + // A --> B + // ^ | + // +-----+ + + query.set_a_invokes(CycleQuery::B); + query.set_b_invokes(CycleQuery::A); + assert!(query.cycle_a().is_err()); } #[test] fn inner_cycle() { - let query = DatabaseImpl::default(); + let mut query = DatabaseImpl::default(); + + // A --> B <-- C + // ^ | + // +-----+ + + query.set_a_invokes(CycleQuery::B); + query.set_b_invokes(CycleQuery::A); + query.set_c_invokes(CycleQuery::B); + let err = query.cycle_c(); assert!(err.is_err()); let cycle = err.unwrap_err().cycle; - assert!( - cycle - .iter() - .zip(&["cycle_b", "cycle_a"]) - .all(|(l, r)| l.contains(r)), - "{:#?}", - cycle - ); + insta::assert_debug_snapshot!(cycle, @r###" + [ + "cycle_a(())", + "cycle_b(())", + ] + "###); } #[test] -fn parallel_cycle() { - let _ = env_logger::try_init(); +fn cycle_revalidate() { + let mut db = DatabaseImpl::default(); - let db = DatabaseImpl::default(); - let thread1 = std::thread::spawn({ - let db = db.snapshot(); - move || { - let result = db.cycle_a(); - assert!(result.is_err(), "Expected cycle error"); - let cycle = result.unwrap_err().cycle; - assert!( - cycle - .iter() - .all(|l| ["cycle_b", "cycle_a"].iter().any(|r| l.contains(r))), - "{:#?}", - cycle - ); - } - }); - - let thread2 = std::thread::spawn(move || { - let result = db.cycle_c(); - assert!(result.is_err(), "Expected cycle error"); - let cycle = result.unwrap_err().cycle; - assert!( - cycle - .iter() - .all(|l| ["cycle_b", "cycle_a"].iter().any(|r| l.contains(r))), - "{:#?}", - cycle - ); - }); - - thread1.join().unwrap(); - thread2.join().unwrap(); - eprintln!("OK"); + // A --> B + // ^ | + // +-----+ + db.set_a_invokes(CycleQuery::B); + db.set_b_invokes(CycleQuery::A); + + assert!(db.cycle_a().is_err()); + db.set_b_invokes(CycleQuery::A); // same value as default + assert!(db.cycle_a().is_err()); +} + +#[test] +fn cycle_revalidate_unchanged_twice() { + let mut db = DatabaseImpl::default(); + + // A --> B + // ^ | + // +-----+ + db.set_a_invokes(CycleQuery::B); + db.set_b_invokes(CycleQuery::A); + + assert!(db.cycle_a().is_err()); + db.set_c_invokes(CycleQuery::A); // force new revisi5on + + // on this run + insta::assert_debug_snapshot!(db.cycle_a(), @r###" + Err( + Error { + cycle: [ + "cycle_a(())", + "cycle_b(())", + ], + }, + ) + "###); +} + +#[test] +fn cycle_appears() { + let mut db = DatabaseImpl::default(); + + // A --> B + db.set_a_invokes(CycleQuery::B); + db.set_b_invokes(CycleQuery::None); + assert!(db.cycle_a().is_ok()); + + // A --> B + // ^ | + // +-----+ + db.set_b_invokes(CycleQuery::A); + log::debug!("Set Cycle Leaf"); + assert!(db.cycle_a().is_err()); +} + +#[test] +fn cycle_disappears() { + let mut db = DatabaseImpl::default(); + + // A --> B + // ^ | + // +-----+ + db.set_a_invokes(CycleQuery::B); + db.set_b_invokes(CycleQuery::A); + assert!(db.cycle_a().is_err()); + + // A --> B + db.set_b_invokes(CycleQuery::None); + assert!(db.cycle_a().is_ok()); +} + +/// A variant on `cycle_disappears` in which the values of +/// `a_invokes` and `b_invokes` are set with durability values. +/// If we are not careful, this could cause us to overlook +/// the fact that the cycle will no longer occur. +#[test] +fn cycle_disappears_durability() { + let mut db = DatabaseImpl::default(); + db.set_a_invokes_with_durability(CycleQuery::B, Durability::LOW); + db.set_b_invokes_with_durability(CycleQuery::A, Durability::HIGH); + + let res = db.cycle_a(); + assert!(res.is_err()); + + // At this point, `a` read `LOW` input, and `b` read `HIGH` input. However, + // because `b` participates in the same cycle as `a`, its final durability + // should be `LOW`. + // + // Check that setting a `LOW` input causes us to re-execute `b` query, and + // observe that the cycle goes away. + db.set_a_invokes_with_durability(CycleQuery::None, Durability::LOW); + + let res = db.cycle_b(); + assert!(res.is_ok()); +} + +#[test] +fn cycle_mixed_1() { + let mut db = DatabaseImpl::default(); + // A --> B <-- C + // | ^ + // +-----+ + db.set_a_invokes(CycleQuery::B); + db.set_b_invokes(CycleQuery::C); + db.set_c_invokes(CycleQuery::B); + + let u = db.cycle_c(); + insta::assert_debug_snapshot!(u, @r###" + Err( + Error { + cycle: [ + "cycle_b(())", + "cycle_c(())", + ], + }, + ) + "###); +} + +#[test] +fn cycle_mixed_2() { + let mut db = DatabaseImpl::default(); + + // Configuration: + // + // A --> B --> C + // ^ | + // +-----------+ + db.set_a_invokes(CycleQuery::B); + db.set_b_invokes(CycleQuery::C); + db.set_c_invokes(CycleQuery::A); + + let u = db.cycle_a(); + insta::assert_debug_snapshot!(u, @r###" + Err( + Error { + cycle: [ + "cycle_a(())", + "cycle_b(())", + "cycle_c(())", + ], + }, + ) + "###); +} + +#[test] +fn cycle_deterministic_order() { + // No matter whether we start from A or B, we get the same set of participants: + let db = || { + let mut db = DatabaseImpl::default(); + // A --> B + // ^ | + // +-----+ + db.set_a_invokes(CycleQuery::B); + db.set_b_invokes(CycleQuery::A); + db + }; + let a = db().cycle_a(); + let b = db().cycle_b(); + insta::assert_debug_snapshot!((a, b), @r###" + ( + Err( + Error { + cycle: [ + "cycle_a(())", + "cycle_b(())", + ], + }, + ), + Err( + Error { + cycle: [ + "cycle_a(())", + "cycle_b(())", + ], + }, + ), + ) + "###); +} + +#[test] +fn cycle_multiple() { + // No matter whether we start from A or B, we get the same set of participants: + let mut db = DatabaseImpl::default(); + + // Configuration: + // + // A --> B <-- C + // ^ | ^ + // +-----+ | + // | | + // +-----+ + // + // Here, conceptually, B encounters a cycle with A and then + // recovers. + db.set_a_invokes(CycleQuery::B); + db.set_b_invokes(CycleQuery::AthenC); + db.set_c_invokes(CycleQuery::B); + + let c = db.cycle_c(); + let b = db.cycle_b(); + let a = db.cycle_a(); + insta::assert_debug_snapshot!((a, b, c), @r###" + ( + Err( + Error { + cycle: [ + "cycle_a(())", + "cycle_b(())", + ], + }, + ), + Err( + Error { + cycle: [ + "cycle_a(())", + "cycle_b(())", + ], + }, + ), + Err( + Error { + cycle: [ + "cycle_a(())", + "cycle_b(())", + ], + }, + ), + ) + "###); +} + +#[test] +fn cycle_recovery_set_but_not_participating() { + let mut db = DatabaseImpl::default(); + + // A --> C -+ + // ^ | + // +--+ + db.set_a_invokes(CycleQuery::C); + db.set_c_invokes(CycleQuery::C); + + // Here we expect C to panic and A not to recover: + let r = extract_cycle(|| drop(db.cycle_a())); + insta::assert_debug_snapshot!(r.all_participants(&db), @r###" + [ + "cycle_c(())", + ] + "###); } diff --git a/tests/incremental/implementation.rs b/tests/incremental/implementation.rs index 23c36b5b..a9c0a4a0 100644 --- a/tests/incremental/implementation.rs +++ b/tests/incremental/implementation.rs @@ -24,6 +24,7 @@ pub(crate) struct TestContextImpl { } impl TestContextImpl { + #[track_caller] pub(crate) fn assert_log(&self, expected_log: &[&str]) { let expected_text = &format!("{:#?}", expected_log); let actual_text = &format!("{:#?}", self.log().take()); diff --git a/tests/incremental/log.rs b/tests/incremental/log.rs index a27a60b2..1ee57fe6 100644 --- a/tests/incremental/log.rs +++ b/tests/incremental/log.rs @@ -11,6 +11,6 @@ impl Log { } pub(crate) fn take(&self) -> Vec { - std::mem::replace(&mut *self.data.borrow_mut(), vec![]) + self.data.take() } } diff --git a/tests/incremental/memoized_volatile.rs b/tests/incremental/memoized_volatile.rs index 203c441a..6dc50300 100644 --- a/tests/incremental/memoized_volatile.rs +++ b/tests/incremental/memoized_volatile.rs @@ -60,7 +60,7 @@ fn revalidate() { // will not (still 0, as 1/2 = 0) query.salsa_runtime_mut().synthetic_write(Durability::LOW); query.memoized2(); - query.assert_log(&["Memoized1 invoked", "Volatile invoked"]); + query.assert_log(&["Volatile invoked", "Memoized1 invoked"]); query.memoized2(); query.assert_log(&[]); @@ -70,7 +70,7 @@ fn revalidate() { query.salsa_runtime_mut().synthetic_write(Durability::LOW); query.memoized2(); - query.assert_log(&["Memoized1 invoked", "Volatile invoked", "Memoized2 invoked"]); + query.assert_log(&["Volatile invoked", "Memoized1 invoked", "Memoized2 invoked"]); query.memoized2(); query.assert_log(&[]); diff --git a/tests/no_send_sync.rs b/tests/no_send_sync.rs index 4d4555ad..2648f2b7 100644 --- a/tests/no_send_sync.rs +++ b/tests/no_send_sync.rs @@ -29,5 +29,5 @@ fn no_send_sync() { let db = DatabaseImpl::default(); assert_eq!(db.no_send_sync_value(true), Rc::new(true)); - assert_eq!(db.no_send_sync_key(Rc::new(false)), false); + assert!(!db.no_send_sync_key(Rc::new(false))); } diff --git a/tests/on_demand_inputs.rs b/tests/on_demand_inputs.rs index 0d4b1987..99092025 100644 --- a/tests/on_demand_inputs.rs +++ b/tests/on_demand_inputs.rs @@ -4,9 +4,9 @@ //! via a b query with zero inputs, which uses `add_synthetic_read` to //! tweak durability and `invalidate` to clear the input. -use std::{cell::Cell, collections::HashMap, rc::Rc}; +use std::{cell::RefCell, collections::HashMap, rc::Rc}; -use salsa::{Database as _, Durability}; +use salsa::{Database as _, Durability, EventKind}; #[salsa::query_group(QueryGroupStorage)] trait QueryGroup: salsa::Database + AsRef> { @@ -39,13 +39,15 @@ fn c(db: &dyn QueryGroup, x: u32) -> u32 { struct Database { storage: salsa::Storage, external_state: HashMap, - on_event: Option>, + on_event: Option>, } impl salsa::Database for Database { fn salsa_event(&self, event: salsa::Event) { + dbg!(event.debug(self)); + if let Some(cb) = &self.on_event { - cb(event) + cb(self, event) } } } @@ -84,30 +86,68 @@ fn on_demand_input_works() { #[test] fn on_demand_input_durability() { let mut db = Database::default(); - db.external_state.insert(1, 10); - db.external_state.insert(2, 20); - assert_eq!(db.b(1), 10); - assert_eq!(db.b(2), 20); - let validated = Rc::new(Cell::new(0)); + let events = Rc::new(RefCell::new(vec![])); db.on_event = Some(Box::new({ - let validated = Rc::clone(&validated); - move |event| { - if let salsa::EventKind::DidValidateMemoizedValue { .. } = event.kind { - validated.set(validated.get() + 1) + let events = events.clone(); + move |db, event| { + if let EventKind::WillCheckCancellation = event.kind { + // these events are not interesting + } else { + events.borrow_mut().push(format!("{:?}", event.debug(db))) } } })); + events.replace(vec![]); + db.external_state.insert(1, 10); + db.external_state.insert(2, 20); + assert_eq!(db.b(1), 10); + assert_eq!(db.b(2), 20); + insta::assert_debug_snapshot!(events, @r###" + RefCell { + value: [ + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: b(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: b(2) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }", + ], + } + "###); + + eprintln!("------------------"); db.salsa_runtime_mut().synthetic_write(Durability::LOW); - validated.set(0); + events.replace(vec![]); assert_eq!(db.c(1), 10); assert_eq!(db.c(2), 20); - assert_eq!(validated.get(), 2); + // Re-execute `a(2)` because that has low durability, but not `a(1)` + insta::assert_debug_snapshot!(events, @r###" + RefCell { + value: [ + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: c(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: b(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: c(2) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: b(2) } }", + ], + } + "###); + eprintln!("------------------"); db.salsa_runtime_mut().synthetic_write(Durability::HIGH); - validated.set(0); + events.replace(vec![]); assert_eq!(db.c(1), 10); assert_eq!(db.c(2), 20); - assert_eq!(validated.get(), 4); + // Re-execute both `a(1)` and `a(2)`, but we don't re-execute any `b` queries as the + // result didn't actually change. + insta::assert_debug_snapshot!(events, @r###" + RefCell { + value: [ + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: c(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: c(2) } }", + ], + } + "###); } diff --git a/tests/panic_safely.rs b/tests/panic_safely.rs index 551e6c7a..e51a74e1 100644 --- a/tests/panic_safely.rs +++ b/tests/panic_safely.rs @@ -70,7 +70,7 @@ fn should_panic_safely() { db.set_one(1); db.outer(); - assert_eq!(OUTER_CALLS.load(SeqCst), 1); + assert_eq!(OUTER_CALLS.load(SeqCst), 2); } } diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 9c6e5360..31c0da18 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -3,6 +3,10 @@ mod setup; mod cancellation; mod frozen; mod independent; +mod parallel_cycle_all_recover; +mod parallel_cycle_mid_recover; +mod parallel_cycle_none_recover; +mod parallel_cycle_one_recovers; mod race; mod signal; mod stress; diff --git a/tests/parallel/parallel_cycle_all_recover.rs b/tests/parallel/parallel_cycle_all_recover.rs new file mode 100644 index 00000000..f3e5609d --- /dev/null +++ b/tests/parallel/parallel_cycle_all_recover.rs @@ -0,0 +1,110 @@ +//! Test for cycle recover spread across two threads. +//! See `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::{Knobs, ParDatabaseImpl}; +use salsa::ParallelDatabase; +use test_env_log::test; + +// Recover cycle test: +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 b1 +// | wait for stage 1 (blocks) +// signal stage 1 | +// wait for stage 2 (blocks) (unblocked) +// | signal stage 2 +// (unblocked) wait for stage 3 (blocks) +// a2 | +// b1 (blocks -> stage 3) | +// | (unblocked) +// | b2 +// | a1 (cycle detected, recovers) +// | b2 completes, recovers +// | b1 completes, recovers +// a2 sees cycle, recovers +// a1 completes, recovers + +#[test] +fn parallel_cycle_all_recover() { + let db = ParDatabaseImpl::default(); + db.knobs().signal_on_will_block.set(3); + + let thread_a = std::thread::spawn({ + let db = db.snapshot(); + move || db.a1(1) + }); + + let thread_b = std::thread::spawn({ + let db = db.snapshot(); + move || db.b1(1) + }); + + assert_eq!(thread_a.join().unwrap(), 11); + assert_eq!(thread_b.join().unwrap(), 21); +} + +#[salsa::query_group(ParallelCycleAllRecover)] +pub(crate) trait TestDatabase: Knobs { + #[salsa::cycle(recover_a1)] + fn a1(&self, key: i32) -> i32; + + #[salsa::cycle(recover_a2)] + fn a2(&self, key: i32) -> i32; + + #[salsa::cycle(recover_b1)] + fn b1(&self, key: i32) -> i32; + + #[salsa::cycle(recover_b2)] + fn b2(&self, key: i32) -> i32; +} + +fn recover_a1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 { + log::debug!("recover_a1"); + key * 10 + 1 +} + +fn recover_a2(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 { + log::debug!("recover_a2"); + key * 10 + 2 +} + +fn recover_b1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 { + log::debug!("recover_b1"); + key * 20 + 1 +} + +fn recover_b2(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 { + log::debug!("recover_b2"); + key * 20 + 2 +} + +fn a1(db: &dyn TestDatabase, key: i32) -> i32 { + // Wait to create the cycle until both threads have entered + db.signal(1); + db.wait_for(2); + + db.a2(key) +} + +fn a2(db: &dyn TestDatabase, key: i32) -> i32 { + db.b1(key) +} + +fn b1(db: &dyn TestDatabase, key: i32) -> i32 { + // Wait to create the cycle until both threads have entered + db.wait_for(1); + db.signal(2); + + // Wait for thread A to block on this thread + db.wait_for(3); + + db.b2(key) +} + +fn b2(db: &dyn TestDatabase, key: i32) -> i32 { + db.a1(key) +} diff --git a/tests/parallel/parallel_cycle_mid_recover.rs b/tests/parallel/parallel_cycle_mid_recover.rs new file mode 100644 index 00000000..80edba1c --- /dev/null +++ b/tests/parallel/parallel_cycle_mid_recover.rs @@ -0,0 +1,110 @@ +//! Test for cycle recover spread across two threads. +//! See `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::{Knobs, ParDatabaseImpl}; +use salsa::ParallelDatabase; +use test_env_log::test; + +// Recover cycle test: +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 b1 +// | wait for stage 1 (blocks) +// signal stage 1 | +// wait for stage 2 (blocks) (unblocked) +// | | +// | b2 +// | b3 +// | a1 (blocks -> stage 2) +// (unblocked) | +// a2 (cycle detected) | +// b3 recovers +// b2 resumes +// b1 panics because bug + +#[test] +fn parallel_cycle_mid_recovers() { + let db = ParDatabaseImpl::default(); + db.knobs().signal_on_will_block.set(2); + + let thread_a = std::thread::spawn({ + let db = db.snapshot(); + move || db.a1(1) + }); + + let thread_b = std::thread::spawn({ + let db = db.snapshot(); + move || db.b1(1) + }); + + // We expect that the recovery function yields + // `1 * 20 + 2`, which is returned (and forwarded) + // to b1, and from there to a2 and a1. + assert_eq!(thread_a.join().unwrap(), 22); + assert_eq!(thread_b.join().unwrap(), 22); +} + +#[salsa::query_group(ParallelCycleMidRecovers)] +pub(crate) trait TestDatabase: Knobs { + fn a1(&self, key: i32) -> i32; + + fn a2(&self, key: i32) -> i32; + + #[salsa::cycle(recover_b1)] + fn b1(&self, key: i32) -> i32; + + fn b2(&self, key: i32) -> i32; + + #[salsa::cycle(recover_b3)] + fn b3(&self, key: i32) -> i32; +} + +fn recover_b1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 { + log::debug!("recover_b1"); + key * 20 + 2 +} + +fn recover_b3(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 { + log::debug!("recover_b1"); + key * 200 + 2 +} + +fn a1(db: &dyn TestDatabase, key: i32) -> i32 { + // tell thread b we have started + db.signal(1); + + // wait for thread b to block on a1 + db.wait_for(2); + + db.a2(key) +} + +fn a2(db: &dyn TestDatabase, key: i32) -> i32 { + // create the cycle + db.b1(key) +} + +fn b1(db: &dyn TestDatabase, key: i32) -> i32 { + // wait for thread a to have started + db.wait_for(1); + + db.b2(key); + + 0 +} + +fn b2(db: &dyn TestDatabase, key: i32) -> i32 { + // will encounter a cycle but recover + db.b3(key); + db.b1(key); // hasn't recovered yet + 0 +} + +fn b3(db: &dyn TestDatabase, key: i32) -> i32 { + // will block on thread a, signaling stage 2 + db.a1(key) +} diff --git a/tests/parallel/parallel_cycle_none_recover.rs b/tests/parallel/parallel_cycle_none_recover.rs new file mode 100644 index 00000000..9cb8f2f9 --- /dev/null +++ b/tests/parallel/parallel_cycle_none_recover.rs @@ -0,0 +1,71 @@ +//! Test a cycle where no queries recover that occurs across threads. +//! See the `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::{Knobs, ParDatabaseImpl}; +use salsa::ParallelDatabase; +use test_env_log::test; + +#[test] +fn parallel_cycle_none_recover() { + let db = ParDatabaseImpl::default(); + db.knobs().signal_on_will_block.set(3); + + let thread_a = std::thread::spawn({ + let db = db.snapshot(); + move || db.a(-1) + }); + + let thread_b = std::thread::spawn({ + let db = db.snapshot(); + move || db.b(-1) + }); + + // We expect B to panic because it detects a cycle (it is the one that calls A, ultimately). + // Right now, it panics with a string. + let err_b = thread_b.join().unwrap_err(); + if let Some(c) = err_b.downcast_ref::() { + insta::assert_debug_snapshot!(c.unexpected_participants(&db), @r###" + [ + "a(-1)", + "b(-1)", + ] + "###); + } else { + panic!("b failed in an unexpected way: {:?}", err_b); + } + + // We expect A to propagate a panic, which causes us to use the sentinel + // type `Canceled`. + assert!(thread_a + .join() + .unwrap_err() + .downcast_ref::() + .is_some()); +} + +#[salsa::query_group(ParallelCycleNoneRecover)] +pub(crate) trait TestDatabase: Knobs { + fn a(&self, key: i32) -> i32; + fn b(&self, key: i32) -> i32; +} + +fn a(db: &dyn TestDatabase, key: i32) -> i32 { + // Wait to create the cycle until both threads have entered + db.signal(1); + db.wait_for(2); + + db.b(key) +} + +fn b(db: &dyn TestDatabase, key: i32) -> i32 { + // Wait to create the cycle until both threads have entered + db.wait_for(1); + db.signal(2); + + // Wait for thread A to block on this thread + db.wait_for(3); + + // Now try to execute A + db.a(key) +} diff --git a/tests/parallel/parallel_cycle_one_recovers.rs b/tests/parallel/parallel_cycle_one_recovers.rs new file mode 100644 index 00000000..d7ae78b8 --- /dev/null +++ b/tests/parallel/parallel_cycle_one_recovers.rs @@ -0,0 +1,95 @@ +//! Test for cycle recover spread across two threads. +//! See `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::{Knobs, ParDatabaseImpl}; +use salsa::ParallelDatabase; +use test_env_log::test; + +// Recover cycle test: +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 b1 +// | wait for stage 1 (blocks) +// signal stage 1 | +// wait for stage 2 (blocks) (unblocked) +// | signal stage 2 +// (unblocked) wait for stage 3 (blocks) +// a2 | +// b1 (blocks -> stage 3) | +// | (unblocked) +// | b2 +// | a1 (cycle detected) +// a2 recovery fn executes | +// a1 completes normally | +// b2 completes, recovers +// b1 completes, recovers + +#[test] +fn parallel_cycle_one_recovers() { + let db = ParDatabaseImpl::default(); + db.knobs().signal_on_will_block.set(3); + + let thread_a = std::thread::spawn({ + let db = db.snapshot(); + move || db.a1(1) + }); + + let thread_b = std::thread::spawn({ + let db = db.snapshot(); + move || db.b1(1) + }); + + // We expect that the recovery function yields + // `1 * 20 + 2`, which is returned (and forwarded) + // to b1, and from there to a2 and a1. + assert_eq!(thread_a.join().unwrap(), 22); + assert_eq!(thread_b.join().unwrap(), 22); +} + +#[salsa::query_group(ParallelCycleOneRecovers)] +pub(crate) trait TestDatabase: Knobs { + fn a1(&self, key: i32) -> i32; + + #[salsa::cycle(recover)] + fn a2(&self, key: i32) -> i32; + + fn b1(&self, key: i32) -> i32; + + fn b2(&self, key: i32) -> i32; +} + +fn recover(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 { + log::debug!("recover"); + key * 20 + 2 +} + +fn a1(db: &dyn TestDatabase, key: i32) -> i32 { + // Wait to create the cycle until both threads have entered + db.signal(1); + db.wait_for(2); + + db.a2(key) +} + +fn a2(db: &dyn TestDatabase, key: i32) -> i32 { + db.b1(key) +} + +fn b1(db: &dyn TestDatabase, key: i32) -> i32 { + // Wait to create the cycle until both threads have entered + db.wait_for(1); + db.signal(2); + + // Wait for thread A to block on this thread + db.wait_for(3); + + db.b2(key) +} + +fn b2(db: &dyn TestDatabase, key: i32) -> i32 { + db.a1(key) +} diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index cc3be434..963c93cd 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -157,7 +157,13 @@ fn sum3_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize { db.sum2_drop_sum(key) } -#[salsa::database(Par)] +#[salsa::database( + Par, + crate::parallel_cycle_all_recover::ParallelCycleAllRecover, + crate::parallel_cycle_none_recover::ParallelCycleNoneRecover, + crate::parallel_cycle_mid_recover::ParallelCycleMidRecovers, + crate::parallel_cycle_one_recovers::ParallelCycleOneRecovers +)] #[derive(Default)] pub(crate) struct ParDatabaseImpl { storage: salsa::Storage,