From 1c35bb953863387abe418e19220567275bbdbfa9 Mon Sep 17 00:00:00 2001 From: Anton Reinhard Date: Wed, 4 Dec 2024 16:00:48 +0100 Subject: [PATCH] Rework FunctionCalls to allow multiple returned objects (#44) * Restructure FunctionCall * Make closures work again * Improve scheduler * Change Function Calls to mostly use Vectors instead of Tuples This makes small function calls slower, but large function calls much much faster to compile * WIP fix input assignment code world age problem * Remove trie workaround * Remove call_fc and execute functions * Remove superfluous _gen_access_expr and simplify _gen_local_init * Renaming of member variables * renaming/docs/removing unused interfaces fix type stability in closures * Work on performance problems with closures (again...) * Fix NodeSplit --- ext/devices/cuda/function.jl | 11 +- ext/devices/rocm/function.jl | 11 +- src/ComputableDAGs.jl | 8 +- src/code_gen/function.jl | 44 ++--- src/code_gen/tape_machine.jl | 313 +++++++++++++++----------------- src/code_gen/type.jl | 78 +++++++- src/code_gen/utils.jl | 41 +++-- src/devices/impl.jl | 67 ++++++- src/devices/interface.jl | 17 -- src/devices/numa/impl.jl | 23 --- src/estimator/global_metric.jl | 32 ++-- src/graph/interface.jl | 13 +- src/graph/mute.jl | 36 ++-- src/graph/print.jl | 10 +- src/graph/properties.jl | 4 +- src/graph/type.jl | 12 +- src/graph/validate.jl | 12 +- src/models/interface.jl | 40 ++-- src/node/type.jl | 16 +- src/node/validate.jl | 8 +- src/operation/apply.jl | 17 +- src/operation/clean.jl | 18 +- src/operation/find.jl | 46 ++--- src/operation/get.jl | 8 +- src/operation/iterate.jl | 10 +- src/operation/print.jl | 8 +- src/operation/utility.jl | 10 +- src/optimization/random_walk.jl | 10 +- src/optimization/reduce.jl | 4 +- src/optimization/split.jl | 4 +- src/properties/create.jl | 22 ++- src/properties/type.jl | 14 +- src/properties/utility.jl | 30 +-- src/scheduler/greedy.jl | 26 ++- src/scheduler/interface.jl | 8 - src/scheduler/type.jl | 16 +- src/task/compute.jl | 142 ++++++++++----- src/trie.jl | 32 +--- src/utils.jl | 35 +--- test/strassen_test.jl | 4 +- 40 files changed, 641 insertions(+), 619 deletions(-) diff --git a/ext/devices/cuda/function.jl b/ext/devices/cuda/function.jl index 7ccdf94..4c9caf8 100644 --- a/ext/devices/cuda/function.jl +++ b/ext/devices/cuda/function.jl @@ -4,26 +4,21 @@ function ComputableDAGs.kernel( machine = cpu_st() tape = ComputableDAGs.gen_tape(graph, instance, machine, context_module) - assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...) + assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.input_assign_code)...) # TODO: use gen_function_body here code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...) function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1])) - res_sym = eval( - ComputableDAGs._gen_access_expr( - ComputableDAGs.entry_device(tape.machine), tape.outputSymbol - ), - ) expr = Meta.parse( "function compute_$(function_id)(input_vector, output_vector, n::Int64) id = (blockIdx().x - 1) * blockDim().x + threadIdx().x if (id > n) return end - @inline data_input = input_vector[id] + @inline input = input_vector[id] $(assign_inputs) $code - @inline output_vector[id] = $res_sym + @inline output_vector[id] = $(tape.output_symbol) return nothing end" ) diff --git a/ext/devices/rocm/function.jl b/ext/devices/rocm/function.jl index dc617e2..1567ce3 100644 --- a/ext/devices/rocm/function.jl +++ b/ext/devices/rocm/function.jl @@ -4,27 +4,22 @@ function ComputableDAGs.kernel( machine = cpu_st() tape = ComputableDAGs.gen_tape(graph, instance, machine, context_module) - assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...) + assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.input_assign_code)...) # TODO use gen_function_body here code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...) function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1])) - res_sym = eval( - ComputableDAGs._gen_access_expr( - ComputableDAGs.entry_device(tape.machine), tape.outputSymbol - ), - ) expr = Meta.parse( "function compute_$(function_id)(input_vector, output_vector, n::Int64) id = (workgroupIdx().x - 1) * workgroupDim().x + workgroupIdx().x if (id > n) return end - @inline data_input = input_vector[id] + @inline input = input_vector[id] $(assign_inputs) $code - @inline output_vector[id] = $res_sym + @inline output_vector[id] = $(tape.output_symbol) return nothing end" ) diff --git a/src/ComputableDAGs.jl b/src/ComputableDAGs.jl index a852e15..c0b8b7c 100644 --- a/src/ComputableDAGs.jl +++ b/src/ComputableDAGs.jl @@ -34,10 +34,7 @@ export reset_graph! export get_operations # code generation related -export execute export get_compute_function -export gen_tape, execute_tape -export unpack_identity # estimator export cost_type, graph_cost, operation_effect @@ -50,8 +47,7 @@ export optimize_step!, optimize! export fixpoint_reached, optimize_to_fixpoint! # models -export AbstractModel, AbstractProblemInstance -export problem_instance, input_type, graph, input_expr +export graph, input_type, input_expr # machine info export Machine @@ -69,6 +65,7 @@ include("properties/type.jl") include("operation/type.jl") include("graph/type.jl") include("scheduler/type.jl") +include("code_gen/type.jl") include("trie.jl") include("utils.jl") @@ -127,7 +124,6 @@ include("devices/ext.jl") include("scheduler/interface.jl") include("scheduler/greedy.jl") -include("code_gen/type.jl") include("code_gen/utils.jl") include("code_gen/tape_machine.jl") include("code_gen/function.jl") diff --git a/src/code_gen/function.jl b/src/code_gen/function.jl index f8dbce0..6a957e8 100644 --- a/src/code_gen/function.jl +++ b/src/code_gen/function.jl @@ -17,50 +17,28 @@ in your top level. ## Keyword Arguments -`closures_size` (default=0 (off)): The size of closures to use in the main generated code. This specifies the size of code blocks across which the compiler cannot optimize. For sufficiently large functions, a larger value means longer compile times but potentially faster execution time. +`closures_size` (default=0 (off)): The size of closures to use in the main generated code. This specifies the size of code blocks across which the + compiler cannot optimize. For sufficiently large functions, a larger value means longer compile times but potentially faster execution time. + **Note** that the actually used closure size might be different than the one passed here, since the function automatically chooses a size that + is close to a n-th root of the total number of loc, based off the given size. """ function get_compute_function( graph::DAG, instance, machine::Machine, context_module::Module; closures_size=0 ) tape = gen_tape(graph, instance, machine, context_module) - assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...) - code = gen_function_body(tape; closures_size=closures_size) + code = gen_function_body(tape, context_module; closures_size=closures_size) + assign_inputs = Expr(:block, expr_from_fc.(tape.input_assign_code)...) - functionId = to_var_name(UUIDs.uuid1(rng[1])) - resSym = eval(_gen_access_expr(entry_device(tape.machine), tape.outputSymbol)) - expr = # - Expr( + function_id = to_var_name(UUIDs.uuid1(rng[1])) + res_sym = tape.output_symbol + expr = Expr( :function, # function definition Expr( - :call, - Symbol("compute_$functionId"), - Expr(:(::), :data_input, input_type(instance)), + :call, Symbol("compute_$function_id"), Expr(:(::), :input, input_type(instance)) ), # function name and parameters - Expr(:block, assignInputs, code, Expr(:return, resSym)), # function body + Expr(:block, assign_inputs, code, Expr(:return, res_sym)), # function body ) return RuntimeGeneratedFunction(@__MODULE__, context_module, expr) end - -""" - execute( - graph::DAG, - instance, - machine::Machine, - input, - context_module::Module - ) - -Execute the code of the given `graph` on the given input values. - -This is essentially shorthand for -```julia -tape = gen_tape(graph, instance, machine, context_module) -return execute_tape(tape, input) -``` -""" -function execute(graph::DAG, instance, machine::Machine, input, context_module::Module) - tape = gen_tape(graph, instance, machine, context_module) - return execute_tape(tape, input) -end diff --git a/src/code_gen/tape_machine.jl b/src/code_gen/tape_machine.jl index 7bdc584..4609dfb 100644 --- a/src/code_gen/tape_machine.jl +++ b/src/code_gen/tape_machine.jl @@ -1,110 +1,78 @@ -# TODO: do this with macros -function call_fc( - fc::FunctionCall{VectorT,0}, cache::Dict{Symbol,Any} -) where {VectorT<:SVector{1}} - cache[fc.return_symbol] = fc.func(cache[fc.arguments[1]]) - return nothing -end - -function call_fc( - fc::FunctionCall{VectorT,1}, cache::Dict{Symbol,Any} -) where {VectorT<:SVector{1}} - cache[fc.return_symbol] = fc.func(fc.value_arguments[1], cache[fc.arguments[1]]) - return nothing -end - -function call_fc( - fc::FunctionCall{VectorT,0}, cache::Dict{Symbol,Any} -) where {VectorT<:SVector{2}} - cache[fc.return_symbol] = fc.func(cache[fc.arguments[1]], cache[fc.arguments[2]]) - return nothing -end - -function call_fc( - fc::FunctionCall{VectorT,1}, cache::Dict{Symbol,Any} -) where {VectorT<:SVector{2}} - cache[fc.return_symbol] = fc.func( - fc.value_arguments[1], cache[fc.arguments[1]], cache[fc.arguments[2]] - ) - return nothing -end - -function call_fc(fc::FunctionCall{VectorT,1}, cache::Dict{Symbol,Any}) where {VectorT} - cache[fc.return_symbol] = fc.func( - fc.value_arguments[1], getindex.(Ref(cache), fc.arguments)... - ) - return nothing -end - -""" - call_fc(fc::FunctionCall, cache::Dict{Symbol, Any}) - -Execute the given [`FunctionCall`](@ref) on the dictionary. - -Several more specialized versions of this function exist to reduce vector unrolling work for common cases. -""" -function call_fc(fc::FunctionCall{VectorT,M}, cache::Dict{Symbol,Any}) where {VectorT,M} - cache[fc.return_symbol] = fc.func( - fc.value_arguments..., getindex.(Ref(cache), fc.arguments)... - ) - return nothing -end - -function expr_from_fc(fc::FunctionCall{VectorT,0}) where {VectorT} - func_call = Expr( - :call, fc.func, eval.(_gen_access_expr.(Ref(fc.device), fc.arguments))... - ) - access_expr = eval(gen_access_expr(fc)) - +function expr_from_fc(fc::FunctionCall{VAL_T,<:Function}) where {VAL_T} + if length(fc) == 1 + func_call = Expr(:call, fc.func, fc.value_arguments[1]..., fc.arguments[1]...) + else + # TBW; dispatch to device specific vectorization + throw("unimplemented") + end + access_expr = gen_access_expr(fc) return Expr(:(=), access_expr, func_call) end -""" - expr_from_fc(fc::FunctionCall) +function expr_from_fc(fc::FunctionCall{VAL_T,Expr}) where {VAL_T} + @assert length(fc) == 1 && isempty(fc.value_arguments[1]) "function call assigning an expression cannot be vectorized and cannot contain value arguments\n$fc" + + fc_expr = Expr( + :block, + gen_local_init(fc), + fc.func, # anonymous function code block + Expr( # return the symbols + :return, + ( + if length(fc.return_symbols[1]) == 1 + fc.return_symbols[1][1] + else + Expr(:tuple, fc.return_symbols[1]...) + end + ), + ), + ) -For a given function call, return an expression evaluating it. -""" -function expr_from_fc(fc::FunctionCall{VectorT,M}) where {VectorT,M} func_call = Expr( - :call, - fc.func, - fc.value_arguments..., - eval.(_gen_access_expr.(Ref(fc.device), fc.arguments))..., + :call, # call + #wrap_in_let_statement( # wrap in let statement to prevent boxing of local variables + Expr( + :->, # anonymous function + Expr( + :tuple, # anonymous function arguments + #fc.arguments[1]..., + ), + fc_expr, # anonymous function code block + ), + #fc.arguments[1], + #), + #fc.arguments[1]..., # runtime arguments passed to the anonymous function call ) - access_expr = eval(gen_access_expr(fc)) + access_expr = gen_access_expr(fc) return Expr(:(=), access_expr, func_call) end """ gen_input_assignment_code( input_symbols::Dict{String, Vector{Symbol}}, - instance::AbstractProblemInstance, + instance::Any, machine::Machine, + input_type::Type, context_module::Module ) Return a `Vector{Expr}` doing the input assignments from the given `problem_input` onto the `input_symbols`. """ function gen_input_assignment_code( - input_symbols::Dict{String,Vector{Symbol}}, - instance, - machine::Machine, - context_module::Module, + input_symbols::Dict{String,Vector{Symbol}}, instance, machine::Machine ) assign_inputs = Vector{FunctionCall}() for (name, symbols) in input_symbols - # make a function for this, since we can't use anonymous functions in the FunctionCall - for symbol in symbols device = entry_device(machine) fc = FunctionCall( - context_module.eval(Expr(:->, :x, input_expr(instance, name, :x))), - SVector{0,Any}(), - SVector{1,Symbol}(:input), - symbol, - Nothing, + Expr(:(=), symbol, input_expr(instance, name, :input)), + (), + Symbol[:input], + Symbol[symbol], + Type[Nothing], device, ) @@ -116,100 +84,145 @@ function gen_input_assignment_code( end """ - gen_function_body(tape::Tape; closures_size) + gen_function_body(tape::Tape, context_module::Module; closures_size) Generate the function body from the given [`Tape`](@ref). ## Keyword Arguments -`closures_size`: The size of closures to generate (in lines of code). Closures introduce function barriers in the function body, preventing some optimizations by the compiler and therefore greatly reducing compile time. A value of 1 or less will disable the use of closures entirely. +`closures_size`: The size of closures to generate (in lines of code). Closures introduce function barriers + in the function body, preventing some optimizations by the compiler and therefore greatly reducing + compile time. A value of 0 will disable the use of closures entirely. """ -function gen_function_body(tape::Tape; closures_size::Int) +function gen_function_body(tape::Tape, context_module::Module; closures_size::Int) + # only need to annotate types later when using closures + types = infer_types!(tape, context_module) + if closures_size > 1 - # only need to annotate types later when using closures - infer_types!(tape) + s = log(closures_size, length(tape.schedule)) + closures_depth = ceil(Int, s) # tend towards more levels/smaller closures + closures_size = ceil(Int, length(tape.schedule)^(1 / closures_depth)) end - fc_vec = tape.schedule + @debug "generating function body with closure size $closures_size" - if (closures_size <= 1) + return _gen_function_body( + tape.schedule, types, tape.machine, context_module; closures_size=closures_size + ) +end + +function _gen_function_body( + fc_vec::AbstractVector{FunctionCall}, + type_dict::Dict{Symbol,Type}, + machine::Machine, + context_module::Module; + closures_size=0, +) + @debug "generating function body from $(length(fc_vec)) function calls with closure size $closures_size" + if closures_size <= 1 || closures_size >= length(fc_vec) return Expr(:block, expr_from_fc.(fc_vec)...) end - closures = Vector{Expr}() # iterate from end to beginning - # this helps because we can collect all undefined arguments to the closures that have to be returned somewhere earlier + # this helps because we can collect all undefined arguments to the closures that have to be defined somewhere earlier undefined_argument_symbols = Set{Symbol}() # the final return symbol is the return of the entire generated function, it always has to be returned - push!(undefined_argument_symbols, eval(gen_access_expr(fc_vec[end]))) + push!(undefined_argument_symbols, gen_access_expr(fc_vec[end])) + closured_fc_vec = FunctionCall[] for i in length(fc_vec):(-closures_size):1 e = i - b = max(i - closures_size, 1) + b = max(i - closures_size + 1, 1) code_block = fc_vec[b:e] - # collect `local var` statements that need to exist before the closure starts - local_inits = gen_local_init.(code_block) + closure_fc = _closure_fc( + code_block, type_dict, machine, undefined_argument_symbols, context_module + ) - return_symbols = eval.(gen_access_expr.(code_block)) + pushfirst!(closured_fc_vec, closure_fc) + end - ret_symbols_set = Set(return_symbols) - for fc in code_block - for arg in fc.arguments - symbol = eval(_gen_access_expr(fc.device, arg)) + return _gen_function_body( + closured_fc_vec, type_dict, machine, context_module; closures_size=closures_size + ) +end - # symbol won't be defined if it is first calculated in the closure - # so don't add it to the arguments in this case - if !(symbol in ret_symbols_set) - push!(undefined_argument_symbols, symbol) - end +""" + _closure_fc( + code_block::AbstractVector{FunctionCall}, + types::Dict{Symbol,Type}, + machine::Machine, + undefined_argument_symbols::Set{Symbol}, + context_module::Module, + ) + +From the given function calls, make and return 2 function calls representing all of them together. 2 function calls are necessary, one for setting up the anonymous +function and the second for calling it. +The undefined_argument_symbols is the set of all Symbols that need to be returned if available inside the code_block. They get updated inside this function. +""" +function _closure_fc( + code_block::AbstractVector{FunctionCall}, + types::Dict{Symbol,Type}, + machine::Machine, + undefined_argument_symbols::Set{Symbol}, + context_module::Module, +) + return_symbols = Symbol[] + for s in + Iterators.flatten(Iterators.flatten(getfield.(code_block, Ref(:return_symbols)))) + push!(return_symbols, s) + end + + ret_symbols_set = Set(return_symbols) + arg_symbols_set = Set{Symbol}() + for fc in code_block + for symbol in Iterators.flatten(fc.arguments) + # symbol won't be defined if it is first calculated in the closure + # so don't add it to the arguments in this case + if !(symbol in ret_symbols_set) + push!(undefined_argument_symbols, symbol) + + push!(arg_symbols_set, symbol) end end + end - intersect!(ret_symbols_set, undefined_argument_symbols) - return_symbols = Symbol[ret_symbols_set...] + setdiff!(arg_symbols_set, ret_symbols_set) + intersect!(ret_symbols_set, undefined_argument_symbols) - closure = Expr( - :block, - Expr( - :(=), - Expr(:tuple, return_symbols...), - Expr( - :call, # call to the following closure (no arguments) - Expr( # create the closure: () -> code block; return (locals) - :->, - :(), # closure arguments (none) - Expr( # actual function body of the closure - :block, - local_inits..., # declare local variables with type information inside the closure - expr_from_fc.(code_block)..., - Expr(:return, Expr(:tuple, return_symbols...)), - ), - ), - ), - ), - ) + arg_symbols_t = [arg_symbols_set...] + ret_symbols_t = [ret_symbols_set...] - setdiff!(undefined_argument_symbols, ret_symbols_set) + ret_types = (getindex.(Ref(types), ret_symbols_t)) - # combine to one closure call, including all the local inits and the actual call to the closure - pushfirst!(closures, closure) - end + fc_expr = Expr( # actual function body of the closure + :block, + expr_from_fc.(code_block)..., # no return statement necessary, will be done via capture and local init + ) + + fc = FunctionCall( + fc_expr, + (), + Symbol[arg_symbols_t...], + ret_symbols_t, + ret_types, + entry_device(machine), + ) - return Expr(:block, closures...) + setdiff!(undefined_argument_symbols, ret_symbols_set) + + return fc end """ gen_tape( graph::DAG, - instance::AbstractProblemInstance, + instance::Any, machine::Machine, context_module::Module, scheduler::AbstractScheduler = GreedyScheduler() ) Generate the code for a given graph. The return value is a [`Tape`](@ref). - -See also: [`execute`](@ref), [`execute_tape`](@ref) """ function gen_tape( graph::DAG, @@ -218,6 +231,7 @@ function gen_tape( context_module::Module, scheduler::AbstractScheduler=GreedyScheduler(), ) + @debug "generating tape" schedule = schedule_dag(scheduler, graph, machine) function_body = lower(schedule, machine) @@ -234,35 +248,8 @@ function gen_tape( # get outSymbol outSym = Symbol(to_var_name(get_exit_node(graph).id)) - assign_inputs = gen_input_assignment_code(input_syms, instance, machine, context_module) - - return Tape{input_type(instance)}( - assign_inputs, function_body, input_syms, outSym, instance, machine - ) -end - -""" - execute_tape(tape::Tape, input::Input) where {Input} - -Execute the given tape with the given input. - -!!! warning - This is very slow and might not work. This is to be majorly revamped. -""" -function execute_tape(tape::Tape, input) - cache = Dict{Symbol,Any}() - cache[:input] = input - # simply execute all the code snippets here - @assert typeof(input) <: input_type(tape.instance) "expected tape input type to fit $(input_type(tape.instance)) but got $(typeof(input))" - - compute_code = tape.schedule - - for function_call in tape.inputAssignCode - call_fc(function_call, cache) - end - for function_call in compute_code - call_fc(function_call, cache) - end + assign_inputs = gen_input_assignment_code(input_syms, instance, machine) - return cache[tape.outputSymbol] + INPUT_T = input_type(instance) + return Tape{INPUT_T}(assign_inputs, function_body, outSym, instance, machine) end diff --git a/src/code_gen/type.jl b/src/code_gen/type.jl index d08bf46..45e653a 100644 --- a/src/code_gen/type.jl +++ b/src/code_gen/type.jl @@ -1,19 +1,81 @@ +""" + FunctionCall{VAL_T<:Tuple,FUNC_T<:Union{Function,Expr}} + +Representation of a function call. Contains the function to call (or an expression of a value to assign), +value arguments of type `VAL_T`, argument symbols, the return symbol(s) and type(s) and the device to execute on. + +To support vectorization, i.e., calling the same function on multiple inputs (SIMD), the value arguments, arguments, +and return symbols are each vectors of the actual inputs. In the non-vectorized case, these `Vector`s simply always +have length 1. For this common case, a special constructor exists which automatically wraps each of these arguments +in a `Vector`. + +## Type Arguments +- `VAL_T<:Tuple`: A tuple of all the value arguments that are passed to the function when it's called. +- `FUNC_T<:Union{Function, Expr}`: The type of the function. `Function` is the default, but in some cases, an `Expr` + of a value can be necessary to assign to the return symbol. In this case, no arguments are allowed. + +## Fields +- `func::FUNC_T`: The function to be called, or an expression containing a value to assign to the return_symbol. +- `value_arguments::Vector{VAL_T}`: The value arguments for the function call. These are passed *first* to the + function, in the order given here. The `Vector` contains the tuple of value arguments for each vectorization + member. +- `arguments::Vector{Vector{Symbol}}`: The first vector represents the vectorization, the second layer represents the + symbols that will be passed as arguments to the function call. +- `return_symbols::Vector{Vector{Symbol}}`: As with the arguments, the first vector level represents the vectorization, + the second represents the symbols that the results of the function call are assigned to. For most function calls, + there is only one return symbol. When using closures when generating a function body for a [`Tape`](@ref), the + option to have multiple return symbols is necessary. +- `return_types::Vector{<:Type}`: The types of the function call with the arguments provided. This field only contains + one level of Vector, because it is required that a `FunctionCall` is type stable, and therefore, the types of the + return symbols have to be equal for all members of a vectorization. The return type is initially set to `Nothing` + and later inferred and assigned by [`infer_types!`](@ref). +- `device::AbstractDevice`: The device that this function call is scheduled on. +""" +mutable struct FunctionCall{VAL_T<:Tuple,FUNC_T<:Union{Function,Expr}} + func::FUNC_T + value_arguments::Vector{VAL_T} # tuple of value arguments for the function call, will be prepended to the other arguments + arguments::Vector{Vector{Symbol}} # symbols of the inputs to the function call + return_symbols::Vector{Vector{Symbol}} # the return symbols + return_types::Vector{<:Type} # the return type of the function call(s); there can only be one return type since we require type stability + device::AbstractDevice +end +function FunctionCall( + func::Union{Function,Expr}, + value_arguments::VAL_T, + arguments::Vector{Symbol}, + return_symbol::Vector{Symbol}, + return_types::Vector{<:Type}, + device::AbstractDevice, +) where {VAL_T<:Tuple} + # convenience constructor for function calls that do not use vectorization, which is most of the use cases + @assert length(return_types) == 0 || length(return_types) == length(return_symbol) "number of return types '$(length(return_types))' does not match the number of return symbols '$(length(return_symbol))'" + @assert func isa Function || length(value_arguments) == 0 "no value arguments are allowed for a an Expr FunctionCall, but got '$value_arguments'" + return FunctionCall( + func, [value_arguments], [arguments], [return_symbol], return_types, device + ) +end """ Tape{INPUT} -TODO: update docs -- `INPUT` the input type of the problem instance +Lowered representation of a computation, generated from a [`DAG`](@ref) through [`gen_tape`](@ref). + +- `INPUT` the input type of the problem instance, see also the interface function [`input_type`](@ref) -- `code::Vector{Expr}`: The julia expression containing the code for the whole graph. -- `inputSymbols::Dict{String, Vector{Symbol}}`: A dictionary of symbols mapping the names of the input nodes of the graph to the symbols their inputs should be provided on. -- `outputSymbol::Symbol`: The symbol of the final calculated value +## Fields +- `input_assign_code::Vector{FunctionCall}`: The [`FunctionCall`](@ref)s representing the input assignments, + mapping part of the input of the computation to each DAG entry node. These functions are generated using + the interface function [`input_expr`](@ref). +- `schedule::Vector{FunctionCall}`: The [`FunctionCall`](@ref)s representing the function body of the computation. + There is one function call for each node in the [`DAG`](@ref). +- `output_symbol::Symbol`: The symbol of the final calculated value, which is returned. +- `instance::Any`: The instance that this tape is generated for. +- `machine::Machine`: The [`Machine`](@ref) that this tape is generated for. """ struct Tape{INPUT} - inputAssignCode::Vector{FunctionCall} + input_assign_code::Vector{FunctionCall} schedule::Vector{FunctionCall} - inputSymbols::Dict{String,Vector{Symbol}} - outputSymbol::Symbol + output_symbol::Symbol instance::Any machine::Machine end diff --git a/src/code_gen/utils.jl b/src/code_gen/utils.jl index 8379b16..e477c68 100644 --- a/src/code_gen/utils.jl +++ b/src/code_gen/utils.jl @@ -1,28 +1,43 @@ +function Base.length(fc::FunctionCall) + @assert length(fc.value_arguments) == length(fc.arguments) == length(fc.return_symbols) "function call length is undefined, got '$(length(fc.value_arguments))' tuples of value arguments, '$(length(fc.arguments))' tuples of arguments, and '$(length(return_symbols))' return symbols" + return length(fc.value_arguments) +end + """ - infer_types!(schedule::Vector{FunctionCall}) + infer_types!(tape::Tape, context_module::Module) -Infer the result type of each function call in the given schedule. Returns a dictionary with the result type for each [`Node`](@ref). This assumes that each node has only one statically inferrable return type and will throw an exceptin otherwise. -This also assumes that the given `Vector` contains a topological ordering of its nodes, such as returned by a call to [`schedule_dag`](@ref). +Infer the result type of each function call in the given tape. Returns a dictionary with the result type for each symbol and sets each function call's return_types. +This function assumes that each [`FunctionCall`](@ref) has only one statically inferrable return type and will throw an exception otherwise. """ -function infer_types!(tape::Tape) +function infer_types!(tape::Tape, context_module::Module) known_result_types = Dict{Symbol,Type}() # the only initially known type known_result_types[:input] = input_type(tape.instance) - for fc in tape.inputAssignCode - res_type = result_type(fc, known_result_types) - fc.return_type = res_type - known_result_types[fc.return_symbol] = res_type + for fc in tape.input_assign_code + res_types = result_types(fc, known_result_types, context_module) + fc.return_types = res_types + for (s, t) in Iterators.zip( + Iterators.flatten(fc.return_symbols), + Iterators.cycle(res_types, length(fc.return_symbols)), + ) + known_result_types[s] = t + end end for fc in tape.schedule - res_type = result_type(fc, known_result_types) - fc.return_type = res_type - known_result_types[fc.return_symbol] = res_type + res_types = result_types(fc, known_result_types, context_module) + fc.return_types = res_types + for (s, t) in Iterators.zip( + Iterators.flatten(fc.return_symbols), + Iterators.cycle(res_types, length(fc.return_symbols)), + ) + known_result_types[s] = t + end end - return nothing + return known_result_types end """ @@ -37,7 +52,7 @@ function lower(schedule::Vector{Node}, machine::Machine) if (node isa DataTaskNode && length(children(node)) == 0) push!(calls, get_init_function_call(node, entry_device(machine))) else - push!(calls, get_function_call(node)...) + push!(calls, get_function_call(node)) end end diff --git a/src/devices/impl.jl b/src/devices/impl.jl index 05c0049..7392074 100644 --- a/src/devices/impl.jl +++ b/src/devices/impl.jl @@ -31,17 +31,74 @@ end """ gen_access_expr(fc::FunctionCall) -Dispatch from the given [`FunctionCall`](@ref) to the interface function `_gen_access_expr`(@ref). +Return the """ -function gen_access_expr(fc::FunctionCall) - return _gen_access_expr(fc.device, fc.return_symbol) +function gen_access_expr(fc::FunctionCall{VAL_T}) where {VAL_T} + if length(fc.return_types) != 1 + # general case + vec = Expr[] + for ret_symbols in fc.return_symbols + push!(vec, unroll_symbol_vector(ret_symbols)) + end + if length(vec) > 1 + return unroll_symbol_vector(vec) + else + return vec[1] + end + end + + # single return value per function + vec = Symbol[] + for ret_symbols in fc.return_symbols + push!(vec, ret_symbols[1]) + end + if length(vec) > 1 + return unroll_symbol_vector(vec) + else + return vec[1] + end end """ gen_local_init(fc::FunctionCall) -Dispatch from the given [`FunctionCall`](@ref) to the interface function `_gen_local_init`(@ref). +Dispatch from the given [`FunctionCall`](@ref) to the lower-level function [`_gen_local_init`](@ref). """ function gen_local_init(fc::FunctionCall) - return _gen_local_init(fc, fc.device) + return Expr( + :block, + _gen_local_init.( + Iterators.flatten(fc.return_symbols), + Iterators.cycle(fc.return_types, length(fc.return_symbols)), + )..., + ) +end + +""" + _gen_local_init(symbol::Symbol, type::Type) + +Return an `Expr` that initializes the symbol in the local scope. +The result looks like `local ::`. +""" +function _gen_local_init(symbol::Symbol, type::Type) + return Expr(:local, symbol, :(::), Symbol(type)) +end + +""" + wrap_in_let_statement(expr, symbols) + +For a given expression and a collection of symbols, generate a let statement that wraps the expression in a let statement with all the symbols, like +`let =, ..., = end` +""" +@inline function wrap_in_let_statement(expr, symbols) + return Expr(:let, Expr(:block, _gen_let_statement.(symbols)...), expr) +end + +""" + _gen_let_statement(symbol::Symbol) + +Return a let-`Expr` like ` = `. +""" +function _gen_let_statement(symbol::Symbol) + return Expr(:(=), symbol, symbol) end diff --git a/src/devices/interface.jl b/src/devices/interface.jl index a4a08be..04ab6e0 100644 --- a/src/devices/interface.jl +++ b/src/devices/interface.jl @@ -46,23 +46,6 @@ Interface function that must be implemented for every subtype of [`AbstractDevic """ function measure_device! end -""" - _gen_access_expr(device::AbstractDevice, symbol::Symbol) - -Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref). -Return an `Expr` or `QuoteNode` accessing the variable identified by [`symbol`]. -""" -function _gen_access_expr end - -""" - _gen_local_init(fc::FunctionCall, device::AbstractDevice) - -Interface function that must be implemented for every subtype of [`AbstractDevice`](@ref). -Return an `Expr` or `QuoteNode` that initializes the access expression returned by [`_gen_access_expr`](@ref) in the local scope. -This expression may be empty. For local variables it should be `local ::`. -""" -function _gen_local_init end - """ kernel(gpu_type::Type{<:AbstractGPU}, graph::DAG, instance) diff --git a/src/devices/numa/impl.jl b/src/devices/numa/impl.jl index e3a2ec8..9e65828 100644 --- a/src/devices/numa/impl.jl +++ b/src/devices/numa/impl.jl @@ -38,26 +38,3 @@ function get_devices(deviceType::Type{T}; verbose::Bool=false) where {T<:NumaNod return devices end - -""" - _gen_access_expr(device::NumaNode, symbol::Symbol) - -Interface implementation, dispatched to from [`gen_access_expr`](@ref). -""" -function _gen_access_expr(::NumaNode, symbol::Symbol) - # TODO rewrite these with Expr instead of quote node - s = Symbol("data_$symbol") - quote_node = Meta.parse(":($s)") - return quote_node -end - -""" - _gen_local_init(fc::FunctionCall, device::NumaNode) - -Interface implementation, dispatched to from [`gen_local_init`](@ref). -""" -function _gen_local_init(fc::FunctionCall, ::NumaNode) - s = Symbol("data_$(fc.return_symbol)") - quote_node = Expr(:local, s, :(::), Symbol(fc.return_type)) # TODO: figure out how to get type info for this local variable - return quote_node -end diff --git a/src/estimator/global_metric.jl b/src/estimator/global_metric.jl index b9e817f..ac81d2d 100644 --- a/src/estimator/global_metric.jl +++ b/src/estimator/global_metric.jl @@ -5,40 +5,40 @@ Representation of a [`DAG`](@ref)'s cost as estimated by the [`GlobalMetricEstim # Fields: `.data`: The total data transfer.\\ -`.computeEffort`: The total compute effort.\\ -`.computeIntensity`: The compute intensity, will always equal `.computeEffort / .data`. +`.compute_effort`: The total compute effort.\\ +`.compute_intensity`: The compute intensity, will always equal `.compute_effort / .data`. !!! note - Note that the `computeIntensity` doesn't necessarily make sense in the context of only operation costs. + Note that the `compute_intensity` doesn't necessarily make sense in the context of only operation costs. It will still work as intended when adding/subtracting to/from a `graph_cost` estimate. """ const CDCost = NamedTuple{ - (:data, :computeEffort, :computeIntensity),Tuple{Float64,Float64,Float64} + (:data, :compute_effort, :compute_intensity),Tuple{Float64,Float64,Float64} } function Base.:+(cost1::CDCost, cost2::CDCost)::CDCost d = cost1.data + cost2.data - ce = computeEffort = cost1.computeEffort + cost2.computeEffort - return (data=d, computeEffort=ce, computeIntensity=ce / d)::CDCost + ce = compute_effort = cost1.compute_effort + cost2.compute_effort + return (data=d, compute_effort=ce, compute_intensity=ce / d)::CDCost end function Base.:-(cost1::CDCost, cost2::CDCost)::CDCost d = cost1.data - cost2.data - ce = computeEffort = cost1.computeEffort - cost2.computeEffort - return (data=d, computeEffort=ce, computeIntensity=ce / d)::CDCost + ce = compute_effort = cost1.compute_effort - cost2.compute_effort + return (data=d, compute_effort=ce, compute_intensity=ce / d)::CDCost end function Base.isless(cost1::CDCost, cost2::CDCost)::Bool - return cost1.data + cost1.computeEffort < cost2.data + cost2.computeEffort + return cost1.data + cost1.compute_effort < cost2.data + cost2.compute_effort end function Base.zero(type::Type{CDCost}) - return (data=0.0, computeEffort=0.0, computeIntensity=0.0)::CDCost + return (data=0.0, compute_effort=0.0, compute_intensity=0.0)::CDCost end function Base.typemax(type::Type{CDCost}) - return (data=Inf, computeEffort=Inf, computeIntensity=0.0)::CDCost + return (data=Inf, compute_effort=Inf, compute_intensity=0.0)::CDCost end """ @@ -56,8 +56,8 @@ function graph_cost(estimator::GlobalMetricEstimator, graph::DAG) properties = get_properties(graph) return ( data=properties.data, - computeEffort=properties.computeEffort, - computeIntensity=properties.computeIntensity, + compute_effort=properties.compute_effort, + compute_intensity=properties.compute_intensity, )::CDCost end @@ -67,8 +67,8 @@ function operation_effect( s = length(operation.input) - 1 return ( data=s * -data(task(operation.input[1])), - computeEffort=s * -compute_effort(task(operation.input[1])), - computeIntensity=typeof(operation.input) <: DataTaskNode ? 0.0 : Inf, + compute_effort=s * -compute_effort(task(operation.input[1])), + compute_intensity=typeof(operation.input) <: DataTaskNode ? 0.0 : Inf, )::CDCost end @@ -78,7 +78,7 @@ function operation_effect( s::Float64 = length(parents(operation.input)) - 1 d::Float64 = s * data(task(operation.input)) ce::Float64 = s * compute_effort(task(operation.input)) - return (data=d, computeEffort=ce, computeIntensity=ce / d)::CDCost + return (data=d, compute_effort=ce, compute_intensity=ce / d)::CDCost end function String(::GlobalMetricEstimator) diff --git a/src/graph/interface.jl b/src/graph/interface.jl index 0fa74cc..97d092d 100644 --- a/src/graph/interface.jl +++ b/src/graph/interface.jl @@ -7,7 +7,7 @@ See also: [`DAG`](@ref), [`pop_operation!`](@ref) """ function push_operation!(graph::DAG, operation::Operation) # 1.: Add the operation to the DAG - push!(graph.operationsToApply, operation) + push!(graph.operations_to_apply, operation) return nothing end @@ -21,10 +21,10 @@ See also: [`DAG`](@ref), [`push_operation!`](@ref) """ function pop_operation!(graph::DAG) # 1.: Remove the operation from the appliedChain of the DAG - if !isempty(graph.operationsToApply) - pop!(graph.operationsToApply) - elseif !isempty(graph.appliedOperations) - appliedOp = pop!(graph.appliedOperations) + if !isempty(graph.operations_to_apply) + pop!(graph.operations_to_apply) + elseif !isempty(graph.applied_operations) + appliedOp = pop!(graph.applied_operations) revert_operation!(graph, appliedOp) else error("No more operations to pop!") @@ -38,7 +38,8 @@ end Return `true` if [`pop_operation!`](@ref) is possible, `false` otherwise. """ -can_pop(graph::DAG) = !isempty(graph.operationsToApply) || !isempty(graph.appliedOperations) +can_pop(graph::DAG) = + !isempty(graph.operations_to_apply) || !isempty(graph.applied_operations) """ reset_graph!(graph::DAG) diff --git a/src/graph/mute.jl b/src/graph/mute.jl index f927967..2018274 100644 --- a/src/graph/mute.jl +++ b/src/graph/mute.jl @@ -56,7 +56,7 @@ function _insert_node!(graph::DAG, node::Node; track=true, invalidate_cache=true if (!invalidate_cache) return node end - push!(graph.dirtyNodes, node) + push!(graph.dirty_nodes, node) return node end @@ -110,8 +110,8 @@ function _insert_edge!( invalidate_operation_caches!(graph, node1) invalidate_operation_caches!(graph, node2) - push!(graph.dirtyNodes, node1) - push!(graph.dirtyNodes, node2) + push!(graph.dirty_nodes, node1) + push!(graph.dirty_nodes, node2) return nothing end @@ -145,7 +145,7 @@ function _remove_node!(graph::DAG, node::Node; track=true, invalidate_cache=true end invalidate_operation_caches!(graph, node) - delete!(graph.dirtyNodes, node) + delete!(graph.dirty_nodes, node) return nothing end @@ -207,10 +207,10 @@ function _remove_edge!( invalidate_operation_caches!(graph, node1) invalidate_operation_caches!(graph, node2) if (node1 in graph) - push!(graph.dirtyNodes, node1) + push!(graph.dirty_nodes, node1) end if (node2 in graph) - push!(graph.dirtyNodes, node2) + push!(graph.dirty_nodes, node2) end return removed_node_index @@ -235,10 +235,10 @@ Invalidate the operation caches for a given [`NodeReduction`](@ref). This deletes the operation from the graph's possible operations and from the involved nodes' own operation caches. """ function invalidate_caches!(graph::DAG, operation::NodeReduction) - delete!(graph.possibleOperations, operation) + delete!(graph.possible_operations, operation) for node in operation.input - node.nodeReduction = missing + node.node_reduction = missing end return nothing @@ -252,11 +252,11 @@ Invalidate the operation caches for a given [`NodeSplit`](@ref). This deletes the operation from the graph's possible operations and from the involved nodes' own operation caches. """ function invalidate_caches!(graph::DAG, operation::NodeSplit) - delete!(graph.possibleOperations, operation) + delete!(graph.possible_operations, operation) # delete the operation from all caches of nodes involved in the operation # for node split there is only one node - operation.input.nodeSplit = missing + operation.input.node_split = missing return nothing end @@ -267,11 +267,11 @@ end Invalidate the operation caches of the given node through calls to the respective [`invalidate_caches!`](@ref) functions. """ function invalidate_operation_caches!(graph::DAG, node::ComputeTaskNode) - if !ismissing(node.nodeReduction) - invalidate_caches!(graph, node.nodeReduction) + if !ismissing(node.node_reduction) + invalidate_caches!(graph, node.node_reduction) end - if !ismissing(node.nodeSplit) - invalidate_caches!(graph, node.nodeSplit) + if !ismissing(node.node_split) + invalidate_caches!(graph, node.node_split) end return nothing end @@ -282,11 +282,11 @@ end Invalidate the operation caches of the given node through calls to the respective [`invalidate_caches!`](@ref) functions. """ function invalidate_operation_caches!(graph::DAG, node::DataTaskNode) - if !ismissing(node.nodeReduction) - invalidate_caches!(graph, node.nodeReduction) + if !ismissing(node.node_reduction) + invalidate_caches!(graph, node.node_reduction) end - if !ismissing(node.nodeSplit) - invalidate_caches!(graph, node.nodeSplit) + if !ismissing(node.node_split) + invalidate_caches!(graph, node.node_split) end return nothing end diff --git a/src/graph/print.jl b/src/graph/print.jl index 118679c..d1fc5af 100644 --- a/src/graph/print.jl +++ b/src/graph/print.jl @@ -28,14 +28,14 @@ function Base.show(io::IO, graph::DAG) print(io, " Nodes: ") nodeDict = Dict{Type,Int64}() - noEdges = 0 + number_of_edges = 0 for node in graph.nodes if haskey(nodeDict, typeof(task(node))) nodeDict[typeof(task(node))] = nodeDict[typeof(task(node))] + 1 else nodeDict[typeof(task(node))] = 1 end - noEdges += length(parents(node)) + number_of_edges += length(parents(node)) end if length(graph.nodes) <= 20 @@ -58,9 +58,9 @@ function Base.show(io::IO, graph::DAG) end end println(io) - println(io, " Edges: ", noEdges) + println(io, " Edges: ", number_of_edges) properties = get_properties(graph) - println(io, " Total Compute Effort: ", properties.computeEffort) + println(io, " Total Compute Effort: ", properties.compute_effort) println(io, " Total Data Transfer: ", properties.data) - return println(io, " Total Compute Intensity: ", properties.computeIntensity) + return println(io, " Total Compute Intensity: ", properties.compute_intensity) end diff --git a/src/graph/properties.jl b/src/graph/properties.jl index 9515820..0f52b12 100644 --- a/src/graph/properties.jl +++ b/src/graph/properties.jl @@ -8,7 +8,7 @@ function get_properties(graph::DAG) apply_all!(graph) # TODO: tests stop working without the if condition, which means there is probably a bug in the lazy evaluation and in the tests - if (graph.properties.computeEffort <= 0.0) + if (graph.properties.compute_effort <= 0.0) graph.properties = GraphProperties(graph) end @@ -51,5 +51,5 @@ end Return the number of operations applied to the graph. """ function operation_stack_length(graph::DAG) - return length(graph.appliedOperations) + length(graph.operationsToApply) + return length(graph.applied_operations) + length(graph.operations_to_apply) end diff --git a/src/graph/type.jl b/src/graph/type.jl index b5ceeb4..a5f8715 100644 --- a/src/graph/type.jl +++ b/src/graph/type.jl @@ -7,8 +7,8 @@ A struct storing all possible operations on a [`DAG`](@ref). To get the [`PossibleOperations`](@ref) on a [`DAG`](@ref), use [`get_operations`](@ref). """ mutable struct PossibleOperations - nodeReductions::Set{NodeReduction} - nodeSplits::Set{NodeSplit} + node_reductions::Set{NodeReduction} + node_splits::Set{NodeSplit} end """ @@ -24,16 +24,16 @@ mutable struct DAG nodes::Set{Union{DataTaskNode,ComputeTaskNode}} # The operations currently applied to the set of nodes - appliedOperations::Stack{AppliedOperation} + applied_operations::Stack{AppliedOperation} # The operations not currently applied but part of the current state of the DAG - operationsToApply::Deque{Operation} + operations_to_apply::Deque{Operation} # The possible operations at the current state of the DAG - possibleOperations::PossibleOperations + possible_operations::PossibleOperations # The set of nodes whose possible operations need to be reevaluated - dirtyNodes::Set{Union{DataTaskNode,ComputeTaskNode}} + dirty_nodes::Set{Union{DataTaskNode,ComputeTaskNode}} # "snapshot" system: keep track of added/removed nodes/edges since last snapshot # these are muted in insert_node! etc. diff --git a/src/graph/validate.jl b/src/graph/validate.jl index ad95d19..c32a138 100644 --- a/src/graph/validate.jl +++ b/src/graph/validate.jl @@ -30,21 +30,21 @@ function is_valid(graph::DAG) @assert is_valid(graph, node) end - for op in graph.operationsToApply + for op in graph.operations_to_apply @assert is_valid(graph, op) end - for nr in graph.possibleOperations.nodeReductions + for nr in graph.possible_operations.node_reductions @assert is_valid(graph, nr) end - for ns in graph.possibleOperations.nodeSplits + for ns in graph.possible_operations.node_splits @assert is_valid(graph, ns) end - for node in graph.dirtyNodes + for node in graph.dirty_nodes @assert node in graph "Dirty Node is not part of the graph!" - @assert ismissing(node.nodeReduction) "Dirty Node has a NodeReduction!" - @assert ismissing(node.nodeSplit) "Dirty Node has a NodeSplit!" + @assert ismissing(node.node_reduction) "Dirty Node has a NodeReduction!" + @assert ismissing(node.node_split) "Dirty Node has a NodeSplit!" end @assert is_connected(graph) "Graph is not connected!" diff --git a/src/models/interface.jl b/src/models/interface.jl index 1672e40..81bf46c 100644 --- a/src/models/interface.jl +++ b/src/models/interface.jl @@ -1,44 +1,26 @@ - -""" - AbstractModel - -Base type for all models. From this, [`AbstractProblemInstance`](@ref)s can be constructed. - -See also: [`problem_instance`](@ref) """ -abstract type AbstractModel end + input_type(problem_instance) -""" - problem_instance(::AbstractModel, ::Vararg) - -Interface function that must be implemented for any implementation of [`AbstractModel`](@ref). This function should return a specific [`AbstractProblemInstance`](@ref) given some parameters. -""" -function problem_instance end +Return the input type for a specific `problem_instance`. This can be a specific type or a supertype for which all child types are expected to be implemented. -""" - AbstractProblemInstance - -Base type for problem instances. An object of this type of a corresponding [`AbstractModel`](@ref) should uniquely identify a problem instance of that model. -""" -abstract type AbstractProblemInstance end - -""" - input_type(problem::AbstractProblemInstance) - -Return the input type for a specific [`AbstractProblemInstance`](@ref). This can be a specific type or a supertype for which all child types are expected to work. +For more details on the `problem_instance`, please refer to the documentation. """ function input_type end """ - graph(::AbstractProblemInstance) + graph(problem_instance) + +Generate the [`DAG`](@ref) for the given `problem_instance`. Every entry node (see [`get_entry_nodes`](@ref)) to the graph must have a name set. Implement [`input_expr`](@ref) to return a valid expression for each of those names. -Generate the [`DAG`](@ref) for the given [`AbstractProblemInstance`](@ref). Every entry node (see [`get_entry_nodes`](@ref)) to the graph must have a name set. Implement [`input_expr`](@ref) to return a valid expression for each of those names. +For more details on the `problem_instance`, please refer to the documentation. """ function graph end """ - input_expr(instance::AbstractProblemInstance, name::String, input_symbol::Symbol) + input_expr(problem_instance, name::String, input_symbol::Symbol) + +For the given `problem_instance`, the entry node name, and the symbol of the problem input (where a variable of type `input_type(...)` will exist), return an `Expr` that gets that specific input value from the input symbol. -For the given [`AbstractProblemInstance`](@ref), the entry node name, and the symbol of the problem input (where a variable of type `input_type(...)` will exist), return an `Expr` that gets that specific input value from the input symbol. +For more details on the `problem_instance`, please refer to the documentation. """ function input_expr end diff --git a/src/node/type.jl b/src/node/type.jl index fa0b400..8e608dd 100644 --- a/src/node/type.jl +++ b/src/node/type.jl @@ -27,8 +27,8 @@ Any node that transfers data and does no computation. `.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\ `.children`: A vector of tuples of the node's children (i.e. nodes that this one depends on) and their indices, indicating their order in the resulting function call passed to the task.\\ `.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\ -`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\ -`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\ +`.node_reduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\ +`.node_split`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\ `.name`: The name of this node for entry nodes into the graph ([`is_entry_node`](@ref)) to reliably assign the inputs to the correct nodes when executing.\\ """ mutable struct DataTaskNode{TaskType<:AbstractDataTask} <: Node @@ -44,10 +44,10 @@ mutable struct DataTaskNode{TaskType<:AbstractDataTask} <: Node # the NodeReduction involving this node, if it exists # Can't use the NodeReduction type here because it's not yet defined - nodeReduction::Union{Operation,Missing} + node_reduction::Union{Operation,Missing} # the NodeSplit involving this node, if it exists - nodeSplit::Union{Operation,Missing} + node_split::Union{Operation,Missing} # for input nodes we need a name for the node to distinguish between them name::String @@ -63,8 +63,8 @@ Any node that computes a result from inputs using an [`AbstractComputeTask`](@re `.parents`: A vector of the node's parents (i.e. nodes that depend on this one).\\ `.children`: A vector of tuples with the node's children (i.e. nodes that this one depends on) and their index, used to order the arguments for the [`AbstractComputeTask`](@ref).\\ `.id`: The node's id. Improves the speed of comparisons and is used as a unique identifier.\\ -`.nodeReduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\ -`.nodeSplit`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\ +`.node_reduction`: Either this node's [`NodeReduction`](@ref) or `missing`, if none. There can only be at most one.\\ +`.node_split`: Either this node's [`NodeSplit`](@ref) or `missing`, if none. There can only be at most one.\\ `.device`: The Device this node has been scheduled on by a [`Scheduler`](@ref). """ mutable struct ComputeTaskNode{TaskType<:AbstractComputeTask} <: Node @@ -73,8 +73,8 @@ mutable struct ComputeTaskNode{TaskType<:AbstractComputeTask} <: Node children::Vector{Tuple{Node,Int}} id::Base.UUID - nodeReduction::Union{Operation,Missing} - nodeSplit::Union{Operation,Missing} + node_reduction::Union{Operation,Missing} + node_split::Union{Operation,Missing} # the device this node is assigned to execute on device::Union{AbstractDevice,Missing} diff --git a/src/node/validate.jl b/src/node/validate.jl index 04bf03f..8676f0b 100644 --- a/src/node/validate.jl +++ b/src/node/validate.jl @@ -22,11 +22,11 @@ function is_valid_node(graph::DAG, node::Node) @assert node in child.parents "Node is not a parent of its child!" end - #=if !ismissing(node.nodeReduction) - @assert is_valid(graph, node.nodeReduction) + #=if !ismissing(node.node_reduction) + @assert is_valid(graph, node.node_reduction) end - if !ismissing(node.nodeSplit) - @assert is_valid(graph, node.nodeSplit) + if !ismissing(node.node_split) + @assert is_valid(graph, node.node_split) end=# return true diff --git a/src/operation/apply.jl b/src/operation/apply.jl index 0de5eea..2a17d19 100644 --- a/src/operation/apply.jl +++ b/src/operation/apply.jl @@ -4,15 +4,15 @@ Apply all unapplied operations in the DAG. Is automatically called in all functions that require the latest state of the [`DAG`](@ref). """ function apply_all!(graph::DAG) - while !isempty(graph.operationsToApply) + while !isempty(graph.operations_to_apply) # get next operation to apply from front of the deque - op = popfirst!(graph.operationsToApply) + op = popfirst!(graph.operations_to_apply) # apply it appliedOp = apply_operation!(graph, op) - # push to the end of the appliedOperations deque - push!(graph.appliedOperations, appliedOp) + # push to the end of the applied_operations deque + push!(graph.applied_operations, appliedOp) end return nothing end @@ -182,13 +182,14 @@ function node_split!( get_snapshot_diff(graph) n1_parents = copy(parents(n1)) + local parent_indices = Dict() n1_children = copy(children(n1)) for parent in n1_parents - _remove_edge!(graph, n1, parent) + parent_indices[parent] = _remove_edge!(graph, n1, parent) end for (child, index) in n1_children - _remove_edge!(graph, child, n1) + @assert index == _remove_edge!(graph, child, n1) end _remove_node!(graph, n1) @@ -196,10 +197,10 @@ function node_split!( n_copy = copy(n1) _insert_node!(graph, n_copy) - _insert_edge!(graph, n_copy, parent) + _insert_edge!(graph, n_copy, parent, parent_indices[parent]) for (child, index) in n1_children - _insert_edge!(graph, child, n_copy) + _insert_edge!(graph, child, n_copy, index) end end diff --git a/src/operation/clean.jl b/src/operation/clean.jl index eec9d3f..8c415d8 100644 --- a/src/operation/clean.jl +++ b/src/operation/clean.jl @@ -7,7 +7,7 @@ Find node reductions involving the given node. The function pushes the found [`N """ function find_reductions!(graph::DAG, node::Node) # there can only be one reduction per node, avoid adding duplicates - if !ismissing(node.nodeReduction) + if !ismissing(node.node_reduction) return nothing end @@ -30,14 +30,14 @@ function find_reductions!(graph::DAG, node::Node) if reductionVector !== nothing nr = NodeReduction(reductionVector) - push!(graph.possibleOperations.nodeReductions, nr) + push!(graph.possible_operations.node_reductions, nr) for node in reductionVector - if !ismissing(node.nodeReduction) + if !ismissing(node.node_reduction) # it can happen that the dirty node becomes part of an existing NodeReduction and overrides those ones now - # this is only a problem insofar the existing NodeReduction has to be deleted and replaced also in the possibleOperations - invalidate_caches!(graph, node.nodeReduction) + # this is only a problem insofar the existing NodeReduction has to be deleted and replaced also in the possible_operations + invalidate_caches!(graph, node.node_reduction) end - node.nodeReduction = nr + node.node_reduction = nr end end @@ -50,14 +50,14 @@ end Find the node split of the given node. The function pushes the found [`NodeSplit`](@ref) (if any) everywhere it needs to be and returns nothing. """ function find_splits!(graph::DAG, node::Node) - if !ismissing(node.nodeSplit) + if !ismissing(node.node_split) return nothing end if (can_split(node)) ns = NodeSplit(node) - push!(graph.possibleOperations.nodeSplits, ns) - node.nodeSplit = ns + push!(graph.possible_operations.node_splits, ns) + node.node_split = ns end return nothing diff --git a/src/operation/find.jl b/src/operation/find.jl index 6179cec..977a86b 100644 --- a/src/operation/find.jl +++ b/src/operation/find.jl @@ -9,7 +9,7 @@ Insert the given node reduction into its input nodes' operation caches. This is """ function insert_operation!(nr::NodeReduction) for n in nr.input - n.nodeReduction = nr + n.node_reduction = nr end return nothing end @@ -20,30 +20,30 @@ end Insert the given node split into its input node's operation cache. This is thread-safe. """ function insert_operation!(ns::NodeSplit) - ns.input.nodeSplit = ns + ns.input.node_split = ns return nothing end """ - nr_insertion!(operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}}) + nr_insertion!(operations::PossibleOperations, node_reductions::Vector{Vector{NodeReduction}}) Insert the node reductions into the graph and the nodes' caches. Employs multithreading for speedup. """ function nr_insertion!( - operations::PossibleOperations, nodeReductions::Vector{Vector{NodeReduction}} + operations::PossibleOperations, node_reductions::Vector{Vector{NodeReduction}} ) total_len = 0 - for vec in nodeReductions + for vec in node_reductions total_len += length(vec) end - sizehint!(operations.nodeReductions, total_len) + sizehint!(operations.node_reductions, total_len) - t = @task for vec in nodeReductions - union!(operations.nodeReductions, Set(vec)) + t = @task for vec in node_reductions + union!(operations.node_reductions, Set(vec)) end schedule(t) - @threads for vec in nodeReductions + @threads for vec in node_reductions for op in vec insert_operation!(op) end @@ -55,25 +55,25 @@ function nr_insertion!( end """ - ns_insertion!(operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplits}}) + ns_insertion!(operations::PossibleOperations, node_splits::Vector{Vector{NodeSplits}}) Insert the node splits into the graph and the nodes' caches. Employs multithreading for speedup. """ function ns_insertion!( - operations::PossibleOperations, nodeSplits::Vector{Vector{NodeSplit}} + operations::PossibleOperations, node_splits::Vector{Vector{NodeSplit}} ) total_len = 0 - for vec in nodeSplits + for vec in node_splits total_len += length(vec) end - sizehint!(operations.nodeSplits, total_len) + sizehint!(operations.node_splits, total_len) - t = @task for vec in nodeSplits - union!(operations.nodeSplits, Set(vec)) + t = @task for vec in node_splits + union!(operations.node_splits, Set(vec)) end schedule(t) - @threads for vec in nodeSplits + @threads for vec in node_splits for op in vec insert_operation!(op) end @@ -124,11 +124,11 @@ function generate_operations(graph::DAG) insert!(trie, candidate) end - nodeReductions = collect(trie) + node_reductions = collect(trie) - for nrVec in nodeReductions - # parent sets are ordered and any node can only be part of one nodeReduction, so a NodeReduction is uniquely identifiable by its first element - # this prevents duplicate nodeReductions being generated + for nrVec in node_reductions + # parent sets are ordered and any node can only be part of one node_reduction, so a NodeReduction is uniquely identifiable by its first element + # this prevents duplicate node_reductions being generated lock(checkedNodesLock) if (nrVec[1] in checkedNodes) unlock(checkedNodesLock) @@ -144,7 +144,7 @@ function generate_operations(graph::DAG) # launch thread for node reduction insertion # remove duplicates - nr_task = @spawn nr_insertion!(graph.possibleOperations, generatedReductions) + nr_task = @spawn nr_insertion!(graph.possible_operations, generatedReductions) # find possible node splits @threads for node in nodeArray @@ -154,9 +154,9 @@ function generate_operations(graph::DAG) end # launch thread for node split insertion - ns_task = @spawn ns_insertion!(graph.possibleOperations, generatedSplits) + ns_task = @spawn ns_insertion!(graph.possible_operations, generatedSplits) - empty!(graph.dirtyNodes) + empty!(graph.dirty_nodes) wait(nr_task) wait(ns_task) diff --git a/src/operation/get.jl b/src/operation/get.jl index 3294459..ff13398 100644 --- a/src/operation/get.jl +++ b/src/operation/get.jl @@ -10,12 +10,12 @@ Return the [`PossibleOperations`](@ref) of the graph at the current state. function get_operations(graph::DAG) apply_all!(graph) - if isempty(graph.possibleOperations) + if isempty(graph.possible_operations) generate_operations(graph) end - clean_node!.(Ref(graph), graph.dirtyNodes) - empty!(graph.dirtyNodes) + clean_node!.(Ref(graph), graph.dirty_nodes) + empty!(graph.dirty_nodes) - return graph.possibleOperations + return graph.possible_operations end diff --git a/src/operation/iterate.jl b/src/operation/iterate.jl index 1a9ece4..4eb6727 100644 --- a/src/operation/iterate.jl +++ b/src/operation/iterate.jl @@ -5,10 +5,10 @@ _POIteratorStateType = NamedTuple{ } @inline function Base.iterate( - possibleOperations::PossibleOperations + possible_operations::PossibleOperations )::Union{Nothing,_POIteratorStateType} for fieldname in _POSSIBLE_OPERATIONS_FIELDS - iterator = iterate(getfield(possibleOperations, fieldname)) + iterator = iterate(getfield(possible_operations, fieldname)) if (!isnothing(iterator)) return (result=iterator[1], state=(fieldname, iterator[2])) end @@ -18,10 +18,10 @@ _POIteratorStateType = NamedTuple{ end @inline function Base.iterate( - possibleOperations::PossibleOperations, state + possible_operations::PossibleOperations, state )::Union{Nothing,_POIteratorStateType} newStateSym = state[1] - newStateIt = iterate(getfield(possibleOperations, newStateSym), state[2]) + newStateIt = iterate(getfield(possible_operations, newStateSym), state[2]) if !isnothing(newStateIt) return (result=newStateIt[1], state=(newStateSym, newStateIt[2])) end @@ -31,7 +31,7 @@ end while index <= length(_POSSIBLE_OPERATIONS_FIELDS) newStateSym = _POSSIBLE_OPERATIONS_FIELDS[index] - newStateIt = iterate(getfield(possibleOperations, newStateSym)) + newStateIt = iterate(getfield(possible_operations, newStateSym)) if !isnothing(newStateIt) return (result=newStateIt[1], state=(newStateSym, newStateIt[2])) end diff --git a/src/operation/print.jl b/src/operation/print.jl index 274531a..400e32d 100644 --- a/src/operation/print.jl +++ b/src/operation/print.jl @@ -4,14 +4,14 @@ Print a string representation of the set of possible operations to io. """ function Base.show(io::IO, ops::PossibleOperations) - print(io, length(ops.nodeReductions)) + print(io, length(ops.node_reductions)) println(io, " Node Reductions: ") - for nr in ops.nodeReductions + for nr in ops.node_reductions println(io, " - ", nr) end - print(io, length(ops.nodeSplits)) + print(io, length(ops.node_splits)) println(io, " Node Splits: ") - for ns in ops.nodeSplits + for ns in ops.node_splits println(io, " - ", ns) end end diff --git a/src/operation/utility.jl b/src/operation/utility.jl index 0424798..263c926 100644 --- a/src/operation/utility.jl +++ b/src/operation/utility.jl @@ -4,7 +4,7 @@ Return whether `operations` is empty, i.e. all of its fields are empty. """ function Base.isempty(operations::PossibleOperations) - return isempty(operations.nodeReductions) && isempty(operations.nodeSplits) + return isempty(operations.node_reductions) && isempty(operations.node_splits) end """ @@ -14,8 +14,8 @@ Return a named tuple with the number of each of the operation types as a named t """ function Base.length(operations::PossibleOperations) return ( - nodeReductions=length(operations.nodeReductions), - nodeSplits=length(operations.nodeSplits), + node_reductions=length(operations.node_reductions), + node_splits=length(operations.node_splits), ) end @@ -25,7 +25,7 @@ end Delete the given node reduction from the possible operations. """ function Base.delete!(operations::PossibleOperations, op::NodeReduction) - delete!(operations.nodeReductions, op) + delete!(operations.node_reductions, op) return operations end @@ -35,7 +35,7 @@ end Delete the given node split from the possible operations. """ function Base.delete!(operations::PossibleOperations, op::NodeSplit) - delete!(operations.nodeSplits, op) + delete!(operations.node_splits, op) return operations end diff --git a/src/optimization/random_walk.jl b/src/optimization/random_walk.jl index 62d02c2..ab381a7 100644 --- a/src/optimization/random_walk.jl +++ b/src/optimization/random_walk.jl @@ -15,7 +15,7 @@ function optimize_step!(optimizer::RandomWalkOptimizer, graph::DAG) operations = get_operations(graph) if sum(length(operations)) == 0 && - length(graph.appliedOperations) + length(graph.operationsToApply) == 0 + length(graph.applied_operations) + length(graph.operations_to_apply) == 0 # in case there are zero operations possible at all on the graph return false end @@ -29,11 +29,11 @@ function optimize_step!(optimizer::RandomWalkOptimizer, graph::DAG) # choose one of split/reduce option = rand(r, 1:2) - if option == 1 && !isempty(operations.nodeReductions) - push_operation!(graph, rand(r, collect(operations.nodeReductions))) + if option == 1 && !isempty(operations.node_reductions) + push_operation!(graph, rand(r, collect(operations.node_reductions))) return true - elseif option == 2 && !isempty(operations.nodeSplits) - push_operation!(graph, rand(r, collect(operations.nodeSplits))) + elseif option == 2 && !isempty(operations.node_splits) + push_operation!(graph, rand(r, collect(operations.node_splits))) return true end else diff --git a/src/optimization/reduce.jl b/src/optimization/reduce.jl index edd49d2..36bdb0e 100644 --- a/src/optimization/reduce.jl +++ b/src/optimization/reduce.jl @@ -14,14 +14,14 @@ function optimize_step!(optimizer::ReductionOptimizer, graph::DAG) return false end - push_operation!(graph, first(operations.nodeReductions)) + push_operation!(graph, first(operations.node_reductions)) return true end function fixpoint_reached(optimizer::ReductionOptimizer, graph::DAG) operations = get_operations(graph) - return isempty(operations.nodeReductions) + return isempty(operations.node_reductions) end function optimize_to_fixpoint!(optimizer::ReductionOptimizer, graph::DAG) diff --git a/src/optimization/split.jl b/src/optimization/split.jl index 0d38dae..b610494 100644 --- a/src/optimization/split.jl +++ b/src/optimization/split.jl @@ -14,14 +14,14 @@ function optimize_step!(optimizer::SplitOptimizer, graph::DAG) return false end - push_operation!(graph, first(operations.nodeSplits)) + push_operation!(graph, first(operations.node_splits)) return true end function fixpoint_reached(optimizer::SplitOptimizer, graph::DAG) operations = get_operations(graph) - return isempty(operations.nodeSplits) + return isempty(operations.node_splits) end function optimize_to_fixpoint!(optimizer::SplitOptimizer, graph::DAG) diff --git a/src/properties/create.jl b/src/properties/create.jl index 72f7b9a..5d0de61 100644 --- a/src/properties/create.jl +++ b/src/properties/create.jl @@ -5,7 +5,11 @@ Create an empty [`GraphProperties`](@ref) object. """ function GraphProperties() return ( - data=0.0, computeEffort=0.0, computeIntensity=0.0, noNodes=0, noEdges=0 + data=0.0, + compute_effort=0.0, + compute_intensity=0.0, + number_of_nodes=0, + number_of_edges=0, )::GraphProperties end @@ -41,10 +45,10 @@ function GraphProperties(graph::DAG) return ( data=d, - computeEffort=ce, - computeIntensity=(d == 0) ? 0.0 : ce / d, - noNodes=length(graph.nodes), - noEdges=ed, + compute_effort=ce, + compute_intensity=(d == 0) ? 0.0 : ce / d, + number_of_nodes=length(graph.nodes), + number_of_edges=ed, )::GraphProperties end @@ -66,9 +70,9 @@ function GraphProperties(diff::Diff) return ( data=d, - computeEffort=ce, - computeIntensity=(d == 0) ? 0.0 : ce / d, - noNodes=length(diff.addedNodes) - length(diff.removedNodes), - noEdges=length(diff.addedEdges) - length(diff.removedEdges), + compute_effort=ce, + compute_intensity=(d == 0) ? 0.0 : ce / d, + number_of_nodes=length(diff.addedNodes) - length(diff.removedNodes), + number_of_edges=length(diff.addedEdges) - length(diff.removedEdges), )::GraphProperties end diff --git a/src/properties/type.jl b/src/properties/type.jl index 9fea9c6..afa9f19 100644 --- a/src/properties/type.jl +++ b/src/properties/type.jl @@ -3,14 +3,14 @@ Representation of a [`DAG`](@ref)'s properties. -# Fields: -`.data`: The total data transfer.\\ -`.computeEffort`: The total compute effort.\\ -`.computeIntensity`: The compute intensity, will always equal `.computeEffort / .data`.\\ -`.noNodes`: Number of [`Node`](@ref)s.\\ -`.noEdges`: Number of [`Edge`](@ref)s. +## Fields: +- `data::Float64`: The total data transfer. +- `compute_effort::Float64`: The total compute effort. +- `compute_intensity::Float64`: The compute intensity, will always equal `compute_effort / data`. +- `number_of_nodes::Int`: Number of [`Node`](@ref)s. +- `number_of_edges::Int`: Number of [`Edge`](@ref)s. """ const GraphProperties = NamedTuple{ - (:data, :computeEffort, :computeIntensity, :noNodes, :noEdges), + (:data, :compute_effort, :compute_intensity, :number_of_nodes, :number_of_edges), Tuple{Float64,Float64,Float64,Int,Int}, } diff --git a/src/properties/utility.jl b/src/properties/utility.jl index 1760d2a..ceddc76 100644 --- a/src/properties/utility.jl +++ b/src/properties/utility.jl @@ -7,14 +7,14 @@ Also take care to keep consistent compute intensity. function Base.:-(prop1::GraphProperties, prop2::GraphProperties) return ( data=prop1.data - prop2.data, - computeEffort=prop1.computeEffort - prop2.computeEffort, - computeIntensity=if (prop1.data - prop2.data == 0) + compute_effort=prop1.compute_effort - prop2.compute_effort, + compute_intensity=if (prop1.data - prop2.data == 0) 0.0 else - (prop1.computeEffort - prop2.computeEffort) / (prop1.data - prop2.data) + (prop1.compute_effort - prop2.compute_effort) / (prop1.data - prop2.data) end, - noNodes=prop1.noNodes - prop2.noNodes, - noEdges=prop1.noEdges - prop2.noEdges, + number_of_nodes=prop1.number_of_nodes - prop2.number_of_nodes, + number_of_edges=prop1.number_of_edges - prop2.number_of_edges, )::GraphProperties end @@ -27,28 +27,28 @@ Also take care to keep consistent compute intensity. function Base.:+(prop1::GraphProperties, prop2::GraphProperties) return ( data=prop1.data + prop2.data, - computeEffort=prop1.computeEffort + prop2.computeEffort, - computeIntensity=if (prop1.data + prop2.data == 0) + compute_effort=prop1.compute_effort + prop2.compute_effort, + compute_intensity=if (prop1.data + prop2.data == 0) 0.0 else - (prop1.computeEffort + prop2.computeEffort) / (prop1.data + prop2.data) + (prop1.compute_effort + prop2.compute_effort) / (prop1.data + prop2.data) end, - noNodes=prop1.noNodes + prop2.noNodes, - noEdges=prop1.noEdges + prop2.noEdges, + number_of_nodes=prop1.number_of_nodes + prop2.number_of_nodes, + number_of_edges=prop1.number_of_edges + prop2.number_of_edges, )::GraphProperties end """ -(prop::GraphProperties) -Unary negation of the graph properties. `.computeIntensity` will not be negated because `.data` and `.computeEffort` both are. +Unary negation of the graph properties. `.compute_intensity` will not be negated because `.data` and `.compute_effort` both are. """ function Base.:-(prop::GraphProperties) return ( data=-prop.data, - computeEffort=-prop.computeEffort, - computeIntensity=prop.computeIntensity, # no negation here! - noNodes=-prop.noNodes, - noEdges=-prop.noEdges, + compute_effort=-prop.compute_effort, + compute_intensity=prop.compute_intensity, # no negation here! + number_of_nodes=-prop.number_of_nodes, + number_of_edges=-prop.number_of_edges, )::GraphProperties end diff --git a/src/scheduler/greedy.jl b/src/scheduler/greedy.jl index 63795f1..61077b1 100644 --- a/src/scheduler/greedy.jl +++ b/src/scheduler/greedy.jl @@ -7,14 +7,16 @@ A greedy implementation of a scheduler, creating a topological ordering of nodes struct GreedyScheduler <: AbstractScheduler end function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine) - node_queue = PriorityQueue{Node,Int}() + node_dict = Dict{Node,Int}() # dictionary of nodes with the number of not-yet-scheduled children + node_stack = Stack{Node}() # stack of currently schedulable nodes, i.e., nodes with all of their children already scheduled + # the stack makes sure that closely related nodes will be scheduled one after another # use a priority equal to the number of unseen children -> 0 are nodes that can be added for node in get_entry_nodes(graph) - enqueue!(node_queue, node => 0) + push!(node_stack, node) end - schedule = Vector{Node}() + schedule = Node[] sizehint!(schedule, length(graph.nodes)) # keep an accumulated cost of things scheduled to this device so far @@ -24,9 +26,8 @@ function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine) end local node - while !isempty(node_queue) - @assert peek(node_queue)[2] == 0 - node = dequeue!(node_queue) + while !isempty(node_stack) + node = pop!(node_stack) # assign the device with lowest accumulated cost to the node (if it's a compute node) if (isa(node, ComputeTaskNode)) @@ -37,15 +38,20 @@ function schedule_dag(::GreedyScheduler, graph::DAG, machine::Machine) push!(schedule, node) + # find all parent's priority, reduce by one if in the node_dict + # if it reaches zero, push onto node_stack for parent in parents(node) - # reduce the priority of all parents by one - if (!haskey(node_queue, parent)) - enqueue!(node_queue, parent => length(children(parent)) - 1) + parents_prio = get(node_dict, parent, length(children(parent))) - 1 + if parents_prio == 0 + delete!(node_dict, parent) + push!(node_stack, parent) else - node_queue[parent] = node_queue[parent] - 1 + node_dict[parent] = parents_prio end end end + @assert isempty(node_dict) "found unschedulable nodes, this most likely means the graph has a cycle" + return schedule end diff --git a/src/scheduler/interface.jl b/src/scheduler/interface.jl index 1dcbfbf..b452a36 100644 --- a/src/scheduler/interface.jl +++ b/src/scheduler/interface.jl @@ -1,11 +1,3 @@ - -""" - AbstractScheduler - -Abstract base type for scheduler implementations. The scheduler is used to assign each node to a device and create a topological ordering of tasks. -""" -abstract type AbstractScheduler end - """ schedule_dag(::Scheduler, ::DAG, ::Machine) diff --git a/src/scheduler/type.jl b/src/scheduler/type.jl index 008f677..7eb4062 100644 --- a/src/scheduler/type.jl +++ b/src/scheduler/type.jl @@ -1,16 +1,6 @@ -using StaticArrays - """ - FunctionCall{N} + AbstractScheduler -Type representing a function call with `N` parameters. Contains the function to call, argument symbols, the return symbol and the device to execute on. +Abstract base type for scheduler implementations. The scheduler is used to assign each node to a device and create a topological ordering of tasks. """ -mutable struct FunctionCall{VectorType<:AbstractVector,N} - func::Function - # TODO: this should be a tuple - value_arguments::SVector{N,Any} # value arguments for the function call, will be prepended to the other arguments - arguments::VectorType # symbols of the inputs to the function call - return_symbol::Symbol - return_type::Type - device::AbstractDevice -end +abstract type AbstractScheduler end diff --git a/src/task/compute.jl b/src/task/compute.jl index 4355142..b76e238 100644 --- a/src/task/compute.jl +++ b/src/task/compute.jl @@ -2,16 +2,17 @@ using StaticArrays """ get_function_call(n::Node) - get_function_call(t::AbstractTask, device::AbstractDevice, in_symbols::AbstractVector, out_symbol::Symbol) + get_function_call(t::AbstractTask, device::AbstractDevice, in_symbols::NTuple{}, out_symbol::Symbol) -For a node or a task together with necessary information, return a vector of [`FunctionCall`](@ref)s for the computation of the node or task. - -For ordinary compute or data tasks the vector will contain exactly one element. +For a node or a task together with necessary information, a [`FunctionCall`](@ref)s for the computation of the node or task. """ function get_function_call( - t::CompTask, device::AbstractDevice, in_symbols::AbstractVector, out_symbol::Symbol -) where {CompTask<:AbstractComputeTask} - return [FunctionCall(compute, SVector{1,Any}(t), in_symbols, out_symbol, Any, device)] + t::AbstractComputeTask, + device::AbstractDevice, + in_symbols::NTuple{N,Symbol}, + out_symbol::Symbol, +) where {N} + return FunctionCall(compute, (t,), [in_symbols...], [out_symbol], [Any], device) end function get_function_call(node::ComputeTaskNode) @@ -21,74 +22,117 @@ function get_function_call(node::ComputeTaskNode) # make sure the node is sorted so the arguments keep their order sort_node!(node) - if (length(node.children) <= 800) - #only use an SVector when there are few children - return get_function_call( - node.task, - node.device, - SVector{length(node.children),Symbol}( - Symbol.(to_var_name.(getfield.(getindex.(children(node), 1), :id)))... - ), - Symbol(to_var_name(node.id)), - ) - else - return get_function_call( - node.task, - node.device, - Symbol.(to_var_name.(getfield.(getindex.(children(node), 1), :id))), - Symbol(to_var_name(node.id)), - ) - end + return get_function_call( + node.task, + node.device, + (Symbol.(to_var_name.(getfield.(getindex.(children(node), 1), :id)))...,), + Symbol(to_var_name(node.id)), + ) end function get_function_call(node::DataTaskNode) @assert length(children(node)) == 1 "trying to call get_function_call on a data task node that has $(length(node.children)) children instead of 1\nchildren: $(node.children)" # TODO: dispatch to device implementations generating the copy commands - return [ - FunctionCall( - unpack_identity, - SVector{0,Any}(), - SVector{1,Symbol}(Symbol(to_var_name(first(children(node))[1].id))), - Symbol(to_var_name(node.id)), - Any, - first(children(node))[1].device, - ), - ] + return FunctionCall( + identity, + (), + [Symbol(to_var_name(first(children(node))[1].id))], + [Symbol(to_var_name(node.id))], + [Any], + first(children(node))[1].device, + ) end function get_init_function_call(node::DataTaskNode, device::AbstractDevice) @assert isempty(children(node)) "trying to call get_init_function_call on a data task node that is not an entry node." return FunctionCall( - unpack_identity, - SVector{0,Any}(), - SVector{1,Symbol}(Symbol("$(to_var_name(node.id))_in")), - Symbol(to_var_name(node.id)), - Any, + identity, + (), + [Symbol("$(to_var_name(node.id))_in")], + [Symbol(to_var_name(node.id))], + [Any], device, ) end -function result_type(fc::FunctionCall, known_res_types::Dict{Symbol,Type}) - argument_types = ( - typeof.(fc.value_arguments)..., getindex.(Ref(known_res_types), fc.arguments)... - ) - types = Base.return_types(fc.func, argument_types) +_value_argument_types(fc::FunctionCall) = typeof.(fc.value_arguments[1]) +function _argument_types(known_res_types::Dict{Symbol,Type}, fc::FunctionCall) + return getindex.(Ref(known_res_types), fc.arguments[1]) +end +function _validate_result_types(fc::FunctionCall, types, arg_types) + N_RET = length(fc.return_types) if length(types) > 1 throw( - "failure during type inference: function call $fc with argument types $(argument_types) is type unstable, possible return types: $types", + "failure during type inference: function call $(fc.func) with argument types $(arg_types) is type unstable, possible return types: $types", ) end if isempty(types) throw( - "failure during type inference: function call $fc with argument types $(argument_types) has no return types, this is likely because no method matches the arguments", + "failure during type inference: function call $(fc.func) with argument types $(arg_types) has no return types, this is likely because no method matches the arguments", ) end if types[1] == Any - @warn "inferred return type 'Any' in task $fc with argument types $(argument_types)" + @warn "inferred return type 'Any' in task $fc with argument types $(arg_types)" + end + + if (N_RET == 1) + return nothing + end + + if !(types[1] isa Tuple) || length(types[1].parameters) != N_RET + throw( + "failure during type inference: function call $(fc.func) was expected to return a Tuple with $N_RET elements, but returns $(types[1])", + ) + end + return nothing +end + +function result_types( + fc::FunctionCall{VAL_T,F_T}, known_res_types::Dict{Symbol,Type}, context_module::Module +) where {VAL_T,F_T<:Function} + arg_types = (_value_argument_types(fc)..., _argument_types(known_res_types, fc)...) + types = Base.return_types(fc.func, arg_types) + + _validate_result_types(fc, types, arg_types) + + N_RET = length(fc.return_types) + if (N_RET == 1) + return [types[1]] + end + + return [types[1].parameters...] +end + +function result_types( + fc::FunctionCall{VAL_T,Expr}, known_res_types::Dict{Symbol,Type}, context_module::Module +) where {VAL_T} + arg_types = _argument_types(known_res_types, fc) + ret_expr = Expr( + :call, + Base.return_types, # return types call + Expr( # function argument to return_types + :->, # anonymous function + Expr( + :tuple, # anonymous function arguments + fc.arguments[1]..., + ), + fc.func, # anonymous function code block + ), + Expr(:tuple, arg_types...), # types arguments to return_types + ) + types = context_module.eval(ret_expr) + + #@info "evaluation of expression\n$ret_expr\ngives\n$types" + + _validate_result_types(fc, types, arg_types) + + N_RET = length(fc.return_types) + if (N_RET == 1) + return [types[1]] end - return types[1] + return [types[1].parameters...] end diff --git a/src/trie.jl b/src/trie.jl index 23cf7ee..e949124 100644 --- a/src/trie.jl +++ b/src/trie.jl @@ -46,24 +46,7 @@ Insert the given node into the trie. The depth is used to iterate through the tr """ function insert_helper!( trie::NodeIdTrie{NodeType}, node::NodeType, depth::Int -) where {TaskType<:AbstractDataTask,NodeType<:DataTaskNode{TaskType}} - if (length(children(node)) == depth) - push!(trie.value, node) - return nothing - end - - depth = depth + 1 - id = node.children[depth][1].id - - if (!haskey(trie.children, id)) - trie.children[id] = NodeIdTrie{NodeType}() - end - return insert_helper!(trie.children[id], node, depth) -end -# TODO: Remove this workaround once https://github.com/JuliaLang/julia/issues/54404 is fixed in julia 1.10+ -function insert_helper!( - trie::NodeIdTrie{NodeType}, node::NodeType, depth::Int -) where {TaskType<:AbstractComputeTask,NodeType<:ComputeTaskNode{TaskType}} +) where {NodeType<:Node} if (length(children(node)) == depth) push!(trie.value, node) return nothing @@ -83,18 +66,7 @@ end Insert the given node into the trie. It's sorted by its type in the first layer, then by its children in the following layers. """ -function Base.insert!( - trie::NodeTrie, node::NodeType -) where {TaskType<:AbstractDataTask,NodeType<:DataTaskNode{TaskType}} - if (!haskey(trie.children, NodeType)) - trie.children[NodeType] = NodeIdTrie{NodeType}() - end - return insert_helper!(trie.children[NodeType], node, 0) -end -# TODO: Remove this workaround once https://github.com/JuliaLang/julia/issues/54404 is fixed in julia 1.10+ -function Base.insert!( - trie::NodeTrie, node::NodeType -) where {TaskType<:AbstractComputeTask,NodeType<:ComputeTaskNode{TaskType}} +function Base.insert!(trie::NodeTrie, node::NodeType) where {NodeType<:Node} if (!haskey(trie.children, NodeType)) trie.children[NodeType] = NodeIdTrie{NodeType}() end diff --git a/src/utils.jl b/src/utils.jl index 862eda9..b8593e0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,14 +5,6 @@ Function with no arguments, returns nothing, does nothing. Useful for noop [`Fun """ @inline noop() = nothing -""" - unpack_identity(x::SVector) - -Function taking an `SVector`, returning it unpacked. -""" -@inline unpack_identity(x::SVector{1,<:Any}) = x[1] -@inline unpack_identity(x) = x - """ bytes_to_human_readable(bytes) @@ -80,18 +72,18 @@ function mem(graph::DAG) size += mem(n) end - size += sizeof(graph.appliedOperations) - size += sizeof(graph.operationsToApply) + size += sizeof(graph.applied_operations) + size += sizeof(graph.operations_to_apply) - size += sizeof(graph.possibleOperations) - for op in graph.possibleOperations.nodeReductions + size += sizeof(graph.possible_operations) + for op in graph.possible_operations.node_reductions size += mem(op) end - for op in graph.possibleOperations.nodeSplits + for op in graph.possible_operations.node_splits size += mem(op) end - size += Base.summarysize(graph.dirtyNodes; exclude=Union{Node}) + size += Base.summarysize(graph.dirty_nodes; exclude=Union{Node}) return size += sizeof(diff) end @@ -118,17 +110,10 @@ end Return the given vector as single String without quotation marks or brackets. """ -function unroll_symbol_vector(vec::Vector) - result = "" - for s in vec - if (result != "") - result *= ", " - end - result *= "$s" - end - return result +function unroll_symbol_vector(vec::VEC) where {VEC<:Union{AbstractVector,Tuple}} + return Expr(:tuple, vec...) end -function unroll_symbol_vector(vec::SVector) - return unroll_symbol_vector(Vector(vec)) +@inline function _call(f, args::Vararg) + return f(args...) end diff --git a/test/strassen_test.jl b/test/strassen_test.jl index e67491e..61c927d 100644 --- a/test/strassen_test.jl +++ b/test/strassen_test.jl @@ -38,8 +38,8 @@ EDGE_NUMBERS = (3, 96, 747, 5304) #, 37203 @test get_exit_node(g) isa DataTaskNode props = get_properties(g) - @test NODE_NUM_EXPECTED == props.noNodes - @test EDGE_NUM_EXPECTED == props.noEdges + @test NODE_NUM_EXPECTED == props.number_of_nodes + @test EDGE_NUM_EXPECTED == props.number_of_edges end f = get_compute_function(g, mm, cpu_st(), @__MODULE__)