Skip to content

Models

models

Source separation models.

Modules:

Name Description
basic_pitch

ICASSP 2022 Basic Pitch. Raw multi-stream outputs only, no symbolic decoding.

beat_this

Beat This! Beat Tracker.

bs_roformer

Band-Split RoPE Transformer

mdx23c

MDX23C.

pesto

PESTO: Pitch Estimation with Self-supervised Transposition-equivariant Objective.

utils

Classes:

Name Description
ModelParamsLike

A trait that must be implemented to be considered a model parameter.

HasSequenceWindowFrames

A trait for spectrogram sequence-labeling model parameters that declare

StemSelectionPlan

Optional model-specific plan for selective stem inference.

SupportsStemSelection
ModelMetadata

Metadata about a model, including its type, parameter class, and model class.

Attributes:

Name Type Description
ModelT
ModelParamsLikeT
StateDictTransform TypeAlias

ModelParamsLike

Bases: Protocol

A trait that must be implemented to be considered a model parameter. Note that input_type and output_type belong to a model's definition and does not allow modification via the configuration dictionary.

Attributes:

Name Type Description
chunk_size ChunkSize
output_stem_names tuple[ModelOutputStemName, ...]
input_channels ModelInputChannels
input_type ModelInputType
output_type ModelOutputType
inference_archetype InferenceArchetype

chunk_size instance-attribute

chunk_size: ChunkSize

output_stem_names instance-attribute

output_stem_names: tuple[ModelOutputStemName, ...]

input_channels property

input_channels: ModelInputChannels

input_type property

input_type: ModelInputType

output_type property

output_type: ModelOutputType

inference_archetype property

inference_archetype: InferenceArchetype

HasSequenceWindowFrames

Bases: Protocol

A trait for spectrogram sequence-labeling model parameters that declare the exact number of sequence frames produced by one inference waveform window.

Attributes:

Name Type Description
sequence_window_frames Gt0[int]

sequence_window_frames instance-attribute

sequence_window_frames: Gt0[int]

ModelT module-attribute

ModelT = TypeVar('ModelT', bound=Module)

ModelParamsLikeT module-attribute

ModelParamsLikeT = TypeVar(
    "ModelParamsLikeT", bound=ModelParamsLike
)

StateDictTransform module-attribute

StateDictTransform: TypeAlias = Callable[
    [dict[str, Tensor]], dict[str, Tensor]
]

StemSelectionPlan dataclass

StemSelectionPlan(
    model_params: ModelParamsLikeT,
    output_stem_names: tuple[ModelOutputStemName, ...],
    state_dict_transform: StateDictTransform | None = None,
)

Bases: Generic[ModelParamsLikeT]

Optional model-specific plan for selective stem inference.

Models can provide this plan to: - instantiate a stem-reduced parameter set, and/or - return a checkpoint state-dict transformer that drops/remaps unrelated heads.

output_stem_names defines the output ordering produced by the instantiated model after applying the plan.

Attributes:

Name Type Description
model_params ModelParamsLikeT
output_stem_names tuple[ModelOutputStemName, ...]
state_dict_transform StateDictTransform | None

model_params instance-attribute

model_params: ModelParamsLikeT

output_stem_names instance-attribute

output_stem_names: tuple[ModelOutputStemName, ...]

state_dict_transform class-attribute instance-attribute

state_dict_transform: StateDictTransform | None = None

SupportsStemSelection

Bases: Protocol[ModelParamsLikeT]

Methods:

Name Description
__splifft_stem_selection_plan__

__splifft_stem_selection_plan__ classmethod

__splifft_stem_selection_plan__(
    model_params: ModelParamsLikeT,
    output_stem_names: tuple[ModelOutputStemName, ...],
) -> StemSelectionPlan[ModelParamsLikeT]
Source code in src/splifft/models/__init__.py
68
69
70
71
72
73
@classmethod
def __splifft_stem_selection_plan__(
    cls,
    model_params: ModelParamsLikeT,
    output_stem_names: tuple[t.ModelOutputStemName, ...],
) -> StemSelectionPlan[ModelParamsLikeT]: ...

ModelMetadata dataclass

ModelMetadata(
    model_type: ModelType,
    params: type[ModelParamsLikeT],
    model: type[ModelT],
)

Bases: Generic[ModelT, ModelParamsLikeT]

Metadata about a model, including its type, parameter class, and model class.

Methods:

Name Description
from_module

Dynamically import a model named X and its parameter dataclass XParams under a

Attributes:

Name Type Description
model_type ModelType
params type[ModelParamsLikeT]
model type[ModelT]

model_type instance-attribute

model_type: ModelType

params instance-attribute

model instance-attribute

model: type[ModelT]

from_module classmethod

from_module(
    module_name: str,
    model_cls_name: str,
    *,
    model_type: ModelType,
    package: str | None = None,
) -> ModelMetadata[Module, ModelParamsLike]

Dynamically import a model named X and its parameter dataclass XParams under a given module name (e.g. splifft.models.bs_roformer).

Parameters:

Name Type Description Default
model_cls_name str

The name of the model class to import, e.g. BSRoformer.

required
module_name str

The name of the module to import, e.g. splifft.models.bs_roformer.

required
model_type ModelType

The type of the model, e.g. bs_roformer.

required
package str | None

The package to use as the anchor point from which to resolve the relative import. to an absolute import. This is only required when performing a relative import.

None
Source code in src/splifft/models/__init__.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
@classmethod
def from_module(
    cls,
    module_name: str,
    model_cls_name: str,
    *,
    model_type: t.ModelType,
    package: str | None = None,
) -> ModelMetadata[nn.Module, ModelParamsLike]:
    """
    Dynamically import a model named `X` and its parameter dataclass `XParams` under a
    given module name (e.g. `splifft.models.bs_roformer`).

    :param model_cls_name: The name of the model class to import, e.g. `BSRoformer`.
    :param module_name: The name of the module to import, e.g. `splifft.models.bs_roformer`.
    :param model_type: The type of the model, e.g. `bs_roformer`.
    :param package: The package to use as the anchor point from which to resolve the relative import.
    to an absolute import. This is only required when performing a relative import.
    """
    _loc = f"{module_name=} under {package=}"
    try:
        module = importlib.import_module(module_name, package)
    except ImportError as e:
        raise ValueError(f"failed to find or import module for {_loc}") from e

    params_cls_name = f"{model_cls_name}Params"
    model_cls = getattr(module, model_cls_name, None)
    params_cls = getattr(module, params_cls_name, None)
    if model_cls is None or params_cls is None:
        raise AttributeError(
            f"expected to find a class named `{params_cls_name}` in {_loc}, but it was not found."
        )

    return ModelMetadata(
        model_type=model_type,
        model=model_cls,
        params=params_cls,
    )

mdx23c

MDX23C.

See: https://arxiv.org/pdf/2306.09382

Classes:

Name Description
MDX23CParams
Upscale
Downscale
MDX23C

Functions:

Name Description
get_norm
get_act
build_tfc_tdf

MDX23CParams dataclass

MDX23CParams(
    chunk_size: ChunkSize,
    output_stem_names: tuple[ModelOutputStemName, ...],
    dim_f: Gt0[int],
    num_subbands: Gt0[int],
    num_scales: Gt0[int],
    scale: tuple[Gt0[int], ...],
    num_blocks_per_scale: Gt0[int],
    hidden_channels: Gt0[int],
    growth: Gt0[int],
    bottleneck_factor: Gt0[int],
    norm_type: Literal["BatchNorm", "InstanceNorm"]
    | str = "InstanceNorm",
    act_type: Literal["gelu", "relu", "elu"] | str = "gelu",
    stereo: bool = True,
)

Bases: ModelParamsLike

Attributes:

Name Type Description
chunk_size ChunkSize
output_stem_names tuple[ModelOutputStemName, ...]
dim_f Gt0[int]

The size of the frequency dimension fed into the network.

num_subbands Gt0[int]
num_scales Gt0[int]
scale tuple[Gt0[int], ...]

Downscaling factor per scale.

num_blocks_per_scale Gt0[int]
hidden_channels Gt0[int]

Base number of channels.

growth Gt0[int]

Channel growth per scale.

bottleneck_factor Gt0[int]
norm_type Literal['BatchNorm', 'InstanceNorm'] | str
act_type Literal['gelu', 'relu', 'elu'] | str
stereo bool
input_channels ModelInputChannels
input_type ModelInputType
output_type ModelOutputType
inference_archetype InferenceArchetype
chunk_size instance-attribute
chunk_size: ChunkSize
output_stem_names instance-attribute
output_stem_names: tuple[ModelOutputStemName, ...]
dim_f instance-attribute
dim_f: Gt0[int]

The size of the frequency dimension fed into the network. Usually smaller than n_fft // 2 + 1.

num_subbands instance-attribute
num_subbands: Gt0[int]
num_scales instance-attribute
num_scales: Gt0[int]
scale instance-attribute
scale: tuple[Gt0[int], ...]

Downscaling factor per scale.

num_blocks_per_scale instance-attribute
num_blocks_per_scale: Gt0[int]
hidden_channels instance-attribute
hidden_channels: Gt0[int]

Base number of channels.

growth instance-attribute
growth: Gt0[int]

Channel growth per scale.

bottleneck_factor instance-attribute
bottleneck_factor: Gt0[int]
norm_type class-attribute instance-attribute
norm_type: Literal["BatchNorm", "InstanceNorm"] | str = (
    "InstanceNorm"
)
act_type class-attribute instance-attribute
act_type: Literal['gelu', 'relu', 'elu'] | str = 'gelu'
stereo class-attribute instance-attribute
stereo: bool = True
input_channels property
input_channels: ModelInputChannels
input_type property
input_type: ModelInputType
output_type property
output_type: ModelOutputType
inference_archetype property
inference_archetype: InferenceArchetype

get_norm

get_norm(norm_type: str, channels: int) -> Module
Source code in src/splifft/models/mdx23c.py
61
62
63
64
65
66
67
68
69
def get_norm(norm_type: str, channels: int) -> nn.Module:
    if norm_type == "BatchNorm":
        return nn.BatchNorm2d(channels)
    elif norm_type == "InstanceNorm":
        return nn.InstanceNorm2d(channels, affine=True)
    elif "GroupNorm" in norm_type:
        g = int(norm_type.replace("GroupNorm", ""))
        return nn.GroupNorm(num_groups=g, num_channels=channels)
    return nn.Identity()

get_act

get_act(act_type: str) -> Module
Source code in src/splifft/models/mdx23c.py
72
73
74
75
76
77
78
79
80
81
82
83
def get_act(act_type: str) -> nn.Module:
    if act_type == "gelu":
        return nn.GELU()
    elif act_type == "relu":
        return nn.ReLU()
    elif act_type.startswith("elu"):
        try:
            alpha = float(act_type.replace("elu", ""))
        except ValueError:
            alpha = 1.0
        return nn.ELU(alpha)
    raise ValueError(f"unknown activation: {act_type}")

build_tfc_tdf

build_tfc_tdf(
    in_c: int,
    c: int,
    blocks_per_scale: int,
    f: int,
    bn: int,
    norm_type: str,
    act_type: str,
) -> TfcTdf
Source code in src/splifft/models/mdx23c.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def build_tfc_tdf(
    in_c: int,
    c: int,
    blocks_per_scale: int,
    f: int,
    bn: int,
    norm_type: str,
    act_type: str,
) -> TfcTdf:
    return TfcTdf(
        in_channels=in_c,
        out_channels=c,
        num_blocks=blocks_per_scale,
        f_bins=f,
        bottleneck_factor=bn,
        norm_factory=lambda channels: get_norm(norm_type, channels),
        act_factory=lambda: get_act(act_type),
    )

Upscale

