71 lines
2.0 KiB
Python
71 lines
2.0 KiB
Python
"""Bar phase (beat-in-bar) tracking for audio beat detection."""
|
|
|
|
import os
|
|
import sys
|
|
|
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
if PROJECT_ROOT not in sys.path:
|
|
sys.path.insert(0, PROJECT_ROOT)
|
|
|
|
from tests.beat_detect import BarPhaseTracker # noqa: E402
|
|
|
|
|
|
def test_bar_phase_increments_on_non_kick_beats():
|
|
tr = BarPhaseTracker(beats_per_bar=4)
|
|
r1 = tr.on_beat(1.0, "snare", 1.3, bpm=120.0)
|
|
assert r1["bar_beat"] == 1
|
|
r2 = tr.on_beat(1.5, "snare", 1.2, bpm=120.0)
|
|
assert r2["bar_beat"] == 2
|
|
r3 = tr.on_beat(2.0, "hat", 1.1, bpm=120.0)
|
|
assert r3["bar_beat"] == 3
|
|
|
|
|
|
def test_kick_near_bar_boundary_resets_to_downbeat():
|
|
tr = BarPhaseTracker(beats_per_bar=4)
|
|
tr.on_beat(0.0, "kick", 1.4, bpm=120.0)
|
|
tr.on_beat(0.5, "snare", 1.2, bpm=120.0)
|
|
tr.on_beat(1.0, "snare", 1.2, bpm=120.0)
|
|
tr.on_beat(1.5, "snare", 1.2, bpm=120.0)
|
|
r = tr.on_beat(2.0, "kick", 1.5, bpm=120.0)
|
|
assert r["bar_beat"] == 1
|
|
assert r["is_downbeat"] is True
|
|
|
|
|
|
def test_anchor_downbeat_sets_confidence():
|
|
tr = BarPhaseTracker(beats_per_bar=4)
|
|
tr.anchor_downbeat(10.0)
|
|
assert tr.bar_beat == 1
|
|
assert tr.confidence >= 0.85
|
|
|
|
|
|
def test_reset_tempo_preserves_bar_phase():
|
|
from argparse import Namespace
|
|
|
|
from tests.beat_detect import BeatDetectRuntime # noqa: E402
|
|
|
|
args = Namespace(
|
|
mode="custom",
|
|
hop_size=256,
|
|
win_mult=2,
|
|
min_band_hz=45.0,
|
|
max_band_hz=180.0,
|
|
energy_weight=0.7,
|
|
flux_weight=0.3,
|
|
threshold_multiplier=1.35,
|
|
ema_alpha=0.08,
|
|
min_ioi_ms=100.0,
|
|
bpm_window=8,
|
|
aubio_method="default",
|
|
aubio_threshold=0.12,
|
|
beats_per_bar=4,
|
|
)
|
|
rt = BeatDetectRuntime(args)
|
|
rt.setup(44100)
|
|
rt.bar_phase.on_beat(0.0, "kick", 1.5, bpm=120.0)
|
|
rt.bar_phase.on_beat(0.5, "snare", 1.2, bpm=120.0)
|
|
assert rt.bar_phase.bar_beat == 2
|
|
rt.reset_tempo_state()
|
|
assert rt.bar_phase.bar_beat == 2
|
|
rt.reset_state()
|
|
assert rt.bar_phase.bar_beat == 1
|