Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extrapolation keyword #193

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/DataInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@
u
end

const EXTRAPOLATION_ERROR = "Cannot extrapolate as `extrapolate` keyword passed was `false`"
struct ExtrapolationError <: Exception end
function Base.showerror(io::IO, e::ExtrapolationError)
print(io, EXTRAPOLATION_ERROR)

Check warning on line 42 in src/DataInterpolations.jl

View check run for this annotation

Codecov / codecov/patch

src/DataInterpolations.jl#L41-L42

Added lines #L41 - L42 were not covered by tests
end

export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation,
AkimaInterpolation, ConstantInterpolation, QuadraticSpline, CubicSpline,
BSplineInterpolation, BSplineApprox
Expand Down
9 changes: 8 additions & 1 deletion src/derivatives.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
derivative(A, t) = derivative(A, t, firstindex(A.t) - 1)[1]
function derivative(A, t)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
derivative(A, t, firstindex(A.t) - 1)[1]
end

function derivative(A::LinearInterpolation{<:AbstractVector}, t::Number, iguess)
idx = searchsortedfirstcorrelated(A.t, t, iguess)
Expand Down Expand Up @@ -33,6 +36,7 @@
end

function derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
idxs = findRequiredIdxs(A, t)
if A.t[idxs[1]] == t
return zero(A.u[idxs[1]])
Expand Down Expand Up @@ -68,6 +72,7 @@
end

function derivative(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
idxs = findRequiredIdxs(A, t)
if A.t[idxs[1]] == t
return zero(A.u[:, idxs[1]])
Expand Down Expand Up @@ -115,10 +120,12 @@
end

function derivative(A::ConstantInterpolation{<:AbstractVector}, t::Number)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())

Check warning on line 123 in src/derivatives.jl

View check run for this annotation

Codecov / codecov/patch

src/derivatives.jl#L123

Added line #L123 was not covered by tests
return isempty(searchsorted(A.t, t)) ? zero(A.u[1]) : eltype(A.u)(NaN)
end

function derivative(A::ConstantInterpolation{<:AbstractMatrix}, t::Number)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())

Check warning on line 128 in src/derivatives.jl

View check run for this annotation

Codecov / codecov/patch

src/derivatives.jl#L128

Added line #L128 was not covered by tests
return isempty(searchsorted(A.t, t)) ? zero(A.u[:, 1]) : eltype(A.u)(NaN) .* A.u[:, 1]
end

Expand Down
126 changes: 85 additions & 41 deletions src/interpolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,38 @@
struct LinearInterpolation{uType, tType, FT, T} <: AbstractInterpolation{FT, T}
u::uType
t::tType
function LinearInterpolation{FT}(u, t) where {FT}
new{typeof(u), typeof(t), FT, eltype(u)}(u, t)
extrapolate::Bool
function LinearInterpolation{FT}(u, t, extrapolate) where {FT}
new{typeof(u), typeof(t), FT, eltype(u)}(u, t, extrapolate)
end
end

function LinearInterpolation(u, t)
function LinearInterpolation(u, t; extrapolate = true)
u, t = munge_data(u, t)
LinearInterpolation{true}(u, t)
LinearInterpolation{true}(u, t, extrapolate)
end

### Quadratic Interpolation
struct QuadraticInterpolation{uType, tType, FT, T} <: AbstractInterpolation{FT, T}
u::uType
t::tType
mode::Symbol
function QuadraticInterpolation{FT}(u, t, mode) where {FT}
extrapolate::Bool
function QuadraticInterpolation{FT}(u, t, mode, extrapolate) where {FT}
mode ∈ (:Forward, :Backward) ||
error("mode should be :Forward or :Backward for QuadraticInterpolation")
new{typeof(u), typeof(t), FT, eltype(u)}(u, t, mode)
new{typeof(u), typeof(t), FT, eltype(u)}(u, t, mode, extrapolate)
end
end

function QuadraticInterpolation(u, t, mode)
function QuadraticInterpolation(u, t, mode; extrapolate = true)
u, t = munge_data(u, t)
QuadraticInterpolation{true}(u, t, mode)
QuadraticInterpolation{true}(u, t, mode, extrapolate)
end

