Skip to content

SpliFFT

image image image Ruff MkDocs

Lightweight utilities for music source separation.

This library is a ground-up rewrite of the zfturbo's MSST repo, with a strong focus on robustness, simplicity and extensibility. While it is a fantastic collection of models and training scripts, this rewrite adopts a different architecture to address common pain points in research code.

Key principles:

  • Configuration as code: pydantic models are used instead of untyped dictionaries or ConfigDict. this provides static type safety, runtime data validation, IDE autocompletion, and a single, clear source of truth for all parameters.
  • Data-oriented and functional core: complex class hierarchies and inheritance are avoided. the codebase is built on plain data structures (like dataclasses) and pure, stateless functions.
  • Semantic typing as documentation: we leverage Python's type system to convey intent. types like RawAudioTensor vs. NormalizedAudioTensor make function signatures self-documenting, reducing the need for verbose comments and ensuring correctness.
  • Extensibility without modification: new models can be integrated from external packages without altering the core library. the dynamic model loading system allows easy plug-and-play adhering to the open/closed principle.

⚠️ This is pre-alpha software, expect significant breaking changes.

Features and Roadmap

Short term (high priority)

  • a robust, typed JSON configuration system powered by pydantic
  • inferencing:
    • normalization and denormalization
    • chunk generation: vectorized with unfold
    • chunk stitching: vectorized overlap-add with fold
    • flexible ruleset for stem deriving: add/subtract model outputs or any intermediate output (e.g., creating an instrumental track by subtracting vocals from the mixture).
  • web-based docs: generated with mkdocs with excellent crossrefs.
  • simple CLI for inferencing on a directory of audio files
  • BS-Roformer: ensure bit-for-bit equivalence in pytorch and strive for max perf.
  • initial fp16 support
  • support coremltools and torch.compile
    • handroll complex multiplication implementation
    • isolate/handroll istft in forward pass
  • evals: SDR, bleedless, fullness, etc.
  • datasets: MUSDB18-HQ, moises
  • proper benchmarking (MFU, memory...)
  • port additional SOTA models from MSST (e.g. Mel Roformer, SCNet)
  • directly support popular models (e.g. by @unwa, gabox, by @becruily)

Long term (low priority)

  • model registry with simple file-based cache
  • data augmentation
  • implement a complete, configurable training loop
  • max kernels
  • simple web-based GUI with FastAPI and Svelte.

Contributing: PRs are very welcome!

Installation & Usage

Documentation on the config (amongst other details) can be found here

CLI

There are three steps. You do not need to have Python installed.

  1. Install uv if you haven't already. It is an awesome Python package and library manager with pip comptability.

    # Linux / MacOS
    wget -qO- https://astral.sh/uv/install.sh | sh
    # Windows
    powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
    

  2. Open a new terminal and install the latest stable PyPI release as a tool. It will install the Python interpreter, all necessary packages and add the splifft executable to your PATH:

    uv tool install "splifft[config,inference,cli]"
    
    I want the latest bleeding-edge version

    This directly pulls from the main branch, which may be unstable:

    uv tool install "git+https://github.com/undef13/splifft.git[config,inference,cli]"
    

  3. Go into a new directory and place the model checkpoint and configuration inside it. Assuming your current directory has this structure (doesn't have to be exactly this):

    Grab an example audio from YouTube

    uv tool install yt-dlp
    yt-dlp -f bestaudio -o data/audio/input/3BFTio5296w.flac 3BFTio5296w
    

    .
    └── data
        ├── audio
        │   ├── input
        │   │   └── 3BFTio5296w.flac
        │   └── output
        ├── config
        │   └── bs_roformer.json
        └── models
            └── roformer-fp16.pt
    

    Run:

    splifft separate data/audio/input/3BFTio5296w.flac --config data/config/bs_roformer.json --checkpoint data/models/roformer-fp16.pt
    
    Console output

    [00:00:41] INFO     using device=device(type='cuda')                                                 __main__.py:117
               INFO     loading configuration from                                                       __main__.py:119
                        config_path=PosixPath('data/config/bs_roformer.json')                                           
               INFO     loading model metadata `BSRoformer` from module `splifft.models.bs_roformer`     __main__.py:122
    [00:00:42] INFO     loading weights from checkpoint_path=PosixPath('data/models/roformer-fp16.pt')   __main__.py:131
               INFO     processing audio file:                                                           __main__.py:138
                        mixture_path=PosixPath('data/audio/input/3BFTio5296w.flac')                                     
    [00:00:56] INFO     wrote stem `bass` to data/audio/output/3BFTio5296w/bass.flac                     __main__.py:168
               INFO     wrote stem `drums` to data/audio/output/3BFTio5296w/drums.flac                   __main__.py:168
               INFO     wrote stem `other` to data/audio/output/3BFTio5296w/other.flac                   __main__.py:168
    [00:00:57] INFO     wrote stem `vocals` to data/audio/output/3BFTio5296w/vocals.flac                 __main__.py:168
               INFO     wrote stem `guitar` to data/audio/output/3BFTio5296w/guitar.flac                 __main__.py:168
               INFO     wrote stem `piano` to data/audio/output/3BFTio5296w/piano.flac                   __main__.py:168
    [00:00:58] INFO     wrote stem `instrumental` to data/audio/output/3BFTio5296w/instrumental.flac     __main__.py:168
               INFO     wrote stem `drums_and_bass` to data/audio/output/3BFTio5296w/drums_and_bass.flac __main__.py:168
    

    To update the tool:

    uv tool upgrade splifft --force-reinstall
    

Library

Add splifft to your project:

# latest pypi version
uv add splifft
# latest bleeding edge
uv add git+https://github.com/undef13/splifft.git

This will install the absolutely minimal core dependencies used under the src/splifft/models directory. Higher level components, e.g. inference, training or CLI components must be installed via optional depedencies, as specified in the project.optional-dependencies section of pyproject.toml, for example:

# enable the built-in configuration, inference and CLI
uv add "splifft[config,inference,cli]"

This will install splifft in your venv.

Development

If you'd like to make local changes, it is recommended to enable all optional and developer group dependencies:

git clone https://github.com/undef13/splifft.git
cd splifft
uv venv
uv sync --all-extras --all-groups

You may also want to use --editable with sync. Check your code:

# lint
uv run ruff check src tests
# format
uv run ruff format --check src tests
# build & host documentation
uv run mkdocs serve
# type check
uv run mypy src tests

This repo is no longer compatible with zfturbo's repo. The last version that does so is v0.0.1. To pin a specific version in uv, change your pyproject.toml:

[tool.uv.sources]
splifft = { git = "https://github.com/undef13/splifft.git", rev = "287235e520f3bb927b58f9f53749fe3ccc248fac" }

Mojo

While the primary goal is just to have minimalist PyTorch-based inference engine, I will be using this project as an opportunity to learn more about heterogenous computing, particularly with the Mojo language. The ultimate goal will be to understand to what extent can its compile-time metaprogramming and explicit memory layout control be used.

My approach will be incremental and bottom-up: I'll develop, test and benchmark small components against their PyTorch counterparts. The PyTorch implementation will always remain the "source of truth", the fully functional baseline and never be removed.

TODO:

  • evaluate pixi in pyproject.toml.
  • use max.torch.CustomOpLibrary to provide a callable from the pytorch side
  • use DeviceContext to interact with the GPU
  • attention
  • use LayoutTensor for QKV
  • rotary embedding
  • feedforward
  • transformer
  • BandSplit & MaskEstimator
  • full graph compilation