Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add audio narration (updated) #346

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
351d87b
added sounddevice to optionally record narration
angelala3252 May 26, 2023
f19a84a
added sounddevice to optionally record narration and initial whisper …
angelala3252 May 27, 2023
e143767
updated requirements for audio narration
angelala3252 May 29, 2023
6f07b93
small changes
angelala3252 May 31, 2023
d3ef09a
fixed issue with created audio file being really slow
angelala3252 May 31, 2023
9e86193
updated to save audio data and transcribed text in database
angelala3252 May 31, 2023
87a814f
pull from main
angelala3252 May 31, 2023
ce84a1b
new alembic migration
angelala3252 May 31, 2023
5c584b2
edited audio tables
angelala3252 Jun 1, 2023
802c8a2
convert audio array to required format for whisper
angelala3252 Jun 1, 2023
aca8cdc
visualize audio info
angelala3252 Jun 1, 2023
42b1007
FLAC compression before storing
angelala3252 Jun 1, 2023
9f4c280
store word by word timestamps
angelala3252 Jun 1, 2023
20d29e1
style changes
angelala3252 Jun 2, 2023
109ffe0
Merge branch 'main' into feat/audio_narration
angelala3252 Jun 14, 2023
8d27b4f
changed tiktoken version
angelala3252 Jun 16, 2023
d631b2d
removed unused tiktoken code
angelala3252 Jun 16, 2023
ab0805e
Merge branch 'main' into feat/audio_narration
angelala3252 Jun 16, 2023
e30538b
alphabetic order, removed redundant dependencies
angelala3252 Jun 18, 2023
9469043
merged AudioInfo and AudioFile
angelala3252 Jun 18, 2023
47bf845
Merge remote-tracking branch 'audio/feat/audio_narration' into feat/a…
angelala3252 Jun 18, 2023
e9f2d36
move audio recording into record_audio function
angelala3252 Jun 19, 2023
9293b0b
use thread-local scoped_session
angelala3252 Jun 19, 2023
a66acbc
Merge branch 'main' into feat/audio_narration
angelala3252 Jun 23, 2023
888d335
remove redundant requirement
angelala3252 Jun 23, 2023
e1a3a18
pull from main
angelala3252 Aug 31, 2023
d7c54f2
pull from main
angelala3252 Aug 31, 2023
3eaa3a8
remove unused tiktoken function
angelala3252 Aug 31, 2023
05834c4
add audio dependencies
angelala3252 Aug 31, 2023
a6e45bd
style changes
angelala3252 Aug 31, 2023
f23df51
new alembic file
angelala3252 Aug 31, 2023
f6cdbc0
delete old requirements.txt
angelala3252 Aug 31, 2023
873cf6d
added audio dependencies
angelala3252 Aug 31, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions alembic/versions/c176288cb508_add_audio_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Add audio info.

Revision ID: c176288cb508
Revises: 8713b142f5de
Create Date: 2023-08-31 00:25:04.889325

