From 10c3aa6ed70b56020dbf9e75dd7e6f23add9cedc Mon Sep 17 00:00:00 2001 From: ArchieMeng Date: Tue, 4 May 2021 21:55:07 +0800 Subject: [PATCH] Add support for interpolating more than one frame Signed-off-by: ArchieMeng --- src/rife_ncnn_vulkan.py | 40 +++++++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/src/rife_ncnn_vulkan.py b/src/rife_ncnn_vulkan.py index b4c78e9..983f0c9 100644 --- a/src/rife_ncnn_vulkan.py +++ b/src/rife_ncnn_vulkan.py @@ -1,5 +1,5 @@ +from __future__ import annotations import sys -from math import floor from pathlib import Path from PIL import Image @@ -13,9 +13,21 @@ class RIFE: - def __init__(self, gpuid: int = -1, model: str = "rife-HD", tta_mode: bool = False, uhd_mode: bool = False, num_threads: int = 1): + def __init__(self, + gpuid: int = -1, + model: str = "rife-HD", + scale: int = 2, + tta_mode: bool = False, + uhd_mode: bool = False, + num_threads: int = 1): rife_v2 = "rife-v2" in model self.model = model + + if (scale & (scale -1)) == 0: + self.scale = scale + else: + raise ValueError("scale should be powers of 2") + self._raw_rife = raw.RIFEWrapper(gpuid, tta_mode, uhd_mode, num_threads, rife_v2) self.load() @@ -42,7 +54,24 @@ def load(self, model_dir: str = ""): else: raise FileNotFoundError(f"{model_dir} not found") - def process(self, im0: Image, im1: Image) -> Image: + def process(self, im0: Image, im1: Image) -> list[Image]: + """ + interpolate frames between im0 and im1 + :param im0: First frame + :param im1: Second frame + :return: the interpolation frames between im0 and im1 + """ + def _proc(im0: Image, im1: Image, level) -> list[Image]: + if level == 1: + return [] + else: + im = self._process(im0, im1) + level /= 2 + return _proc(im0, im, level) + [im] + _proc(im, im1, level) + + return _proc(im0, im1, self.scale) + + def _process(self, im0: Image, im1: Image) -> Image: in_bytes0, in_bytes1 = bytearray(im0.tobytes()), bytearray(im1.tobytes()) channels = int(len(in_bytes0) / (im0.width * im0.height)) out_bytes = bytearray(len(in_bytes0)) @@ -62,6 +91,7 @@ def process(self, im0: Image, im1: Image) -> Image: t = time() im0, im1 = Image.open("../images/0.png"), Image.open("../images/1.png") rife = RIFE(0) - im = rife.process(im0, im1) - im.save("../images/out_wrapper.png") + ims = rife.process(im0, im1) + for i, im in enumerate(ims): + im.save(f"../images/out_{i}.png") print(f"Elapsed time: {time() - t}s")