-
Notifications
You must be signed in to change notification settings - Fork 0
/
RbfKernel.cs
129 lines (98 loc) · 2.98 KB
/
RbfKernel.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Grammophone.Vectors;
namespace Grammophone.Kernels
{
/// <summary>
/// Gaussian RBF kernel for vectors.
/// </summary>
[Serializable]
public class RbfKernel : Kernel<Vector>
{
#region Auxilliary types
[Serializable]
private struct Component
{
public double Weight;
public Vector Vector;
}
#endregion
#region Private members
private IList<Component> components;
#endregion
#region Construction
/// <summary>
/// Create.
/// </summary>
/// <param name="σ2">The variance of the Gaussian.</param>
/// <param name="dimensionality">
/// The dimensionality of the vectors processed by the kernel.
/// </param>
public RbfKernel(double σ2, int dimensionality)
{
if (σ2 <= 0.0) throw new ArgumentException("The variance must be positive.", "σ2");
if (dimensionality < 0)
throw new ArgumentException("dimensionality must be non-negative.", "dimensionality");
this.σ2 = σ2;
this.Dimensionality = dimensionality;
this.components = new List<Component>();
}
#endregion
#region Public properties
/// <summary>
/// The variance of the Gaussian.
/// </summary>
public double σ2 { get; private set; }
/// <summary>
/// The dimensionality of the vectors processed by the kernel.
/// </summary>
public int Dimensionality { get; private set; }
#endregion
#region Kernel<Vector> implementation
public override bool HasComponents
{
get { return this.components.Count > 0; }
}
public override double Compute(Vector arg1, Vector arg2)
{
if (arg1 == null) throw new ArgumentNullException("arg1");
if (arg2 == null) throw new ArgumentNullException("arg2");
if (arg1.Length != this.Dimensionality)
throw new ArgumentException(
"The supplied vector is not compatible with the kernel's dimensionality",
"arg1");
return Math.Exp(-(arg1.Norm2 + arg2.Norm2 - 2 * arg1 * arg2) / (2 * σ2));
}
public override double ComputeSum(Vector arg)
{
if (arg == null) throw new ArgumentNullException("arg");
if (arg.Length != this.Dimensionality)
throw new ArgumentException(
"The supplied vector is not compatible with the kernel's dimensionality",
"arg");
double innerArgDotArg = arg.Norm2;
return this.components
.Sum(c => c.Weight * Math.Exp(-(innerArgDotArg + c.Vector.Norm2 - 2 * arg * c.Vector) / (2 * σ2)));
}
public override void AddComponent(double weight, Vector arg)
{
if (arg == null) throw new ArgumentNullException("arg");
if (arg.Length != this.Dimensionality)
throw new ArgumentException(
"The supplied vector is not compatible with the kernel's dimensionality",
"arg");
this.components.Add(new Component() { Weight = weight, Vector = arg });
}
public override void ClearComponents()
{
this.components.Clear();
}
public override Kernel<Vector> ForkNew()
{
return new RbfKernel(this.σ2, this.Dimensionality);
}
#endregion
}
}