"""
import sqlalchemy as sa

from alembic import op
from openadapt.models import ForceFloat

# revision identifiers, used by Alembic.
revision = "c176288cb508"
down_revision = "8713b142f5de"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"audio_info",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("flac_data", sa.LargeBinary(), nullable=True),
sa.Column("transcribed_text", sa.String(), nullable=True),
sa.Column(
"recording_timestamp",
ForceFloat(precision=10, scale=2, asdecimal=False),
nullable=True,
),
sa.Column("sample_rate", sa.Integer(), nullable=True),
sa.Column("words_with_timestamps", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(
["recording_timestamp"],
["recording.timestamp"],
name=op.f("fk_audio_info_recording_timestamp_recording"),
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_audio_info")),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("audio_info")
# ### end Alembic commands ###
Empty file added openadapt/crud.py
Empty file.
21 changes: 21 additions & 0 deletions openadapt/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

from typing import Any
import json

from loguru import logger
import sqlalchemy as sa
Expand All @@ -12,6 +13,7 @@
from openadapt.db.db import BaseModel, Session
from openadapt.models import (
ActionEvent,
AudioInfo,
MemoryStat,
PerformanceStat,
Recording,
Expand Down Expand Up @@ -411,3 +413,22 @@ def new_session() -> None:
if db:
db.close()
db = Session()


def insert_audio_info(
audio_data: bytes,
transcribed_text: str,
recording_timestamp: float,
sample_rate: int,
word_list: list,
) -> None:
"""Create an AudioInfo entry in the database."""
audio_info = AudioInfo(
flac_data=audio_data,
transcribed_text=transcribed_text,
recording_timestamp=recording_timestamp,
sample_rate=sample_rate,
words_with_timestamps=json.dumps(word_list),
)
db.add(audio_info)
db.commit()
5 changes: 3 additions & 2 deletions openadapt/db/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dictalchemy import DictableModel
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.schema import MetaData
import sqlalchemy as sa

Expand Down Expand Up @@ -67,4 +67,5 @@ def get_base(engine: sa.engine) -> sa.engine:

engine = get_engine()
Base = get_base(engine)
Session = sessionmaker(bind=engine)
session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory)
17 changes: 17 additions & 0 deletions openadapt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class Recording(db.Base):
order_by="WindowEvent.timestamp",
)

audio_info = sa.orm.relationship("AudioInfo", back_populates="recording")

_processed_action_events = None

@property
Expand Down Expand Up @@ -378,6 +380,21 @@ def get_active_window_event(cls: "WindowEvent") -> "WindowEvent":
return WindowEvent(**window.get_active_window_data())


class AudioInfo(db.Base):
"""Class representing the audio from a recording in the database."""

__tablename__ = "audio_info"

id = sa.Column(sa.Integer, primary_key=True)
flac_data = sa.Column(sa.LargeBinary)
transcribed_text = sa.Column(sa.String)
recording_timestamp = sa.Column(sa.ForeignKey("recording.timestamp"))
sample_rate = sa.Column(sa.Integer)
words_with_timestamps = sa.Column(sa.Text)

recording = sa.orm.relationship("Recording", back_populates="audio_info")


class PerformanceStat(db.Base):
"""Class representing a performance statistic in the database."""

Expand Down
111 changes: 108 additions & 3 deletions openadapt/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@

$ python openadapt/record.py "<description of task to be recorded>"

To record audio:

$ python openadapt/record.py "<description of task to be recorded>" --enable_audio

"""

from collections import namedtuple
from functools import partial, wraps
from typing import Any, Callable, Union
import io
import multiprocessing
import os
import queue
Expand All @@ -24,7 +29,11 @@
from tqdm import tqdm
import fire
import mss.tools
import numpy as np
import psutil
import sounddevice
angelala3252 marked this conversation as resolved.
Show resolved Hide resolved
import soundfile
import whisper

