Skip to content

Models

models

Source separation models.

Modules:

Name Description
bs_roformer
utils

Classes:

Name Description
ModelConfigLike

A trait that must be implemented to be considered a model configuration.

ModelMetadata

Metadata about a model, including its type, configuration class, and model class.

Attributes:

Name Type Description
ModelType TypeAlias

The type of the model, e.g. bs_roformer, demucs

ChunkSize TypeAlias

The length of an audio segment, in samples, processed by the model at one time.

ModelOutputStemName TypeAlias

The output stem name, e.g. vocals, drums, bass, etc.

ModelConfigLikeT
ModelT

ModelType module-attribute

ModelType: TypeAlias = str

The type of the model, e.g. bs_roformer, demucs

ModelConfigLike

Bases: Protocol

A trait that must be implemented to be considered a model configuration.

Attributes:

Name Type Description
chunk_size ChunkSize
output_stem_names tuple[ModelOutputStemName, ...]

chunk_size instance-attribute

chunk_size: ChunkSize

output_stem_names instance-attribute

output_stem_names: tuple[ModelOutputStemName, ...]

ChunkSize module-attribute

ChunkSize: TypeAlias = Annotated[int, Gt(0)]

The length of an audio segment, in samples, processed by the model at one time.

A full audio track is often too long to fit into GPU, instead we process it in fixed-size chunks. A larger chunk size may allow the model to capture more temporal context at the cost of increased memory usage.

ModelOutputStemName module-attribute

ModelOutputStemName: TypeAlias = str

The output stem name, e.g. vocals, drums, bass, etc.

ModelConfigLikeT module-attribute

ModelConfigLikeT = TypeVar(
    "ModelConfigLikeT", bound=ModelConfigLike
)

ModelT module-attribute

ModelT = TypeVar('ModelT', bound=Module)

ModelMetadata dataclass

ModelMetadata(
    model_type: ModelType,
    config: type[ModelConfigLikeT],
    model: type[ModelT],
)

Bases: Generic[ModelT, ModelConfigLikeT]

Metadata about a model, including its type, configuration class, and model class.

Methods:

Name Description
from_module

Dynamically import a model named X and its configuration dataclass XConfig under a

Attributes:

Name Type Description
model_type ModelType
config type[ModelConfigLikeT]
model type[ModelT]

model_type instance-attribute

model_type: ModelType

config instance-attribute

model instance-attribute

model: type[ModelT]

from_module classmethod

from_module(
    module_name: str,
    model_cls_name: str,
    *,
    model_type: ModelType,
    package: str | None = None,
) -> ModelMetadata[Module, ModelConfigLike]

Dynamically import a model named X and its configuration dataclass XConfig under a given module name (e.g. splifft.models.bs_roformer).

Parameters:

Name Type Description Default
model_cls_name str

The name of the model class to import, e.g. BSRoformer.

required
module_name str

The name of the module to import, e.g. splifft.models.bs_roformer.

required
package str | None

The package to use as the anchor point from which to resolve the relative import. to an absolute import. This is only required when performing a relative import.

None
Source code in src/splifft/models/__init__.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@classmethod
def from_module(
    cls,
    module_name: str,
    model_cls_name: str,
    *,
    model_type: ModelType,
    package: str | None = None,
) -> ModelMetadata[nn.Module, ModelConfigLike]:
    """
    Dynamically import a model named `X` and its configuration dataclass `XConfig` under a
    given module name (e.g. `splifft.models.bs_roformer`).

    :param model_cls_name: The name of the model class to import, e.g. `BSRoformer`.
    :param module_name: The name of the module to import, e.g. `splifft.models.bs_roformer`.
    :param package: The package to use as the anchor point from which to resolve the relative import.
    to an absolute import. This is only required when performing a relative import.
    """
    _loc = f"{module_name=} under {package=}"
    try:
        module = importlib.import_module(module_name, package)
    except ImportError as e:
        raise ValueError(f"failed to find or import module for {_loc}") from e

    config_cls_name = f"{model_cls_name}Config"
    model_cls = getattr(module, model_cls_name, None)
    config_cls = getattr(module, config_cls_name, None)
    if model_cls is None or config_cls is None:
        raise AttributeError(
            f"expected to find a class named `{config_cls_name}` in {_loc}, but it was not found."
        )

    return ModelMetadata(
        model_type=model_type,
        model=model_cls,
        config=config_cls,  # type: ignore
    )

bs_roformer

Classes:

Name Description
BSRoformerConfig
CustomNorm
RotaryEmbedding
FeedForward
Attention
LinearAttention

this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.

Transformer
BandSplit
MaskEstimator
BSRoformer

Functions:

Name Description
l2norm
MLP

Attributes:

Name Type Description
DEFAULT_FREQS_PER_BANDS

DEFAULT_FREQS_PER_BANDS module-attribute

DEFAULT_FREQS_PER_BANDS = (
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    12,
    12,
    12,
    12,
    12,
    12,
    12,
    12,
    24,
    24,
    24,
    24,
    24,
    24,
    24,
    24,
    48,
    48,
    48,
    48,
    48,
    48,
    48,
    48,
    128,
    129,
)

BSRoformerConfig dataclass

BSRoformerConfig(
    chunk_size: ChunkSize,
    output_stem_names: tuple[ModelOutputStemName, ...],
    dim: int,
    depth: int,
    stereo: bool = False,
    time_transformer_depth: int = 2,
    freq_transformer_depth: int = 2,
    linear_transformer_depth: int = 0,
    freqs_per_bands: tuple[
        int, ...
    ] = lambda: DEFAULT_FREQS_PER_BANDS(),
    dim_head: int = 64,
    heads: int = 8,
    attn_dropout: float = 0.0,
    ff_dropout: float = 0.0,
    flash_attn: bool = True,
    stft_n_fft: int = 2048,
    stft_hop_length: int = 512,
    stft_win_length: int = 2048,
    stft_normalized: bool = False,
    stft_window_fn_name: str = "torch.hann_window",
    mask_estimator_depth: int = 2,
    mlp_expansion_factor: int = 4,
    use_torch_checkpoint: bool = False,
    sage_attention: bool = False,
    use_shared_bias: bool = True,
    skip_connection: bool = False,
    multi_stft_resolution_loss_weight: float = 1.0,
    multi_stft_resolutions_window_sizes: tuple[
        int, ...
    ] = lambda: (4096, 2048, 1024, 512, 256)(),
    multi_stft_hop_size: int = 147,
    multi_stft_normalized: bool = False,
    multi_stft_window_fn_name: str = "torch.hann_window",
    debug: bool = False,
)

