From c6bc594485b2461fbbff77f9c175f1bd8f16b722 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 22 Dec 2021 10:15:39 -0500 Subject: [PATCH] add p-norm --- stdlib/LinearAlgebra/src/generic.jl | 68 ++++++++++++++++------------- 1 file changed, 37 insertions(+), 31 deletions(-) diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index 775c096dc58b31..f94698cbb529ee 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -648,53 +648,59 @@ end norm(::Missing, p::Real=2) = missing # With dims keyword -norm0_dims!(B, A, dims) = count!(!iszero, B, A) -norm1_dims!(B, A, dims) = Base.mapreducedim!(norm, +, B, A) -normInf_dims!(B, A, dims) = Base.mapreducedim!(norm, max, B, A) -normMinusInf_dims!(B, A, dims) = Base.mapreducedim!(norm, min, B, A) +norm0_dims!(B, A) = count!(!iszero, B, A) +norm1_dims!(B, A) = Base.mapreducedim!(norm, +, B, A) +normInf_dims!(B, A) = Base.mapreducedim!(norm, max, B, A) +normMinusInf_dims!(B, A) = Base.mapreducedim!(norm, min, B, A) function norm2_dims!(B::AbstractArray, A::AbstractArray, dims) - sum!(norm_sqr, B, A) + sum!(LinearAlgebra.norm_sqr, B, A) map!(sqrt, B, B) - # Checking whether `A` is safe for the fast path is slower than taking it, check later: + # Checking whether `A` is safe for the fast path is slower than taking it, + # so check and fix any zero/infinite answers afterwards: + _norm_dims_check!(B, A, dims, LinearAlgebra.norm2) + B +end + +function normp_dims!(B::AbstractArray, A::AbstractArray, p::Real, dims) + if p == 0.5 + sum!(sqrt ∘ norm, B, A) + map!(abs2, B, B) + elseif p == 3 + sum!(x -> norm(x)^3, B, A) + map!(cbrt, B, B) + else + sum!(x -> norm(x)^p, B, A) + invp = inv(p) + map!(x -> x^invp, B, B) + end + _norm_dims_check!(B, A, dims, LinearAlgebra.normp, p) + B +end + +function _norm_dims_check!(B, A, dims, norm, args...) if A isa AbstractVecOrMat && dims == 1 for (i,x) in zip(eachindex(B), eachcol(A)) !iszero(B[i]) && isfinite(B[i]) && continue - B[i] = norm2(x) + B[i] = norm(x, args...) end elseif A isa AbstractVecOrMat && dims == 2 for (i,x) in zip(eachindex(B), eachrow(A)) !iszero(B[i]) && isfinite(B[i]) && continue - B[i] = norm2(x) + B[i] = norm(x, args...) end - # In general `eachslice(A; dims)` is not what we need here. elseif all(y -> !iszero(y) && isfinite(y), B) for I in CartesianIndices(B) !iszero(B[I]) && isfinite(B[I]) && continue - # This path is quite slow, but hopefully rare. + # Unfortunately `eachslice(A; dims)` is not what we need here. + # This path is not type-stable, so quite slow, but hopefully rare. J = ntuple(d -> d in dims ? Colon() : I[d], ndims(A)) - B[I] = norm2(view(A, J...)) + B[I] = norm(view(A, J...), args...) end end B end -function normp_dims!(B::AbstractArray, A::AbstractArray, p::Real, dims) - if A isa AbstractVecOrMat && dims == 1 - for (i,x) in zip(eachindex(B), eachcol(A)) - B[i] = normp(x, p) - end - elseif A isa AbstractVecOrMat && dims == 2 - for (i,x) in zip(eachindex(B), eachrow(A)) - B[i] = normp(x, p) - end - else - # This is slower, but doesn't affect type-stability of `norm` - copyto!(B, Base.mapslices(x -> normp(x,p), A; dims)) - end - B -end - """ norm(A::AbstractArray, [p]; dims) @@ -758,13 +764,13 @@ function norm(A::AbstractArray, p::Real=2; dims=:) if p == 2 norm2_dims!(B, A, dims) elseif p == 1 - norm1_dims!(B, A, dims) + norm1_dims!(B, A) elseif p == Inf - normInf_dims!(B, A, dims) + normInf_dims!(B, A) elseif p == 0 - norm0_dims!(B, A, dims) + norm0_dims!(B, A) elseif p == -Inf - normMinusInf_dims!(B, A, dims) + normMinusInf_dims!(B, A) else normp_dims!(B, A, p, dims) end