Skip to content

Models

models

Source separation models.

Modules:

Name Description
basic_pitch

ICASSP 2022 Basic Pitch. Raw multi-stream outputs only, no symbolic decoding.

beat_this

Beat This! Beat Tracker.

bs_roformer

Band-Split RoPE Transformer

mdx23c

MDX23C.

pesto

PESTO: Pitch Estimation with Self-supervised Transposition-equivariant Objective.

utils

Classes:

Name Description
ModelParamsLike

A trait that must be implemented to be considered a model parameter.

StemSelectionPlan

Optional model-specific plan for selective stem inference.

SupportsStemSelection
ModelMetadata

Metadata about a model, including its type, parameter class, and model class.

Attributes:

Name Type Description
ModelT
ModelParamsLikeT
StateDictTransform TypeAlias

ModelParamsLike

Bases: Protocol

A trait that must be implemented to be considered a model parameter. Note that input_type and output_type belong to a model's definition and does not allow modification via the configuration dictionary.

Attributes:

Name Type Description
chunk_size ChunkSize
output_stem_names tuple[ModelOutputStemName, ...]
input_channels ModelInputChannels
input_type ModelInputType
output_type ModelOutputType
inference_archetype InferenceArchetype

chunk_size instance-attribute

chunk_size: ChunkSize

output_stem_names instance-attribute

output_stem_names: tuple[ModelOutputStemName, ...]

input_channels property

input_channels: ModelInputChannels

input_type property

input_type: ModelInputType

output_type property

output_type: ModelOutputType

inference_archetype property

inference_archetype: InferenceArchetype

ModelT module-attribute

ModelT = TypeVar('ModelT', bound=Module)

ModelParamsLikeT module-attribute

ModelParamsLikeT = TypeVar(
    "ModelParamsLikeT", bound=ModelParamsLike
)

StateDictTransform module-attribute

StateDictTransform: TypeAlias = Callable[
    [dict[str, Tensor]], dict[str, Tensor]
]

StemSelectionPlan dataclass

StemSelectionPlan(
    model_params: ModelParamsLikeT,
    output_stem_names: tuple[ModelOutputStemName, ...],
    state_dict_transform: StateDictTransform | None = None,
)

Bases: Generic[ModelParamsLikeT]

Optional model-specific plan for selective stem inference.

Models can provide this plan to: - instantiate a stem-reduced parameter set, and/or - return a checkpoint state-dict transformer that drops/remaps unrelated heads.

output_stem_names defines the output ordering produced by the instantiated model after applying the plan.

Attributes:

Name Type Description
model_params ModelParamsLikeT
output_stem_names tuple[ModelOutputStemName, ...]
state_dict_transform StateDictTransform | None

model_params instance-attribute

model_params: ModelParamsLikeT

output_stem_names instance-attribute

output_stem_names: tuple[ModelOutputStemName, ...]

state_dict_transform class-attribute instance-attribute

state_dict_transform: StateDictTransform | None = None

SupportsStemSelection

Bases: Protocol[ModelParamsLikeT]

Methods:

Name Description
__splifft_stem_selection_plan__

__splifft_stem_selection_plan__ classmethod

__splifft_stem_selection_plan__(
    model_params: ModelParamsLikeT,
    output_stem_names: tuple[ModelOutputStemName, ...],
) -> StemSelectionPlan[ModelParamsLikeT]
Source code in src/splifft/models/__init__.py
59
60
61
62
63
64
@classmethod
def __splifft_stem_selection_plan__(
    cls,
    model_params: ModelParamsLikeT,
    output_stem_names: tuple[t.ModelOutputStemName, ...],
) -> StemSelectionPlan[ModelParamsLikeT]: ...

ModelMetadata dataclass

ModelMetadata(
    model_type: ModelType,
    params: type[ModelParamsLikeT],
    model: type[ModelT],
)

Bases: Generic[ModelT, ModelParamsLikeT]

Metadata about a model, including its type, parameter class, and model class.

Methods:

Name Description
from_module

Dynamically import a model named X and its parameter dataclass XParams under a

Attributes:

Name Type Description
model_type ModelType
params type[ModelParamsLikeT]
model type[ModelT]

model_type instance-attribute

model_type: ModelType

params instance-attribute

model instance-attribute

model: type[ModelT]

from_module classmethod

from_module(
    module_name: str,
    model_cls_name: str,
    *,
    model_type: ModelType,
    package: str | None = None,
) -> ModelMetadata[Module, ModelParamsLike]

Dynamically import a model named X and its parameter dataclass XParams under a given module name (e.g. splifft.models.bs_roformer).

Parameters:

Name Type Description Default
model_cls_name str

The name of the model class to import, e.g. BSRoformer.

required
module_name str

The name of the module to import, e.g. splifft.models.bs_roformer.

required
model_type ModelType

The type of the model, e.g. bs_roformer.

required
package str | None

The package to use as the anchor point from which to resolve the relative import. to an absolute import. This is only required when performing a relative import.

None
Source code in src/splifft/models/__init__.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
@classmethod
def from_module(
    cls,
    module_name: str,
    model_cls_name: str,
    *,
    model_type: t.ModelType,
    package: str | None = None,
) -> ModelMetadata[nn.Module, ModelParamsLike]:
    """
    Dynamically import a model named `X` and its parameter dataclass `XParams` under a
    given module name (e.g. `splifft.models.bs_roformer`).

    :param model_cls_name: The name of the model class to import, e.g. `BSRoformer`.
    :param module_name: The name of the module to import, e.g. `splifft.models.bs_roformer`.
    :param model_type: The type of the model, e.g. `bs_roformer`.
    :param package: The package to use as the anchor point from which to resolve the relative import.
    to an absolute import. This is only required when performing a relative import.
    """
    _loc = f"{module_name=} under {package=}"
    try:
        module = importlib.import_module(module_name, package)
    except ImportError as e:
        raise ValueError(f"failed to find or import module for {_loc}") from e

    params_cls_name = f"{model_cls_name}Params"
    model_cls = getattr(module, model_cls_name, None)
    params_cls = getattr(module, params_cls_name, None)
    if model_cls is None or params_cls is None:
        raise AttributeError(
            f"expected to find a class named `{params_cls_name}` in {_loc}, but it was not found."
        )

    return ModelMetadata(
        model_type=model_type,
        model=model_cls,
        params=params_cls,
    )

pesto

PESTO: Pitch Estimation with Self-supervised Transposition-equivariant Objective.

See: https://github.com/SonyCSLParis/pesto, https://arxiv.org/abs/2309.02265

Classes:

Name Description
PestoParams
ToeplitzLinear
Resnet1d

Compact 1D CNN used by PESTO to decode HCQT frames into activations.

ConfidenceClassifier

Frame-level voiced/unvoiced confidence head.

Pesto

PESTO inference head over externally computed HCQT features.

Functions:

Name Description
reduce_activations

Reduce per-bin probabilities to scalar pitch per frame.

PestoParams dataclass

PestoParams(
    chunk_size: ChunkSize,
    output_stem_names: tuple[ModelOutputStemName, ...],
    reduction: Literal["argmax", "mean", "alwa"] = "alwa",
    convert_to_freq: bool = True,
    crop_freq_bins_bottom: Ge0[int] = 16,
    crop_freq_bins_top: Ge0[int] = 16,
    n_chan_input: Gt0[int] = 1,
    n_chan_layers: tuple[Gt0[int], ...] = (
        40,
        30,
        30,
        10,
        3,
    ),
    n_prefilt_layers: Gt0[int] = 3,
    prefilt_kernel_size: Gt0[int] = 39,
    residual: bool = True,
    n_bins_in: Gt0[int] = 219,
    output_dim: Gt0[int] = 384,
    activation_fn: Literal[
        "relu", "silu", "leaky"
    ] = "leaky",
    a_lrelu: Ge0[float] = 0.3,
    p_dropout: Dropout = 0.2,
    bins_per_semitone: Gt0[int] = 3,
)

Bases: ModelParamsLike

Attributes:

Name Type Description
chunk_size ChunkSize
output_stem_names tuple[ModelOutputStemName, ...]
reduction Literal['argmax', 'mean', 'alwa']
convert_to_freq bool
crop_freq_bins_bottom Ge0[int]
crop_freq_bins_top Ge0[int]
n_chan_input Gt0[int]
n_chan_layers tuple[Gt0[int], ...]
n_prefilt_layers Gt0[int]
prefilt_kernel_size Gt0[int]
residual bool
n_bins_in Gt0[int]
output_dim Gt0[int]
activation_fn Literal['relu', 'silu', 'leaky']
a_lrelu Ge0[float]
p_dropout Dropout
bins_per_semitone Gt0[int]
input_channels ModelInputChannels
input_type ModelInputType
output_type ModelOutputType
inference_archetype InferenceArchetype
chunk_size instance-attribute
chunk_size: ChunkSize
output_stem_names instance-attribute
output_stem_names: tuple[ModelOutputStemName, ...]
reduction class-attribute instance-attribute
reduction: Literal['argmax', 'mean', 'alwa'] = 'alwa'
convert_to_freq class-attribute instance-attribute
convert_to_freq: bool = True
crop_freq_bins_bottom class-attribute instance-attribute
crop_freq_bins_bottom: Ge0[int] = 16
crop_freq_bins_top class-attribute instance-attribute
crop_freq_bins_top: Ge0[int] = 16
n_chan_input class-attribute instance-attribute
n_chan_input: Gt0[int] = 1
n_chan_layers class-attribute instance-attribute
n_chan_layers: tuple[Gt0[int], ...] = (40, 30, 30, 10, 3)
n_prefilt_layers class-attribute instance-attribute
n_prefilt_layers: Gt0[int] = 3
prefilt_kernel_size class-attribute instance-attribute
prefilt_kernel_size: Gt0[int] = 39
residual class-attribute instance-attribute
residual: bool = True
n_bins_in class-attribute instance-attribute
n_bins_in: Gt0[int] = 219
output_dim class-attribute instance-attribute
output_dim: Gt0[int] = 384
activation_fn class-attribute instance-attribute
activation_fn: Literal['relu', 'silu', 'leaky'] = 'leaky'
a_lrelu class-attribute instance-attribute
a_lrelu: Ge0[float] = 0.3
p_dropout class-attribute instance-attribute
p_dropout: Dropout = 0.2
bins_per_semitone class-attribute instance-attribute
bins_per_semitone: Gt0[int] = 3
input_channels property
input_channels: ModelInputChannels
input_type property
input_type: ModelInputType
output_type property
output_type: ModelOutputType
inference_archetype property
inference_archetype: InferenceArchetype

ToeplitzLinear

ToeplitzLinear(in_features: int, out_features: int)

Bases: Conv1d

Methods:

Name Description
forward
Source code in src/splifft/models/pesto.py
65
66
67
68
69
70
71
72
def __init__(self, in_features: int, out_features: int):
    super().__init__(
        in_channels=1,
        out_channels=1,
        kernel_size=in_features + out_features - 1,
        padding=out_features - 1,
        bias=False,
    )
forward
forward(input: Tensor) -> Tensor
Source code in src/splifft/models/pesto.py
74
75
def forward(self, input: torch.Tensor) -> torch.Tensor:
    return super().forward(input.unsqueeze(-2)).squeeze(-2)

Resnet1d

Resnet1d(
    *,
    n_chan_input: int = 1,
    n_chan_layers: tuple[int, ...] = (40, 30, 30, 10, 3),
    n_prefilt_layers: int = 3,
    prefilt_kernel_size: int = 39,
    residual: bool = True,
    n_bins_in: int = 219,
    output_dim: int = 384,
    activation_fn: Literal[
        "relu", "silu", "leaky"
    ] = "leaky",
    a_lrelu: float = 0.3,
    p_dropout: float = 0.2,
)

Bases: Module

Compact 1D CNN used by PESTO to decode HCQT frames into activations.

Methods:

Name Description
forward

Attributes:

Name Type Description
layernorm
conv1
n_prefilt_layers
prefilt_layers
residual
conv_layers
flatten
fc
final_norm
Source code in src/splifft/models/pesto.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def __init__(
    self,
    *,
    n_chan_input: int = 1,
    n_chan_layers: tuple[int, ...] = (40, 30, 30, 10, 3),
    n_prefilt_layers: int = 3,
    prefilt_kernel_size: int = 39,
    residual: bool = True,
    n_bins_in: int = 219,
    output_dim: int = 384,
    activation_fn: Literal["relu", "silu", "leaky"] = "leaky",
    a_lrelu: float = 0.3,
    p_dropout: float = 0.2,
):
    super().__init__()

    activation_layer: Callable[[], nn.Module]
    if activation_fn == "relu":
        activation_layer = nn.ReLU
    elif activation_fn == "silu":
        activation_layer = nn.SiLU
    elif activation_fn == "leaky":
        activation_layer = partial(nn.LeakyReLU, negative_slope=a_lrelu)
    else:
        raise ValueError(f"unsupported activation_fn={activation_fn!r}")

    n_ch = list(n_chan_layers)
    if len(n_ch) < 5:
        n_ch.append(1)

    self.layernorm = nn.LayerNorm(normalized_shape=[n_chan_input, n_bins_in])

    prefilt_padding = prefilt_kernel_size // 2
    self.conv1 = nn.Sequential(
        nn.Conv1d(
            in_channels=n_chan_input,
            out_channels=n_ch[0],
            kernel_size=prefilt_kernel_size,
            padding=prefilt_padding,
            stride=1,
        ),
        activation_layer(),
        nn.Dropout(p=p_dropout),
    )
    self.n_prefilt_layers = n_prefilt_layers
    self.prefilt_layers = nn.ModuleList(
        [
            nn.Sequential(
                nn.Conv1d(
                    in_channels=n_ch[0],
                    out_channels=n_ch[0],
                    kernel_size=prefilt_kernel_size,
                    padding=prefilt_padding,
                    stride=1,
                ),
                activation_layer(),
                nn.Dropout(p=p_dropout),
            )
            for _ in range(n_prefilt_layers - 1)
        ]
    )
    self.residual = residual

    conv_layers: list[nn.Module] = []
    for i in range(len(n_ch) - 1):
        conv_layers.extend(
            [
                nn.Conv1d(
                    in_channels=n_ch[i],
                    out_channels=n_ch[i + 1],
                    kernel_size=1,
                    padding=0,
                    stride=1,
                ),
                activation_layer(),
                nn.Dropout(p=p_dropout),
            ]
        )
    self.conv_layers = nn.Sequential(*conv_layers)

    self.flatten = nn.Flatten(start_dim=1)
    self.fc = ToeplitzLinear(n_bins_in * n_ch[-1], output_dim)
    self.final_norm = nn.Softmax(dim=-1)
layernorm instance-attribute
layernorm = LayerNorm(
    normalized_shape=[n_chan_input, n_bins_in]
)
conv1 instance-attribute
conv1 = Sequential(
    Conv1d(
        in_channels=n_chan_input,
        out_channels=n_ch[0],
        kernel_size=prefilt_kernel_size,
        padding=prefilt_padding,
        stride=1,
    ),
    activation_layer(),
    Dropout(p=p_dropout),
)
n_prefilt_layers instance-attribute
n_prefilt_layers = n_prefilt_layers
prefilt_layers instance-attribute
prefilt_layers = ModuleList(
    [
        (
            Sequential(
                Conv1d(
                    in_channels=n_ch[0],
                    out_channels=n_ch[0],
                    kernel_size=prefilt_kernel_size,
                    padding=prefilt_padding,
                    stride=1,
                ),
                activation_layer(),
                Dropout(p=p_dropout),
            )
        )
        for _ in (range(n_prefilt_layers - 1))
    ]
)
residual instance-attribute
residual = residual
conv_layers instance-attribute
conv_layers = Sequential(*conv_layers)
flatten instance-attribute
flatten = Flatten(start_dim=1)
fc instance-attribute
fc = ToeplitzLinear(n_bins_in * n_ch[-1], output_dim)
final_norm instance-attribute
final_norm = Softmax(dim=-1)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/pesto.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.layernorm(x)

    x = self.conv1(x)
    for i in range(0, self.n_prefilt_layers - 1):
        prefilt_layer = self.prefilt_layers[i]
        if self.residual:
            x = prefilt_layer(x) + x
        else:
            x = prefilt_layer(x)

    x = self.conv_layers(x)
    x = self.flatten(x)
    y_pred = self.fc(x)
    return cast(torch.Tensor, self.final_norm(y_pred))

ConfidenceClassifier

ConfidenceClassifier()

Bases: Module

Frame-level voiced/unvoiced confidence head.

Methods:

Name Description
forward

Attributes:

Name Type Description
conv
linear
Source code in src/splifft/models/pesto.py
185
186
187
188
def __init__(self) -> None:
    super().__init__()
    self.conv = nn.Conv1d(1, 1, 39, stride=3)
    self.linear = nn.Linear(72, 1)
conv instance-attribute
conv = Conv1d(1, 1, 39, stride=3)
linear instance-attribute
linear = Linear(72, 1)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/pesto.py
190
191
192
193
194
195
196
def forward(self, x: torch.Tensor) -> torch.Tensor:
    geometric_mean = x.log().mean(dim=-1, keepdim=True).exp()
    arithmetic_mean = x.mean(dim=-1, keepdim=True).clamp_(min=1e-8)
    flatness = geometric_mean / arithmetic_mean

    x = F.relu(self.conv(x.unsqueeze(1)).squeeze(1))
    return torch.sigmoid(self.linear(torch.cat((x, flatness), dim=-1))).squeeze(-1)

reduce_activations

reduce_activations(
    activations: Tensor, reduction: str = "alwa"
) -> Tensor

Reduce per-bin probabilities to scalar pitch per frame.

Source code in src/splifft/models/pesto.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def reduce_activations(activations: torch.Tensor, reduction: str = "alwa") -> torch.Tensor:
    """Reduce per-bin probabilities to scalar pitch per frame."""
    device = activations.device
    num_bins = int(activations.size(-1))

    bps, rem = divmod(num_bins, 128)
    if rem != 0:
        raise ValueError(f"expected output_dim to be divisible by 128, got {num_bins}")

    if reduction == "argmax":
        pred = activations.argmax(dim=-1)
        return pred.float() / bps

    all_pitches = torch.arange(num_bins, dtype=torch.float, device=device).div_(bps)
    if reduction == "mean":
        return torch.matmul(activations, all_pitches)

    if reduction == "alwa":
        center_bin = activations.argmax(dim=-1, keepdim=True)
        window = torch.arange(1, 2 * bps, device=device) - bps
        indices = (center_bin + window).clamp_(min=0, max=num_bins - 1)
        cropped_activations = activations.gather(-1, indices)
        cropped_pitches = all_pitches.unsqueeze(0).expand_as(activations).gather(-1, indices)
        return (cropped_activations * cropped_pitches).sum(dim=-1) / cropped_activations.sum(dim=-1)

    raise ValueError(f"unknown reduction={reduction!r}")

Pesto

Pesto(cfg: PestoParams)

Bases: Module

PESTO inference head over externally computed HCQT features.

Input contract: tensor of shape (batch, time, feature_dim) where feature_dim = harmonics * freq_bins in dB log-magnitude HCQT.

Methods:

Name Description
forward

Attributes:

Name Type Description
cfg
encoder
confidence
Source code in src/splifft/models/pesto.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def __init__(self, cfg: PestoParams):
    super().__init__()
    self.cfg = cfg
    self.encoder = Resnet1d(
        n_chan_input=cfg.n_chan_input,
        n_chan_layers=cfg.n_chan_layers,
        n_prefilt_layers=cfg.n_prefilt_layers,
        prefilt_kernel_size=cfg.prefilt_kernel_size,
        residual=cfg.residual,
        n_bins_in=cfg.n_bins_in,
        output_dim=cfg.output_dim,
        activation_fn=cfg.activation_fn,
        a_lrelu=cfg.a_lrelu,
        p_dropout=cfg.p_dropout,
    )
    self.confidence = ConfidenceClassifier()

    self.register_buffer("shift", torch.zeros((), dtype=torch.float), persistent=True)
