Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming Inference of the FS-EEND system #16

Open
BNarayanaReddy opened this issue Jan 2, 2025 · 7 comments
Open

Streaming Inference of the FS-EEND system #16

BNarayanaReddy opened this issue Jan 2, 2025 · 7 comments

Comments

@BNarayanaReddy
Copy link

Could you please, help me inferring the FS-EEND system in the real-time?

  • I could not find any code for online streaming inference?
  • Only Diarization results are only employed which are quoted after the provided dataset is provided in the kaldi-format.
  • Please clarify the doubt that whether this system is able to perform online real-time diarization as mentioned in research paper.
  • How should we give the input waveform to the system, the system uses the same kaldi-format of data preparation to be performed for both fine-tuning or inferring the system.
  • Please provide the clarification about your claim of the system being online or real-time diarization model.
@DiLiangWU
Copy link
Member

DiLiangWU commented Jan 2, 2025

Hi, I think there may be a couple of areas that need some clarification.

About the FS-EEND system being online or real-time diarization model:

Actually, the network architecture of FS-EEND is causal/online with a masked self-attention module in the time dimension. This means that the output for each frame only depends on its previous context (while ignoring a few look-ahead frames for descriptive purposes). Therefore, performing inference directly provides online diarization results due to the masking mechanism in the time dimension, and this is essentially the same as iteratively performing inference over time steps.

About the data format:

We refer to the recipe for data preparation in EEND, which generates kaldi-format data including wav.scp/utt2spk/spk2utt/segments/rttm. The wav.scp records the waveform paths, and the other files are used to generate ground-truth labels. If you want to use other dataset for fine-tuning or inference, you should prepare these kaldi-format files, which can be referenced from Kaldi's recipe. If you only want to input the waveform for inference without labels for reference, please modify KaldiDiarizationDataset and test_step. By the way, direct inference without fine-tuning usually yields poorer results due to domain differences

@BNarayanaReddy
Copy link
Author

Yes sir, thank you for your previous reply. I understand the inference process. I have performed the evaluation on the AMI corpus test dataset with the pretrained model given by your team in the readme. I would like to know whether the streaming inference can be done or not. For example, if I run a particular code block, can the diarization is performed on the speech coming-in through the microphone on the go (real-time streaming). I don't find the scripts relating to the online-diarization inference. Even though the system only depending on the causal frames, is it possible to perform the diarization in the streaming or real-time setting (low latency) with the audio/speech coming from microphone. Just like DIART system, as depicted in the pic.

image

@DiLiangWU
Copy link
Member

DiLiangWU commented Jan 6, 2025

Thank you for your comments, I understand now. FS-EEND can perform online/streaming inference by changing the masked parallel form into an iterative inference paradigm. The two paradigms are equivalent in terms of the output results. We have modified the TransformerEncoder with masked self-attention for streaming inference and validated the equivalence between the two paradigms. The code updates can be found in nnet/modules/streaming_tfm.py.

However, the entire system's streaming inference, like DIART, still requires additional engineering work, such as the decoder and Conv1D parts. This will take some extra time, and we will update the code later. It is important to emphasize that these are engineering implementation differences and do not affect the conclusion in the paper that FS-EEND can perform streaming inference.

@BNarayanaReddy
Copy link
Author

Thank you for your kind reply, sir. I understand the differences. Thank you for providing you answer and adding new feature to your project. I was trying to implement the same not the exact streaming inference but like buffer wise diarization. Here is the code, please help me if I am doing anything wrong so far.

from nnet.model.onl_tfm_enc_1dcnn_enc_linear_non_autoreg_pos_enc_l2norm import OnlineTransformerDADiarization
import torch
import hyperpyyaml
import librosa
import numpy as np
from datasets.feature import stft, transform, splice, subsample
from nnet.model.onl_tfm_enc_1dcnn_enc_linear_non_autoreg_pos_enc_l2norm import OnlineTransformerDADiarization
from train.oln_tfm_enc_dec_spk_pit import SpeakerDiarization
import warnings
warnings.filterwarnings("ignore")
n_spks = 4
class SpeakerDiarzationWrapperForTest(SpeakerDiarization):
    def __init__(self, hparams, model, datasets, opt, scheduler, collate_func):
        super(SpeakerDiarization, self).__init__()
        self.model = model

    def forward(self, src, ilens=None, tgt=None):
        return self.model(src)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        return y_hat
