Skip to content

Commit

Permalink
Allowing savefig when vtk=false (#2143)
Browse files Browse the repository at this point in the history
* Fix having to plot twice to get a file with size bigger than 3kb

* changing var name

* Implementing savefig argument

* Fixing tests

* Apply suggestions from code review

* Adding more tests
  • Loading branch information
germa89 authored Jun 27, 2023
1 parent 6f22756 commit b51f3fc
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 51 deletions.
174 changes: 124 additions & 50 deletions src/ansys/mapdl/core/mapdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@
]

# test for png file
PNG_TEST = re.compile("WRITTEN TO FILE(.*).png")
PNG_IS_WRITTEN_TO_FILE = re.compile(
"WRITTEN TO FILE"
) # getting the file name is buggy.

VWRITE_REPLACEMENT = """
Cannot use *VWRITE directly as a command in MAPDL
Expand Down Expand Up @@ -1361,8 +1363,8 @@ def nplot(self, nnum="", vtk=None, **kwargs):
if isinstance(nnum, bool):
nnum = int(nnum)

self._enable_interactive_plotting()
return super().nplot(nnum, **kwargs)
with self._enable_interactive_plotting():
return super().nplot(nnum, **kwargs)

def eplot(self, show_node_numbering=False, vtk=None, **kwargs):
"""Plots the currently selected elements.
Expand Down Expand Up @@ -1491,8 +1493,8 @@ def eplot(self, show_node_numbering=False, vtk=None, **kwargs):
)

# otherwise, use MAPDL plotter
self._enable_interactive_plotting()
return self.run("EPLOT", **kwargs)
with self._enable_interactive_plotting():
return self.run("EPLOT", **kwargs)