cfg instance-attribute
cfg = cfg
encoder instance-attribute
encoder = Resnet1d(
    n_chan_input=n_chan_input,
    n_chan_layers=n_chan_layers,
    n_prefilt_layers=n_prefilt_layers,
    prefilt_kernel_size=prefilt_kernel_size,
    residual=residual,
    n_bins_in=n_bins_in,
    output_dim=output_dim,
    activation_fn=activation_fn,
    a_lrelu=a_lrelu,
    p_dropout=p_dropout,
)
confidence instance-attribute
confidence = ConfidenceClassifier()
forward
forward(x: Tensor) -> dict[str, Tensor]
Source code in src/splifft/models/pesto.py
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
    if x.ndim != 3:
        raise ValueError(
            f"expected `(batch,time,feature_dim)` input, got shape={tuple(x.shape)}"
        )

    total_bins = (
        self.cfg.n_bins_in + self.cfg.crop_freq_bins_bottom + self.cfg.crop_freq_bins_top
    )
    expected_feature_dim = self.cfg.n_chan_input * total_bins
    if x.shape[-1] != expected_feature_dim:
        raise ValueError(
            "invalid PESTO feature dimension: "
            f"expected {expected_feature_dim} (= n_chan_input * (n_bins_in + crop_bottom + crop_top)), "
            f"got {x.shape[-1]}"
        )

    batch_size, num_frames, _feature_dim = x.shape
    x = x.view(batch_size, num_frames, self.cfg.n_chan_input, total_bins)
    x = x.flatten(0, 1)  # (B*T, H, F)

    # match reference implementation: convert dB back to linear energy and derive
    # confidence + volume from pre-crop HCQT bins.
    energy = x.mul(torch.log(torch.tensor(10.0, device=x.device, dtype=x.dtype)) / 10.0).exp()
    confidence_energy = energy.mean(dim=1)
    volume = energy.sum(dim=-1).mean(dim=-1)
    confidence = self.confidence(confidence_energy)

    x = self._crop_cqt(x)
    activations = self.encoder(x)

    activations = activations.view(batch_size, num_frames, activations.size(-1))
    confidence = confidence.view(batch_size, num_frames)
    volume = volume.view(batch_size, num_frames)

    shift_tensor = cast(torch.Tensor, self.shift)
    shift_bins = int(torch.round(shift_tensor * self.cfg.bins_per_semitone).item())
    activations = activations.roll(-shift_bins, dims=-1)

    pitch = reduce_activations(activations, reduction=self.cfg.reduction)
    if self.cfg.convert_to_freq:
        pitch = 440 * 2 ** ((pitch - 69) / 12)

    return {
        "pitch": pitch,
        "confidence": confidence,
        "volume": volume,
        "activations": activations,
    }

beat_this

Beat This! Beat Tracker.

Classes:

Name Description
BeatThisParams
PartialFTTransformer

Takes a (batch, channels, freqs, time) input, applies self-attention and

SumHead
Head
BeatThis

BeatThisParams dataclass

BeatThisParams(
    chunk_size: ChunkSize,
    output_stem_names: tuple[ModelOutputStemName, ...],
    spect_dim: Gt0[int] = 128,
    transformer_dim: Gt0[int] = 512,
    ff_mult: Gt0[int] = 4,
    n_layers: Gt0[int] = 6,
    head_dim: Gt0[int] = 32,
    stem_dim: Gt0[int] = 32,
    dropout_frontend: Dropout = 0.1,
    dropout_transformer: Dropout = 0.2,
    sum_head: bool = True,
    partial_transformers: bool = True,
    rotary_embed_dtype: TorchDtype | None = None,
    transformer_residual_dtype: TorchDtype | None = None,
    log_mel_hop_length: HopSize = 441,
)

Bases: ModelParamsLike

Attributes:

Name Type Description
chunk_size ChunkSize
output_stem_names tuple[ModelOutputStemName, ...]
spect_dim Gt0[int]
transformer_dim Gt0[int]
ff_mult Gt0[int]
n_layers Gt0[int]
head_dim Gt0[int]
stem_dim Gt0[int]
dropout_frontend Dropout
dropout_transformer Dropout
sum_head bool
partial_transformers bool
rotary_embed_dtype TorchDtype | None
transformer_residual_dtype TorchDtype | None
log_mel_hop_length HopSize

The hop length of the log mel spectrogram.

input_channels ModelInputChannels
input_type ModelInputType
output_type ModelOutputType
inference_archetype InferenceArchetype
chunk_size instance-attribute
chunk_size: ChunkSize
output_stem_names instance-attribute
output_stem_names: tuple[ModelOutputStemName, ...]
spect_dim class-attribute instance-attribute
spect_dim: Gt0[int] = 128
transformer_dim class-attribute instance-attribute
transformer_dim: Gt0[int] = 512
ff_mult class-attribute instance-attribute
ff_mult: Gt0[int] = 4
n_layers class-attribute instance-attribute
n_layers: Gt0[int] = 6
head_dim class-attribute instance-attribute
head_dim: Gt0[int] = 32
stem_dim class-attribute instance-attribute
stem_dim: Gt0[int] = 32
dropout_frontend class-attribute instance-attribute
dropout_frontend: Dropout = 0.1
dropout_transformer class-attribute instance-attribute
dropout_transformer: Dropout = 0.2
sum_head class-attribute instance-attribute
sum_head: bool = True
partial_transformers class-attribute instance-attribute
partial_transformers: bool = True
rotary_embed_dtype class-attribute instance-attribute
rotary_embed_dtype: TorchDtype | None = None
transformer_residual_dtype class-attribute instance-attribute
transformer_residual_dtype: TorchDtype | None = None
log_mel_hop_length class-attribute instance-attribute
log_mel_hop_length: HopSize = 441

The hop length of the log mel spectrogram.

Warning

This must match the hop_length in the LogMelConfig to ensure the rotary embeddings are sized correctly for the sequence length.

input_channels property
input_channels: ModelInputChannels
input_type property
input_type: ModelInputType
output_type property
output_type: ModelOutputType
inference_archetype property
inference_archetype: InferenceArchetype

PartialFTTransformer

PartialFTTransformer(
    dim: int,
    dim_head: int,
    n_head: int,
    rotary_embed_f: RotaryEmbedding,
    rotary_embed_t: RotaryEmbedding,
    dropout: float,
)

Bases: Module

Takes a (batch, channels, freqs, time) input, applies self-attention and a feed-forward block once across frequencies and once across time.

Methods:

Name Description
forward

Attributes:

Name Type Description
attnF
ffF
attnT
ffT
Source code in src/splifft/models/beat_this.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def __init__(
    self,
    dim: int,
    dim_head: int,
    n_head: int,
    rotary_embed_f: RotaryEmbedding,
    rotary_embed_t: RotaryEmbedding,
    dropout: float,
):
    super().__init__()
    self.attnF = Attention(
        dim, heads=n_head, dim_head=dim_head, dropout=dropout, rotary_embed=rotary_embed_f
    )
    self.ffF = FeedForward(dim, dropout=dropout)
    self.attnT = Attention(
        dim, heads=n_head, dim_head=dim_head, dropout=dropout, rotary_embed=rotary_embed_t
    )
    self.ffT = FeedForward(dim, dropout=dropout)
attnF instance-attribute
attnF = Attention(
    dim,
    heads=n_head,
    dim_head=dim_head,
    dropout=dropout,
    rotary_embed=rotary_embed_f,
)
ffF instance-attribute
ffF = FeedForward(dim, dropout=dropout)
attnT instance-attribute
attnT = Attention(
    dim,
    heads=n_head,
    dim_head=dim_head,
    dropout=dropout,
    rotary_embed=rotary_embed_t,
)
ffT instance-attribute
ffT = FeedForward(dim, dropout=dropout)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/beat_this.py
 98
 99
100
101
102
103
104
105
106
107
def forward(self, x: Tensor) -> Tensor:
    b = len(x)
    x = rearrange(x, "b c f t -> (b t) f c")
    x = x + self.attnF(x)
    x = x + self.ffF(x)
    x = rearrange(x, "(b t) f c -> (b f) t c", b=b)
    x = x + self.attnT(x)
    x = x + self.ffT(x)
    x = rearrange(x, "(b f) t c -> b c f t", b=b)
    return x

SumHead

SumHead(input_dim: int)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
beat_downbeat_lin
Source code in src/splifft/models/beat_this.py
111
112
113
def __init__(self, input_dim: int):
    super().__init__()
    self.beat_downbeat_lin = nn.Linear(input_dim, 2)
beat_downbeat_lin instance-attribute
beat_downbeat_lin = Linear(input_dim, 2)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/beat_this.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def forward(self, x: Tensor) -> Tensor:
    beat_downbeat = self.beat_downbeat_lin(x)
    beat, downbeat = rearrange(beat_downbeat, "b t c -> c b t", c=2)

    # aggregate beats and downbeats prediction
    # autocast to float16 disabled to avoid numerical issues causing NaNs
    device_type = beat.device.type
    if device_type != "mps" and torch.amp.is_autocast_available(device_type):  # type: ignore
        disable_autocast = torch.autocast(device_type, enabled=False)
    else:
        disable_autocast = contextlib.nullcontext()

    with disable_autocast:
        beat = beat.float() + downbeat.float()
    return torch.stack([beat, downbeat], dim=0)  # (2, b, t)

Head

Head(input_dim: int)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
beat_downbeat_lin
Source code in src/splifft/models/beat_this.py
133
134
135
def __init__(self, input_dim: int):
    super().__init__()
    self.beat_downbeat_lin = nn.Linear(input_dim, 2)
beat_downbeat_lin instance-attribute
beat_downbeat_lin = Linear(input_dim, 2)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/beat_this.py
137
138
139
140
def forward(self, x: Tensor) -> Tensor:
    beat_downbeat = self.beat_downbeat_lin(x)
    beat, downbeat = rearrange(beat_downbeat, "b t c -> c b t", c=2)
    return torch.stack([beat, downbeat], dim=0)

BeatThis

BeatThis(cfg: BeatThisParams)

Bases: Module

Methods:

Name Description
forward

:param x: Input spectrogram (B, T, F)

Attributes:

Name Type Description
spect_dim
frontend
transformer_blocks
task_heads
Source code in src/splifft/models/beat_this.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def __init__(self, cfg: BeatThisParams):
    super().__init__()
    self.spect_dim = cfg.spect_dim

    max_frames = cfg.chunk_size // cfg.log_mel_hop_length
    rotary_embed_t = RotaryEmbedding(
        seq_len=max_frames,  # by default 1500 frames * 441 = 661500 samples
        dim_head=cfg.head_dim,
        dtype=cfg.rotary_embed_dtype,
    )

    # NOTE: removed rearrange from original impl to standardise input as (B, F, T)
    stem = nn.Sequential(
        OrderedDict(
            bn1d=nn.BatchNorm1d(cfg.spect_dim),
            add_channel=Rearrange("b f t -> b 1 f t"),
            conv2d=nn.Conv2d(
                in_channels=1,
                out_channels=cfg.stem_dim,
                kernel_size=(4, 3),
                stride=(4, 1),
                padding=(0, 1),
                bias=False,
            ),
            bn2d=nn.BatchNorm2d(cfg.stem_dim),
            activation=nn.GELU(),
        )
    )

    spect_dim = cfg.spect_dim // 4
    dim = cfg.stem_dim
    frontend_blocks = []
    for _ in range(3):
        rotary_embed_f = RotaryEmbedding(
            seq_len=spect_dim,
            dim_head=cfg.head_dim,
            dtype=cfg.rotary_embed_dtype,
        )
        frontend_blocks.append(
            nn.Sequential(
                OrderedDict(
                    partial=(
                        PartialFTTransformer(
                            dim=dim,
                            dim_head=cfg.head_dim,
                            n_head=dim // cfg.head_dim,
                            rotary_embed_f=rotary_embed_f,
                            rotary_embed_t=rotary_embed_t,
                            dropout=cfg.dropout_frontend,
                        )
                        if cfg.partial_transformers
                        else nn.Identity()
                    ),
                    conv2d=nn.Conv2d(
                        in_channels=dim,
                        out_channels=dim * 2,
                        kernel_size=(2, 3),
                        stride=(2, 1),
                        padding=(0, 1),
                        bias=False,
                    ),
                    norm=nn.BatchNorm2d(dim * 2),
                    activation=nn.GELU(),
                )
            )
        )
        dim *= 2
        spect_dim //= 2

    self.frontend = nn.Sequential(
        OrderedDict(
            stem=stem,
            blocks=nn.Sequential(*frontend_blocks),
            concat=Rearrange("b c f t -> b t (c f)"),
            linear=nn.Linear(dim * spect_dim, cfg.transformer_dim),
        )
    )

    n_heads = cfg.transformer_dim // cfg.head_dim
    # TODO check if this is really equivalent
    self.transformer_blocks = Transformer(
        dim=cfg.transformer_dim,
        depth=cfg.n_layers,
        dim_head=cfg.head_dim,
        heads=n_heads,
        attn_dropout=cfg.dropout_transformer,
        ff_dropout=cfg.dropout_transformer,
        ff_mult=cfg.ff_mult,
        norm_output=True,
        rotary_embed=rotary_embed_t,
        transformer_residual_dtype=cfg.transformer_residual_dtype,
    )

    if cfg.sum_head:
        self.task_heads = SumHead(cfg.transformer_dim)
    else:
        self.task_heads = Head(cfg.transformer_dim)
spect_dim instance-attribute
spect_dim = spect_dim
frontend instance-attribute
frontend = Sequential(
    OrderedDict(
        stem=stem,
        blocks=Sequential(*frontend_blocks),
        concat=Rearrange("b c f t -> b t (c f)"),
        linear=Linear(dim * spect_dim, transformer_dim),
    )
)
transformer_blocks instance-attribute
transformer_blocks = Transformer(
    dim=transformer_dim,
    depth=n_layers,
    dim_head=head_dim,
    heads=n_heads,
    attn_dropout=dropout_transformer,
    ff_dropout=dropout_transformer,
    ff_mult=ff_mult,
    norm_output=True,
    rotary_embed=rotary_embed_t,
    transformer_residual_dtype=transformer_residual_dtype,
)
task_heads instance-attribute
task_heads = SumHead(transformer_dim)
forward
forward(x: Tensor) -> Tensor

Parameters:

Name Type Description Default
x Tensor

Input spectrogram (B, T, F)

required

Returns:

Type Description
Tensor

Logits (2, B, T) -> [Beats, Downbeats]

Source code in src/splifft/models/beat_this.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def forward(self, x: Tensor) -> Tensor:
    """
    :param x: Input spectrogram (B, T, F)
    :return: Logits (2, B, T) -> [Beats, Downbeats]
    """
    if x.ndim != 3:
        raise ValueError(f"expected 3D spectrogram input, got shape={tuple(x.shape)}")
    if x.shape[-1] != self.spect_dim:
        raise ValueError(
            f"expected beat_this input shape `(B,T,{self.spect_dim})`, got {tuple(x.shape)}"
        )

    x = x.transpose(1, 2)
    x = self.frontend(x)
    x = self.transformer_blocks(x)
    x = self.task_heads(x)
    return x

bs_roformer

Band-Split RoPE Transformer

This implementation merges the two versions found in lucidrains's implementation However, there are several inconsistencies:

  • MLP was defined differently in each file, one that has depth - 1 hidden layers and one that has depth layers.
  • BSRoformer applies one final RMSNorm after the entire stack of transformer layers, while the MelBandRoformer applies an RMSNorm at the end of each axial transformer block (time_transformer, freq_transformer, etc.) and has no final normalization layer.

Since fixing the three inconsistencies upstream is too big of a breaking change, we inherit them to maintain compatibility with community-trained models. See: https://web.archive.org/web/20260112010548/https://github.com/lucidrains/BS-RoFormer/issues/48.

To avoid dependency bloat, we do not:

Classes:

Name Description
FixedBandsConfig
MelBandsConfig
BaselineMaskEstimatorConfig
AxialRefinerLargeV2MaskEstimatorConfig

unwa large-v2 head. Adds a small axial transformer refiner inside the mask head.

HyperAceResidualV1MaskEstimatorConfig

unwa HyperACE v1 residual head compatibility config.

HyperAceResidualV2MaskEstimatorConfig

UNWA HyperACE v2 residual head compatibility config.

BSRoformerParams
RMSNorm
RMSNormWithEps
RotaryEmbedding

A performance-oriented version of RoPE.

FeedForward
Attention
LinearAttention

this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.

Transformer
BandSplit
MaskEstimator
AxialRefinerLargeV2MaskEstimator
HyperAceResidualMaskEstimator
BSRoformer

Functions:

Name Description
l2norm
rms_norm
mlp

Attributes:

Name Type Description
DEFAULT_FREQS_PER_BANDS
MaskEstimatorConfig

DEFAULT_FREQS_PER_BANDS module-attribute

DEFAULT_FREQS_PER_BANDS = (
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    2,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    4,
    12,
    12,
    12,
    12,
    12,
    12,
    12,
    12,
    24,
    24,
    24,
    24,
    24,
    24,
    24,
    24,
    48,
    48,
    48,
    48,
    48,
    48,
    48,
    48,
    128,
    129,
)

FixedBandsConfig dataclass

FixedBandsConfig(
    kind: Literal["fixed"],
    freqs_per_bands: tuple[Gt0[int], ...] = (
        lambda: DEFAULT_FREQS_PER_BANDS
    )(),
)

Attributes:

Name Type Description
kind Literal['fixed']
freqs_per_bands tuple[Gt0[int], ...]
kind instance-attribute
kind: Literal['fixed']
freqs_per_bands class-attribute instance-attribute
freqs_per_bands: tuple[Gt0[int], ...] = field(
    default_factory=lambda: DEFAULT_FREQS_PER_BANDS
)

MelBandsConfig dataclass

MelBandsConfig(
    kind: Literal["mel"],
    stft_n_fft: Gt0[int] = 2048,
    num_bands: Gt0[int] = 60,
    sample_rate: Gt0[int] = 44100,
)

Attributes:

Name Type Description
kind Literal['mel']
stft_n_fft Gt0[int]
num_bands Gt0[int]
sample_rate Gt0[int]
kind instance-attribute
kind: Literal['mel']
stft_n_fft class-attribute instance-attribute
stft_n_fft: Gt0[int] = 2048
num_bands class-attribute instance-attribute
num_bands: Gt0[int] = 60
sample_rate class-attribute instance-attribute
sample_rate: Gt0[int] = 44100

BaselineMaskEstimatorConfig dataclass

BaselineMaskEstimatorConfig(
    kind: Literal["baseline"] = "baseline",
)

Attributes:

Name Type Description
kind Literal['baseline']
kind class-attribute instance-attribute
kind: Literal['baseline'] = 'baseline'

AxialRefinerLargeV2MaskEstimatorConfig dataclass

AxialRefinerLargeV2MaskEstimatorConfig(
    kind: Literal["axial_refiner_large_v2"],
    axial_refiner_depth: Gt0[int] = 4,
)

unwa large-v2 head. Adds a small axial transformer refiner inside the mask head.

Attributes:

Name Type Description
kind Literal['axial_refiner_large_v2']
axial_refiner_depth Gt0[int]
kind instance-attribute
kind: Literal['axial_refiner_large_v2']
axial_refiner_depth class-attribute instance-attribute
axial_refiner_depth: Gt0[int] = 4

HyperAceResidualV1MaskEstimatorConfig dataclass

HyperAceResidualV1MaskEstimatorConfig(
    kind: Literal["hyperace_residual_v1"],
    num_hyperedges: Gt0[int] | None = None,
    num_heads: Gt0[int] = 8,
)

unwa HyperACE v1 residual head compatibility config.

Attributes:

Name Type Description
kind Literal['hyperace_residual_v1']
num_hyperedges Gt0[int] | None
num_heads Gt0[int]
kind instance-attribute
kind: Literal['hyperace_residual_v1']
num_hyperedges class-attribute instance-attribute
num_hyperedges: Gt0[int] | None = None
num_heads class-attribute instance-attribute
num_heads: Gt0[int] = 8

HyperAceResidualV2MaskEstimatorConfig dataclass

HyperAceResidualV2MaskEstimatorConfig(
    kind: Literal["hyperace_residual_v2"],
    num_hyperedges: Gt0[int] | None = None,
    num_heads: Gt0[int] = 8,
)

UNWA HyperACE v2 residual head compatibility config.

Attributes:

Name Type Description
kind Literal['hyperace_residual_v2']
num_hyperedges Gt0[int] | None
num_heads Gt0[int]
kind instance-attribute
kind: Literal['hyperace_residual_v2']
num_hyperedges class-attribute instance-attribute
num_hyperedges: Gt0[int] | None = None
num_heads class-attribute instance-attribute
num_heads: Gt0[int] = 8

BSRoformerParams dataclass

BSRoformerParams(
    chunk_size: ChunkSize,
    output_stem_names: tuple[ModelOutputStemName, ...],
    dim: Gt0[int],
    depth: Gt0[int],
    band_config: FixedBandsConfig | MelBandsConfig,
    stft_hop_length: HopSize = 512,
    stereo: bool = True,
    time_transformer_depth: Gt0[int] = 1,
    freq_transformer_depth: Gt0[int] = 1,
    linear_transformer_depth: Ge0[int] = 0,
    dim_head: int = 64,
    heads: Gt0[int] = 8,
    attn_dropout: Dropout = 0.0,
    ff_dropout: Dropout = 0.0,
    ff_mult: Gt0[int] = 4,
    flash_attn: bool = True,
    norm_output: bool = False,
    mask_estimator_depth: Gt0[int] = 2,
    mlp_expansion_factor: Gt0[int] = 4,
    mask_estimator: MaskEstimatorConfig = BaselineMaskEstimatorConfig(),
    use_torch_checkpoint: bool = False,
    sage_attention: bool = False,
    use_shared_bias: bool = False,
    skip_connection: bool = False,
    rms_norm_eps: Ge0[float] | None = None,
    rotary_embed_dtype: TorchDtype | None = None,
    transformer_residual_dtype: TorchDtype | None = None,
    debug: bool = False,
)