Upscale(
    in_c: int,
    out_c: int,
    scale: tuple[int, int],
    norm_type: str,
    act_type: str,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
conv
Source code in src/splifft/models/mdx23c.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def __init__(
    self,
    in_c: int,
    out_c: int,
    scale: tuple[int, int],
    norm_type: str,
    act_type: str,
):
    super().__init__()
    self.conv = nn.Sequential(
        get_norm(norm_type, in_c),
        get_act(act_type),
        nn.ConvTranspose2d(
            in_channels=in_c,
            out_channels=out_c,
            kernel_size=scale,
            stride=scale,
            bias=False,
        ),
    )
conv instance-attribute
conv = Sequential(
    get_norm(norm_type, in_c),
    get_act(act_type),
    ConvTranspose2d(
        in_channels=in_c,
        out_channels=out_c,
        kernel_size=scale,
        stride=scale,
        bias=False,
    ),
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/mdx23c.py
128
129
def forward(self, x: Tensor) -> Tensor:
    return self.conv(x)  # type: ignore

Downscale

Downscale(
    in_c: int,
    out_c: int,
    scale: tuple[int, int],
    norm_type: str,
    act_type: str,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
conv
Source code in src/splifft/models/mdx23c.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def __init__(
    self,
    in_c: int,
    out_c: int,
    scale: tuple[int, int],
    norm_type: str,
    act_type: str,
):
    super().__init__()
    self.conv = nn.Sequential(
        get_norm(norm_type, in_c),
        get_act(act_type),
        nn.Conv2d(
            in_channels=in_c,
            out_channels=out_c,
            kernel_size=scale,
            stride=scale,
            bias=False,
        ),
    )
conv instance-attribute
conv = Sequential(
    get_norm(norm_type, in_c),
    get_act(act_type),
    Conv2d(
        in_channels=in_c,
        out_channels=out_c,
        kernel_size=scale,
        stride=scale,
        bias=False,
    ),
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/mdx23c.py
154
155
def forward(self, x: Tensor) -> Tensor:
    return self.conv(x)  # type: ignore

MDX23C

MDX23C(cfg: MDX23CParams)

Bases: Module, SupportsStemSelection[MDX23CParams]

Methods:

Name Description
cac2cws
cws2cac
forward

:param x: input spectrogram (B, F*S, T, 2)

__splifft_stem_selection_plan__

Slice final_conv.2 per requested stem.

Attributes:

Name Type Description
cfg
num_target_instruments
audio_channels
num_subbands
first_conv
encoder_blocks
bottleneck_block
decoder_blocks
final_conv
Source code in src/splifft/models/mdx23c.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def __init__(self, cfg: MDX23CParams):
    super().__init__()
    self.cfg = cfg
    if len(cfg.scale) != 2:
        raise ValueError(f"expected `scale` to have 2 elements, got {cfg.scale}")
    scale_2d = (cfg.scale[0], cfg.scale[1])
    self.num_target_instruments = len(cfg.output_stem_names)
    self.audio_channels = 2 if cfg.stereo else 1
    self.num_subbands = cfg.num_subbands

    dim_c = self.num_subbands * self.audio_channels * 2
    n = cfg.num_scales
    blocks_per_scale = cfg.num_blocks_per_scale
    c = cfg.hidden_channels
    g = cfg.growth
    bn = cfg.bottleneck_factor
    f = cfg.dim_f // self.num_subbands

    self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)

    self.encoder_blocks = nn.ModuleList()
    for _ in range(n):
        block = nn.Module()
        block.tfc_tdf = build_tfc_tdf(
            c, c, blocks_per_scale, f, bn, cfg.norm_type, cfg.act_type
        )
        block.downscale = Downscale(c, c + g, scale_2d, cfg.norm_type, cfg.act_type)
        f = f // scale_2d[1]
        c += g
        self.encoder_blocks.append(block)

    self.bottleneck_block = build_tfc_tdf(
        c, c, blocks_per_scale, f, bn, cfg.norm_type, cfg.act_type
    )

    self.decoder_blocks = nn.ModuleList()
    for _ in range(n):
        block = nn.Module()
        block.upscale = Upscale(c, c - g, scale_2d, cfg.norm_type, cfg.act_type)
        f = f * scale_2d[1]
        c -= g
        block.tfc_tdf = build_tfc_tdf(
            2 * c, c, blocks_per_scale, f, bn, cfg.norm_type, cfg.act_type
        )
        self.decoder_blocks.append(block)

    self.final_conv = nn.Sequential(
        nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
        get_act(cfg.act_type),
        nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False),
    )
cfg instance-attribute
cfg = cfg
num_target_instruments instance-attribute
num_target_instruments = len(output_stem_names)
audio_channels instance-attribute
audio_channels = 2 if stereo else 1
num_subbands instance-attribute
num_subbands = num_subbands
first_conv instance-attribute
first_conv = Conv2d(dim_c, c, 1, 1, 0, bias=False)
encoder_blocks instance-attribute
encoder_blocks = ModuleList()
bottleneck_block instance-attribute
bottleneck_block = build_tfc_tdf(
    c, c, blocks_per_scale, f, bn, norm_type, act_type
)
decoder_blocks instance-attribute
decoder_blocks = ModuleList()
final_conv instance-attribute
final_conv = Sequential(
    Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
    get_act(act_type),
    Conv2d(
        c,
        num_target_instruments * dim_c,
        1,
        1,
        0,
        bias=False,
    ),
)
cac2cws
cac2cws(x: Tensor) -> Tensor
Source code in src/splifft/models/mdx23c.py
211
212
def cac2cws(self, x: Tensor) -> Tensor:
    return rearrange(x, "b c (k f) t -> b (c k) f t", k=self.num_subbands)
cws2cac
cws2cac(x: Tensor) -> Tensor
Source code in src/splifft/models/mdx23c.py
214
215
def cws2cac(self, x: Tensor) -> Tensor:
    return rearrange(x, "b (c k) f t -> b c (k f) t", k=self.num_subbands)
forward
forward(x: Tensor) -> Tensor

Parameters:

Name Type Description Default
x Tensor

input spectrogram (B, F*S, T, 2)

required

Returns:

Type Description
Tensor

output spectrogram (B, N, F*S, T, 2)

Source code in src/splifft/models/mdx23c.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def forward(self, x: Tensor) -> Tensor:
    """
    :param x: input spectrogram (B, F*S, T, 2)
    :return: output spectrogram (B, N, F*S, T, 2)
    """
    b, fs, t, ri = x.shape
    if ri != 2:
        raise ValueError(f"expected final complex axis of size 2, got {ri}")
    f_full = fs // self.audio_channels
    if fs % self.audio_channels != 0:
        raise ValueError(
            f"expected frequency-channel axis divisible by audio channels ({self.audio_channels}), "
            f"got {fs}"
        )

    x_in = rearrange(
        x,
        "b (f s) t ri -> b (s ri) f t",
        s=self.audio_channels,
        ri=2,
    )
    x_in = x_in[..., : self.cfg.dim_f, :]
    mix = x_in = self.cac2cws(x_in)
    first_conv_out = x_in = self.first_conv(x_in)
    x_in = rearrange(x_in, "b c f t -> b c t f")

    encoder_outputs = []
    for block in self.encoder_blocks:
        x_in = block.tfc_tdf(x_in)  # type: ignore
        encoder_outputs.append(x_in)
        x_in = block.downscale(x_in)  # type: ignore

    x_in = self.bottleneck_block(x_in)

    for block in self.decoder_blocks:
        x_in = block.upscale(x_in)  # type: ignore
        x_in = torch.cat([x_in, encoder_outputs.pop()], 1)
        x_in = block.tfc_tdf(x_in)  # type: ignore

    x_in = rearrange(x_in, "b c t f -> b c f t")
    x_in = x_in * first_conv_out
    x_in = self.final_conv(torch.cat([mix, x_in], 1))
    x_in = self.cws2cac(x_in)

    x_in = rearrange(
        x_in,
        "b (n c) f t -> b n c f t",
        n=self.num_target_instruments,
    )

    if f_full > self.cfg.dim_f:
        pad_size = f_full - self.cfg.dim_f
        x_in = torch.nn.functional.pad(x_in, (0, 0, 0, pad_size))

    x_in = rearrange(
        x_in,
        "b n (s ri) f t -> b n (f s) t ri",
        s=self.audio_channels,
        ri=2,
    )

    return x_in
__splifft_stem_selection_plan__ classmethod
__splifft_stem_selection_plan__(
    model_params: MDX23CParams,
    output_stem_names: tuple[ModelOutputStemName, ...],
) -> StemSelectionPlan[MDX23CParams]

Slice final_conv.2 per requested stem.

Source code in src/splifft/models/mdx23c.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
@classmethod
def __splifft_stem_selection_plan__(
    cls,
    model_params: MDX23CParams,
    output_stem_names: tuple[t.ModelOutputStemName, ...],
) -> StemSelectionPlan[MDX23CParams]:
    """Slice `final_conv.2` per requested stem."""

    full_stem_names = tuple(model_params.output_stem_names)
    requested_stem_names = tuple(output_stem_names)
    if requested_stem_names == full_stem_names:
        return StemSelectionPlan(
            model_params=model_params,
            output_stem_names=full_stem_names,
        )

    selected_stem_indices = tuple(full_stem_names.index(name) for name in requested_stem_names)
    dim_c = model_params.num_subbands * (2 if model_params.stereo else 1) * 2

    def state_dict_transform(state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
        key = "final_conv.2.weight"
        if key not in state_dict:
            return state_dict

        transformed = dict(state_dict)
        weight = transformed[key]
        grouped = weight.reshape(len(full_stem_names), dim_c, *weight.shape[1:])
        selected = grouped.index_select(
            0,
            torch.tensor(selected_stem_indices, device=weight.device),
        )
        transformed[key] = selected.reshape(
            len(requested_stem_names) * dim_c, *weight.shape[1:]
        )
        return transformed

    return StemSelectionPlan(
        model_params=replace(model_params, output_stem_names=requested_stem_names),
        output_stem_names=requested_stem_names,
        state_dict_transform=state_dict_transform,
    )

beat_this

Beat This! Beat Tracker.

Classes:

Name Description
BeatThisParams
PartialFTTransformer
SumHead
Head
BeatThis

BeatThisParams dataclass

BeatThisParams(
    chunk_size: ChunkSize,
    output_stem_names: tuple[ModelOutputStemName, ...],
    sequence_window_frames: Gt0[int],
    spect_dim: Gt0[int] = 128,
    transformer_dim: Gt0[int] = 512,
    ff_mult: Gt0[int] = 4,
    n_layers: Gt0[int] = 6,
    head_dim: Gt0[int] = 32,
    stem_dim: Gt0[int] = 32,
    dropout_frontend: Dropout = 0.1,
    dropout_transformer: Dropout = 0.2,
    sum_head: bool = True,
    partial_transformers: bool = True,
    rotary_embed_dtype: TorchDtype | None = None,
    transformer_residual_dtype: TorchDtype | None = None,
    log_mel_hop_length: HopSize = 441,
)

Bases: ModelParamsLike, HasSequenceWindowFrames

Attributes:

Name Type Description
chunk_size ChunkSize
output_stem_names tuple[ModelOutputStemName, ...]
sequence_window_frames Gt0[int]
spect_dim Gt0[int]
transformer_dim Gt0[int]
ff_mult Gt0[int]
n_layers Gt0[int]
head_dim Gt0[int]
stem_dim Gt0[int]
dropout_frontend Dropout
dropout_transformer Dropout
sum_head bool
partial_transformers bool
rotary_embed_dtype TorchDtype | None
transformer_residual_dtype TorchDtype | None
log_mel_hop_length HopSize

The hop length of the log mel spectrogram.

input_channels ModelInputChannels
input_type ModelInputType
output_type ModelOutputType
inference_archetype InferenceArchetype
chunk_size instance-attribute
chunk_size: ChunkSize
output_stem_names instance-attribute
output_stem_names: tuple[ModelOutputStemName, ...]
sequence_window_frames instance-attribute
sequence_window_frames: Gt0[int]
spect_dim class-attribute instance-attribute
spect_dim: Gt0[int] = 128
transformer_dim class-attribute instance-attribute
transformer_dim: Gt0[int] = 512
ff_mult class-attribute instance-attribute
ff_mult: Gt0[int] = 4
n_layers class-attribute instance-attribute
n_layers: Gt0[int] = 6
head_dim class-attribute instance-attribute
head_dim: Gt0[int] = 32
stem_dim class-attribute instance-attribute
stem_dim: Gt0[int] = 32
dropout_frontend class-attribute instance-attribute
dropout_frontend: Dropout = 0.1
dropout_transformer class-attribute instance-attribute
dropout_transformer: Dropout = 0.2
sum_head class-attribute instance-attribute
sum_head: bool = True
partial_transformers class-attribute instance-attribute
partial_transformers: bool = True
rotary_embed_dtype class-attribute instance-attribute
rotary_embed_dtype: TorchDtype | None = None
transformer_residual_dtype class-attribute instance-attribute
transformer_residual_dtype: TorchDtype | None = None
log_mel_hop_length class-attribute instance-attribute
log_mel_hop_length: HopSize = 441

The hop length of the log mel spectrogram.

Warning

This must match the hop_length in the LogMelConfig to ensure the rotary embeddings are sized correctly for the sequence length.

input_channels property
input_channels: ModelInputChannels
input_type property
input_type: ModelInputType
output_type property
output_type: ModelOutputType
inference_archetype property
inference_archetype: InferenceArchetype

PartialFTTransformer

PartialFTTransformer(
    dim: int,
    dim_head: int,
    n_head: int,
    rotary_embed_f: RotaryEmbedding,
    rotary_embed_t: RotaryEmbedding,
    dropout: float,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
attnF
ffF
attnT
ffT
Source code in src/splifft/models/beat_this.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def __init__(
    self,
    dim: int,
    dim_head: int,
    n_head: int,
    rotary_embed_f: RotaryEmbedding,
    rotary_embed_t: RotaryEmbedding,
    dropout: float,
):
    super().__init__()
    self.attnF = Attention(
        dim, heads=n_head, dim_head=dim_head, dropout=dropout, rotary_embed=rotary_embed_f
    )
    self.ffF = FeedForward(dim, dropout=dropout)
    self.attnT = Attention(
        dim, heads=n_head, dim_head=dim_head, dropout=dropout, rotary_embed=rotary_embed_t
    )
    self.ffT = FeedForward(dim, dropout=dropout)
attnF instance-attribute
attnF = Attention(
    dim,
    heads=n_head,
    dim_head=dim_head,
    dropout=dropout,
    rotary_embed=rotary_embed_f,
)
ffF instance-attribute
ffF = FeedForward(dim, dropout=dropout)
attnT instance-attribute
attnT = Attention(
    dim,
    heads=n_head,
    dim_head=dim_head,
    dropout=dropout,
    rotary_embed=rotary_embed_t,
)
ffT instance-attribute
ffT = FeedForward(dim, dropout=dropout)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/beat_this.py
 96
 97
 98
 99
100
101
102
103
104
105
def forward(self, x: Tensor) -> Tensor:
    b = len(x)
    x = rearrange(x, "b c f t -> (b t) f c")
    x = x + self.attnF(x)
    x = x + self.ffF(x)
    x = rearrange(x, "(b t) f c -> (b f) t c", b=b)
    x = x + self.attnT(x)
    x = x + self.ffT(x)
    x = rearrange(x, "(b f) t c -> b c f t", b=b)
    return x

SumHead

SumHead(input_dim: int)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
beat_downbeat_lin
Source code in src/splifft/models/beat_this.py
109
110
111
def __init__(self, input_dim: int):
    super().__init__()
    self.beat_downbeat_lin = nn.Linear(input_dim, 2)
beat_downbeat_lin instance-attribute
beat_downbeat_lin = Linear(input_dim, 2)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/beat_this.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def forward(self, x: Tensor) -> Tensor:
    beat_downbeat = self.beat_downbeat_lin(x)
    beat, downbeat = rearrange(beat_downbeat, "b t c -> c b t", c=2)

    # aggregate beats and downbeats prediction
    # autocast to float16 disabled to avoid numerical issues causing NaNs
    device_type = beat.device.type
    if device_type != "mps" and torch.amp.is_autocast_available(device_type):  # type: ignore
        disable_autocast = torch.autocast(device_type, enabled=False)
    else:
        disable_autocast = contextlib.nullcontext()

    with disable_autocast:
        beat = beat.float() + downbeat.float()
    return torch.stack([beat, downbeat], dim=0)  # (2, b, t)

Head

Head(input_dim: int)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
beat_downbeat_lin
Source code in src/splifft/models/beat_this.py
131
132
133
def __init__(self, input_dim: int):
    super().__init__()
    self.beat_downbeat_lin = nn.Linear(input_dim, 2)
beat_downbeat_lin instance-attribute
beat_downbeat_lin = Linear(input_dim, 2)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/beat_this.py
135
136
137
138
def forward(self, x: Tensor) -> Tensor:
    beat_downbeat = self.beat_downbeat_lin(x)
    beat, downbeat = rearrange(beat_downbeat, "b t c -> c b t", c=2)
    return torch.stack([beat, downbeat], dim=0)

BeatThis

BeatThis(cfg: BeatThisParams)

Bases: Module

Methods:

Name Description
forward

:param x: Input spectrogram (B, T, F)

Attributes:

Name Type Description
spect_dim
frontend
transformer_blocks
task_heads
Source code in src/splifft/models/beat_this.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def __init__(self, cfg: BeatThisParams):
    super().__init__()
    self.spect_dim = cfg.spect_dim

    max_frames = cfg.chunk_size // cfg.log_mel_hop_length
    rotary_embed_t = RotaryEmbedding(
        seq_len=max_frames,  # by default 1500 frames * 441 = 661500 samples
        dim_head=cfg.head_dim,
        dtype=cfg.rotary_embed_dtype,
    )

    # NOTE: removed rearrange from original impl to standardise input as (B, F, T)
    stem = nn.Sequential(
        OrderedDict(
            bn1d=nn.BatchNorm1d(cfg.spect_dim),
            add_channel=Rearrange("b f t -> b 1 f t"),
            conv2d=nn.Conv2d(
                in_channels=1,
                out_channels=cfg.stem_dim,
                kernel_size=(4, 3),
                stride=(4, 1),
                padding=(0, 1),
                bias=False,
            ),
            bn2d=nn.BatchNorm2d(cfg.stem_dim),
            activation=nn.GELU(),
        )
    )

    spect_dim = cfg.spect_dim // 4
    dim = cfg.stem_dim
    frontend_blocks = []
    for _ in range(3):
        rotary_embed_f = RotaryEmbedding(
            seq_len=spect_dim,
            dim_head=cfg.head_dim,
            dtype=cfg.rotary_embed_dtype,
        )
        frontend_blocks.append(
            nn.Sequential(
                OrderedDict(
                    partial=(
                        PartialFTTransformer(
                            dim=dim,
                            dim_head=cfg.head_dim,
                            n_head=dim // cfg.head_dim,
                            rotary_embed_f=rotary_embed_f,
                            rotary_embed_t=rotary_embed_t,
                            dropout=cfg.dropout_frontend,
                        )
                        if cfg.partial_transformers
                        else nn.Identity()
                    ),
                    conv2d=nn.Conv2d(
                        in_channels=dim,
                        out_channels=dim * 2,
                        kernel_size=(2, 3),
                        stride=(2, 1),
                        padding=(0, 1),
                        bias=False,
                    ),
                    norm=nn.BatchNorm2d(dim * 2),
                    activation=nn.GELU(),
                )
            )
        )
        dim *= 2
        spect_dim //= 2

    self.frontend = nn.Sequential(
        OrderedDict(
            stem=stem,
            blocks=nn.Sequential(*frontend_blocks),
            concat=Rearrange("b c f t -> b t (c f)"),
            linear=nn.Linear(dim * spect_dim, cfg.transformer_dim),
        )
    )

    n_heads = cfg.transformer_dim // cfg.head_dim
    # TODO check if this is really equivalent
    self.transformer_blocks = Transformer(
        dim=cfg.transformer_dim,
        depth=cfg.n_layers,
        dim_head=cfg.head_dim,
        heads=n_heads,
        attn_dropout=cfg.dropout_transformer,
        ff_dropout=cfg.dropout_transformer,
        ff_mult=cfg.ff_mult,
        norm_output=True,
        rotary_embed=rotary_embed_t,
        transformer_residual_dtype=cfg.transformer_residual_dtype,
    )

    if cfg.sum_head:
        self.task_heads = SumHead(cfg.transformer_dim)
    else:
        self.task_heads = Head(cfg.transformer_dim)
spect_dim instance-attribute
spect_dim = spect_dim
frontend instance-attribute
frontend = Sequential(
    OrderedDict(
        stem=stem,
        blocks=Sequential(*frontend_blocks),
        concat=Rearrange("b c f t -> b t (c f)"),
        linear=Linear(dim * spect_dim, transformer_dim),
    )
)
transformer_blocks instance-attribute
transformer_blocks = Transformer(
    dim=transformer_dim,
    depth=n_layers,
    dim_head=head_dim,
    heads=n_heads,
    attn_dropout=dropout_transformer,
    ff_dropout=dropout_transformer,
    ff_mult=ff_mult,
    norm_output=True,
    rotary_embed=rotary_embed_t,
    transformer_residual_dtype=transformer_residual_dtype,
)
task_heads instance-attribute
task_heads = SumHead(transformer_dim)
forward
forward(x: Tensor) -> Tensor

Parameters:

Name Type Description Default
x Tensor

Input spectrogram (B, T, F)

required

Returns:

Type Description
Tensor

Logits (2, B, T) -> [Beats, Downbeats]

Source code in src/splifft/models/beat_this.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def forward(self, x: Tensor) -> Tensor:
    """
    :param x: Input spectrogram (B, T, F)
    :return: Logits (2, B, T) -> [Beats, Downbeats]
    """
    if x.ndim != 3:
        raise ValueError(f"expected 3D spectrogram input, got shape={tuple(x.shape)}")
    if x.shape[-1] != self.spect_dim:
        raise ValueError(
            f"expected beat_this input shape `(B,T,{self.spect_dim})`, got {tuple(x.shape)}"
        )

    x = x.transpose(1, 2)
    x = self.frontend(x)
    x = self.transformer_blocks(x)
    x = self.task_heads(x)
    return x

pesto

PESTO: Pitch Estimation with Self-supervised Transposition-equivariant Objective.

See: https://github.com/SonyCSLParis/pesto, https://arxiv.org/abs/2309.02265

Classes:

Name Description
PestoParams
ToeplitzLinear
Resnet1d
ConfidenceClassifier

Frame-level voiced/unvoiced confidence head.

Pesto

PESTO inference head over externally computed HCQT features.

Functions:

Name Description
reduce_activations

Reduce per-bin probabilities to scalar pitch per frame.

Attributes:

Name Type Description
PestoReduction

PestoReduction module-attribute

PestoReduction = Literal['argmax', 'mean', 'alwa']

PestoParams dataclass

PestoParams(
    chunk_size: ChunkSize,
    output_stem_names: tuple[ModelOutputStemName, ...],
    sequence_window_frames: Gt0[int],
    reduction: PestoReduction = "alwa",
    convert_to_freq: bool = True,
    crop_freq_bins_bottom: Ge0[int] = 16,
    crop_freq_bins_top: Ge0[int] = 16,
    n_chan_input: Gt0[int] = 1,
    n_chan_layers: tuple[Gt0[int], ...] = (
        40,
        30,
        30,
        10,
        3,
    ),
    n_prefilt_layers: Gt0[int] = 3,
    prefilt_kernel_size: Gt0[int] = 39,
    residual: bool = True,
    n_bins_in: Gt0[int] = 219,
    output_dim: Gt0[int] = 384,
    activation_fn: Literal[
        "relu", "silu", "leaky"
    ] = "leaky",
    a_lrelu: Ge0[float] = 0.3,
    p_dropout: Dropout = 0.2,
    bins_per_semitone: Gt0[int] = 3,
)

Bases: ModelParamsLike, HasSequenceWindowFrames

Attributes:

Name Type Description
chunk_size ChunkSize
output_stem_names tuple[ModelOutputStemName, ...]
sequence_window_frames Gt0[int]
reduction PestoReduction
convert_to_freq bool
crop_freq_bins_bottom Ge0[int]
crop_freq_bins_top Ge0[int]
n_chan_input Gt0[int]
n_chan_layers tuple[Gt0[int], ...]
n_prefilt_layers Gt0[int]
prefilt_kernel_size Gt0[int]
residual bool
n_bins_in Gt0[int]
output_dim Gt0[int]
activation_fn Literal['relu', 'silu', 'leaky']
a_lrelu Ge0[float]
p_dropout Dropout
bins_per_semitone Gt0[int]
input_channels ModelInputChannels
input_type ModelInputType
output_type ModelOutputType
inference_archetype InferenceArchetype
chunk_size instance-attribute
chunk_size: ChunkSize
output_stem_names instance-attribute
output_stem_names: tuple[ModelOutputStemName, ...]
sequence_window_frames instance-attribute
sequence_window_frames: Gt0[int]
reduction class-attribute instance-attribute
reduction: PestoReduction = 'alwa'
convert_to_freq class-attribute instance-attribute
convert_to_freq: bool = True
crop_freq_bins_bottom class-attribute instance-attribute
crop_freq_bins_bottom: Ge0[int] = 16
crop_freq_bins_top class-attribute instance-attribute
crop_freq_bins_top: Ge0[int] = 16
n_chan_input class-attribute instance-attribute
n_chan_input: Gt0[int] = 1
n_chan_layers class-attribute instance-attribute
n_chan_layers: tuple[Gt0[int], ...] = (40, 30, 30, 10, 3)
n_prefilt_layers class-attribute instance-attribute
n_prefilt_layers: Gt0[int] = 3
prefilt_kernel_size class-attribute instance-attribute
prefilt_kernel_size: Gt0[int] = 39
residual class-attribute instance-attribute
residual: bool = True
n_bins_in class-attribute instance-attribute
n_bins_in: Gt0[int] = 219
output_dim class-attribute instance-attribute
output_dim: Gt0[int] = 384
activation_fn class-attribute instance-attribute
activation_fn: Literal['relu', 'silu', 'leaky'] = 'leaky'
a_lrelu class-attribute instance-attribute
a_lrelu: Ge0[float] = 0.3
p_dropout class-attribute instance-attribute
p_dropout: Dropout = 0.2
bins_per_semitone class-attribute instance-attribute
bins_per_semitone: Gt0[int] = 3
input_channels property
input_channels: ModelInputChannels
input_type property
input_type: ModelInputType
output_type property
output_type: ModelOutputType
inference_archetype property
inference_archetype: InferenceArchetype

ToeplitzLinear

ToeplitzLinear(in_features: int, out_features: int)

Bases: Conv1d

Methods:

Name Description
forward
Source code in src/splifft/models/pesto.py
70
71
72
73
74
75
76
77
def __init__(self, in_features: int, out_features: int):
    super().__init__(
        in_channels=1,
        out_channels=1,
        kernel_size=in_features + out_features - 1,
        padding=out_features - 1,
        bias=False,
    )
forward
forward(input: Tensor) -> Tensor
Source code in src/splifft/models/pesto.py
79
80
def forward(self, input: torch.Tensor) -> torch.Tensor:
    return super().forward(input.unsqueeze(-2)).squeeze(-2)

Resnet1d

Resnet1d(
    *,
    n_chan_input: int = 1,
    n_chan_layers: tuple[int, ...] = (40, 30, 30, 10, 3),
    n_prefilt_layers: int = 3,
    prefilt_kernel_size: int = 39,
    residual: bool = True,
    n_bins_in: int = 219,
    output_dim: int = 384,
    activation_fn: Literal[
        "relu", "silu", "leaky"
    ] = "leaky",
    a_lrelu: float = 0.3,
    p_dropout: float = 0.2,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
layernorm
conv1
n_prefilt_layers
prefilt_layers
residual
conv_layers
flatten
fc
final_norm
Source code in src/splifft/models/pesto.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def __init__(
    self,
    *,
    n_chan_input: int = 1,
    n_chan_layers: tuple[int, ...] = (40, 30, 30, 10, 3),
    n_prefilt_layers: int = 3,
    prefilt_kernel_size: int = 39,
    residual: bool = True,
    n_bins_in: int = 219,
    output_dim: int = 384,
    activation_fn: Literal["relu", "silu", "leaky"] = "leaky",
    a_lrelu: float = 0.3,
    p_dropout: float = 0.2,
):
    super().__init__()

    activation_layer: Callable[[], nn.Module]
    if activation_fn == "relu":
        activation_layer = nn.ReLU
    elif activation_fn == "silu":
        activation_layer = nn.SiLU
    elif activation_fn == "leaky":
        activation_layer = partial(nn.LeakyReLU, negative_slope=a_lrelu)
    else:
        assert_never(activation_fn)

    n_ch = list(n_chan_layers)
    if len(n_ch) < 5:
        n_ch.append(1)

    self.layernorm = nn.LayerNorm(normalized_shape=[n_chan_input, n_bins_in])

    prefilt_padding = prefilt_kernel_size // 2
    self.conv1 = nn.Sequential(
        nn.Conv1d(
            in_channels=n_chan_input,
            out_channels=n_ch[0],
            kernel_size=prefilt_kernel_size,
            padding=prefilt_padding,
            stride=1,
        ),
        activation_layer(),
        nn.Dropout(p=p_dropout),
    )
    self.n_prefilt_layers = n_prefilt_layers
    self.prefilt_layers = nn.ModuleList(
        [
            nn.Sequential(
                nn.Conv1d(
                    in_channels=n_ch[0],
                    out_channels=n_ch[0],
                    kernel_size=prefilt_kernel_size,
                    padding=prefilt_padding,
                    stride=1,
                ),
                activation_layer(),
                nn.Dropout(p=p_dropout),
            )
            for _ in range(n_prefilt_layers - 1)
        ]
    )
    self.residual = residual

    conv_layers: list[nn.Module] = []
    for i in range(len(n_ch) - 1):
        conv_layers.extend(
            [
                nn.Conv1d(
                    in_channels=n_ch[i],
                    out_channels=n_ch[i + 1],
                    kernel_size=1,
                    padding=0,
                    stride=1,
                ),
                activation_layer(),
                nn.Dropout(p=p_dropout),
            ]
        )
    self.conv_layers = nn.Sequential(*conv_layers)

    self.flatten = nn.Flatten(start_dim=1)
    self.fc = ToeplitzLinear(n_bins_in * n_ch[-1], output_dim)
    self.final_norm = nn.Softmax(dim=-1)
layernorm instance-attribute
layernorm = LayerNorm(
    normalized_shape=[n_chan_input, n_bins_in]
)
conv1 instance-attribute
conv1 = Sequential(
    Conv1d(
        in_channels=n_chan_input,
        out_channels=n_ch[0],
        kernel_size=prefilt_kernel_size,
        padding=prefilt_padding,
        stride=1,
    ),
    activation_layer(),
    Dropout(p=p_dropout),
)
n_prefilt_layers instance-attribute
n_prefilt_layers = n_prefilt_layers
prefilt_layers instance-attribute
prefilt_layers = ModuleList(
    [
        (
            Sequential(
                Conv1d(
                    in_channels=n_ch[0],
                    out_channels=n_ch[0],
                    kernel_size=prefilt_kernel_size,
                    padding=prefilt_padding,
                    stride=1,
                ),
                activation_layer(),
                Dropout(p=p_dropout),
            )
        )
        for _ in (range(n_prefilt_layers - 1))
    ]
)
residual instance-attribute
residual = residual
conv_layers instance-attribute
conv_layers = Sequential(*conv_layers)
flatten instance-attribute
flatten = Flatten(start_dim=1)
fc instance-attribute
fc = ToeplitzLinear(n_bins_in * n_ch[-1], output_dim)
final_norm instance-attribute
final_norm = Softmax(dim=-1)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/pesto.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.layernorm(x)

    x = self.conv1(x)
    for i in range(0, self.n_prefilt_layers - 1):
        prefilt_layer = self.prefilt_layers[i]
        if self.residual:
            x = prefilt_layer(x) + x
        else:
            x = prefilt_layer(x)

    x = self.conv_layers(x)
    x = self.flatten(x)
    y_pred = self.fc(x)
    return cast(torch.Tensor, self.final_norm(y_pred))

ConfidenceClassifier

ConfidenceClassifier()

Bases: Module

Frame-level voiced/unvoiced confidence head.

Methods:

Name Description
forward

Attributes:

Name Type Description
conv
linear
Source code in src/splifft/models/pesto.py
188
189
190
191
def __init__(self) -> None:
    super().__init__()
    self.conv = nn.Conv1d(1, 1, 39, stride=3)
    self.linear = nn.Linear(72, 1)
conv instance-attribute
conv = Conv1d(1, 1, 39, stride=3)
linear instance-attribute
linear = Linear(72, 1)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/pesto.py
193
194
195
196
197
198
199
def forward(self, x: torch.Tensor) -> torch.Tensor:
    geometric_mean = x.log().mean(dim=-1, keepdim=True).exp()
    arithmetic_mean = x.mean(dim=-1, keepdim=True).clamp_(min=1e-8)
    flatness = geometric_mean / arithmetic_mean

    x = F.relu(self.conv(x.unsqueeze(1)).squeeze(1))
    return torch.sigmoid(self.linear(torch.cat((x, flatness), dim=-1))).squeeze(-1)

reduce_activations

reduce_activations(
    activations: Tensor, reduction: PestoReduction = "alwa"
) -> Tensor

Reduce per-bin probabilities to scalar pitch per frame.

Source code in src/splifft/models/pesto.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def reduce_activations(
    activations: torch.Tensor,
    reduction: PestoReduction = "alwa",
) -> torch.Tensor:
    """Reduce per-bin probabilities to scalar pitch per frame."""
    device = activations.device
    num_bins = int(activations.size(-1))

    bps, rem = divmod(num_bins, 128)
    if rem != 0:
        raise ValueError(f"expected output_dim to be divisible by 128, got {num_bins}")

    if reduction == "argmax":
        pred = activations.argmax(dim=-1)
        return pred.float() / bps

    all_pitches = torch.arange(num_bins, dtype=torch.float, device=device).div_(bps)
    if reduction == "mean":
        return torch.matmul(activations, all_pitches)

    if reduction == "alwa":
        center_bin = activations.argmax(dim=-1, keepdim=True)
        window = torch.arange(1, 2 * bps, device=device) - bps
        indices = (center_bin + window).clamp_(min=0, max=num_bins - 1)
        cropped_activations = activations.gather(-1, indices)
        cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(-1, indices)
        return (cropped_activations * cropped_pitches).sum(dim=-1) / cropped_activations.sum(dim=-1)

    assert_never(reduction)

Pesto

Pesto(cfg: PestoParams)

Bases: Module

PESTO inference head over externally computed HCQT features.

Input contract: tensor of shape (batch, time, feature_dim) where feature_dim = harmonics * freq_bins in dB log-magnitude HCQT.

Methods:

Name Description
forward

Attributes:

Name Type Description
cfg
encoder
confidence
Source code in src/splifft/models/pesto.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def __init__(self, cfg: PestoParams):
    super().__init__()
    self.cfg = cfg
    self.encoder = Resnet1d(
        n_chan_input=cfg.n_chan_input,
        n_chan_layers=cfg.n_chan_layers,
        n_prefilt_layers=cfg.n_prefilt_layers,
        prefilt_kernel_size=cfg.prefilt_kernel_size,
        residual=cfg.residual,
        n_bins_in=cfg.n_bins_in,
        output_dim=cfg.output_dim,
        activation_fn=cfg.activation_fn,
        a_lrelu=cfg.a_lrelu,
        p_dropout=cfg.p_dropout,
    )
    self.confidence = ConfidenceClassifier()

    self.register_buffer("shift", torch.zeros((), dtype=torch.float), persistent=True)
cfg instance-attribute
cfg = cfg
encoder instance-attribute
encoder = Resnet1d(
    n_chan_input=n_chan_input,
    n_chan_layers=n_chan_layers,
    n_prefilt_layers=n_prefilt_layers,
    prefilt_kernel_size=prefilt_kernel_size,
    residual=residual,
    n_bins_in=n_bins_in,
    output_dim=output_dim,
    activation_fn=activation_fn,
    a_lrelu=a_lrelu,
    p_dropout=p_dropout,
)
confidence instance-attribute
confidence = ConfidenceClassifier()
forward
forward(x: Tensor) -> tuple[Tensor, ...]
Source code in src/splifft/models/pesto.py
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]:
    if x.ndim != 3:
        raise ValueError(
            f"expected `(batch,time,feature_dim)` input, got shape={tuple(x.shape)}"
        )

    total_bins = (
        self.cfg.n_bins_in + self.cfg.crop_freq_bins_bottom + self.cfg.crop_freq_bins_top
    )
    expected_feature_dim = self.cfg.n_chan_input * total_bins
    if x.shape[-1] != expected_feature_dim:
        raise ValueError(
            "invalid PESTO feature dimension: "
            f"expected {expected_feature_dim} (= n_chan_input * (n_bins_in + crop_bottom + crop_top)), "
            f"got {x.shape[-1]}"
        )

    batch_size, num_frames, _feature_dim = x.shape
    x = x.view(batch_size, num_frames, self.cfg.n_chan_input, total_bins)
    x = x.flatten(0, 1)  # (B*T, H, F)

    # match reference implementation: convert dB back to linear energy and derive
    # confidence + volume from pre-crop HCQT bins.
    energy = x.mul(torch.log(torch.tensor(10.0, device=x.device, dtype=x.dtype)) / 10.0).exp()
    confidence_energy = energy.mean(dim=1)
    volume = energy.sum(dim=-1).mean(dim=-1)
    confidence = self.confidence(confidence_energy)

    x = self._crop_cqt(x)
    activations = self.encoder(x)

    activations = activations.view(batch_size, num_frames, activations.size(-1))
    confidence = confidence.view(batch_size, num_frames)
    volume = volume.view(batch_size, num_frames)

    shift_tensor = cast(torch.Tensor, self.shift)
    shift_bins = int(torch.round(shift_tensor * self.cfg.bins_per_semitone).item())
    activations = activations.roll(-shift_bins, dims=-1)

    pitch = reduce_activations(activations, reduction=self.cfg.reduction)
    if self.cfg.convert_to_freq:
        pitch = 440 * 2 ** ((pitch - 69) / 12)

    return pitch, confidence, volume, activations

bs_roformer

Band-Split RoPE Transformer

This implementation merges the two versions found in lucidrains's implementation However, there are several inconsistencies:

  • MLP was defined differently in each file, one that has depth - 1 hidden layers and one that has depth layers.
  • BSRoformer applies one final RMSNorm after the entire stack of transformer layers, while the MelBandRoformer applies an RMSNorm at the end of each axial transformer block (time_transformer, freq_transformer, etc.) and has no final normalization layer.

Since fixing the three inconsistencies upstream is too big of a breaking change, we inherit them to maintain compatibility with community-trained models. See: https://web.archive.org/web/20260112010548/https://github.com/lucidrains/BS-RoFormer/issues/48.

To avoid dependency bloat, we do not:

Classes:

Name Description
FixedBandsConfig
MelBandsConfig
BaselineMaskEstimatorConfig
AxialRefinerLargeV2MaskEstimatorConfig

unwa large-v2 head. Adds a small axial transformer refiner inside the mask head.

HyperAceResidualV1MaskEstimatorConfig

unwa HyperACE v1 residual head compatibility config.

HyperAceResidualV2MaskEstimatorConfig

UNWA HyperACE v2 residual head compatibility config.

BSRoformerParams
RMSNorm
RMSNormWithEps
FeedForward
Attention
LinearAttention

this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.

Transformer
BandSplit
MaskEstimator
AxialRefinerLargeV2MaskEstimator
HyperAceResidualMaskEstimator
BSRoformer

Functions:

Name Description
l2norm
rms_norm
mlp

Attributes:

Name Type Description
DEFAULT_FREQS_PER_BANDS
MaskEstimatorConfig

DEFAULT_FREQS_PER_BANDS module-attribute

DEFAULT_FREQS_PER_BANDS = (
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    12,
    12,
    12,
    12,
    12,
    12,
    12,
    12,
    24,
    24,
    24,
    24,
    24,
    24,
    24,
    24,
    48,
    48,
    48,
    48,
    48,
    48,
    48,
    48,
    128,
    129,
)

FixedBandsConfig dataclass

FixedBandsConfig(
    kind: Literal["fixed"],
    freqs_per_bands: tuple[Gt0[int], ...] = (
        lambda: DEFAULT_FREQS_PER_BANDS
    )(),
)

Attributes:

Name Type Description
kind Literal['fixed']
freqs_per_bands tuple[Gt0[int], ...]
kind instance-attribute
kind: Literal['fixed']
freqs_per_bands class-attribute instance-attribute
freqs_per_bands: tuple[Gt0[int], ...] = field(
    default_factory=lambda: DEFAULT_FREQS_PER_BANDS
)

MelBandsConfig dataclass

MelBandsConfig(
    kind: Literal["mel"],
    stft_n_fft: Gt0[int] = 2048,
    num_bands: Gt0[int] = 60,
    sample_rate: Gt0[int] = 44100,
)

Attributes:

Name Type Description
kind Literal['mel']
stft_n_fft Gt0[int]
num_bands Gt0[int]
sample_rate Gt0[int]
kind instance-attribute
kind: Literal['mel']
stft_n_fft class-attribute instance-attribute
stft_n_fft: Gt0[int] = 2048
num_bands class-attribute instance-attribute
num_bands: Gt0[int] = 60
sample_rate class-attribute instance-attribute
sample_rate: Gt0[int] = 44100

BaselineMaskEstimatorConfig dataclass

BaselineMaskEstimatorConfig(
    kind: Literal["baseline"] = "baseline",
)

Attributes:

Name Type Description
kind Literal['baseline']
kind class-attribute instance-attribute
kind: Literal['baseline'] = 'baseline'

AxialRefinerLargeV2MaskEstimatorConfig dataclass

AxialRefinerLargeV2MaskEstimatorConfig(
    kind: Literal["axial_refiner_large_v2"],
    axial_refiner_depth: Gt0[int] = 4,
)

unwa large-v2 head. Adds a small axial transformer refiner inside the mask head.

Attributes:

Name Type Description
kind Literal['axial_refiner_large_v2']
axial_refiner_depth Gt0[int]
kind instance-attribute
kind: Literal['axial_refiner_large_v2']
axial_refiner_depth class-attribute instance-attribute
axial_refiner_depth: Gt0[int] = 4

HyperAceResidualV1MaskEstimatorConfig dataclass

HyperAceResidualV1MaskEstimatorConfig(
    kind: Literal["hyperace_residual_v1"],
    num_hyperedges: Gt0[int] | None = None,
    num_heads: Gt0[int] = 8,
)

unwa HyperACE v1 residual head compatibility config.

Attributes:

Name Type Description
kind Literal['hyperace_residual_v1']
num_hyperedges Gt0[int] | None
num_heads Gt0[int]
kind instance-attribute
kind: Literal['hyperace_residual_v1']
num_hyperedges class-attribute instance-attribute
num_hyperedges: Gt0[int] | None = None
num_heads class-attribute instance-attribute
num_heads: Gt0[int] = 8

HyperAceResidualV2MaskEstimatorConfig dataclass

HyperAceResidualV2MaskEstimatorConfig(
    kind: Literal["hyperace_residual_v2"],
    num_hyperedges: Gt0[int] | None = None,
    num_heads: Gt0[int] = 8,
)

UNWA HyperACE v2 residual head compatibility config.

Attributes:

Name Type Description
kind Literal['hyperace_residual_v2']
num_hyperedges Gt0[int] | None
num_heads Gt0[int]
kind instance-attribute
kind: Literal['hyperace_residual_v2']
num_hyperedges class-attribute instance-attribute
num_hyperedges: Gt0[int] | None = None
num_heads class-attribute instance-attribute
num_heads: Gt0[int] = 8

BSRoformerParams dataclass

BSRoformerParams(
    chunk_size: ChunkSize,
    output_stem_names: tuple[ModelOutputStemName, ...],
    dim: Gt0[int],
    depth: Gt0[int],
    band_config: FixedBandsConfig | MelBandsConfig,
    stft_hop_length: HopSize = 512,
    stereo: bool = True,
    time_transformer_depth: Gt0[int] = 1,
    freq_transformer_depth: Gt0[int] = 1,
    linear_transformer_depth: Ge0[int] = 0,
    dim_head: int = 64,
    heads: Gt0[int] = 8,
    attn_dropout: Dropout = 0.0,
    ff_dropout: Dropout = 0.0,
    ff_mult: Gt0[int] = 4,
    attention_backend: AttentionBackend = "sdpa",
    norm_output: bool = False,
    mask_estimator_depth: Gt0[int] = 2,
    mlp_expansion_factor: Gt0[int] = 4,
    mask_estimator: MaskEstimatorConfig = BaselineMaskEstimatorConfig(),
    use_torch_checkpoint: bool = False,
    use_shared_bias: bool = False,
    skip_connection: bool = False,
    rms_norm_eps: Ge0[float] | None = None,
    rotary_embed_dtype: TorchDtype | None = None,
    transformer_residual_dtype: TorchDtype | None = None,
    debug: bool = False,
)

Bases: ModelParamsLike

Attributes:

Name Type Description
chunk_size ChunkSize
output_stem_names tuple[ModelOutputStemName, ...]
dim Gt0[int]
depth Gt0[int]
band_config FixedBandsConfig | MelBandsConfig
stft_hop_length HopSize
stereo bool
time_transformer_depth Gt0[int]
freq_transformer_depth Gt0[int]
linear_transformer_depth Ge0[int]
dim_head int
heads Gt0[int]
attn_dropout Dropout
ff_dropout Dropout
ff_mult Gt0[int]
attention_backend AttentionBackend
norm_output bool

Note that in lucidrains' implementation, this is set to

mask_estimator_depth Gt0[int]

The number of hidden layers of the MLP is mask_estimator_depth - 1, that is:

mlp_expansion_factor Gt0[int]
mask_estimator MaskEstimatorConfig
use_torch_checkpoint bool
use_shared_bias bool
skip_connection bool
rms_norm_eps Ge0[float] | None
rotary_embed_dtype TorchDtype | None
transformer_residual_dtype TorchDtype | None
debug bool

Whether to check for nan/inf in model outputs. Keep it off for torch.compile.

input_channels ModelInputChannels
input_type ModelInputType
output_type ModelOutputType
inference_archetype InferenceArchetype
chunk_size instance-attribute
chunk_size: ChunkSize
output_stem_names instance-attribute
output_stem_names: tuple[ModelOutputStemName, ...]
dim instance-attribute
dim: Gt0[int]
depth instance-attribute
depth: Gt0[int]
band_config instance-attribute
stft_hop_length class-attribute instance-attribute
stft_hop_length: HopSize = 512
stereo class-attribute instance-attribute
stereo: bool = True
time_transformer_depth class-attribute instance-attribute
time_transformer_depth: Gt0[int] = 1
freq_transformer_depth class-attribute instance-attribute
freq_transformer_depth: Gt0[int] = 1
linear_transformer_depth class-attribute instance-attribute
linear_transformer_depth: Ge0[int] = 0
dim_head class-attribute instance-attribute
dim_head: int = 64
heads class-attribute instance-attribute
heads: Gt0[int] = 8
attn_dropout class-attribute instance-attribute
attn_dropout: Dropout = 0.0
ff_dropout class-attribute instance-attribute
ff_dropout: Dropout = 0.0
ff_mult class-attribute instance-attribute
ff_mult: Gt0[int] = 4
attention_backend class-attribute instance-attribute
attention_backend: AttentionBackend = 'sdpa'
norm_output class-attribute instance-attribute
norm_output: bool = False

Note that in lucidrains' implementation, this is set to False for bs_roformer but True for mel_roformer!!

mask_estimator_depth class-attribute instance-attribute
mask_estimator_depth: Gt0[int] = 2

The number of hidden layers of the MLP is mask_estimator_depth - 1, that is:

  • depth = 1: (dim_in, dim_out)
  • depth = 2: (dim_in, dim_hidden, dim_out)

Note that in lucidrains' implementation of mel-band roformers, the number of hidden layers is incorrectly set as mask_estimator_depth. This includes popular models like kim-vocals and all models that use zfturbo's music-source-separation training.

If you are migrating a mel-band roformer's zfturbo configuration, increment the mask_estimator depth by 1.

mlp_expansion_factor class-attribute instance-attribute
mlp_expansion_factor: Gt0[int] = 4
mask_estimator class-attribute instance-attribute
mask_estimator: MaskEstimatorConfig = field(
    default_factory=BaselineMaskEstimatorConfig
)
use_torch_checkpoint class-attribute instance-attribute
use_torch_checkpoint: bool = False
use_shared_bias class-attribute instance-attribute
use_shared_bias: bool = False
skip_connection class-attribute instance-attribute
skip_connection: bool = False
rms_norm_eps class-attribute instance-attribute
rms_norm_eps: Ge0[float] | None = None
rotary_embed_dtype class-attribute instance-attribute
rotary_embed_dtype: TorchDtype | None = None
transformer_residual_dtype class-attribute instance-attribute
transformer_residual_dtype: TorchDtype | None = None
debug class-attribute instance-attribute
debug: bool = False

Whether to check for nan/inf in model outputs. Keep it off for torch.compile.

input_channels property
input_channels: ModelInputChannels
input_type property
input_type: ModelInputType
output_type property
output_type: ModelOutputType
inference_archetype property
inference_archetype: InferenceArchetype

l2norm

l2norm(t: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
177
178
def l2norm(t: Tensor) -> Tensor:
    return F.normalize(t, dim=-1, p=2)

RMSNorm

RMSNorm(dim: int)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
scale
gamma
Source code in src/splifft/models/bs_roformer.py
182
183
184
185
def __init__(self, dim: int):
    super().__init__()
    self.scale = dim**0.5
    self.gamma = nn.Parameter(torch.ones(dim))
scale instance-attribute
scale = dim ** 0.5
gamma instance-attribute
gamma = Parameter(ones(dim))
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
187
188
def forward(self, x: Tensor) -> Tensor:
    return F.normalize(x, dim=-1) * self.scale * self.gamma  # type: ignore

RMSNormWithEps

RMSNormWithEps(
    dim: int, eps: float = 5.960464477539063e-08
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
scale
gamma
eps
Source code in src/splifft/models/bs_roformer.py
192
193
194
195
196
def __init__(self, dim: int, eps: float = 5.960464477539063e-08):
    super().__init__()
    self.scale = dim**0.5
    self.gamma = nn.Parameter(torch.ones(dim))
    self.eps = eps
scale instance-attribute
scale = dim ** 0.5
gamma instance-attribute
gamma = Parameter(ones(dim))
eps instance-attribute
eps = eps
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
198
199
200
201
202
def forward(self, x: Tensor) -> Tensor:
    l2_norm = torch.linalg.norm(x, dim=-1, keepdim=True)
    denom = torch.maximum(l2_norm, torch.full_like(l2_norm, self.eps))
    normalized_x = x / denom
    return normalized_x * self.scale * self.gamma  # type: ignore

rms_norm

rms_norm(
    dim: int, eps: float | None
) -> RMSNorm | RMSNormWithEps
Source code in src/splifft/models/bs_roformer.py
205
206
207
208
def rms_norm(dim: int, eps: float | None) -> RMSNorm | RMSNormWithEps:
    if eps is None:
        return RMSNorm(dim)
    return RMSNormWithEps(dim, eps)

FeedForward

FeedForward(
    dim: int,
    mult: int = 4,
    dropout: float = 0.0,
    rms_norm_eps: float | None = None,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
net
Source code in src/splifft/models/bs_roformer.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def __init__(
    self, dim: int, mult: int = 4, dropout: float = 0.0, rms_norm_eps: float | None = None
):
    super().__init__()
    dim_inner = int(dim * mult)
    # NOTE: in the paper: RMSNorm -> FC -> Tanh -> FC -> GLU
    self.net = nn.Sequential(
        rms_norm(dim, eps=rms_norm_eps),
        nn.Linear(dim, dim_inner),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim_inner, dim),
        nn.Dropout(dropout),
    )
net instance-attribute
net = Sequential(
    rms_norm(dim, eps=rms_norm_eps),
    Linear(dim, dim_inner),
    GELU(),
    Dropout(dropout),
    Linear(dim_inner, dim),
    Dropout(dropout),
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
230
231
def forward(self, x: Tensor) -> Tensor:
    return cast(Tensor, self.net(x))

Attention

Attention(
    dim: int,
    heads: int = 8,
    dim_head: int = 64,
    dropout: float = 0.0,
    shared_qkv_bias: Parameter | None = None,
    shared_out_bias: Parameter | None = None,
    rotary_embed: RotaryEmbedding | None = None,
    attention_backend: AttentionBackend = "sdpa",
    rms_norm_eps: float | None = None,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
heads
scale
rotary_embed
attend
norm
to_qkv
to_gates
to_out
Source code in src/splifft/models/bs_roformer.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def __init__(
    self,
    dim: int,
    heads: int = 8,
    dim_head: int = 64,
    dropout: float = 0.0,
    shared_qkv_bias: nn.Parameter | None = None,
    shared_out_bias: nn.Parameter | None = None,
    rotary_embed: RotaryEmbedding | None = None,
    attention_backend: AttentionBackend = "sdpa",
    rms_norm_eps: float | None = None,
):
    super().__init__()
    self.heads = heads
    self.scale = dim_head**-0.5
    dim_inner = heads * dim_head

    self.rotary_embed = rotary_embed

    self.attend = Attend(backend=attention_backend, dropout=dropout)

    self.norm = rms_norm(dim, eps=rms_norm_eps)
    self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=(shared_qkv_bias is not None))
    if shared_qkv_bias is not None:
        self.to_qkv.bias = shared_qkv_bias

    self.to_gates = nn.Linear(dim, heads)

    self.to_out = nn.Sequential(
        nn.Linear(dim_inner, dim, bias=(shared_out_bias is not None)),
        nn.Dropout(dropout),
    )
    if shared_out_bias is not None:
        self.to_out[0].bias = shared_out_bias
heads instance-attribute
heads = heads
scale instance-attribute
scale = dim_head ** -0.5
rotary_embed instance-attribute
rotary_embed = rotary_embed
attend instance-attribute
attend = Attend(backend=attention_backend, dropout=dropout)
norm instance-attribute
norm = rms_norm(dim, eps=rms_norm_eps)
to_qkv instance-attribute
to_qkv = Linear(
    dim, dim_inner * 3, bias=shared_qkv_bias is not None
)
to_gates instance-attribute
to_gates = Linear(dim, heads)
to_out instance-attribute
to_out = Sequential(
    Linear(
        dim_inner, dim, bias=shared_out_bias is not None
    ),
    Dropout(dropout),
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def forward(self, x: Tensor) -> Tensor:
    x = self.norm(x)

    qkv = self.to_qkv(x)
    q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)

    if self.rotary_embed is not None:
        q = self.rotary_embed(q)
        k = self.rotary_embed(k)

    out = self.attend(q, k, v)

    gates = self.to_gates(x)
    gate_act = gates.sigmoid()

    out = out * rearrange(gate_act, "b n h -> b h n 1")

    out = rearrange(out, "b h n d -> b n (h d)")
    out = self.to_out(out)
    return cast(Tensor, out)

LinearAttention

LinearAttention(
    *,
    dim: int,
    dim_head: int = 32,
    heads: int = 8,
    scale: int = 8,
    dropout: float = 0.0,
    attention_backend: AttentionBackend = "sdpa",
    rms_norm_eps: float | None = None,
)

Bases: Module

this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.

Methods:

Name Description
forward

Attributes:

Name Type Description
norm
to_qkv
temperature
attend
to_out
Source code in src/splifft/models/bs_roformer.py
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def __init__(
    self,
    *,
    dim: int,
    dim_head: int = 32,
    heads: int = 8,
    scale: int = 8,
    dropout: float = 0.0,
    attention_backend: AttentionBackend = "sdpa",
    rms_norm_eps: float | None = None,
):
    super().__init__()
    dim_inner = dim_head * heads
    self.norm = rms_norm(dim, eps=rms_norm_eps)

    self.to_qkv = nn.Sequential(
        nn.Linear(dim, dim_inner * 3, bias=False),
        Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
    )

    self.temperature = nn.Parameter(torch.ones(heads, 1, 1))

    self.attend = Attend(scale=scale, dropout=dropout, backend=attention_backend)

    self.to_out = nn.Sequential(
        Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
    )
norm instance-attribute
norm = rms_norm(dim, eps=rms_norm_eps)
to_qkv instance-attribute
to_qkv = Sequential(
    Linear(dim, dim_inner * 3, bias=False),
    Rearrange(
        "b n (qkv h d) -> qkv b h d n", qkv=3, h=heads
    ),
)
temperature instance-attribute
temperature = Parameter(ones(heads, 1, 1))
attend instance-attribute
attend = Attend(
    scale=scale, dropout=dropout, backend=attention_backend
)
to_out instance-attribute
to_out = Sequential(
    Rearrange("b h d n -> b n (h d)"),
    Linear(dim_inner, dim, bias=False),
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
325
326
327
328
329
330
331
332
333
334
335
def forward(self, x: Tensor) -> Tensor:
    x = self.norm(x)

    q, k, v = self.to_qkv(x)

    q, k = map(l2norm, (q, k))
    q = q * self.temperature.exp()

    out = self.attend(q, k, v)

    return cast(Tensor, self.to_out(out))

Transformer

Transformer(
    *,
    dim: int,
    depth: int,
    dim_head: int = 64,
    heads: int = 8,
    attn_dropout: float = 0.0,
    ff_dropout: float = 0.0,
    ff_mult: int = 4,
    norm_output: bool = True,
    rotary_embed: RotaryEmbedding | None = None,
    attention_backend: AttentionBackend = "sdpa",
    linear_attn: bool = False,
    shared_qkv_bias: Parameter | None = None,
    shared_out_bias: Parameter | None = None,
    rms_norm_eps: float | None = None,
    transformer_residual_dtype: dtype | None = None,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
layers
transformer_residual_dtype
norm
Source code in src/splifft/models/bs_roformer.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
def __init__(
    self,
    *,
    dim: int,
    depth: int,
    dim_head: int = 64,
    heads: int = 8,
    attn_dropout: float = 0.0,
    ff_dropout: float = 0.0,
    ff_mult: int = 4,
    norm_output: bool = True,
    rotary_embed: RotaryEmbedding | None = None,
    attention_backend: AttentionBackend = "sdpa",
    linear_attn: bool = False,
    shared_qkv_bias: nn.Parameter | None = None,
    shared_out_bias: nn.Parameter | None = None,
    rms_norm_eps: float | None = None,
    transformer_residual_dtype: torch.dtype | None = None,  # COMPAT: float32, see 265
):
    super().__init__()
    self.layers = ModuleList([])

    for _ in range(depth):
        attn: LinearAttention | Attention
        if linear_attn:
            attn = LinearAttention(
                dim=dim,
                dim_head=dim_head,
                heads=heads,
                dropout=attn_dropout,
                attention_backend=attention_backend,
                rms_norm_eps=rms_norm_eps,
            )
        else:
            attn = Attention(
                dim=dim,
                dim_head=dim_head,
                heads=heads,
                dropout=attn_dropout,
                shared_qkv_bias=shared_qkv_bias,
                shared_out_bias=shared_out_bias,
                rotary_embed=rotary_embed,
                attention_backend=attention_backend,
                rms_norm_eps=rms_norm_eps,
            )

        ff = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout, rms_norm_eps=rms_norm_eps)
        self.layers.append(ModuleList([attn, ff]))
    self.transformer_residual_dtype = transformer_residual_dtype

    self.norm = rms_norm(dim, eps=rms_norm_eps) if norm_output else nn.Identity()
layers instance-attribute
layers = ModuleList([])
transformer_residual_dtype instance-attribute
transformer_residual_dtype = transformer_residual_dtype
norm instance-attribute
norm = (
    rms_norm(dim, eps=rms_norm_eps)
    if norm_output
    else Identity()
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
def forward(self, x: Tensor) -> Tensor:
    for layer in self.layers:
        block = cast(ModuleList, layer)
        attn = block[0]
        ff = block[1]
        attn_out = attn(x)
        if self.transformer_residual_dtype is not None:
            x = (
                attn_out.to(self.transformer_residual_dtype)
                + x.to(self.transformer_residual_dtype)
            ).to(x.dtype)
        else:
            x = attn_out + x

        ff_out = ff(x)
        if self.transformer_residual_dtype is not None:
            x = (
                ff_out.to(self.transformer_residual_dtype)
                + x.to(self.transformer_residual_dtype)
            ).to(x.dtype)
        else:
            x = ff_out + x
    return cast(Tensor, self.norm(x))

BandSplit

BandSplit(
    dim: int,
    dim_inputs: tuple[int, ...],
    rms_norm_eps: float | None = None,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
dim_inputs
to_features
Source code in src/splifft/models/bs_roformer.py
420
421
422
423
424
425
426
427
def __init__(self, dim: int, dim_inputs: tuple[int, ...], rms_norm_eps: float | None = None):
    super().__init__()
    self.dim_inputs = dim_inputs
    self.to_features = ModuleList([])

    for dim_in in dim_inputs:
        net = nn.Sequential(rms_norm(dim_in, rms_norm_eps), nn.Linear(dim_in, dim))
        self.to_features.append(net)
dim_inputs instance-attribute
dim_inputs = dim_inputs
to_features instance-attribute
to_features = ModuleList([])
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
429
430
431
432
433
434
435
def forward(self, x: Tensor) -> Tensor:
    x_split = torch.split(x, list(self.dim_inputs), dim=-1)
    outs = []
    for split_input, to_feature_net in zip(x_split, self.to_features):
        split_output = to_feature_net(split_input)
        outs.append(split_output)
    return torch.stack(outs, dim=-2)

mlp

mlp(
    dim_in: int,
    dim_out: int,
    dim_hidden: int | None = None,
    depth: int = 1,
    activation: type[Module] = Tanh,
) -> Sequential
Source code in src/splifft/models/bs_roformer.py
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
def mlp(
    dim_in: int,
    dim_out: int,
    dim_hidden: int | None = None,
    depth: int = 1,
    activation: type[Module] = nn.Tanh,
) -> nn.Sequential:
    dim_hidden_ = dim_hidden or dim_in

    net: list[Module] = []
    # NOTE: in lucidrain's impl, `bs_roformer` has `depth - 1` but `mel_roformer` has `depth`
    num_hidden_layers = depth - 1
    dims = (dim_in, *((dim_hidden_,) * num_hidden_layers), dim_out)

    for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
        is_last = ind == (len(dims) - 2)

        net.append(nn.Linear(layer_dim_in, layer_dim_out))

        if is_last:
            continue

        net.append(activation())

    return nn.Sequential(*net)

MaskEstimator

MaskEstimator(
    dim: int,
    dim_inputs: tuple[int, ...],
    depth: int,
    mlp_expansion_factor: int,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
dim_inputs
to_freqs
Source code in src/splifft/models/bs_roformer.py
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
def __init__(
    self,
    dim: int,
    dim_inputs: tuple[int, ...],
    depth: int,
    mlp_expansion_factor: int,
):
    super().__init__()
    self.dim_inputs = dim_inputs
    self.to_freqs = _build_band_to_freq_mlps(
        dim=dim,
        dim_inputs=dim_inputs,
        depth=depth,
        mlp_expansion_factor=mlp_expansion_factor,
    )
dim_inputs instance-attribute
dim_inputs = dim_inputs
to_freqs instance-attribute
to_freqs = _build_band_to_freq_mlps(
    dim=dim,
    dim_inputs=dim_inputs,
    depth=depth,
    mlp_expansion_factor=mlp_expansion_factor,
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
501
502
503
504
505
506
507
508
509
510
def forward(self, x: Tensor) -> Tensor:
    x_unbound = x.unbind(dim=-2)

    outs = []

    for band_features, mlp_net in zip(x_unbound, self.to_freqs):
        freq_out = mlp_net(band_features)
        outs.append(freq_out)

    return torch.cat(outs, dim=-1)

AxialRefinerLargeV2MaskEstimator

AxialRefinerLargeV2MaskEstimator(
    dim: int,
    dim_inputs: tuple[int, ...],
    mlp_depth: int,
    mlp_expansion_factor: int,
    axial_refiner_depth: int,
    t_frames: int,
    num_bands: int,
    rotary_embed_dtype: dtype | None,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
dim_inputs
to_freqs
layers
norm
Source code in src/splifft/models/bs_roformer.py
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
def __init__(
    self,
    dim: int,
    dim_inputs: tuple[int, ...],
    mlp_depth: int,
    mlp_expansion_factor: int,
    axial_refiner_depth: int,
    t_frames: int,
    num_bands: int,
    rotary_embed_dtype: torch.dtype | None,
):
    super().__init__()
    self.dim_inputs = dim_inputs
    self.to_freqs = _build_band_to_freq_mlps(
        dim=dim,
        dim_inputs=dim_inputs,
        depth=mlp_depth,
        mlp_expansion_factor=mlp_expansion_factor,
    )

    self.layers = ModuleList([])

    heads = 8
    dim_head = 64

    time_rotary_embed = RotaryEmbedding(
        seq_len=t_frames,
        dim_head=dim_head,
        dtype=rotary_embed_dtype,
    )
    freq_rotary_embed = RotaryEmbedding(
        seq_len=num_bands,
        dim_head=dim_head,
        dtype=rotary_embed_dtype,
    )

    for _ in range(axial_refiner_depth):
        self.layers.append(
            nn.ModuleList(
                [
                    Transformer(
                        dim=dim,
                        depth=1,
                        heads=heads,
                        dim_head=dim_head,
                        attn_dropout=0.0,
                        ff_dropout=0.0,
                        attention_backend="sdpa",
                        norm_output=False,
                        rotary_embed=time_rotary_embed,
                    ),
                    Transformer(
                        dim=dim,
                        depth=1,
                        heads=heads,
                        dim_head=dim_head,
                        attn_dropout=0.0,
                        ff_dropout=0.0,
                        attention_backend="sdpa",
                        norm_output=False,
                        rotary_embed=freq_rotary_embed,
                    ),
                ]
            )
        )

    self.norm = RMSNorm(dim)
dim_inputs instance-attribute
dim_inputs = dim_inputs
to_freqs instance-attribute
to_freqs = _build_band_to_freq_mlps(
    dim=dim,
    dim_inputs=dim_inputs,
    depth=mlp_depth,
    mlp_expansion_factor=mlp_expansion_factor,
)
layers instance-attribute
layers = ModuleList([])
norm instance-attribute
norm = RMSNorm(dim)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
def forward(self, x: Tensor) -> Tensor:
    for transformer_block in self.layers:
        block = cast(ModuleList, transformer_block)
        time_transformer, freq_transformer = block

        x = rearrange(x, "b t f d -> b f t d")
        x, ps = pack([x], "* t d")

        x = time_transformer(x)

        (x,) = unpack(x, ps, "* t d")
        x = rearrange(x, "b f t d -> b t f d")
        x, ps = pack([x], "* f d")

        x = freq_transformer(x)

        (x,) = unpack(x, ps, "* f d")

    x = self.norm(x)

    x_unbound = x.unbind(dim=-2)

    outs = []

    for band_features, mlp_net in zip(x_unbound, self.to_freqs):
        freq_out = mlp_net(band_features)
        outs.append(freq_out)

    return torch.cat(outs, dim=-1)

HyperAceResidualMaskEstimator

HyperAceResidualMaskEstimator(
    dim: int,
    dim_inputs: tuple[int, ...],
    depth: int,
    mlp_expansion_factor: int,
    segm: Module,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
dim_inputs
to_freqs
segm
Source code in src/splifft/models/bs_roformer.py
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
def __init__(
    self,
    dim: int,
    dim_inputs: tuple[int, ...],
    depth: int,
    mlp_expansion_factor: int,
    segm: nn.Module,
):
    super().__init__()
    self.dim_inputs = dim_inputs
    self.to_freqs = _build_band_to_freq_mlps(
        dim=dim,
        dim_inputs=dim_inputs,
        depth=depth,
        mlp_expansion_factor=mlp_expansion_factor,
    )

    self.segm = segm
dim_inputs instance-attribute
dim_inputs = dim_inputs
to_freqs instance-attribute
to_freqs = _build_band_to_freq_mlps(
    dim=dim,
    dim_inputs=dim_inputs,
    depth=depth,
    mlp_expansion_factor=mlp_expansion_factor,
)
segm instance-attribute
segm = segm
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
633
634
635
636
637
638
639
640
641
642
643
644
def forward(self, x: Tensor) -> Tensor:
    y = rearrange(x, "b t f c -> b c t f")
    y = self.segm(y)
    y = rearrange(y, "b c t f -> b t (f c)")

    x_unbound = x.unbind(dim=-2)
    outs = []
    for band_features, mlp_net in zip(x_unbound, self.to_freqs):
        freq_out = mlp_net(band_features)
        outs.append(freq_out)

    return cast(Tensor, torch.cat(outs, dim=-1) + y)

BSRoformer

BSRoformer(cfg: BSRoformerParams)

Bases: Module, SupportsStemSelection[BSRoformerParams]

Methods:

Name Description
forward

:param stft_repr: input spectrogram. shape (b, f*s, t, c)

__splifft_stem_selection_plan__

Remap mask_estimators.{i}.* state-dict entries to a compact

Attributes:

Name Type Description
stereo
audio_channels
num_stems
use_torch_checkpoint
skip_connection
layers
shared_qkv_bias Parameter | None
shared_out_bias Parameter | None
is_mel
final_norm
band_split
mask_estimators
debug
Source code in src/splifft/models/bs_roformer.py
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
def __init__(self, cfg: BSRoformerParams):
    super().__init__()
    self.stereo = cfg.stereo
    self.audio_channels = 2 if cfg.stereo else 1
    self.num_stems = len(cfg.output_stem_names)
    self.use_torch_checkpoint = cfg.use_torch_checkpoint
    self.skip_connection = cfg.skip_connection

    self.layers = ModuleList([])

    self.shared_qkv_bias: nn.Parameter | None = None
    self.shared_out_bias: nn.Parameter | None = None
    if cfg.use_shared_bias:
        dim_inner = cfg.heads * cfg.dim_head
        self.shared_qkv_bias = nn.Parameter(torch.ones(dim_inner * 3))
        self.shared_out_bias = nn.Parameter(torch.ones(cfg.dim))

    transformer = partial(
        Transformer,
        dim=cfg.dim,
        heads=cfg.heads,
        dim_head=cfg.dim_head,
        attn_dropout=cfg.attn_dropout,
        ff_dropout=cfg.ff_dropout,
        ff_mult=cfg.ff_mult,
        attention_backend=cfg.attention_backend,
        norm_output=cfg.norm_output,
        shared_qkv_bias=self.shared_qkv_bias,
        shared_out_bias=self.shared_out_bias,
        rms_norm_eps=cfg.rms_norm_eps,
        transformer_residual_dtype=cfg.transformer_residual_dtype,
    )

    t_frames = cfg.chunk_size // cfg.stft_hop_length + 1  # e.g. 588800 // 512 + 1 = 1151
    time_rotary_embed = RotaryEmbedding(
        seq_len=t_frames, dim_head=cfg.dim_head, dtype=cfg.rotary_embed_dtype
    )

    if is_mel := isinstance(cfg.band_config, MelBandsConfig):
        from torchaudio.functional import melscale_fbanks

        mel_cfg = cfg.band_config
        num_bands = mel_cfg.num_bands
        freqs = mel_cfg.stft_n_fft // 2 + 1
        mel_filter_bank = melscale_fbanks(
            n_freqs=freqs,
            f_min=0.0,
            f_max=float(mel_cfg.sample_rate / 2),
            n_mels=num_bands,
            sample_rate=mel_cfg.sample_rate,
            norm="slaney",
            mel_scale="slaney",
        ).T
        # TODO: adopt https://github.com/lucidrains/BS-RoFormer/issues/47
        mel_filter_bank[0, 0] = 1.0
        mel_filter_bank[-1, -1] = 1.0

        freqs_per_band_mask = mel_filter_bank > 0
        assert freqs_per_band_mask.any(dim=0).all(), (
            "all frequencies must be covered by at least one band"
        )

        repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands)
        freq_indices = repeated_freq_indices[freqs_per_band_mask]
        if self.stereo:
            freq_indices = repeat(freq_indices, "f -> f s", s=2)
            freq_indices = freq_indices * 2 + torch.arange(2)
            freq_indices = rearrange(freq_indices, "f s -> (f s)")
        self.register_buffer("freq_indices", freq_indices, persistent=False)
        self.register_buffer("freqs_per_band_mask", freqs_per_band_mask, persistent=False)

        num_freqs_per_band = reduce(freqs_per_band_mask, "b f -> b", "sum")
        num_bands_per_freq = reduce(freqs_per_band_mask, "b f -> f", "sum")

        self.register_buffer("num_freqs_per_band", num_freqs_per_band, persistent=False)
        self.register_buffer("num_bands_per_freq", num_bands_per_freq, persistent=False)

    elif isinstance(cfg.band_config, FixedBandsConfig):
        num_freqs_per_band = torch.tensor(cfg.band_config.freqs_per_bands)
        num_bands = len(cfg.band_config.freqs_per_bands)
    else:
        raise TypeError(f"unknown band config: {cfg.band_config}")
    self.is_mel = is_mel

    freq_rotary_embed = RotaryEmbedding(
        seq_len=num_bands, dim_head=cfg.dim_head, dtype=cfg.rotary_embed_dtype
    )

    for _ in range(cfg.depth):
        tran_modules = []
        if cfg.linear_transformer_depth > 0:
            tran_modules.append(
                transformer(depth=cfg.linear_transformer_depth, linear_attn=True)
            )
        tran_modules.append(
            transformer(depth=cfg.time_transformer_depth, rotary_embed=time_rotary_embed)
        )
        tran_modules.append(
            transformer(depth=cfg.freq_transformer_depth, rotary_embed=freq_rotary_embed)
        )
        self.layers.append(nn.ModuleList(tran_modules))

    self.final_norm = (
        rms_norm(cfg.dim, eps=cfg.rms_norm_eps) if not self.is_mel else nn.Identity()
    )

    freqs_per_bands_with_complex = tuple(
        2 * f * self.audio_channels for f in num_freqs_per_band.tolist()
    )

    self.band_split = BandSplit(
        dim=cfg.dim,
        dim_inputs=freqs_per_bands_with_complex,
        rms_norm_eps=cfg.rms_norm_eps,
    )

    self.mask_estimators = nn.ModuleList([])

    def build_hyperace(config: MaskEstimatorConfig) -> nn.Module:
        if isinstance(config, HyperAceResidualV1MaskEstimatorConfig):
            from .utils.hyperace import SegmModelHyperAceV1

            return SegmModelHyperAceV1(
                in_bands=len(freqs_per_bands_with_complex),
                in_dim=cfg.dim,
                out_bins=sum(freqs_per_bands_with_complex) // 4,
                num_hyperedges=config.num_hyperedges or 16,
                num_heads=config.num_heads,
            )
        if isinstance(config, HyperAceResidualV2MaskEstimatorConfig):
            from .utils.hyperace import SegmModelHyperAceV2

            return SegmModelHyperAceV2(
                in_bands=len(freqs_per_bands_with_complex),
                in_dim=cfg.dim,
                out_bins=sum(freqs_per_bands_with_complex) // 4,
                num_hyperedges=config.num_hyperedges or 32,
                num_heads=config.num_heads,
            )
        raise TypeError(f"mask estimator is not hyperace-based: {config}")

    def build_mask_estimator(config: MaskEstimatorConfig) -> nn.Module:
        if isinstance(config, BaselineMaskEstimatorConfig):
            return MaskEstimator(
                dim=cfg.dim,
                dim_inputs=freqs_per_bands_with_complex,
                depth=cfg.mask_estimator_depth,
                mlp_expansion_factor=cfg.mlp_expansion_factor,
            )

        if isinstance(config, AxialRefinerLargeV2MaskEstimatorConfig):
            return AxialRefinerLargeV2MaskEstimator(
                dim=cfg.dim,
                dim_inputs=freqs_per_bands_with_complex,
                mlp_depth=cfg.mask_estimator_depth,
                mlp_expansion_factor=cfg.mlp_expansion_factor,
                axial_refiner_depth=config.axial_refiner_depth,
                t_frames=t_frames,
                num_bands=num_bands,
                rotary_embed_dtype=cfg.rotary_embed_dtype,
            )

        if isinstance(
            config,
            HyperAceResidualV1MaskEstimatorConfig | HyperAceResidualV2MaskEstimatorConfig,
        ):
            return HyperAceResidualMaskEstimator(
                dim=cfg.dim,
                dim_inputs=freqs_per_bands_with_complex,
                depth=cfg.mask_estimator_depth,
                mlp_expansion_factor=cfg.mlp_expansion_factor,
                segm=build_hyperace(config),
            )

        raise TypeError(f"unknown mask_estimator config: {config}")

    for _ in range(len(cfg.output_stem_names)):
        self.mask_estimators.append(build_mask_estimator(cfg.mask_estimator))

    self.debug = cfg.debug
stereo instance-attribute
stereo = stereo
audio_channels instance-attribute
audio_channels = 2 if stereo else 1
num_stems instance-attribute
num_stems = len(output_stem_names)
use_torch_checkpoint instance-attribute
use_torch_checkpoint = use_torch_checkpoint
skip_connection instance-attribute
skip_connection = skip_connection
layers instance-attribute
layers = ModuleList([])
shared_qkv_bias instance-attribute
shared_qkv_bias: Parameter | None = None
shared_out_bias instance-attribute
shared_out_bias: Parameter | None = None
is_mel instance-attribute
is_mel = is_mel
final_norm instance-attribute
final_norm = (
    rms_norm(dim, eps=rms_norm_eps)
    if not is_mel
    else Identity()
)
band_split instance-attribute
band_split = BandSplit(
    dim=dim,
    dim_inputs=freqs_per_bands_with_complex,
    rms_norm_eps=rms_norm_eps,
)
mask_estimators instance-attribute
mask_estimators = ModuleList([])
debug instance-attribute
debug = debug
forward
forward(stft_repr: Tensor) -> Tensor

Parameters:

Name Type Description Default
stft_repr Tensor

input spectrogram. shape (b, f*s, t, c)

required

Returns:

Type Description
Tensor

estimated mask. shape (b, n, f*s, t, c)

Source code in src/splifft/models/bs_roformer.py
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
def forward(self, stft_repr: Tensor) -> Tensor:
    """
    :param stft_repr: input spectrogram. shape (b, f*s, t, c)
    :return: estimated mask. shape (b, n, f*s, t, c)
    """
    batch, _, t_frames, _ = stft_repr.shape
    device = stft_repr.device
    if self.is_mel:
        batch_arange = torch.arange(batch, device=device)[..., None]
        x = stft_repr[batch_arange, cast(Tensor, self.freq_indices)]
        x = rearrange(x, "b f t c -> b t (f c)")
    else:
        x = rearrange(stft_repr, "b f t c -> b t (f c)")

    if self.debug and (torch.isnan(x).any() or torch.isinf(x).any()):
        raise RuntimeError(
            f"nan/inf in x after rearrange: {x.isnan().sum()} nans, {x.isinf().sum()} infs"
        )

    if self.use_torch_checkpoint:
        x = cast(Tensor, checkpoint(self.band_split, x, use_reentrant=False))
    else:
        x = cast(Tensor, self.band_split(x))

    if self.debug and (torch.isnan(x).any() or torch.isinf(x).any()):
        raise RuntimeError(
            f"nan/inf in x after band_split: {x.isnan().sum()} nans, {x.isinf().sum()} infs"
        )

    # axial / hierarchical attention

    store: list[Tensor | None] = [None] * len(self.layers)
    for i, transformer_block in enumerate(self.layers):
        block = cast(ModuleList, transformer_block)
        if len(block) == 3:
            linear_transformer, time_transformer, freq_transformer = block

            x, ft_ps = pack([x], "b * d")
            if self.use_torch_checkpoint:
                x = checkpoint(linear_transformer, x, use_reentrant=False)
            else:
                x = linear_transformer(x)
            (x,) = unpack(x, ft_ps, "b * d")
        else:
            time_transformer, freq_transformer = block

        if self.skip_connection:
            for j in range(i):
                if store[j] is not None:
                    assert x is not None
                    x = x + cast(Tensor, store[j])

        x = rearrange(x, "b t f d -> b f t d")
        x, ps = pack([x], "* t d")

        if self.use_torch_checkpoint:
            x = checkpoint(time_transformer, x, use_reentrant=False)
        else:
            x = time_transformer(x)

        (x,) = unpack(x, ps, "* t d")
        x = rearrange(x, "b f t d -> b t f d")
        x, ps = pack([x], "* f d")

        if self.use_torch_checkpoint:
            x = checkpoint(freq_transformer, x, use_reentrant=False)
        else:
            x = freq_transformer(x)

        (x,) = unpack(x, ps, "* f d")

        if self.skip_connection:
            store[i] = x

    x = self.final_norm(x)

    if self.use_torch_checkpoint:
        mask = torch.stack(
            [
                cast(Tensor, checkpoint(fn, x, use_reentrant=False))
                for fn in self.mask_estimators
            ],
            dim=1,
        )
    else:
        mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
    mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)

    if not self.is_mel:
        return mask

    stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
    # stft_repr may be fp16 but complex32 support is experimental so we upcast it early
    stft_repr_complex = torch.view_as_complex(stft_repr.to(torch.float32))

    masks_per_band_complex = torch.view_as_complex(mask)
    masks_per_band_complex = masks_per_band_complex.type(stft_repr_complex.dtype)

    scatter_indices = repeat(
        cast(Tensor, self.freq_indices),
        "f -> b n f t",
        b=batch,
        n=self.num_stems,
        t=stft_repr_complex.shape[-1],
    )
    stft_repr_expanded_stems = repeat(stft_repr_complex, "b 1 ... -> b n ...", n=self.num_stems)

    masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(
        2, scatter_indices, masks_per_band_complex
    )

    denom = cast(Tensor, repeat(self.num_bands_per_freq, "f -> (f r) 1", r=self.audio_channels))
    masks_averaged = masks_summed / denom.clamp(min=1e-8)

    return torch.view_as_real(masks_averaged).to(stft_repr.dtype)
__splifft_stem_selection_plan__ classmethod
__splifft_stem_selection_plan__(
    model_params: BSRoformerParams,
    output_stem_names: tuple[ModelOutputStemName, ...],
) -> StemSelectionPlan[BSRoformerParams]

Remap mask_estimators.{i}.* state-dict entries to a compact [0..k) index range so unrelated per-stem heads are never instantiated or loaded.

Source code in src/splifft/models/bs_roformer.py
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
@classmethod
def __splifft_stem_selection_plan__(
    cls,
    model_params: BSRoformerParams,
    output_stem_names: tuple[t.ModelOutputStemName, ...],
) -> StemSelectionPlan[BSRoformerParams]:
    """Remap `mask_estimators.{i}.*` state-dict entries to a compact
    `[0..k)` index range so unrelated per-stem heads are never instantiated
    or loaded.
    """

    full_stem_names = tuple(model_params.output_stem_names)
    requested_stem_names = tuple(output_stem_names)
    if requested_stem_names == full_stem_names:
        return StemSelectionPlan(
            model_params=model_params,
            output_stem_names=full_stem_names,
        )

    selected_stem_indices = tuple(full_stem_names.index(name) for name in requested_stem_names)
    key_re = re.compile(r"^mask_estimators\.(\d+)\.(.+)$")
    index_remap = {src: dst for dst, src in enumerate(selected_stem_indices)}

    def state_dict_transform(state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
        remapped: dict[str, Tensor] = {}
        for key, value in state_dict.items():
            if (m := key_re.match(key)) is None:
                remapped[key] = value
                continue
            if (src_idx := int(m.group(1))) not in index_remap:
                continue
            suffix = m.group(2)
            remapped[f"mask_estimators.{index_remap[src_idx]}.{suffix}"] = value
        return remapped

    return StemSelectionPlan(
        model_params=replace(model_params, output_stem_names=requested_stem_names),
        output_stem_names=requested_stem_names,
        state_dict_transform=state_dict_transform,
    )

basic_pitch

ICASSP 2022 Basic Pitch. Raw multi-stream outputs only, no symbolic decoding.

See: https://github.com/spotify/basic-pitch, https://arxiv.org/abs/2203.09893

Classes:

Name Description
BasicPitchParams
HarmonicStacking
BasicPitch

BasicPitchParams dataclass

BasicPitchParams(
    chunk_size: ChunkSize,
    output_stem_names: tuple[ModelOutputStemName, ...],
    sequence_window_frames: Gt0[int],
    n_semitones: Gt0[int] = 88,
    n_harmonics: Gt0[int] = 8,
    contour_bins_per_semitone: Gt0[int] = 3,
    cqt_bins_per_semitone: Gt0[int] = 3,
)

Bases: ModelParamsLike, HasSequenceWindowFrames

Attributes:

Name Type Description
chunk_size ChunkSize
output_stem_names tuple[ModelOutputStemName, ...]
sequence_window_frames Gt0[int]
n_semitones Gt0[int]
n_harmonics Gt0[int]
contour_bins_per_semitone Gt0[int]
cqt_bins_per_semitone Gt0[int]
input_channels ModelInputChannels
input_type ModelInputType
output_type ModelOutputType
inference_archetype InferenceArchetype
chunk_size instance-attribute
chunk_size: ChunkSize
output_stem_names instance-attribute
output_stem_names: tuple[ModelOutputStemName, ...]
sequence_window_frames instance-attribute
sequence_window_frames: Gt0[int]
n_semitones class-attribute instance-attribute
n_semitones: Gt0[int] = 88
n_harmonics class-attribute instance-attribute
n_harmonics: Gt0[int] = 8
contour_bins_per_semitone class-attribute instance-attribute
contour_bins_per_semitone: Gt0[int] = 3
cqt_bins_per_semitone class-attribute instance-attribute
cqt_bins_per_semitone: Gt0[int] = 3
input_channels property
input_channels: ModelInputChannels
input_type property
input_type: ModelInputType
output_type property
output_type: ModelOutputType
inference_archetype property
inference_archetype: InferenceArchetype

HarmonicStacking

HarmonicStacking(
    *,
    bins_per_semitone: int,
    harmonics: tuple[float, ...],
    n_output_freqs: int,
)

Bases: Module

Methods:

Name Description
forward

:param x: (B, T, F)

Attributes:

Name Type Description
n_output_freqs
shifts
Source code in src/splifft/models/basic_pitch.py
51
52
53
54
55
56
57
58
59
60
def __init__(
    self,
    *,
    bins_per_semitone: int,
    harmonics: tuple[float, ...],
    n_output_freqs: int,
):
    super().__init__()
    self.n_output_freqs = n_output_freqs
    self.shifts = [int(round(12.0 * bins_per_semitone * math.log2(h))) for h in harmonics]
n_output_freqs instance-attribute
n_output_freqs = n_output_freqs
shifts instance-attribute
shifts = [
    (int(round(12.0 * bins_per_semitone * log2(h))))
    for h in harmonics
]
forward
forward(x: Tensor) -> Tensor

Parameters:

Name Type Description Default
x Tensor

(B, T, F)

required

Returns:

Type Description
Tensor

(B, H, T, F_out)

Source code in src/splifft/models/basic_pitch.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    :param x: (B, T, F)
    :return: (B, H, T, F_out)
    """
    stacked: list[torch.Tensor] = []
    for shift in self.shifts:
        if shift == 0:
            shifted = x
        elif shift > 0:
            shifted = F.pad(x[:, :, shift:], (0, shift))
        else:
            shifted = F.pad(x[:, :, :shift], (-shift, 0))
        stacked.append(shifted[:, :, : self.n_output_freqs])
    return torch.stack(stacked, dim=1)

BasicPitch

BasicPitch(cfg: BasicPitchParams)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
cfg
n_contour_bins
cqt_n_bins
stack_harmonics
hs
bn_layer
conv_contour
conv_note
conv_onset_pre
conv_onset_post
Source code in src/splifft/models/basic_pitch.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def __init__(self, cfg: BasicPitchParams):
    super().__init__()
    self.cfg = cfg
    if cfg.n_harmonics == 1:
        stack_harmonics = (1.0,)
        harmonic_context_semitones = 0
    else:
        stack_harmonics = (0.5, *(float(harmonic) for harmonic in range(1, cfg.n_harmonics)))
        harmonic_context_semitones = math.ceil(12.0 * math.log2(cfg.n_harmonics))
    self.n_contour_bins = cfg.n_semitones * cfg.contour_bins_per_semitone
    self.cqt_n_bins = (cfg.n_semitones + harmonic_context_semitones) * cfg.cqt_bins_per_semitone
    self.stack_harmonics = stack_harmonics
    self.hs = HarmonicStacking(
        bins_per_semitone=cfg.cqt_bins_per_semitone,
        harmonics=self.stack_harmonics,
        n_output_freqs=self.n_contour_bins,
    )
    self.bn_layer = nn.BatchNorm2d(1, eps=0.001)

    num_in_channels = len(self.stack_harmonics)
    # NOTE: the first contour Conv-BN-ReLU block was accidentally skipped in the released
    # Basic Pitch checkpoints. we keep the reduced contour stack here to stay weight-compatible
    # with those shipped assets.
    # See: https://github.com/spotify/basic-pitch/issues/21, https://github.com/spotify/basic-pitch/pull/180
    self.conv_contour = nn.Sequential(
        nn.Conv2d(num_in_channels, 8, kernel_size=(3, 39), padding="same"),
        nn.BatchNorm2d(8, eps=0.001),
        nn.ReLU(),
        nn.Conv2d(8, 1, kernel_size=5, padding="same"),
        nn.Sigmoid(),
    )
    self.conv_note = nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=7, stride=(1, 3)),
        nn.ReLU(),
        nn.Conv2d(32, 1, kernel_size=(7, 3), padding="same"),
        nn.Sigmoid(),
    )
    self.conv_onset_pre = nn.Sequential(
        nn.Conv2d(num_in_channels, 32, kernel_size=5, stride=(1, 3)),
        nn.BatchNorm2d(32, eps=0.001),
        nn.ReLU(),
    )
    self.conv_onset_post = nn.Sequential(
        nn.Conv2d(33, 1, kernel_size=3, stride=1, padding="same"),
        nn.Sigmoid(),
    )
cfg instance-attribute
cfg = cfg
n_contour_bins instance-attribute
n_contour_bins = n_semitones * contour_bins_per_semitone
cqt_n_bins instance-attribute
cqt_n_bins = (
    n_semitones + harmonic_context_semitones
) * cqt_bins_per_semitone
stack_harmonics instance-attribute
stack_harmonics = stack_harmonics
hs instance-attribute
hs = HarmonicStacking(
    bins_per_semitone=cqt_bins_per_semitone,
    harmonics=stack_harmonics,
    n_output_freqs=n_contour_bins,
)
bn_layer instance-attribute
bn_layer = BatchNorm2d(1, eps=0.001)
conv_contour instance-attribute
conv_contour = Sequential(
    Conv2d(
        num_in_channels,
        8,
        kernel_size=(3, 39),
        padding="same",
    ),
    BatchNorm2d(8, eps=0.001),
    ReLU(),
    Conv2d(8, 1, kernel_size=5, padding="same"),
    Sigmoid(),
)
conv_note instance-attribute
conv_note = Sequential(
    Conv2d(1, 32, kernel_size=7, stride=(1, 3)),
    ReLU(),
    Conv2d(32, 1, kernel_size=(7, 3), padding="same"),
    Sigmoid(),
)
conv_onset_pre instance-attribute
conv_onset_pre = Sequential(
    Conv2d(
        num_in_channels, 32, kernel_size=5, stride=(1, 3)
    ),
    BatchNorm2d(32, eps=0.001),
    ReLU(),
)
conv_onset_post instance-attribute
conv_onset_post = Sequential(
    Conv2d(33, 1, kernel_size=3, stride=1, padding="same"),
    Sigmoid(),
)
forward
forward(x: Tensor) -> tuple[Tensor, ...]
Source code in src/splifft/models/basic_pitch.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]:
    if x.ndim != 3:
        raise ValueError(f"expected `(B,T,F)` input, got {tuple(x.shape)}")
    if x.shape[-1] != self.cqt_n_bins:
        raise ValueError(f"expected feature dim {self.cqt_n_bins}, got {x.shape[-1]}")

    x = self._normalize_log_features(x)
    x = self.bn_layer(x.unsqueeze(1)).squeeze(1)

    cqt = self.hs(x)

    contour = self.conv_contour(cqt)

    contour_for_note = F.pad(contour, (2, 2, 3, 3))
    note = self.conv_note(contour_for_note)

    cqt_for_onset = F.pad(cqt, (1, 1, 2, 2))
    onset_pre = self.conv_onset_pre(cqt_for_onset)
    onset_in = torch.cat((note, onset_pre), dim=1)
    onset = self.conv_onset_post(onset_in)

    contour_out = contour.squeeze(1)
    note_out = note.squeeze(1)
    onset_out = onset.squeeze(1)

    return onset_out, note_out, contour_out

utils

Modules:

Name Description
attend
cqt

Utilities for Constant-Q transform.

hyperace

HyperACE segmentation backbones for BS-RoFormer mask heads.

rotary
stft
tfc_tdf

Time-Frequency Convolutions and Time-Distributed Fully-connected (TFC-TDF)

Functions:

Name Description
parse_version
log_once

parse_version

parse_version(v_str: str) -> tuple[int, ...]
Source code in src/splifft/models/utils/__init__.py
6
7
8
def parse_version(v_str: str) -> tuple[int, ...]:
    # e.g "2.1.0+cu118" -> (2, 1, 0)
    return tuple(map(int, v_str.split("+")[0].split(".")))

log_once cached

log_once(
    logger: Logger, msg: object, *, level: int = DEBUG
) -> None
Source code in src/splifft/models/utils/__init__.py
11
12
13
@lru_cache(10)
def log_once(logger: Logger, msg: object, *, level: int = logging.DEBUG) -> None:
    logger.log(level, msg)

rotary

Classes:

Name Description
RotaryEmbedding

A performance-oriented version of RoPE for source-separation models.

Functions:

Name Description
rotate_half
apply_rotary_embedding
rotate_half
rotate_half(x: Tensor) -> Tensor
Source code in src/splifft/models/utils/rotary.py
11
12
13
14
15
def rotate_half(x: Tensor) -> Tensor:
    x = rearrange(x, "... (d r) -> ... d r", r=2)
    x1, x2 = x.unbind(dim=-1)
    x = torch.stack((-x2, x1), dim=-1)
    return rearrange(x, "... d r -> ... (d r)")
apply_rotary_embedding
apply_rotary_embedding(
    x: Tensor,
    *,
    cos: Tensor,
    sin: Tensor,
    sum_dtype: dtype | None = None,
) -> Tensor
Source code in src/splifft/models/utils/rotary.py
18
19
20
21
22
23
24
25
26
27
28
29
def apply_rotary_embedding(
    x: Tensor,
    *,
    cos: Tensor,
    sin: Tensor,
    sum_dtype: torch.dtype | None = None,
) -> Tensor:
    term1 = x * cos
    term2 = rotate_half(x) * sin
    if sum_dtype is None:
        return term1 + term2
    return (term1.to(sum_dtype) + term2.to(sum_dtype)).to(x.dtype)
RotaryEmbedding
RotaryEmbedding(
    seq_len: int,
    dim_head: int,
    *,
    dtype: dtype | None = None,
    theta: int = 10000,
    device: device | None = None,
)

Bases: Module

A performance-oriented version of RoPE for source-separation models.

Unlike lucidrains' implementation which computes embeddings JIT during the forward pass and caches calls with the same or shorter sequence length, we simply compute them AOT as persistent buffers. To keep the computational graph clean, we do not support dynamic sequence lengths, learned frequencies or length extrapolation.

Methods:

Name Description
embeddings
forward

Attributes:

Name Type Description
cos_emb Tensor
sin_emb Tensor
Source code in src/splifft/models/utils/rotary.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def __init__(
    self,
    seq_len: int,
    dim_head: int,
    *,
    dtype: torch.dtype | None = None,
    theta: int = 10000,
    device: torch.device | None = None,
):
    super().__init__()
    if dim_head % 2 != 0:
        raise ValueError(f"rotary dimension must be even, got {dim_head}")
    base_freqs = 1.0 / (
        theta ** (torch.arange(0, dim_head, 2, device=device).float() / dim_head)
    )
    positions = torch.arange(seq_len, dtype=base_freqs.dtype, device=base_freqs.device)
    freqs = torch.einsum("i,j->ij", positions, base_freqs)
    freqs = repeat(freqs, "... d -> ... (d r)", r=2)

    # COMPAT: the original implementation does not generate the embeddings
    # on the fly, but serialises them in fp16. there are some tiny
    # differences:
    # |                     |   from weights  |   generated    |
    # | ------------------- | --------------- | -------------- |
    # | cos_emb_time:971,22 | -0.99462890625  | -0.994140625   |
    # | cos_emb_time:971,23 | -0.99462890625  | -0.994140625   |
    # | sin_emb_time:727,12 | -0.457763671875 | -0.4580078125  |
    # | sin_emb_time:727,13 | -0.457763671875 | -0.4580078125  |
    # | sin_emb_time:825,4  | -0.8544921875   | -0.85400390625 |
    # | sin_emb_time:825,5  | -0.8544921875   | -0.85400390625 |
    target_dtype = torch.float16 if dtype is None else dtype
    self.register_buffer("cos_emb", freqs.cos().to(target_dtype), persistent=True)
    self.register_buffer("sin_emb", freqs.sin().to(target_dtype), persistent=True)
    self.cos_emb: torch.Tensor
    self.sin_emb: torch.Tensor
cos_emb instance-attribute
cos_emb: Tensor
sin_emb instance-attribute
sin_emb: Tensor
embeddings
embeddings(
    x: Tensor, *, start_index: int = 0
) -> tuple[Tensor, Tensor]
Source code in src/splifft/models/utils/rotary.py
78
79
80
81
82
83
84
85
86
87
88
def embeddings(self, x: Tensor, *, start_index: int = 0) -> tuple[Tensor, Tensor]:
    seq_len = x.shape[-2]
    end_index = start_index + seq_len
    if end_index > self.cos_emb.shape[0]:
        raise ValueError(
            f"requested rotary positions up to {end_index}, "
            f"but only {self.cos_emb.shape[0]} are available"
        )
    cos = self.cos_emb[start_index:end_index].unsqueeze(0).unsqueeze(0).to(x.device, x.dtype)
    sin = self.sin_emb[start_index:end_index].unsqueeze(0).unsqueeze(0).to(x.device, x.dtype)
    return cos, sin
forward
forward(x: Tensor, *, start_index: int = 0) -> Tensor
Source code in src/splifft/models/utils/rotary.py
90
91
92
93
94
95
def forward(self, x: Tensor, *, start_index: int = 0) -> Tensor:
    cos, sin = self.embeddings(x, start_index=start_index)
    # NOTE: original impl performed addition between two f32s but it comes
    # with 30% slowdown. we eliminate it so the addition is performed in the
    # query / key dtype instead.
    return apply_rotary_embedding(x, cos=cos, sin=sin)

attend

Classes:

Name Description
Attend

Attributes:

Name Type Description
logger
AttentionBackend
logger module-attribute
logger = getLogger(__name__)
AttentionBackend module-attribute
AttentionBackend = Literal['math', 'sdpa', 'sage']
Attend
Attend(
    dropout: float = 0.0,
    backend: AttentionBackend = "math",
    scale: float | None = None,
)

Bases: Module

Methods:

Name Description
forward

einstein notation

Attributes:

Name Type Description
scale
dropout
attn_dropout
backend
use_sage
use_sdpa
cpu_backends
cuda_backends list[_SDPBackend] | None
Source code in src/splifft/models/utils/attend.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __init__(
    self,
    dropout: float = 0.0,
    backend: AttentionBackend = "math",
    scale: float | None = None,
) -> None:
    super().__init__()
    self.scale = scale
    self.dropout = dropout
    self.attn_dropout = nn.Dropout(dropout)

    self.backend = backend

    if self.backend == "sage" and not _has_sage_attention:
        log_once(logger, "sageattention requested but not found. falling back to sdpa.")
        self.backend = "sdpa"

    if self.backend == "sdpa" and parse_version(torch.__version__) < (2, 0, 0):
        log_once(logger, "sdpa requested but pytorch < 2.0.0. falling back to math.")
        self.backend = "math"

    self.use_sage = self.backend == "sage"
    self.use_sdpa = self.backend == "sdpa"

    self.cpu_backends = [
        SDPBackend.FLASH_ATTENTION,
        SDPBackend.EFFICIENT_ATTENTION,
        SDPBackend.MATH,
    ]
    self.cuda_backends: list[_SDPBackend] | None = None

    if not torch.cuda.is_available() or not self.use_sdpa:
        return

    device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
    device_version = parse_version(f"{device_properties.major}.{device_properties.minor}")

    if device_version >= (8, 0):
        if os.name == "nt":
            cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
            log_once(logger, f"windows detected, using {cuda_backends=}")
        else:
            cuda_backends = [SDPBackend.FLASH_ATTENTION, SDPBackend.MATH]
            log_once(logger, f"gpu compute capability >= 8.0, using {cuda_backends=}")
    else:
        cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
        log_once(logger, f"gpu compute capability < 8.0, using {cuda_backends=}")

    self.cuda_backends = cuda_backends
scale instance-attribute
scale = scale
dropout instance-attribute
dropout = dropout
attn_dropout instance-attribute
attn_dropout = Dropout(dropout)
backend instance-attribute
backend = backend
use_sage instance-attribute
use_sage = backend == 'sage'
use_sdpa instance-attribute
use_sdpa = backend == 'sdpa'
cpu_backends instance-attribute
cpu_backends = [FLASH_ATTENTION, EFFICIENT_ATTENTION, MATH]
cuda_backends instance-attribute
cuda_backends: list[_SDPBackend] | None = cuda_backends
forward
forward(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    *,
    is_causal: bool = False,
) -> Tensor

einstein notation

  • b: batch
  • h: heads
  • i, j: sequence length (source, target)
  • d: feature dimension
Source code in src/splifft/models/utils/attend.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def forward(self, q: Tensor, k: Tensor, v: Tensor, *, is_causal: bool = False) -> Tensor:
    """
    einstein notation

    - b: batch
    - h: heads
    - i, j: sequence length (source, target)
    - d: feature dimension
    """
    if self.use_sage:
        # NOTE: sage requires (is_causal and q.shape[-2] != k.shape[-2]) and q.shape[-1] in (64, 96, 128)
        # and q.dtype in (torch.float16, torch.bfloat16) but we just fail through for now.
        return sageattn(  # type: ignore
            q, k, v, tensor_layout="HND", is_causal=is_causal, sm_scale=self.scale
        )

    if self.use_sdpa:
        is_cuda = q.is_cuda
        backends = self.cuda_backends if is_cuda else self.cpu_backends
        with sdpa_kernel(backends=backends):  # type: ignore
            return F.scaled_dot_product_attention(
                q,
                k,
                v,
                dropout_p=self.dropout if self.training else 0.0,
                is_causal=is_causal,
                scale=self.scale,
            )

    scale = self.scale or q.shape[-1] ** -0.5

    # similarity
    sim = torch.einsum("b h i d, b h j d -> b h i j", q, k) * scale

    if is_causal:
        i, j = q.shape[-2], k.shape[-2]
        causal_mask = torch.ones((i, j), device=q.device, dtype=torch.bool).triu_(
            diagonal=j - i + 1
        )
        sim.masked_fill_(causal_mask, torch.finfo(sim.dtype).min)

    # attention
    attn = sim.softmax(dim=-1)
    attn = self.attn_dropout(attn)

    # aggregate values
    out = torch.einsum("b h i j, b h j d -> b h i d", attn, v)

    return out

tfc_tdf

Time-Frequency Convolutions and Time-Distributed Fully-connected (TFC-TDF)

See: https://arxiv.org/pdf/2306.09382

Classes:

Name Description
TfcTdfBlock
TfcTdf

Functions:

Name Description
instance_norm_factory
silu_factory

Attributes:

Name Type Description
NormFactory
ActFactory
NormFactory module-attribute
NormFactory = Callable[[int], Module]
ActFactory module-attribute
ActFactory = Callable[[], Module]
instance_norm_factory
instance_norm_factory(channels: int) -> Module
Source code in src/splifft/models/utils/tfc_tdf.py
16
17
def instance_norm_factory(channels: int) -> nn.Module:
    return nn.InstanceNorm2d(channels, affine=True, eps=1e-8)
silu_factory
silu_factory() -> Module
Source code in src/splifft/models/utils/tfc_tdf.py
20
21
def silu_factory() -> nn.Module:
    return nn.SiLU()
TfcTdfBlock
TfcTdfBlock(
    in_channels: int,
    out_channels: int,
    f_bins: int,
    bottleneck_factor: int,
    *,
    norm_factory: NormFactory,
    act_factory: ActFactory,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
tfc1
tdf
tfc2
shortcut
Source code in src/splifft/models/utils/tfc_tdf.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    f_bins: int,
    bottleneck_factor: int,
    *,
    norm_factory: NormFactory,
    act_factory: ActFactory,
):
    super().__init__()

    self.tfc1 = nn.Sequential(
        norm_factory(in_channels),
        act_factory(),
        nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
    )
    self.tdf = nn.Sequential(
        norm_factory(out_channels),
        act_factory(),
        nn.Linear(f_bins, f_bins // bottleneck_factor, bias=False),
        norm_factory(out_channels),
        act_factory(),
        nn.Linear(f_bins // bottleneck_factor, f_bins, bias=False),
    )
    self.tfc2 = nn.Sequential(
        norm_factory(out_channels),
        act_factory(),
        nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
    )
    self.shortcut = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)
tfc1 instance-attribute
tfc1 = Sequential(
    norm_factory(in_channels),
    act_factory(),
    Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
)
tdf instance-attribute
tdf = Sequential(
    norm_factory(out_channels),
    act_factory(),
    Linear(f_bins, f_bins // bottleneck_factor, bias=False),
    norm_factory(out_channels),
    act_factory(),
    Linear(f_bins // bottleneck_factor, f_bins, bias=False),
)
tfc2 instance-attribute
tfc2 = Sequential(
    norm_factory(out_channels),
    act_factory(),
    Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
)
shortcut instance-attribute
shortcut = Conv2d(
    in_channels, out_channels, 1, 1, 0, bias=False
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/utils/tfc_tdf.py
57
58
59
60
61
62
63
def forward(self, x: Tensor) -> Tensor:
    s = self.shortcut(x)
    x = self.tfc1(x)
    x = x + self.tdf(x)
    x = self.tfc2(x)
    x = x + s
    return x
TfcTdf
TfcTdf(
    in_channels: int,
    out_channels: int,
    num_blocks: int,
    f_bins: int,
    bottleneck_factor: int,
    *,
    norm_factory: NormFactory,
    act_factory: ActFactory,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
blocks
Source code in src/splifft/models/utils/tfc_tdf.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    num_blocks: int,
    f_bins: int,
    bottleneck_factor: int,
    *,
    norm_factory: NormFactory,
    act_factory: ActFactory,
):
    super().__init__()

    self.blocks = nn.ModuleList(
        [
            TfcTdfBlock(
                in_channels=in_channels if i == 0 else out_channels,
                out_channels=out_channels,
                f_bins=f_bins,
                bottleneck_factor=bottleneck_factor,
                norm_factory=norm_factory,
                act_factory=act_factory,
            )
            for i in range(num_blocks)
        ]
    )
blocks instance-attribute
blocks = ModuleList(
    [
        (
            TfcTdfBlock(
                in_channels=in_channels
                if i == 0
                else out_channels,
                out_channels=out_channels,
                f_bins=f_bins,
                bottleneck_factor=bottleneck_factor,
                norm_factory=norm_factory,
                act_factory=act_factory,
            )
        )
        for i in (range(num_blocks))
    ]
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/utils/tfc_tdf.py
94
95
96
97
def forward(self, x: Tensor) -> Tensor:
    for block in self.blocks:
        x = block(x)
    return x

cqt

Utilities for Constant-Q transform.

  1. Schörkhuber, C. and Klapuri, A., 2010. Constant-Q transform toolbox for music processing. In Proc. 7th Sound and Music Computing Conf.

Classes:

Name Description
RegularCqtKernels
RegularCQT
HarmonicRegularCQT
RecursiveCqtKernels
RecursiveCQT

Recursive octave-decimation CQT.

HarmonicRecursiveCQT
RegularCqtKernels

Bases: NamedTuple

Attributes:

Name Type Description
kernels Tensor
fft_length int
sqrt_lengths Tensor
frequencies Tensor
kernels instance-attribute
kernels: Tensor
fft_length instance-attribute
fft_length: int
sqrt_lengths instance-attribute
sqrt_lengths: Tensor
frequencies instance-attribute
frequencies: Tensor
RegularCQT
RegularCQT(
    *,
    sr: SampleRate,
    hop_length: HopSize,
    fmin: Gt0[float],
    n_bins: Gt0[int],
    bins_per_octave: Gt0[int],
    gamma: Ge0[float],
    center: bool,
    filter_scale: Gt0[float] = 1.0,
    window: WindowShape = "hann",
    norm: int = 1,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
n_bins
conv
Source code in src/splifft/models/utils/cqt.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def __init__(
    self,
    *,
    sr: t.SampleRate,
    hop_length: t.HopSize,
    fmin: t.Gt0[float],
    n_bins: t.Gt0[int],
    bins_per_octave: t.Gt0[int],
    gamma: t.Ge0[float],
    center: bool,
    filter_scale: t.Gt0[float] = 1.0,
    window: t.WindowShape = "hann",
    norm: int = 1,
):
    super().__init__()
    self.n_bins = n_bins

    Q = filter_scale / (2 ** (1 / bins_per_octave) - 1)
    kernel_bank = _create_regular_cqt_kernels(
        Q=Q,
        fs=sr,
        fmin=fmin,
        n_bins=n_bins,
        bins_per_octave=bins_per_octave,
        norm=norm,
        window=window,
        gamma=gamma,
    )

    self.register_buffer("sqrt_lengths", kernel_bank.sqrt_lengths, persistent=False)
    self.register_buffer("kernels", kernel_bank.kernels, persistent=False)
    self.register_buffer("frequencies", kernel_bank.frequencies, persistent=False)

    padding = kernel_bank.fft_length // 2 if center else 0
    self.conv = nn.Conv1d(
        1,
        2 * n_bins,
        kernel_size=kernel_bank.fft_length,
        stride=hop_length,
        padding=padding,
        padding_mode="reflect",
        bias=False,
    )
    self._init_weights()
n_bins instance-attribute
n_bins = n_bins
conv instance-attribute
conv = Conv1d(
    1,
    2 * n_bins,
    kernel_size=fft_length,
    stride=hop_length,
    padding=padding,
    padding_mode="reflect",
    bias=False,
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/utils/cqt.py
159
160
161
162
163
164
165
166
167
168
169
def forward(self, x: Tensor) -> Tensor:
    if x.ndim == 1:
        x = x.unsqueeze(0)
    if x.ndim == 2:
        x = x.unsqueeze(1)
    if x.ndim != 3:
        raise ValueError(f"expected shape (batch,time) or (batch,1,time), got {tuple(x.shape)}")

    cqt = self.conv(x).view(x.size(0), 2, self.n_bins, -1)
    cqt = cqt * self.sqrt_lengths.to(cqt.device)
    return cast(Tensor, cqt.permute(0, 2, 3, 1))
HarmonicRegularCQT
HarmonicRegularCQT(
    *,
    harmonics: Sequence[Gt0[int]],
    sr: SampleRate,
    hop_length: HopSize,
    fmin: Gt0[float],
    bins_per_semitone: Gt0[int],
    n_bins: Gt0[int],
    center_bins: bool,
    gamma: Ge0[float],
    center: bool,
    filter_scale: Gt0[float] = 1.0,
    window: WindowShape = "hann",
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
cqt_layers
Source code in src/splifft/models/utils/cqt.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def __init__(
    self,
    *,
    harmonics: Sequence[t.Gt0[int]],
    sr: t.SampleRate,
    hop_length: t.HopSize,
    fmin: t.Gt0[float],
    bins_per_semitone: t.Gt0[int],
    n_bins: t.Gt0[int],
    center_bins: bool,
    gamma: t.Ge0[float],
    center: bool,
    filter_scale: t.Gt0[float] = 1.0,
    window: t.WindowShape = "hann",
):
    super().__init__()
    fmin = _resolve_centered_fmin(
        fmin,
        bins_per_semitone=bins_per_semitone,
        center_bins=center_bins,
    )

    self.cqt_layers = nn.ModuleList(
        [
            RegularCQT(
                sr=sr,
                hop_length=hop_length,
                fmin=h * fmin,
                n_bins=n_bins,
                bins_per_octave=12 * bins_per_semitone,
                gamma=gamma,
                center=center,
                filter_scale=filter_scale,
                window=window,
            )
            for h in harmonics
        ]
    )
cqt_layers instance-attribute
cqt_layers = ModuleList(
    [
        (
            RegularCQT(
                sr=sr,
                hop_length=hop_length,
                fmin=h * fmin,
                n_bins=n_bins,
                bins_per_octave=12 * bins_per_semitone,
                gamma=gamma,
                center=center,
                filter_scale=filter_scale,
                window=window,
            )
        )
        for h in harmonics
    ]
)
forward
forward(audio_waveforms: Tensor) -> Tensor
Source code in src/splifft/models/utils/cqt.py
212
213
def forward(self, audio_waveforms: Tensor) -> Tensor:
    return torch.stack([cqt(audio_waveforms) for cqt in self.cqt_layers], dim=1)
RecursiveCqtKernels

Bases: NamedTuple

Attributes:

Name Type Description
kernels ndarray
fft_length int
kernels instance-attribute
kernels: ndarray
fft_length instance-attribute
fft_length: int
RecursiveCQT
RecursiveCQT(
    *,
    sr: SampleRate,
    hop_length: HopSize,
    fmin: Gt0[float],
    n_bins: Gt0[int],
    bins_per_octave: Gt0[int],
    center: bool,
    filter_scale: Gt0[float] = 1.0,
    basis_norm: int = 1,
)

Bases: Module

Recursive octave-decimation CQT.

This follows the octave structure in [1], but like nnAudio it uses FIR anti-alias filtering and per-frame reflect padding instead of the paper's forward-reverse Butterworth decimation filter and global zero-padding scheme

Methods:

Name Description
forward

Attributes:

Name Type Description
sample_rate
hop_length
fmin
n_bins
bins_per_octave
center
n_octaves
n_fft
Source code in src/splifft/models/utils/cqt.py
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
def __init__(
    self,
    *,
    sr: t.SampleRate,
    hop_length: t.HopSize,
    fmin: t.Gt0[float],
    n_bins: t.Gt0[int],
    bins_per_octave: t.Gt0[int],
    center: bool,
    filter_scale: t.Gt0[float] = 1.0,
    basis_norm: int = 1,
):
    super().__init__()

    self.sample_rate = sr
    self.hop_length = hop_length
    self.fmin = fmin
    self.n_bins = n_bins
    self.bins_per_octave = bins_per_octave
    self.center = center

    Q = float(filter_scale) / (2 ** (1 / bins_per_octave) - 1)
    n_filters = min(bins_per_octave, n_bins)
    self.n_octaves = int(np.ceil(float(n_bins) / bins_per_octave))

    fmin_top = fmin * 2 ** (self.n_octaves - 1)
    remainder = n_bins % bins_per_octave
    if remainder == 0:
        fmax_top = fmin_top * 2 ** ((bins_per_octave - 1) / bins_per_octave)
    else:
        fmax_top = fmin_top * 2 ** ((remainder - 1) / bins_per_octave)
    fmin_top = fmax_top / 2 ** (1 - 1 / bins_per_octave)

    if fmax_top > sr / 2:
        raise ValueError(f"top CQT bin {fmax_top} exceeds Nyquist for sample_rate={sr}")

    window_bandwidth = 1.5
    filter_cutoff = fmax_top * (1 + 0.5 * window_bandwidth / Q)
    if (
        _early_downsample_count(
            nyquist_hz=sr // 2,
            filter_cutoff_hz=filter_cutoff,
            hop_length=hop_length,
            n_octaves=self.n_octaves,
        )
        != 0
    ):
        raise NotImplementedError(
            "RecursiveCQT early downsampling is not implemented for this frontend"
        )

    kernel_bank = _create_recursive_cqt_kernels(
        Q=Q,
        fs=sr,
        fmin=fmin_top,
        n_bins=n_filters,
        bins_per_octave=bins_per_octave,
        norm=basis_norm,
    )
    self.n_fft = kernel_bank.fft_length

    freqs = fmin * 2.0 ** (np.r_[0:n_bins] / float(bins_per_octave))
    lengths = np.ceil(Q * sr / freqs)
    self.register_buffer(
        "frequencies", torch.tensor(freqs, dtype=torch.float32), persistent=False
    )

    self.register_buffer(
        "lengths", torch.tensor(lengths, dtype=torch.float32), persistent=False
    )
    self.register_buffer(
        "cqt_kernels_real",
        torch.tensor(np.real(kernel_bank.kernels), dtype=torch.float32).unsqueeze(1),
        persistent=False,
    )
    self.register_buffer(
        "cqt_kernels_imag",
        torch.tensor(np.imag(kernel_bank.kernels), dtype=torch.float32).unsqueeze(1),
        persistent=False,
    )
    self.register_buffer(
        "lowpass_filter",
        _create_recursive_lowpass_filter(
            band_center=0.5, kernel_length=256, transition_bandwidth=0.001
        ),
        persistent=False,
    )
sample_rate instance-attribute
sample_rate = sr
hop_length instance-attribute
hop_length = hop_length
fmin instance-attribute
fmin = fmin
n_bins instance-attribute
n_bins = n_bins
bins_per_octave instance-attribute
bins_per_octave = bins_per_octave
center instance-attribute
center = center
n_octaves instance-attribute
n_octaves = int(ceil(float(n_bins) / bins_per_octave))
n_fft instance-attribute
n_fft = fft_length
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/utils/cqt.py
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
def forward(self, x: Tensor) -> Tensor:
    if x.ndim == 1:
        x = x.unsqueeze(0)
    if x.ndim == 2:
        x = x.unsqueeze(1)
    if x.ndim != 3:
        raise ValueError(f"expected shape (batch,time) or (batch,1,time), got {tuple(x.shape)}")

    cqt_kernels_real = cast(Tensor, self.cqt_kernels_real)
    cqt_kernels_imag = cast(Tensor, self.cqt_kernels_imag)
    lowpass_filter = cast(Tensor, self.lowpass_filter)
    lengths = cast(Tensor, self.lengths)

    padding = self.n_fft // 2 if self.center else 0
    hop = self.hop_length
    cqt = _get_recursive_cqt_complex(
        x,
        cqt_kernels_real,
        cqt_kernels_imag,
        hop_length=hop,
        padding=padding,
    )

    x_down = x
    for _ in range(self.n_octaves - 1):
        hop //= 2
        x_down = _downsampling_by_n(x_down, lowpass_filter, 2)
        next_octave = _get_recursive_cqt_complex(
            x_down,
            cqt_kernels_real,
            cqt_kernels_imag,
            hop_length=hop,
            padding=padding,
        )
        cqt = torch.cat((next_octave, cqt), dim=1)

    cqt = cqt[:, -self.n_bins :, :]
    cqt = cqt * torch.sqrt(lengths.view(1, -1, 1, 1).to(cqt.device))
    return cqt
HarmonicRecursiveCQT
HarmonicRecursiveCQT(
    *,
    harmonics: Sequence[Gt0[int]],
    sr: SampleRate,
    hop_length: HopSize,
    fmin: Gt0[float],
    bins_per_semitone: Gt0[int],
    n_bins: Gt0[int],
    center_bins: bool,
    center: bool,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
cqt_layers
Source code in src/splifft/models/utils/cqt.py
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
def __init__(
    self,
    *,
    harmonics: Sequence[t.Gt0[int]],
    sr: t.SampleRate,
    hop_length: t.HopSize,
    fmin: t.Gt0[float],
    bins_per_semitone: t.Gt0[int],
    n_bins: t.Gt0[int],
    center_bins: bool,
    center: bool,
):
    super().__init__()
    fmin = _resolve_centered_fmin(
        fmin,
        bins_per_semitone=bins_per_semitone,
        center_bins=center_bins,
    )

    self.cqt_layers = nn.ModuleList(
        [
            RecursiveCQT(
                sr=sr,
                hop_length=hop_length,
                fmin=h * fmin,
                n_bins=n_bins,
                bins_per_octave=12 * bins_per_semitone,
                center=center,
            )
            for h in harmonics
        ]
    )
cqt_layers instance-attribute
cqt_layers = ModuleList(
    [
        (
            RecursiveCQT(
                sr=sr,
                hop_length=hop_length,
                fmin=h * fmin,
                n_bins=n_bins,
                bins_per_octave=12 * bins_per_semitone,
                center=center,
            )
        )
        for h in harmonics
    ]
)
forward
forward(audio_waveforms: Tensor) -> Tensor
Source code in src/splifft/models/utils/cqt.py
485
486
def forward(self, audio_waveforms: Tensor) -> Tensor:
    return torch.stack([cqt(audio_waveforms) for cqt in self.cqt_layers], dim=1)

hyperace

HyperACE segmentation backbones for BS-RoFormer mask heads.

These modules are compatibility shims for unwa variants trained in msst (hyperace_v1, hyperace_v2, and large_inst_v2 head behavior). They are kept separate from the core transformer stack because they are used only by a small subset of checkpoints.

See: https://huggingface.co/pcunwa/BS-Roformer-HyperACE and https://arxiv.org/abs/2506.17733

Classes:

Name Description
Conv
DSConv
DS_Bottleneck
DS_C3k
DS_C3k2
AdaptiveHyperedgeGeneration
HypergraphConvolution
AdaptiveHypergraphComputation
C3AH
HyperACE
GatedFusion
BackboneHyperAceV1
BackboneHyperAceV2
DecoderHyperAce
FreqPixelShuffleV1
ProgressiveUpsampleHeadV1
FreqPixelShuffleV2
ProgressiveUpsampleHeadV2
SegmModelHyperAceV1
SegmModelHyperAceV2

Functions:

Name Description
autopad
build_hyperace_tfc_tdf
autopad
autopad(
    k: int | tuple[int, int],
    p: int | tuple[int, int] | None = None,
) -> int | tuple[int, int]
Source code in src/splifft/models/utils/hyperace.py
24
25
26
27
28
29
30
def autopad(
    k: int | tuple[int, int],
    p: int | tuple[int, int] | None = None,
) -> int | tuple[int, int]:
    if p is None:
        p = k // 2 if isinstance(k, int) else (k[0] // 2, k[1] // 2)
    return p
build_hyperace_tfc_tdf
build_hyperace_tfc_tdf(
    in_c: int, c: int, l: int, f: int, bn: int = 4
) -> TfcTdf
Source code in src/splifft/models/utils/hyperace.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def build_hyperace_tfc_tdf(
    in_c: int,
    c: int,
    l: int,
    f: int,
    bn: int = 4,
) -> TfcTdf:
    return TfcTdf(
        in_channels=in_c,
        out_channels=c,
        num_blocks=l,
        f_bins=f,
        bottleneck_factor=bn,
        norm_factory=instance_norm_factory,
        act_factory=silu_factory,
    )
Conv
Conv(
    c1: int,
    c2: int,
    k: int | tuple[int, int] = 1,
    s: int | tuple[int, int] = 1,
    p: int | tuple[int, int] | None = None,
    g: int = 1,
    act: bool = True,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
conv
bn
act
Source code in src/splifft/models/utils/hyperace.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def __init__(
    self,
    c1: int,
    c2: int,
    k: int | tuple[int, int] = 1,
    s: int | tuple[int, int] = 1,
    p: int | tuple[int, int] | None = None,
    g: int = 1,
    act: bool = True,
):
    super().__init__()
    self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
    self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
    self.act = nn.SiLU() if act else nn.Identity()
conv instance-attribute
conv = Conv2d(
    c1, c2, k, s, autopad(k, p), groups=g, bias=False
)
bn instance-attribute
bn = InstanceNorm2d(c2, affine=True, eps=1e-08)
act instance-attribute
act = SiLU() if act else Identity()
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
67
68
def forward(self, x: Tensor) -> Any:
    return self.act(self.bn(self.conv(x)))
DSConv
DSConv(
    c1: int,
    c2: int,
    k: int | tuple[int, int] = 3,
    s: int | tuple[int, int] = 1,
    p: int | tuple[int, int] | None = None,
    act: bool = True,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
dwconv
pwconv
bn
act
Source code in src/splifft/models/utils/hyperace.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def __init__(
    self,
    c1: int,
    c2: int,
    k: int | tuple[int, int] = 3,
    s: int | tuple[int, int] = 1,
    p: int | tuple[int, int] | None = None,
    act: bool = True,
):
    super().__init__()
    self.dwconv = nn.Conv2d(c1, c1, k, s, autopad(k, p), groups=c1, bias=False)
    self.pwconv = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
    self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
    self.act = nn.SiLU() if act else nn.Identity()
dwconv instance-attribute
dwconv = Conv2d(
    c1, c1, k, s, autopad(k, p), groups=c1, bias=False
)
pwconv instance-attribute
pwconv = Conv2d(c1, c2, 1, 1, 0, bias=False)
bn instance-attribute
bn = InstanceNorm2d(c2, affine=True, eps=1e-08)
act instance-attribute
act = SiLU() if act else Identity()
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
87
88
def forward(self, x: Tensor) -> Any:
    return self.act(self.bn(self.pwconv(self.dwconv(x))))
DS_Bottleneck
DS_Bottleneck(
    c1: int, c2: int, k: int = 3, shortcut: bool = True
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
dsconv1
dsconv2
shortcut
Source code in src/splifft/models/utils/hyperace.py
92
93
94
95
96
97
def __init__(self, c1: int, c2: int, k: int = 3, shortcut: bool = True):
    super().__init__()
    c_ = c1
    self.dsconv1 = DSConv(c1, c_, k=3, s=1)
    self.dsconv2 = DSConv(c_, c2, k=k, s=1)
    self.shortcut = shortcut and c1 == c2
dsconv1 instance-attribute
dsconv1 = DSConv(c1, c_, k=3, s=1)
dsconv2 instance-attribute
dsconv2 = DSConv(c_, c2, k=k, s=1)
shortcut instance-attribute
shortcut = shortcut and c1 == c2
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
 99
100
101
102
def forward(self, x: Tensor) -> Any:
    if self.shortcut:
        return x + self.dsconv2(self.dsconv1(x))
    return self.dsconv2(self.dsconv1(x))
DS_C3k
DS_C3k(
    c1: int, c2: int, n: int = 1, k: int = 3, e: float = 0.5
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
cv1
cv2
cv3
m
Source code in src/splifft/models/utils/hyperace.py
106
107
108
109
110
111
112
def __init__(self, c1: int, c2: int, n: int = 1, k: int = 3, e: float = 0.5):
    super().__init__()
    c_ = int(c2 * e)
    self.cv1 = Conv(c1, c_, 1, 1)
    self.cv2 = Conv(c1, c_, 1, 1)
    self.cv3 = Conv(2 * c_, c2, 1, 1)
    self.m = nn.Sequential(*[DS_Bottleneck(c_, c_, k=k, shortcut=True) for _ in range(n)])
cv1 instance-attribute
cv1 = Conv(c1, c_, 1, 1)
cv2 instance-attribute
cv2 = Conv(c1, c_, 1, 1)
cv3 instance-attribute
cv3 = Conv(2 * c_, c2, 1, 1)
m instance-attribute
m = Sequential(
    *[
        (DS_Bottleneck(c_, c_, k=k, shortcut=True))
        for _ in (range(n))
    ]
)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
114
115
def forward(self, x: Tensor) -> Any:
    return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
DS_C3k2
DS_C3k2(
    c1: int, c2: int, n: int = 1, k: int = 3, e: float = 0.5
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
cv1
m
cv2
Source code in src/splifft/models/utils/hyperace.py
119
120
121
122
123
124
def __init__(self, c1: int, c2: int, n: int = 1, k: int = 3, e: float = 0.5):
    super().__init__()
    c_ = int(c2 * e)
    self.cv1 = Conv(c1, c_, 1, 1)
    self.m = DS_C3k(c_, c_, n=n, k=k, e=1.0)
    self.cv2 = Conv(c_, c2, 1, 1)
cv1 instance-attribute
cv1 = Conv(c1, c_, 1, 1)
m instance-attribute
m = DS_C3k(c_, c_, n=n, k=k, e=1.0)
cv2 instance-attribute
cv2 = Conv(c_, c2, 1, 1)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
126
127
128
129
def forward(self, x: Tensor) -> Any:
    x_ = self.cv1(x)
    x_ = self.m(x_)
    return self.cv2(x_)
AdaptiveHyperedgeGeneration
AdaptiveHyperedgeGeneration(
    in_channels: int,
    num_hyperedges: int,
    num_heads: int = 8,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
num_hyperedges
num_heads
head_dim
global_proto
context_mapper
query_proj
scale
Source code in src/splifft/models/utils/hyperace.py
133
134
135
136
137
138
139
140
141
142
143
144
145
def __init__(self, in_channels: int, num_hyperedges: int, num_heads: int = 8):
    super().__init__()
    self.num_hyperedges = num_hyperedges
    self.num_heads = num_heads
    self.head_dim = in_channels // num_heads

    self.global_proto = nn.Parameter(torch.randn(num_hyperedges, in_channels))

    self.context_mapper = nn.Linear(2 * in_channels, num_hyperedges * in_channels, bias=False)

    self.query_proj = nn.Linear(in_channels, in_channels, bias=False)

    self.scale = self.head_dim**-0.5
num_hyperedges instance-attribute
num_hyperedges = num_hyperedges
num_heads instance-attribute
num_heads = num_heads
head_dim instance-attribute
head_dim = in_channels // num_heads
global_proto instance-attribute
global_proto = Parameter(randn(num_hyperedges, in_channels))
context_mapper instance-attribute
context_mapper = Linear(
    2 * in_channels,
    num_hyperedges * in_channels,
    bias=False,
)
query_proj instance-attribute
query_proj = Linear(in_channels, in_channels, bias=False)
scale instance-attribute
scale = head_dim ** -0.5
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/utils/hyperace.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def forward(self, x: Tensor) -> Tensor:
    b, n, c = x.shape

    f_avg = F.adaptive_avg_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
    f_max = F.adaptive_max_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
    f_ctx = torch.cat((f_avg, f_max), dim=1)

    delta_p = self.context_mapper(f_ctx).view(b, self.num_hyperedges, c)
    p = self.global_proto.unsqueeze(0) + delta_p

    z = self.query_proj(x)

    z = z.view(b, n, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

    p = p.view(b, self.num_hyperedges, self.num_heads, self.head_dim).permute(0, 2, 3, 1)

    sim = (z @ p) * self.scale

    s_bar = sim.mean(dim=1)

    a = F.softmax(s_bar.permute(0, 2, 1), dim=-1)

    return a
HypergraphConvolution
HypergraphConvolution(in_channels: int, out_channels: int)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
W_e
W_v
act
Source code in src/splifft/models/utils/hyperace.py
173
174
175
176
177
def __init__(self, in_channels: int, out_channels: int):
    super().__init__()
    self.W_e = nn.Linear(in_channels, in_channels, bias=False)
    self.W_v = nn.Linear(in_channels, out_channels, bias=False)
    self.act = nn.SiLU()
W_e instance-attribute
W_e = Linear(in_channels, in_channels, bias=False)
W_v instance-attribute
W_v = Linear(in_channels, out_channels, bias=False)
act instance-attribute
act = SiLU()
forward
forward(x: Tensor, a: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
179
180
181
182
183
184
185
186
def forward(self, x: Tensor, a: Tensor) -> Any:
    f_m = torch.bmm(a, x)
    f_m = self.act(self.W_e(f_m))

    x_out = torch.bmm(a.transpose(1, 2), f_m)
    x_out = self.act(self.W_v(x_out))

    return x + x_out
AdaptiveHypergraphComputation
AdaptiveHypergraphComputation(
    in_channels: int,
    out_channels: int,
    num_hyperedges: int = 8,
    num_heads: int = 8,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
adaptive_hyperedge_gen
hypergraph_conv
Source code in src/splifft/models/utils/hyperace.py
190
191
192
193
194
195
196
197
198
199
200
201
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    num_hyperedges: int = 8,
    num_heads: int = 8,
):
    super().__init__()
    self.adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(
        in_channels, num_hyperedges, num_heads
    )
    self.hypergraph_conv = HypergraphConvolution(in_channels, out_channels)
adaptive_hyperedge_gen instance-attribute
adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(
    in_channels, num_hyperedges, num_heads
)
hypergraph_conv instance-attribute
hypergraph_conv = HypergraphConvolution(
    in_channels, out_channels
)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
203
204
205
206
207
208
209
210
211
212
def forward(self, x: Tensor) -> Any:
    b, _, h, w = x.shape
    x_flat = x.flatten(2).permute(0, 2, 1)

    a = self.adaptive_hyperedge_gen(x_flat)

    x_out_flat = self.hypergraph_conv(x_flat, a)

    x_out = x_out_flat.permute(0, 2, 1).view(b, -1, h, w)
    return x_out
C3AH
C3AH(
    c1: int,
    c2: int,
    num_hyperedges: int = 8,
    num_heads: int = 8,
    e: float = 0.5,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
cv1
cv2
ahc
cv3
Source code in src/splifft/models/utils/hyperace.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def __init__(
    self,
    c1: int,
    c2: int,
    num_hyperedges: int = 8,
    num_heads: int = 8,
    e: float = 0.5,
):
    super().__init__()
    c_ = int(c1 * e)
    self.cv1 = Conv(c1, c_, 1, 1)
    self.cv2 = Conv(c1, c_, 1, 1)
    self.ahc = AdaptiveHypergraphComputation(c_, c_, num_hyperedges, num_heads)
    self.cv3 = Conv(2 * c_, c2, 1, 1)
cv1 instance-attribute
cv1 = Conv(c1, c_, 1, 1)
cv2 instance-attribute
cv2 = Conv(c1, c_, 1, 1)
ahc instance-attribute
ahc = AdaptiveHypergraphComputation(
    c_, c_, num_hyperedges, num_heads
)
cv3 instance-attribute
cv3 = Conv(2 * c_, c2, 1, 1)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
231
232
233
234
def forward(self, x: Tensor) -> Any:
    x_lateral = self.cv1(x)
    x_ahc = self.ahc(self.cv2(x))
    return self.cv3(torch.cat((x_ahc, x_lateral), dim=1))
HyperACE
HyperACE(
    in_channels: list[int],
    out_channels: int,
    num_hyperedges: int = 8,
    num_heads: int = 8,
    k: int = 2,
    l: int = 1,
    c_h: float = 0.5,
    c_l: float = 0.25,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
fuse_conv
c_h
c_l
c_s
high_order_branch
high_order_fuse
low_order_branch
final_fuse
Source code in src/splifft/models/utils/hyperace.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def __init__(
    self,
    in_channels: list[int],
    out_channels: int,
    num_hyperedges: int = 8,
    num_heads: int = 8,
    k: int = 2,
    l: int = 1,
    c_h: float = 0.5,
    c_l: float = 0.25,
):
    super().__init__()

    c2, c3, c4, c5 = in_channels
    c_mid = c4

    self.fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1)

    self.c_h = int(c_mid * c_h)
    self.c_l = int(c_mid * c_l)
    self.c_s = c_mid - self.c_h - self.c_l
    assert self.c_s > 0, "Channel split error"

    self.high_order_branch = nn.ModuleList(
        [C3AH(self.c_h, self.c_h, num_hyperedges, num_heads, e=1.0) for _ in range(k)]
    )
    self.high_order_fuse = Conv(self.c_h * k, self.c_h, 1, 1)

    self.low_order_branch = nn.Sequential(
        *[DS_C3k(self.c_l, self.c_l, n=1, k=3, e=1.0) for _ in range(l)]
    )

    self.final_fuse = Conv(self.c_h + self.c_l + self.c_s, out_channels, 1, 1)
fuse_conv instance-attribute
fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1)
c_h instance-attribute
c_h = int(c_mid * c_h)
c_l instance-attribute
c_l = int(c_mid * c_l)
c_s instance-attribute
c_s = c_mid - c_h - c_l
high_order_branch instance-attribute
high_order_branch = ModuleList(
    [
        (C3AH(c_h, c_h, num_hyperedges, num_heads, e=1.0))
        for _ in (range(k))
    ]
)
high_order_fuse instance-attribute
high_order_fuse = Conv(c_h * k, c_h, 1, 1)
low_order_branch instance-attribute
low_order_branch = Sequential(
    *[
        (DS_C3k(c_l, c_l, n=1, k=3, e=1.0))
        for _ in (range(l))
    ]
)
final_fuse instance-attribute
final_fuse = Conv(c_h + c_l + c_s, out_channels, 1, 1)
forward
forward(x: list[Tensor]) -> Any
Source code in src/splifft/models/utils/hyperace.py
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def forward(self, x: list[Tensor]) -> Any:
    b2, b3, b4, b5 = x

    _, _, h4, w4 = b4.shape

    b2_resized = F.interpolate(b2, size=(h4, w4), mode="bilinear", align_corners=False)
    b3_resized = F.interpolate(b3, size=(h4, w4), mode="bilinear", align_corners=False)
    b5_resized = F.interpolate(b5, size=(h4, w4), mode="bilinear", align_corners=False)

    x_b = self.fuse_conv(torch.cat((b2_resized, b3_resized, b4, b5_resized), dim=1))

    x_h, x_l, x_s = torch.split(x_b, [self.c_h, self.c_l, self.c_s], dim=1)

    x_h_outs = [m(x_h) for m in self.high_order_branch]
    x_h_fused = self.high_order_fuse(torch.cat(x_h_outs, dim=1))

    x_l_out = self.low_order_branch(x_l)

    y = self.final_fuse(torch.cat((x_h_fused, x_l_out, x_s), dim=1))

    return y
GatedFusion
GatedFusion(in_channels: int)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
gamma
Source code in src/splifft/models/utils/hyperace.py
296
297
298
def __init__(self, in_channels: int):
    super().__init__()
    self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
gamma instance-attribute
gamma = Parameter(zeros(1, in_channels, 1, 1))
forward
forward(f_in: Tensor, h: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
300
301
302
303
def forward(self, f_in: Tensor, h: Tensor) -> Any:
    if f_in.shape[1] != h.shape[1]:
        raise ValueError(f"Channel mismatch: f_in={f_in.shape}, h={h.shape}")
    return f_in + self.gamma * h
BackboneHyperAceV1
BackboneHyperAceV1(
    in_channels: int = 256,
    base_channels: int = 64,
    base_depth: int = 3,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
stem
p2
p3
p4
p5
out_channels
Source code in src/splifft/models/utils/hyperace.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
def __init__(self, in_channels: int = 256, base_channels: int = 64, base_depth: int = 3):
    super().__init__()
    c2 = base_channels
    c3 = 256
    c4 = 384
    c5 = 512
    c6 = 768

    self.stem = DSConv(in_channels, c2, k=3, s=(2, 1), p=1)

    self.p2 = nn.Sequential(
        DSConv(c2, c3, k=3, s=(2, 1), p=1),
        DS_C3k2(c3, c3, n=base_depth),
    )

    self.p3 = nn.Sequential(
        DSConv(c3, c4, k=3, s=(2, 1), p=1),
        DS_C3k2(c4, c4, n=base_depth * 2),
    )

    self.p4 = nn.Sequential(
        DSConv(c4, c5, k=3, s=(2, 1), p=1),
        DS_C3k2(c5, c5, n=base_depth * 2),
    )

    self.p5 = nn.Sequential(
        DSConv(c5, c6, k=3, s=(2, 1), p=1),
        DS_C3k2(c6, c6, n=base_depth),
    )

    self.out_channels = [c3, c4, c5, c6]
stem instance-attribute
stem = DSConv(in_channels, c2, k=3, s=(2, 1), p=1)
p2 instance-attribute
p2 = Sequential(
    DSConv(c2, c3, k=3, s=(2, 1), p=1),
    DS_C3k2(c3, c3, n=base_depth),
)
p3 instance-attribute
p3 = Sequential(
    DSConv(c3, c4, k=3, s=(2, 1), p=1),
    DS_C3k2(c4, c4, n=base_depth * 2),
)
p4 instance-attribute
p4 = Sequential(
    DSConv(c4, c5, k=3, s=(2, 1), p=1),
    DS_C3k2(c5, c5, n=base_depth * 2),
)
p5 instance-attribute
p5 = Sequential(
    DSConv(c5, c6, k=3, s=(2, 1), p=1),
    DS_C3k2(c6, c6, n=base_depth),
)
out_channels instance-attribute
out_channels = [c3, c4, c5, c6]
forward
forward(x: Tensor) -> list[Tensor]
Source code in src/splifft/models/utils/hyperace.py
339
340
341
342
343
344
345
def forward(self, x: Tensor) -> list[Tensor]:
    x = self.stem(x)
    x2 = self.p2(x)
    x3 = self.p3(x2)
    x4 = self.p4(x3)
    x5 = self.p5(x4)
    return [x2, x3, x4, x5]
BackboneHyperAceV2
BackboneHyperAceV2(
    in_channels: int = 256,
    base_channels: int = 64,
    base_depth: int = 3,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
stem
p2
p3
p4
p5
out_channels
Source code in src/splifft/models/utils/hyperace.py
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def __init__(self, in_channels: int = 256, base_channels: int = 64, base_depth: int = 3):
    super().__init__()
    c2 = base_channels
    c3 = 256
    c4 = 384
    c5 = 512
    c6 = 768

    self.stem = DSConv(in_channels, c2, k=3, s=(2, 1), p=1)

    self.p2 = nn.Sequential(
        DSConv(c2, c3, k=3, s=(2, 1), p=1),
        DS_C3k2(c3, c3, n=base_depth),
    )

    self.p3 = nn.Sequential(
        DSConv(c3, c4, k=3, s=(2, 1), p=1),
        DS_C3k2(c4, c4, n=base_depth * 2),
    )

    self.p4 = nn.Sequential(
        DSConv(c4, c5, k=3, s=2, p=1),
        DS_C3k2(c5, c5, n=base_depth * 2),
    )

    self.p5 = nn.Sequential(
        DSConv(c5, c6, k=3, s=2, p=1),
        DS_C3k2(c6, c6, n=base_depth),
    )

    self.out_channels = [c3, c4, c5, c6]
stem instance-attribute
stem = DSConv(in_channels, c2, k=3, s=(2, 1), p=1)
p2 instance-attribute
p2 = Sequential(
    DSConv(c2, c3, k=3, s=(2, 1), p=1),
    DS_C3k2(c3, c3, n=base_depth),
)
p3 instance-attribute
p3 = Sequential(
    DSConv(c3, c4, k=3, s=(2, 1), p=1),
    DS_C3k2(c4, c4, n=base_depth * 2),
)
p4 instance-attribute
p4 = Sequential(
    DSConv(c4, c5, k=3, s=2, p=1),
    DS_C3k2(c5, c5, n=base_depth * 2),
)
p5 instance-attribute
p5 = Sequential(
    DSConv(c5, c6, k=3, s=2, p=1),
    DS_C3k2(c6, c6, n=base_depth),
)
out_channels instance-attribute
out_channels = [c3, c4, c5, c6]
forward
forward(x: Tensor) -> list[Tensor]
Source code in src/splifft/models/utils/hyperace.py
381
382
383
384
385
386
387
def forward(self, x: Tensor) -> list[Tensor]:
    x = self.stem(x)
    x2 = self.p2(x)
    x3 = self.p3(x2)
    x4 = self.p4(x3)
    x5 = self.p5(x4)
    return [x2, x3, x4, x5]
DecoderHyperAce
DecoderHyperAce(
    encoder_channels: list[int],
    hyperace_out_c: int,
    decoder_channels: list[int],
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
h_to_d5
h_to_d4
h_to_d3
h_to_d2
fusion_d5
fusion_d4
fusion_d3
fusion_d2
skip_p5
skip_p4
skip_p3
skip_p2
up_d5
up_d4
up_d3
final_d2
Source code in src/splifft/models/utils/hyperace.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
def __init__(
    self, encoder_channels: list[int], hyperace_out_c: int, decoder_channels: list[int]
):
    super().__init__()
    c_p2, c_p3, c_p4, c_p5 = encoder_channels
    c_d2, c_d3, c_d4, c_d5 = decoder_channels

    self.h_to_d5 = Conv(hyperace_out_c, c_d5, 1, 1)
    self.h_to_d4 = Conv(hyperace_out_c, c_d4, 1, 1)
    self.h_to_d3 = Conv(hyperace_out_c, c_d3, 1, 1)
    self.h_to_d2 = Conv(hyperace_out_c, c_d2, 1, 1)

    self.fusion_d5 = GatedFusion(c_d5)
    self.fusion_d4 = GatedFusion(c_d4)
    self.fusion_d3 = GatedFusion(c_d3)
    self.fusion_d2 = GatedFusion(c_d2)

    self.skip_p5 = Conv(c_p5, c_d5, 1, 1)
    self.skip_p4 = Conv(c_p4, c_d4, 1, 1)
    self.skip_p3 = Conv(c_p3, c_d3, 1, 1)
    self.skip_p2 = Conv(c_p2, c_d2, 1, 1)

    self.up_d5 = DS_C3k2(c_d5, c_d4, n=1)
    self.up_d4 = DS_C3k2(c_d4, c_d3, n=1)
    self.up_d3 = DS_C3k2(c_d3, c_d2, n=1)

    self.final_d2 = DS_C3k2(c_d2, c_d2, n=1)
h_to_d5 instance-attribute
h_to_d5 = Conv(hyperace_out_c, c_d5, 1, 1)
h_to_d4 instance-attribute
h_to_d4 = Conv(hyperace_out_c, c_d4, 1, 1)
h_to_d3 instance-attribute
h_to_d3 = Conv(hyperace_out_c, c_d3, 1, 1)
h_to_d2 instance-attribute
h_to_d2 = Conv(hyperace_out_c, c_d2, 1, 1)
fusion_d5 instance-attribute
fusion_d5 = GatedFusion(c_d5)
fusion_d4 instance-attribute
fusion_d4 = GatedFusion(c_d4)
fusion_d3 instance-attribute
fusion_d3 = GatedFusion(c_d3)
fusion_d2 instance-attribute
fusion_d2 = GatedFusion(c_d2)
skip_p5 instance-attribute
skip_p5 = Conv(c_p5, c_d5, 1, 1)
skip_p4 instance-attribute
skip_p4 = Conv(c_p4, c_d4, 1, 1)
skip_p3 instance-attribute
skip_p3 = Conv(c_p3, c_d3, 1, 1)
skip_p2 instance-attribute
skip_p2 = Conv(c_p2, c_d2, 1, 1)
up_d5 instance-attribute
up_d5 = DS_C3k2(c_d5, c_d4, n=1)
up_d4 instance-attribute
up_d4 = DS_C3k2(c_d4, c_d3, n=1)
up_d3 instance-attribute
up_d3 = DS_C3k2(c_d3, c_d2, n=1)
final_d2 instance-attribute
final_d2 = DS_C3k2(c_d2, c_d2, n=1)
forward
forward(enc_feats: list[Tensor], h_ace: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
def forward(self, enc_feats: list[Tensor], h_ace: Tensor) -> Any:
    p2, p3, p4, p5 = enc_feats

    d5 = self.skip_p5(p5)
    h_d5 = self.h_to_d5(F.interpolate(h_ace, size=d5.shape[2:], mode="bilinear"))
    d5 = self.fusion_d5(d5, h_d5)

    d5_up = F.interpolate(d5, size=p4.shape[2:], mode="bilinear")
    d4_skip = self.skip_p4(p4)
    d4 = self.up_d5(d5_up) + d4_skip

    h_d4 = self.h_to_d4(F.interpolate(h_ace, size=d4.shape[2:], mode="bilinear"))
    d4 = self.fusion_d4(d4, h_d4)

    d4_up = F.interpolate(d4, size=p3.shape[2:], mode="bilinear")
    d3_skip = self.skip_p3(p3)
    d3 = self.up_d4(d4_up) + d3_skip

    h_d3 = self.h_to_d3(F.interpolate(h_ace, size=d3.shape[2:], mode="bilinear"))
    d3 = self.fusion_d3(d3, h_d3)

    d3_up = F.interpolate(d3, size=p2.shape[2:], mode="bilinear")
    d2_skip = self.skip_p2(p2)
    d2 = self.up_d3(d3_up) + d2_skip

    h_d2 = self.h_to_d2(F.interpolate(h_ace, size=d2.shape[2:], mode="bilinear"))
    d2 = self.fusion_d2(d2, h_d2)

    d2_final = self.final_d2(d2)

    return d2_final
FreqPixelShuffleV1
FreqPixelShuffleV1(
    in_channels: int, out_channels: int, scale: int = 2
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
scale
conv
Source code in src/splifft/models/utils/hyperace.py
453
454
455
456
def __init__(self, in_channels: int, out_channels: int, scale: int = 2):
    super().__init__()
    self.scale = scale
    self.conv = DSConv(in_channels, out_channels * scale, k=3, s=1, p=1)
scale instance-attribute
scale = scale
conv instance-attribute
conv = DSConv(
    in_channels, out_channels * scale, k=3, s=1, p=1
)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
458
459
460
461
462
463
464
465
466
467
468
def forward(self, x: Tensor) -> Any:
    x = self.conv(x)
    b, c_r, h, w = x.shape
    out_c = c_r // self.scale

    x = x.view(b, out_c, self.scale, h, w)

    x = x.permute(0, 1, 3, 4, 2).contiguous()
    x = x.view(b, out_c, h, w * self.scale)

    return x
ProgressiveUpsampleHeadV1
ProgressiveUpsampleHeadV1(
    in_channels: int,
    out_channels: int,
    target_bins: int = 1025,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
target_bins
block1
block2
block3
block4
final_conv
Source code in src/splifft/models/utils/hyperace.py
472
473
474
475
476
477
478
479
480
481
482
483
def __init__(self, in_channels: int, out_channels: int, target_bins: int = 1025):
    super().__init__()
    self.target_bins = target_bins

    c = in_channels

    self.block1 = FreqPixelShuffleV1(c, c, scale=2)
    self.block2 = FreqPixelShuffleV1(c, c // 2, scale=2)
    self.block3 = FreqPixelShuffleV1(c // 2, c // 2, scale=2)
    self.block4 = FreqPixelShuffleV1(c // 2, c // 4, scale=2)

    self.final_conv = nn.Conv2d(c // 4, out_channels, kernel_size=1, bias=False)
target_bins instance-attribute
target_bins = target_bins
block1 instance-attribute
block1 = FreqPixelShuffleV1(c, c, scale=2)
block2 instance-attribute
block2 = FreqPixelShuffleV1(c, c // 2, scale=2)
block3 instance-attribute
block3 = FreqPixelShuffleV1(c // 2, c // 2, scale=2)
block4 instance-attribute
block4 = FreqPixelShuffleV1(c // 2, c // 4, scale=2)
final_conv instance-attribute
final_conv = Conv2d(
    c // 4, out_channels, kernel_size=1, bias=False
)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
def forward(self, x: Tensor) -> Any:
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    x = self.block4(x)

    if x.shape[-1] != self.target_bins:
        x = F.interpolate(
            x,
            size=(x.shape[2], self.target_bins),
            mode="bilinear",
            align_corners=False,
        )

    x = self.final_conv(x)
    return x
FreqPixelShuffleV2
FreqPixelShuffleV2(
    in_channels: int, out_channels: int, scale: int, f: int
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
scale
conv
out_conv
Source code in src/splifft/models/utils/hyperace.py
504
505
506
507
508
def __init__(self, in_channels: int, out_channels: int, scale: int, f: int):
    super().__init__()
    self.scale = scale
    self.conv = DSConv(in_channels, out_channels * scale)
    self.out_conv = build_hyperace_tfc_tdf(out_channels, out_channels, 2, f)
scale instance-attribute
scale = scale
conv instance-attribute
conv = DSConv(in_channels, out_channels * scale)
out_conv instance-attribute
out_conv = build_hyperace_tfc_tdf(
    out_channels, out_channels, 2, f
)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
510
511
512
513
514
515
516
517
518
519
520
def forward(self, x: Tensor) -> Any:
    x = self.conv(x)
    b, c_r, h, w = x.shape
    out_c = c_r // self.scale

    x = x.view(b, out_c, self.scale, h, w)

    x = x.permute(0, 1, 3, 4, 2).contiguous()
    x = x.view(b, out_c, h, w * self.scale)

    return self.out_conv(x)
ProgressiveUpsampleHeadV2
ProgressiveUpsampleHeadV2(
    in_channels: int,
    out_channels: int,
    target_bins: int = 1025,
    in_bands: int = 62,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
target_bins
block1
block2
block3
block4
final_conv
Source code in src/splifft/models/utils/hyperace.py
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
def __init__(
    self, in_channels: int, out_channels: int, target_bins: int = 1025, in_bands: int = 62
):
    super().__init__()
    self.target_bins = target_bins

    c = in_channels

    self.block1 = FreqPixelShuffleV2(c, c // 2, scale=2, f=in_bands * 2)
    self.block2 = FreqPixelShuffleV2(c // 2, c // 4, scale=2, f=in_bands * 4)
    self.block3 = FreqPixelShuffleV2(c // 4, c // 8, scale=2, f=in_bands * 8)
    self.block4 = FreqPixelShuffleV2(c // 8, c // 16, scale=2, f=in_bands * 16)

    self.final_conv = nn.Conv2d(
        c // 16,
        out_channels,
        kernel_size=3,
        stride=1,
        padding="same",
        bias=False,
    )
target_bins instance-attribute
target_bins = target_bins
block1 instance-attribute
block1 = FreqPixelShuffleV2(
    c, c // 2, scale=2, f=in_bands * 2
)
block2 instance-attribute
block2 = FreqPixelShuffleV2(
    c // 2, c // 4, scale=2, f=in_bands * 4
)
block3 instance-attribute
block3 = FreqPixelShuffleV2(
    c // 4, c // 8, scale=2, f=in_bands * 8
)
block4 instance-attribute
block4 = FreqPixelShuffleV2(
    c // 8, c // 16, scale=2, f=in_bands * 16
)
final_conv instance-attribute
final_conv = Conv2d(
    c // 16,
    out_channels,
    kernel_size=3,
    stride=1,
    padding="same",
    bias=False,
)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
def forward(self, x: Tensor) -> Any:
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    x = self.block4(x)

    if x.shape[-1] != self.target_bins:
        x = F.interpolate(
            x,
            size=(x.shape[2], self.target_bins),
            mode="bilinear",
            align_corners=False,
        )

    x = self.final_conv(x)
    return x
SegmModelHyperAceV1
SegmModelHyperAceV1(
    in_bands: int = 62,
    in_dim: int = 256,
    out_bins: int = 1025,
    out_channels: int = 4,
    base_channels: int = 64,
    base_depth: int = 2,
    num_hyperedges: int = 16,
    num_heads: int = 8,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
backbone
hyperace
decoder
upsample_head
Source code in src/splifft/models/utils/hyperace.py
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
def __init__(
    self,
    in_bands: int = 62,
    in_dim: int = 256,
    out_bins: int = 1025,
    out_channels: int = 4,
    base_channels: int = 64,
    base_depth: int = 2,
    num_hyperedges: int = 16,
    num_heads: int = 8,
):
    super().__init__()

    self.backbone = BackboneHyperAceV1(
        in_channels=in_dim,
        base_channels=base_channels,
        base_depth=base_depth,
    )
    enc_channels = self.backbone.out_channels
    c2, c3, c4, c5 = enc_channels

    hyperace_in_channels = enc_channels
    hyperace_out_channels = c4
    self.hyperace = HyperACE(
        hyperace_in_channels,
        hyperace_out_channels,
        num_hyperedges,
        num_heads,
        k=3,
        l=2,
    )

    decoder_channels = [c2, c3, c4, c5]
    self.decoder = DecoderHyperAce(enc_channels, hyperace_out_channels, decoder_channels)

    self.upsample_head = ProgressiveUpsampleHeadV1(
        in_channels=decoder_channels[0],
        out_channels=out_channels,
        target_bins=out_bins,
    )
backbone instance-attribute
backbone = BackboneHyperAceV1(
    in_channels=in_dim,
    base_channels=base_channels,
    base_depth=base_depth,
)
hyperace instance-attribute
hyperace = HyperACE(
    hyperace_in_channels,
    hyperace_out_channels,
    num_hyperedges,
    num_heads,
    k=3,
    l=2,
)
decoder instance-attribute
decoder = DecoderHyperAce(
    enc_channels, hyperace_out_channels, decoder_channels
)
upsample_head instance-attribute
upsample_head = ProgressiveUpsampleHeadV1(
    in_channels=decoder_channels[0],
    out_channels=out_channels,
    target_bins=out_bins,
)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
def forward(self, x: Tensor) -> Any:
    h, _ = x.shape[2:]

    enc_feats = self.backbone(x)

    h_ace_feats = self.hyperace(enc_feats)

    dec_feat = self.decoder(enc_feats, h_ace_feats)

    feat_time_restored = F.interpolate(
        dec_feat,
        size=(h, dec_feat.shape[-1]),
        mode="bilinear",
        align_corners=False,
    )

    out = self.upsample_head(feat_time_restored)

    return out
SegmModelHyperAceV2
SegmModelHyperAceV2(
    in_bands: int = 62,
    in_dim: int = 256,
    out_bins: int = 1025,
    out_channels: int = 4,
    base_channels: int = 64,
    base_depth: int = 2,
    num_hyperedges: int = 32,
    num_heads: int = 8,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
backbone
hyperace
decoder
upsample_head
Source code in src/splifft/models/utils/hyperace.py
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
def __init__(
    self,
    in_bands: int = 62,
    in_dim: int = 256,
    out_bins: int = 1025,
    out_channels: int = 4,
    base_channels: int = 64,
    base_depth: int = 2,
    num_hyperedges: int = 32,
    num_heads: int = 8,
):
    super().__init__()

    self.backbone = BackboneHyperAceV2(
        in_channels=in_dim,
        base_channels=base_channels,
        base_depth=base_depth,
    )
    enc_channels = self.backbone.out_channels
    c2, c3, c4, c5 = enc_channels

    hyperace_in_channels = enc_channels
    hyperace_out_channels = c4
    self.hyperace = HyperACE(
        hyperace_in_channels,
        hyperace_out_channels,
        num_hyperedges,
        num_heads,
        k=2,
        l=1,
    )

    decoder_channels = [c2, c3, c4, c5]
    self.decoder = DecoderHyperAce(enc_channels, hyperace_out_channels, decoder_channels)

    self.upsample_head = ProgressiveUpsampleHeadV2(
        in_channels=decoder_channels[0],
        out_channels=out_channels,
        target_bins=out_bins,
        in_bands=in_bands,
    )
backbone instance-attribute
backbone = BackboneHyperAceV2(
    in_channels=in_dim,
    base_channels=base_channels,
    base_depth=base_depth,
)
hyperace instance-attribute
hyperace = HyperACE(
    hyperace_in_channels,
    hyperace_out_channels,
    num_hyperedges,
    num_heads,
    k=2,
    l=1,
)
decoder instance-attribute
decoder = DecoderHyperAce(
    enc_channels, hyperace_out_channels, decoder_channels
)
upsample_head instance-attribute
upsample_head = ProgressiveUpsampleHeadV2(
    in_channels=decoder_channels[0],
    out_channels=out_channels,
    target_bins=out_bins,
    in_bands=in_bands,
)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
def forward(self, x: Tensor) -> Any:
    h, _ = x.shape[2:]

    enc_feats = self.backbone(x)

    h_ace_feats = self.hyperace(enc_feats)

    dec_feat = self.decoder(enc_feats, h_ace_feats)

    feat_time_restored = F.interpolate(
        dec_feat,
        size=(h, dec_feat.shape[-1]),
        mode="bilinear",
        align_corners=False,
    )

    out = self.upsample_head(feat_time_restored)

    return out

stft

Classes:

Name Description
Stft

A custom STFT implementation using 1D convolutions to ensure compatibility with CoreML.

IStft

A simple wrapper around torch.istft with a hacky workaround for MPS.

Stft
Stft(
    n_fft: int,
    hop_length: int,
    win_length: int,
    window_fn: Callable[[int], Tensor],
    conv_dtype: dtype | None,
)

Bases: Module

A custom STFT implementation using 1D convolutions to ensure compatibility with CoreML.

Methods:

Name Description
forward

Attributes:

Name Type Description
n_fft
hop_length
win_length
conv_dtype
Source code in src/splifft/models/utils/stft.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def __init__(
    self,
    n_fft: int,
    hop_length: int,
    win_length: int,
    window_fn: Callable[[int], Tensor],
    conv_dtype: torch.dtype | None,
):
    super().__init__()
    self.n_fft = n_fft
    self.hop_length = hop_length
    self.win_length = win_length
    self.conv_dtype = conv_dtype

    window = window_fn(self.win_length)

    dft_mat = torch.fft.fft(torch.eye(self.n_fft, device=window.device))
    dft_mat_T = dft_mat.T

    real_kernels = dft_mat_T.real[
        : self.win_length, : (self.n_fft // 2 + 1)
    ] * window.unsqueeze(-1)
    imag_kernels = dft_mat_T.imag[
        : self.win_length, : (self.n_fft // 2 + 1)
    ] * window.unsqueeze(-1)

    # (out_channels, in_channels, kernel_size)
    self.register_buffer("real_conv_weight", real_kernels.T.unsqueeze(1).to(self.conv_dtype))
    self.register_buffer("imag_conv_weight", imag_kernels.T.unsqueeze(1).to(self.conv_dtype))
n_fft instance-attribute
n_fft = n_fft
hop_length instance-attribute
hop_length = hop_length
win_length instance-attribute
win_length = win_length
conv_dtype instance-attribute
conv_dtype = conv_dtype
forward
forward(x: Tensor) -> ComplexSpectrogram
Source code in src/splifft/models/utils/stft.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def forward(self, x: Tensor) -> t.ComplexSpectrogram:
    b, s, t = x.shape
    x = x.reshape(b * s, 1, t).to(self.conv_dtype)

    padding = self.n_fft // 2
    x = F.pad(x, (padding, padding), "reflect")

    # NOTE: keeping the real / imag passes separate.
    # a single concatenated conv is exact in f16 on CUDA but diverges in f32
    real_part = F.conv1d(x, self.real_conv_weight, stride=self.hop_length)  # type: ignore
    imag_part = F.conv1d(x, self.imag_conv_weight, stride=self.hop_length)  # type: ignore
    spec = torch.stack((real_part, imag_part), dim=-1)  # (b*s, f, t_frames, c=2)

    _bs, f, t_frames, c = spec.shape
    spec = spec.view(b, s, f, t_frames, c)

    return spec  # type: ignore
IStft
IStft(
    n_fft: int,
    hop_length: int,
    win_length: int,
    window_fn: Callable[[int], Tensor] = hann_window,
)

Bases: Module

A simple wrapper around torch.istft with a hacky workaround for MPS.

TODO: implement a proper workaround.

Methods:

Name Description
forward

Attributes:

Name Type Description
n_fft
hop_length
win_length
window
Source code in src/splifft/models/utils/stft.py
71
72
73
74
75
76
77
78
79
80
81
82
def __init__(
    self,
    n_fft: int,
    hop_length: int,
    win_length: int,
    window_fn: Callable[[int], Tensor] = torch.hann_window,
):
    super().__init__()
    self.n_fft = n_fft
    self.hop_length = hop_length
    self.win_length = win_length
    self.window = window_fn(self.win_length)
n_fft instance-attribute
n_fft = n_fft
hop_length instance-attribute
hop_length = hop_length
win_length instance-attribute
win_length = win_length
window instance-attribute
window = window_fn(win_length)
forward
forward(
    spec: ComplexSpectrogram, length: int | None = None
) -> RawAudioTensor | NormalizedAudioTensor
Source code in src/splifft/models/utils/stft.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def forward(
    self, spec: t.ComplexSpectrogram, length: int | None = None
) -> t.RawAudioTensor | t.NormalizedAudioTensor:
    device = spec.device
    is_mps = device.type == "mps"
    window = self.window.to(device)
    # see https://github.com/lucidrains/BS-RoFormer/issues/47
    # this would introduce a breaking change.
    # spec = spec.index_fill(1, torch.tensor(0, device=spec.device), 0.)  # type: ignore
    spec_complex = torch.view_as_complex(spec)

    try:
        audio = torch.istft(
            spec_complex,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window=window,
            return_complex=False,
            length=length,
        )
    except RuntimeError:
        audio = torch.istft(
            spec_complex.cpu() if is_mps else spec_complex,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window=window.cpu() if is_mps else window,
            return_complex=False,
            length=length,
        ).to(device)

    return audio  # type: ignore