-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathpairwise_distance.py
84 lines (62 loc) · 1.95 KB
/
pairwise_distance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0
import argparse
from math import sqrt
from string import Template
from time import time
import dpctl
import dpctl.memory as dpctl_mem
import dpnp as np
import numba_dpex as ndpx
parser = argparse.ArgumentParser(
description="Program to compute pairwise distance"
)
parser.add_argument("-n", type=int, default=10, help="Number of points")
parser.add_argument("-d", type=int, default=3, help="Dimensions")
parser.add_argument("-r", type=int, default=1, help="repeat")
parser.add_argument("-l", type=int, default=1, help="local_work_size")
args = parser.parse_args()
# Global work size is equal to the number of points
global_size = ndpx.Range(args.n)
# Local Work size is optional
local_size = ndpx.Range(args.l)
X = np.random.random((args.n, args.d)).astype(np.single)
D = np.empty((args.n, args.n), dtype=np.single)
@ndpx.kernel
def pairwise_distance(X, D, xshape0, xshape1):
"""
An Euclidean pairwise distance computation implemented as
a ``kernel`` function.
"""
idx = ndpx.get_global_id(0)
d0 = X[idx, 0] - X[idx, 0]
# for i in range(xshape0):
for j in range(X.shape[0]):
d = d0
for k in range(X.shape[1]):
tmp = X[idx, k] - X[j, k]
d += tmp * tmp
D[idx, j] = sqrt(d)
def driver():
# measure running time
times = list()
for repeat in range(args.r):
start = time()
pairwise_distance[ndpx.NdRange(global_size, local_size)](
X, D, X.shape[0], X.shape[1]
)
end = time()
total_time = end - start
times.append(total_time)
return times
def main():
times = None
times = driver()
times = np.asarray(times, dtype=np.float32)
t = Template("Average time of $runs is = ${timing}")
tstr = t.substitute(runs=args.r, timing=times.mean())
print(tstr)
print("Done...")
if __name__ == "__main__":
main()