Library Tutorial
Basic inference
This example demonstrates the lower level API for inference usecases. In the future, we will have a high level API for convenience.
# ruff: noqa: E402
from pathlib import Path
PATH_CONFIG = Path("data/config/bs_roformer.json")
PATH_CKPT = Path("data/models/roformer-fp16.pt")
PATH_MIXTURE = Path("data/audio/input/3BFTio5296w.flac")
# 1. parse + validate a JSON *without* having to import a particular pytorch model.
from splifft.config import Config
config = Config.from_file(PATH_CONFIG)
# 2. we now want to *lock in* the configuration to a specific model.
from splifft.models import ModelMetadata
from splifft.models.bs_roformer import BSRoformer, BSRoformerParams
metadata = ModelMetadata(model_type="bs_roformer", params=BSRoformerParams, model=BSRoformer)
model_params = config.model.to_concrete(metadata.params)
# 3. `metadata` acts as a model builder
from splifft.io import load_weights
model = metadata.model(model_params)
model = load_weights(model, PATH_CKPT, device="cpu")
# 4. load audio and run inference by passing dependencies explicitly.
from splifft.inference import run_inference_on_file
from splifft.io import read_audio
mixture = read_audio(
PATH_MIXTURE,
config.audio_io.target_sample_rate,
config.audio_io.force_channels,
)
stems = run_inference_on_file(mixture, config, model, model_params)
print(list(stems.keys()))
Extending splifft
splifft
is designed to be easily extended without modifying its core.
Make sure you have added splifft
as a dependency. Assuming your library has this structure:
├── pyproject.toml
├── scripts
│ └── main.py
└── src
└── my_library
└── models
├── __init__.py
└── my_model.py
1. Define a new model
Don't do this
A common pattern is to define a model with a huge list of parameters in its __init__
method:
from torch import nn
from beartype import beartype
class MyModel(nn.Module):
@beartype
def __init__(
self,
chunk_size: int,
output_stem_names: tuple[str, ...],
# ... a bunch of args here
):
...
The problem is that it tightly couples the model's implementation to its configuration. Serializing to/from a JSON file and simultaneously supporting static type checking is a headache.
Instead, define a stdlib dataclass
separate from the model:
from dataclasses import dataclass
from torch import nn
from splifft.models import ModelParamsLike
from splifft.types import ChunkSize, ModelInputType, ModelOutputStemName, ModelOutputType
@dataclass
class MyModelParams(ModelParamsLike): # (1)!
chunk_size: ChunkSize
output_stem_names: tuple[ModelOutputStemName, ...]
# ... any other config your model needs
@property
def input_type(self) -> ModelInputType:
return "waveform"
@property
def output_type(self) -> ModelOutputType:
return "waveform"
class MyModel(nn.Module):
def __init__(self, params: MyModelParams):
super().__init__()
self.params = params
ModelParamsLike
is not a base class to inherit from, but rather a form of structural typing that signals thatMyModelParams
is compatible with thesplifft
configuration system. You can remove it if you don't like it.
2. Register the model
With the model and its config defined, our configuration system needs to understand your model.
Don't do this
A common solution is to define a "global" dictionary of available models:
from my_library.models.my_model import MyModelParams, MyModel
MODEL_REGISTRY = {
"my_model": (MyModel, MyModelParams),
# every other model must be added here
}
To add a new model, you'd have to modify this central registry. It also forces the import of all models and unwanted dependencies at once.
Instead, our configuration system uses a simple ModelMetadata
wrapper struct to act as a "descriptor" for your model. Create a factory function that defers the imports until its actually needed:
from splifft.models import ModelMetadata
def my_model_metadata():
from .my_model import MyModel, MyModelParams
return ModelMetadata(model_type="my_model", params=MyModelParams, model=MyModel)
I need to take a user's input string and dynamically import the model. How?
ModelMetadata.from_module
is an alternative way to load the model metadata. It uses importlib under the hood. In fact, our CLI uses this exact approach.
from splifft.models import ModelMetadata
my_model_metadata = ModelMetadata.from_module(
module_name="my_library.models.my_model",
model_cls_name="MyModel",
model_type="my_model"
)
3. Putting everything together
First, load in the configuration:
from pathlib import Path
from splifft.config import Config
config = Config.from_file(Path("path/to/my_model_config.json"))
config.model
is a lazy model configuration that is not yet fully validated.
Next, we need to create the PyTorch model. Concretize the lazy model configuration into the dataclass
we defined earlier then instantiate the model:
from my_library.models import my_model_metadata
metadata = my_model_metadata()
my_model_params = config.model.to_concrete(metadata.params)
model = metadata.model(my_model_params)
Finally, load the weights, input audio and run!
from splifft.inference import run_inference_on_file
from splifft.io import load_weights, read_audio
checkpoint_path = Path("path/to/my_model.pt")
model = load_weights(model, checkpoint_path, device="cpu")
mixture = read_audio(
Path("path/to/mixture.wav"), config.audio_io.target_sample_rate, config.audio_io.force_channels
)
stems = run_inference_on_file(mixture, config, model, my_model_params)
print(f"{list(stems.keys())=}")