Skip to content

Models

models

Source separation models.

Modules:

Name Description
bs_roformer

Band-Split RoPE Transformer

utils

Classes:

Name Description
ModelParamsLike

A trait that must be implemented to be considered a model parameter.

ModelMetadata

Metadata about a model, including its type, parameter class, and model class.

Attributes:

Name Type Description
ModelT
ModelParamsLikeT

ModelParamsLike

Bases: Protocol

A trait that must be implemented to be considered a model parameter. Note that input_type and output_type belong to a model's definition and does not allow modification via the configuration dictionary.

Attributes:

Name Type Description
chunk_size ChunkSize
output_stem_names tuple[ModelOutputStemName, ...]
input_type ModelInputType
output_type ModelOutputType

chunk_size instance-attribute

chunk_size: ChunkSize

output_stem_names instance-attribute

output_stem_names: tuple[ModelOutputStemName, ...]

input_type property

input_type: ModelInputType

output_type property

output_type: ModelOutputType

ModelT module-attribute

ModelT = TypeVar('ModelT', bound=Module)

ModelParamsLikeT module-attribute

ModelParamsLikeT = TypeVar(
    "ModelParamsLikeT", bound=ModelParamsLike
)

ModelMetadata dataclass

ModelMetadata(
    model_type: ModelType,
    params: type[ModelParamsLikeT],
    model: type[ModelT],
)

Bases: Generic[ModelT, ModelParamsLikeT]

Metadata about a model, including its type, parameter class, and model class.

Methods:

Name Description
from_module

Dynamically import a model named X and its parameter dataclass XParams under a

Attributes:

Name Type Description
model_type ModelType
params type[ModelParamsLikeT]
model type[ModelT]

model_type instance-attribute

model_type: ModelType

params instance-attribute

model instance-attribute

model: type[ModelT]

from_module classmethod

from_module(
    module_name: str,
    model_cls_name: str,
    *,
    model_type: ModelType,
    package: str | None = None,
) -> ModelMetadata[Module, ModelParamsLike]

Dynamically import a model named X and its parameter dataclass XParams under a given module name (e.g. splifft.models.bs_roformer).

Parameters:

Name Type Description Default
model_cls_name str

The name of the model class to import, e.g. BSRoformer.

required
module_name str

The name of the module to import, e.g. splifft.models.bs_roformer.

required
model_type ModelType

The type of the model, e.g. bs_roformer.

required
package str | None

The package to use as the anchor point from which to resolve the relative import. to an absolute import. This is only required when performing a relative import.

None
Source code in src/splifft/models/__init__.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@classmethod
def from_module(
    cls,
    module_name: str,
    model_cls_name: str,
    *,
    model_type: t.ModelType,
    package: str | None = None,
) -> ModelMetadata[nn.Module, ModelParamsLike]:
    """
    Dynamically import a model named `X` and its parameter dataclass `XParams` under a
    given module name (e.g. `splifft.models.bs_roformer`).

    :param model_cls_name: The name of the model class to import, e.g. `BSRoformer`.
    :param module_name: The name of the module to import, e.g. `splifft.models.bs_roformer`.
    :param model_type: The type of the model, e.g. `bs_roformer`.
    :param package: The package to use as the anchor point from which to resolve the relative import.
    to an absolute import. This is only required when performing a relative import.
    """
    _loc = f"{module_name=} under {package=}"
    try:
        module = importlib.import_module(module_name, package)
    except ImportError as e:
        raise ValueError(f"failed to find or import module for {_loc}") from e

    params_cls_name = f"{model_cls_name}Params"
    model_cls = getattr(module, model_cls_name, None)
    params_cls = getattr(module, params_cls_name, None)
    if model_cls is None or params_cls is None:
        raise AttributeError(
            f"expected to find a class named `{params_cls_name}` in {_loc}, but it was not found."
        )

    return ModelMetadata(
        model_type=model_type,
        model=model_cls,
        params=params_cls,
    )

bs_roformer

Band-Split RoPE Transformer

This implementation merges the two versions found in lucidrains's implementation However, there are several inconsistencies:

  • MLP was defined differently in each file, one that has depth - 1 hidden layers and one that has depth layers.
  • BSRoformer applies one final RMSNorm after the entire stack of transformer layers, while the MelBandRoformer applies an RMSNorm at the end of each axial transformer block (time_transformer, freq_transformer, etc.) and has no final normalization layer.

Since fixing the three inconsistencies upstream is too big of a breaking change, we inherit them to maintain compatability with community-trained models. See: https://github.com/lucidrains/BS-RoFormer/issues/48.

To avoid dependency bloat, we do not:

Classes:

Name Description
FixedBandsConfig
MelBandsConfig
BSRoformerParams
RMSNorm
RMSNormWithEps
RotaryEmbedding

A performance-oriented version of RoPE.

FeedForward
Attention
LinearAttention

this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.

Transformer
BandSplit
MaskEstimator
BSRoformer

Functions:

Name Description
l2norm
rms_norm
mlp

Attributes:

Name Type Description
DEFAULT_FREQS_PER_BANDS

DEFAULT_FREQS_PER_BANDS module-attribute

DEFAULT_FREQS_PER_BANDS = (
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    12,
    12,
    12,
    12,
    12,
    12,
    12,
    12,
    24,
    24,
    24,
    24,
    24,
    24,
    24,
    24,
    48,
    48,
    48,
    48,
    48,
    48,
    48,
    48,
    128,
    129,
)

FixedBandsConfig dataclass

FixedBandsConfig(
    kind: Literal["fixed"] = "fixed",
    freqs_per_bands: tuple[Gt0[int], ...] = (
        lambda: DEFAULT_FREQS_PER_BANDS
    )(),
)

Attributes:

Name Type Description
kind Literal['fixed']
freqs_per_bands tuple[Gt0[int], ...]
kind class-attribute instance-attribute
kind: Literal['fixed'] = 'fixed'
freqs_per_bands class-attribute instance-attribute
freqs_per_bands: tuple[Gt0[int], ...] = field(
    default_factory=lambda: DEFAULT_FREQS_PER_BANDS
)

MelBandsConfig dataclass

MelBandsConfig(
    kind: Literal["mel"],
    num_bands: Gt0[int],
    sample_rate: Gt0[int],
    stft_n_fft: Gt0[int],
)

Attributes:

Name Type Description
kind Literal['mel']
num_bands Gt0[int]
sample_rate Gt0[int]
stft_n_fft Gt0[int]
kind instance-attribute
kind: Literal['mel']
num_bands instance-attribute
num_bands: Gt0[int]
sample_rate instance-attribute
sample_rate: Gt0[int]
stft_n_fft instance-attribute
stft_n_fft: Gt0[int]

BSRoformerParams dataclass

