diff --git a/gstools/field/base.py b/gstools/field/base.py index cf74c73f..487a2b64 100755 --- a/gstools/field/base.py +++ b/gstools/field/base.py @@ -74,8 +74,8 @@ def mesh( Parameters ---------- - mesh : meshio.Mesh or ogs5py.MSH - The given meshio or ogs5py mesh + mesh : meshio.Mesh or ogs5py.MSH or PyVista mesh + The given meshio, ogs5py, or PyVista mesh points : :class:`str`, optional The points to evaluate the field at. Either the "centroids" of the mesh cells @@ -88,21 +88,39 @@ def mesh( you have to pass "xz". By default, all directions are used. One can also pass a list of indices. Default: "all" - name : :class:`str`, optional - Name to store the field in the given mesh as point_data or - cell_data. Default: "field" + name : :class:`str` or :class:`list` of :class:`str`, optional + Name(s) to store the field(s) in the given mesh as point_data or + cell_data. If to few names are given, digits will be appended. + Default: "field" **kwargs Keyword arguments forwareded to `Field.__call__`. Notes ----- This will store the field in the given mesh under the given name, - if a meshio mesh was given. + if a meshio or PyVista mesh was given. See: https://github.com/nschloe/meshio + See: https://github.com/pyvista/pyvista See: :any:`Field.__call__` """ + has_pyvista = False + has_ogs5py = False + + try: + import pyvista as pv + + has_pyvista = True + except ImportError: + pass + try: + import ogs5py as ogs + + has_ogs5py = True + except ImportError: + pass + if isinstance(direction, str) and direction == "all": select = list(range(self.model.dim)) elif isinstance(direction, str): @@ -115,12 +133,25 @@ def mesh( self.model.dim, direction ) ) - if hasattr(mesh, "centroids_flat"): + # convert pyvista mesh + if has_pyvista and pv.is_pyvista_dataset(mesh): + if points == "centroids": + pnts = mesh.cell_centers().points.T[select] + else: + pnts = mesh.points.T[select] + out = self.unstructured(pos=pnts, **kwargs) + # Deal with the output + fields = [out] if isinstance(out, np.ndarray) else out + for f_name, field in zip(_names(name, len(fields)), fields): + mesh[f_name] = field + # convert ogs5py mesh + elif has_ogs5py and isinstance(mesh, ogs.MSH): if points == "centroids": pnts = mesh.centroids_flat.T[select] else: pnts = mesh.NODES.T[select] out = self.unstructured(pos=pnts, **kwargs) + # convert meshio mesh elif isinstance(mesh, meshio.Mesh): if points == "centroids": # define unique order of cells @@ -138,23 +169,20 @@ def mesh( # generate pos for __call__ pnts = pnts.T[select] out = self.unstructured(pos=pnts, **kwargs) - if isinstance(out, np.ndarray): - field = out - else: - # if multiple values are returned, take the first one - field = out[0] - field_list = [] - for i in range(len(offset)): - field_list.append(field[offset[i] : offset[i] + length[i]]) - mesh.cell_data[name] = field_list + fields = [out] if isinstance(out, np.ndarray) else out + f_lists = [] + for field in fields: + f_list = [] + for of, le in zip(offset, length): + f_list.append(field[of : of + le]) + f_lists.append(f_list) + for f_name, f_list in zip(_names(name, len(f_lists)), f_lists): + mesh.cell_data[f_name] = f_list else: out = self.unstructured(pos=mesh.points.T[select], **kwargs) - if isinstance(out, np.ndarray): - field = out - else: - # if multiple values are returned, take the first one - field = out[0] - mesh.point_data[name] = field + fields = [out] if isinstance(out, np.ndarray) else out + for f_name, field in zip(_names(name, len(fields)), fields): + mesh.point_data[f_name] = field else: raise ValueError("Field.mesh: Unknown mesh format!") return out @@ -384,6 +412,13 @@ def __repr__(self): ) +def _names(name, cnt): + name = [name] if isinstance(name, str) else list(name)[:cnt] + if len(name) < cnt: + name += [name[-1] + str(i + 1) for i in range(cnt - len(name))] + return name + + def _get_select(direction): select = [] if not (0 < len(direction) < 4): diff --git a/tests/test_srf.py b/tests/test_srf.py index 4f9a6f26..d9606c61 100644 --- a/tests/test_srf.py +++ b/tests/test_srf.py @@ -10,6 +10,14 @@ from gstools import transform as tf import meshio +HAS_PYVISTA = False +try: + import pyvista as pv + + HAS_PYVISTA = True +except ImportError: + pass + class TestSRF(unittest.TestCase): def setUp(self): @@ -231,6 +239,22 @@ def test_calls(self): self.assertAlmostEqual(field[0, 0], srf.field[0, 0]) self.assertAlmostEqual(field[0, 0], field2[0, 0]) + @unittest.skipIf(not HAS_PYVISTA, "PyVista is not installed") + def test_mesh_pyvista(self): + """Test the `.mesh` call with various PyVista meshes.""" + # Create model + srf = SRF(self.cov_model, mean=self.mean, mode_no=self.mode_no) + # Get the field the normal way for comparison + field = srf((self.x_tuple, self.y_tuple, self.z_tuple), seed=self.seed) + # Create mesh space with PyVista + pv_mesh = pv.PolyData(np.c_[self.x_tuple, self.y_tuple, self.z_tuple]) + # Run the helper + _ = srf.mesh(pv_mesh, seed=self.seed, points="centroids") + self.assertTrue(np.allclose(field, pv_mesh["field"])) + # points="centroids" + _ = srf.mesh(pv_mesh, seed=self.seed, points="points") + self.assertTrue(np.allclose(field, pv_mesh["field"])) + def test_transform(self): self.cov_model.dim = 2 srf = SRF(self.cov_model, mean=self.mean, mode_no=self.mode_no)