Skip to content

Commit

Permalink
Implemented directional variograms for unstructured (#87)
Browse files Browse the repository at this point in the history
* implemented a prototype for the unstructured function with angles.

* minor fix with variable name

* added docstring and arguments for angle estimation

* added a working version which tests the basic functionality

* docstring adaption

* changed for automatic testcase data generation and split up the test cases

* changed default angle tolerance to 25deg

* added option to also return the counts (number of pairs) from unstructured

* implemented a meaningful test for 2d variogram estimation

* bugfix for 3d case when elevation is 90° or 270°

* implemented some basic 3d test cases

* vario: cleanup cython routines; use greate-circle for tolerance in 3D; check both directions between point pairs

* vario: doc update; correct intervals for angles; general formatting of angles array

* vario: better handling of angle ranges

* vario: fix wrong assumption about hemisphere for angles

Co-authored-by: MuellerSeb <[email protected]>
  • Loading branch information
TobiasGlaubach and MuellerSeb authored Nov 6, 2020
1 parent 67a228d commit 95b1431
Show file tree
Hide file tree
Showing 3 changed files with 449 additions and 25 deletions.
97 changes: 93 additions & 4 deletions gstools/variogram/estimator.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import numpy as np
cimport cython
from cython.parallel import prange, parallel
from libcpp.vector cimport vector
from libc.math cimport fabs, sqrt
from libc.math cimport fabs, sqrt, atan2, acos, asin, sin, cos, M_PI
cimport numpy as np


Expand Down Expand Up @@ -47,6 +47,68 @@ cdef inline double _distance_3d(
(y[i] - y[j]) * (y[i] - y[j]) +
(z[i] - z[j]) * (z[i] - z[j]))

cdef inline bint _angle_test_1d(
const double[:] x,
const double[:] y,
const double[:] z,
const double[:] angles,
const double angles_tol,
const int i,
const int j
) nogil:
return True

cdef inline bint _angle_test_2d(
const double[:] x,
const double[:] y,
const double[:] z,
const double[:] angles,
const double angles_tol,
const int i,
const int j
) nogil:
cdef double dx = x[i] - x[j]
cdef double dy = y[i] - y[j]
# azimuth
cdef double phi1 = atan2(dy,dx) % (2.0 * M_PI)
cdef double phi2 = atan2(-dy,-dx) % (2.0 * M_PI)
# check both directions (+/-)
cdef bint dir1 = fabs(phi1 - angles[0]) <= angles_tol
cdef bint dir2 = fabs(phi2 - angles[0]) <= angles_tol
return dir1 or dir2

cdef inline bint _angle_test_3d(
const double[:] x,
const double[:] y,
const double[:] z,
const double[:] angles,
const double angles_tol,
const int i,
const int j
) nogil:
cdef double dx = x[i] - x[j]
cdef double dy = y[i] - y[j]
cdef double dz = z[i] - z[j]
cdef double dr = sqrt(dx**2 + dy**2 + dz**2)
# azimuth
cdef double phi1 = atan2(dy, dx) % (2.0 * M_PI)
cdef double phi2 = atan2(-dy, -dx) % (2.0 * M_PI)
# elevation
cdef double theta1 = acos(dz / dr)
cdef double theta2 = acos(-dz / dr)
# definitions for great-circle distance (for tolerance check)
cdef double dx1 = sin(theta1) * cos(phi1) - sin(angles[1]) * cos(angles[0])
cdef double dy1 = sin(theta1) * sin(phi1) - sin(angles[1]) * sin(angles[0])
cdef double dz1 = cos(theta1) - cos(angles[1])
cdef double dx2 = sin(theta2) * cos(phi2) - sin(angles[1]) * cos(angles[0])
cdef double dy2 = sin(theta2) * sin(phi2) - sin(angles[1]) * sin(angles[0])
cdef double dz2 = cos(theta2) - cos(angles[1])
cdef double dist1 = 2.0 * asin(sqrt(dx1**2 + dy1**2 + dz1**2) * 0.5)
cdef double dist2 = 2.0 * asin(sqrt(dx2**2 + dy2**2 + dz2**2) * 0.5)
# check both directions (+/-)
cdef bint dir1 = dist1 <= angles_tol
cdef bint dir2 = dist2 <= angles_tol
return dir1 or dir2

cdef inline double estimator_matheron(const double f_diff) nogil:
return f_diff * f_diff
Expand Down Expand Up @@ -110,13 +172,25 @@ ctypedef double (*_dist_func)(
const int
) nogil

ctypedef bint (*_angle_test_func)(
const double[:],
const double[:],
const double[:],
const double[:],
const double,
const int,
const int
) nogil


def unstructured(
const double[:] f,
const double[:] bin_edges,
const double[:] x,
const double[:] y=None,
const double[:] z=None,
const double[:] angles=None,
const double angles_tol=0.436332,
str estimator_type='m'
):
if x.shape[0] != f.shape[0]:
Expand All @@ -126,21 +200,35 @@ def unstructured(
raise ValueError('len(bin_edges) too small')

cdef _dist_func distance
cdef _angle_test_func angle_test

# 3d
if z is not None:
if z.shape[0] != f.shape[0]:
raise ValueError('len(z) = {0} != len(f) = {1} '.
format(z.shape[0], f.shape[0]))
distance = _distance_3d
angle_test = _angle_test_3d
# 2d
elif y is not None:
if y.shape[0] != f.shape[0]:
raise ValueError('len(y) = {0} != len(f) = {1} '.
format(y.shape[0], f.shape[0]))
distance = _distance_2d
angle_test = _angle_test_2d
# 1d
else:
distance = _distance_1d
angle_test = _angle_test_1d

if angles is not None:
if z is not None and angles.size < 2:
raise ValueError('3d requested but only one angle given')
if y is not None and angles.size < 1:
raise ValueError('2d with angle requested but no angle given')

if angles_tol <= 0:
raise ValueError('tolerance for angle search masks must be > 0')

cdef _estimator_func estimator_func = choose_estimator_func(estimator_type)
cdef _normalization_func normalization_func = (
Expand All @@ -160,11 +248,12 @@ def unstructured(
for k in range(j+1, k_max):
dist = distance(x, y, z, k, j)
if dist >= bin_edges[i] and dist < bin_edges[i+1]:
counts[i] += 1
variogram[i] += estimator_func(f[k] - f[j])
if angles is None or angle_test(x, y, z, angles, angles_tol, k, j):
counts[i] += 1
variogram[i] += estimator_func(f[k] - f[j])

normalization_func(variogram, counts)
return np.asarray(variogram)
return np.asarray(variogram), np.asarray(counts)


def structured(const double[:,:,:] f, str estimator_type='m'):
Expand Down
81 changes: 63 additions & 18 deletions gstools/variogram/variogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ def vario_estimate_unstructured(
pos,
field,
bin_edges,
angles=None,
angles_tol=0.436332,
sampling_size=None,
sampling_seed=None,
estimator="matheron",
return_counts=False,
):
r"""
Estimates the variogram on a unstructured grid.
Expand All @@ -64,10 +67,6 @@ def vario_estimate_unstructured(
being the bins.
The Cressie estimator is more robust to outliers.
Notes
-----
Internally uses double precision and also returns doubles.
Parameters
----------
pos : :class:`list`
Expand All @@ -77,6 +76,18 @@ def vario_estimate_unstructured(
the spatially distributed data
bin_edges : :class:`numpy.ndarray`
the bins on which the variogram will be calculated
angles : :class:`numpy.ndarray`
the angles of the main axis to calculate the variogram for in radians
angle definitions from ISO standard 80000-2:2009
for 1d this parameter will have no effect at all
for 2d supply one angle which is azimuth φ (ccw from +x in xy plane)
for 3d supply two angles which are azimuth φ (ccw from +x in xy plane)
and inclination θ (cw from +z)
angles_tol : class:`float`
the tolerance around the variogram angle to count a point as being
within this direction from another point (the angular tolerance around
the directional vector given by angles)
Default: 25°≈0.436332
sampling_size : :class:`int` or :any:`None`, optional
for large input data, this method can take a long
time to compute the variogram, therefore this argument specifies
Expand All @@ -92,18 +103,43 @@ def vario_estimate_unstructured(
* "cressie": an estimator more robust to outliers
Default: "matheron"
return_counts: class:`bool`, optional
if set to true, this function will also return the number of data
points found at each lag distance as a third return value
Default: False
Returns
-------
:class:`tuple` of :class:`numpy.ndarray`
the estimated variogram and the bin centers
1. the bin centers
2. the estimated variogram values at bin centers
3. (optional) the number of points found at each bin center
(see argument return_counts)
Notes
-----
Internally uses double precision and also returns doubles.
"""
# TODO check_mesh
field = np.array(field, ndmin=1, dtype=np.double)
bin_edges = np.array(bin_edges, ndmin=1, dtype=np.double)
x, y, z, dim = pos2xyz(pos, calc_dim=True, dtype=np.double)

bin_centres = (bin_edges[:-1] + bin_edges[1:]) / 2.0

if angles is not None and dim > 1:
# there are at most (dim-1) angles
angles = np.array(angles, ndmin=1, dtype=np.double) # type convert
angles = angles.ravel()[: (dim - 1)] # cutoff
angles = np.pad(
angles, (0, dim - angles.size - 1), "constant", constant_values=0.0
) # fill with 0 if too less given
# correct intervalls for angles
angles[0] = angles[0] % (2 * np.pi)
angles[1:] = angles[1:] % np.pi
elif dim == 1:
angles = None

if sampling_size is not None and sampling_size < len(field):
sampled_idx = np.random.RandomState(sampling_seed).choice(
np.arange(len(field)), sampling_size, replace=False
Expand All @@ -117,13 +153,22 @@ def vario_estimate_unstructured(

cython_estimator = _set_estimator(estimator)

return (
bin_centres,
unstructured(
field, bin_edges, x, y, z, estimator_type=cython_estimator
),
estimates, counts = unstructured(
field,
bin_edges,
x,
y,
z,
angles,
angles_tol,
estimator_type=cython_estimator,
)

if return_counts:
return bin_centres, estimates, counts
else:
return bin_centres, estimates


def vario_estimate_structured(field, direction="x", estimator="matheron"):
r"""Estimates the variogram on a regular grid.
Expand All @@ -149,14 +194,6 @@ def vario_estimate_structured(field, direction="x", estimator="matheron"):
being the bins.
The Cressie estimator is more robust to outliers.
Warnings
--------
It is assumed that the field is defined on an equidistant Cartesian grid.
Notes
-----
Internally uses double precision and also returns doubles.
Parameters
----------
field : :class:`numpy.ndarray`
Expand All @@ -175,6 +212,14 @@ def vario_estimate_structured(field, direction="x", estimator="matheron"):
-------
:class:`numpy.ndarray`
the estimated variogram along the given direction.
Warnings
--------
It is assumed that the field is defined on an equidistant Cartesian grid.
Notes
-----
Internally uses double precision and also returns doubles.
"""
try:
mask = np.array(field.mask, dtype=np.int32)
Expand Down
Loading

0 comments on commit 95b1431

Please sign in to comment.