-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmath.zig
83 lines (68 loc) · 3.27 KB
/
math.zig
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
const std = @import("std");
////////////////////////////////////////////////////////////////////////////////
// Utility Functions
////////////////////////////////////////////////////////////////////////////////
pub fn zero(comptime T: type) T {
if (T == std.math.Complex(f32)) {
return .{ .re = 0, .im = 0 };
} else if (T == f32) {
return 0;
} else unreachable;
}
pub fn sub(comptime T: type, x: T, y: T) T {
if (T == std.math.Complex(f32)) {
return x.sub(y);
} else if (T == f32) {
return x - y;
} else unreachable;
}
pub fn scalarMul(comptime T: type, x: T, scalar: f32) T {
if (T == std.math.Complex(f32)) {
return .{ .re = x.re * scalar, .im = x.im * scalar };
} else if (T == f32) {
return x * scalar;
} else unreachable;
}
pub fn scalarDiv(comptime T: type, x: T, scalar: f32) T {
if (T == std.math.Complex(f32)) {
return .{ .re = x.re / scalar, .im = x.im / scalar };
} else if (T == f32) {
return x / scalar;
} else unreachable;
}
pub fn innerProduct(comptime T: type, comptime U: type, x: []const T, y: []const U) T {
var acc = zero(T);
std.debug.assert(x.len == y.len);
if (T == std.math.Complex(f32) and U == std.math.Complex(f32)) {
for (x, 0..) |_, i| acc = acc.add(x[i].mul(y[i]));
} else if (T == std.math.Complex(f32) and U == f32) {
for (x, 0..) |_, i| acc = acc.add(.{ .re = x[i].re * y[i], .im = x[i].im * y[i] });
} else if (T == f32 and U == f32) {
for (x, 0..) |_, i| acc += x[i] * y[i];
} else unreachable;
return acc;
}
////////////////////////////////////////////////////////////////////////////////
// Tests
////////////////////////////////////////////////////////////////////////////////
test "zero" {
try std.testing.expectEqual(std.math.Complex(f32).init(0, 0), zero(std.math.Complex(f32)));
try std.testing.expectEqual(@as(f32, 0), zero(f32));
}
test "sub" {
try std.testing.expectEqual(std.math.Complex(f32).init(1, 2), sub(std.math.Complex(f32), std.math.Complex(f32).init(2, 3), std.math.Complex(f32).init(1, 1)));
try std.testing.expectEqual(@as(f32, 2), sub(f32, 4, 2));
}
test "scalarMul" {
try std.testing.expectEqual(std.math.Complex(f32).init(3, 6), scalarMul(std.math.Complex(f32), std.math.Complex(f32).init(1, 2), 3));
try std.testing.expectEqual(@as(f32, 6), scalarMul(f32, 2, 3));
}
test "scalarDiv" {
try std.testing.expectEqual(std.math.Complex(f32).init(1, 2), scalarDiv(std.math.Complex(f32), std.math.Complex(f32).init(3, 6), 3));
try std.testing.expectEqual(@as(f32, 2), scalarDiv(f32, 6, 3));
}
test "innerProduct" {
try std.testing.expectEqual(std.math.Complex(f32).init(-24, 85), innerProduct(std.math.Complex(f32), std.math.Complex(f32), &[3]std.math.Complex(f32){ .{ .re = 1, .im = 2 }, .{ .re = 2, .im = 3 }, .{ .re = 3, .im = 4 } }, &[3]std.math.Complex(f32){ .{ .re = 4, .im = 5 }, .{ .re = 5, .im = 6 }, .{ .re = 6, .im = 7 } }));
try std.testing.expectEqual(std.math.Complex(f32).init(14, 20), innerProduct(std.math.Complex(f32), f32, &[3]std.math.Complex(f32){ .{ .re = 1, .im = 2 }, .{ .re = 2, .im = 3 }, .{ .re = 3, .im = 4 } }, &[3]f32{ 1, 2, 3 }));
try std.testing.expectEqual(@as(f32, 32), innerProduct(f32, f32, &[3]f32{ 1, 2, 3 }, &[3]f32{ 4, 5, 6 }));
}