Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Composite transform #108

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/FeatureTransforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export Transform, transform, transform!
export HoD, LinearCombination, OneHotEncoding, Periodic, Power
export AbstractScaling, IdentityScaling, MeanStdScaling, StandardScaling
export LogTransform, InverseHyperbolicSine
export Composite

include("utils.jl")
include("traits.jl")
Expand All @@ -24,6 +25,7 @@ include("periodic.jl")
include("power.jl")
include("scaling.jl")
include("temporal.jl")
include("composite.jl")

include("test_utils.jl")

Expand Down
54 changes: 54 additions & 0 deletions src/composite.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Composite <: Transform

A `Composite` transform is a composition of `Transform`s, currently limited to `OneToOne()`
cardinality. It can be fit and applied in a single step.

The transforms in `Composite([t1, t2, t3])` are applied in `t1`, `t2`, `t3` order, where
the output of `t1` is the input to `t2` etc. When using `∘` to create transforms, the order
is `t3 ∘ t2 ∘ t1`, as in function composition.
Comment on lines +7 to +9
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm...that might get confusing...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I had it the other way around first but the issue then is that the c.transforms[1] is the last transform that is applied, which is even more confusing I think.

The only totally non-confusing way is to get rid of the syntactic sugar, which is what switches the order.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only totally non-confusing way is to get rid of the ∘ syntactic sugar, which is what switches the order.

I wouldn't be against this tbh, but it might be handy for building pipelines of transforms like

tc = Composite(t1, t2)
...
...
tc = t3  tc
...
tc = t4  tc

so I can see value in it being used that way


```jldoctest composite
julia> id = IdentityScaling();

julia> power = Power(2.0);

julia> id ∘ power == Composite([power, id])
true
```
"""
struct Composite <: Transform
transforms::Tuple{Vararg{Transform}}

function Composite(transforms::Tuple{Vararg{Transform}})
all(==(OneToOne()), map(cardinality, transforms)) && return new(transforms)
throw(ArgumentError("Only OneToOne() transforms are supported."))
end
end

cardinality(c::Composite) = ∘(map(cardinality, c.transforms)...)

function fit!(c::Composite, data; kwargs...)
for t in c.transforms
fit!(t, data; kwargs...)
data = t(data)
end
return c
end

function _apply(x, c::Composite; kwargs...)
data = deepcopy(x)
for t in c.transforms
data = _apply(data, t; kwargs...)
end
return data
end

# creating composite transforms: reverse the order so that c.transforms[1] is the first
# transforms that gets applied
Base.:(∘)(f::Transform, g::Transform) = Composite((g, f))
Base.:(∘)(c::Composite, t::Transform) = Composite((t, c.transforms...))
Base.:(∘)(t::Transform, c::Composite) = Composite((c.transforms..., t))
Base.:(∘)(c::Composite, c2::Composite) = Composite((c2.transforms..., c.transforms...))

Base.:(==)(c::Composite, d::Composite) = return all(map(==, c.transforms, d.transforms))
12 changes: 12 additions & 0 deletions src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,15 @@ struct ManyToMany <: Cardinality end
Returns the [`Cardinality`](@ref) of the `transform`.
"""
function cardinality end

Base.:(∘)(::OneToOne, ::OneToOne) = OneToOne()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if is the right syntax here... it's used to compose functions whereas here we're sort of "reducing" over the cardinalities.
I guess it makes sense in the context of composing the equivalent transforms?

If we do go for it then becomes part of the Cardinality API and should be documented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative is to have an internal _compose function which would do this instead? I don't think users need to use this at all.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah making it internal would indicate it shouldn't be used publically

