Skip to content

Tutorial

This page targets two audiences:

  • Users shipping features: use the high-level engine and keep config-driven behavior.
  • Researchers/developers: drop to low-level APIs when you need full control over tensors, transforms, and post-processing.

Basic inference

Use the splifft.inference.InferenceEngine.from_pretrained for a convenient high level API.

PATH_MIXTURE = "data/audio/input/3BFTio5296w.flac"

from splifft.inference import InferenceEngine

engine = InferenceEngine.from_pretrained(
    config="/path/to/config.json",
    checkpoint_path="/path/to/checkpoint.pt",
)
result = engine.run(PATH_MIXTURE)
print(result)

#
# or, if you use the default user cache registry
#

engine = InferenceEngine.from_registry("bs_roformer-fruit-sw")

#
# and to track progress for long files or on slow hardware:
#

import logging

from splifft.inference import InferenceOutput

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s", datefmt="[%X]")
logger = logging.getLogger(__name__)

for event in engine.stream(PATH_MIXTURE):
    if isinstance(event, InferenceOutput):
        print(event)
        break
    logger.info(event)
[00:00:36] Stage.Started(stage='normalize', total_batches=None)
[00:00:36] Stage.Completed(stage='normalize')
[00:00:38] ChunkProcessed(batch_index=1, total_batches=9)
[00:00:39] ChunkProcessed(batch_index=2, total_batches=9)
[00:00:41] ChunkProcessed(batch_index=3, total_batches=9)
[00:00:42] ChunkProcessed(batch_index=4, total_batches=9)
[00:00:43] ChunkProcessed(batch_index=5, total_batches=9)
[00:00:44] ChunkProcessed(batch_index=6, total_batches=9)
[00:00:46] ChunkProcessed(batch_index=7, total_batches=9)
[00:00:47] ChunkProcessed(batch_index=8, total_batches=9)
[00:00:47] ChunkProcessed(batch_index=9, total_batches=9)
[00:00:47] Stage.Started(stage='stitch', total_batches=None)
[00:00:47] Stage.Completed(stage='stitch')
[00:00:47] Stage.Started(stage='collect_outputs', total_batches=None)
[00:00:47] Stage.Completed(stage='collect_outputs')
[00:00:47] Stage.Started(stage='derive_stems', total_batches=None)
[00:00:47] Stage.Completed(stage='derive_stems')
InferenceOutput(
    outputs={
        'bass': tensor([[-1.3643e-05, -1.3736e-05, -1.3643e-05,  ..., 
-1.3958e-05,
        -1.3730e-05, -1.3960e-05],
        [-1.3811e-05, -1.3586e-05, -1.3811e-05,  ..., -1.3738e-05,
        -1.3953e-05, -1.3736e-05]], device='cuda:0'),
        'drums': tensor([[-1.3493e-05, -1.4200e-05, -1.3493e-05,  ..., 
-1.2080e-05,
        -1.2848e-05, -1.2020e-05],
        [-1.3936e-05, -1.3758e-05, -1.3936e-05,  ..., -1.1843e-05,
        -1.2818e-05, -1.1989e-05]], device='cuda:0'),
        'other': tensor([[-7.5168e-07, -6.3413e-07, -7.5222e-07,  ...,  
1.9690e-05,
        -3.3400e-05,  2.5086e-05],
        [-7.4173e-07, -6.7063e-07, -7.4244e-07,  ...,  3.2220e-05,
        -3.7293e-05,  2.0826e-05]], device='cuda:0'),
        'vocals': tensor([[-1.3789e-05, -1.3904e-05, -1.3789e-05,  ..., 
-1.3930e-05,
        -1.3755e-05, -1.4037e-05],
        [-1.3860e-05, -1.3833e-05, -1.3860e-05,  ..., -1.3848e-05,
        -1.3747e-05, -1.3913e-05]], device='cuda:0'),
        'guitar': tensor([[-1.3846e-05, -1.3846e-05, -1.3846e-05,  ..., 
-1.3928e-05,
        -1.3760e-05, -1.3928e-05],
        [-1.3910e-05, -1.3782e-05, -1.3910e-05,  ..., -1.3871e-05,
        -1.3818e-05, -1.3871e-05]], device='cuda:0'),
        'piano': tensor([[-1.3789e-05, -1.3902e-05, -1.3789e-05,  ..., 
-1.3933e-05,
        -1.3759e-05, -1.3933e-05],
        [-1.3881e-05, -1.3810e-05, -1.3881e-05,  ..., -1.3849e-05,
        -1.3843e-05, -1.3849e-05]], device='cuda:0'),
        'instrum': tensor([[ 1.3789e-05,  1.3904e-05,  1.3789e-05,  ...,  
4.7834e-05,
        -2.5873e-05,  5.2345e-05],
        [ 1.3860e-05,  1.3833e-05,  1.3860e-05,  ...,  6.8972e-05,
        -3.0868e-05,  4.3241e-05]], device='cuda:0')
    },
    sample_rate=44100
)

This outputs splifft.inference.InferenceOutput, containing:

  • the dictionary of stem names to tensor (which can be audio or logits)
  • the sample rate of the input tensor (so you can save the audio)

Low level inference

inference_low_level.py
# ruff: noqa: E402
from pathlib import Path

import torch

PATH_CONFIG = Path("data/config/bs_roformer.json")
PATH_CKPT = Path("data/models/roformer-fp16.pt")
PATH_MIXTURE = Path("data/audio/input/3BFTio5296w.flac")