Bases: ModelConfigLike

Attributes:

Name Type Description
chunk_size ChunkSize
output_stem_names tuple[ModelOutputStemName, ...]
dim int
depth int
stereo bool
time_transformer_depth int
freq_transformer_depth int
linear_transformer_depth int
freqs_per_bands tuple[int, ...]
dim_head int
heads int
attn_dropout float
ff_dropout float
flash_attn bool
stft_n_fft int
stft_hop_length int
stft_win_length int
stft_normalized bool
stft_window_fn_name str
mask_estimator_depth int
mlp_expansion_factor int
use_torch_checkpoint bool
sage_attention bool
use_shared_bias bool
skip_connection bool
multi_stft_resolution_loss_weight float
multi_stft_resolutions_window_sizes tuple[int, ...]
multi_stft_hop_size int
multi_stft_normalized bool
multi_stft_window_fn_name str
debug bool

Whether to check for nan/inf in model outputs. Keep it off for torch.compile.

stft_window_fn Callable[[int], Tensor]
multi_stft_window_fn Callable[[int], Tensor]
chunk_size instance-attribute
chunk_size: ChunkSize
output_stem_names instance-attribute
output_stem_names: tuple[ModelOutputStemName, ...]
dim instance-attribute
dim: int
depth instance-attribute
depth: int
stereo class-attribute instance-attribute
stereo: bool = False
time_transformer_depth class-attribute instance-attribute
time_transformer_depth: int = 2
freq_transformer_depth class-attribute instance-attribute
freq_transformer_depth: int = 2
linear_transformer_depth class-attribute instance-attribute
linear_transformer_depth: int = 0
freqs_per_bands class-attribute instance-attribute
freqs_per_bands: tuple[int, ...] = field(
    default_factory=lambda: DEFAULT_FREQS_PER_BANDS
)
dim_head class-attribute instance-attribute
dim_head: int = 64
heads class-attribute instance-attribute
heads: int = 8
attn_dropout class-attribute instance-attribute
attn_dropout: float = 0.0
ff_dropout class-attribute instance-attribute
ff_dropout: float = 0.0
flash_attn class-attribute instance-attribute
flash_attn: bool = True
stft_n_fft class-attribute instance-attribute
stft_n_fft: int = 2048
stft_hop_length class-attribute instance-attribute
stft_hop_length: int = 512
stft_win_length class-attribute instance-attribute
stft_win_length: int = 2048
stft_normalized class-attribute instance-attribute
stft_normalized: bool = False
stft_window_fn_name class-attribute instance-attribute
stft_window_fn_name: str = 'torch.hann_window'
mask_estimator_depth class-attribute instance-attribute
mask_estimator_depth: int = 2
mlp_expansion_factor class-attribute instance-attribute
mlp_expansion_factor: int = 4
use_torch_checkpoint class-attribute instance-attribute
use_torch_checkpoint: bool = False
sage_attention class-attribute instance-attribute
sage_attention: bool = False
use_shared_bias class-attribute instance-attribute
use_shared_bias: bool = True
skip_connection class-attribute instance-attribute
skip_connection: bool = False
multi_stft_resolution_loss_weight class-attribute instance-attribute
multi_stft_resolution_loss_weight: float = 1.0
multi_stft_resolutions_window_sizes class-attribute instance-attribute
multi_stft_resolutions_window_sizes: tuple[int, ...] = (
    field(
        default_factory=lambda: (4096, 2048, 1024, 512, 256)
    )
)
multi_stft_hop_size class-attribute instance-attribute
multi_stft_hop_size: int = 147
multi_stft_normalized class-attribute instance-attribute
multi_stft_normalized: bool = False
multi_stft_window_fn_name class-attribute instance-attribute
multi_stft_window_fn_name: str = 'torch.hann_window'
debug class-attribute instance-attribute
debug: bool = False

Whether to check for nan/inf in model outputs. Keep it off for torch.compile.

stft_window_fn property
stft_window_fn: Callable[[int], Tensor]
multi_stft_window_fn property
multi_stft_window_fn: Callable[[int], Tensor]

l2norm