from openadapt import config, utils, window
from openadapt.db import crud
Expand Down Expand Up @@ -804,15 +813,101 @@ def read_mouse_events(
mouse_listener.stop()


def record_audio(
terminate_event: multiprocessing.Event,
recording_timestamp: float,
) -> None:
angelala3252 marked this conversation as resolved.
Show resolved Hide resolved
"""Record audio narration during the recording and store data in database.

Args:
terminate_event: The event to signal termination of event reading.
recording_timestamp: The timestamp of the recording.
"""
utils.configure_logging(logger, LOG_LEVEL)
utils.set_start_time(recording_timestamp)

audio_frames = [] # to store audio frames

def audio_callback(
indata: np.ndarray, frames: int, time: Any, status: sounddevice.CallbackFlags
) -> None:
"""Callback function used when new audio frames are recorded.

Note: time is of type cffi.FFI.CData, but since we don't use this argument
and we also don't use the cffi library, the Any type annotation is used.
"""
# called whenever there is new audio frames
audio_frames.append(indata.copy())

# open InputStream and start recording while ActionEvents are recorded
audio_stream = sounddevice.InputStream(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@angelala3252 @0dm what is the easiest way to implement a MacOS-compatible analog of this? Can we re-use existing code in other PRs?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#362 has a good method of getting audio devices via Apple AVFoundation, I'm sure it can be used here with minimal issue. I'm not sure if it'll be plug & play with my PR though, would need some changes depending on the implementation.

callback=audio_callback, samplerate=16000, channels=1
)
logger.info("Audio recording started.")
audio_stream.start()
terminate_event.wait()
audio_stream.stop()
audio_stream.close()

# Concatenate into one Numpy array
concatenated_audio = np.concatenate(audio_frames, axis=0)
# convert concatenated_audio to format expected by whisper
converted_audio = concatenated_audio.flatten().astype(np.float32)

# Convert audio to text using OpenAI's Whisper
logger.info("Transcribing audio...")
model = whisper.load_model("base")
result_info = model.transcribe(converted_audio, word_timestamps=True, fp16=False)
logger.info(f"The narrated text is: {result_info['text']}")
# empty word_list if the user didn't say anything
word_list = []
# segments could be empty
if len(result_info["segments"]) > 0:
# there won't be a 'words' list if the user didn't say anything
if "words" in result_info["segments"][0]:
word_list = result_info["segments"][0]["words"]

# compress and convert to bytes to save to database
logger.info(
"Size of uncompressed audio data: {} bytes".format(converted_audio.nbytes)
)
# Create an in-memory file-like object
file_obj = io.BytesIO()
# Write the audio data using lossless compression
soundfile.write(
file_obj, converted_audio, int(audio_stream.samplerate), format="FLAC"
)
# Get the compressed audio data as bytes
compressed_audio_bytes = file_obj.getvalue()

logger.info(
"Size of compressed audio data: {} bytes".format(len(compressed_audio_bytes))
)

file_obj.close()

# To decompress the audio and restore it to its original form:
# restored_audio, restored_samplerate = sf.read(
# io.BytesIO(compressed_audio_bytes))

# Create AudioInfo entry
crud.insert_audio_info(
compressed_audio_bytes,
result_info["text"],
recording_timestamp,
int(audio_stream.samplerate),
word_list,
)


@logger.catch
@trace(logger)
def record(
task_description: str,
) -> None:
def record(task_description: str, enable_audio: bool = False) -> None:
"""Record Screenshots/ActionEvents/WindowEvents.

Args:
task_description: A text description of the task to be recorded.
enable_audio: a flag to enable or disable audio recording (default: False)
"""
logger.info(f"{task_description=}")

Expand Down Expand Up @@ -943,6 +1038,13 @@ def record(
)
mem_plotter.start()

if enable_audio:
audio_recorder = threading.Thread(
target=record_audio,
args=(terminate_event, recording_timestamp),
)
audio_recorder.start()

# TODO: discard events until everything is ready

collect_stats()
Expand Down Expand Up @@ -972,6 +1074,9 @@ def record(
screen_event_writer.join()
action_event_writer.join()
window_event_writer.join()
if enable_audio:
audio_recorder.join()

terminate_perf_event.set()

if PLOT_PERFORMANCE:
Expand Down
49 changes: 0 additions & 49 deletions openadapt/strategies/mixins/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ class MyReplayStrategy(OpenAIReplayStrategyMixin):

from loguru import logger
import openai
import tiktoken

from openadapt import cache, config, models
from openadapt.strategies.base import BaseReplayStrategy
Expand All @@ -29,7 +28,6 @@ class MyReplayStrategy(OpenAIReplayStrategyMixin):
MODEL_NAME = "gpt-4"

openai.api_key = config.OPENAI_API_KEY
encoding = tiktoken.get_encoding("cl100k_base")


class OpenAIReplayStrategyMixin(BaseReplayStrategy):
Expand Down Expand Up @@ -187,50 +185,3 @@ def _get_completion(prompt: str) -> str:
logger.debug(f"appending assistant_message=\n{pformat(assistant_message)}")
messages.append(assistant_message)
return messages


# XXX TODO not currently in use
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def num_tokens_from_messages(messages: list, model: str = "gpt-3.5-turbo-0301") -> int:
"""Returns the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.info("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model == "gpt-3.5-turbo":
logger.info(
"Warning: gpt-3.5-turbo may change over time. Returning num tokens "
"assuming gpt-3.5-turbo-0301."
)
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
elif model == "gpt-4":
logger.info(
"Warning: gpt-4 may change over time. Returning num tokens "
"assuming gpt-4-0314."
)
return num_tokens_from_messages(messages, model="gpt-4-0314")
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = (
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
)
tokens_per_name = -1 # if there's a name, the role is omitted
elif model == "gpt-4-0314":
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"""num_tokens_from_messages() is not implemented for model "
"{model}. See "
"https://github.com/openai/openai-python/blob/main/chatml.md for "
information on how messages are converted to tokens."""
)
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
6 changes: 5 additions & 1 deletion openadapt/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import click

from openadapt import config
from openadapt.db.crud import get_latest_recording, get_recording
from openadapt.db.crud import get_audio_info, get_latest_recording, get_recording
from openadapt.events import get_events
from openadapt.utils import (
EMPTY,
Expand Down Expand Up @@ -147,6 +147,10 @@ def main(timestamp: str, notify: bool = True) -> None:
scrub.scrub_text(recording.task_description)
logger.debug(f"{recording=}")

audio_info = row2dict(get_audio_info(recording))
# don't display the FLAC data
del audio_info["flac_data"]

meta = {}
action_events = get_events(recording, process=PROCESS_EVENTS, meta=meta)
event_dicts = rows2dicts(action_events)
Expand Down
Loading