Skip to content

Config

config

Configuration

Classes:

Name Description
LazyModelConfig

A lazily validated model configuration.

StftConfig

configuration for the short-time fourier transform.

AudioIOConfig
TorchCompileConfig
InferenceConfig
ChunkingConfig
MaskingConfig
SubtractConfig
SumConfig
OutputConfig
Config
Model
Metrics
Resource
Comment
Registry

Attributes:

Name Type Description
TorchDtype TypeAlias
Tuple
NonEmptyUnique
ModelInputStemName TypeAlias
ModelOutputStemName TypeAlias
DerivedStemName TypeAlias

The name of a derived stem, e.g. vocals_minus_drums.

StemName TypeAlias

A name of a stem, either a model output stem or a derived stem.

DerivedStemRule TypeAlias
DerivedStemsConfig TypeAlias

TorchDtype module-attribute

TorchDtype: TypeAlias = Annotated[
    dtype, GetPydanticSchema(_get_torch_dtype_schema)
]

Tuple module-attribute

Tuple = Annotated[
    tuple[_Item, ...], BeforeValidator(_to_tuple)
]

NonEmptyUnique module-attribute

NonEmptyUnique = Annotated[
    _S,
    Len(min_length=1),
    AfterValidator(_validate_unique_sequence),
    Field(json_schema_extra={"unique_items": True}),
]

ModelInputStemName module-attribute

ModelInputStemName: TypeAlias = Literal['mixture']

ModelOutputStemName module-attribute

ModelOutputStemName: TypeAlias = Annotated[
    ModelOutputStemName, StringConstraints(min_length=1)
]

LazyModelConfig

Bases: BaseModel

A lazily validated model configuration.

Note that it is not guaranteed to be fully valid until to_concrete is called.

Methods:

Name Description
to_concrete

Validate against a real set of model parameters and convert to it.

Attributes:

Name Type Description
chunk_size ChunkSize
output_stem_names NonEmptyUnique[Tuple[ModelOutputStemName]]
stem_names tuple[ModelInputStemName | ModelOutputStemName, ...]

Returns the model's input and output stem names.

model_config

chunk_size instance-attribute

chunk_size: ChunkSize

output_stem_names instance-attribute

output_stem_names: NonEmptyUnique[
    Tuple[ModelOutputStemName]
]

to_concrete

to_concrete(
    model_params: type[ModelParamsLikeT],
    *,
    pydantic_config: ConfigDict = ConfigDict(
        extra="forbid"
    ),
) -> ModelParamsLikeT

Validate against a real set of model parameters and convert to it.

Raises:

Type Description
pydantic.ValidationError

if extra fields are present in the model parameters that doesn't exist in the concrete model parameters.

Source code in src/splifft/config.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def to_concrete(
    self,
    model_params: type[ModelParamsLikeT],
    *,
    pydantic_config: ConfigDict = ConfigDict(extra="forbid"),
) -> ModelParamsLikeT:
    """Validate against a real set of model parameters and convert to it.

    :raises pydantic.ValidationError: if extra fields are present in the model parameters
        that doesn't exist in the concrete model parameters.
    """
    # input_type and output_type are inconfigurable anyway
    # TODO: use lru cache to avoid recreating the TypeAdapter in a hot loop but dict isn't hashable
    ta = TypeAdapter(
        type(
            f"{model_params.__name__}Validator",
            (model_params,),
            {"__pydantic_config__": pydantic_config},
        )  # needed for https://docs.pydantic.dev/latest/errors/usage_errors/#type-adapter-config-unused
    )  # type: ignore
    # types defined within `TYPE_CHECKING` blocks will be forward references, so we need rebuild
    ta.rebuild(_types_namespace={"TorchDtype": TorchDtype, "t": t})
    model_params_concrete: ModelParamsLikeT = ta.validate_python(self.model_dump())  # type: ignore
    return model_params_concrete

stem_names property

Returns the model's input and output stem names.

model_config class-attribute instance-attribute

model_config = ConfigDict(strict=True, extra='allow')

StftConfig

Bases: BaseModel

configuration for the short-time fourier transform.

Attributes:

Name Type Description
n_fft FftSize
hop_length HopSize
win_length FftSize
window_shape WindowShape
normalized bool
conv_dtype TorchDtype | None

