Skip to content

IO

io

Operations for reading and writing to disk. All side effects should go here.

Functions:

Name Description
read_audio

Loads, resamples and converts channels.

load_weights

Load the weights from a checkpoint into the given model.

Attributes:

Name Type Description
FileLike TypeAlias

FileLike module-attribute

FileLike: TypeAlias = Path | str | BinaryIO

read_audio

read_audio(
    file: FileLike,
    target_sr: SampleRate,
    target_channels: int | None,
    device: device | None = None,
) -> Audio[RawAudioTensor]

Loads, resamples and converts channels.

Source code in src/splifft/io.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def read_audio(
    file: FileLike,
    target_sr: SampleRate,
    target_channels: int | None,
    device: torch.device | None = None,
) -> Audio[RawAudioTensor]:
    """Loads, resamples and converts channels."""
    waveform, sr = torchaudio.load(file, channels_first=True)
    waveform = waveform.to(device)

    if sr != target_sr:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr).to(device)
        waveform = resampler(waveform)

    current_channels = waveform.shape[0]
    if target_channels is not None and current_channels != target_channels:
        if target_channels == 1:  # stereo -> mono
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        elif target_channels == 2:  # mono -> stereo
            waveform = waveform.repeat(2, 1)
        else:
            raise ValueError(
                f"expected target_channels to be 1 or 2, got {target_channels=} with {current_channels=}."
            )

    return Audio(RawAudioTensor(waveform), target_sr)

load_weights

load_weights(
    model: ModelT,
    checkpoint_file: FileLike,
    device: device | str,
) -> ModelT

Load the weights from a checkpoint into the given model.

Source code in src/splifft/io.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def load_weights(
    model: ModelT,
    checkpoint_file: FileLike,
    device: torch.device | str,
) -> ModelT:
    """Load the weights from a checkpoint into the given model."""

    state_dict = torch.load(checkpoint_file, map_location=device, weights_only=True)

    # TODO: DataParallel and `module.` prefix
    model.load_state_dict(state_dict)
    # NOTE: do not torch.compile here!

    return model.to(device)