Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LinearCombination transform #8

Merged
merged 6 commits into from
Feb 8, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/Transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ module Transforms

using Tables

export Transform, Power
export LinearCombination, Transform, Power
export transform, transform!

include("utils.jl")
include("transformers.jl")
include("linear_combination.jl")
include("power.jl")

end
77 changes: 77 additions & 0 deletions src/linear_combination.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
LinearCombination(coefficients) <: Transform

Calculate the linear combination using the column weights passed in.
"""
struct LinearCombination <: Transform
coefficients::Vector{Real}
end

function _check_dimensions_match(LC::LinearCombination, num_inds)
num_coefficients = length(LC.coefficients)
if num_inds != num_coefficients
throw(DimensionMismatch(
"Size $num_inds doesn't match number of coefficients ($num_coefficients)"
))
end
end

_sum_row(row, coefficients) = sum(map(*, row, coefficients))

"""
apply(x::AbstractVector, LC::LinearCombination; inds=Colon())

Applies the [`LinearCombination`](@ref) to each of the specified indices in `x`.

If no `inds` are specified, then the [`LinearCombination`](@ref) is applied to all elements.
"""
function apply(x::AbstractVector, LC::LinearCombination; inds=Colon())
morris25 marked this conversation as resolved.
Show resolved Hide resolved
# Treat each element as it's own column
# Error if dimensions don't match
num_inds = inds isa Colon ? length(x) : length(inds)
_check_dimensions_match(LC, num_inds)

return [_sum_row(x[inds], LC.coefficients)]
end

"""
apply(x::AbstractArray, LC::LinearCombination; dims=1, inds=Colon())

Applies the [`LinearCombination`](@ref) to each of the specified indices in `x` along the
dimension specified, which defaults to applying across the columns of x.
nicoleepp marked this conversation as resolved.
Show resolved Hide resolved

