Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(anywidget): Hoist static assets (_esm, _css) to share among front-end widget instances #628

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion anywidget/_file_contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from psygnal import Signal

__all__ = ["FileContents", "VirtualFileContents", "_VIRTUAL_FILES"]
__all__ = ["_VIRTUAL_FILES", "FileContents", "VirtualFileContents"]

_VIRTUAL_FILES: weakref.WeakValueDictionary[str, VirtualFileContents] = (
weakref.WeakValueDictionary()
Expand Down
58 changes: 58 additions & 0 deletions anywidget/_static_asset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

import pathlib
import typing

from anywidget._file_contents import VirtualFileContents

from ._descriptor import open_comm
from ._util import try_file_contents

if typing.TYPE_CHECKING:
import pathlib

import comm


def send_asset_to_front_end(comm: comm.base_comm.BaseComm, contents: str) -> None:
"""Send the static asset to the front end."""
msg = {"method": "update", "state": {"data": contents}, "buffer_paths": []}
comm.send(data=msg, buffers=[])


class StaticAsset:
"""
Represents a static asset (e.g. a file) for the anywidget front end.

This class is used _internally_ to hoist static files (_esm, _css) into
the front end such that they can be shared across widget instances. This
implementation detail may change in the future, so this class is not
intended for direct use in user code.
"""

def __init__(self, data: str | pathlib.Path) -> None:
"""
Create a static asset for the anywidget front end.

Parameters
----------
data : str or pathlib.Path
The data to be shared with the front end.
"""
self._comm = open_comm()
self._file_contents = try_file_contents(data) or VirtualFileContents(str(data))
send_asset_to_front_end(self._comm, str(self))
self._file_contents.changed.connect(
lambda contents: send_asset_to_front_end(self._comm, contents)
)

def __str__(self) -> str:
"""Return the string representation of the asset."""
return str(self._file_contents)

def __del__(self) -> None:
"""Close the comm when the asset is deleted."""
self._comm.close()

def serialize(self) -> str:
return f"anywidget-static-asset:{self._comm.comm_id}"
69 changes: 45 additions & 24 deletions anywidget/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

from __future__ import annotations

import typing
from contextlib import contextmanager

import ipywidgets
import traitlets.traitlets as t

from ._file_contents import FileContents, VirtualFileContents
from ._static_asset import StaticAsset
from ._util import (
_ANYWIDGET_ID_KEY,
_CSS_KEY,
Expand All @@ -14,7 +17,6 @@
enable_custom_widget_manager_once,
in_colab,
repr_mimebundle,
try_file_contents,
)
from ._version import _ANYWIDGET_SEMVER_VERSION
from .experimental import _collect_anywidget_commands, _register_anywidget_commands
Expand All @@ -37,40 +39,28 @@ def __init__(self, *args: object, **kwargs: object) -> None:
if in_colab():
enable_custom_widget_manager_once()

anywidget_traits = {}
self._anywidget_internal_state = {}
for key in (_ESM_KEY, _CSS_KEY):
if hasattr(self, key) and not self.has_trait(key):
value = getattr(self, key)
anywidget_traits[key] = t.Unicode(str(value)).tag(sync=True)
if isinstance(value, (VirtualFileContents, FileContents)):
value.changed.connect(
lambda new_contents, key=key: setattr(self, key, new_contents),
)

self._anywidget_internal_state[key] = getattr(self, key)
# show default _esm if not defined
if not hasattr(self, _ESM_KEY):
anywidget_traits[_ESM_KEY] = t.Unicode(_DEFAULT_ESM).tag(sync=True)
self._anywidget_internal_state[_ESM_KEY] = _DEFAULT_ESM
self._anywidget_internal_state[_ANYWIDGET_ID_KEY] = _id_for(self)

# TODO(manzt): a better way to uniquely identify this subclasses? # noqa: TD003
# We use the fully-qualified name to get an id which we
# can use to update CSS if necessary.
anywidget_traits[_ANYWIDGET_ID_KEY] = t.Unicode(
f"{self.__class__.__module__}.{self.__class__.__name__}",
).tag(sync=True)
with _patch_get_state(self, self._anywidget_internal_state):
super().__init__(*args, **kwargs)

self.add_traits(**anywidget_traits)
super().__init__(*args, **kwargs)
_register_anywidget_commands(self)

def __init_subclass__(cls, **kwargs: dict) -> None:
"""Coerces _esm and _css to FileContents if they are files."""
super().__init_subclass__(**kwargs)
for key in (_ESM_KEY, _CSS_KEY) & cls.__dict__.keys():
# TODO(manzt): Upgrate to := when we drop Python 3.7
# https://github.com/manzt/anywidget/pull/167
file_contents = try_file_contents(getattr(cls, key))
if file_contents:
setattr(cls, key, file_contents)
# TODO: Upgrate to := when we drop Python 3.7 # noqa: TD002, TD003
value = getattr(cls, key)
if not isinstance(value, StaticAsset):
setattr(cls, key, StaticAsset(value))
_collect_anywidget_commands(cls)

def _repr_mimebundle_(self, **kwargs: dict) -> tuple[dict, dict] | None: # noqa: ARG002
Expand All @@ -80,3 +70,34 @@ def _repr_mimebundle_(self, **kwargs: dict) -> tuple[dict, dict] | None: # noqa
if self._view_name is None:
return None # type: ignore[unreachable]
return repr_mimebundle(model_id=self.model_id, repr_text=plaintext)


def _id_for(obj: object) -> str:
"""Return a unique identifier for an object."""
# TODO: a better way to uniquely identify this subclasses? # noqa: TD002, TD003
# We use the fully-qualified name to get an id which we
# can use to update CSS if necessary.
return f"{obj.__class__.__module__}.{obj.__class__.__name__}"


@contextmanager
def _patch_get_state(
widget: AnyWidget, extra_state: dict[str, str | StaticAsset]
) -> typing.Generator[None, None, None]:
"""Patch get_state to include anywidget-specific data."""
original_get_state = widget.get_state

def temp_get_state() -> dict:
return {
**original_get_state(),
**{
k: v.serialize() if isinstance(v, StaticAsset) else v
for k, v in extra_state.items()
},
}

widget.get_state = temp_get_state
try:
yield
finally:
widget.get_state = original_get_state
111 changes: 96 additions & 15 deletions packages/anywidget/src/widget.js
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,72 @@ function throw_anywidget_error(source) {
throw source;
}

/**
* @param {unknown} v
* @return {v is {}}
*/
function is_object(v) {
return typeof v === "object" && v !== null;
}

/**
* @param {unknown} v
* @return {v is import("@jupyter-widgets/base").DOMWidgetModel}
*/
function is_model(v) {
return is_object(v) && "on" in v && typeof v.on === "function";
}

/**
* @template {"_esm" | "_css"} T
* @param {import("@jupyter-widgets/base").DOMWidgetModel} model
* @param {T} asset_name
* @returns {{ get(name: T): string, on(event: `change:${T}`, callback: () => void): void, off(event: `change:${T}`): void }}
*/
function resolve_asset_model(model, asset_name) {
let value = model.get(asset_name);
if (is_model(value)) {
return {
/** @param {T} _name */
get(_name) {
return value.get("data");
},
/**
* @param {`change:${T}`} _event
* @param {() => void} callback
*/
on(_event, callback) {
value.on("change:data", callback);
},
/**
* @param {`change:${T}`} _event
*/
off(_event) {
return value.off("change:data");
},
};
}
return model;
}

/**
* @template {"_esm" | "_css"} T
* @param {import("@jupyter-widgets/base").DOMWidgetModel} base_model
* @param {T} asset_name
* @param {() => void} cb
*/
function create_asset_signal(base_model, asset_name, cb) {
let model = resolve_asset_model(base_model, asset_name);
/** @type {import("solid-js").Signal<string>} */
let [asset, set_asset] = solid.createSignal(model.get(asset_name));
model.on(`change:${asset_name}`, () => {
cb();
set_asset(model.get(asset_name));
});
solid.onCleanup(() => model.off(`change:${asset_name}`));
return asset;
}

/**
* @typedef InvokeOptions
* @prop {DataView[]} [buffers]
Expand Down Expand Up @@ -312,25 +378,18 @@ class Runtime {

/** @param {base.DOMWidgetModel} model */
constructor(model) {
let id = () => model.get("_anywidget_id");

this.#disposer = solid.createRoot((dispose) => {
let [css, set_css] = solid.createSignal(model.get("_css"));
model.on("change:_css", () => {
let id = model.get("_anywidget_id");
console.debug(`[anywidget] css hot updated: ${id}`);
set_css(model.get("_css"));
});
solid.createEffect(() => {
let id = model.get("_anywidget_id");
load_css(css(), id);
let css = create_asset_signal(model, "_css", () => {
console.debug(`[anywidget] css hot updated: ${id()}`);
});
solid.createEffect(() => load_css(css(), id()));

/** @type {import("solid-js").Signal<string>} */
let [esm, setEsm] = solid.createSignal(model.get("_esm"));
model.on("change:_esm", async () => {
let id = model.get("_anywidget_id");
console.debug(`[anywidget] esm hot updated: ${id}`);
setEsm(model.get("_esm"));
let esm = create_asset_signal(model, "_esm", () => {
console.debug(`[anywidget] esm hot updated: ${id()}`);
});

/** @type {void | (() => Awaitable<void>)} */
let cleanup;
this.#widget_result = solid.createResource(esm, async (update) => {
Expand Down Expand Up @@ -419,6 +478,22 @@ class Runtime {
}
}

let anywidget_static_asset = {
/** @param {{ model_id: string }} model */
serialize(model) {
return `anywidget-static-asset:${model.model_id}`;
},
/**
* @param {string} value
* @param {import("@jupyter-widgets/base").DOMWidgetModel["widget_manager"]} widget_manager
*/
async deserialize(value, widget_manager) {
let model_id = value.slice("anywidget-static-asset:".length);
let model = await widget_manager.get_model(model_id);
return model;
},
};

// @ts-expect-error - injected by bundler
let version = globalThis.VERSION;

Expand Down Expand Up @@ -453,6 +528,12 @@ export default function ({ DOMWidgetModel, DOMWidgetView }) {
RUNTIMES.set(this, runtime);
}

static serializers = {
...DOMWidgetModel.serializers,
_esm: anywidget_static_asset,
_css: anywidget_static_asset,
};

/**
* @param {Record<string, any>} state
*
Expand Down
Loading