Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: support unwrapping #17

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions src/fallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,32 @@ haslength(L) = L isa Union{Base.HasShape, Base.HasLength}
"""
allocatecolumn(T, len) = Vector{T}(undef, len)

@inline function allocatecolumns(::Schema{names, types}, len) where {names, types}
function _allocatecolumn(::Type{T}, len; unwrap = t -> false) where {T}
if unwrap(T)
names = fieldnames(T)
types = map(t -> fieldtype(T, t), names)
allocatecolumns(Tables.Schema(names, types), len; unwrap = unwrap)
else
allocatecolumn(T, len)
end
end

@inline function allocatecolumns(::Schema{names, types}, len; unwrap = t -> false) where {names, types}
if @generated
vals = Tuple(:(allocatecolumn($(fieldtype(types, i)), len)) for i = 1:fieldcount(types))
vals = Tuple(:(_allocatecolumn($(fieldtype(types, i)), len; unwrap = unwrap)) for i = 1:fieldcount(types))
return :(NamedTuple{names}(($(vals...),)))
else
return NamedTuple{names}(Tuple(allocatecolumn(fieldtype(types, i), len) for i = 1:fieldcount(types)))
return NamedTuple{names}(Tuple(_allocatecolumn(fieldtype(types, i), len; unwrap = unwrap) for i = 1:fieldcount(types)))
end
end

# add! will push! or setindex! a value depending on if the row-iterator HasLength or not
@inline add!(val, col::Int, nm::Symbol, ::Union{Base.HasLength, Base.HasShape{1}}, nt, row) = setindex!(nt[col], val, row)
@inline add!(val, col::Int, nm::Symbol, T, nt, row) = push!(nt[col], val)
@inline add!(val, col::Int, nm::Symbol, T, nt, row) = add!(nt[col], val, T, row)

@inline function buildcolumns(schema, rowitr::T) where {T}
@inline function buildcolumns(schema, rowitr::T; unwrap = t -> false) where {T}
L = Base.IteratorSize(T)
len = haslength(L) ? length(rowitr) : 0
nt = allocatecolumns(schema, len)
nt = allocatecolumns(schema, len; unwrap = unwrap)
for (i, row) in enumerate(rowitr)
eachcolumn(add!, schema, row, L, nt, i)
end
Expand All @@ -71,6 +80,15 @@ end

@inline add!(dest::AbstractVector, val, ::Union{Base.HasLength, Base.HasShape{1}}, row) = setindex!(dest, val, row)
@inline add!(dest::AbstractVector, val, T, row) = push!(dest, val)
@generated function add!(dest::NamedTuple{names}, val, T, row) where names
exprs = Expr[]
for sym in names
quot = Expr(:quote, sym)
push!(exprs, :(add!(getproperty(dest, $quot), getproperty(val, $quot), T, row)))
end
Expr(:block, exprs...)
end


@inline function add_or_widen!(dest::AbstractVector{T}, val::S, nm::Symbol, L, row, len, updated) where {T, S}
if S === T || val isa T
Expand Down Expand Up @@ -115,10 +133,10 @@ function _buildcolumns(rowitr, row, st, sch, L, columns, rownbr, len, updated)
return updated[]
end

@inline function columns(x::T) where {T}
@inline function columns(x::T; unwrap = t -> false) where {T}
if rowaccess(T)
r = rows(x)
return buildcolumns(schema(r), r)
return buildcolumns(schema(r), r; unwrap = unwrap)
else
throw(ArgumentError("no default `Tables.columns` implementation for type: $T"))
end
Expand Down
8 changes: 7 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,10 @@ end
mt = Tables.DataValueUnwrapper(ei.x |> y->QueryOperators.map(y, x->(a=x.a, c=x.c), Expr(:block)))
@test (mt |> columntable) == (a = Real[1, 2.0, 3], c = ["7", "8", "9"])
@test length(mt |> rowtable) == 3
end
end

@testset "nesting" begin
x = (a = (b = 1,), c = 2)
@test Tables.columns([x]) == (a = [(b = 1,)], c = [2])
@test Tables.columns([x], unwrap = T -> T <: NamedTuple) == (a = (b = [1,],), c = [2])
end