def vplot(
self,
Expand Down Expand Up @@ -1592,10 +1594,10 @@ def vplot(
self.cmsel("S", cm_name, "AREA", mute=True)
return out
else:
self._enable_interactive_plotting()
return super().vplot(
nv1=nv1, nv2=nv2, ninc=ninc, degen=degen, scale=scale, **kwargs
)
with self._enable_interactive_plotting():
return super().vplot(
nv1=nv1, nv2=nv2, ninc=ninc, degen=degen, scale=scale, **kwargs
)

def aplot(
self,
Expand Down Expand Up @@ -1782,13 +1784,13 @@ def aplot(

return general_plotter(meshes, [], labels, **kwargs)

self._enable_interactive_plotting()
return super().aplot(
na1=na1, na2=na2, ninc=ninc, degen=degen, scale=scale, **kwargs
)
with self._enable_interactive_plotting():
return super().aplot(
na1=na1, na2=na2, ninc=ninc, degen=degen, scale=scale, **kwargs
)

@supress_logging
def _enable_interactive_plotting(self, pixel_res=1600):
def _enable_interactive_plotting(self, pixel_res: int = 1600):
"""Enables interactive plotting. Requires matplotlib
Parameters
Expand All @@ -1799,16 +1801,32 @@ def _enable_interactive_plotting(self, pixel_res=1600):
Increasing the resolution produces a "sharper" image but
takes longer to render.
"""
if not self._has_matplotlib:
raise ImportError(
"Install matplotlib to display plots from MAPDL ,"
"from Python. Otherwise, plot with vtk with:\n"
"``vtk=True``"
)
return self.WithInterativePlotting(self, pixel_res)

class WithInterativePlotting:
"""Allows to redirect plots to MAPDL plots."""

def __init__(self, parent: "_MapdlCore", pixel_res: int) -> None:
self._parent = weakref.ref(parent)
self._pixel_res = pixel_res

def __enter__(self) -> None:
self._parent()._log.debug("Entering in 'WithInterativePlotting' mode")

if not self._parent()._has_matplotlib: # pragma: no cover
raise ImportError(
"Install matplotlib to display plots from MAPDL ,"
"from Python. Otherwise, plot with vtk with:\n"
"``vtk=True``"
)

if not self._png_mode:
self.show("PNG", mute=True)
self.gfile(pixel_res, mute=True)
if not self._parent()._png_mode:
self._parent().show("PNG", mute=True)
self._parent().gfile(self._pixel_res, mute=True)

def __exit__(self, *args) -> None:
self._parent()._log.debug("Exiting in 'WithInterativePlotting' mode")
self._parent().show("close", mute=True)

@property
def _has_matplotlib(self):
Expand Down Expand Up @@ -1957,8 +1975,8 @@ def lplot(

return general_plotter(meshes, [], labels, **kwargs)
else:
self._enable_interactive_plotting()
return super().lplot(nl1=nl1, nl2=nl2, ninc=ninc, **kwargs)
with self._enable_interactive_plotting():
return super().lplot(nl1=nl1, nl2=nl2, ninc=ninc, **kwargs)

def kplot(
self,
Expand Down Expand Up @@ -2033,8 +2051,8 @@ def kplot(
return general_plotter([], points, labels, **kwargs)

# otherwise, use the legacy plotter
self._enable_interactive_plotting()
return super().kplot(np1=np1, np2=np2, ninc=ninc, lab=lab, **kwargs)
with self._enable_interactive_plotting():
return super().kplot(np1=np1, np2=np2, ninc=ninc, lab=lab, **kwargs)

@property
@requires_package("ansys.mapdl.reader", softerror=True)
Expand Down Expand Up @@ -2551,7 +2569,7 @@ def jobname(self) -> str:
def jobname(self, new_jobname: str):
"""Set the jobname"""
self.finish(mute=True)
self.filname(new_jobname, mute=True)
self.filname(new_jobname)
self._jobname = new_jobname

def modal_analysis(
Expand Down Expand Up @@ -3008,7 +3026,13 @@ def run(self, command, write_to_log=True, mute=None, **kwargs) -> str:
short_cmd = parse_to_short_cmd(command)

if short_cmd in PLOT_COMMANDS:
return self._display_plot(self._response)
self._log.debug("It is a plot command.")
plot_path = self._get_plot_name(text)
save_fig = kwargs.get("savefig", False)
if save_fig:
self._download_plot(plot_path, save_fig)
else:
return self._display_plot(plot_path)

return self._response

Expand Down Expand Up @@ -3258,9 +3282,6 @@ def load_table(self, name, array, var1="", var2="", var3="", csysid=""):
else:
self.slashdelete(filename)

def _display_plot(self, *args, **kwargs): # pragma: no cover
raise NotImplementedError("Implemented by child class")

def _run(self, *args, **kwargs): # pragma: no cover
raise NotImplementedError("Implemented by child class")

Expand Down Expand Up @@ -3488,34 +3509,87 @@ def _get_array(
else:
return array

def _display_plot(self, text):
"""Display the last generated plot (*.png) from MAPDL"""
import scooby
def _get_plot_name(self, text: str) -> str:
""" "Obtain the plot filename. It also downloads it if in remote session."""
self._log.debug(text)
png_found = PNG_IS_WRITTEN_TO_FILE.findall(text)

self._enable_interactive_plotting()
png_found = PNG_TEST.findall(text)
if png_found:
# flush graphics writer
self.show("CLOSE", mute=True)
self.show("PNG", mute=True)

import matplotlib.image as mpimg
import matplotlib.pyplot as plt
# self.show("PNG", mute=True)

filename = self._screenshot_path()
self._log.debug(f"Screenshot at: {filename}")

if os.path.isfile(filename):
img = mpimg.imread(filename)
plt.imshow(img)
plt.axis("off")
if self._show_matplotlib_figures: # pragma: no cover
plt.show() # consider in-line plotting
if scooby.in_ipython():
from IPython.display import display

display(plt.gcf())
return filename
else: # pragma: no cover
self._log.error("Unable to find screenshot at %s", filename)
else:
self._log.error("Unable to find file in MAPDL command output.")

def _display_plot(self, filename: str) -> None:
"""Display the last generated plot (*.png) from MAPDL"""
import matplotlib.image as mpimg
import matplotlib.pyplot as plt

def in_ipython():
# from scooby.in_ipython
# to avoid dependency here.
try:
__IPYTHON__
return True
except NameError: # pragma: no cover
return False

self._log.debug("A screenshot file has been found.")
img = mpimg.imread(filename)
plt.imshow(img)
plt.axis("off")

if self._show_matplotlib_figures: # pragma: no cover
self._log.debug("Using Matplotlib to plot")
plt.show() # consider in-line plotting

if in_ipython():
self._log.debug("Using ipython")
from IPython.display import display

display(plt.gcf())

def _download_plot(self, filename: str, plot_name: str) -> None:
"""Copy the temporary download plot to the working directory."""
if isinstance(plot_name, str):
provided = True
path_ = pathlib.Path(plot_name)
plot_name = path_.name
plot_stem = path_.stem
plot_ext = path_.suffix
plot_path = str(path_.parent)
if not plot_path or plot_path == ".":
plot_path = os.getcwd()

elif isinstance(plot_name, bool):
provided = False
plot_name = "plot.png"
plot_stem = "plot"
plot_ext = ".png"
plot_path = os.getcwd()
else: # pragma: no cover
raise ValueError("Only booleans and str are allowed.")

id_ = 0
plot_path_ = os.path.join(plot_path, plot_name)
while os.path.exists(plot_path_) and not provided:
id_ += 1
plot_path_ = os.path.join(plot_path, f"{plot_stem}_{id_}{plot_ext}")
else:
copyfile(filename, plot_path_)

self._log.debug(
f"Copy plot file from temp directory to working directory as: {plot_path}"
)

def _screenshot_path(self):
"""Return last filename based on the current jobname"""
Expand Down
2 changes: 1 addition & 1 deletion src/ansys/mapdl/core/mapdl_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2377,7 +2377,7 @@ def _screenshot_path(self):
all_filenames = self.list_files()
filenames = []
for filename in all_filenames:
if ".png" == filename[-4:]:
if filename.endswith(".png"):
filenames.append(filename)
filenames.sort()
filename = os.path.basename(filenames[-1])
Expand Down
47 changes: 47 additions & 0 deletions tests/test_mapdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1982,3 +1982,50 @@ def num_():
def test_download_results_non_local(mapdl, cube_solve):
assert mapdl.result is not None
assert isinstance(mapdl.result, Result)


def test_download_file_with_vkt_false(mapdl, cube_solve, tmpdir):
# Testing basic behaviour
mapdl.eplot(vtk=False, savefig="myfile.png")
assert os.path.exists("myfile.png")
ti_m = os.path.getmtime("myfile.png")

# Testing overwriting
mapdl.eplot(vtk=False, savefig="myfile.png")
assert not os.path.exists("myfile_1.png")
assert os.path.getmtime("myfile.png") != ti_m # file has been modified.

os.remove("myfile.png")

# Testing no extension
mapdl.eplot(vtk=False, savefig="myfile")
assert os.path.exists("myfile")
os.remove("myfile")

# Testing update name when file exists.
mapdl.eplot(vtk=False, savefig=True)
assert os.path.exists("plot.png")

mapdl.eplot(vtk=False, savefig=True)
assert os.path.exists("plot_1.png")

os.remove("plot.png")
os.remove("plot_1.png")

# Testing full path for downloading
plot_ = os.path.join(tmpdir, "myplot.png")
mapdl.eplot(vtk=False, savefig=plot_)
assert os.path.exists(plot_)

plot_ = os.path.join(tmpdir, "myplot")
mapdl.eplot(vtk=False, savefig=plot_)
assert os.path.exists(plot_)


def test_plots_no_vtk(mapdl):
mapdl.kplot(vtk=False)
mapdl.lplot(vtk=False)
mapdl.aplot(vtk=False)
mapdl.vplot(vtk=False)
mapdl.nplot(vtk=False)
mapdl.eplot(vtk=False)

0 comments on commit b51f3fc

Please sign in to comment.