QuadraticInterpolation(u, t) = QuadraticInterpolation(u, t, :Forward)
function QuadraticInterpolation(u, t; extrapolate = true)
QuadraticInterpolation(u, t, :Forward; extrapolate)
end

### Lagrange Interpolation
struct LagrangeInterpolation{uType, tType, FT, T, bcacheType} <:
Expand All @@ -38,22 +42,27 @@ struct LagrangeInterpolation{uType, tType, FT, T, bcacheType} <:
t::tType
n::Int
bcache::bcacheType
function LagrangeInterpolation{FT}(u, t, n) where {FT}
extrapolate::Bool
function LagrangeInterpolation{FT}(u, t, n, extrapolate) where {FT}
bcache = zeros(eltype(u[1]), n + 1)
fill!(bcache, NaN)
new{typeof(u), typeof(t), FT, eltype(u), typeof(bcache)}(u, t, n, bcache)
new{typeof(u), typeof(t), FT, eltype(u), typeof(bcache)}(u,
t,
n,
bcache,
extrapolate)
end
end

function LagrangeInterpolation(u, t, n = nothing)
function LagrangeInterpolation(u, t, n = nothing; extrapolate = true)
u, t = munge_data(u, t)
if isnothing(n)
n = length(t) - 1 # degree
end
if n != length(t) - 1
error("Currently only n=length(t) - 1 is supported")
end
LagrangeInterpolation{true}(u, t, n)
LagrangeInterpolation{true}(u, t, n, extrapolate)
end

### Akima Interpolation
Expand All @@ -64,17 +73,19 @@ struct AkimaInterpolation{uType, tType, bType, cType, dType, FT, T} <:
b::bType
c::cType
d::dType
function AkimaInterpolation{FT}(u, t, b, c, d) where {FT}
extrapolate::Bool
function AkimaInterpolation{FT}(u, t, b, c, d, extrapolate) where {FT}
new{typeof(u), typeof(t), typeof(b), typeof(c),
typeof(d), FT, eltype(u)}(u,
t,
b,
c,
d)
d,
extrapolate)
end
end

function AkimaInterpolation(u, t)
function AkimaInterpolation(u, t; extrapolate = true)
u, t = munge_data(u, t)
n = length(t)
dt = diff(t)
Expand All @@ -96,22 +107,23 @@ function AkimaInterpolation(u, t)
c = (3.0 .* m[3:(end - 2)] .- 2.0 .* b[1:(end - 1)] .- b[2:end]) ./ dt
d = (b[1:(end - 1)] .+ b[2:end] .- 2.0 .* m[3:(end - 2)]) ./ dt .^ 2

AkimaInterpolation{true}(u, t, b, c, d)
AkimaInterpolation{true}(u, t, b, c, d, extrapolate)
end

### ConstantInterpolation Interpolation
struct ConstantInterpolation{uType, tType, dirType, FT, T} <: AbstractInterpolation{FT, T}
u::uType
t::tType
dir::Symbol # indicates if value to the $dir should be used for the interpolation
function ConstantInterpolation{FT}(u, t, dir) where {FT}
new{typeof(u), typeof(t), typeof(dir), FT, eltype(u)}(u, t, dir)
extrapolate::Bool
function ConstantInterpolation{FT}(u, t, dir, extrapolate) where {FT}
new{typeof(u), typeof(t), typeof(dir), FT, eltype(u)}(u, t, dir, extrapolate)
end
end

function ConstantInterpolation(u, t; dir = :left)
function ConstantInterpolation(u, t; dir = :left, extrapolate = true)
u, t = munge_data(u, t)
ConstantInterpolation{true}(u, t, dir)
ConstantInterpolation{true}(u, t, dir, extrapolate)
end

Base.@deprecate_binding ZeroSpline ConstantInterpolation
Expand All @@ -124,17 +136,21 @@ struct QuadraticSpline{uType, tType, tAType, dType, zType, FT, T} <:
tA::tAType
d::dType
z::zType
function QuadraticSpline{FT}(u, t, tA, d, z) where {FT}
extrapolate::Bool
function QuadraticSpline{FT}(u, t, tA, d, z, extrapolate) where {FT}
new{typeof(u), typeof(t), typeof(tA),
typeof(d), typeof(z), FT, eltype(u)}(u,
t,
tA,
d,
z)
z,
extrapolate)
end
end

