-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSingleEntryVector.jl
107 lines (90 loc) · 2.89 KB
/
SingleEntryVector.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
SingleEntryVector{N} <: AbstractVector{N}
A sparse unit vector with arbitrary one-element.
### Fields
- `i` -- index of non-zero entry
- `n` -- vector length
- `v` -- non-zero entry
"""
struct SingleEntryVector{N} <: AbstractVector{N}
i::Int
n::Int
v::N
end
# convenience constructor with one-element of corresponding type
SingleEntryVector{N}(i::Int, n::Int) where {N} = SingleEntryVector{N}(i, n, one(N))
function Base.getindex(e::SingleEntryVector{N}, i::Int) where {N}
@boundscheck @assert 1 <= i <= e.n
return i == e.i ? e.v : zero(N)
end
Base.size(e::SingleEntryVector) = (e.n,)
# define matrix-vector multiplication with SingleEntryVector
# due to type piracy in other packages, we need to enumerate the matrix types
# explicitly here
for MT in [Matrix, AbstractSparseMatrix]
function Base.:(*)(A::MT, e::SingleEntryVector)
return A[:, e.i] * e.v
end
end
# multiplication with diagonal matrix
function Base.:(*)(D::Diagonal{N,V},
e::SingleEntryVector{N}) where {N,V<:AbstractVector{N}}
return SingleEntryVector(e.i, e.n, D.diag[e.i] * e.v)
end
# negation
function Base.:(-)(e::SingleEntryVector{N}) where {N}
return SingleEntryVector(e.i, e.n, -e.v)
end
# arithmetic
for (opS, opF) in ((:(+), +), (:(-), -))
@eval begin
function Base.$opS(e1::SingleEntryVector{N}, e2::SingleEntryVector{N}) where {N}
if e1.n != e2.n
throw(DimensionMismatch("dimensions must match, but they are $(length(e1)) and $(length(e2)) respectively"))
end
if e1.i == e2.i
return SingleEntryVector(e1.i, e1.n, $opF(e1.v, e2.v))
else
res = spzeros(N, e1.n)
@inbounds begin
res[e1.i] = e1.v
res[e2.i] = $opF(e2.v)
end
return res
end
end
end
end
function inner(e1::SingleEntryVector{N}, A::AbstractMatrix{N},
e2::SingleEntryVector{N}) where {N}
return A[e1.i, e2.i] * e1.v * e2.v
end
# norm
function LinearAlgebra.norm(e::SingleEntryVector, ::Real=Inf)
return abs(e.v)
end
function append_zeros(e::SingleEntryVector, n::Int)
return SingleEntryVector(e.i, e.n + n, e.v)
end
function prepend_zeros(e::SingleEntryVector, n::Int)
return SingleEntryVector(e.i + n, e.n + n, e.v)
end
# distance = norm of the difference ||x - y||_p
function distance(e1::SingleEntryVector{N}, e2::SingleEntryVector{N}; p::Real=N(2)) where {N}
if e1.n != e2.n
throw(DimensionMismatch("dimensions must match, but they are " *
"$(length(e1)) and $(length(e2)) respectively"))
end
if e1.i == e2.i
return abs(e1.v - e2.v)
else
a = abs(e1.v)
b = abs(e2.v)
if isinf(p)
return max(a, b)
else
s = a^p + b^p
return s^(1 / p)
end
end
end