Skip to content

Commit

Permalink
Change vfolder_nodes and kernel_nodes to list
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Oct 29, 2024
1 parent d608cd0 commit a1bf3e5
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 36 deletions.
22 changes: 22 additions & 0 deletions src/ai/backend/manager/models/gql_models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ImageAlias,
)
from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream
from ai.backend.manager.models.container_registry import ContainerRegistryRow, ContainerRegistryType

from ...api.exceptions import ImageNotFound, ObjectNotFound
Expand Down Expand Up @@ -331,6 +332,27 @@ class Meta:
graphene.String, description="Added in 24.03.4. The array of image aliases."
)

@classmethod
async def batch_load_by_name_and_arch(
cls,
graph_ctx: GraphQueryContext,
name_and_arch: Sequence[tuple[str, str]],
) -> Sequence[Sequence[ImageNode]]:
query = (
sa.select(ImageRow)
.where(sa.tuple_(ImageRow.name, ImageRow.architecture).in_(name_and_arch))
.options(selectinload(ImageRow.aliases))
)
async with graph_ctx.db.begin_readonly_session() as db_session:
return await batch_multiresult_in_scalar_stream(
graph_ctx,
db_session,
query,
cls,
name_and_arch,
lambda row: (row.name, row.architecture),
)

@overload
@classmethod
def from_row(cls, graph_ctx: GraphQueryContext, row: ImageRow) -> Self: ...
Expand Down
26 changes: 16 additions & 10 deletions src/ai/backend/manager/models/gql_models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
from typing import (
TYPE_CHECKING,
Any,
Optional,
Self,
cast,
)

import graphene
import sqlalchemy as sa
from graphene.types.datetime import DateTime as GQLDateTime
from redis.asyncio import Redis
from sqlalchemy.orm import joinedload, selectinload

from ai.backend.common import msgpack, redis_helper
from ai.backend.common.types import KernelId, SessionId
from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream

from ..gql_relay import AsyncNode, Connection
from ..image import ImageRow
from ..kernel import KernelRow, KernelStatus
from ..user import UserRole
from .image import ImageNode
Expand Down Expand Up @@ -47,6 +47,7 @@ class Meta:

# image
image = graphene.Field(ImageNode)
image_reference = graphene.String(description="Added in 24.12.0.")
architecture = graphene.String(
description="Added in 24.12.0. The architecture that the image of this kernel requires"
)
Expand Down Expand Up @@ -78,13 +79,7 @@ async def batch_load_by_session_id(
session_ids: Sequence[SessionId],
) -> Sequence[Sequence[Self]]:
async with graph_ctx.db.begin_readonly_session() as db_sess:
query = (
sa.select(KernelRow)
.where(KernelRow.session_id.in_(session_ids))
.options(
joinedload(KernelRow.image_row).options(selectinload(ImageRow.aliases)),
)
)
query = sa.select(KernelRow).where(KernelRow.session_id.in_(session_ids))
return await batch_multiresult_in_scalar_stream(
graph_ctx,
db_sess,
Expand Down Expand Up @@ -112,7 +107,7 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self:
cluster_role=row.cluster_role,
session_id=row.session_id,
architecture=row.architecture,
image=ImageNode.from_row(ctx, row.image_row),
image_reference=row.image,
status=row.status,
status_changed=row.status_changed,
status_info=row.status_info,
Expand All @@ -129,6 +124,17 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self:
preopen_ports=row.preopen_ports,
)

async def resolve_image(self, info: graphene.ResolveInfo) -> Optional[ImageNode]:
graph_ctx: GraphQueryContext = info.context
loader = graph_ctx.dataloader_manager.get_loader_by_func(
graph_ctx, ImageNode.batch_load_by_name_and_arch
)
images = cast(list[ImageNode], await loader.load((self.image_reference, self.architecture)))
try:
return images[0]
except IndexError:
return None

