refactor(api): complete fastapi migration and related features
Finish native FastAPI controllers, drop vendored microdot, and add Wi-Fi driver runtime, beat SSE, simulated BPM, sequence playback improvements, bridge ESP-NOW sources, UI updates, and tests. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -22,7 +22,7 @@ Tests for the LED Controller project live under **`tests/`** (pytest + legacy sc
|
||||
| `udp_server.py` | UDP discovery / hello test listener (port **8766**) |
|
||||
| `bridge_broadcast_test.py` | Manual bridge WebSocket broadcast script |
|
||||
| `ws.py` | WebSocket client checks |
|
||||
| `web.py` | Local dev static server (not the main app) |
|
||||
| `web.py` | Local dev server on port 5000 (`pipenv run web`) |
|
||||
| `conftest.py` | Pytest fixtures |
|
||||
| `models/` | Model unit tests (`run_all.py`, `test_zone.py`, …) |
|
||||
|
||||
@@ -50,6 +50,6 @@ Requires **Selenium**, Chrome/Chromium, and a matching **ChromeDriver**.
|
||||
python tests/models/run_all.py
|
||||
```
|
||||
|
||||
### Local static server
|
||||
### Local dev server (port 5000)
|
||||
|
||||
`tests/web.py` serves files for quick UI experiments; it is **not** the Microdot app. For the real server use **`pipenv run run`** from the repo root.
|
||||
`pipenv run web` runs the FastAPI app on **http://localhost:5000** (production-style default is **`pipenv run run`** on port 80).
|
||||
|
||||
@@ -14,9 +14,8 @@ from starlette.testclient import TestClient
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC_PATH = PROJECT_ROOT / "src"
|
||||
LIB_PATH = PROJECT_ROOT / "lib"
|
||||
|
||||
for p in (str(PROJECT_ROOT), str(LIB_PATH), str(SRC_PATH)):
|
||||
for p in (str(PROJECT_ROOT), str(SRC_PATH)):
|
||||
if p in sys.path:
|
||||
sys.path.remove(p)
|
||||
sys.path.insert(0, p)
|
||||
@@ -81,7 +80,7 @@ def server(monkeypatch, tmp_path_factory):
|
||||
tmp_db_dir = tmp_root / "db"
|
||||
tmp_settings_file = tmp_root / "settings.json"
|
||||
|
||||
for p in (str(SRC_PATH), str(LIB_PATH), str(PROJECT_ROOT)):
|
||||
for p in (str(SRC_PATH), str(PROJECT_ROOT)):
|
||||
if p in sys.path:
|
||||
sys.path.remove(p)
|
||||
sys.path.insert(0, p)
|
||||
|
||||
@@ -115,14 +115,36 @@ def parse_args() -> argparse.Namespace:
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _clamp_detected_bpm(bpm: float | None) -> float | None:
|
||||
if bpm is None:
|
||||
return None
|
||||
try:
|
||||
from util.bpm_limits import clamp_bpm_optional
|
||||
|
||||
return clamp_bpm_optional(bpm)
|
||||
except ImportError:
|
||||
v = float(bpm)
|
||||
if v <= 0:
|
||||
return None
|
||||
return max(60.0, min(200.0, v))
|
||||
|
||||
|
||||
def _estimate_bpm(beat_times: Deque[float]) -> float | None:
|
||||
if len(beat_times) < 3:
|
||||
return None
|
||||
try:
|
||||
from util.bpm_limits import max_beat_interval_s, min_beat_interval_s
|
||||
|
||||
ioi_min = min_beat_interval_s()
|
||||
ioi_max = max_beat_interval_s()
|
||||
except ImportError:
|
||||
ioi_min = 0.3
|
||||
ioi_max = 1.0
|
||||
intervals = np.diff(np.array(beat_times, dtype=np.float64))
|
||||
valid = intervals[(intervals > 0.2) & (intervals < 2.0)]
|
||||
valid = intervals[(intervals >= ioi_min) & (intervals <= ioi_max)]
|
||||
if valid.size == 0:
|
||||
return None
|
||||
return 60.0 / float(np.median(valid))
|
||||
return _clamp_detected_bpm(60.0 / float(np.median(valid)))
|
||||
|
||||
|
||||
def _is_plausible_ioi(
|
||||
@@ -131,7 +153,7 @@ def _is_plausible_ioi(
|
||||
now_s: float,
|
||||
*,
|
||||
min_ratio: float = 0.42,
|
||||
max_ratio: float = 2.5,
|
||||
max_ratio: float = 3.5,
|
||||
) -> bool:
|
||||
"""Reject double-time / half-time false triggers vs recent median interval."""
|
||||
if last_trigger_s <= 0 or len(beat_times) < 2:
|
||||
@@ -251,7 +273,7 @@ def _resolve_bpm(
|
||||
) -> float | None:
|
||||
estimated = _estimate_bpm(beat_times)
|
||||
if estimated is None:
|
||||
return aubio_bpm
|
||||
return _clamp_detected_bpm(aubio_bpm)
|
||||
if aubio_bpm is None or aubio_bpm <= 0:
|
||||
return estimated
|
||||
ratio = float(aubio_bpm) / estimated
|
||||
@@ -400,7 +422,7 @@ class BeatDetectRuntime:
|
||||
if self.tempo is not None:
|
||||
aubio_hit = bool(self.tempo(f32)[0])
|
||||
val = float(self.tempo.get_bpm())
|
||||
aubio_bpm = val if val > 0 else None
|
||||
aubio_bpm = _clamp_detected_bpm(val if val > 0 else None)
|
||||
|
||||
if now_s is None:
|
||||
now_s = time.time()
|
||||
|
||||
@@ -5,11 +5,10 @@ pytest_plugins = ["api_server"]
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC_PATH = PROJECT_ROOT / "src"
|
||||
LIB_PATH = PROJECT_ROOT / "lib"
|
||||
|
||||
# Last insert(0) wins: order must be (root, lib, src) so src/models wins over
|
||||
# Last insert(0) wins: order must be (root, src) so src/models wins over
|
||||
# tests/models (same package name "models" on sys.path when pytest imports tests).
|
||||
for p in (str(PROJECT_ROOT), str(LIB_PATH), str(SRC_PATH)):
|
||||
for p in (str(PROJECT_ROOT), str(SRC_PATH)):
|
||||
if p in sys.path:
|
||||
sys.path.remove(p)
|
||||
sys.path.insert(0, p)
|
||||
|
||||
@@ -27,7 +27,6 @@ def test_sequence():
|
||||
assert sequence["lanes"] == [[]]
|
||||
assert sequence.get("lanes_group_ids") == [[]]
|
||||
assert sequence.get("advance_mode") == "beats"
|
||||
assert sequence.get("simulated_bpm") == 120
|
||||
assert sequence["step_duration_ms"] == 3000
|
||||
assert sequence["loop"] is True
|
||||
assert sequence.get("sequence_transition") == 500
|
||||
@@ -43,7 +42,6 @@ def test_sequence():
|
||||
"step_duration_ms": 5000,
|
||||
"loop": True,
|
||||
"advance_mode": "beats",
|
||||
"simulated_bpm": 128,
|
||||
}
|
||||
result = sequences.update(sequence_id, update_data)
|
||||
assert result is True
|
||||
@@ -58,7 +56,6 @@ def test_sequence():
|
||||
assert len(updated["lanes"][0]) == 2
|
||||
assert updated["lanes"][0][0]["beats"] == 2
|
||||
assert updated.get("advance_mode") == "beats"
|
||||
assert updated.get("simulated_bpm") == 128
|
||||
assert updated["step_duration_ms"] == 5000
|
||||
assert updated["loop"] is True
|
||||
|
||||
|
||||
@@ -1,23 +1,18 @@
|
||||
"""Audio input device_select persistence (Pulse name must survive start)."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from fastapi import FastAPI
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC_PATH = PROJECT_ROOT / "src"
|
||||
if str(SRC_PATH) not in sys.path:
|
||||
sys.path.insert(0, str(SRC_PATH))
|
||||
|
||||
from microdot import Microdot # noqa: E402
|
||||
from util.audio_run_persist import read_audio_run_state, write_audio_run_state # noqa: E402
|
||||
|
||||
SNOWBALL = (
|
||||
@@ -25,28 +20,6 @@ SNOWBALL = (
|
||||
)
|
||||
|
||||
|
||||
def _start_app(app: Microdot, port: int = 0):
|
||||
def runner():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(app.start_server(host="127.0.0.1", port=port))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
thread = threading.Thread(target=runner, daemon=True)
|
||||
thread.start()
|
||||
deadline = time.time() + 5.0
|
||||
while time.time() < deadline:
|
||||
server = getattr(app, "server", None)
|
||||
if server and getattr(server, "sockets", None):
|
||||
sockets = server.sockets or []
|
||||
if sockets:
|
||||
return thread, sockets[0].getsockname()[1]
|
||||
time.sleep(0.05)
|
||||
raise RuntimeError("server failed to start")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_run_path(tmp_path, monkeypatch):
|
||||
path = tmp_path / "audio_run.json"
|
||||
@@ -69,12 +42,12 @@ def test_write_start_keeps_pulse_device_select_not_portaudio_index(audio_run_pat
|
||||
|
||||
|
||||
def test_put_device_saves_pulse_name(audio_run_path):
|
||||
app = Microdot()
|
||||
api = FastAPI()
|
||||
|
||||
@app.route("/api/audio/device", methods=["PUT"])
|
||||
async def audio_set_device(request):
|
||||
payload = request.json if isinstance(request.json, dict) else {}
|
||||
device_select = str(payload.get("device_select") or "").strip()
|
||||
@api.put("/api/audio/device")
|
||||
async def audio_set_device(payload: dict | None = None):
|
||||
body = payload if isinstance(payload, dict) else {}
|
||||
device_select = str(body.get("device_select") or "").strip()
|
||||
from util.audio_run_persist import read_audio_run_state, write_audio_run_state
|
||||
|
||||
prev = read_audio_run_state()
|
||||
@@ -86,13 +59,11 @@ def test_put_device_saves_pulse_name(audio_run_path):
|
||||
)
|
||||
return {"ok": True, "audio_run": read_audio_run_state()}
|
||||
|
||||
_, port = _start_app(app)
|
||||
base = f"http://127.0.0.1:{port}"
|
||||
resp = requests.put(
|
||||
f"{base}/api/audio/device",
|
||||
json={"device_select": SNOWBALL, "device_override": ""},
|
||||
timeout=5,
|
||||
)
|
||||
with TestClient(api) as client:
|
||||
resp = client.put(
|
||||
"/api/audio/device",
|
||||
json={"device_select": SNOWBALL, "device_override": ""},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["audio_run"]["device_select"] == SNOWBALL
|
||||
@@ -112,15 +83,15 @@ def test_start_preserves_device_select_in_status(audio_run_path, monkeypatch):
|
||||
fake_resolve,
|
||||
)
|
||||
|
||||
app = Microdot()
|
||||
api = FastAPI()
|
||||
|
||||
@app.route("/api/audio/start", methods=["POST"])
|
||||
async def audio_start(request):
|
||||
payload = request.json if isinstance(request.json, dict) else {}
|
||||
device = payload.get("device", None)
|
||||
@api.post("/api/audio/start")
|
||||
async def audio_start(payload: dict | None = None):
|
||||
body = payload if isinstance(payload, dict) else {}
|
||||
device = body.get("device", None)
|
||||
if device in ("", None):
|
||||
device = None
|
||||
device_select = str(payload.get("device_select") or "").strip()
|
||||
device_select = str(body.get("device_select") or "").strip()
|
||||
if not device_select and device not in ("", None):
|
||||
device_select = str(device).strip()
|
||||
from util.pulse_audio_devices import resolve_capture_device
|
||||
@@ -138,29 +109,26 @@ def test_start_preserves_device_select_in_status(audio_run_path, monkeypatch):
|
||||
st["audio_run"] = read_audio_run_state()
|
||||
return {"ok": True, "status": st}
|
||||
|
||||
@app.route("/api/audio/status")
|
||||
async def audio_status(request):
|
||||
_ = request
|
||||
@api.get("/api/audio/status")
|
||||
async def audio_status():
|
||||
from util.audio_run_persist import read_audio_run_state
|
||||
|
||||
st = detector.status()
|
||||
st["audio_run"] = read_audio_run_state()
|
||||
return {"status": st}
|
||||
|
||||
_, port = _start_app(app)
|
||||
base = f"http://127.0.0.1:{port}"
|
||||
start = requests.post(
|
||||
f"{base}/api/audio/start",
|
||||
json={"device": SNOWBALL, "device_select": SNOWBALL, "device_override": ""},
|
||||
timeout=5,
|
||||
)
|
||||
assert start.status_code == 200, start.text
|
||||
run = start.json()["status"]["audio_run"]
|
||||
assert run["device_select"] == SNOWBALL
|
||||
assert run["device"] == 2
|
||||
with TestClient(api) as client:
|
||||
start = client.post(
|
||||
"/api/audio/start",
|
||||
json={"device": SNOWBALL, "device_select": SNOWBALL, "device_override": ""},
|
||||
)
|
||||
assert start.status_code == 200, start.text
|
||||
run = start.json()["status"]["audio_run"]
|
||||
assert run["device_select"] == SNOWBALL
|
||||
assert run["device"] == 2
|
||||
|
||||
status = requests.get(f"{base}/api/audio/status", timeout=5).json()["status"]
|
||||
assert status["audio_run"]["device_select"] == SNOWBALL
|
||||
status = client.get("/api/audio/status").json()["status"]
|
||||
assert status["audio_run"]["device_select"] == SNOWBALL
|
||||
|
||||
|
||||
def test_pulse_device_list_uses_stable_pulse_ids():
|
||||
|
||||
@@ -9,7 +9,11 @@ SRC_PATH = os.path.join(PROJECT_ROOT, "src")
|
||||
if SRC_PATH not in sys.path:
|
||||
sys.path.insert(0, SRC_PATH)
|
||||
|
||||
from util.audio_detector import AudioBeatDetector # noqa: E402
|
||||
from util.audio_detector import ( # noqa: E402
|
||||
AudioBeatDetector,
|
||||
set_shared_beat_detector,
|
||||
shared_beat_detector_timing_sequences,
|
||||
)
|
||||
|
||||
|
||||
class _FakeRuntime:
|
||||
@@ -83,7 +87,47 @@ def test_silence_gap_starts_holdover_and_resets_tempo_once():
|
||||
det._maybe_recover_after_silence_gap(rt)
|
||||
assert rt.reset_tempo_calls == 1
|
||||
det._record_beat(120.0)
|
||||
assert det._holdover_active is False
|
||||
assert det._holdover_active is True
|
||||
|
||||
|
||||
def test_timing_sequences_true_while_holdover_active():
|
||||
det = AudioBeatDetector()
|
||||
set_shared_beat_detector(det)
|
||||
try:
|
||||
with det._lock:
|
||||
det._running = True
|
||||
det._status["running"] = True
|
||||
det._status["bpm"] = 120.0
|
||||
det._record_beat(120.0)
|
||||
assert det._holdover_active is True
|
||||
assert shared_beat_detector_timing_sequences() is True
|
||||
finally:
|
||||
set_shared_beat_detector(None)
|
||||
|
||||
|
||||
def test_timing_sequences_false_when_running_without_beats():
|
||||
det = AudioBeatDetector()
|
||||
set_shared_beat_detector(det)
|
||||
try:
|
||||
with det._lock:
|
||||
det._running = True
|
||||
det._status["running"] = True
|
||||
assert shared_beat_detector_timing_sequences() is False
|
||||
det._record_beat(120.0)
|
||||
assert shared_beat_detector_timing_sequences() is True
|
||||
det._stop_bpm_holdover()
|
||||
with det._lock:
|
||||
det._last_real_beat_ts = time.time() - 5.0
|
||||
assert shared_beat_detector_timing_sequences() is False
|
||||
finally:
|
||||
set_shared_beat_detector(None)
|
||||
|
||||
|
||||
def test_record_beat_keeps_previous_bpm_when_new_readout_invalid():
|
||||
det = AudioBeatDetector()
|
||||
det._record_beat(128.0)
|
||||
det._record_beat(None)
|
||||
assert det.status()["bpm"] == 128.0
|
||||
|
||||
|
||||
def test_holdover_last_beat_does_not_block_tempo_retry():
|
||||
|
||||
34
tests/test_audio_sse.py
Normal file
34
tests/test_audio_sse.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Server-sent events for audio/beat status."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initial_sse_line_includes_status(monkeypatch):
|
||||
from util import beat_status_broadcaster as bsb
|
||||
|
||||
bsb.configure(
|
||||
loop=asyncio.get_running_loop(),
|
||||
status_builder=lambda: {"bpm_simulated": True, "beat_seq": 3},
|
||||
)
|
||||
line = await bsb.initial_sse_line()
|
||||
assert line.startswith("data: ")
|
||||
payload = json.loads(line[6:])
|
||||
assert payload["type"] == "status"
|
||||
assert payload["status"]["beat_seq"] == 3
|
||||
|
||||
|
||||
def test_audio_events_sse_first_chunk(server):
|
||||
c = server["client"]
|
||||
with c.stream("GET", "/api/audio/events") as resp:
|
||||
assert resp.status_code == 200
|
||||
assert "text/event-stream" in resp.headers.get("content-type", "")
|
||||
chunk = next(resp.iter_bytes())
|
||||
text = chunk.decode("utf-8")
|
||||
assert text.startswith("data: ")
|
||||
payload = json.loads(text.strip().removeprefix("data: "))
|
||||
assert payload.get("type") == "status"
|
||||
assert "bpm_simulated" in payload.get("status", {})
|
||||
@@ -26,3 +26,14 @@ def test_resolve_bpm_prefers_intervals_over_wrong_aubio():
|
||||
bpm = _resolve_bpm(times, 70.0)
|
||||
assert bpm is not None
|
||||
assert abs(bpm - 120.0) < 5.0
|
||||
|
||||
|
||||
def test_resolve_bpm_clamps_runaway_aubio():
|
||||
times = deque([0.0])
|
||||
assert _resolve_bpm(times, 400.0) == 200.0
|
||||
assert _resolve_bpm(times, 999.0) == 200.0
|
||||
|
||||
|
||||
def test_resolve_bpm_clamps_slow_aubio():
|
||||
times = deque([0.0])
|
||||
assert _resolve_bpm(times, 30.0) == 60.0
|
||||
|
||||
45
tests/test_bpm_limits.py
Normal file
45
tests/test_bpm_limits.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""BPM clamp helpers."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
SRC_PATH = os.path.join(PROJECT_ROOT, "src")
|
||||
if SRC_PATH not in sys.path:
|
||||
sys.path.insert(0, SRC_PATH)
|
||||
|
||||
from util.audio_detector import AudioBeatDetector # noqa: E402
|
||||
from util.bpm_limits import (
|
||||
BPM_MAX,
|
||||
BPM_MIN,
|
||||
clamp_bpm,
|
||||
clamp_bpm_optional,
|
||||
max_beat_interval_s,
|
||||
max_beat_min_ioi_ms,
|
||||
min_beat_interval_s,
|
||||
)
|
||||
|
||||
|
||||
def test_clamp_bpm_bounds():
|
||||
assert clamp_bpm(120) == 120.0
|
||||
assert clamp_bpm(400) == float(BPM_MAX)
|
||||
assert clamp_bpm(20) == float(BPM_MIN)
|
||||
|
||||
|
||||
def test_clamp_bpm_optional():
|
||||
assert clamp_bpm_optional(None) is None
|
||||
assert clamp_bpm_optional(0) is None
|
||||
assert clamp_bpm_optional(350) == float(BPM_MAX)
|
||||
|
||||
|
||||
def test_beat_interval_bounds():
|
||||
assert abs(min_beat_interval_s() - 60.0 / BPM_MAX) < 1e-9
|
||||
assert abs(max_beat_interval_s() - 60.0 / BPM_MIN) < 1e-9
|
||||
assert abs(max_beat_min_ioi_ms() - 60_000.0 / BPM_MAX) < 1e-6
|
||||
|
||||
|
||||
def test_status_clamps_high_bpm():
|
||||
det = AudioBeatDetector()
|
||||
with det._lock:
|
||||
det._status["bpm"] = 350.0
|
||||
assert det.status()["bpm"] == float(BPM_MAX)
|
||||
@@ -43,18 +43,32 @@ def test_deliver_json_messages_defaults_broadcast():
|
||||
def __init__(self):
|
||||
self.keys = []
|
||||
|
||||
async def send(self, envelope):
|
||||
async def send(self, envelope, addr=None):
|
||||
del addr
|
||||
devs = envelope.get("dv") or envelope.get("devices") or {}
|
||||
self.keys.extend(devs.keys())
|
||||
return True
|
||||
|
||||
class _Devices:
|
||||
def read(self, mac):
|
||||
return {
|
||||
"id": mac,
|
||||
"name": mac,
|
||||
"transport": "espnow",
|
||||
"address": mac,
|
||||
}
|
||||
|
||||
def items(self):
|
||||
return []
|
||||
|
||||
async def _run():
|
||||
bridge = _Bridge()
|
||||
await deliver_json_messages(
|
||||
bridge,
|
||||
[json.dumps({"v": "1", "select": ["2"]})],
|
||||
["188b0e1560a8", "e8f60a16ea10"],
|
||||
None,
|
||||
_Devices(),
|
||||
delay_s=0,
|
||||
)
|
||||
return bridge.keys
|
||||
|
||||
|
||||
@@ -607,39 +607,45 @@ def test_profiles_ui(browser: BrowserTest) -> bool:
|
||||
return passed == total
|
||||
|
||||
|
||||
def test_mobile_tab_presets_two_columns():
|
||||
def test_mobile_tab_presets_three_columns():
|
||||
"""
|
||||
Verify that the zone preset selecting area shows roughly two preset tiles per row
|
||||
on a phone-sized viewport.
|
||||
On a phone-sized viewport the zone strip is hidden; zones are chosen from the
|
||||
header Zones menu. Preset tiles use a 3-column grid (see style.css).
|
||||
"""
|
||||
bt = BrowserTest(base_url=BASE_URL, headless=True)
|
||||
if not bt.setup():
|
||||
assert False, "Failed to start browser"
|
||||
|
||||
try:
|
||||
# Simulate a mobile viewport
|
||||
bt.driver.set_window_size(400, 800)
|
||||
assert bt.navigate('/'), "Failed to load main page"
|
||||
|
||||
# Click the first zone button to load presets for that zone
|
||||
first_tab = bt.wait_for_element(By.CSS_SELECTOR, '.zone-button', timeout=10)
|
||||
assert first_tab is not None, "No zone buttons found"
|
||||
first_tab.click()
|
||||
# Desktop zone buttons live in .zones-container which is display:none on mobile.
|
||||
WebDriverWait(bt.driver, 10).until(
|
||||
EC.presence_of_element_located(
|
||||
(By.CSS_SELECTOR, '#zones-menu-dropdown .zones-menu-item')
|
||||
)
|
||||
)
|
||||
assert bt.click_element(By.ID, 'zones-menu-btn'), "Failed to open Zones menu"
|
||||
assert bt.click_element(
|
||||
By.CSS_SELECTOR, '#zones-menu-dropdown .zones-menu-item'
|
||||
), "Failed to select zone from mobile menu"
|
||||
_browser_sleep(1)
|
||||
|
||||
container = bt.wait_for_element(By.ID, 'presets-list-zone', timeout=10)
|
||||
assert container is not None, "presets-list-zone not found"
|
||||
|
||||
tiles = bt.driver.find_elements(By.CSS_SELECTOR, '#presets-list-zone .preset-tile-row')
|
||||
# Need at least 2 presets to make this meaningful
|
||||
tiles = bt.driver.find_elements(
|
||||
By.CSS_SELECTOR, '#presets-list-zone .preset-tile-row'
|
||||
)
|
||||
assert len(tiles) >= 2, "Fewer than 2 presets found for zone"
|
||||
|
||||
container_width = container.size['width']
|
||||
first_width = tiles[0].size['width']
|
||||
|
||||
# Each tile should be about half the container width (tolerate some margin)
|
||||
assert 0.4 * container_width <= first_width <= 0.6 * container_width, (
|
||||
f"Preset tile width {first_width} not ~half of container {container_width}"
|
||||
# Three columns on max-width 600px (~one third of the row, minus gaps).
|
||||
assert 0.22 * container_width <= first_width <= 0.42 * container_width, (
|
||||
f"Preset tile width {first_width} not ~third of container {container_width}"
|
||||
)
|
||||
finally:
|
||||
bt.teardown()
|
||||
|
||||
461
tests/test_driver_delivery_wifi.py
Normal file
461
tests/test_driver_delivery_wifi.py
Normal file
@@ -0,0 +1,461 @@
|
||||
"""Tests for dual-transport delivery (ESP-NOW bridge + Wi-Fi WebSocket) and Wi-Fi runtime."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import socket
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC_PATH = PROJECT_ROOT / "src"
|
||||
|
||||
for p in (str(PROJECT_ROOT), str(SRC_PATH)):
|
||||
if p in sys.path:
|
||||
sys.path.remove(p)
|
||||
sys.path.insert(0, p)
|
||||
|
||||
_models = sys.modules.get("models")
|
||||
if _models is not None:
|
||||
_mf = (getattr(_models, "__file__", "") or "").replace("\\", "/")
|
||||
if "/tests/models" in _mf:
|
||||
for key in list(sys.modules):
|
||||
if key == "models" or key.startswith("models."):
|
||||
del sys.modules[key]
|
||||
|
||||
from util.bridge_envelope import BROADCAST_MAC # noqa: E402
|
||||
|
||||
|
||||
class FakeDevices:
|
||||
def __init__(self, docs: Dict[str, Dict[str, Any]]):
|
||||
self._docs = docs
|
||||
|
||||
def read(self, mac: str) -> Optional[Dict[str, Any]]:
|
||||
return self._docs.get(mac)
|
||||
|
||||
def items(self):
|
||||
return self._docs.items()
|
||||
|
||||
|
||||
class RecordingBridge:
|
||||
def __init__(self) -> None:
|
||||
self.envelopes: List[Dict[str, Any]] = []
|
||||
|
||||
async def send(self, data, addr=None):
|
||||
del addr
|
||||
if isinstance(data, dict):
|
||||
self.envelopes.append(data)
|
||||
elif isinstance(data, str):
|
||||
self.envelopes.append(json.loads(data))
|
||||
return True
|
||||
|
||||
def mac_keys(self) -> List[str]:
|
||||
keys: List[str] = []
|
||||
for env in self.envelopes:
|
||||
devs = env.get("dv") or env.get("devices") or {}
|
||||
keys.extend(devs.keys())
|
||||
return keys
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bridge():
|
||||
return RecordingBridge()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def espnow_devices():
|
||||
return FakeDevices(
|
||||
{
|
||||
"188b0e1560a8": {
|
||||
"id": "188b0e1560a8",
|
||||
"name": "esp-a",
|
||||
"transport": "espnow",
|
||||
"address": "188b0e1560a8",
|
||||
},
|
||||
"e8f60a16ea10": {
|
||||
"id": "e8f60a16ea10",
|
||||
"name": "esp-b",
|
||||
"transport": "espnow",
|
||||
"address": "e8f60a16ea10",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixed_devices():
|
||||
return FakeDevices(
|
||||
{
|
||||
"188b0e1560a8": {
|
||||
"id": "188b0e1560a8",
|
||||
"name": "esp-a",
|
||||
"transport": "espnow",
|
||||
"address": "188b0e1560a8",
|
||||
},
|
||||
"102030405060": {
|
||||
"id": "102030405060",
|
||||
"name": "wifi-a",
|
||||
"transport": "wifi",
|
||||
"address": "192.168.50.10",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_wifi_message_for_device_narrows_select():
|
||||
from util.driver_delivery import _wifi_message_for_device
|
||||
|
||||
msg = json.dumps(
|
||||
{"v": "1", "select": {"wifi-a": 0, "esp-a": 1}},
|
||||
separators=(",", ":"),
|
||||
)
|
||||
narrowed = _wifi_message_for_device(msg, "wifi-a")
|
||||
body = json.loads(narrowed)
|
||||
assert body["select"] == {"wifi-a": 0}
|
||||
|
||||
|
||||
def test_combine_preset_chunks_for_wifi():
|
||||
from util.driver_delivery import _combine_preset_chunks_for_wifi
|
||||
|
||||
chunks = [
|
||||
json.dumps({"v": "1", "presets": {"a": {"p": "on"}}}, separators=(",", ":")),
|
||||
json.dumps(
|
||||
{"v": "1", "presets": {"b": {"p": "blink"}}, "save": True, "default": "b"},
|
||||
separators=(",", ":"),
|
||||
),
|
||||
]
|
||||
combined = json.loads(_combine_preset_chunks_for_wifi(chunks))
|
||||
assert combined["presets"]["a"]["p"] == "on"
|
||||
assert combined["presets"]["b"]["p"] == "blink"
|
||||
assert combined["save"] is True
|
||||
assert combined["default"] == "b"
|
||||
|
||||
|
||||
def test_deliver_json_broadcast_espnow_only(bridge, espnow_devices, monkeypatch):
|
||||
from util import driver_delivery
|
||||
|
||||
wifi_sends: list[tuple[str, str]] = []
|
||||
|
||||
async def fake_wifi(ip, msg):
|
||||
wifi_sends.append((ip, msg))
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(driver_delivery, "send_json_line_to_ip", fake_wifi)
|
||||
|
||||
async def _run():
|
||||
return await driver_delivery.deliver_json_messages(
|
||||
bridge,
|
||||
[json.dumps({"v": "1", "select": ["off"]})],
|
||||
None,
|
||||
espnow_devices,
|
||||
delay_s=0,
|
||||
)
|
||||
|
||||
deliveries, n = asyncio.run(_run())
|
||||
assert n == 1
|
||||
assert deliveries >= 1
|
||||
assert bridge.mac_keys() == [BROADCAST_MAC]
|
||||
assert wifi_sends == []
|
||||
|
||||
|
||||
def test_deliver_json_broadcast_includes_wifi(bridge, mixed_devices, monkeypatch):
|
||||
from util import driver_delivery
|
||||
|
||||
wifi_sends: list[tuple[str, str]] = []
|
||||
|
||||
async def fake_wifi(ip, msg):
|
||||
wifi_sends.append((ip, msg))
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(driver_delivery, "send_json_line_to_ip", fake_wifi)
|
||||
|
||||
async def _run():
|
||||
return await driver_delivery.deliver_json_messages(
|
||||
bridge,
|
||||
[json.dumps({"v": "1", "select": ["off"]})],
|
||||
None,
|
||||
mixed_devices,
|
||||
delay_s=0,
|
||||
)
|
||||
|
||||
deliveries, _n = asyncio.run(_run())
|
||||
assert deliveries >= 2
|
||||
assert bridge.mac_keys() == [BROADCAST_MAC]
|
||||
assert len(wifi_sends) == 1
|
||||
assert wifi_sends[0][0] == "192.168.50.10"
|
||||
|
||||
|
||||
def test_deliver_json_targeted_espnow_unicasts(bridge, espnow_devices, monkeypatch):
|
||||
from util import driver_delivery
|
||||
|
||||
monkeypatch.setattr(
|
||||
driver_delivery,
|
||||
"send_json_line_to_ip",
|
||||
AsyncMock(return_value=True),
|
||||
)
|
||||
|
||||
async def _run():
|
||||
return await driver_delivery.deliver_json_messages(
|
||||
bridge,
|
||||
[json.dumps({"v": "1", "select": ["2"]})],
|
||||
["188b0e1560a8", "e8f60a16ea10"],
|
||||
espnow_devices,
|
||||
delay_s=0,
|
||||
)
|
||||
|
||||
asyncio.run(_run())
|
||||
keys = bridge.mac_keys()
|
||||
assert "18:8b:0e:15:60:a8" in keys
|
||||
assert "e8:f6:0a:16:ea:10" in keys
|
||||
assert BROADCAST_MAC not in keys
|
||||
|
||||
|
||||
def test_deliver_json_targeted_wifi_uses_websocket(bridge, mixed_devices, monkeypatch):
|
||||
from util import driver_delivery
|
||||
|
||||
wifi_sends: list[tuple[str, str]] = []
|
||||
|
||||
async def fake_wifi(ip, msg):
|
||||
wifi_sends.append((ip, msg))
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(driver_delivery, "send_json_line_to_ip", fake_wifi)
|
||||
|
||||
async def _run():
|
||||
await driver_delivery.deliver_json_messages(
|
||||
bridge,
|
||||
[json.dumps({"v": "1", "select": {"wifi-a": 0}})],
|
||||
["102030405060"],
|
||||
mixed_devices,
|
||||
delay_s=0,
|
||||
)
|
||||
|
||||
asyncio.run(_run())
|
||||
assert bridge.mac_keys() == []
|
||||
assert len(wifi_sends) == 1
|
||||
assert wifi_sends[0][0] == "192.168.50.10"
|
||||
body = json.loads(wifi_sends[0][1])
|
||||
assert body["select"] == {"wifi-a": 0}
|
||||
|
||||
|
||||
def test_deliver_json_unicast_flag_wifi(bridge, mixed_devices, monkeypatch):
|
||||
from util import driver_delivery
|
||||
|
||||
wifi_sends: list[str] = []
|
||||
|
||||
async def fake_wifi(ip, msg):
|
||||
wifi_sends.append(msg)
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(driver_delivery, "send_json_line_to_ip", fake_wifi)
|
||||
|
||||
async def _run():
|
||||
await driver_delivery.deliver_json_messages(
|
||||
bridge,
|
||||
[json.dumps({"v": "1", "b": 128})],
|
||||
["102030405060"],
|
||||
mixed_devices,
|
||||
delay_s=0,
|
||||
unicast=True,
|
||||
)
|
||||
|
||||
asyncio.run(_run())
|
||||
assert len(wifi_sends) == 1
|
||||
assert bridge.mac_keys() == []
|
||||
|
||||
|
||||
def test_deliver_preset_broadcast_then_per_device_wifi(
|
||||
bridge, mixed_devices, monkeypatch
|
||||
):
|
||||
from util import driver_delivery
|
||||
|
||||
wifi_sends: list[str] = []
|
||||
|
||||
async def fake_wifi(ip, msg):
|
||||
wifi_sends.append(msg)
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(driver_delivery, "send_json_line_to_ip", fake_wifi)
|
||||
|
||||
chunks = [
|
||||
json.dumps(
|
||||
{"v": "1", "presets": {"p1": {"p": "on"}}, "save": True},
|
||||
separators=(",", ":"),
|
||||
)
|
||||
]
|
||||
|
||||
async def _run():
|
||||
return await driver_delivery.deliver_preset_broadcast_then_per_device(
|
||||
bridge,
|
||||
chunks,
|
||||
None,
|
||||
mixed_devices,
|
||||
default_id=None,
|
||||
delay_s=0,
|
||||
)
|
||||
|
||||
count = asyncio.run(_run())
|
||||
assert count >= 2
|
||||
assert bridge.mac_keys() == [BROADCAST_MAC]
|
||||
assert len(wifi_sends) == 1
|
||||
combined = json.loads(wifi_sends[0])
|
||||
assert "p1" in combined["presets"]
|
||||
|
||||
|
||||
def test_deliver_json_requires_bridge(monkeypatch):
|
||||
from util import driver_delivery
|
||||
import models.transport as transport_mod
|
||||
|
||||
monkeypatch.setattr(transport_mod, "get_current_bridge", lambda: None)
|
||||
|
||||
async def _run():
|
||||
with pytest.raises(RuntimeError, match="Transport not configured"):
|
||||
await driver_delivery.deliver_json_messages(
|
||||
None, ["{}"], None, FakeDevices({}), delay_s=0
|
||||
)
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_device_status_broadcaster_send_text():
|
||||
from util.device_status_broadcaster import (
|
||||
_ws_send_text,
|
||||
broadcast_device_tcp_snapshot_to,
|
||||
broadcast_device_tcp_status,
|
||||
register_device_status_ws,
|
||||
unregister_device_status_ws,
|
||||
)
|
||||
|
||||
class StarletteLikeWS:
|
||||
def __init__(self):
|
||||
self.out: list[str] = []
|
||||
|
||||
async def send_text(self, msg: str):
|
||||
self.out.append(msg)
|
||||
|
||||
class SendTextOnlyWS:
|
||||
def __init__(self):
|
||||
self.out: list[str] = []
|
||||
|
||||
async def send(self, msg: str):
|
||||
self.out.append(msg)
|
||||
|
||||
async def _run():
|
||||
starlette = StarletteLikeWS()
|
||||
legacy_ws = SendTextOnlyWS()
|
||||
await _ws_send_text(starlette, '{"ok":true}')
|
||||
await _ws_send_text(legacy_ws, '{"ok":true}')
|
||||
assert starlette.out == ['{"ok":true}']
|
||||
assert legacy_ws.out == ['{"ok":true}']
|
||||
|
||||
await register_device_status_ws(starlette)
|
||||
await broadcast_device_tcp_status("192.168.1.5", True)
|
||||
assert len(starlette.out) == 2
|
||||
status = json.loads(starlette.out[1])
|
||||
assert status["type"] == "device_tcp"
|
||||
assert status["ip"] == "192.168.1.5"
|
||||
assert status["connected"] is True
|
||||
|
||||
await broadcast_device_tcp_snapshot_to(starlette)
|
||||
snapshot = json.loads(starlette.out[2])
|
||||
assert snapshot["type"] == "device_tcp_snapshot"
|
||||
assert "connected_ips" in snapshot
|
||||
|
||||
await unregister_device_status_ws(starlette)
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_process_udp_datagram_registers_and_connects(monkeypatch):
|
||||
from util import wifi_driver_runtime
|
||||
|
||||
registered: list[tuple[str, str, str]] = []
|
||||
connected: list[str] = []
|
||||
|
||||
def fake_register(device_name, peer_ip, mac, device_type=None):
|
||||
del device_type
|
||||
registered.append((device_name, peer_ip, str(mac)))
|
||||
|
||||
monkeypatch.setattr(
|
||||
wifi_driver_runtime,
|
||||
"_register_udp_device_sync",
|
||||
fake_register,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
wifi_driver_runtime.tcp_client_registry,
|
||||
"ensure_driver_connection",
|
||||
lambda ip: connected.append(ip),
|
||||
)
|
||||
|
||||
line = json.dumps(
|
||||
{"v": "1", "device_name": "strip-a", "mac": "aabbccddeeff", "type": "led"}
|
||||
).encode()
|
||||
wifi_driver_runtime._process_udp_datagram(line, "192.168.1.42")
|
||||
assert registered == [("strip-a", "192.168.1.42", "aabbccddeeff")]
|
||||
assert connected == ["192.168.1.42"]
|
||||
|
||||
|
||||
def test_process_udp_datagram_ignores_invalid():
|
||||
from util.wifi_driver_runtime import _process_udp_datagram
|
||||
|
||||
_process_udp_datagram(b"not-json\n", "10.0.0.1")
|
||||
_process_udp_datagram(b'{"v":"1"}\n', "10.0.0.1")
|
||||
|
||||
|
||||
def test_discovery_protocol_uses_datagram_endpoint(monkeypatch):
|
||||
pytest.importorskip("uvloop")
|
||||
import uvloop
|
||||
|
||||
from util.wifi_driver_runtime import _DiscoveryProtocol
|
||||
|
||||
async def _run():
|
||||
echoed: list[bytes] = []
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
holder: dict = {"closing": False}
|
||||
loop = asyncio.get_running_loop()
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
port = sock.getsockname()[1]
|
||||
transport, _protocol = await loop.create_datagram_endpoint(
|
||||
lambda: _DiscoveryProtocol(holder),
|
||||
sock=sock,
|
||||
)
|
||||
|
||||
class _EchoClient(asyncio.DatagramProtocol):
|
||||
def connection_made(self, t):
|
||||
self._transport = t
|
||||
|
||||
def datagram_received(self, data, addr):
|
||||
del addr
|
||||
echoed.append(data)
|
||||
|
||||
client_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
client_sock.bind(("127.0.0.1", 0))
|
||||
client_transport, _ = await loop.create_datagram_endpoint(
|
||||
_EchoClient,
|
||||
sock=client_sock,
|
||||
)
|
||||
payload = b'{"v":"1","device_name":"x","mac":"112233445566"}\n'
|
||||
client_transport.sendto(payload, ("127.0.0.1", port))
|
||||
await asyncio.sleep(0.05)
|
||||
holder["closing"] = True
|
||||
client_transport.close()
|
||||
transport.close()
|
||||
return echoed, payload
|
||||
|
||||
monkeypatch.setattr(
|
||||
"util.wifi_driver_runtime._register_udp_device_sync",
|
||||
lambda *a, **k: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"util.wifi_driver_runtime.tcp_client_registry.ensure_driver_connection",
|
||||
lambda _ip: None,
|
||||
)
|
||||
echoed, payload = asyncio.run(_run())
|
||||
assert echoed == [payload]
|
||||
@@ -55,6 +55,9 @@ def test_main_routes(server):
|
||||
|
||||
with c.websocket_connect("/ws") as ws:
|
||||
ws.send_text('{"v":"1","select":["off"]}')
|
||||
snapshot = ws.receive_json()
|
||||
assert snapshot.get("type") == "device_tcp_snapshot"
|
||||
assert isinstance(snapshot.get("connected_ips"), list)
|
||||
|
||||
|
||||
def test_settings_controller(server):
|
||||
@@ -66,6 +69,9 @@ def test_settings_controller(server):
|
||||
data = resp.json()
|
||||
assert isinstance(data, dict)
|
||||
assert "wifi_channel" in data
|
||||
assert "wifi_driver_ws_port" in data
|
||||
assert "wifi_driver_ws_path" in data
|
||||
assert data.get("wifi_driver_ws_path") == "/ws"
|
||||
|
||||
resp = c.get(f"{base_url}/settings/wifi/ap")
|
||||
assert resp.status_code == 200
|
||||
@@ -183,6 +189,37 @@ def test_profiles_presets_zones_endpoints(server, monkeypatch):
|
||||
assert sent_result["presets_sent"] >= 1
|
||||
assert len(bridge.sent) >= 1
|
||||
|
||||
wifi_sends = []
|
||||
|
||||
async def _fake_wifi_send(ip, msg):
|
||||
wifi_sends.append((ip, msg))
|
||||
return True
|
||||
|
||||
import util.driver_delivery as driver_delivery_mod
|
||||
|
||||
monkeypatch.setattr(driver_delivery_mod, "send_json_line_to_ip", _fake_wifi_send)
|
||||
resp = c.post(
|
||||
f"{base_url}/devices",
|
||||
json={
|
||||
"name": "pytest-wifi-preset",
|
||||
"transport": "wifi",
|
||||
"address": "192.168.50.20",
|
||||
"mac": "203040506070",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
bridge.sent.clear()
|
||||
resp = c.post(
|
||||
f"{base_url}/presets/send",
|
||||
json={"preset_ids": [new_preset_id], "save": False},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(bridge.sent) >= 1
|
||||
assert len(wifi_sends) >= 1
|
||||
assert wifi_sends[0][0] == "192.168.50.20"
|
||||
resp = c.delete(f"{base_url}/devices/203040506070")
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp = c.delete(f"{base_url}/presets/{new_preset_id}")
|
||||
assert resp.status_code == 200
|
||||
resp = c.get(f"{base_url}/presets/{new_preset_id}")
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
import pytest
|
||||
|
||||
pytest.skip("Legacy manual server script (not a pytest suite).", allow_module_level=True)
|
||||
|
||||
from microdot import Microdot
|
||||
from src.profile import profile_app
|
||||
|
||||
app = Microdot()
|
||||
|
||||
@app.route('/')
|
||||
async def index(request):
|
||||
return 'Hello, world!'
|
||||
|
||||
app.mount(profile_app, url_prefix="/profile")
|
||||
|
||||
app.run(port=8080, debug=True)
|
||||
@@ -13,6 +13,56 @@ if SRC_PATH not in sys.path:
|
||||
from util import sequence_playback as sp # noqa: E402
|
||||
|
||||
|
||||
def test_effective_switch_wait_ignores_saved_downbeat_when_audio_off(monkeypatch):
|
||||
class FakeSettings:
|
||||
def get(self, key, default=None):
|
||||
if key == "sequence_switch_wait":
|
||||
return "downbeat"
|
||||
return default
|
||||
|
||||
monkeypatch.setattr("settings.get_settings", lambda: FakeSettings())
|
||||
monkeypatch.setattr(
|
||||
"util.audio_detector.shared_beat_detector_timing_sequences", lambda: False
|
||||
)
|
||||
assert sp.effective_sequence_switch_wait() == "beat"
|
||||
|
||||
|
||||
def test_simulated_mode_forces_beat_switch_wait(monkeypatch):
|
||||
class FakeSettings:
|
||||
def get(self, key, default=None):
|
||||
if key == "sequence_switch_wait":
|
||||
return "downbeat"
|
||||
return default
|
||||
|
||||
monkeypatch.setattr("settings.get_settings", lambda: FakeSettings())
|
||||
monkeypatch.setattr(
|
||||
"util.audio_detector.shared_beat_detector_timing_sequences", lambda: False
|
||||
)
|
||||
assert sp._sequence_switch_wait_from_settings() == "beat"
|
||||
monkeypatch.setattr(
|
||||
"util.audio_detector.shared_beat_detector_timing_sequences", lambda: True
|
||||
)
|
||||
assert sp._sequence_switch_wait_from_settings() == "downbeat"
|
||||
|
||||
|
||||
def test_beat_switch_when_audio_running_but_sim_clocks(monkeypatch):
|
||||
"""Mic on without timing sequences: still beat-only (not downbeat)."""
|
||||
class FakeSettings:
|
||||
def get(self, key, default=None):
|
||||
if key == "sequence_switch_wait":
|
||||
return "downbeat"
|
||||
return default
|
||||
|
||||
monkeypatch.setattr("settings.get_settings", lambda: FakeSettings())
|
||||
monkeypatch.setattr(
|
||||
"util.audio_detector.shared_beat_detector_running", lambda: True
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"util.audio_detector.shared_beat_detector_timing_sequences", lambda: False
|
||||
)
|
||||
assert sp.effective_sequence_switch_wait() == "beat"
|
||||
|
||||
|
||||
def test_normalize_wait_for():
|
||||
assert sp._normalize_wait_for({"wait_for": "beat"}) == "beat"
|
||||
assert sp._normalize_wait_for({"start_on": "downbeat"}) == "downbeat"
|
||||
@@ -37,19 +87,49 @@ def test_queue_and_clear_pending():
|
||||
assert sp.pending_play_status()["pending"] is False
|
||||
|
||||
|
||||
def test_try_consume_pending_beat():
|
||||
def test_try_consume_pending_beat(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"util.audio_detector.shared_beat_detector_timing_sequences", lambda: False
|
||||
)
|
||||
sp.clear_pending_play()
|
||||
sp._queue_pending_start("z1", "s1", "p1", None, "beat", bpm=120.0)
|
||||
|
||||
async def fake_start(*_a, **_k):
|
||||
return None
|
||||
|
||||
sp._start_immediate = fake_start # type: ignore[method-assign]
|
||||
monkeypatch.setattr("util.sequence_playback._start_immediate", fake_start)
|
||||
assert asyncio.run(sp._try_consume_pending_play(is_downbeat=False)) is True
|
||||
assert sp.pending_play_status()["pending"] is False
|
||||
|
||||
|
||||
def test_try_consume_pending_downbeat_skips_upbeat():
|
||||
def test_try_consume_pending_beat_accepts_upbeat(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"util.audio_detector.shared_beat_detector_timing_sequences", lambda: False
|
||||
)
|
||||
sp.clear_pending_play()
|
||||
sp._queue_pending_start("z1", "s1", "p1", None, "beat", bpm=120.0)
|
||||
sp._mark_simulated_beat_phase()
|
||||
sp._last_thread_beat_phase = {"bar_beat": 3, "is_downbeat": False}
|
||||
|
||||
async def fake_start(*_a, **_k):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("util.sequence_playback._start_immediate", fake_start)
|
||||
assert asyncio.run(sp._try_consume_pending_play(is_downbeat=False)) is True
|
||||
sp.clear_pending_play()
|
||||
|
||||
|
||||
def test_try_consume_pending_downbeat_skips_upbeat(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"util.audio_detector.shared_beat_detector_timing_sequences", lambda: True
|
||||
)
|
||||
class FakeSettings:
|
||||
def get(self, key, default=None):
|
||||
if key == "sequence_switch_wait":
|
||||
return "downbeat"
|
||||
return default
|
||||
|
||||
monkeypatch.setattr("settings.get_settings", lambda: FakeSettings())
|
||||
sp.clear_pending_play()
|
||||
sp._queue_pending_start("z1", "s1", "p1", None, "downbeat", bpm=120.0)
|
||||
assert asyncio.run(sp._try_consume_pending_play(is_downbeat=False)) is False
|
||||
@@ -63,12 +143,94 @@ def test_try_consume_pending_downbeat_skips_upbeat():
|
||||
sp.clear_pending_play()
|
||||
|
||||
|
||||
def test_sequence_pass_start_anchors_bar_phase_to_one():
|
||||
sp.stop()
|
||||
sp._sim_beat_counter = 7
|
||||
sp._last_thread_beat_phase = {"bar_beat": 3, "is_downbeat": False}
|
||||
ctx = {
|
||||
"lanes": [[{"preset_id": "1", "beats": 6}]],
|
||||
"lane_states": [{"stepIdx": 0, "beatCount": 0, "done": False}],
|
||||
"num_lanes": 1,
|
||||
"loop": False,
|
||||
"sequence_loop_beat": 0,
|
||||
"presets_map": {},
|
||||
}
|
||||
with sp._beat_run_lock:
|
||||
sp._beat_run = ctx
|
||||
assert sp._is_sequence_pass_start(ctx) is True
|
||||
sp._anchor_bar_phase_for_sequence_start()
|
||||
phase = sp.simulated_beat_phase_snapshot()
|
||||
assert phase["bar_beat"] == 1
|
||||
assert phase["is_downbeat"] is True
|
||||
assert phase["bar_phase_readout"] == "1/4"
|
||||
asyncio.run(sp.process_active_beat_advance())
|
||||
st = sp.playback_status()
|
||||
assert st["beat_readout"] == "1/6"
|
||||
assert sp.simulated_beat_phase_snapshot()["bar_beat"] == 1
|
||||
sp.stop()
|
||||
|
||||
|
||||
def test_sequence_pass_start_not_mid_pass():
|
||||
ctx = {
|
||||
"lanes": [[{"preset_id": "1", "beats": 2}, {"preset_id": "2", "beats": 2}]],
|
||||
"lane_states": [{"stepIdx": 1, "beatCount": 0, "done": False}],
|
||||
"num_lanes": 1,
|
||||
"loop": False,
|
||||
}
|
||||
assert sp._is_sequence_pass_start(ctx) is False
|
||||
|
||||
|
||||
def test_completed_beat_readout_survives_stop_playback():
|
||||
sp.stop()
|
||||
sp.clear_completed_beat_readout()
|
||||
ctx = {
|
||||
"lanes": [[{"preset_id": "1", "beats": 6}]],
|
||||
"lane_states": [{"stepIdx": 0, "beatCount": 6, "done": True}],
|
||||
"num_lanes": 1,
|
||||
"loop": False,
|
||||
"sequence_loop_beat": 6,
|
||||
"presets_map": {},
|
||||
}
|
||||
with sp._beat_run_lock:
|
||||
sp._beat_run = ctx
|
||||
sp.remember_completed_beat_readout(sp._beat_readout_for_ctx(ctx))
|
||||
asyncio.run(sp.stop_playback(clear_devices=False))
|
||||
assert sp.last_completed_beat_readout() == "6/6"
|
||||
assert sp.playback_status()["active"] is False
|
||||
sp.stop()
|
||||
|
||||
|
||||
def test_playback_beat_readout_six_beat_sequence():
|
||||
"""Beat readout is 1..tot with no duplicate 1 at start or missing final beat."""
|
||||
sp.stop()
|
||||
ctx = {
|
||||
"lanes": [[{"preset_id": "1", "beats": 6}]],
|
||||
"lane_states": [{"stepIdx": 0, "beatCount": 0, "done": False}],
|
||||
"num_lanes": 1,
|
||||
"loop": False,
|
||||
"sequence_loop_beat": 0,
|
||||
"presets_map": {},
|
||||
}
|
||||
with sp._beat_run_lock:
|
||||
sp._beat_run = ctx
|
||||
assert sp.playback_status()["beat_readout"] == ""
|
||||
for n in range(1, 5):
|
||||
ctx["lane_states"][0]["beatCount"] = n
|
||||
assert sp.playback_status()["beat_readout"] == f"{n}/6"
|
||||
ctx["lane_states"][0]["beatCount"] = 5
|
||||
assert sp.playback_status()["beat_readout"] == "6/6"
|
||||
ctx["lane_states"][0]["beatCount"] = 6
|
||||
ctx["lane_states"][0]["done"] = True
|
||||
assert sp.playback_status()["beat_readout"] == "6/6"
|
||||
sp.stop()
|
||||
|
||||
|
||||
def test_downbeat_start_counts_trigger_beat(monkeypatch):
|
||||
"""The downbeat that starts playback is beat 1 of the step, not beat 0."""
|
||||
sp.clear_pending_play()
|
||||
sp.stop()
|
||||
|
||||
async def fake_start(_z, _s, _p, _opts):
|
||||
async def fake_start(_z, _s, _p, _opts, **_kwargs):
|
||||
sp._beat_run = {
|
||||
"lanes": [[{"preset_id": "1", "beats": 4}]],
|
||||
"lane_states": [{"stepIdx": 0, "beatCount": 0, "done": False}],
|
||||
|
||||
156
tests/test_simulated_beat_continuity.py
Normal file
156
tests/test_simulated_beat_continuity.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Background simulated beat clock vs live audio."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
SRC_PATH = os.path.join(PROJECT_ROOT, "src")
|
||||
if SRC_PATH not in sys.path:
|
||||
sys.path.insert(0, SRC_PATH)
|
||||
|
||||
from util import sequence_playback as sp # noqa: E402
|
||||
from util.audio_detector import AudioBeatDetector, set_shared_beat_detector # noqa: E402
|
||||
|
||||
|
||||
def _loop_ctx():
|
||||
return {
|
||||
"lanes": [[{"preset_id": "1", "beats": 99}]],
|
||||
"lane_states": [{"stepIdx": 0, "beatCount": 0, "done": False}],
|
||||
"num_lanes": 1,
|
||||
"loop": True,
|
||||
"sequence_loop_beat": 0,
|
||||
}
|
||||
|
||||
|
||||
async def _run_background_beats(*, bpm: float, seconds: float, audio_running: bool) -> int:
|
||||
sp.stop()
|
||||
sp.clear_pending_play()
|
||||
det = AudioBeatDetector()
|
||||
set_shared_beat_detector(det)
|
||||
try:
|
||||
with det._lock:
|
||||
det._running = bool(audio_running)
|
||||
det._status["running"] = bool(audio_running)
|
||||
if audio_running:
|
||||
det._status["bpm"] = float(bpm)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
ctx = _loop_ctx()
|
||||
with sp._beat_run_lock:
|
||||
sp._beat_run = ctx
|
||||
sp._beat_consumer_started = False
|
||||
sp._background_beat_task = None
|
||||
sp.ensure_beat_consumer_started()
|
||||
|
||||
monkeypatch_bpm = bpm
|
||||
|
||||
def fake_bpm():
|
||||
return monkeypatch_bpm
|
||||
|
||||
orig = sp._simulated_bpm_from_settings
|
||||
sp._simulated_bpm_from_settings = fake_bpm # type: ignore[method-assign]
|
||||
try:
|
||||
await asyncio.sleep(seconds)
|
||||
finally:
|
||||
sp._simulated_bpm_from_settings = orig # type: ignore[method-assign]
|
||||
|
||||
with sp._beat_run_lock:
|
||||
st = sp._beat_run["lane_states"][0] if sp._beat_run else {}
|
||||
beat_count = int(st.get("beatCount", 0))
|
||||
tick = sp.simulated_beat_tick()
|
||||
sp.stop()
|
||||
set_shared_beat_detector(None)
|
||||
return beat_count, tick
|
||||
|
||||
|
||||
def test_background_beats_continue_past_four_with_audio_off():
|
||||
beat_count, tick = asyncio.run(
|
||||
_run_background_beats(bpm=200.0, seconds=2.5, audio_running=False)
|
||||
)
|
||||
assert beat_count > 4, f"expected more than 4 beats, got {beat_count}"
|
||||
assert tick > 4, f"expected tick past 4, got {tick}"
|
||||
|
||||
|
||||
def test_background_advances_sequence_when_audio_on_without_beats():
|
||||
beat_count, tick = asyncio.run(
|
||||
_run_background_beats(bpm=200.0, seconds=2.5, audio_running=True)
|
||||
)
|
||||
assert beat_count > 4, f"sim should fill when audio is on but not clocking, got {beat_count}"
|
||||
assert tick > 4, f"background tick should still count, got {tick}"
|
||||
|
||||
|
||||
def test_holdover_fills_beats_between_sparse_real_detections():
|
||||
det = AudioBeatDetector()
|
||||
set_shared_beat_detector(det)
|
||||
try:
|
||||
with det._lock:
|
||||
det._running = True
|
||||
det._status["running"] = True
|
||||
|
||||
async def run():
|
||||
sp.stop()
|
||||
sp.clear_pending_play()
|
||||
ctx = _loop_ctx()
|
||||
with sp._beat_run_lock:
|
||||
sp._beat_run = ctx
|
||||
sp._beat_consumer_started = False
|
||||
sp._background_beat_task = None
|
||||
sp.ensure_beat_consumer_started()
|
||||
|
||||
det._record_beat(120.0)
|
||||
await asyncio.sleep(2.2)
|
||||
with sp._beat_run_lock:
|
||||
beat_count = int(ctx["lane_states"][0].get("beatCount", 0))
|
||||
sp.stop()
|
||||
return beat_count
|
||||
|
||||
beat_count = asyncio.run(run())
|
||||
assert beat_count > 2, f"holdover should advance between kicks, got {beat_count}"
|
||||
finally:
|
||||
set_shared_beat_detector(None)
|
||||
|
||||
|
||||
def test_live_audio_advances_sequence_when_running():
|
||||
det = AudioBeatDetector()
|
||||
set_shared_beat_detector(det)
|
||||
try:
|
||||
with det._lock:
|
||||
det._running = True
|
||||
det._status["running"] = True
|
||||
|
||||
async def run():
|
||||
sp.stop()
|
||||
sp.clear_pending_play()
|
||||
ctx = _loop_ctx()
|
||||
with sp._beat_run_lock:
|
||||
sp._beat_run = ctx
|
||||
sp._beat_consumer_started = False
|
||||
sp._background_beat_task = None
|
||||
sp.ensure_beat_consumer_started()
|
||||
|
||||
gap = sp._min_processed_beat_gap_s() + 0.01
|
||||
for _ in range(8):
|
||||
det._record_beat(400.0)
|
||||
await asyncio.sleep(gap)
|
||||
with sp._beat_run_lock:
|
||||
beat_count = int(ctx["lane_states"][0].get("beatCount", 0))
|
||||
sp.stop()
|
||||
return beat_count
|
||||
|
||||
beat_count = asyncio.run(run())
|
||||
assert beat_count > 4, f"audio should drive sequence, got {beat_count}"
|
||||
finally:
|
||||
set_shared_beat_detector(None)
|
||||
|
||||
|
||||
def test_beat_dedupe_drops_double_fire():
|
||||
sp.stop()
|
||||
sp.clear_pending_play()
|
||||
sp._accept_thread_beat_now()
|
||||
assert sp._accept_thread_beat_now() is False
|
||||
time.sleep(sp._min_processed_beat_gap_s() + 0.02)
|
||||
assert sp._accept_thread_beat_now() is True
|
||||
sp.stop()
|
||||
580
tests/test_simulated_sequence_switch.py
Normal file
580
tests/test_simulated_sequence_switch.py
Normal file
@@ -0,0 +1,580 @@
|
||||
"""Simulated BPM: sequence switching timing and beat regularity."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
SRC_PATH = os.path.join(PROJECT_ROOT, "src")
|
||||
if SRC_PATH not in sys.path:
|
||||
sys.path.insert(0, SRC_PATH)
|
||||
|
||||
from util import sequence_playback as sp # noqa: E402
|
||||
from util.audio_detector import AudioBeatDetector, set_shared_beat_detector # noqa: E402
|
||||
|
||||
|
||||
class _FakeSettings:
|
||||
def __init__(self, **values):
|
||||
self._values = values
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self._values.get(key, default)
|
||||
|
||||
|
||||
def _install_simulated_bpm(monkeypatch, bpm: float, *, sequence_switch_wait: str = "beat"):
|
||||
monkeypatch.setattr(
|
||||
"settings.get_settings",
|
||||
lambda: _FakeSettings(
|
||||
audio_simulated_bpm=bpm,
|
||||
sequence_switch_wait=sequence_switch_wait,
|
||||
),
|
||||
)
|
||||
det = AudioBeatDetector()
|
||||
set_shared_beat_detector(det)
|
||||
monkeypatch.setattr(
|
||||
"util.audio_detector.shared_beat_detector_timing_sequences", lambda: False
|
||||
)
|
||||
|
||||
|
||||
def _beat_timestamps(seconds: float) -> List[float]:
|
||||
async def collect():
|
||||
sp.stop()
|
||||
sp.clear_pending_play()
|
||||
set_shared_beat_detector(None)
|
||||
sp._beat_consumer_started = False
|
||||
sp._background_beat_task = None
|
||||
sp.ensure_beat_consumer_started()
|
||||
|
||||
stamps: List[float] = []
|
||||
last = sp.simulated_beat_tick()
|
||||
deadline = time.monotonic() + seconds
|
||||
while time.monotonic() < deadline:
|
||||
tick = sp.simulated_beat_tick()
|
||||
if tick != last:
|
||||
stamps.append(time.monotonic())
|
||||
last = tick
|
||||
await asyncio.sleep(0.005)
|
||||
sp.stop()
|
||||
return stamps
|
||||
|
||||
return asyncio.run(collect())
|
||||
|
||||
|
||||
def _intervals(stamps: List[float]) -> List[float]:
|
||||
return [stamps[i + 1] - stamps[i] for i in range(len(stamps) - 1)]
|
||||
|
||||
|
||||
def test_effective_switch_wait_is_beat_when_audio_off_even_if_saved_downbeat(monkeypatch):
|
||||
_install_simulated_bpm(monkeypatch, 60.0, sequence_switch_wait="downbeat")
|
||||
assert sp.effective_sequence_switch_wait() == "beat"
|
||||
set_shared_beat_detector(None)
|
||||
|
||||
|
||||
def test_e2e_switch_on_next_beat_while_mic_running_sim_clocks(monkeypatch):
|
||||
"""End-to-end: audio running flag set, sim BPM ticks, switch on next beat not downbeat."""
|
||||
bpm = 120.0
|
||||
_install_simulated_bpm(monkeypatch, bpm, sequence_switch_wait="downbeat")
|
||||
det = AudioBeatDetector()
|
||||
set_shared_beat_detector(det)
|
||||
with det._lock:
|
||||
det._running = True
|
||||
det._status["running"] = True
|
||||
|
||||
sp.stop()
|
||||
sp.clear_pending_play()
|
||||
sp._beat_consumer_started = False
|
||||
sp._background_beat_task = None
|
||||
sp._sim_beat_counter = 0
|
||||
sp._last_thread_beat_phase = {"bar_beat": 1, "is_downbeat": True}
|
||||
|
||||
switch_events: List[tuple] = []
|
||||
|
||||
async def track_start(_z, seq_id, _p, _opts, **_kwargs):
|
||||
phase = sp._beat_phase_from_sources()
|
||||
switch_events.append((time.monotonic(), str(seq_id), int(phase.get("bar_beat") or 0)))
|
||||
|
||||
monkeypatch.setattr(sp, "_start_immediate", track_start)
|
||||
|
||||
async def run():
|
||||
sp.ensure_beat_consumer_started()
|
||||
with sp._beat_run_lock:
|
||||
sp._beat_run = {
|
||||
"lanes": [[{"preset_id": "1", "beats": 99}]],
|
||||
"lane_states": [{"stepIdx": 0, "beatCount": 0, "done": False}],
|
||||
"num_lanes": 1,
|
||||
"loop": True,
|
||||
"sequence_loop_beat": 0,
|
||||
"sequence_id": "1",
|
||||
}
|
||||
sp._mark_simulated_beat_phase()
|
||||
sp._mark_simulated_beat_phase()
|
||||
assert sp._beat_phase_from_sources()["bar_beat"] == 2
|
||||
|
||||
t_queue = time.monotonic()
|
||||
sp._queue_pending_start(
|
||||
"z1", "2", "1", None, sp.effective_sequence_switch_wait(), bpm=bpm
|
||||
)
|
||||
assert sp.pending_play_status()["wait_for"] == "beat"
|
||||
|
||||
beat_interval = 60.0 / bpm
|
||||
for _ in range(6):
|
||||
if not sp.pending_play_status()["pending"]:
|
||||
break
|
||||
sp._mark_simulated_beat_phase()
|
||||
sp.push_thread_beat()
|
||||
await asyncio.sleep(0.05)
|
||||
return t_queue, switch_events, beat_interval
|
||||
|
||||
t_queue, events, beat_interval = asyncio.run(run())
|
||||
assert len(events) == 1, f"expected one switch, got {events}"
|
||||
_t, seq_id, bar_beat = events[0]
|
||||
assert seq_id == "2"
|
||||
assert bar_beat == 3, f"expected switch on next beat (bar 3), got bar {bar_beat}"
|
||||
assert _t - t_queue < beat_interval * 1.1, (
|
||||
f"switch took too long ({_t - t_queue:.2f}s) for {bpm} BPM"
|
||||
)
|
||||
sp.stop()
|
||||
with det._lock:
|
||||
det._running = False
|
||||
det._status["running"] = False
|
||||
set_shared_beat_detector(None)
|
||||
|
||||
|
||||
def test_simulated_beat_intervals_steady_at_60_bpm(monkeypatch):
|
||||
bpm = 60.0
|
||||
_install_simulated_bpm(monkeypatch, bpm)
|
||||
expected = 60.0 / bpm
|
||||
stamps = _beat_timestamps(seconds=5.5)
|
||||
assert len(stamps) >= 4, f"expected several beats, got {len(stamps)}"
|
||||
for gap in _intervals(stamps):
|
||||
assert abs(gap - expected) < 0.12, f"beat gap {gap:.3f}s expected ~{expected:.3f}s"
|
||||
set_shared_beat_detector(None)
|
||||
|
||||
|
||||
def test_simulated_switch_consumes_on_upbeat_not_only_downbeat(monkeypatch):
|
||||
"""With downbeat saved but audio off, switch must happen on the next beat (e.g. bar 3), not bar 1."""
|
||||
_install_simulated_bpm(monkeypatch, 120.0, sequence_switch_wait="downbeat")
|
||||
assert sp.effective_sequence_switch_wait() == "beat"
|
||||
sp.stop()
|
||||
sp.clear_pending_play()
|
||||
sp._sim_beat_counter = 0
|
||||
sp._last_thread_beat_phase = {"bar_beat": 1, "is_downbeat": True}
|
||||
|
||||
consumed_bar_beats: List[int] = []
|
||||
|
||||
async def fake_start(_z, _s, _p, _opts, **_kwargs):
|
||||
phase = sp._beat_phase_from_sources()
|
||||
consumed_bar_beats.append(int(phase.get("bar_beat") or 0))
|
||||
with sp._beat_run_lock:
|
||||
sp._beat_run = {
|
||||
"lanes": [[{"preset_id": "1", "beats": 99}]],
|
||||
"lane_states": [{"stepIdx": 0, "beatCount": 0, "done": False}],
|
||||
"num_lanes": 1,
|
||||
"loop": True,
|
||||
"sequence_loop_beat": 0,
|
||||
"sequence_id": "2",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("util.sequence_playback._start_immediate", fake_start)
|
||||
|
||||
async def run():
|
||||
with sp._beat_run_lock:
|
||||
sp._beat_run = {
|
||||
"lanes": [[{"preset_id": "1", "beats": 99}]],
|
||||
"lane_states": [{"stepIdx": 0, "beatCount": 0, "done": False}],
|
||||
"num_lanes": 1,
|
||||
"loop": True,
|
||||
"sequence_loop_beat": 0,
|
||||
"sequence_id": "1",
|
||||
}
|
||||
sp._mark_simulated_beat_phase()
|
||||
sp._mark_simulated_beat_phase()
|
||||
assert sp._beat_phase_from_sources()["bar_beat"] == 2
|
||||
|
||||
wait_for = sp.effective_sequence_switch_wait()
|
||||
assert wait_for == "beat"
|
||||
sp._queue_pending_start("z1", "2", "1", None, wait_for, bpm=120.0)
|
||||
assert sp.pending_play_status()["wait_for"] == "beat"
|
||||
|
||||
for _ in range(6):
|
||||
if not sp.pending_play_status()["pending"]:
|
||||
break
|
||||
sp._mark_simulated_beat_phase()
|
||||
phase = sp._beat_phase_from_sources()
|
||||
is_down = bool(phase.get("is_downbeat"))
|
||||
await sp._try_consume_pending_play(is_downbeat=is_down)
|
||||
return consumed_bar_beats
|
||||
|
||||
consumed = asyncio.run(run())
|
||||
assert consumed == [3], f"expected switch on bar beat 3 (next beat), got {consumed}"
|
||||
sp.stop()
|
||||
set_shared_beat_detector(None)
|
||||
|
||||
|
||||
def test_simulated_switch_waits_for_downbeat_only_when_pending_downbeat(monkeypatch):
|
||||
"""Control: downbeat pending with live audio must skip upbeats."""
|
||||
monkeypatch.setattr(
|
||||
"settings.get_settings",
|
||||
lambda: _FakeSettings(
|
||||
audio_simulated_bpm=120,
|
||||
sequence_switch_wait="downbeat",
|
||||
),
|
||||
)
|
||||
det = AudioBeatDetector()
|
||||
set_shared_beat_detector(det)
|
||||
with det._lock:
|
||||
det._running = True
|
||||
det._status["running"] = True
|
||||
sp.stop()
|
||||
sp.clear_pending_play()
|
||||
sp._sim_beat_counter = 0
|
||||
sp._last_thread_beat_phase = {"bar_beat": 1, "is_downbeat": True}
|
||||
|
||||
consumed_bar_beats: List[int] = []
|
||||
|
||||
async def fake_start(_z, _s, _p, _opts, **_kwargs):
|
||||
phase = sp._beat_phase_from_sources()
|
||||
consumed_bar_beats.append(int(phase.get("bar_beat") or 0))
|
||||
|
||||
monkeypatch.setattr("util.sequence_playback._start_immediate", fake_start)
|
||||
|
||||
async def run():
|
||||
sp._queue_pending_start("z1", "2", "1", None, "downbeat", bpm=120.0)
|
||||
for _ in range(6):
|
||||
if not sp.pending_play_status()["pending"]:
|
||||
break
|
||||
sp._mark_simulated_beat_phase()
|
||||
phase = sp._beat_phase_from_sources()
|
||||
await sp._try_consume_pending_play(
|
||||
is_downbeat=bool(phase.get("is_downbeat"))
|
||||
)
|
||||
return consumed_bar_beats
|
||||
|
||||
consumed = asyncio.run(run())
|
||||
assert consumed == [1], f"downbeat pending should wait for bar 1, got {consumed}"
|
||||
sp.stop()
|
||||
with det._lock:
|
||||
det._running = False
|
||||
det._status["running"] = False
|
||||
set_shared_beat_detector(None)
|
||||
|
||||
|
||||
def test_pending_switch_freezes_current_sequence(monkeypatch):
|
||||
"""While waiting for the next beat, the running sequence must not advance."""
|
||||
_install_simulated_bpm(monkeypatch, 120.0)
|
||||
sp.stop()
|
||||
sp.clear_pending_play()
|
||||
sp._beat_consumer_started = False
|
||||
sp._background_beat_task = None
|
||||
sp.ensure_beat_consumer_started()
|
||||
|
||||
ctx = {
|
||||
"lanes": [[{"preset_id": "1", "beats": 99}]],
|
||||
"lane_states": [{"stepIdx": 0, "beatCount": 0, "done": False}],
|
||||
"num_lanes": 1,
|
||||
"loop": True,
|
||||
"sequence_loop_beat": 0,
|
||||
"sequence_id": "1",
|
||||
}
|
||||
with sp._beat_run_lock:
|
||||
sp._beat_run = ctx
|
||||
|
||||
async def fake_start(_z, _s, _p, _opts, **_kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("util.sequence_playback._start_immediate", fake_start)
|
||||
|
||||
async def run():
|
||||
sp._queue_pending_start(
|
||||
"z1", "2", "1", None, sp.effective_sequence_switch_wait(), bpm=120.0
|
||||
)
|
||||
assert ctx.get("_pending_switch") is True
|
||||
await asyncio.sleep(2.5)
|
||||
return int(ctx.get("lane_states", [{}])[0].get("beatCount", 0))
|
||||
|
||||
beat_count = asyncio.run(run())
|
||||
assert beat_count == 0, f"sequence should freeze while pending, got {beat_count}"
|
||||
sp.stop()
|
||||
set_shared_beat_detector(None)
|
||||
|
||||
|
||||
def test_pending_switch_drains_piled_beats_after_slow_start(monkeypatch):
|
||||
"""Beats queued during a slow handoff must not advance the new sequence twice."""
|
||||
_install_simulated_bpm(monkeypatch, 120.0)
|
||||
sp.stop()
|
||||
sp.clear_pending_play()
|
||||
|
||||
async def slow_start(_z, _s, _p, _opts, **_kwargs):
|
||||
sp.push_thread_beat()
|
||||
sp.push_thread_beat()
|
||||
await asyncio.sleep(0.05)
|
||||
with sp._beat_run_lock:
|
||||
sp._beat_run = {
|
||||
"lanes": [[{"preset_id": "1", "beats": 99}]],
|
||||
"lane_states": [{"stepIdx": 0, "beatCount": 0, "done": False}],
|
||||
"num_lanes": 1,
|
||||
"loop": True,
|
||||
"sequence_loop_beat": 0,
|
||||
"sequence_id": "2",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("util.sequence_playback._start_immediate", slow_start)
|
||||
|
||||
async def run():
|
||||
with sp._beat_run_lock:
|
||||
sp._beat_run = {
|
||||
"lanes": [[{"preset_id": "1", "beats": 99}]],
|
||||
"lane_states": [{"stepIdx": 0, "beatCount": 0, "done": False}],
|
||||
"num_lanes": 1,
|
||||
"loop": True,
|
||||
"sequence_loop_beat": 0,
|
||||
"sequence_id": "1",
|
||||
}
|
||||
sp._queue_pending_start("z1", "2", "1", None, "beat", bpm=120.0)
|
||||
assert await sp._try_consume_pending_play(is_downbeat=False) is True
|
||||
piled = 0
|
||||
while True:
|
||||
try:
|
||||
sp._thread_beat_queue.get_nowait()
|
||||
piled += 1
|
||||
except Exception:
|
||||
break
|
||||
assert piled == 0, f"piled beats should be drained, found {piled}"
|
||||
await sp.process_active_beat_advance()
|
||||
with sp._beat_run_lock:
|
||||
ctx = sp._beat_run
|
||||
assert ctx is not None
|
||||
return int(ctx["lane_states"][0].get("beatCount", 0))
|
||||
|
||||
beat_count = asyncio.run(run())
|
||||
assert beat_count == 1, f"expected single advance after switch, got {beat_count}"
|
||||
sp.stop()
|
||||
set_shared_beat_detector(None)
|
||||
|
||||
|
||||
def test_handoff_rearm_blocks_immediate_double_advance(monkeypatch):
|
||||
"""After a switch, piled beats must not advance the new sequence twice in a row."""
|
||||
_install_simulated_bpm(monkeypatch, 120.0)
|
||||
sp.stop()
|
||||
sp.clear_pending_play()
|
||||
|
||||
async def slow_start(_z, _s, _p, _opts, **kwargs):
|
||||
sp.push_thread_beat()
|
||||
await asyncio.sleep(0.02)
|
||||
with sp._beat_run_lock:
|
||||
sp._beat_run = {
|
||||
"lanes": [[{"preset_id": "1", "beats": 99}]],
|
||||
"lane_states": [{"stepIdx": 0, "beatCount": 0, "done": False}],
|
||||
"num_lanes": 1,
|
||||
"loop": True,
|
||||
"sequence_loop_beat": 0,
|
||||
"sequence_id": "2",
|
||||
"_anchor_bar_on_pass_start": False,
|
||||
}
|
||||
|
||||
monkeypatch.setattr("util.sequence_playback._start_immediate", slow_start)
|
||||
monkeypatch.setattr("util.sequence_playback._restart_background_beat_clock", lambda: None)
|
||||
|
||||
async def run():
|
||||
with sp._beat_run_lock:
|
||||
sp._beat_run = {
|
||||
"lanes": [[{"preset_id": "1", "beats": 99}]],
|
||||
"lane_states": [{"stepIdx": 0, "beatCount": 0, "done": False}],
|
||||
"num_lanes": 1,
|
||||
"loop": True,
|
||||
"sequence_loop_beat": 0,
|
||||
"sequence_id": "1",
|
||||
}
|
||||
sp._queue_pending_start("z1", "2", "1", None, "beat", bpm=120.0)
|
||||
sp._accept_thread_beat_now()
|
||||
assert await sp._try_consume_pending_play(is_downbeat=False) is True
|
||||
assert sp._accept_thread_beat_now() is False
|
||||
await sp.process_active_beat_advance()
|
||||
sp.push_thread_beat()
|
||||
assert sp._accept_thread_beat_now() is False
|
||||
with sp._beat_run_lock:
|
||||
ctx = sp._beat_run
|
||||
return int(ctx["lane_states"][0].get("beatCount", 0))
|
||||
|
||||
beat_count = asyncio.run(run())
|
||||
assert beat_count == 1, f"handoff should advance once, got {beat_count}"
|
||||
sp.stop()
|
||||
set_shared_beat_detector(None)
|
||||
|
||||
|
||||
def test_mid_bar_handoff_keeps_bar_phase(monkeypatch):
|
||||
"""Switching on an upbeat must not snap the bar readout back to 1/4."""
|
||||
_install_simulated_bpm(monkeypatch, 120.0)
|
||||
sp.stop()
|
||||
sp.clear_pending_play()
|
||||
sp._sim_beat_counter = 3
|
||||
sp._last_thread_beat_phase = {"bar_beat": 3, "is_downbeat": False}
|
||||
|
||||
async def fake_start(_z, _s, _p, _opts, **kwargs):
|
||||
with sp._beat_run_lock:
|
||||
sp._beat_run = {
|
||||
"lanes": [[{"preset_id": "1", "beats": 99}]],
|
||||
"lane_states": [{"stepIdx": 0, "beatCount": 0, "done": False}],
|
||||
"num_lanes": 1,
|
||||
"loop": True,
|
||||
"sequence_loop_beat": 0,
|
||||
"sequence_id": "2",
|
||||
"_anchor_bar_on_pass_start": kwargs.get("handoff_is_downbeat", False),
|
||||
}
|
||||
|
||||
monkeypatch.setattr("util.sequence_playback._start_immediate", fake_start)
|
||||
monkeypatch.setattr("util.sequence_playback._restart_background_beat_clock", lambda: None)
|
||||
|
||||
async def run():
|
||||
sp._queue_pending_start("z1", "2", "1", None, "beat", bpm=120.0)
|
||||
assert await sp._try_consume_pending_play(is_downbeat=False) is True
|
||||
await sp.process_active_beat_advance()
|
||||
return sp._sim_beat_counter, sp._last_thread_beat_phase["bar_beat"]
|
||||
|
||||
counter, bar_beat = asyncio.run(run())
|
||||
assert counter == 3, f"mid-bar handoff should keep sim counter, got {counter}"
|
||||
assert bar_beat == 3, f"mid-bar handoff should keep bar beat, got {bar_beat}"
|
||||
sp.stop()
|
||||
set_shared_beat_detector(None)
|
||||
|
||||
|
||||
def test_idle_start_is_immediate_not_pending(monkeypatch):
|
||||
"""First sequence with nothing playing should not wait for the next beat."""
|
||||
_install_simulated_bpm(monkeypatch, 60.0)
|
||||
sp.stop()
|
||||
sp.clear_pending_play()
|
||||
|
||||
class FakeSeq:
|
||||
def read(self, _sid):
|
||||
return {"profile_id": "1", "lanes": [[{"preset_id": "1", "beats": 1}]]}
|
||||
|
||||
monkeypatch.setitem(sys.modules, "models.sequence", type(sys)("models.sequence"))
|
||||
sys.modules["models.sequence"].Sequence = FakeSeq # type: ignore[attr-defined]
|
||||
|
||||
started = []
|
||||
|
||||
async def fake_start(z, s, p, opts):
|
||||
started.append((z, s, p))
|
||||
|
||||
monkeypatch.setattr("util.sequence_playback._start_immediate", fake_start)
|
||||
|
||||
async def run():
|
||||
t0 = time.monotonic()
|
||||
await sp.start("z1", "1", "1", None)
|
||||
return time.monotonic() - t0
|
||||
|
||||
elapsed = asyncio.run(run())
|
||||
assert sp.pending_play_status()["pending"] is False
|
||||
assert started == [("z1", "1", "1")]
|
||||
assert elapsed < 0.05, f"idle start should be immediate, took {elapsed:.3f}s"
|
||||
sp.stop()
|
||||
set_shared_beat_detector(None)
|
||||
|
||||
|
||||
def test_active_switch_still_queues_pending(monkeypatch):
|
||||
_install_simulated_bpm(monkeypatch, 60.0, sequence_switch_wait="downbeat")
|
||||
sp.stop()
|
||||
with sp._beat_run_lock:
|
||||
sp._beat_run = {
|
||||
"lanes": [[{"preset_id": "1", "beats": 1}]],
|
||||
"lane_states": [{"stepIdx": 0, "beatCount": 0, "done": False}],
|
||||
"num_lanes": 1,
|
||||
"loop": False,
|
||||
"sequence_id": "1",
|
||||
}
|
||||
|
||||
class FakeSeq:
|
||||
def read(self, _sid):
|
||||
return {"profile_id": "1", "lanes": [[{"preset_id": "1", "beats": 1}]]}
|
||||
|
||||
monkeypatch.setitem(sys.modules, "models.sequence", type(sys)("models.sequence"))
|
||||
sys.modules["models.sequence"].Sequence = FakeSeq # type: ignore[attr-defined]
|
||||
monkeypatch.setattr("util.sequence_playback._start_immediate", lambda *a, **k: None)
|
||||
|
||||
async def run():
|
||||
await sp.start("z1", "2", "1", None)
|
||||
|
||||
asyncio.run(run())
|
||||
st = sp.pending_play_status()
|
||||
assert st["pending"] is True
|
||||
assert st["sequence_id"] == "2"
|
||||
assert st["wait_for"] == "beat", f"simulated switch must queue beat, got {st['wait_for']!r}"
|
||||
sp.stop()
|
||||
set_shared_beat_detector(None)
|
||||
|
||||
|
||||
def test_pending_switch_uses_beat_after_audio_stops(monkeypatch):
|
||||
"""Queued while live audio was timing (downbeat) must switch on beat once sim clocks."""
|
||||
monkeypatch.setattr(
|
||||
"settings.get_settings",
|
||||
lambda: _FakeSettings(
|
||||
audio_simulated_bpm=120,
|
||||
sequence_switch_wait="downbeat",
|
||||
),
|
||||
)
|
||||
det = AudioBeatDetector()
|
||||
set_shared_beat_detector(det)
|
||||
sp.stop()
|
||||
sp.clear_pending_play()
|
||||
sp._sim_beat_counter = 0
|
||||
sp._last_thread_beat_phase = {"bar_beat": 1, "is_downbeat": True}
|
||||
|
||||
consumed_bar_beats: List[int] = []
|
||||
|
||||
async def fake_start(_z, _s, _p, _opts, **_kwargs):
|
||||
phase = sp._beat_phase_from_sources()
|
||||
consumed_bar_beats.append(int(phase.get("bar_beat") or 0))
|
||||
|
||||
monkeypatch.setattr(sp, "_start_immediate", fake_start)
|
||||
|
||||
async def run():
|
||||
with det._lock:
|
||||
det._running = True
|
||||
det._status["running"] = True
|
||||
det._holdover_active = True
|
||||
monkeypatch.setattr(
|
||||
"util.audio_detector.shared_beat_detector_timing_sequences", lambda: True
|
||||
)
|
||||
sp._queue_pending_start("z1", "2", "1", None, "downbeat", bpm=120.0)
|
||||
assert sp.pending_play_status()["wait_for"] == "downbeat"
|
||||
|
||||
with det._lock:
|
||||
det._running = False
|
||||
det._status["running"] = False
|
||||
det._holdover_active = False
|
||||
monkeypatch.setattr(
|
||||
"util.audio_detector.shared_beat_detector_timing_sequences", lambda: False
|
||||
)
|
||||
|
||||
sp._mark_simulated_beat_phase()
|
||||
sp._mark_simulated_beat_phase()
|
||||
for _ in range(4):
|
||||
if not sp.pending_play_status()["pending"]:
|
||||
break
|
||||
sp._mark_simulated_beat_phase()
|
||||
phase = sp._beat_phase_from_sources()
|
||||
await sp._try_consume_pending_play(
|
||||
is_downbeat=bool(phase.get("is_downbeat"))
|
||||
)
|
||||
return consumed_bar_beats
|
||||
|
||||
consumed = asyncio.run(run())
|
||||
assert consumed == [3], (
|
||||
f"after audio stop, pending should consume on next beat (bar 3), got {consumed}"
|
||||
)
|
||||
sp.stop()
|
||||
set_shared_beat_detector(None)
|
||||
|
||||
|
||||
def test_audio_status_reports_beat_switch_when_simulated(server):
|
||||
"""API: audio off + saved downbeat still exposes beat-only switch wait."""
|
||||
c = server["client"]
|
||||
c.put("/settings", json={"sequence_switch_wait": "downbeat"})
|
||||
c.post("/api/audio/stop")
|
||||
|
||||
status = c.get("/api/audio/status").json()["status"]
|
||||
assert status.get("bpm_simulated") is True
|
||||
assert status.get("sequence_switch_wait") == "beat"
|
||||
assert status.get("sequence_switch_wait_saved") == "downbeat"
|
||||
355
tests/web.py
355
tests/web.py
@@ -1,342 +1,35 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Local development web server - imports and runs src.main with port 5000
|
||||
"""
|
||||
"""Local development server: FastAPI app on port 5000."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
import signal
|
||||
import sys
|
||||
|
||||
# Add project root, src, and lib to path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
src_path = os.path.join(project_root, 'src')
|
||||
lib_path = os.path.join(project_root, 'lib')
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
SRC_PATH = os.path.join(PROJECT_ROOT, "src")
|
||||
|
||||
# Add to path in the right order - src must be first so 'models' and 'controllers' can be imported
|
||||
# This ensures imports like 'from models.preset import Preset' work
|
||||
sys.path.insert(0, src_path)
|
||||
sys.path.insert(0, lib_path)
|
||||
sys.path.insert(0, project_root)
|
||||
sys.path.insert(0, SRC_PATH)
|
||||
os.chdir(SRC_PATH)
|
||||
|
||||
# Mock MicroPython modules before importing main
|
||||
class MockMachine:
|
||||
class WDT:
|
||||
def __init__(self, timeout):
|
||||
pass
|
||||
def feed(self):
|
||||
pass
|
||||
|
||||
class MockESPNow:
|
||||
def __init__(self):
|
||||
self.active_value = False
|
||||
self.peers = []
|
||||
self.websocket_client = None # Store single WebSocket connection
|
||||
def active(self, value):
|
||||
self.active_value = value
|
||||
print(f"[MOCK] ESPNow active: {value}")
|
||||
def add_peer(self, peer):
|
||||
self.peers.append(peer)
|
||||
if hasattr(peer, 'hex'):
|
||||
print(f"[MOCK] Added peer: {peer.hex()}")
|
||||
else:
|
||||
print(f"[MOCK] Added peer: {peer}")
|
||||
def register_websocket(self, ws):
|
||||
"""Register a WebSocket connection to forward ESPNow data to."""
|
||||
self.websocket_client = ws
|
||||
print(f"[MOCK] Registered WebSocket client")
|
||||
def unregister_websocket(self, ws):
|
||||
"""Unregister a WebSocket connection."""
|
||||
if self.websocket_client == ws:
|
||||
self.websocket_client = None
|
||||
print(f"[MOCK] Unregistered WebSocket client")
|
||||
async def asend(self, peer, data):
|
||||
if hasattr(peer, 'hex'):
|
||||
print(f"[MOCK] Would send to {peer.hex()}: {data}")
|
||||
else:
|
||||
print(f"[MOCK] Would send to {peer}: {data}")
|
||||
|
||||
# Forward data to the connected WebSocket client
|
||||
if self.websocket_client:
|
||||
try:
|
||||
await self.websocket_client.send(data)
|
||||
print(f"[MOCK] Forwarded to WebSocket client")
|
||||
except Exception as e:
|
||||
print(f"[MOCK] WebSocket client disconnected: {e}")
|
||||
self.websocket_client = None
|
||||
|
||||
class MockAIOESPNow:
|
||||
def __init__(self):
|
||||
self.espnow = MockESPNow()
|
||||
def active(self, value):
|
||||
self.espnow.active(value)
|
||||
return self.espnow
|
||||
def add_peer(self, peer):
|
||||
self.espnow.add_peer(peer)
|
||||
async def asend(self, peer, data):
|
||||
await self.espnow.asend(peer, data)
|
||||
|
||||
# Store reference to mock instance for WebSocket registration
|
||||
@property
|
||||
def mock_instance(self):
|
||||
return self.espnow
|
||||
|
||||
# Create mock ESPNow instance and store reference for WebSocket registration
|
||||
mock_espnow_instance = MockESPNow()
|
||||
mock_aioespnow = MockAIOESPNow()
|
||||
mock_aioespnow.espnow = mock_espnow_instance # Use the shared instance
|
||||
|
||||
# Create mock ESPNow instance and store reference for WebSocket registration
|
||||
mock_espnow_instance = MockESPNow()
|
||||
mock_aioespnow = MockAIOESPNow()
|
||||
mock_aioespnow.espnow = mock_espnow_instance # Use the shared instance
|
||||
|
||||
# Install mocks in sys.modules before any imports
|
||||
sys.modules['machine'] = MockMachine()
|
||||
# Store the mock instance in the module so it can be accessed
|
||||
aioespnow_module = type('module', (), {'AIOESPNow': MockAIOESPNow, '_mock_instance': mock_espnow_instance})()
|
||||
sys.modules['aioespnow'] = aioespnow_module
|
||||
class MockWLAN:
|
||||
def __init__(self, interface):
|
||||
self.interface = interface
|
||||
def active(self, value):
|
||||
print(f"[MOCK] WLAN({self.interface}) active: {value}")
|
||||
|
||||
sys.modules['network'] = type('module', (), {
|
||||
'WLAN': MockWLAN,
|
||||
'STA_IF': 0
|
||||
})()
|
||||
|
||||
# Mock asyncio.sleep_ms for regular Python
|
||||
_original_sleep = asyncio.sleep
|
||||
async def sleep_ms(ms):
|
||||
await _original_sleep(ms / 1000.0)
|
||||
|
||||
# Patch asyncio.sleep_ms
|
||||
asyncio.sleep_ms = sleep_ms
|
||||
|
||||
# Patch sys.print_exception for regular Python (MicroPython has this, regular Python doesn't)
|
||||
if not hasattr(sys, 'print_exception'):
|
||||
import traceback
|
||||
sys.print_exception = lambda e, file=None: traceback.print_exception(type(e), e, e.__traceback__, file=file)
|
||||
|
||||
# Patch builtins.open to redirect /db/ paths to project db directory
|
||||
import builtins
|
||||
_original_open = builtins.open
|
||||
def patched_open(file, mode='r', *args, **kwargs):
|
||||
if isinstance(file, str):
|
||||
if file.startswith('/db/'):
|
||||
# Redirect to project db directory
|
||||
filename = os.path.basename(file)
|
||||
file = os.path.join(project_root, 'db', filename)
|
||||
elif not os.path.isabs(file):
|
||||
# For relative paths starting with templates/ or static/,
|
||||
# always resolve to src/ directory
|
||||
if file.startswith('templates/') or file.startswith('static/'):
|
||||
file = os.path.join(src_path, file)
|
||||
# For other relative paths, check if they exist in current dir
|
||||
# If not, try src/ directory
|
||||
elif not os.path.exists(file):
|
||||
src_file = os.path.join(src_path, file)
|
||||
if os.path.exists(src_file):
|
||||
file = src_file
|
||||
return _original_open(file, mode, *args, **kwargs)
|
||||
builtins.open = patched_open
|
||||
|
||||
# Also patch os.mkdir to handle /db path
|
||||
original_mkdir = os.mkdir
|
||||
def patched_mkdir(path):
|
||||
if path == "/db":
|
||||
# Use project db directory instead
|
||||
db_path = os.path.join(project_root, "db")
|
||||
if not os.path.exists(db_path):
|
||||
os.makedirs(db_path, exist_ok=True)
|
||||
else:
|
||||
original_mkdir(path)
|
||||
os.mkdir = patched_mkdir
|
||||
|
||||
# Create a flag to stop the infinite loop
|
||||
_stop_flag = False
|
||||
|
||||
# Patch gc.collect to check stop flag
|
||||
import gc as gc_module
|
||||
_original_collect = gc_module.collect
|
||||
def collect():
|
||||
global _stop_flag
|
||||
if _stop_flag:
|
||||
raise KeyboardInterrupt("Stop requested")
|
||||
return _original_collect()
|
||||
gc_module.collect = collect
|
||||
|
||||
# Change to src directory for file paths (where templates and static are)
|
||||
# main.py expects templates/ and static/ to be relative to the working directory
|
||||
os.chdir(src_path)
|
||||
|
||||
# Override settings path for local development
|
||||
# Import settings module and patch the path before main imports it
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("settings", os.path.join(src_path, "settings.py"))
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"settings", os.path.join(SRC_PATH, "settings.py")
|
||||
)
|
||||
settings_module = importlib.util.module_from_spec(spec)
|
||||
sys.modules['settings'] = settings_module
|
||||
sys.modules["settings"] = settings_module
|
||||
spec.loader.exec_module(settings_module)
|
||||
settings_module.Settings.SETTINGS_FILE = os.path.join(project_root, 'settings.json')
|
||||
|
||||
# Patch the Model class file path before importing
|
||||
# We need to monkey-patch the model.py file's behavior
|
||||
import importlib.util
|
||||
model_spec = importlib.util.spec_from_file_location("models.model", os.path.join(src_path, "models", "model.py"))
|
||||
model_module = importlib.util.module_from_spec(model_spec)
|
||||
|
||||
# Patch os.mkdir in the model module's context
|
||||
original_mkdir = os.mkdir
|
||||
def patched_mkdir(path):
|
||||
if path == "/db":
|
||||
db_path = os.path.join(project_root, "db")
|
||||
if not os.path.exists(db_path):
|
||||
os.makedirs(db_path, exist_ok=True)
|
||||
else:
|
||||
original_mkdir(path)
|
||||
|
||||
# Set up the module's namespace with patched os
|
||||
model_module.__dict__['os'] = type('os', (), {'mkdir': patched_mkdir, 'path': os.path})()
|
||||
model_spec.loader.exec_module(model_module)
|
||||
sys.modules['models.model'] = model_module
|
||||
|
||||
# Now patch the Model class to fix file paths
|
||||
# The issue is that Model.__init__ sets self.file and immediately calls load()
|
||||
# before we can patch it. We need to replace __init__ completely.
|
||||
# Also clear any existing singleton instances
|
||||
Model = model_module.Model
|
||||
# Clear singleton instances for all Model subclasses
|
||||
for attr_name in dir(model_module):
|
||||
attr = getattr(model_module, attr_name)
|
||||
if isinstance(attr, type) and issubclass(attr, Model) and attr != Model:
|
||||
if hasattr(attr, '_instance'):
|
||||
delattr(attr, '_instance')
|
||||
|
||||
original_save = Model.save
|
||||
original_load = Model.load
|
||||
original_set_defaults = Model.set_defaults
|
||||
|
||||
def patched_init(self):
|
||||
# Only initialize once (check if already initialized)
|
||||
if hasattr(self, '_initialized'):
|
||||
return
|
||||
|
||||
# Create db directory if it doesn't exist (use project db, not /db)
|
||||
db_path = os.path.join(project_root, "db")
|
||||
if not os.path.exists(db_path):
|
||||
os.makedirs(db_path, exist_ok=True)
|
||||
|
||||
self.class_name = self.__class__.__name__
|
||||
# Set file path to project db directory from the start
|
||||
self.file = os.path.join(project_root, 'db', f"{self.class_name.lower()}.json")
|
||||
super(Model, self).__init__()
|
||||
|
||||
# Now call load with the correct path already set
|
||||
# Call the patched load method (defined below)
|
||||
Model.load(self)
|
||||
self._initialized = True
|
||||
|
||||
def patched_save(self):
|
||||
# Ensure file path is correct before saving (this will also fix print statements)
|
||||
if hasattr(self, 'file') and self.file.startswith('/db/'):
|
||||
filename = os.path.basename(self.file)
|
||||
self.file = os.path.join(project_root, 'db', filename)
|
||||
# Also ensure the directory exists
|
||||
db_dir = os.path.dirname(self.file)
|
||||
if not os.path.exists(db_dir):
|
||||
os.makedirs(db_dir, exist_ok=True)
|
||||
return original_save(self)
|
||||
|
||||
def patched_load(self):
|
||||
# Ensure file path is correct before loading
|
||||
if hasattr(self, 'file') and self.file.startswith('/db/'):
|
||||
filename = os.path.basename(self.file)
|
||||
self.file = os.path.join(project_root, 'db', filename)
|
||||
try:
|
||||
with open(self.file, 'r') as file:
|
||||
import json
|
||||
loaded_settings = json.load(file)
|
||||
# Use dict.update() directly, not the subclass's update() method
|
||||
dict.update(self, loaded_settings)
|
||||
print(f"{self.class_name} loaded successfully.")
|
||||
except FileNotFoundError:
|
||||
# File doesn't exist yet - this is normal on first run
|
||||
print(f"No existing {self.class_name} file found, creating defaults.")
|
||||
self.set_defaults()
|
||||
self.save()
|
||||
except Exception as e:
|
||||
# Other errors - log and create defaults
|
||||
print(f"Error loading {self.class_name}: {type(e).__name__}: {e}")
|
||||
self.set_defaults()
|
||||
self.save()
|
||||
|
||||
# Apply patches - load must be patched before init uses it
|
||||
Model.load = patched_load
|
||||
Model.__init__ = patched_init
|
||||
Model.save = patched_save
|
||||
|
||||
# Patch with_websocket decorator before importing main to register WebSocket connections
|
||||
from microdot.websocket import with_websocket as original_with_websocket
|
||||
|
||||
def patched_with_websocket(f):
|
||||
"""Patched with_websocket decorator that registers connections with mock ESPNow."""
|
||||
@original_with_websocket
|
||||
async def wrapped_handler(request, ws):
|
||||
# Register WebSocket connection with mock ESPNow
|
||||
mock_espnow_instance.register_websocket(ws)
|
||||
try:
|
||||
# Call original handler
|
||||
await f(request, ws)
|
||||
finally:
|
||||
# Unregister when connection closes
|
||||
mock_espnow_instance.unregister_websocket(ws)
|
||||
return wrapped_handler
|
||||
|
||||
# Now import main (which will use the patched settings module and model)
|
||||
# Import as a module file directly to avoid package import issues
|
||||
main_spec = importlib.util.spec_from_file_location("main", os.path.join(src_path, "main.py"))
|
||||
main_module = importlib.util.module_from_spec(main_spec)
|
||||
|
||||
# Patch with_websocket in the main module before executing it
|
||||
main_module.__dict__['with_websocket'] = patched_with_websocket
|
||||
|
||||
main_spec.loader.exec_module(main_module)
|
||||
main = main_module.main
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
"""Handle Ctrl+C gracefully."""
|
||||
global _stop_flag
|
||||
print("\nShutting down server...")
|
||||
_stop_flag = True
|
||||
# Force exit since main has an infinite loop
|
||||
sys.exit(0)
|
||||
|
||||
async def run_web():
|
||||
"""Run main with port 5000."""
|
||||
print("Starting LED Controller Web Server (Local Development)")
|
||||
print("=" * 60)
|
||||
print(f"Server will run on http://localhost:5000")
|
||||
print("Press Ctrl+C to stop")
|
||||
print("=" * 60)
|
||||
|
||||
# Set up signal handler
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
try:
|
||||
# Call main with port 5000
|
||||
await main(port=5000)
|
||||
except KeyboardInterrupt:
|
||||
print("\nShutting down server...")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
raise
|
||||
settings_module.Settings.SETTINGS_FILE = os.path.join(PROJECT_ROOT, "settings.json")
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(run_web())
|
||||
except KeyboardInterrupt:
|
||||
print("\nExiting...")
|
||||
except SystemExit:
|
||||
pass
|
||||
import uvicorn
|
||||
|
||||
from fastapi_app import app
|
||||
|
||||
print("Starting LED Controller Web Server (Local Development)")
|
||||
print("=" * 60)
|
||||
print("Server will run on http://localhost:5000")
|
||||
print("Press Ctrl+C to stop")
|
||||
print("=" * 60)
|
||||
uvicorn.run(app, host="0.0.0.0", port=5000)
|
||||
|
||||
Reference in New Issue
Block a user