Skip to content

Commit

Permalink
fixes, but still broken
Browse files Browse the repository at this point in the history
  • Loading branch information
willow-ahrens committed Feb 6, 2025
1 parent 9033ce7 commit 1ff24a4
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
8 changes: 4 additions & 4 deletions src/architecture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ get_task_num(task::VirtualCPUThread) = task.tid

struct CPULocalArray{A}
device::CPU
data::arrtor{A}
data::Vector{A}
end

function CPULocalArray{A}(device::CPU) where {A}
Expand All @@ -227,13 +227,13 @@ end
Base.eltype(::Type{CPULocalArray{A}}) where {A} = eltype(A)
Base.ndims(::Type{CPULocalArray{A}}) where {A} = ndims(A)

function transfer(arr::A, device::CPU, style::BroadcastSend) where {A<:AbstractArray}
function transfer(arr::A, device::CPU, style::BcastSend) where {A<:AbstractArray}
CPULocalArray{A}(mem.device, [copy(arr) for _ in 1:(mem.device.n)])
end
function transfer(arr::CPULocalArray, device::CPU, style::BroadcastSend)
function transfer(arr::CPULocalArray, device::CPU, style::BcastSend)
return arr
end
function transfer(arr::CPULocalArray, task::CPUThread, style::BroadcastRecv)
function transfer(arr::CPULocalArray, task::CPUThread, style::BcastRecv)
if get_device(task) === arr.device
temp = arr.data[task.tid]
return temp
Expand Down
10 changes: 5 additions & 5 deletions src/lower.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,10 @@ function lower_parallel_loop(ctx, root, ext::ParallelDimension, device::VirtualC
)

for tns in setdiff(used_in_scope, decl_in_scope)
set_binding(ctx, tns, virtual_transfer(ctx, resolve(ctx, tns), device, bcast_send))
set_binding!(ctx, tns, virtual_transfer(ctx, resolve(ctx, tns), device, bcast_send))
end
for tns in intersect(used_in_scope, decl_in_scope)
set_binding(
set_binding!(
ctx, tns, virtual_transfer(ctx, resolve(ctx, tns), device, scatter_send)
)
end
Expand All @@ -354,14 +354,14 @@ function lower_parallel_loop(ctx, root, ext::ParallelDimension, device::VirtualC
subtask = get_task(ctx_2)
tid = get_task_num(subtask)
for tns in setdiff(used_in_scope, decl_in_scope)
set_binding(
set_binding!(
ctx_2,
tns,
virtual_transfer(ctx_2, resolve(ctx_2, tns), subtask, bcast_recv),
)
end
for tns in intersect(used_in_scope, decl_in_scope)
set_binding(
set_binding!(
ctx_2,
tns,
virtual_transfer(ctx_2, resolve(ctx_2, tns), subtask, scatter_recv),
Expand All @@ -381,7 +381,7 @@ function lower_parallel_loop(ctx, root, ext::ParallelDimension, device::VirtualC
end
end
for tns in intersect(used_in_scope, decl_in_scope)
set_binding(
set_binding!(
ctx_2,
tns,
virtual_transfer(ctx_2, resolve(ctx_2, tns), subtask, gather_send),
Expand Down
4 changes: 2 additions & 2 deletions src/tensors/fibers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ function unfurl(ctx::AbstractCompiler, arr::VirtualFiber, ext, mode, proto)
end

function virtual_transfer(ctx::AbstractCompiler, fbr::VirtualFiber, arch, style)
virtual_transfer_level(ctx, fbr.lvl, arch, style)
VirtualFiber(virtual_transfer_level(ctx, fbr.lvl, arch, style))
end

function virtual_transfer(ctx::AbstractCompiler, fbr::VirtualSubFiber, arch, style)
virtual_transfer_level(ctx, fbr.lvl, arch, style)
VirtualSubFiber(virtual_transfer_level(ctx, fbr.lvl, arch, style), fbr.pos)
end

struct HollowSubFiber{Lvl,Pos,Dirty} <: AbstractFiber{Lvl}
Expand Down
2 changes: 1 addition & 1 deletion src/tensors/levels/sparse_interval_levels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ function virtual_transfer_level(ctx, lvl::VirtualSparseIntervalLevel, arch, styl
lvl_2 = virtual_transfer_level(ctx, lvl.lvl, arch, style)
VirtualSparseIntervalLevel(
lvl_2, lvl.ex, lvl.Ti, left_2, right_2, lvl.shape, lvl.qos_fill, lvl.qos_stop,
lvl.prev_pos
lvl.prev_pos,
)
end

Expand Down

0 comments on commit 1ff24a4

Please sign in to comment.