Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PyVista mesh support to Field #59

Merged
merged 5 commits into from
Dec 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 57 additions & 22 deletions gstools/field/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def mesh(

Parameters
----------
mesh : meshio.Mesh or ogs5py.MSH
The given meshio or ogs5py mesh
mesh : meshio.Mesh or ogs5py.MSH or PyVista mesh
The given meshio, ogs5py, or PyVista mesh
points : :class:`str`, optional
The points to evaluate the field at.
Either the "centroids" of the mesh cells
Expand All @@ -88,21 +88,39 @@ def mesh(
you have to pass "xz". By default, all directions are used.
One can also pass a list of indices.
Default: "all"
name : :class:`str`, optional
Name to store the field in the given mesh as point_data or
cell_data. Default: "field"
name : :class:`str` or :class:`list` of :class:`str`, optional
Name(s) to store the field(s) in the given mesh as point_data or
cell_data. If to few names are given, digits will be appended.
Default: "field"
**kwargs
Keyword arguments forwareded to `Field.__call__`.

Notes
-----
This will store the field in the given mesh under the given name,
if a meshio mesh was given.
if a meshio or PyVista mesh was given.

See: https://github.com/nschloe/meshio
See: https://github.com/pyvista/pyvista

See: :any:`Field.__call__`
"""
has_pyvista = False
has_ogs5py = False

try:
import pyvista as pv

has_pyvista = True
except ImportError:
pass
try:
import ogs5py as ogs

has_ogs5py = True
except ImportError:
pass

if isinstance(direction, str) and direction == "all":
select = list(range(self.model.dim))
elif isinstance(direction, str):
Expand All @@ -115,12 +133,25 @@ def mesh(
self.model.dim, direction
)
)
if hasattr(mesh, "centroids_flat"):
# convert pyvista mesh
if has_pyvista and pv.is_pyvista_dataset(mesh):
if points == "centroids":
pnts = mesh.cell_centers().points.T[select]
else:
pnts = mesh.points.T[select]
out = self.unstructured(pos=pnts, **kwargs)
# Deal with the output
fields = [out] if isinstance(out, np.ndarray) else out
for f_name, field in zip(_names(name, len(fields)), fields):
mesh[f_name] = field
# convert ogs5py mesh
elif has_ogs5py and isinstance(mesh, ogs.MSH):
if points == "centroids":
pnts = mesh.centroids_flat.T[select]
else:
pnts = mesh.NODES.T[select]
out = self.unstructured(pos=pnts, **kwargs)
# convert meshio mesh
elif isinstance(mesh, meshio.Mesh):
if points == "centroids":
# define unique order of cells
Expand All @@ -138,23 +169,20 @@ def mesh(
# generate pos for __call__
pnts = pnts.T[select]
out = self.unstructured(pos=pnts, **kwargs)
if isinstance(out, np.ndarray):
field = out
else:
# if multiple values are returned, take the first one
field = out[0]
field_list = []
for i in range(len(offset)):
field_list.append(field[offset[i] : offset[i] + length[i]])
mesh.cell_data[name] = field_list
fields = [out] if isinstance(out, np.ndarray) else out
f_lists = []
for field in fields:
f_list = []
for of, le in zip(offset, length):
f_list.append(field[of : of + le])
f_lists.append(f_list)
for f_name, f_list in zip(_names(name, len(f_lists)), f_lists):
mesh.cell_data[f_name] = f_list
else:
out = self.unstructured(pos=mesh.points.T[select], **kwargs)
if isinstance(out, np.ndarray):
field = out
else:
# if multiple values are returned, take the first one
field = out[0]
mesh.point_data[name] = field
fields = [out] if isinstance(out, np.ndarray) else out
for f_name, field in zip(_names(name, len(fields)), fields):
mesh.point_data[f_name] = field
else:
raise ValueError("Field.mesh: Unknown mesh format!")
return out
Expand Down Expand Up @@ -384,6 +412,13 @@ def __repr__(self):
)


def _names(name, cnt):
name = [name] if isinstance(name, str) else list(name)[:cnt]
if len(name) < cnt:
name += [name[-1] + str(i + 1) for i in range(cnt - len(name))]
return name


def _get_select(direction):
select = []
if not (0 < len(direction) < 4):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_srf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@
from gstools import transform as tf
import meshio

HAS_PYVISTA = False
try:
import pyvista as pv

HAS_PYVISTA = True
except ImportError:
pass


class TestSRF(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -231,6 +239,22 @@ def test_calls(self):
self.assertAlmostEqual(field[0, 0], srf.field[0, 0])
self.assertAlmostEqual(field[0, 0], field2[0, 0])

@unittest.skipIf(not HAS_PYVISTA, "PyVista is not installed")
def test_mesh_pyvista(self):
"""Test the `.mesh` call with various PyVista meshes."""
# Create model
srf = SRF(self.cov_model, mean=self.mean, mode_no=self.mode_no)
# Get the field the normal way for comparison
field = srf((self.x_tuple, self.y_tuple, self.z_tuple), seed=self.seed)
# Create mesh space with PyVista
pv_mesh = pv.PolyData(np.c_[self.x_tuple, self.y_tuple, self.z_tuple])
# Run the helper
_ = srf.mesh(pv_mesh, seed=self.seed, points="centroids")
self.assertTrue(np.allclose(field, pv_mesh["field"]))
# points="centroids"
_ = srf.mesh(pv_mesh, seed=self.seed, points="points")
self.assertTrue(np.allclose(field, pv_mesh["field"]))

def test_transform(self):
self.cov_model.dim = 2
srf = SRF(self.cov_model, mean=self.mean, mode_no=self.mode_no)
Expand Down