Skip to content

Core

core

Reusable, pure algorithmic components for inference and training.

Classes:

Name Description
Audio
NormalizationStats

Statistics for normalizing

NormalizedAudio

Container for normalized audio and its original stats.

Functions:

Name Description
normalize_audio

Preprocess the raw audio in the time domain to have a mean of 0 and a std of 1

denormalize_audio

Take the model output and restore them to their original loudness.

generate_chunks

Generates batches of overlapping chunks from an audio tensor.

stitch_chunks

Stitches processed audio chunks back together using the overlap-add method.

derive_stems

It is the caller's responsibility to ensure that all tensors are aligned and have the same shape.

get_dtype

Attributes:

Name Type Description
Samples TypeAlias

Number of samples in the audio signal.

SampleRate TypeAlias

The number of samples of audio recorded per second (hertz).

Channels TypeAlias

Number of audio streams.

FileFormat TypeAlias
AudioEncoding TypeAlias
BitDepth TypeAlias

Number of bits of information in each sample.

RawAudioTensor

Time domain tensor of audio samples.

NormalizedAudioTensor

A mixture tensor that has been normalized using on-the-fly statistics.

ComplexSpectrogram

A complex-valued representation of audio's frequency content over time via the STFT.

HopSize TypeAlias

The step size, in samples, between the start of consecutive chunks.

WindowShape TypeAlias

The shape of the window function applied to each chunk before computing the STFT.

FftSize TypeAlias

The number of frequency bins in the STFT, controlling the frequency resolution.

Bands TypeAlias
BatchSize TypeAlias

The number of chunks processed simultaneously by the GPU.

Dtype TypeAlias
PaddingMode TypeAlias

The method used to pad the audio before chunking, crucial for handling the edges of the audio signal.

ChunkDuration TypeAlias

The length of an audio segment, in seconds, processed by the model at one time.

OverlapRatio TypeAlias

The fraction of a chunk that overlaps with the next one.

Padding TypeAlias

Samples to add to the beginning and end of each chunk.

PaddedChunkedAudioTensor

A batch of audio chunks from a padded source.

NumModelStems TypeAlias

The number of stems the model outputs. This should be the length of [splifft.models.ModelConfigLike.output_stem_names].

SeparatedChunkedTensor

A batch of separated audio chunks from the model.

WindowTensor

A 1D tensor representing a window function.

RawSeparatedTensor

The final, stitched, raw-domain separated audio.

Sdr TypeAlias

Signal-to-Distortion Ratio (decibels). Higher is better.

SiSdr TypeAlias

Scale-Invariant SDR (SI-SDR) is invariant to scaling errors (decibels). Higher is better.

L1Norm TypeAlias

L1 norm (mean absolute error) between two signals (dimensionless). Lower is better.

DbDifferenceMel TypeAlias

Difference in the dB-scaled mel spectrogram.

Bleedless TypeAlias

A metric to quantify the amount of "bleeding" from other sources. Higher is better.

Fullness TypeAlias

A metric to quantify how much of the original source is missing. Higher is better.

Audio dataclass

Audio(data: _AudioTensorLike, sample_rate: SampleRate)

Bases: Generic[_AudioTensorLike]

Attributes:

Name Type Description
data _AudioTensorLike

This should either be an raw or a

sample_rate SampleRate

data instance-attribute

data: _AudioTensorLike

This should either be an raw or a normalized audio tensor.

sample_rate instance-attribute

sample_rate: SampleRate

NormalizationStats dataclass

NormalizationStats(
    mean: float, std: Annotated[float, Gt(0)]
)

Statistics for normalizing and denormalizing audio.

Attributes:

Name Type Description
mean float

Mean \(\mu\) of the mixture

std Annotated[float, Gt(0)]

Standard deviation \(\sigma\) of the mixture

mean instance-attribute

mean: float

Mean \(\mu\) of the mixture

std instance-attribute

std: Annotated[float, Gt(0)]

Standard deviation \(\sigma\) of the mixture

NormalizedAudio dataclass

NormalizedAudio(
    audio: Audio[NormalizedAudioTensor],
    stats: NormalizationStats,
)

Container for normalized audio and its original stats.

Attributes:

Name Type Description
audio Audio[NormalizedAudioTensor]
stats NormalizationStats

audio instance-attribute

stats instance-attribute

normalize_audio

normalize_audio(
    audio: Audio[RawAudioTensor],
) -> NormalizedAudio

