Skip to content

Commit

Permalink
Merge pull request #83 from JuliaQuantumControl/print-iter-info
Browse files Browse the repository at this point in the history
Allow to configure which info is printed per iteration
  • Loading branch information
goerz authored Sep 24, 2024
2 parents 768f41f + f7dc23d commit b6b872b
Show file tree
Hide file tree
Showing 7 changed files with 335 additions and 76 deletions.
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ fallbacks = ExternalFallbacks(
"QuantumControl.QuantumPropagators.Controls.get_controls" => "@extref QuantumPropagators.Controls.get_controls",
"Trajectory" => "@extref QuantumControl.Trajectory",
"propagate" => "@extref QuantumPropagators.propagate",
"init_prop_trajectory" => "@extref QuantumControl.init_prop_trajectory"
"init_prop_trajectory" => "@extref QuantumControl.init_prop_trajectory",
"Functional" => "@extref QuantumControl :std:label:`Functional`",
)

println("Starting makedocs")
Expand Down
13 changes: 12 additions & 1 deletion ext/GRAPELBFGSBExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module GRAPELBFGSBExt

import LBFGSB
using GRAPE: GrapeWrk, update_result!
import GRAPE: run_optimizer, gradient, step_width, search_direction
import GRAPE: run_optimizer, gradient, step_width, search_direction, norm_search


function run_optimizer(optimizer::LBFGSB.L_BFGS_B, wrk, fg!, callback, check_convergence!)
Expand Down Expand Up @@ -197,4 +197,15 @@ function search_direction(wrk::GrapeWrk{O}) where {O<:LBFGSB.L_BFGS_B}
return wrk.optimizer.wa[n0:n0+n-1]
end


function norm_search(wrk::GrapeWrk{O}) where {O<:LBFGSB.L_BFGS_B}
n = length(wrk.pulsevals)
n0 = wrk.optimizer.isave[13]
r = 0.0
for i = n0:n0+n-1
r += wrk.optimizer.wa[i]^2
end
return sqrt(r)
end

