Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dfir_lang): add repeat_n(n) windowing operator, modify scheduler #1596

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion dfir_lang/src/graph/ops/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ pub const BATCH: OperatorConstraints = OperatorConstraints {
let input = &inputs[0];
quote_spanned! {op_span=>
let mut #vec_ident = #context.state_ref(#singleton_output_ident).borrow_mut();
*#vec_ident = #input.collect::<::std::vec::Vec<_>>();
if #context.is_first_run_this_tick() {
*#vec_ident = #input.collect::<::std::vec::Vec<_>>();
}
let #ident = ::std::iter::once(::std::clone::Clone::clone(&*#vec_ident));
}
} else if let Some(_output) = outputs.first() {
Expand Down
1 change: 1 addition & 0 deletions dfir_lang/src/graph/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ declare_ops![
persist_mut_keyed::PERSIST_MUT_KEYED,
py_udf::PY_UDF,
reduce::REDUCE,
repeat_n::REPEAT_N,
spin::SPIN,
sort::SORT,
sort_by_key::SORT_BY_KEY,
Expand Down
54 changes: 54 additions & 0 deletions dfir_lang/src/graph/ops/repeat_n.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use quote::quote_spanned;

use super::{OperatorConstraints, OperatorWriteOutput, WriteContextArgs};

/// TODO(mingwei): docs
pub const REPEAT_N: OperatorConstraints = OperatorConstraints {
name: "repeat_n",
num_args: 1,
write_fn: |wc @ &WriteContextArgs {
context,
hydroflow,
op_span,
arguments,
..
},
diagnostics| {
let OperatorWriteOutput {
write_prologue,
write_iterator,
write_iterator_after,
} = (super::all_once::ALL_ONCE.write_fn)(wc, diagnostics)?;

let count_ident = wc.make_ident("count");

let write_prologue = quote_spanned! {op_span=>
#write_prologue

let #count_ident = #hydroflow.add_state(::std::cell::Cell::new(0_usize));
#hydroflow.set_state_tick_hook(#count_ident, move |cell| { cell.take(); });
};

// Reschedule, to repeat.
let count_arg = &arguments[0];
let write_iterator_after = quote_spanned! {op_span=>
#write_iterator_after

{
let count_ref = #context.state_ref(#count_ident);
let count = count_ref.get() + 1;
if count < #count_arg {
count_ref.set(count);
#context.reschedule_current_subgraph();
}
}
};

Ok(OperatorWriteOutput {
write_prologue,
write_iterator,
write_iterator_after,
})
},
..super::all_once::ALL_ONCE
};
21 changes: 17 additions & 4 deletions dfir_rs/src/scheduled/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//! Provides APIs for state and scheduling.

use std::any::Any;
use std::cell::Cell;
use std::collections::VecDeque;
use std::future::Future;
use std::marker::PhantomData;
Expand Down Expand Up @@ -36,10 +37,13 @@ pub struct Context {
/// If the events have been received for this tick.
pub(super) events_received_tick: bool,

// TODO(mingwei): as long as this is here, it's impossible to know when all work is done.
// Second field (bool) is for if the event is an external "important" event (true).
// TODO(mingwei): as long as this is unclosed, it's impossible to know when all work is done.
/// Second field (bool) is for if the event is an external "important" event (true).
pub(super) event_queue_send: UnboundedSender<(SubgraphId, bool)>,

/// If the current subgraph wants to reschedule in the current tick+stratum.
pub(super) reschedule_current_subgraph: Cell<bool>,

pub(super) current_tick: TickInstant,
pub(super) current_stratum: usize,

Expand All @@ -51,7 +55,6 @@ pub struct Context {
pub(super) subgraph_id: SubgraphId,

tasks_to_spawn: Vec<Pin<Box<dyn Future<Output = ()> + 'static>>>,

/// Join handles for spawned tasks.
task_join_handles: Vec<JoinHandle<()>>,
}
Expand Down Expand Up @@ -85,11 +88,20 @@ impl Context {
self.subgraph_id
}

/// Schedules a subgraph.
/// Schedules a subgraph for the next tick.
///
/// If `is_external` is `true`, the scheduling will trigger the next tick to begin. If it is
/// `false` then scheduling will be lazy and the next tick will not begin unless there is other
/// reason to.
pub fn schedule_subgraph(&self, sg_id: SubgraphId, is_external: bool) {
self.event_queue_send.send((sg_id, is_external)).unwrap()
}

/// Schedules the current subgraph to run again _this tick_.
pub fn reschedule_current_subgraph(&self) {
self.reschedule_current_subgraph.set(true);
}

/// Returns a `Waker` for interacting with async Rust.
/// Waker events are considered to be extenral.
pub fn waker(&self) -> std::task::Waker {
Expand Down Expand Up @@ -231,6 +243,7 @@ impl Default for Context {
events_received_tick: false,

event_queue_send,
reschedule_current_subgraph: Cell::new(false),

current_stratum: 0,
current_tick: TickInstant::default(),
Expand Down
9 changes: 9 additions & 0 deletions dfir_rs/src/scheduled/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ impl<'a> Dfir<'a> {
}

let sg_data = &self.subgraphs[sg_id.0];
debug_assert_eq!(self.context.current_stratum, sg_data.stratum);
for &handoff_id in sg_data.succs.iter() {
let handoff = &self.handoffs[handoff_id.0];
if !handoff.handoff.is_bottom() {
Expand All @@ -298,6 +299,14 @@ impl<'a> Dfir<'a> {
}
}
}

// Check if subgraph wants rescheduling
if self.context.reschedule_current_subgraph.take() {
// Add subgraph to stratum queue if it is not already scheduled.
if !sg_data.is_scheduled.replace(true) {
self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
}
}
}
work_done
}
Expand Down
19 changes: 11 additions & 8 deletions dfir_rs/tests/snapshots/surface_loop__flo_nested@graphvis_dot.snap
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,23 @@ digraph {
n6v1 [label="(n6v1) flatten()", shape=invhouse, fillcolor="#88aaff"]
n7v1 [label="(n7v1) cross_join::<'static, 'tick>()", shape=invhouse, fillcolor="#88aaff"]
n8v1 [label="(n8v1) all_once()", shape=invhouse, fillcolor="#88aaff"]
n9v1 [label="(n9v1) for_each(|all| println!(\"{}: {:?}\", context.current_tick(), all))", shape=house, fillcolor="#ffff88"]
n10v1 [label="(n10v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n9v1 [label="(n9v1) map(|vec| (context.current_tick().0, vec))", shape=invhouse, fillcolor="#88aaff"]
n10v1 [label="(n10v1) assert_eq([\l (\l 0,\l vec![\l (\"alice\", 0),\l (\"alice\", 1),\l (\"alice\", 2),\l (\"bob\", 0),\l (\"bob\", 1),\l (\"bob\", 2),\l ],\l ),\l (\l 1,\l vec![\l (\"alice\", 3),\l (\"alice\", 4),\l (\"alice\", 5),\l (\"bob\", 3),\l (\"bob\", 4),\l (\"bob\", 5),\l ],\l ),\l (\l 2,\l vec![\l (\"alice\", 6),\l (\"alice\", 7),\l (\"alice\", 8),\l (\"bob\", 6),\l (\"bob\", 7),\l (\"bob\", 8),\l ],\l ),\l (\l 3,\l vec![\l (\"alice\", 9),\l (\"alice\", 10),\l (\"alice\", 11),\l (\"bob\", 9),\l (\"bob\", 10),\l (\"bob\", 11),\l ],\l ),\l])\l", shape=house, fillcolor="#ffff88"]
n11v1 [label="(n11v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n12v1 [label="(n12v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n13v1 [label="(n13v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n4v1 -> n7v1 [label="0"]
n3v1 -> n4v1
n1v1 -> n10v1
n1v1 -> n11v1
n6v1 -> n7v1 [label="1"]
n5v1 -> n6v1
n2v1 -> n11v1
n2v1 -> n12v1
n9v1 -> n10v1
n8v1 -> n9v1
n7v1 -> n12v1
n10v1 -> n3v1
n11v1 -> n5v1
n12v1 -> n8v1 [color=red]
n7v1 -> n13v1
n11v1 -> n3v1
n12v1 -> n5v1
n13v1 -> n8v1 [color=red]
subgraph "cluster n1v1" {
fillcolor="#dddddd"
style=filled
Expand Down Expand Up @@ -68,5 +70,6 @@ digraph {
label = "sg_4v1\nstratum 1"
n8v1
n9v1
n10v1
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,23 @@ linkStyle default stroke:#aaa
6v1[\"(6v1) <code>flatten()</code>"/]:::pullClass
7v1[\"(7v1) <code>cross_join::&lt;'static, 'tick&gt;()</code>"/]:::pullClass
8v1[\"(8v1) <code>all_once()</code>"/]:::pullClass
9v1[/"(9v1) <code>for_each(|all| println!(&quot;{}: {:?}&quot;, context.current_tick(), all))</code>"\]:::pushClass
10v1["(10v1) <code>handoff</code>"]:::otherClass
9v1[\"(9v1) <code>map(|vec| (context.current_tick().0, vec))</code>"/]:::pullClass
10v1[/"<div style=text-align:center>(10v1)</div> <code>assert_eq([<br> (<br> 0,<br> vec![<br> (&quot;alice&quot;, 0),<br> (&quot;alice&quot;, 1),<br> (&quot;alice&quot;, 2),<br> (&quot;bob&quot;, 0),<br> (&quot;bob&quot;, 1),<br> (&quot;bob&quot;, 2),<br> ],<br> ),<br> (<br> 1,<br> vec![<br> (&quot;alice&quot;, 3),<br> (&quot;alice&quot;, 4),<br> (&quot;alice&quot;, 5),<br> (&quot;bob&quot;, 3),<br> (&quot;bob&quot;, 4),<br> (&quot;bob&quot;, 5),<br> ],<br> ),<br> (<br> 2,<br> vec![<br> (&quot;alice&quot;, 6),<br> (&quot;alice&quot;, 7),<br> (&quot;alice&quot;, 8),<br> (&quot;bob&quot;, 6),<br> (&quot;bob&quot;, 7),<br> (&quot;bob&quot;, 8),<br> ],<br> ),<br> (<br> 3,<br> vec![<br> (&quot;alice&quot;, 9),<br> (&quot;alice&quot;, 10),<br> (&quot;alice&quot;, 11),<br> (&quot;bob&quot;, 9),<br> (&quot;bob&quot;, 10),<br> (&quot;bob&quot;, 11),<br> ],<br> ),<br>])</code>"\]:::pushClass
11v1["(11v1) <code>handoff</code>"]:::otherClass
12v1["(12v1) <code>handoff</code>"]:::otherClass
13v1["(13v1) <code>handoff</code>"]:::otherClass
4v1-->|0|7v1
3v1-->4v1
1v1-->10v1
1v1-->11v1
6v1-->|1|7v1
5v1-->6v1
2v1-->11v1
2v1-->12v1
9v1-->10v1
8v1-->9v1
7v1-->12v1
10v1-->3v1
11v1-->5v1
12v1--x8v1; linkStyle 10 stroke:red
7v1-->13v1
11v1-->3v1
12v1-->5v1
13v1--x8v1; linkStyle 11 stroke:red
subgraph sg_1v1 ["sg_1v1 stratum 0"]
1v1
subgraph sg_1v1_var_users ["var <tt>users</tt>"]
Expand All @@ -56,4 +58,5 @@ end
subgraph sg_4v1 ["sg_4v1 stratum 1"]
8v1
9v1
10v1
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
---
source: dfir_rs/tests/surface_loop.rs
expression: "df.meta_graph().unwrap().to_dot(& Default :: default())"
---
digraph {
node [fontname="Monaco,Menlo,Consolas,&quot;Droid Sans Mono&quot;,Inconsolata,&quot;Courier New&quot;,monospace", style=filled];
edge [fontname="Monaco,Menlo,Consolas,&quot;Droid Sans Mono&quot;,Inconsolata,&quot;Courier New&quot;,monospace"];
n1v1 [label="(n1v1) source_iter([\"alice\", \"bob\"])", shape=invhouse, fillcolor="#88aaff"]
n2v1 [label="(n2v1) source_stream(iter_batches_stream(0..12, 3))", shape=invhouse, fillcolor="#88aaff"]
n3v1 [label="(n3v1) batch()", shape=invhouse, fillcolor="#88aaff"]
n4v1 [label="(n4v1) flatten()", shape=invhouse, fillcolor="#88aaff"]
n5v1 [label="(n5v1) batch()", shape=invhouse, fillcolor="#88aaff"]
n6v1 [label="(n6v1) flatten()", shape=invhouse, fillcolor="#88aaff"]
n7v1 [label="(n7v1) cross_join::<'static, 'tick>()", shape=invhouse, fillcolor="#88aaff"]
n8v1 [label="(n8v1) repeat_n(3)", shape=invhouse, fillcolor="#88aaff"]
n9v1 [label="(n9v1) map(|vec| (context.current_tick().0, vec))", shape=invhouse, fillcolor="#88aaff"]
n10v1 [label="(n10v1) inspect(|x| println!(\"{:?}\", x))", shape=invhouse, fillcolor="#88aaff"]
n11v1 [label="(n11v1) assert_eq([\l (\l 0,\l vec![\l (\"alice\", 0),\l (\"alice\", 1),\l (\"alice\", 2),\l (\"bob\", 0),\l (\"bob\", 1),\l (\"bob\", 2),\l ],\l ),\l (\l 0,\l vec![\l (\"alice\", 0),\l (\"alice\", 1),\l (\"alice\", 2),\l (\"bob\", 0),\l (\"bob\", 1),\l (\"bob\", 2),\l ],\l ),\l (\l 0,\l vec![\l (\"alice\", 0),\l (\"alice\", 1),\l (\"alice\", 2),\l (\"bob\", 0),\l (\"bob\", 1),\l (\"bob\", 2),\l ],\l ),\l (\l 1,\l vec![\l (\"alice\", 3),\l (\"alice\", 4),\l (\"alice\", 5),\l (\"bob\", 3),\l (\"bob\", 4),\l (\"bob\", 5),\l ],\l ),\l (\l 1,\l vec![\l (\"alice\", 3),\l (\"alice\", 4),\l (\"alice\", 5),\l (\"bob\", 3),\l (\"bob\", 4),\l (\"bob\", 5),\l ],\l ),\l (\l 1,\l vec![\l (\"alice\", 3),\l (\"alice\", 4),\l (\"alice\", 5),\l (\"bob\", 3),\l (\"bob\", 4),\l (\"bob\", 5),\l ],\l ),\l (\l 2,\l vec![\l (\"alice\", 6),\l (\"alice\", 7),\l (\"alice\", 8),\l (\"bob\", 6),\l (\"bob\", 7),\l (\"bob\", 8),\l ],\l ),\l (\l 2,\l vec![\l (\"alice\", 6),\l (\"alice\", 7),\l (\"alice\", 8),\l (\"bob\", 6),\l (\"bob\", 7),\l (\"bob\", 8),\l ],\l ),\l (\l 2,\l vec![\l (\"alice\", 6),\l (\"alice\", 7),\l (\"alice\", 8),\l (\"bob\", 6),\l (\"bob\", 7),\l (\"bob\", 8),\l ],\l ),\l (\l 3,\l vec![\l (\"alice\", 9),\l (\"alice\", 10),\l (\"alice\", 11),\l (\"bob\", 9),\l (\"bob\", 10),\l (\"bob\", 11),\l ],\l ),\l (\l 3,\l vec![\l (\"alice\", 9),\l (\"alice\", 10),\l (\"alice\", 11),\l (\"bob\", 9),\l (\"bob\", 10),\l (\"bob\", 11),\l ],\l ),\l (\l 3,\l vec![\l (\"alice\", 9),\l (\"alice\", 10),\l (\"alice\", 11),\l (\"bob\", 9),\l (\"bob\", 10),\l (\"bob\", 11),\l ],\l ),\l])\l", shape=house, fillcolor="#ffff88"]
n12v1 [label="(n12v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n13v1 [label="(n13v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n14v1 [label="(n14v1) handoff", shape=parallelogram, fillcolor="#ddddff"]
n4v1 -> n7v1 [label="0"]
n3v1 -> n4v1
n1v1 -> n12v1
n6v1 -> n7v1 [label="1"]
n5v1 -> n6v1
n2v1 -> n13v1
n10v1 -> n11v1
n9v1 -> n10v1
n8v1 -> n9v1
n7v1 -> n14v1
n12v1 -> n3v1
n13v1 -> n5v1
n14v1 -> n8v1 [color=red]
subgraph "cluster n1v1" {
fillcolor="#dddddd"
style=filled
label = "sg_1v1\nstratum 0"
n1v1
subgraph "cluster_sg_1v1_var_users" {
label="var users"
n1v1
}
}
subgraph "cluster n2v1" {
fillcolor="#dddddd"
style=filled
label = "sg_2v1\nstratum 0"
n2v1
subgraph "cluster_sg_2v1_var_messages" {
label="var messages"
n2v1
}
}
subgraph "cluster n3v1" {
fillcolor="#dddddd"
style=filled
label = "sg_3v1\nstratum 0"
n3v1
n4v1
n5v1
n6v1
n7v1
subgraph "cluster_sg_3v1_var_cp" {
label="var cp"
n7v1
}
}
subgraph "cluster n4v1" {
fillcolor="#dddddd"
style=filled
label = "sg_4v1\nstratum 1"
n8v1
n9v1
n10v1
n11v1
}
}
Loading
Loading