Skip to content

Commit

Permalink
Merge pull request #13 from JuliaData/jq/makeparts
Browse files Browse the repository at this point in the history
Add TableOperations.makepartitions functionality
  • Loading branch information
quinnj authored Nov 10, 2020
2 parents ffd23f6 + 58ccbf0 commit b3703ba
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
SentinelArrays = "1"
Tables = "1.1"
Tables = "1.2"
julia = "1"

[extras]
Expand Down
107 changes: 107 additions & 0 deletions src/TableOperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,4 +303,111 @@ end

joinpartitions() = x -> joinpartitions(x)

struct ColumnPartitions{T}
source::T
nrows::Int
end

struct RowPartitions{T}
source::T
nrows::Int
end

"""
Tables.makepartitions(source, nrows) => Tables.MakePartitions
source |> Tables.makepartitions(nrows) => Tables.MakePartitions
Take a Tables.jl-compatible source and "partition" it into chunks with `nrows` rows.
Returns a `Tables.MakePartitions` object that implements `Tables.partitions`, where each
iteration returns `nrows` rows. For input tables that produce forward-only row iterators,
a `materializerows=true` keyword argument can be passed that will basically make a copy
of each row to ensure partitioning works as expected.
"""
function makepartitions end

makepartitions(nrows::Int; kw...) = x->makepartitions(x, nrows; kw...)

function makepartitions(x::T, nrows::Int; materializerows::Bool=false) where {T}
nrows < 1 && throw(ArgumentError("cannot create partitions of length $nrows"))
colaccess = Tables.columnaccess(T)
if colaccess
return ColumnPartitions(Tables.columns(x), nrows)
else
r = Tables.rows(x)
if Base.haslength(r) && !materializerows
return Iterators.partition(r, nrows)
else
return RowPartitions(r, nrows)
end
end
end

Tables.partitions(x::ColumnPartitions) = x
Tables.partitions(x::RowPartitions) = x

struct ColumnPartitionRows{T}
source::T
rng::UnitRange{Int}
end

Tables.istable(::Type{<:ColumnPartitionRows}) = true
Tables.rowaccess(::Type{<:ColumnPartitionRows}) = true
Tables.rows(x::ColumnPartitionRows) = x

Base.IteratorSize(::Type{<:ColumnPartitions}) = Base.HasLength()
Base.length(x::ColumnPartitions) = cld(Tables.rowcount(x.source), x.nrows)
Base.eltype(x::ColumnPartitions{T}) where {T} = ColumnPartitionRows{T}

@inline function Base.iterate(x::ColumnPartitions, st=0)
st >= length(x) && return nothing
len = min(x.nrows * (st + 1), Tables.rowcount(x.source))
return ColumnPartitionRows(x.source, (x.nrows * st + 1):len), st + 1
end

Base.IteratorSize(::Type{<:ColumnPartitionRows}) = Base.HasLength()
Base.length(x::ColumnPartitionRows) = length(getfield(x, :rng))
Base.eltype(x::ColumnPartitionRows{T}) where {T} = Tables.ColumnsRow{T}

@inline function Base.iterate(x::ColumnPartitionRows, st=1)
st > length(x) && return nothing
return Tables.ColumnsRow(getfield(x, :source), getfield(x, :rng)[st]), st + 1
end

struct MaterializedRow <: Tables.AbstractRow
names::Vector{Symbol}
values::Vector{Any}
lookup::Dict{Symbol, Any}
end

Tables.getcolumn(x::MaterializedRow, i::Int) = getfield(x, :values)[i]
Tables.getcolumn(x::MaterializedRow, nm::Symbol) = getfield(x, :lookup)[nm]
Tables.getcolumn(x::MaterializedRow, ::Type{T}, i::Int, nm::Symbol) where {T} = getfield(x, :values)[i]
Tables.columnnames(x::MaterializedRow) = getfield(x, :names)

function MaterializedRow(x)
names = collect(Tables.columnnames(x))
values = Any[Tables.getcolumn(x, i) for i = 1:length(names)]
return MaterializedRow(names, values, Dict{Symbol, Any}(nm => val for (nm, val) in zip(names, values)))
end

Base.IteratorSize(::Type{<:RowPartitions}) = Base.SizeUnknown()
Base.eltype(x::RowPartitions) = Vector{MaterializedRow}

@inline function Base.iterate(x::RowPartitions, st=())
st === nothing && return nothing
v = Vector{MaterializedRow}(undef, x.nrows)
y = iterate(x.source, st...)
i = 0
while y !== nothing
i += 1
v[i] = MaterializedRow(y[1])
if i >= x.nrows
break
end
y = iterate(x.source, y[2])
end
i == 0 && return nothing
return resize!(v, i), y === nothing ? nothing : (y[2],)
end

end # module
25 changes: 25 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using TableOperations, Tables, Test

ctable = (A=[1, missing, 3], B=[1.0, 2.0, 3.0], C=["hey", "there", "sailor"])
rtable = Tables.rowtable(ctable)
rtable2 = Iterators.filter(i -> i.a % 2 == 0, [(a=x, b=y) for (x, y) in zip(1:20, 21:40)])

@testset "TableOperations.transform" begin

Expand Down Expand Up @@ -290,4 +291,28 @@ j = TableOperations.joinpartitions(p)
@test isequal(Tables.getcolumn(j, 1), vcat(Tables.getcolumn(ctable, 1), Tables.getcolumn(ctable, 1)))
@test isequal(Tables.getcolumn(j, :A), vcat(Tables.getcolumn(ctable, :A), Tables.getcolumn(ctable, :A)))

end

@testset "TableOperations.makepartitions" begin

# columns
@test_throws ArgumentError TableOperations.makepartitions(ctable, 0)
parts = collect.(collect(Tables.partitions(TableOperations.makepartitions(ctable, 2))))
@test length(parts[1]) == 2
@test length(parts[2]) == 1
@test parts[2][1].A == 3

# rows
parts = collect(Tables.partitions(TableOperations.makepartitions(rtable, 2)))
@test length(parts[1]) == 2
@test length(parts[2]) == 1
@test parts[2][1].A == 3

# forward-only row iterator
parts = collect(Tables.partitions(TableOperations.makepartitions(rtable2, 3)))
@test length(parts) == 4
@test length(parts[1]) == 3
@test length(parts[end]) == 1
@test parts[end][1].a == 20

end

0 comments on commit b3703ba

Please sign in to comment.