BSRoformerParams(
    chunk_size: ChunkSize,
    output_stem_names: tuple[ModelOutputStemName, ...],
    dim: Gt0[int],
    depth: Gt0[int],
    stft_hop_length: HopSize,
    stereo: bool = True,
    time_transformer_depth: Gt0[int] = 1,
    freq_transformer_depth: Gt0[int] = 1,
    linear_transformer_depth: Ge0[int] = 0,
    band_config: FixedBandsConfig
    | MelBandsConfig = FixedBandsConfig(),
    dim_head: int = 64,
    heads: Gt0[int] = 8,
    attn_dropout: Dropout = 0.0,
    ff_dropout: Dropout = 0.0,
    ff_mult: Gt0[int] = 4,
    flash_attn: bool = True,
    norm_output: bool = False,
    mask_estimator_depth: Gt0[int] = 2,
    mlp_expansion_factor: Gt0[int] = 4,
    use_torch_checkpoint: bool = False,
    sage_attention: bool = False,
    use_shared_bias: bool = False,
    skip_connection: bool = False,
    rms_norm_eps: Ge0[float] | None = None,
    rotary_embed_dtype: TorchDtype | None = None,
    transformer_residual_dtype: TorchDtype | None = None,
    debug: bool = False,
)

Bases: ModelParamsLike

Attributes:

Name Type Description
chunk_size ChunkSize
output_stem_names tuple[ModelOutputStemName, ...]
dim Gt0[int]
depth Gt0[int]
stft_hop_length HopSize
stereo bool
time_transformer_depth Gt0[int]
freq_transformer_depth Gt0[int]
linear_transformer_depth Ge0[int]
band_config FixedBandsConfig | MelBandsConfig
dim_head int
heads Gt0[int]
attn_dropout Dropout
ff_dropout Dropout
ff_mult Gt0[int]
flash_attn bool
norm_output bool

Note that in lucidrains' implementation, this is set to

mask_estimator_depth Gt0[int]

The number of hidden layers of the MLP is mask_estimator_depth - 1, that is:

mlp_expansion_factor Gt0[int]
use_torch_checkpoint bool
sage_attention bool
use_shared_bias bool
skip_connection bool
rms_norm_eps Ge0[float] | None
rotary_embed_dtype TorchDtype | None
transformer_residual_dtype TorchDtype | None
debug bool

Whether to check for nan/inf in model outputs. Keep it off for torch.compile.

input_type ModelInputType
output_type ModelOutputType
chunk_size instance-attribute
chunk_size: ChunkSize
output_stem_names instance-attribute
output_stem_names: tuple[ModelOutputStemName, ...]
dim instance-attribute
dim: Gt0[int]
depth instance-attribute
depth: Gt0[int]
stft_hop_length instance-attribute
stft_hop_length: HopSize
stereo class-attribute instance-attribute
stereo: bool = True
time_transformer_depth class-attribute instance-attribute
time_transformer_depth: Gt0[int] = 1
freq_transformer_depth class-attribute instance-attribute
freq_transformer_depth: Gt0[int] = 1
linear_transformer_depth class-attribute instance-attribute
linear_transformer_depth: Ge0[int] = 0
band_config class-attribute instance-attribute
band_config: FixedBandsConfig | MelBandsConfig = field(
    default_factory=FixedBandsConfig
)
dim_head class-attribute instance-attribute
dim_head: int = 64
heads class-attribute instance-attribute
heads: Gt0[int] = 8
attn_dropout class-attribute instance-attribute
attn_dropout: Dropout = 0.0
ff_dropout class-attribute instance-attribute
ff_dropout: Dropout = 0.0
ff_mult class-attribute instance-attribute
ff_mult: Gt0[int] = 4
flash_attn class-attribute instance-attribute
flash_attn: bool = True
norm_output class-attribute instance-attribute
norm_output: bool = False

Note that in lucidrains' implementation, this is set to False for bs_roformer but True for mel_roformer!!

mask_estimator_depth class-attribute instance-attribute
mask_estimator_depth: Gt0[int] = 2

The number of hidden layers of the MLP is mask_estimator_depth - 1, that is:

  • depth = 1: (dim_in, dim_out)
  • depth = 2: (dim_in, dim_hidden, dim_out)

Note that in lucidrains' implementation of mel-band roformers, the number of hidden layers is incorrectly set as mask_estimator_depth. This includes popular models like kim-vocals and all models that use zfturbo's music-source-separation training.

If you are migrating a mel-band roformer's zfturbo configuration, increment the mask_estimator depth by 1.

mlp_expansion_factor class-attribute instance-attribute
mlp_expansion_factor: Gt0[int] = 4
use_torch_checkpoint class-attribute instance-attribute
use_torch_checkpoint: bool = False
sage_attention class-attribute instance-attribute
sage_attention: bool = False
use_shared_bias class-attribute instance-attribute
use_shared_bias: bool = False
skip_connection class-attribute instance-attribute
skip_connection: bool = False
rms_norm_eps class-attribute instance-attribute
rms_norm_eps: Ge0[float] | None = None
rotary_embed_dtype class-attribute instance-attribute
rotary_embed_dtype: TorchDtype | None = None
transformer_residual_dtype class-attribute instance-attribute
transformer_residual_dtype: TorchDtype | None = None
debug class-attribute instance-attribute
debug: bool = False

Whether to check for nan/inf in model outputs. Keep it off for torch.compile.

input_type property
input_type: ModelInputType
output_type property
output_type: ModelOutputType

l2norm