function QuadraticSpline(u::uType, t) where {uType <: AbstractVector{<:Number}}
function QuadraticSpline(u::uType,
t;
extrapolate = true) where {uType <: AbstractVector{<:Number}}
u, t = munge_data(u, t)
s = length(t)
dl = ones(eltype(t), s - 1)
Expand All @@ -147,10 +163,10 @@ function QuadraticSpline(u::uType, t) where {uType <: AbstractVector{<:Number}}

d = map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s)
z = tA \ d
QuadraticSpline{true}(u, t, tA, d, z)
QuadraticSpline{true}(u, t, tA, d, z, extrapolate)
end

function QuadraticSpline(u::uType, t) where {uType <: AbstractVector}
function QuadraticSpline(u::uType, t; extrapolate = true) where {uType <: AbstractVector}
u, t = munge_data(u, t)
s = length(t)
dl = ones(eltype(t), s - 1)
Expand All @@ -163,7 +179,7 @@ function QuadraticSpline(u::uType, t) where {uType <: AbstractVector}
d = transpose(reshape(reduce(hcat, d_), :, s))
z_ = reshape(transpose(tA \ d), size(u[1])..., :)
z = [z_s for z_s in eachslice(z_, dims = ndims(z_))]
QuadraticSpline{true}(u, t, tA, d, z)
QuadraticSpline{true}(u, t, tA, d, z, extrapolate)
end

# Cubic Spline Interpolation
Expand All @@ -172,12 +188,19 @@ struct CubicSpline{uType, tType, hType, zType, FT, T} <: AbstractInterpolation{F
t::tType
h::hType
z::zType
function CubicSpline{FT}(u, t, h, z) where {FT}
new{typeof(u), typeof(t), typeof(h), typeof(z), FT, eltype(u)}(u, t, h, z)
extrapolate::Bool
function CubicSpline{FT}(u, t, h, z, extrapolate) where {FT}
new{typeof(u), typeof(t), typeof(h), typeof(z), FT, eltype(u)}(u,
t,
h,
z,
extrapolate)
end
end

function CubicSpline(u::uType, t) where {uType <: AbstractVector{<:Number}}
function CubicSpline(u::uType,
t;
extrapolate = true) where {uType <: AbstractVector{<:Number}}
u, t = munge_data(u, t)
n = length(t) - 1
h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0)
Expand All @@ -194,10 +217,10 @@ function CubicSpline(u::uType, t) where {uType <: AbstractVector{<:Number}}
6(u[i + 1] - u[i]) / h[i + 1] - 6(u[i] - u[i - 1]) / h[i],
1:(n + 1))
z = tA \ d
CubicSpline{true}(u, t, h[1:(n + 1)], z)
CubicSpline{true}(u, t, h[1:(n + 1)], z, extrapolate)
end

function CubicSpline(u::uType, t) where {uType <: AbstractVector}
function CubicSpline(u::uType, t; extrapolate = true) where {uType <: AbstractVector}
u, t = munge_data(u, t)
n = length(t) - 1
h = vcat(0, map(k -> t[k + 1] - t[k], 1:(length(t) - 1)), 0)
Expand All @@ -211,7 +234,7 @@ function CubicSpline(u::uType, t) where {uType <: AbstractVector}
d = transpose(reshape(reduce(hcat, d_), :, n + 1))
z_ = reshape(transpose(tA \ d), size(u[1])..., :)
z = [z_s for z_s in eachslice(z_, dims = ndims(z_))]
CubicSpline{true}(u, t, h[1:(n + 1)], z)
CubicSpline{true}(u, t, h[1:(n + 1)], z, extrapolate)
end

### BSpline Curve Interpolation
Expand All @@ -225,19 +248,29 @@ struct BSplineInterpolation{uType, tType, pType, kType, cType, FT, T} <:
c::cType # control points
pVecType::Symbol
knotVecType::Symbol
function BSplineInterpolation{FT}(u, t, d, p, k, c, pVecType, knotVecType) where {FT}
extrapolate::Bool
function BSplineInterpolation{FT}(u,
t,
d,
p,
k,
c,
pVecType,
knotVecType,
extrapolate) where {FT}
new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), FT, eltype(u)}(u,
t,
d,
p,
k,
c,
pVecType,
knotVecType)
knotVecType,
extrapolate)
end
end

