Skip to content

Commit

Permalink
Add support for interpolating more than one frame
Browse files Browse the repository at this point in the history
Signed-off-by: ArchieMeng <[email protected]>
  • Loading branch information
ArchieMeng committed May 4, 2021
1 parent c8734e0 commit 10c3aa6
Showing 1 changed file with 35 additions and 5 deletions.
40 changes: 35 additions & 5 deletions src/rife_ncnn_vulkan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
import sys
from math import floor
from pathlib import Path

from PIL import Image
Expand All @@ -13,9 +13,21 @@


class RIFE:
def __init__(self, gpuid: int = -1, model: str = "rife-HD", tta_mode: bool = False, uhd_mode: bool = False, num_threads: int = 1):
def __init__(self,
gpuid: int = -1,
model: str = "rife-HD",
scale: int = 2,
tta_mode: bool = False,
uhd_mode: bool = False,
num_threads: int = 1):
rife_v2 = "rife-v2" in model
self.model = model

if (scale & (scale -1)) == 0:
self.scale = scale
else:
raise ValueError("scale should be powers of 2")

self._raw_rife = raw.RIFEWrapper(gpuid, tta_mode, uhd_mode, num_threads, rife_v2)
self.load()

Expand All @@ -42,7 +54,24 @@ def load(self, model_dir: str = ""):
else:
raise FileNotFoundError(f"{model_dir} not found")

def process(self, im0: Image, im1: Image) -> Image:
def process(self, im0: Image, im1: Image) -> list[Image]:
"""
interpolate frames between im0 and im1
:param im0: First frame
:param im1: Second frame
:return: the interpolation frames between im0 and im1
"""
def _proc(im0: Image, im1: Image, level) -> list[Image]:
if level == 1:
return []
else:
im = self._process(im0, im1)
level /= 2
return _proc(im0, im, level) + [im] + _proc(im, im1, level)

return _proc(im0, im1, self.scale)

def _process(self, im0: Image, im1: Image) -> Image:
in_bytes0, in_bytes1 = bytearray(im0.tobytes()), bytearray(im1.tobytes())
channels = int(len(in_bytes0) / (im0.width * im0.height))
out_bytes = bytearray(len(in_bytes0))
Expand All @@ -62,6 +91,7 @@ def process(self, im0: Image, im1: Image) -> Image:
t = time()
im0, im1 = Image.open("../images/0.png"), Image.open("../images/1.png")
rife = RIFE(0)
im = rife.process(im0, im1)
im.save("../images/out_wrapper.png")
ims = rife.process(im0, im1)
for i, im in enumerate(ims):
im.save(f"../images/out_{i}.png")
print(f"Elapsed time: {time() - t}s")

0 comments on commit 10c3aa6

Please sign in to comment.