Skip to content

Core

core

Reusable, pure algorithmic components for inference and training.

Classes:

Name Description
Audio
NormalizationStats

Statistics for normalizing

NormalizedAudio

Container for normalized audio and its original stats.

ModelWaveformToWaveform

Functions:

Name Description
normalize_audio

Preprocess the raw audio in the time domain to have a mean of 0 and a std of 1

denormalize_audio

Take the model output and restore them to their original loudness.

generate_chunks

Generates batches of overlapping chunks from an audio tensor.

stitch_chunks

Stitches processed audio chunks back together using the overlap-add method.

apply_mask

Applies a complex mask to a spectrogram.

create_w2w_model
derive_stems

It is the caller's responsibility to ensure that all tensors are aligned and have the same shape.

str_to_torch_dtype

Audio dataclass

Audio(data: _AudioTensorLike, sample_rate: SampleRate)

Bases: Generic[_AudioTensorLike]

Attributes:

Name Type Description
data _AudioTensorLike

This should either be an raw or a

sample_rate SampleRate

data instance-attribute

data: _AudioTensorLike

This should either be an raw or a normalized audio tensor.

sample_rate instance-attribute

sample_rate: SampleRate

NormalizationStats dataclass

NormalizationStats(
    mean: float, std: Annotated[float, Gt(0)]
)

Statistics for normalizing and denormalizing audio.

Attributes:

Name Type Description
mean float

Mean \(\mu\) of the mixture

std Annotated[float, Gt(0)]

Standard deviation \(\sigma\) of the mixture

mean instance-attribute

mean: float

Mean \(\mu\) of the mixture

std instance-attribute

std: Annotated[float, Gt(0)]

Standard deviation \(\sigma\) of the mixture

NormalizedAudio dataclass

NormalizedAudio(
    audio: Audio[NormalizedAudioTensor],
    stats: NormalizationStats,
)

Container for normalized audio and its original stats.

Attributes:

Name Type Description
audio Audio[NormalizedAudioTensor]
stats NormalizationStats

audio instance-attribute

stats instance-attribute

normalize_audio

normalize_audio(
    audio: Audio[RawAudioTensor],
) -> NormalizedAudio

Preprocess the raw audio in the time domain to have a mean of 0 and a std of 1 before passing it to the model.

Operates on the mean of the channels.

Source code in src/splifft/core.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def normalize_audio(audio: Audio[t.RawAudioTensor]) -> NormalizedAudio:
    """Preprocess the raw audio in the time domain to have a mean of 0 and a std of 1
    before passing it to the model.

    Operates on the mean of the [channels][splifft.types.Channels].
    """
    mono_audio = audio.data.mean(dim=0)
    mean = float(mono_audio.mean())
    std = float(mono_audio.std())

    if std <= 1e-8:  # silent audio
        return NormalizedAudio(
            audio=Audio(data=t.NormalizedAudioTensor(audio.data), sample_rate=audio.sample_rate),
            stats=NormalizationStats(mean, 1.0),
        )

    normalized_data = (audio.data - mean) / std
    return NormalizedAudio(
        audio=Audio(data=t.NormalizedAudioTensor(normalized_data), sample_rate=audio.sample_rate),
        stats=NormalizationStats(mean, std),
    )

denormalize_audio

denormalize_audio(
    audio_data: NormalizedAudioTensor,
    stats: NormalizationStats,
) -> RawAudioTensor

Take the model output and restore them to their original loudness.

Source code in src/splifft/core.py
90
91
92
93
94
def denormalize_audio(
    audio_data: t.NormalizedAudioTensor, stats: NormalizationStats
) -> t.RawAudioTensor:
    """Take the model output and restore them to their original loudness."""
    return t.RawAudioTensor((audio_data * stats.std) + stats.mean)

generate_chunks

generate_chunks(
    audio_data: RawAudioTensor | NormalizedAudioTensor,
    chunk_size: ChunkSize,
    hop_size: HopSize,
    batch_size: BatchSize,
    *,
    padding_mode: PaddingMode = "reflect",
) -> Iterator[PaddedChunkedAudioTensor]

Generates batches of overlapping chunks from an audio tensor.

