This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
test_utils.py
2601 lines (2248 loc) · 101 KB
/
test_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tools for testing."""
# pylint: disable=too-many-lines
import time
import gzip
import struct
import traceback
import numbers
import sys
import os
import platform
import errno
import logging
import bz2
import zipfile
import json
from contextlib import contextmanager
from collections import OrderedDict
import numpy as np
import numpy.testing as npt
import numpy.random as rnd
try:
import scipy.stats as ss
except ImportError:
ss = None
try:
import requests
except ImportError:
# in rare cases requests may be not installed
pass
import mxnet as mx
from .device import current_device
from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID, get_dtype_name
from .symbol import Symbol
from .symbol.numpy import _Symbol as np_symbol
from .util import use_np, use_np_default_dtype, getenv, setenv # pylint: disable=unused-import
from .util import get_max_supported_compute_capability, get_rtc_compile_opts # pylint: disable=unused-import
from .runtime import Features
from .numpy_extension import get_cuda_compute_capability
def default_device():
"""Get default device for regression test."""
# _TODO: get device from environment variable to support
# testing with GPUs
return current_device()
def set_default_device(device):
"""Set default device."""
mx.device._current.set(device)
def default_dtype():
"""Get default data type for regression test."""
# _TODO: get default dtype from environment variable
return np.float32
def default_rtols():
"""Get default relative tolerances for data comparisons involving each data type."""
return {np.dtype(np.float16): 1e-2,
np.dtype(np.float32): 1e-4,
np.dtype(np.float64): 1e-5,
np.dtype(np.bool): 0,
np.dtype(np.int8): 0,
np.dtype(np.uint8): 0,
np.dtype(np.int32): 0,
np.dtype(np.uint32): 0,
np.dtype(np.int64): 0,
np.dtype(np.uint64): 0}
def default_atols():
"""Get default absolute tolerances for data comparisons involving each data type."""
return {np.dtype(np.float16): 1e-1,
np.dtype(np.float32): 1e-3,
np.dtype(np.float64): 1e-20,
np.dtype(np.bool): 0,
np.dtype(np.int8): 0,
np.dtype(np.uint8): 0,
np.dtype(np.int32): 0,
np.dtype(np.uint32): 0,
np.dtype(np.int64): 0,
np.dtype(np.uint64): 0}
def default_numeric_eps():
"""Get default epsilon for finite difference gradient calculations with data type."""
# prefer a power-of-two eps, since no bits are dropped when serving as an input delta
return {np.dtype(np.float16): 1.0 / 2**6,
np.dtype(np.float32): 1.0 / 2**9,
np.dtype(np.float64): 1.0 / 2**14}
def effective_dtype(dat):
""" Return the most appropriate dtype for determining the tolerance used in dat comparisons
Parameters
----------
dat : np.ndarray or mx.nd.array or mx.np.ndarray
"""
# On arch 80 gpus or later, a float32-io gemm or conv op will trim the mantissa of
# data inputs to be of comparable precision to a float16, so float16 becomes the
# 'effective dtype' for tolerance tests involving such op outputs.
# Is TF32 enabled in the device (the default on arch 80 GPUs)
def is_TF32_enabled(device):
try:
return (device.device_type == 'gpu' and
get_cuda_compute_capability(device) >= 80 and
os.environ.get('NVIDIA_TF32_OVERRIDE') != '0')
except: # pylint: disable=bare-except
return False
device = dat.device if hasattr(dat, 'device') else None
dtype = np.dtype(dat.dtype)
if dtype == np.dtype(np.float32) and is_TF32_enabled(device):
return np.dtype(np.float16)
else:
return dtype
def get_tolerance(dat, tol, default_tol):
""" Return the tolerance to be used for dat comparisons based on the given tol, datatype and device.
Parameters
----------
dat : np.ndarray or mx.nd.array or mx.np.ndarray
tol : float, or a dict of dtype->float
default_tol : default dict of dtype->float for all types
"""
if isinstance(tol, numbers.Number):
return tol
# If the caller has supplied a tol dict, use that if it has an entry for dtype,
# else use the supplied default tol dict.
dtype = effective_dtype(dat)
tol = {} if tol is None else tol
return tol.get(dtype, default_tol[dtype])
def get_tols(x, y, rtol, atol):
"""For comparing two datasets 'x' and 'y', what relative and absolute tolerances should be used."""
# Tolerance analysis needs 'dtype' of 'x' and 'y', so convert numbers to numpy scalars as needed
if isinstance(x, numbers.Number):
x = np.array(x)
if isinstance(y, numbers.Number):
y = np.array(y)
# If tols are not specified, use the largest default tol for 'x' and 'y' based on their ctx and dtype.
rtol = max(get_tolerance(x, rtol, default_rtols()),
get_tolerance(y, rtol, default_rtols()))
atol = max(get_tolerance(x, atol, default_atols()),
get_tolerance(y, atol, default_atols()))
return rtol, atol
def get_atol(atol=None, dtype=np.dtype(np.float64)):
"""Get default numerical threshold for regression test."""
return default_atols()[dtype] if atol is None else atol
def get_rtol(rtol=None, dtype=np.dtype(np.float64)):
"""Get default numerical threshold for regression test."""
return default_rtols()[dtype] if rtol is None else rtol
def get_etol(etol=None):
"""Get default numerical threshold for regression test."""
# _TODO: get from env variable, different threshold might
# be needed for different device and dtype
return 0 if etol is None else etol
def random_arrays(*shapes):
"""Generate some random numpy arrays."""
arrays = [np.array(np.random.randn(), dtype=default_dtype())
if len(s) == 0 else np.random.randn(*s).astype(default_dtype())
for s in shapes]
if len(arrays) == 1:
return arrays[0]
return arrays
def random_uniform_arrays(*shapes, **kwargs):
"""Generate some random numpy arrays."""
low = kwargs.pop('low', 0.0)
high = kwargs.pop('high', 1.0)
dtype = kwargs.pop('dtype', default_dtype())
if len(kwargs) > 0:
raise TypeError('Got unexpected argument/s : ' + str(kwargs.keys()))
arrays = [np.random.uniform(low, high, size=s).astype(dtype)
for s in shapes]
return arrays
def random_sample(population, k):
"""Return a k length list of the elements chosen from the population sequence."""
assert 0 <= k <= len(population)
population_copy = population[:]
np.random.shuffle(population_copy)
return population_copy[0:k]
def _sorted_items(d):
"""Return (key, value) pairs of dict 'd' in a deterministic order (sorted by key)."""
return sorted(d.items(), key=lambda t: t[0])
def _sorted_dict(d):
"""Return ordered dictionary containing items ordered by their keys."""
return OrderedDict(_sorted_items(d))
def _validate_csr_generation_inputs(num_rows, num_cols, density,
distribution="uniform"):
"""Validates inputs for csr generation helper functions
"""
total_nnz = int(num_rows * num_cols * density)
if density < 0 or density > 1:
raise ValueError("density has to be between 0 and 1")
if num_rows <= 0 or num_cols <= 0:
raise ValueError("num_rows or num_cols should be greater than 0")
if distribution == "powerlaw":
if total_nnz < 2 * num_rows:
raise ValueError(f"not supported for this density: {density}"
f" for this shape ({num_rows}, {num_cols})"
" Please keep :"
" num_rows * num_cols * density >= 2 * num_rows")
def shuffle_csr_column_indices(csr):
"""Shuffle CSR column indices per row
This allows validation of unordered column indices, which is not a requirement
for a valid CSR matrix
"""
row_count = len(csr.indptr) - 1
for i in range(row_count):
start_index = csr.indptr[i]
end_index = csr.indptr[i + 1]
sublist = np.array(csr.indices[start_index : end_index])
np.random.shuffle(sublist)
csr.indices[start_index : end_index] = sublist
def _get_uniform_dataset_csr(num_rows, num_cols, density=0.1, dtype=None,
data_init=None, shuffle_csr_indices=False):
"""Returns CSRNDArray with uniform distribution
This generates a csr matrix with totalnnz unique randomly chosen numbers
from num_rows*num_cols and arranges them in the 2d array in the
following way:
row_index = (random_number_generated / num_rows)
col_index = random_number_generated - row_index * num_cols
"""
_validate_csr_generation_inputs(num_rows, num_cols, density,
distribution="uniform")
try:
from scipy import sparse as spsp
csr = spsp.rand(num_rows, num_cols, density, dtype=dtype, format="csr")
if data_init is not None:
csr.data.fill(data_init)
if shuffle_csr_indices is True:
shuffle_csr_column_indices(csr)
result = mx.nd.sparse.csr_matrix((csr.data, csr.indices, csr.indptr),
shape=(num_rows, num_cols), dtype=dtype)
except ImportError:
assert(data_init is None), \
"data_init option is not supported when scipy is absent"
assert(not shuffle_csr_indices), \
"shuffle_csr_indices option is not supported when scipy is absent"
# scipy not available. try to generate one from a dense array
dns = mx.nd.random.uniform(shape=(num_rows, num_cols), dtype=dtype)
masked_dns = dns * (dns < density)
result = masked_dns.tostype('csr')
return result
def _get_powerlaw_dataset_csr(num_rows, num_cols, density=0.1, dtype=None):
"""Returns CSRNDArray with powerlaw distribution
with exponentially increasing number of non zeros in each row.
Not supported for cases where total_nnz < 2*num_rows. This is because
the algorithm first tries to ensure that there are rows with no zeros by
putting non zeros at beginning of each row.
"""
_validate_csr_generation_inputs(num_rows, num_cols, density,
distribution="powerlaw")
total_nnz = int(num_rows * num_cols * density)
unused_nnz = total_nnz
output_arr = np.zeros((num_rows, num_cols), dtype=dtype)
# Start with ones on each row so that no row is empty
for row in range(num_rows):
output_arr[row][0] = 1 + rnd.uniform(0.001, 2)
unused_nnz = unused_nnz - 1
if unused_nnz <= 0:
return mx.nd.array(output_arr).tostype("csr")
# Populate rest of matrix with 2^i items in ith row.
# if we have used all total nnz return the sparse matrix
# else if we reached max column size then fill up full columns until we use all nnz
col_max = 2
for row in range(num_rows):
col_limit = min(num_cols, col_max)
# In case col_limit reached assign same value to all elements, which is much faster
if col_limit == num_cols and unused_nnz > col_limit:
output_arr[row] = 1 + rnd.uniform(0.001, 2)
unused_nnz = unused_nnz - col_limit + 1
if unused_nnz <= 0:
return mx.nd.array(output_arr).tostype("csr")
else:
continue
for col_index in range(1, col_limit):
output_arr[row][col_index] = 1 + rnd.uniform(0.001, 2)
unused_nnz = unused_nnz - 1
if unused_nnz <= 0:
return mx.nd.array(output_arr).tostype("csr")
col_max = col_max * 2
if unused_nnz > 0:
raise ValueError(f"not supported for this density: {density}"
f" for this shape ({num_rows},{num_cols})")
return mx.nd.array(output_arr).tostype("csr")
def assign_each(the_input, function):
"""Return ndarray composed of passing each array value through some function"""
if function is None:
output = np.array(the_input)
else:
it_input = np.nditer(the_input, flags=['f_index'])
output = np.zeros(the_input.shape)
it_out = np.nditer(output, flags=['f_index'], op_flags=['writeonly'])
while not it_input.finished:
val_input = it_input[0]
it_out[0] = function(val_input)
it_input.iternext()
it_out.iternext()
return output
def assign_each2(input1, input2, function):
"""Return ndarray composed of passing two array values through some function"""
if function is None:
output = np.array(input1)
else:
assert input1.shape == input2.shape
it_input1 = np.nditer(input1, flags=['f_index'])
it_input2 = np.nditer(input2, flags=['f_index'])
output = np.zeros(input1.shape)
it_out = np.nditer(output, flags=['f_index'], op_flags=['writeonly'])
while not it_input1.finished:
val_input1 = it_input1[0]
val_input2 = it_input2[0]
it_out[0] = function(val_input1, val_input2)
it_input1.iternext()
it_input2.iternext()
it_out.iternext()
return output
def create_2d_np_tensor(rows, columns, dtype=np.int64):
inp = mx.np.arange(0, rows, dtype=dtype).reshape(rows, 1)
inp = mx.np.broadcast_to(inp, shape=(inp.shape[0], columns))
return inp
# For testing Large Tensors having total size > 2^32 elements
def create_2d_tensor(rows, columns, dtype=np.int64):
a = mx.nd.arange(0, rows, dtype=dtype).reshape(rows, 1)
b = mx.nd.broadcast_to(a, shape=(a.shape[0], columns))
return b
# For testing Large Vectors having total size > 2^32 elements
def create_vector(size, dtype=np.int64):
a = mx.nd.arange(0, size, dtype=dtype)
return a
def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=None,
data_init=None, rsp_indices=None, modifier_func=None,
shuffle_csr_indices=False, ctx=None):
"""Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np)
Parameters
----------
shape: list or tuple
stype: str
valid values: "csr" or "row_sparse"
density: float, optional
should be between 0 and 1
distribution: str, optional
valid values: "uniform" or "powerlaw"
dtype: numpy.dtype, optional
default value is None
Returns
-------
Result of type CSRNDArray or RowSparseNDArray
Examples
--------
Below is an example of the powerlaw distribution with csr as the stype.
It calculates the nnz using the shape and density.
It fills up the ndarray with exponentially increasing number of elements.
If there are enough unused_nnzs, n+1th row will have twice more nnzs compared to nth row.
else, remaining unused_nnzs will be used in n+1th row
If number of cols is too small and we have already reached column size it will fill up
all following columns in all followings rows until we reach the required density.
>>> csr_arr, _ = rand_sparse_ndarray(shape=(5, 16), stype="csr",
density=0.50, distribution="powerlaw")
>>> indptr = csr_arr.indptr.asnumpy()
>>> indices = csr_arr.indices.asnumpy()
>>> data = csr_arr.data.asnumpy()
>>> row2nnz = len(data[indptr[1]:indptr[2]])
>>> row3nnz = len(data[indptr[2]:indptr[3]])
>>> assert(row3nnz == 2*row2nnz)
>>> row4nnz = len(data[indptr[3]:indptr[4]])
>>> assert(row4nnz == 2*row3nnz)
"""
ctx = ctx if ctx else default_device()
density = rnd.rand() if density is None else density
dtype = default_dtype() if dtype is None else dtype
distribution = "uniform" if distribution is None else distribution
if stype == 'row_sparse':
assert (distribution == "uniform"), \
f"Distribution {distribution} not supported for row_sparse"
# sample index
if rsp_indices is not None:
indices = rsp_indices
assert(len(indices) <= shape[0])
else:
idx_sample = rnd.rand(shape[0])
indices = np.argwhere(idx_sample < density).flatten()
if indices.shape[0] == 0:
result = mx.nd.zeros(shape, stype='row_sparse', dtype=dtype, ctx=ctx)
return result, (np.array([], dtype=dtype), np.array([]))
# generate random values
val = rnd.rand(indices.shape[0], *shape[1:]).astype(dtype)
# Allow caller to override or adjust random values
if data_init is not None:
val.fill(data_init)
if modifier_func is not None:
val = assign_each(val, modifier_func)
arr = mx.nd.sparse.row_sparse_array((val, indices), shape=shape, dtype=dtype, ctx=ctx)
return arr, (val, indices)
elif stype == 'csr':
assert len(shape) == 2
if distribution == "uniform":
csr = _get_uniform_dataset_csr(shape[0], shape[1], density,
data_init=data_init,
shuffle_csr_indices=shuffle_csr_indices, dtype=dtype).as_in_context(ctx)
return csr, (csr.indptr, csr.indices, csr.data)
elif distribution == "powerlaw":
csr = _get_powerlaw_dataset_csr(shape[0], shape[1], density=density, dtype=dtype).as_in_context(ctx)
return csr, (csr.indptr, csr.indices, csr.data)
else:
assert(False), f"Distribution not supported: {distribution}"
return False
else:
assert(False), "unknown storage type"
return False
def rand_ndarray(shape, stype='default', density=None, dtype=None, modifier_func=None,
shuffle_csr_indices=False, distribution=None, ctx=None):
"""Generate a random sparse ndarray. Returns the generated ndarray."""
ctx = ctx if ctx else default_device()
if stype == 'default':
arr = mx.nd.array(random_arrays(shape), dtype=dtype, ctx=ctx)
else:
arr, _ = rand_sparse_ndarray(shape, stype, density=density,
modifier_func=modifier_func, dtype=dtype,
shuffle_csr_indices=shuffle_csr_indices,
distribution=distribution, ctx=ctx)
return arr
def create_sparse_array(shape, stype, data_init=None, rsp_indices=None,
dtype=None, modifier_func=None, density=.5,
shuffle_csr_indices=False):
"""Create a sparse array, For Rsp, assure indices are in a canonical format"""
if stype == 'row_sparse':
if rsp_indices is not None:
arr_indices = np.asarray(rsp_indices)
arr_indices.sort()
else:
arr_indices = None
arr_data, (_, _) = rand_sparse_ndarray(shape, stype,
density=density,
data_init=data_init,
rsp_indices=arr_indices,
dtype=dtype,
modifier_func=modifier_func)
elif stype == 'csr':
arr_data, (_, _, _) = rand_sparse_ndarray(shape,
stype,
density=density,
data_init=data_init,
dtype=dtype,
modifier_func=modifier_func,
shuffle_csr_indices=shuffle_csr_indices)
else:
msg = "Unknown storage type: " + stype
raise AssertionError(msg)
return arr_data
def create_sparse_array_zd(shape, stype, density, data_init=None,
rsp_indices=None, dtype=None, modifier_func=None,
shuffle_csr_indices=False):
"""Create sparse array, using only rsp_indices to determine density"""
if stype == 'row_sparse':
density = 0.0
if rsp_indices is not None:
assert len(rsp_indices) <= shape[0]
return create_sparse_array(shape, stype,
data_init=data_init,
rsp_indices=rsp_indices,
dtype=dtype,
modifier_func=modifier_func,
density=density,
shuffle_csr_indices=shuffle_csr_indices)
def rand_shape_2d(dim0=10, dim1=10, allow_zero_size=False):
low = 0 if allow_zero_size else 1
return rnd.randint(low, dim0 + 1), rnd.randint(low, dim1 + 1)
def rand_shape_3d(dim0=10, dim1=10, dim2=10, allow_zero_size=False):
low = 0 if allow_zero_size else 1
return rnd.randint(low, dim0 + 1), rnd.randint(low, dim1 + 1), rnd.randint(low, dim2 + 1)
def rand_shape_nd(num_dim, dim=10, allow_zero_size=False):
low = 0 if allow_zero_size else 1
return tuple(rnd.randint(low, dim+1, size=num_dim))
def rand_coord_2d(x_low, x_high, y_low, y_high):
x = np.random.randint(x_low, x_high, dtype=np.int64)
y = np.random.randint(y_low, y_high, dtype=np.int64)
return x, y
def np_reduce(dat, axis, keepdims, numpy_reduce_func):
"""Compatible reduce for old version of NumPy.
Parameters
----------
dat : np.ndarray
Same as NumPy.
axis : None or int or list-like
Same as NumPy.
keepdims : bool
Same as NumPy.
numpy_reduce_func : function
A NumPy reducing function like ``np.sum`` or ``np.max``.
"""
if isinstance(axis, int):
axis = [axis]
else:
axis = list(axis) if axis is not None else range(len(dat.shape))
ret = dat
for i in reversed(sorted(axis)):
ret = numpy_reduce_func(ret, axis=i)
if keepdims:
keepdims_shape = list(dat.shape)
for i in axis:
keepdims_shape[i] = 1
ret = ret.reshape(tuple(keepdims_shape))
return ret
def _find_max_violation(a, b, rtol, atol):
"""Finds and returns the location of maximum violation."""
# 'smart' absdiff that considers inf's as equals (to match np.allclose)
absdiff = np.where(np.equal(a, b), 0, np.abs(a-b))
tol = atol + rtol*np.abs(b)
violation = absdiff/(tol+1e-20)
loc = np.argmax(violation)
idx = np.unravel_index(loc, violation.shape)
return idx, np.max(violation)
def same(a, b):
"""Test if two NumPy arrays are the same.
Parameters
----------
a : np.ndarray
b : np.ndarray
"""
return np.array_equal(a, b)
def checkShapes(a, b):
if a.shape != b.shape:
msg = npt.build_err_msg([a, b],
err_msg="a.shape = {} and b.shape = {} are not equal"
.format(str(a.shape), str(b.shape)))
raise AssertionError(msg)
def almost_equal(a, b, rtol=None, atol=None, equal_nan=False, use_broadcast=True):
"""Test if two numpy arrays are almost equal."""
# pylint: disable=unexpected-keyword-arg
if not use_broadcast:
checkShapes(a, b)
return np.allclose(a, b, rtol=get_rtol(rtol), atol=get_atol(atol), equal_nan=equal_nan)
# pylint: enable=unexpected-keyword-arg
def locationError(a, b, index, names, maxError=False):
"""Create element mismatch comment
Parameters
----------
a, b : compared np.ndarray's
index : tuple of coordinate arrays
Location of violation
names : tuple of names
The names of compared arrays.
maxError: boolean, optional
Flag indicating that maximum error is reporting.
"""
maximum = "maximum " if maxError else ""
return f"Location of {maximum} error: {str(index)}, {names[0]}={a[index]:.8f}, {names[1]}={b[index]:.8f}"
def assert_almost_equal(a, b, rtol=None, atol=None, names=('a', 'b'), equal_nan=False,
use_broadcast=True, mismatches=(10, 10)):
"""Test that two numpy arrays are almost equal. Raise exception message if not.
Parameters
----------
a : np.ndarray or mx.nd.array
b : np.ndarray or mx.nd.array
rtol : None or float or dict of dtype -> float
The relative threshold. Default threshold will be used if set to ``None``.
atol : None or float or dict of dtype -> float
The absolute threshold. Default threshold will be used if set to ``None``.
names : tuple of names, optional
The names used in error message when an exception occurs
equal_nan : boolean, optional
The flag determining how to treat NAN values in comparison
mismatches : tuple of mismatches
Maximum number of mismatches to be printed (mismatches[0]) and determine (mismatches[1])
"""
if not use_broadcast:
checkShapes(a, b)
rtol, atol = get_tols(a, b, rtol, atol)
if isinstance(a, mx.numpy.ndarray):
a = a.asnumpy()
if isinstance(b, mx.numpy.ndarray):
b = b.asnumpy()
use_np_allclose = isinstance(a, np.ndarray) and isinstance(b, np.ndarray)
if not use_np_allclose:
if not (hasattr(a, 'ctx') and hasattr(b, 'ctx') and a.device == b.device and a.dtype == b.dtype):
use_np_allclose = True
if isinstance(a, mx.nd.NDArray):
a = a.asnumpy()
if isinstance(b, mx.nd.NDArray):
b = b.asnumpy()
if use_np_allclose:
if hasattr(a, 'dtype') and a.dtype == np.bool_ and hasattr(b, 'dtype') and b.dtype == np.bool_:
np.testing.assert_equal(a, b)
return
if almost_equal(a, b, rtol, atol, equal_nan=equal_nan):
return
else:
output = mx.nd.contrib.allclose(a, b, rtol, atol, equal_nan)
if output.asnumpy() == 1:
return
a = a.asnumpy()
b = b.asnumpy()
index, rel = _find_max_violation(a, b, rtol, atol)
if index != ():
# a, b are the numpy arrays
indexErr = index
relErr = rel
print('\n*** Maximum errors for vector of size {}: rtol={}, atol={}\n'.format(a.size, rtol, atol))
aTmp = a.copy()
bTmp = b.copy()
i = 1
while i <= a.size:
if i <= mismatches[0]:
print(f"{i:3d}: Error {rel} {locationError(a, b, index, names)}")
aTmp[index] = bTmp[index] = 0
if almost_equal(aTmp, bTmp, rtol, atol, equal_nan=equal_nan):
break
i += 1
if i <= mismatches[1] or mismatches[1] <= 0:
index, rel = _find_max_violation(aTmp, bTmp, rtol, atol)
else:
break
mismatchDegree = "at least " if mismatches[1] > 0 and i > mismatches[1] else ""
errMsg = f"Error {relErr} exceeds tolerance rtol={rtol:e}, atol={atol:e} " \
f"(mismatch {mismatchDegree}{100*i/a.size}%).\n" \
f"{locationError(a, b, indexErr, names, maxError=True)}"
else:
errMsg = f"Error {rel} exceeds tolerance rtol={rtol:e}, atol={atol:e}.\n"
np.set_printoptions(threshold=4, suppress=True)
msg = npt.build_err_msg([a, b], err_msg=errMsg)
raise AssertionError(msg)
def assert_allclose(a, b, rtol=1e-07, atol=0, equal_nan=True):
assert_almost_equal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
def assert_almost_equal_with_err(a, b, rtol=None, atol=None, etol=None,
names=('a', 'b'), equal_nan=False, mismatches=(10, 10)):
"""Test that two numpy arrays are almost equal within given error rate. Raise exception message if not.
Parameters
----------
a : np.ndarray
b : np.ndarray
rtol : None or float or dict of dtype -> float
The relative threshold. Default threshold will be used if set to ``None``.
atol : None or float or dict of dtype -> float
The absolute threshold. Default threshold will be used if set to ``None``.
etol : None or float
The error rate threshold. If etol is float, return true if error_rate < etol even if
any error is found.
names : tuple of names, optional
The names used in error message when an exception occurs
equal_nan : boolean, optional
The flag determining how to treat NAN values in comparison
mismatches : tuple of mismatches
Maximum number of mismatches to be printed (mismatches[0]) and determine (mismatches[1])
"""
etol = get_etol(etol)
if etol > 0:
rtol, atol = get_tols(a, b, rtol, atol)
if isinstance(a, mx.nd.NDArray):
a = a.asnumpy()
if isinstance(b, mx.nd.NDArray):
b = b.asnumpy()
equals = np.isclose(a, b, rtol=rtol, atol=atol)
err = 1 - np.count_nonzero(equals) / equals.size
if err > etol:
index, rel = _find_max_violation(a, b, rtol, atol)
indexErr = index
relErr = rel
print('\n*** Maximum errors for vector of size {}: rtol={}, atol={}\n'.format(a.size, rtol, atol))
aTmp = a.copy()
bTmp = b.copy()
i = 1
while i <= a.size:
if i <= mismatches[0]:
print(f"{i:3d}: Error {rel} {locationError(a, b, index, names)}")
aTmp[index] = bTmp[index] = 0
if almost_equal(aTmp, bTmp, rtol, atol, equal_nan=equal_nan):
break
i += 1
if i <= mismatches[1] or mismatches[1] <= 0:
index, rel = _find_max_violation(aTmp, bTmp, rtol, atol)
else:
break
mismatchDegree = "at least " if mismatches[1] > 0 and i > mismatches[1] else ""
errMsg = f"Error {relErr} exceeds tolerance rtol={rtol:e}, atol={atol:e} " \
f"(mismatch {mismatchDegree}{100*i/a.size}%).\n" \
f"{locationError(a, b, indexErr, names, maxError=True)}"
np.set_printoptions(threshold=4, suppress=True)
msg = npt.build_err_msg([a, b], err_msg=errMsg)
raise AssertionError(msg)
else:
assert_almost_equal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
def assert_almost_equal_ignore_nan(a, b, rtol=None, atol=None, names=('a', 'b')):
"""Test that two NumPy arrays are almost equal (ignoring NaN in either array).
Combines a relative and absolute measure of approximate eqality.
If either the relative or absolute check passes, the arrays are considered equal.
Including an absolute check resolves issues with the relative check where all
array values are close to zero.
Parameters
----------
a : np.ndarray
b : np.ndarray
rtol : None or float
The relative threshold. Default threshold will be used if set to ``None``.
atol : None or float
The absolute threshold. Default threshold will be used if set to ``None``.
"""
a = np.copy(a)
b = np.copy(b)
nan_mask = np.logical_or(np.isnan(a), np.isnan(b))
a[nan_mask] = 0
b[nan_mask] = 0
assert_almost_equal(a, b, rtol, atol, names)
def assert_exception(f, exception_type, *args, **kwargs):
"""Test that function f will throw an exception of type given by `exception_type`"""
try:
f(*args, **kwargs)
assert(False)
except exception_type:
return
def _parse_location(sym, location, ctx, dtype=default_dtype()):
"""Parses the given location to a ordered dictionary.
Arguments of the provided op `sym` are used as dictionary keys
and elements of `location` are used as values.
Parameters
----------
sym : Symbol
Symbol containing op
location : list or tuple or dict
Argument values location
- if type is list or tuple of `np.ndarray`
inner elements are arrays correspoding to
``sym.list_arguments()``.
- if type is dict of str -> `np.ndarray`
maps the name of arguments to the corresponding `np.ndarray`.
*In either case, value of all the arguments must be provided.*
ctx : Device
Device context.
dtype: "asnumpy" or np.float16 or np.float32 or np.float64
If dtype is "asnumpy" then the mx.nd.array created will have the same
type as th numpy array from which it is copied.
Otherwise, dtype is the explicit datatype for all mx.nd.array objects
created in this function.
Returns
-------
dict
Dictionary with `sym` arguments as keys and `location` elements as
values.
Examples
-------
>>> a = mx.symbol.Variable('a')
>>> b = mx.symbol.Variable('b')
>>> l1 = np.ndarray([2,3])
>>> l2 = np.ndarray([3,4])
>>> _parse_location(a * b, [l1, l2], None)
{'a': <NDArray 2x3 @cpu(0)>, 'b': <NDArray 3x4 @cpu(0)>}
>>> _parse_location(a * b, {'a': l1, 'b': l2}, None)
{'a': <NDArray 2x3 @cpu(0)>, 'b': <NDArray 3x4 @cpu(0)>}
>>> _parse_location(a * b, {'a': l1}, None)
ValueError: Symbol arguments and keys of the given location do not match.
"""
assert isinstance(location, (dict, list, tuple))
assert dtype == "asnumpy" or dtype in (np.float16, np.float32, np.float64)
if isinstance(location, dict):
if set(location.keys()) != set(sym.list_arguments()):
raise ValueError("Symbol arguments and keys of the given location do not match."
f"symbol args:{str(set(sym.list_arguments()))}, location.keys():{str(set(location.keys()))}")
else:
location = {k: v for k, v in zip(sym.list_arguments(), location)}
location = {k: mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" else dtype) \
if isinstance(v, np.ndarray) else v for k, v in location.items()}
return _sorted_dict(location)
def _parse_aux_states(sym, aux_states, ctx, dtype=default_dtype()):
"""Parses the given auxiliary states to a dictionary.
Auxiliary states of the provided op `sym` are used as dictionary
keys and elements of `aux_states` are used as values.
Parameters
----------
sym : Symbol
Symbol containing op
aux_states : None or list or dict
Aux states
- if type is list or tuple of `np.ndarray`
inner elements are arrays correspoding to
``sym.list_auxiliary_states()``.
- if type is dict of str -> `np.ndarray`
maps the name of arguments to the corresponding `np.ndarray`.
*In either case, all aux states of `sym` must be provided.*
ctx : Device
Device context.
dtype: "asnumpy" or np.float16 or np.float32 or np.float64
If dtype is "asnumpy" then the mx.nd.array created will have the same
type as th numpy array from which it is copied.
Otherwise, dtype is the explicit datatype for all mx.nd.array objects
created in this function.
Returns
-------
dict
Dictionary with `sym` aux states as keys and `aux_states` elements
as values.
Examples
-------
>>> data = mx.symbol.Variable('data')
>>> weight = mx.sym.Variable(name='fc1_weight')
>>> fc1 = mx.symbol.FullyConnected(data = data, weight=weight, name='fc1', num_hidden=128)
>>> fc2 = mx.symbol.BatchNorm(fc1, name='batchnorm0')
>>> mean_states = np.ones(3)
>>> var_states = np.ones(3)
>>> _parse_aux_states(fc2, [mean_states, var_states], None)
{'batchnorm0_moving_var': <NDArray 3 @cpu(0)>, 'batchnorm0_moving_mean': <NDArray 3 @cpu(0)>}
>>> _parse_aux_states(fc2, {'batchnorm0_moving_var': mean_states,
... 'batchnorm0_moving_mean': var_states}, None)
{'batchnorm0_moving_var': <NDArray 3 @cpu(0)>, 'batchnorm0_moving_mean': <NDArray 3 @cpu(0)>}
>>> _parse_aux_states(fc2, {'batchnorm0_moving_var': mean_states}, None)
ValueError: Symbol aux_states names and given aux_states do not match.
"""
assert dtype == "asnumpy" or dtype in (np.float16, np.float32, np.float64)
if aux_states is not None:
if isinstance(aux_states, dict):
if set(aux_states.keys()) != set(sym.list_auxiliary_states()):
raise ValueError("Symbol aux_states names and given aux_states do not match."
f"symbol aux_names:{str(set(sym.list_auxiliary_states()))}, aux_states.keys:{str(set(aux_states.keys()))}")
elif isinstance(aux_states, (list, tuple)):
aux_names = sym.list_auxiliary_states()
aux_states = {k:v for k, v in zip(aux_names, aux_states)}
aux_states = {k: mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" else dtype) \
for k, v in aux_states.items()}
return aux_states
def numeric_grad(executor, location, aux_states=None, eps=1e-4,
use_forward_train=True, dtype=default_dtype()):
"""Calculates a numeric gradient via finite difference method.
Class based on Theano's `theano.gradient.numeric_grad` [1]
Parameters
----------
executor : Executor
Executor that computes the forward pass.
location : list of numpy.ndarray or dict of str to numpy.ndarray
Argument values used as location to compute gradient
Maps the name of arguments to the corresponding numpy.ndarray.
Value of all the arguments must be provided.
aux_states : None or list of numpy.ndarray or dict of str to numpy.ndarray, optional
Auxiliary states values used as location to compute gradient
Maps the name of aux_states to the corresponding numpy.ndarray.
Value of all the auxiliary arguments must be provided.
eps : float, optional
Epsilon for the finite-difference method.
use_forward_train : bool, optional
Whether to use `is_train=True` in testing.
dtype: np.float16 or np.float32 or np.float64
Datatype for mx.nd.array.
References
---------
..[1] https://github.com/Theano/Theano/blob/master/theano/gradient.py
"""
def as_stype(var, stype, dtype):
return mx.nd.cast_storage(mx.nd.array(var, dtype=dtype), stype=stype)
assert dtype in (np.float16, np.float32, np.float64)
approx_grads = {k: np.zeros(v.shape, dtype=dtype)
for k, v in location.items()}
for k, v in location.items():