The data type used for the conv1d buffers.

model_config

n_fft instance-attribute

n_fft: FftSize

hop_length instance-attribute

hop_length: HopSize

win_length instance-attribute

win_length: FftSize

window_shape class-attribute instance-attribute

window_shape: WindowShape = 'hann'

normalized class-attribute instance-attribute

normalized: bool = False

conv_dtype class-attribute instance-attribute

conv_dtype: TorchDtype | None = None

The data type used for the conv1d buffers.

model_config class-attribute instance-attribute

model_config = _PYDANTIC_STRICT_CONFIG

AudioIOConfig

Bases: BaseModel

Attributes:

Name Type Description
target_sample_rate SampleRate
force_channels Channels | None

Whether to force mono or stereo audio input. If None, keep original.

model_config

target_sample_rate class-attribute instance-attribute

target_sample_rate: SampleRate = 44100

force_channels class-attribute instance-attribute

force_channels: Channels | None = 2

Whether to force mono or stereo audio input. If None, keep original.

model_config class-attribute instance-attribute

model_config = _PYDANTIC_STRICT_CONFIG

TorchCompileConfig

Bases: BaseModel

Attributes:

Name Type Description
fullgraph bool
dynamic bool
mode Literal['default', 'reduce-overhead', 'max-autotune', 'max-autotune-no-cudagraphs']

fullgraph class-attribute instance-attribute

fullgraph: bool = True

dynamic class-attribute instance-attribute

dynamic: bool = True

mode class-attribute instance-attribute

mode: Literal[
    "default",
    "reduce-overhead",
    "max-autotune",
    "max-autotune-no-cudagraphs",
] = "reduce-overhead"

InferenceConfig

Bases: BaseModel

Attributes:

Name Type Description
normalize_input_audio bool
batch_size BatchSize
force_weights_dtype TorchDtype | None
use_autocast_dtype TorchDtype | None
compile_model TorchCompileConfig | None
apply_tta bool
model_config

normalize_input_audio class-attribute instance-attribute

normalize_input_audio: bool = False

batch_size class-attribute instance-attribute

batch_size: BatchSize = 8

force_weights_dtype class-attribute instance-attribute

force_weights_dtype: TorchDtype | None = None

use_autocast_dtype class-attribute instance-attribute

use_autocast_dtype: TorchDtype | None = None

compile_model class-attribute instance-attribute

compile_model: TorchCompileConfig | None = None

apply_tta class-attribute instance-attribute

apply_tta: bool = False

model_config class-attribute instance-attribute

model_config = _PYDANTIC_STRICT_CONFIG

ChunkingConfig

Bases: BaseModel

Attributes:

Name Type Description
method Literal['overlap_add_windowed']
overlap_ratio OverlapRatio
window_shape WindowShape
padding_mode PaddingMode
model_config

method class-attribute instance-attribute

method: Literal["overlap_add_windowed"] = (
    "overlap_add_windowed"
)

overlap_ratio class-attribute instance-attribute

overlap_ratio: OverlapRatio = 0.5

window_shape class-attribute instance-attribute

window_shape: WindowShape = 'hann'

padding_mode class-attribute instance-attribute

padding_mode: PaddingMode = 'reflect'

model_config class-attribute instance-attribute

model_config = _PYDANTIC_STRICT_CONFIG

MaskingConfig

Bases: BaseModel

Attributes:

Name Type Description
add_sub_dtype TorchDtype | None
out_dtype TorchDtype | None
model_config

add_sub_dtype class-attribute instance-attribute

add_sub_dtype: TorchDtype | None = None

out_dtype class-attribute instance-attribute

out_dtype: TorchDtype | None = None

model_config class-attribute instance-attribute

model_config = _PYDANTIC_STRICT_CONFIG

DerivedStemName module-attribute

DerivedStemName: TypeAlias = Annotated[
    str, StringConstraints(min_length=1)
]

The name of a derived stem, e.g. vocals_minus_drums.

StemName module-attribute

A name of a stem, either a model output stem or a derived stem.

SubtractConfig

Bases: BaseModel

Attributes:

Name Type Description
operation Literal['subtract']
stem_name StemName
by_stem_name StemName
model_config

operation instance-attribute

operation: Literal['subtract']

stem_name instance-attribute

stem_name: StemName