l2norm(t: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
91
92
def l2norm(t: Tensor) -> Tensor:
    return F.normalize(t, dim=-1, p=2)

CustomNorm

CustomNorm(dim: int, eps: float = 5.960464477539063e-08)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
scale
gamma
eps
Source code in src/splifft/models/bs_roformer.py
 96
 97
 98
 99
100
def __init__(self, dim: int, eps: float = 5.960464477539063e-08):
    super().__init__()
    self.scale = dim**0.5
    self.gamma = nn.Parameter(torch.ones(dim))
    self.eps = eps
scale instance-attribute
scale = dim ** 0.5
gamma instance-attribute
gamma = Parameter(ones(dim))
eps instance-attribute
eps = eps
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
102
103
104
105
106
def forward(self, x: Tensor) -> Tensor:
    l2_norm = torch.linalg.norm(x, dim=-1, keepdim=True)
    denom = torch.maximum(l2_norm, torch.full_like(l2_norm, self.eps))
    normalized_x = x / denom
    return normalized_x * self.scale * self.gamma

RotaryEmbedding

RotaryEmbedding(cos_emb: Parameter, sin_emb: Parameter)

Bases: Module

Methods:

Name Description
rotate_half
forward

Attributes:

Name Type Description
cos_emb
sin_emb
Source code in src/splifft/models/bs_roformer.py
113
114
115
116
117
def __init__(self, cos_emb: nn.Parameter, sin_emb: nn.Parameter):
    super().__init__()
    # both (seq_len_for_rotation, dim_head)
    self.cos_emb = cos_emb
    self.sin_emb = sin_emb
cos_emb instance-attribute
cos_emb = cos_emb
sin_emb instance-attribute
sin_emb = sin_emb
rotate_half
rotate_half(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
119
120
121
122
123
def rotate_half(self, x: Tensor) -> Tensor:
    x = rearrange(x, "... (d r) -> ... d r", r=2)
    x1, x2 = x.unbind(dim=-1)
    x = torch.stack((-x2, x1), dim=-1)
    return rearrange(x, "... d r -> ... (d r)")
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
125
126
127
128
129
130
131
132
133
134
def forward(self, x: Tensor) -> Tensor:
    # x is (batch_eff, heads, seq_len_for_rotation, dim_head)
    cos_b = self.cos_emb.unsqueeze(0).unsqueeze(0).to(x.device, x.dtype)
    sin_b = self.sin_emb.unsqueeze(0).unsqueeze(0).to(x.device, x.dtype)

    term1 = x * cos_b
    term2 = self.rotate_half(x) * sin_b

    sum_val = term1.to(torch.float32) + term2.to(torch.float32)
    return sum_val.to(x.dtype)

FeedForward

FeedForward(dim: int, mult: int = 4, dropout: float = 0.0)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
net
Source code in src/splifft/models/bs_roformer.py
138
139
140
141
142
143
144
145
146
147
148
def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0):
    super().__init__()
    dim_inner = int(dim * mult)
    self.net = nn.Sequential(
        CustomNorm(dim),
        nn.Linear(dim, dim_inner),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim_inner, dim),
        nn.Dropout(dropout),
    )
net instance-attribute
net = Sequential(
    CustomNorm(dim),
    Linear(dim, dim_inner),
    GELU(),
    Dropout(dropout),
    Linear(dim_inner, dim),
    Dropout(dropout),
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
150
151
def forward(self, x: Tensor) -> Tensor:
    return self.net(x)

Attention

Attention(
    dim: int,
    heads: int = 8,
    dim_head: int = 64,
    dropout: float = 0.0,
    shared_qkv_bias: Parameter | None = None,
    shared_out_bias: Parameter | None = None,
    rotary_embed: RotaryEmbedding | None = None,
    flash: bool = True,
    sage_attention: bool = False,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
heads
scale
rotary_embed
attend
norm
to_qkv
to_gates
to_out
Source code in src/splifft/models/bs_roformer.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def __init__(
    self,
    dim: int,
    heads: int = 8,
    dim_head: int = 64,
    dropout: float = 0.0,
    shared_qkv_bias: nn.Parameter | None = None,
    shared_out_bias: nn.Parameter | None = None,
    rotary_embed: RotaryEmbedding | None = None,
    flash: bool = True,
    sage_attention: bool = False,
):
    super().__init__()
    self.heads = heads
    self.scale = dim_head**-0.5
    dim_inner = heads * dim_head

    self.rotary_embed = rotary_embed

    if sage_attention:
        self.attend = AttendSage(flash=flash, dropout=dropout)
    else:
        self.attend = Attend(flash=flash, dropout=dropout)

    self.norm = CustomNorm(dim)
    self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=(shared_qkv_bias is not None))
    if shared_qkv_bias is not None:
        self.to_qkv.bias = shared_qkv_bias

    self.to_gates = nn.Linear(dim, heads)

    self.to_out = nn.Sequential(
        nn.Linear(dim_inner, dim, bias=(shared_out_bias is not None)),
        nn.Dropout(dropout),
    )
    if shared_out_bias is not None:
        self.to_out[0].bias = shared_out_bias
heads instance-attribute
heads = heads
scale instance-attribute
scale = dim_head ** -0.5
rotary_embed instance-attribute
rotary_embed = rotary_embed
attend instance-attribute
attend = AttendSage(flash=flash, dropout=dropout)
norm instance-attribute
norm = CustomNorm(dim)
to_qkv instance-attribute
to_qkv = Linear(
    dim, dim_inner * 3, bias=shared_qkv_bias is not None
)
to_gates instance-attribute
to_gates = Linear(dim, heads)
to_out instance-attribute
to_out = Sequential(
    Linear(
        dim_inner, dim, bias=shared_out_bias is not None
    ),
    Dropout(dropout),
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def forward(self, x: Tensor) -> Tensor:
    x = self.norm(x)

    qkv = self.to_qkv(x)
    q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)

    if self.rotary_embed is not None:
        q = self.rotary_embed(q)
        k = self.rotary_embed(k)

    out = self.attend(q, k, v)

    gates = self.to_gates(x)
    gate_act = gates.sigmoid()

    out = out * rearrange(gate_act, "b n h -> b h n 1")

    out = rearrange(out, "b h n d -> b n (h d)")
    out = self.to_out(out)
    return out

LinearAttention

LinearAttention(
    *,
    dim: int,
    dim_head: int = 32,
    heads: int = 8,
    scale: int = 8,
    flash: bool = False,
    dropout: float = 0.0,
    sage_attention: bool = False,
)

Bases: Module

this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.

Methods:

Name Description
forward

Attributes:

Name Type Description
norm
to_qkv
temperature
attend
to_out
Source code in src/splifft/models/bs_roformer.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def __init__(
    self,
    *,
    dim: int,
    dim_head: int = 32,
    heads: int = 8,
    scale: int = 8,
    flash: bool = False,
    dropout: float = 0.0,
    sage_attention: bool = False,
):
    super().__init__()
    dim_inner = dim_head * heads
    self.norm = CustomNorm(dim)

    self.to_qkv = nn.Sequential(
        nn.Linear(dim, dim_inner * 3, bias=False),
        Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
    )

    self.temperature = nn.Parameter(torch.ones(heads, 1, 1))

    if sage_attention:
        self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)  # type: ignore
    else:
        self.attend = Attend(scale=scale, dropout=dropout, flash=flash)  # type: ignore

    self.to_out = nn.Sequential(
        Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
    )
