Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
prouast committed Nov 14, 2024
1 parent a69114d commit 15e7130
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 21 deletions.
13 changes: 8 additions & 5 deletions examples/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import concurrent.futures
import cv2
import numpy as np
from prpy.constants import SECONDS_PER_MINUTE
from prpy.numpy.face import get_upper_body_roi_from_det
from prpy.numpy.signal import estimate_freq
import sys
Expand All @@ -13,6 +14,7 @@
from vitallens import VitalLens, Mode, Method
from vitallens.buffer import SignalBuffer, MultiSignalBuffer
from vitallens.constants import API_MIN_FRAMES
from vitallens.constants import CALC_HR_MIN, CALC_HR_MAX, CALC_RR_MIN, CALC_RR_MAX

def draw_roi(frame, roi):
roi = np.asarray(roi).astype(np.int32)
Expand Down Expand Up @@ -49,9 +51,10 @@ def draw_fps(frame, fps, text, draw_area_bl_x, draw_area_bl_y):
cv2.putText(frame, text='{}: {:.1f}'.format(text, fps), org=(draw_area_bl_x, draw_area_bl_y),
fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.6, color=(0,255,0), thickness=1)

def draw_vital(frame, sig, text, sig_name, fps, mult, color, draw_area_bl_x, draw_area_bl_y):
def draw_vital(frame, sig, text, sig_name, fps, color, draw_area_bl_x, draw_area_bl_y):
if sig_name in sig:
val = estimate_freq(x=sig[sig_name], f_s=fps, f_res=0.0167, method='periodogram') * mult
f_range = (CALC_HR_MIN/SECONDS_PER_MINUTE, CALC_HR_MAX/SECONDS_PER_MINUTE) if 'heart' in sig_name else (CALC_RR_MIN/SECONDS_PER_MINUTE, CALC_RR_MAX/SECONDS_PER_MINUTE)
val = estimate_freq(x=sig[sig_name], f_s=fps, f_res=0.1/SECONDS_PER_MINUTE, f_range=f_range, method='periodogram') * SECONDS_PER_MINUTE
cv2.putText(frame, text='{}: {:.1f}'.format(text, val), org=(draw_area_bl_x, draw_area_bl_y),
fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.6, color=color, thickness=1)