by_stem_name instance-attribute

by_stem_name: StemName

model_config class-attribute instance-attribute

model_config = _PYDANTIC_STRICT_CONFIG

SumConfig

Bases: BaseModel

Attributes:

Name Type Description
operation Literal['sum']
stem_names NonEmptyUnique[Tuple[StemName]]
model_config

operation instance-attribute

operation: Literal['sum']

stem_names instance-attribute

model_config class-attribute instance-attribute

model_config = _PYDANTIC_STRICT_CONFIG

DerivedStemRule module-attribute

DerivedStemRule: TypeAlias = Annotated[
    Union[SubtractConfig, SumConfig],
    Discriminator("operation"),
]

DerivedStemsConfig module-attribute

DerivedStemsConfig: TypeAlias = dict[
    DerivedStemName, DerivedStemRule
]

OutputConfig

Bases: BaseModel

Attributes:

Name Type Description
stem_names Literal['all'] | NonEmptyUnique[Tuple[StemName]]
file_format FileFormat
bit_rate BitRate | None

Output bit rate for lossy formats. The default is chosen by FFmpeg.

model_config

stem_names class-attribute instance-attribute

stem_names: (
    Literal["all"] | NonEmptyUnique[Tuple[StemName]]
) = "all"

file_format class-attribute instance-attribute

file_format: FileFormat = 'wav'

bit_rate class-attribute instance-attribute

bit_rate: BitRate | None = None

Output bit rate for lossy formats. The default is chosen by FFmpeg.

model_config class-attribute instance-attribute

model_config = _PYDANTIC_STRICT_CONFIG

Config

Bases: BaseModel

Methods:

Name Description
check_derived_stems
from_file

Attributes:

Name Type Description
identifier str

Unique identifier for this configuration

model_type ModelType
model LazyModelConfig
stft StftConfig | None
audio_io AudioIOConfig
inference InferenceConfig
chunking ChunkingConfig
masking MaskingConfig
derived_stems DerivedStemsConfig | None
output OutputConfig
experimental dict[str, Any] | None

Any extra experimental configurations outside of the splifft core.

model_config

identifier instance-attribute

identifier: str

Unique identifier for this configuration

model_type instance-attribute

model_type: ModelType

model instance-attribute

stft class-attribute instance-attribute

stft: StftConfig | None = None

audio_io class-attribute instance-attribute

audio_io: AudioIOConfig = Field(
    default_factory=AudioIOConfig
)

inference class-attribute instance-attribute

inference: InferenceConfig = Field(
    default_factory=InferenceConfig
)

chunking class-attribute instance-attribute

chunking: ChunkingConfig = Field(
    default_factory=ChunkingConfig
)

masking class-attribute instance-attribute

masking: MaskingConfig = Field(
    default_factory=MaskingConfig
)

derived_stems class-attribute instance-attribute

derived_stems: DerivedStemsConfig | None = None

output class-attribute instance-attribute

output: OutputConfig = Field(default_factory=OutputConfig)

experimental class-attribute instance-attribute

experimental: dict[str, Any] | None = None

Any extra experimental configurations outside of the splifft core.

check_derived_stems

check_derived_stems() -> Self
Source code in src/splifft/config.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
@model_validator(mode="after")
def check_derived_stems(self) -> Self:
    if self.derived_stems is None:
        return self
    existing_stem_names: list[StemName] = list(self.model.stem_names)
    for derived_stem_name, definition in self.derived_stems.items():
        if derived_stem_name in existing_stem_names:
            raise PydanticCustomError(
                "derived_stem_name_conflict",
                "Derived stem `{derived_stem_name}` must not conflict with existing stem names: `{existing_stem_names}`",
                {
                    "derived_stem_name": derived_stem_name,
                    "existing_stem_names": existing_stem_names,
                },
            )
        required_stems: tuple[StemName, ...] = tuple()
        if isinstance(definition, SubtractConfig):
            required_stems = (definition.stem_name, definition.by_stem_name)
        elif isinstance(definition, SumConfig):
            required_stems = definition.stem_names
        for stem_name in required_stems:
            if stem_name not in existing_stem_names:
                raise PydanticCustomError(
                    "invalid_derived_stem",
                    "Derived stem `{derived_stem_name}` requires stem `{stem_name}` but is not found in `{existing_stem_names}`",
                    {
                        "derived_stem_name": derived_stem_name,
                        "stem_name": stem_name,
                        "existing_stem_names": existing_stem_names,
                    },
                )
        existing_stem_names.append(derived_stem_name)
    return self

from_file classmethod

from_file(path: BytesPath) -> Config
Source code in src/splifft/config.py
295
296
297
298
@classmethod
def from_file(cls, path: t.BytesPath) -> Config:
    with open(path, "rb") as f:
        return Config.model_validate_json(f.read())

model_config class-attribute instance-attribute

model_config = ConfigDict(
    arbitrary_types_allowed=True,
    strict=True,
    extra="forbid",
)

Model

Bases: BaseModel

Attributes:

Name Type Description
authors list[str]
purpose Literal['separation', 'denoise', 'de-reverb', 'enhancement', 'crowd_removal', 'upscaler', 'phase_fixer'] | str
architecture Literal['bs_roformer', 'mel_roformer', 'mdx23c', 'scnet'] | str
release_date str | None

YYYY-MM-DD, date is optional

finetuned_from Identifier | None
output list[Instrument]
status Literal['alpha', 'beta', 'stable', 'deprecated'] | None
metrics list[Metrics]
description list[Comment]
approx_model_size_mb float | None

authors instance-attribute

authors: list[str]

purpose instance-attribute

purpose: (
    Literal[
        "separation",
        "denoise",
        "de-reverb",
        "enhancement",
        "crowd_removal",
        "upscaler",
        "phase_fixer",
    ]
    | str
)

architecture instance-attribute

architecture: (
    Literal[
        "bs_roformer", "mel_roformer", "mdx23c", "scnet"
    ]
    | str
)

release_date class-attribute instance-attribute

release_date: str | None = None

YYYY-MM-DD, date is optional

finetuned_from class-attribute instance-attribute

finetuned_from: Identifier | None = None

output class-attribute instance-attribute

output: list[Instrument] = Field(default_factory=list)

status class-attribute instance-attribute

status: (
    Literal["alpha", "beta", "stable", "deprecated"] | None
) = None

metrics class-attribute instance-attribute

metrics: list[Metrics] = Field(default_factory=list)

description class-attribute instance-attribute

description: list[Comment] = Field(default_factory=list)

approx_model_size_mb class-attribute instance-attribute

approx_model_size_mb: float | None = None

Metrics

Bases: BaseModel

Attributes:

Name Type Description
values dict[Instrument, dict[Metric, float]]
source Literal['mvsep'] | str | None

values class-attribute instance-attribute

values: dict[Instrument, dict[Metric, float]] = Field(
    default_factory=dict
)

source class-attribute instance-attribute

source: Literal['mvsep'] | str | None = None

Resource

Bases: BaseModel

Attributes:

Name Type Description
kind Literal['model_ckpt', 'huggingface', 'config_msst', 'colab', 'mvsep', 'other']
url str
digest str | None

kind instance-attribute

kind: Literal[
    "model_ckpt",
    "huggingface",
    "config_msst",
    "colab",
    "mvsep",
    "other",
]

url instance-attribute

url: str

digest class-attribute instance-attribute

digest: str | None = None

Comment

Bases: BaseModel

Attributes:

Name Type Description
content list[str]

Condensed informative points of the model (lowercase)

author str | None

content instance-attribute

content: list[str]

Condensed informative points of the model (lowercase)

author class-attribute instance-attribute

author: str | None = None

Registry

Bases: dict[Identifier, Model]

Methods:

Name Description
__get_pydantic_core_schema__
from_file

__get_pydantic_core_schema__ classmethod

__get_pydantic_core_schema__(
    source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema
Source code in src/splifft/config.py
362
363
364
365
366
@classmethod
def __get_pydantic_core_schema__(
    cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
    return core_schema.no_info_after_validator_function(cls, handler(dict[t.Identifier, Model]))

from_file classmethod

from_file(path: StrOrBytesPath) -> Registry
Source code in src/splifft/config.py
368
369
370
371
372
373
@classmethod
def from_file(cls, path: t.StrOrBytesPath) -> Registry:
    with open(path, "r") as f:
        data = f.read()
    ta = TypeAdapter(cls)
    return ta.validate_json(data)