diff --git a/Project.toml b/Project.toml index 49a5876..6fc26fc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,12 @@ name = "AxisSets" uuid = "a1a1544e-ba16-4f6d-8861-e833517b754e" authors = ["Invenia Technical Computing Corporation"] -version = "0.1.6" +version = "0.1.7" [deps] AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f" AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" +FeatureTransforms = "8fd68953-04b8-4117-ac19-158bf6de9782" Impute = "f7bf1975-0170-51b9-8c5f-a992d46b9575" NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" @@ -13,12 +14,13 @@ ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83" [compat] AutoHashEquals = "0.2" -AxisKeys = "0.1" +AxisKeys = "0.1.16" +FeatureTransforms = "0.3.6" Impute = "0.6" NamedDims = "0.2" OrderedCollections = "1" ReadOnlyArrays = "0.1" -julia = "1.3" +julia = "1.5" [extras] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 9759082..e568d7d 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -8,21 +8,21 @@ version = "0.5.0" [[Adapt]] deps = ["LinearAlgebra"] -git-tree-sha1 = "ffcfa2d345aaee0ef3d8346a073d5dd03c983ebe" +git-tree-sha1 = "f1b523983a58802c4695851926203b36e28f09db" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.2.0" +version = "3.3.0" + +[[ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" [[ArrayInterface]] deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"] -git-tree-sha1 = "e7edcc1ac140cce87b7442ff0fa88b5f19fb71fa" +git-tree-sha1 = "2fbfa5f372352f92191b63976d070dc7195f47a4" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "3.1.3" +version = "3.1.7" [[Artifacts]] -deps = ["Pkg"] -git-tree-sha1 = "c30985d8821e0cd73870b17b0ed0ce6dc44cb744" uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" -version = "1.3.0" [[AutoHashEquals]] git-tree-sha1 = "45bb6705d93be619b81451bb2006b7ee5d4e4453" @@ -31,20 +31,20 @@ version = "0.2.0" [[AxisKeys]] deps = ["AbstractFFTs", "CovarianceEstimation", "IntervalSets", "InvertedIndices", "LazyStack", "LinearAlgebra", "NamedDims", "OffsetArrays", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "063d295667b562a974c8f6a0c21458a3c89df08d" +git-tree-sha1 = "118c5c2c9f509f503efa05fa2385936bc2cad78d" uuid = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" -version = "0.1.13" +version = "0.1.16" [[AxisSets]] -deps = ["AutoHashEquals", "AxisKeys", "Impute", "NamedDims", "OrderedCollections", "ReadOnlyArrays"] +deps = ["AutoHashEquals", "AxisKeys", "FeatureTransforms", "Impute", "NamedDims", "OrderedCollections", "ReadOnlyArrays"] path = ".." uuid = "a1a1544e-ba16-4f6d-8861-e833517b754e" -version = "0.1.0" +version = "0.1.7" [[BSON]] -git-tree-sha1 = "db18b5ea04686f73d269e10bdb241947c40d7d6f" +git-tree-sha1 = "92b8a8479128367aaab2620b8e73dff632f5ae69" uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" -version = "0.3.2" +version = "0.3.3" [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -61,17 +61,11 @@ git-tree-sha1 = "6d4242ef4cb1539e7ede8e01a47a32365e0a34cd" uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" version = "0.8.4" -[[CategoricalArrays]] -deps = ["DataAPI", "Future", "JSON", "Missings", "Printf", "Statistics", "StructTypes", "Unicode"] -git-tree-sha1 = "dbfddfafb75fae5356e00529ce67454125935945" -uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597" -version = "0.9.3" - [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "919c7f3151e79ff196add81d7f4e45d91bbf420b" +git-tree-sha1 = "ac4132ad78082518ec2037ae5770b6e796f7f956" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.25.0" +version = "3.27.0" [[CovarianceEstimation]] deps = ["LinearAlgebra", "Statistics", "StatsBase"] @@ -96,10 +90,10 @@ uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" version = "0.7.7" [[DataFrames]] -deps = ["CategoricalArrays", "Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "b0db5579803eabb33f1274ca7ca2f472fdfb7f2a" +deps = ["Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "56ff5833e5b755d2db654479993e949e73606b64" uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "0.22.5" +version = "1.0.0" [[DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] @@ -132,15 +126,19 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[DocStringExtensions]] deps = ["LibGit2", "Markdown", "Pkg", "Test"] -git-tree-sha1 = "50ddf44c53698f5e784bbebb3f4b21c5807401b1" +git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.8.3" +version = "0.8.4" [[Documenter]] deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "b7715ae18be02110a8cf9cc8ed2ccdb1e3e3aba2" +git-tree-sha1 = "3ebb967819b284dc1e3c0422229b58a40a255649" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.26.1" +version = "0.26.3" + +[[Downloads]] +deps = ["ArgTools", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" [[EllipsisNotation]] deps = ["ArrayInterface"] @@ -159,6 +157,12 @@ git-tree-sha1 = "0fa3b52a04a4e210aeb1626def9c90df3ae65268" uuid = "8f5d6c58-4d21-5cfd-889c-e3ad7ee6a615" version = "1.1.0" +[[FeatureTransforms]] +deps = ["Dates", "InteractiveUtils", "NamedDims", "Statistics", "Tables"] +git-tree-sha1 = "14aca9d7f91be3968c6d89b0bb1edd0d66d25f39" +uuid = "8fd68953-04b8-4117-ac19-158bf6de9782" +version = "0.3.6" + [[Formatting]] deps = ["Printf"] git-tree-sha1 = "8339d61043228fdd3eb658d86c926cb282ae72a8" @@ -188,9 +192,9 @@ version = "0.1.0" [[Impute]] deps = ["BSON", "CSV", "DataDeps", "Distances", "IterTools", "LinearAlgebra", "Missings", "NamedDims", "NearestNeighbors", "Random", "Statistics", "StatsBase", "TableOperations", "Tables"] -git-tree-sha1 = "0f2132d3e1438d930a05536c388f113991d64022" +git-tree-sha1 = "8ea049aaa69914ca9c2c51fb69990d2778def47a" uuid = "f7bf1975-0170-51b9-8c5f-a992d46b9575" -version = "0.6.3" +version = "0.6.5" [[IniFile]] deps = ["Test"] @@ -204,9 +208,9 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[IntervalSets]] deps = ["Dates", "EllipsisNotation", "Statistics"] -git-tree-sha1 = "93a6d78525feb0d3ee2a2ae83a7d04db1db5663f" +git-tree-sha1 = "3cc368af3f110a767ac786560045dceddfc16758" uuid = "8197267c-284f-5f27-9208-e0e47529a953" -version = "0.5.2" +version = "0.5.3" [[InvertedIndices]] deps = ["Test"] @@ -225,9 +229,10 @@ uuid = "82899510-4779-5014-852e-03e436cf321d" version = "1.0.0" [[JLLWrappers]] -git-tree-sha1 = "a431f5f2ca3f4feef3bd7a5e94b8b8d4f2f647a0" +deps = ["Preferences"] +git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.2.0" +version = "1.3.0" [[JSON]] deps = ["Dates", "Mmap", "Parsers", "Unicode"] @@ -241,10 +246,22 @@ git-tree-sha1 = "a8bf67afad3f1ee59d367267adb7c44ccac7fdee" uuid = "1fad7336-0346-5a1a-a56f-a06ba010965b" version = "0.0.7" +[[LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" + +[[LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" + [[LibGit2]] -deps = ["Printf"] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +[[LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" + [[Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -278,10 +295,8 @@ uuid = "739be429-bea8-5141-9913-cc70e7f3736d" version = "1.0.3" [[MbedTLS_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "0eef589dd1c26a3ac9d753fe1a8bcad63f956fa6" +deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.16.8+1" [[Missings]] deps = ["DataAPI"] @@ -298,11 +313,14 @@ git-tree-sha1 = "916b850daad0d46b8c71f65f719c49957e9513ed" uuid = "78c3b35d-d492-501b-9361-3d52fe80e533" version = "0.7.1" +[[MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" + [[NamedDims]] deps = ["AbstractFFTs", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "d60a8f176d28ed99ca0f63738fb021bfa0a69ba3" +git-tree-sha1 = "0838a2ee62194d1a4dbf3904dca75cf62374b701" uuid = "356022a1-0364-5f58-8944-0da4b18d706f" -version = "0.2.31" +version = "0.2.32" [[NearestNeighbors]] deps = ["Distances", "StaticArrays"] @@ -311,9 +329,7 @@ uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" version = "0.4.8" [[NetworkOptions]] -git-tree-sha1 = "ed3157f48a05543cce9b241e1f2815f7e843d96e" uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" [[OffsetArrays]] deps = ["Adapt"] @@ -328,12 +344,12 @@ version = "1.4.0" [[Parsers]] deps = ["Dates"] -git-tree-sha1 = "50c9a9ed8c714945e01cd53a21007ed3865ed714" +git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "1.0.15" +version = "1.1.0" [[Pkg]] -deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" [[PooledArrays]] @@ -342,18 +358,24 @@ git-tree-sha1 = "cde4ce9d6f33219465b55162811d8de8139c0414" uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" version = "1.2.1" +[[Preferences]] +deps = ["TOML"] +git-tree-sha1 = "ea79e4c9077208cd3bc5d29631a26bc0cff78902" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.2.1" + [[PrettyTables]] deps = ["Crayons", "Formatting", "Markdown", "Reexport", "Tables"] -git-tree-sha1 = "574a6b3ea95f04e8757c0280bb9c29f1a5e35138" +git-tree-sha1 = "a7162ad93a899333717481f448a235ffafeb5eba" uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "0.11.1" +version = "1.0.0" [[Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" [[REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets"] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" [[Random]] @@ -378,9 +400,9 @@ version = "1.0.0" [[Requires]] deps = ["UUIDs"] -git-tree-sha1 = "cfbac6c1ed70c002ec6361e7fd334f02820d6419" +git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621" uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.1.2" +version = "1.1.3" [[SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -413,15 +435,15 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[Static]] deps = ["IfElse"] -git-tree-sha1 = "1b0fdbbc15c5b13dcf52343ac681a3060ddb8ee4" +git-tree-sha1 = "ddec5466a1d2d7e58adf9a427ba69763661aacf6" uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" -version = "0.2.1" +version = "0.2.4" [[StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "9da72ed50e94dbff92036da395275ed114e04d49" +git-tree-sha1 = "e8cd1b100d37f5b4cfd2c83f45becf61c762eaf7" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.0.1" +version = "1.1.1" [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] @@ -429,15 +451,13 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] -git-tree-sha1 = "400aa43f7de43aeccc5b2e39a76a79d262202b76" +git-tree-sha1 = "4bc58880426274277a066de306ef19ecc22a6863" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.3" +version = "0.33.5" -[[StructTypes]] -deps = ["Dates", "UUIDs"] -git-tree-sha1 = "d7f4287dbc1e590265f50ceda1b40ed2bb31bbbb" -uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" -version = "1.4.0" +[[TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" [[TableOperations]] deps = ["SentinelArrays", "Tables", "Test"] @@ -447,18 +467,22 @@ version = "1.0.0" [[TableTraits]] deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "b1ad568ba658d8cbb3b892ed5380a6f3e781a81e" +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.0" +version = "1.0.1" [[Tables]] deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "TableTraits", "Test"] -git-tree-sha1 = "a716dde43d57fa537a19058d044b495301ba6565" +git-tree-sha1 = "c9d2d262e9a327be1f35844df25fe4561d258dc9" uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.3.2" +version = "1.4.2" + +[[Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" [[Test]] -deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[TimeZones]] @@ -481,15 +505,13 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [[XML2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "Zlib_jll"] -git-tree-sha1 = "be0db24f70aae7e2b89f2f3092e93b8606d659a6" +git-tree-sha1 = "afd2b541e8fd425cd3b7aa55932a257035ab4a70" uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.9.10+3" +version = "2.9.11+0" [[Zlib_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "320228915c8debb12cb434c59057290f0834dbf6" +deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.11+18" [[ZygoteRules]] deps = ["MacroTools"] @@ -497,8 +519,10 @@ git-tree-sha1 = "9e7a1e8ca60b742e508a315c17eef5211e7fbfd7" uuid = "700de1a5-db45-46bc-99cf-38207098b444" version = "0.2.1" +[[nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" + [[p7zip_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "ee65cfa19bea645698a0224bfa216f2b1c8b559f" +deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "16.2.0+3" diff --git a/docs/Project.toml b/docs/Project.toml index dbec0b9..f992886 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,7 @@ AxisSets = "a1a1544e-ba16-4f6d-8861-e833517b754e" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +FeatureTransforms = "8fd68953-04b8-4117-ac19-158bf6de9782" Impute = "f7bf1975-0170-51b9-8c5f-a992d46b9575" NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" diff --git a/src/AxisSets.jl b/src/AxisSets.jl index 8621849..fffe089 100644 --- a/src/AxisSets.jl +++ b/src/AxisSets.jl @@ -2,6 +2,7 @@ module AxisSets using AutoHashEquals using AxisKeys +using FeatureTransforms using Impute using NamedDims using OrderedCollections @@ -88,5 +89,7 @@ include("dataset.jl") include("indexing.jl") include("functions.jl") include("impute.jl") +include("featuretransforms.jl") +include("utils.jl") end diff --git a/src/featuretransforms.jl b/src/featuretransforms.jl new file mode 100644 index 0000000..09fa100 --- /dev/null +++ b/src/featuretransforms.jl @@ -0,0 +1,46 @@ +""" + FeatureTransforms.apply(ds::KeyedDataset, t::Transform, [key]; dims=:, kwargs...) + +Apply the `Transform` to each component of the [`KeyedDataset`](@ref). +Returns a new dataset with the same constraints, but transformed components. + +The transform can be applied to a subselection of components via a [`Pattern`](@ref) `key`. +Otherwise, components are selected by the desired `dims`. + +Keyword arguments including `dims` are passed to the appropriate `FeatureTransforms` method +for a component. + +# Example +```jldoctest +julia> using AxisKeys, FeatureTransforms; using AxisSets: KeyedDataset, Pattern, flatten; + +julia> ds = KeyedDataset( + flatten([ + :train => [ + :load => KeyedArray([7.0 7.7; 8.0 8.2; 9.0 9.9]; time=1:3, loc=[:x, :y]), + :price => KeyedArray([-2.0 4.0; 3.0 2.0; -1.0 -1.0]; time=1:3, id=[:a, :b]), + ], + :predict => [ + :load => KeyedArray([7.0 7.7; 8.1 7.9; 9.0 9.9]; time=1:3, loc=[:x, :y]), + :price => KeyedArray([0.5 -1.0; -5.0 -2.0; 0.0 1.0]; time=1:3, id=[:a, :b]), + ] + ])... + ); + +julia> p = Power(2); + +julia> r = FeatureTransforms.apply(ds, p, (:_, :price, :_)); + +julia> [k => parent(parent(v)) for (k, v) in r.data] +4-element Vector{Pair{Tuple{Symbol, Symbol}, Matrix{Float64}}}: + (:train, :load) => [7.0 7.7; 8.0 8.2; 9.0 9.9] + (:train, :price) => [4.0 16.0; 9.0 4.0; 1.0 1.0] + (:predict, :load) => [7.0 7.7; 8.1 7.9; 9.0 9.9] + (:predict, :price) => [0.25 1.0; 25.0 4.0; 0.0 1.0] +``` +""" +function FeatureTransforms.apply(ds::KeyedDataset, t::Transform, key=Pattern((:__,)); kwargs...) + return map(ds, _pattern(key)) do a + FeatureTransforms.apply(a, t; kwargs...) + end +end diff --git a/src/impute.jl b/src/impute.jl index 9bd548d..04cdc9c 100644 --- a/src/impute.jl +++ b/src/impute.jl @@ -126,10 +126,6 @@ julia> [k => parent(parent(v)) for (k, v) in Impute.filter(ds; dims=:loc).data] """ Impute.apply(ds::KeyedDataset, f::Filter; dims) = Impute.apply!(deepcopy(ds), f; dims=dims) -_pattern(dims::Pattern) = dims -_pattern(dims::Tuple) = Pattern(dims) -_pattern(dims) = Pattern(:__, dims) - function Impute.apply!(ds::KeyedDataset, f::Filter; dims) pattern = _pattern(dims) dim = pattern.segments[end] diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..d0f94ec --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,5 @@ +# Convert a dims argument to a Pattern +_pattern(dims::Pattern) = dims +_pattern(dims::Tuple) = Pattern(dims) +_pattern(::Colon) = Pattern(:__) +_pattern(dims) = Pattern(:__, dims) diff --git a/test/featuretransforms.jl b/test/featuretransforms.jl new file mode 100644 index 0000000..cdb0ae4 --- /dev/null +++ b/test/featuretransforms.jl @@ -0,0 +1,360 @@ +@testset "FeatureTransforms" begin + + M1 = [0.0 1.0; 1.0 2.0; -0.5 0.0] + M2 = [-2.0 4.0; 3.0 2.0; -1.0 -1.0] + M3 = [0.0 1.0; -1.0 0.5; -0.5 0.0] + M4 = [0.5 -1.0; -5.0 -2.0; 0.0 1.0] + + ds = KeyedDataset( + flatten([ + :train => [ + :load => KeyedArray(M1; time=1:3, loc=[:x, :y]), + :price => KeyedArray(M2; time=1:3, id=[:a, :b]), + ], + :predict => [ + :load => KeyedArray(M3; time=1:3, loc=[:x, :y]), + :price => KeyedArray(M4; time=1:3, id=[:a, :b]), + ] + ])... + ) + + @testset "transform" begin + @test is_transformable(ds) + end + + @testset "OneToOne" begin + T = FakeOneToOneTransform() + + @testset "default applies to all components" begin + expected = KeyedDataset( + flatten([ + :train => [ + :load => KeyedArray(ones(3, 2); time=1:3, loc=[:x, :y]), + :price => KeyedArray(ones(3, 2); time=1:3, id=[:a, :b]), + ], + :predict => [ + :load => KeyedArray(ones(3, 2); time=1:3, loc=[:x, :y]), + :price => KeyedArray(ones(3, 2); time=1:3, id=[:a, :b]), + ] + ])... + ) + + r = FeatureTransforms.apply(ds, T) + + @test r isa KeyedDataset + @test isequal(r, expected) + @test !isequal(ds, expected) + end + + @testset "using pattern" begin + expected = KeyedDataset( + flatten([ + :train => [ + :load => KeyedArray(M1; time=1:3, loc=[:x, :y]), + :price => KeyedArray(ones(3, 2); time=1:3, id=[:a, :b]), + ], + :predict => [ + :load => KeyedArray(M3; time=1:3, loc=[:x, :y]), + :price => KeyedArray(ones(3, 2); time=1:3, id=[:a, :b]), + ] + ])... + ) + + r = FeatureTransforms.apply(ds, T, (:_, :price, :_)) + + @test r isa KeyedDataset + @test isequal(r, expected) + @test !isequal(ds, expected) + end + + @testset "using symbol" begin + expected = KeyedDataset( + flatten([ + :train => [ + :load => KeyedArray(ones(3, 2); time=1:3, loc=[:x, :y]), + :price => KeyedArray(M2; time=1:3, id=[:a, :b]), + ], + :predict => [ + :load => KeyedArray(ones(3, 2); time=1:3, loc=[:x, :y]), + :price => KeyedArray(M4; time=1:3, id=[:a, :b]), + ] + ])... + ) + + r = FeatureTransforms.apply(ds, T, :loc) + + @test r isa KeyedDataset + @test isequal(r, expected) + @test !isequal(ds, expected) + end + + @testset "using dims" begin + # replaces the first :loc column with ones(...) + expected = KeyedDataset( + flatten([ + :train => [ + :load => KeyedArray(ones(3, 1); time=1:3, loc=[:x]), + :price => KeyedArray(M2; time=1:3, id=[:a, :b]), + ], + :predict => [ + :load => KeyedArray(ones(3, 1); time=1:3, loc=[:x]), + :price => KeyedArray(M4; time=1:3, id=[:a, :b]), + ] + ])... + ) + + r = FeatureTransforms.apply(ds, T, :loc; dims=2, inds=[1]) + + @test r isa KeyedDataset + @test isequal(r, expected) + @test !isequal(ds, expected) + end + end + + @testset "OneToMany" begin + T = FakeOneToManyTransform() + + @testset "default applies to all components" begin + expected = KeyedDataset( + flatten([ + :train => [ + :load => KeyedArray(ones(3, 4); time=1:3, loc=[:x, :y, :x, :y]), + :price => KeyedArray(ones(3, 4); time=1:3, id=[:a, :b, :a, :b]), + ], + :predict => [ + :load => KeyedArray(ones(3, 4); time=1:3, loc=[:x, :y, :x, :y]), + :price => KeyedArray(ones(3, 4); time=1:3, id=[:a, :b, :a, :b]), + ] + ])... + ) + + r = FeatureTransforms.apply(ds, T) + + @test r isa KeyedDataset + @test isequal(r, expected) + @test !isequal(ds, expected) + end + + @testset "using pattern" begin + expected = KeyedDataset( + flatten([ + :train => [ + :load => KeyedArray(M1; time=1:3, loc=[:x, :y]), + :price => KeyedArray(ones(3, 4); time=1:3, id=[:a, :b, :a, :b]), + ], + :predict => [ + :load => KeyedArray(M3; time=1:3, loc=[:x, :y]), + :price => KeyedArray(ones(3, 4); time=1:3, id=[:a, :b, :a, :b]), + ] + ])... + ) + + r = FeatureTransforms.apply(ds, T, (:_, :price, :_)) + + @test r isa KeyedDataset + @test isequal(r, expected) + @test !isequal(ds, expected) + end + + @testset "using symbol" begin + expected = KeyedDataset( + flatten([ + :train => [ + :load => KeyedArray(ones(3, 4); time=1:3, loc=[:x, :y, :x, :y]), + :price => KeyedArray(M2; time=1:3, id=[:a, :b]), + ], + :predict => [ + :load => KeyedArray(ones(3, 4); time=1:3, loc=[:x, :y, :x, :y]), + :price => KeyedArray(M4; time=1:3, id=[:a, :b]), + ] + ])... + ) + + r = FeatureTransforms.apply(ds, T, :loc) + + @test r isa KeyedDataset + @test isequal(r, expected) + @test !isequal(ds, expected) + end + + @testset "using dims" begin + expected = KeyedDataset( + flatten([ + :train => [ + :load => KeyedArray(ones(3, 2); time=1:3, loc=[:x, :x]), + :price => KeyedArray(M2; time=1:3, id=[:a, :b]), + ], + :predict => [ + :load => KeyedArray(ones(3, 2); time=1:3, loc=[:x, :x]), + :price => KeyedArray(M4; time=1:3, id=[:a, :b]), + ] + ])... + ) + + r = FeatureTransforms.apply(ds, T, :loc; dims=2, inds=[1]) + + @test r isa KeyedDataset + @test isequal(r, expected) + @test !isequal(ds, expected) + end + end + + @testset "ManyToOne" begin + T = FakeManyToOneTransform() + + @testset "default applies to all components" begin + expected = KeyedDataset( + # ideally we would drop the :time constraint when it gets reduced + OrderedSet(Pattern[(:__, :time), (:__, :loc), (:__, :id)]), + LittleDict(flatten([ + :train => [ + :load => KeyedArray(ones(2); loc=[:x, :y]), + :price => KeyedArray(ones(2); id=[:a, :b]), + ], + :predict => [ + :load => KeyedArray(ones(2); loc=[:x, :y]), + :price => KeyedArray(ones(2); id=[:a, :b]), + ] + ])...) + ) + + r = FeatureTransforms.apply(ds, T; dims=:time) + + @test r isa KeyedDataset + @test isequal(r, expected) + @test !isequal(ds, expected) + end + + @testset "using pattern" begin + expected = KeyedDataset( + flatten([ + :train => [ + :load => KeyedArray(M1; time=1:3, loc=[:x, :y]), + :price => KeyedArray(ones(2); id=[:a, :b]), + ], + :predict => [ + :load => KeyedArray(M3; time=1:3, loc=[:x, :y]), + :price => KeyedArray(ones(2); id=[:a, :b]), + ] + ])... + ) + + r = FeatureTransforms.apply(ds, T, (:_, :price, :_); dims=:time) + + @test r isa KeyedDataset + @test isequal(r, expected) + @test !isequal(ds, expected) + end + + @testset "using symbol" begin + expected = KeyedDataset( + flatten([ + :train => [ + :load => KeyedArray(ones(2); loc=[:x, :y]), + :price => KeyedArray(M2; time=1:3, id=[:a, :b]), + ], + :predict => [ + :load => KeyedArray(ones(2); loc=[:x, :y]), + :price => KeyedArray(M4; time=1:3, id=[:a, :b]), + ] + ])... + ) + + r = FeatureTransforms.apply(ds, T, :loc; dims=:time) + + @test r isa KeyedDataset + @test isequal(r, expected) + @test !isequal(ds, expected) + end + end + + # Note: There are no ManyToMany transforms implemented just yet + @testset "ManyToMany" begin + T = FakeManyToManyTransform() + + @testset "default applies to all components" begin + expected = KeyedDataset( + flatten([ + :train => [ + :load => KeyedArray(ones(3, 4); time=1:3, loc=[:x, :y, :x, :y]), + :price => KeyedArray(ones(3, 4); time=1:3, id=[:a, :b, :a, :b]), + ], + :predict => [ + :load => KeyedArray(ones(3, 4); time=1:3, loc=[:x, :y, :x, :y]), + :price => KeyedArray(ones(3, 4); time=1:3, id=[:a, :b, :a, :b]), + ] + ])... + ) + + r = FeatureTransforms.apply(ds, T) + + @test r isa KeyedDataset + @test isequal(r, expected) + @test !isequal(ds, expected) + end + + @testset "using pattern" begin + expected = KeyedDataset( + flatten([ + :train => [ + :load => KeyedArray(M1; time=1:3, loc=[:x, :y]), + :price => KeyedArray(ones(3, 4); time=1:3, id=[:a, :b, :a, :b]), + ], + :predict => [ + :load => KeyedArray(M3; time=1:3, loc=[:x, :y]), + :price => KeyedArray(ones(3, 4); time=1:3, id=[:a, :b, :a, :b]), + ] + ])... + ) + + r = FeatureTransforms.apply(ds, T, (:_, :price, :_)) + + @test r isa KeyedDataset + @test isequal(r, expected) + @test !isequal(ds, expected) + end + + @testset "using symbol" begin + expected = KeyedDataset( + flatten([ + :train => [ + :load => KeyedArray(ones(3, 4); time=1:3, loc=[:x, :y, :x, :y]), + :price => KeyedArray(M2; time=1:3, id=[:a, :b]), + ], + :predict => [ + :load => KeyedArray(ones(3, 4); time=1:3, loc=[:x, :y, :x, :y]), + :price => KeyedArray(M4; time=1:3, id=[:a, :b]), + ] + ])... + ) + + r = FeatureTransforms.apply(ds, T, :loc) + + @test r isa KeyedDataset + @test isequal(r, expected) + @test !isequal(ds, expected) + end + + @testset "using dims" begin + expected = KeyedDataset( + flatten([ + :train => [ + :load => KeyedArray(ones(3, 2); time=1:3, loc=[:x, :x]), + :price => KeyedArray(M2; time=1:3, id=[:a, :b]), + ], + :predict => [ + :load => KeyedArray(ones(3, 2); time=1:3, loc=[:x, :x]), + :price => KeyedArray(M4; time=1:3, id=[:a, :b]), + ] + ])... + ) + + r = FeatureTransforms.apply(ds, T, :loc; dims=2, inds=[1]) + + @test r isa KeyedDataset + @test isequal(r, expected) + @test !isequal(ds, expected) + end + end + +end diff --git a/test/runtests.jl b/test/runtests.jl index fe3c839..7933500 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,8 @@ using AxisKeys using AxisSets using Dates using Documenter +using FeatureTransforms +using FeatureTransforms.TestUtils using Impute using Missings using OrderedCollections @@ -27,6 +29,7 @@ using Impute: ThresholdError include("indexing.jl") include("functions.jl") include("impute.jl") + include("featuretransforms.jl") # The doctests fail on x86, so only run them on 64-bit hardware & Julia 1.6 Sys.WORD_SIZE == 64 && v"1.6" <= VERSION < v"1.7" && doctest(AxisSets)