Returns:

Type Description
Iterator[PaddedChunkedAudioTensor]

An iterator that yields batches of chunks of shape (B, C, chunk_T).

Source code in src/splifft/core.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def generate_chunks(
    audio_data: t.RawAudioTensor | t.NormalizedAudioTensor,
    chunk_size: t.ChunkSize,
    hop_size: t.HopSize,
    batch_size: t.BatchSize,
    *,
    padding_mode: t.PaddingMode = "reflect",
) -> Iterator[t.PaddedChunkedAudioTensor]:
    """Generates batches of overlapping chunks from an audio tensor.

    :return: An iterator that yields batches of chunks of shape (B, C, chunk_T).
    """
    padding = chunk_size - hop_size
    padded_audio = F.pad(audio_data, (padding, padding), mode=padding_mode)

    padded_len = padded_audio.shape[-1]
    rem = (padded_len - chunk_size) % hop_size
    if rem != 0:
        final_pad = hop_size - rem
        padded_audio = F.pad(padded_audio, (0, final_pad), mode="constant", value=0)

    unfolded = padded_audio.unfold(
        dimension=-1, size=chunk_size, step=hop_size
    )  # (C, num_chunks, chunk_size)

    num_chunks = unfolded.shape[1]
    unfolded = unfolded.permute(1, 0, 2)  # (num_chunks, C, chunk_size)

    for i in range(0, num_chunks, batch_size):
        yield t.PaddedChunkedAudioTensor(unfolded[i : i + batch_size])

stitch_chunks

stitch_chunks(
    processed_chunks: Sequence[SeparatedChunkedTensor],
    num_stems: NumModelStems,
    chunk_size: ChunkSize,
    hop_size: HopSize,
    target_num_samples: Samples,
    *,
    window: WindowTensor,
) -> RawSeparatedTensor

Stitches processed audio chunks back together using the overlap-add method.

Reconstructs the full audio signal from a sequence of overlapping, processed chunks. Ensures that the sum of all overlapping windows is constant at every time step: \(\sum_{m=-\infty}^{\infty} w[n - mH] = C\) where \(H\) is the hop size.

Source code in src/splifft/core.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def stitch_chunks(
    processed_chunks: Sequence[t.SeparatedChunkedTensor],
    num_stems: t.NumModelStems,
    chunk_size: t.ChunkSize,
    hop_size: t.HopSize,
    target_num_samples: t.Samples,
    *,
    window: t.WindowTensor,
) -> t.RawSeparatedTensor:
    r"""Stitches processed audio chunks back together using the [overlap-add method](https://en.wikipedia.org/wiki/Overlap%E2%80%93add_method).

    Reconstructs the full audio signal from a sequence of overlapping, processed chunks. Ensures
    that the sum of all overlapping windows is constant at every time step:
    $\sum_{m=-\infty}^{\infty} w[n - mH] = C$ where $H$ is the [hop size][splifft.types.HopSize].
    """
    all_chunks = torch.cat(tuple(processed_chunks), dim=0)
    total_chunks, _N, num_channels, _chunk_T = all_chunks.shape
    windowed_chunks = all_chunks * window.view(1, 1, 1, -1)

    # folding: (B, N * C * chunk_T) -> (1, N * C * chunk_T, total_chunks)
    reshaped_for_fold = windowed_chunks.permute(1, 2, 3, 0).reshape(
        1, num_stems * num_channels * chunk_size, total_chunks
    )

    total_length = (total_chunks - 1) * hop_size + chunk_size

    folded = F.fold(
        reshaped_for_fold,
        output_size=(1, total_length),
        kernel_size=(1, chunk_size),
        stride=(1, hop_size),
    )  # (1, N * C, 1, total_length)
    stitched = folded.view(num_stems, num_channels, total_length)

    # normalization for overlap-add
    windows_to_fold = window.expand(total_chunks, 1, chunk_size)
    reshaped_windows_for_fold = windows_to_fold.permute(1, 2, 0).reshape(
        1, chunk_size, total_chunks
    )
    norm_window = F.fold(
        reshaped_windows_for_fold,
        output_size=(1, total_length),
        kernel_size=(1, chunk_size),
        stride=(1, hop_size),
    ).squeeze(0)

    norm_window.clamp_min_(1e-8)  # for edges where the window sum might be zero
    stitched /= norm_window

    padding = chunk_size - hop_size
    if padding > 0:
        stitched = stitched[..., padding:-padding]

    return t.RawSeparatedTensor(stitched[..., :target_num_samples])