If no `inds` are specified, then the [`LinearCombination`](@ref) is applied to all columns.
"""
function apply(x::AbstractArray, LC::LinearCombination; dims=1, inds=Colon())
# Get the number of vectors in the dimension not specified
other_dim = dims == 1 ? 2 : 1
nicoleepp marked this conversation as resolved.
Show resolved Hide resolved
num_inds = inds isa Colon ? size(x, other_dim) : length(inds)
# Error if dimensions don't match
_check_dimensions_match(LC, num_inds)

return [_sum_row(row[inds], LC.coefficients) for row in eachslice(x; dims=dims)]
end

"""
apply(x::Table, LC::LinearCombination; inds=nothing

Applies the [`LinearCombination`](@ref) to each of the specified indices (columns) in `x`.

If no `inds` are specified, then the [`LinearCombination`](@ref) is applied to all columns.
"""
function apply(x, LC::LinearCombination; inds=nothing)
nicoleepp marked this conversation as resolved.
Show resolved Hide resolved
# Error if dimensions don't match
num_inds = inds === nothing ? length(Tables.columnnames(x)) : length(inds)
_check_dimensions_match(LC, num_inds)

# Keep the generic form when not specifying column names
# because that is much more performant than selecting each col by name
if inds === nothing
return [_sum_row(row, LC.coefficients) for row in Tables.rows(x)]
else
return [
_sum_row([row[cname] for cname in inds], LC.coefficients)
for row in Tables.rows(x)
]
end
end
202 changes: 202 additions & 0 deletions test/linear_combination.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
@testset "linear combination" begin

lc = LinearCombination([1, -1])
@test lc isa Transform

@testset "Vector" begin
x = [1, 2]
expected = [-1]

@testset "all inds" begin
@test Transforms.apply(x, lc) == expected
@test lc(x) == expected
end

@testset "dims not supported" begin
@test_throws MethodError Transforms.apply(x, lc; dims=1)
end

@testset "dimension mismatch" begin
x = [1, 2, 3]
@test_throws DimensionMismatch Transforms.apply(x, lc)
end

@testset "specified inds" begin
x = [1, 2, 3]
inds = [2, 3]
expected = [-1]

@test Transforms.apply(x, lc; inds=inds) == expected
@test lc(x; inds=inds) == expected
end
end

@testset "Matrix" begin
M = [1 1; 2 2; 3 5]
expected = [0, 0, -2]

@testset "all inds" begin
@test Transforms.apply(M, lc) == expected
@test lc(M) == expected
end

# TODO: Colon() is not supported by eachslice(A, dims=:), do we care?
nicoleepp marked this conversation as resolved.
Show resolved Hide resolved
@testset "dims" begin
@testset "dims = 1" begin
@test Transforms.apply(M, lc; dims=1) == expected
@test lc(M; dims=1) == expected
end

@testset "dims = 2" begin
# There are 3 rows so trying to apply along dim 2 without specifying inds
# won't work
@test_throws DimensionMismatch Transforms.apply(M, lc; dims=2)

@test Transforms.apply(M, lc; dims=2, inds=[2, 3]) == [-1, -3]
@test lc(M; dims=2, inds=[1, 3]) == [-2, -4]
end
end

@testset "dimension mismatch" begin
M = [1 1 1; 2 2 2]
@test_throws DimensionMismatch Transforms.apply(M, lc)
end

@testset "specified inds" begin
M = [1 1 5; 2 2 4]
inds = [2, 3]
expected = [-4, -2]

@test Transforms.apply(M, lc; inds=inds) == expected
@test lc(M; inds=inds) == expected
end
end

@testset "AxisArray" begin
A = AxisArray([1 2; 4 5], foo=["a", "b"], bar=["x", "y"])
expected = [-1, -1]

@testset "all inds" begin
@test Transforms.apply(A, lc) == expected
@test lc(A) == expected
end

@testset "dims" begin
@testset "dims = 1" begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this also work if passing in the axis key? e.g. foo, bar?

Copy link
Member

@glennmoy glennmoy Feb 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you do the same with AxisArrays though? If I use (Colon(), :foo, :bar) like for KeyedArray in test/power.jl I get

dims = foo: Error During Test at /Users/bencottier/JuliaEnvs/Transform.jl/test/power.jl:65
  Got exception outside of a @test
  MethodError: no method matching length(::Symbol)
  Closest candidates are:
    length(!Matched::Base.Iterators.Flatten{Tuple{}}) at iterators.jl:1061
    length(!Matched::Base.MethodList) at reflection.jl:872
    length(!Matched::DataStructures.EnumerateAll) at /Users/bencottier/.julia/packages/DataStructures/ixwFs/src/multi_dict.jl:96
    ...
  Stacktrace:
   [1] #eachslice#196 at ./abstractarraymath.jl:496 [inlined]
   [2] apply!(::AxisArray{Int64,2,Array{Int64,2},Tuple{Axis{:foo,Array{String,1}},Axis{:bar,Array{String,1}}}}, ::Power; dims::Symbol, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /Users/bencottier/JuliaEnvs/Transform.jl/src/transformers.jl:57
   [3] apply(::AxisArray{Int64,2,Array{Int64,2},Tuple{Axis{:foo,Array{String,1}},Axis{:bar,Array{String,1}}}}, ::Power; kwargs::Base.Iterators.Pairs{Symbol,Symbol,Tuple{Symbol},NamedTuple{(:dims,),Tuple{Symbol}}}) at /Users/bencottier/JuliaEnvs/Transform.jl/src/transformers.jl:64
   [4] macro expansion at /Users/bencottier/JuliaEnvs/Transform.jl/test/power.jl:66 [inlined]
   [5] macro expansion at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/Test/src/Test.jl:1190 [inlined]
   [6] macro expansion at /Users/bencottier/JuliaEnvs/Transform.jl/test/power.jl:65 [inlined]
   [7] macro expansion at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/Test/src/Test.jl:1115 [inlined]
   [8] macro expansion at /Users/bencottier/JuliaEnvs/Transform.jl/test/power.jl:62 [inlined]
   [9] macro expansion at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/Test/src/Test.jl:1115 [inlined]
   [10] top-level scope at /Users/bencottier/JuliaEnvs/Transform.jl/test/power.jl:3
   [11] include(::String) at ./client.jl:457
   [12] top-level scope at /Users/bencottier/JuliaEnvs/Transform.jl/test/runtests.jl:12
   [13] top-level scope at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/Test/src/Test.jl:1115
   [14] top-level scope at /Users/bencottier/JuliaEnvs/Transform.jl/test/runtests.jl:11
   [15] include(::String) at ./client.jl:457
   [16] top-level scope at none:6
   [17] eval(::Module, ::Any) at ./boot.jl:331
   [18] exec_options(::Base.JLOptions) at ./client.jl:272
   [19] _start() at ./client.jl:506

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like eachslice doesn't work on AxisArrays maybe. Nevermind it was only a passing thought.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it doesn't work without some extra voodoo

julia> eachslice(A, dims=axisdim(A, Axis{:foo}))
Base.Generator{Base.OneTo{Int64},Base.var"#199#202"{AxisArray{Int64,2,Array{Int64,2},Tuple{Axis{:foo,Array{String,1}},Axis{:bar,Array{String,1}}}},Tuple{},Tuple{Colon}}}(Base.var"#199#202"{AxisArray{Int64,2,Array{Int64,2},Tuple{Axis{:foo,Array{String,1}},Axis{:bar,Array{String,1}}}},Tuple{},Tuple{Colon}}([1 2 3; 4 5 6], (), (Colon(),)), Base.OneTo(2))

And it's been on their todo list for 5 years 🙄
JuliaArrays/AxisArrays.jl#7

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can make this work if we want to, by checking if it's an AxisArray. I might prefer to do this in a follow up MR just to keep diffs smaller.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#13

@test Transforms.apply(A, lc; dims=1) == expected
@test lc(A; dims=1) == expected
end

@testset "dims = 2" begin
@test Transforms.apply(A, lc; dims=2) == [-3, -3]
@test lc(A; dims=2) == [-3, -3]
end
end

@testset "dimension mismatch" begin
A = AxisArray([1 2 3; 4 5 5], foo=["a", "b"], bar=["x", "y", "z"])
@test_throws DimensionMismatch Transforms.apply(A, lc)
end

@testset "specified inds" begin
A = AxisArray([1 2 3; 4 5 5], foo=["a", "b"], bar=["x", "y", "z"])
inds = [1, 2]
expected = [-1, -1]

@test Transforms.apply(A, lc; inds=inds) == expected
@test lc(A; inds=inds) == expected
end
end

@testset "AxisKey" begin
A = KeyedArray([1 2; 4 5], foo=["a", "b"], bar=["x", "y"])
expected = [-1, -1]

@testset "all inds" begin
@test Transforms.apply(A, lc) == expected
@test lc(A) == expected
end

@testset "dims" begin
@testset "dims = 1" begin
glennmoy marked this conversation as resolved.
Show resolved Hide resolved
@test Transforms.apply(A, lc; dims=1) == expected
@test lc(A; dims=1) == expected
end

@testset "dims = 2" begin
@test Transforms.apply(A, lc; dims=2) == [-3, -3]
@test lc(A; dims=2) == [-3, -3]
end
end

@testset "dimension mismatch" begin
A = KeyedArray([1 2 3; 4 5 6], foo=["a", "b"], bar=["x", "y", "z"])
@test_throws DimensionMismatch Transforms.apply(A, lc)
end

@testset "specified inds" begin
A = KeyedArray([1 2 3; 4 5 5], foo=["a", "b"], bar=["x", "y", "z"])
inds = [1, 2]
expected = [-1, -1]

@test Transforms.apply(A, lc; inds=inds) == expected
@test lc(A; inds=inds) == expected
end
end

@testset "NamedTuple" begin
nt = (a = [1, 2, 3], b = [4, 5, 6])
expected = [-3, -3, -3]

@testset "all cols" begin
@test Transforms.apply(nt, lc) == expected
@test lc(nt) == expected
end

@testset "dims not supported" begin
@test_throws MethodError Transforms.apply(nt, lc; dims=1)
end

@testset "dimension mismatch" begin
nt = (a = [1, 2, 3], b = [4, 5, 6], c = [1, 1, 1])
@test_throws DimensionMismatch Transforms.apply(nt, lc)
end

@testset "specified cols" begin
nt = (a = [1, 2, 3], b = [4, 5, 6], c = [1, 1, 1])
inds = [:a, :b]
expected = [-3, -3, -3]

@test Transforms.apply(nt, lc; inds=inds) == expected
@test lc(nt; inds=inds) == expected
end
end

@testset "DataFrame" begin
df = DataFrame(:a => [1, 2, 3], :b => [4, 5, 6])
expected = [-3, -3, -3]

@testset "all cols" begin
@test Transforms.apply(df, lc) == expected
@test lc(df) == expected
end

@testset "dims not supported" begin
@test_throws MethodError Transforms.apply(df, lc; dims=1)
end

@testset "dimension mismatch" begin
df = DataFrame(:a => [1, 2, 3], :b => [4, 5, 6], :c => [1, 1, 1])
@test_throws DimensionMismatch Transforms.apply(df, lc)
end

@testset "specified cols" begin
df = DataFrame(:a => [1, 2, 3], :b => [4, 5, 6], :c => [1, 1, 1])
inds = [:b, :c]
expected = [3, 4, 5]

@test Transforms.apply(df, lc; inds=inds) == expected
@test lc(df; inds=inds) == expected
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ using Transforms: _try_copy
using Test

@testset "Transforms.jl" begin
include("linear_combination.jl")
include("power.jl")
end