From 7abaeeb5614e8206ad1d3e2e4061d1af10733f21 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sun, 25 Jun 2023 21:49:36 -0400 Subject: [PATCH 01/15] chainrulescore ext --- Project.toml | 11 +++++++ ext/FFTWChainRulesCoreExt.jl | 60 ++++++++++++++++++++++++++++++++++++ src/FFTW.jl | 10 ++++++ 3 files changed, 81 insertions(+) create mode 100644 ext/FFTWChainRulesCoreExt.jl diff --git a/Project.toml b/Project.toml index 79902d9..a5ee242 100644 --- a/Project.toml +++ b/Project.toml @@ -9,11 +9,22 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" + +[extensions] +FFTWChainRulesCoreExt = "ChainRulesCore" [compat] AbstractFFTs = "1.0" +ChainRulesCore = "1" FFTW_jll = "3.3.9" MKL_jll = "2019.0.117, 2020, 2021, 2022, 2023" Preferences = "1.2" Reexport = "0.2, 1.0" julia = "1.6" + +[extras] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/ext/FFTWChainRulesCoreExt.jl b/ext/FFTWChainRulesCoreExt.jl new file mode 100644 index 0000000..64e8fe2 --- /dev/null +++ b/ext/FFTWChainRulesCoreExt.jl @@ -0,0 +1,60 @@ +module FFTWChainRulesCoreExt + +using FFTW +using ChainRulesCore + +# DCT + +function ChainRulesCore.frule(Δ, ::typeof(dct), x::AbstractArray, args...) + Δx = Δ[2] + y = dct(x, args...) + Δy = dct(Δx, args...) + return y, Δy +end + +function ChainRulesCore.rrule(::typeof(dct), x::AbstractArray, args...) + y = dct(x, args...) + project_x = ProjectTo(x) + function dct_pullback(ȳ) + f̄ = NoTangent() + x̄ = project_x(idct(unthunk(ȳ), args...)) + ā = NoTangent() + + if isempty(args) + return f̄, x̄ + else + return f̄, x̄, ā + end + end + + return y, dct_pullback +end + +# IDCT + +function ChainRulesCore.frule(Δ, ::typeof(idct), x::AbstractArray, args...) + Δx = Δ[2] + y = idct(x, args...) + Δy = idct(Δx, args...) + return y, Δy +end + +function ChainRulesCore.rrule(::typeof(idct), x::AbstractArray, args...) + y = idct(x, args...) + project_x = ChainRulesCore.ProjectTo(x) + function idct_pullback(ȳ) + f̄ = NoTangent() + x̄ = project_x(dct(unthunk(ȳ), args...)) + ā = NoTangent() + + if isempty(args) + return f̄, x̄ + else + return f̄, x̄, ā + end + end + + return y, idct_pullback +end + +end # module diff --git a/src/FFTW.jl b/src/FFTW.jl index 4366ee7..ba4b625 100644 --- a/src/FFTW.jl +++ b/src/FFTW.jl @@ -16,6 +16,10 @@ export dct, idct, dct!, idct!, plan_dct, plan_idct, plan_dct!, plan_idct! include("providers.jl") +@static if !isdefined(Base, :get_extension) + import Requires +end + function __init__() # If someone is trying to set the provider via the old environment variable, warn them that they # should instead use `set_provider!()` instead. @@ -35,6 +39,12 @@ function __init__() libfftw3[] = MKL_jll.libmkl_rt_path libfftw3f[] = MKL_jll.libmkl_rt_path end + + @static if !isdefined(Base, :get_extension) + Requires.@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" begin + include("../ext/FFTWChainRulesCoreExt.jl") + end + end end # most FFTW calls other than fftw_execute should be protected by a lock to be thread-safe From 1a9cd68d3a5be06ff3e4a3a0bbc7624319f49d34 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sun, 25 Jun 2023 21:50:04 -0400 Subject: [PATCH 02/15] add dct chainrules tests --- test/Project.toml | 5 +---- test/runtests.jl | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index c46e7ba..06d99a7 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,8 +1,5 @@ -# A bug in Julia 1.6.0's Pkg causes Preferences to be dropped during `Pkg.test()`, so we work around -# it by explicitly creating a `test/Project.toml` which will correctly communicate any preferences -# through to the child Julia process. X-ref: https://github.com/JuliaLang/Pkg.jl/issues/2500 - [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index 301194d..08208da 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -577,3 +577,33 @@ end end end end + +@testset "ChainRules" begin + + if isdefined(Base, :get_extension) + CRCEXT = Base.get_extension(FFTW, :FFTWChainRulesCoreExt) + @test isnothing(CRCEXT) + end + + using ChainRulesTestUtils + + if isdefined(Base, :get_extension) + CRCEXT = Base.get_extension(FFTW, :FFTWChainRulesCoreExt) + @test !isnothing(CRCEXT) + end + + @testset "DCT" begin + for f in (dct, idct) + for x in (randn(3), randn(3, 4), randn(3, 4, 5)) + test_frule(f, x) + test_rrule(f, x) + + N = ndims(x) + for dims in unique((1, 1:N, N)) + test_frule(f, x, dims) + test_rrule(f, x, dims) + end + end + end + end +end From d87c5503dcde184a51de1f250515ec7cc947dd63 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Mon, 26 Jun 2023 10:00:57 -0400 Subject: [PATCH 03/15] add CRC to test dependencies --- ext/FFTWChainRulesCoreExt.jl | 2 ++ test/Project.toml | 1 + 2 files changed, 3 insertions(+) diff --git a/ext/FFTWChainRulesCoreExt.jl b/ext/FFTWChainRulesCoreExt.jl index 64e8fe2..26ab7f7 100644 --- a/ext/FFTWChainRulesCoreExt.jl +++ b/ext/FFTWChainRulesCoreExt.jl @@ -15,6 +15,7 @@ end function ChainRulesCore.rrule(::typeof(dct), x::AbstractArray, args...) y = dct(x, args...) project_x = ProjectTo(x) + function dct_pullback(ȳ) f̄ = NoTangent() x̄ = project_x(idct(unthunk(ȳ), args...)) @@ -42,6 +43,7 @@ end function ChainRulesCore.rrule(::typeof(idct), x::AbstractArray, args...) y = idct(x, args...) project_x = ChainRulesCore.ProjectTo(x) + function idct_pullback(ȳ) f̄ = NoTangent() x̄ = project_x(dct(unthunk(ȳ), args...)) diff --git a/test/Project.toml b/test/Project.toml index 06d99a7..a477908 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" From 1f321f809ce0dbc63f22e9c98f87c76370bc8935 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Mon, 26 Jun 2023 19:12:21 -0400 Subject: [PATCH 04/15] rules for r2r WIP --- ext/FFTWChainRulesCoreExt.jl | 33 ++++++++++++++++++++++++++++++++- test/runtests.jl | 24 ++++++++++++++++++++---- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/ext/FFTWChainRulesCoreExt.jl b/ext/FFTWChainRulesCoreExt.jl index 26ab7f7..e4be290 100644 --- a/ext/FFTWChainRulesCoreExt.jl +++ b/ext/FFTWChainRulesCoreExt.jl @@ -1,6 +1,7 @@ module FFTWChainRulesCoreExt using FFTW +using FFTW: r2r using ChainRulesCore # DCT @@ -42,7 +43,7 @@ end function ChainRulesCore.rrule(::typeof(idct), x::AbstractArray, args...) y = idct(x, args...) - project_x = ChainRulesCore.ProjectTo(x) + project_x = ProjectTo(x) function idct_pullback(ȳ) f̄ = NoTangent() @@ -59,4 +60,34 @@ function ChainRulesCore.rrule(::typeof(idct), x::AbstractArray, args...) return y, idct_pullback end +# R2R + +function ChainRulesCore.frule(Δ, ::typeof(r2r), x::AbstractArray, args...) + Δx = Δ[2] + y = r2r(x, args...) + Δy = r2r(Δx, args...) + return y, Δy +end + +function ChainRulesCore.rrule(::typeof(r2r), x::AbstractArray, kinds, args...) + y = r2r(x, kinds, args...) + kinvs = Tuple(FFTW.inv_kind[k] for k in kinds) + project_x = ProjectTo(x) + + function r2r_pullback(ȳ) + f̄ = NoTangent() + x̄ = project_x(r2r(unthunk(ȳ), kinvs, args...)) + k̄ = NoTangent() + ā = NoTangent() + + if isempty(args) + return f̄, x̄, k̄ + else + return f̄, x̄, k̄, ā + end + end + + return y, r2r_pullback +end + end # module diff --git a/test/runtests.jl b/test/runtests.jl index 08208da..8b8d4e7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ # This file was formerly a part of Julia. License is MIT: https://julialang.org/license using FFTW -using FFTW: fftw_provider +using FFTW: fftw_provider, r2r using AbstractFFTs: Plan, plan_inv using Test using LinearAlgebra @@ -602,8 +602,24 @@ end for dims in unique((1, 1:N, N)) test_frule(f, x, dims) test_rrule(f, x, dims) - end - end - end + end # for dims + end # for x + end # for f + end + + @testset "r2r" begin + for k in 0:10 + for x in (randn(3), randn(3, 4), randn(3, 4, 5)) + test_frule(r2r, x, k) + test_rrule(r2r, x, k) + + N = ndims(x) + for dims in unique((1, 1:N, N)) + test_frule(r2r, x, k, dims) + test_rrule(r2r, x, k, dims) + end # for dims + end # for x + end # for f end + end From 8da249560e266ba903f66a48efa0b008f8c104d1 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Fri, 7 Jul 2023 20:18:48 -0400 Subject: [PATCH 05/15] rename args -> region --- ext/FFTWChainRulesCoreExt.jl | 60 +++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/ext/FFTWChainRulesCoreExt.jl b/ext/FFTWChainRulesCoreExt.jl index e4be290..b81bb30 100644 --- a/ext/FFTWChainRulesCoreExt.jl +++ b/ext/FFTWChainRulesCoreExt.jl @@ -6,26 +6,26 @@ using ChainRulesCore # DCT -function ChainRulesCore.frule(Δ, ::typeof(dct), x::AbstractArray, args...) +function ChainRulesCore.frule(Δ, ::typeof(dct), x::AbstractArray, region...) Δx = Δ[2] - y = dct(x, args...) - Δy = dct(Δx, args...) + y = dct(x, region...) + Δy = dct(Δx, region...) return y, Δy end -function ChainRulesCore.rrule(::typeof(dct), x::AbstractArray, args...) - y = dct(x, args...) +function ChainRulesCore.rrule(::typeof(dct), x::AbstractArray, region...) + y = dct(x, region...) project_x = ProjectTo(x) function dct_pullback(ȳ) f̄ = NoTangent() - x̄ = project_x(idct(unthunk(ȳ), args...)) - ā = NoTangent() + x̄ = project_x(idct(unthunk(ȳ), region...)) + r̄ = NoTangent() - if isempty(args) + if isempty(region) return f̄, x̄ else - return f̄, x̄, ā + return f̄, x̄, r̄ end end @@ -34,26 +34,26 @@ end # IDCT -function ChainRulesCore.frule(Δ, ::typeof(idct), x::AbstractArray, args...) +function ChainRulesCore.frule(Δ, ::typeof(idct), x::AbstractArray, region...) Δx = Δ[2] - y = idct(x, args...) - Δy = idct(Δx, args...) + y = idct(x, region...) + Δy = idct(Δx, region...) return y, Δy end -function ChainRulesCore.rrule(::typeof(idct), x::AbstractArray, args...) - y = idct(x, args...) +function ChainRulesCore.rrule(::typeof(idct), x::AbstractArray, region...) + y = idct(x, region...) project_x = ProjectTo(x) function idct_pullback(ȳ) f̄ = NoTangent() - x̄ = project_x(dct(unthunk(ȳ), args...)) - ā = NoTangent() + x̄ = project_x(dct(unthunk(ȳ), region...)) + r̄ = NoTangent() - if isempty(args) + if isempty(region) return f̄, x̄ else - return f̄, x̄, ā + return f̄, x̄, r̄ end end @@ -62,28 +62,32 @@ end # R2R -function ChainRulesCore.frule(Δ, ::typeof(r2r), x::AbstractArray, args...) +function ChainRulesCore.frule(Δ, ::typeof(r2r), x::AbstractArray, region...) Δx = Δ[2] - y = r2r(x, args...) - Δy = r2r(Δx, args...) + y = r2r(x, region...) + Δy = r2r(Δx, region...) return y, Δy end -function ChainRulesCore.rrule(::typeof(r2r), x::AbstractArray, kinds, args...) - y = r2r(x, kinds, args...) - kinvs = Tuple(FFTW.inv_kind[k] for k in kinds) +function ChainRulesCore.rrule(::typeof(r2r), x::AbstractArray, kinds, region...) + y = r2r(x, kinds, region...) + kinvs = if kinds isa Integer + FFTW.inv_kind[kinds] + else + Tuple(FFTW.inv_kind[k] for k in kinds) + end project_x = ProjectTo(x) function r2r_pullback(ȳ) f̄ = NoTangent() - x̄ = project_x(r2r(unthunk(ȳ), kinvs, args...)) + x̄ = project_x(r2r(unthunk(ȳ), kinvs, region...)) k̄ = NoTangent() - ā = NoTangent() + r̄ = NoTangent() - if isempty(args) + if isempty(region) return f̄, x̄, k̄ else - return f̄, x̄, k̄, ā + return f̄, x̄, k̄, r̄ end end From 5a975ab5935eb0b45f3f266fc6034c4b6f0eaf4f Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Fri, 7 Jul 2023 20:20:06 -0400 Subject: [PATCH 06/15] rename dims -> region --- test/runtests.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 8b8d4e7..76d91b6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -597,27 +597,27 @@ end for x in (randn(3), randn(3, 4), randn(3, 4, 5)) test_frule(f, x) test_rrule(f, x) - + N = ndims(x) - for dims in unique((1, 1:N, N)) - test_frule(f, x, dims) - test_rrule(f, x, dims) - end # for dims + for region in unique((1, 1:N, N)) + test_frule(f, x, region) + test_rrule(f, x, region) + end # for region end # for x end # for f end @testset "r2r" begin - for k in 0:10 - for x in (randn(3), randn(3, 4), randn(3, 4, 5)) + for k in 4 #0:10 + for x in (randn(3), )#randn(3, 4), randn(3, 4, 5)) test_frule(r2r, x, k) - test_rrule(r2r, x, k) - + # test_rrule(r2r, x, k) + N = ndims(x) - for dims in unique((1, 1:N, N)) - test_frule(r2r, x, k, dims) - test_rrule(r2r, x, k, dims) - end # for dims + for region in unique((1, 1:N, N)) + test_frule(r2r, x, k, region) + # test_rrule(r2r, x, k, region) + end # for region end # for x end # for f end From 39c7b6d00654af264cbdd1a83bfe6cc83faca3ed Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Fri, 7 Jul 2023 20:21:03 -0400 Subject: [PATCH 07/15] rm r2r rrule. needs more thinking --- ext/FFTWChainRulesCoreExt.jl | 25 ------------------------- test/runtests.jl | 2 -- 2 files changed, 27 deletions(-) diff --git a/ext/FFTWChainRulesCoreExt.jl b/ext/FFTWChainRulesCoreExt.jl index b81bb30..b940065 100644 --- a/ext/FFTWChainRulesCoreExt.jl +++ b/ext/FFTWChainRulesCoreExt.jl @@ -69,29 +69,4 @@ function ChainRulesCore.frule(Δ, ::typeof(r2r), x::AbstractArray, region...) return y, Δy end -function ChainRulesCore.rrule(::typeof(r2r), x::AbstractArray, kinds, region...) - y = r2r(x, kinds, region...) - kinvs = if kinds isa Integer - FFTW.inv_kind[kinds] - else - Tuple(FFTW.inv_kind[k] for k in kinds) - end - project_x = ProjectTo(x) - - function r2r_pullback(ȳ) - f̄ = NoTangent() - x̄ = project_x(r2r(unthunk(ȳ), kinvs, region...)) - k̄ = NoTangent() - r̄ = NoTangent() - - if isempty(region) - return f̄, x̄, k̄ - else - return f̄, x̄, k̄, r̄ - end - end - - return y, r2r_pullback -end - end # module diff --git a/test/runtests.jl b/test/runtests.jl index 76d91b6..027c7a4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -611,12 +611,10 @@ end for k in 4 #0:10 for x in (randn(3), )#randn(3, 4), randn(3, 4, 5)) test_frule(r2r, x, k) - # test_rrule(r2r, x, k) N = ndims(x) for region in unique((1, 1:N, N)) test_frule(r2r, x, k, region) - # test_rrule(r2r, x, k, region) end # for region end # for x end # for f From 10ac31c162cb43b20284eb094be3bfb684040032 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sat, 8 Jul 2023 16:03:26 -0400 Subject: [PATCH 08/15] Update Project.toml Co-authored-by: David Widmann --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a5ee242..008bccc 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" [extensions] FFTWChainRulesCoreExt = "ChainRulesCore" From 3ee37e75a8c4acbe42d6b78d9cc0b95630637091 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sat, 8 Jul 2023 16:04:05 -0400 Subject: [PATCH 09/15] Update src/FFTW.jl Co-authored-by: David Widmann --- src/FFTW.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/FFTW.jl b/src/FFTW.jl index ba4b625..e205654 100644 --- a/src/FFTW.jl +++ b/src/FFTW.jl @@ -16,10 +16,6 @@ export dct, idct, dct!, idct!, plan_dct, plan_idct, plan_dct!, plan_idct! include("providers.jl") -@static if !isdefined(Base, :get_extension) - import Requires -end - function __init__() # If someone is trying to set the provider via the old environment variable, warn them that they # should instead use `set_provider!()` instead. From 74b8c75b606ca8f53d2f8f71258e0e1ac7140307 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sat, 8 Jul 2023 16:09:13 -0400 Subject: [PATCH 10/15] rm requires --- src/FFTW.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/FFTW.jl b/src/FFTW.jl index e205654..5abbba0 100644 --- a/src/FFTW.jl +++ b/src/FFTW.jl @@ -4,6 +4,10 @@ using LinearAlgebra, Reexport, Preferences @reexport using AbstractFFTs using Base.Threads +@static if !isdefined(Base, :get_extension) + include("../ext/FFTWChainRulesCoreExt.jl") +end + import AbstractFFTs: Plan, ScaledPlan, fft, ifft, bfft, fft!, ifft!, bfft!, plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!, @@ -35,12 +39,6 @@ function __init__() libfftw3[] = MKL_jll.libmkl_rt_path libfftw3f[] = MKL_jll.libmkl_rt_path end - - @static if !isdefined(Base, :get_extension) - Requires.@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" begin - include("../ext/FFTWChainRulesCoreExt.jl") - end - end end # most FFTW calls other than fftw_execute should be protected by a lock to be thread-safe From a31b52dda91f78853557d64da5e113aef5d99299 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sat, 8 Jul 2023 16:09:36 -0400 Subject: [PATCH 11/15] uncomment tests --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 027c7a4..8acf182 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -608,8 +608,8 @@ end end @testset "r2r" begin - for k in 4 #0:10 - for x in (randn(3), )#randn(3, 4), randn(3, 4, 5)) + for k in 0:10 + for x in (randn(3), randn(3, 4), randn(3, 4, 5)) test_frule(r2r, x, k) N = ndims(x) From 3d6e160e377d3fec706f941b5fa97baacb45c893 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sat, 8 Jul 2023 16:14:45 -0400 Subject: [PATCH 12/15] mv include(CRCext) to bottom to fix precomile failure in <1.9 --- src/FFTW.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/FFTW.jl b/src/FFTW.jl index 5abbba0..82d6bde 100644 --- a/src/FFTW.jl +++ b/src/FFTW.jl @@ -4,10 +4,6 @@ using LinearAlgebra, Reexport, Preferences @reexport using AbstractFFTs using Base.Threads -@static if !isdefined(Base, :get_extension) - include("../ext/FFTWChainRulesCoreExt.jl") -end - import AbstractFFTs: Plan, ScaledPlan, fft, ifft, bfft, fft!, ifft!, bfft!, plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!, @@ -76,4 +72,8 @@ include("dct.jl") include("precompile.jl") _precompile_() +@static if !isdefined(Base, :get_extension) + include("../ext/FFTWChainRulesCoreExt.jl") +end + end # module From d0ef83839929be50cc9b8dd0faa4872489b44351 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sat, 8 Jul 2023 16:39:03 -0400 Subject: [PATCH 13/15] default region to 1:ndims(x) in frules --- ext/FFTWChainRulesCoreExt.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/ext/FFTWChainRulesCoreExt.jl b/ext/FFTWChainRulesCoreExt.jl index b940065..01dff37 100644 --- a/ext/FFTWChainRulesCoreExt.jl +++ b/ext/FFTWChainRulesCoreExt.jl @@ -6,10 +6,10 @@ using ChainRulesCore # DCT -function ChainRulesCore.frule(Δ, ::typeof(dct), x::AbstractArray, region...) +function ChainRulesCore.frule(Δ, ::typeof(dct), x::AbstractArray, region = 1:ndims(x)) Δx = Δ[2] - y = dct(x, region...) - Δy = dct(Δx, region...) + y = dct(x, region) + Δy = dct(Δx, region) return y, Δy end @@ -34,10 +34,10 @@ end # IDCT -function ChainRulesCore.frule(Δ, ::typeof(idct), x::AbstractArray, region...) +function ChainRulesCore.frule(Δ, ::typeof(idct), x::AbstractArray, region = 1:ndims(x)) Δx = Δ[2] - y = idct(x, region...) - Δy = idct(Δx, region...) + y = idct(x, region) + Δy = idct(Δx, region) return y, Δy end @@ -62,10 +62,10 @@ end # R2R -function ChainRulesCore.frule(Δ, ::typeof(r2r), x::AbstractArray, region...) +function ChainRulesCore.frule(Δ, ::typeof(r2r), x::AbstractArray, kind, region = 1:ndims(x)) Δx = Δ[2] - y = r2r(x, region...) - Δy = r2r(Δx, region...) + y = r2r(x, kind, region) + Δy = r2r(Δx, kind, region) return y, Δy end From 6b8d178afd368f1b5e5334b76e8b9b617fc3e6d6 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sat, 8 Jul 2023 17:38:30 -0400 Subject: [PATCH 14/15] different rrules for dct(x), dct(x, region) --- ext/FFTWChainRulesCoreExt.jl | 50 ++++++++++++++---------------------- test/runtests.jl | 2 +- 2 files changed, 20 insertions(+), 32 deletions(-) diff --git a/ext/FFTWChainRulesCoreExt.jl b/ext/FFTWChainRulesCoreExt.jl index 01dff37..c33f38a 100644 --- a/ext/FFTWChainRulesCoreExt.jl +++ b/ext/FFTWChainRulesCoreExt.jl @@ -4,7 +4,7 @@ using FFTW using FFTW: r2r using ChainRulesCore -# DCT +# DCT/IDCT function ChainRulesCore.frule(Δ, ::typeof(dct), x::AbstractArray, region = 1:ndims(x)) Δx = Δ[2] @@ -13,23 +13,17 @@ function ChainRulesCore.frule(Δ, ::typeof(dct), x::AbstractArray, region = 1:nd return y, Δy end -function ChainRulesCore.rrule(::typeof(dct), x::AbstractArray, region...) - y = dct(x, region...) +function ChainRulesCore.rrule(::typeof(dct), x::AbstractArray) project_x = ProjectTo(x) + region = 1:ndims(x) + dct_pb(Δ) = NoTangent(), project_x(idct(unthunk(Δ), region)) + return dct(x, region), dct_pb +end - function dct_pullback(ȳ) - f̄ = NoTangent() - x̄ = project_x(idct(unthunk(ȳ), region...)) - r̄ = NoTangent() - - if isempty(region) - return f̄, x̄ - else - return f̄, x̄, r̄ - end - end - - return y, dct_pullback +function ChainRulesCore.rrule(::typeof(dct), x::AbstractArray, region) + project_x = ProjectTo(x) + dct_pb(Δ) = NoTangent(), project_x(idct(unthunk(Δ), region)), NoTangent() + return dct(x, region), dct_pb end # IDCT @@ -41,23 +35,17 @@ function ChainRulesCore.frule(Δ, ::typeof(idct), x::AbstractArray, region = 1:n return y, Δy end -function ChainRulesCore.rrule(::typeof(idct), x::AbstractArray, region...) - y = idct(x, region...) +function ChainRulesCore.rrule(::typeof(idct), x::AbstractArray) project_x = ProjectTo(x) + region = 1:ndims(x) + dct_pb(Δ) = NoTangent(), project_x(dct(unthunk(Δ), region)) + return idct(x, region), dct_pb +end - function idct_pullback(ȳ) - f̄ = NoTangent() - x̄ = project_x(dct(unthunk(ȳ), region...)) - r̄ = NoTangent() - - if isempty(region) - return f̄, x̄ - else - return f̄, x̄, r̄ - end - end - - return y, idct_pullback +function ChainRulesCore.rrule(::typeof(idct), x::AbstractArray, region) + project_x = ProjectTo(x) + dct_pb(Δ) = NoTangent(), project_x(dct(unthunk(Δ), region)), NoTangent() + return idct(x, region), dct_pb end # R2R diff --git a/test/runtests.jl b/test/runtests.jl index 8acf182..90838d1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -606,7 +606,7 @@ end end # for x end # for f end - + @testset "r2r" begin for k in 0:10 for x in (randn(3), randn(3, 4), randn(3, 4, 5)) From 608aa023b4e43c9d58a57dd387f44ad82d581e36 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sat, 8 Jul 2023 17:50:22 -0400 Subject: [PATCH 15/15] put DCT/IDCT rules in a loop --- ext/FFTWChainRulesCoreExt.jl | 62 +++++++++++++----------------------- 1 file changed, 22 insertions(+), 40 deletions(-) diff --git a/ext/FFTWChainRulesCoreExt.jl b/ext/FFTWChainRulesCoreExt.jl index c33f38a..789e4e9 100644 --- a/ext/FFTWChainRulesCoreExt.jl +++ b/ext/FFTWChainRulesCoreExt.jl @@ -6,46 +6,28 @@ using ChainRulesCore # DCT/IDCT -function ChainRulesCore.frule(Δ, ::typeof(dct), x::AbstractArray, region = 1:ndims(x)) - Δx = Δ[2] - y = dct(x, region) - Δy = dct(Δx, region) - return y, Δy -end - -function ChainRulesCore.rrule(::typeof(dct), x::AbstractArray) - project_x = ProjectTo(x) - region = 1:ndims(x) - dct_pb(Δ) = NoTangent(), project_x(idct(unthunk(Δ), region)) - return dct(x, region), dct_pb -end - -function ChainRulesCore.rrule(::typeof(dct), x::AbstractArray, region) - project_x = ProjectTo(x) - dct_pb(Δ) = NoTangent(), project_x(idct(unthunk(Δ), region)), NoTangent() - return dct(x, region), dct_pb -end - -# IDCT - -function ChainRulesCore.frule(Δ, ::typeof(idct), x::AbstractArray, region = 1:ndims(x)) - Δx = Δ[2] - y = idct(x, region) - Δy = idct(Δx, region) - return y, Δy -end - -function ChainRulesCore.rrule(::typeof(idct), x::AbstractArray) - project_x = ProjectTo(x) - region = 1:ndims(x) - dct_pb(Δ) = NoTangent(), project_x(dct(unthunk(Δ), region)) - return idct(x, region), dct_pb -end - -function ChainRulesCore.rrule(::typeof(idct), x::AbstractArray, region) - project_x = ProjectTo(x) - dct_pb(Δ) = NoTangent(), project_x(dct(unthunk(Δ), region)), NoTangent() - return idct(x, region), dct_pb +for (fwd, bwd) in ( + (dct, idct), + (idct, dct), +) + function ChainRulesCore.frule(Δ, ::typeof(fwd), x::AbstractArray, region = 1:ndims(x)) + Δx = Δ[2] + y = fwd(x, region) + Δy = fwd(Δx, region) + return y, Δy + end + + function ChainRulesCore.rrule(::typeof(fwd), x::AbstractArray) + project_x = ProjectTo(x) + dct_pb(Δ) = NoTangent(), project_x(bwd(unthunk(Δ))) + return fwd(x), dct_pb + end + + function ChainRulesCore.rrule(::typeof(fwd), x::AbstractArray, region) + project_x = ProjectTo(x) + dct_pb(Δ) = NoTangent(), project_x(bwd(unthunk(Δ), region)), NoTangent() + return fwd(x, region), dct_pb + end end # R2R