Files
led-controller/tests/beat_detect.py

546 lines
18 KiB
Python

#!/usr/bin/env python3
"""Live beat detection utility with custom/aubio/hybrid modes."""
from __future__ import annotations
import argparse
import collections
import queue
import sys
import time
from typing import Deque
try:
import numpy as np
except ImportError as exc:
raise SystemExit(
"Missing dependency: numpy. Install with `pip install numpy`."
) from exc
try:
import sounddevice as sd
except ImportError as exc:
raise SystemExit(
"Missing dependency: sounddevice. Install with `pip install sounddevice`."
) from exc
try:
import requests
except ImportError:
requests = None
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Beat detector utility")
parser.add_argument(
"--mode",
choices=("custom", "aubio", "hybrid"),
default="aubio",
help="Detection mode",
)
parser.add_argument("--device", default=None, help="Input device name or index")
parser.add_argument(
"--sample-rate",
type=int,
default=0,
help="Audio sample rate (0 = use selected device default)",
)
parser.add_argument("--hop-size", type=int, default=256, help="Frame hop size in samples")
parser.add_argument("--win-mult", type=int, default=2, help="Aubio window size multiplier")
parser.add_argument(
"--min-band-hz",
type=float,
default=45.0,
help="Low frequency bound used for beat energy",
)
parser.add_argument(
"--max-band-hz",
type=float,
default=180.0,
help="High frequency bound used for beat energy",
)
parser.add_argument(
"--energy-weight",
type=float,
default=0.7,
help="Weight for low-band energy component (0..1)",
)
parser.add_argument(
"--flux-weight",
type=float,
default=0.3,
help="Weight for spectral flux component (0..1)",
)
parser.add_argument(
"--threshold-multiplier",
type=float,
default=1.35,
help="Custom-mode threshold multiplier vs adaptive baseline",
)
parser.add_argument(
"--ema-alpha",
type=float,
default=0.08,
help="Adaptive baseline smoothing (higher reacts faster)",
)
parser.add_argument(
"--min-ioi-ms",
type=float,
default=85.0,
help="Minimum time between beats in milliseconds",
)
parser.add_argument(
"--bpm-window",
type=int,
default=8,
help="How many recent beat intervals to use for BPM estimate",
)
parser.add_argument(
"--post-url",
default="",
help="Optional HTTP URL to POST beat events",
)
parser.add_argument(
"--aubio-method",
default="default",
choices=("default", "specdiff", "hfc", "complex", "phase", "energy"),
help="Aubio tempo method",
)
parser.add_argument(
"--aubio-threshold",
type=float,
default=0.12,
help="Aubio detection threshold",
)
return parser.parse_args()
def _estimate_bpm(beat_times: Deque[float]) -> float | None:
if len(beat_times) < 3:
return None
intervals = np.diff(np.array(beat_times, dtype=np.float64))
valid = intervals[(intervals > 0.2) & (intervals < 2.0)]
if valid.size == 0:
return None
return 60.0 / float(np.median(valid))
def _is_plausible_ioi(
last_trigger_s: float,
beat_times: Deque[float],
now_s: float,
*,
min_ratio: float = 0.42,
max_ratio: float = 2.5,
) -> bool:
"""Reject double-time / half-time false triggers vs recent median interval."""
if last_trigger_s <= 0 or len(beat_times) < 2:
return True
ioi = now_s - last_trigger_s
if ioi <= 0:
return False
intervals = np.diff(np.array(list(beat_times)[-8:], dtype=np.float64))
if intervals.size == 0:
return True
med = float(np.median(intervals))
if med < 0.05:
return True
return (ioi >= med * min_ratio) and (ioi <= med * max_ratio)
class BarPhaseTracker:
"""Track beat-in-bar from downbeat counting (kick hints)."""
def __init__(self, beats_per_bar: int = 4, kick_conf_min: float = 1.15):
self.beats_per_bar = max(1, int(beats_per_bar))
self.kick_conf_min = float(kick_conf_min)
self.bar_beat = 1
self.is_downbeat = True
self.confidence = 0.0
self._last_downbeat_s = 0.0
self._aligned_kicks = 0
self._total_beats = 0
def reset(self) -> None:
self.bar_beat = 1
self.is_downbeat = True
self.confidence = 0.0
self._last_downbeat_s = 0.0
self._aligned_kicks = 0
self._total_beats = 0
def anchor_downbeat(self, now_s: float) -> None:
self.bar_beat = 1
self.is_downbeat = True
self._last_downbeat_s = float(now_s)
self.confidence = max(self.confidence, 0.85)
def _bar_duration_s(
self, bpm: float | None, median_ioi: float | None
) -> float | None:
if bpm is not None and bpm > 0:
return (60.0 / float(bpm)) * self.beats_per_bar
if median_ioi is not None and median_ioi > 0:
return float(median_ioi) * self.beats_per_bar
return None
@staticmethod
def _near_whole_bars(elapsed: float, bar_dur: float, tol: float = 0.14) -> bool:
if bar_dur <= 0 or elapsed <= 0:
return False
n = elapsed / bar_dur
nearest = max(1, round(n))
return abs(n - nearest) <= tol
def on_beat(
self,
now_s: float,
beat_type: str,
beat_type_conf: float,
*,
bpm: float | None = None,
median_ioi: float | None = None,
) -> dict[str, int | float | bool | str]:
self._total_beats += 1
bar_dur = self._bar_duration_s(bpm, median_ioi)
is_kick = (
str(beat_type or "").lower() == "kick"
and float(beat_type_conf or 0.0) >= self.kick_conf_min
)
downbeat_locked = False
if is_kick:
if self._last_downbeat_s <= 0 or self._total_beats <= 2:
downbeat_locked = True
elif bar_dur and self._near_whole_bars(
now_s - self._last_downbeat_s, bar_dur
):
downbeat_locked = True
elif is_kick and self.bar_beat >= max(2, self.beats_per_bar - 1):
downbeat_locked = True
prev_bar_beat = int(self.bar_beat)
if downbeat_locked:
self.bar_beat = 1
self.is_downbeat = True
self._last_downbeat_s = float(now_s)
self._aligned_kicks += 1
elif self._total_beats <= 1:
self.bar_beat = 1
self.is_downbeat = True
else:
self.bar_beat = (prev_bar_beat % self.beats_per_bar) + 1
self.is_downbeat = self.bar_beat == 1
if self._total_beats >= self.beats_per_bar:
bars_seen = max(1, self._total_beats // self.beats_per_bar)
self.confidence = min(1.0, self._aligned_kicks / bars_seen)
return {
"bar_beat": int(self.bar_beat),
"beats_per_bar": int(self.beats_per_bar),
"is_downbeat": bool(self.is_downbeat),
"phase_confidence": round(float(self.confidence), 3),
"bar_phase_readout": f"{int(self.bar_beat)}/{int(self.beats_per_bar)}",
}
def _resolve_bpm(
beat_times: Deque[float],
aubio_bpm: float | None,
) -> float | None:
estimated = _estimate_bpm(beat_times)
if estimated is None:
return aubio_bpm
if aubio_bpm is None or aubio_bpm <= 0:
return estimated
ratio = float(aubio_bpm) / estimated
if ratio > 1.75 or ratio < 0.57:
return estimated
return estimated
def _load_aubio_if_needed(mode: str):
if mode == "custom":
return None
try:
import aubio
return aubio
except ImportError:
dist_packages = "/usr/lib/python3/dist-packages"
if dist_packages not in sys.path:
sys.path.append(dist_packages)
try:
import aubio
return aubio
except ImportError:
raise SystemExit("aubio not installed; use --mode custom or install aubio")
class BeatDetectRuntime:
"""Reusable detector runtime so web and CLI can share logic."""
def __init__(self, args):
self.args = args
self.aubio = _load_aubio_if_needed(args.mode)
self.sample_rate = 0
self.frame_size = 0
self.tempo = None
self.band_mask = None
self.freqs = None
self.window = None
self.prev_mag = None
self.kick_mask = None
self.snare_mask = None
self.hat_mask = None
self.baseline = 1e-6
self.beat_times: Deque[float] = collections.deque(
maxlen=max(2, args.bpm_window + 1)
)
self.last_trigger_s = 0.0
self.debounce_s = float(args.min_ioi_ms) / 1000.0
bpb = int(getattr(args, "beats_per_bar", 4) or 4)
self.bar_phase = BarPhaseTracker(beats_per_bar=bpb)
def setup(self, sample_rate: int):
self.sample_rate = int(sample_rate)
self.frame_size = max(128, int(self.args.hop_size))
win_size = max(1024, self.frame_size * max(2, self.args.win_mult))
freqs = np.fft.rfftfreq(self.frame_size, d=1.0 / self.sample_rate)
self.freqs = freqs
self.band_mask = (freqs >= self.args.min_band_hz) & (
freqs <= self.args.max_band_hz
)
self.kick_mask = (freqs >= 40.0) & (freqs <= 140.0)
self.snare_mask = (freqs >= 140.0) & (freqs <= 3000.0)
self.hat_mask = (freqs >= 5000.0) & (freqs <= 12000.0)
if not np.any(self.band_mask):
raise ValueError("Invalid band range for current sample rate")
self.window = np.hanning(self.frame_size).astype(np.float32)
self.prev_mag = np.zeros(freqs.shape[0], dtype=np.float32)
self.baseline = 1e-6
self.last_trigger_s = 0.0
self.beat_times.clear()
self.tempo = None
if self.aubio is not None:
self._init_aubio_tempo(win_size)
def _init_aubio_tempo(self, win_size: int):
self.tempo = self.aubio.tempo(
self.args.aubio_method, win_size, self.frame_size, self.sample_rate
)
if hasattr(self.tempo, "set_threshold"):
self.tempo.set_threshold(float(self.args.aubio_threshold))
if hasattr(self.tempo, "set_minioi_ms"):
self.tempo.set_minioi_ms(float(self.args.min_ioi_ms))
def reset_tempo_state(self) -> None:
"""Clear tempo/aubio history without losing bar phase."""
self.baseline = 1e-6
if self.prev_mag is not None:
self.prev_mag[:] = 0.0
self.beat_times.clear()
self.last_trigger_s = 0.0
if self.aubio is not None and self.sample_rate > 0:
win_size = max(1024, self.frame_size * max(2, self.args.win_mult))
self._init_aubio_tempo(win_size)
def reset_state(self):
"""Full reset (manual): tempo history and bar phase."""
self.reset_tempo_state()
self.bar_phase.reset()
def anchor_bar_phase(self, now_s: float | None = None) -> None:
if now_s is None:
now_s = time.time()
self.bar_phase.anchor_downbeat(now_s)
def _classify_hit(self, mag: np.ndarray):
total = float(np.mean(mag) + 1e-9)
kick = float(np.mean(mag[self.kick_mask])) / total if np.any(self.kick_mask) else 0.0
snare = float(np.mean(mag[self.snare_mask])) / total if np.any(self.snare_mask) else 0.0
hat = float(np.mean(mag[self.hat_mask])) / total if np.any(self.hat_mask) else 0.0
scores = {
"kick": kick,
"snare": snare,
"hat": hat,
}
label, value = max(scores.items(), key=lambda kv: kv[1])
if value < 1.15:
return "unknown", value
return label, value
def process_frame(self, frame: np.ndarray, now_s: float | None = None):
if self.window is None or self.band_mask is None:
raise RuntimeError("Runtime not setup")
if frame.shape[0] != self.frame_size:
if frame.shape[0] > self.frame_size:
frame = frame[: self.frame_size]
else:
frame = np.pad(frame, (0, self.frame_size - frame.shape[0]))
f32 = frame.astype(np.float32)
rms = float(np.sqrt(np.mean(f32 * f32) + 1e-12))
db = 20.0 * np.log10(max(rms, 1e-12))
mag = np.abs(np.fft.rfft(f32 * self.window)).astype(np.float32)
band_energy = float(np.mean(mag[self.band_mask]))
flux = float(np.mean(np.maximum(0.0, mag - self.prev_mag)))
self.prev_mag[:] = mag
weight_sum = max(1e-6, self.args.energy_weight + self.args.flux_weight)
score = ((self.args.energy_weight * band_energy) + (self.args.flux_weight * flux)) / weight_sum
self.baseline = ((1.0 - self.args.ema_alpha) * self.baseline) + (
self.args.ema_alpha * score
)
threshold = self.baseline * self.args.threshold_multiplier
custom_hit = score > threshold
aubio_hit = False
aubio_bpm = None
if self.tempo is not None:
aubio_hit = bool(self.tempo(f32)[0])
val = float(self.tempo.get_bpm())
aubio_bpm = val if val > 0 else None
if now_s is None:
now_s = time.time()
if (now_s - self.last_trigger_s) < self.debounce_s:
return None
if self.args.mode == "custom":
should_trigger = custom_hit
elif self.args.mode == "aubio":
should_trigger = aubio_hit
else:
should_trigger = custom_hit or aubio_hit
if should_trigger and not _is_plausible_ioi(
self.last_trigger_s, self.beat_times, now_s
):
should_trigger = False
if not should_trigger:
return None
self.last_trigger_s = now_s
self.beat_times.append(now_s)
bpm = _resolve_bpm(self.beat_times, aubio_bpm)
strength = score / max(1e-9, self.baseline)
beat_type, beat_type_conf = self._classify_hit(mag)
median_ioi = None
if len(self.beat_times) >= 2:
intervals = np.diff(np.array(self.beat_times, dtype=np.float64))
if intervals.size > 0:
median_ioi = float(np.median(intervals))
phase = self.bar_phase.on_beat(
now_s,
beat_type,
beat_type_conf,
bpm=bpm,
median_ioi=median_ioi,
)
if self.args.mode == "custom":
src = "custom"
elif self.args.mode == "aubio":
src = "aubio"
elif custom_hit and aubio_hit:
src = "both"
elif custom_hit:
src = "custom"
else:
src = "aubio"
return {
"ts": now_s,
"bpm": bpm,
"src": src,
"score": score,
"threshold": threshold,
"strength": strength,
"beat_type": beat_type,
"beat_type_confidence": beat_type_conf,
"db": db,
**phase,
}
def main() -> int:
args = parse_args()
runtime = BeatDetectRuntime(args)
if args.post_url and requests is None:
raise SystemExit("`requests` is required for --post-url (pip install requests)")
if args.sample_rate > 0:
sample_rate = args.sample_rate
else:
dev_info = sd.query_devices(args.device, "input")
sample_rate = int(dev_info["default_samplerate"])
runtime.setup(sample_rate=sample_rate)
frame_size = runtime.frame_size
audio_q: "queue.Queue[np.ndarray]" = queue.Queue(maxsize=64)
def audio_callback(indata, frames, _time_info, status):
_ = frames
if status:
print(f"audio status: {status}")
mono = np.asarray(indata[:, 0], dtype=np.float32)
if not audio_q.full():
audio_q.put_nowait(mono)
print(
"Listening... Ctrl+C to stop. "
f"mode={args.mode} sr={sample_rate} hop={frame_size} "
f"band={args.min_band_hz:.0f}-{args.max_band_hz:.0f}Hz "
f"custom_th={args.threshold_multiplier:.2f} aubio_th={args.aubio_threshold:.2f} "
f"min_ioi={args.min_ioi_ms:.0f}ms"
)
with sd.InputStream(
device=args.device,
channels=1,
samplerate=sample_rate,
blocksize=frame_size,
callback=audio_callback,
):
try:
while True:
try:
frame = audio_q.get(timeout=0.1)
except queue.Empty:
continue
if frame.shape[0] != frame_size:
if frame.shape[0] > frame_size:
frame = frame[:frame_size]
else:
frame = np.pad(frame, (0, frame_size - frame.shape[0]))
event = runtime.process_frame(frame, now_s=time.time())
if event is None:
continue
now_s = event["ts"]
bpm = event["bpm"]
bpm_text = f"{bpm:.1f}" if isinstance(bpm, (float, int)) else "--"
src = event["src"]
print(
f"[{args.mode}] BEAT bpm={bpm_text} src={src} type={event['beat_type']} "
f"type_conf={event['beat_type_confidence']:.2f} strength={event['strength']:.2f} "
f"db={event['db']:.1f} "
f"score={event['score']:.3e} threshold={event['threshold']:.3e}"
)
if args.post_url and requests is not None:
try:
requests.post(
args.post_url,
json={"beat": True, "source": src, "ts": now_s, "bpm": bpm},
timeout=0.5,
)
except Exception as exc:
print(f"post failed: {exc}")
except KeyboardInterrupt:
print("\nStopped.")
return 0
if __name__ == "__main__":
raise SystemExit(main())