diff --git a/base/set.jl b/base/set.jl index 41238f14dd4a4..a1686c7c46961 100644 --- a/base/set.jl +++ b/base/set.jl @@ -300,6 +300,11 @@ convert(::Type{T}, s::AbstractSet) where {T<:AbstractSet} = T(s) ## replace/replace! ## +function check_count(count::Integer) + count < 0 && throw(DomainError(count, "`count` must not be negative (got $count)")) + return min(count, typemax(Int)) % Int +end + # TODO: use copy!, which is currently unavailable from here since it is defined in Future _copy_oftype(x, ::Type{T}) where {T} = copyto!(similar(x, T), x) # TODO: use similar() once deprecation is removed and it preserves keys @@ -310,8 +315,14 @@ _copy_oftype(x::AbstractArray{T}, ::Type{T}) where {T} = copy(x) _copy_oftype(x::AbstractDict{K,V}, ::Type{Pair{K,V}}) where {K,V} = copy(x) _copy_oftype(x::AbstractSet{T}, ::Type{T}) where {T} = copy(x) +_similar_or_copy(x::Any) = similar(x) +_similar_or_copy(x::Any, ::Type{T}) where {T} = similar(x, T) +# Make a copy on construction since it is faster than inserting elements separately +_similar_or_copy(x::Union{AbstractDict,AbstractSet}) = copy(x) +_similar_or_copy(x::Union{AbstractDict,AbstractSet}, ::Type{T}) where {T} = _copy_oftype(x, T) + # to make replace/replace! work for a new container type Cont, only -# replace!(new::Callable, A::Cont; count::Integer=typemax(Int)) +# _replace!(new::Callable, res::Cont, A::Cont, count::Int) # has to be implemented """ @@ -336,16 +347,17 @@ julia> replace!(Set([1, 2, 3]), 1=>0) Set([0, 2, 3]) ``` """ -replace!(A, old_new::Pair...; count::Integer=typemax(Int)) = _replace!(A, count, old_new) +replace!(A, old_new::Pair...; count::Integer=typemax(Int)) = + replace_pairs!(A, A, check_count(count), old_new) -function _replace!(A, count::Integer, old_new::Tuple{Vararg{Pair}}) +function replace_pairs!(res, A, count::Int, old_new::Tuple{Vararg{Pair}}) @inline function new(x) for o_n in old_new isequal(first(o_n), x) && return last(o_n) end return x # no replace end - replace!(new, A, count=count) + _replace!(new, res, A, count) end """ @@ -367,7 +379,7 @@ julia> replace!(isodd, A, 0, count=2) ``` """ replace!(pred::Callable, A, new; count::Integer=typemax(Int)) = - replace!(x -> ifelse(pred(x), new, x), A, count=count) + replace!(x -> ifelse(pred(x), new, x), A, count=check_count(count)) """ replace!(new::Function, A; [count::Integer]) @@ -396,12 +408,8 @@ julia> replace!(x->2x, Set([3, 6])) Set([6, 12]) ``` """ -function replace!(new::Callable, A::Union{AbstractArray,AbstractDict,AbstractSet}; - count::Integer=typemax(Int)) - count < 0 && throw(DomainError(count, "`count` must not be negative")) - count != 0 && _replace!(new, A, min(count, typemax(Int)) % Int) - A -end +replace!(new::Callable, A; count::Integer=typemax(Int)) = + _replace!(new, A, A, check_count(count)) """ replace(A, old_new::Pair...; [count::Integer]) @@ -410,6 +418,14 @@ Return a copy of collection `A` where, for each pair `old=>new` in `old_new`, all occurrences of `old` are replaced by `new`. Equality is determined using [`isequal`](@ref). If `count` is specified, then replace at most `count` occurrences in total. + +The element type of the result is chosen using promotion (see [`promote_type`](@ref)) +based on the element type of `A` and on the types of the `new` values in pairs. +If `count` is omitted and the element type of `A` is a `Union`, the element type +of the result will not include singleton types which are replaced with values of +a different type: for example, `Union{T,Missing}` will become `T` if `missing` is +replaced. + See also [`replace!`](@ref). # Examples @@ -420,18 +436,41 @@ julia> replace([1, 2, 1, 3], 1=>0, 2=>4, count=2) 4 1 3 + +julia> replace([1, missing], missing=>0) +2-element Array{Int64,1}: + 1 + 0 ``` """ -function replace(A, old_new::Pair...; count::Integer=typemax(Int)) +function replace(A, old_new::Pair...; count::Union{Integer,Nothing}=nothing) V = promote_valuetype(old_new...) - T = promote_type(eltype(A), V) - _replace!(_copy_oftype(A, T), count, old_new) + if count isa Nothing + T = promote_type(subtract_singletontype(eltype(A), old_new...), V) + replace_pairs!(_similar_or_copy(A, T), A, typemax(Int), old_new) + else + U = promote_type(eltype(A), V) + replace_pairs!(_similar_or_copy(A, U), A, check_count(count), old_new) + end end promote_valuetype(x::Pair{K, V}) where {K, V} = V promote_valuetype(x::Pair{K, V}, y::Pair...) where {K, V} = promote_type(V, promote_valuetype(y...)) +# Subtract singleton types which are going to be replaced +@pure issingletontype(T::DataType) = isdefined(T, :instance) +issingletontype(::Type) = false +function subtract_singletontype(::Type{T}, x::Pair{K}) where {T, K} + if issingletontype(K) + Core.Compiler.typesubtract(T, K) + else + T + end +end +subtract_singletontype(::Type{T}, x::Pair{K}, y::Pair...) where {T, K} = + subtract_singletontype(subtract_singletontype(T, y...), x) + """ replace(pred::Function, A, new; [count::Integer]) @@ -451,7 +490,7 @@ julia> replace(isodd, [1, 2, 3, 1], 0, count=2) """ function replace(pred::Callable, A, new; count::Integer=typemax(Int)) T = promote_type(eltype(A), typeof(new)) - replace!(pred, _copy_oftype(A, T), new, count=count) + _replace!(x -> ifelse(pred(x), new, x), _similar_or_copy(A, T), A, check_count(count)) end """ @@ -478,7 +517,8 @@ Dict{Int64,Int64} with 2 entries: 1 => 3 ``` """ -replace(new::Callable, A; count::Integer=typemax(Int)) = replace!(new, copy(A), count=count) +replace(new::Callable, A; count::Integer=typemax(Int)) = + _replace!(new, _similar_or_copy(A), A, check_count(count)) # Handle ambiguities replace!(a::Callable, b::Pair; count::Integer=-1) = throw(MethodError(replace!, (a, b))) @@ -487,42 +527,65 @@ replace(a::Callable, b::Pair; count::Integer=-1) = throw(MethodError(replace, (a replace(a::Callable, b::Pair, c::Pair; count::Integer=-1) = throw(MethodError(replace, (a, b, c))) replace(a::AbstractString, b::Pair, c::Pair) = throw(MethodError(replace, (a, b, c))) - ### replace! for AbstractDict/AbstractSet askey(k, ::AbstractDict) = k.first askey(k, ::AbstractSet) = k -function _replace!(new::Callable, A::Union{AbstractDict,AbstractSet}, count::Int) - repl = Pair{eltype(A),eltype(A)}[] +function _replace!(new::Callable, res::T, A::T, + count::Int) where T<:Union{AbstractDict,AbstractSet} c = 0 - for x in A - y = new(x) - if x !== y - push!(repl, x => y) - c += 1 + if res === A # cannot replace elements while iterating over A + repl = Pair{eltype(A),eltype(A)}[] + for x in A + y = new(x) + if x !== y + push!(repl, x => y) + c += 1 + end + c == count && break + end + for oldnew in repl + pop!(res, askey(first(oldnew), res)) + end + for oldnew in repl + push!(res, last(oldnew)) + end + else + for x in A + y = new(x) + if x !== y + pop!(res, askey(x, res)) + push!(res, y) + c += 1 + end + c == count && break end - c == count && break - end - for oldnew in repl - pop!(A, askey(first(oldnew), A)) - end - for oldnew in repl - push!(A, last(oldnew)) end + res end -### AbstractArray +### replace! for AbstractArray -function _replace!(new::Callable, A::AbstractArray, count::Int) +function _replace!(new::Callable, res::AbstractArray, A::AbstractArray, count::Int) c = 0 - for i in eachindex(A) - @inbounds Ai = A[i] - y = new(Ai) - if Ai !== y - @inbounds A[i] = y - c += 1 + if count >= length(A) # simpler loop allows for SIMD + for i in eachindex(A) + @inbounds Ai = A[i] + y = new(Ai) + @inbounds res[i] = y + end + else + for i in eachindex(A) + @inbounds Ai = A[i] + if c < count + y = new(Ai) + @inbounds res[i] = y + c += (Ai !== y) + else + @inbounds res[i] = Ai + end end - c == count && break end + res end diff --git a/test/sets.jl b/test/sets.jl index 31012363a90f5..ddd64444d572d 100644 --- a/test/sets.jl +++ b/test/sets.jl @@ -555,6 +555,15 @@ end x = @inferred replace(x -> x > 1, [1, 2], missing) @test isequal(x, [1, missing]) && x isa Vector{Union{Int, Missing}} + x = @inferred replace([1, missing], missing=>2) + @test x == [1, 2] && x isa Vector{Int} + x = @inferred replace([1, missing], missing=>2, count=1) + @test x == [1, 2] && x isa Vector{Union{Int, Missing}} + x = @inferred replace([1, missing], missing=>missing) + @test isequal(x, [1, missing]) && x isa Vector{Union{Int, Missing}} + x = @inferred replace([1, missing], missing=>2, 1=>missing) + @test isequal(x, [missing, 2]) && x isa Vector{Union{Int, Missing}} + # test that isequal is used @test replace([NaN, 1.0], NaN=>0.0) == [0.0, 1.0] @test replace([1, missing], missing=>0) == [1, 0]