norm instance-attribute
norm = CustomNorm(dim)
to_qkv instance-attribute
to_qkv = Sequential(
    Linear(dim, dim_inner * 3, bias=False),
    Rearrange(
        "b n (qkv h d) -> qkv b h d n", qkv=3, h=heads
    ),
)
temperature instance-attribute
temperature = Parameter(ones(heads, 1, 1))
attend instance-attribute
attend = AttendSage(
    scale=scale, dropout=dropout, flash=flash
)
to_out instance-attribute
to_out = Sequential(
    Rearrange("b h d n -> b n (h d)"),
    Linear(dim_inner, dim, bias=False),
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
251
252
253
254
255
256
257
258
259
260
261
def forward(self, x: Tensor) -> Tensor:
    x = self.norm(x)

    q, k, v = self.to_qkv(x)

    q, k = map(l2norm, (q, k))
    q = q * self.temperature.exp()

    out = self.attend(q, k, v)

    return self.to_out(out)

Transformer

Transformer(
    *,
    dim: int,
    depth: int,
    dim_head: int = 64,
    heads: int = 8,
    attn_dropout: float = 0.0,
    ff_dropout: float = 0.0,
    ff_mult: int = 4,
    norm_output: bool = True,
    rotary_embed: RotaryEmbedding | None = None,
    flash_attn: bool = True,
    linear_attn: bool = False,
    sage_attention: bool = False,
    shared_qkv_bias: Parameter | None = None,
    shared_out_bias: Parameter | None = None,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
layers
norm
Source code in src/splifft/models/bs_roformer.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
def __init__(
    self,
    *,
    dim: int,
    depth: int,
    dim_head: int = 64,
    heads: int = 8,
    attn_dropout: float = 0.0,
    ff_dropout: float = 0.0,
    ff_mult: int = 4,
    norm_output: bool = True,
    rotary_embed: RotaryEmbedding | None = None,
    flash_attn: bool = True,
    linear_attn: bool = False,
    sage_attention: bool = False,
    shared_qkv_bias: nn.Parameter | None = None,
    shared_out_bias: nn.Parameter | None = None,
):
    super().__init__()
    self.layers = ModuleList([])

    for _ in range(depth):
        attn: LinearAttention | Attention
        if linear_attn:
            attn = LinearAttention(
                dim=dim,
                dim_head=dim_head,
                heads=heads,
                dropout=attn_dropout,
                flash=flash_attn,
                sage_attention=sage_attention,
            )
        else:
            attn = Attention(
                dim=dim,
                dim_head=dim_head,
                heads=heads,
                dropout=attn_dropout,
                shared_qkv_bias=shared_qkv_bias,
                shared_out_bias=shared_out_bias,
                rotary_embed=rotary_embed,
                flash=flash_attn,
                sage_attention=sage_attention,
            )

        self.layers.append(
            ModuleList([attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)])
        )

    self.norm = CustomNorm(dim) if norm_output else nn.Identity()
layers instance-attribute
layers = ModuleList([])
norm instance-attribute
norm = CustomNorm(dim) if norm_output else Identity()
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
316
317
318
319
320
def forward(self, x: Tensor) -> Tensor:
    for attn, ff in self.layers:  # type: ignore
        x = attn(x) + x
        x = ff(x) + x
    return self.norm(x)

BandSplit

BandSplit(dim: int, dim_inputs: tuple[int, ...])

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
dim_inputs
to_features
Source code in src/splifft/models/bs_roformer.py
327
328
329
330
331
332
333
334
def __init__(self, dim: int, dim_inputs: tuple[int, ...]):
    super().__init__()
    self.dim_inputs = dim_inputs
    self.to_features = ModuleList([])

    for dim_in in dim_inputs:
        net = nn.Sequential(CustomNorm(dim_in), nn.Linear(dim_in, dim))
        self.to_features.append(net)
dim_inputs instance-attribute
dim_inputs = dim_inputs
to_features instance-attribute
to_features = ModuleList([])
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
336
337
338
339
340
341
342
def forward(self, x: Tensor) -> Tensor:
    x_split = x.split(self.dim_inputs, dim=-1)
    outs = []
    for split_input, to_feature_net in zip(x_split, self.to_features):
        split_output = to_feature_net(split_input)
        outs.append(split_output)
    return torch.stack(outs, dim=-2)

MLP

MLP(
    dim_in: int,
    dim_out: int,
    dim_hidden: int | None = None,
    depth: int = 1,
    activation: type[Module] = Tanh,
) -> Sequential
Source code in src/splifft/models/bs_roformer.py
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
def MLP(
    dim_in: int,
    dim_out: int,
    dim_hidden: int | None = None,
    depth: int = 1,
    activation: type[Module] = nn.Tanh,
) -> nn.Sequential:
    dim_hidden_ = dim_hidden or dim_in

    net: list[Module] = []
    dims = (dim_in, *((dim_hidden_,) * (depth - 1)), dim_out)

    for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
        is_last = ind == (len(dims) - 2)

        net.append(nn.Linear(layer_dim_in, layer_dim_out))

        if is_last:
            continue

        net.append(activation())

    return nn.Sequential(*net)

MaskEstimator

MaskEstimator(
    dim: int,
    dim_inputs: tuple[int, ...],
    depth: int,
    mlp_expansion_factor: int,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
dim_inputs
to_freqs
Source code in src/splifft/models/bs_roformer.py
371
372
373
374
375
376
377
378
379
380
381
382
383
def __init__(
    self, dim: int, dim_inputs: tuple[int, ...], depth: int, mlp_expansion_factor: int
):
    super().__init__()
    self.dim_inputs = dim_inputs
    self.to_freqs = ModuleList([])
    dim_hidden = dim * mlp_expansion_factor

    for dim_in in dim_inputs:
        mlp = nn.Sequential(
            MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
        )
        self.to_freqs.append(mlp)
dim_inputs instance-attribute
dim_inputs = dim_inputs
to_freqs instance-attribute
to_freqs = ModuleList([])
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
385
386
387
388
389
390
391
392
393
394
def forward(self, x: Tensor) -> Tensor:
    x_unbound = x.unbind(dim=-2)

    outs = []

    for band_features, mlp in zip(x_unbound, self.to_freqs):
        freq_out = mlp(band_features)
        outs.append(freq_out)

    return torch.cat(outs, dim=-1)

BSRoformer

BSRoformer(cfg: BSRoformerConfig)

Bases: Module

Methods:

Name Description
forward

einops

Attributes:

Name Type Description
stereo
audio_channels
num_stems
use_torch_checkpoint
skip_connection
layers
shared_qkv_bias Parameter | None
shared_out_bias Parameter | None
cos_emb_time
sin_emb_time
cos_emb_freq
sin_emb_freq
final_norm
stft_kwargs
stft_window_fn
band_split
mask_estimators
multi_stft_resolution_loss_weight
multi_stft_resolutions_window_sizes
multi_stft_n_fft
multi_stft_window_fn
multi_stft_kwargs
debug
Source code in src/splifft/models/bs_roformer.py
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
def __init__(self, cfg: BSRoformerConfig):
    super().__init__()
    self.stereo = cfg.stereo
    self.audio_channels = 2 if cfg.stereo else 1
    self.num_stems = len(cfg.output_stem_names)
    self.use_torch_checkpoint = cfg.use_torch_checkpoint
    self.skip_connection = cfg.skip_connection

    self.layers = ModuleList([])

    self.shared_qkv_bias: nn.Parameter | None = None
    self.shared_out_bias: nn.Parameter | None = None
    if cfg.use_shared_bias:
        dim_inner = cfg.heads * cfg.dim_head
        self.shared_qkv_bias = nn.Parameter(torch.ones(dim_inner * 3))
        self.shared_out_bias = nn.Parameter(torch.ones(cfg.dim))

    transformer_kwargs = dict(
        dim=cfg.dim,
        heads=cfg.heads,
        dim_head=cfg.dim_head,
        attn_dropout=cfg.attn_dropout,
        ff_dropout=cfg.ff_dropout,
        flash_attn=cfg.flash_attn,
        norm_output=False,
        sage_attention=cfg.sage_attention,
        shared_qkv_bias=self.shared_qkv_bias,
        shared_out_bias=self.shared_out_bias,
    )

    t_frames = cfg.chunk_size // cfg.stft_hop_length + 1  # e.g. 588800 // 512 + 1 = 1151
    self.cos_emb_time = nn.Parameter(torch.zeros(t_frames, cfg.dim_head))
    self.sin_emb_time = nn.Parameter(torch.zeros(t_frames, cfg.dim_head))
    time_rotary_embed = RotaryEmbedding(cos_emb=self.cos_emb_time, sin_emb=self.sin_emb_time)

    num_bands = len(cfg.freqs_per_bands)  # e.g. 62
    self.cos_emb_freq = nn.Parameter(torch.zeros(num_bands, cfg.dim_head))
    self.sin_emb_freq = nn.Parameter(torch.zeros(num_bands, cfg.dim_head))
    freq_rotary_embed = RotaryEmbedding(cos_emb=self.cos_emb_freq, sin_emb=self.sin_emb_freq)

    for _ in range(cfg.depth):
        tran_modules = []
        if cfg.linear_transformer_depth > 0:
            tran_modules.append(
                Transformer(
                    depth=cfg.linear_transformer_depth,
                    linear_attn=True,
                    **transformer_kwargs,  # type: ignore
                )
            )
        tran_modules.append(
            Transformer(
                depth=cfg.time_transformer_depth,
                rotary_embed=time_rotary_embed,
                **transformer_kwargs,  # type: ignore
            )
        )
        tran_modules.append(
            Transformer(
                depth=cfg.freq_transformer_depth,
                rotary_embed=freq_rotary_embed,
                **transformer_kwargs,  # type: ignore
            )
        )
        self.layers.append(nn.ModuleList(tran_modules))

    self.final_norm = CustomNorm(cfg.dim)

    self.stft_kwargs = dict(
        n_fft=cfg.stft_n_fft,
        hop_length=cfg.stft_hop_length,
        win_length=cfg.stft_win_length,
        normalized=cfg.stft_normalized,
    )

    self.stft_window_fn = partial(cfg.stft_window_fn, cfg.stft_win_length)

    freqs_per_bands_with_complex = tuple(
        2 * f * self.audio_channels for f in cfg.freqs_per_bands
    )

    self.band_split = BandSplit(dim=cfg.dim, dim_inputs=freqs_per_bands_with_complex)

    self.mask_estimators = nn.ModuleList([])

    for _ in range(len(cfg.output_stem_names)):
        mask_estimator = MaskEstimator(
            dim=cfg.dim,
            dim_inputs=freqs_per_bands_with_complex,
            depth=cfg.mask_estimator_depth,
            mlp_expansion_factor=cfg.mlp_expansion_factor,
        )

        self.mask_estimators.append(mask_estimator)

    # for the multi-resolution stft loss

    self.multi_stft_resolution_loss_weight = cfg.multi_stft_resolution_loss_weight
    self.multi_stft_resolutions_window_sizes = cfg.multi_stft_resolutions_window_sizes
    self.multi_stft_n_fft = cfg.stft_n_fft
    self.multi_stft_window_fn = cfg.multi_stft_window_fn

    self.multi_stft_kwargs = dict(
        hop_length=cfg.multi_stft_hop_size, normalized=cfg.multi_stft_normalized
    )
    self.debug = cfg.debug
stereo instance-attribute
stereo = stereo
audio_channels instance-attribute
audio_channels = 2 if stereo else 1
num_stems instance-attribute
num_stems = len(output_stem_names)
use_torch_checkpoint instance-attribute
use_torch_checkpoint = use_torch_checkpoint
skip_connection instance-attribute
skip_connection = skip_connection
layers instance-attribute
layers = ModuleList([])
shared_qkv_bias instance-attribute
shared_qkv_bias: Parameter | None = None
shared_out_bias instance-attribute
shared_out_bias: Parameter | None = None
cos_emb_time instance-attribute
cos_emb_time = Parameter(zeros(t_frames, dim_head))
sin_emb_time instance-attribute
sin_emb_time = Parameter(zeros(t_frames, dim_head))
cos_emb_freq instance-attribute
cos_emb_freq = Parameter(zeros(num_bands, dim_head))
sin_emb_freq instance-attribute
sin_emb_freq = Parameter(zeros(num_bands, dim_head))
final_norm instance-attribute
final_norm = CustomNorm(dim)
stft_kwargs instance-attribute
stft_kwargs = dict(
    n_fft=stft_n_fft,
    hop_length=stft_hop_length,
    win_length=stft_win_length,
    normalized=stft_normalized,
)
stft_window_fn instance-attribute
stft_window_fn = partial(stft_window_fn, stft_win_length)
band_split instance-attribute
band_split = BandSplit(
    dim=dim, dim_inputs=freqs_per_bands_with_complex
)
mask_estimators instance-attribute
mask_estimators = ModuleList([])
multi_stft_resolution_loss_weight instance-attribute
multi_stft_resolution_loss_weight = (
    multi_stft_resolution_loss_weight
)
multi_stft_resolutions_window_sizes instance-attribute
multi_stft_resolutions_window_sizes = (
    multi_stft_resolutions_window_sizes
)
multi_stft_n_fft instance-attribute
multi_stft_n_fft = stft_n_fft
multi_stft_window_fn instance-attribute
multi_stft_window_fn = multi_stft_window_fn
multi_stft_kwargs instance-attribute
multi_stft_kwargs = dict(
    hop_length=multi_stft_hop_size,
    normalized=multi_stft_normalized,
)
debug instance-attribute
debug = debug
forward
forward(
    raw_audio: Tensor,
    target: Tensor | None = None,
    return_loss_breakdown: bool = False,
) -> Tensor | tuple[Tensor, tuple[Tensor, float]]

einops

  • b: batch
  • f: frequency
  • t: time
  • s: audio channel (1 for mono, 2 for stereo)
  • n: number of stems
  • c: complex (2)
  • d: feature dimension
Source code in src/splifft/models/bs_roformer.py
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
def forward(
    self, raw_audio: Tensor, target: Tensor | None = None, return_loss_breakdown: bool = False
) -> Tensor | tuple[Tensor, tuple[Tensor, float]]:
    """
    einops

    - b: batch
    - f: frequency
    - t: time
    - s: audio channel (1 for mono, 2 for stereo)
    - n: number of stems
    - c: complex (2)
    - d: feature dimension
    """

    device = raw_audio.device

    # defining whether model is loaded on MPS (MacOS GPU accelerator)
    x_is_mps = True if device.type == "mps" else False

    if raw_audio.ndim == 2:
        raw_audio = rearrange(raw_audio, "b t -> b 1 t")

    channels = raw_audio.shape[1]
    assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), (
        "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
    )

    # to stft

    raw_audio, batch_audio_channel_packed_shape = pack([raw_audio], "* t")

    stft_window = self.stft_window_fn(device=device)

    # RuntimeError: FFT operations are only supported on MacOS 14+
    # Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used
    try:
        stft_repr_complex = torch.stft(
            raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
        )
    except RuntimeError:
        stft_repr_complex = torch.stft(
            raw_audio.cpu() if x_is_mps else raw_audio,
            **self.stft_kwargs,
            window=stft_window.cpu() if x_is_mps else stft_window,
            return_complex=True,
        ).to(device)
    stft_repr = torch.view_as_real(stft_repr_complex)

    stft_repr = unpack(stft_repr, batch_audio_channel_packed_shape, "* f t c")[0]

    # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
    stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")

    x = rearrange(stft_repr, "b f t c -> b t (f c)")

    if self.debug and (torch.isnan(x).any() or torch.isinf(x).any()):
        raise RuntimeError(
            f"NaN/Inf in x after stft: {x.isnan().sum()} NaNs, {x.isinf().sum()} Infs"
        )

    if self.use_torch_checkpoint:
        x = checkpoint(self.band_split, x, use_reentrant=False)
    else:
        x = self.band_split(x)

    if self.debug and (torch.isnan(x).any() or torch.isinf(x).any()):
        raise RuntimeError(
            f"NaN/Inf in x after band_split: {x.isnan().sum()} NaNs, {x.isinf().sum()} Infs"
        )

    # axial / hierarchical attention

    store = [None] * len(self.layers)
    for i, transformer_block in enumerate(self.layers):
        if len(transformer_block) == 3:
            linear_transformer, time_transformer, freq_transformer = transformer_block

            x, ft_ps = pack([x], "b * d")
            if self.use_torch_checkpoint:
                x = checkpoint(linear_transformer, x, use_reentrant=False)
            else:
                x = linear_transformer(x)
            (x,) = unpack(x, ft_ps, "b * d")
        else:
            time_transformer, freq_transformer = transformer_block

        if self.skip_connection:
            # Sum all previous
            for j in range(i):
                x = x + store[j]

        x = rearrange(x, "b t f d -> b f t d")
        x, ps = pack([x], "* t d")

        if self.use_torch_checkpoint:
            x = checkpoint(time_transformer, x, use_reentrant=False)
        else:
            x = time_transformer(x)

        (x,) = unpack(x, ps, "* t d")
        x = rearrange(x, "b f t d -> b t f d")
        x, ps = pack([x], "* f d")

        if self.use_torch_checkpoint:
            x = checkpoint(freq_transformer, x, use_reentrant=False)
        else:
            x = freq_transformer(x)

        (x,) = unpack(x, ps, "* f d")

        if self.skip_connection:
            store[i] = x

    x = self.final_norm(x)

    num_stems = len(self.mask_estimators)

    if self.use_torch_checkpoint:
        mask = torch.stack(
            [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators],
            dim=1,
        )
    else:
        mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
    mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)

    # modulate frequency representation

    stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")

    # complex number multiplication

    stft_repr = torch.view_as_complex(stft_repr)
    mask = torch.view_as_complex(mask)

    stft_repr = stft_repr * mask

    # istft

    stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels)

    # same as torch.stft() fix for MacOS MPS above
    try:
        recon_audio = torch.istft(
            stft_repr,
            **self.stft_kwargs,
            window=stft_window,
            return_complex=False,
            length=raw_audio.shape[-1],
        )
    except RuntimeError:
        recon_audio = torch.istft(
            stft_repr.cpu() if x_is_mps else stft_repr,
            **self.stft_kwargs,
            window=stft_window.cpu() if x_is_mps else stft_window,
            return_complex=False,
            length=raw_audio.shape[-1],
        ).to(device)

    recon_audio = rearrange(
        recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
    )

    if num_stems == 1:
        recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")

    # if a target is passed in, calculate loss for learning

    if target is None:
        return recon_audio

    if self.num_stems > 1:
        assert target.ndim == 4 and target.shape[1] == self.num_stems

    if target.ndim == 2:
        target = rearrange(target, "... t -> ... 1 t")

    target = target[..., : recon_audio.shape[-1]]  # protect against lost length on istft

    loss = F.l1_loss(recon_audio, target)

    multi_stft_resolution_loss = 0.0

    for window_size in self.multi_stft_resolutions_window_sizes:
        res_stft_kwargs = dict(
            n_fft=max(
                window_size, self.multi_stft_n_fft
            ),  # not sure what n_fft is across multi resolution stft
            win_length=window_size,
            return_complex=True,
            window=self.multi_stft_window_fn(window_size, device=device),
            **self.multi_stft_kwargs,
        )

        recon_Y = torch.stft(rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs)
        target_Y = torch.stft(rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs)

        multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)

    weighted_multi_resolution_loss = (
        multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
    )

    total_loss = loss + weighted_multi_resolution_loss

    if not return_loss_breakdown:
        return total_loss

    return total_loss, (loss, multi_stft_resolution_loss)