Preprocess the raw audio in the time domain to have a mean of 0 and a std of 1 before passing it to the model.

Operates on the mean of the channels.

Source code in src/splifft/core.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def normalize_audio(audio: Audio[RawAudioTensor]) -> NormalizedAudio:
    """Preprocess the raw audio in the time domain to have a mean of 0 and a std of 1
    before passing it to the model.

    Operates on the mean of the [channels][splifft.core.Channels].
    """
    mono_audio = audio.data.mean(dim=0)
    mean = float(mono_audio.mean())
    std = float(mono_audio.std())

    if std <= 1e-8:  # silent audio
        return NormalizedAudio(
            audio=Audio(data=NormalizedAudioTensor(audio.data), sample_rate=audio.sample_rate),
            stats=NormalizationStats(mean, 1.0),
        )

    normalized_data = (audio.data - mean) / std
    return NormalizedAudio(
        audio=Audio(data=NormalizedAudioTensor(normalized_data), sample_rate=audio.sample_rate),
        stats=NormalizationStats(mean, std),
    )

denormalize_audio

denormalize_audio(
    audio_data: NormalizedAudioTensor,
    stats: NormalizationStats,
) -> RawAudioTensor

Take the model output and restore them to their original loudness.

Source code in src/splifft/core.py
79
80
81
82
83
def denormalize_audio(
    audio_data: NormalizedAudioTensor, stats: NormalizationStats
) -> RawAudioTensor:
    """Take the model output and restore them to their original loudness."""
    return RawAudioTensor((audio_data * stats.std) + stats.mean)

generate_chunks

generate_chunks(
    audio_data: RawAudioTensor | NormalizedAudioTensor,
    chunk_size: ChunkSize,
    hop_size: HopSize,
    batch_size: BatchSize,
    *,
    padding_mode: PaddingMode = "reflect",
) -> Iterator[PaddedChunkedAudioTensor]

Generates batches of overlapping chunks from an audio tensor.

Returns:

Type Description
Iterator[PaddedChunkedAudioTensor]

An iterator that yields batches of chunks of shape (B, C, chunk_T).

Source code in src/splifft/core.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def generate_chunks(
    audio_data: RawAudioTensor | NormalizedAudioTensor,
    chunk_size: ChunkSize,
    hop_size: HopSize,
    batch_size: BatchSize,
    *,
    padding_mode: PaddingMode = "reflect",
) -> Iterator[PaddedChunkedAudioTensor]:
    """Generates batches of overlapping chunks from an audio tensor.

    :return: An iterator that yields batches of chunks of shape (B, C, chunk_T).
    """
    padding = chunk_size - hop_size
    padded_audio = F.pad(audio_data, (padding, padding), mode=padding_mode)

    unfolded = padded_audio.unfold(
        dimension=-1, size=chunk_size, step=hop_size
    )  # (C, num_chunks, chunk_size)

    num_chunks = unfolded.shape[1]
    unfolded = unfolded.permute(1, 0, 2)  # (num_chunks, C, chunk_size)

    for i in range(0, num_chunks, batch_size):
        yield PaddedChunkedAudioTensor(unfolded[i : i + batch_size])

stitch_chunks

stitch_chunks(
    processed_chunks: Sequence[SeparatedChunkedTensor],
    num_stems: NumModelStems,
    chunk_size: ChunkSize,
    hop_size: HopSize,
    target_num_samples: Samples,
    *,
    window: WindowTensor,
) -> RawSeparatedTensor

Stitches processed audio chunks back together using the overlap-add method.

Reconstructs the full audio signal from a sequence of overlapping, processed chunks. Ensures that the sum of all overlapping windows is constant at every time step: \(\sum_{m=-\infty}^{\infty} w[n - mH] = C\) where \(H\) is the hop size.

