Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Changed diag operator tests to use np.diag() as comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
ifeherva committed Jul 12, 2018
1 parent 5d99a8b commit b764ef9
Showing 1 changed file with 5 additions and 39 deletions.
44 changes: 5 additions & 39 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7006,36 +7006,22 @@ def test_diag():

# k == 0
r = mx.nd.diag(a)

for i in range(r.shape[0]):
assert r[i] == a[i][i]
assert_almost_equal(r.asnumpy(), np.diag(a_np))

# k == 1
k = 1
r = mx.nd.diag(a, k=k)

for i in range(max(r.shape[0]-k, 0)):
assert r[i] == a[i][i+k]
assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))

# k == -1
k = -1
r = mx.nd.diag(a, k=k)

for i in range(max(r.shape[0]+k, 0)):
assert r[i] == a[i-k][i]
assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))

# random k
k = np.random.randint(-min(h,w),max(h,w))
r = mx.nd.diag(a, k=k)
if k > 0:
for i in range(max(r.shape[0]-k, 0)):
assert r[i] == a[i][i+k]
elif k < 0:
for i in range(max(r.shape[0]+k, 0)):
assert r[i] == a[i-k][i]
else:
for i in range(max(r.shape[0], 0)):
assert r[i] == a[i][i+k]
assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))

# Test 2d backward, k=0
data = mx.sym.Variable('data')
Expand All @@ -7061,27 +7047,7 @@ def test_diag():
k = np.random.randint(-d,d)
r = mx.nd.diag(a, k=k)

if k > 0:
for i in range(r.shape[0]):
for j in range(r.shape[1]):
if i + k == j:
assert r[i][j] == a[i]
else:
assert r[i][j] == 0
elif k < 0:
for i in range(r.shape[0]):
for j in range(r.shape[1]):
if i == j - k:
assert r[i][j] == a[j]
else:
assert r[i][j] == 0
else:
for i in range(r.shape[0]):
for j in range(r.shape[1]):
if i == j:
assert r[i][j] == a[i]
else:
assert r[i][j] == 0
assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))

# Test 2d backward, k=0
data = mx.sym.Variable('data')
Expand Down

0 comments on commit b764ef9

Please sign in to comment.