Skip to content

Commit

Permalink
Merge pull request #59 from banesullivan/patch-2
Browse files Browse the repository at this point in the history
Add PyVista mesh support to Field
  • Loading branch information
MuellerSeb authored Dec 10, 2020
2 parents 473ce01 + 8b8883d commit bcbb8db
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 22 deletions.
79 changes: 57 additions & 22 deletions gstools/field/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def mesh(
Parameters
----------
mesh : meshio.Mesh or ogs5py.MSH
The given meshio or ogs5py mesh
mesh : meshio.Mesh or ogs5py.MSH or PyVista mesh
The given meshio, ogs5py, or PyVista mesh
points : :class:`str`, optional
The points to evaluate the field at.
Either the "centroids" of the mesh cells
Expand All @@ -88,21 +88,39 @@ def mesh(
you have to pass "xz". By default, all directions are used.
One can also pass a list of indices.
Default: "all"
name : :class:`str`, optional
Name to store the field in the given mesh as point_data or
cell_data. Default: "field"
name : :class:`str` or :class:`list` of :class:`str`, optional
Name(s) to store the field(s) in the given mesh as point_data or
cell_data. If to few names are given, digits will be appended.
Default: "field"
**kwargs
Keyword arguments forwareded to `Field.__call__`.
Notes
-----
This will store the field in the given mesh under the given name,
if a meshio mesh was given.
if a meshio or PyVista mesh was given.
See: https://github.com/nschloe/meshio
See: https://github.com/pyvista/pyvista
See: :any:`Field.__call__`
"""
has_pyvista = False
has_ogs5py = False

try:
import pyvista as pv

has_pyvista = True
except ImportError:
pass
try:
import ogs5py as ogs

has_ogs5py = True
except ImportError:
pass

if isinstance(direction, str) and direction == "all":
select = list(range(self.model.dim))
elif isinstance(direction, str):
Expand All @@ -115,12 +133,25 @@ def mesh(
self.model.dim, direction
)
)
if hasattr(mesh, "centroids_flat"):
# convert pyvista mesh
if has_pyvista and pv.is_pyvista_dataset(mesh):
if points == "centroids":
pnts = mesh.cell_centers().points.T[select]
else:
pnts = mesh.points.T[select]
out = self.unstructured(pos=pnts, **kwargs)
# Deal with the output
fields = [out] if isinstance(out, np.ndarray) else out
for f_name, field in zip(_names(name, len(fields)), fields):
mesh[f_name] = field
# convert ogs5py mesh
elif has_ogs5py and isinstance(mesh, ogs.MSH):
if points == "centroids":
pnts = mesh.centroids_flat.T[select]
else:
pnts = mesh.NODES.T[select]
out = self.unstructured(pos=pnts, **kwargs)
# convert meshio mesh
elif isinstance(mesh, meshio.Mesh):
if points == "centroids":
# define unique order of cells
Expand All @@ -138,23 +169,20 @@ def mesh(
# generate pos for __call__
pnts = pnts.T[select]
out = self.unstructured(pos=pnts, **kwargs)
if isinstance(out, np.ndarray):
field = out
else:
# if multiple values are returned, take the first one
field = out[0]
field_list = []
for i in range(len(offset)):
field_list.append(field[offset[i] : offset[i] + length[i]])
mesh.cell_data[name] = field_list
fields = [out] if isinstance(out, np.ndarray) else out
f_lists = []
for field in fields:
f_list = []
for of, le in zip(offset, length):
f_list.append(field[of : of + le])
f_lists.append(f_list)
for f_name, f_list in zip(_names(name, len(f_lists)), f_lists):
mesh.cell_data[f_name] = f_list
else:
out = self.unstructured(pos=mesh.points.T[select], **kwargs)
if isinstance(out, np.ndarray):
field = out
else:
# if multiple values are returned, take the first one
field = out[0]
mesh.point_data[name] = field
fields = [out] if isinstance(out, np.ndarray) else out
for f_name, field in zip(_names(name, len(fields)), fields):
mesh.point_data[f_name] = field
else:
raise ValueError("Field.mesh: Unknown mesh format!")
return out
Expand Down Expand Up @@ -384,6 +412,13 @@ def __repr__(self):
)


def _names(name, cnt):
name = [name] if isinstance(name, str) else list(name)[:cnt]
if len(name) < cnt:
name += [name[-1] + str(i + 1) for i in range(cnt - len(name))]
return name


def _get_select(direction):
select = []
if not (0 < len(direction) < 4):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_srf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@
from gstools import transform as tf
import meshio

HAS_PYVISTA = False
try:
import pyvista as pv

HAS_PYVISTA = True
except ImportError:
pass


class TestSRF(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -231,6 +239,22 @@ def test_calls(self):
self.assertAlmostEqual(field[0, 0], srf.field[0, 0])
self.assertAlmostEqual(field[0, 0], field2[0, 0])

@unittest.skipIf(not HAS_PYVISTA, "PyVista is not installed")
def test_mesh_pyvista(self):
"""Test the `.mesh` call with various PyVista meshes."""
# Create model
srf = SRF(self.cov_model, mean=self.mean, mode_no=self.mode_no)
# Get the field the normal way for comparison
field = srf((self.x_tuple, self.y_tuple, self.z_tuple), seed=self.seed)
# Create mesh space with PyVista
pv_mesh = pv.PolyData(np.c_[self.x_tuple, self.y_tuple, self.z_tuple])
# Run the helper
_ = srf.mesh(pv_mesh, seed=self.seed, points="centroids")
self.assertTrue(np.allclose(field, pv_mesh["field"]))
# points="centroids"
_ = srf.mesh(pv_mesh, seed=self.seed, points="points")
self.assertTrue(np.allclose(field, pv_mesh["field"]))

def test_transform(self):
self.cov_model.dim = 2
srf = SRF(self.cov_model, mean=self.mean, mode_no=self.mode_no)
Expand Down

0 comments on commit bcbb8db

Please sign in to comment.