Skip to content

Commit

Permalink
Merge pull request #13 from Rouast-Labs/idxs-fix
Browse files Browse the repository at this point in the history
Improve handling of indices
  • Loading branch information
prouast authored Nov 15, 2024
2 parents c8c7b62 + 5dfccbc commit ce1d34f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"importlib_resources",
"numpy",
"onnxruntime",
"prpy[ffmpeg,numpy_min]>=0.2.15",
"prpy[ffmpeg,numpy_min]>=0.2.17",
"python-dotenv",
"pyyaml",
"requests",
Expand Down
6 changes: 3 additions & 3 deletions vitallens/methods/vitallens.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def __call__(
frames = inputs
# Longer videos are split up with small overlaps
n_splits = 1 if expected_ds_n <= API_MAX_FRAMES else math.ceil((expected_ds_n - API_MAX_FRAMES) / (API_MAX_FRAMES - API_OVERLAP)) + 1
split_len = expected_ds_n if n_splits == 1 else math.ceil((inputs_n + (n_splits-1) * API_OVERLAP * expected_ds_factor) / n_splits)
start_idxs = [i * (split_len - API_OVERLAP * expected_ds_factor) for i in range(n_splits)]
split_len = inputs_n if n_splits == 1 else math.ceil((inputs_n + (n_splits-1) * API_OVERLAP) / n_splits)
start_idxs = [i * (split_len - API_OVERLAP) for i in range(n_splits)]
end_idxs = [min(start + split_len, inputs_n) for start in start_idxs]
start_idxs = [max(0, end - split_len) for end in end_idxs]
logging.info("Running inference for {} frames using {} request(s)...".format(expected_ds_n, n_splits))
Expand Down Expand Up @@ -243,7 +243,7 @@ def process_api_batch(
assert self.input_buffer is not None
payload["state"] = base64.b64encode(self.state.astype(np.float32).tobytes()).decode('utf-8')
# Adjust idxs
idxs = idxs[3:] - 3
idxs = idxs[(self.n_inputs-1):] - (self.n_inputs-1)
# Ask API to process video
response = requests.post(API_URL, headers=headers, json=payload)
response_body = json.loads(response.text)
Expand Down

0 comments on commit ce1d34f

Please sign in to comment.