"""Outbound WebSocket clients to Wi-Fi LED drivers (firmware serves ``/ws`` on device).""" from __future__ import annotations import asyncio import errno import json import traceback import websockets from websockets.exceptions import ConnectionClosed _connections: dict[str, object] = {} _send_locks: dict[str, asyncio.Lock] = {} _tasks: dict[str, asyncio.Task] = {} _unreachable_counts: dict[str, int] = {} _settings = None _tcp_status_broadcast = None def set_settings(settings) -> None: global _settings _settings = settings def set_tcp_status_broadcaster(coro) -> None: global _tcp_status_broadcast _tcp_status_broadcast = coro def _schedule_status_broadcast(ip: str, connected: bool) -> None: fn = _tcp_status_broadcast if not fn: return try: loop = asyncio.get_running_loop() except RuntimeError: return try: loop.create_task(fn(ip, connected)) except Exception: pass def _benign_ws_connect_failure(exc: BaseException) -> bool: """True for common \"driver down / no route\" errors while dialling the WebSocket.""" if isinstance(exc, (asyncio.TimeoutError, TimeoutError)): return True if isinstance(exc, ConnectionRefusedError): return True if not isinstance(exc, OSError): return False en = exc.errno if en is None: return False codes = {errno.ECONNREFUSED, errno.ETIMEDOUT} for name in ("EHOSTUNREACH", "ENETUNREACH", "ENETDOWN", "EADDRNOTAVAIL"): if hasattr(errno, name): codes.add(getattr(errno, name)) return en in codes def normalize_tcp_peer_ip(ip: str) -> str: """Match peer addresses to registry IPs (strip IPv4-mapped IPv6 prefix).""" s = str(ip).strip() if s.lower().startswith("::ffff:"): s = s[7:] return s def _ws_open(ws) -> bool: try: return ws.close_code is None except Exception: return False def prune_stale_tcp_writers() -> None: """Drop closed WebSocket entries (name kept for callers).""" stale = [ip for ip, ws in list(_connections.items()) if not _ws_open(ws)] for ip in stale: _connections.pop(ip, None) _schedule_status_broadcast(ip, False) def _global_brightness_message_text() -> str | None: """v1 JSON line for saved zone UI brightness; works with shipping driver firmware (applies ``b`` in RAM).""" global _settings if _settings is None: return None try: b = int(_settings.get("global_brightness", 255)) except (TypeError, ValueError): b = 255 b = max(0, min(255, b)) return json.dumps({"v": "1", "b": b}) async def sync_global_brightness_to_driver(ip: str) -> bool: """Push Pi-stored global brightness to one Wi-Fi driver over the outbound WebSocket.""" text = _global_brightness_message_text() if not text: return False return await send_json_line_to_ip(ip, text) async def broadcast_global_brightness_to_tcp_drivers() -> None: """Push saved global brightness to every connected Wi-Fi driver.""" text = _global_brightness_message_text() if not text: return for ip in list_connected_ips(): await send_json_line_to_ip(ip, text) def _register_ws(ip: str, ws) -> None: key = normalize_tcp_peer_ip(ip) if not key: return _connections[key] = ws _unreachable_counts.pop(key, None) if key not in _send_locks: _send_locks[key] = asyncio.Lock() _schedule_status_broadcast(key, True) print(f"[WS] driver connected {key!r}") try: loop = asyncio.get_running_loop() except RuntimeError: return async def _apply_saved_brightness(): await sync_global_brightness_to_driver(key) loop.create_task(_apply_saved_brightness()) def unregister_tcp_writer(peer_ip: str, ws=None) -> str: """ Remove the WebSocket for peer_ip. If ``ws`` is given, only pop when it is still the registered instance. Returns ``removed``, ``noop``, or ``superseded`` (same contract as former TCP registry). """ if not peer_ip: return "noop" key = normalize_tcp_peer_ip(peer_ip) if not key: return "noop" current = _connections.get(key) if ws is not None: if current is None: return "noop" if current is not ws: return "superseded" had = key in _connections if had: _connections.pop(key, None) _schedule_status_broadcast(key, False) print(f"[WS] driver disconnected: {key}") return "removed" return "noop" def list_connected_ips(): """IPs with an active outbound WebSocket to the driver.""" prune_stale_tcp_writers() return list(_connections.keys()) def tcp_client_connected(ip: str) -> bool: """True if the controller has an outbound WebSocket to this driver IP.""" prune_stale_tcp_writers() key = normalize_tcp_peer_ip(ip) return bool(key and key in _connections) async def send_json_line_to_ip(ip: str, json_str: str) -> bool: """Send one JSON text frame (v1 line; trailing newline stripped for WebSocket).""" ip = normalize_tcp_peer_ip(ip) ws = _connections.get(ip) if ws is None or not _ws_open(ws): return False text = json_str.rstrip("\n") lock = _send_locks.setdefault(ip, asyncio.Lock()) try: async with lock: await ws.send(text) return True except Exception as exc: print(f"[WS] send to {ip} failed: {exc}") unregister_tcp_writer(ip, ws) return False async def _recv_forward_loop(ip: str, ws) -> None: from models.transport import get_current_sender sender = get_current_sender() async for message in ws: if isinstance(message, bytes): try: text = message.decode("utf-8") except UnicodeDecodeError: print(f"[WS] recv {ip} (non-UTF-8, {len(message)} bytes)") continue else: text = message text = text.strip() if not text: continue print(f"[WS] recv {ip}: {text}") if not sender: continue try: parsed = json.loads(text) except json.JSONDecodeError: try: await sender.send(text) except Exception: pass continue if isinstance(parsed, dict): addr = parsed.pop("to", None) payload = json.dumps(parsed) if parsed else "{}" try: await sender.send(payload, addr=addr) except Exception as e: print(f"[WS] forward to bridge failed: {e}") else: try: await sender.send(text) except Exception: pass def _stagger_delay_s_for_ip(ip: str) -> float: """0 .. wifi_driver_connect_stagger_max_s based on last IPv4 octet (deterministic spread).""" global _settings if _settings is None: return 0.0 try: max_s = float(_settings.get("wifi_driver_connect_stagger_max_s", 2.5)) except (TypeError, ValueError): max_s = 2.5 if max_s <= 0: return 0.0 parts = str(ip).strip().split(".") if len(parts) != 4: return 0.0 try: last = int(parts[3]) % 256 except ValueError: return 0.0 return (last / 255.0) * max_s async def _driver_connection_loop(ip: str) -> None: global _settings if _settings is None: return port = int(_settings.get("wifi_driver_ws_port", 80)) path = str(_settings.get("wifi_driver_ws_path", "/ws")) if not path.startswith("/"): path = "/" + path uri = f"ws://{ip}:{port}{path}" try: retry_interval_s = float(_settings.get("wifi_driver_connect_retry_interval_s", 2.0)) except (TypeError, ValueError): retry_interval_s = 2.0 retry_interval_s = max(0.2, retry_interval_s) try: retry_window_s = float(_settings.get("wifi_driver_connect_retry_window_s", 120.0)) except (TypeError, ValueError): retry_window_s = 120.0 retry_window_s = max(5.0, retry_window_s) try: open_timeout = float(_settings.get("wifi_driver_ws_open_timeout", 45.0)) except (TypeError, ValueError): open_timeout = 45.0 open_timeout = max(5.0, open_timeout) loop = asyncio.get_running_loop() stagger = _stagger_delay_s_for_ip(ip) if stagger > 0: await asyncio.sleep(stagger) # Only bound boot-time: after we have connected once, keep retrying (Wi-Fi drops, reboots). connected_once = False deadline = loop.time() + retry_window_s try: while True: now = loop.time() if not connected_once and now >= deadline: print( f"[WS] driver {ip} still unreachable after {int(retry_window_s)}s " f"(initial window); stopping until next UDP hello / registry prime" ) break try: print(f"[WS] connecting to {uri!r}") async with websockets.connect( uri, ping_interval=20, ping_timeout=15, open_timeout=open_timeout, ) as ws: connected_once = True _register_ws(ip, ws) try: await _recv_forward_loop(ip, ws) finally: unregister_tcp_writer(ip, ws) except asyncio.CancelledError: raise except ConnectionClosed as e: print(f"[WS] driver {ip} closed: {e}") unregister_tcp_writer(ip, None) except Exception as e: if _benign_ws_connect_failure(e): n = _unreachable_counts.get(ip, 0) + 1 _unreachable_counts[ip] = n if n == 1 or (n % 30) == 0: print( f"[WS] driver {ip} unreachable, retry in {retry_interval_s}s: {e} (x{n})" ) else: print(f"[WS] driver {ip} session error: {e!r}") traceback.print_exception(type(e), e, e.__traceback__) _unreachable_counts.pop(ip, None) unregister_tcp_writer(ip, None) await asyncio.sleep(retry_interval_s) except asyncio.CancelledError: unregister_tcp_writer(ip, None) raise finally: _tasks.pop(ip, None) def ensure_driver_connection(peer_ip: str) -> None: """Start (or keep) a background task that maintains ``ws://:port/ws``.""" key = normalize_tcp_peer_ip(peer_ip) if not key: return t = _tasks.get(key) if t is not None and not t.done(): return try: loop = asyncio.get_running_loop() except RuntimeError: return _tasks[key] = loop.create_task(_driver_connection_loop(key)) def cancel_all_driver_tasks() -> None: """Signal shutdown: cancel outbound driver connection tasks.""" for _ip, t in list(_tasks.items()): if not t.done(): t.cancel() _tasks.clear() for ip in list(_connections.keys()): _schedule_status_broadcast(ip, False) _connections.clear() _send_locks.clear() _unreachable_counts.clear()