Expand Down Expand Up @@ -131,7 +134,7 @@ def run(args):
# Start next prediction
if len(frame_buffer) >= (API_MIN_FRAMES if args.method == Method.VITALLENS else 1):
n_frames = len(frame_buffer)
future = executor.submit(vl, frame_buffer.copy(), fps)
executor.submit(vl, frame_buffer.copy(), fps)
frame_buffer.clear()
# Sample frames
if i % ds_factor == 0:
Expand All @@ -149,8 +152,8 @@ def run(args):
draw_area_tl_x=roi[2]+20, draw_area_tl_y=int(roi[1]+(roi[3]-roi[1])/2.0), color=(255, 0, 0))
draw_fps(frame, fps=fps, text="fps", draw_area_bl_x=roi[0], draw_area_bl_y=roi[3]+20)
draw_fps(frame, fps=p_fps, text="p_fps", draw_area_bl_x=int(roi[0]+0.4*(roi[2]-roi[0])), draw_area_bl_y=roi[3]+20)
draw_vital(frame, sig=signals, text="hr [bpm]", sig_name='ppg_waveform_sig', fps=fps, mult=60., color=(0,0,255), draw_area_bl_x=roi[2]+20, draw_area_bl_y=int(roi[1]+(roi[3]-roi[1])/2.0))
draw_vital(frame, sig=signals, text="rr [rpm]", sig_name='respiratory_waveform_sig', fps=fps, mult=60., color=(255,0,0), draw_area_bl_x=roi[2]+20, draw_area_bl_y=roi[3])
draw_vital(frame, sig=signals, text="hr [bpm]", sig_name='ppg_waveform_sig', fps=fps, color=(0,0,255), draw_area_bl_x=roi[2]+20, draw_area_bl_y=int(roi[1]+(roi[3]-roi[1])/2.0))
draw_vital(frame, sig=signals, text="rr [rpm]", sig_name='respiratory_waveform_sig', fps=fps, color=(255,0,0), draw_area_bl_x=roi[2]+20, draw_area_bl_y=roi[3])
cv2.imshow('Live', frame)
c = cv2.waitKey(1)
if c == 27:
Expand Down
18 changes: 10 additions & 8 deletions tests/test_vitallens.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def test_VitalLensRPPGMethod_mock(mock_post, request, file, long, override_fps_t
assert live.shape == (test_video_ndarray.shape[0],)

@pytest.mark.parametrize("process_signals", [True, False])
def test_VitalLens_API_valid_response(request, process_signals):
@pytest.mark.parametrize("n_frames", [16, 250])
def test_VitalLens_API_valid_response(request, process_signals, n_frames):
config = load_config("vitallens.yaml")
api_key = request.getfixturevalue('test_dev_api_key')
test_video_ndarray = request.getfixturevalue('test_video_ndarray')
Expand All @@ -145,7 +146,7 @@ def test_VitalLens_API_valid_response(request, process_signals):
inputs=test_video_ndarray, fps=test_video_fps, target_size=config['input_size'],
roi=test_video_faces[0].tolist(), library='prpy', scale_algorithm='bilinear')
headers = {"x-api-key": api_key}
payload = {"video": base64.b64encode(frames[:16].tobytes()).decode('utf-8')}
payload = {"video": base64.b64encode(frames[:n_frames].tobytes()).decode('utf-8')}
if process_signals: payload['fps'] = str(30)
response = requests.post(API_URL, headers=headers, json=payload)
response_body = json.loads(response.text)
Expand All @@ -157,13 +158,14 @@ def test_VitalLens_API_valid_response(request, process_signals):
ppg_waveform_conf = np.asarray(response_body["vital_signs"]["ppg_waveform"]["confidence"])
resp_waveform_data = np.asarray(response_body["vital_signs"]["respiratory_waveform"]["data"])
resp_waveform_conf = np.asarray(response_body["vital_signs"]["respiratory_waveform"]["confidence"])
assert ppg_waveform_data.shape == (16,)
assert ppg_waveform_conf.shape == (16,)
assert resp_waveform_data.shape == (16,)
assert resp_waveform_conf.shape == (16,)
assert all((key in vital_signs) if process_signals else (key not in vital_signs) for key in ["heart_rate", "respiratory_rate"])
assert ppg_waveform_data.shape == (n_frames,)
assert ppg_waveform_conf.shape == (n_frames,)
assert resp_waveform_data.shape == (n_frames,)
assert resp_waveform_conf.shape == (n_frames,)
t = n_frames/test_video_fps
assert all((key in vital_signs) if (process_signals and t > 8.) else (key not in vital_signs) for key in ["heart_rate", "respiratory_rate"])
live = np.asarray(response_body["face"]["confidence"])
assert live.shape == (16,)
assert live.shape == (n_frames,)
state = np.asarray(response_body["state"]["data"])
assert state.shape == (2, 128)