Source code in src/splifft/core.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def stitch_chunks(
    processed_chunks: Sequence[SeparatedChunkedTensor],
    num_stems: NumModelStems,
    chunk_size: ChunkSize,
    hop_size: HopSize,
    target_num_samples: Samples,
    *,
    window: WindowTensor,
) -> RawSeparatedTensor:
    r"""Stitches processed audio chunks back together using the [overlap-add method](https://en.wikipedia.org/wiki/Overlap%E2%80%93add_method).

    Reconstructs the full audio signal from a sequence of overlapping, processed chunks. Ensures
    that the sum of all overlapping windows is constant at every time step:
    $\sum_{m=-\infty}^{\infty} w[n - mH] = C$ where $H$ is the [hop size][splifft.core.HopSize].
    """
    all_chunks = torch.cat(tuple(processed_chunks), dim=0)
    total_chunks, _N, num_channels, _chunk_T = all_chunks.shape
    windowed_chunks = all_chunks * window.view(1, 1, 1, -1)

    # folding: (B, N * C * chunk_T) -> (1, N * C * chunk_T, total_chunks)
    reshaped_for_fold = windowed_chunks.permute(1, 2, 3, 0).reshape(
        1, num_stems * num_channels * chunk_size, total_chunks
    )

    total_length = (total_chunks - 1) * hop_size + chunk_size

    folded = F.fold(
        reshaped_for_fold,
        output_size=(1, total_length),
        kernel_size=(1, chunk_size),
        stride=(1, hop_size),
    )  # (1, N * C, 1, total_length)
    stitched = folded.view(num_stems, num_channels, total_length)

    # normalization for overlap-add
    ones_template = torch.ones(1, 1, total_length, device=window.device)
    unfolded_ones = ones_template.unfold(dimension=-1, size=chunk_size, step=hop_size)
    windowed_unfolded_ones = unfolded_ones * window
    reshaped_for_fold_norm = windowed_unfolded_ones.permute(0, 1, 3, 2).reshape(
        1, chunk_size, total_chunks
    )
    norm_window = F.fold(
        reshaped_for_fold_norm,
        output_size=(1, total_length),
        kernel_size=(1, chunk_size),
        stride=(1, hop_size),
    ).squeeze()

    norm_window.clamp_min_(1e-8)  # for edges where the window sum might be zero
    stitched /= norm_window

    padding = chunk_size - hop_size
    return RawSeparatedTensor(stitched[..., padding : padding + target_num_samples])

derive_stems

derive_stems(
    separated_stems: Mapping[
        ModelOutputStemName, RawAudioTensor
    ],
    mixture_input: RawAudioTensor,
    stem_rules: DerivedStemsConfig,
) -> dict[StemName, RawAudioTensor]

It is the caller's responsibility to ensure that all tensors are aligned and have the same shape.

Note

Mixture input and separated stems must first be denormalized.

Source code in src/splifft/core.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def derive_stems(
    separated_stems: Mapping[ModelOutputStemName, RawAudioTensor],
    mixture_input: RawAudioTensor,
    stem_rules: DerivedStemsConfig,
) -> dict[StemName, RawAudioTensor]:
    """
    It is the caller's responsibility to ensure that all tensors are aligned and have the same shape.

    !!! note
        Mixture input and separated stems must first be [denormalized][splifft.core.denormalize_audio].
    """
    stems = {
        "mixture": RawAudioTensor(mixture_input),  # for subtraction
        **separated_stems,
    }

    for derived_name, rule in stem_rules.items():
        if rule.operation == "subtract":
            minuend = stems.get(rule.stem_name, mixture_input)
            subtrahend = stems.get(rule.by_stem_name, mixture_input)
            stems[derived_name] = RawAudioTensor(minuend - subtrahend)
        elif rule.operation == "sum":
            to_sum = tuple(stems[s] for s in rule.stem_names)
            stems[derived_name] = RawAudioTensor(torch.stack(to_sum).sum(dim=0))

    stems.pop("mixture", None)
    return stems

get_dtype

get_dtype(dtype: Dtype) -> dtype
Source code in src/splifft/core.py
209
210
211
212
213
214
215
216
217
def get_dtype(dtype: Dtype) -> torch.dtype:
    if dtype == "float32":
        return torch.float32
    elif dtype == "float16":
        return torch.float16
    elif dtype == "bfloat16":
        return torch.bfloat16
    else:
        raise ValueError(f"unsupported {dtype=}")

Samples module-attribute

Samples: TypeAlias = Annotated[int, Gt(0)]

Number of samples in the audio signal.

SampleRate module-attribute

SampleRate: TypeAlias = Annotated[int, Gt(0)]

The number of samples of audio recorded per second (hertz).

See concepts for more details.

Channels module-attribute

Channels: TypeAlias = Annotated[int, Gt(0)]

Number of audio streams.

  • 1: Mono audio
  • 2: Stereo (left and right). Models are usually trained on stereo audio.

FileFormat module-attribute

FileFormat: TypeAlias = Literal['flac', 'wav', 'ogg']