async def resolve_live_stat(self, info: graphene.ResolveInfo) -> dict[str, Any] | None:
graph_ctx: GraphQueryContext = info.context
loader = graph_ctx.dataloader_manager.get_loader_by_func(
Expand Down
35 changes: 11 additions & 24 deletions src/ai/backend/manager/models/gql_models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
)
from ..user import UserRole
from ..utils import execute_with_txn_retry
from .kernel import KernelConnection, KernelNode
from .vfolder import VirtualFolderConnection, VirtualFolderNode
from .kernel import KernelNode
from .vfolder import VirtualFolderNode

if TYPE_CHECKING:
from ..gql import GraphQueryContext
Expand Down Expand Up @@ -196,8 +196,8 @@ class Meta:
lambda: graphene.String,
description="Added in 24.12.0.",
)
vfolder_nodes = PaginatedConnectionField(
VirtualFolderConnection,
vfolder_nodes = graphene.List(
lambda: VirtualFolderNode,
description="Added in 24.12.0.",
)

Expand All @@ -206,8 +206,9 @@ class Meta:
inference_metrics = graphene.JSONString()

# relations
kernel_nodes = PaginatedConnectionField(
KernelConnection,
kernel_nodes = graphene.List(
lambda: KernelNode,
description="Added in 24.9.0.",
)
dependents = PaginatedConnectionField(
"ai.backend.manager.models.gql_models.session.ComputeSessionConnection",
Expand Down Expand Up @@ -285,34 +286,20 @@ async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> dict[str, Any
async def resolve_vfolder_nodes(
self,
info: graphene.ResolveInfo,
) -> ConnectionResolverResult[VirtualFolderNode]:
) -> list[VirtualFolderNode]:
ctx: GraphQueryContext = info.context
loader = ctx.dataloader_manager.get_loader_by_func(ctx, VirtualFolderNode.batch_load_by_id)
vfolder_mounts = cast(list[VFolderMount], self.vfolder_mounts)
_folder_ids = [vf_mount.vfid.folder_id for vf_mount in vfolder_mounts]
folders = cast(list[VirtualFolderNode], await loader.load_many(_folder_ids))
return ConnectionResolverResult(
folders,
None,
None,
None,
total_count=len(folders),
)
return await loader.load_many(_folder_ids)

async def resolve_kernel_nodes(
self,
info: graphene.ResolveInfo,
) -> ConnectionResolverResult[KernelNode]:
) -> list[KernelNode]:
ctx: GraphQueryContext = info.context
loader = ctx.dataloader_manager.get_loader(ctx, "KernelNode.by_session_id")
kernels = await loader.load(self.row_id)
return ConnectionResolverResult(
kernels,
None,
None,
None,
total_count=len(kernels),
)
return await loader.load(self.row_id)

async def resolve_dependees(
self,
Expand Down
7 changes: 5 additions & 2 deletions src/ai/backend/manager/models/gql_models/vfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
VFolderID,
VFolderUsageMode,
)
from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream

from ...api.exceptions import (
VFolderOperationFailed,
Expand Down Expand Up @@ -221,7 +222,7 @@ async def batch_load_by_id(
cls,
graph_ctx: GraphQueryContext,
folder_ids: Sequence[uuid.UUID],
) -> list[Self]:
) -> Sequence[Sequence[Self]]:
query = (
sa.select(VFolderRow)
.where(VFolderRow.id.in_(folder_ids))
Expand All @@ -231,7 +232,9 @@ async def batch_load_by_id(
)
)
async with graph_ctx.db.begin_readonly_session() as db_session:
return [cls.from_row(graph_ctx, row) for row in await db_session.scalars(query)]
return await batch_multiresult_in_scalar_stream(
graph_ctx, db_session, query, cls, folder_ids, lambda row: row.id
)

@classmethod
async def get_node(cls, info: graphene.ResolveInfo, id: str) -> Self:
Expand Down

0 comments on commit a1bf3e5

Please sign in to comment.