Skip to content

Commit

Permalink
Fix using a manual grad_J_a
Browse files Browse the repository at this point in the history
This fixes a bug where giving `grad_J_a` manually would still call the
`make_grad_J_a` function, which would then fail.
  • Loading branch information
goerz committed Aug 30, 2024
1 parent 644dfbe commit 1a2e34a
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,13 @@ function optimize_grape(problem)
chi = wrk.kwargs[:chi] # guaranteed to exist in `GrapeWrk` constructor
grad_J_a! = nothing
if !isnothing(J_a_func)
grad_J_a! = get(wrk.kwargs, :grad_J_a, make_grad_J_a(J_a_func, tlist))
if haskey(wrk.kwargs, :grad_J_a)
grad_J_a! = wrk.kwargs[:grad_J_a]
else
# With a manually given `grad_J_a`, the `make_grad_J_a` function
# should never be called. So we can't use `get` to set this.
grad_J_a! = make_grad_J_a(J_a_func, tlist)
end
end

τ = wrk.result.tau_vals
Expand Down
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ unicodeplots()
include("test_empty_optimization.jl")
end

println("\n* Pulse Running Cost (test_pulse_running_cost.jl)")
@time @safetestset "Pulse Running Cost" begin
include("test_pulse_running_cost.jl")
end

println("\n* Taylor Gradient (test_taylor_grad.jl):")
@time @safetestset "Taylor Gradient" begin
include("test_taylor_grad.jl")
Expand Down
72 changes: 72 additions & 0 deletions test/test_pulse_running_cost.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using QuantumControl
using QuantumControl.Functionals: J_a_fluence
using Test
using StableRNGs
using QuantumControlTestUtils.DummyOptimization: dummy_control_problem
using QuantumControl.Functionals: J_T_re
using LinearAlgebra: norm
using GRAPE

@testset "running cost with manual gradient" begin

function _TEST_J_a_smoothness(pulsevals, tlist)
N = length(tlist) - 1 # number of intervals
L = length(pulsevals) ÷ N
@assert length(pulsevals) == N * L
J_a = 0.0
for l = 1:L
for n = 2:N
J_a += (pulsevals[(l-1)*N+n] - pulsevals[(l-1)*N+n-1])^2
end
end
return 0.5 * J_a
end

function _TEST_grad_J_a_smoothness!(∇J_a, pulsevals, tlist)
N = length(tlist) - 1 # number of intervals
L = length(pulsevals) ÷ N
for l = 1:L
for n = 1:N
∇J_a[(l-1)*N+n] = 0.0
uₙ = pulsevals[(l-1)*N+n]
if n > 1
uₙ₋₁ = pulsevals[(l-1)*N+n-1]
∇J_a[(l-1)*N+n] += (uₙ - uₙ₋₁)
end
if n < N
uₙ₊₁ = pulsevals[(l-1)*N+n+1]
∇J_a[(l-1)*N+n] += (uₙ - uₙ₊₁)
end
end
end
return ∇J_a
end

rng = StableRNG(1244561944)
problem = dummy_control_problem(; n_controls=2, rng)
res = optimize(
problem;
method=GRAPE,
J_a=_TEST_J_a_smoothness,
grad_J_a=_TEST_grad_J_a_smoothness!,
lambda_a=0.1,
J_T=J_T_re,
iter_stop=2
)
@test res.converged
@test res.J_T < res.J_T_prev

end


@testset "J_a_fluence running cost" begin

rng = StableRNG(1244561944)
problem = dummy_control_problem(; n_controls=2, rng)
res0 = optimize(problem; method=GRAPE, J_T=J_T_re, iter_stop=2)
res = optimize(problem; method=GRAPE, J_a=J_a_fluence, J_T=J_T_re, iter_stop=2)
@test res0.converged
@test res.converged
@test sum(norm.(res.optimized_controls)) < sum(norm.(res0.optimized_controls))

end

0 comments on commit 1a2e34a

Please sign in to comment.