Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(BA-604): Add image_node and vfolder_node fields to ComputeSession schema #2987

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/2987.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add image_node and vfolder_node fields to ComputeSession schema
18 changes: 17 additions & 1 deletion docs/manager/graphql-reference/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,10 @@

"""Added in 24.09.0."""
compute_session_nodes(
"""Added in 24.12.0."""
"""
Added in 25.2.0. Default value `system` queries across the entire system.
"""
scope_id: ScopeField

Check warning on line 176 in docs/manager/graphql-reference/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Description for argument 'scope_id' on field 'Queries.compute_session_nodes' changed from 'Added in 24.12.0.' to 'Added in 25.2.0. Default value `system` queries across the entire system.'

Description for argument 'scope_id' on field 'Queries.compute_session_nodes' changed from 'Added in 24.12.0.' to 'Added in 25.2.0. Default value `system` queries across the entire system.'

"""Added in 24.09.0."""
project_id: UUID @deprecated(reason: "Deprecated since 24.12.0. use `scope_id` instead.")
Expand Down Expand Up @@ -612,6 +614,14 @@
cluster_hostname: String
session_id: UUID
image: ImageNode

"""Added in 25.2.0."""
image_reference: String

Check notice on line 619 in docs/manager/graphql-reference/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'image_reference' was added to object type 'KernelNode'

Field 'image_reference' was added to object type 'KernelNode'

"""
Added in 24.12.0. The architecture that the image of this kernel requires
"""
architecture: String

Check notice on line 624 in docs/manager/graphql-reference/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'architecture' was added to object type 'KernelNode'

Field 'architecture' was added to object type 'KernelNode'
status: String
status_changed: DateTime
status_info: String
Expand Down Expand Up @@ -1227,6 +1237,12 @@
vfolder_mounts: [String]
occupied_slots: JSONString
requested_slots: JSONString

"""Added in 25.2.0."""
image_references: [String]

Check notice on line 1242 in docs/manager/graphql-reference/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'image_references' was added to object type 'ComputeSessionNode'

Field 'image_references' was added to object type 'ComputeSessionNode'

"""Added in 25.2.0."""
vfolder_nodes(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): VirtualFolderConnection

Check notice on line 1245 in docs/manager/graphql-reference/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'vfolder_nodes' was added to object type 'ComputeSessionNode'

Field 'vfolder_nodes' was added to object type 'ComputeSessionNode'
num_queries: BigInt
inference_metrics: JSONString
kernel_nodes(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): KernelConnection
ironAiken2 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/manager/models/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,7 +1012,7 @@ async def from_row(
return cls(
endpoint_id=row.id,
# image="", # deprecated, row.image_object.name,
image_object=ImageNode.from_row(row.image_row),
image_object=ImageNode.from_row(ctx, row.image_row),
domain=row.domain,
project=row.project,
resource_group=row.resource_group,
Expand Down
4 changes: 3 additions & 1 deletion src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,9 @@ class Queries(graphene.ObjectType):
compute_session_nodes = PaginatedConnectionField(
ComputeSessionConnection,
description="Added in 24.09.0.",
scope_id=ScopeField(description="Added in 24.12.0."),
scope_id=ScopeField(
description="Added in 25.2.0. Default value `system` queries across the entire system."
),
project_id=graphene.UUID(
required=False,
description="Added in 24.09.0.",
Expand Down
23 changes: 15 additions & 8 deletions src/ai/backend/manager/models/gql_models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AsyncIterator,
List,
Optional,
Self,
overload,
)
from uuid import UUID
Expand All @@ -32,7 +33,7 @@
from ...api.exceptions import ImageNotFound, ObjectNotFound
from ...defs import DEFAULT_IMAGE_ARCH
from ..base import batch_multiresult_in_scalar_stream, set_if_set
from ..gql_relay import AsyncNode
from ..gql_relay import AsyncNode, Connection
from ..image import (
ImageAliasRow,
ImageIdentifier,
Expand Down Expand Up @@ -376,14 +377,14 @@ async def batch_load_by_image_identifier(

@overload
@classmethod
def from_row(cls, row: ImageRow) -> ImageNode: ...
def from_row(cls, graph_ctx: GraphQueryContext, row: ImageRow) -> Self: ...

@overload
@classmethod
def from_row(cls, row: None) -> None: ...
def from_row(cls, graph_ctx: GraphQueryContext, row: None) -> None: ...

@classmethod
def from_row(cls, row: ImageRow | None) -> ImageNode | None:
def from_row(cls, graph_ctx: GraphQueryContext, row: ImageRow | None) -> Self | None:
if row is None:
return None
image_ref = row.image_ref
Expand Down Expand Up @@ -455,7 +456,13 @@ async def get_node(cls, info: graphene.ResolveInfo, id: str) -> ImageNode:
image_row = await db_session.scalar(query)
if image_row is None:
raise ValueError(f"Image not found (id: {image_id})")
return cls.from_row(image_row)
return cls.from_row(graph_ctx, image_row)


class ImageConnection(Connection):
class Meta:
node = ImageNode
description = "Added in 25.2.0."


class ForgetImageById(graphene.Mutation):
Expand Down Expand Up @@ -507,7 +514,7 @@ async def mutate(
):
return ForgetImageById(ok=False, msg="Forbidden")
await session.delete(image_row)
return ForgetImageById(ok=True, msg="", image=ImageNode.from_row(image_row))
return ForgetImageById(ok=True, msg="", image=ImageNode.from_row(ctx, image_row))


class ForgetImage(graphene.Mutation):
Expand Down Expand Up @@ -554,7 +561,7 @@ async def mutate(
):
return ForgetImage(ok=False, msg="Forbidden")
await session.delete(image_row)
return ForgetImage(ok=True, msg="", image=ImageNode.from_row(image_row))
return ForgetImage(ok=True, msg="", image=ImageNode.from_row(ctx, image_row))


class UntagImageFromRegistry(graphene.Mutation):
Expand Down Expand Up @@ -620,7 +627,7 @@ async def mutate(
scanner = HarborRegistry_v2(ctx.db, image_row.image_ref.registry, registry_info)
await scanner.untag(image_row.image_ref)

return UntagImageFromRegistry(ok=True, msg="", image=ImageNode.from_row(image_row))
return UntagImageFromRegistry(ok=True, msg="", image=ImageNode.from_row(ctx, image_row))


class PreloadImage(graphene.Mutation):
Expand Down
30 changes: 22 additions & 8 deletions src/ai/backend/manager/models/gql_models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from typing import (
TYPE_CHECKING,
Any,
Optional,
Self,
cast,
)

import graphene
Expand All @@ -14,10 +16,7 @@

from ai.backend.common import msgpack, redis_helper
from ai.backend.common.types import AgentId, KernelId, SessionId
from ai.backend.manager.models.base import (
batch_multiresult_in_scalar_stream,
batch_multiresult_in_session,
)
from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream

from ..gql_relay import AsyncNode, Connection
from ..kernel import KernelRow, KernelStatus
Expand Down Expand Up @@ -48,6 +47,10 @@ class Meta:

# image
image = graphene.Field(ImageNode)
image_reference = graphene.String(description="Added in 25.2.0.")
architecture = graphene.String(
description="Added in 24.12.0. The architecture that the image of this kernel requires"
)

# status
status = graphene.String()
Expand Down Expand Up @@ -75,11 +78,9 @@ async def batch_load_by_session_id(
graph_ctx: GraphQueryContext,
session_ids: Sequence[SessionId],
) -> Sequence[Sequence[Self]]:
from ..kernel import kernels

async with graph_ctx.db.begin_readonly_session() as db_sess:
query = sa.select(kernels).where(kernels.c.session_id.in_(session_ids))
return await batch_multiresult_in_session(
query = sa.select(KernelRow).where(KernelRow.session_id.in_(session_ids))
return await batch_multiresult_in_scalar_stream(
graph_ctx,
db_sess,
query,
Expand Down Expand Up @@ -122,6 +123,8 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self:
local_rank=row.local_rank,
cluster_role=row.cluster_role,
session_id=row.session_id,
architecture=row.architecture,
image_reference=row.image,
status=row.status,
status_changed=row.status_changed,
status_info=row.status_info,
Expand All @@ -138,6 +141,17 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self:
preopen_ports=row.preopen_ports,
)

async def resolve_image(self, info: graphene.ResolveInfo) -> Optional[ImageNode]:
graph_ctx: GraphQueryContext = info.context
loader = graph_ctx.dataloader_manager.get_loader_by_func(
graph_ctx, ImageNode.batch_load_by_name_and_arch
)
images = cast(list[ImageNode], await loader.load((self.image_reference, self.architecture)))
try:
return images[0]
except IndexError:
return None

async def resolve_live_stat(self, info: graphene.ResolveInfo) -> dict[str, Any] | None:
graph_ctx: GraphQueryContext = info.context
loader = graph_ctx.dataloader_manager.get_loader_by_func(
Expand Down
35 changes: 27 additions & 8 deletions src/ai/backend/manager/models/gql_models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import graphene
import graphql
import more_itertools
import sqlalchemy as sa
import trafaret as t
from dateutil.parser import parse as dtparse
Expand All @@ -21,7 +22,7 @@
from sqlalchemy.orm import selectinload

from ai.backend.common import validators as tx
from ai.backend.common.types import ClusterMode, SessionId, SessionResult
from ai.backend.common.types import ClusterMode, SessionId, SessionResult, VFolderMount
from ai.backend.manager.idle import ReportInfo

from ..base import (
Expand Down Expand Up @@ -58,6 +59,7 @@
from ..user import UserRole
from ..utils import execute_with_txn_retry
from .kernel import KernelConnection, KernelNode
from .vfolder import VirtualFolderConnection, VirtualFolderNode

if TYPE_CHECKING:
from ..gql import GraphQueryContext
Expand Down Expand Up @@ -194,6 +196,14 @@ class Meta:
vfolder_mounts = graphene.List(lambda: graphene.String)
occupied_slots = graphene.JSONString()
requested_slots = graphene.JSONString()
image_references = graphene.List(
lambda: graphene.String,
description="Added in 25.2.0.",
)
vfolder_nodes = PaginatedConnectionField(
VirtualFolderConnection,
description="Added in 25.2.0.",
)

# statistics
num_queries = BigInt()
Expand Down Expand Up @@ -262,6 +272,7 @@ def from_row(
vfolder_mounts=[vf.vfid.folder_id for vf in row.vfolder_mounts],
occupied_slots=row.occupying_slots.to_json(),
requested_slots=row.requested_slots.to_json(),
image_references=row.images,
# statistics
num_queries=row.num_queries,
)
Expand All @@ -275,19 +286,28 @@ async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> dict[str, Any
)
return await loader.load(self.row_id)

async def resolve_vfolder_nodes(
self,
info: graphene.ResolveInfo,
) -> ConnectionResolverResult[VirtualFolderNode]:
ctx: GraphQueryContext = info.context
vfolder_mounts = cast(list[VFolderMount], self.vfolder_mounts)
_folder_ids = [vf_mount.vfid.folder_id for vf_mount in vfolder_mounts]
loader = ctx.dataloader_manager.get_loader_by_func(ctx, VirtualFolderNode.batch_load_by_id)
result = cast(list[list[VirtualFolderNode]], await loader.load_many(_folder_ids))

vf_nodes = cast(list[VirtualFolderNode], list(more_itertools.flatten(result)))
return ConnectionResolverResult(vf_nodes, None, None, None, total_count=len(vf_nodes))

async def resolve_kernel_nodes(
self,
info: graphene.ResolveInfo,
) -> ConnectionResolverResult[KernelNode]:
ctx: GraphQueryContext = info.context
loader = ctx.dataloader_manager.get_loader(ctx, "KernelNode.by_session_id")
kernels = await loader.load(self.row_id)
kernel_nodes = await loader.load(self.row_id)
return ConnectionResolverResult(
kernels,
None,
None,
None,
total_count=len(kernels),
kernel_nodes, None, None, None, total_count=len(kernel_nodes)
)

async def resolve_dependees(
Expand Down Expand Up @@ -490,7 +510,6 @@ async def get_accessible_connection(
before=before,
last=last,
)
query = query.options(selectinload(SessionRow.kernels))
async with graph_ctx.db.connect() as db_conn:
user = graph_ctx.user
client_ctx = ClientContext(
Expand Down
22 changes: 21 additions & 1 deletion src/ai/backend/manager/models/gql_models/vfolder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import uuid
from collections.abc import Iterable, Mapping
from collections.abc import Iterable, Mapping, Sequence
from datetime import datetime
from typing import (
TYPE_CHECKING,
Expand All @@ -25,6 +25,7 @@
VFolderID,
VFolderUsageMode,
)
from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream

from ...api.exceptions import (
VFolderOperationFailed,
Expand Down Expand Up @@ -216,6 +217,25 @@ def from_row(
result.permissions = [] if permissions is None else permissions
return result

@classmethod
async def batch_load_by_id(
cls,
graph_ctx: GraphQueryContext,
folder_ids: Sequence[uuid.UUID],
) -> Sequence[Sequence[Self]]:
query = (
sa.select(VFolderRow)
.where(VFolderRow.id.in_(folder_ids))
.options(
joinedload(VFolderRow.user_row),
joinedload(VFolderRow.group_row),
)
)
async with graph_ctx.db.begin_readonly_session() as db_session:
return await batch_multiresult_in_scalar_stream(
graph_ctx, db_session, query, cls, folder_ids, lambda row: row.id
)

@classmethod
async def get_node(cls, info: graphene.ResolveInfo, id: str) -> Self:
graph_ctx: GraphQueryContext = info.context
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ def parse_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Mapping[str, Any]:
"session_id": row.session_id,
# image
"image": row.image,
"image_object": ImageNode.from_row(row.image_row),
"image_object": ImageNode.from_row(ctx, row.image_row),
"architecture": row.architecture,
"registry": row.registry,
# status
Expand Down
Loading