Skip to content

IO

io

Operations for reading and writing to disk. All side effects should go here.

Functions:

Name Description
read_audio

Loads, resamples and converts channels.

load_weights

Load the weights from a checkpoint into the given model.

read_audio

read_audio(
    file: StrPath,
    target_sr: SampleRate,
    target_channels: int | None,
    device: device | None = None,
) -> Audio[RawAudioTensor]

Loads, resamples and converts channels.

Source code in src/splifft/io.py
18
19
20
21
22
23
24
25
26
27
28
29
def read_audio(
    file: t.StrPath,
    target_sr: t.SampleRate,
    target_channels: int | None,
    device: torch.device | None = None,
) -> Audio[t.RawAudioTensor]:
    """Loads, resamples and converts channels."""
    decoder = AudioDecoder(source=file, sample_rate=target_sr, num_channels=target_channels)
    samples = decoder.get_all_samples()
    waveform = samples.data.to(device)

    return Audio(t.RawAudioTensor(waveform), samples.sample_rate)

load_weights

load_weights(
    model: ModelT,
    checkpoint_file: StrPath | bytes,
    device: device | str,
    *,
    strict: bool = False,
) -> ModelT

Load the weights from a checkpoint into the given model.

Source code in src/splifft/io.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def load_weights(
    model: ModelT,
    checkpoint_file: t.StrPath | bytes,
    device: torch.device | str,
    *,
    strict: bool = False,
) -> ModelT:
    """Load the weights from a checkpoint into the given model."""

    state_dict = torch.load(checkpoint_file, map_location=device, weights_only=True)

    # TODO: DataParallel and `module.` prefix
    model.load_state_dict(state_dict, strict=strict)
    # NOTE: do not torch.compile here!

    return model.to(device)