Skip to content

Commit

Permalink
chore(BA-440): Upgrade mypy to 1.14.1 and ruff to 0.8.5 (#3354)
Browse files Browse the repository at this point in the history
Backported-from: main (24.12)
Backported-to: 24.09
  • Loading branch information
achimnol committed Jan 2, 2025
1 parent 001c142 commit d321c7e
Show file tree
Hide file tree
Showing 95 changed files with 659 additions and 715 deletions.
9 changes: 6 additions & 3 deletions src/ai/backend/accelerator/cuda_open/nvidia.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import ctypes
import platform
from abc import ABCMeta, abstractmethod
from collections.abc import MutableMapping, Sequence
from itertools import groupby
from operator import itemgetter
from typing import Any, MutableMapping, NamedTuple, Tuple, TypeAlias
from typing import Any, NamedTuple, TypeAlias, cast

# ref: https://developer.nvidia.com/cuda-toolkit-archive
TARGET_CUDA_VERSIONS = (
Expand Down Expand Up @@ -487,7 +488,7 @@ def load_library(cls):
return None

@classmethod
def get_version(cls) -> Tuple[int, int]:
def get_version(cls) -> tuple[int, int]:
if cls._version == (0, 0):
raw_ver = ctypes.c_int()
cls.invoke("cudaRuntimeGetVersion", ctypes.byref(raw_ver))
Expand All @@ -513,7 +514,9 @@ def get_device_props(cls, device_idx: int):
props_struct = cudaDeviceProp()
cls.invoke("cudaGetDeviceProperties", ctypes.byref(props_struct), device_idx)
props: MutableMapping[str, Any] = {
k: getattr(props_struct, k) for k, _ in props_struct._fields_
# Treat each field as two-tuple assuming that we don't have bit-fields
k: getattr(props_struct, k)
for k, _ in cast(Sequence[tuple[str, Any]], props_struct._fields_)
}
pci_bus_id = b" " * 16
cls.invoke("cudaDeviceGetPCIBusId", ctypes.c_char_p(pci_bus_id), 16, device_idx)
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/accelerator/cuda_open/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ async def list_devices(self) -> Collection[CUDADevice]:
if dev_id in self.device_mask:
continue
raw_info = libcudart.get_device_props(int(dev_id))
sysfs_node_path = f"/sys/bus/pci/devices/{raw_info["pciBusID_str"].lower()}/numa_node"
sysfs_node_path = f"/sys/bus/pci/devices/{raw_info['pciBusID_str'].lower()}/numa_node"
node: Optional[int]
try:
node = int(Path(sysfs_node_path).read_text().strip())
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/accelerator/mock/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ async def list_devices(self) -> Collection[MockDevice]:
init_kwargs["is_mig_device"] = dev_info["is_mig_device"]
if dev_info["is_mig_device"]:
init_kwargs["device_id"] = DeviceId(
f"MIG-{dev_info["mother_uuid"]}/{idx}/0"
f"MIG-{dev_info['mother_uuid']}/{idx}/0"
)
device_cls = CUDADevice
case _:
Expand Down Expand Up @@ -810,7 +810,7 @@ def get_metadata(self) -> AcceleratorMetadata:

device_format = self.device_formats[format_key]
return {
"slot_name": f"{self.mock_config["slot_name"]}.{format_key}",
"slot_name": f"{self.mock_config['slot_name']}.{format_key}",
"human_readable_name": device_format["human_readable_name"],
"description": device_format["description"],
"display_unit": device_format["display_unit"],
Expand Down
6 changes: 3 additions & 3 deletions src/ai/backend/account_manager/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,9 @@ async def server_main(
try:
ssl_ctx = None
if am_cfg.ssl_enabled:
assert (
am_cfg.ssl_cert is not None
), "Should set `account_manager.ssl-cert` in config file."
assert am_cfg.ssl_cert is not None, (
"Should set `account_manager.ssl-cert` in config file."
)
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_ctx.load_cert_chain(
str(am_cfg.ssl_cert),
Expand Down
10 changes: 5 additions & 5 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,7 +2016,7 @@ async def create_kernel(
if len(overlapping_services) > 0:
raise AgentError(
f"Port {port_no} overlaps with built-in service"
f" {overlapping_services[0]["name"]}"
f" {overlapping_services[0]['name']}"
)

preopen_sport: ServicePort = {
Expand Down Expand Up @@ -2340,7 +2340,7 @@ async def load_model_definition(

if not model_definition_path:
raise AgentError(
f"Model definition file ({" or ".join(model_definition_candidates)}) does not exist under vFolder"
f"Model definition file ({' or '.join(model_definition_candidates)}) does not exist under vFolder"
f" {model_folder.name} (ID {model_folder.vfid})",
)
try:
Expand Down Expand Up @@ -2371,11 +2371,11 @@ async def load_model_definition(
]
if len(overlapping_services) > 0:
raise AgentError(
f"Port {service["port"]} overlaps with built-in service"
f" {overlapping_services[0]["name"]}"
f"Port {service['port']} overlaps with built-in service"
f" {overlapping_services[0]['name']}"
)
service_ports.append({
"name": f"{model["name"]}-{service["port"]}",
"name": f"{model['name']}-{service['port']}",
"protocol": ServicePortProtocols.PREOPEN,
"container_ports": (service["port"],),
"host_ports": (None,),
Expand Down
6 changes: 3 additions & 3 deletions src/ai/backend/agent/docker/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ async def start_container(
label for label in service_ports_label if label
])
update_nested_dict(container_config, self.computer_docker_args)
kernel_name = f"kernel.{self.image_ref.name.split("/")[-1]}.{self.kernel_id}"
kernel_name = f"kernel.{self.image_ref.name.split('/')[-1]}.{self.kernel_id}"

# optional local override of docker config
extra_container_opts_name = "agent-docker-container-opts.json"
Expand Down Expand Up @@ -1169,7 +1169,7 @@ async def __ainit__(self) -> None:
{
"Cmd": [
f"UNIX-LISTEN:/ipc/{self.agent_sockpath.name},unlink-early,fork,mode=777",
f"TCP-CONNECT:127.0.0.1:{self.local_config["agent"]["agent-sock-port"]}",
f"TCP-CONNECT:127.0.0.1:{self.local_config['agent']['agent-sock-port']}",
],
"HostConfig": {
"Mounts": [
Expand Down Expand Up @@ -1409,7 +1409,7 @@ async def handle_agent_socket(self):
while True:
agent_sock = zmq_ctx.socket(zmq.REP)
try:
agent_sock.bind(f"tcp://127.0.0.1:{self.local_config["agent"]["agent-sock-port"]}")
agent_sock.bind(f"tcp://127.0.0.1:{self.local_config['agent']['agent-sock-port']}")
while True:
msg = await agent_sock.recv_multipart()
if not msg:
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/agent/docker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def get_container_version_and_status(self) -> Tuple[int, bool]:
raise
if c["Config"].get("Labels", {}).get("ai.backend.system", "0") != "1":
raise RuntimeError(
f"An existing container named \"{c["Name"].lstrip("/")}\" is not a system container"
f'An existing container named "{c["Name"].lstrip("/")}" is not a system container'
" spawned by Backend.AI. Please check and remove it."
)
return (
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/agent/kubernetes/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ async def check_krunner_pv_status(self):
new_pv.label("backend.ai/backend-ai-scratch-volume", "hostPath")
else:
raise NotImplementedError(
f'Scratch type {self.local_config["container"]["scratch-type"]} is not'
f"Scratch type {self.local_config['container']['scratch-type']} is not"
" supported",
)

Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/agent/kubernetes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def get_container_version_and_status(self) -> Tuple[int, bool]:
raise
if c["Config"].get("Labels", {}).get("ai.backend.system", "0") != "1":
raise RuntimeError(
f"An existing container named \"{c["Name"].lstrip("/")}\" is not a system container"
f'An existing container named "{c["Name"].lstrip("/")}" is not a system container'
" spawned by Backend.AI. Please check and remove it."
)
return (
Expand Down
13 changes: 7 additions & 6 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
import signal
import sys
from collections import OrderedDict, defaultdict
from ipaddress import _BaseAddress as BaseIPAddress
from ipaddress import ip_network
from ipaddress import IPv4Address, IPv6Address, ip_network
from pathlib import Path
from pprint import pformat, pprint
from typing import (
Expand Down Expand Up @@ -870,7 +869,7 @@ async def server_main(

log.info("Preparing kernel runner environments...")
kernel_mod = importlib.import_module(
f"ai.backend.agent.{local_config["agent"]["backend"].value}.kernel",
f"ai.backend.agent.{local_config['agent']['backend'].value}.kernel",
)
krunner_volumes = await kernel_mod.prepare_krunner_env(local_config) # type: ignore
# TODO: merge k8s branch: nfs_mount_path = local_config['baistatic']['mounted-at']
Expand All @@ -890,8 +889,8 @@ async def server_main(
}
scope_prefix_map = {
ConfigScopes.GLOBAL: "",
ConfigScopes.SGROUP: f"sgroup/{local_config["agent"]["scaling-group"]}",
ConfigScopes.NODE: f"nodes/agents/{local_config["agent"]["id"]}",
ConfigScopes.SGROUP: f"sgroup/{local_config['agent']['scaling-group']}",
ConfigScopes.NODE: f"nodes/agents/{local_config['agent']['id']}",
}
etcd = AsyncEtcd(
local_config["etcd"]["addr"],
Expand Down Expand Up @@ -1053,7 +1052,9 @@ def main(
raise click.Abort()

rpc_host = cfg["agent"]["rpc-listen-addr"].host
if isinstance(rpc_host, BaseIPAddress) and (rpc_host.is_unspecified or rpc_host.is_link_local):
if isinstance(rpc_host, (IPv4Address, IPv6Address)) and (
rpc_host.is_unspecified or rpc_host.is_link_local
):
print(
"ConfigurationError: "
"Cannot use link-local or unspecified IP address as the RPC listening host.",
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/agent/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def main(
fn = Path(cfg["logging"]["file"]["filename"])
cfg["logging"]["file"]["filename"] = f"{fn.stem}-watcher{fn.suffix}"

setproctitle(f"backend.ai: watcher {cfg["etcd"]["namespace"]}")
setproctitle(f"backend.ai: watcher {cfg['etcd']['namespace']}")
with logger:
log.info("Backend.AI Agent Watcher {0}", VERSION)
log.info("runtime: {0}", utils.env_info())
Expand Down
6 changes: 3 additions & 3 deletions src/ai/backend/cli/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ def ask_string_in_array(prompt: str, choices: list, default: str) -> Optional[st

if default:
question = (
f"{prompt} (choices: {"/".join(choices)}, "
f"{prompt} (choices: {'/'.join(choices)}, "
f"if left empty, this will use default value: {default}): "
)
else:
question = (
f"{prompt} (choices: {"/".join(choices)}, if left empty, this will remove this key): "
f"{prompt} (choices: {'/'.join(choices)}, if left empty, this will remove this key): "
)

while True:
Expand All @@ -92,7 +92,7 @@ def ask_string_in_array(prompt: str, choices: list, default: str) -> Optional[st
elif user_reply.lower() in choices:
break
else:
print(f"Please answer in {"/".join(choices)}.")
print(f"Please answer in {'/'.join(choices)}.")
return user_reply


Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/client/cli/admin/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def rescan_images_impl(registry: str) -> None:
print_error(e)
sys.exit(ExitCode.FAILURE)
if not result["ok"]:
print_fail(f"Failed to begin registry scanning: {result["msg"]}")
print_fail(f"Failed to begin registry scanning: {result['msg']}")
sys.exit(ExitCode.FAILURE)
print_done("Started updating the image metadata from the configured registries.")
bgtask_id = result["task_id"]
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/client/cli/pretty.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def format_error(exc: Exception):
if matches:
yield "\nCandidates (up to 10 recent entries):\n"
for item in matches:
yield f"- {item["id"]} ({item["name"]}, {item["status"]})\n"
yield f"- {item['id']} ({item['name']}, {item['status']})\n"
elif exc.data["type"].endswith("/session-already-exists"):
existing_session_id = exc.data["data"].get("existingSessionId", None)
if existing_session_id is not None:
Expand All @@ -144,7 +144,7 @@ def format_error(exc: Exception):
if exc.data["type"].endswith("/graphql-error"):
yield "\n\u279c Message:\n"
for err_item in exc.data.get("data", []):
yield f"{err_item["message"]}"
yield f"{err_item['message']}"
if err_path := err_item.get("path"):
yield f" (path: {_format_gql_path(err_path)})"
yield "\n"
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/client/cli/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def info(ctx: CLIContext, service_name_or_id: str):
)
print()
for route in routes:
print(f"Route {route["routing_id"]}: ")
print(f"Route {route['routing_id']}: ")
ctx.output.print_item(
route,
_default_routing_fields,
Expand Down Expand Up @@ -645,7 +645,7 @@ def generate_token(ctx: CLIContext, service_name_or_id: str, duration: str, quie
if quiet:
print(resp["token"])
else:
print_done(f"Generated API token {resp["token"]}")
print_done(f"Generated API token {resp['token']}")
except Exception as e:
ctx.output.print_error(e)
sys.exit(ExitCode.FAILURE)
Expand Down
6 changes: 3 additions & 3 deletions src/ai/backend/client/cli/session/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ async def cmd_main() -> None:
session = api_sess.ComputeSession.from_session_id(session_id)
resp = await session.update(priority=priority)
item = resp["item"]
print_done(f"Session {item["name"]!r} priority is changed to {item["priority"]}.")
print_done(f"Session {item['name']!r} priority is changed to {item['priority']}.")

try:
asyncio.run(cmd_main())
Expand Down Expand Up @@ -1263,7 +1263,7 @@ def watch(
session_names = _fetch_session_names()
if not session_names:
if output == "json":
sys.stderr.write(f'{json.dumps({"ok": False, "reason": "No matching items."})}\n')
sys.stderr.write(f"{json.dumps({'ok': False, 'reason': 'No matching items.'})}\n")
else:
print_fail("No matching items.")
sys.exit(ExitCode.FAILURE)
Expand All @@ -1285,7 +1285,7 @@ def watch(
else:
if output == "json":
sys.stderr.write(
f'{json.dumps({"ok": False, "reason": "No matching items."})}\n'
f"{json.dumps({'ok': False, 'reason': 'No matching items.'})}\n"
)
else:
print_fail("No matching items.")
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/client/cli/vfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def request_download(name, filename):
with Session() as session:
try:
response = json.loads(session.VFolder(name).request_download(filename))
print_done(f'Download token: {response["token"]}')
print_done(f"Download token: {response['token']}")
except Exception as e:
print_error(e)
sys.exit(ExitCode.FAILURE)
Expand Down
15 changes: 6 additions & 9 deletions src/ai/backend/client/func/acl.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import textwrap
from typing import Sequence

from ai.backend.client.output.fields import permission_fields
from ai.backend.client.output.types import FieldSpec

from ..output.fields import permission_fields
from ..output.types import FieldSpec
from ..session import api_session
from ..utils import dedent as _d
from .base import BaseFunction, api_function

__all__ = ("Permission",)
Expand All @@ -24,13 +23,11 @@ async def list(
:param fields: Additional permission query fields to fetch.
"""
query = textwrap.dedent(
"""\
query = _d("""
query {
vfolder_host_permissions {$fields}
vfolder_host_permissions { $fields }
}
"""
)
""")
query = query.replace("$fields", " ".join(f.field_ref for f in fields))
data = await api_session.get().Admin._query(query)
return data["vfolder_host_permissions"]
19 changes: 8 additions & 11 deletions src/ai/backend/client/func/agent.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations

import textwrap
from typing import Optional, Sequence

from ai.backend.client.output.fields import agent_fields
from ai.backend.client.output.types import FieldSpec, PaginatedResult
from ai.backend.client.pagination import fetch_paginated_result
from ai.backend.client.request import Request
from ai.backend.client.session import api_session

from ..output.fields import agent_fields
from ..output.types import FieldSpec, PaginatedResult
from ..pagination import fetch_paginated_result
from ..request import Request
from ..session import api_session
from ..utils import dedent as _d
from .base import BaseFunction, api_function

__all__ = (
Expand Down Expand Up @@ -88,13 +87,11 @@ async def detail(
agent_id: str,
fields: Sequence[FieldSpec] = _default_detail_fields,
) -> Sequence[dict]:
query = textwrap.dedent(
"""\
query = _d("""
query($agent_id: String!) {
agent(agent_id: $agent_id) {$fields}
}
"""
)
""")
query = query.replace("$fields", " ".join(f.field_ref for f in fields))
variables = {"agent_id": agent_id}
data = await api_session.get().Admin._query(query, variables)
Expand Down
Loading

0 comments on commit d321c7e

Please sign in to comment.