l2norm(t: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
127
128
def l2norm(t: Tensor) -> Tensor:
    return F.normalize(t, dim=-1, p=2)

RMSNorm

RMSNorm(dim: int)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
scale
gamma
Source code in src/splifft/models/bs_roformer.py
132
133
134
135
def __init__(self, dim: int):
    super().__init__()
    self.scale = dim**0.5
    self.gamma = nn.Parameter(torch.ones(dim))
scale instance-attribute
scale = dim ** 0.5
gamma instance-attribute
gamma = Parameter(ones(dim))
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
137
138
def forward(self, x: Tensor) -> Tensor:
    return F.normalize(x, dim=-1) * self.scale * self.gamma  # type: ignore

RMSNormWithEps

RMSNormWithEps(
    dim: int, eps: float = 5.960464477539063e-08
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
scale
gamma
eps
Source code in src/splifft/models/bs_roformer.py
142
143
144
145
146
def __init__(self, dim: int, eps: float = 5.960464477539063e-08):
    super().__init__()
    self.scale = dim**0.5
    self.gamma = nn.Parameter(torch.ones(dim))
    self.eps = eps
scale instance-attribute
scale = dim ** 0.5
gamma instance-attribute
gamma = Parameter(ones(dim))
eps instance-attribute
eps = eps
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
148
149
150
151
152
def forward(self, x: Tensor) -> Tensor:
    l2_norm = torch.linalg.norm(x, dim=-1, keepdim=True)
    denom = torch.maximum(l2_norm, torch.full_like(l2_norm, self.eps))
    normalized_x = x / denom
    return normalized_x * self.scale * self.gamma  # type: ignore

rms_norm

rms_norm(
    dim: int, eps: float | None
) -> RMSNorm | RMSNormWithEps
Source code in src/splifft/models/bs_roformer.py
155
156
157
158
def rms_norm(dim: int, eps: float | None) -> RMSNorm | RMSNormWithEps:
    if eps is None:
        return RMSNorm(dim)
    return RMSNormWithEps(dim, eps)

RotaryEmbedding

RotaryEmbedding(
    seq_len: int,
    dim_head: int,
    *,
    dtype: dtype | None,
    theta: int = 10000,
)

Bases: Module

A performance-oriented version of RoPE.

Unlike lucidrains' implementation which compute embeddings JIT during the forward pass and caches calls with the same or shorter sequence length, we simply compute them AOT as persistent buffers. To keep the computational graph clean, we do not support dynamic sequence lengths, learned frequencies or length extrapolation.

Methods:

Name Description
rotate_half
forward

Attributes:

Name Type Description
cos_emb
sin_emb
Source code in src/splifft/models/bs_roformer.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def __init__(
    self, seq_len: int, dim_head: int, *, dtype: torch.dtype | None, theta: int = 10000
):
    super().__init__()
    # COMPAT: the original implementation does not generate the embeddings
    # on the fly, but serialises them in fp16. there are some tiny
    # differences:
    # |                     |   from weights  |   generated    |
    # | ------------------- | --------------- | -------------- |
    # | cos_emb_time:971,22 | -0.99462890625  | -0.994140625   |
    # | cos_emb_time:971,23 | -0.99462890625  | -0.994140625   |
    # | sin_emb_time:727,12 | -0.457763671875 | -0.4580078125  |
    # | sin_emb_time:727,13 | -0.457763671875 | -0.4580078125  |
    # | sin_emb_time:825,4  | -0.8544921875   | -0.85400390625 |
    # | sin_emb_time:825,5  | -0.8544921875   | -0.85400390625 |
    freqs = 1.0 / (theta ** (torch.arange(0, dim_head, 2).float() / dim_head))
    t = torch.arange(seq_len)
    freqs = torch.einsum("i,j->ij", t, freqs)  # (seq_len, dim / 2)
    freqs = repeat(freqs, "... d -> ... (d r)", r=2)  # (seq_len, dim)
    self.cos_emb = freqs.cos().to(dtype)
    self.sin_emb = freqs.sin().to(dtype)
cos_emb instance-attribute
cos_emb = to(dtype)
sin_emb instance-attribute
sin_emb = to(dtype)
rotate_half
rotate_half(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
196
197
198
199
200
def rotate_half(self, x: Tensor) -> Tensor:
    x = rearrange(x, "... (d r) -> ... d r", r=2)
    x1, x2 = x.unbind(dim=-1)
    x = torch.stack((-x2, x1), dim=-1)
    return rearrange(x, "... d r -> ... (d r)")
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
202
203
204
205
206
207
208
209
210
211
212
def forward(self, x: Tensor) -> Tensor:
    # x is (batch_eff, heads, seq_len_for_rotation, dim_head)
    cos_b = self.cos_emb.unsqueeze(0).unsqueeze(0).to(x.device, x.dtype)
    sin_b = self.sin_emb.unsqueeze(0).unsqueeze(0).to(x.device, x.dtype)

    term1 = x * cos_b
    term2 = self.rotate_half(x) * sin_b

    # NOTE: original impl performed addition between two f32s but it comes with 30% slowdown
    # we eliminate it so the addition is performed between two f16s (according to __init__).
    return term1 + term2

FeedForward

FeedForward(
    dim: int,
    mult: int = 4,
    dropout: float = 0.0,
    rms_norm_eps: float | None = None,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
net
Source code in src/splifft/models/bs_roformer.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def __init__(
    self, dim: int, mult: int = 4, dropout: float = 0.0, rms_norm_eps: float | None = None
):
    super().__init__()
    dim_inner = int(dim * mult)
    # NOTE: in the paper: RMSNorm -> FC -> Tanh -> FC -> GLU
    self.net = nn.Sequential(
        rms_norm(dim, eps=rms_norm_eps),
        nn.Linear(dim, dim_inner),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim_inner, dim),
        nn.Dropout(dropout),
    )
net instance-attribute
net = Sequential(
    rms_norm(dim, eps=rms_norm_eps),
    Linear(dim, dim_inner),
    GELU(),
    Dropout(dropout),
    Linear(dim_inner, dim),
    Dropout(dropout),
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
231
232
def forward(self, x: Tensor) -> Tensor:
    return self.net(x)

Attention

Attention(
    dim: int,
    heads: int = 8,
    dim_head: int = 64,
    dropout: float = 0.0,
    shared_qkv_bias: Parameter | None = None,
    shared_out_bias: Parameter | None = None,
    rotary_embed: RotaryEmbedding | None = None,
    flash: bool = True,
    sage_attention: bool = False,
    rms_norm_eps: float | None = None,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
heads
scale
rotary_embed
attend
norm
to_qkv
to_gates
to_out
Source code in src/splifft/models/bs_roformer.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def __init__(
    self,
    dim: int,
    heads: int = 8,
    dim_head: int = 64,
    dropout: float = 0.0,
    shared_qkv_bias: nn.Parameter | None = None,
    shared_out_bias: nn.Parameter | None = None,
    rotary_embed: RotaryEmbedding | None = None,
    flash: bool = True,
    sage_attention: bool = False,
    rms_norm_eps: float | None = None,
):
    super().__init__()
    self.heads = heads
    self.scale = dim_head**-0.5
    dim_inner = heads * dim_head

    self.rotary_embed = rotary_embed

    if sage_attention:
        from .utils.attend_sage import AttendSage

        self.attend = AttendSage(flash=flash, dropout=dropout)
    else:
        from .utils.attend import Attend

        self.attend = Attend(flash=flash, dropout=dropout)  # type: ignore

    self.norm = rms_norm(dim, eps=rms_norm_eps)
    self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=(shared_qkv_bias is not None))
    if shared_qkv_bias is not None:
        self.to_qkv.bias = shared_qkv_bias

    self.to_gates = nn.Linear(dim, heads)

    self.to_out = nn.Sequential(
        nn.Linear(dim_inner, dim, bias=(shared_out_bias is not None)),
        nn.Dropout(dropout),
    )
    if shared_out_bias is not None:
        self.to_out[0].bias = shared_out_bias
heads instance-attribute
heads = heads
scale instance-attribute
scale = dim_head ** -0.5
rotary_embed instance-attribute
rotary_embed = rotary_embed
attend instance-attribute
attend = AttendSage(flash=flash, dropout=dropout)
norm instance-attribute
norm = rms_norm(dim, eps=rms_norm_eps)
to_qkv instance-attribute
to_qkv = Linear(
    dim, dim_inner * 3, bias=shared_qkv_bias is not None
)
to_gates instance-attribute
to_gates = Linear(dim, heads)
to_out instance-attribute
to_out = Sequential(
    Linear(
        dim_inner, dim, bias=shared_out_bias is not None
    ),
    Dropout(dropout),
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def forward(self, x: Tensor) -> Tensor:
    x = self.norm(x)

    qkv = self.to_qkv(x)
    q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)

    if self.rotary_embed is not None:
        q = self.rotary_embed(q)
        k = self.rotary_embed(k)

    out = self.attend(q, k, v)

    gates = self.to_gates(x)
    gate_act = gates.sigmoid()

    out = out * rearrange(gate_act, "b n h -> b h n 1")

    out = rearrange(out, "b h n d -> b n (h d)")
    out = self.to_out(out)
    return out

LinearAttention

LinearAttention(
    *,
    dim: int,
    dim_head: int = 32,
    heads: int = 8,
    scale: int = 8,
    flash: bool = False,
    dropout: float = 0.0,
    sage_attention: bool = False,
    rms_norm_eps: float | None = None,
)

Bases: Module

this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.

Methods:

Name Description
forward

Attributes:

Name Type Description
norm
to_qkv
temperature
attend
to_out
Source code in src/splifft/models/bs_roformer.py
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
def __init__(
    self,
    *,
    dim: int,
    dim_head: int = 32,
    heads: int = 8,
    scale: int = 8,
    flash: bool = False,
    dropout: float = 0.0,
    sage_attention: bool = False,
    rms_norm_eps: float | None = None,
):
    super().__init__()
    dim_inner = dim_head * heads
    self.norm = rms_norm(dim, eps=rms_norm_eps)

    self.to_qkv = nn.Sequential(
        nn.Linear(dim, dim_inner * 3, bias=False),
        Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
    )

    self.temperature = nn.Parameter(torch.ones(heads, 1, 1))

    if sage_attention:
        from .utils.attend_sage import AttendSage

        self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
    else:
        from .utils.attend import Attend

        self.attend = Attend(scale=scale, dropout=dropout, flash=flash)  # type: ignore

    self.to_out = nn.Sequential(
        Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
    )
norm instance-attribute
norm = rms_norm(dim, eps=rms_norm_eps)
to_qkv instance-attribute
to_qkv = Sequential(
    Linear(dim, dim_inner * 3, bias=False),
    Rearrange(
        "b n (qkv h d) -> qkv b h d n", qkv=3, h=heads
    ),
)
temperature instance-attribute
temperature = Parameter(ones(heads, 1, 1))
attend instance-attribute
attend = AttendSage(
    scale=scale, dropout=dropout, flash=flash
)
to_out instance-attribute
to_out = Sequential(
    Rearrange("b h d n -> b n (h d)"),
    Linear(dim_inner, dim, bias=False),
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
342
343
344
345
346
347
348
349
350
351
352
def forward(self, x: Tensor) -> Tensor:
    x = self.norm(x)

    q, k, v = self.to_qkv(x)

    q, k = map(l2norm, (q, k))
    q = q * self.temperature.exp()

    out = self.attend(q, k, v)

    return self.to_out(out)

Transformer

Transformer(
    *,
    dim: int,
    depth: int,
    dim_head: int = 64,
    heads: int = 8,
    attn_dropout: float = 0.0,
    ff_dropout: float = 0.0,
    ff_mult: int = 4,
    norm_output: bool = True,
    rotary_embed: RotaryEmbedding | None = None,
    flash_attn: bool = True,
    linear_attn: bool = False,
    sage_attention: bool = False,
    shared_qkv_bias: Parameter | None = None,
    shared_out_bias: Parameter | None = None,
    rms_norm_eps: float | None = None,
    transformer_residual_dtype: dtype | None = None,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
layers
transformer_residual_dtype
norm
Source code in src/splifft/models/bs_roformer.py
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
def __init__(
    self,
    *,
    dim: int,
    depth: int,
    dim_head: int = 64,
    heads: int = 8,
    attn_dropout: float = 0.0,
    ff_dropout: float = 0.0,
    ff_mult: int = 4,
    norm_output: bool = True,
    rotary_embed: RotaryEmbedding | None = None,
    flash_attn: bool = True,
    linear_attn: bool = False,
    sage_attention: bool = False,
    shared_qkv_bias: nn.Parameter | None = None,
    shared_out_bias: nn.Parameter | None = None,
    rms_norm_eps: float | None = None,
    transformer_residual_dtype: torch.dtype | None = None,  # COMPAT: float32, see 265
):
    super().__init__()
    self.layers = ModuleList([])

    for _ in range(depth):
        attn: LinearAttention | Attention
        if linear_attn:
            attn = LinearAttention(
                dim=dim,
                dim_head=dim_head,
                heads=heads,
                dropout=attn_dropout,
                flash=flash_attn,
                sage_attention=sage_attention,
                rms_norm_eps=rms_norm_eps,
            )
        else:
            attn = Attention(
                dim=dim,
                dim_head=dim_head,
                heads=heads,
                dropout=attn_dropout,
                shared_qkv_bias=shared_qkv_bias,
                shared_out_bias=shared_out_bias,
                rotary_embed=rotary_embed,
                flash=flash_attn,
                sage_attention=sage_attention,
                rms_norm_eps=rms_norm_eps,
            )

        ff = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout, rms_norm_eps=rms_norm_eps)
        self.layers.append(ModuleList([attn, ff]))
    self.transformer_residual_dtype = transformer_residual_dtype

    self.norm = rms_norm(dim, eps=rms_norm_eps) if norm_output else nn.Identity()
layers instance-attribute
layers = ModuleList([])
transformer_residual_dtype instance-attribute
transformer_residual_dtype = transformer_residual_dtype
norm instance-attribute
norm = (
    rms_norm(dim, eps=rms_norm_eps)
    if norm_output
    else Identity()
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
def forward(self, x: Tensor) -> Tensor:
    for attn, ff in self.layers:  # type: ignore
        attn_out = attn(x)
        if self.transformer_residual_dtype is not None:
            x = (
                attn_out.to(self.transformer_residual_dtype)
                + x.to(self.transformer_residual_dtype)
            ).to(x.dtype)
        else:
            x = attn_out + x

        ff_out = ff(x)
        if self.transformer_residual_dtype is not None:
            x = (
                ff_out.to(self.transformer_residual_dtype)
                + x.to(self.transformer_residual_dtype)
            ).to(x.dtype)
        else:
            x = ff_out + x
    return self.norm(x)

BandSplit

BandSplit(
    dim: int,
    dim_inputs: tuple[int, ...],
    rms_norm_eps: float | None = None,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
dim_inputs
to_features
Source code in src/splifft/models/bs_roformer.py
437
438
439
440
441
442
443
444
def __init__(self, dim: int, dim_inputs: tuple[int, ...], rms_norm_eps: float | None = None):
    super().__init__()
    self.dim_inputs = dim_inputs
    self.to_features = ModuleList([])

    for dim_in in dim_inputs:
        net = nn.Sequential(rms_norm(dim_in, rms_norm_eps), nn.Linear(dim_in, dim))
        self.to_features.append(net)
dim_inputs instance-attribute
dim_inputs = dim_inputs
to_features instance-attribute
to_features = ModuleList([])
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
446
447
448
449
450
451
452
def forward(self, x: Tensor) -> Tensor:
    x_split = x.split(self.dim_inputs, dim=-1)
    outs = []
    for split_input, to_feature_net in zip(x_split, self.to_features):
        split_output = to_feature_net(split_input)
        outs.append(split_output)
    return torch.stack(outs, dim=-2)

mlp

mlp(
    dim_in: int,
    dim_out: int,
    dim_hidden: int | None = None,
    depth: int = 1,
    activation: type[Module] = Tanh,
) -> Sequential
Source code in src/splifft/models/bs_roformer.py
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
def mlp(
    dim_in: int,
    dim_out: int,
    dim_hidden: int | None = None,
    depth: int = 1,
    activation: type[Module] = nn.Tanh,
) -> nn.Sequential:
    dim_hidden_ = dim_hidden or dim_in

    net: list[Module] = []
    # NOTE: in lucidrain's impl, `bs_roformer` has `depth - 1` but `mel_roformer` has `depth`
    num_hidden_layers = depth - 1
    dims = (dim_in, *((dim_hidden_,) * num_hidden_layers), dim_out)

    for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
        is_last = ind == (len(dims) - 2)

        net.append(nn.Linear(layer_dim_in, layer_dim_out))

        if is_last:
            continue

        net.append(activation())

    return nn.Sequential(*net)

MaskEstimator

MaskEstimator(
    dim: int,
    dim_inputs: tuple[int, ...],
    depth: int,
    mlp_expansion_factor: int,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
dim_inputs
to_freqs
Source code in src/splifft/models/bs_roformer.py
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
def __init__(
    self,
    dim: int,
    dim_inputs: tuple[int, ...],
    depth: int,
    mlp_expansion_factor: int,
):
    super().__init__()
    self.dim_inputs = dim_inputs
    self.to_freqs = ModuleList([])
    dim_hidden = dim * mlp_expansion_factor

    for dim_in in dim_inputs:
        self.to_freqs.append(
            nn.Sequential(
                mlp(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
                nn.GLU(dim=-1),
            )
        )
dim_inputs instance-attribute
dim_inputs = dim_inputs
to_freqs instance-attribute
to_freqs = ModuleList([])
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
503
504
505
506
507
508
509
510
511
512
def forward(self, x: Tensor) -> Tensor:
    x_unbound = x.unbind(dim=-2)

    outs = []

    for band_features, mlp_net in zip(x_unbound, self.to_freqs):
        freq_out = mlp_net(band_features)
        outs.append(freq_out)

    return torch.cat(outs, dim=-1)

BSRoformer

BSRoformer(cfg: BSRoformerParams)

Bases: Module

Methods:

Name Description
forward

:param stft_repr: input spectrogram. shape (b, f*s, t, c)

Attributes:

Name Type Description
stereo
audio_channels
num_stems
use_torch_checkpoint
skip_connection
layers
shared_qkv_bias Parameter | None
shared_out_bias Parameter | None
is_mel
final_norm
band_split
mask_estimators
debug
Source code in src/splifft/models/bs_roformer.py
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
def __init__(self, cfg: BSRoformerParams):
    super().__init__()
    self.stereo = cfg.stereo
    self.audio_channels = 2 if cfg.stereo else 1
    self.num_stems = len(cfg.output_stem_names)
    self.use_torch_checkpoint = cfg.use_torch_checkpoint
    self.skip_connection = cfg.skip_connection

    self.layers = ModuleList([])

    self.shared_qkv_bias: nn.Parameter | None = None
    self.shared_out_bias: nn.Parameter | None = None
    if cfg.use_shared_bias:
        dim_inner = cfg.heads * cfg.dim_head
        self.shared_qkv_bias = nn.Parameter(torch.ones(dim_inner * 3))
        self.shared_out_bias = nn.Parameter(torch.ones(cfg.dim))

    transformer = partial(
        Transformer,
        dim=cfg.dim,
        heads=cfg.heads,
        dim_head=cfg.dim_head,
        attn_dropout=cfg.attn_dropout,
        ff_dropout=cfg.ff_dropout,
        ff_mult=cfg.ff_mult,
        flash_attn=cfg.flash_attn,
        norm_output=cfg.norm_output,
        sage_attention=cfg.sage_attention,
        shared_qkv_bias=self.shared_qkv_bias,
        shared_out_bias=self.shared_out_bias,
        rms_norm_eps=cfg.rms_norm_eps,
        transformer_residual_dtype=cfg.transformer_residual_dtype,
    )

    t_frames = cfg.chunk_size // cfg.stft_hop_length + 1  # e.g. 588800 // 512 + 1 = 1151
    time_rotary_embed = RotaryEmbedding(
        seq_len=t_frames, dim_head=cfg.dim_head, dtype=cfg.rotary_embed_dtype
    )

    if is_mel := isinstance(cfg.band_config, MelBandsConfig):
        from torchaudio.functional import melscale_fbanks

        mel_cfg = cfg.band_config
        num_bands = mel_cfg.num_bands
        freqs = mel_cfg.stft_n_fft // 2 + 1
        mel_filter_bank = melscale_fbanks(
            n_freqs=freqs,
            f_min=0.0,
            f_max=float(mel_cfg.sample_rate / 2),
            n_mels=num_bands,
            sample_rate=mel_cfg.sample_rate,
            norm="slaney",
            mel_scale="slaney",
        ).T
        # TODO: adopt https://github.com/lucidrains/BS-RoFormer/issues/47
        mel_filter_bank[0, 0] = 1.0
        mel_filter_bank[-1, -1] = 1.0

        freqs_per_band_mask = mel_filter_bank > 0
        assert freqs_per_band_mask.any(dim=0).all(), (
            "all frequencies must be covered by at least one band"
        )

        repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands)
        freq_indices = repeated_freq_indices[freqs_per_band_mask]
        if self.stereo:
            freq_indices = repeat(freq_indices, "f -> f s", s=2)
            freq_indices = freq_indices * 2 + torch.arange(2)
            freq_indices = rearrange(freq_indices, "f s -> (f s)")
        self.register_buffer("freq_indices", freq_indices, persistent=False)
        self.register_buffer("freqs_per_band_mask", freqs_per_band_mask, persistent=False)

        num_freqs_per_band = reduce(freqs_per_band_mask, "b f -> b", "sum")
        num_bands_per_freq = reduce(freqs_per_band_mask, "b f -> f", "sum")

        self.register_buffer("num_freqs_per_band", num_freqs_per_band, persistent=False)
        self.register_buffer("num_bands_per_freq", num_bands_per_freq, persistent=False)

    elif isinstance(cfg.band_config, FixedBandsConfig):
        num_freqs_per_band = torch.tensor(cfg.band_config.freqs_per_bands)
        num_bands = len(cfg.band_config.freqs_per_bands)
    else:
        raise TypeError(f"unknown band config: {cfg.band_config}")
    self.is_mel = is_mel

    freq_rotary_embed = RotaryEmbedding(
        seq_len=num_bands, dim_head=cfg.dim_head, dtype=cfg.rotary_embed_dtype
    )

    for _ in range(cfg.depth):
        tran_modules = []
        if cfg.linear_transformer_depth > 0:
            tran_modules.append(
                transformer(depth=cfg.linear_transformer_depth, linear_attn=True)
            )
        tran_modules.append(
            transformer(depth=cfg.time_transformer_depth, rotary_embed=time_rotary_embed)
        )
        tran_modules.append(
            transformer(depth=cfg.freq_transformer_depth, rotary_embed=freq_rotary_embed)
        )
        self.layers.append(nn.ModuleList(tran_modules))

    self.final_norm = (
        rms_norm(cfg.dim, eps=cfg.rms_norm_eps) if not self.is_mel else nn.Identity()
    )

    freqs_per_bands_with_complex = tuple(
        2 * f * self.audio_channels for f in num_freqs_per_band.tolist()
    )

    self.band_split = BandSplit(
        dim=cfg.dim,
        dim_inputs=freqs_per_bands_with_complex,
        rms_norm_eps=cfg.rms_norm_eps,
    )

    self.mask_estimators = nn.ModuleList([])

    for _ in range(len(cfg.output_stem_names)):
        mask_estimator = MaskEstimator(
            dim=cfg.dim,
            dim_inputs=freqs_per_bands_with_complex,
            depth=cfg.mask_estimator_depth,
            mlp_expansion_factor=cfg.mlp_expansion_factor,
        )

        self.mask_estimators.append(mask_estimator)

    self.debug = cfg.debug
stereo instance-attribute
stereo = stereo
audio_channels instance-attribute
audio_channels = 2 if stereo else 1
num_stems instance-attribute
num_stems = len(output_stem_names)
use_torch_checkpoint instance-attribute
use_torch_checkpoint = use_torch_checkpoint
skip_connection instance-attribute
skip_connection = skip_connection
layers instance-attribute
layers = ModuleList([])
shared_qkv_bias instance-attribute
shared_qkv_bias: Parameter | None = None
shared_out_bias instance-attribute
shared_out_bias: Parameter | None = None
is_mel instance-attribute
is_mel = is_mel
final_norm instance-attribute
final_norm = (
    rms_norm(dim, eps=rms_norm_eps)
    if not is_mel
    else Identity()
)
band_split instance-attribute
band_split = BandSplit(
    dim=dim,
    dim_inputs=freqs_per_bands_with_complex,
    rms_norm_eps=rms_norm_eps,
)
mask_estimators instance-attribute
mask_estimators = ModuleList([])
debug instance-attribute
debug = debug
forward
forward(stft_repr: Tensor) -> Tensor

Parameters:

Name Type Description Default
stft_repr Tensor

input spectrogram. shape (b, f*s, t, c)

required

Returns:

Type Description
Tensor

estimated mask. shape (b, n, f*s, t, c)

Source code in src/splifft/models/bs_roformer.py
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
def forward(self, stft_repr: Tensor) -> Tensor:
    """
    :param stft_repr: input spectrogram. shape (b, f*s, t, c)
    :return: estimated mask. shape (b, n, f*s, t, c)
    """
    batch, _, t_frames, _ = stft_repr.shape
    device = stft_repr.device
    if self.is_mel:
        batch_arange = torch.arange(batch, device=device)[..., None]
        x = stft_repr[batch_arange, self.freq_indices]
        x = rearrange(x, "b f t c -> b t (f c)")
    else:
        x = rearrange(stft_repr, "b f t c -> b t (f c)")

    if self.debug and (torch.isnan(x).any() or torch.isinf(x).any()):
        raise RuntimeError(
            f"nan/inf in x after rearrange: {x.isnan().sum()} nans, {x.isinf().sum()} infs"
        )

    if self.use_torch_checkpoint:
        x = checkpoint(self.band_split, x, use_reentrant=False)
    else:
        x = self.band_split(x)

    if self.debug and (torch.isnan(x).any() or torch.isinf(x).any()):
        raise RuntimeError(
            f"nan/inf in x after band_split: {x.isnan().sum()} nans, {x.isinf().sum()} infs"
        )

    # axial / hierarchical attention

    store = [None] * len(self.layers)
    for i, transformer_block in enumerate(self.layers):
        if len(transformer_block) == 3:
            linear_transformer, time_transformer, freq_transformer = transformer_block

            x, ft_ps = pack([x], "b * d")
            if self.use_torch_checkpoint:
                x = checkpoint(linear_transformer, x, use_reentrant=False)
            else:
                x = linear_transformer(x)
            (x,) = unpack(x, ft_ps, "b * d")
        else:
            time_transformer, freq_transformer = transformer_block

        if self.skip_connection:
            for j in range(i):
                x = x + store[j]

        x = rearrange(x, "b t f d -> b f t d")
        x, ps = pack([x], "* t d")

        if self.use_torch_checkpoint:
            x = checkpoint(time_transformer, x, use_reentrant=False)
        else:
            x = time_transformer(x)

        (x,) = unpack(x, ps, "* t d")
        x = rearrange(x, "b f t d -> b t f d")
        x, ps = pack([x], "* f d")

        if self.use_torch_checkpoint:
            x = checkpoint(freq_transformer, x, use_reentrant=False)
        else:
            x = freq_transformer(x)

        (x,) = unpack(x, ps, "* f d")

        if self.skip_connection:
            store[i] = x

    x = self.final_norm(x)

    if self.use_torch_checkpoint:
        mask = torch.stack(
            [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators],
            dim=1,
        )
    else:
        mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
    mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)

    if not self.is_mel:
        return mask

    stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
    # stft_repr may be fp16 but complex32 support is experimental so we upcast it early
    stft_repr_complex = torch.view_as_complex(stft_repr.to(torch.float32))

    masks_per_band_complex = torch.view_as_complex(mask)
    masks_per_band_complex = masks_per_band_complex.type(stft_repr_complex.dtype)

    scatter_indices = repeat(
        self.freq_indices,
        "f -> b n f t",
        b=batch,
        n=self.num_stems,
        t=stft_repr_complex.shape[-1],
    )
    stft_repr_expanded_stems = repeat(stft_repr_complex, "b 1 ... -> b n ...", n=self.num_stems)

    masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(
        2, scatter_indices, masks_per_band_complex
    )

    denom = repeat(self.num_bands_per_freq, "f -> (f r) 1", r=self.audio_channels)
    masks_averaged = masks_summed / denom.clamp(min=1e-8)

    return torch.view_as_real(masks_averaged).to(stft_repr.dtype)

utils

Modules:

Name Description
attend
attend_sage
stft

Functions:

Name Description
parse_version
log_once

parse_version

parse_version(v_str: str) -> tuple[int, ...]
Source code in src/splifft/models/utils/__init__.py
6
7
8
def parse_version(v_str: str) -> tuple[int, ...]:
    # e.g "2.1.0+cu118" -> (2, 1, 0)
    return tuple(map(int, v_str.split("+")[0].split(".")))

log_once cached

log_once(
    logger: Logger, msg: object, *, level: int = DEBUG
) -> None
Source code in src/splifft/models/utils/__init__.py
11
12
13
@lru_cache(10)
def log_once(logger: Logger, msg: object, *, level: int = logging.DEBUG) -> None:
    logger.log(level, msg)

attend_sage

Classes:

Name Description
AttendSage

Attributes:

Name Type Description
logger
logger module-attribute
logger = getLogger(__name__)
AttendSage
AttendSage(
    dropout: float = 0.0,
    flash: bool = False,
    scale: float | None = None,
)

Bases: Module

Parameters:

Name Type Description Default
flash bool

if True, attempts to use SageAttention or PyTorch SDPA.

False

Methods:

Name Description
forward

einstein notation

Attributes:

Name Type Description
scale
dropout
use_sage
use_pytorch_sdpa
attn_dropout
Source code in src/splifft/models/utils/attend_sage.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def __init__(
    self,
    dropout: float = 0.0,
    flash: bool = False,
    scale: float | None = None,
):
    """
    :param flash: if True, attempts to use SageAttention or PyTorch SDPA.
    """
    super().__init__()
    self.scale = scale  # for einsum path
    self.dropout = dropout  # for einsum/SDPA path

    self.use_sage = flash and _has_sage_attention
    self.use_pytorch_sdpa = False
    self._sdpa_checked = False

    if flash and not self.use_sage:
        if not self._sdpa_checked:
            if parse_version(torch.__version__) >= (2, 0, 0):
                self.use_pytorch_sdpa = True
                log_once(
                    logger,
                    "Using PyTorch SDPA backend (FlashAttention-2, Memory-Efficient, or Math).",
                )
            else:
                log_once(
                    logger,
                    "Flash attention requested but Pytorch < 2.0 and SageAttention not found. Falling back to einsum.",
                )
            self._sdpa_checked = True

    # dropout layer for manual einsum implementation ONLY
    # SDPA and SageAttention handle dropout differently
    # (or not at all in Sage's base API)
    self.attn_dropout = nn.Dropout(dropout)
scale instance-attribute
scale = scale
dropout instance-attribute
dropout = dropout
use_sage instance-attribute
use_sage = flash and _has_sage_attention
use_pytorch_sdpa instance-attribute
use_pytorch_sdpa = False
attn_dropout instance-attribute
attn_dropout = Dropout(dropout)
forward
forward(q: Tensor, k: Tensor, v: Tensor) -> Tensor

einstein notation

  • b: batch
  • h: heads
  • n, i, j: sequence length (base sequence length, source, target)
  • d: feature dimension

Input tensors q, k, v expected in shape: (batch, heads, seq_len, dim_head) -> HND layout

Source code in src/splifft/models/utils/attend_sage.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
    """
    einstein notation

    - b: batch
    - h: heads
    - n, i, j: sequence length (base sequence length, source, target)
    - d: feature dimension

    Input tensors q, k, v expected in shape: (batch, heads, seq_len, dim_head) -> HND layout
    """
    _q_len, _k_len, _device = q.shape[-2], k.shape[-2], q.device

    # priority 1: SageAttention
    if self.use_sage:
        # assumes q, k, v are FP16/BF16 (handled by autocast upstream)
        # assumes scale is handled internally by sageattn
        # assumes dropout is NOT handled by sageattn kernel
        # is_causal=False based on how Attend is called in mel_band_roformer
        out = sageattn(q, k, v, tensor_layout="HND", is_causal=False)  # type: ignore
        return out  # type: ignore
        try:
            out = sageattn(q, k, v, tensor_layout="HND", is_causal=False)
            return out
        except Exception as e:
            logger.error(f"SageAttention failed with error: {e}. Falling back.")
            self.use_sage = False
            if not self._sdpa_checked:
                if parse_version(torch.__version__) >= (2, 0, 0):
                    self.use_pytorch_sdpa = True
                    log_once(logger, "falling back to PyTorch SDPA")
                else:
                    log_once(logger, "falling back to einsum.")

                self._sdpa_checked = True

    # priority 2: PyTorch SDPA
    if self.use_pytorch_sdpa:
        # it handles scaling and dropout internally.
        try:
            with sdpa_kernel(
                [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
            ):
                out = F.scaled_dot_product_attention(
                    q,
                    k,
                    v,
                    attn_mask=None,  # assuming no explicit mask needed here
                    dropout_p=self.dropout if self.training else 0.0,
                    is_causal=False,  # assuming not needed based on usage context
                )
            return out
        except Exception as e:
            log_once(
                logger,
                f"pytorch SDPA failed with error: {e}. falling back to einsum.",
                level=logging.ERROR,
            )
            self.use_pytorch_sdpa = False

    scale = self.scale or q.shape[-1] ** -0.5

    # similarity
    sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale

    # attention
    attn = sim.softmax(dim=-1)
    attn = self.attn_dropout(attn)  # ONLY in einsum path

    # aggregate values
    out = einsum("b h i j, b h j d -> b h i d", attn, v)

    return out

attend

Classes:

Name Description
Attend

Attributes:

Name Type Description
logger
logger module-attribute
logger = getLogger(__name__)
Attend
Attend(
    dropout: float = 0.0,
    flash: bool = False,
    scale: float | None = None,
)

Bases: Module

Methods:

Name Description
flash_attn
forward

einstein notation

Attributes:

Name Type Description
scale
dropout
attn_dropout
flash
cpu_backends
cuda_backends list[_SDPBackend] | None
Source code in src/splifft/models/utils/attend.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def __init__(
    self, dropout: float = 0.0, flash: bool = False, scale: float | None = None
) -> None:
    super().__init__()
    self.scale = scale
    self.dropout = dropout
    self.attn_dropout = nn.Dropout(dropout)

    self.flash = flash
    assert not (flash and parse_version(torch.__version__) < (2, 0, 0)), (
        "expected pytorch >= 2.0.0 to use flash attention"
    )

    # determine efficient attention configs for cuda and cpu
    self.cpu_backends = [
        SDPBackend.FLASH_ATTENTION,
        SDPBackend.EFFICIENT_ATTENTION,
        SDPBackend.MATH,
    ]
    self.cuda_backends: list[_SDPBackend] | None = None

    if not torch.cuda.is_available() or not flash:
        return

    device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
    device_version = parse_version(f"{device_properties.major}.{device_properties.minor}")

    if device_version >= (8, 0):
        if os.name == "nt":
            cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
            log_once(logger, f"windows detected, using {cuda_backends=}")
        else:
            cuda_backends = [SDPBackend.FLASH_ATTENTION]
            log_once(logger, f"gpu compute capability >= 8.0, using {cuda_backends=}")
    else:
        cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
        log_once(logger, f"gpu compute capability < 8.0, using {cuda_backends=}")

    self.cuda_backends = cuda_backends
scale instance-attribute
scale = scale
dropout instance-attribute
dropout = dropout
attn_dropout instance-attribute
attn_dropout = Dropout(dropout)
flash instance-attribute
flash = flash
cpu_backends instance-attribute
cpu_backends = [FLASH_ATTENTION, EFFICIENT_ATTENTION, MATH]
cuda_backends instance-attribute
cuda_backends: list[_SDPBackend] | None = cuda_backends
flash_attn
flash_attn(q: Tensor, k: Tensor, v: Tensor) -> Tensor
Source code in src/splifft/models/utils/attend.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def flash_attn(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
    _, _heads, _q_len, _, _k_len, is_cuda, _device = (
        *q.shape,
        k.shape[-2],
        q.is_cuda,
        q.device,
    )  # type: ignore

    if self.scale is not None:
        default_scale = q.shape[-1] ** -0.5
        q = q * (self.scale / default_scale)

    backends = self.cuda_backends if is_cuda else self.cpu_backends
    # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
    with sdpa_kernel(backends=backends):  # type: ignore
        out = F.scaled_dot_product_attention(
            q, k, v, dropout_p=self.dropout if self.training else 0.0
        )

    return out
forward
forward(q: Tensor, k: Tensor, v: Tensor) -> Tensor

einstein notation

  • b: batch
  • h: heads
  • n, i, j: sequence length (base sequence length, source, target)
  • d: feature dimension
Source code in src/splifft/models/utils/attend.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
    """
    einstein notation

    - b: batch
    - h: heads
    - n, i, j: sequence length (base sequence length, source, target)
    - d: feature dimension
    """
    _q_len, _k_len, _device = q.shape[-2], k.shape[-2], q.device

    scale = self.scale or q.shape[-1] ** -0.5

    if self.flash:
        return self.flash_attn(q, k, v)

    # similarity
    sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale

    # attention
    attn = sim.softmax(dim=-1)
    attn = self.attn_dropout(attn)

    # aggregate values
    out = einsum("b h i j, b h j d -> b h i d", attn, v)

    return out

stft

Classes:

Name Description
Stft

A custom STFT implementation using 1D convolutions to ensure compatibility with CoreML.

IStft

A simple wrapper around torch.istft with a hacky workaround for MPS.

Stft
Stft(
    n_fft: int,
    hop_length: int,
    win_length: int,
    window_fn: Callable[[int], Tensor],
    conv_dtype: dtype | None,
)

Bases: Module

A custom STFT implementation using 1D convolutions to ensure compatibility with CoreML.

Methods:

Name Description
forward

Attributes:

Name Type Description
n_fft
hop_length
win_length
conv_dtype
Source code in src/splifft/models/utils/stft.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def __init__(
    self,
    n_fft: int,
    hop_length: int,
    win_length: int,
    window_fn: Callable[[int], Tensor],
    conv_dtype: torch.dtype | None,
):
    super().__init__()
    self.n_fft = n_fft
    self.hop_length = hop_length
    self.win_length = win_length
    self.conv_dtype = conv_dtype

    window = window_fn(self.win_length)

    dft_mat = torch.fft.fft(torch.eye(self.n_fft, device=window.device))
    dft_mat_T = dft_mat.T

    real_kernels = dft_mat_T.real[
        : self.win_length, : (self.n_fft // 2 + 1)
    ] * window.unsqueeze(-1)
    imag_kernels = dft_mat_T.imag[
        : self.win_length, : (self.n_fft // 2 + 1)
    ] * window.unsqueeze(-1)

    # (out_channels, in_channels, kernel_size)
    self.register_buffer("real_conv_weight", real_kernels.T.unsqueeze(1).to(self.conv_dtype))
    self.register_buffer("imag_conv_weight", imag_kernels.T.unsqueeze(1).to(self.conv_dtype))
n_fft instance-attribute
n_fft = n_fft
hop_length instance-attribute
hop_length = hop_length
win_length instance-attribute
win_length = win_length
conv_dtype instance-attribute
conv_dtype = conv_dtype
forward
forward(x: Tensor) -> ComplexSpectrogram
Source code in src/splifft/models/utils/stft.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def forward(self, x: Tensor) -> t.ComplexSpectrogram:
    b, s, t = x.shape
    x = x.reshape(b * s, 1, t).to(self.conv_dtype)

    padding = self.n_fft // 2
    x = F.pad(x, (padding, padding), "reflect")

    real_part = F.conv1d(x, self.real_conv_weight, stride=self.hop_length)  # type: ignore
    imag_part = F.conv1d(x, self.imag_conv_weight, stride=self.hop_length)  # type: ignore
    spec = torch.stack((real_part, imag_part), dim=-1)  # (b*s, f, t_frames, c=2)

    _bs, f, t_frames, c = spec.shape
    spec = spec.view(b, s, f, t_frames, c)

    return spec  # type: ignore
IStft
IStft(
    n_fft: int,
    hop_length: int,
    win_length: int,
    window_fn: Callable[[int], Tensor] = hann_window,
)

Bases: Module

A simple wrapper around torch.istft with a hacky workaround for MPS.

TODO: implement a proper workaround.

Methods:

Name Description
forward

Attributes:

Name Type Description
n_fft
hop_length
win_length
window
Source code in src/splifft/models/utils/stft.py
69
70
71
72
73
74
75
76
77
78
79
80
def __init__(
    self,
    n_fft: int,
    hop_length: int,
    win_length: int,
    window_fn: Callable[[int], Tensor] = torch.hann_window,
):
    super().__init__()
    self.n_fft = n_fft
    self.hop_length = hop_length
    self.win_length = win_length
    self.window = window_fn(self.win_length)
n_fft instance-attribute
n_fft = n_fft
hop_length instance-attribute
hop_length = hop_length
win_length instance-attribute
win_length = win_length
window instance-attribute
window = window_fn(win_length)
forward
forward(
    spec: ComplexSpectrogram, length: int | None = None
) -> RawAudioTensor | NormalizedAudioTensor
Source code in src/splifft/models/utils/stft.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def forward(
    self, spec: t.ComplexSpectrogram, length: int | None = None
) -> t.RawAudioTensor | t.NormalizedAudioTensor:
    device = spec.device
    is_mps = device.type == "mps"
    window = self.window.to(device)
    # see https://github.com/lucidrains/BS-RoFormer/issues/47
    # this would introduce a breaking change.
    # spec = spec.index_fill(1, torch.tensor(0, device=spec.device), 0.)  # type: ignore
    spec_complex = torch.view_as_complex(spec)

    try:
        audio = torch.istft(
            spec_complex,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window=window,
            return_complex=False,
            length=length,
        )
    except RuntimeError:
        audio = torch.istft(
            spec_complex.cpu() if is_mps else spec_complex,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window=window.cpu() if is_mps else window,
            return_complex=False,
            length=length,
        ).to(device)

    return audio  # type: ignore