function BSplineInterpolation(u, t, d, pVecType, knotVecType)
function BSplineInterpolation(u, t, d, pVecType, knotVecType; extrapolate = true)
u, t = munge_data(u, t)
n = length(t)
s = zero(eltype(u))
Expand Down Expand Up @@ -298,7 +331,7 @@ function BSplineInterpolation(u, t, d, pVecType, knotVecType)
# control points
N = spline_coefficients(n, d, k, p)
c = vec(N \ u[:, :])
BSplineInterpolation{true}(u, t, d, p, k, c, pVecType, knotVecType)
BSplineInterpolation{true}(u, t, d, p, k, c, pVecType, knotVecType, extrapolate)
end

### BSpline Curve Approx
Expand All @@ -313,7 +346,17 @@ struct BSplineApprox{uType, tType, pType, kType, cType, FT, T} <:
c::cType # control points
pVecType::Symbol
knotVecType::Symbol
function BSplineApprox{FT}(u, t, d, h, p, k, c, pVecType, knotVecType) where {FT}
extrapolate::Bool
function BSplineApprox{FT}(u,
t,
d,
h,
p,
k,
c,
pVecType,
knotVecType,
extrapolate) where {FT}
new{typeof(u), typeof(t), typeof(p), typeof(k), typeof(c), FT, eltype(u)}(u,
t,
d,
Expand All @@ -322,11 +365,12 @@ struct BSplineApprox{uType, tType, pType, kType, cType, FT, T} <:
k,
c,
pVecType,
knotVecType)
knotVecType,
extrapolate)
end
end

function BSplineApprox(u, t, d, h, pVecType, knotVecType)
function BSplineApprox(u, t, d, h, pVecType, knotVecType; extrapolate = true)
u, t = munge_data(u, t)
n = length(t)
s = zero(eltype(u))
Expand Down Expand Up @@ -409,5 +453,5 @@ function BSplineApprox(u, t, d, h, pVecType, knotVecType)
M = transpose(N) * N
P = M \ Q
c[2:(end - 1)] .= vec(P)
BSplineApprox{true}(u, t, d, h, p, k, c, pVecType, knotVecType)
BSplineApprox{true}(u, t, d, h, p, k, c, pVecType, knotVecType, extrapolate)
end
8 changes: 7 additions & 1 deletion src/interpolation_methods.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
_interpolate(interp, t) = _interpolate(interp, t, firstindex(interp.t) - 1)[1]
function _interpolate(interp, t)
((t < interp.t[1] || t > interp.t[end]) && !interp.extrapolate) &&
throw(ExtrapolationError())
_interpolate(interp, t, firstindex(interp.t) - 1)[1]
end

# Linear Interpolation
function _interpolate(A::LinearInterpolation{<:AbstractVector}, t::Number, iguess)
Expand Down Expand Up @@ -53,6 +57,7 @@ end

# Lagrange Interpolation
function _interpolate(A::LagrangeInterpolation{<:AbstractVector}, t::Number)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
idxs = findRequiredIdxs(A, t)
if A.t[idxs[1]] == t
return A.u[idxs[1]]
Expand Down Expand Up @@ -81,6 +86,7 @@ function _interpolate(A::LagrangeInterpolation{<:AbstractVector}, t::Number)
end

function _interpolate(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
idxs = findRequiredIdxs(A, t)
if A.t[idxs[1]] == t
return A.u[:, idxs[1]]
Expand Down
2 changes: 1 addition & 1 deletion src/interpolation_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,4 @@ function searchsortedlastcorrelated(v::AbstractVector, x, guess)
end

searchsortedfirstcorrelated(r::AbstractRange, x, _) = searchsortedfirst(r, x)
searchsortedlastcorrelated(r::AbstractRange, x, _) = searchsortedlast(r, x)
searchsortedlastcorrelated(r::AbstractRange, x, _) = searchsortedlast(r, x)
Loading