Skip to content

Commit

Permalink
fix: don't ignore empty dirs when unpacking model and bento (#5073)
Browse files Browse the repository at this point in the history
* fix: don't ignore empty dirs when unpacking model and bento

Signed-off-by: Frost Ming <[email protected]>
  • Loading branch information
frostming authored Nov 11, 2024
1 parent b720a05 commit 5675be9
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 64 deletions.
11 changes: 2 additions & 9 deletions src/bentoml/_internal/cloud/bento.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import tarfile
import typing as t
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from tempfile import NamedTemporaryFile

import attrs
Expand All @@ -19,6 +18,7 @@
from ..bento import BentoStore
from ..configuration.containers import BentoMLContainer
from ..tag import Tag
from ..utils.filesystem import safe_extract_tarfile
from .base import FILE_CHUNK_SIZE
from .base import UPLOAD_RETRY_COUNT
from .base import CallbackIOWrapper
Expand Down Expand Up @@ -520,14 +520,7 @@ def _do_pull_bento(
tar = tarfile.open(fileobj=tar_file, mode="r")
with self.spinner.spin(text=f'Extracting bento "{_tag}" tar file'):
with fs.open_fs("temp://") as temp_fs:
for member in tar.getmembers():
f = tar.extractfile(member)
if f is None:
continue
p = Path(member.name)
if p.parent != Path("."):
temp_fs.makedirs(p.parent.as_posix(), recreate=True)
temp_fs.writebytes(member.name, f.read())
safe_extract_tarfile(tar, temp_fs.getsyspath("/"))
bento = Bento.from_fs(temp_fs)
bento = bento.save(bento_store)
self.spinner.log(f'[bold green]Successfully pulled bento "{_tag}"')
Expand Down
11 changes: 2 additions & 9 deletions src/bentoml/_internal/cloud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import typing as t
import warnings
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from tempfile import NamedTemporaryFile

import attrs
Expand All @@ -20,6 +19,7 @@
from ..models import Model as StoredModel
from ..models import ModelStore
from ..tag import Tag
from ..utils.filesystem import safe_extract_tarfile
from .base import FILE_CHUNK_SIZE
from .base import UPLOAD_RETRY_COUNT
from .base import CallbackIOWrapper
Expand Down Expand Up @@ -482,14 +482,7 @@ def _do_pull_model(
tar = tarfile.open(fileobj=tar_file, mode="r")
with self.spinner.spin(text=f'Extracting model "{_tag}" tar file'):
with fs.open_fs("temp://") as temp_fs:
for member in tar.getmembers():
f = tar.extractfile(member)
if f is None:
continue
p = Path(member.name)
if p.parent != Path("."):
temp_fs.makedirs(str(p.parent), recreate=True)
temp_fs.writebytes(member.name, f.read())
safe_extract_tarfile(tar, temp_fs.getsyspath("/"))
model = StoredModel.from_fs(temp_fs).save(model_store)
self.spinner.log(f'[bold green]Successfully pulled model "{_tag}"')
return model
Expand Down
2 changes: 1 addition & 1 deletion src/bentoml/_internal/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import starlette.datastructures
from starlette.background import BackgroundTasks

from .utils.filesystem import TempfilePool
from .utils.http import Cookie
from .utils.temp import TempfilePool

if TYPE_CHECKING:
import starlette.requests
Expand Down
93 changes: 93 additions & 0 deletions src/bentoml/_internal/utils/filesystem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import annotations

import logging
import os
import shutil
import tarfile
import tempfile
from collections import deque
from functools import partial
from pathlib import Path
from threading import Lock

import fs

logger = logging.getLogger(__name__)


class TempfilePool:
"""A simple pool to get temp directories,
so they are reused as much as possible.
"""

def __init__(
self,
suffix: str | None = None,
prefix: str | None = None,
dir: str | None = None,
) -> None:
self._pool: deque[str] = deque([])
self._lock = Lock()
self._new = partial(tempfile.mkdtemp, suffix=suffix, prefix=prefix, dir=dir)

def cleanup(self) -> None:
while len(self._pool):
dir = self._pool.popleft()
shutil.rmtree(dir, ignore_errors=True)

def acquire(self) -> str:
with self._lock:
if not len(self._pool):
return self._new()
else:
return self._pool.popleft()

def release(self, dir: str) -> None:
for child in Path(dir).iterdir():
if child.is_dir():
shutil.rmtree(child)
else:
child.unlink()
with self._lock:
self._pool.append(dir)


def safe_extract_tarfile(tar: tarfile.TarFile, destination: str) -> None:
# Borrowed from pip but continue on error
os.makedirs(destination, exist_ok=True)
for member in tar.getmembers():
fn = member.name
path = os.path.join(destination, fn)
if not fs.path.relativefrom(destination, path):
logger.warning(
"The tar file has a file (%s) trying to unpack to"
"outside target directory",
fn,
)
continue
if member.isdir():
os.makedirs(path, exist_ok=True)
elif member.issym():
try:
tar._extract_member(member, path)
except Exception as exc:
# Some corrupt tar files seem to produce this
# (specifically bad symlinks)
logger.warning("In the tar file the member %s is invalid: %s", fn, exc)
continue
else:
try:
fp = tar.extractfile(member)
except (KeyError, AttributeError) as exc:
# Some corrupt tar files seem to produce this
# (specifically bad symlinks)
logger.warning("In the tar file the member %s is invalid: %s", fn, exc)
continue
os.makedirs(os.path.dirname(path), exist_ok=True)
if fp is None:
continue
with open(path, "wb") as destfp:
shutil.copyfileobj(fp, destfp)
fp.close()
# Update the timestamp (useful for cython compiled files)
tar.utime(member, path)
45 changes: 0 additions & 45 deletions src/bentoml/_internal/utils/temp.py

This file was deleted.

0 comments on commit 5675be9

Please sign in to comment.