utils

Modules:

Name Description
attend
attend_sage

Functions:

Name Description
parse_version
log_once

parse_version

parse_version(v_str: str) -> tuple[int, ...]
Source code in src/splifft/models/utils/__init__.py
6
7
8
def parse_version(v_str: str) -> tuple[int, ...]:
    # e.g "2.1.0+cu118" -> (2, 1, 0)
    return tuple(map(int, v_str.split("+")[0].split(".")))

log_once cached

log_once(
    logger: Logger, msg: object, *, level: int = DEBUG
) -> None
Source code in src/splifft/models/utils/__init__.py
11
12
13
@lru_cache(10)
def log_once(logger: Logger, msg: object, *, level: int = logging.DEBUG) -> None:
    logger.log(level, msg)

attend_sage

Classes:

Name Description
AttendSage

Attributes:

Name Type Description
logger
logger module-attribute
logger = getLogger(__name__)
AttendSage
AttendSage(
    dropout: float = 0.0,
    flash: bool = False,
    scale: float | None = None,
)

Bases: Module

Parameters:

Name Type Description Default
flash bool

if True, attempts to use SageAttention or PyTorch SDPA.

False

Methods:

Name Description
forward

einstein notation

Attributes:

Name Type Description
scale
dropout
use_sage
use_pytorch_sdpa
attn_dropout
Source code in src/splifft/models/utils/attend_sage.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def __init__(
    self,
    dropout: float = 0.0,
    flash: bool = False,
    scale: float | None = None,
):
    """
    :param flash: if True, attempts to use SageAttention or PyTorch SDPA.
    """
    super().__init__()
    self.scale = scale  # for einsum path
    self.dropout = dropout  # for einsum/SDPA path

    self.use_sage = flash and _has_sage_attention
    self.use_pytorch_sdpa = False
    self._sdpa_checked = False

    if flash and not self.use_sage:
        if not self._sdpa_checked:
            if parse_version(torch.__version__) >= (2, 0, 0):
                self.use_pytorch_sdpa = True
                log_once(
                    logger,
                    "Using PyTorch SDPA backend (FlashAttention-2, Memory-Efficient, or Math).",
                )
            else:
                log_once(
                    logger,
                    "Flash attention requested but Pytorch < 2.0 and SageAttention not found. Falling back to einsum.",
                )
            self._sdpa_checked = True

    # dropout layer for manual einsum implementation ONLY
    # SDPA and SageAttention handle dropout differently
    # (or not at all in Sage's base API)
    self.attn_dropout = nn.Dropout(dropout)