Expand Down
4 changes: 2 additions & 2 deletions vitallens/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ def __call__(
# Probe inputs
if self.mode == Mode.BURST and not isinstance(video, np.ndarray):
raise ValueError("Must provide `np.ndarray` inputs for burst mode.")
if self.mode == Mode.BURST and video.shape[0] > API_MAX_FRAMES:
raise ValueError(f"Maximum number of frames in burst mode is {API_MAX_FRAMES}, but received {video.shape[0]}.")
if self.mode == Mode.BURST and video.shape[0] > (API_MAX_FRAMES - self.rppg.n_inputs + 1):
raise ValueError(f"Maximum number of frames in burst mode is {API_MAX_FRAMES - self.rppg.n_inputs + 1}, but received {video.shape[0]}.")
inputs_shape, fps, _ = probe_image_inputs(video, fps=fps, allow_image=False)
# TODO: Optimize performance of simple rPPG methods for long videos
# Warning if using long video
Expand Down
2 changes: 2 additions & 0 deletions vitallens/configs/vitallens.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
model: 'vitallens'
# Size of the input
input_size: 40
# Number of inputs
n_inputs: 4
# List estimated signals
signals: ['heart_rate', 'respiratory_rate', 'ppg_waveform', 'respiratory_waveform']

Expand Down
1 change: 1 addition & 0 deletions vitallens/methods/rppg_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
"""
self.fps_target = config['fps_target']
self.op_mode = mode
self.n_inputs = 1
self.est_window_length = config['est_window_length']
self.est_window_overlap = config['est_window_overlap']
self.est_window_flexible = self.est_window_length == 0
Expand Down
32 changes: 26 additions & 6 deletions vitallens/methods/vitallens.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ def __init__(
self.api_key = api_key
self.model = config['model']
self.input_size = config['input_size']
self.n_inputs = config['n_inputs']
self.roi_method = config['roi_method']
self.signals = config['signals']
if mode == Mode.BURST:
self.state = None
self.input_buffer = None
def __call__(
self,
inputs: Union[np.ndarray, str],
Expand Down Expand Up @@ -140,8 +142,9 @@ def __call__(
x=idxs, y=conf_ds, xs=np.arange(inputs_n), axis=1)
live = interpolate_cubic_spline(
x=idxs, y=live_ds, xs=np.arange(inputs_n), axis=0)
# Filter (2, n_frames)
sig = np.asarray([self.postprocess(p, fps, type=name) for p, name in zip(sig, ['ppg', 'resp'])])
# Filter only in batch mode (2, n_frames)
if self.op_mode == Mode.BATCH:
sig = np.asarray([self.postprocess(p, fps, type=name) for p, name in zip(sig, ['ppg', 'resp'])])
# Assemble and return the results
return assemble_results(sig=sig,
conf=conf,
Expand Down Expand Up @@ -209,21 +212,38 @@ def process_api_batch(
else:
idxs = list(range(0, inputs_shape[0], ds_factor))
else:
# Buffer inputs for burst mode
if self.op_mode == Mode.BURST:
# Inputs in burst mode are always np.ndarray
if self.state is not None:
# State has been initialized
assert self.input_buffer is not None
if inputs.shape[1:] != self.input_buffer.shape[1:]:
raise ValueError("In burst mode, input dimensions must be consistent.")
inputs = np.concatenate([self.input_buffer, inputs], axis=0)
self.input_buffer = inputs[-(self.n_inputs-1):]
# Inputs have not been parsed globally. Parse the inputs
frames_ds, _, _, ds_factor, idxs = parse_image_inputs(
inputs=inputs, fps=fps, roi=roi, target_size=self.input_size, target_fps=fps_target,
preserve_aspect_ratio=False, library='prpy', scale_algorithm='bilinear',
trim=(start, end) if start is not None and end is not None else None,
allow_image=False, videodims=True)
# Make sure we have the correct number of frames
idxs = np.asarray(idxs)
expected_n = math.ceil(((end-start) if start is not None and end is not None else inputs_shape[0]) / ds_factor)
if frames_ds.shape[0] != expected_n or len(idxs) != expected_n:
if (self.op_mode == Mode.BURST and self.state is not None): expected_n += (self.n_inputs - 1)
if frames_ds.shape[0] != expected_n or idxs.shape[0] != expected_n:
raise ValueError("Unexpected number of frames returned. Try to set `override_global_parse` to `True` or `False`.")
# Prepare API header and payload
headers = {"x-api-key": self.api_key}
payload = {"video": base64.b64encode(frames_ds.tobytes()).decode('utf-8')}
if self.op_mode == Mode.BURST and self.state is not None:
payload["state"] = base64.b64encode(self.state.astype(np.float32).tobytes()).decode('utf-8')
if self.op_mode == Mode.BURST:
if self.state is not None:
# State and frame buffer have been initialized
assert self.input_buffer is not None
payload["state"] = base64.b64encode(self.state.astype(np.float32).tobytes()).decode('utf-8')
# Adjust idxs
idxs = idxs[3:] - 3
# Ask API to process video
response = requests.post(API_URL, headers=headers, json=payload)
response_body = json.loads(response.text)
Expand All @@ -250,7 +270,6 @@ def process_api_batch(
live_ds = np.asarray(response_body["face"]["confidence"])
if self.op_mode == Mode.BURST:
self.state = np.asarray(response_body["state"]["data"], dtype=np.float32)
idxs = np.asarray(idxs)
return sig_ds, conf_ds, live_ds, idxs
def postprocess(
self,
Expand Down Expand Up @@ -293,3 +312,4 @@ def reset(self):
"""Reset"""
if self.op_mode == Mode.BURST:
self.state = None
self.input_buffer = None

0 comments on commit 15e7130

Please sign in to comment.