This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
test_ndarray.py
2066 lines (1803 loc) · 84.5 KB
/
test_ndarray.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import mxnet as mx
import numpy as np
from distutils.version import LooseVersion
from itertools import permutations, combinations_with_replacement
import os
import pickle as pkl
import random
import functools
import pytest
from common import assertRaises, TemporaryDirectory
from mxnet.test_utils import almost_equal
from mxnet.test_utils import assert_almost_equal, assert_exception
from mxnet.test_utils import default_device
from mxnet.test_utils import np_reduce
from mxnet.test_utils import same
from mxnet.test_utils import random_sample, rand_shape_nd, random_arrays
from mxnet import runtime
from numpy.testing import assert_allclose, assert_array_equal, assert_array_almost_equal
import mxnet.autograd
from mxnet.base import integer_types
from mxnet.ndarray.ndarray import py_slice
from mxnet.amp.amp import bfloat16
def check_with_uniform(uf, arg_shapes, dim=None, npuf=None, rmin=-10, type_list=[np.float32]):
"""check function consistency with uniform random numbers"""
if isinstance(arg_shapes, int):
assert dim
shape = tuple(np.random.randint(1, int(1000**(1.0/dim)), size=dim))
arg_shapes = [shape] * arg_shapes
if npuf is None:
npuf = uf
for dtype in type_list:
ndarray_arg = []
numpy_arg = []
for s in arg_shapes:
npy = np.random.uniform(rmin, 10, s).astype(dtype)
narr = mx.nd.array(npy, dtype=dtype)
ndarray_arg.append(narr)
numpy_arg.append(npy)
out1 = uf(*ndarray_arg)
out2 = npuf(*numpy_arg).astype(dtype)
assert out1.shape == out2.shape
if isinstance(out1, mx.nd.NDArray):
out1 = out1.asnumpy()
if dtype == np.float16:
assert_almost_equal(out1, out2, rtol=2e-3, atol=1e-5)
else:
assert_almost_equal(out1, out2, atol=1e-5)
def random_ndarray(dim):
shape = tuple(np.random.randint(1, int(1000**(1.0/dim)), size=dim))
data = mx.nd.array(np.random.uniform(-10, 10, shape))
return data
def test_ndarray_setitem():
shape = (3, 4, 2)
# scalar assignment
x = mx.nd.zeros(shape)
x[:] = 1
x_np = np.ones(shape, dtype=x.dtype)
assert same(x.asnumpy(), x_np)
# ndarray assignment
x = mx.nd.zeros(shape)
x[:] = mx.nd.ones(shape)
x_np = np.ones(shape, dtype=x.dtype)
assert same(x.asnumpy(), x_np)
# numpy assignment
x = mx.nd.zeros(shape)
x[:] = np.ones(shape)
x_np = np.ones(shape, dtype=x.dtype)
assert same(x.asnumpy(), x_np)
# indexing sub-arrays
x = mx.nd.zeros(shape)
x[1] = 1
x_np = np.zeros(shape, dtype=x.dtype)
x_np[1] = 1
assert same(x.asnumpy(), x_np)
x[-1] = 1
x_np[-1] = 1
assert same(x.asnumpy(), x_np)
# Ellipsis
x = mx.nd.zeros(shape)
x[2, ...] = 1
x_np = np.zeros(shape, dtype=x.dtype)
x_np[2, ...] = 1
assert same(x.asnumpy(), x_np)
x = mx.nd.zeros(shape)
x[..., 1] = 1
x_np = np.zeros(shape, dtype=x.dtype)
x_np[..., 1] = 1
assert same(x.asnumpy(), x_np)
# `None` should be ignored
x = mx.nd.zeros(shape)
x[None, 0, None, None, 0, 0, None] = 1
x_np = np.zeros(shape, dtype=x.dtype)
x_np[None, 0, None, None, 0, 0, None] = 1
assert same(x.asnumpy(), x_np)
# short all-dim indexing
x = mx.nd.zeros(shape)
val = mx.nd.ones((3, 2))
x[:, 1:3, 1] = val
x_np = np.zeros(shape, dtype=x.dtype)
x_np[:, 1:3, 1] = val.asnumpy()
assert same(x.asnumpy(), x_np)
x[:, 1:3, -1] = val
x_np[:, 1:3, -1] = val.asnumpy()
assert same(x.asnumpy(), x_np)
x = mx.nd.zeros(shape)
x[:, 1:3, 1:2] = 1
x_np = np.zeros(shape, dtype=x.dtype)
x_np[:, 1:3, 1:2] = 1
assert same(x.asnumpy(), x_np)
x[:, -3:-1, -2:-1] = 1
x_np[:, -3:-1, -2:-1] = 1
assert same(x.asnumpy(), x_np)
# Assignments for empty axes
for trivial_shape in [(1,), (1, 1), (1, 1, 1)]:
x = mx.nd.zeros(trivial_shape)
x[:] = np.ones(trivial_shape)
x_np = np.ones(trivial_shape, dtype=x.dtype)
assert x.shape == trivial_shape
assert same(x.asnumpy(), x_np)
# test https://github.com/apache/mxnet/issues/16647
dst = mx.nd.zeros((1, 3, 1)) # destination array
src = [1, 2, 3]
dst[0, :len(src), 0] = src
assert same(dst.asnumpy(), np.array([1, 2, 3], dtype=dst.dtype).reshape(dst.shape))
dst = mx.nd.zeros((1, 3, 1)) # destination array
src = [1, 2, 3]
dst[0, :len(src), 0] = mx.nd.array(src)
assert same(dst.asnumpy(), np.array([1, 2, 3], dtype=dst.dtype).reshape(dst.shape))
dst = mx.nd.zeros((1, 3, 1)) # destination array
src = [1, 2]
dst[0, :len(src), 0] = src
assert same(dst.asnumpy(), np.array([1, 2, 0], dtype=dst.dtype).reshape(dst.shape))
def test_ndarray_elementwise():
nrepeat = 10
maxdim = 4
all_type = [np.float32, np.float64, np.float16, np.uint8, np.int8, np.int32, np.int64]
real_type = [np.float32, np.float64, np.float16]
for _ in range(nrepeat):
for dim in range(1, maxdim):
check_with_uniform(lambda x, y: x + y, 2, dim, type_list=all_type)
check_with_uniform(lambda x, y: x - y, 2, dim, type_list=all_type)
check_with_uniform(lambda x, y: x * y, 2, dim, type_list=all_type)
check_with_uniform(lambda x, y: x / y, 2, dim, type_list=real_type)
check_with_uniform(lambda x, y: x / y, 2, dim, rmin=1, type_list=all_type)
check_with_uniform(mx.nd.sqrt, 1, dim, np.sqrt, rmin=0)
check_with_uniform(mx.nd.square, 1, dim, np.square, rmin=0)
check_with_uniform(lambda x: mx.nd.norm(x).asscalar(), 1, dim, np.linalg.norm)
def test_ndarray_elementwisesum():
ones = mx.nd.ones((10,), dtype=np.int32)
res = mx.nd.ElementWiseSum(ones, ones*2, ones*4, ones*8)
assert same(res.asnumpy(), ones.asnumpy()*15)
def test_ndarray_negate():
npy = np.random.uniform(-10, 10, (2,3,4))
arr = mx.nd.array(npy)
assert_almost_equal(npy, arr.asnumpy())
assert_almost_equal(-npy, (-arr).asnumpy())
# a final check to make sure the negation (-) is not implemented
# as inplace operation, so the contents of arr does not change after
# we compute (-arr)
assert_almost_equal(npy, arr.asnumpy())
def test_ndarray_magic_abs():
for dim in range(1, 7):
shape = rand_shape_nd(dim)
npy = np.random.uniform(-10, 10, shape)
arr = mx.nd.array(npy)
assert_almost_equal(abs(arr).asnumpy(), arr.abs().asnumpy())
def test_ndarray_reshape():
tensor = (mx.nd.arange(30) + 1).reshape(2, 3, 5)
true_res = mx.nd.arange(30) + 1
assert same(tensor.reshape((-1,)).asnumpy(), true_res.asnumpy())
assert same(tensor.reshape((2, -1)).asnumpy(), true_res.reshape(2, 15).asnumpy())
assert same(tensor.reshape((0, -1)).asnumpy(), true_res.reshape(2, 15).asnumpy())
assert same(tensor.reshape((-1, 2)).asnumpy(), true_res.reshape(15, 2).asnumpy())
assert same(tensor.reshape(6, 5).asnumpy(), true_res.reshape(6, 5).asnumpy())
assert same(tensor.reshape(-1, 2).asnumpy(), true_res.reshape(15, 2).asnumpy())
assert same(tensor.reshape(-1).asnumpy(), true_res.asnumpy())
assert same(tensor.reshape(30).asnumpy(), true_res.asnumpy())
assert same(tensor.reshape(0, -1).asnumpy(), true_res.reshape(2, 15).asnumpy())
assert same(tensor.reshape(-1, 6).asnumpy(), true_res.reshape(5, 6).asnumpy())
assert same(tensor.reshape(-2,).asnumpy(), true_res.reshape(2, 3, 5).asnumpy())
assert same(tensor.reshape(-3, -1).asnumpy(), true_res.reshape(6, 5).asnumpy())
assert same(tensor.reshape(-1, 15).reshape(0, -4, 3, -1).asnumpy(), true_res.reshape(2, 3, 5).asnumpy())
assert same(tensor.reshape(-1, 0).asnumpy(), true_res.reshape(10, 3).asnumpy())
assert same(tensor.reshape(-1, 0, reverse=True).asnumpy(), true_res.reshape(6, 5).asnumpy())
# https://github.com/apache/mxnet/issues/18886
assertRaises(ValueError, tensor.reshape, (2, 3))
def test_ndarray_flatten():
tensor = (mx.nd.arange(30) + 1).reshape(2, 3, 5)
copy = tensor.flatten()
ref = tensor.flatten(inplace=True)
assert same(copy.asnumpy(), tensor.reshape(2, 15).asnumpy())
assert same(ref.asnumpy(), tensor.reshape(2, 15).asnumpy())
tensor[0] = -1
assert not same(copy.asnumpy(), tensor.reshape(2, 15).asnumpy())
assert same(ref.asnumpy(), tensor.reshape(2, 15).asnumpy())
def test_ndarray_squeeze():
def check_squeeze(shape, axis=None):
data = mx.random.uniform(low=-10.0, high=10.0, shape=shape)
copy = data.squeeze(axis=axis)
ref = data.squeeze(axis=axis, inplace=True)
out_expected = np.squeeze(data.asnumpy(), axis=axis)
if copy.shape == (1,): # as an exception (1, 1, 1) will be squeezed to (1,)
out_expected = np.squeeze(data.asnumpy(), axis=tuple([i for i in range(1, len(shape))]))
assert same(copy.asnumpy(), out_expected)
assert same(ref.asnumpy(), out_expected)
data[0][0] = -1
assert same(copy.asnumpy(), out_expected)
assert not same(ref.asnumpy(), out_expected)
# check forward
check_squeeze((1, 5, 1, 3, 1), 0)
check_squeeze((1, 5, 1, 3, 1), 2)
check_squeeze((1, 5, 1, 3, 1), 4)
check_squeeze((1, 5, 1, 3, 1), (0, 4))
check_squeeze((1, 5, 1, 3, 1), (0, 2, 4))
check_squeeze((1, 5, 1, 3, 1), -5)
check_squeeze((1, 5, 1, 3, 1), -3)
check_squeeze((1, 5, 1, 3, 1), -1)
check_squeeze((1, 5, 1, 3, 1), (0, 4))
check_squeeze((1, 5, 1, 3, 1), (0, 2, 4))
check_squeeze((1, 5, 1, 3, 1))
check_squeeze((1, 1, 1, 1))
def test_ndarray_expand_dims():
for ndim in range(1, 6):
for axis in range(-ndim-1, ndim+1):
shape = list(np.random.randint(1, 10, size=ndim))
data = mx.random.normal(shape=shape)
copy = data.expand_dims(axis=axis)
ref = data.expand_dims(axis=axis, inplace=True)
out_expected = np.expand_dims(data.asnumpy(), axis=axis)
assert same(copy.asnumpy(), out_expected)
assert same(ref.asnumpy(), out_expected), (shape, axis, ref.asnumpy().shape, out_expected.shape)
data[0] = -1
assert same(copy.asnumpy(), out_expected)
assert not same(ref.asnumpy(), out_expected)
def test_ndarray_choose():
shape = (100, 20)
npy = np.arange(np.prod(shape)).reshape(shape)
arr = mx.nd.array(npy)
nrepeat = 3
for _ in range(nrepeat):
indices = np.random.randint(shape[1], size=shape[0])
assert same(npy[np.arange(shape[0]), indices],
mx.nd.choose_element_0index(arr, mx.nd.array(indices)).asnumpy())
def test_ndarray_fill():
shape = (100, 20)
npy = np.arange(np.prod(shape)).reshape(shape)
arr = mx.nd.array(npy)
new_npy = npy.copy()
nrepeat = 3
for _ in range(nrepeat):
indices = np.random.randint(shape[1], size=shape[0])
val = np.random.randint(shape[1], size=shape[0])
new_npy[:] = npy
new_npy[np.arange(shape[0]), indices] = val
assert same(new_npy,
mx.nd.fill_element_0index(arr, mx.nd.array(val), mx.nd.array(indices)).asnumpy())
def test_ndarray_onehot():
shape = (100, 20)
npy = np.arange(np.prod(shape)).reshape(shape)
arr = mx.nd.array(npy)
nrepeat = 3
for _ in range(nrepeat):
indices = np.random.randint(shape[1], size=shape[0])
npy[:] = 0.0
npy[np.arange(shape[0]), indices] = 1.0
mx.nd.onehot_encode(mx.nd.array(indices), out=arr)
assert same(npy, arr.asnumpy())
def test_init_from_scalar():
npy = np.ones([])
arr = mx.nd.array(npy)
assert arr.shape == ()
assert same(npy, arr.asnumpy())
def test_ndarray_copy():
c = mx.nd.array(np.random.uniform(-10, 10, (10, 10)))
d = c.copyto(mx.Context('cpu', 0))
assert np.sum(np.abs(c.asnumpy() != d.asnumpy())) == 0.0
def test_ndarray_scalar():
c = mx.nd.empty((10,10))
d = mx.nd.empty((10,10))
c[:] = 0.5
d[:] = 1.0
d -= c * 2 / 3 * 6.0
c += 0.5
assert(np.sum(c.asnumpy()) - 100 < 1e-5)
assert(np.sum(d.asnumpy()) + 100 < 1e-5)
c[:] = 2
assert(np.sum(c.asnumpy()) - 200 < 1e-5)
d = -c + 2
assert(np.sum(d.asnumpy()) < 1e-5)
def test_ndarray_pickle():
maxdim = 5
for dim in range(1, maxdim):
a = random_ndarray(dim)
b = mx.nd.empty(a.shape)
a[:] = np.random.uniform(-10, 10, a.shape)
b[:] = np.random.uniform(-10, 10, a.shape)
a = a + b
data = pkl.dumps(a)
a2 = pkl.loads(data)
assert np.sum(a.asnumpy() != a2.asnumpy()) == 0
@pytest.mark.parametrize('save_fn', [mx.nd.save, mx.npx.savez])
def test_ndarray_saveload(save_fn):
nrepeat = 10
fname = 'tmp_list'
for _ in range(nrepeat):
data = []
# test save/load as list
for _ in range(10):
data.append(random_ndarray(np.random.randint(1, 5)))
if save_fn is mx.nd.save:
save_fn(fname, data)
else:
save_fn(fname, *data)
data2 = mx.nd.load(fname)
assert len(data) == len(data2)
for x, y in zip(data, data2 if save_fn is mx.nd.save else data2.values()):
assert np.sum(x.asnumpy() != y.asnumpy()) == 0
# test save/load as dict
dmap = {f'ndarray xx {i}' : x for i, x in enumerate(data)}
if save_fn is mx.nd.save:
save_fn(fname, dmap)
else:
save_fn(fname, **dmap)
dmap2 = mx.nd.load(fname)
assert len(dmap2) == len(dmap)
for k, x in dmap.items():
y = dmap2[k]
assert np.sum(x.asnumpy() != y.asnumpy()) == 0
# test save/load as ndarray
# we expect the single ndarray to be converted into a list containing the ndarray
single_ndarray = data[0]
save_fn(fname, single_ndarray)
# Test loading with numpy
if save_fn is mx.npx.savez:
with np.load(fname) as fname_np_loaded:
single_ndarray_loaded = fname_np_loaded['arr_0']
assert np.sum(single_ndarray.asnumpy() != single_ndarray_loaded) == 0
mx.npx.save(fname, single_ndarray)
single_ndarray_loaded = np.load(fname)
assert np.sum(single_ndarray.asnumpy() != single_ndarray_loaded) == 0
# Test loading with mxnet backend
single_ndarray_loaded = mx.nd.load(fname)
assert len(single_ndarray_loaded) == 1
single_ndarray_loaded = single_ndarray_loaded[0]
assert np.sum(single_ndarray.asnumpy() != single_ndarray_loaded.asnumpy()) == 0
os.remove(fname)
@mx.util.use_np
def test_ndarray_load_fortran_order(tmp_path):
arr = np.arange(20).reshape((2, 10)).T
assert np.isfortran(arr)
np.save(tmp_path / 'fortran_order.npy', arr)
mx_arr = mx.npx.load(str(tmp_path / 'fortran_order.npy'))
np_mx_arr = mx_arr.asnumpy()
assert not np.isfortran(np_mx_arr)
assert np.sum(np_mx_arr != arr) == 0
def test_ndarray_legacy_load():
data = []
for _ in range(6):
data.append(mx.nd.arange(128))
path = os.path.dirname(os.path.realpath(__file__))
legacy_data = mx.nd.load(os.path.join(path, 'legacy_ndarray.v0'))
assert len(data) == len(legacy_data)
for i in range(len(data)):
assert same(data[i].asnumpy(), legacy_data[i].asnumpy())
def test_buffer_load():
nrepeat = 10
with TemporaryDirectory(prefix='test_buffer_load_') as tmpdir:
for repeat in range(nrepeat):
# test load_buffer as list
data = []
for _ in range(10):
data.append(random_ndarray(np.random.randint(1, 5)))
fname = os.path.join(tmpdir, 'list_{0}.param'.format(repeat))
mx.nd.save(fname, data)
with open(fname, 'rb') as dfile:
buf_data = dfile.read()
data2 = mx.nd.load_frombuffer(buf_data)
assert len(data) == len(data2)
for x, y in zip(data, data2):
assert np.sum(x.asnumpy() != y.asnumpy()) == 0
# test garbage values
assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer, buf_data[:-10])
# test load_buffer as dict
dmap = {f'ndarray xx {i}' : x for i, x in enumerate(data)}
fname = os.path.join(tmpdir, 'dict_{0}.param'.format(repeat))
mx.nd.save(fname, dmap)
with open(fname, 'rb') as dfile:
buf_dmap = dfile.read()
dmap2 = mx.nd.load_frombuffer(buf_dmap)
assert len(dmap2) == len(dmap)
for k, x in dmap.items():
y = dmap2[k]
assert np.sum(x.asnumpy() != y.asnumpy()) == 0
# test garbage values
assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer, buf_dmap[:-10])
# we expect the single ndarray to be converted into a list containing the ndarray
single_ndarray = data[0]
fname = os.path.join(tmpdir, 'single_{0}.param'.format(repeat))
mx.nd.save(fname, single_ndarray)
with open(fname, 'rb') as dfile:
buf_single_ndarray = dfile.read()
single_ndarray_loaded = mx.nd.load_frombuffer(buf_single_ndarray)
assert len(single_ndarray_loaded) == 1
single_ndarray_loaded = single_ndarray_loaded[0]
assert np.sum(single_ndarray.asnumpy() != single_ndarray_loaded.asnumpy()) == 0
# test garbage values
assertRaises(mx.base.MXNetError, mx.nd.load_frombuffer, buf_single_ndarray[:-10])
@pytest.mark.serial
def test_ndarray_slice():
shape = (10,)
A = mx.nd.array(np.random.uniform(-10, 10, shape))
A2 = A.asnumpy()
assert same(A[3:8].asnumpy(), A2[3:8])
A2[3:8] *= 10
A[3:8] = A2[3:8]
assert same(A[3:8].asnumpy(), A2[3:8])
shape = (3,4,5,6,7)
A = mx.nd.random.uniform(shape=shape)
A2 = A.asnumpy()
assert same(A[1,3:4,:,1:5].asnumpy(), A2[1,3:4,:,1:5])
assert A[1,2,3,4,5].asscalar() == A2[1,2,3,4,5]
assert A[-1,-2,-3,-4,-5].asscalar() == A2[-1,-2,-3,-4,-5]
a = mx.nd.array([[0, 1], [2, 3]])
assert (a[[1, 1, 0], [0, 1, 0]].asnumpy() == [2, 3, 0]).all()
assert (a[mx.nd.array([1, 1, 0]), mx.nd.array([0, 1, 0])].asnumpy() == [2, 3, 0]).all()
shape = (4, 4)
A = mx.nd.random.uniform(shape=shape)
A2 = A.asnumpy()
for i in range(-4, 0):
assert A[i, i].asscalar() == A2[i, i]
assert same(A[:, i].asnumpy(), A2[:, i])
assert same(A[i, :].asnumpy(), A2[i, :])
def test_ndarray_crop():
# get crop
x = mx.nd.ones((2, 3, 4))
y = mx.nd.crop(x, begin=(0, 0, 0), end=(2, 1, 3))
assert same(y.asnumpy(), np.ones((2, 1, 3), dtype=y.dtype))
# crop assign
z = mx.nd.zeros((2, 1, 3))
mx.nd._internal._crop_assign(x, z, begin=(0, 0, 0),
end=(2, 1, 3), out=x)
np_x = np.ones(x.shape, dtype=x.dtype)
np_x[0:2, 0:1, 0:3] = 0
assert same(x.asnumpy(), np_x)
# crop assign with scalar
x = mx.nd.ones((2, 3, 4))
mx.nd._internal._crop_assign_scalar(x, scalar=5,
begin=(0, 0, 0),
end=(2, 1, 3), out=x)
np_x = np.ones(x.shape, dtype=x.dtype)
np_x[0:2, 0:1, 0:3] = 5
assert same(x.asnumpy(), np_x)
@pytest.mark.serial
def test_ndarray_concatenate():
axis = 1
shapes = [(2, 3, 4, 2), (2, 2, 4, 2), (2, 1, 4, 2)]
arrays_np = [np.random.uniform(-10, 10, s).astype(np.float32) for s in shapes]
arrays_nd = [mx.nd.array(x) for x in arrays_np]
array_nd = mx.nd.concatenate(arrays_nd, axis=axis)
array_np = np.concatenate(arrays_np, axis=axis)
assert same(array_np, array_nd.asnumpy())
def test_clip():
shape = (10,)
A = mx.random.uniform(-10, 10, shape)
B = mx.nd.clip(A, -2, 2)
B1 = B.asnumpy()
for i in range(shape[0]):
assert B1[i] >= -2
assert B1[i] <= 2
def test_dot():
# Non-zero atol required, as exposed by seed 828791701
atol = 1e-5
# Test normal dot
a = np.random.uniform(-3, 3, (3, 4))
b = np.random.uniform(-3, 3, (4, 5))
c = np.dot(a, b)
A = mx.nd.array(a)
B = mx.nd.array(b)
C = mx.nd.dot(A, B)
assert_almost_equal(c, C.asnumpy(), atol=atol)
# Test dot with transpose kargs
a = np.random.uniform(-3, 3, (3, 4))
b = np.random.uniform(-3, 3, (3, 5))
c = np.dot(a.T, b)
A = mx.nd.array(a)
B = mx.nd.array(b)
C = mx.nd.dot(A, B, transpose_a=True)
assert_almost_equal(c, C.asnumpy(), atol=atol)
# Test dot with transpose kargs
a = np.random.uniform(-3, 3, (3, 4))
b = np.random.uniform(-3, 3, (5, 4))
c = np.dot(a, b.T)
A = mx.nd.array(a)
B = mx.nd.array(b)
C = mx.nd.dot(A, B, transpose_b=True)
assert_almost_equal(c, C.asnumpy(), atol=atol)
# Test dot with transpose kargs
a = np.random.uniform(-3, 3, (4, 3))
b = np.random.uniform(-3, 3, (5, 4))
c = np.dot(a.T, b.T)
A = mx.nd.array(a)
B = mx.nd.array(b)
C = mx.nd.dot(A, B, transpose_a=True, transpose_b=True)
assert_almost_equal(c, C.asnumpy(), atol=atol)
@pytest.mark.serial
def test_reduce():
sample_num = 300
def test_reduce_inner(numpy_reduce_func, nd_reduce_func, multi_axes,
allow_almost_equal=False, check_dtype=True):
dtypes = [(np.float16, 1),
(np.float32, 4),
(np.double, 6)]
for _ in range(sample_num):
dtype, decimal = random.choice(dtypes)
ndim = np.random.randint(1, 6)
shape = np.random.randint(1, 11, size=ndim)
dat = (np.random.rand(*shape) - 0.5).astype(dtype)
keepdims = np.random.randint(0, 2)
allow_nan = np.random.randint(0, 2)
if allow_nan:
total_nans = np.random.randint(0, dat.size//10+1)
dat.ravel()[np.random.choice(
dat.size, total_nans, replace=False)] = np.nan
allow_inf = np.random.randint(0, 2)
if allow_inf:
r = np.random.randint(0, 3)
total_infs = np.random.randint(0, dat.size//20+1)
if r == 0:
total_pos_infs, total_neg_infs = total_infs, 0
elif r == 1:
total_pos_infs, total_neg_infs = 0, total_infs
else:
total_pos_infs = total_neg_infs = total_infs // 2
dat.ravel()[np.random.choice(
dat.size, total_pos_infs, replace=False)] = np.inf
dat.ravel()[np.random.choice(
dat.size, total_neg_infs, replace=False)] = -np.inf
if multi_axes:
axis_flags = np.random.randint(0, 2, size=ndim)
axes = []
for (axis, flag) in enumerate(axis_flags):
if flag:
axes.append(axis)
if 0 == len(axes):
axes = tuple(range(ndim))
else:
axes = tuple(axes)
else:
axes = np.random.randint(0, ndim)
numpy_ret = numpy_reduce_func(dat, axis=axes, keepdims=keepdims)
mx_arr = mx.nd.array(dat, dtype=dtype)
ndarray_ret = nd_reduce_func(mx_arr, axis=axes, keepdims=keepdims)
if type(ndarray_ret) is mx.ndarray.NDArray:
ndarray_ret = ndarray_ret.asnumpy()
assert (ndarray_ret.shape == numpy_ret.shape) or \
(ndarray_ret.shape == (1,) and numpy_ret.shape == ()), \
f"nd:{ndarray_ret.shape}, numpy:{numpy_ret.shape}"
if check_dtype:
assert ndarray_ret.dtype == numpy_ret.dtype,\
(ndarray_ret.dtype, numpy_ret.dtype)
if allow_almost_equal:
assert_array_almost_equal(ndarray_ret, numpy_ret, decimal=decimal)
else:
assert_array_equal(ndarray_ret, numpy_ret)
test_reduce_inner(lambda data, axis, keepdims:np_reduce(data, axis, keepdims, np.sum),
mx.nd.sum, True, allow_almost_equal=True)
test_reduce_inner(lambda data, axis, keepdims:np_reduce(data, axis, keepdims, np.max),
mx.nd.max, True)
test_reduce_inner(lambda data, axis, keepdims:np_reduce(data, axis, keepdims, np.min),
mx.nd.min, True)
test_reduce_inner(lambda data, axis, keepdims:np_reduce(data, axis, keepdims, np.argmax),
mx.nd.argmax, False, check_dtype=False)
test_reduce_inner(lambda data, axis, keepdims:np_reduce(data, axis, keepdims, np.argmin),
mx.nd.argmin, False, check_dtype=False)
@pytest.mark.serial
def test_broadcast():
sample_num = 1000
def test_broadcast_to():
for _ in range(sample_num):
ndim = np.random.randint(1, 6)
target_shape = np.random.randint(1, 11, size=ndim)
shape = target_shape.copy()
axis_flags = np.random.randint(0, 2, size=ndim)
for (axis, flag) in enumerate(axis_flags):
if flag:
shape[axis] = 1
dat = np.random.rand(*shape) - 0.5
numpy_ret = dat
ndarray_ret = mx.nd.array(dat).broadcast_to(shape=target_shape)
if type(ndarray_ret) is mx.ndarray.NDArray:
ndarray_ret = ndarray_ret.asnumpy()
assert (ndarray_ret.shape == target_shape).all()
err = np.square(ndarray_ret - numpy_ret).mean()
assert err < 1E-8
def test_broadcast_like():
for _ in range(sample_num):
ndim = np.random.randint(1, 6)
target_shape = np.random.randint(1, 11, size=ndim)
target = mx.nd.ones(shape=tuple(target_shape))
shape = target_shape.copy()
axis_flags = np.random.randint(0, 2, size=ndim)
for (axis, flag) in enumerate(axis_flags):
if flag:
shape[axis] = 1
dat = np.random.rand(*shape) - 0.5
numpy_ret = dat
ndarray_ret = mx.nd.array(dat).broadcast_like(target)
if type(ndarray_ret) is mx.ndarray.NDArray:
ndarray_ret = ndarray_ret.asnumpy()
assert (ndarray_ret.shape == target_shape).all()
err = np.square(ndarray_ret - numpy_ret).mean()
assert err < 1E-8
def test_broadcast_like_axis():
testcases = [
# Lhs shape, rhs shape, lhs axis, rhs axis, result
[(1, 2, 1, 3), (5, 6, 7, 8), (0,2), (1,3), (6, 2, 8, 3)],
[(1,), (5,), (0,), (-1,), (5,)],
[(1, 7, 9, 1, 1), (9,), (-2,), (0,), (1, 7, 9, 9, 1)],
[(1, 7, 9, 1, 1), (9, 1), (-2, -1), (-2, -1), (1, 7, 9, 9, 1)],
[(2, 1), (1, 7, 9, 1, 1), (1,), (-3,), (2, 9)]
]
for test_data in testcases:
lhs = mx.nd.random.uniform(shape=test_data[0])
rhs = mx.nd.random.uniform(shape=test_data[1])
output = mx.nd.broadcast_like(lhs, rhs, lhs_axes=test_data[2], rhs_axes=test_data[3])
assert_exception(mx.nd.broadcast_like, mx.base.MXNetError, lhs, rhs, lhs_axes=(), rhs_axes=())
assert output.shape == test_data[4]
test_broadcast_to()
test_broadcast_like()
test_broadcast_like_axis()
@pytest.mark.serial
def test_broadcast_binary():
N = 100
def check_broadcast_binary(fn):
for _ in range(N):
ndim = np.random.randint(1, 6)
oshape = np.random.randint(1, 6, size=(ndim,))
bdim = np.random.randint(1, ndim+1)
lshape = list(oshape)
rshape = list(oshape[ndim-bdim:])
for i in range(bdim):
sep = np.random.uniform(0, 1)
if sep < 0.33:
lshape[ndim-i-1] = 1
elif sep < 0.66:
rshape[bdim-i-1] = 1
lhs = np.random.normal(0, 1, size=lshape)
rhs = np.random.normal(0, 1, size=rshape)
assert_allclose(fn(lhs, rhs),
fn(mx.nd.array(lhs), mx.nd.array(rhs)).asnumpy(),
rtol=1e-4, atol=1e-4)
check_broadcast_binary(lambda x, y: x + y)
check_broadcast_binary(lambda x, y: x - y)
check_broadcast_binary(lambda x, y: x * y)
check_broadcast_binary(lambda x, y: x / y)
# The following ops are sensitive to the precision of the calculation.
# Force numpy to match mxnet's float32.
check_broadcast_binary(lambda x, y: x.astype(np.float32) > y.astype(np.float32))
check_broadcast_binary(lambda x, y: x.astype(np.float32) < y.astype(np.float32))
check_broadcast_binary(lambda x, y: x.astype(np.float32) >= y.astype(np.float32))
check_broadcast_binary(lambda x, y: x.astype(np.float32) <= y.astype(np.float32))
check_broadcast_binary(lambda x, y: x.astype(np.float32) == y.astype(np.float32))
def test_moveaxis():
X = mx.nd.array([[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9], [10, 11, 12]]])
res = mx.nd.moveaxis(X, 0, 2).asnumpy()
true_res = mx.nd.array([[[ 1., 7.],
[ 2., 8.],
[ 3., 9.]],
[[ 4., 10.],
[ 5., 11.],
[ 6., 12.]]])
assert same(res, true_res.asnumpy())
assert mx.nd.moveaxis(X, 2, 0).shape == (3, 2, 2)
def test_move_to_end():
x = mx.nd.random.normal(0, 1, (5, 6, 7))
for source, expected in [(0, (6, 7, 5)),
(1, (5, 7, 6)),
(2, (5, 6, 7)),
(-1, (5, 6, 7))]:
actual = mx.nd.moveaxis(x, source, -1).shape
assert actual == expected
def test_move_new_position():
x = mx.nd.random.normal(0, 1, (1, 2, 3, 4))
for source, destination, expected in [
(0, 1, (2, 1, 3, 4)),
(1, 2, (1, 3, 2, 4)),
(1, -1, (1, 3, 4, 2)),
]:
actual = mx.nd.moveaxis(x, source, destination).shape
assert actual == expected
def test_preserve_order():
x = mx.nd.zeros((1, 2, 3, 4))
for source, destination in [
(0, 0),
(3, -1),
(-1, 3),
([0, -1], [0, -1]),
([2, 0], [2, 0]),
(range(4), range(4)),
]:
actual = mx.nd.moveaxis(x, source, destination).shape
assert actual == (1, 2, 3, 4)
def test_move_multiples():
x = mx.nd.zeros((4, 1, 2, 3))
for source, destination, expected in [
([0, 1], [2, 3], (2, 3, 4, 1)),
([2, 3], [0, 1], (2, 3, 4, 1)),
([0, 1, 2], [2, 3, 0], (2, 3, 4, 1)),
([3, 0], [1, 0], (4, 3, 1, 2)),
([0, 3], [0, 1], (4, 3, 1, 2)),
]:
actual = mx.nd.moveaxis(x, source, destination).shape
assert actual == expected
def test_errors():
x = mx.nd.random.normal(0, 1, (1, 2, 3))
assert_exception(mx.nd.moveaxis, ValueError, x, 3, 0)
assert_exception(mx.nd.moveaxis, ValueError, x, -4, 0)
assert_exception(mx.nd.moveaxis, ValueError, x, 0, 5)
assert_exception(mx.nd.moveaxis, ValueError, x, [0, 0], [0, 1])
assert_exception(mx.nd.moveaxis, ValueError, x, [0, 1], [1, 1])
assert_exception(mx.nd.moveaxis, ValueError, x, 0, [0, 1])
assert_exception(mx.nd.moveaxis, ValueError, x, [0, 1], [0])
test_move_to_end()
test_move_new_position()
test_preserve_order()
test_move_multiples()
test_errors()
def test_arange():
for _ in range(5):
start = np.random.rand() * 10
stop = start + np.random.rand() * 100
step = np.random.rand() * 4
repeat = int(np.random.rand() * 5) + 1
gt = np.arange(start=start, stop=stop, step=step)
gt = np.broadcast_to(gt.reshape((gt.shape[0], 1)), shape=(gt.shape[0], repeat)).ravel()
pred = mx.nd.arange(start=start, stop=stop, step=step, repeat=repeat).asnumpy()
assert_almost_equal(pred, gt)
gt = np.arange(start=0, stop=10000**2, step=10001, dtype=np.int32)
pred = mx.nd.arange(start=0, stop=10000**2, step=10001,
dtype="int32").asnumpy()
assert_almost_equal(pred, gt)
def test_linspace():
for _ in range(5):
start = np.random.rand() * 100
stop = np.random.rand() * 100
num = np.random.randint(20)
gt = np.linspace(start, stop, num)
pred = mx.nd.linspace(start, stop, num).asnumpy()
assert_almost_equal(pred, gt)
gt = np.linspace(start, stop, num, endpoint=False)
pred = mx.nd.linspace(start, stop, num, endpoint=False).asnumpy()
assert_almost_equal(pred, gt)
gt = np.linspace(start, stop, num, dtype="int32")
pred = mx.nd.linspace(start, stop, num, dtype="int32").asnumpy()
assert_almost_equal(pred, gt)
@pytest.mark.serial
def test_order():
ctx = default_device()
dat_size = 5
is_large_tensor_enabled = runtime.Features().is_enabled('INT64_TENSOR_SIZE')
def gt_topk(dat, axis, ret_typ, k, is_ascend):
if ret_typ == "indices":
if is_ascend:
indices = np.arange(k)
else:
indices = np.arange(-1, -k-1, -1)
ret = np.take(dat.argsort(axis=axis), axis=axis, indices=indices, mode='wrap')
elif ret_typ == "value":
if is_ascend:
indices = np.arange(k)
else:
indices = np.arange(-1, -k-1, -1)
ret = np.take(np.sort(dat, axis=axis), axis=axis, indices=indices, mode='wrap')
else:
assert dat.shape == (dat_size, dat_size, dat_size, dat_size)
assert axis is None or axis ==1
ret = np.zeros(dat.shape)
if is_ascend:
indices = np.arange(k)
else:
indices = np.arange(-1, -k-1, -1)
gt_argsort = np.take(dat.argsort(axis=axis), axis=axis, indices=indices, mode='wrap')
if axis is None:
ret.ravel()[gt_argsort] = 1
else:
for i in range(dat_size):
for j in range(dat_size):
for k in range(dat_size):
ret[i, gt_argsort[i, :, j, k], j, k] = 1
return ret
# Produce input data for the tests, including ensuring unique values if desired.
# Numpy's argsort does not consistently return lowest-index-first for matching
# values, making it hard to generate a numpy 'golden copy' to compare against
# the mxnet operator. The 'mask' function is particularly hard to test given that
# equal values might span the 'k' boundary. Issue exposed with seed 1405838964.
def get_values(ensure_unique, dtype):
if dtype == np.int16 or dtype == np.int32 or dtype == np.int64:
return np.arange(dat_size ** 4, dtype=dtype).reshape((dat_size, dat_size, dat_size, dat_size))
elif dtype == np.float32 or dtype == np.float64:
while True:
data = np.random.normal(size=(dat_size, dat_size, dat_size, dat_size)).astype(dtype)
if not ensure_unique:
return data
num_unique_values = len(set(data.flatten()))
if data.size == num_unique_values:
return data
else:
raise NotImplementedError
# Produce a large matrix (256, 300096) as the input data, to cover the case which
# has a large size of matrix (exceed the express range by float precisly), but
# the number of elements in each dimension could be expressed by float precisly.
def get_large_matrix():
data = np.array([np.arange(300096).astype(np.float32)])
data = np.repeat(data, 100, axis=0)
np.apply_along_axis(np.random.shuffle, 1, data)
return data
large_matrix_npy = get_large_matrix()
large_matrix_nd = mx.nd.array(large_matrix_npy, ctx=ctx, dtype=large_matrix_npy.dtype)
nd_ret_topk = mx.nd.topk(large_matrix_nd, axis=1, ret_typ="indices", k=5, is_ascend=False).asnumpy()
gt = gt_topk(large_matrix_npy, axis=1, ret_typ="indices", k=5, is_ascend=False)
assert_almost_equal(nd_ret_topk, gt)
for dtype in [np.int32, np.int64, np.float32, np.float64]:
a_npy = get_values(ensure_unique=True, dtype=dtype)
a_nd = mx.nd.array(a_npy, ctx=ctx, dtype=dtype)
# test for ret_typ=indices
nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="indices", k=3, is_ascend=True).asnumpy()
# Test the default dtype
assert nd_ret_topk.dtype == np.float32
gt = gt_topk(a_npy, axis=1, ret_typ="indices", k=3, is_ascend=True)
assert_almost_equal(nd_ret_topk, gt)
nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="indices", k=2, is_ascend=False, dtype=np.float64).asnumpy()
assert nd_ret_topk.dtype == np.float64
gt = gt_topk(a_npy, axis=3, ret_typ="indices", k=2, is_ascend=False)
assert_almost_equal(nd_ret_topk, gt)
nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="indices", k=21, is_ascend=False, dtype=np.int32).asnumpy()
assert nd_ret_topk.dtype == np.int32
gt = gt_topk(a_npy, axis=None, ret_typ="indices", k=21, is_ascend=False)
assert_almost_equal(nd_ret_topk, gt)
# test for ret_typ=value
nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="value", k=3, is_ascend=True).asnumpy()
assert nd_ret_topk.dtype == dtype
gt = gt_topk(a_npy, axis=1, ret_typ="value", k=3, is_ascend=True)
assert_almost_equal(nd_ret_topk, gt)
nd_ret_topk = mx.nd.topk(a_nd, axis=3, ret_typ="value", k=2, is_ascend=False).asnumpy()
gt = gt_topk(a_npy, axis=3, ret_typ="value", k=2, is_ascend=False)
assert_almost_equal(nd_ret_topk, gt)
nd_ret_topk = mx.nd.topk(a_nd, axis=None, ret_typ="value", k=21, is_ascend=False).asnumpy()
gt = gt_topk(a_npy, axis=None, ret_typ="value", k=21, is_ascend=False)
assert_almost_equal(nd_ret_topk, gt)
# test for ret_typ=mask
nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=3, is_ascend=True).asnumpy()
assert nd_ret_topk.dtype == dtype
gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=3, is_ascend=True)
assert_almost_equal(nd_ret_topk, gt)
nd_ret_topk = mx.nd.topk(a_nd, axis=1, ret_typ="mask", k=2, is_ascend=False).asnumpy()
gt = gt_topk(a_npy, axis=1, ret_typ="mask", k=2, is_ascend=False)