-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor apply methods to use Traits (#80)
- Loading branch information
Glenn Moynihan
authored
Apr 16, 2021
1 parent
f66d6a4
commit e95ea9c
Showing
5 changed files
with
73 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,72 +1,32 @@ | ||
""" | ||
LinearCombination(coefficients) <: Transform | ||
Calculate the linear combination using the vector coefficients passed in. | ||
""" | ||
struct LinearCombination <: Transform | ||
coefficients::Vector{Real} | ||
end | ||
|
||
cardinality(::LinearCombination) = ManyToOne() | ||
Calculates the linear combination of a collection of terms weighted by some `coefficients`. | ||
""" | ||
apply( | ||
::AbstractArray{<:Real, N}, ::LinearCombination; dims=1, inds=: | ||
) -> AbstractArray{<:Real, N-1} | ||
Applies the [`LinearCombination`](@ref) to each of the specified indices in the N-dimensional | ||
array `A`, reducing along the `dim` provided. The result is an (N-1)-dimensional array. | ||
The default behaviour reduces along the column dimension. | ||
When applied to an N-dimensional array, `LinearCombination` reduces along the `dim` provided | ||
and returns an (N-1)-dimensional array. | ||
If no `inds` are specified, then the transform is applied to all elements. | ||
""" | ||
function apply( | ||
A::AbstractArray{<:Real, N}, LC::LinearCombination; dims=1, inds=: | ||
)::AbstractArray{<:Real, N-1} where N | ||
|
||
dims === Colon() && throw(ArgumentError("dims=: not supported, choose dims ∈ [1, $N]")) | ||
return _sum_terms(eachslice(selectdim(A, dims, inds); dims=dims), LC.coefficients) | ||
end | ||
|
||
""" | ||
apply(table, LC::LinearCombination; [cols], [header]) -> Table | ||
Applies the [`LinearCombination`](@ref) across the specified cols in `table`. If no `cols` | ||
are specified, then the [`LinearCombination`](@ref) is applied to all columns. | ||
Optionally provide a `header` for the output table. If none is provided the default in | ||
`Tables.table` is used. | ||
!!!note | ||
The current default is that `dims=1` but this behaviour will be deprecated in a future | ||
release and the `dims` keyword argument will have to be specified explicitly. | ||
https://github.com/invenia/FeatureTransforms.jl/issues/82 | ||
""" | ||
function apply(table, LC::LinearCombination; cols=_get_cols(table), header=nothing, kwargs...) | ||
Tables.istable(table) || throw(MethodError(apply, (table, LC))) | ||
|
||
# Extract a columns iterator that we should be able to use to mutate the data. | ||
# NOTE: Mutation is not guaranteed for all table types, but it avoid copying the data | ||
coltable = Tables.columntable(table) | ||
cols = _to_vec(cols) | ||
|
||
result = hcat(_sum_terms([getproperty(coltable, col) for col in cols], LC.coefficients)) | ||
return Tables.materializer(table)(_to_table(result, header)) | ||
struct LinearCombination <: Transform | ||
coefficients::Vector{Real} | ||
end | ||
|
||
function apply_append( | ||
A::AbstractArray{<:Real, N}, LC::LinearCombination; append_dim, kwargs... | ||
)::AbstractArray{<:Real, N} where N | ||
# A was reduced along the append_dim so we must reshape the result setting that dim to 1 | ||
new_size = collect(size(A)) | ||
setindex!(new_size, 1, dim(A, append_dim)) | ||
return cat(A, reshape(apply(A, LC; kwargs...), new_size...); dims=append_dim) | ||
end | ||
cardinality(::LinearCombination) = ManyToOne() | ||
|
||
function _sum_terms(terms, coeffs) | ||
function _apply(terms, LC::LinearCombination; kwargs...) | ||
# Need this check because map will work even if there are more/less terms than coeffs | ||
if length(terms) != length(coeffs) | ||
if length(terms) != length(LC.coefficients) | ||
throw(DimensionMismatch( | ||
"Number of terms $(length(terms)) does not match "* | ||
"number of coefficients $(length(coeffs))." | ||
"number of coefficients $(length(LC.coefficients))." | ||
)) | ||
end | ||
return sum(map(*, terms, coeffs)) | ||
|
||
return sum(map(*, terms, LC.coefficients)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters