Library Tutorial
Extending splifft
splifft
is designed to be easily extended without modifying its core.
Make sure you have added splifft
as a dependency. Assuming your library has this structure:
├── pyproject.toml
├── scripts
│ └── main.py
└── src
└── my_library
└── models
├── __init__.py
└── my_model.py
1. Define a new model
Don't do this
A common pattern is to define a model with a huge list of parameters in its __init__
method:
from torch import nn
from beartype import beartype
class MyModel(nn.Module):
@beartype
def __init__(
self,
chunk_size: int,
output_stem_names: tuple[str, ...],
# ... a bunch of args here
):
...
The problem is that it tightly couples the model's implementation to its configuration. Serializing to/from a JSON file and simultaneously supporting static type checking is a headache.
Instead, define a stdlib dataclass
separate from the model:
from dataclasses import dataclass
from torch import nn
from splifft.models import ChunkSize, ModelConfigLike, ModelOutputStemName
@dataclass
class MyModelConfig(ModelConfigLike): # (1)!
chunk_size: ChunkSize
output_stem_names: tuple[ModelOutputStemName, ...]
# ... any other parameters your model needs
class MyModel(nn.Module):
def __init__(self, cfg: MyModelConfig):
super().__init__()
self.cfg = cfg
ModelConfigLike
is not a base class to inherit from, but rather a form of structural typing that signals thatMyModelConfig
is compatible with thesplifft
configuration system. You can remove it if you don't like it.
2. Register the model
With the model and its config defined, our configuration system needs to understand your model.
Don't do this
A common solution is to define a "global" dictionary of available models:
from my_library.models.my_model import MyModelConfig, MyModel
MODEL_REGISTRY = {
"my_model": (MyModel, MyModelConfig),
# every other model must be added here
}
To add a new model, you'd have to modify this central registry. It also forces the import of all models and unwanted dependencies at once.
Instead, our configuration system uses a simple ModelMetadata
wrapper struct to act as a "descriptor" for your model. Create a factory function that defers the imports until its actually needed:
from splifft.models import ModelMetadata
def my_model_metadata():
from .my_model import MyModel, MyModelConfig
return ModelMetadata(model_type="my_model", config=MyModelConfig, model=MyModel)
I need to take a user's input string and dynamically import the model. How?
ModelMetadata.from_module
is an alternative way to load the model metadata. It uses importlib under the hood. In fact, our CLI uses this exact approach.
from splifft.models import ModelMetadata
my_model_metadata = ModelMetadata.from_module(
module_name="my_library.models.my_model",
model_cls_name="MyModel",
model_type="my_model"
)
3. Putting everything together
First, load in the configuration:
from pathlib import Path
from splifft.config import Config
config = Config.from_file(Path("path/to/my_model_config.json"))
config.model
is a lazy model configuration that is not yet fully validated.
Next, we need to create the PyTorch model. Concretize the lazy model configuration into the dataclass
we defined earlier then instantiate the model:
from my_library.models import my_model_metadata
metadata = my_model_metadata()
my_model_config = config.model.to_concrete(metadata.config)
model = metadata.model(my_model_config)
Finally, load the weights, input audio and run!
from splifft.inference import run_inference_on_file
from splifft.io import load_weights, read_audio
checkpoint_path = Path("path/to/my_model.pt")
model = load_weights(model, checkpoint_path, device="cpu")
mixture = read_audio(
Path("path/to/mixture.wav"), config.audio_io.target_sample_rate, config.audio_io.force_channels
)
stems = run_inference_on_file(mixture, config, model)
print(f"{list(stems.keys())=}")