Skip to content

Commit

Permalink
Let kernel nodes resolve image node
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Oct 29, 2024
1 parent a92696c commit 3aab734
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 49 deletions.
17 changes: 0 additions & 17 deletions src/ai/backend/manager/models/gql_models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,23 +404,6 @@ async def get_node(cls, info: graphene.ResolveInfo, id: str) -> ImageNode:
raise ValueError(f"Image not found (id: {image_id})")
return cls.from_row(graph_ctx, image_row)

@classmethod
async def batch_load_by_canonical(
cls,
graph_ctx: GraphQueryContext,
image_names: Sequence[str],
) -> list[Self]:
query = (
sa.select(ImageRow)
.where(ImageRow.name.in_(image_names))
.options(selectinload(ImageRow.aliases))
)
ret: list[Self] = []
async with graph_ctx.db.begin_readonly_session() as db_session:
async for row in await db_session.stream_scalars(query):
ret.append(cls.from_row(graph_ctx, row))
return ret


class ImageConnection(Connection):
class Meta:
Expand Down
21 changes: 16 additions & 5 deletions src/ai/backend/manager/models/gql_models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
import sqlalchemy as sa
from graphene.types.datetime import DateTime as GQLDateTime
from redis.asyncio import Redis
from sqlalchemy.orm import joinedload, selectinload

from ai.backend.common import msgpack, redis_helper
from ai.backend.common.types import KernelId, SessionId
from ai.backend.manager.models.base import batch_multiresult_in_session
from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream

from ..gql_relay import AsyncNode, Connection
from ..image import ImageRow
from ..kernel import KernelRow, KernelStatus
from ..user import UserRole
from .image import ImageNode
Expand Down Expand Up @@ -45,6 +47,9 @@ class Meta:

# image
image = graphene.Field(ImageNode)
architecture = graphene.String(
description="Added in 24.12.0. The architecture that the image of this kernel requires"
)

# status
status = graphene.String()
Expand Down Expand Up @@ -72,11 +77,15 @@ async def batch_load_by_session_id(
graph_ctx: GraphQueryContext,
session_ids: Sequence[SessionId],
) -> Sequence[Sequence[Self]]:
from ..kernel import kernels

async with graph_ctx.db.begin_readonly_session() as db_sess:
query = sa.select(kernels).where(kernels.c.session_id.in_(session_ids))
return await batch_multiresult_in_session(
query = (
sa.select(KernelRow)
.where(KernelRow.session_id.in_(session_ids))
.options(
joinedload(KernelRow.image_row).options(selectinload(ImageRow.aliases)),
)
)
return await batch_multiresult_in_scalar_stream(
graph_ctx,
db_sess,
query,
Expand All @@ -102,6 +111,8 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self:
local_rank=row.local_rank,
cluster_role=row.cluster_role,
session_id=row.session_id,
architecture=row.architecture,
image=ImageNode.from_row(ctx, row.image_row),
status=row.status,
status_changed=row.status_changed,
status_info=row.status_info,
Expand Down
27 changes: 4 additions & 23 deletions src/ai/backend/manager/models/gql_models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
)
from ..user import UserRole
from ..utils import execute_with_txn_retry
from .image import ImageConnection, ImageNode
from .kernel import KernelConnection, KernelNode
from .vfolder import VirtualFolderConnection, VirtualFolderNode

Expand Down Expand Up @@ -197,10 +196,6 @@ class Meta:
lambda: graphene.String,
description="Added in 24.12.0.",
)
image_nodes = PaginatedConnectionField(
ImageConnection,
description="Added in 24.12.0.",
)
vfolder_nodes = PaginatedConnectionField(
VirtualFolderConnection,
description="Added in 24.12.0.",
Expand Down Expand Up @@ -287,14 +282,15 @@ async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> dict[str, Any
)
return await loader.load(self.row_id)

async def resolve_image_nodes(
async def resolve_vfolder_nodes(
self,
info: graphene.ResolveInfo,
) -> ConnectionResolverResult[KernelNode]:
) -> ConnectionResolverResult[VirtualFolderNode]:
ctx: GraphQueryContext = info.context
loader = ctx.dataloader_manager.get_loader_by_func(ctx, VirtualFolderNode.batch_load_by_id)
vfolder_mounts = cast(list[VFolderMount], self.vfolder_mounts)
folders = await loader.load_many([vf_mount.vfid.folder_id for vf_mount in vfolder_mounts])
_folder_ids = [vf_mount.vfid.folder_id for vf_mount in vfolder_mounts]
folders = cast(list[VirtualFolderNode], await loader.load_many(_folder_ids))
return ConnectionResolverResult(
folders,
None,
Expand All @@ -303,21 +299,6 @@ async def resolve_image_nodes(
total_count=len(folders),
)

async def resolve_vfolder_nodes(
self,
info: graphene.ResolveInfo,
) -> ConnectionResolverResult[KernelNode]:
ctx: GraphQueryContext = info.context
loader = ctx.dataloader_manager.get_loader_by_func(ctx, ImageNode.batch_load_by_canonical)
images = await loader.load_many(self.image_references)
return ConnectionResolverResult(
images,
None,
None,
None,
total_count=len(images),
)

async def resolve_kernel_nodes(
self,
info: graphene.ResolveInfo,
Expand Down
5 changes: 1 addition & 4 deletions src/ai/backend/manager/models/gql_models/vfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,8 @@ async def batch_load_by_id(
joinedload(VFolderRow.group_row),
)
)
ret: list[Self] = []
async with graph_ctx.db.begin_readonly_session() as db_session:
async for row in await db_session.stream_scalars(query):
ret.append(cls.from_row(graph_ctx, row))
return ret
return [cls.from_row(graph_ctx, row) for row in await db_session.scalars(query)]

@classmethod
async def get_node(cls, info: graphene.ResolveInfo, id: str) -> Self:
Expand Down

0 comments on commit 3aab734

Please sign in to comment.