Skip to content
This repository has been archived by the owner on Dec 6, 2024. It is now read-only.

extension to enable/disable remote signals #2

Merged
merged 2 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions src/neptune_experimental/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,51 @@
from typing import Any
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have CHANGELOG here as well :D. You can just copy-paste what we had earlier in regular package

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, on it! :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


import neptune
from neptune.internal.backgroud_job_list import BackgroundJobList
from neptune.internal.hardware.hardware_metric_reporting_job import HardwareMetricReportingJob
from neptune.internal.streams.std_capture_background_job import (
StderrCaptureBackgroundJob,
StdoutCaptureBackgroundJob,
)
from neptune.internal.utils import verify_type
from neptune.internal.utils.ping_background_job import PingBackgroundJob
from neptune.internal.utils.traceback_job import TracebackJob
from neptune.internal.websockets.websocket_signals_background_job import WebsocketSignalsBackgroundJob


# That's just a boilerplate code to make sure that the extension is loaded
class CustomRun(neptune.Run):
def __init__(self, *args: Any, **kwargs: Any) -> None:
print("That's custom class")
enable_remote_signals = kwargs.pop("enable_remote_signals", None)

kwargs["capture_hardware_metrics"] = False
kwargs["capture_stdout"] = False
kwargs["capture_stderr"] = False
kwargs["capture_traceback"] = False
if enable_remote_signals is None:
self._enable_remote_signals = True # user did not pass this param in kwargs -> default value
else:

verify_type("enable_remote_signals", enable_remote_signals, bool)
self._enable_remote_signals = enable_remote_signals

super().__init__(*args, **kwargs)

def _prepare_background_jobs(self) -> BackgroundJobList:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're probably able to get rid of this duplicated logic with calling the parent _prepare_background_jobs and filtering out the WebsocketsSignalsBackgroundJob? What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

background_jobs = [PingBackgroundJob()]

if self._enable_remote_signals:
websockets_factory = self._backend.websockets_factory(self._project_api_object.id, self._id)
if websockets_factory:
background_jobs.append(WebsocketSignalsBackgroundJob(websockets_factory))

if self._capture_stdout:
background_jobs.append(StdoutCaptureBackgroundJob(attribute_name=self._stdout_path))

if self._capture_stderr:
background_jobs.append(StderrCaptureBackgroundJob(attribute_name=self._stderr_path))

if self._capture_hardware_metrics:
background_jobs.append(HardwareMetricReportingJob(attribute_namespace=self._monitoring_namespace))

if self._capture_traceback:
background_jobs.append(
TracebackJob(path=f"{self._monitoring_namespace}/traceback", fail_on_exception=self._fail_on_exception)
)

return BackgroundJobList(background_jobs)
10 changes: 10 additions & 0 deletions tests/unit/test_custom_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from unittest.mock import patch

from neptune import Run

from neptune_experimental.run import CustomRun
Expand All @@ -21,3 +23,11 @@
def test_custom_run():
with Run(mode="debug") as run:
assert isinstance(run, CustomRun)


@patch("neptune.internal.backends.neptune_backend_mock.NeptuneBackendMock.websockets_factory")
@patch("neptune.internal.websockets.websocket_signals_background_job.WebsocketSignalsBackgroundJob")
def test_disabled_remote_signals(ws_factory, signals_job):
with Run(mode="debug", enable_remote_signals=False):
assert not ws_factory.called
assert not signals_job.called
Loading