Tutorial
This page targets two audiences:
- Users shipping features: use the high-level engine and keep config-driven behavior.
- Researchers/developers: drop to low-level APIs when you need full control over tensors, transforms, and post-processing.
Basic inference
Use the splifft.inference.InferenceEngine.from_pretrained for a convenient high level API.
PATH_MIXTURE = "data/audio/input/3BFTio5296w.flac"
from splifft.inference import InferenceEngine
engine = InferenceEngine.from_pretrained(
config="/path/to/config.json",
checkpoint_path="/path/to/checkpoint.pt",
)
result = engine.run(PATH_MIXTURE)
print(result)
#
# or, if you use the default user cache registry
#
engine = InferenceEngine.from_registry("bs_roformer-fruit-sw")
#
# and to track progress for long files or on slow hardware:
#
import logging
from splifft.inference import InferenceOutput
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s", datefmt="[%X]")
logger = logging.getLogger(__name__)
for event in engine.stream(PATH_MIXTURE):
if isinstance(event, InferenceOutput):
print(event)
break
logger.info(event)
[00:00:36] Stage.Started(stage='normalize', total_batches=None)
[00:00:36] Stage.Completed(stage='normalize')
[00:00:38] ChunkProcessed(batch_index=1, total_batches=9)
[00:00:39] ChunkProcessed(batch_index=2, total_batches=9)
[00:00:41] ChunkProcessed(batch_index=3, total_batches=9)
[00:00:42] ChunkProcessed(batch_index=4, total_batches=9)
[00:00:43] ChunkProcessed(batch_index=5, total_batches=9)
[00:00:44] ChunkProcessed(batch_index=6, total_batches=9)
[00:00:46] ChunkProcessed(batch_index=7, total_batches=9)
[00:00:47] ChunkProcessed(batch_index=8, total_batches=9)
[00:00:47] ChunkProcessed(batch_index=9, total_batches=9)
[00:00:47] Stage.Started(stage='stitch', total_batches=None)
[00:00:47] Stage.Completed(stage='stitch')
[00:00:47] Stage.Started(stage='collect_outputs', total_batches=None)
[00:00:47] Stage.Completed(stage='collect_outputs')
[00:00:47] Stage.Started(stage='derive_stems', total_batches=None)
[00:00:47] Stage.Completed(stage='derive_stems')
InferenceOutput(
outputs={
'bass': tensor([[-1.3643e-05, -1.3736e-05, -1.3643e-05, ...,
-1.3958e-05,
-1.3730e-05, -1.3960e-05],
[-1.3811e-05, -1.3586e-05, -1.3811e-05, ..., -1.3738e-05,
-1.3953e-05, -1.3736e-05]], device='cuda:0'),
'drums': tensor([[-1.3493e-05, -1.4200e-05, -1.3493e-05, ...,
-1.2080e-05,
-1.2848e-05, -1.2020e-05],
[-1.3936e-05, -1.3758e-05, -1.3936e-05, ..., -1.1843e-05,
-1.2818e-05, -1.1989e-05]], device='cuda:0'),
'other': tensor([[-7.5168e-07, -6.3413e-07, -7.5222e-07, ...,
1.9690e-05,
-3.3400e-05, 2.5086e-05],
[-7.4173e-07, -6.7063e-07, -7.4244e-07, ..., 3.2220e-05,
-3.7293e-05, 2.0826e-05]], device='cuda:0'),
'vocals': tensor([[-1.3789e-05, -1.3904e-05, -1.3789e-05, ...,
-1.3930e-05,
-1.3755e-05, -1.4037e-05],
[-1.3860e-05, -1.3833e-05, -1.3860e-05, ..., -1.3848e-05,
-1.3747e-05, -1.3913e-05]], device='cuda:0'),
'guitar': tensor([[-1.3846e-05, -1.3846e-05, -1.3846e-05, ...,
-1.3928e-05,
-1.3760e-05, -1.3928e-05],
[-1.3910e-05, -1.3782e-05, -1.3910e-05, ..., -1.3871e-05,
-1.3818e-05, -1.3871e-05]], device='cuda:0'),
'piano': tensor([[-1.3789e-05, -1.3902e-05, -1.3789e-05, ...,
-1.3933e-05,
-1.3759e-05, -1.3933e-05],
[-1.3881e-05, -1.3810e-05, -1.3881e-05, ..., -1.3849e-05,
-1.3843e-05, -1.3849e-05]], device='cuda:0'),
'instrum': tensor([[ 1.3789e-05, 1.3904e-05, 1.3789e-05, ...,
4.7834e-05,
-2.5873e-05, 5.2345e-05],
[ 1.3860e-05, 1.3833e-05, 1.3860e-05, ..., 6.8972e-05,
-3.0868e-05, 4.3241e-05]], device='cuda:0')
},
sample_rate=44100
)
This outputs splifft.inference.InferenceOutput, containing:
- the dictionary of stem names to tensor (which can be audio or logits)
- the sample rate of the input tensor (so you can save the audio)
Low level inference
# ruff: noqa: E402
from pathlib import Path
import torch
PATH_CONFIG = Path("data/config/bs_roformer.json")
PATH_CKPT = Path("data/models/roformer-fp16.pt")
PATH_MIXTURE = Path("data/audio/input/3BFTio5296w.flac")
# 1. parse + validate a JSON *without* having to import a particular pytorch model.
from splifft.config import Config
config = Config.from_file(PATH_CONFIG)
# 2. we now want to *lock in* the configuration to a specific model.
from splifft.models import ModelMetadata
from splifft.models.bs_roformer import BSRoformer, BSRoformerParams
metadata = ModelMetadata(model_type="bs_roformer", params=BSRoformerParams, model=BSRoformer)
model_params = config.model.to_concrete(metadata.params)
# 3. `metadata` acts as a model builder
from splifft.io import load_weights
model = metadata.model(model_params)
model = load_weights(model, PATH_CKPT, device="cpu")
# 4. load audio and run inference by passing dependencies explicitly.
from splifft.inference import InferenceEngine
from splifft.io import read_audio
mixture = read_audio(
PATH_MIXTURE,
config.audio_io.target_sample_rate,
config.audio_io.force_channels,
)
engine = InferenceEngine(
config=config,
model=model,
model_params_concrete=model_params,
model_device=next(model.parameters()).device,
io_device=torch.device("cpu"),
)
result = engine.run(mixture)
print(list(result.outputs.keys()))
Extending splifft
splifft is designed to be easily extended without modifying its core.
Make sure you have added splifft as a dependency. Assuming your library has this structure:
├── pyproject.toml
├── scripts
│ └── main.py
└── src
└── my_library
└── models
├── __init__.py
└── my_model.py
1. Define a new model
Don't do this
A common pattern is to define a model with a huge list of parameters in its __init__ method:
from torch import nn
from beartype import beartype
class MyModel(nn.Module):
@beartype
def __init__(
self,
chunk_size: int,
output_stem_names: tuple[str, ...],
# ... a bunch of args here
):
...
The problem is that it tightly couples the model's implementation to its configuration. Serializing to/from a JSON file and simultaneously supporting static type checking is a headache.
Instead, define a stdlib dataclass separate from the model:
from dataclasses import dataclass
from torch import nn
from splifft.models import ModelParamsLike
from splifft.types import (
ChunkSize,
InferenceArchetype,
ModelInputType,
ModelOutputStemName,
ModelOutputType,
)
@dataclass
class MyModelParams(ModelParamsLike): # (1)!
chunk_size: ChunkSize
output_stem_names: tuple[ModelOutputStemName, ...]
# ... any other config your model needs
@property
def input_type(self) -> ModelInputType:
return "waveform"
@property
def output_type(self) -> ModelOutputType:
return "waveform"
@property
def inference_archetype(self) -> InferenceArchetype:
return "standard_end_to_end"
class MyModel(nn.Module):
def __init__(self, params: MyModelParams):
super().__init__()
self.params = params
ModelParamsLikeis not a base class to inherit from, but rather a form of structural typing that signals thatMyModelParamsis compatible with thesplifftconfiguration system. You can remove it if you don't like it.
2. Register the model
With the model and its config defined, our configuration system needs to understand your model.
Don't do this
A common solution is to define a "global" dictionary of available models:
from my_library.models.my_model import MyModelParams, MyModel
MODEL_REGISTRY = {
"my_model": (MyModel, MyModelParams),
# every other model must be added here
}
To add a new model, you'd have to modify this central registry. It also forces the import of all models and unwanted dependencies at once.
Instead, our configuration system uses a simple ModelMetadata wrapper struct to act as a "descriptor" for your model. Create a factory function that defers the imports until its actually needed:
from splifft.models import ModelMetadata
def my_model_metadata():
from .my_model import MyModel, MyModelParams
return ModelMetadata(model_type="my_model", params=MyModelParams, model=MyModel)
I need to take a user's input string and dynamically import the model. How?
ModelMetadata.from_module is an alternative way to load the model metadata. It uses importlib under the hood. In fact, our CLI uses this exact approach.
from splifft.models import ModelMetadata
my_model_metadata = ModelMetadata.from_module(
module_name="my_library.models.my_model",
model_cls_name="MyModel",
model_type="my_model"
)
3. Putting everything together
First, load in the configuration:
from pathlib import Path
import torch
from splifft.config import Config
config = Config.from_file(Path("path/to/my_model_config.json"))
This validates your JSON and returns a pydantic.BaseModel. Note that at this point, config.model is a lazy model configuration that is not yet fully validated.
Next, we need to create the PyTorch model. Concretize the lazy model configuration into the dataclass we defined earlier then instantiate the model:
from my_library.models import my_model_metadata
metadata = my_model_metadata()
my_model_params = config.model.to_concrete(metadata.params)
model = metadata.model(my_model_params)
Finally, load the weights, input audio and run!
from splifft.inference import InferenceEngine
from splifft.io import load_weights, read_audio
checkpoint_path = Path("path/to/my_model.pt")
model = load_weights(model, checkpoint_path, device="cpu")
mixture = read_audio(
Path("path/to/mixture.wav"), config.audio_io.target_sample_rate, config.audio_io.force_channels
)
engine = InferenceEngine(
config=config,
model=model,
model_params_concrete=my_model_params,
model_device=next(model.parameters()).device,
io_device=torch.device("cpu"),
)
result = engine.run(mixture)
print(f"{list(result.outputs.keys())=}")