Bases: ModelParamsLike

Attributes:

Name Type Description
chunk_size ChunkSize
output_stem_names tuple[ModelOutputStemName, ...]
dim Gt0[int]
depth Gt0[int]
band_config FixedBandsConfig | MelBandsConfig
stft_hop_length HopSize
stereo bool
time_transformer_depth Gt0[int]
freq_transformer_depth Gt0[int]
linear_transformer_depth Ge0[int]
dim_head int
heads Gt0[int]
attn_dropout Dropout
ff_dropout Dropout
ff_mult Gt0[int]
flash_attn bool
norm_output bool

Note that in lucidrains' implementation, this is set to

mask_estimator_depth Gt0[int]

The number of hidden layers of the MLP is mask_estimator_depth - 1, that is:

mlp_expansion_factor Gt0[int]
mask_estimator MaskEstimatorConfig
use_torch_checkpoint bool
sage_attention bool
use_shared_bias bool
skip_connection bool
rms_norm_eps Ge0[float] | None
rotary_embed_dtype TorchDtype | None
transformer_residual_dtype TorchDtype | None
debug bool

Whether to check for nan/inf in model outputs. Keep it off for torch.compile.

input_channels ModelInputChannels
input_type ModelInputType
output_type ModelOutputType
inference_archetype InferenceArchetype
chunk_size instance-attribute
chunk_size: ChunkSize
output_stem_names instance-attribute
output_stem_names: tuple[ModelOutputStemName, ...]
dim instance-attribute
dim: Gt0[int]
depth instance-attribute
depth: Gt0[int]
band_config instance-attribute
stft_hop_length class-attribute instance-attribute
stft_hop_length: HopSize = 512
stereo class-attribute instance-attribute
stereo: bool = True
time_transformer_depth class-attribute instance-attribute
time_transformer_depth: Gt0[int] = 1
freq_transformer_depth class-attribute instance-attribute
freq_transformer_depth: Gt0[int] = 1
linear_transformer_depth class-attribute instance-attribute
linear_transformer_depth: Ge0[int] = 0
dim_head class-attribute instance-attribute
dim_head: int = 64
heads class-attribute instance-attribute
heads: Gt0[int] = 8
attn_dropout class-attribute instance-attribute
attn_dropout: Dropout = 0.0
ff_dropout class-attribute instance-attribute
ff_dropout: Dropout = 0.0
ff_mult class-attribute instance-attribute
ff_mult: Gt0[int] = 4
flash_attn class-attribute instance-attribute
flash_attn: bool = True
norm_output class-attribute instance-attribute
norm_output: bool = False

Note that in lucidrains' implementation, this is set to False for bs_roformer but True for mel_roformer!!

mask_estimator_depth class-attribute instance-attribute
mask_estimator_depth: Gt0[int] = 2

The number of hidden layers of the MLP is mask_estimator_depth - 1, that is:

  • depth = 1: (dim_in, dim_out)
  • depth = 2: (dim_in, dim_hidden, dim_out)

Note that in lucidrains' implementation of mel-band roformers, the number of hidden layers is incorrectly set as mask_estimator_depth. This includes popular models like kim-vocals and all models that use zfturbo's music-source-separation training.

If you are migrating a mel-band roformer's zfturbo configuration, increment the mask_estimator depth by 1.

mlp_expansion_factor class-attribute instance-attribute
mlp_expansion_factor: Gt0[int] = 4
mask_estimator class-attribute instance-attribute
mask_estimator: MaskEstimatorConfig = field(
    default_factory=BaselineMaskEstimatorConfig
)
use_torch_checkpoint class-attribute instance-attribute
use_torch_checkpoint: bool = False
sage_attention class-attribute instance-attribute
sage_attention: bool = False
use_shared_bias class-attribute instance-attribute
use_shared_bias: bool = False
skip_connection class-attribute instance-attribute
skip_connection: bool = False
rms_norm_eps class-attribute instance-attribute
rms_norm_eps: Ge0[float] | None = None
rotary_embed_dtype class-attribute instance-attribute
rotary_embed_dtype: TorchDtype | None = None
transformer_residual_dtype class-attribute instance-attribute
transformer_residual_dtype: TorchDtype | None = None
debug class-attribute instance-attribute
debug: bool = False

Whether to check for nan/inf in model outputs. Keep it off for torch.compile.

input_channels property
input_channels: ModelInputChannels
input_type property
input_type: ModelInputType
output_type property
output_type: ModelOutputType
inference_archetype property
inference_archetype: InferenceArchetype

l2norm

