Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Bool by default for OneHotEncoding Transform #31

Merged
merged 3 commits into from
Feb 25, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Fixup
nicoleepp committed Feb 25, 2021
commit 1f663951b82e812f43ea7d39278dba76cf504caa
10 changes: 5 additions & 5 deletions src/one_hot_encoding.jl
Original file line number Diff line number Diff line change
@@ -13,18 +13,18 @@ of results. It defaults to a `Matrix` of `Bool`s.
Note that this Transform does not support specifying dims other than `:` (all dims) because
it is a one-to-many transform (for example a `Vector` input produces a `Matrix` output).
"""
struct OneHotEncoding{R<:Real, T} <: Transform
categories::Dict{T, Int}
struct OneHotEncoding{R<:Real} <: Transform
categories::Dict

function OneHotEncoding{R}(possible_values::AbstractVector{T}) where {R<:Real, T}
function OneHotEncoding{R}(possible_values::AbstractVector) where {R<:Real}
if length(unique(possible_values)) < length(possible_values)
throw(ArgumentError("Expected a list of all unique possible values"))
end

# Create a dictionary that maps unique values in the input array to column positions
# in the sparse matrix that results from applying the OneHotEncoding transform
categories = Dict(value => i for (i, value) in enumerate(possible_values))
return new{R, T}(categories)
return new{R}(categories)
end
end

@@ -36,7 +36,7 @@ function _apply(x, encoding::OneHotEncoding{R}; kwargs...) where R <: Real
n_categories = length(encoding.categories)
results = zeros(R, length(x), n_categories)

for (i, value) in enumerate(x)
@views for (i, value) in enumerate(x)
col_pos = encoding.categories[value]
results[i, col_pos] = true
end