AudioEncoding module-attribute

AudioEncoding: TypeAlias = Literal[
    "PCM_S", "PCM_U", "PCM_F", "ULAW", "ALAW"
]

Audio encoding

  • PCM_S: Signed integer linear pulse-code modulation
  • PCM_U: Unsigned integer linear pulse-code modulation
  • PCM_F: Floating-point pulse-code modulation
  • ULAW: μ-law
  • ALAW: a-law

BitDepth module-attribute

BitDepth: TypeAlias = Literal[8, 16, 24, 32, 64]

Number of bits of information in each sample.

It determines the dynamic range of the audio signal: the difference between the quietest and loudest possible sounds.

  • 16-bit: Standard for CD audio: ~96 dB dynamic range.
  • 24-bit: Common in professional audio, allowing for more headroom during mixing
  • 32-bit float: Standard in digital audio workstations (DAWs) and deep learning models. The amplitude is represented by a floating-point number, which prevents clipping (distortion from exceeding the maximum value). This library primarily works with fp32 tensors.

RawAudioTensor module-attribute

RawAudioTensor = NewType('RawAudioTensor', Tensor)

Time domain tensor of audio samples. Shape (channels, samples)

NormalizedAudioTensor module-attribute

NormalizedAudioTensor = NewType(
    "NormalizedAudioTensor", Tensor
)

A mixture tensor that has been normalized using on-the-fly statistics. Shape (channels, samples)

ComplexSpectrogram module-attribute

ComplexSpectrogram = NewType('ComplexSpectrogram', Tensor)

A complex-valued representation of audio's frequency content over time via the STFT.

Shape (channels, frequency bins, time frames, 2)

See concepts for more details.

HopSize module-attribute

HopSize: TypeAlias = Annotated[int, Gt(0)]

The step size, in samples, between the start of consecutive chunks.

To avoid artifacts at the edges of chunks, we process them with overlap. The hop size is the distance we "slide" the chunking window forward. ChunkSize < HopSize implies overlap and the overlap amount is ChunkSize - HopSize.

WindowShape module-attribute

WindowShape: TypeAlias = Literal[
    "hann", "hamming", "linear_fade"
]

The shape of the window function applied to each chunk before computing the STFT.

FftSize module-attribute

FftSize: TypeAlias = Annotated[int, Gt(0)]

The number of frequency bins in the STFT, controlling the frequency resolution.

Bands module-attribute

Bands: TypeAlias = Tensor

BatchSize module-attribute

BatchSize: TypeAlias = Annotated[int, Gt(0)]

The number of chunks processed simultaneously by the GPU.

Increasing the batch size can improve GPU utilisation and speed up training, but it requires more memory.

Dtype module-attribute

Dtype: TypeAlias = Literal["float32", "float16", "bfloat16"]

PaddingMode module-attribute

PaddingMode: TypeAlias = Literal[
    "reflect", "constant", "replicate"
]

The method used to pad the audio before chunking, crucial for handling the edges of the audio signal.

  • reflect: Pads the signal by reflecting the audio at the boundary. This creates a smooth continuation and often yields the best results for music.
  • constant: Pads with zeros. Simpler, but can introduce silence at the edges.
  • replicate: Repeats the last sample at the edge.

ChunkDuration module-attribute

ChunkDuration: TypeAlias = Annotated[float, Gt(0)]

The length of an audio segment, in seconds, processed by the model at one time.

Equivalent to chunk size divided by the sample rate.

OverlapRatio module-attribute

OverlapRatio: TypeAlias = Annotated[float, Ge(0), Lt(1)]

The fraction of a chunk that overlaps with the next one.

The relationship with hop size is: $$ \text{hop_size} = \text{chunk_size} \cdot (1 - \text{overlap_ratio}) $$

  • A ratio of 0.0 means no overlap (hop_size = chunk_size).
  • A ratio of 0.5 means 50% overlap (hop_size = chunk_size / 2).
  • A higher overlap ratio increases computational cost as more chunks are processed, but it can lead to smoother results by averaging more predictions for each time frame.

Padding module-attribute

Padding: TypeAlias = Annotated[int, Gt(0)]

Samples to add to the beginning and end of each chunk.

  • To ensure that the very beginning and end of a track can be centerd within a chunk, we often may add "reflection padding" or "zero padding" before chunking.
  • To ensure that the last chunk is full-size, we may pad the audio so its length is a multiple of the hop size.

