Skip to content

Commit

Permalink
fix(video): fix video playback with force_key_frame (#726)
Browse files Browse the repository at this point in the history
* force_key_frame

* fix embedded video playback

* disable fix_moov

* black; flake8

* fix typo

* yuv420p -> yuv444p (lossless)
  • Loading branch information
abrichr authored Jun 8, 2024
1 parent 0adcc72 commit 1f67822
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 27 deletions.
1 change: 0 additions & 1 deletion openadapt/config.defaults.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
"RECORD_FULL_VIDEO": false,
"RECORD_IMAGES": false,
"LOG_MEMORY": false,
"VIDEO_PIXEL_FORMAT": "rgb24",
"STOP_SEQUENCES": [
[
"o",
Expand Down
3 changes: 2 additions & 1 deletion openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ class SegmentationAdapter(str, Enum):
# useful for debugging but expensive computationally
LOG_MEMORY: bool
REPLAY_STRIP_ELEMENT_STATE: bool = True
VIDEO_PIXEL_FORMAT: str = "rgb24"
VIDEO_ENCODING: str = "libx264"
VIDEO_PIXEL_FORMAT: str = "yuv444p"
VIDEO_DIR_PATH: str = str(VIDEO_DIR_PATH)
# sequences that when typed, will stop the recording of ActionEvents in record.py
STOP_SEQUENCES: list[list[str]] = [
Expand Down
46 changes: 28 additions & 18 deletions openadapt/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def video_pre_callback(db: crud.SaSession, recording: Recording) -> dict[str, An
"video_stream": video_stream,
"video_start_timestamp": video_start_timestamp,
"last_pts": 0,
"video_file_path": video_file_path,
}


Expand All @@ -423,6 +424,11 @@ def video_post_callback(state: dict) -> None:
video.finalize_video_writer(
state["video_container"],
state["video_stream"],
state["video_start_timestamp"],
state["last_frame"],
state["last_frame_timestamp"],
state["last_pts"],
state["video_file_path"],
)


Expand All @@ -435,7 +441,7 @@ def write_video_event(
video_stream: av.stream.Stream,
video_start_timestamp: float,
last_pts: int = 0,
num_copies: int = 2,
**kwargs: dict,
) -> dict[str, Any]:
"""Write a screen event to the video file and update the performance queue.
Expand All @@ -450,29 +456,33 @@ def write_video_event(
video_start_timestamp (float): The base timestamp from which the video
recording started.
last_pts: The last presentation timestamp.
num_copies: The number of times to write the first each frame.
Returns:
dict containing state.
"""
if last_pts != 0:
num_copies = 1
# ensure that the first frame is available (otherwise occasionally it is not)
for _ in range(num_copies):
last_pts = video.write_video_frame(
video_container,
video_stream,
event.data,
event.timestamp,
video_start_timestamp,
last_pts,
)
screenshot_image = event.data
screenshot_timestamp = event.timestamp
force_key_frame = last_pts == 0
last_pts = video.write_video_frame(
video_container,
video_stream,
screenshot_image,
screenshot_timestamp,
video_start_timestamp,
last_pts,
force_key_frame,
)
perf_q.put((f"{event.type}(video)", event.timestamp, utils.get_timestamp()))
return {
"video_container": video_container,
"video_stream": video_stream,
"video_start_timestamp": video_start_timestamp,
"last_pts": last_pts,
**kwargs,
**{
"video_container": video_container,
"video_stream": video_stream,
"video_start_timestamp": video_start_timestamp,
"last_frame": screenshot_image,
"last_frame_timestamp": screenshot_timestamp,
"last_pts": last_pts,
},
}


Expand Down
97 changes: 90 additions & 7 deletions openadapt/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from fractions import Fraction
from pprint import pformat
import os
import subprocess
import tempfile
import threading

from loguru import logger
Expand Down Expand Up @@ -47,7 +49,7 @@ def initialize_video_writer(
width: int,
height: int,
fps: int = 24,
codec: str = "libx264rgb",
codec: str = config.VIDEO_ENCODING,
pix_fmt: str = config.VIDEO_PIXEL_FORMAT,
crf: int = 0,
preset: str = "veryslow",
Expand All @@ -60,8 +62,8 @@ def initialize_video_writer(
height (int): Height of the video.
fps (int, optional): Frames per second of the video. Defaults to 24.
codec (str, optional): Codec used for encoding the video.
Defaults to 'libx264rgb'.
pix_fmt (str, optional): Pixel format of the video. Defaults to 'rgb24'.
Defaults to 'libx264'.
pix_fmt (str, optional): Pixel format of the video. Defaults to 'yuv420p'.
crf (int, optional): Constant Rate Factor for encoding quality.
Defaults to 0 for lossless.
preset (str, optional): Encoding speed/quality trade-off.
Expand Down Expand Up @@ -91,6 +93,7 @@ def write_video_frame(
timestamp: float,
video_start_timestamp: float,
last_pts: int,
force_key_frame: bool = False,
) -> int:
"""Encodes and writes a video frame to the output container from a given screenshot.
Expand All @@ -108,6 +111,7 @@ def write_video_frame(
video_start_timestamp (float): The base timestamp from which the video
recording started.
last_pts (int): The PTS of the last written frame.
force_key_frame (bool): Whether to force this frame to be a key frame.
Returns:
int: The updated last_pts value, to be used for writing the next frame.
Expand All @@ -118,23 +122,28 @@ def write_video_frame(
- The function logs the current timestamp, base timestamp, and
calculated PTS values for debugging purposes.
"""
logger.debug(f"{timestamp=} {video_start_timestamp=}")

# Convert the PIL Image to an AVFrame
av_frame = av.VideoFrame.from_image(screenshot)

# Optionally force a key frame
# TODO: force key frames on active window change?
if force_key_frame:
av_frame.pict_type = "I"

# Calculate the time difference in seconds
time_diff = timestamp - video_start_timestamp

# Calculate PTS, taking into account the fractional average rate
pts = int(time_diff * float(Fraction(video_stream.average_rate)))

logger.debug(f"{time_diff=} {pts=} {video_stream.average_rate=}")
logger.debug(
f"{timestamp=} {video_start_timestamp=} {time_diff=} {pts=} {force_key_frame=}"
)

# Ensure monotonically increasing PTS
if pts <= last_pts:
pts = last_pts + 1
logger.debug("incremented {pts=}")
logger.debug(f"incremented {pts=}")
av_frame.pts = pts
last_pts = pts # Update the last_pts

Expand All @@ -149,16 +158,45 @@ def write_video_frame(
def finalize_video_writer(
video_container: av.container.OutputContainer,
video_stream: av.stream.Stream,
video_start_timestamp: float,
last_frame: Image.Image,
last_frame_timestamp: float,
last_pts: int,
video_file_path: str,
fix_moov: bool = False,
) -> None:
"""Finalizes the video writer, ensuring all buffered frames are encoded and written.
Args:
video_container (av.container.OutputContainer): The AV container to finalize.
video_stream (av.stream.Stream): The AV stream to finalize.
video_start_timestamp (float): The base timestamp from which the video
recording started.
last_frame (Image.Image): The last frame that was written (to be written again).
last_frame_timestamp (float): The timestamp of the last frame that was written.
last_pts (int): The last presentation timestamp.
video_file_path (str): The path to the video file.
fix_moov (bool): Whether to move the moov atom to the beginning of the file.
Setting this to True will fix a bug when displaying the video in Github
comments causing the video to appear to start a few seconds after 0:00.
However, this causes extract_frames to fail.
"""
# Closing the container in the main thread leads to a GIL deadlock.
# https://github.com/PyAV-Org/PyAV/issues/1053

# Write a final key frame
last_pts = write_video_frame(
video_container,
video_stream,
last_frame,
last_frame_timestamp,
video_start_timestamp,
last_pts,
force_key_frame=True,
)

# Closing in the same thread sometimes hangs, so do it in a different thread:

# Define a function to close the container
def close_container() -> None:
logger.info("closing video container...")
Expand All @@ -177,9 +215,54 @@ def close_container() -> None:

# Wait for the thread to finish execution
close_thread.join()

# Move moov atom to beginning of file
if fix_moov:
# TODO: fix this
logger.warning(f"{fix_moov=} will cause extract_frames() to fail!!!")
move_moov_atom(video_file_path)

logger.info("done")


def move_moov_atom(input_file: str, output_file: str = None) -> None:
"""Moves the moov atom to the beginning of the video file using ffmpeg.
If no output file is specified, modifies the input file in place.
Args:
input_file (str): The path to the input MP4 file.
output_file (str, optional): The path to the output MP4 file where the moov
atom is at the beginning. If None, modifies the input file in place.
"""
if output_file is None:
# Create a temporary file
temp_file = tempfile.NamedTemporaryFile(
delete=False,
suffix=".mp4",
dir=os.path.dirname(input_file),
).name
output_file = temp_file

command = [
"ffmpeg",
"-y", # Automatically overwrite files without asking
"-i",
input_file,
"-codec",
"copy", # Avoid re-encoding; just copy streams
"-movflags",
"faststart", # Move the moov atom to the start
output_file,
]
logger.info(f"{command=}")
subprocess.run(command, check=True)

if temp_file:
# Replace the original file with the modified one
os.replace(temp_file, input_file)


def extract_frames(
video_filename: str,
timestamps: list[str],
Expand Down

0 comments on commit 1f67822

Please sign in to comment.