Skip to content

Commit

Permalink
Add intersect(::AbstractRange, ::AbstractVector)
Browse files Browse the repository at this point in the history
Also adds:
 - `intersect(::AbstractRange, ::AbstractRange)`
 - `intersect(::AbstractRange, ::AbstractRange)`

Closes #41759

Co-authored-by: Ian Butterworth <[email protected]>
Co-authored-by: Jeff Bezanson <[email protected]>
  • Loading branch information
3 people committed Aug 27, 2021
1 parent 6814f2b commit b95c68e
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 12 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Standard library changes
arithmetic to error if the result may be wrapping. Or use a package such as SaferIntegers.jl when
constructing the range. ([#40382])
* TCP socket objects now expose `closewrite` functionality and support half-open mode usage ([#40783]).
* Intersect returns a result with the eltype of the type-promoted eltypes of the two inputs ([#41769]).

#### InteractiveUtils
* A new macro `@time_imports` for reporting any time spent importing packages and their dependencies ([#41612])
Expand Down
15 changes: 10 additions & 5 deletions base/abstractset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ Set{Int64} with 3 elements:
"""
function union end

_in(itr) = x -> x in itr

union(s, sets...) = union!(emptymutable(s, promote_eltype(s, sets...)), s, sets...)
union(s::AbstractSet) = copy(s)

Expand Down Expand Up @@ -109,6 +107,10 @@ Maintain order with arrays.
See also: [`setdiff`](@ref), [`isdisjoint`](@ref), [`issubset`](@ref Base.issubset), [`issetequal`](@ref).
!!! compat "Julia 1.8"
As of Julia 1.8 intersect returns a result with the eltype of the
type-promoted eltypes of the two inputs
# Examples
```jldoctest
julia> intersect([1, 2, 3], [3, 4, 5])
Expand All @@ -125,9 +127,12 @@ Set{Int64} with 1 element:
2
```
"""
intersect(s::AbstractSet, itr, itrs...) = intersect!(intersect(s, itr), itrs...)
function intersect(s::AbstractSet, itr, itrs...)
T = promote_eltype(s, itr, itrs...)
return intersect!(Set{T}(s), itr, itrs...)
end
intersect(s) = union(s)
intersect(s::AbstractSet, itr) = mapfilter(_in(s), push!, itr, emptymutable(s))
intersect(s::AbstractSet, itr) = mapfilter(in(s), push!, itr, emptymutable(s, promote_eltype(s, itr)))

const = intersect

Expand All @@ -143,7 +148,7 @@ function intersect!(s::AbstractSet, itrs...)
end
return s
end
intersect!(s::AbstractSet, s2::AbstractSet) = filter!(_in(s2), s)
intersect!(s::AbstractSet, s2::AbstractSet) = filter!(in(s2), s)
intersect!(s::AbstractSet, itr) =
intersect!(s, union!(emptymutable(s, eltype(itr)), itr))

Expand Down
18 changes: 13 additions & 5 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2603,19 +2603,27 @@ function _shrink!(shrinker!, v::AbstractVector, itrs)
seen = Set{eltype(v)}()
filter!(_grow_filter!(seen), v)
shrinker!(seen, itrs...)
filter!(_in(seen), v)
filter!(in(seen), v)
end

intersect!(v::AbstractVector, itrs...) = _shrink!(intersect!, v, itrs)
setdiff!( v::AbstractVector, itrs...) = _shrink!(setdiff!, v, itrs)

vectorfilter(f, v::AbstractVector) = filter(f, v) # TODO: do we want this special case?
vectorfilter(f, v) = [x for x in v if f(x)]
vectorfilter(T::Type, f, v) = T[x for x in v if f(x)]

function _shrink(shrinker!, itr, itrs)
keep = shrinker!(Set(itr), itrs...)
vectorfilter(_shrink_filter!(keep), itr)
T = promote_eltype(itr, itrs...)
keep = shrinker!(Set{T}(itr), itrs...)
vectorfilter(T, _shrink_filter!(keep), itr)
end

intersect(itr, itrs...) = _shrink(intersect!, itr, itrs)
setdiff( itr, itrs...) = _shrink(setdiff!, itr, itrs)

function intersect(v::AbstractVector, r::AbstractRange)
T = promote_eltype(v, r)
common = Iterators.filter(in(r), v)
seen = Set{T}(common)
return vectorfilter(T, _shrink_filter!(seen), common)
end
intersect(r::AbstractRange, v::AbstractVector) = intersect(v, r)
10 changes: 10 additions & 0 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,16 @@ function intersect(r::StepRange, s::StepRange)
step(r) < zero(step(r)) ? StepRange{T,S}(n, -a, m) : StepRange{T,S}(m, a, n)
end

function intersect(r1::AbstractRange, r2::AbstractRange)
# To iterate over the shorter range
length(r1) > length(r2) && return intersect(r2, r1)

r1 = unique(r1)
T = promote_eltype(r1, r2)

return T[x for x in r1 if x r2]
end

function intersect(r1::AbstractRange, r2::AbstractRange, r3::AbstractRange, r::AbstractRange...)
i = intersect(intersect(r1, r2), r3)
for t in r
Expand Down
5 changes: 5 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,11 @@ end
@test isequal(intersect([1,2,3], Float64[]), Float64[])
@test isequal(intersect(Int64[], [1,2,3]), Int64[])
@test isequal(intersect(Int64[]), Int64[])
@test isequal(intersect([1, 3], 1:typemax(Int)), [1, 3])
@test isequal(intersect(1:typemax(Int), [1, 3]), [1, 3])
@test isequal(intersect([1, 2, 3], 2:0.1:5), [2., 3.])
@test isequal(intersect([1.0, 2.0, 3.0], 2:5), [2., 3.])

@test isequal(setdiff([1,2,3,4], [2,5,4]), [1,3])
@test isequal(setdiff([1,2,3,4], [7,8,9]), [1,2,3,4])
@test isequal(setdiff([1,2,3,4], Int64[]), Int64[1,2,3,4])
Expand Down
18 changes: 18 additions & 0 deletions test/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,9 @@ end
@test intersect(1:3, 2) === intersect(2, 1:3) === 2:2
@test intersect(1.0:3.0, 2) == intersect(2, 1.0:3.0) == [2.0]

@test intersect(1:typemax(Int), [1, 3]) === [1, 3]
@test intersect([1, 3], 1:typemax(Int)) === [1, 3]

@testset "Support StepRange with a non-numeric step" begin
start = Date(1914, 7, 28)
stop = Date(1918, 11, 11)
Expand All @@ -426,6 +429,21 @@ end
@test intersect(start-Day(10):Day(1):stop-Day(10), start:Day(5):stop) ==
start:Day(5):stop-Day(10)-mod(stop-start, Day(5))
end

@testset "Two AbstractRanges" begin
struct DummyRange{T} <: AbstractRange{T}
r
end
Base.iterate(dr::DummyRange) = iterate(dr.r)
Base.iterate(dr::DummyRange, state) = iterate(dr.r, state)
Base.length(dr::DummyRange) = length(dr.r)
Base.in(x::Int, dr::DummyRange) = in(x, dr.r)
Base.unique(dr::DummyRange) = unique(dr.r)
r1 = DummyRange{Int}([1, 2, 3, 3, 4, 5])
r2 = DummyRange{Int}([3, 3, 4, 5, 6])
@test intersect(r1, r2) == [3, 4, 5]
@test intersect(r2, r1) == [3, 4, 5]
end
end
@testset "issubset" begin
@test issubset(1:3, 1:typemax(Int)) #32461
Expand Down
25 changes: 23 additions & 2 deletions test/sets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,16 @@ end
s2 = Set([nothing])
union!(s2, [nothing])
@test s2 == Set([nothing])

@testset "promotion" begin
ints = [1:5, [1, 2], Set([1, 2])]
floats = [2:0.1:3, [2.0, 3.5], Set([2.0, 3.5])]

for a in ints, b in floats
@test eltype(union(a, b)) == Float64
@test eltype(union(b, a)) == Float64
end
end
end

@testset "intersect" begin
Expand All @@ -238,7 +248,7 @@ end
end
end
@test intersect(Set([1]), BitSet()) isa Set{Int}
@test intersect(BitSet([1]), Set()) isa BitSet
@test intersect(BitSet([1]), Set()) isa Set{Any}
@test intersect([1], BitSet()) isa Vector{Int}
# intersect must uniquify
@test intersect([1, 2, 1]) == intersect!([1, 2, 1]) == [1, 2]
Expand All @@ -249,7 +259,18 @@ end
y = () (42,)
@test isempty(x)
@test isempty(y)
@test eltype(x) == eltype(y) == Union{}

# Discussed in PR#41769
@testset "promotion" begin
ints = [1:5, [1, 2], Set([1, 2])]
floats = [2:0.1:3, [2.0, 3.5], Set([2.0, 3.5])]

for a in ints, b in floats
@test eltype(intersect(a, b)) == Float64
@test eltype(intersect(b, a)) == Float64
@test eltype(intersect(a, a, b)) == Float64
end
end
end

@testset "setdiff" begin
Expand Down

0 comments on commit b95c68e

Please sign in to comment.