Quick Example of Blind Source Separation#

In this notebook, we will show you a quick example of blind source separation by Gauss-ILRMA.

[1]:
!pip install ssspy
Requirement already satisfied: ssspy in /home/docs/checkouts/readthedocs.org/user_builds/sound-source-separation-python/envs/v0.1.7/lib/python3.10/site-packages (0.1.7)
Requirement already satisfied: numpy in /home/docs/checkouts/readthedocs.org/user_builds/sound-source-separation-python/envs/v0.1.7/lib/python3.10/site-packages (from ssspy) (1.26.0)
[2]:
import numpy as np
import scipy.signal as ss
import IPython.display as ipd
import matplotlib.pyplot as plt
Matplotlib is building the font cache; this may take a moment.

Here, we use sample music data.

[3]:
import os
import urllib.request
from typing import Tuple

import ssspy

def download_sample_music_data(
    n_sources: int = 2,
    conv: bool = True,
    branch: str = "main"
) -> Tuple[np.ndarray, int]:
    instruments = ["violin", "piano"]
    root = ".data"
    url_template = "https://github.com/tky823/ssspy-data/raw/{branch}/{path}"
    path_template = "audio/{instrument}_8k_reverbed.wav"
    waveforms = []

    for src_idx in range(1, n_sources + 1):
        instrument = instruments[src_idx - 1]
        path = path_template.format(instrument=instrument)
        download_path = os.path.join(root, path)
        url = url_template.format(branch=branch, path=path)

        os.makedirs(os.path.dirname(download_path), exist_ok=True)

        if not os.path.exists(download_path):
            urllib.request.urlretrieve(url, download_path)

        waveform, sample_rate = ssspy.wavread(download_path, channels_first=True)

        assert sample_rate == 8000, f"Sampling rate is expected 8000, but {sample_rate} is given."

        waveforms.append(waveform)

    waveforms = np.stack(waveforms, axis=1)

    return waveforms[: n_sources], 8000
[4]:
n_sources = 2
n_fft, hop_length = 2048, 512
window = "hann"
[5]:
waveform_src_img, sample_rate = download_sample_music_data(
    n_sources=n_sources,
    conv=True,
)  # (n_channels, n_sources, n_samples)
waveform_mix = np.sum(waveform_src_img, axis=1)  # (n_channels, n_samples)

Let’s listen to mixtures.

[6]:
for idx, waveform in enumerate(waveform_mix):
    print("Mixture: {}".format(idx + 1))
    display(ipd.Audio(waveform, rate=sample_rate))
    print()
Mixture: 1

Mixture: 2

ssspy.bss.ilrma.GaussILRMA class provides Gauss-ILRMA.

[7]:
from ssspy.bss.ilrma import GaussILRMA
[8]:
ilrma = GaussILRMA(n_basis=3, rng=np.random.default_rng(0))
print(ilrma)
GaussILRMA(n_basis=3, spatial_algorithm=IP, source_algorithm=MM, domain=2, partitioning=False, normalization=True, scale_restoration=True, record_loss=True, reference_id=0)
[9]:
_, _, spectrogram_mix = ss.stft(waveform_mix, window=window, nperseg=n_fft, noverlap=n_fft-hop_length)
[10]:
spectrogram_est = ilrma(spectrogram_mix, n_iter=500)
[11]:
_, waveform_est = ss.istft(spectrogram_est, window=window, nperseg=n_fft, noverlap=n_fft-hop_length)

Let’s listen to estimated sources.

[12]:
for idx, waveform in enumerate(waveform_est):
    print("Estimated source: {}".format(idx + 1))
    display(ipd.Audio(waveform, rate=sample_rate))
    print()
Estimated source: 1

Estimated source: 2

Display cost at each iteration.

[13]:
plt.figure()
plt.plot(range(len(ilrma.loss)), ilrma.loss)
plt.xlabel("Iteration")
plt.ylabel("Cost")
plt.show()
plt.close()
../_images/_notebooks_Getting-Started_18_0.png

Since there has been a significant decrease by the first update, let’s see the costs after the first iteration.

[14]:
plt.figure()
plt.plot(range(1, len(ilrma.loss)), ilrma.loss[1:])
plt.xlabel("Iteration")
plt.ylabel("Cost")
plt.show()
plt.close()
../_images/_notebooks_Getting-Started_20_0.png
[ ]: