Skip to content

Inference

inference

Public inference APIs.

Classes:

Name Description
ChunkProcessed
Stage
InferenceOutput
InferenceEngine

Functions:

Name Description
resolve_model_entrypoint

Attributes:

Name Type Description
SUPPORTED_MODELS dict[str, tuple[str, str]]
InferenceEvent TypeAlias

SUPPORTED_MODELS module-attribute

SUPPORTED_MODELS: dict[str, tuple[str, str]] = {
    "bs_roformer": (
        "splifft.models.bs_roformer",
        "BSRoformer",
    ),
    "mel_roformer": (
        "splifft.models.bs_roformer",
        "BSRoformer",
    ),
    "mdx23c": ("splifft.models.mdx23c", "MDX23C"),
    "beat_this": ("splifft.models.beat_this", "BeatThis"),
    "pesto": ("splifft.models.pesto", "Pesto"),
    "basic_pitch": (
        "splifft.models.basic_pitch",
        "BasicPitch",
    ),
}

resolve_model_entrypoint

resolve_model_entrypoint(
    model_type: ModelType,
    module_name: str | None,
    class_name: str | None,
) -> tuple[str, str]
Source code in src/splifft/inference.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def resolve_model_entrypoint(
    model_type: t.ModelType,
    module_name: str | None,
    class_name: str | None,
) -> tuple[str, str]:
    if module_name is not None and class_name is not None:
        return module_name, class_name
    try:
        return SUPPORTED_MODELS[model_type]
    except KeyError as e:
        raise ValueError(
            f"could not resolve model entrypoint for model_type={model_type!r}; "
            "provide both module and class explicitly"
        ) from e

ChunkProcessed dataclass

ChunkProcessed(batch_index: int, total_batches: int)

Attributes:

Name Type Description
batch_index int
total_batches int

batch_index instance-attribute

batch_index: int

total_batches instance-attribute

total_batches: int

Stage dataclass

Stage(stage: str, *, total_batches: int | None = None)

Classes:

Name Description
Started
Completed

Methods:

Name Description
__enter__
__exit__

Attributes:

Name Type Description
stage str
total_batches int | None
started Started
completed Completed

stage instance-attribute

stage: str

total_batches class-attribute instance-attribute

total_batches: int | None = field(
    kw_only=True, default=None
)

Started dataclass

Started(stage: str, total_batches: int | None = None)

Attributes:

Name Type Description
stage str
total_batches int | None
stage instance-attribute
stage: str
total_batches class-attribute instance-attribute
total_batches: int | None = None

Completed dataclass

Completed(stage: str)

Attributes:

Name Type Description
stage str
stage instance-attribute
stage: str

started property

started: Started

completed property

completed: Completed

__enter__

__enter__() -> Stage
Source code in src/splifft/inference.py
106
107
def __enter__(self) -> Stage:
    return self

__exit__

__exit__(*_: object) -> None
Source code in src/splifft/inference.py
109
110
def __exit__(self, *_: object) -> None:
    return None

InferenceOutput dataclass

InferenceOutput(
    outputs: dict[StemName, RawAudioTensor]
    | dict[str, Tensor],
    sample_rate: SampleRate,
)

Attributes:

Name Type Description
outputs dict[StemName, RawAudioTensor] | dict[str, Tensor]
sample_rate SampleRate

outputs instance-attribute

sample_rate instance-attribute

sample_rate: SampleRate

InferenceEvent module-attribute

InferenceEngine dataclass

InferenceEngine(
    config: Config,
    model: Module,
    model_params_concrete: ModelParamsLike,
    model_device: device,
    io_device: device,
    model_input_dtype: dtype | None,
)

Methods:

Name Description
from_pretrained
from_registry
to_audio_tensor
run
stream

Attributes:

Name Type Description
config Config
model Module
model_params_concrete ModelParamsLike
model_device device
io_device device
model_input_dtype dtype | None

config instance-attribute

config: Config

model instance-attribute

model: Module

model_params_concrete instance-attribute

model_params_concrete: ModelParamsLike

model_device instance-attribute

model_device: device

io_device instance-attribute

io_device: device

model_input_dtype instance-attribute

model_input_dtype: dtype | None

from_pretrained classmethod

from_pretrained(
    *,
    config: IntoConfig,
    checkpoint_path: StrPath,
    overrides: ConfigOverrides = (),
    model_device: device | str | None = None,
    io_device: device | str | None = None,
    module_name: str | None = None,
    class_name: str | None = None,
    package_name: str | None = None,
) -> InferenceEngine
Source code in src/splifft/inference.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
@classmethod
def from_pretrained(
    cls,
    *,
    config: IntoConfig,
    checkpoint_path: t.StrPath,
    overrides: ConfigOverrides = (),
    model_device: torch.device | str | None = None,
    io_device: torch.device | str | None = None,
    module_name: str | None = None,
    class_name: str | None = None,
    package_name: str | None = None,
) -> InferenceEngine:
    from .config import into_config
    from .io import load_weights
    from .models import ModelMetadata

    config = into_config(config, overrides=overrides)

    model_device_resolved = _resolve_device(
        model_device or config.inference.model_device,
        field_name="inference.model_device",
    )
    io_device_resolved = _resolve_device(
        io_device or config.inference.io_device,
        field_name="inference.io_device",
    )
    resolved_module, resolved_class = resolve_model_entrypoint(
        config.model_type, module_name, class_name
    )
    metadata = ModelMetadata.from_module(
        module_name=resolved_module,
        model_cls_name=resolved_class,
        model_type=config.model_type,
        package=package_name,
    )

    model_params = config.model.to_concrete(metadata.params)
    full_output_stems = tuple(config.model.output_stem_names)
    requested_stems = tuple(config.inference.requested_stems or full_output_stems)

    state_dict_transform = None
    if requested_stems != full_output_stems:
        # optional model-level optimization contract: models can choose to
        # provide a stem-selection plan that may mutate params and checkpoint
        # loading. if absent, we keep full model outputs and discard
        # unrelated stems immediately after each forward pass.
        from .models import SupportsStemSelection

        if isinstance(metadata.model, SupportsStemSelection):
            plan = metadata.model.__splifft_stem_selection_plan__(model_params, requested_stems)
            model_params = plan.model_params
            state_dict_transform = plan.state_dict_transform

    model = metadata.model(model_params)
    if (forced_dtype := config.inference.force_weights_dtype) is not None:
        model = model.to(_resolve_runtime_dtype(forced_dtype, device=model_device_resolved))
    model = load_weights(
        model,
        checkpoint_path,
        device=model_device_resolved,
        state_dict_transform=state_dict_transform,
    ).eval()

    # maybe we should to an explicit try_compile() method while emitting events but eh.
    # we shuold probably log since it can take extremely long
    if (compile_cfg := config.inference.compile_model) is not None:
        compiled_model = torch.compile(
            model,
            fullgraph=compile_cfg.fullgraph,
            dynamic=compile_cfg.dynamic,
            mode=compile_cfg.mode,
        )
        model = cast(nn.Module, compiled_model)

    return cls(
        config=config,
        model=model,
        model_params_concrete=model_params,
        model_device=model_device_resolved,
        io_device=io_device_resolved,
        model_input_dtype=core.get_model_floating_dtype(model),
    )

from_registry classmethod

from_registry(
    model_id: str,
    *,
    model_device: device | str | None = None,
    io_device: device | str | None = None,
    overrides: ConfigOverrides = (),
    fetch_if_missing: bool = True,
    force_overwrite_config: bool = False,
    force_overwrite_model: bool = False,
    registry_path: Path = PATH_REGISTRY_DEFAULT,
) -> InferenceEngine
Source code in src/splifft/inference.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
@classmethod
def from_registry(
    cls,
    model_id: str,
    *,
    model_device: torch.device | str | None = None,
    io_device: torch.device | str | None = None,
    overrides: ConfigOverrides = (),
    fetch_if_missing: bool = True,
    force_overwrite_config: bool = False,
    force_overwrite_model: bool = False,
    registry_path: Path = PATH_REGISTRY_DEFAULT,
) -> InferenceEngine:
    from .config import Registry
    from .io import get_model_paths

    model_paths = get_model_paths(
        model_id,
        fetch_if_missing=fetch_if_missing,
        force_overwrite_config=force_overwrite_config,
        force_overwrite_model=force_overwrite_model,
        registry=Registry.from_file(registry_path),
    )
    return cls.from_pretrained(
        config=model_paths.path_config,
        checkpoint_path=model_paths.path_checkpoint,
        overrides=overrides,
        model_device=model_device,
        io_device=io_device,
    )

to_audio_tensor

to_audio_tensor(
    mixture: StrPath
    | BytesPath
    | RawAudioTensor
    | Audio[RawAudioTensor],
) -> Audio[RawAudioTensor]
Source code in src/splifft/inference.py
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
def to_audio_tensor(
    self,
    mixture: t.StrPath | t.BytesPath | t.RawAudioTensor | core.Audio[t.RawAudioTensor],
) -> core.Audio[t.RawAudioTensor]:
    if isinstance(mixture, core.Audio):
        return mixture
    elif isinstance(mixture, torch.Tensor):
        return core.Audio(
            data=t.RawAudioTensor(mixture),
            sample_rate=self.config.audio_io.target_sample_rate,
        )
    else:
        from .io import read_audio

        return read_audio(
            mixture,  # type: ignore[arg-type]
            self.config.audio_io.target_sample_rate,
            self.config.audio_io.force_channels,
            device=self.io_device,
        )

run

Source code in src/splifft/inference.py
329
330
331
332
333
334
335
336
def run(
    self,
    mixture: t.StrPath | t.BytesPath | t.RawAudioTensor | core.Audio[t.RawAudioTensor],
) -> InferenceOutput:
    for event in self.stream(mixture):
        if isinstance(event, InferenceOutput):
            return event
    raise RuntimeError("inference stream finished without outputs")

stream

stream(
    mixture: StrPath
    | BytesPath
    | RawAudioTensor
    | Audio[RawAudioTensor],
) -> Generator[InferenceEvent, None, None]
Source code in src/splifft/inference.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
def stream(
    self,
    mixture: t.StrPath | t.BytesPath | t.RawAudioTensor | core.Audio[t.RawAudioTensor],
) -> Generator[InferenceEvent, None, None]:
    archetype = self.config.validate_inference_contract(self.model_params_concrete)

    audio_tensor = self.to_audio_tensor(mixture)
    raw_mixture_data = t.RawAudioTensor(audio_tensor.data.to(self.io_device))
    mixture_data: t.RawAudioTensor | t.NormalizedAudioTensor = raw_mixture_data
    mixture_stats: core.NormalizationStats | None = None

    if self.config.normalization.enabled:
        with Stage("normalize") as s:
            yield s.started
            normalized = core.normalize_audio(
                core.Audio(data=raw_mixture_data, sample_rate=audio_tensor.sample_rate)
            )
            mixture_data = normalized.audio.data
            mixture_stats = normalized.stats
            yield s.completed

    mixture_data = yield from self._adapt_input_channels(mixture_data)

    if archetype == "sequence_labeling":
        requested_stems = self._requested_output_stem_names()
        requested_stem_indices = self._requested_output_stem_indices()
        sequence_outputs = yield from self._stream_sequence_labeling(
            mixture_data,
            requested_stems=requested_stems,
            output_indices=requested_stem_indices,
        )
        yield InferenceOutput(outputs=sequence_outputs, sample_rate=audio_tensor.sample_rate)
        return

    requested_stems = self._requested_output_stem_names()
    requested_stem_indices = self._requested_output_stem_indices()
    separated_data = yield from self._stream_waveform_pipeline(
        mixture_data,
        archetype,
        output_indices=requested_stem_indices,
        num_stems=len(requested_stems),
    )

    denormalized_stems: dict[t.ModelOutputStemName, t.RawAudioTensor] = {}
    with Stage("collect_outputs") as s:
        yield s.started
        for i, stem_name in enumerate(requested_stems):
            stem_data = separated_data[i, ...]
            if mixture_stats is not None:
                stem_data = core.denormalize_audio(
                    audio_data=t.NormalizedAudioTensor(stem_data),
                    stats=mixture_stats,
                )
            denormalized_stems[stem_name] = t.RawAudioTensor(stem_data)
        yield s.completed

    output_stems: dict[StemName, t.RawAudioTensor] = denormalized_stems
    if derived_stems_cfg := self.config.derived_stems:
        with Stage("derive_stems") as s:
            yield s.started
            output_stems = core.derive_stems(
                denormalized_stems,
                raw_mixture_data,
                derived_stems_cfg,
            )
            yield s.completed

    yield InferenceOutput(outputs=output_stems, sample_rate=audio_tensor.sample_rate)