apply_mask

apply_mask(
    spec_for_masking: ComplexSpectrogram,
    mask_batch: ComplexSpectrogram,
    mask_add_sub_dtype: dtype | None,
    mask_out_dtype: dtype | None,
) -> SeparatedSpectrogramTensor

Applies a complex mask to a spectrogram.

While this can be simply replaced by a complex multiplication and torch.view_as_complex, CoreML does not support it: https://github.com/apple/coremltools/issues/2003 so we handroll our own.

Source code in src/splifft/core.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def apply_mask(
    spec_for_masking: t.ComplexSpectrogram,
    mask_batch: t.ComplexSpectrogram,
    mask_add_sub_dtype: torch.dtype | None,
    mask_out_dtype: torch.dtype | None,
) -> t.SeparatedSpectrogramTensor:
    """Applies a complex mask to a spectrogram.

    While this can be simply replaced by a complex multiplication and `torch.view_as_complex`,
    CoreML does not support it: https://github.com/apple/coremltools/issues/2003 so we handroll our
    own.
    """
    spec_real = spec_for_masking[..., 0]
    spec_imag = spec_for_masking[..., 1]
    mask_real = mask_batch[..., 0]
    mask_imag = mask_batch[..., 1]

    # see: 14385, 14401, 14392, 14408
    ac = spec_real * mask_real
    bd = spec_imag * mask_imag
    ad = spec_real * mask_imag
    bc = spec_imag * mask_real

    # see: 509, 506, 505, 504, 741, 747
    out_real = ac.to(mask_add_sub_dtype) - bd.to(mask_add_sub_dtype)
    out_imag = ad.to(mask_add_sub_dtype) + bc.to(mask_add_sub_dtype)

    # see: 503, 501
    separated_spec = torch.stack([out_real, out_imag], dim=-1).to(mask_out_dtype)
    return t.SeparatedSpectrogramTensor(separated_spec)

ModelWaveformToWaveform

