-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- i'm especially proud of the new huggingface & langchain code - i need to create better usage-guidance documentation throughout this library
- Loading branch information
Showing
18 changed files
with
753 additions
and
163 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
Empty file.
Empty file.
4 changes: 4 additions & 0 deletions
4
HuggingFace/audio_transcription/MicrophoneTranscription/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
pyaudio | ||
numpy | ||
torch | ||
transformers |
45 changes: 45 additions & 0 deletions
45
HuggingFace/audio_transcription/MicrophoneTranscription/run.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import argparse | ||
import logging | ||
import sys | ||
|
||
from transcribe_microphone import RealTimeASR | ||
|
||
|
||
def main(args): | ||
# Initialize the RealTimeASR object | ||
asr_app = RealTimeASR(maxlen=args.maxlen) | ||
asr_app.initialize_audio() | ||
|
||
# Set up logging | ||
try: | ||
with open(args.log_file, "a") as f: | ||
logging.basicConfig(filename=args.log_file, level=logging.INFO) | ||
except (FileNotFoundError, PermissionError): | ||
print(f"Error opening log file: {args.log_file}", file=sys.stderr, flush=True) | ||
args.log_file = None | ||
|
||
try: | ||
# Capture and transcribe audio in real-time | ||
if asr_app.stream.is_active() and asr_app.asr_pipeline.is_running(): | ||
asr_app.capture_and_transcribe(log_file=args.log_file) | ||
else: | ||
print("Error: PyAudio stream or ASR pipeline is not active.", file=sys.stderr, flush=True) | ||
except KeyboardInterrupt: | ||
logging.info("Transcription stopped by user.") | ||
print("Stopping transcription.") | ||
except Exception as e: | ||
logging.exception("Error during transcription.") | ||
print(f"Error during transcription: {e}", file=sys.stderr, flush=True) | ||
finally: | ||
# Close the PyAudio stream and write final transcriptions to the log file | ||
asr_app.close_stream(log_file=args.log_file) | ||
|
||
|
||
if __name__ == "__main__": | ||
# Parse command-line arguments | ||
parser = argparse.ArgumentParser(description="Real-time ASR using the Transformers library.") | ||
parser.add_argument("--maxlen", type=int, default=300, help="Maximum number of transcriptions to store in the cache.") | ||
parser.add_argument("--log-file", type=str, default="transcription_log.txt", help="Path to the log file to write transcriptions to.") | ||
args = parser.parse_args() | ||
|
||
main(args) |
204 changes: 204 additions & 0 deletions
204
HuggingFace/audio_transcription/MicrophoneTranscription/transcribe_microphone.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
""" | ||
Long-Form Transcription | ||
The Whisper model is intrinsically designed to work on audio samples of up to 30s in duration. However, by using a chunking algorithm, it can be used to transcribe audio samples of up to arbitrary length. This is possible through Transformers pipeline method. Chunking is enabled by setting chunk_length_s=30 when instantiating the pipeline. With chunking enabled, the pipeline can be run with batched inference. It can also be extended to predict sequence level timestamps by passing return_timestamps=True: | ||
""" | ||
|
||
import pyaudio | ||
import numpy as np | ||
import torch | ||
from transformers import pipeline | ||
from collections import deque | ||
import sys | ||
import os | ||
|
||
|
||
class RealTimeASR: | ||
""" | ||
This class demonstrates how to perform real-time ASR using the pipeline method of the Transformers library. | ||
""" | ||
def __init__(self, maxlen=300): | ||
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" | ||
self.asr_pipeline = pipeline( | ||
"automatic-speech-recognition", | ||
model="openai/whisper-large-v2", | ||
chunk_length_s=30, | ||
device=self.device, | ||
return_timestamps=True | ||
) | ||
self.transcription_cache = deque(maxlen=maxlen) | ||
self.sliding_window = np.array([]) | ||
|
||
def initialize_audio(self): | ||
self.p = pyaudio.PyAudio() | ||
self.stream = self.p.open(format=pyaudio.paInt16, | ||
channels=1, | ||
rate=16000, | ||
input=True, | ||
frames_per_buffer=1024) | ||
|
||
def capture_and_transcribe(self, log_file=None): | ||
""" | ||
Continuously captures audio from the microphone, concatenates it to a sliding window, and transcribes the audio | ||
using the ASR pipeline. If the sliding window is longer than 30 seconds, the pipeline is run on the first 30 seconds | ||
of audio and the sliding window is shifted by 5 seconds. If there is a transcription in the cache, it is printed to | ||
stdout and written to the log file. | ||
Args: | ||
log_file (str): The path to the log file to write transcriptions to. If None, transcriptions will not be written | ||
to a log file. | ||
Returns: | ||
None | ||
""" | ||
# Check if the log file path is valid before writing to it | ||
if log_file is not None: | ||
try: | ||
with open(log_file, "a") as f: | ||
pass | ||
except (FileNotFoundError, PermissionError): | ||
print(f"Error opening log file: {log_file}", file=sys.stderr, flush=True) | ||
log_file = None | ||
|
||
while True: | ||
# Check if the PyAudio stream is active before attempting to read from it | ||
if self.stream.is_active(): | ||
# Capture audio from the microphone | ||
audio_data = np.frombuffer(self.stream.read(1024), dtype=np.int16) | ||
|
||
# Concatenate the audio data to the sliding window | ||
self.sliding_window = np.concatenate((self.sliding_window, audio_data)) | ||
|
||
# Check if the sliding window is shorter than the ASR pipeline chunk length before attempting to transcribe it | ||
if len(self.sliding_window) >= 16000 * self.asr_pipeline.task.config.chunk_size_ms / 1000: | ||
# Check if the sliding window is shorter than 30 seconds before attempting to transcribe it | ||
if len(self.sliding_window) < 16000 * 30: | ||
# Transcribe the sliding window and shift it by the shift length | ||
transcription = self.asr_pipeline(self.sliding_window) | ||
# Check if the ASR pipeline returns a transcription before appending it to the cache | ||
if "text" in transcription: | ||
# Check if the ASR pipeline returns timestamps before appending them to the cache | ||
if "timestamps" in transcription: | ||
# Check if the transcription cache is full before attempting to append a new transcription | ||
if len(self.transcription_cache) == self.transcription_cache.maxlen: | ||
self.transcription_cache.popleft() | ||
self.transcription_cache.append(transcription) | ||
self.sliding_window = np.array([]) | ||
else: | ||
print("Error: ASR pipeline does not return timestamps.", file=sys.stderr, flush=True) | ||
else: | ||
print("Error transcribing audio.", file=sys.stderr, flush=True) | ||
else: | ||
# Transcribe the first 30 seconds of audio and shift the sliding window by the shift length | ||
transcription = self.asr_pipeline(self.sliding_window[:16000 * 30]) | ||
# Check if the ASR pipeline returns a transcription before appending it to the cache | ||
if "text" in transcription: | ||
# Check if the ASR pipeline returns timestamps before appending them to the cache | ||
if "timestamps" in transcription: | ||
# Check if the transcription cache is full before attempting to append a new transcription | ||
if len(self.transcription_cache) == self.transcription_cache.maxlen: | ||
self.transcription_cache.popleft() | ||
self.transcription_cache.append(transcription) | ||
self.sliding_window = self.sliding_window[16000 * self.asr_pipeline.task.config.shift_ms / 1000:] | ||
else: | ||
print("Error: ASR pipeline does not return timestamps.", file=sys.stderr, flush=True) | ||
else: | ||
print("Error transcribing audio.", file=sys.stderr, flush=True) | ||
|
||
# If there is a transcription in the cache, print it to stdout and write it to the log file | ||
if len(self.transcription_cache) > 0: | ||
transcription = self.transcription_cache.popleft() | ||
# Check if the transcription cache is empty before attempting to pop a transcription | ||
if transcription is not None: | ||
# Check if the ASR pipeline returns timestamps before appending them to the cache | ||
if "timestamps" in transcription: | ||
print(transcription["text"], file=sys.stdout, flush=True) | ||
if log_file is not None: | ||
# Check if the log file is a file before writing to it | ||
if not os.path.isfile(log_file): | ||
print(f"Error writing to log file: {log_file}", file=sys.stderr, flush=True) | ||
else: | ||
# Check if the log file directory exists before writing to it | ||
log_dir = os.path.dirname(log_file) | ||
if not os.path.isdir(log_dir): | ||
print(f"Error writing to log file: {log_file}", file=sys.stderr, flush=True) | ||
else: | ||
# Check if the user has permission to write to the log file before writing to it | ||
if not os.access(log_file, os.W_OK): | ||
print(f"Error writing to log file: {log_file}", file=sys.stderr, flush=True) | ||
else: | ||
# Check if the log file is too large before writing to it | ||
if os.path.isfile(log_file) and os.path.getsize(log_file) > 1000000: | ||
log_file = create_new_log_file(log_file) | ||
try: | ||
with open(log_file, "a") as f: | ||
f.write(transcription["text"] + "\n") | ||
except (FileNotFoundError, PermissionError): | ||
print(f"Error writing to log file: {log_file}", file=sys.stderr, flush=True) | ||
else: | ||
print("Error: ASR pipeline does not return timestamps.", file=sys.stderr, flush=True) | ||
else: | ||
# Check if the PyAudio stream is stopped before closing it | ||
if self.stream.is_stopped(): | ||
self.stream.close() | ||
# Check if the PyAudio library is terminated before closing the stream | ||
if self.p.is_terminated(): | ||
self.p.terminate() | ||
# Write the final transcriptions to the log file | ||
if log_file is not None: | ||
# Check if the log file is writable before writing to it | ||
if not os.access(log_file, os.W_OK): | ||
print(f"Error writing to log file: {log_file}", file=sys.stderr, flush=True) | ||
else: | ||
try: | ||
with open(log_file, "a") as f: | ||
for transcription in self.transcription_cache: | ||
if transcription is not None: | ||
# Check if the ASR pipeline returns timestamps before appending them to the cache | ||
if "timestamps" in transcription: | ||
# Check if the transcription cache is full before attempting to append a new transcription | ||
if len(self.transcription_cache) == self.transcription_cache.maxlen: | ||
self.transcription_cache.popleft() | ||
self.transcription_cache.append(transcription) | ||
f.write(transcription["text"] + "\n") | ||
else: | ||
print("Error: ASR pipeline does not return timestamps.", file=sys.stderr, flush=True) | ||
except (FileNotFoundError, PermissionError): | ||
print(f"Error writing to log file: {log_file}", file=sys.stderr, flush=True) | ||
else: | ||
print("Error terminating PyAudio library.", file=sys.stderr, flush=True) | ||
else: | ||
print("Error stopping PyAudio stream.", file=sys.stderr, flush=True) | ||
break | ||
|
||
|
||
def close_stream(self, log_file=None): | ||
self.stream.stop_stream() | ||
self.stream.close() | ||
self.p.terminate() | ||
|
||
# Write the final transcription to the log file | ||
if log_file is not None: | ||
with open(log_file, "a") as f: | ||
for transcription in self.transcription_cache: | ||
f.write(transcription + "\n") | ||
|
||
|
||
def create_new_log_file(log_file): | ||
""" | ||
Creates a new log file with a different name if the original log file is too large. | ||
Args: | ||
log_file (str): The path to the original log file. | ||
Returns: | ||
str: The path to the new log file. | ||
""" | ||
log_dir = os.path.dirname(log_file) | ||
log_name, log_ext = os.path.splitext(os.path.basename(log_file)) | ||
i = 1 | ||
while True: | ||
new_log_name = f"{log_name}_{i}{log_ext}" | ||
new_log_file = os.path.join(log_dir, new_log_name) | ||
if not os.path.isfile(new_log_file): | ||
return new_log_file | ||
i += 1 |
Binary file not shown.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import unittest | ||
import numpy as np | ||
from io import StringIO | ||
from contextlib import redirect_stdout | ||
from ..MicrophoneTranscription.transcribe_microphone import RealTimeASR | ||
|
||
|
||
import unittest | ||
import numpy as np | ||
from io import StringIO | ||
from contextlib import redirect_stdout | ||
from RealTimeASR import RealTimeASR | ||
|
||
class TestRealTimeASR(unittest.TestCase): | ||
""" | ||
This class contains unit tests for the RealTimeASR class. | ||
""" | ||
def setUp(self): | ||
""" | ||
This method sets up the test environment before each test case is run. | ||
""" | ||
self.asr_app = RealTimeASR() | ||
self.asr_app.initialize_audio() | ||
|
||
def test_sliding_window(self): | ||
""" | ||
This method tests that the sliding window is correctly updated with new audio data. | ||
""" | ||
audio_data = np.ones(16000, dtype=np.int16) | ||
self.asr_app.sliding_window = np.array([]) | ||
self.asr_app.sliding_window = np.concatenate((self.asr_app.sliding_window, audio_data)) | ||
self.assertEqual(len(self.asr_app.sliding_window), 16000) | ||
|
||
def test_transcription_cache(self): | ||
""" | ||
This method tests that the transcription cache is correctly updated with new transcriptions. | ||
""" | ||
transcription = {"text": "hello world"} | ||
self.asr_app.transcription_cache.append(transcription["text"]) | ||
self.assertEqual(len(self.asr_app.transcription_cache), 1) | ||
self.assertEqual(self.asr_app.transcription_cache[0], "hello world") | ||
|
||
def test_capture_and_transcribe(self): | ||
""" | ||
This method tests that the capture_and_transcribe method correctly transcribes audio. | ||
""" | ||
audio_data = np.ones(16000 * 30, dtype=np.int16) | ||
self.asr_app.sliding_window = np.array([]) | ||
self.asr_app.sliding_window = np.concatenate((self.asr_app.sliding_window, audio_data)) | ||
with redirect_stdout(StringIO()): | ||
self.asr_app.capture_and_transcribe() | ||
self.assertEqual(len(self.asr_app.transcription_cache), 1) | ||
self.assertTrue(isinstance(self.asr_app.transcription_cache[0], str)) | ||
|
||
def test_close_stream(self): | ||
""" | ||
This method tests that the stream is closed correctly. | ||
""" | ||
self.asr_app.close_stream() | ||
self.assertTrue(self.asr_app.stream.is_stopped()) | ||
self.assertTrue(self.asr_app.stream.is_closed()) | ||
|
||
def test_device(self): | ||
""" | ||
This method tests that the device is correctly set. | ||
""" | ||
self.assertTrue(self.asr_app.device in ["cuda:0", "cpu"]) | ||
|
||
def test_asr_pipeline(self): | ||
""" | ||
This method tests that the ASR pipeline is correctly set. | ||
""" | ||
self.assertTrue(isinstance(self.asr_app.asr_pipeline, RealTimeASR)) | ||
|
||
def test_sliding_window_shift(self): | ||
""" | ||
This method tests that the sliding window is correctly shifted. | ||
""" | ||
audio_data = np.ones(16000 * 30, dtype=np.int16) | ||
self.asr_app.sliding_window = np.array([]) | ||
self.asr_app.sliding_window = np.concatenate((self.asr_app.sliding_window, audio_data)) | ||
with redirect_stdout(StringIO()): | ||
self.asr_app.capture_and_transcribe() | ||
self.assertEqual(len(self.asr_app.transcription_cache), 1) | ||
self.assertTrue(isinstance(self.asr_app.transcription_cache[0], str)) | ||
self.assertEqual(len(self.asr_app.sliding_window), 16000 * 5) | ||
|
||
def tearDown(self): | ||
""" | ||
This method tears down the test environment after each test case is run. | ||
""" | ||
self.asr_app.close_stream() | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Oops, something went wrong.