scale instance-attribute
scale = scale
dropout instance-attribute
dropout = dropout
use_sage instance-attribute
use_sage = flash and _has_sage_attention
use_pytorch_sdpa instance-attribute
use_pytorch_sdpa = False
attn_dropout instance-attribute
attn_dropout = Dropout(dropout)
forward
forward(q: Tensor, k: Tensor, v: Tensor) -> Tensor

einstein notation

  • b: batch
  • h: heads
  • n, i, j: sequence length (base sequence length, source, target)
  • d: feature dimension

Input tensors q, k, v expected in shape: (batch, heads, seq_len, dim_head) -> HND layout

Source code in src/splifft/models/utils/attend_sage.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
    """
    einstein notation

    - b: batch
    - h: heads
    - n, i, j: sequence length (base sequence length, source, target)
    - d: feature dimension

    Input tensors q, k, v expected in shape: (batch, heads, seq_len, dim_head) -> HND layout
    """
    _q_len, _k_len, _device = q.shape[-2], k.shape[-2], q.device

    # priority 1: SageAttention
    if self.use_sage:
        # assumes q, k, v are FP16/BF16 (handled by autocast upstream)
        # assumes scale is handled internally by sageattn
        # assumes dropout is NOT handled by sageattn kernel
        # is_causal=False based on how Attend is called in mel_band_roformer
        out = sageattn(q, k, v, tensor_layout="HND", is_causal=False)  # type: ignore
        return out  # type: ignore
        try:
            out = sageattn(q, k, v, tensor_layout="HND", is_causal=False)
            return out
        except Exception as e:
            logger.error(f"SageAttention failed with error: {e}. Falling back.")
            self.use_sage = False
            if not self._sdpa_checked:
                if parse_version(torch.__version__) >= (2, 0, 0):
                    self.use_pytorch_sdpa = True
                    log_once(logger, "falling back to PyTorch SDPA")
                else:
                    log_once(logger, "falling back to einsum.")

                self._sdpa_checked = True

    # priority 2: PyTorch SDPA
    if self.use_pytorch_sdpa:
        # it handles scaling and dropout internally.
        try:
            with sdpa_kernel(
                [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
            ):
                out = F.scaled_dot_product_attention(
                    q,
                    k,
                    v,
                    attn_mask=None,  # assuming no explicit mask needed here
                    dropout_p=self.dropout if self.training else 0.0,
                    is_causal=False,  # assuming not needed based on usage context
                )
            return out
        except Exception as e:
            log_once(
                logger,
                f"pytorch SDPA failed with error: {e}. falling back to einsum.",
                level=logging.ERROR,
            )
            self.use_pytorch_sdpa = False

    scale = self.scale or q.shape[-1] ** -0.5

    # similarity
    sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale

    # attention
    attn = sim.softmax(dim=-1)
    attn = self.attn_dropout(attn)  # ONLY in einsum path

    # aggregate values
    out = einsum("b h i j, b h j d -> b h i d", attn, v)

    return out

attend

Classes:

Name Description
Attend

Attributes:

Name Type Description
logger
logger module-attribute
logger = getLogger(__name__)
Attend
Attend(
    dropout: float = 0.0,
    flash: bool = False,
    scale: float | None = None,
)

Bases: Module

Methods:

Name Description
flash_attn
forward

einstein notation

Attributes:

Name Type Description
scale
dropout
attn_dropout
flash
cpu_backends
cuda_backends list[_SDPBackend] | None
Source code in src/splifft/models/utils/attend.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def __init__(
    self, dropout: float = 0.0, flash: bool = False, scale: float | None = None
) -> None:
    super().__init__()
    self.scale = scale
    self.dropout = dropout
    self.attn_dropout = nn.Dropout(dropout)

    self.flash = flash
    assert not (flash and parse_version(torch.__version__) < (2, 0, 0)), (
        "expected pytorch >= 2.0.0 to use flash attention"
    )

    # determine efficient attention configs for cuda and cpu
    self.cpu_backends = [
        SDPBackend.FLASH_ATTENTION,
        SDPBackend.EFFICIENT_ATTENTION,
        SDPBackend.MATH,
    ]
    self.cuda_backends: list[_SDPBackend] | None = None

    if not torch.cuda.is_available() or not flash:
        return

    device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
    device_version = parse_version(f"{device_properties.major}.{device_properties.minor}")

    if device_version >= (8, 0):
        if os.name == "nt":
            cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
            log_once(logger, f"windows detected, using {cuda_backends=}")
        else:
            cuda_backends = [SDPBackend.FLASH_ATTENTION]
            log_once(logger, f"gpu compute capability >= 8.0, using {cuda_backends=}")
    else:
        cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
        log_once(logger, f"gpu compute capability < 8.0, using {cuda_backends=}")

    self.cuda_backends = cuda_backends
scale instance-attribute
scale = scale
dropout instance-attribute
dropout = dropout
attn_dropout instance-attribute
attn_dropout = Dropout(dropout)
flash instance-attribute
flash = flash
cpu_backends instance-attribute
cpu_backends = [FLASH_ATTENTION, EFFICIENT_ATTENTION, MATH]
cuda_backends instance-attribute
cuda_backends: list[_SDPBackend] | None = cuda_backends
flash_attn
flash_attn(q: Tensor, k: Tensor, v: Tensor) -> Tensor
Source code in src/splifft/models/utils/attend.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def flash_attn(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
    _, _heads, _q_len, _, _k_len, is_cuda, _device = (
        *q.shape,
        k.shape[-2],
        q.is_cuda,
        q.device,
    )  # type: ignore

    if self.scale is not None:
        default_scale = q.shape[-1] ** -0.5
        q = q * (self.scale / default_scale)

    backends = self.cuda_backends if is_cuda else self.cpu_backends
    # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
    with sdpa_kernel(backends=backends):  # type: ignore
        out = F.scaled_dot_product_attention(
            q, k, v, dropout_p=self.dropout if self.training else 0.0
        )

    return out
forward
forward(q: Tensor, k: Tensor, v: Tensor) -> Tensor

einstein notation

  • b: batch
  • h: heads
  • n, i, j: sequence length (base sequence length, source, target)
  • d: feature dimension
Source code in src/splifft/models/utils/attend.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
    """
    einstein notation

    - b: batch
    - h: heads
    - n, i, j: sequence length (base sequence length, source, target)
    - d: feature dimension
    """
    _q_len, _k_len, _device = q.shape[-2], k.shape[-2], q.device

    scale = self.scale or q.shape[-1] ** -0.5

    if self.flash:
        return self.flash_attn(q, k, v)

    # similarity
    sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale

    # attention
    attn = sim.softmax(dim=-1)
    attn = self.attn_dropout(attn)

    # aggregate values
    out = einsum("b h i j, b h j d -> b h i d", attn, v)

    return out