# 1. parse + validate a JSON *without* having to import a particular pytorch model.
from splifft.config import Config

config = Config.from_file(PATH_CONFIG)

# 2. we now want to *lock in* the configuration to a specific model.
from splifft.models import ModelMetadata
from splifft.models.bs_roformer import BSRoformer, BSRoformerParams

metadata = ModelMetadata(model_type="bs_roformer", params=BSRoformerParams, model=BSRoformer)
model_params = config.model.to_concrete(metadata.params)

# 3. `metadata` acts as a model builder
from splifft.io import load_weights

model = metadata.model(model_params)
model = load_weights(model, PATH_CKPT, device="cpu")

# 4. load audio and run inference by passing dependencies explicitly.
from splifft.inference import InferenceEngine
from splifft.io import read_audio

mixture = read_audio(
    PATH_MIXTURE,
    config.audio_io.target_sample_rate,
    config.audio_io.force_channels,
)
engine = InferenceEngine(
    config=config,
    model=model,
    model_params_concrete=model_params,
    model_device=next(model.parameters()).device,
    io_device=torch.device("cpu"),
)
result = engine.run(mixture)

print(list(result.outputs.keys()))

Extending splifft

splifft is designed to be easily extended without modifying its core.

Make sure you have added splifft as a dependency. Assuming your library has this structure:

tree /path/to/ext_project
├── pyproject.toml
├── scripts
│   └── main.py
└── src
    └── my_library
        └── models
            ├── __init__.py
            └── my_model.py

1. Define a new model

Don't do this

A common pattern is to define a model with a huge list of parameters in its __init__ method:

src/my_library/models/my_model.py
from torch import nn
from beartype import beartype

class MyModel(nn.Module):
    @beartype
    def __init__(
        self,
        chunk_size: int,
        output_stem_names: tuple[str, ...],
        # ... a bunch of args here
    ):
        ...

The problem is that it tightly couples the model's implementation to its configuration. Serializing to/from a JSON file and simultaneously supporting static type checking is a headache.

Instead, define a stdlib dataclass separate from the model:

src/my_library/models/my_model.py
from dataclasses import dataclass

from torch import nn

from splifft.models import ModelParamsLike
from splifft.types import (
    ChunkSize,
    InferenceArchetype,
    ModelInputType,
    ModelOutputStemName,
    ModelOutputType,
)


@dataclass
class MyModelParams(ModelParamsLike):  # (1)!
    chunk_size: ChunkSize
    output_stem_names: tuple[ModelOutputStemName, ...]

    # ... any other config your model needs
    @property
    def input_type(self) -> ModelInputType:
        return "waveform"

    @property
    def output_type(self) -> ModelOutputType:
        return "waveform"

    @property
    def inference_archetype(self) -> InferenceArchetype:
        return "standard_end_to_end"


class MyModel(nn.Module):
    def __init__(self, params: MyModelParams):
        super().__init__()
        self.params = params
  1. ModelParamsLike is not a base class to inherit from, but rather a form of structural typing that signals that MyModelParams is compatible with the splifft configuration system. You can remove it if you don't like it.

2. Register the model

With the model and its config defined, our configuration system needs to understand your model.

Don't do this

A common solution is to define a "global" dictionary of available models:

src/my_library/models/__init__.py
from my_library.models.my_model import MyModelParams, MyModel

MODEL_REGISTRY = {
    "my_model": (MyModel, MyModelParams),
    # every other model must be added here
}

To add a new model, you'd have to modify this central registry. It also forces the import of all models and unwanted dependencies at once.

Instead, our configuration system uses a simple ModelMetadata wrapper struct to act as a "descriptor" for your model. Create a factory function that defers the imports until its actually needed:

src/my_library/models/__init__.py
from splifft.models import ModelMetadata


def my_model_metadata():
    from .my_model import MyModel, MyModelParams

    return ModelMetadata(model_type="my_model", params=MyModelParams, model=MyModel)
I need to take a user's input string and dynamically import the model. How?

ModelMetadata.from_module is an alternative way to load the model metadata. It uses importlib under the hood. In fact, our CLI uses this exact approach.

from splifft.models import ModelMetadata

my_model_metadata = ModelMetadata.from_module(
    module_name="my_library.models.my_model",
    model_cls_name="MyModel",
    model_type="my_model"
)

3. Putting everything together

First, load in the configuration:

scripts/main.py
from pathlib import Path

import torch

from splifft.config import Config

config = Config.from_file(Path("path/to/my_model_config.json"))

This validates your JSON and returns a pydantic.BaseModel. Note that at this point, config.model is a lazy model configuration that is not yet fully validated.

Next, we need to create the PyTorch model. Concretize the lazy model configuration into the dataclass we defined earlier then instantiate the model:

scripts/main.py
from my_library.models import my_model_metadata

metadata = my_model_metadata()
my_model_params = config.model.to_concrete(metadata.params)
model = metadata.model(my_model_params)

Finally, load the weights, input audio and run!

scripts/main.py
from splifft.inference import InferenceEngine
from splifft.io import load_weights, read_audio

checkpoint_path = Path("path/to/my_model.pt")
model = load_weights(model, checkpoint_path, device="cpu")

mixture = read_audio(
    Path("path/to/mixture.wav"), config.audio_io.target_sample_rate, config.audio_io.force_channels
)
engine = InferenceEngine(
    config=config,
    model=model,
    model_params_concrete=my_model_params,
    model_device=next(model.parameters()).device,
    io_device=torch.device("cpu"),
)
result = engine.run(mixture)

print(f"{list(result.outputs.keys())=}")