This is a companion guide to Working With wav2vec2 Part 1 - Finetuning XLS-R for Automatic Speech Recognition (the "Part 1 guide") and Working With wav2vec2 Part 2 - Running Inference on Finetuned ASR Models (the "Part 2 guide"). In those guides, I outlined the steps to generate text transcriptions from audio using a finetuned wav2vec2 ASR model. Those readers who opted to train wav2vec2 on the Chilean Spanish dataset that I used for my model likely noticed that the audio samples were (generally) less than 10 seconds long. While the inference notebook in Part 2 can theoretically be used with longer audio inputs, realistically it will "choke" on large audio files (i.e. audio files more than a few seconds in length).
Chunking is a simple technique that we can employ to make wav2vec2 finetuned models work on long audio files. This guide walks through the steps to build a simple Python application that can run inference on long audio files.
You will build this:
It is assumed that you have completed the Part 1 and Part 2 guides and that you have generated your own finetuned wav2vec2 XLS-R model. This guide continues working with the Spanish language, but the Python application that you will build can be used with models finetuned on other languages.
Chunking
Basic logic tells us that it is more computationally expensive to process a longer audio file than a shorter one. In the case of the transformer architecture, the computational complexity of the attention mechanism is quadratic with respect to the length of the sequence fed into the transformer. So, larger and larger sequences drive order-of-magnitude increases in computational requirements.
Chunking is a technique that we can use to get around this limitation. Simply put, we can:
- Divide a long sequence of audio into chunks of fixed lengths, e.g. 15 seconds.
- Run inference on each chunk - i.e. generate a text transcription for each individual chunk.
- Concatenate the chunk-specific text transcriptions to create a complete transcription for the long audio file.
While the approach above will work, the beginning and end of each audio chunk will "see" poor inference performance. This is because wav2vec2 performs inference on a given section of audio using the context of that section - i.e. the audio "around" it. Since, by definition, the beginning and end of each audio chunk have no context, inference results are expected to be poor in those sections.
A solution to this problem is to add some amount of context, a stride, to the beginning and to the end of each audio chunk solely for the purpose of running inference. After the inference is complete, we drop the inference results for the added context so that we are left only with the inference results for the audio chunk. Just as before, we can concatenate the individual text transcriptions to create a complete transcription, but with the benefit of better inference at the beginning and end of each audio chunk.
As will be seen in the next section, configuring a Hugging Face pipeline to use chunking and a context stride is very simple. For additional discussion on this topic, please see this excellent Hugging Face blog post, which also includes a visual depiction of the chunking approach.
Configuring the pipeline
Class for Chunking
As you might remember from the Part 2 guide, we configured an instance of the Hugging Face pipeline
class for automatic speech recognition ("ASR"). Specifically, in Step 2.13 of that guide, we initialized a transcriber
as follows:
transcriber = pipeline("automatic-speech-recognition", model = "YOUR_FINETUNED_MODEL_PATH")
The pipeline
class for ASR accepts two additional initialization arguments that allow us to implement chunking and to add a context stride to each chunk:
-
chunk_length_s
: An integer specifying the length in seconds of each chunk. -
stride_length_s
: A tuple of integers specifying the stride lengths in seconds at the beginning and end of each chunk.
For example, if we wanted to use a chunk size of 8 seconds, and a stride length of 2 seconds on each side of the chunk, we would initialize the pipeline
using:
transcriber = pipeline("automatic-speech-recognition", chunk_length_s = 8, stride_length_s = (2,2), model = "YOUR_FINETUNED_MODEL_PATH")
With these new parameters in mind, you're now ready to build the Python application for long inference.
Prerequisites and Before You Get Started
To complete the guide, you will need to have:
- A finetuned wav2vec2 model.
- Intermediate knowledge of Python.
- Basic knowledge of the Python tkinter package.
- Intermediate knowledge of ML concepts.
- Basic knowledge of ASR concepts.
Building the Python Application for Long Inference
Step 1 - Setting Up Your Python Environment
Use pip
to install the torchaudio
and transformers
packages in your Python environment if they are not already.
pip install torchaudio
pip install transformers
You might also need to install the soundfile
package to use the torchaudio.load
method which is used to load audio files.
Step 2 - Building the Application
The following sub-steps build the application progressively. Each sub-step includes an explanation that describes the particular section of code. If you prefer to view the finished application, feel free to jump to Step 3 which displays the complete program.
Step 2.1 - Creating a New Python Application and Adding Imports
Create a new Python application and add the following import
statements:
import time
import torch
import torchaudio
from transformers import pipeline
import tkinter as tk
import tkinter.scrolledtext as scrolledtext
from tkinter import ttk, filedialog
- The
torchaudio
package will be used to load and resample audio data. - The
transformers
package, and specifically thepipeline
class, will be used to run inference. tkinter
will be used to create the user interface for the application.
It is assumed you have an existing background working with tkinter
GUIs. However, if you need a brief primer, you can take a look at my Hackernoon article Building Your First Python GUI With Tkinter.
Step 2.2 - Adding Constants
Add the following constants below the module imports.
GUI = {
"title": "Long Inference with wav2vec2 ASR Models",
"root_width" : 800,
"root_height" : 500,
"pad_x": 10,
"pad_y": 10,
"input_field_width": 110,
"textbox_width": 125,
"textbox_height": 15,
"select_model_label": "Please select an ASR model",
"browse_models_button_label": "Browse Models",
"select_file_label": "Please select an audio file for inference",
"browse_files_button_label": "Browse Files",
"run_inference_button_label": "Run",
"save_to_file_button_label": "Save To File",
"reset": "Reset",
"default_notification": "",
"notification_select_model": "You need to select an ASR model directory",
"notification_select_audio_file": "You need to select an audio file",
"notification_select_model_and_audio_file": "You need to select an ASR model and audio file before you can run inference",
"notification_running_inference": "Running inference...this might take awhile...",
"notification_finished_inference": "Finished running inference in {} seconds",
"notification_select_file_for_saving": "You need to select a file to save the transcription",
"notification_finished_saving": "Finished saving transcription"
}
TGT_SAMPLING_RATE = 16000
CHUNK_LENGTH = 10
STRIDE_START = 3
STRIDE_END = 3
- The
GUI
dictionary captures window dimensions, labels, and other data used to create and update the user interface. TGT_SAMPLING_RATE
is the target sampling rate, expressed in Hz, used when resampling audio data.CHUNK_LENGTH
is the chunking length expressed in seconds and will be used when initializing thepipeline
class for ASR. In this guide, each audio chunk will be 10 seconds in length.STRIDE_START
andSTRIDE_END
represent the starting and ending context strides applied to each chunk. Both values are expressed in seconds and will be used when initializing thepipeline
class for ASR. With both values equal to3
, an equivalent stride length of 3 seconds will be applied to both the start and end of each audio chunk.
Step 2.3 - Adding Globals
Add the following two global variables below the constant variables.
MODEL = None
AUDIO_FILE = None
- The
MODEL
andAUDIO_FILE
globals will be assigned to the ASR model directory path and audio file path respectively. - As you will see, the application enforces a selection order whereby an ASR model must be chosen before an audio file can be selected.
Step 2.4 - Adding Utility Methods
Add the following utility methods below the MODEL
and AUDIO_FILE
globals.
def read_audio_data(file: str) -> tuple[torch.Tensor, int]:
audio, sampling_rate = torchaudio.load(file, normalize = True)
return audio, sampling_rate
def resample(waveform: torch.Tensor, orig_sampling_rate: int) -> torch.Tensor:
transform = torchaudio.transforms.Resample(orig_sampling_rate, TGT_SAMPLING_RATE)
waveform = transform(waveform)
return waveform[0]
- The
read_audio_data
method is used to load audio files using thetorchaudio.load
method. - The
resample
method is used to resample audio from its original sampling rate to the target sampling rate of16000
as specified byTGT_SAMPLING_RATE
.
Step 2.5 - Adding Callback Methods
There are five callback methods, each of which are bound to one of five user interface buttons:
-
Browse Models button
-
Browse Files button
-
Run button
-
Save To File button
-
Reset button
When a given button is pressed, its respective callback method is called.
Step 2.5.1 - Adding callback_select_model
Add the following code for the callback_select_model
method:
def callback_select_model(e: object, args: list) -> None:
global MODEL
# Unpack widgets
select_model_field = args[0]
notification_label = args[1]
# Open file dialog
dir = filedialog.askdirectory()
# Add `MODEL` to the `select_model_field` widget
if dir:
MODEL = dir
select_model_field.configure({"text": MODEL})
select_model_field.update()
else:
notification_label.configure({"text": GUI["notification_select_model"]})
notification_label.update()
- This method is called when the Browse Models button is pressed on the GUI.
- A file dialog is opened which the user can use to select an ASR model directory.
- If the user closes the dialog before selecting a valid directory, a warning notifcation is displayed.
- If the user selects a valid directory, the directory path is added to the GUI and assigned to the
MODEL
global.
Step 2.5.2 - Adding callback_select_file
Add the following code for the callback_select_file
method:
def callback_select_file(e: object, args: list) -> None:
global AUDIO_FILE
# Unpack widgets
select_file_field = args[0]
notification_label = args[1]
if not MODEL:
notification_label.configure({"text": GUI["notification_select_model"]})
notification_label.update()
else:
# Open file dialog
file = filedialog.askopenfilename()
# Add `AUDIO_FILE` filename to the `select_file_field`
if file:
AUDIO_FILE = file
select_file_field.configure({"text": AUDIO_FILE})
select_file_field.update()
else:
notification_label.configure({"text": GUI["notification_select_audio_file"]})
notification_label.update()
- This method is called when the Browse Files button is pressed on the GUI.
- The logic first checks if an ASR model directory has been selected. If a model directory has not yet been chosen, a warning notification is displayed.
- If an ASR model directory has been chosen, a file dialog is opened which the user can use to select an audio file.
- If the user closes the dialog before selecting a valid file, a warning notifcation is displayed.
- If the user selects a valid file, the file path is added to the GUI and assigned to the
AUDIO_FILE
global.
Step 2.5.3 - Adding callback_run_inference
Add the following code for the callback_run_inference
method:
def callback_run_inference(e: object, args: list) -> None:
global MODEL
global AUDIO_FILE
# Unpack widgets
textbox = args[0]
notification_label = args[1]
save_to_file_button = args[2]
reset_button = args[3]
select_model_button = args[4]
select_file_button = args[5]
if not MODEL or not AUDIO_FILE:
notification_label.configure({"text": GUI["notification_select_model_and_audio_file"]})
notification_label.update()
else:
# Disable select model and select file buttons when running inference
select_model_button.unbind("<Button>")
select_file_button.unbind("<Button>")
# Set the input list for the ASR pipline
pipeline_input = []
orig_audio, orig_sampling_rate = read_audio_data(AUDIO_FILE)
resampled_audio = resample(orig_audio, orig_sampling_rate)
pipeline_input.append({
"raw": resampled_audio.numpy(),
"sampling_rate": TGT_SAMPLING_RATE
})
# Initialize instance of ASR pipeline
transcriber = pipeline("automatic-speech-recognition", chunk_length_s = CHUNK_LENGTH, stride_length_s = (STRIDE_START, STRIDE_END), model = MODEL)
# Update notification
notification_label.configure({"text": GUI["notification_running_inference"]})
notification_label.update()
# Set start time
start_time = time.time()
# Run inference
transcription = transcriber(pipeline_input)
# Set end time
end_time = time.time()
# Add transcription to text box
textbox.insert("1.0", transcription[0]["text"])
textbox.update()
# Update notification
notification_label.configure({"text": GUI["notification_finished_inference"].format(str(int(end_time - start_time)))})
notification_label.update()
# Bind `save_to_file_button`
save_to_file_button.bind("<Button>", lambda e, args = [textbox, notification_label]: callback_save_file(e, args))
- This method is called when the Run button is pressed on the GUI and is the heart of the application.
- The logic first checks that an ASR model directory and audio file path have been selected. If either has not yet been selected, a warning notification is displayed.
- If a valid ASR model directory and audio file have been selected, the logic will:
- Disable the Browse Models and Browse Files buttons in preparation for running inference.
- Load the chosen audio file.
- Resample the audio data to the target sampling rate of 16,000 Hz.
- Initialize the
pipeline
class for automatic speech recognition using the ASR model specified byMODEL
, along with theCHUNK_LENGTH
,STRIDE_START
, andSTRIDE_END
values set earlier. - Run inference on the audio sample.
- Add the complete text transcription to the GUI textbox for review.
- If you completed the Part 2 guide, you will recognize that the inference logic mimics the inference logic in Step 2.12 through Step 2.14 of that guide, with the exception of initializing the
pipeline
class with thechunk_length_s
andstride_length_s
parameters. - A success notification is displayed after inference is complete with the inference duration expressed in seconds.
- The Save To File button on the GUI is enabled after inference is complete.
Step 2.5.4 - Adding callback_save_file
Add the following code for the callback_save_file
method:
def callback_save_file(e: object, args: list) -> None:
# Unpack widgets
textbox = args[0]
notification_label = args[1]
# Ask user to select a file for saving
save_file = filedialog.asksaveasfilename()
if save_file:
# Write transcription to file
transcription = textbox.get(1.0, tk.END)
with open(save_file, "w", encoding = "utf8") as handle:
handle.write(transcription)
# Update notification
notification_label.configure({"text": GUI["notification_finished_saving"]})
notification_label.update()
else:
notification_label.configure({"text": GUI["select_file_for_saving"]})
notification_label.update()
- This method is called when the Save To File button is pressed on the GUI.
- A file dialog is opened which the user can use to specify a filename and location for saving the generated text transcription.
- If the user closes the dialog before specifying a valid filename, a warning notifcation is displayed.
- If the user provides a valid filename, the transcription is written out to file using the specified filename. A success notification confirming the save is displayed for the user after the write is complete.
Step 2.5.5 - Adding callback_reset
Add the following code for the callback_reset
method:
def callback_reset(e: object, args: list) -> None:
global MODEL
global AUDIO_FILE
# Unpack widgets
gui_root = args[0]
select_model_field = args[1]
select_model_button = args[2]
select_file_field = args[3]
select_file_button = args[4]
textbox = args[5]
save_to_file_button = args[6]
notification_label = args[7]
MODEL = None
AUDIO_FILE = None
select_model_field.configure({"text": ""})
select_file_field.configure({"text": ""})
textbox.delete("1.0", tk.END)
notification_label.configure({"text": ""})
select_model_button.bind("<Button>", lambda e, args = [select_model_field, notification_label]: callback_select_model(e, args))
select_file_button.bind("<Button>", lambda e, args = [select_file_field, notification_label]: callback_select_file(e, args))
save_to_file_button.unbind("<Button>")
gui_root.update()
- This method is called when the Reset button is pressed on the GUI. It resets the application for a new inference run. The logic:
- Resets the
MODEL
andAUDIO_FILE
globals toNone
. - Clears the existing ASR model directory path and audio file path data from the GUI.
- Clears the transcription textbox.
- Clears any displayed notification.
- Re-enables the Browse Models and Browse Files buttons by binding those widgets to their respective callbacks.
- Disables the Save To File button by unbinding it from its callback.
- Resets the
Step 2.6 - Adding main
Method
Add the following code for the main
method below the callback methods:
def main():
gui_root = tk.Tk()
gui_root.title(GUI["title"])
window_width = GUI["root_width"]
window_height = GUI["root_height"]
# Get the screen dimensions
screen_width = gui_root.winfo_screenwidth()
screen_height = gui_root.winfo_screenheight()
# Find the center point
center_x = int(screen_width/2 - window_width/2)
center_y = int(screen_height/2 - window_height/2)
# Set the position of the window to the center of the screen
gui_root.geometry(f"{window_width}x{window_height}+{center_x}+{center_y}")
# Not resizable
gui_root.resizable(False, False)
# Configure grid
gui_root.configure(padx = GUI["pad_x"])
gui_root.columnconfigure(0, weight = 1)
# Widgets
select_model_label = ttk.Label(gui_root, text = GUI["select_model_label"])
select_model_frame = ttk.Frame(gui_root)
select_model_field = ttk.Label(select_model_frame, text = "")
select_model_button = ttk.Button(select_model_frame, text = GUI["browse_models_button_label"])
select_file_label = ttk.Label(gui_root, text = GUI["select_file_label"])
select_file_frame = ttk.Frame(gui_root)
select_file_field = ttk.Label(select_file_frame, text = "")
select_file_button = ttk.Button(select_file_frame, text = GUI["browse_files_button_label"])
run_inference_button = ttk.Button(gui_root, text = GUI["run_inference_button_label"])
textbox = scrolledtext.ScrolledText(gui_root, width = GUI["textbox_width"], height = GUI["textbox_height"])
textbox_buttons_frame = ttk.Frame(gui_root)
save_to_file_button = ttk.Button(textbox_buttons_frame, text = GUI["save_to_file_button_label"])
reset_button = ttk.Button(textbox_buttons_frame, text = GUI["reset"])
notification_label = ttk.Label(gui_root, text = GUI["default_notification"])
# Place widgets
# Row 0
select_model_label.grid(column = 0, row = 0, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
# Row 1
select_model_frame.grid(column = 0, row = 1, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
select_model_field.config(background = "white", width = GUI["input_field_width"])
select_model_field.pack(side = "left", padx = (0, GUI["pad_x"]))
select_model_button.pack()
# Row 2
select_file_label.grid(column = 0, row = 2, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
# Row 3
select_file_frame.grid(column = 0, row = 3, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
select_file_field.config(background = "white", width = GUI["input_field_width"])
select_file_field.pack(side = "left", padx = (0, GUI["pad_x"]))
select_file_button.pack()
# Row 4
run_inference_button.grid(column = 0, row = 4, sticky = tk.W, pady = (GUI["pad_y"], 0))
# Row 5
textbox.grid(column = 0, row = 5, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
# Row 6
textbox_buttons_frame.grid(column = 0, row = 6, sticky = tk.W, pady = (GUI["pad_y"], 0))
save_to_file_button.pack(side = "left")
reset_button.pack()
# Row 7
notification_label.config(foreground = "blue")
notification_label.grid(column = 0, row = 7, columnspan = 1, pady = (GUI["pad_y"], 0))
# Bind buttons
select_model_button.bind("<Button>", lambda e, args = [select_model_field, notification_label]: callback_select_model(e, args))
select_file_button.bind("<Button>", lambda e, args = [select_file_field, notification_label]: callback_select_file(e, args))
run_inference_button.bind("<Button>", lambda e, args = [textbox, notification_label, save_to_file_button, reset_button, select_model_button, select_file_button]: callback_run_inference(e, args))
reset_button.bind("<Button>", lambda e, args = [gui_root, select_model_field, select_file_button, select_file_field, select_file_button, textbox, save_to_file_button, notification_label]: callback_reset(e, args))
gui_root.mainloop()
- The
main
method creates the application GUI by initializing a root window, placingtkinter
widgets within the window, and binding button widgets to their callback methods. - The GUI layout uses a simple grid with 7 rows.
Frame
widgets are used to manage the layout of sub-sections of the overall interface.- You might notice that the
select_model_field
andselect_file_field
widgets areLabel
widgets with white backgrounds. In other words, they are styled to look like input entry fields but are not actual input entry fields. This was done to force users to use the file dialogs when specifying the ASR model directory and audio file path, and to eliminate the need to parse user inputs. - All GUI buttons are bound when the interface is initialized, except for the Save To File button.
Step 2.7 - Calling main
Method
Finally, call the main
method in the last line of the application to start/run the program:
main()
Step 3 - Reviewing the Complete Application
The complete application should be as follows:
import time
import torch
import torchaudio
from transformers import pipeline
import tkinter as tk
import tkinter.scrolledtext as scrolledtext
from tkinter import ttk, filedialog
### CONSTANTS ###
GUI = {
"title": "Long Inference with wav2vec2 ASR Models",
"root_width" : 800,
"root_height" : 500,
"pad_x": 10,
"pad_y": 10,
"input_field_width": 110,
"textbox_width": 125,
"textbox_height": 15,
"select_model_label": "Please select an ASR model",
"browse_models_button_label": "Browse Models",
"select_file_label": "Please select an audio file for inference",
"browse_files_button_label": "Browse Files",
"run_inference_button_label": "Run",
"save_to_file_button_label": "Save To File",
"reset": "Reset",
"default_notification": "",
"notification_select_model": "You need to select an ASR model directory",
"notification_select_audio_file": "You need to select an audio file",
"notification_select_model_and_audio_file": "You need to select an ASR model and audio file before you can run inference",
"notification_running_inference": "Running inference...this might take awhile...",
"notification_finished_inference": "Finished running inference in {} seconds",
"notification_select_file_for_saving": "You need to select a file to save the transcription",
"notification_finished_saving": "Finished saving transcription"
}
TGT_SAMPLING_RATE = 16000
CHUNK_LENGTH = 10
STRIDE_START = 3
STRIDE_END = 3
### GLOBALS ###
MODEL = None
AUDIO_FILE = None
### UTILITY METHODS ###
def read_audio_data(file: str) -> tuple[torch.Tensor, int]:
audio, sampling_rate = torchaudio.load(file, normalize = True)
return audio, sampling_rate
def resample(waveform: torch.Tensor, orig_sampling_rate: int) -> torch.Tensor:
transform = torchaudio.transforms.Resample(orig_sampling_rate, TGT_SAMPLING_RATE)
waveform = transform(waveform)
return waveform[0]
### CALLBACK METHODS ###
def callback_select_model(e: object, args: list) -> None:
global MODEL
# Unpack widgets
select_model_field = args[0]
notification_label = args[1]
# Open file dialog
dir = filedialog.askdirectory()
# Add `MODEL` to the `select_model_field` widget
if dir:
MODEL = dir
select_model_field.configure({"text": MODEL})
select_model_field.update()
else:
notification_label.configure({"text": GUI["notification_select_model"]})
notification_label.update()
def callback_select_file(e: object, args: list) -> None:
global AUDIO_FILE
# Unpack widgets
select_file_field = args[0]
notification_label = args[1]
if not MODEL:
notification_label.configure({"text": GUI["notification_select_model"]})
notification_label.update()
else:
# Open file dialog
file = filedialog.askopenfilename()
# Add `AUDIO_FILE` filename to the `select_file_field`
if file:
AUDIO_FILE = file
select_file_field.configure({"text": AUDIO_FILE})
select_file_field.update()
else:
notification_label.configure({"text": GUI["notification_select_audio_file"]})
notification_label.update()
def callback_run_inference(e: object, args: list) -> None:
global MODEL
global AUDIO_FILE
# Unpack widgets
textbox = args[0]
notification_label = args[1]
save_to_file_button = args[2]
reset_button = args[3]
select_model_button = args[4]
select_file_button = args[5]
if not MODEL or not AUDIO_FILE:
notification_label.configure({"text": GUI["notification_select_model_and_audio_file"]})
notification_label.update()
else:
# Disable select model and select file buttons when running inference
select_model_button.unbind("<Button>")
select_file_button.unbind("<Button>")
# Set the input list for the ASR pipline
pipeline_input = []
orig_audio, orig_sampling_rate = read_audio_data(AUDIO_FILE)
resampled_audio = resample(orig_audio, orig_sampling_rate)
pipeline_input.append({
"raw": resampled_audio.numpy(),
"sampling_rate": TGT_SAMPLING_RATE
})
# Initialize instance of ASR pipeline
transcriber = pipeline("automatic-speech-recognition", chunk_length_s = CHUNK_LENGTH, stride_length_s = (STRIDE_START, STRIDE_END), model = MODEL)
# Update notification
notification_label.configure({"text": GUI["notification_running_inference"]})
notification_label.update()
# Set start time
start_time = time.time()
# Run inference
transcription = transcriber(pipeline_input)
# Set end time
end_time = time.time()
# Add transcription to text box
textbox.insert("1.0", transcription[0]["text"])
textbox.update()
# Update notification
notification_label.configure({"text": GUI["notification_finished_inference"].format(str(int(end_time - start_time)))})
notification_label.update()
# Bind `save_to_file_button`
save_to_file_button.bind("<Button>", lambda e, args = [textbox, notification_label]: callback_save_file(e, args))
def callback_save_file(e: object, args: list) -> None:
# Unpack widgets
textbox = args[0]
notification_label = args[1]
# Ask user to select a file for saving
save_file = filedialog.asksaveasfilename()
if save_file:
# Write transcription to file
transcription = textbox.get(1.0, tk.END)
with open(save_file, "w", encoding = "utf8") as handle:
handle.write(transcription)
# Update notification
notification_label.configure({"text": GUI["notification_finished_saving"]})
notification_label.update()
else:
notification_label.configure({"text": GUI["select_file_for_saving"]})
notification_label.update()
def callback_reset(e: object, args: list) -> None:
global MODEL
global AUDIO_FILE
# Unpack widgets
gui_root = args[0]
select_model_field = args[1]
select_model_button = args[2]
select_file_field = args[3]
select_file_button = args[4]
textbox = args[5]
save_to_file_button = args[6]
notification_label = args[7]
MODEL = None
AUDIO_FILE = None
select_model_field.configure({"text": ""})
select_file_field.configure({"text": ""})
textbox.delete("1.0", tk.END)
notification_label.configure({"text": ""})
select_model_button.bind("<Button>", lambda e, args = [select_model_field, notification_label]: callback_select_model(e, args))
select_file_button.bind("<Button>", lambda e, args = [select_file_field, notification_label]: callback_select_file(e, args))
save_to_file_button.unbind("<Button>")
gui_root.update()
def main():
gui_root = tk.Tk()
gui_root.title(GUI["title"])
window_width = GUI["root_width"]
window_height = GUI["root_height"]
# Get the screen dimensions
screen_width = gui_root.winfo_screenwidth()
screen_height = gui_root.winfo_screenheight()
# Find the center point
center_x = int(screen_width/2 - window_width/2)
center_y = int(screen_height/2 - window_height/2)
# Set the position of the window to the center of the screen
gui_root.geometry(f"{window_width}x{window_height}+{center_x}+{center_y}")
# Not resizable
gui_root.resizable(False, False)
# Configure grid
gui_root.configure(padx = GUI["pad_x"])
gui_root.columnconfigure(0, weight = 1)
# Widgets
select_model_label = ttk.Label(gui_root, text = GUI["select_model_label"])
select_model_frame = ttk.Frame(gui_root)
select_model_field = ttk.Label(select_model_frame, text = "")
select_model_button = ttk.Button(select_model_frame, text = GUI["browse_models_button_label"])
select_file_label = ttk.Label(gui_root, text = GUI["select_file_label"])
select_file_frame = ttk.Frame(gui_root)
select_file_field = ttk.Label(select_file_frame, text = "")
select_file_button = ttk.Button(select_file_frame, text = GUI["browse_files_button_label"])
run_inference_button = ttk.Button(gui_root, text = GUI["run_inference_button_label"])
textbox = scrolledtext.ScrolledText(gui_root, width = GUI["textbox_width"], height = GUI["textbox_height"])
textbox_buttons_frame = ttk.Frame(gui_root)
save_to_file_button = ttk.Button(textbox_buttons_frame, text = GUI["save_to_file_button_label"])
reset_button = ttk.Button(textbox_buttons_frame, text = GUI["reset"])
notification_label = ttk.Label(gui_root, text = GUI["default_notification"])
# Place widgets
# Row 0
select_model_label.grid(column = 0, row = 0, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
# Row 1
select_model_frame.grid(column = 0, row = 1, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
select_model_field.config(background = "white", width = GUI["input_field_width"])
select_model_field.pack(side = "left", padx = (0, GUI["pad_x"]))
select_model_button.pack()
# Row 2
select_file_label.grid(column = 0, row = 2, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
# Row 3
select_file_frame.grid(column = 0, row = 3, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
select_file_field.config(background = "white", width = GUI["input_field_width"])
select_file_field.pack(side = "left", padx = (0, GUI["pad_x"]))
select_file_button.pack()
# Row 4
run_inference_button.grid(column = 0, row = 4, sticky = tk.W, pady = (GUI["pad_y"], 0))
# Row 5
textbox.grid(column = 0, row = 5, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
# Row 6
textbox_buttons_frame.grid(column = 0, row = 6, sticky = tk.W, pady = (GUI["pad_y"], 0))
save_to_file_button.pack(side = "left")
reset_button.pack()
# Row 7
notification_label.config(foreground = "blue")
notification_label.grid(column = 0, row = 7, columnspan = 1, pady = (GUI["pad_y"], 0))
# Bind buttons
select_model_button.bind("<Button>", lambda e, args = [select_model_field, notification_label]: callback_select_model(e, args))
select_file_button.bind("<Button>", lambda e, args = [select_file_field, notification_label]: callback_select_file(e, args))
run_inference_button.bind("<Button>", lambda e, args = [textbox, notification_label, save_to_file_button, reset_button, select_model_button, select_file_button]: callback_run_inference(e, args))
reset_button.bind("<Button>", lambda e, args = [gui_root, select_model_field, select_file_button, select_file_field, select_file_button, textbox, save_to_file_button, notification_label]: callback_reset(e, args))
gui_root.mainloop()
main()
Using the Application
The application workflow is straightforward.
- Launch the application.
- Click on Browse Models. When the file dialog opens, navigate to the directory containing your ASR model and click Select Folder.
- Next, click on Browse Files. When the file dialog open, navigate to the audio file that you want to run inference on and click Open.
- Now that you've selected your model and audio file, click Run. This will kick off the inference workflow. The GUI will display the notification"
Running inference...this might take while...
. Bear in mind that inference might take several minutes depending on the length of the audio sample, as well as the values chosen forCHUNK_LENGTH
,STRIDE_START
, andSTRIDE_END
. - Once inference is complete, click on Save To File if you want to save the generated text transcription.
- Click on Reset to reset the application for a new inference run.
To reiterate, it can take several minutes to generate a text transcription. For example, the screenshot in the Introduction shows an inference run that took 208 seconds to complete, or ~3.5 minutes. I conducted the run on this Spanish language audio sample from the news channel DW. You will note that the audio sample itself has a duration of 221 seconds. You might consider experimenting with the chunk length and stride values to examine the relationship between the final inference result and the time required to generate the result with respect to your particular ASR model.
Conclusion and Next Steps
I didn't originally plan on writing this third guide on working with wav2vec2. However, I think it is worthwhile to walk through how a practical ASR application can be built. There are any number of follow-up projects that you may wish to undertake following this guide, such as a web version of the application and/or modifying the logic to run live inference instead of waiting to display the complete transcription. As always, I hope you found this guide to be useful and happy building!