Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tidy C API to use @ccall #115

Merged
merged 6 commits into from
Apr 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 51 additions & 91 deletions src/C_API.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,7 @@ const INFINITY = 1e20
###

function c_api_License_SetString(license::String)
return ccall(
(:License_SetString, PATH_SOLVER),
Cint,
(Ptr{Cchar},),
license,
)
return @ccall PATH_SOLVER.License_SetString(license::Ptr{Cchar})::Cint
end

###
Expand Down Expand Up @@ -69,11 +64,8 @@ function OutputInterface(output_data)
end

function c_api_Output_SetInterface(o::OutputInterface)
return ccall(
(:Output_SetInterface, PATH_SOLVER),
Cvoid,
(Ref{OutputInterface},),
o,
return @ccall(
PATH_SOLVER.Output_SetInterface(o::Ref{OutputInterface})::Cvoid,
)
end

Expand All @@ -94,34 +86,29 @@ Base.cconvert(::Type{Ptr{Cvoid}}, x::Options) = x
Base.unsafe_convert(::Type{Ptr{Cvoid}}, x::Options) = x.ptr

function c_api_Options_Create()
ptr = ccall((:Options_Create, PATH_SOLVER), Ptr{Cvoid}, ())
return Options(ptr)
return Options(@ccall PATH_SOLVER.Options_Create()::Ptr{Cvoid})
end

function c_api_Options_Destroy(o::Options)
return ccall((:Options_Destroy, PATH_SOLVER), Cvoid, (Ptr{Cvoid},), o)
return @ccall PATH_SOLVER.Options_Destroy(o::Ptr{Cvoid})::Cvoid
end

function c_api_Options_Default(o::Options)
return ccall((:Options_Default, PATH_SOLVER), Cvoid, (Ptr{Cvoid},), o)
return @ccall PATH_SOLVER.Options_Default(o::Ptr{Cvoid})::Cvoid
end

function c_api_Options_Display(o::Options)
return ccall((:Options_Display, PATH_SOLVER), Cvoid, (Ptr{Cvoid},), o)
return @ccall PATH_SOLVER.Options_Display(o::Ptr{Cvoid})::Cvoid
end

function c_api_Options_Read(o::Options, filename::String)
return ccall(
(:Options_Read, PATH_SOLVER),
Cvoid,
(Ptr{Cvoid}, Ptr{Cchar}),
o,
filename,
return @ccall(
PATH_SOLVER.Options_Read(o::Ptr{Cvoid}, filename::Ptr{Cchar})::Cvoid,
)
end

function c_api_Path_AddOptions(o::Options)
return ccall((:Path_AddOptions, PATH_SOLVER), Cvoid, (Ptr{Cvoid},), o)
return @ccall PATH_SOLVER.Path_AddOptions(o::Ptr{Cvoid})::Cvoid
end

###
Expand All @@ -137,8 +124,7 @@ end

function _c_jac_typ(data_ptr::Ptr{Cvoid}, nnz::Cint, typ_ptr::Ptr{Cint})
data = unsafe_pointer_to_objref(data_ptr)::PresolveData
typ = unsafe_wrap(Array{Cint}, typ_ptr, nnz)
data.jac_typ(nnz, typ)
data.jac_typ(nnz, unsafe_wrap(Array{Cint}, typ_ptr, nnz))
return
end

Expand Down Expand Up @@ -189,10 +175,8 @@ function _c_problem_size(
nnz_ptr::Ptr{Cint},
)
id_data = unsafe_pointer_to_objref(id_ptr)::InterfaceData
n = unsafe_wrap(Array{Cint}, n_ptr, 1)
n[1] = id_data.n
nnz = unsafe_wrap(Array{Cint}, nnz_ptr, 1)
nnz[1] = id_data.nnz
unsafe_store!(n_ptr, id_data.n)
unsafe_store!(nnz_ptr, id_data.nnz)
return
end

Expand All @@ -204,14 +188,9 @@ function _c_bounds(
ub_ptr::Ptr{Cdouble},
)
id_data = unsafe_pointer_to_objref(id_ptr)::InterfaceData
z = unsafe_wrap(Array{Cdouble}, z_ptr, n)
lb = unsafe_wrap(Array{Cdouble}, lb_ptr, n)
ub = unsafe_wrap(Array{Cdouble}, ub_ptr, n)
for i in 1:n
z[i] = id_data.z[i]
lb[i] = id_data.lb[i]
ub[i] = id_data.ub[i]
end
copy!(unsafe_wrap(Array{Cdouble}, z_ptr, n), id_data.z)
copy!(unsafe_wrap(Array{Cdouble}, lb_ptr, n), id_data.lb)
copy!(unsafe_wrap(Array{Cdouble}, ub_ptr, n), id_data.ub)
return
end

Expand All @@ -224,8 +203,7 @@ function _c_function_evaluation(
id_data = unsafe_pointer_to_objref(id_ptr)::InterfaceData
x = unsafe_wrap(Array{Cdouble}, x_ptr, n)
f = unsafe_wrap(Array{Cdouble}, f_ptr, n)
err = id_data.F(n, x, f)
return err
return id_data.F(n, x, f)
end

function _c_jacobian_evaluation(
Expand All @@ -247,13 +225,13 @@ function _c_jacobian_evaluation(
f = unsafe_wrap(Array{Cdouble}, f_ptr, n)
err += id_data.F(n, x, f)
end
nnz = unsafe_wrap(Array{Cint}, nnz_ptr, 1)
nnz = unsafe_load(nnz_ptr)::Cint
col = unsafe_wrap(Array{Cint}, col_ptr, n)
len = unsafe_wrap(Array{Cint}, len_ptr, n)
row = unsafe_wrap(Array{Cint}, row_ptr, nnz[1])
data = unsafe_wrap(Array{Cdouble}, data_ptr, nnz[1])
err += id_data.J(n, nnz[1], x, col, len, row, data)
nnz[1] = sum(len)
row = unsafe_wrap(Array{Cint}, row_ptr, nnz)
data = unsafe_wrap(Array{Cdouble}, data_ptr, nnz)
err += id_data.J(n, nnz, x, col, len, row, data)
unsafe_store!(nnz_ptr, Cint(sum(len)))
return err
end

Expand Down Expand Up @@ -401,64 +379,52 @@ Base.cconvert(::Type{Ptr{Cvoid}}, x::MCP) = x
Base.unsafe_convert(::Type{Ptr{Cvoid}}, x::MCP) = x.ptr

function c_api_MCP_Create(n::Int, nnz::Int)
ptr = ccall((:MCP_Create, PATH_SOLVER), Ptr{Cvoid}, (Cint, Cint), n, nnz)
ptr = @ccall PATH_SOLVER.MCP_Create(n::Cint, nnz::Cint)::Ptr{Cvoid}
return MCP(n, ptr)
end

function c_api_MCP_Jacobian_Structure_Constant(m::MCP, flag::Bool)
ccall(
(:MCP_Jacobian_Structure_Constant, PATH_SOLVER),
Cvoid,
(Ptr{Cvoid}, Cint),
m,
flag,
)
@ccall PATH_SOLVER.MCP_Jacobian_Structure_Constant(
m::Ptr{Cvoid},
flag::Cint,
)::Cvoid
return
end

function c_api_MCP_Jacobian_Data_Contiguous(m::MCP, flag::Bool)
ccall(
(:MCP_Jacobian_Data_Contiguous, PATH_SOLVER),
Cvoid,
(Ptr{Cvoid}, Cint),
m,
flag,
)
@ccall PATH_SOLVER.MCP_Jacobian_Data_Contiguous(
m::Ptr{Cvoid},
flag::Cint,
)::Cvoid
return
end

function c_api_MCP_Destroy(m::MCP)
if m.ptr === C_NULL
return
end
ccall((:MCP_Destroy, PATH_SOLVER), Cvoid, (Ptr{Cvoid},), m)
@ccall PATH_SOLVER.MCP_Destroy(m::Ptr{Cvoid})::Cvoid
return
end

function c_api_MCP_SetInterface(m::MCP, interface::MCP_Interface)
ccall(
(:MCP_SetInterface, PATH_SOLVER),
Cvoid,
(Ptr{Cvoid}, Ref{MCP_Interface}),
m,
interface,
)
@ccall PATH_SOLVER.MCP_SetInterface(
m::Ptr{Cvoid},
interface::Ref{MCP_Interface},
)::Cvoid
return
end

function c_api_MCP_SetPresolveInterface(m::MCP, interface::Presolve_Interface)
ccall(
(:MCP_SetPresolveInterface, PATH_SOLVER),
Cvoid,
(Ptr{Cvoid}, Ref{Presolve_Interface}),
m,
interface,
)
@ccall PATH_SOLVER.MCP_SetPresolveInterface(
m::Ptr{Cvoid},
interface::Ref{Presolve_Interface},
)::Cvoid
return
end

function c_api_MCP_GetX(m::MCP)
ptr = ccall((:MCP_GetX, PATH_SOLVER), Ptr{Cdouble}, (Ptr{Cvoid},), m)
ptr = @ccall PATH_SOLVER.MCP_GetX(m::Ptr{Cvoid})::Ptr{Cdouble}
return copy(unsafe_wrap(Array{Cdouble}, ptr, m.n))
end

Expand Down Expand Up @@ -580,7 +546,7 @@ Check that the current license (stored in the environment variable
Returns a nonzero value on successful completion, and a zero value on failure.
"""
function c_api_Path_CheckLicense(n::Int, nnz::Int)
return ccall((:Path_CheckLicense, PATH_SOLVER), Cint, (Cint, Cint), n, nnz)
return @ccall PATH_SOLVER.Path_CheckLicense(n::Cint, nnz::Cint)::Cint
end

"""
Expand All @@ -589,8 +555,7 @@ end
Return a string of the PATH version.
"""
function c_api_Path_Version()
ptr = ccall((:Path_Version, PATH_SOLVER), Ptr{Cchar}, ())
return unsafe_string(ptr)
return unsafe_string(@ccall PATH_SOLVER.Path_Version()::Ptr{Cchar})
end

"""
Expand All @@ -599,12 +564,8 @@ end
Returns a MCP_Termination status.
"""
function c_api_Path_Solve(m::MCP, info::Information)
return ccall(
(:Path_Solve, PATH_SOLVER),
Cint,
(Ptr{Cvoid}, Ref{Information}),
m,
info,
return @ccall(
PATH_SOLVER.Path_Solve(m::Ptr{Cvoid}, info::Ref{Information})::Cint,
)
end

Expand Down Expand Up @@ -812,12 +773,11 @@ function solve_mcp(
gc_root[m_interface] = true
c_api_MCP_SetInterface(m, m_interface)
if jacobian_structure_constant && !isempty(jacobian_linear_elements)
presolve_data = PresolveData() do nnz, types
for i in jacobian_linear_elements
types[i] = PRESOLVE_LINEAR
end
function presolve_fn(::Cint, types::Vector{Cint})
types[jacobian_linear_elements] .= PRESOLVE_LINEAR
return
end
presolve_data = PresolveData(presolve_fn)
# We shouldn't GC presolve_data until we exit the GC.@preserve block.
gc_root[presolve_data] = true
presolve_interface = Presolve_Interface(presolve_data)
Expand Down Expand Up @@ -863,7 +823,7 @@ function _linear_function(M::AbstractMatrix, q::Vector)
elseif size(M, 1) != length(q)
error("q is wrong shape. Expected $(size(M, 1)), got $(length(q)).")
end
return (n::Cint, x::Vector{Cdouble}, f::Vector{Cdouble}) -> begin
return function F(n::Cint, x::Vector{Cdouble}, f::Vector{Cdouble})
f .= M * x .+ q
return Cint(0)
end
Expand All @@ -872,15 +832,15 @@ end
function _linear_jacobian(M::SparseArrays.SparseMatrixCSC{Cdouble,Cint})
# Size is checked with error message in _linear_function.
@assert size(M, 1) == size(M, 2)
return (
return function J(
n::Cint,
nnz::Cint,
x::Vector{Cdouble},
col::Vector{Cint},
len::Vector{Cint},
row::Vector{Cint},
data::Vector{Cdouble},
) -> begin
)
@assert n == length(x) == length(col) == length(len) == size(M, 1)
@assert nnz == length(row) == length(data)
@assert nnz >= SparseArrays.nnz(M)
Expand Down
Loading