Skip to content

Commit

Permalink
Complete make-over for RBFModel
Browse files Browse the repository at this point in the history
This is a pretty massiv commit, sorry.
There is now a near-standalone backend providing `RBFSurrogate`,
similar to Flux/Lux models.
We use the same methods for `RBFModel`.
It still needs some interface definitions, but evaluation and
construction appears to work well.
Actually did some profiling.
  • Loading branch information
manuelbb-upb committed Feb 28, 2024
1 parent 2e59ff1 commit ab5cba8
Show file tree
Hide file tree
Showing 19 changed files with 3,030 additions and 901 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@ ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4"
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
StructHelpers = "4093c41a-2008-41fd-82b8-e3f9d02b504f"

[weakdeps]
Expand Down
3 changes: 2 additions & 1 deletion src/Compromise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import LinearAlgebra as LA
using DocStringExtensions

import Logging: @logmsg, LogLevel
import Printf: @sprintf

# Re-export symbols from important sub-modules
import Reexport: @reexport
Expand Down Expand Up @@ -93,7 +94,7 @@ end
export ForwardDiffBackend

# Import Radial Basis Function surrogates:
include("evaluators/RBFModels.jl")
include("evaluators/RBFModels/RBFModels.jl")
@reexport using .RBFModels

# Taylor Polynomial surrogates:
Expand Down
61 changes: 56 additions & 5 deletions src/CompromiseEvaluators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ abstract type AbstractNonlinearOperator end
abstract type AbstractNonlinearOperatorWithParams <: AbstractNonlinearOperator end
abstract type AbstractNonlinearOperatorNoParams <: AbstractNonlinearOperator end

const Vec = AbstractVector
const Mat = AbstractMatrix

# ### Common Methods
# Both, `AbstractNonlinearOperatorWithParams` and `AbstractNonlinearOperatorNoParams`
# have methods like `eval_op!`.
Expand Down Expand Up @@ -80,10 +83,19 @@ enforce_max_calls(op::AbstractNonlinearOperator)=true
Evaluate the operator `op` at variable vector `x` with parameters `p`
and mutate the target vector `y` to contain the result.
"""
function eval_op!(y, op::AbstractNonlinearOperatorWithParams, x, p)
function eval_op!(y::Vec, op::AbstractNonlinearOperatorWithParams, x::Vec, p)
return error("No implementation of `eval_op!` for operator $op.")
end

# Optional Parallel Evaluation:
function eval_op!(Y::AbstractMatrix, op::AbstractNonlinearOperatorWithParams, X::AbstractMatrix, p)
for (x, y) = zip(eachcol(X), eachcol(Y))
c = eval_op!(y, op, x)
!isnothing(c) && return c
end
return nothing
end

"""
eval_grads!(Dy, op, x, p)
Expand Down Expand Up @@ -175,16 +187,36 @@ function check_num_calls(op, ind; force::Bool=enforce_max_calls(op))
isnothing(max_call_tuple) && return nothing
ncalls = num_calls(op)
for i=ind
ni = ncalls[i]
mi = max_call_tuple[i]
isnothing(mi) && continue
ni = ncalls[i]
if ni >= mi
return "Maximum evaluation count reached, order=$(i-1), evals $(ni) >= $(mi)."
end
end
return nothing
end

# The same logic can be used to query the remaining evaluation budget:
function budget_num_calls(op, ind; force=enforce_max_calls(op))
!is_counted(op) && return nothing
!force && return nothing
max_call_tuple = max_calls(op)
isnothing(max_call_tuple) && return nothing
ncalls = num_calls(op)
budget_vec = Vector{Union{Int, Nothing}}(undef, length(ind))
for i=ind
mi = max_call_tuple[i]
if isnothing(mi)
budget_vec[i] = nothing
continue
end
ni = ncalls[i]
budget_vec[i] = mi-ni
end
return budget_vec
end

function func_vals!(y, op::AbstractNonlinearOperator, x, p; outputs=nothing)
@serve check_num_calls(op, 1)
if !isnothing(outputs) && supports_partial_evaluation(op)
Expand Down Expand Up @@ -232,12 +264,20 @@ end
# This also makes writing extensions a tiny bit easier.

#src The below methods have been written by ChatGPT according to what is above:
function eval_op!(y, op::AbstractNonlinearOperatorNoParams, x)
function eval_op!(y::Vec, op::AbstractNonlinearOperatorNoParams, x::Vec)
return error("No implementation of `eval_op!` for operator $op.")
end
function eval_grads!(Dy, op::AbstractNonlinearOperatorNoParams, x)
return error("No implementation of `eval_grads!` for operator $op.")
end
# Optional Parallel Evaluation:
function eval_op!(Y::Mat, op::AbstractNonlinearOperatorWithParams, X::Mat)
for (x, y) = zip(eachcol(X), eachcol(Y))
c = eval_op!(y, op, x)
!isnothing(c) && return c
end
return nothing
end
# Optional, derived method for values and gradients:
function eval_op_and_grads!(y, Dy, op::AbstractNonlinearOperatorNoParams, x)
@serve eval_op!(y, op, x)
Expand Down Expand Up @@ -279,7 +319,7 @@ end
# To also be able to use non-parametric operators in the more general setting,
# implement the parametric-interface:

function eval_op!(y, op::AbstractNonlinearOperatorNoParams, x, p)
function eval_op!(y::Vec, op::AbstractNonlinearOperatorNoParams, x::Vec, p)
return eval_op!(y, op, x)
end
function eval_grads!(Dy, op::AbstractNonlinearOperatorNoParams, x, p)
Expand Down Expand Up @@ -463,7 +503,7 @@ end
# ### Evaluation
# In place evaluation and differentiation, similar to `AbstractNonlinearOperatorNoParams`.
# Mandatory:
function model_op!(y, surr::AbstractSurrogateModel, x)
function model_op!(y::Vec, surr::AbstractSurrogateModel, x::Vec)
return nothing
end
# Mandatory:
Expand All @@ -490,6 +530,17 @@ function model_op_and_grads!(y, Dy, surr::AbstractSurrogateModel, x, outputs)
return nothing
end

# #### Optional Parallel Evaluation
function model_op!(Y::Mat, surr::AbstractSurrogateModel, X::Mat)
for (x, y) = zip(eachcol(X), eachcol(Y))
c = model_op!(y, surr, x)
if !isnothing(c)
return c
end
end
return nothing
end

# #### Safe-guarded, internal Methods
# The methods below are used in the algorithm and have the same signature as
# the corresponding methods for `AbstractNonlinearOperator`.
Expand Down
2 changes: 1 addition & 1 deletion src/evaluators/ExactModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function CE.init_surrogate(::ExactModelConfig, op, dim_in, dim_out, params, T)
return ExactModel(op, params)
end

function CE.model_op!(y, surr::ExactModel, x)
function CE.model_op!(y::AbstractVector, surr::ExactModel, x::AbstractVector)
#eval_op!(y, surr.op, x, surr.params)
# if `surr.op` has enforce_max_calls==true then func_vals checks for max_calls
# if it does not, we could/should do it here...
Expand Down
4 changes: 2 additions & 2 deletions src/evaluators/NonlinearFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ end

CE.enforce_max_calls(op::NonlinearParametricFunction)=op.enforce_max_calls

function CE.eval_op!(y, op::NonlinearParametricFunction, x, p)
function CE.eval_op!(y::AbstractVector, op::NonlinearParametricFunction, x::AbstractVector, p)
if op.func_iip
op.func(y, x, p)
else
Expand Down Expand Up @@ -322,7 +322,7 @@ end
@forward CE.set_num_calls!(op::NonlinearFunction)
@forward CE.provides_grads(op::NonlinearFunction)
@forward CE.provides_hessians(op::NonlinearFunction)
@forward CE.eval_op!(y, op::NonlinearFunction, x)
@forward CE.eval_op!(y::VecOrMat, op::NonlinearFunction, x::VecOrMat)
@forward CE.eval_grads!(Dy, op::NonlinearFunction, x)
@forward CE.eval_hessians!(H, op::NonlinearFunction, x)
@forward CE.eval_op_and_grads!(y, Dy, op::NonlinearFunction, x)
Expand Down
Loading

0 comments on commit ab5cba8

Please sign in to comment.