Skip to content

Commit

Permalink
Rework FunctionCalls to allow multiple returned objects (#44)
Browse files Browse the repository at this point in the history
* Restructure FunctionCall

* Make closures work again

* Improve scheduler

* Change Function Calls to mostly use Vectors instead of Tuples

This makes small function calls slower, but large function calls much much faster to compile

* WIP fix input assignment code world age problem

* Remove trie workaround

* Remove call_fc and execute functions

* Remove superfluous _gen_access_expr and simplify _gen_local_init

* Renaming of member variables

* renaming/docs/removing unused interfaces

fix type stability in closures

* Work on performance problems with closures

(again...)

* Fix NodeSplit
  • Loading branch information
AntonReinhard authored Dec 4, 2024
1 parent 4a35700 commit 1c35bb9
Show file tree
Hide file tree
Showing 40 changed files with 641 additions and 619 deletions.
11 changes: 3 additions & 8 deletions ext/devices/cuda/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,21 @@ function ComputableDAGs.kernel(
machine = cpu_st()
tape = ComputableDAGs.gen_tape(graph, instance, machine, context_module)

assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...)
assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.input_assign_code)...)
# TODO: use gen_function_body here
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...)

function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1]))
res_sym = eval(
ComputableDAGs._gen_access_expr(
ComputableDAGs.entry_device(tape.machine), tape.outputSymbol
),
)
expr = Meta.parse(
"function compute_$(function_id)(input_vector, output_vector, n::Int64)
id = (blockIdx().x - 1) * blockDim().x + threadIdx().x
if (id > n)
return
end
@inline data_input = input_vector[id]
@inline input = input_vector[id]
$(assign_inputs)
$code
@inline output_vector[id] = $res_sym
@inline output_vector[id] = $(tape.output_symbol)
return nothing
end"
)
Expand Down
11 changes: 3 additions & 8 deletions ext/devices/rocm/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,22 @@ function ComputableDAGs.kernel(
machine = cpu_st()
tape = ComputableDAGs.gen_tape(graph, instance, machine, context_module)

assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.inputAssignCode)...)
assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.input_assign_code)...)

# TODO use gen_function_body here
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...)

function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(ComputableDAGs.rng[1]))
res_sym = eval(
ComputableDAGs._gen_access_expr(
ComputableDAGs.entry_device(tape.machine), tape.outputSymbol
),
)
expr = Meta.parse(
"function compute_$(function_id)(input_vector, output_vector, n::Int64)
id = (workgroupIdx().x - 1) * workgroupDim().x + workgroupIdx().x
if (id > n)
return
end
@inline data_input = input_vector[id]
@inline input = input_vector[id]
$(assign_inputs)
$code
@inline output_vector[id] = $res_sym
@inline output_vector[id] = $(tape.output_symbol)
return nothing
end"
)
Expand Down
8 changes: 2 additions & 6 deletions src/ComputableDAGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@ export reset_graph!
export get_operations

# code generation related
export execute
export get_compute_function
export gen_tape, execute_tape
export unpack_identity

# estimator
export cost_type, graph_cost, operation_effect
Expand All @@ -50,8 +47,7 @@ export optimize_step!, optimize!
export fixpoint_reached, optimize_to_fixpoint!

# models
export AbstractModel, AbstractProblemInstance
export problem_instance, input_type, graph, input_expr
export graph, input_type, input_expr

# machine info
export Machine
Expand All @@ -69,6 +65,7 @@ include("properties/type.jl")
include("operation/type.jl")
include("graph/type.jl")
include("scheduler/type.jl")
include("code_gen/type.jl")

include("trie.jl")
include("utils.jl")
Expand Down Expand Up @@ -127,7 +124,6 @@ include("devices/ext.jl")
include("scheduler/interface.jl")
include("scheduler/greedy.jl")

include("code_gen/type.jl")
include("code_gen/utils.jl")
include("code_gen/tape_machine.jl")
include("code_gen/function.jl")
Expand Down
44 changes: 11 additions & 33 deletions src/code_gen/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,50 +17,28 @@ in your top level.
## Keyword Arguments
`closures_size` (default=0 (off)): The size of closures to use in the main generated code. This specifies the size of code blocks across which the compiler cannot optimize. For sufficiently large functions, a larger value means longer compile times but potentially faster execution time.
`closures_size` (default=0 (off)): The size of closures to use in the main generated code. This specifies the size of code blocks across which the
compiler cannot optimize. For sufficiently large functions, a larger value means longer compile times but potentially faster execution time.
**Note** that the actually used closure size might be different than the one passed here, since the function automatically chooses a size that
is close to a n-th root of the total number of loc, based off the given size.
"""
function get_compute_function(
graph::DAG, instance, machine::Machine, context_module::Module; closures_size=0
)
tape = gen_tape(graph, instance, machine, context_module)

assignInputs = Expr(:block, expr_from_fc.(tape.inputAssignCode)...)
code = gen_function_body(tape; closures_size=closures_size)
code = gen_function_body(tape, context_module; closures_size=closures_size)
assign_inputs = Expr(:block, expr_from_fc.(tape.input_assign_code)...)

functionId = to_var_name(UUIDs.uuid1(rng[1]))
resSym = eval(_gen_access_expr(entry_device(tape.machine), tape.outputSymbol))
expr = #
Expr(
function_id = to_var_name(UUIDs.uuid1(rng[1]))
res_sym = tape.output_symbol
expr = Expr(
:function, # function definition
Expr(
:call,
Symbol("compute_$functionId"),
Expr(:(::), :data_input, input_type(instance)),
:call, Symbol("compute_$function_id"), Expr(:(::), :input, input_type(instance))
), # function name and parameters
Expr(:block, assignInputs, code, Expr(:return, resSym)), # function body
Expr(:block, assign_inputs, code, Expr(:return, res_sym)), # function body
)

return RuntimeGeneratedFunction(@__MODULE__, context_module, expr)
end

"""
execute(
graph::DAG,
instance,
machine::Machine,
input,
context_module::Module
)
Execute the code of the given `graph` on the given input values.
This is essentially shorthand for
```julia
tape = gen_tape(graph, instance, machine, context_module)
return execute_tape(tape, input)
```
"""
function execute(graph::DAG, instance, machine::Machine, input, context_module::Module)
tape = gen_tape(graph, instance, machine, context_module)
return execute_tape(tape, input)
end
Loading

0 comments on commit 1c35bb9

Please sign in to comment.