def load_wav(wav_file):
    audio, sr = librosa.load(wav_file, sr=None)
    return audio,sr

def extract_mfccs(samples):
    # stft -> transform (logmel23) -> splice -> subsample
    features = stft(samples, frame_size=200,frame_shift=80)
    features = transform(features, transform_type="logmel23")
    features = splice(features, context_size=7)
    # dummy labels (zeros) just for not touching the actual pipeline
    T = np.zeros((features.shape[0], n_spks), dtype=np.int32) # 4 - number of speakers
    # features, T = subsample(features,T, 10)
    return features, T # should be 1, 32, 345 => 345 to be the dim
def main(config_file="/home/narayana/ML/spkr_diarize/FS_EEND_sys/FS-EEND/conf/spk_onl_tfm_enc_dec_nonautoreg_infer.yaml"):
    with open(config_file, "r") as f:
        configs = hyperpyyaml.load_hyperpyyaml(f)
        f.close()
    model = OnlineTransformerDADiarization(
        n_speakers=configs["data"]["num_speakers"],
        in_size=(2 * configs["data"]["context_recp"] + 1) * configs["data"]["feat"]["n_mels"],          # Transformer need to know maximum data length
        **configs["model"]["params"],
    )
    spk_dia_main = SpeakerDiarzationWrapperForTest(
        hparams=configs,
        model=model,
        datasets=None,
        opt=None,
        scheduler=None,
        collate_func=None
    )
    wav_file = "/home/narayana/ML/spkr_diarize/FS_EEND_sys/FS-EEND/test_wav_files/qadia.wav"
    audio,sr = load_wav(wav_file)
    # load model
    # Load Checkpoint
    avg_ckpt_path = "/home/narayana/ML/spkr_diarize/FS_EEND_sys/ckpts/ch/FS-EEND_ch_91_100epo_avg_model.ckpt"
    state_dict = torch.load(avg_ckpt_path, map_location="cuda")
    # debug
    spk_dia_main.load_state_dict(state_dict)
    spk_dia_main.eval()
    # 0.5 sec as buffer from the audio every time
    for i in range(0, len(audio), int(sr/2)):
        audio_0_5 = audio[i:i+int(sr/2)]
        # omit the last buffer if it is less than 0.5 sec
        if len(audio_0_5) < int(sr/2):
            continue
        Y, T_ss = extract_mfccs(audio_0_5)
        Y = torch.tensor(Y).unsqueeze(0)
        # print(mfccs.shape)
        with torch.no_grad():  # Disable gradient computation for inference
            output, emb, attractors = model.test(Y, ilens=[Y.shape[1]])
        # print(output)
        preds_realspk = [p[:, 1:n_spks + 1] for p, nspk in zip(output, [n_spks])]
        print(preds_realspk)
if __name__ == "__main__":
    main()

@DiLiangWU
Copy link
Member

DiLiangWU commented Jan 7, 2025

Thank you for your interest and suggestions. If you have any questions, please feel free to raise them at any time.
I have read the code. Performing local segment-wise inference every 0.5 seconds independently has a problem: the resulting speaker IDs are also local to the segment level and lack utterance-level meaning. Simply concatenating the outputs would lead to a speaker permutation ambiguity problem. Therefore, it is necessary to cache previous contextual information at each layer of the network so that each frame/buffer's inference can refer to the previous context, maintaining consistency of the speaker ID at the utterance level. This requires modifying the internal network implementation as mentioned earlier, and we will update the code accordingly.

@BNarayanaReddy
Copy link
Author

Can't we just use a post-clustering process?

@DiLiangWU
Copy link
Member

Online speaker diarization can be achieved by extracting global speaker embeddings at the segment level and then performing online clustering. This requires pretraining a speaker verification network to extract global speaker embeddings. However, in the EEND framework, the speaker embeddings learned are local, meaning the embedding of the same speaker may vary across different utterances, as shown in Fig.4 in EEND-EDA. Therefore, online diarization by clustering segment-level embedding without attending to previous context is not straightforward within the EEND framework.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants