-
Notifications
You must be signed in to change notification settings - Fork 18
/
forward.hpp
69 lines (60 loc) · 1.82 KB
/
forward.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include <cmath>
#include <iostream>
namespace autodiff {
template <typename T>
class dual {
T val_;
T tan_;
public:
dual(const T& val = 0, const T& tan = 0)
: val_(val), tan_(tan) { }
const T& val() const { return val_; }
const T& tan() const { return tan_; }
};
template <typename T>
inline dual<T> exp(const dual<T>& x) {
using std::exp;
T y = exp(x.val());
return dual(y, x.tan() * y);
}
template <typename T>
inline dual<T> log(const dual<T>& x) {
using std::log;
T y = log(x.val());
return dual(y, x.tan() / y );
}
template <typename T>
inline dual<T> operator*(const dual<T>& x1, const dual<T>& x2) {
return dual(x1.val() * x2.val(),
x1.tan() * x2.val() + x1.val() * x2.tan());
}
template <typename T, typename U,
typename = std::enable_if_t<std::is_arithmetic_v<U>>>
inline dual<T> operator*(const dual<T>& x1, const U& c2) {
return dual(x1.val() * c2, x1.tan() * c2);
}
template <typename T, typename U,
typename = std::enable_if_t<std::is_arithmetic_v<U>>>
inline dual<T> operator*(const U& c1, const dual<T>& x2) {
return dual(c1 * x2.val(), c1 * x2.tan());
}
template <typename T>
inline dual<T> operator+(const dual<T>& x1, const dual<T>& x2) {
return dual(x1.val() + x2.val(), x1.tan() + x2.tan());
}
template <typename T, typename U,
typename = std::enable_if_t<std::is_arithmetic_v<U>>>
inline dual<T> operator+(const dual<T>& x1, const U& c2) {
return dual(x1.val() + c2, x1.tan());
}
template <typename T, typename U,
typename = std::enable_if_t<std::is_arithmetic_v<U>>>
inline dual<T> operator+(const U& c1, const dual<T>& x2) {
return dual(c1 + x2.val(), x2.tan());
}
template <typename T>
std::ostream& operator<<(std::ostream& o, const dual<T>& y) {
o << "<" << y.val() << ", " << y.tan() << ">";
return o;
}
} // namespace autodiff