Skip to content

Library Tutorial

Extending splifft

splifft is designed to be easily extended without modifying its core.

Make sure you have added splifft as a dependency. Assuming your library has this structure:

tree /path/to/ext_project
├── pyproject.toml
├── scripts
│   └── main.py
└── src
    └── my_library
        └── models
            ├── __init__.py
            └── my_model.py

1. Define a new model

Don't do this

A common pattern is to define a model with a huge list of parameters in its __init__ method:

src/my_library/models/my_model.py
from torch import nn
from beartype import beartype

class MyModel(nn.Module):
    @beartype
    def __init__(
        self,
        chunk_size: int,
        output_stem_names: tuple[str, ...],
        # ... a bunch of args here
    ):
        ...

The problem is that it tightly couples the model's implementation to its configuration. Serializing to/from a JSON file and simultaneously supporting static type checking is a headache.

Instead, define a stdlib dataclass separate from the model:

src/my_library/models/my_model.py
from dataclasses import dataclass

from torch import nn

from splifft.models import ChunkSize, ModelConfigLike, ModelOutputStemName


@dataclass
class MyModelConfig(ModelConfigLike):  # (1)!
    chunk_size: ChunkSize
    output_stem_names: tuple[ModelOutputStemName, ...]
    # ... any other parameters your model needs


class MyModel(nn.Module):
    def __init__(self, cfg: MyModelConfig):
        super().__init__()
        self.cfg = cfg
  1. ModelConfigLike is not a base class to inherit from, but rather a form of structural typing that signals that MyModelConfig is compatible with the splifft configuration system. You can remove it if you don't like it.

2. Register the model

With the model and its config defined, our configuration system needs to understand your model.

Don't do this

A common solution is to define a "global" dictionary of available models:

src/my_library/models/__init__.py
from my_library.models.my_model import MyModelConfig, MyModel

MODEL_REGISTRY = {
    "my_model": (MyModel, MyModelConfig),
    # every other model must be added here
}

To add a new model, you'd have to modify this central registry. It also forces the import of all models and unwanted dependencies at once.

Instead, our configuration system uses a simple ModelMetadata wrapper struct to act as a "descriptor" for your model. Create a factory function that defers the imports until its actually needed:

src/my_library/models/__init__.py
from splifft.models import ModelMetadata


def my_model_metadata():
    from .my_model import MyModel, MyModelConfig

    return ModelMetadata(model_type="my_model", config=MyModelConfig, model=MyModel)
I need to take a user's input string and dynamically import the model. How?

ModelMetadata.from_module is an alternative way to load the model metadata. It uses importlib under the hood. In fact, our CLI uses this exact approach.

from splifft.models import ModelMetadata

my_model_metadata = ModelMetadata.from_module(
    module_name="my_library.models.my_model",
    model_cls_name="MyModel",
    model_type="my_model"
)

3. Putting everything together

First, load in the configuration:

scripts/main.py
from pathlib import Path

from splifft.config import Config

config = Config.from_file(Path("path/to/my_model_config.json"))
This validates your JSON and returns a pydantic.BaseModel. Note that at this point, config.model is a lazy model configuration that is not yet fully validated.

Next, we need to create the PyTorch model. Concretize the lazy model configuration into the dataclass we defined earlier then instantiate the model:

scripts/main.py
from my_library.models import my_model_metadata

metadata = my_model_metadata()
my_model_config = config.model.to_concrete(metadata.config)
model = metadata.model(my_model_config)

Finally, load the weights, input audio and run!

scripts/main.py
from splifft.inference import run_inference_on_file
from splifft.io import load_weights, read_audio

checkpoint_path = Path("path/to/my_model.pt")
model = load_weights(model, checkpoint_path, device="cpu")

mixture = read_audio(
    Path("path/to/mixture.wav"), config.audio_io.target_sample_rate, config.audio_io.force_channels
)
stems = run_inference_on_file(mixture, config, model)

print(f"{list(stems.keys())=}")