Function inputs & outputs can be annotated to help the reader better understand intended shape/dtype.
from etils.array_types import Array, FloatArray, f32, ui8
def _normalize_image(img: ui8['h w c']) -> f32['h w c']:
return np.interp(img, from_=(0, 255), to=(-1, 1))
This indicates the reader that the function takes a 3d uint8 array and return a 3d float32 with the same shape values.
Note: Those typing annotations are not (yet) detected by static type checking tools. However, they are already helpful as documentation.
Typing annotations shape follow the conventions:
- Valid symbols:
str
: Named axis (e.g.f32['batch height width']
)int
: Static axis (e.g.f32[28, 28]
,f32['h w 3']
)_
: Anonymous axis (e.g.f32['batch _ _ c']
,f32[None, 3]
)...
: Anonymous zeros or more axis (e.g.f32['... h w c']
,f32[..., 3]
)*name
: Named zeros or more axis (e.g.f32['*batch_dims h w c']
)+
,-
,/
,*
operators (e.g.f32['h/2 w/2 c1+c2']
)
- Typing annotations are only considered to be consistent per function
call, so a function
f32['h w'] -> f32['h w']
can be called twice with 2 different image sizes. - Passing multiple values is the same as concatenating the string (e.g.
f32[..., 'h', 'w', 3] == f32['... h w 3']
- DType can be:
Array[...]
: Any dtype acceptedFloatArray
(acceptsf32
,bf16
, ...),IntArray
(acceptsui8
,i32
,i64
, ...): Respectively accept an union of multiple typesf32
,ui8
, ...: Specific type
ArrayLike[f32[...]]
indicates any array convertible values are accepted (list
,tuple
, ...).
You can decorate your function with @enp.check_and_normalize_arrays
so that
array shape/dtype are dynamically validated at runtime:
from etils import enp
from etils.array_types import FloatArray, IntArray
@enp.check_and_normalize_arrays
def add(x: IntArray, y: IntArray) -> IntArray:
return x + y
Functions decorated with enp.check_and_normalize_arrays
support np
, jnp
,
and tnp
:
- If args are mixed between
jnp
andtnp
, an error is raised - If args are
xnp
withnp
, thenp
array is auto-casted toxnp
. - You can force usage of TF / Jax / Numpy by passing a
xnp=
kwargs (automatically added).
add(np.array(1), jnp.array(2)) # np auto-casted to jnp
add(tf.constant(1), jnp.array(2)) # Error jnp / TF conflict
add(tf.constant(1), jnp.array(2), xnp=jnp) # Force jnp usage
Using strict=False
makes your function auto-convert list
, int
,... to
xnp.ndarray
:
@enp.check_and_normalize_arrays(strict=False)
def add(x: IntArray, y: IntArray):
return x + y
add([1, 2, 3], 10) # == np.array([10, 12, 13])
add([1, 2, 3], 10, xnp=jnp) # == jnp.array([10, 12, 13])
add([1, 2, 3], tf.constant(10)) # == tnp.array([10, 12, 13])
You can add a xnp: enp.NpModule = ...
kwarg to your function which will be
automatically assigned to the auto-infered xnp:
@enp.check_and_normalize_arrays(strict=False)
def add(x: IntArray, y: IntArray, *, xnp: enp.NpModule = ...):
return xnp.add(x, y)
add(1, [1, 2, 3]) # Inside the function, `xnp=np`
add(tf.constant(1), tf.constant(2)) # Inside the function, `xnp=tnp`
There are 2 levels of checking:
- Using type union:
IntArray
(acceptsui8
,i32
,i64
, ...),FloatArray
(acceptsf32
,bf16
, ...) - Using specific type:
f32
,ui8
, ...
Using type unions allows your functions to support quantization, ...
Currently, shape checking is not yet supported (but in project).