ModelWaveformToWaveform(
    model: Module,
    preprocess: PreprocessFn,
    postprocess: PostprocessFn,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
model
preprocess
postprocess
Source code in src/splifft/core.py
228
229
230
231
232
233
234
235
236
237
def __init__(
    self,
    model: nn.Module,
    preprocess: t.PreprocessFn,
    postprocess: t.PostprocessFn,
):
    super().__init__()
    self.model = model
    self.preprocess = preprocess
    self.postprocess = postprocess

model instance-attribute

model = model

preprocess instance-attribute

preprocess = preprocess

postprocess instance-attribute

postprocess = postprocess

forward

Source code in src/splifft/core.py
239
240
241
242
243
244
def forward(
    self, waveform_chunk: t.RawAudioTensor | t.NormalizedAudioTensor
) -> t.SeparatedChunkedTensor:
    preprocessed_input = self.preprocess(waveform_chunk)
    model_output = self.model(*preprocessed_input)
    return t.SeparatedChunkedTensor(self.postprocess(model_output, *preprocessed_input))

create_w2w_model

create_w2w_model(
    model: Module,
    model_input_type: ModelInputType,
    model_output_type: ModelOutputType,
    stft_cfg: StftConfig | None,
    num_channels: Channels,
    chunk_size: ChunkSize,
    masking_cfg: MaskingConfig,
) -> ModelWaveformToWaveform
Source code in src/splifft/core.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def create_w2w_model(
    model: nn.Module,
    model_input_type: t.ModelInputType,
    model_output_type: t.ModelOutputType,
    stft_cfg: StftConfig | None,
    num_channels: t.Channels,
    chunk_size: t.ChunkSize,
    masking_cfg: MaskingConfig,
) -> ModelWaveformToWaveform:
    try:
        device = next(model.parameters()).device
    except StopIteration:
        device = torch.device("cpu")

    needs_stft = model_input_type == "spectrogram" or model_input_type == "waveform_and_spectrogram"
    needs_istft = model_output_type == "spectrogram_mask" or model_output_type == "spectrogram"

    if (needs_stft or needs_istft) and stft_cfg is None:
        raise ValueError(
            "expected stft config for models that operate on spectrograms, but found `None`."
        )

    preprocess: t.PreprocessFn = lambda chunk: (chunk,)  # noqa: E731
    postprocess: t.PostprocessFn = lambda model_output, *_: model_output  # noqa: E731

    if needs_stft:
        assert stft_cfg is not None
        stft_module = Stft(
            n_fft=stft_cfg.n_fft,
            hop_length=stft_cfg.hop_length,
            win_length=stft_cfg.win_length,
            window_fn=lambda win_len: _get_window_fn(stft_cfg.window_shape, win_len, device),
            conv_dtype=stft_cfg.conv_dtype,
        ).to(device)
        if model_input_type == "spectrogram":
            preprocess = _create_stft_preprocessor(stft_module)
        elif model_input_type == "waveform_and_spectrogram":
            preprocess = _create_hybrid_preprocessor(stft_module)
        else:
            raise NotImplementedError(f"unsupported input type for stft: {model_input_type}")

    if needs_istft:
        assert stft_cfg is not None
        istft_module = IStft(
            n_fft=stft_cfg.n_fft,
            hop_length=stft_cfg.hop_length,
            win_length=stft_cfg.win_length,
            window_fn=lambda win_len: _get_window_fn(stft_cfg.window_shape, win_len, device),
        ).to(device)
        postprocess = _create_spec_postprocessor(
            istft_module,
            num_channels,
            chunk_size,
            masking_cfg.add_sub_dtype,
            masking_cfg.out_dtype,
            model_output_type,
        )
    return ModelWaveformToWaveform(model, preprocess, postprocess)

derive_stems

derive_stems(
    separated_stems: Mapping[
        ModelOutputStemName, RawAudioTensor
    ],
    mixture_input: RawAudioTensor,
    stem_rules: DerivedStemsConfig,
) -> dict[StemName, RawAudioTensor]

It is the caller's responsibility to ensure that all tensors are aligned and have the same shape.

Note

Mixture input and separated stems must first be denormalized.

Source code in src/splifft/core.py
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
def derive_stems(
    separated_stems: Mapping[t.ModelOutputStemName, t.RawAudioTensor],
    mixture_input: t.RawAudioTensor,
    stem_rules: DerivedStemsConfig,
) -> dict[StemName, t.RawAudioTensor]:
    """
    It is the caller's responsibility to ensure that all tensors are aligned and have the same shape.

    !!! note
        Mixture input and separated stems must first be [denormalized][splifft.core.denormalize_audio].
    """
    stems = {
        "mixture": t.RawAudioTensor(mixture_input),  # for subtraction
        **separated_stems,
    }

    for derived_name, rule in stem_rules.items():
        if rule.operation == "subtract":
            minuend = stems.get(rule.stem_name, mixture_input)
            subtrahend = stems.get(rule.by_stem_name, mixture_input)
            stems[derived_name] = t.RawAudioTensor(minuend - subtrahend)
        elif rule.operation == "sum":
            to_sum = tuple(stems[s] for s in rule.stem_names)
            stems[derived_name] = t.RawAudioTensor(torch.stack(to_sum).sum(dim=0))

    stems.pop("mixture", None)
    return stems

str_to_torch_dtype

str_to_torch_dtype(value: Any) -> dtype
Source code in src/splifft/core.py
422
423
424
425
426
427
428
429
430
431
def str_to_torch_dtype(value: Any) -> torch.dtype:
    if not isinstance(value, str):
        raise TypeError(f"expected dtype to be a string, got {value} (type {type(value)})")
    try:
        dtype = getattr(torch, value)
    except AttributeError:
        raise ValueError(f"`{value}` cannot be found under the `torch` namespace")
    if not isinstance(dtype, torch.dtype):
        raise TypeError(f"expected {dtype} to be a dtype but it is a {type(dtype)}")
    return dtype