-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Composite
transform
#108
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
""" | ||
Composite <: Transform | ||
|
||
A `Composite` transform is a composition of `Transform`s, currently limited to `OneToOne()` | ||
cardinality. It can be fit and applied in a single step. | ||
|
||
The transforms in `Composite([t1, t2, t3])` are applied in `t1`, `t2`, `t3` order, where | ||
the output of `t1` is the input to `t2` etc. When using `∘` to create transforms, the order | ||
is `t3 ∘ t2 ∘ t1`, as in function composition. | ||
|
||
```jldoctest composite | ||
julia> id = IdentityScaling(); | ||
|
||
julia> power = Power(2.0); | ||
|
||
julia> id ∘ power == Composite([power, id]) | ||
true | ||
``` | ||
""" | ||
struct Composite <: Transform | ||
transforms::Tuple{Vararg{Transform}} | ||
|
||
function Composite(transforms::Tuple{Vararg{Transform}}) | ||
all(==(OneToOne()), map(cardinality, transforms)) && return new(transforms) | ||
throw(ArgumentError("Only OneToOne() transforms are supported.")) | ||
end | ||
end | ||
|
||
cardinality(c::Composite) = ∘(map(cardinality, c.transforms)...) | ||
|
||
function fit!(c::Composite, data; kwargs...) | ||
for t in c.transforms | ||
fit!(t, data; kwargs...) | ||
data = t(data) | ||
end | ||
return c | ||
end | ||
|
||
function _apply(x, c::Composite; kwargs...) | ||
data = deepcopy(x) | ||
for t in c.transforms | ||
data = _apply(data, t; kwargs...) | ||
end | ||
return data | ||
end | ||
|
||
# creating composite transforms: reverse the order so that c.transforms[1] is the first | ||
# transforms that gets applied | ||
Base.:(∘)(f::Transform, g::Transform) = Composite((g, f)) | ||
Base.:(∘)(c::Composite, t::Transform) = Composite((t, c.transforms...)) | ||
Base.:(∘)(t::Transform, c::Composite) = Composite((c.transforms..., t)) | ||
Base.:(∘)(c::Composite, c2::Composite) = Composite((c2.transforms..., c.transforms...)) | ||
|
||
Base.:(==)(c::Composite, d::Composite) = return all(map(==, c.transforms, d.transforms)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,3 +46,15 @@ struct ManyToMany <: Cardinality end | |
Returns the [`Cardinality`](@ref) of the `transform`. | ||
""" | ||
function cardinality end | ||
|
||
Base.:(∘)(::OneToOne, ::OneToOne) = OneToOne() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if If we do go for it then There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The alternative is to have an internal There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah making it internal would indicate it shouldn't be used publically |
||
Base.:(∘)(::OneToMany, ::OneToOne) = OneToMany() | ||
Base.:(∘)(::ManyToOne, ::OneToMany) = OneToOne() | ||
Base.:(∘)(::ManyToMany, ::OneToMany) = OneToMany() | ||
Base.:(∘)(::OneToOne, ::ManyToOne) = ManyToOne() | ||
Base.:(∘)(::OneToMany, ::ManyToOne) = ManyToMany() | ||
Base.:(∘)(::ManyToOne, ::ManyToMany) = ManyToOne() | ||
Base.:(∘)(::ManyToMany, ::ManyToMany) = ManyToMany() | ||
function Base.:(∘)(c2::Cardinality, c1::Cardinality) | ||
return throw(ArgumentError("Cannot compose cardinalities: $c2 ∘ $c1.")) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
@testset "composite.jl" begin | ||
@testset "constructor" begin | ||
id = IdentityScaling() | ||
power = Power(3.0) | ||
logt = LogTransform() | ||
@test id ∘ id == Composite((id, id)) | ||
@test id ∘ id ∘ power == Composite((power, id, id)) | ||
@test power ∘ id ∘ power == Composite((power, id, power)) | ||
|
||
@test power ∘ (id ∘ logt) == Composite((logt, id, power)) | ||
@test (power ∘ id) ∘ logt == Composite((logt, id, power)) | ||
@test (power ∘ id) ∘ (logt ∘ id) == Composite((id, logt, id, power)) | ||
|
||
@test_throws ArgumentError id ∘ LinearCombination([1, 2, 3]) | ||
@test_throws ArgumentError OneHotEncoding([1, 2]) ∘ id | ||
Comment on lines
+14
to
+15
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. alternatively we could use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wanted to make sure there is an actual error, so possibly @test_broken is the right thing |
||
end | ||
|
||
@testset "apply" begin | ||
p = Power(4.0) | ||
c = Power(2.0) ∘ Power(2.0) ∘ IdentityScaling() | ||
x = [1, 2, 3] | ||
@test FeatureTransforms.apply(x, p) == FeatureTransforms.apply(x, c) | ||
@test p(x) == c(x) | ||
end | ||
|
||
@testset "fit!" begin | ||
s = StandardScaling() | ||
c = StandardScaling() ∘ IdentityScaling() ∘ StandardScaling() | ||
x = rand(10) | ||
x_copy = deepcopy(x) | ||
|
||
fit!(s, x) | ||
fit!(c, x) | ||
|
||
@test c(x) ≈ s(x) | ||
|
||
# did not change the input data | ||
@test x_copy == x | ||
|
||
# but make sure that it is fit and transformed on the already transformed data, in | ||
# this case leaving the second scaling redundant, i.e. centered at 0.0 and std = 1.0 | ||
@test isapprox(0.0, c.transforms[3].μ; atol=1e-15) | ||
@test isapprox(1.0, c.transforms[3].σ; atol=1e-15) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm...that might get confusing...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I had it the other way around first but the issue then is that the
c.transforms[1]
is the last transform that is applied, which is even more confusing I think.The only totally non-confusing way is to get rid of the
∘
syntactic sugar, which is what switches the order.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't be against this tbh, but it might be handy for building pipelines of transforms like
so I can see value in it being used that way