Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wma/shard levels #697

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
de87c7a
this is the idea
willow-ahrens Nov 24, 2024
e0da0fe
fixing
willow-ahrens Nov 26, 2024
208cc18
trying out local memory concept
willow-ahrens Dec 13, 2024
7d9217c
add tasks
willow-ahrens Dec 13, 2024
364d5a3
stuff
willow-ahrens Jan 8, 2025
fa71b9b
fixes
willow-ahrens Jan 30, 2025
193064a
some fixes
willow-ahrens Jan 30, 2025
46f5652
sorta broken
willow-ahrens Jan 30, 2025
f5a5c87
rm println
willow-ahrens Jan 30, 2025
5f984f9
fix
willow-ahrens Feb 4, 2025
7c530d5
Merge remote-tracking branch 'origin/wma/formatter' into wma/shard_le…
willow-ahrens Feb 5, 2025
e8da20c
fix
willow-ahrens Feb 5, 2025
c054628
directions
willow-ahrens Feb 5, 2025
a8c2eb1
directions
willow-ahrens Feb 5, 2025
2289266
add a parallel manifesto
willow-ahrens Feb 6, 2025
ebc14cb
Update docs/src/docs/internals/parallel.md
willow-ahrens Feb 6, 2025
afae7e5
Update docs/src/docs/internals/parallel.md
willow-ahrens Feb 6, 2025
39d9356
Update docs/src/docs/internals/parallel.md
willow-ahrens Feb 6, 2025
1b98661
fix
willow-ahrens Feb 6, 2025
231f79a
Merge branch 'wma/shard_levels' of github.com:willow-ahrens/Finch.jl …
willow-ahrens Feb 6, 2025
a509212
no stylebot (awkward
willow-ahrens Feb 6, 2025
311c781
quick fix
willow-ahrens Feb 6, 2025
8617a20
rename
willow-ahrens Feb 6, 2025
553e57e
right direction
willow-ahrens Feb 6, 2025
f701473
rm shards
willow-ahrens Feb 6, 2025
a855754
keep commit
willow-ahrens Feb 6, 2025
2173409
Merge branch 'wma/transfers' into wma/shard_levels
willow-ahrens Feb 6, 2025
6aacac4
fix
willow-ahrens Feb 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ though both are included as part of the test suite.

## Code Style

We use [Blue Style](https://github.com/JuliaDiff/BlueStyle) formatting, with a few tweaks
defined in `.JuliaFormatter.toml`. Running the tests in overwrite mode will
automatically reformat your code, but you can also add [`JuliaFormatter`](https://domluna.github.io/JuliaFormatter.jl/stable/#Editor-Plugins) to your editor to reformat as you go.
We use [Blue Style](https://github.com/JuliaDiff/BlueStyle) formatting, with a
few tweaks defined in `.JuliaFormatter.toml`. Running the tests in overwrite
mode will automatically reformat your code, but you can also add
[`JuliaFormatter`](https://domluna.github.io/JuliaFormatter.jl/stable/#Editor-Plugins)
to your editor to reformat as you go, or call
`julia -e "using JuliaFormatter; format("path/to/Finch.jl")` manually.
willow-ahrens marked this conversation as resolved.
Show resolved Hide resolved
62 changes: 62 additions & 0 deletions docs/src/docs/internals/parallel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Parallel Processing in Finch

## Modelling the Architecture

Finch uses a simple, hierarchical representation of devices and tasks to model
different kind of parallel processing. An [`AbstractDevice`](@ref) is a physical or
virtual device on which we can execute tasks, which may each be represented by
an [`AbstractTask`](@ref).

```@docs
AbstractTask
AbstractDevice
```

The current task in a compilation context can be queried with
[`get_task`](@ref). Each device has a set of numbered child
tasks, and each task has a parent task.

```@docs
get_num_tasks
get_task_num
get_device
get_parent_task
```

## Data Movement

Before entering a parallel loop, a tensor may reside on a single task, or
represent a single view of data distributed across multiple tasks, or represent
willow-ahrens marked this conversation as resolved.
Show resolved Hide resolved
multiple separate tensors local to multiple tasks. A tensor's data must be
resident in the current task to process operations on that tensor, such as loops
over the indices, accesses to the tensor, or `declare`, `freeze`, or `thaw`.
Upon entering a parallel loop, we must transfer the tensor to the tasks
where it is needed. Upon exiting the parallel loop, we may need to combine
the data from multiple tasks into a single tensor.

There are two cases, depending on whether the tensor is declared outside the
parallel loop or is a temporary tensor declared within the parallel loop.

If the tensor is a temporary tensor declared within the parallel loop, we call
`bcast` to broadcast the tensor to all tasks.

If the tensor is declared outside the parallel loop, we call `scatter` to
send it to the tasks where it is needed. Note that if the tensor is in `read` mode,
`scatter` may simply `bcast` the entire tensor to all tasks. If the device has global
memory, `scatter` may also be a no-op. When the parallel loop is exited, we call
`gather` to reconcile the data from multiple tasks back into a single tensor.

Each of these operations begins with a `_send` variant on one task, and
finishes with a `_recv` variant on the recieving task.

```@docs
bcast
bcast_send
bcast_recv
scatter
scatter_send
scatter_recv
gather
gather_send
gather_recv
```
willow-ahrens marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions src/Finch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export Dense, DenseLevel
export Element, ElementLevel
export AtomicElement, AtomicElementLevel
export Separate, SeparateLevel
export Shard, ShardLevel
export Mutex, MutexLevel
export Pattern, PatternLevel
export Scalar, SparseScalar, ShortCircuitScalar, SparseShortCircuitScalar
Expand Down Expand Up @@ -141,6 +142,7 @@ include("tensors/levels/dense_rle_levels.jl")
include("tensors/levels/element_levels.jl")
include("tensors/levels/atomic_element_levels.jl")
include("tensors/levels/separate_levels.jl")
include("tensors/levels/shard_levels.jl")
include("tensors/levels/mutex_levels.jl")
include("tensors/levels/pattern_levels.jl")
include("tensors/masks.jl")
Expand Down
95 changes: 88 additions & 7 deletions src/architecture.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,45 @@
"""
AbstractDevice

A datatype representing a device on which tasks can be executed.
"""
abstract type AbstractDevice end
abstract type AbstractVirtualDevice end

"""
AbstractTask

An individual processing unit on a device, responsible for running code.
"""
abstract type AbstractTask end
abstract type AbstractVirtualTask end

"""
get_num_tasks(dev::AbstractDevice)

Return the number of tasks on the device dev.
"""
function get_num_tasks end
"""
get_task_num(task::AbstractTask)

Return the task number of `task`.
"""
function get_task_num end
"""
get_device(task::AbstractTask)

Return the device that `task` is running on.
"""
function get_device end

"""
get_parent_task(task::AbstractTask)

Return the task which spawned `task`.
"""
function get_parent_task end

"""
aquire_lock!(dev::AbstractDevice, val)

Expand Down Expand Up @@ -40,6 +77,7 @@ struct CPU <: AbstractDevice
n::Int
end
CPU() = CPU(Threads.nthreads())
get_num_tasks(dev::CPU) = dev.n
@kwdef struct VirtualCPU <: AbstractVirtualDevice
ex
n
Expand All @@ -57,6 +95,7 @@ end
function lower(ctx::AbstractCompiler, device::VirtualCPU, ::DefaultStyle)
something(device.ex, :(CPU($(ctx(device.n)))))
end
get_num_tasks(::VirtualCPU) = literal(1)

FinchNotation.finch_leaf(device::VirtualCPU) = virtual(device)

Expand All @@ -68,21 +107,24 @@ A device that represents a serial CPU execution.
struct Serial <: AbstractTask end
const serial = Serial()
get_device(::Serial) = CPU(1)
get_task(::Serial) = nothing
get_parent_task(::Serial) = nothing
get_task_num(::Serial) = 1
struct VirtualSerial <: AbstractVirtualTask end
virtualize(ctx, ex, ::Type{Serial}) = VirtualSerial()
lower(ctx::AbstractCompiler, task::VirtualSerial, ::DefaultStyle) = :(Serial())
FinchNotation.finch_leaf(device::VirtualSerial) = virtual(device)
virtual_get_device(::VirtualSerial) = VirtualCPU(nothing, 1)
virtual_get_task(::VirtualSerial) = nothing
get_device(::VirtualSerial) = VirtualCPU(nothing, 1)
get_parent_task(::VirtualSerial) = nothing
get_task_num(::VirtualSerial) = literal(1)

struct CPUThread{Parent} <: AbstractTask
tid::Int
dev::CPU
parent::Parent
end
get_device(task::CPUThread) = task.device
get_task(task::CPUThread) = task.parent
get_parent_task(task::CPUThread) = task.parent
get_task_num(task::CPUThread) = task.tid

@inline function make_lock(::Type{Threads.Atomic{T}}) where {T}
return Threads.Atomic{T}(zero(T))
Expand Down Expand Up @@ -139,8 +181,9 @@ function lower(ctx::AbstractCompiler, task::VirtualCPUThread, ::DefaultStyle)
:(CPUThread($(ctx(task.tid)), $(ctx(task.dev)), $(ctx(task.parent))))
end
FinchNotation.finch_leaf(device::VirtualCPUThread) = virtual(device)
virtual_get_device(task::VirtualCPUThread) = task.dev
virtual_get_task(task::VirtualCPUThread) = task.parent
get_device(task::VirtualCPUThread) = task.dev
get_parent_task(task::VirtualCPUThread) = task.parent
get_task_num(task::VirtualCPUThread) = task.tid

struct CPULocalMemory
device::CPU
Expand Down Expand Up @@ -174,7 +217,20 @@ function moveto(vec::CPULocalVector, task::CPUThread)
return temp
end

struct Converter{f,T} end
"""
local_memory(device)

Returns the local memory type for a given device.
"""
function local_memory(device::CPU)
return CPULocalMemory(device)
end

function local_memory(device::Serial)
return device
end

struct Converter{f, T} end
willow-ahrens marked this conversation as resolved.
Show resolved Hide resolved

(::Converter{f,T})(x) where {f,T} = T(f(x))

Expand Down Expand Up @@ -238,4 +294,29 @@ for T in [
pointer(vec, idx), $T(Vf), x, UnsafeAtomics.seq_cst, UnsafeAtomics.seq_cst
)
end

willow-ahrens marked this conversation as resolved.
Show resolved Hide resolved
end

function virtual_parallel_region(f, ctx, ::Serial)
contain(f, ctx)
end

function virtual_parallel_region(f, ctx, device::VirtualCPU)
tid = freshen(ctx, :tid)

code = contain(ctx) do ctx_2
subtask = VirtualCPUThread(value(tid, Int), device, ctx_2.code.task)
contain(f, ctx_2, task=subtask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
contain(f, ctx_2, task=subtask)
contain(f, ctx_2; task=subtask)

end

return quote
Threads.@threads for $tid = 1:$(ctx(device.n))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
Threads.@threads for $tid = 1:$(ctx(device.n))
Threads.@threads for $tid in 1:($(ctx(device.n)))

Finch.@barrier begin
@inbounds @fastmath begin
$code
end
nothing
end
end
end
end
62 changes: 21 additions & 41 deletions src/lower.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,19 +319,11 @@ end
function lower_parallel_loop(ctx, root, ext::ParallelDimension, device::VirtualCPU)
root = ensure_concurrent(root, ctx)

tid = index(freshen(ctx, :tid))
i = freshen(ctx, :i)

decl_in_scope = unique(
filter(
!isnothing,
map(node -> begin
if @capture(node, declare(~tns, ~init, ~op))
tns
end
end, PostOrderDFS(root.body)),
),
)
decl_in_scope = unique(filter(!isnothing, map(node-> begin
if @capture(node, declare(~tns, ~init, ~op))
tns
end
end, PostOrderDFS(root.body))))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
decl_in_scope = unique(filter(!isnothing, map(node-> begin
if @capture(node, declare(~tns, ~init, ~op))
tns
end
end, PostOrderDFS(root.body))))
decl_in_scope = unique(
filter(
!isnothing,
map(node -> begin
if @capture(node, declare(~tns, ~init, ~op))
tns
end
end, PostOrderDFS(root.body)),
),
)


used_in_scope = unique(
filter(
Expand All @@ -344,39 +336,27 @@ function lower_parallel_loop(ctx, root, ext::ParallelDimension, device::VirtualC
),
)

root_2 = loop(tid, Extent(value(i, Int), value(i, Int)),
loop(root.idx, ext.ext,
sieve(access(VirtualSplitMask(device.n), reader(), root.idx, tid),
root.body,
),
),
)

for tns in setdiff(used_in_scope, decl_in_scope)
virtual_moveto(ctx, resolve(ctx, tns), device)
end

code = contain(ctx) do ctx_2
subtask = VirtualCPUThread(value(i, Int), device, ctx_2.code.task)
contain(ctx_2; task=subtask) do ctx_3
for tns in intersect(used_in_scope, decl_in_scope)
virtual_moveto(ctx_3, resolve(ctx_3, tns), subtask)
end
contain(ctx_3) do ctx_4
open_scope(ctx_4) do ctx_5
ctx_5(instantiate!(ctx_5, root_2))
end
end
virtual_parallel_region(ctx, device) do ctx_2
subtask = get_task(ctx_2)
tid = get_task_num(subtask)
for tns in intersect(used_in_scope, decl_in_scope)
virtual_moveto(ctx_2, resolve(ctx_2, tns), subtask)
end
end

return quote
Threads.@threads for $i in 1:($(ctx(device.n)))
Finch.@barrier begin
@inbounds @fastmath begin
$code
end
nothing
contain(ctx_2) do ctx_3
open_scope(ctx_3) do ctx_4
i = index(freshen(ctx, :i))
root_2 = loop(i, Extent(tid, tid),
loop(root.idx, ext.ext,
sieve(access(VirtualSplitMask(device.n), reader(), root.idx, i),
root.body
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
root.body
)
)
root.body,
),
),

)
ctx_4(instantiate!(ctx_4, root_2))
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/tensors/levels/atomic_element_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ function lower_assign(ctx, fbr::VirtualSubFiber{VirtualAtomicElementLevel}, mode
(lvl, pos) = (fbr.lvl, fbr.pos)
op = ctx(op)
rhs = ctx(rhs)
device = ctx(virtual_get_device(get_task(ctx)))
device = ctx(get_device(get_task(ctx)))
:(Finch.atomic_modify!($device, $(lvl.val), $(ctx(pos)), $op, $rhs))
end

Expand All @@ -221,6 +221,6 @@ function lower_assign(
)
op = ctx(op)
rhs = ctx(rhs)
device = ctx(virtual_get_device(get_task(ctx)))
device = ctx(get_device(get_task(ctx)))
:(Finch.atomic_modify!($device, $(lvl.val), $(ctx(pos)), $op, $rhs))
end
Loading
Loading