Skip to content

Commit

Permalink
Proj Refactor + New libs (#29)
Browse files Browse the repository at this point in the history
- i'm especially proud of the new huggingface & langchain code
- i need to create better usage-guidance documentation throughout this library
  • Loading branch information
Daethyra authored Oct 11, 2023
2 parents bf1594f + 3873ac2 commit 6c8122d
Show file tree
Hide file tree
Showing 18 changed files with 753 additions and 163 deletions.
Empty file added .github/.gitignore
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pyaudio
numpy
torch
transformers
45 changes: 45 additions & 0 deletions HuggingFace/audio_transcription/MicrophoneTranscription/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import argparse
import logging
import sys

from transcribe_microphone import RealTimeASR


def main(args):
# Initialize the RealTimeASR object
asr_app = RealTimeASR(maxlen=args.maxlen)
asr_app.initialize_audio()

# Set up logging
try:
with open(args.log_file, "a") as f:
logging.basicConfig(filename=args.log_file, level=logging.INFO)
except (FileNotFoundError, PermissionError):
print(f"Error opening log file: {args.log_file}", file=sys.stderr, flush=True)
args.log_file = None

try:
# Capture and transcribe audio in real-time
if asr_app.stream.is_active() and asr_app.asr_pipeline.is_running():
asr_app.capture_and_transcribe(log_file=args.log_file)
else:
print("Error: PyAudio stream or ASR pipeline is not active.", file=sys.stderr, flush=True)
except KeyboardInterrupt:
logging.info("Transcription stopped by user.")
print("Stopping transcription.")
except Exception as e:
logging.exception("Error during transcription.")
print(f"Error during transcription: {e}", file=sys.stderr, flush=True)
finally:
# Close the PyAudio stream and write final transcriptions to the log file
asr_app.close_stream(log_file=args.log_file)


if __name__ == "__main__":
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Real-time ASR using the Transformers library.")
parser.add_argument("--maxlen", type=int, default=300, help="Maximum number of transcriptions to store in the cache.")
parser.add_argument("--log-file", type=str, default="transcription_log.txt", help="Path to the log file to write transcriptions to.")
args = parser.parse_args()

main(args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
"""
Long-Form Transcription
The Whisper model is intrinsically designed to work on audio samples of up to 30s in duration. However, by using a chunking algorithm, it can be used to transcribe audio samples of up to arbitrary length. This is possible through Transformers pipeline method. Chunking is enabled by setting chunk_length_s=30 when instantiating the pipeline. With chunking enabled, the pipeline can be run with batched inference. It can also be extended to predict sequence level timestamps by passing return_timestamps=True:
"""

import pyaudio
import numpy as np
import torch
from transformers import pipeline
from collections import deque
import sys
import os


class RealTimeASR:
"""
This class demonstrates how to perform real-time ASR using the pipeline method of the Transformers library.
"""
def __init__(self, maxlen=300):
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.asr_pipeline = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v2",
chunk_length_s=30,
device=self.device,
return_timestamps=True
)
self.transcription_cache = deque(maxlen=maxlen)
self.sliding_window = np.array([])

def initialize_audio(self):
self.p = pyaudio.PyAudio()
self.stream = self.p.open(format=pyaudio.paInt16,
channels=1,
rate=16000,
input=True,
frames_per_buffer=1024)

def capture_and_transcribe(self, log_file=None):
"""
Continuously captures audio from the microphone, concatenates it to a sliding window, and transcribes the audio
using the ASR pipeline. If the sliding window is longer than 30 seconds, the pipeline is run on the first 30 seconds
of audio and the sliding window is shifted by 5 seconds. If there is a transcription in the cache, it is printed to
stdout and written to the log file.
Args:
log_file (str): The path to the log file to write transcriptions to. If None, transcriptions will not be written
to a log file.
Returns:
None
"""
# Check if the log file path is valid before writing to it
if log_file is not None:
try:
with open(log_file, "a") as f:
pass
except (FileNotFoundError, PermissionError):
print(f"Error opening log file: {log_file}", file=sys.stderr, flush=True)
log_file = None

while True:
# Check if the PyAudio stream is active before attempting to read from it
if self.stream.is_active():
# Capture audio from the microphone
audio_data = np.frombuffer(self.stream.read(1024), dtype=np.int16)

# Concatenate the audio data to the sliding window
self.sliding_window = np.concatenate((self.sliding_window, audio_data))

# Check if the sliding window is shorter than the ASR pipeline chunk length before attempting to transcribe it
if len(self.sliding_window) >= 16000 * self.asr_pipeline.task.config.chunk_size_ms / 1000:
# Check if the sliding window is shorter than 30 seconds before attempting to transcribe it
if len(self.sliding_window) < 16000 * 30:
# Transcribe the sliding window and shift it by the shift length
transcription = self.asr_pipeline(self.sliding_window)
# Check if the ASR pipeline returns a transcription before appending it to the cache
if "text" in transcription:
# Check if the ASR pipeline returns timestamps before appending them to the cache
if "timestamps" in transcription:
# Check if the transcription cache is full before attempting to append a new transcription
if len(self.transcription_cache) == self.transcription_cache.maxlen:
self.transcription_cache.popleft()
self.transcription_cache.append(transcription)
self.sliding_window = np.array([])
else:
print("Error: ASR pipeline does not return timestamps.", file=sys.stderr, flush=True)
else:
print("Error transcribing audio.", file=sys.stderr, flush=True)
else:
# Transcribe the first 30 seconds of audio and shift the sliding window by the shift length
transcription = self.asr_pipeline(self.sliding_window[:16000 * 30])
# Check if the ASR pipeline returns a transcription before appending it to the cache
if "text" in transcription:
# Check if the ASR pipeline returns timestamps before appending them to the cache
if "timestamps" in transcription:
# Check if the transcription cache is full before attempting to append a new transcription
if len(self.transcription_cache) == self.transcription_cache.maxlen:
self.transcription_cache.popleft()
self.transcription_cache.append(transcription)
self.sliding_window = self.sliding_window[16000 * self.asr_pipeline.task.config.shift_ms / 1000:]
else:
print("Error: ASR pipeline does not return timestamps.", file=sys.stderr, flush=True)
else:
print("Error transcribing audio.", file=sys.stderr, flush=True)

# If there is a transcription in the cache, print it to stdout and write it to the log file
if len(self.transcription_cache) > 0:
transcription = self.transcription_cache.popleft()
# Check if the transcription cache is empty before attempting to pop a transcription
if transcription is not None:
# Check if the ASR pipeline returns timestamps before appending them to the cache
if "timestamps" in transcription:
print(transcription["text"], file=sys.stdout, flush=True)
if log_file is not None:
# Check if the log file is a file before writing to it
if not os.path.isfile(log_file):
print(f"Error writing to log file: {log_file}", file=sys.stderr, flush=True)
else:
# Check if the log file directory exists before writing to it
log_dir = os.path.dirname(log_file)
if not os.path.isdir(log_dir):
print(f"Error writing to log file: {log_file}", file=sys.stderr, flush=True)
else:
# Check if the user has permission to write to the log file before writing to it
if not os.access(log_file, os.W_OK):
print(f"Error writing to log file: {log_file}", file=sys.stderr, flush=True)
else:
# Check if the log file is too large before writing to it
if os.path.isfile(log_file) and os.path.getsize(log_file) > 1000000:
log_file = create_new_log_file(log_file)
try:
with open(log_file, "a") as f:
f.write(transcription["text"] + "\n")
except (FileNotFoundError, PermissionError):
print(f"Error writing to log file: {log_file}", file=sys.stderr, flush=True)
else:
print("Error: ASR pipeline does not return timestamps.", file=sys.stderr, flush=True)
else:
# Check if the PyAudio stream is stopped before closing it
if self.stream.is_stopped():
self.stream.close()
# Check if the PyAudio library is terminated before closing the stream
if self.p.is_terminated():
self.p.terminate()
# Write the final transcriptions to the log file
if log_file is not None:
# Check if the log file is writable before writing to it
if not os.access(log_file, os.W_OK):
print(f"Error writing to log file: {log_file}", file=sys.stderr, flush=True)
else:
try:
with open(log_file, "a") as f:
for transcription in self.transcription_cache:
if transcription is not None:
# Check if the ASR pipeline returns timestamps before appending them to the cache
if "timestamps" in transcription:
# Check if the transcription cache is full before attempting to append a new transcription
if len(self.transcription_cache) == self.transcription_cache.maxlen:
self.transcription_cache.popleft()
self.transcription_cache.append(transcription)
f.write(transcription["text"] + "\n")
else:
print("Error: ASR pipeline does not return timestamps.", file=sys.stderr, flush=True)
except (FileNotFoundError, PermissionError):
print(f"Error writing to log file: {log_file}", file=sys.stderr, flush=True)
else:
print("Error terminating PyAudio library.", file=sys.stderr, flush=True)
else:
print("Error stopping PyAudio stream.", file=sys.stderr, flush=True)
break


def close_stream(self, log_file=None):
self.stream.stop_stream()
self.stream.close()
self.p.terminate()

# Write the final transcription to the log file
if log_file is not None:
with open(log_file, "a") as f:
for transcription in self.transcription_cache:
f.write(transcription + "\n")


def create_new_log_file(log_file):
"""
Creates a new log file with a different name if the original log file is too large.
Args:
log_file (str): The path to the original log file.
Returns:
str: The path to the new log file.
"""
log_dir = os.path.dirname(log_file)
log_name, log_ext = os.path.splitext(os.path.basename(log_file))
i = 1
while True:
new_log_name = f"{log_name}_{i}{log_ext}"
new_log_file = os.path.join(log_dir, new_log_name)
if not os.path.isfile(new_log_file):
return new_log_file
i += 1
Binary file not shown.
Empty file.
96 changes: 96 additions & 0 deletions HuggingFace/audio_transcription/test/test_at1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import unittest
import numpy as np
from io import StringIO
from contextlib import redirect_stdout
from ..MicrophoneTranscription.transcribe_microphone import RealTimeASR


import unittest
import numpy as np
from io import StringIO
from contextlib import redirect_stdout
from RealTimeASR import RealTimeASR

class TestRealTimeASR(unittest.TestCase):
"""
This class contains unit tests for the RealTimeASR class.
"""
def setUp(self):
"""
This method sets up the test environment before each test case is run.
"""
self.asr_app = RealTimeASR()
self.asr_app.initialize_audio()

def test_sliding_window(self):
"""
This method tests that the sliding window is correctly updated with new audio data.
"""
audio_data = np.ones(16000, dtype=np.int16)
self.asr_app.sliding_window = np.array([])
self.asr_app.sliding_window = np.concatenate((self.asr_app.sliding_window, audio_data))
self.assertEqual(len(self.asr_app.sliding_window), 16000)

def test_transcription_cache(self):
"""
This method tests that the transcription cache is correctly updated with new transcriptions.
"""
transcription = {"text": "hello world"}
self.asr_app.transcription_cache.append(transcription["text"])
self.assertEqual(len(self.asr_app.transcription_cache), 1)
self.assertEqual(self.asr_app.transcription_cache[0], "hello world")

def test_capture_and_transcribe(self):
"""
This method tests that the capture_and_transcribe method correctly transcribes audio.
"""
audio_data = np.ones(16000 * 30, dtype=np.int16)
self.asr_app.sliding_window = np.array([])
self.asr_app.sliding_window = np.concatenate((self.asr_app.sliding_window, audio_data))
with redirect_stdout(StringIO()):
self.asr_app.capture_and_transcribe()
self.assertEqual(len(self.asr_app.transcription_cache), 1)
self.assertTrue(isinstance(self.asr_app.transcription_cache[0], str))

def test_close_stream(self):
"""
This method tests that the stream is closed correctly.
"""
self.asr_app.close_stream()
self.assertTrue(self.asr_app.stream.is_stopped())
self.assertTrue(self.asr_app.stream.is_closed())

def test_device(self):
"""
This method tests that the device is correctly set.
"""
self.assertTrue(self.asr_app.device in ["cuda:0", "cpu"])

def test_asr_pipeline(self):
"""
This method tests that the ASR pipeline is correctly set.
"""
self.assertTrue(isinstance(self.asr_app.asr_pipeline, RealTimeASR))

def test_sliding_window_shift(self):
"""
This method tests that the sliding window is correctly shifted.
"""
audio_data = np.ones(16000 * 30, dtype=np.int16)
self.asr_app.sliding_window = np.array([])
self.asr_app.sliding_window = np.concatenate((self.asr_app.sliding_window, audio_data))
with redirect_stdout(StringIO()):
self.asr_app.capture_and_transcribe()
self.assertEqual(len(self.asr_app.transcription_cache), 1)
self.assertTrue(isinstance(self.asr_app.transcription_cache[0], str))
self.assertEqual(len(self.asr_app.sliding_window), 16000 * 5)

def tearDown(self):
"""
This method tears down the test environment after each test case is run.
"""
self.asr_app.close_stream()


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 6c8122d

Please sign in to comment.