l2norm(t: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
176
177
def l2norm(t: Tensor) -> Tensor:
    return F.normalize(t, dim=-1, p=2)

RMSNorm

RMSNorm(dim: int)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
scale
gamma
Source code in src/splifft/models/bs_roformer.py
181
182
183
184
def __init__(self, dim: int):
    super().__init__()
    self.scale = dim**0.5
    self.gamma = nn.Parameter(torch.ones(dim))
scale instance-attribute
scale = dim ** 0.5
gamma instance-attribute
gamma = Parameter(ones(dim))
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
186
187
def forward(self, x: Tensor) -> Tensor:
    return F.normalize(x, dim=-1) * self.scale * self.gamma  # type: ignore

RMSNormWithEps

RMSNormWithEps(
    dim: int, eps: float = 5.960464477539063e-08
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
scale
gamma
eps
Source code in src/splifft/models/bs_roformer.py
191
192
193
194
195
def __init__(self, dim: int, eps: float = 5.960464477539063e-08):
    super().__init__()
    self.scale = dim**0.5
    self.gamma = nn.Parameter(torch.ones(dim))
    self.eps = eps
scale instance-attribute
scale = dim ** 0.5
gamma instance-attribute
gamma = Parameter(ones(dim))
eps instance-attribute
eps = eps
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
197
198
199
200
201
def forward(self, x: Tensor) -> Tensor:
    l2_norm = torch.linalg.norm(x, dim=-1, keepdim=True)
    denom = torch.maximum(l2_norm, torch.full_like(l2_norm, self.eps))
    normalized_x = x / denom
    return normalized_x * self.scale * self.gamma  # type: ignore

rms_norm

rms_norm(
    dim: int, eps: float | None
) -> RMSNorm | RMSNormWithEps
Source code in src/splifft/models/bs_roformer.py
204
205
206
207
def rms_norm(dim: int, eps: float | None) -> RMSNorm | RMSNormWithEps:
    if eps is None:
        return RMSNorm(dim)
    return RMSNormWithEps(dim, eps)

RotaryEmbedding

RotaryEmbedding(
    seq_len: int,
    dim_head: int,
    *,
    dtype: dtype | None,
    theta: int = 10000,
)

Bases: Module

A performance-oriented version of RoPE.

Unlike lucidrains' implementation which compute embeddings JIT during the forward pass and caches calls with the same or shorter sequence length, we simply compute them AOT as persistent buffers. To keep the computational graph clean, we do not support dynamic sequence lengths, learned frequencies or length extrapolation.

Methods:

Name Description
rotate_half
forward

Attributes:

Name Type Description
cos_emb
sin_emb
Source code in src/splifft/models/bs_roformer.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def __init__(
    self, seq_len: int, dim_head: int, *, dtype: torch.dtype | None, theta: int = 10000
):
    super().__init__()
    # COMPAT: the original implementation does not generate the embeddings
    # on the fly, but serialises them in fp16. there are some tiny
    # differences:
    # |                     |   from weights  |   generated    |
    # | ------------------- | --------------- | -------------- |
    # | cos_emb_time:971,22 | -0.99462890625  | -0.994140625   |
    # | cos_emb_time:971,23 | -0.99462890625  | -0.994140625   |
    # | sin_emb_time:727,12 | -0.457763671875 | -0.4580078125  |
    # | sin_emb_time:727,13 | -0.457763671875 | -0.4580078125  |
    # | sin_emb_time:825,4  | -0.8544921875   | -0.85400390625 |
    # | sin_emb_time:825,5  | -0.8544921875   | -0.85400390625 |
    freqs = 1.0 / (theta ** (torch.arange(0, dim_head, 2).float() / dim_head))
    t = torch.arange(seq_len)
    freqs = torch.einsum("i,j->ij", t, freqs)  # (seq_len, dim / 2)
    freqs = repeat(freqs, "... d -> ... (d r)", r=2)  # (seq_len, dim)
    self.cos_emb = freqs.cos().to(dtype)
    self.sin_emb = freqs.sin().to(dtype)
cos_emb instance-attribute
cos_emb = to(dtype)
sin_emb instance-attribute
sin_emb = to(dtype)
rotate_half
rotate_half(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
245
246
247
248
249
def rotate_half(self, x: Tensor) -> Tensor:
    x = rearrange(x, "... (d r) -> ... d r", r=2)
    x1, x2 = x.unbind(dim=-1)
    x = torch.stack((-x2, x1), dim=-1)
    return rearrange(x, "... d r -> ... (d r)")
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
251
252
253
254
255
256
257
258
259
260
261
def forward(self, x: Tensor) -> Tensor:
    # x is (batch_eff, heads, seq_len_for_rotation, dim_head)
    cos_b = self.cos_emb.unsqueeze(0).unsqueeze(0).to(x.device, x.dtype)
    sin_b = self.sin_emb.unsqueeze(0).unsqueeze(0).to(x.device, x.dtype)

    term1 = x * cos_b
    term2 = self.rotate_half(x) * sin_b

    # NOTE: original impl performed addition between two f32s but it comes with 30% slowdown
    # we eliminate it so the addition is performed between two f16s (according to __init__).
    return term1 + term2

FeedForward

FeedForward(
    dim: int,
    mult: int = 4,
    dropout: float = 0.0,
    rms_norm_eps: float | None = None,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
net
Source code in src/splifft/models/bs_roformer.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def __init__(
    self, dim: int, mult: int = 4, dropout: float = 0.0, rms_norm_eps: float | None = None
):
    super().__init__()
    dim_inner = int(dim * mult)
    # NOTE: in the paper: RMSNorm -> FC -> Tanh -> FC -> GLU
    self.net = nn.Sequential(
        rms_norm(dim, eps=rms_norm_eps),
        nn.Linear(dim, dim_inner),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim_inner, dim),
        nn.Dropout(dropout),
    )
net instance-attribute
net = Sequential(
    rms_norm(dim, eps=rms_norm_eps),
    Linear(dim, dim_inner),
    GELU(),
    Dropout(dropout),
    Linear(dim_inner, dim),
    Dropout(dropout),
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
280
281
def forward(self, x: Tensor) -> Tensor:
    return cast(Tensor, self.net(x))

Attention

Attention(
    dim: int,
    heads: int = 8,
    dim_head: int = 64,
    dropout: float = 0.0,
    shared_qkv_bias: Parameter | None = None,
    shared_out_bias: Parameter | None = None,
    rotary_embed: RotaryEmbedding | None = None,
    flash: bool = True,
    sage_attention: bool = False,
    rms_norm_eps: float | None = None,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
heads
scale
rotary_embed
attend
norm
to_qkv
to_gates
to_out
Source code in src/splifft/models/bs_roformer.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def __init__(
    self,
    dim: int,
    heads: int = 8,
    dim_head: int = 64,
    dropout: float = 0.0,
    shared_qkv_bias: nn.Parameter | None = None,
    shared_out_bias: nn.Parameter | None = None,
    rotary_embed: RotaryEmbedding | None = None,
    flash: bool = True,
    sage_attention: bool = False,
    rms_norm_eps: float | None = None,
):
    super().__init__()
    self.heads = heads
    self.scale = dim_head**-0.5
    dim_inner = heads * dim_head

    self.rotary_embed = rotary_embed

    if sage_attention:
        from .utils.attend_sage import AttendSage

        self.attend = AttendSage(flash=flash, dropout=dropout)
    else:
        from .utils.attend import Attend

        self.attend = Attend(flash=flash, dropout=dropout)  # type: ignore

    self.norm = rms_norm(dim, eps=rms_norm_eps)
    self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=(shared_qkv_bias is not None))
    if shared_qkv_bias is not None:
        self.to_qkv.bias = shared_qkv_bias

    self.to_gates = nn.Linear(dim, heads)

    self.to_out = nn.Sequential(
        nn.Linear(dim_inner, dim, bias=(shared_out_bias is not None)),
        nn.Dropout(dropout),
    )
    if shared_out_bias is not None:
        self.to_out[0].bias = shared_out_bias
heads instance-attribute
heads = heads
scale instance-attribute
scale = dim_head ** -0.5
rotary_embed instance-attribute
rotary_embed = rotary_embed
attend instance-attribute
attend = AttendSage(flash=flash, dropout=dropout)
norm instance-attribute
norm = rms_norm(dim, eps=rms_norm_eps)
to_qkv instance-attribute
to_qkv = Linear(
    dim, dim_inner * 3, bias=shared_qkv_bias is not None
)
to_gates instance-attribute
to_gates = Linear(dim, heads)
to_out instance-attribute
to_out = Sequential(
    Linear(
        dim_inner, dim, bias=shared_out_bias is not None
    ),
    Dropout(dropout),
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
def forward(self, x: Tensor) -> Tensor:
    x = self.norm(x)

    qkv = self.to_qkv(x)
    q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)

    if self.rotary_embed is not None:
        q = self.rotary_embed(q)
        k = self.rotary_embed(k)

    out = self.attend(q, k, v)

    gates = self.to_gates(x)
    gate_act = gates.sigmoid()

    out = out * rearrange(gate_act, "b n h -> b h n 1")

    out = rearrange(out, "b h n d -> b n (h d)")
    out = self.to_out(out)
    return cast(Tensor, out)

LinearAttention

LinearAttention(
    *,
    dim: int,
    dim_head: int = 32,
    heads: int = 8,
    scale: int = 8,
    flash: bool = False,
    dropout: float = 0.0,
    sage_attention: bool = False,
    rms_norm_eps: float | None = None,
)

Bases: Module

this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.

Methods:

Name Description
forward

Attributes:

Name Type Description
norm
to_qkv
temperature
attend
to_out
Source code in src/splifft/models/bs_roformer.py
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
def __init__(
    self,
    *,
    dim: int,
    dim_head: int = 32,
    heads: int = 8,
    scale: int = 8,
    flash: bool = False,
    dropout: float = 0.0,
    sage_attention: bool = False,
    rms_norm_eps: float | None = None,
):
    super().__init__()
    dim_inner = dim_head * heads
    self.norm = rms_norm(dim, eps=rms_norm_eps)

    self.to_qkv = nn.Sequential(
        nn.Linear(dim, dim_inner * 3, bias=False),
        Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
    )

    self.temperature = nn.Parameter(torch.ones(heads, 1, 1))

    if sage_attention:
        from .utils.attend_sage import AttendSage

        self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
    else:
        from .utils.attend import Attend

        self.attend = Attend(scale=scale, dropout=dropout, flash=flash)  # type: ignore

    self.to_out = nn.Sequential(
        Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
    )
norm instance-attribute
norm = rms_norm(dim, eps=rms_norm_eps)
to_qkv instance-attribute
to_qkv = Sequential(
    Linear(dim, dim_inner * 3, bias=False),
    Rearrange(
        "b n (qkv h d) -> qkv b h d n", qkv=3, h=heads
    ),
)
temperature instance-attribute
temperature = Parameter(ones(heads, 1, 1))
attend instance-attribute
attend = AttendSage(
    scale=scale, dropout=dropout, flash=flash
)
to_out instance-attribute
to_out = Sequential(
    Rearrange("b h d n -> b n (h d)"),
    Linear(dim_inner, dim, bias=False),
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
391
392
393
394
395
396
397
398
399
400
401
def forward(self, x: Tensor) -> Tensor:
    x = self.norm(x)

    q, k, v = self.to_qkv(x)

    q, k = map(l2norm, (q, k))
    q = q * self.temperature.exp()

    out = self.attend(q, k, v)

    return cast(Tensor, self.to_out(out))

Transformer

Transformer(
    *,
    dim: int,
    depth: int,
    dim_head: int = 64,
    heads: int = 8,
    attn_dropout: float = 0.0,
    ff_dropout: float = 0.0,
    ff_mult: int = 4,
    norm_output: bool = True,
    rotary_embed: RotaryEmbedding | None = None,
    flash_attn: bool = True,
    linear_attn: bool = False,
    sage_attention: bool = False,
    shared_qkv_bias: Parameter | None = None,
    shared_out_bias: Parameter | None = None,
    rms_norm_eps: float | None = None,
    transformer_residual_dtype: dtype | None = None,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
layers
transformer_residual_dtype
norm
Source code in src/splifft/models/bs_roformer.py
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
def __init__(
    self,
    *,
    dim: int,
    depth: int,
    dim_head: int = 64,
    heads: int = 8,
    attn_dropout: float = 0.0,
    ff_dropout: float = 0.0,
    ff_mult: int = 4,
    norm_output: bool = True,
    rotary_embed: RotaryEmbedding | None = None,
    flash_attn: bool = True,
    linear_attn: bool = False,
    sage_attention: bool = False,
    shared_qkv_bias: nn.Parameter | None = None,
    shared_out_bias: nn.Parameter | None = None,
    rms_norm_eps: float | None = None,
    transformer_residual_dtype: torch.dtype | None = None,  # COMPAT: float32, see 265
):
    super().__init__()
    self.layers = ModuleList([])

    for _ in range(depth):
        attn: LinearAttention | Attention
        if linear_attn:
            attn = LinearAttention(
                dim=dim,
                dim_head=dim_head,
                heads=heads,
                dropout=attn_dropout,
                flash=flash_attn,
                sage_attention=sage_attention,
                rms_norm_eps=rms_norm_eps,
            )
        else:
            attn = Attention(
                dim=dim,
                dim_head=dim_head,
                heads=heads,
                dropout=attn_dropout,
                shared_qkv_bias=shared_qkv_bias,
                shared_out_bias=shared_out_bias,
                rotary_embed=rotary_embed,
                flash=flash_attn,
                sage_attention=sage_attention,
                rms_norm_eps=rms_norm_eps,
            )

        ff = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout, rms_norm_eps=rms_norm_eps)
        self.layers.append(ModuleList([attn, ff]))
    self.transformer_residual_dtype = transformer_residual_dtype

    self.norm = rms_norm(dim, eps=rms_norm_eps) if norm_output else nn.Identity()
layers instance-attribute
layers = ModuleList([])
transformer_residual_dtype instance-attribute
transformer_residual_dtype = transformer_residual_dtype
norm instance-attribute
norm = (
    rms_norm(dim, eps=rms_norm_eps)
    if norm_output
    else Identity()
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
def forward(self, x: Tensor) -> Tensor:
    for layer in self.layers:
        block = cast(ModuleList, layer)
        attn = block[0]
        ff = block[1]
        attn_out = attn(x)
        if self.transformer_residual_dtype is not None:
            x = (
                attn_out.to(self.transformer_residual_dtype)
                + x.to(self.transformer_residual_dtype)
            ).to(x.dtype)
        else:
            x = attn_out + x

        ff_out = ff(x)
        if self.transformer_residual_dtype is not None:
            x = (
                ff_out.to(self.transformer_residual_dtype)
                + x.to(self.transformer_residual_dtype)
            ).to(x.dtype)
        else:
            x = ff_out + x
    return cast(Tensor, self.norm(x))

BandSplit

BandSplit(
    dim: int,
    dim_inputs: tuple[int, ...],
    rms_norm_eps: float | None = None,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
dim_inputs
to_features
Source code in src/splifft/models/bs_roformer.py
489
490
491
492
493
494
495
496
def __init__(self, dim: int, dim_inputs: tuple[int, ...], rms_norm_eps: float | None = None):
    super().__init__()
    self.dim_inputs = dim_inputs
    self.to_features = ModuleList([])

    for dim_in in dim_inputs:
        net = nn.Sequential(rms_norm(dim_in, rms_norm_eps), nn.Linear(dim_in, dim))
        self.to_features.append(net)
dim_inputs instance-attribute
dim_inputs = dim_inputs
to_features instance-attribute
to_features = ModuleList([])
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
498
499
500
501
502
503
504
def forward(self, x: Tensor) -> Tensor:
    x_split = torch.split(x, list(self.dim_inputs), dim=-1)
    outs = []
    for split_input, to_feature_net in zip(x_split, self.to_features):
        split_output = to_feature_net(split_input)
        outs.append(split_output)
    return torch.stack(outs, dim=-2)

mlp

mlp(
    dim_in: int,
    dim_out: int,
    dim_hidden: int | None = None,
    depth: int = 1,
    activation: type[Module] = Tanh,
) -> Sequential
Source code in src/splifft/models/bs_roformer.py
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
def mlp(
    dim_in: int,
    dim_out: int,
    dim_hidden: int | None = None,
    depth: int = 1,
    activation: type[Module] = nn.Tanh,
) -> nn.Sequential:
    dim_hidden_ = dim_hidden or dim_in

    net: list[Module] = []
    # NOTE: in lucidrain's impl, `bs_roformer` has `depth - 1` but `mel_roformer` has `depth`
    num_hidden_layers = depth - 1
    dims = (dim_in, *((dim_hidden_,) * num_hidden_layers), dim_out)

    for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
        is_last = ind == (len(dims) - 2)

        net.append(nn.Linear(layer_dim_in, layer_dim_out))

        if is_last:
            continue

        net.append(activation())

    return nn.Sequential(*net)

MaskEstimator

MaskEstimator(
    dim: int,
    dim_inputs: tuple[int, ...],
    depth: int,
    mlp_expansion_factor: int,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
dim_inputs
to_freqs
Source code in src/splifft/models/bs_roformer.py
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
def __init__(
    self,
    dim: int,
    dim_inputs: tuple[int, ...],
    depth: int,
    mlp_expansion_factor: int,
):
    super().__init__()
    self.dim_inputs = dim_inputs
    self.to_freqs = _build_band_to_freq_mlps(
        dim=dim,
        dim_inputs=dim_inputs,
        depth=depth,
        mlp_expansion_factor=mlp_expansion_factor,
    )
dim_inputs instance-attribute
dim_inputs = dim_inputs
to_freqs instance-attribute
to_freqs = _build_band_to_freq_mlps(
    dim=dim,
    dim_inputs=dim_inputs,
    depth=depth,
    mlp_expansion_factor=mlp_expansion_factor,
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
570
571
572
573
574
575
576
577
578
579
def forward(self, x: Tensor) -> Tensor:
    x_unbound = x.unbind(dim=-2)

    outs = []

    for band_features, mlp_net in zip(x_unbound, self.to_freqs):
        freq_out = mlp_net(band_features)
        outs.append(freq_out)

    return torch.cat(outs, dim=-1)

AxialRefinerLargeV2MaskEstimator

AxialRefinerLargeV2MaskEstimator(
    dim: int,
    dim_inputs: tuple[int, ...],
    mlp_depth: int,
    mlp_expansion_factor: int,
    axial_refiner_depth: int,
    t_frames: int,
    num_bands: int,
    rotary_embed_dtype: dtype | None,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
dim_inputs
to_freqs
layers
norm
Source code in src/splifft/models/bs_roformer.py
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
def __init__(
    self,
    dim: int,
    dim_inputs: tuple[int, ...],
    mlp_depth: int,
    mlp_expansion_factor: int,
    axial_refiner_depth: int,
    t_frames: int,
    num_bands: int,
    rotary_embed_dtype: torch.dtype | None,
):
    super().__init__()
    self.dim_inputs = dim_inputs
    self.to_freqs = _build_band_to_freq_mlps(
        dim=dim,
        dim_inputs=dim_inputs,
        depth=mlp_depth,
        mlp_expansion_factor=mlp_expansion_factor,
    )

    self.layers = ModuleList([])

    heads = 8
    dim_head = 64

    time_rotary_embed = RotaryEmbedding(
        seq_len=t_frames,
        dim_head=dim_head,
        dtype=rotary_embed_dtype,
    )
    freq_rotary_embed = RotaryEmbedding(
        seq_len=num_bands,
        dim_head=dim_head,
        dtype=rotary_embed_dtype,
    )

    for _ in range(axial_refiner_depth):
        self.layers.append(
            nn.ModuleList(
                [
                    Transformer(
                        dim=dim,
                        depth=1,
                        heads=heads,
                        dim_head=dim_head,
                        attn_dropout=0.0,
                        ff_dropout=0.0,
                        flash_attn=True,
                        norm_output=False,
                        rotary_embed=time_rotary_embed,
                        sage_attention=False,
                    ),
                    Transformer(
                        dim=dim,
                        depth=1,
                        heads=heads,
                        dim_head=dim_head,
                        attn_dropout=0.0,
                        ff_dropout=0.0,
                        flash_attn=True,
                        norm_output=False,
                        rotary_embed=freq_rotary_embed,
                        sage_attention=False,
                    ),
                ]
            )
        )

    self.norm = RMSNorm(dim)
dim_inputs instance-attribute
dim_inputs = dim_inputs
to_freqs instance-attribute
to_freqs = _build_band_to_freq_mlps(
    dim=dim,
    dim_inputs=dim_inputs,
    depth=mlp_depth,
    mlp_expansion_factor=mlp_expansion_factor,
)
layers instance-attribute
layers = ModuleList([])
norm instance-attribute
norm = RMSNorm(dim)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
def forward(self, x: Tensor) -> Tensor:
    for transformer_block in self.layers:
        block = cast(ModuleList, transformer_block)
        time_transformer, freq_transformer = block

        x = rearrange(x, "b t f d -> b f t d")
        x, ps = pack([x], "* t d")

        x = time_transformer(x)

        (x,) = unpack(x, ps, "* t d")
        x = rearrange(x, "b f t d -> b t f d")
        x, ps = pack([x], "* f d")

        x = freq_transformer(x)

        (x,) = unpack(x, ps, "* f d")

    x = self.norm(x)

    x_unbound = x.unbind(dim=-2)

    outs = []

    for band_features, mlp_net in zip(x_unbound, self.to_freqs):
        freq_out = mlp_net(band_features)
        outs.append(freq_out)

    return torch.cat(outs, dim=-1)

HyperAceResidualMaskEstimator

HyperAceResidualMaskEstimator(
    dim: int,
    dim_inputs: tuple[int, ...],
    depth: int,
    mlp_expansion_factor: int,
    segm: Module,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
dim_inputs
to_freqs
segm
Source code in src/splifft/models/bs_roformer.py
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
def __init__(
    self,
    dim: int,
    dim_inputs: tuple[int, ...],
    depth: int,
    mlp_expansion_factor: int,
    segm: nn.Module,
):
    super().__init__()
    self.dim_inputs = dim_inputs
    self.to_freqs = _build_band_to_freq_mlps(
        dim=dim,
        dim_inputs=dim_inputs,
        depth=depth,
        mlp_expansion_factor=mlp_expansion_factor,
    )

    self.segm = segm
dim_inputs instance-attribute
dim_inputs = dim_inputs
to_freqs instance-attribute
to_freqs = _build_band_to_freq_mlps(
    dim=dim,
    dim_inputs=dim_inputs,
    depth=depth,
    mlp_expansion_factor=mlp_expansion_factor,
)
segm instance-attribute
segm = segm
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/bs_roformer.py
704
705
706
707
708
709
710
711
712
713
714
715
def forward(self, x: Tensor) -> Tensor:
    y = rearrange(x, "b t f c -> b c t f")
    y = self.segm(y)
    y = rearrange(y, "b c t f -> b t (f c)")

    x_unbound = x.unbind(dim=-2)
    outs = []
    for band_features, mlp_net in zip(x_unbound, self.to_freqs):
        freq_out = mlp_net(band_features)
        outs.append(freq_out)

    return cast(Tensor, torch.cat(outs, dim=-1) + y)

BSRoformer

BSRoformer(cfg: BSRoformerParams)

Bases: Module, SupportsStemSelection[BSRoformerParams]

Methods:

Name Description
forward

:param stft_repr: input spectrogram. shape (b, f*s, t, c)

__splifft_stem_selection_plan__

Remap mask_estimators.{i}.* state-dict entries to a compact

Attributes:

Name Type Description
stereo
audio_channels
num_stems
use_torch_checkpoint
skip_connection
layers
shared_qkv_bias Parameter | None
shared_out_bias Parameter | None
is_mel
final_norm
band_split
mask_estimators
debug
Source code in src/splifft/models/bs_roformer.py
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
def __init__(self, cfg: BSRoformerParams):
    super().__init__()
    self.stereo = cfg.stereo
    self.audio_channels = 2 if cfg.stereo else 1
    self.num_stems = len(cfg.output_stem_names)
    self.use_torch_checkpoint = cfg.use_torch_checkpoint
    self.skip_connection = cfg.skip_connection

    self.layers = ModuleList([])

    self.shared_qkv_bias: nn.Parameter | None = None
    self.shared_out_bias: nn.Parameter | None = None
    if cfg.use_shared_bias:
        dim_inner = cfg.heads * cfg.dim_head
        self.shared_qkv_bias = nn.Parameter(torch.ones(dim_inner * 3))
        self.shared_out_bias = nn.Parameter(torch.ones(cfg.dim))

    transformer = partial(
        Transformer,
        dim=cfg.dim,
        heads=cfg.heads,
        dim_head=cfg.dim_head,
        attn_dropout=cfg.attn_dropout,
        ff_dropout=cfg.ff_dropout,
        ff_mult=cfg.ff_mult,
        flash_attn=cfg.flash_attn,
        norm_output=cfg.norm_output,
        sage_attention=cfg.sage_attention,
        shared_qkv_bias=self.shared_qkv_bias,
        shared_out_bias=self.shared_out_bias,
        rms_norm_eps=cfg.rms_norm_eps,
        transformer_residual_dtype=cfg.transformer_residual_dtype,
    )

    t_frames = cfg.chunk_size // cfg.stft_hop_length + 1  # e.g. 588800 // 512 + 1 = 1151
    time_rotary_embed = RotaryEmbedding(
        seq_len=t_frames, dim_head=cfg.dim_head, dtype=cfg.rotary_embed_dtype
    )

    if is_mel := isinstance(cfg.band_config, MelBandsConfig):
        from torchaudio.functional import melscale_fbanks

        mel_cfg = cfg.band_config
        num_bands = mel_cfg.num_bands
        freqs = mel_cfg.stft_n_fft // 2 + 1
        mel_filter_bank = melscale_fbanks(
            n_freqs=freqs,
            f_min=0.0,
            f_max=float(mel_cfg.sample_rate / 2),
            n_mels=num_bands,
            sample_rate=mel_cfg.sample_rate,
            norm="slaney",
            mel_scale="slaney",
        ).T
        # TODO: adopt https://github.com/lucidrains/BS-RoFormer/issues/47
        mel_filter_bank[0, 0] = 1.0
        mel_filter_bank[-1, -1] = 1.0

        freqs_per_band_mask = mel_filter_bank > 0
        assert freqs_per_band_mask.any(dim=0).all(), (
            "all frequencies must be covered by at least one band"
        )

        repeated_freq_indices = repeat(torch.arange(freqs), "f -> b f", b=num_bands)
        freq_indices = repeated_freq_indices[freqs_per_band_mask]
        if self.stereo:
            freq_indices = repeat(freq_indices, "f -> f s", s=2)
            freq_indices = freq_indices * 2 + torch.arange(2)
            freq_indices = rearrange(freq_indices, "f s -> (f s)")
        self.register_buffer("freq_indices", freq_indices, persistent=False)
        self.register_buffer("freqs_per_band_mask", freqs_per_band_mask, persistent=False)

        num_freqs_per_band = reduce(freqs_per_band_mask, "b f -> b", "sum")
        num_bands_per_freq = reduce(freqs_per_band_mask, "b f -> f", "sum")

        self.register_buffer("num_freqs_per_band", num_freqs_per_band, persistent=False)
        self.register_buffer("num_bands_per_freq", num_bands_per_freq, persistent=False)

    elif isinstance(cfg.band_config, FixedBandsConfig):
        num_freqs_per_band = torch.tensor(cfg.band_config.freqs_per_bands)
        num_bands = len(cfg.band_config.freqs_per_bands)
    else:
        raise TypeError(f"unknown band config: {cfg.band_config}")
    self.is_mel = is_mel

    freq_rotary_embed = RotaryEmbedding(
        seq_len=num_bands, dim_head=cfg.dim_head, dtype=cfg.rotary_embed_dtype
    )

    for _ in range(cfg.depth):
        tran_modules = []
        if cfg.linear_transformer_depth > 0:
            tran_modules.append(
                transformer(depth=cfg.linear_transformer_depth, linear_attn=True)
            )
        tran_modules.append(
            transformer(depth=cfg.time_transformer_depth, rotary_embed=time_rotary_embed)
        )
        tran_modules.append(
            transformer(depth=cfg.freq_transformer_depth, rotary_embed=freq_rotary_embed)
        )
        self.layers.append(nn.ModuleList(tran_modules))

    self.final_norm = (
        rms_norm(cfg.dim, eps=cfg.rms_norm_eps) if not self.is_mel else nn.Identity()
    )

    freqs_per_bands_with_complex = tuple(
        2 * f * self.audio_channels for f in num_freqs_per_band.tolist()
    )

    self.band_split = BandSplit(
        dim=cfg.dim,
        dim_inputs=freqs_per_bands_with_complex,
        rms_norm_eps=cfg.rms_norm_eps,
    )

    self.mask_estimators = nn.ModuleList([])

    def build_hyperace(config: MaskEstimatorConfig) -> nn.Module:
        if isinstance(config, HyperAceResidualV1MaskEstimatorConfig):
            from .utils.hyperace import SegmModelHyperAceV1

            return SegmModelHyperAceV1(
                in_bands=len(freqs_per_bands_with_complex),
                in_dim=cfg.dim,
                out_bins=sum(freqs_per_bands_with_complex) // 4,
                num_hyperedges=config.num_hyperedges or 16,
                num_heads=config.num_heads,
            )
        if isinstance(config, HyperAceResidualV2MaskEstimatorConfig):
            from .utils.hyperace import SegmModelHyperAceV2

            return SegmModelHyperAceV2(
                in_bands=len(freqs_per_bands_with_complex),
                in_dim=cfg.dim,
                out_bins=sum(freqs_per_bands_with_complex) // 4,
                num_hyperedges=config.num_hyperedges or 32,
                num_heads=config.num_heads,
            )
        raise TypeError(f"mask estimator is not hyperace-based: {config}")

    def build_mask_estimator(config: MaskEstimatorConfig) -> nn.Module:
        if isinstance(config, BaselineMaskEstimatorConfig):
            return MaskEstimator(
                dim=cfg.dim,
                dim_inputs=freqs_per_bands_with_complex,
                depth=cfg.mask_estimator_depth,
                mlp_expansion_factor=cfg.mlp_expansion_factor,
            )

        if isinstance(config, AxialRefinerLargeV2MaskEstimatorConfig):
            return AxialRefinerLargeV2MaskEstimator(
                dim=cfg.dim,
                dim_inputs=freqs_per_bands_with_complex,
                mlp_depth=cfg.mask_estimator_depth,
                mlp_expansion_factor=cfg.mlp_expansion_factor,
                axial_refiner_depth=config.axial_refiner_depth,
                t_frames=t_frames,
                num_bands=num_bands,
                rotary_embed_dtype=cfg.rotary_embed_dtype,
            )

        if isinstance(
            config,
            HyperAceResidualV1MaskEstimatorConfig | HyperAceResidualV2MaskEstimatorConfig,
        ):
            return HyperAceResidualMaskEstimator(
                dim=cfg.dim,
                dim_inputs=freqs_per_bands_with_complex,
                depth=cfg.mask_estimator_depth,
                mlp_expansion_factor=cfg.mlp_expansion_factor,
                segm=build_hyperace(config),
            )

        raise TypeError(f"unknown mask_estimator config: {config}")

    for _ in range(len(cfg.output_stem_names)):
        self.mask_estimators.append(build_mask_estimator(cfg.mask_estimator))

    self.debug = cfg.debug
stereo instance-attribute
stereo = stereo
audio_channels instance-attribute
audio_channels = 2 if stereo else 1
num_stems instance-attribute
num_stems = len(output_stem_names)
use_torch_checkpoint instance-attribute
use_torch_checkpoint = use_torch_checkpoint
skip_connection instance-attribute
skip_connection = skip_connection
layers instance-attribute
layers = ModuleList([])
shared_qkv_bias instance-attribute
shared_qkv_bias: Parameter | None = None
shared_out_bias instance-attribute
shared_out_bias: Parameter | None = None
is_mel instance-attribute
is_mel = is_mel
final_norm instance-attribute
final_norm = (
    rms_norm(dim, eps=rms_norm_eps)
    if not is_mel
    else Identity()
)
band_split instance-attribute
band_split = BandSplit(
    dim=dim,
    dim_inputs=freqs_per_bands_with_complex,
    rms_norm_eps=rms_norm_eps,
)
mask_estimators instance-attribute
mask_estimators = ModuleList([])
debug instance-attribute
debug = debug
forward
forward(stft_repr: Tensor) -> Tensor

Parameters:

Name Type Description Default
stft_repr Tensor

input spectrogram. shape (b, f*s, t, c)

required

Returns:

Type Description
Tensor

estimated mask. shape (b, n, f*s, t, c)

Source code in src/splifft/models/bs_roformer.py
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
def forward(self, stft_repr: Tensor) -> Tensor:
    """
    :param stft_repr: input spectrogram. shape (b, f*s, t, c)
    :return: estimated mask. shape (b, n, f*s, t, c)
    """
    batch, _, t_frames, _ = stft_repr.shape
    device = stft_repr.device
    if self.is_mel:
        batch_arange = torch.arange(batch, device=device)[..., None]
        x = stft_repr[batch_arange, cast(Tensor, self.freq_indices)]
        x = rearrange(x, "b f t c -> b t (f c)")
    else:
        x = rearrange(stft_repr, "b f t c -> b t (f c)")

    if self.debug and (torch.isnan(x).any() or torch.isinf(x).any()):
        raise RuntimeError(
            f"nan/inf in x after rearrange: {x.isnan().sum()} nans, {x.isinf().sum()} infs"
        )

    if self.use_torch_checkpoint:
        x = cast(Tensor, checkpoint(self.band_split, x, use_reentrant=False))
    else:
        x = cast(Tensor, self.band_split(x))

    if self.debug and (torch.isnan(x).any() or torch.isinf(x).any()):
        raise RuntimeError(
            f"nan/inf in x after band_split: {x.isnan().sum()} nans, {x.isinf().sum()} infs"
        )

    # axial / hierarchical attention

    store: list[Tensor | None] = [None] * len(self.layers)
    for i, transformer_block in enumerate(self.layers):
        block = cast(ModuleList, transformer_block)
        if len(block) == 3:
            linear_transformer, time_transformer, freq_transformer = block

            x, ft_ps = pack([x], "b * d")
            if self.use_torch_checkpoint:
                x = checkpoint(linear_transformer, x, use_reentrant=False)
            else:
                x = linear_transformer(x)
            (x,) = unpack(x, ft_ps, "b * d")
        else:
            time_transformer, freq_transformer = block

        if self.skip_connection:
            for j in range(i):
                if store[j] is not None:
                    assert x is not None
                    x = x + cast(Tensor, store[j])

        x = rearrange(x, "b t f d -> b f t d")
        x, ps = pack([x], "* t d")

        if self.use_torch_checkpoint:
            x = checkpoint(time_transformer, x, use_reentrant=False)
        else:
            x = time_transformer(x)

        (x,) = unpack(x, ps, "* t d")
        x = rearrange(x, "b f t d -> b t f d")
        x, ps = pack([x], "* f d")

        if self.use_torch_checkpoint:
            x = checkpoint(freq_transformer, x, use_reentrant=False)
        else:
            x = freq_transformer(x)

        (x,) = unpack(x, ps, "* f d")

        if self.skip_connection:
            store[i] = x

    x = self.final_norm(x)

    if self.use_torch_checkpoint:
        mask = torch.stack(
            [
                cast(Tensor, checkpoint(fn, x, use_reentrant=False))
                for fn in self.mask_estimators
            ],
            dim=1,
        )
    else:
        mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
    mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)

    if not self.is_mel:
        return mask

    stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
    # stft_repr may be fp16 but complex32 support is experimental so we upcast it early
    stft_repr_complex = torch.view_as_complex(stft_repr.to(torch.float32))

    masks_per_band_complex = torch.view_as_complex(mask)
    masks_per_band_complex = masks_per_band_complex.type(stft_repr_complex.dtype)

    scatter_indices = repeat(
        cast(Tensor, self.freq_indices),
        "f -> b n f t",
        b=batch,
        n=self.num_stems,
        t=stft_repr_complex.shape[-1],
    )
    stft_repr_expanded_stems = repeat(stft_repr_complex, "b 1 ... -> b n ...", n=self.num_stems)

    masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(
        2, scatter_indices, masks_per_band_complex
    )

    denom = cast(Tensor, repeat(self.num_bands_per_freq, "f -> (f r) 1", r=self.audio_channels))
    masks_averaged = masks_summed / denom.clamp(min=1e-8)

    return torch.view_as_real(masks_averaged).to(stft_repr.dtype)
__splifft_stem_selection_plan__ classmethod
__splifft_stem_selection_plan__(
    model_params: BSRoformerParams,
    output_stem_names: tuple[ModelOutputStemName, ...],
) -> StemSelectionPlan[BSRoformerParams]

Remap mask_estimators.{i}.* state-dict entries to a compact [0..k) index range so unrelated per-stem heads are never instantiated or loaded.

Source code in src/splifft/models/bs_roformer.py
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
@classmethod
def __splifft_stem_selection_plan__(
    cls,
    model_params: BSRoformerParams,
    output_stem_names: tuple[t.ModelOutputStemName, ...],
) -> StemSelectionPlan[BSRoformerParams]:
    """Remap `mask_estimators.{i}.*` state-dict entries to a compact
    `[0..k)` index range so unrelated per-stem heads are never instantiated
    or loaded.
    """

    full_stem_names = tuple(model_params.output_stem_names)
    requested_stem_names = tuple(output_stem_names)
    if requested_stem_names == full_stem_names:
        return StemSelectionPlan(
            model_params=model_params,
            output_stem_names=full_stem_names,
        )

    selected_stem_indices = tuple(full_stem_names.index(name) for name in requested_stem_names)
    key_re = re.compile(r"^mask_estimators\.(\d+)\.(.+)$")
    index_remap = {src: dst for dst, src in enumerate(selected_stem_indices)}

    def state_dict_transform(state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
        remapped: dict[str, Tensor] = {}
        for key, value in state_dict.items():
            if (m := key_re.match(key)) is None:
                remapped[key] = value
                continue
            if (src_idx := int(m.group(1))) not in index_remap:
                continue
            suffix = m.group(2)
            remapped[f"mask_estimators.{index_remap[src_idx]}.{suffix}"] = value
        return remapped

    return StemSelectionPlan(
        model_params=replace(model_params, output_stem_names=requested_stem_names),
        output_stem_names=requested_stem_names,
        state_dict_transform=state_dict_transform,
    )

mdx23c

MDX23C.

See: https://arxiv.org/pdf/2306.09382

Classes:

Name Description
MDX23CParams
Upscale
Downscale
MDX23C

Functions:

Name Description
get_norm
get_act
build_tfc_tdf

MDX23CParams dataclass

MDX23CParams(
    chunk_size: ChunkSize,
    output_stem_names: tuple[ModelOutputStemName, ...],
    dim_f: Gt0[int],
    num_subbands: Gt0[int],
    num_scales: Gt0[int],
    scale: tuple[Gt0[int], ...],
    num_blocks_per_scale: Gt0[int],
    hidden_channels: Gt0[int],
    growth: Gt0[int],
    bottleneck_factor: Gt0[int],
    norm_type: Literal["BatchNorm", "InstanceNorm"]
    | str = "InstanceNorm",
    act_type: Literal["gelu", "relu", "elu"] | str = "gelu",
    stereo: bool = True,
)

Bases: ModelParamsLike

Attributes:

Name Type Description
chunk_size ChunkSize
output_stem_names tuple[ModelOutputStemName, ...]
dim_f Gt0[int]

The size of the frequency dimension fed into the network.

num_subbands Gt0[int]
num_scales Gt0[int]
scale tuple[Gt0[int], ...]

Downscaling factor per scale.

num_blocks_per_scale Gt0[int]
hidden_channels Gt0[int]

Base number of channels.

growth Gt0[int]

Channel growth per scale.

bottleneck_factor Gt0[int]
norm_type Literal['BatchNorm', 'InstanceNorm'] | str
act_type Literal['gelu', 'relu', 'elu'] | str
stereo bool
input_channels ModelInputChannels
input_type ModelInputType
output_type ModelOutputType
inference_archetype InferenceArchetype
chunk_size instance-attribute
chunk_size: ChunkSize
output_stem_names instance-attribute
output_stem_names: tuple[ModelOutputStemName, ...]
dim_f instance-attribute
dim_f: Gt0[int]

The size of the frequency dimension fed into the network. Usually smaller than n_fft // 2 + 1.

num_subbands instance-attribute
num_subbands: Gt0[int]
num_scales instance-attribute
num_scales: Gt0[int]
scale instance-attribute
scale: tuple[Gt0[int], ...]

Downscaling factor per scale.

num_blocks_per_scale instance-attribute
num_blocks_per_scale: Gt0[int]
hidden_channels instance-attribute
hidden_channels: Gt0[int]

Base number of channels.

growth instance-attribute
growth: Gt0[int]

Channel growth per scale.

bottleneck_factor instance-attribute
bottleneck_factor: Gt0[int]
norm_type class-attribute instance-attribute
norm_type: Literal["BatchNorm", "InstanceNorm"] | str = (
    "InstanceNorm"
)
act_type class-attribute instance-attribute
act_type: Literal['gelu', 'relu', 'elu'] | str = 'gelu'
stereo class-attribute instance-attribute
stereo: bool = True
input_channels property
input_channels: ModelInputChannels
input_type property
input_type: ModelInputType
output_type property
output_type: ModelOutputType
inference_archetype property
inference_archetype: InferenceArchetype

get_norm

get_norm(norm_type: str, channels: int) -> Module
Source code in src/splifft/models/mdx23c.py
61
62
63
64
65
66
67
68
69
def get_norm(norm_type: str, channels: int) -> nn.Module:
    if norm_type == "BatchNorm":
        return nn.BatchNorm2d(channels)
    elif norm_type == "InstanceNorm":
        return nn.InstanceNorm2d(channels, affine=True)
    elif "GroupNorm" in norm_type:
        g = int(norm_type.replace("GroupNorm", ""))
        return nn.GroupNorm(num_groups=g, num_channels=channels)
    return nn.Identity()

get_act

get_act(act_type: str) -> Module
Source code in src/splifft/models/mdx23c.py
72
73
74
75
76
77
78
79
80
81
82
83
def get_act(act_type: str) -> nn.Module:
    if act_type == "gelu":
        return nn.GELU()
    elif act_type == "relu":
        return nn.ReLU()
    elif act_type.startswith("elu"):
        try:
            alpha = float(act_type.replace("elu", ""))
        except ValueError:
            alpha = 1.0
        return nn.ELU(alpha)
    raise ValueError(f"unknown activation: {act_type}")

build_tfc_tdf

build_tfc_tdf(
    in_c: int,
    c: int,
    blocks_per_scale: int,
    f: int,
    bn: int,
    norm_type: str,
    act_type: str,
) -> TfcTdf
Source code in src/splifft/models/mdx23c.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def build_tfc_tdf(
    in_c: int,
    c: int,
    blocks_per_scale: int,
    f: int,
    bn: int,
    norm_type: str,
    act_type: str,
) -> TfcTdf:
    return TfcTdf(
        in_channels=in_c,
        out_channels=c,
        num_blocks=blocks_per_scale,
        f_bins=f,
        bottleneck_factor=bn,
        norm_factory=lambda channels: get_norm(norm_type, channels),
        act_factory=lambda: get_act(act_type),
    )

Upscale

Upscale(
    in_c: int,
    out_c: int,
    scale: tuple[int, int],
    norm_type: str,
    act_type: str,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
conv
Source code in src/splifft/models/mdx23c.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def __init__(
    self,
    in_c: int,
    out_c: int,
    scale: tuple[int, int],
    norm_type: str,
    act_type: str,
):
    super().__init__()
    self.conv = nn.Sequential(
        get_norm(norm_type, in_c),
        get_act(act_type),
        nn.ConvTranspose2d(
            in_channels=in_c,
            out_channels=out_c,
            kernel_size=scale,
            stride=scale,
            bias=False,
        ),
    )
conv instance-attribute
conv = Sequential(
    get_norm(norm_type, in_c),
    get_act(act_type),
    ConvTranspose2d(
        in_channels=in_c,
        out_channels=out_c,
        kernel_size=scale,
        stride=scale,
        bias=False,
    ),
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/mdx23c.py
128
129
def forward(self, x: Tensor) -> Tensor:
    return self.conv(x)  # type: ignore

Downscale

Downscale(
    in_c: int,
    out_c: int,
    scale: tuple[int, int],
    norm_type: str,
    act_type: str,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
conv
Source code in src/splifft/models/mdx23c.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def __init__(
    self,
    in_c: int,
    out_c: int,
    scale: tuple[int, int],
    norm_type: str,
    act_type: str,
):
    super().__init__()
    self.conv = nn.Sequential(
        get_norm(norm_type, in_c),
        get_act(act_type),
        nn.Conv2d(
            in_channels=in_c,
            out_channels=out_c,
            kernel_size=scale,
            stride=scale,
            bias=False,
        ),
    )
conv instance-attribute
conv = Sequential(
    get_norm(norm_type, in_c),
    get_act(act_type),
    Conv2d(
        in_channels=in_c,
        out_channels=out_c,
        kernel_size=scale,
        stride=scale,
        bias=False,
    ),
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/mdx23c.py
154
155
def forward(self, x: Tensor) -> Tensor:
    return self.conv(x)  # type: ignore

MDX23C

MDX23C(cfg: MDX23CParams)

Bases: Module, SupportsStemSelection[MDX23CParams]

Methods:

Name Description
cac2cws
cws2cac
forward

:param x: input spectrogram (B, F*S, T, 2)

__splifft_stem_selection_plan__

Slice final_conv.2 per requested stem.

Attributes:

Name Type Description
cfg
num_target_instruments
audio_channels
num_subbands
first_conv
encoder_blocks
bottleneck_block
decoder_blocks
final_conv
Source code in src/splifft/models/mdx23c.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def __init__(self, cfg: MDX23CParams):
    super().__init__()
    self.cfg = cfg
    if len(cfg.scale) != 2:
        raise ValueError(f"expected `scale` to have 2 elements, got {cfg.scale}")
    scale_2d = (cfg.scale[0], cfg.scale[1])
    self.num_target_instruments = len(cfg.output_stem_names)
    self.audio_channels = 2 if cfg.stereo else 1
    self.num_subbands = cfg.num_subbands

    dim_c = self.num_subbands * self.audio_channels * 2
    n = cfg.num_scales
    blocks_per_scale = cfg.num_blocks_per_scale
    c = cfg.hidden_channels
    g = cfg.growth
    bn = cfg.bottleneck_factor
    f = cfg.dim_f // self.num_subbands

    self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)

    self.encoder_blocks = nn.ModuleList()
    for _ in range(n):
        block = nn.Module()
        block.tfc_tdf = build_tfc_tdf(
            c, c, blocks_per_scale, f, bn, cfg.norm_type, cfg.act_type
        )
        block.downscale = Downscale(c, c + g, scale_2d, cfg.norm_type, cfg.act_type)
        f = f // scale_2d[1]
        c += g
        self.encoder_blocks.append(block)

    self.bottleneck_block = build_tfc_tdf(
        c, c, blocks_per_scale, f, bn, cfg.norm_type, cfg.act_type
    )

    self.decoder_blocks = nn.ModuleList()
    for _ in range(n):
        block = nn.Module()
        block.upscale = Upscale(c, c - g, scale_2d, cfg.norm_type, cfg.act_type)
        f = f * scale_2d[1]
        c -= g
        block.tfc_tdf = build_tfc_tdf(
            2 * c, c, blocks_per_scale, f, bn, cfg.norm_type, cfg.act_type
        )
        self.decoder_blocks.append(block)

    self.final_conv = nn.Sequential(
        nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
        get_act(cfg.act_type),
        nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False),
    )
cfg instance-attribute
cfg = cfg
num_target_instruments instance-attribute
num_target_instruments = len(output_stem_names)
audio_channels instance-attribute
audio_channels = 2 if stereo else 1
num_subbands instance-attribute
num_subbands = num_subbands
first_conv instance-attribute
first_conv = Conv2d(dim_c, c, 1, 1, 0, bias=False)
encoder_blocks instance-attribute
encoder_blocks = ModuleList()
bottleneck_block instance-attribute
bottleneck_block = build_tfc_tdf(
    c, c, blocks_per_scale, f, bn, norm_type, act_type
)
decoder_blocks instance-attribute
decoder_blocks = ModuleList()
final_conv instance-attribute
final_conv = Sequential(
    Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
    get_act(act_type),
    Conv2d(
        c,
        num_target_instruments * dim_c,
        1,
        1,
        0,
        bias=False,
    ),
)
cac2cws
cac2cws(x: Tensor) -> Tensor
Source code in src/splifft/models/mdx23c.py
211
212
def cac2cws(self, x: Tensor) -> Tensor:
    return rearrange(x, "b c (k f) t -> b (c k) f t", k=self.num_subbands)
cws2cac
cws2cac(x: Tensor) -> Tensor
Source code in src/splifft/models/mdx23c.py
214
215
def cws2cac(self, x: Tensor) -> Tensor:
    return rearrange(x, "b (c k) f t -> b c (k f) t", k=self.num_subbands)
forward
forward(x: Tensor) -> Tensor

Parameters:

Name Type Description Default
x Tensor

input spectrogram (B, F*S, T, 2)

required

Returns:

Type Description
Tensor

output spectrogram (B, N, F*S, T, 2)

Source code in src/splifft/models/mdx23c.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def forward(self, x: Tensor) -> Tensor:
    """
    :param x: input spectrogram (B, F*S, T, 2)
    :return: output spectrogram (B, N, F*S, T, 2)
    """
    b, fs, t, ri = x.shape
    if ri != 2:
        raise ValueError(f"expected final complex axis of size 2, got {ri}")
    f_full = fs // self.audio_channels
    if fs % self.audio_channels != 0:
        raise ValueError(
            f"expected frequency-channel axis divisible by audio channels ({self.audio_channels}), "
            f"got {fs}"
        )

    x_in = rearrange(
        x,
        "b (f s) t ri -> b (s ri) f t",
        s=self.audio_channels,
        ri=2,
    )
    x_in = x_in[..., : self.cfg.dim_f, :]
    mix = x_in = self.cac2cws(x_in)
    first_conv_out = x_in = self.first_conv(x_in)
    x_in = rearrange(x_in, "b c f t -> b c t f")

    encoder_outputs = []
    for block in self.encoder_blocks:
        x_in = block.tfc_tdf(x_in)  # type: ignore
        encoder_outputs.append(x_in)
        x_in = block.downscale(x_in)  # type: ignore

    x_in = self.bottleneck_block(x_in)

    for block in self.decoder_blocks:
        x_in = block.upscale(x_in)  # type: ignore
        x_in = torch.cat([x_in, encoder_outputs.pop()], 1)
        x_in = block.tfc_tdf(x_in)  # type: ignore

    x_in = rearrange(x_in, "b c t f -> b c f t")
    x_in = x_in * first_conv_out
    x_in = self.final_conv(torch.cat([mix, x_in], 1))
    x_in = self.cws2cac(x_in)

    x_in = rearrange(
        x_in,
        "b (n c) f t -> b n c f t",
        n=self.num_target_instruments,
    )

    if f_full > self.cfg.dim_f:
        pad_size = f_full - self.cfg.dim_f
        x_in = torch.nn.functional.pad(x_in, (0, 0, 0, pad_size))

    x_in = rearrange(
        x_in,
        "b n (s ri) f t -> b n (f s) t ri",
        s=self.audio_channels,
        ri=2,
    )

    return x_in
__splifft_stem_selection_plan__ classmethod
__splifft_stem_selection_plan__(
    model_params: MDX23CParams,
    output_stem_names: tuple[ModelOutputStemName, ...],
) -> StemSelectionPlan[MDX23CParams]

Slice final_conv.2 per requested stem.

Source code in src/splifft/models/mdx23c.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
@classmethod
def __splifft_stem_selection_plan__(
    cls,
    model_params: MDX23CParams,
    output_stem_names: tuple[t.ModelOutputStemName, ...],
) -> StemSelectionPlan[MDX23CParams]:
    """Slice `final_conv.2` per requested stem."""

    full_stem_names = tuple(model_params.output_stem_names)
    requested_stem_names = tuple(output_stem_names)
    if requested_stem_names == full_stem_names:
        return StemSelectionPlan(
            model_params=model_params,
            output_stem_names=full_stem_names,
        )

    selected_stem_indices = tuple(full_stem_names.index(name) for name in requested_stem_names)
    dim_c = model_params.num_subbands * (2 if model_params.stereo else 1) * 2

    def state_dict_transform(state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
        key = "final_conv.2.weight"
        if key not in state_dict:
            return state_dict

        transformed = dict(state_dict)
        weight = transformed[key]
        grouped = weight.reshape(len(full_stem_names), dim_c, *weight.shape[1:])
        selected = grouped.index_select(
            0,
            torch.tensor(selected_stem_indices, device=weight.device),
        )
        transformed[key] = selected.reshape(
            len(requested_stem_names) * dim_c, *weight.shape[1:]
        )
        return transformed

    return StemSelectionPlan(
        model_params=replace(model_params, output_stem_names=requested_stem_names),
        output_stem_names=requested_stem_names,
        state_dict_transform=state_dict_transform,
    )

basic_pitch

ICASSP 2022 Basic Pitch. Raw multi-stream outputs only, no symbolic decoding.

See: https://github.com/spotify/basic-pitch, https://arxiv.org/abs/2203.09893

Classes:

Name Description
BasicPitchParams
HarmonicStacking
BasicPitch

BasicPitchParams dataclass

BasicPitchParams(
    chunk_size: ChunkSize,
    output_stem_names: tuple[ModelOutputStemName, ...],
    n_semitones: Gt0[int] = 88,
    contour_bins_per_semitone: Gt0[int] = 3,
    cqt_bins_per_semitone: Gt0[int] = 3,
    cqt_n_bins: Gt0[int] = 372,
    stack_harmonics: tuple[Gt0[float], ...] = (
        0.5,
        1.0,
        2.0,
        3.0,
        4.0,
        5.0,
        6.0,
        7.0,
    ),
)

Bases: ModelParamsLike

Attributes:

Name Type Description
chunk_size ChunkSize
output_stem_names tuple[ModelOutputStemName, ...]
n_semitones Gt0[int]
contour_bins_per_semitone Gt0[int]
cqt_bins_per_semitone Gt0[int]
cqt_n_bins Gt0[int]
stack_harmonics tuple[Gt0[float], ...]
input_channels ModelInputChannels
input_type ModelInputType
output_type ModelOutputType
inference_archetype InferenceArchetype
chunk_size instance-attribute
chunk_size: ChunkSize
output_stem_names instance-attribute
output_stem_names: tuple[ModelOutputStemName, ...]
n_semitones class-attribute instance-attribute
n_semitones: Gt0[int] = 88
contour_bins_per_semitone class-attribute instance-attribute
contour_bins_per_semitone: Gt0[int] = 3
cqt_bins_per_semitone class-attribute instance-attribute
cqt_bins_per_semitone: Gt0[int] = 3
cqt_n_bins class-attribute instance-attribute
cqt_n_bins: Gt0[int] = 372
stack_harmonics class-attribute instance-attribute
stack_harmonics: tuple[Gt0[float], ...] = (
    0.5,
    1.0,
    2.0,
    3.0,
    4.0,
    5.0,
    6.0,
    7.0,
)
input_channels property
input_channels: ModelInputChannels
input_type property
input_type: ModelInputType
output_type property
output_type: ModelOutputType
inference_archetype property
inference_archetype: InferenceArchetype

HarmonicStacking

HarmonicStacking(
    *,
    bins_per_semitone: int,
    harmonics: tuple[float, ...],
    n_output_freqs: int,
)

Bases: Module

Methods:

Name Description
forward

:param x: (B, T, F)

Attributes:

Name Type Description
n_output_freqs
shifts
Source code in src/splifft/models/basic_pitch.py
51
52
53
54
55
56
57
58
59
60
def __init__(
    self,
    *,
    bins_per_semitone: int,
    harmonics: tuple[float, ...],
    n_output_freqs: int,
):
    super().__init__()
    self.n_output_freqs = n_output_freqs
    self.shifts = [int(round(12.0 * bins_per_semitone * math.log2(h))) for h in harmonics]
n_output_freqs instance-attribute
n_output_freqs = n_output_freqs
shifts instance-attribute
shifts = [
    (int(round(12.0 * bins_per_semitone * log2(h))))
    for h in harmonics
]
forward
forward(x: Tensor) -> Tensor

Parameters:

Name Type Description Default
x Tensor

(B, T, F)

required

Returns:

Type Description
Tensor

(B, H, T, F_out)

Source code in src/splifft/models/basic_pitch.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    :param x: (B, T, F)
    :return: (B, H, T, F_out)
    """
    stacked: list[torch.Tensor] = []
    for shift in self.shifts:
        if shift == 0:
            shifted = x
        elif shift > 0:
            shifted = F.pad(x[:, :, shift:], (0, shift))
        else:
            shifted = F.pad(x[:, :, :shift], (-shift, 0))
        stacked.append(shifted[:, :, : self.n_output_freqs])
    return torch.stack(stacked, dim=1)

BasicPitch

BasicPitch(cfg: BasicPitchParams)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
cfg
n_contour_bins
hs
conv_contour
conv_note
conv_onset_pre
conv_onset_post
Source code in src/splifft/models/basic_pitch.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def __init__(self, cfg: BasicPitchParams):
    super().__init__()
    self.cfg = cfg
    self.n_contour_bins = cfg.n_semitones * cfg.contour_bins_per_semitone

    self.hs = HarmonicStacking(
        bins_per_semitone=cfg.cqt_bins_per_semitone,
        harmonics=cfg.stack_harmonics,
        n_output_freqs=self.n_contour_bins,
    )

    num_in_channels = len(cfg.stack_harmonics)
    self.conv_contour = nn.Sequential(
        nn.Conv2d(num_in_channels, 8, kernel_size=(3, 39), padding="same"),
        nn.BatchNorm2d(8, eps=0.001),
        nn.ReLU(),
        nn.Conv2d(8, 1, kernel_size=5, padding="same"),
        nn.Sigmoid(),
    )
    self.conv_note = nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=7, stride=(1, 3)),
        nn.ReLU(),
        nn.Conv2d(32, 1, kernel_size=(7, 3), padding="same"),
        nn.Sigmoid(),
    )
    self.conv_onset_pre = nn.Sequential(
        nn.Conv2d(num_in_channels, 32, kernel_size=5, stride=(1, 3)),
        nn.BatchNorm2d(32, eps=0.001),
        nn.ReLU(),
    )
    self.conv_onset_post = nn.Sequential(
        nn.Conv2d(33, 1, kernel_size=3, stride=1, padding="same"),
        nn.Sigmoid(),
    )
cfg instance-attribute
cfg = cfg
n_contour_bins instance-attribute
n_contour_bins = n_semitones * contour_bins_per_semitone
hs instance-attribute
hs = HarmonicStacking(
    bins_per_semitone=cqt_bins_per_semitone,
    harmonics=stack_harmonics,
    n_output_freqs=n_contour_bins,
)
conv_contour instance-attribute
conv_contour = Sequential(
    Conv2d(
        num_in_channels,
        8,
        kernel_size=(3, 39),
        padding="same",
    ),
    BatchNorm2d(8, eps=0.001),
    ReLU(),
    Conv2d(8, 1, kernel_size=5, padding="same"),
    Sigmoid(),
)
conv_note instance-attribute
conv_note = Sequential(
    Conv2d(1, 32, kernel_size=7, stride=(1, 3)),
    ReLU(),
    Conv2d(32, 1, kernel_size=(7, 3), padding="same"),
    Sigmoid(),
)
conv_onset_pre instance-attribute
conv_onset_pre = Sequential(
    Conv2d(
        num_in_channels, 32, kernel_size=5, stride=(1, 3)
    ),
    BatchNorm2d(32, eps=0.001),
    ReLU(),
)
conv_onset_post instance-attribute
conv_onset_post = Sequential(
    Conv2d(33, 1, kernel_size=3, stride=1, padding="same"),
    Sigmoid(),
)
forward
forward(x: Tensor) -> dict[str, Tensor]
Source code in src/splifft/models/basic_pitch.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
    if x.ndim != 3:
        raise ValueError(f"expected `(B,T,F)` input, got {tuple(x.shape)}")
    if x.shape[-1] != self.cfg.cqt_n_bins:
        raise ValueError(f"expected feature dim {self.cfg.cqt_n_bins}, got {x.shape[-1]}")

    cqt = self.hs(x)

    contour = self.conv_contour(cqt)

    contour_for_note = F.pad(contour, (2, 2, 3, 3))
    note = self.conv_note(contour_for_note)

    cqt_for_onset = F.pad(cqt, (1, 1, 2, 2))
    onset_pre = self.conv_onset_pre(cqt_for_onset)
    onset_in = torch.cat((note, onset_pre), dim=1)
    onset = self.conv_onset_post(onset_in)

    contour_out = contour.squeeze(1)
    note_out = note.squeeze(1)
    onset_out = onset.squeeze(1)

    return {
        "onset": onset_out,
        "note": note_out,
        "contour": contour_out,
    }

utils

Modules:

Name Description
attend
attend_sage
hyperace

HyperACE segmentation backbones for BS-RoFormer mask heads.

stft
tfc_tdf

Time-Frequency Convolutions and Time-Distributed Fully-connected (TFC-TDF)

Functions:

Name Description
parse_version
log_once

parse_version

parse_version(v_str: str) -> tuple[int, ...]
Source code in src/splifft/models/utils/__init__.py
6
7
8
def parse_version(v_str: str) -> tuple[int, ...]:
    # e.g "2.1.0+cu118" -> (2, 1, 0)
    return tuple(map(int, v_str.split("+")[0].split(".")))

log_once cached

log_once(
    logger: Logger, msg: object, *, level: int = DEBUG
) -> None
Source code in src/splifft/models/utils/__init__.py
11
12
13
@lru_cache(10)
def log_once(logger: Logger, msg: object, *, level: int = logging.DEBUG) -> None:
    logger.log(level, msg)

tfc_tdf

Time-Frequency Convolutions and Time-Distributed Fully-connected (TFC-TDF)

See: https://arxiv.org/pdf/2306.09382

Classes:

Name Description
TfcTdfBlock
TfcTdf

Functions:

Name Description
instance_norm_factory
silu_factory

Attributes:

Name Type Description
NormFactory
ActFactory
NormFactory module-attribute
NormFactory = Callable[[int], Module]
ActFactory module-attribute
ActFactory = Callable[[], Module]
instance_norm_factory
instance_norm_factory(channels: int) -> Module
Source code in src/splifft/models/utils/tfc_tdf.py
16
17
def instance_norm_factory(channels: int) -> nn.Module:
    return nn.InstanceNorm2d(channels, affine=True, eps=1e-8)
silu_factory
silu_factory() -> Module
Source code in src/splifft/models/utils/tfc_tdf.py
20
21
def silu_factory() -> nn.Module:
    return nn.SiLU()
TfcTdfBlock
TfcTdfBlock(
    in_channels: int,
    out_channels: int,
    f_bins: int,
    bottleneck_factor: int,
    *,
    norm_factory: NormFactory,
    act_factory: ActFactory,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
tfc1
tdf
tfc2
shortcut
Source code in src/splifft/models/utils/tfc_tdf.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    f_bins: int,
    bottleneck_factor: int,
    *,
    norm_factory: NormFactory,
    act_factory: ActFactory,
):
    super().__init__()

    self.tfc1 = nn.Sequential(
        norm_factory(in_channels),
        act_factory(),
        nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
    )
    self.tdf = nn.Sequential(
        norm_factory(out_channels),
        act_factory(),
        nn.Linear(f_bins, f_bins // bottleneck_factor, bias=False),
        norm_factory(out_channels),
        act_factory(),
        nn.Linear(f_bins // bottleneck_factor, f_bins, bias=False),
    )
    self.tfc2 = nn.Sequential(
        norm_factory(out_channels),
        act_factory(),
        nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
    )
    self.shortcut = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)
tfc1 instance-attribute
tfc1 = Sequential(
    norm_factory(in_channels),
    act_factory(),
    Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
)
tdf instance-attribute
tdf = Sequential(
    norm_factory(out_channels),
    act_factory(),
    Linear(f_bins, f_bins // bottleneck_factor, bias=False),
    norm_factory(out_channels),
    act_factory(),
    Linear(f_bins // bottleneck_factor, f_bins, bias=False),
)
tfc2 instance-attribute
tfc2 = Sequential(
    norm_factory(out_channels),
    act_factory(),
    Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
)
shortcut instance-attribute
shortcut = Conv2d(
    in_channels, out_channels, 1, 1, 0, bias=False
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/utils/tfc_tdf.py
57
58
59
60
61
62
63
def forward(self, x: Tensor) -> Tensor:
    s = self.shortcut(x)
    x = self.tfc1(x)
    x = x + self.tdf(x)
    x = self.tfc2(x)
    x = x + s
    return x
TfcTdf
TfcTdf(
    in_channels: int,
    out_channels: int,
    num_blocks: int,
    f_bins: int,
    bottleneck_factor: int,
    *,
    norm_factory: NormFactory,
    act_factory: ActFactory,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
blocks
Source code in src/splifft/models/utils/tfc_tdf.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    num_blocks: int,
    f_bins: int,
    bottleneck_factor: int,
    *,
    norm_factory: NormFactory,
    act_factory: ActFactory,
):
    super().__init__()

    self.blocks = nn.ModuleList(
        [
            TfcTdfBlock(
                in_channels=in_channels if i == 0 else out_channels,
                out_channels=out_channels,
                f_bins=f_bins,
                bottleneck_factor=bottleneck_factor,
                norm_factory=norm_factory,
                act_factory=act_factory,
            )
            for i in range(num_blocks)
        ]
    )
blocks instance-attribute
blocks = ModuleList(
    [
        (
            TfcTdfBlock(
                in_channels=in_channels
                if i == 0
                else out_channels,
                out_channels=out_channels,
                f_bins=f_bins,
                bottleneck_factor=bottleneck_factor,
                norm_factory=norm_factory,
                act_factory=act_factory,
            )
        )
        for i in (range(num_blocks))
    ]
)
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/utils/tfc_tdf.py
94
95
96
97
def forward(self, x: Tensor) -> Tensor:
    for block in self.blocks:
        x = block(x)
    return x

attend_sage

Classes:

Name Description
AttendSage

Attributes:

Name Type Description
logger
logger module-attribute
logger = getLogger(__name__)
AttendSage
AttendSage(
    dropout: float = 0.0,
    flash: bool = False,
    scale: float | None = None,
)

Bases: Module

Parameters:

Name Type Description Default
flash bool

if True, attempts to use SageAttention or PyTorch SDPA.

False

Methods:

Name Description
forward

einstein notation

Attributes:

Name Type Description
scale
dropout
use_sage
use_pytorch_sdpa
attn_dropout
Source code in src/splifft/models/utils/attend_sage.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def __init__(
    self,
    dropout: float = 0.0,
    flash: bool = False,
    scale: float | None = None,
):
    """
    :param flash: if True, attempts to use SageAttention or PyTorch SDPA.
    """
    super().__init__()
    self.scale = scale  # for einsum path
    self.dropout = dropout  # for einsum/SDPA path

    self.use_sage = flash and _has_sage_attention
    self.use_pytorch_sdpa = False
    self._sdpa_checked = False

    if flash and not self.use_sage:
        if not self._sdpa_checked:
            if parse_version(torch.__version__) >= (2, 0, 0):
                self.use_pytorch_sdpa = True
                log_once(
                    logger,
                    "Using PyTorch SDPA backend (FlashAttention-2, Memory-Efficient, or Math).",
                )
            else:
                log_once(
                    logger,
                    "Flash attention requested but Pytorch < 2.0 and SageAttention not found. Falling back to einsum.",
                )
            self._sdpa_checked = True

    # dropout layer for manual einsum implementation ONLY
    # SDPA and SageAttention handle dropout differently
    # (or not at all in Sage's base API)
    self.attn_dropout = nn.Dropout(dropout)
scale instance-attribute
scale = scale
dropout instance-attribute
dropout = dropout
use_sage instance-attribute
use_sage = flash and _has_sage_attention
use_pytorch_sdpa instance-attribute
use_pytorch_sdpa = False
attn_dropout instance-attribute
attn_dropout = Dropout(dropout)
forward
forward(q: Tensor, k: Tensor, v: Tensor) -> Tensor

einstein notation

  • b: batch
  • h: heads
  • n, i, j: sequence length (base sequence length, source, target)
  • d: feature dimension

Input tensors q, k, v expected in shape: (batch, heads, seq_len, dim_head) -> HND layout

Source code in src/splifft/models/utils/attend_sage.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
    """
    einstein notation

    - b: batch
    - h: heads
    - n, i, j: sequence length (base sequence length, source, target)
    - d: feature dimension

    Input tensors q, k, v expected in shape: (batch, heads, seq_len, dim_head) -> HND layout
    """
    _q_len, _k_len, _device = q.shape[-2], k.shape[-2], q.device

    # priority 1: SageAttention
    if self.use_sage:
        # assumes q, k, v are FP16/BF16 (handled by autocast upstream)
        # assumes scale is handled internally by sageattn
        # assumes dropout is NOT handled by sageattn kernel
        # is_causal=False based on how Attend is called in mel_band_roformer
        out = sageattn(q, k, v, tensor_layout="HND", is_causal=False)  # type: ignore
        return out  # type: ignore
        try:
            out = sageattn(q, k, v, tensor_layout="HND", is_causal=False)
            return out
        except Exception as e:
            logger.error(f"SageAttention failed with error: {e}. Falling back.")
            self.use_sage = False
            if not self._sdpa_checked:
                if parse_version(torch.__version__) >= (2, 0, 0):
                    self.use_pytorch_sdpa = True
                    log_once(logger, "falling back to PyTorch SDPA")
                else:
                    log_once(logger, "falling back to einsum.")

                self._sdpa_checked = True

    # priority 2: PyTorch SDPA
    if self.use_pytorch_sdpa:
        # it handles scaling and dropout internally.
        try:
            with sdpa_kernel(
                [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
            ):
                out = F.scaled_dot_product_attention(
                    q,
                    k,
                    v,
                    attn_mask=None,  # assuming no explicit mask needed here
                    dropout_p=self.dropout if self.training else 0.0,
                    is_causal=False,  # assuming not needed based on usage context
                )
            return out
        except Exception as e:
            log_once(
                logger,
                f"pytorch SDPA failed with error: {e}. falling back to einsum.",
                level=logging.ERROR,
            )
            self.use_pytorch_sdpa = False

    scale = self.scale or q.shape[-1] ** -0.5

    # similarity
    sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale

    # attention
    attn = sim.softmax(dim=-1)
    attn = self.attn_dropout(attn)  # ONLY in einsum path

    # aggregate values
    out = einsum("b h i j, b h j d -> b h i d", attn, v)

    return out

attend

Classes:

Name Description
Attend

Attributes:

Name Type Description
logger
logger module-attribute
logger = getLogger(__name__)
Attend
Attend(
    dropout: float = 0.0,
    flash: bool = False,
    scale: float | None = None,
)

Bases: Module

Methods:

Name Description
flash_attn
forward

einstein notation

Attributes:

Name Type Description
scale
dropout
attn_dropout
flash
cpu_backends
cuda_backends list[_SDPBackend] | None
Source code in src/splifft/models/utils/attend.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def __init__(
    self, dropout: float = 0.0, flash: bool = False, scale: float | None = None
) -> None:
    super().__init__()
    self.scale = scale
    self.dropout = dropout
    self.attn_dropout = nn.Dropout(dropout)

    self.flash = flash
    assert not (flash and parse_version(torch.__version__) < (2, 0, 0)), (
        "expected pytorch >= 2.0.0 to use flash attention"
    )

    # determine efficient attention configs for cuda and cpu
    self.cpu_backends = [
        SDPBackend.FLASH_ATTENTION,
        SDPBackend.EFFICIENT_ATTENTION,
        SDPBackend.MATH,
    ]
    self.cuda_backends: list[_SDPBackend] | None = None

    if not torch.cuda.is_available() or not flash:
        return

    device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
    device_version = parse_version(f"{device_properties.major}.{device_properties.minor}")

    if device_version >= (8, 0):
        if os.name == "nt":
            cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
            log_once(logger, f"windows detected, using {cuda_backends=}")
        else:
            cuda_backends = [SDPBackend.FLASH_ATTENTION, SDPBackend.MATH]
            log_once(logger, f"gpu compute capability >= 8.0, using {cuda_backends=}")
    else:
        cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
        log_once(logger, f"gpu compute capability < 8.0, using {cuda_backends=}")

    self.cuda_backends = cuda_backends
scale instance-attribute
scale = scale
dropout instance-attribute
dropout = dropout
attn_dropout instance-attribute
attn_dropout = Dropout(dropout)
flash instance-attribute
flash = flash
cpu_backends instance-attribute
cpu_backends = [FLASH_ATTENTION, EFFICIENT_ATTENTION, MATH]
cuda_backends instance-attribute
cuda_backends: list[_SDPBackend] | None = cuda_backends
flash_attn
flash_attn(q: Tensor, k: Tensor, v: Tensor) -> Tensor
Source code in src/splifft/models/utils/attend.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def flash_attn(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
    _, _heads, _q_len, _, _k_len, is_cuda, _device = (
        *q.shape,
        k.shape[-2],
        q.is_cuda,
        q.device,
    )  # type: ignore

    if self.scale is not None:
        default_scale = q.shape[-1] ** -0.5
        q = q * (self.scale / default_scale)

    backends = self.cuda_backends if is_cuda else self.cpu_backends
    # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
    with sdpa_kernel(backends=backends):  # type: ignore
        out = F.scaled_dot_product_attention(
            q, k, v, dropout_p=self.dropout if self.training else 0.0
        )

    return out
forward
forward(q: Tensor, k: Tensor, v: Tensor) -> Tensor

einstein notation

  • b: batch
  • h: heads
  • n, i, j: sequence length (base sequence length, source, target)
  • d: feature dimension
Source code in src/splifft/models/utils/attend.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
    """
    einstein notation

    - b: batch
    - h: heads
    - n, i, j: sequence length (base sequence length, source, target)
    - d: feature dimension
    """
    _q_len, _k_len, _device = q.shape[-2], k.shape[-2], q.device

    scale = self.scale or q.shape[-1] ** -0.5

    if self.flash:
        return self.flash_attn(q, k, v)

    # similarity
    sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale

    # attention
    attn = sim.softmax(dim=-1)
    attn = self.attn_dropout(attn)

    # aggregate values
    out = einsum("b h i j, b h j d -> b h i d", attn, v)

    return out

stft

Classes:

Name Description
Stft

A custom STFT implementation using 1D convolutions to ensure compatibility with CoreML.

IStft

A simple wrapper around torch.istft with a hacky workaround for MPS.

Stft
Stft(
    n_fft: int,
    hop_length: int,
    win_length: int,
    window_fn: Callable[[int], Tensor],
    conv_dtype: dtype | None,
)

Bases: Module

A custom STFT implementation using 1D convolutions to ensure compatibility with CoreML.

Methods:

Name Description
forward

Attributes:

Name Type Description
n_fft
hop_length
win_length
conv_dtype
Source code in src/splifft/models/utils/stft.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def __init__(
    self,
    n_fft: int,
    hop_length: int,
    win_length: int,
    window_fn: Callable[[int], Tensor],
    conv_dtype: torch.dtype | None,
):
    super().__init__()
    self.n_fft = n_fft
    self.hop_length = hop_length
    self.win_length = win_length
    self.conv_dtype = conv_dtype

    window = window_fn(self.win_length)

    dft_mat = torch.fft.fft(torch.eye(self.n_fft, device=window.device))
    dft_mat_T = dft_mat.T

    real_kernels = dft_mat_T.real[
        : self.win_length, : (self.n_fft // 2 + 1)
    ] * window.unsqueeze(-1)
    imag_kernels = dft_mat_T.imag[
        : self.win_length, : (self.n_fft // 2 + 1)
    ] * window.unsqueeze(-1)

    # (out_channels, in_channels, kernel_size)
    self.register_buffer("real_conv_weight", real_kernels.T.unsqueeze(1).to(self.conv_dtype))
    self.register_buffer("imag_conv_weight", imag_kernels.T.unsqueeze(1).to(self.conv_dtype))
n_fft instance-attribute
n_fft = n_fft
hop_length instance-attribute
hop_length = hop_length
win_length instance-attribute
win_length = win_length
conv_dtype instance-attribute
conv_dtype = conv_dtype
forward
forward(x: Tensor) -> ComplexSpectrogram
Source code in src/splifft/models/utils/stft.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def forward(self, x: Tensor) -> t.ComplexSpectrogram:
    b, s, t = x.shape
    x = x.reshape(b * s, 1, t).to(self.conv_dtype)

    padding = self.n_fft // 2
    x = F.pad(x, (padding, padding), "reflect")

    real_part = F.conv1d(x, self.real_conv_weight, stride=self.hop_length)  # type: ignore
    imag_part = F.conv1d(x, self.imag_conv_weight, stride=self.hop_length)  # type: ignore
    spec = torch.stack((real_part, imag_part), dim=-1)  # (b*s, f, t_frames, c=2)

    _bs, f, t_frames, c = spec.shape
    spec = spec.view(b, s, f, t_frames, c)

    return spec  # type: ignore
IStft
IStft(
    n_fft: int,
    hop_length: int,
    win_length: int,
    window_fn: Callable[[int], Tensor] = hann_window,
)

Bases: Module

A simple wrapper around torch.istft with a hacky workaround for MPS.

TODO: implement a proper workaround.

Methods:

Name Description
forward

Attributes:

Name Type Description
n_fft
hop_length
win_length
window
Source code in src/splifft/models/utils/stft.py
69
70
71
72
73
74
75
76
77
78
79
80
def __init__(
    self,
    n_fft: int,
    hop_length: int,
    win_length: int,
    window_fn: Callable[[int], Tensor] = torch.hann_window,
):
    super().__init__()
    self.n_fft = n_fft
    self.hop_length = hop_length
    self.win_length = win_length
    self.window = window_fn(self.win_length)
n_fft instance-attribute
n_fft = n_fft
hop_length instance-attribute
hop_length = hop_length
win_length instance-attribute
win_length = win_length
window instance-attribute
window = window_fn(win_length)
forward
forward(
    spec: ComplexSpectrogram, length: int | None = None
) -> RawAudioTensor | NormalizedAudioTensor
Source code in src/splifft/models/utils/stft.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def forward(
    self, spec: t.ComplexSpectrogram, length: int | None = None
) -> t.RawAudioTensor | t.NormalizedAudioTensor:
    device = spec.device
    is_mps = device.type == "mps"
    window = self.window.to(device)
    # see https://github.com/lucidrains/BS-RoFormer/issues/47
    # this would introduce a breaking change.
    # spec = spec.index_fill(1, torch.tensor(0, device=spec.device), 0.)  # type: ignore
    spec_complex = torch.view_as_complex(spec)

    try:
        audio = torch.istft(
            spec_complex,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window=window,
            return_complex=False,
            length=length,
        )
    except RuntimeError:
        audio = torch.istft(
            spec_complex.cpu() if is_mps else spec_complex,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window=window.cpu() if is_mps else window,
            return_complex=False,
            length=length,
        ).to(device)

    return audio  # type: ignore

hyperace

HyperACE segmentation backbones for BS-RoFormer mask heads.

These modules are compatibility shims for unwa variants trained in msst (hyperace_v1, hyperace_v2, and large_inst_v2 head behavior). They are kept separate from the core transformer stack because they are used only by a small subset of checkpoints.

See: https://huggingface.co/pcunwa/BS-Roformer-HyperACE and https://arxiv.org/abs/2506.17733

Classes:

Name Description
Conv
DSConv
DS_Bottleneck
DS_C3k
DS_C3k2
AdaptiveHyperedgeGeneration
HypergraphConvolution
AdaptiveHypergraphComputation
C3AH
HyperACE
GatedFusion
BackboneHyperAceV1
BackboneHyperAceV2
DecoderHyperAce
FreqPixelShuffleV1
ProgressiveUpsampleHeadV1
FreqPixelShuffleV2
ProgressiveUpsampleHeadV2
SegmModelHyperAceV1
SegmModelHyperAceV2

Functions:

Name Description
autopad
build_hyperace_tfc_tdf
autopad
autopad(
    k: int | tuple[int, int],
    p: int | tuple[int, int] | None = None,
) -> int | tuple[int, int]
Source code in src/splifft/models/utils/hyperace.py
24
25
26
27
28
29
30
def autopad(
    k: int | tuple[int, int],
    p: int | tuple[int, int] | None = None,
) -> int | tuple[int, int]:
    if p is None:
        p = k // 2 if isinstance(k, int) else (k[0] // 2, k[1] // 2)
    return p
build_hyperace_tfc_tdf
build_hyperace_tfc_tdf(
    in_c: int, c: int, l: int, f: int, bn: int = 4
) -> TfcTdf
Source code in src/splifft/models/utils/hyperace.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def build_hyperace_tfc_tdf(
    in_c: int,
    c: int,
    l: int,
    f: int,
    bn: int = 4,
) -> TfcTdf:
    return TfcTdf(
        in_channels=in_c,
        out_channels=c,
        num_blocks=l,
        f_bins=f,
        bottleneck_factor=bn,
        norm_factory=instance_norm_factory,
        act_factory=silu_factory,
    )
Conv
Conv(
    c1: int,
    c2: int,
    k: int | tuple[int, int] = 1,
    s: int | tuple[int, int] = 1,
    p: int | tuple[int, int] | None = None,
    g: int = 1,
    act: bool = True,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
conv
bn
act
Source code in src/splifft/models/utils/hyperace.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def __init__(
    self,
    c1: int,
    c2: int,
    k: int | tuple[int, int] = 1,
    s: int | tuple[int, int] = 1,
    p: int | tuple[int, int] | None = None,
    g: int = 1,
    act: bool = True,
):
    super().__init__()
    self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
    self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
    self.act = nn.SiLU() if act else nn.Identity()
conv instance-attribute
conv = Conv2d(
    c1, c2, k, s, autopad(k, p), groups=g, bias=False
)
bn instance-attribute
bn = InstanceNorm2d(c2, affine=True, eps=1e-08)
act instance-attribute
act = SiLU() if act else Identity()
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
67
68
def forward(self, x: Tensor) -> Any:
    return self.act(self.bn(self.conv(x)))
DSConv
DSConv(
    c1: int,
    c2: int,
    k: int | tuple[int, int] = 3,
    s: int | tuple[int, int] = 1,
    p: int | tuple[int, int] | None = None,
    act: bool = True,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
dwconv
pwconv
bn
act
Source code in src/splifft/models/utils/hyperace.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def __init__(
    self,
    c1: int,
    c2: int,
    k: int | tuple[int, int] = 3,
    s: int | tuple[int, int] = 1,
    p: int | tuple[int, int] | None = None,
    act: bool = True,
):
    super().__init__()
    self.dwconv = nn.Conv2d(c1, c1, k, s, autopad(k, p), groups=c1, bias=False)
    self.pwconv = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
    self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
    self.act = nn.SiLU() if act else nn.Identity()
dwconv instance-attribute
dwconv = Conv2d(
    c1, c1, k, s, autopad(k, p), groups=c1, bias=False
)
pwconv instance-attribute
pwconv = Conv2d(c1, c2, 1, 1, 0, bias=False)
bn instance-attribute
bn = InstanceNorm2d(c2, affine=True, eps=1e-08)
act instance-attribute
act = SiLU() if act else Identity()
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
87
88
def forward(self, x: Tensor) -> Any:
    return self.act(self.bn(self.pwconv(self.dwconv(x))))
DS_Bottleneck
DS_Bottleneck(
    c1: int, c2: int, k: int = 3, shortcut: bool = True
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
dsconv1
dsconv2
shortcut
Source code in src/splifft/models/utils/hyperace.py
92
93
94
95
96
97
def __init__(self, c1: int, c2: int, k: int = 3, shortcut: bool = True):
    super().__init__()
    c_ = c1
    self.dsconv1 = DSConv(c1, c_, k=3, s=1)
    self.dsconv2 = DSConv(c_, c2, k=k, s=1)
    self.shortcut = shortcut and c1 == c2
dsconv1 instance-attribute
dsconv1 = DSConv(c1, c_, k=3, s=1)
dsconv2 instance-attribute
dsconv2 = DSConv(c_, c2, k=k, s=1)
shortcut instance-attribute
shortcut = shortcut and c1 == c2
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
 99
100
101
102
def forward(self, x: Tensor) -> Any:
    if self.shortcut:
        return x + self.dsconv2(self.dsconv1(x))
    return self.dsconv2(self.dsconv1(x))
DS_C3k
DS_C3k(
    c1: int, c2: int, n: int = 1, k: int = 3, e: float = 0.5
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
cv1
cv2
cv3
m
Source code in src/splifft/models/utils/hyperace.py
106
107
108
109
110
111
112
def __init__(self, c1: int, c2: int, n: int = 1, k: int = 3, e: float = 0.5):
    super().__init__()
    c_ = int(c2 * e)
    self.cv1 = Conv(c1, c_, 1, 1)
    self.cv2 = Conv(c1, c_, 1, 1)
    self.cv3 = Conv(2 * c_, c2, 1, 1)
    self.m = nn.Sequential(*[DS_Bottleneck(c_, c_, k=k, shortcut=True) for _ in range(n)])
cv1 instance-attribute
cv1 = Conv(c1, c_, 1, 1)
cv2 instance-attribute
cv2 = Conv(c1, c_, 1, 1)
cv3 instance-attribute
cv3 = Conv(2 * c_, c2, 1, 1)
m instance-attribute
m = Sequential(
    *[
        (DS_Bottleneck(c_, c_, k=k, shortcut=True))
        for _ in (range(n))
    ]
)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
114
115
def forward(self, x: Tensor) -> Any:
    return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
DS_C3k2
DS_C3k2(
    c1: int, c2: int, n: int = 1, k: int = 3, e: float = 0.5
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
cv1
m
cv2
Source code in src/splifft/models/utils/hyperace.py
119
120
121
122
123
124
def __init__(self, c1: int, c2: int, n: int = 1, k: int = 3, e: float = 0.5):
    super().__init__()
    c_ = int(c2 * e)
    self.cv1 = Conv(c1, c_, 1, 1)
    self.m = DS_C3k(c_, c_, n=n, k=k, e=1.0)
    self.cv2 = Conv(c_, c2, 1, 1)
cv1 instance-attribute
cv1 = Conv(c1, c_, 1, 1)
m instance-attribute
m = DS_C3k(c_, c_, n=n, k=k, e=1.0)
cv2 instance-attribute
cv2 = Conv(c_, c2, 1, 1)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
126
127
128
129
def forward(self, x: Tensor) -> Any:
    x_ = self.cv1(x)
    x_ = self.m(x_)
    return self.cv2(x_)
AdaptiveHyperedgeGeneration
AdaptiveHyperedgeGeneration(
    in_channels: int,
    num_hyperedges: int,
    num_heads: int = 8,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
num_hyperedges
num_heads
head_dim
global_proto
context_mapper
query_proj
scale
Source code in src/splifft/models/utils/hyperace.py
133
134
135
136
137
138
139
140
141
142
143
144
145
def __init__(self, in_channels: int, num_hyperedges: int, num_heads: int = 8):
    super().__init__()
    self.num_hyperedges = num_hyperedges
    self.num_heads = num_heads
    self.head_dim = in_channels // num_heads

    self.global_proto = nn.Parameter(torch.randn(num_hyperedges, in_channels))

    self.context_mapper = nn.Linear(2 * in_channels, num_hyperedges * in_channels, bias=False)

    self.query_proj = nn.Linear(in_channels, in_channels, bias=False)

    self.scale = self.head_dim**-0.5
num_hyperedges instance-attribute
num_hyperedges = num_hyperedges
num_heads instance-attribute
num_heads = num_heads
head_dim instance-attribute
head_dim = in_channels // num_heads
global_proto instance-attribute
global_proto = Parameter(randn(num_hyperedges, in_channels))
context_mapper instance-attribute
context_mapper = Linear(
    2 * in_channels,
    num_hyperedges * in_channels,
    bias=False,
)
query_proj instance-attribute
query_proj = Linear(in_channels, in_channels, bias=False)
scale instance-attribute
scale = head_dim ** -0.5
forward
forward(x: Tensor) -> Tensor
Source code in src/splifft/models/utils/hyperace.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def forward(self, x: Tensor) -> Tensor:
    b, n, c = x.shape

    f_avg = F.adaptive_avg_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
    f_max = F.adaptive_max_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
    f_ctx = torch.cat((f_avg, f_max), dim=1)

    delta_p = self.context_mapper(f_ctx).view(b, self.num_hyperedges, c)
    p = self.global_proto.unsqueeze(0) + delta_p

    z = self.query_proj(x)

    z = z.view(b, n, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

    p = p.view(b, self.num_hyperedges, self.num_heads, self.head_dim).permute(0, 2, 3, 1)

    sim = (z @ p) * self.scale

    s_bar = sim.mean(dim=1)

    a = F.softmax(s_bar.permute(0, 2, 1), dim=-1)

    return a
HypergraphConvolution
HypergraphConvolution(in_channels: int, out_channels: int)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
W_e
W_v
act
Source code in src/splifft/models/utils/hyperace.py
173
174
175
176
177
def __init__(self, in_channels: int, out_channels: int):
    super().__init__()
    self.W_e = nn.Linear(in_channels, in_channels, bias=False)
    self.W_v = nn.Linear(in_channels, out_channels, bias=False)
    self.act = nn.SiLU()
W_e instance-attribute
W_e = Linear(in_channels, in_channels, bias=False)
W_v instance-attribute
W_v = Linear(in_channels, out_channels, bias=False)
act instance-attribute
act = SiLU()
forward
forward(x: Tensor, a: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
179
180
181
182
183
184
185
186
def forward(self, x: Tensor, a: Tensor) -> Any:
    f_m = torch.bmm(a, x)
    f_m = self.act(self.W_e(f_m))

    x_out = torch.bmm(a.transpose(1, 2), f_m)
    x_out = self.act(self.W_v(x_out))

    return x + x_out
AdaptiveHypergraphComputation
AdaptiveHypergraphComputation(
    in_channels: int,
    out_channels: int,
    num_hyperedges: int = 8,
    num_heads: int = 8,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
adaptive_hyperedge_gen
hypergraph_conv
Source code in src/splifft/models/utils/hyperace.py
190
191
192
193
194
195
196
197
198
199
200
201
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    num_hyperedges: int = 8,
    num_heads: int = 8,
):
    super().__init__()
    self.adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(
        in_channels, num_hyperedges, num_heads
    )
    self.hypergraph_conv = HypergraphConvolution(in_channels, out_channels)
adaptive_hyperedge_gen instance-attribute
adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(
    in_channels, num_hyperedges, num_heads
)
hypergraph_conv instance-attribute
hypergraph_conv = HypergraphConvolution(
    in_channels, out_channels
)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
203
204
205
206
207
208
209
210
211
212
def forward(self, x: Tensor) -> Any:
    b, _, h, w = x.shape
    x_flat = x.flatten(2).permute(0, 2, 1)

    a = self.adaptive_hyperedge_gen(x_flat)

    x_out_flat = self.hypergraph_conv(x_flat, a)

    x_out = x_out_flat.permute(0, 2, 1).view(b, -1, h, w)
    return x_out
C3AH
C3AH(
    c1: int,
    c2: int,
    num_hyperedges: int = 8,
    num_heads: int = 8,
    e: float = 0.5,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
cv1
cv2
ahc
cv3
Source code in src/splifft/models/utils/hyperace.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def __init__(
    self,
    c1: int,
    c2: int,
    num_hyperedges: int = 8,
    num_heads: int = 8,
    e: float = 0.5,
):
    super().__init__()
    c_ = int(c1 * e)
    self.cv1 = Conv(c1, c_, 1, 1)
    self.cv2 = Conv(c1, c_, 1, 1)
    self.ahc = AdaptiveHypergraphComputation(c_, c_, num_hyperedges, num_heads)
    self.cv3 = Conv(2 * c_, c2, 1, 1)
cv1 instance-attribute
cv1 = Conv(c1, c_, 1, 1)
cv2 instance-attribute
cv2 = Conv(c1, c_, 1, 1)
ahc instance-attribute
ahc = AdaptiveHypergraphComputation(
    c_, c_, num_hyperedges, num_heads
)
cv3 instance-attribute
cv3 = Conv(2 * c_, c2, 1, 1)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
231
232
233
234
def forward(self, x: Tensor) -> Any:
    x_lateral = self.cv1(x)
    x_ahc = self.ahc(self.cv2(x))
    return self.cv3(torch.cat((x_ahc, x_lateral), dim=1))
HyperACE
HyperACE(
    in_channels: list[int],
    out_channels: int,
    num_hyperedges: int = 8,
    num_heads: int = 8,
    k: int = 2,
    l: int = 1,
    c_h: float = 0.5,
    c_l: float = 0.25,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
fuse_conv
c_h
c_l
c_s
high_order_branch
high_order_fuse
low_order_branch
final_fuse
Source code in src/splifft/models/utils/hyperace.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def __init__(
    self,
    in_channels: list[int],
    out_channels: int,
    num_hyperedges: int = 8,
    num_heads: int = 8,
    k: int = 2,
    l: int = 1,
    c_h: float = 0.5,
    c_l: float = 0.25,
):
    super().__init__()

    c2, c3, c4, c5 = in_channels
    c_mid = c4

    self.fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1)

    self.c_h = int(c_mid * c_h)
    self.c_l = int(c_mid * c_l)
    self.c_s = c_mid - self.c_h - self.c_l
    assert self.c_s > 0, "Channel split error"

    self.high_order_branch = nn.ModuleList(
        [C3AH(self.c_h, self.c_h, num_hyperedges, num_heads, e=1.0) for _ in range(k)]
    )
    self.high_order_fuse = Conv(self.c_h * k, self.c_h, 1, 1)

    self.low_order_branch = nn.Sequential(
        *[DS_C3k(self.c_l, self.c_l, n=1, k=3, e=1.0) for _ in range(l)]
    )

    self.final_fuse = Conv(self.c_h + self.c_l + self.c_s, out_channels, 1, 1)
fuse_conv instance-attribute
fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1)
c_h instance-attribute
c_h = int(c_mid * c_h)
c_l instance-attribute
c_l = int(c_mid * c_l)
c_s instance-attribute
c_s = c_mid - c_h - c_l
high_order_branch instance-attribute
high_order_branch = ModuleList(
    [
        (C3AH(c_h, c_h, num_hyperedges, num_heads, e=1.0))
        for _ in (range(k))
    ]
)
high_order_fuse instance-attribute
high_order_fuse = Conv(c_h * k, c_h, 1, 1)
low_order_branch instance-attribute
low_order_branch = Sequential(
    *[
        (DS_C3k(c_l, c_l, n=1, k=3, e=1.0))
        for _ in (range(l))
    ]
)
final_fuse instance-attribute
final_fuse = Conv(c_h + c_l + c_s, out_channels, 1, 1)
forward
forward(x: list[Tensor]) -> Any
Source code in src/splifft/models/utils/hyperace.py
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def forward(self, x: list[Tensor]) -> Any:
    b2, b3, b4, b5 = x

    _, _, h4, w4 = b4.shape

    b2_resized = F.interpolate(b2, size=(h4, w4), mode="bilinear", align_corners=False)
    b3_resized = F.interpolate(b3, size=(h4, w4), mode="bilinear", align_corners=False)
    b5_resized = F.interpolate(b5, size=(h4, w4), mode="bilinear", align_corners=False)

    x_b = self.fuse_conv(torch.cat((b2_resized, b3_resized, b4, b5_resized), dim=1))

    x_h, x_l, x_s = torch.split(x_b, [self.c_h, self.c_l, self.c_s], dim=1)

    x_h_outs = [m(x_h) for m in self.high_order_branch]
    x_h_fused = self.high_order_fuse(torch.cat(x_h_outs, dim=1))

    x_l_out = self.low_order_branch(x_l)

    y = self.final_fuse(torch.cat((x_h_fused, x_l_out, x_s), dim=1))

    return y
GatedFusion
GatedFusion(in_channels: int)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
gamma
Source code in src/splifft/models/utils/hyperace.py
296
297
298
def __init__(self, in_channels: int):
    super().__init__()
    self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
gamma instance-attribute
gamma = Parameter(zeros(1, in_channels, 1, 1))
forward
forward(f_in: Tensor, h: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
300
301
302
303
def forward(self, f_in: Tensor, h: Tensor) -> Any:
    if f_in.shape[1] != h.shape[1]:
        raise ValueError(f"Channel mismatch: f_in={f_in.shape}, h={h.shape}")
    return f_in + self.gamma * h
BackboneHyperAceV1
BackboneHyperAceV1(
    in_channels: int = 256,
    base_channels: int = 64,
    base_depth: int = 3,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
stem
p2
p3
p4
p5
out_channels
Source code in src/splifft/models/utils/hyperace.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
def __init__(self, in_channels: int = 256, base_channels: int = 64, base_depth: int = 3):
    super().__init__()
    c2 = base_channels
    c3 = 256
    c4 = 384
    c5 = 512
    c6 = 768

    self.stem = DSConv(in_channels, c2, k=3, s=(2, 1), p=1)

    self.p2 = nn.Sequential(
        DSConv(c2, c3, k=3, s=(2, 1), p=1),
        DS_C3k2(c3, c3, n=base_depth),
    )

    self.p3 = nn.Sequential(
        DSConv(c3, c4, k=3, s=(2, 1), p=1),
        DS_C3k2(c4, c4, n=base_depth * 2),
    )

    self.p4 = nn.Sequential(
        DSConv(c4, c5, k=3, s=(2, 1), p=1),
        DS_C3k2(c5, c5, n=base_depth * 2),
    )

    self.p5 = nn.Sequential(
        DSConv(c5, c6, k=3, s=(2, 1), p=1),
        DS_C3k2(c6, c6, n=base_depth),
    )

    self.out_channels = [c3, c4, c5, c6]
stem instance-attribute
stem = DSConv(in_channels, c2, k=3, s=(2, 1), p=1)
p2 instance-attribute
p2 = Sequential(
    DSConv(c2, c3, k=3, s=(2, 1), p=1),
    DS_C3k2(c3, c3, n=base_depth),
)
p3 instance-attribute
p3 = Sequential(
    DSConv(c3, c4, k=3, s=(2, 1), p=1),
    DS_C3k2(c4, c4, n=base_depth * 2),
)
p4 instance-attribute
p4 = Sequential(
    DSConv(c4, c5, k=3, s=(2, 1), p=1),
    DS_C3k2(c5, c5, n=base_depth * 2),
)
p5 instance-attribute
p5 = Sequential(
    DSConv(c5, c6, k=3, s=(2, 1), p=1),
    DS_C3k2(c6, c6, n=base_depth),
)
out_channels instance-attribute
out_channels = [c3, c4, c5, c6]
forward
forward(x: Tensor) -> list[Tensor]
Source code in src/splifft/models/utils/hyperace.py
339
340
341
342
343
344
345
def forward(self, x: Tensor) -> list[Tensor]:
    x = self.stem(x)
    x2 = self.p2(x)
    x3 = self.p3(x2)
    x4 = self.p4(x3)
    x5 = self.p5(x4)
    return [x2, x3, x4, x5]
BackboneHyperAceV2
BackboneHyperAceV2(
    in_channels: int = 256,
    base_channels: int = 64,
    base_depth: int = 3,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
stem
p2
p3
p4
p5
out_channels
Source code in src/splifft/models/utils/hyperace.py
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def __init__(self, in_channels: int = 256, base_channels: int = 64, base_depth: int = 3):
    super().__init__()
    c2 = base_channels
    c3 = 256
    c4 = 384
    c5 = 512
    c6 = 768

    self.stem = DSConv(in_channels, c2, k=3, s=(2, 1), p=1)

    self.p2 = nn.Sequential(
        DSConv(c2, c3, k=3, s=(2, 1), p=1),
        DS_C3k2(c3, c3, n=base_depth),
    )

    self.p3 = nn.Sequential(
        DSConv(c3, c4, k=3, s=(2, 1), p=1),
        DS_C3k2(c4, c4, n=base_depth * 2),
    )

    self.p4 = nn.Sequential(
        DSConv(c4, c5, k=3, s=2, p=1),
        DS_C3k2(c5, c5, n=base_depth * 2),
    )

    self.p5 = nn.Sequential(
        DSConv(c5, c6, k=3, s=2, p=1),
        DS_C3k2(c6, c6, n=base_depth),
    )

    self.out_channels = [c3, c4, c5, c6]
stem instance-attribute
stem = DSConv(in_channels, c2, k=3, s=(2, 1), p=1)
p2 instance-attribute
p2 = Sequential(
    DSConv(c2, c3, k=3, s=(2, 1), p=1),
    DS_C3k2(c3, c3, n=base_depth),
)
p3 instance-attribute
p3 = Sequential(
    DSConv(c3, c4, k=3, s=(2, 1), p=1),
    DS_C3k2(c4, c4, n=base_depth * 2),
)
p4 instance-attribute
p4 = Sequential(
    DSConv(c4, c5, k=3, s=2, p=1),
    DS_C3k2(c5, c5, n=base_depth * 2),
)
p5 instance-attribute
p5 = Sequential(
    DSConv(c5, c6, k=3, s=2, p=1),
    DS_C3k2(c6, c6, n=base_depth),
)
out_channels instance-attribute
out_channels = [c3, c4, c5, c6]
forward
forward(x: Tensor) -> list[Tensor]
Source code in src/splifft/models/utils/hyperace.py
381
382
383
384
385
386
387
def forward(self, x: Tensor) -> list[Tensor]:
    x = self.stem(x)
    x2 = self.p2(x)
    x3 = self.p3(x2)
    x4 = self.p4(x3)
    x5 = self.p5(x4)
    return [x2, x3, x4, x5]
DecoderHyperAce
DecoderHyperAce(
    encoder_channels: list[int],
    hyperace_out_c: int,
    decoder_channels: list[int],
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
h_to_d5
h_to_d4
h_to_d3
h_to_d2
fusion_d5
fusion_d4
fusion_d3
fusion_d2
skip_p5
skip_p4
skip_p3
skip_p2
up_d5
up_d4
up_d3
final_d2
Source code in src/splifft/models/utils/hyperace.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
def __init__(
    self, encoder_channels: list[int], hyperace_out_c: int, decoder_channels: list[int]
):
    super().__init__()
    c_p2, c_p3, c_p4, c_p5 = encoder_channels
    c_d2, c_d3, c_d4, c_d5 = decoder_channels

    self.h_to_d5 = Conv(hyperace_out_c, c_d5, 1, 1)
    self.h_to_d4 = Conv(hyperace_out_c, c_d4, 1, 1)
    self.h_to_d3 = Conv(hyperace_out_c, c_d3, 1, 1)
    self.h_to_d2 = Conv(hyperace_out_c, c_d2, 1, 1)

    self.fusion_d5 = GatedFusion(c_d5)
    self.fusion_d4 = GatedFusion(c_d4)
    self.fusion_d3 = GatedFusion(c_d3)
    self.fusion_d2 = GatedFusion(c_d2)

    self.skip_p5 = Conv(c_p5, c_d5, 1, 1)
    self.skip_p4 = Conv(c_p4, c_d4, 1, 1)
    self.skip_p3 = Conv(c_p3, c_d3, 1, 1)
    self.skip_p2 = Conv(c_p2, c_d2, 1, 1)

    self.up_d5 = DS_C3k2(c_d5, c_d4, n=1)
    self.up_d4 = DS_C3k2(c_d4, c_d3, n=1)
    self.up_d3 = DS_C3k2(c_d3, c_d2, n=1)

    self.final_d2 = DS_C3k2(c_d2, c_d2, n=1)
h_to_d5 instance-attribute
h_to_d5 = Conv(hyperace_out_c, c_d5, 1, 1)
h_to_d4 instance-attribute
h_to_d4 = Conv(hyperace_out_c, c_d4, 1, 1)
h_to_d3 instance-attribute
h_to_d3 = Conv(hyperace_out_c, c_d3, 1, 1)
h_to_d2 instance-attribute
h_to_d2 = Conv(hyperace_out_c, c_d2, 1, 1)
fusion_d5 instance-attribute
fusion_d5 = GatedFusion(c_d5)
fusion_d4 instance-attribute
fusion_d4 = GatedFusion(c_d4)
fusion_d3 instance-attribute
fusion_d3 = GatedFusion(c_d3)
fusion_d2 instance-attribute
fusion_d2 = GatedFusion(c_d2)
skip_p5 instance-attribute
skip_p5 = Conv(c_p5, c_d5, 1, 1)
skip_p4 instance-attribute
skip_p4 = Conv(c_p4, c_d4, 1, 1)
skip_p3 instance-attribute
skip_p3 = Conv(c_p3, c_d3, 1, 1)
skip_p2 instance-attribute
skip_p2 = Conv(c_p2, c_d2, 1, 1)
up_d5 instance-attribute
up_d5 = DS_C3k2(c_d5, c_d4, n=1)
up_d4 instance-attribute
up_d4 = DS_C3k2(c_d4, c_d3, n=1)
up_d3 instance-attribute
up_d3 = DS_C3k2(c_d3, c_d2, n=1)
final_d2 instance-attribute
final_d2 = DS_C3k2(c_d2, c_d2, n=1)
forward
forward(enc_feats: list[Tensor], h_ace: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
def forward(self, enc_feats: list[Tensor], h_ace: Tensor) -> Any:
    p2, p3, p4, p5 = enc_feats

    d5 = self.skip_p5(p5)
    h_d5 = self.h_to_d5(F.interpolate(h_ace, size=d5.shape[2:], mode="bilinear"))
    d5 = self.fusion_d5(d5, h_d5)

    d5_up = F.interpolate(d5, size=p4.shape[2:], mode="bilinear")
    d4_skip = self.skip_p4(p4)
    d4 = self.up_d5(d5_up) + d4_skip

    h_d4 = self.h_to_d4(F.interpolate(h_ace, size=d4.shape[2:], mode="bilinear"))
    d4 = self.fusion_d4(d4, h_d4)

    d4_up = F.interpolate(d4, size=p3.shape[2:], mode="bilinear")
    d3_skip = self.skip_p3(p3)
    d3 = self.up_d4(d4_up) + d3_skip

    h_d3 = self.h_to_d3(F.interpolate(h_ace, size=d3.shape[2:], mode="bilinear"))
    d3 = self.fusion_d3(d3, h_d3)

    d3_up = F.interpolate(d3, size=p2.shape[2:], mode="bilinear")
    d2_skip = self.skip_p2(p2)
    d2 = self.up_d3(d3_up) + d2_skip

    h_d2 = self.h_to_d2(F.interpolate(h_ace, size=d2.shape[2:], mode="bilinear"))
    d2 = self.fusion_d2(d2, h_d2)

    d2_final = self.final_d2(d2)

    return d2_final
FreqPixelShuffleV1
FreqPixelShuffleV1(
    in_channels: int, out_channels: int, scale: int = 2
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
scale
conv
Source code in src/splifft/models/utils/hyperace.py
453
454
455
456
def __init__(self, in_channels: int, out_channels: int, scale: int = 2):
    super().__init__()
    self.scale = scale
    self.conv = DSConv(in_channels, out_channels * scale, k=3, s=1, p=1)
scale instance-attribute
scale = scale
conv instance-attribute
conv = DSConv(
    in_channels, out_channels * scale, k=3, s=1, p=1
)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
458
459
460
461
462
463
464
465
466
467
468
def forward(self, x: Tensor) -> Any:
    x = self.conv(x)
    b, c_r, h, w = x.shape
    out_c = c_r // self.scale

    x = x.view(b, out_c, self.scale, h, w)

    x = x.permute(0, 1, 3, 4, 2).contiguous()
    x = x.view(b, out_c, h, w * self.scale)

    return x
ProgressiveUpsampleHeadV1
ProgressiveUpsampleHeadV1(
    in_channels: int,
    out_channels: int,
    target_bins: int = 1025,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
target_bins
block1
block2
block3
block4
final_conv
Source code in src/splifft/models/utils/hyperace.py
472
473
474
475
476
477
478
479
480
481
482
483
def __init__(self, in_channels: int, out_channels: int, target_bins: int = 1025):
    super().__init__()
    self.target_bins = target_bins

    c = in_channels

    self.block1 = FreqPixelShuffleV1(c, c, scale=2)
    self.block2 = FreqPixelShuffleV1(c, c // 2, scale=2)
    self.block3 = FreqPixelShuffleV1(c // 2, c // 2, scale=2)
    self.block4 = FreqPixelShuffleV1(c // 2, c // 4, scale=2)

    self.final_conv = nn.Conv2d(c // 4, out_channels, kernel_size=1, bias=False)
target_bins instance-attribute
target_bins = target_bins
block1 instance-attribute
block1 = FreqPixelShuffleV1(c, c, scale=2)
block2 instance-attribute
block2 = FreqPixelShuffleV1(c, c // 2, scale=2)
block3 instance-attribute
block3 = FreqPixelShuffleV1(c // 2, c // 2, scale=2)
block4 instance-attribute
block4 = FreqPixelShuffleV1(c // 2, c // 4, scale=2)
final_conv instance-attribute
final_conv = Conv2d(
    c // 4, out_channels, kernel_size=1, bias=False
)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
def forward(self, x: Tensor) -> Any:
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    x = self.block4(x)

    if x.shape[-1] != self.target_bins:
        x = F.interpolate(
            x,
            size=(x.shape[2], self.target_bins),
            mode="bilinear",
            align_corners=False,
        )

    x = self.final_conv(x)
    return x
FreqPixelShuffleV2
FreqPixelShuffleV2(
    in_channels: int, out_channels: int, scale: int, f: int
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
scale
conv
out_conv
Source code in src/splifft/models/utils/hyperace.py
504
505
506
507
508
def __init__(self, in_channels: int, out_channels: int, scale: int, f: int):
    super().__init__()
    self.scale = scale
    self.conv = DSConv(in_channels, out_channels * scale)
    self.out_conv = build_hyperace_tfc_tdf(out_channels, out_channels, 2, f)
scale instance-attribute
scale = scale
conv instance-attribute
conv = DSConv(in_channels, out_channels * scale)
out_conv instance-attribute
out_conv = build_hyperace_tfc_tdf(
    out_channels, out_channels, 2, f
)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
510
511
512
513
514
515
516
517
518
519
520
def forward(self, x: Tensor) -> Any:
    x = self.conv(x)
    b, c_r, h, w = x.shape
    out_c = c_r // self.scale

    x = x.view(b, out_c, self.scale, h, w)

    x = x.permute(0, 1, 3, 4, 2).contiguous()
    x = x.view(b, out_c, h, w * self.scale)

    return self.out_conv(x)
ProgressiveUpsampleHeadV2
ProgressiveUpsampleHeadV2(
    in_channels: int,
    out_channels: int,
    target_bins: int = 1025,
    in_bands: int = 62,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
target_bins
block1
block2
block3
block4
final_conv
Source code in src/splifft/models/utils/hyperace.py
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
def __init__(
    self, in_channels: int, out_channels: int, target_bins: int = 1025, in_bands: int = 62
):
    super().__init__()
    self.target_bins = target_bins

    c = in_channels

    self.block1 = FreqPixelShuffleV2(c, c // 2, scale=2, f=in_bands * 2)
    self.block2 = FreqPixelShuffleV2(c // 2, c // 4, scale=2, f=in_bands * 4)
    self.block3 = FreqPixelShuffleV2(c // 4, c // 8, scale=2, f=in_bands * 8)
    self.block4 = FreqPixelShuffleV2(c // 8, c // 16, scale=2, f=in_bands * 16)

    self.final_conv = nn.Conv2d(
        c // 16,
        out_channels,
        kernel_size=3,
        stride=1,
        padding="same",
        bias=False,
    )
target_bins instance-attribute
target_bins = target_bins
block1 instance-attribute
block1 = FreqPixelShuffleV2(
    c, c // 2, scale=2, f=in_bands * 2
)
block2 instance-attribute
block2 = FreqPixelShuffleV2(
    c // 2, c // 4, scale=2, f=in_bands * 4
)
block3 instance-attribute
block3 = FreqPixelShuffleV2(
    c // 4, c // 8, scale=2, f=in_bands * 8
)
block4 instance-attribute
block4 = FreqPixelShuffleV2(
    c // 8, c // 16, scale=2, f=in_bands * 16
)
final_conv instance-attribute
final_conv = Conv2d(
    c // 16,
    out_channels,
    kernel_size=3,
    stride=1,
    padding="same",
    bias=False,
)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
def forward(self, x: Tensor) -> Any:
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    x = self.block4(x)

    if x.shape[-1] != self.target_bins:
        x = F.interpolate(
            x,
            size=(x.shape[2], self.target_bins),
            mode="bilinear",
            align_corners=False,
        )

    x = self.final_conv(x)
    return x
SegmModelHyperAceV1
SegmModelHyperAceV1(
    in_bands: int = 62,
    in_dim: int = 256,
    out_bins: int = 1025,
    out_channels: int = 4,
    base_channels: int = 64,
    base_depth: int = 2,
    num_hyperedges: int = 16,
    num_heads: int = 8,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
backbone
hyperace
decoder
upsample_head
Source code in src/splifft/models/utils/hyperace.py
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
def __init__(
    self,
    in_bands: int = 62,
    in_dim: int = 256,
    out_bins: int = 1025,
    out_channels: int = 4,
    base_channels: int = 64,
    base_depth: int = 2,
    num_hyperedges: int = 16,
    num_heads: int = 8,
):
    super().__init__()

    self.backbone = BackboneHyperAceV1(
        in_channels=in_dim,
        base_channels=base_channels,
        base_depth=base_depth,
    )
    enc_channels = self.backbone.out_channels
    c2, c3, c4, c5 = enc_channels

    hyperace_in_channels = enc_channels
    hyperace_out_channels = c4
    self.hyperace = HyperACE(
        hyperace_in_channels,
        hyperace_out_channels,
        num_hyperedges,
        num_heads,
        k=3,
        l=2,
    )

    decoder_channels = [c2, c3, c4, c5]
    self.decoder = DecoderHyperAce(enc_channels, hyperace_out_channels, decoder_channels)

    self.upsample_head = ProgressiveUpsampleHeadV1(
        in_channels=decoder_channels[0],
        out_channels=out_channels,
        target_bins=out_bins,
    )
backbone instance-attribute
backbone = BackboneHyperAceV1(
    in_channels=in_dim,
    base_channels=base_channels,
    base_depth=base_depth,
)
hyperace instance-attribute
hyperace = HyperACE(
    hyperace_in_channels,
    hyperace_out_channels,
    num_hyperedges,
    num_heads,
    k=3,
    l=2,
)
decoder instance-attribute
decoder = DecoderHyperAce(
    enc_channels, hyperace_out_channels, decoder_channels
)
upsample_head instance-attribute
upsample_head = ProgressiveUpsampleHeadV1(
    in_channels=decoder_channels[0],
    out_channels=out_channels,
    target_bins=out_bins,
)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
def forward(self, x: Tensor) -> Any:
    h, _ = x.shape[2:]

    enc_feats = self.backbone(x)

    h_ace_feats = self.hyperace(enc_feats)

    dec_feat = self.decoder(enc_feats, h_ace_feats)

    feat_time_restored = F.interpolate(
        dec_feat,
        size=(h, dec_feat.shape[-1]),
        mode="bilinear",
        align_corners=False,
    )

    out = self.upsample_head(feat_time_restored)

    return out
SegmModelHyperAceV2
SegmModelHyperAceV2(
    in_bands: int = 62,
    in_dim: int = 256,
    out_bins: int = 1025,
    out_channels: int = 4,
    base_channels: int = 64,
    base_depth: int = 2,
    num_hyperedges: int = 32,
    num_heads: int = 8,
)

Bases: Module

Methods:

Name Description
forward

Attributes:

Name Type Description
backbone
hyperace
decoder
upsample_head
Source code in src/splifft/models/utils/hyperace.py
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
def __init__(
    self,
    in_bands: int = 62,
    in_dim: int = 256,
    out_bins: int = 1025,
    out_channels: int = 4,
    base_channels: int = 64,
    base_depth: int = 2,
    num_hyperedges: int = 32,
    num_heads: int = 8,
):
    super().__init__()

    self.backbone = BackboneHyperAceV2(
        in_channels=in_dim,
        base_channels=base_channels,
        base_depth=base_depth,
    )
    enc_channels = self.backbone.out_channels
    c2, c3, c4, c5 = enc_channels

    hyperace_in_channels = enc_channels
    hyperace_out_channels = c4
    self.hyperace = HyperACE(
        hyperace_in_channels,
        hyperace_out_channels,
        num_hyperedges,
        num_heads,
        k=2,
        l=1,
    )

    decoder_channels = [c2, c3, c4, c5]
    self.decoder = DecoderHyperAce(enc_channels, hyperace_out_channels, decoder_channels)

    self.upsample_head = ProgressiveUpsampleHeadV2(
        in_channels=decoder_channels[0],
        out_channels=out_channels,
        target_bins=out_bins,
        in_bands=in_bands,
    )
backbone instance-attribute
backbone = BackboneHyperAceV2(
    in_channels=in_dim,
    base_channels=base_channels,
    base_depth=base_depth,
)
hyperace instance-attribute
hyperace = HyperACE(
    hyperace_in_channels,
    hyperace_out_channels,
    num_hyperedges,
    num_heads,
    k=2,
    l=1,
)
decoder instance-attribute
decoder = DecoderHyperAce(
    enc_channels, hyperace_out_channels, decoder_channels
)
upsample_head instance-attribute
upsample_head = ProgressiveUpsampleHeadV2(
    in_channels=decoder_channels[0],
    out_channels=out_channels,
    target_bins=out_bins,
    in_bands=in_bands,
)
forward
forward(x: Tensor) -> Any
Source code in src/splifft/models/utils/hyperace.py
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
def forward(self, x: Tensor) -> Any:
    h, _ = x.shape[2:]

    enc_feats = self.backbone(x)

    h_ace_feats = self.hyperace(enc_feats)

    dec_feat = self.decoder(enc_feats, h_ace_feats)

    feat_time_restored = F.interpolate(
        dec_feat,
        size=(h, dec_feat.shape[-1]),
        mode="bilinear",
        align_corners=False,
    )

    out = self.upsample_head(feat_time_restored)

    return out