Skip to content

Commit

Permalink
Rework replace and replace! (#26206)
Browse files Browse the repository at this point in the history
Introduce a new _replace!(new::Callable, res::T, A::T, count::Int) method
which custom types can implement to support all replace and replace! methods automatically,
instead of the current replace!(new::Callable, A::T, count::Int). This offers several advantages:
- For arrays, instead of copying the input and then replace elements, we can do the copy and replace
operations at the same time, which is quite faster for arrays when count=nothing.
- For dicts and sets, copying up-front is still faster as long as most original elements are preserved,
but for replace(), we can apply replacements directly instead of storing a them in a temporary vector.
- When the LHS of a pair contains a singleton type, we can subtract it from the element type
of the result, e.g. Union{T,Missing} becomes T.

Also simplify the dispatch logic by introducing the replace_pairs! function.
  • Loading branch information
nalimilan authored Apr 13, 2018
1 parent cb49bef commit 85c341c
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 40 deletions.
143 changes: 103 additions & 40 deletions base/set.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,11 @@ convert(::Type{T}, s::AbstractSet) where {T<:AbstractSet} = T(s)

## replace/replace! ##

function check_count(count::Integer)
count < 0 && throw(DomainError(count, "`count` must not be negative (got $count)"))
return min(count, typemax(Int)) % Int
end

# TODO: use copy!, which is currently unavailable from here since it is defined in Future
_copy_oftype(x, ::Type{T}) where {T} = copyto!(similar(x, T), x)
# TODO: use similar() once deprecation is removed and it preserves keys
Expand All @@ -310,8 +315,14 @@ _copy_oftype(x::AbstractArray{T}, ::Type{T}) where {T} = copy(x)
_copy_oftype(x::AbstractDict{K,V}, ::Type{Pair{K,V}}) where {K,V} = copy(x)
_copy_oftype(x::AbstractSet{T}, ::Type{T}) where {T} = copy(x)

_similar_or_copy(x::Any) = similar(x)
_similar_or_copy(x::Any, ::Type{T}) where {T} = similar(x, T)
# Make a copy on construction since it is faster than inserting elements separately
_similar_or_copy(x::Union{AbstractDict,AbstractSet}) = copy(x)
_similar_or_copy(x::Union{AbstractDict,AbstractSet}, ::Type{T}) where {T} = _copy_oftype(x, T)

# to make replace/replace! work for a new container type Cont, only
# replace!(new::Callable, A::Cont; count::Integer=typemax(Int))
# _replace!(new::Callable, res::Cont, A::Cont, count::Int)
# has to be implemented

"""
Expand All @@ -336,16 +347,17 @@ julia> replace!(Set([1, 2, 3]), 1=>0)
Set([0, 2, 3])
```
"""
replace!(A, old_new::Pair...; count::Integer=typemax(Int)) = _replace!(A, count, old_new)
replace!(A, old_new::Pair...; count::Integer=typemax(Int)) =
replace_pairs!(A, A, check_count(count), old_new)

function _replace!(A, count::Integer, old_new::Tuple{Vararg{Pair}})
function replace_pairs!(res, A, count::Int, old_new::Tuple{Vararg{Pair}})
@inline function new(x)
for o_n in old_new
isequal(first(o_n), x) && return last(o_n)
end
return x # no replace
end
replace!(new, A, count=count)
_replace!(new, res, A, count)
end

"""
Expand All @@ -367,7 +379,7 @@ julia> replace!(isodd, A, 0, count=2)
```
"""
replace!(pred::Callable, A, new; count::Integer=typemax(Int)) =
replace!(x -> ifelse(pred(x), new, x), A, count=count)
replace!(x -> ifelse(pred(x), new, x), A, count=check_count(count))

"""
replace!(new::Function, A; [count::Integer])
Expand Down Expand Up @@ -396,12 +408,8 @@ julia> replace!(x->2x, Set([3, 6]))
Set([6, 12])
```
"""
function replace!(new::Callable, A::Union{AbstractArray,AbstractDict,AbstractSet};
count::Integer=typemax(Int))
count < 0 && throw(DomainError(count, "`count` must not be negative"))
count != 0 && _replace!(new, A, min(count, typemax(Int)) % Int)
A
end
replace!(new::Callable, A; count::Integer=typemax(Int)) =
_replace!(new, A, A, check_count(count))

"""
replace(A, old_new::Pair...; [count::Integer])
Expand All @@ -410,6 +418,14 @@ Return a copy of collection `A` where, for each pair `old=>new` in `old_new`,
all occurrences of `old` are replaced by `new`.
Equality is determined using [`isequal`](@ref).
If `count` is specified, then replace at most `count` occurrences in total.
The element type of the result is chosen using promotion (see [`promote_type`](@ref))
based on the element type of `A` and on the types of the `new` values in pairs.
If `count` is omitted and the element type of `A` is a `Union`, the element type
of the result will not include singleton types which are replaced with values of
a different type: for example, `Union{T,Missing}` will become `T` if `missing` is
replaced.
See also [`replace!`](@ref).
# Examples
Expand All @@ -420,18 +436,41 @@ julia> replace([1, 2, 1, 3], 1=>0, 2=>4, count=2)
4
1
3
julia> replace([1, missing], missing=>0)
2-element Array{Int64,1}:
1
0
```
"""
function replace(A, old_new::Pair...; count::Integer=typemax(Int))
function replace(A, old_new::Pair...; count::Union{Integer,Nothing}=nothing)
V = promote_valuetype(old_new...)
T = promote_type(eltype(A), V)
_replace!(_copy_oftype(A, T), count, old_new)
if count isa Nothing
T = promote_type(subtract_singletontype(eltype(A), old_new...), V)
replace_pairs!(_similar_or_copy(A, T), A, typemax(Int), old_new)
else
U = promote_type(eltype(A), V)
replace_pairs!(_similar_or_copy(A, U), A, check_count(count), old_new)
end
end

promote_valuetype(x::Pair{K, V}) where {K, V} = V
promote_valuetype(x::Pair{K, V}, y::Pair...) where {K, V} =
promote_type(V, promote_valuetype(y...))

# Subtract singleton types which are going to be replaced
@pure issingletontype(T::DataType) = isdefined(T, :instance)
issingletontype(::Type) = false
function subtract_singletontype(::Type{T}, x::Pair{K}) where {T, K}
if issingletontype(K)
Core.Compiler.typesubtract(T, K)
else
T
end
end
subtract_singletontype(::Type{T}, x::Pair{K}, y::Pair...) where {T, K} =
subtract_singletontype(subtract_singletontype(T, y...), x)

"""
replace(pred::Function, A, new; [count::Integer])
Expand All @@ -451,7 +490,7 @@ julia> replace(isodd, [1, 2, 3, 1], 0, count=2)
"""
function replace(pred::Callable, A, new; count::Integer=typemax(Int))
T = promote_type(eltype(A), typeof(new))
replace!(pred, _copy_oftype(A, T), new, count=count)
_replace!(x -> ifelse(pred(x), new, x), _similar_or_copy(A, T), A, check_count(count))
end

"""
Expand All @@ -478,7 +517,8 @@ Dict{Int64,Int64} with 2 entries:
1 => 3
```
"""
replace(new::Callable, A; count::Integer=typemax(Int)) = replace!(new, copy(A), count=count)
replace(new::Callable, A; count::Integer=typemax(Int)) =
_replace!(new, _similar_or_copy(A), A, check_count(count))

# Handle ambiguities
replace!(a::Callable, b::Pair; count::Integer=-1) = throw(MethodError(replace!, (a, b)))
Expand All @@ -487,42 +527,65 @@ replace(a::Callable, b::Pair; count::Integer=-1) = throw(MethodError(replace, (a
replace(a::Callable, b::Pair, c::Pair; count::Integer=-1) = throw(MethodError(replace, (a, b, c)))
replace(a::AbstractString, b::Pair, c::Pair) = throw(MethodError(replace, (a, b, c)))


### replace! for AbstractDict/AbstractSet

askey(k, ::AbstractDict) = k.first
askey(k, ::AbstractSet) = k

function _replace!(new::Callable, A::Union{AbstractDict,AbstractSet}, count::Int)
repl = Pair{eltype(A),eltype(A)}[]
function _replace!(new::Callable, res::T, A::T,
count::Int) where T<:Union{AbstractDict,AbstractSet}
c = 0
for x in A
y = new(x)
if x !== y
push!(repl, x => y)
c += 1
if res === A # cannot replace elements while iterating over A
repl = Pair{eltype(A),eltype(A)}[]
for x in A
y = new(x)
if x !== y
push!(repl, x => y)
c += 1
end
c == count && break
end
for oldnew in repl
pop!(res, askey(first(oldnew), res))
end
for oldnew in repl
push!(res, last(oldnew))
end
else
for x in A
y = new(x)
if x !== y
pop!(res, askey(x, res))
push!(res, y)
c += 1
end
c == count && break
end
c == count && break
end
for oldnew in repl
pop!(A, askey(first(oldnew), A))
end
for oldnew in repl
push!(A, last(oldnew))
end
res
end

### AbstractArray
### replace! for AbstractArray

function _replace!(new::Callable, A::AbstractArray, count::Int)
function _replace!(new::Callable, res::AbstractArray, A::AbstractArray, count::Int)
c = 0
for i in eachindex(A)
@inbounds Ai = A[i]
y = new(Ai)
if Ai !== y
@inbounds A[i] = y
c += 1
if count >= length(A) # simpler loop allows for SIMD
for i in eachindex(A)
@inbounds Ai = A[i]
y = new(Ai)
@inbounds res[i] = y
end
else
for i in eachindex(A)
@inbounds Ai = A[i]
if c < count
y = new(Ai)
@inbounds res[i] = y
c += (Ai !== y)
else
@inbounds res[i] = Ai
end
end
c == count && break
end
res
end
9 changes: 9 additions & 0 deletions test/sets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,15 @@ end
x = @inferred replace(x -> x > 1, [1, 2], missing)
@test isequal(x, [1, missing]) && x isa Vector{Union{Int, Missing}}

x = @inferred replace([1, missing], missing=>2)
@test x == [1, 2] && x isa Vector{Int}
x = @inferred replace([1, missing], missing=>2, count=1)
@test x == [1, 2] && x isa Vector{Union{Int, Missing}}
x = @inferred replace([1, missing], missing=>missing)
@test isequal(x, [1, missing]) && x isa Vector{Union{Int, Missing}}
x = @inferred replace([1, missing], missing=>2, 1=>missing)
@test isequal(x, [missing, 2]) && x isa Vector{Union{Int, Missing}}

# test that isequal is used
@test replace([NaN, 1.0], NaN=>0.0) == [0.0, 1.0]
@test replace([1, missing], missing=>0) == [1, 0]
Expand Down

0 comments on commit 85c341c

Please sign in to comment.