Base.:(∘)(::OneToMany, ::OneToOne) = OneToMany()
Base.:(∘)(::ManyToOne, ::OneToMany) = OneToOne()
Base.:(∘)(::ManyToMany, ::OneToMany) = OneToMany()
Base.:(∘)(::OneToOne, ::ManyToOne) = ManyToOne()
Base.:(∘)(::OneToMany, ::ManyToOne) = ManyToMany()
Base.:(∘)(::ManyToOne, ::ManyToMany) = ManyToOne()
Base.:(∘)(::ManyToMany, ::ManyToMany) = ManyToMany()
function Base.:(∘)(c2::Cardinality, c1::Cardinality)
return throw(ArgumentError("Cannot compose cardinalities: $c2 ∘ $c1."))
end
45 changes: 45 additions & 0 deletions test/composite.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
@testset "composite.jl" begin
@testset "constructor" begin
id = IdentityScaling()
power = Power(3.0)
logt = LogTransform()
@test id ∘ id == Composite((id, id))
@test id ∘ id ∘ power == Composite((power, id, id))
@test power ∘ id ∘ power == Composite((power, id, power))

@test power ∘ (id ∘ logt) == Composite((logt, id, power))
@test (power ∘ id) ∘ logt == Composite((logt, id, power))
@test (power ∘ id) ∘ (logt ∘ id) == Composite((id, logt, id, power))

@test_throws ArgumentError id ∘ LinearCombination([1, 2, 3])
@test_throws ArgumentError OneHotEncoding([1, 2]) ∘ id
Comment on lines +14 to +15
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alternatively we could use test_skip until these are implemented?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to make sure there is an actual error, so possibly @test_broken is the right thing

end

@testset "apply" begin
p = Power(4.0)
c = Power(2.0) ∘ Power(2.0) ∘ IdentityScaling()
x = [1, 2, 3]
@test FeatureTransforms.apply(x, p) == FeatureTransforms.apply(x, c)
@test p(x) == c(x)
end

@testset "fit!" begin
s = StandardScaling()
c = StandardScaling() ∘ IdentityScaling() ∘ StandardScaling()
x = rand(10)
x_copy = deepcopy(x)

fit!(s, x)
fit!(c, x)

@test c(x) ≈ s(x)

# did not change the input data
@test x_copy == x

# but make sure that it is fit and transformed on the already transformed data, in
# this case leaving the second scaling redundant, i.e. centered at 0.0 and std = 1.0
@test isapprox(0.0, c.transforms[3].μ; atol=1e-15)
@test isapprox(1.0, c.transforms[3].σ; atol=1e-15)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ using TimeZones
include("scaling.jl")
include("temporal.jl")
include("traits.jl")
include("composite.jl")
include("test_utils.jl")

include("types/tables.jl")
Expand Down
22 changes: 22 additions & 0 deletions test/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,26 @@
for t in (OneToOne(), OneToMany(), ManyToOne(), ManyToMany())
@test t isa FeatureTransforms.Cardinality
end

@testset "composite" begin
@test OneToOne() == OneToOne() ∘ OneToOne()
@test OneToMany() == OneToMany() ∘ OneToOne()
@test_throws ArgumentError ManyToOne() ∘ OneToOne()
@test_throws ArgumentError ManyToMany() ∘ OneToOne()

@test ManyToOne() == OneToOne() ∘ ManyToOne()
@test ManyToMany() == OneToMany() ∘ ManyToOne()
@test_throws ArgumentError ManyToOne() ∘ ManyToOne()
@test_throws ArgumentError ManyToMany() ∘ ManyToOne()

@test_throws ArgumentError OneToOne() ∘ OneToMany()
@test_throws ArgumentError OneToMany() ∘ OneToMany()
@test OneToOne() == ManyToOne() ∘ OneToMany()
@test OneToMany() == ManyToMany() ∘ OneToMany()

@test_throws ArgumentError OneToOne() ∘ ManyToMany()
@test_throws ArgumentError OneToMany() ∘ ManyToMany()
@test ManyToOne() == ManyToOne() ∘ ManyToMany()
@test ManyToMany() == ManyToMany() ∘ ManyToMany()
end
end