PaddedChunkedAudioTensor module-attribute

PaddedChunkedAudioTensor = NewType(
    "PaddedChunkedAudioTensor", Tensor
)

A batch of audio chunks from a padded source. Shape (batch size, channels, chunk size)

NumModelStems module-attribute

NumModelStems: TypeAlias = Annotated[int, Gt(0)]

The number of stems the model outputs. This should be the length of [splifft.models.ModelConfigLike.output_stem_names].

SeparatedChunkedTensor module-attribute

SeparatedChunkedTensor = NewType(
    "SeparatedChunkedTensor", Tensor
)

A batch of separated audio chunks from the model. Shape (batch size, number of stems, channels, chunk size)

WindowTensor module-attribute

WindowTensor = NewType('WindowTensor', Tensor)

A 1D tensor representing a window function. Shape (chunk size)

RawSeparatedTensor module-attribute

RawSeparatedTensor = NewType('RawSeparatedTensor', Tensor)

The final, stitched, raw-domain separated audio. Shape (number of stems, channels, samples)

Sdr module-attribute

Signal-to-Distortion Ratio (decibels). Higher is better.

Measures the ratio of the power of clean reference signal to the power of all other error components (interference, artifacts, and spatial distortion).

Definition: $$ \text{SDR} = 10 \log_{10} \frac{|\mathbf{s}|^2}{|\mathbf{s} - \mathbf{\hat{s}}|^2}, $$ where:

  • \(\mathbf{s}\): ground truth source signal
  • \(\mathbf{\hat{s}}\): estimated source signal produced by the model
  • \(||\cdot||^2\): squared L2 norm (power) of the signal

SiSdr module-attribute

SiSdr: TypeAlias = float

Scale-Invariant SDR (SI-SDR) is invariant to scaling errors (decibels). Higher is better.

It projects the estimate onto the reference to find the optimal scaling factor \(\alpha\), creating a scaled reference that best matches the estimate's amplitude.

  • Optimal scaling factor: \(\alpha = \frac{\langle\mathbf{\hat{s}}, \mathbf{s}\rangle}{||\mathbf{s}||^2}\)
  • Scaled reference: \(\mathbf{s}_\text{target} = \alpha \cdot \mathbf{s}\)
  • Error: \(\mathbf{e} = \mathbf{\hat{s}} - \mathbf{s}_\text{target}\)
  • \(\text{SI-SDR} = 10 \log_{10} \frac{||\mathbf{s}_\text{target}||^2}{||\mathbf{e}||^2}\)

L1Norm module-attribute

L1Norm: TypeAlias = float

L1 norm (mean absolute error) between two signals (dimensionless). Lower is better.

Measures the average absolute difference between the reference and estimated signals.

  • Time domain: \(\mathcal{L}_\text{L1} = \frac{1}{N} \sum_{n=1}^{N} |\mathbf{s}[n] - \mathbf{\hat{s}}[n]|\),
  • Frequency domain: \(\mathcal{L}_\text{L1Freq} = \frac{1}{\text{MK}}\sum_{m=1}^{M} \sum_{k=1}^{K} \left||S(m, k)| - |\hat{S}(m, k)|\right|\)

DbDifferenceMel module-attribute

DbDifferenceMel: TypeAlias = float

Difference in the dB-scaled mel spectrogram. $$ \mathbf{D}(m, k) = \text{dB}(|\hat{S}\text{mel}(m, k)|) - \text{dB}(|S\text{mel}(m, k)|) $$

Bleedless module-attribute

Bleedless: TypeAlias = float

A metric to quantify the amount of "bleeding" from other sources. Higher is better.

Measures the average energy of the parts of the mel spectrogram that are louder than the reference. A high value indicates that the estimate contains unwanted energy (bleed) from other sources: $$ \text{Bleed} = \text{mean}(\mathbf{D}(m, k)) \quad \forall \quad \mathbf{D}(m, k) > 0 $$

Fullness module-attribute

Fullness: TypeAlias = float

A metric to quantify how much of the original source is missing. Higher is better.

Complementary to Bleedless. Measures the average energy of the parts of the mel spectrogram that are quieter than the reference. A high value indicates that parts of the target loss were lost during the separation, indicating that more of the original source's character is preserved. $$ \text{Fullness} = \text{mean}(|\mathbf{D}(m, k)|) \quad \forall \quad \mathbf{D}(m, k) < 0 $$