Skip to content

Commit

Permalink
fix: download with resume (#689)
Browse files Browse the repository at this point in the history
* fix: download with resume

* fix: pass a valid user-agent

* fix: unttest
  • Loading branch information
numb3r3 authored Apr 24, 2022
1 parent 6eadbfb commit 3bd7464
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 15 deletions.
54 changes: 42 additions & 12 deletions server/clip_server/model/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import io
import urllib
import shutil
import warnings
from typing import Union, List

Expand Down Expand Up @@ -36,7 +37,7 @@
}


def _download(url: str, root: str):
def _download(url: str, root: str, with_resume: bool = True):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)

Expand Down Expand Up @@ -70,20 +71,48 @@ def _download(url: str, root: str):

task = progress.add_task('download', filename=url, start=False)

with urllib.request.urlopen(url) as source, open(
download_target, 'wb'
) as output:
tmp_file_path = download_target + '.part'
resume_byte_pos = (
os.path.getsize(tmp_file_path) if os.path.exists(tmp_file_path) else 0
)

total_bytes = -1
try:
# resolve the 403 error by passing a valid user-agent
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})

total_bytes = int(
urllib.request.urlopen(req).info().get('Content-Length', -1)
)

mode = 'ab' if (with_resume and resume_byte_pos) else 'wb'

with open(tmp_file_path, mode) as output:

progress.update(task, total=total_bytes)

progress.start_task(task)

progress.update(task, total=int(source.info().get('Content-Length')))
if resume_byte_pos and with_resume:
progress.update(task, advance=resume_byte_pos)
req.headers['Range'] = f'bytes={resume_byte_pos}-'

progress.start_task(task)
while True:
buffer = source.read(8192)
if not buffer:
break
with urllib.request.urlopen(req) as source:
while True:
buffer = source.read(8192)
if not buffer:
break

output.write(buffer)
progress.update(task, advance=len(buffer))
output.write(buffer)
progress.update(task, advance=len(buffer))
except Exception as ex:
raise ex
finally:
# rename the temp download file to the correct name if fully downloaded
if os.path.exists(tmp_file_path) and (
total_bytes == os.path.getsize(tmp_file_path)
):
shutil.move(tmp_file_path, download_target)

return download_target

Expand Down Expand Up @@ -165,6 +194,7 @@ def load(
model_path = _download(
_S3_BUCKET + _MODELS[name],
download_root or os.path.expanduser('~/.cache/clip'),
with_resume=True,
)
elif os.path.isfile(name):
model_path = name
Expand Down
8 changes: 6 additions & 2 deletions server/clip_server/model/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ class CLIPOnnxModel:
def __init__(self, name: str = None):
if name in _MODELS:
cache_dir = os.path.expanduser(f'~/.cache/clip/{name.replace("/", "-")}')
self._textual_path = _download(_S3_BUCKET + _MODELS[name][0], cache_dir)
self._visual_path = _download(_S3_BUCKET + _MODELS[name][1], cache_dir)
self._textual_path = _download(
_S3_BUCKET + _MODELS[name][0], cache_dir, with_resume=True
)
self._visual_path = _download(
_S3_BUCKET + _MODELS[name][1], cache_dir, with_resume=True
)
else:
raise RuntimeError(
f'Model {name} not found; available models = {available_models()}'
Expand Down
21 changes: 20 additions & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,30 @@
import os

import pytest
from clip_server.model.clip import _transform_ndarray, _transform_blob
from clip_server.model.clip import _transform_ndarray, _transform_blob, _download
from docarray import Document
import numpy as np


def test_server_download(tmpdir):
_download('https://docarray.jina.ai/_static/favicon.png', tmpdir, with_resume=False)

target_path = os.path.join(tmpdir, 'favicon.png')
file_size = os.path.getsize(target_path)
assert file_size > 0

part_path = target_path + '.part'
with open(target_path, 'rb') as source, open(part_path, 'wb') as part_out:
buf = source.read(10)
part_out.write(buf)

os.remove(target_path)

_download('https://docarray.jina.ai/_static/favicon.png', tmpdir, with_resume=True)
assert os.path.getsize(target_path) == file_size
assert not os.path.exists(part_path)


@pytest.mark.parametrize(
'image_uri',
[
Expand Down

0 comments on commit 3bd7464

Please sign in to comment.