end
3 changes: 2 additions & 1 deletion ext/GRAPEOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ function run_optimizer(
if hasproperty(objective, :DF)
# DF is the *current* gradient, i.e., the gradient of the updated
# pulsevals, which (after the call to `callback`) is the gradient
# for the the guess of the next iteration.
# for the the guess of the next iteration. It's important that
# we're setting this after the GRAPE.jl `callback`
wrk.gradient .= objective.DF
elseif (optimization_state.iteration == 1)
@error "Cannot determine guess gradient"
Expand Down
299 changes: 229 additions & 70 deletions src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,53 @@ with explicit keyword arguments to `optimize`.
an unconstrained lower bound for that time interval,
* `print_iters=true`: Whether to print information after each iteration.
* `store_iter_info=Set()`: Which fields from `print_iters` to store in
`result.records`. A subset of
`Set(["iter.", "J_T", "|∇J_T|", "ΔJ_T", "FG(F)", "secs"])`.
* `print_iter_info=["iter.", "J_T", "|∇J|", "|Δϵ|", "ΔJ", "FG(F)", "secs"]`:
Which fields to print if `print_iters=true`. If given, must be a list of
header labels (strings), which can be any of the following:
- `"iter."`: The iteration number
- `"J_T"`: The value of the final-time functional for the dynamics under the
optimized pulses
- `"J_a"`: The value of the pulse-dependent running cost for the optimized
pulses
- `"λ_a⋅J_a"`: The total contribution of `J_a` to the full functional `J`
- `"J"`: The value of the optimization functional for the optimized pulses
- `"ǁ∇J_Tǁ"`: The ℓ²-norm of the *current* gradient of the final-time
functional. Note that this is usually the gradient of the optimize pulse,
not the guess pulse.
- `"ǁ∇J_aǁ"`: The ℓ²-norm of the the *current* gradient of the pulse-dependent
running cost. For comparison with `"ǁ∇J_Tǁ"`.
- `"λ_aǁ∇J_aǁ"`: The ℓ²-norm of the the *current* gradient of the complete
pulse-dependent running cost term. For comparison with `"ǁ∇J_Tǁ"`.
- `"ǁ∇Jǁ"`: The norm of the guess pulse gradient. Note that the *guess* pulse
gradient is not the same the *current* gradient.
- `"ǁΔϵǁ"`: The ℓ²-norm of the pulse update
- `"ǁϵǁ"`: The ℓ²-norm of optimized pulse values
- `"max|Δϵ|"` The maximum value of the pulse update (infinity norm)
- `"max|ϵ|"`: The maximum value of the pulse values (infinity norm)
- `"ǁΔϵǁ/ǁϵǁ"`: The ratio of the pulse update tothe optimized pulse values
- `"∫Δϵ²dt"`: The L²-norm of the pulse update, summed over all pulses. A
convergence measure comparable (proportional) to the running cost in
Krotov's method
- `"ǁsǁ"`: The norm of the search direction. Should be `ǁΔϵǁ` scaled by the
step with `α`.
- `"∠°"`: The angle (in degrees) between the negative gradient `-∇J` and the
search direction `s`.
- `"α"`: The step width as determined by the line search (`Δϵ = α⋅s`)
- `"ΔJ_T"`: The change in the final time functional relative to the previous
iteration
- `"ΔJ_a"`: The change in the control-dependent running cost relative to the
previous iteration
- `"λ_a⋅ΔJ_a"`: The change in the control-dependent running cost term
relative to the previous iteration.
- `"ΔJ"`: The change in the total optimization functional relative to the
previous iteration.
- `"FG(F)"`: The number of functional/gradient evaluation (FG), or pure
functional (F) evaluations
- `"secs"`: The number of seconds of wallclock time spent on the iteration.
* `store_iter_info=[]`: Which fields to store in `result.records`, given as
a list of header labels, see `print_iter_info`.
* `callback`: A function (or tuple of functions) that receives the
[GRAPE workspace](@ref GrapeWrk) and the iteration number. The function
may return a tuple of values which are stored in the
Expand All @@ -113,7 +157,7 @@ with explicit keyword arguments to `optimize`.
similar manipulations. Note that `print_iters=true` (default) adds an
automatic callback to print information after each iteration. With
`store_iter_info`, that callback automatically stores a subset of the
printed information.
available information.
* `check_convergence`: A function to check whether convergence has been
reached. Receives a [`GrapeResult`](@ref) object `result`, and should set
`result.converged` to `true` and `result.message` to an appropriate string in
Expand Down Expand Up @@ -464,6 +508,12 @@ function update_result!(wrk::GrapeWrk, i::Int64)
end
res.J_T_prev = res.J_T
res.J_T = wrk.J_parts[1]
res.J_a_prev = res.J_a
res.J_a = wrk.J_parts[2]
if res.J_a > 0.0
λₐ = get(wrk.kwargs, :lambda_a, 1.0)
res.J_a /= λₐ
end
(i > 0) && (res.iter = i)
if i >= res.iter_stop
res.converged = true
Expand Down Expand Up @@ -499,88 +549,197 @@ This functions serves as the default `info_hook` for an optimization with
GRAPE.
"""
function make_grape_print_iters(; kwargs...)

headers = ["iter.", "J_T", "|∇J_T|", "ΔJ_T", "FG(F)", "secs"]
store_iter_info = Set(get(kwargs, :store_iter_info, Set()))
info_vals = Vector{Any}(undef, length(headers))
fill!(info_vals, nothing)
store_iter = false
store_J_T = false
store_grad_norm = false
store_ΔJ_T = false
store_counts = false
store_secs = false
for item in store_iter_info
if item == "iter."
store_iter = true
elseif item == "J_T"
store_J_T = true
elseif item == "|∇J_T|"
store_grad_norm = true
elseif item == "ΔJ_T"
store_ΔJ_T = true
elseif item == "FG(F)"
store_counts = true
elseif item == "secs"
store_secs = true
else
msg = "Item $(repr(item)) in `store_iter_info` is not one of $(repr(headers)))"
throw(ArgumentError(msg))
end
headers = [
"iter.",
"J_T",
"J_a",
# "J_b",
"λ_a⋅J_a",
# "λ_b⋅J_b",
"J",
"ǁ∇J_Tǁ",
"ǁ∇J_aǁ",
"λ_aǁ∇J_aǁ",
# "ǁ∇J_bǁ",
"λ_a⋅ΔJ_a",
# "λ_b⋅ΔJ_b",
"ǁ∇Jǁ",
"ǁΔϵǁ",
"ǁϵǁ",
"max|Δϵ|",
"max|ϵ|",
"ǁΔϵǁ/ǁϵǁ",
"∫Δϵ²dt",
"ǁsǁ",
"∠°",
"α",
"ΔJ_T",
"ΔJ_a",
# "ΔJ_b",
"λ_a⋅ΔJ_a",
# "λ_b⋅ΔJ_b",
"ΔJ",
"FG(F)",
"secs"
] # TODO: update docs when J_b is implemented
delta_headers = Set([
"ΔJ_T",
"λ_a⋅ΔJ_a",
"ΔJ_a",
"λ_b⋅ΔJ_b",
"ΔJ",
"ǁΔϵǁ",
"ǁΔϵǁ/ǁϵǁ",
"max|Δϵ|",
"∫Δϵ²dt",
"α",
"ǁsǁ"
])
store_iter_info = get(kwargs, :store_iter_info, String[])
if Set(store_iter_info) Set(headers)
diff = [field for field in store_iter_info if field headers]
msg = "store_iter_info contains invalid elements $(diff)"
@warn "Invalid $(diff) not in allowed fields = [$(join(map(repr, headers), ", "))]"
throw(ArgumentError(msg))
end
print_iter_info = get(
kwargs,
:print_iter_info,
["iter.", "J_T", "ǁ∇Jǁ", "ǁΔϵǁ", "ΔJ", "FG(F)", "secs"]
)
if Set(print_iter_info) Set(headers)
diff = [field for field in print_iter_info if field headers]
msg = "print_iter_info contains invalid elements $(diff)"
@warn "Invalid $(diff) not in allowed fields = [$(join(map(repr, headers), ", "))]"
throw(ArgumentError(msg))
end
needed_fields = Set(store_iter_info) Set(print_iter_info)
info_vals = Dict{String,Any}()

function print_table(wrk, iteration, args...)

J_T = wrk.result.J_T
ΔJ_T = J_T - wrk.result.J_T_prev
secs = wrk.result.secs
grad_norm = norm(wrk.grad_J_T)
counts = Tuple(wrk.fg_count)
λ_a = get(wrk.kwargs, :lambda_a, 1.0)
info_vals["iter."] = iteration
info_vals["J_T"] = wrk.result.J_T
info_vals["ΔJ_T"] = wrk.result.J_T - wrk.result.J_T_prev
info_vals["J_a"] = wrk.result.J_a
info_vals["λ_a⋅J_a"] = wrk.J_parts[2]
ΔJ_a = wrk.result.J_a - wrk.result.J_a_prev
info_vals["ΔJ_a"] = ΔJ_a
info_vals["λ_a⋅ΔJ_a"] = λ_a * ΔJ_a
info_vals["J"] = wrk.result.J_T + λ_a * wrk.result.J_a
if "ǁ∇J_Tǁ" needed_fields
info_vals["ǁ∇J_Tǁ"] = norm(wrk.grad_J_T)
end
if ("ǁ∇J_aǁ" needed_fields) || ("λ_aǁ∇J_aǁ" needed_fields)
nrm_∇J_a = norm(wrk.grad_J_a)
info_vals["ǁ∇J_aǁ"] = nrm_∇J_a
info_vals["λ_aǁ∇J_aǁ"] = λ_a * nrm_∇J_a
end
if "ǁ∇Jǁ" needed_fields
info_vals["ǁ∇Jǁ"] = norm(gradient(wrk; which=:initial))
end
if "ΔJ" needed_fields
J = wrk.result.J_T + λ_a * wrk.result.J_a
J_prev = wrk.result.J_T_prev + λ_a * wrk.result.J_a_prev
info_vals["ΔJ"] = J - J_prev
end
if ("ǁΔϵǁ/ǁϵǁ" needed_fields) ||
("ǁΔϵǁ" needed_fields) ||
("ǁϵǁ" needed_fields) ||
("max|ϵ|" needed_fields) ||
("max|Δϵ|" needed_fields) ||
("∫Δϵ²dt" needed_fields)
r = 0.0
= 0.0
∫Δϵ²dt = 0.0
max_ϵ = 0.0
max_Δϵ = 0.0
N = length(wrk.result.tlist) - 1
for i = 1:length(wrk.pulsevals)
n = ((i - 1) % N) + 1 # index of time interval 1…N
dt = wrk.result.tlist[n+1] - wrk.result.tlist[n]
r += wrk.pulsevals[i]^2
abs_Δϵᵢ = abs(wrk.pulsevals[i] - wrk.pulsevals_guess[i])
Δϵᵢ² = abs_Δϵᵢ^2
∫Δϵ²dt += Δϵᵢ² * dt
+= Δϵᵢ²
abs_ϵᵢ = abs(wrk.pulsevals[i])
(abs_ϵᵢ > max_ϵ) && (max_ϵ = abs_ϵᵢ)
(abs_Δϵᵢ > max_Δϵ) && (max_Δϵ = abs_Δϵᵢ)
end
info_vals["ǁϵǁ"] = sqrt(r)
info_vals["ǁΔϵǁ"] = sqrt(rΔ)
info_vals["ǁΔϵǁ/ǁϵǁ"] = sqrt(rΔ) / sqrt(r)
info_vals["max|ϵ|"] = max_ϵ
info_vals["max|Δϵ|"] = max_Δϵ
info_vals["∫Δϵ²dt"] = ∫Δϵ²dt
end
if "ǁsǁ" needed_fields
info_vals["ǁsǁ"] = norm_search(wrk)
end
if "α" needed_fields
info_vals["α"] = step_width(wrk)
end
if "∠°" needed_fields
s_G = -1 * gradient(wrk; which=:initial)
s = search_direction(wrk)
info_vals["∠°"] = vec_angle(s_G, s; unit=:degree)
end
info_vals["FG(F)"] = Tuple(wrk.fg_count)
info_vals["secs"] = wrk.result.secs

iter_stop = "$(get(wrk.kwargs, :iter_stop, 5000))"
width = Dict(
width = Dict( # default width is 11
"iter." => max(length("$iter_stop"), 6),
"J_T" => 11,
"|∇J_T|" => 11,
"|∇J_a|" => 11,
"|∇J|" => 11,
"ΔJ" => 11,
"ΔJ_T" => 11,
"FG(F)" => 8,
"secs" => 8,
"∠°" => 7,
)

store_iter && (info_vals[1] = iteration)
store_J_T && (info_vals[2] = J_T)
store_grad_norm && (info_vals[3] = grad_norm)
store_ΔJ_T && (info_vals[4] = ΔJ_T)
store_counts && (info_vals[5] = counts)
store_secs && (info_vals[6] = secs)

if iteration == 0
for header in headers
w = width[header]
print(lpad(header, w))
if length(print_iter_info) > 0

if iteration == 0
for header in print_iter_info
w = get(width, header, 11)
print(lpad(header, w))
end
print("\n")
end

for header in print_iter_info
if header == "iter."
str = "$(info_vals[header])"
elseif header == "FG(F)"
counts = info_vals[header]
str = @sprintf("%d(%d)", counts[1], counts[2])
elseif header == "secs"
str = @sprintf("%.1f", info_vals[header])
elseif header delta_headers
if iteration > 0
str = @sprintf("%.2e", info_vals[header])
else
str = "n/a"
end
elseif header == "∠°"
if iteration > 0
str = @sprintf("%.1f", info_vals["∠°"])
else
str = "n/a"
end
else
str = @sprintf("%.2e", info_vals[header])
end
w = get(width, header, 11)
print(lpad(str, w))
end

print("\n")
end
flush(stdout)

strs = [
"$iteration",
@sprintf("%.2e", J_T),
@sprintf("%.2e", grad_norm),
(iteration > 0) ? @sprintf("%.2e", ΔJ_T) : "n/a",
@sprintf("%d(%d)", counts[1], counts[2]),
@sprintf("%.1f", secs),
]
for (str, header) in zip(strs, headers)
w = width[header]
print(lpad(str, w))
end
print("\n")
flush(stdout)

return Tuple((value for value in info_vals if (value !== nothing)))
return Tuple((info_vals[field] for field in store_iter_info))

end

Expand Down
Loading

0 comments on commit b6b872b

Please sign in to comment.