From f69b327b95856083d8a14da372c76740b2f18b31 Mon Sep 17 00:00:00 2001 From: Sam O'Connor Date: Wed, 23 Mar 2016 10:22:48 +1100 Subject: [PATCH] [ci skip] WIP: Refactor pmap New functions: - head_and_tail -- like take and rest but atomic - batchsplit -- like split, but aware of nworkers - generate -- shorthand for creating a Genertor - asyncgenerate -- generate using tasks - asyncmap -- map using tasks - pgenerate -- generate using tasks and workers. Reimplement pmap: pmap(f, c...) = collect(pgenerate(f, c...)) --- base/generator.jl | 2 + base/iterator.jl | 19 ++++++ base/mapiterator.jl | 23 ++++++- base/multi.jl | 146 +++++++++++++++++++++----------------------- 4 files changed, 112 insertions(+), 78 deletions(-) diff --git a/base/generator.jl b/base/generator.jl index dff08ea9d25422..c0c6f59ccfc164 100644 --- a/base/generator.jl +++ b/base/generator.jl @@ -20,4 +20,6 @@ function next(g::Generator, s) g.f(v), s2 end +generate(f, c...) = Generator(f, c...) + collect(g::Generator) = map(g.f, g.iter) diff --git a/base/iterator.jl b/base/iterator.jl index d7080ac7c1c8da..b321f9bad740c2 100644 --- a/base/iterator.jl +++ b/base/iterator.jl @@ -123,6 +123,25 @@ done(i::Rest, st) = done(i.itr, st) eltype{I}(::Type{Rest{I}}) = eltype(I) + +""" + head_and_tail(c, n) -> head, tail + +Returns `head`: the first `n` elements of `c`; +and `tail`: an iterator over the remaining elements. +""" +function head_and_tail(c, n) + head = Vector{eltype(c)}(n) + s = start(c) + i = 0 + while i < n && !done(c, s) + i += 1 + head[i], s = next(c, s) + end + return resize!(head, i), rest(c, s) +end + + # Count -- infinite counting immutable Count{S<:Number} diff --git a/base/mapiterator.jl b/base/mapiterator.jl index 75bde41ed6c85f..3b4f2fc4d79c61 100644 --- a/base/mapiterator.jl +++ b/base/mapiterator.jl @@ -90,7 +90,7 @@ end Apply f to each element of c using at most 100 asynchronous tasks. For multiple collection arguments, apply f elementwise. -The iterator returns results as the become available. +Results are returned by the iterator as they become available. Note: `collect(StreamMapIterator(f, c...; ntasks=1))` is equivalent to `map(f, c...)`. """ @@ -144,3 +144,24 @@ function next(itr::StreamMapIterator, state::StreamMapState) return (r, state) end + + + +""" + asyncgenerate(f, c...) -> iterator + +Apply `@async f` to each element of `c`. +For multiple collection arguments, apply f elementwise. +Results are returned by the iterator as they become available. +""" +asyncgenerate(f, c...) = StreamMapIterator(f, c...) + + + +""" + asyncmap(f, c...) -> collection + +Transform collection `c` by applying `@async f` to each element. +For multiple collection arguments, apply f elementwise. +""" +asyncmap(f, c...) = collect(asyncgenerate(f, c...)) diff --git a/base/multi.jl b/base/multi.jl index 11b01c50e916b1..18aa4b332a6126 100644 --- a/base/multi.jl +++ b/base/multi.jl @@ -1536,93 +1536,85 @@ end pmap(f) = f() -# dynamic scheduling by creating a local task to feed work to each processor -# as it finishes. -# example unbalanced workload: -# rsym(n) = (a=rand(n,n);a*a') -# L = {rsym(200),rsym(1000),rsym(200),rsym(1000),rsym(200),rsym(1000),rsym(200),rsym(1000)}; -# pmap(eig, L); -function pmap(f, lsts...; err_retry=true, err_stop=false, pids = workers()) - len = length(lsts) - - results = Dict{Int,Any}() - - busy_workers = fill(false, length(pids)) - busy_workers_ntfy = Condition() - - retryqueue = [] - task_in_err = false - is_task_in_error() = task_in_err - set_task_in_error() = (task_in_err = true) - - nextidx = 0 - getnextidx() = (nextidx += 1) - - states = [start(lsts[idx]) for idx in 1:len] - function getnext_tasklet() - if is_task_in_error() && err_stop - return nothing - elseif !any(idx->done(lsts[idx],states[idx]), 1:len) - nxts = [next(lsts[idx],states[idx]) for idx in 1:len] - for idx in 1:len; states[idx] = nxts[idx][2]; end - nxtvals = [x[1] for x in nxts] - return (getnextidx(), nxtvals) - elseif !isempty(retryqueue) - return shift!(retryqueue) - elseif err_retry - # Handles the condition where we have finished processing the requested lsts as well - # as any retryqueue entries, but there are still some jobs active that may result - # in an error and have to be retried. - while any(busy_workers) - wait(busy_workers_ntfy) - if !isempty(retryqueue) - return shift!(retryqueue) - end - end - return nothing - else - return nothing - end + +""" + pgenerate(f, c...) + + asyncgenerate(f, c...) -> iterator + +Apply `f` to each element of `c` in parallel using available workers and tasks. +For multiple collection arguments, apply f elementwise. +Results are returned by the iterator as they become available. +""" +function pgenerate(f, c) + if nworkers() == 1 + return asyncgenerate(f, c) end + return flatten(asyncgenerate(remote(b -> asyncmap(f, b)), batchsplit(c))) +end - @sync begin - for (pididx, wpid) in enumerate(pids) - @async begin - tasklet = getnext_tasklet() - while (tasklet !== nothing) - (idx, fvals) = tasklet - busy_workers[pididx] = true - try - results[idx] = remotecall_fetch(f, wpid, fvals...) - catch ex - if err_retry - push!(retryqueue, (idx,fvals, ex)) - else - results[idx] = ex - end - set_task_in_error() - - busy_workers[pididx] = false - notify(busy_workers_ntfy; all=true) - - break # remove this worker from accepting any more tasks - end +pgenerate(f, c1, c...) = pgenerate(a->f(a...), zip(c1, c...)) - busy_workers[pididx] = false - notify(busy_workers_ntfy; all=true) - tasklet = getnext_tasklet() - end - end +function pmap(f, c...; err_retry=nothing, err_stop=nothing, pids=nothing) + + if err_retry != nothing + depwarn("`err_retry` is deprecated, use `pmap(retry(f), c...)`.", :pmap) + if err_retry == true + f = retry(f) + end + end + + if err_stop != nothing + depwarn("`err_stop` is deprecated, use `pmap(@catch(f), c...).", :pmap) + if err_stop == false + f = @catch(f) end end - for failure in retryqueue - results[failure[1]] = failure[3] + if pids != nothing + depwarn("`pids` is deprecated. It no longer has any effect.", :pmap) end - [results[x] for x in 1:nextidx] + + return collect(pgenerate(f, c...)) end + + +""" + batchsplit(c; min_batch_count=0, max_batch_size=100) + +Split a collection into at least `min_batch_count` batches. +If `min_batch_count` is not specified `batchsplit` attempts +to produce enough batches to allow work to be distributed amongst +available workers. + +Equivalent to `split(c, batch_size)` when `length(c) >> max_batch_size`. +""" +function batchsplit(c; min_batch_count=0, max_batch_size=100) + + if min_batch_count == 0 + min_batch_count = nworkers() * 3 + end + + @assert min_batch_count > 0 + @assert max_batch_size > 1 + + # Split collection into batches, then peek at the first few batches... + batches = split(c, max_batch_size) + head, tail = head_and_tail(batches, min_batch_count) + + # If there are not enough batches, use a smaller batch size... + if length(head) < min_batch_count + head = vcat(head...) + batch_size = max(1, div(length(head), min_batch_count)) + return split(head, batch_size) + end + + return flatten((head, tail)) +end + + # Statically split range [1,N] into equal sized chunks for np processors function splitrange(N::Int, np::Int) each = div(N,np)