Skip to content

Commit

Permalink
feat: Add image_node and vfolder_node fields to ComputeSession schema
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Nov 7, 2024
1 parent 1ca70ca commit 737f7f5
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 44 deletions.
1 change: 1 addition & 0 deletions changes/2987.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add image_node and vfolder_node fields to ComputeSession schema
23 changes: 21 additions & 2 deletions src/ai/backend/manager/api/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,21 @@ type Queries {
id: GlobalIDField!

"""Added in 24.09.0."""
project_id: UUID!
project_id: UUID @deprecated(reason: "Deprecated since 24.12.0.")

"""Added in 24.09.0. Default is read_attribute."""
permission: SessionPermissionValueField = "read_attribute"
): ComputeSessionNode

"""Added in 24.09.0."""
compute_session_nodes(
"""
Added in 24.12.0. Default value `system` queries across the entire system.
"""
scope_id: ScopeField

"""Added in 24.09.0."""
project_id: UUID!
project_id: UUID @deprecated(reason: "Deprecated since 24.12.0. Use `scope_id` instead.")

"""Added in 24.09.0. Default is read_attribute."""
permission: SessionPermissionValueField = "read_attribute"
Expand Down Expand Up @@ -591,6 +596,14 @@ type KernelNode implements Node {
cluster_hostname: String
session_id: UUID
image: ImageNode

"""Added in 24.12.0."""
image_reference: String

"""
Added in 24.12.0. The architecture that the image of this kernel requires
"""
architecture: String
status: String
status_changed: DateTime
status_info: String
Expand Down Expand Up @@ -1201,6 +1214,12 @@ type ComputeSessionNode implements Node {
vfolder_mounts: [String]
occupied_slots: JSONString
requested_slots: JSONString

"""Added in 24.12.0."""
image_references: [String]

"""Added in 24.12.0."""
vfolder_nodes(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): VirtualFolderConnection
num_queries: BigInt
inference_metrics: JSONString
kernel_nodes(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): KernelConnection
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/manager/models/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ async def from_row(
return cls(
endpoint_id=row.id,
# image="", # deprecated, row.image_object.name,
image_object=ImageNode.from_row(row.image_row),
image_object=ImageNode.from_row(ctx, row.image_row),
domain=row.domain,
project=row.project,
resource_group=row.resource_group,
Expand Down
40 changes: 33 additions & 7 deletions src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
LegacyComputeSessionList,
)
from .keypair import CreateKeyPair, DeleteKeyPair, KeyPair, KeyPairList, ModifyKeyPair
from .rbac import ScopeType, SystemScope
from .rbac import ProjectScope, ScopeType, SystemScope
from .rbac.permission_defs import AgentPermission, ComputeSessionPermission, DomainPermission
from .rbac.permission_defs import VFolderPermission as VFolderRBACPermission
from .resource_policy import (
Expand Down Expand Up @@ -726,7 +726,11 @@ class Queries(graphene.ObjectType):
ComputeSessionNode,
description="Added in 24.09.0.",
id=GlobalIDField(required=True),
project_id=graphene.UUID(required=True, description="Added in 24.09.0."),
project_id=graphene.UUID(
required=False,
description="Added in 24.09.0.",
deprecation_reason="Deprecated since 24.12.0.",
),
permission=SessionPermissionValueField(
default_value=ComputeSessionPermission.READ_ATTRIBUTE,
description=f"Added in 24.09.0. Default is {ComputeSessionPermission.READ_ATTRIBUTE.value}.",
Expand All @@ -736,7 +740,15 @@ class Queries(graphene.ObjectType):
compute_session_nodes = PaginatedConnectionField(
ComputeSessionConnection,
description="Added in 24.09.0.",
project_id=graphene.UUID(required=True, description="Added in 24.09.0."),
scope_id=ScopeField(
required=False,
description="Added in 24.12.0. Default value `system` queries across the entire system.",
),
project_id=graphene.UUID(
required=False,
description="Added in 24.09.0.",
deprecation_reason="Deprecated since 24.12.0. Use `scope_id` instead.",
),
permission=SessionPermissionValueField(
default_value=ComputeSessionPermission.READ_ATTRIBUTE,
description=f"Added in 24.09.0. Default is {ComputeSessionPermission.READ_ATTRIBUTE.value}.",
Expand Down Expand Up @@ -2043,17 +2055,23 @@ async def resolve_compute_session_node(
info: graphene.ResolveInfo,
*,
id: ResolvedGlobalID,
project_id: uuid.UUID,
project_id: Optional[uuid.UUID] = None,
permission: ComputeSessionPermission = ComputeSessionPermission.READ_ATTRIBUTE,
) -> ComputeSessionNode | None:
return await ComputeSessionNode.get_accessible_node(info, id, project_id, permission)
scope_id: ScopeType
if project_id is None:
scope_id = SystemScope()
else:
scope_id = ProjectScope(project_id=project_id)
return await ComputeSessionNode.get_accessible_node(info, id, scope_id, permission)

@staticmethod
async def resolve_compute_session_nodes(
root: Any,
info: graphene.ResolveInfo,
*,
project_id: uuid.UUID,
scope_id: Optional[ScopeType] = None,
project_id: Optional[uuid.UUID] = None,
permission: ComputeSessionPermission = ComputeSessionPermission.READ_ATTRIBUTE,
filter: str | None = None,
order: str | None = None,
Expand All @@ -2063,9 +2081,17 @@ async def resolve_compute_session_nodes(
before: str | None = None,
last: int | None = None,
) -> ConnectionResolverResult[ComputeSessionNode]:
_scope_id: ScopeType
if scope_id is not None:
_scope_id = scope_id
else:
if project_id is not None:
_scope_id = ProjectScope(project_id=project_id)
else:
_scope_id = SystemScope()
return await ComputeSessionNode.get_accessible_connection(
info,
project_id,
_scope_id,
permission,
filter,
order,
Expand Down
29 changes: 22 additions & 7 deletions src/ai/backend/manager/models/gql_models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AsyncIterator,
List,
Optional,
Self,
overload,
)
from uuid import UUID
Expand All @@ -27,12 +28,18 @@
ImageAlias,
)
from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream
from ai.backend.manager.models.container_registry import ContainerRegistryRow, ContainerRegistryType

from ...api.exceptions import ImageNotFound, ObjectNotFound
from ...defs import DEFAULT_IMAGE_ARCH
<<<<<<< HEAD
from ..base import batch_multiresult_in_scalar_stream, set_if_set
from ..gql_relay import AsyncNode
=======
from ..base import set_if_set
from ..gql_relay import AsyncNode, Connection
>>>>>>> 3d32becb0 (feat: Add image_node and vfolder_node fields to ComputeSession schema)
from ..image import (
ImageAliasRow,
ImageIdentifier,
Expand Down Expand Up @@ -374,16 +381,18 @@ async def batch_load_by_image_identifier(
name_and_arch_tuples = [(img.canonical, img.architecture) for img in image_ids]
return await cls.batch_load_by_name_and_arch(graph_ctx, name_and_arch_tuples)

@overload

@overload
@classmethod
def from_row(cls, row: ImageRow) -> ImageNode: ...
def from_row(cls, graph_ctx: GraphQueryContext, row: ImageRow) -> Self: ...

@overload
@classmethod
def from_row(cls, row: None) -> None: ...
def from_row(cls, graph_ctx: GraphQueryContext, row: None) -> None: ...

@classmethod
def from_row(cls, row: ImageRow | None) -> ImageNode | None:
def from_row(cls, graph_ctx: GraphQueryContext, row: ImageRow | None) -> Self | None:
if row is None:
return None
image_ref = row.image_ref
Expand Down Expand Up @@ -455,7 +464,13 @@ async def get_node(cls, info: graphene.ResolveInfo, id: str) -> ImageNode:
image_row = await db_session.scalar(query)
if image_row is None:
raise ValueError(f"Image not found (id: {image_id})")
return cls.from_row(image_row)
return cls.from_row(graph_ctx, image_row)


class ImageConnection(Connection):
class Meta:
node = ImageNode
description = "Added in 24.12.0."


class ForgetImageById(graphene.Mutation):
Expand Down Expand Up @@ -507,7 +522,7 @@ async def mutate(
):
return ForgetImageById(ok=False, msg="Forbidden")
await session.delete(image_row)
return ForgetImageById(ok=True, msg="", image=ImageNode.from_row(image_row))
return ForgetImageById(ok=True, msg="", image=ImageNode.from_row(ctx, image_row))


class ForgetImage(graphene.Mutation):
Expand Down Expand Up @@ -554,7 +569,7 @@ async def mutate(
):
return ForgetImage(ok=False, msg="Forbidden")
await session.delete(image_row)
return ForgetImage(ok=True, msg="", image=ImageNode.from_row(image_row))
return ForgetImage(ok=True, msg="", image=ImageNode.from_row(ctx, image_row))


class UntagImageFromRegistry(graphene.Mutation):
Expand Down Expand Up @@ -620,7 +635,7 @@ async def mutate(
scanner = HarborRegistry_v2(ctx.db, image_row.image_ref.registry, registry_info)
await scanner.untag(image_row.image_ref)

return UntagImageFromRegistry(ok=True, msg="", image=ImageNode.from_row(image_row))
return UntagImageFromRegistry(ok=True, msg="", image=ImageNode.from_row(ctx, image_row))


class PreloadImage(graphene.Mutation):
Expand Down
30 changes: 22 additions & 8 deletions src/ai/backend/manager/models/gql_models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from typing import (
TYPE_CHECKING,
Any,
Optional,
Self,
cast,
)

import graphene
Expand All @@ -14,10 +16,7 @@

from ai.backend.common import msgpack, redis_helper
from ai.backend.common.types import AgentId, KernelId, SessionId
from ai.backend.manager.models.base import (
batch_multiresult_in_scalar_stream,
batch_multiresult_in_session,
)
from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream

from ..gql_relay import AsyncNode, Connection
from ..kernel import KernelRow, KernelStatus
Expand Down Expand Up @@ -48,6 +47,10 @@ class Meta:

# image
image = graphene.Field(ImageNode)
image_reference = graphene.String(description="Added in 24.12.0.")
architecture = graphene.String(
description="Added in 24.12.0. The architecture that the image of this kernel requires"
)

# status
status = graphene.String()
Expand Down Expand Up @@ -75,11 +78,9 @@ async def batch_load_by_session_id(
graph_ctx: GraphQueryContext,
session_ids: Sequence[SessionId],
) -> Sequence[Sequence[Self]]:
from ..kernel import kernels

async with graph_ctx.db.begin_readonly_session() as db_sess:
query = sa.select(kernels).where(kernels.c.session_id.in_(session_ids))
return await batch_multiresult_in_session(
query = sa.select(KernelRow).where(KernelRow.session_id.in_(session_ids))
return await batch_multiresult_in_scalar_stream(
graph_ctx,
db_sess,
query,
Expand Down Expand Up @@ -122,6 +123,8 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self:
local_rank=row.local_rank,
cluster_role=row.cluster_role,
session_id=row.session_id,
architecture=row.architecture,
image_reference=row.image,
status=row.status,
status_changed=row.status_changed,
status_info=row.status_info,
Expand All @@ -138,6 +141,17 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self:
preopen_ports=row.preopen_ports,
)

async def resolve_image(self, info: graphene.ResolveInfo) -> Optional[ImageNode]:
graph_ctx: GraphQueryContext = info.context
loader = graph_ctx.dataloader_manager.get_loader_by_func(
graph_ctx, ImageNode.batch_load_by_name_and_arch
)
images = cast(list[ImageNode], await loader.load((self.image_reference, self.architecture)))
try:
return images[0]
except IndexError:
return None

async def resolve_live_stat(self, info: graphene.ResolveInfo) -> dict[str, Any] | None:
graph_ctx: GraphQueryContext = info.context
loader = graph_ctx.dataloader_manager.get_loader_by_func(
Expand Down
Loading

0 comments on commit 737f7f5

Please sign in to comment.