diff --git a/docs/pattern-contract.md b/docs/pattern-contract.md new file mode 100644 index 0000000..5e6cb77 --- /dev/null +++ b/docs/pattern-contract.md @@ -0,0 +1,51 @@ +# Pattern Contract (Important) + +Pattern classes are loaded dynamically by `Presets._load_dynamic_patterns()`. + +Patterns must follow this contract exactly. + +## Required class shape + +- File name is the pattern id (for example `blink.py` -> pattern name `blink`). +- Module exports a class with: + - `__init__(self, driver)` where `driver` is the `Presets` instance. + - `run(self, preset)` that returns a generator. + +`Presets` binds patterns like this: + +- `pattern_class(self).run` +- then calls `self.patterns[preset.p](preset)` and stores that generator. +- every frame, `Presets.tick()` does `next(self.generator)`. + +## `run()` generator rules + +- `run()` must `yield` frequently (normally once per tick loop). +- Do not block inside `run()`: + - no `sleep()` / `sleep_ms()` / long loops without `yield`. + - no network or file I/O. +- Use time checks (`utime.ticks_ms()` + `utime.ticks_diff(...)`) to schedule updates. +- Keep pattern state inside local variables in `run()` (or object fields if needed). + +## Drawing and brightness + +- Use `self.driver.apply_brightness(color, preset.b)` for per-preset brightness. +- Write pixels through `self.driver.n[...]` / `self.driver.n.fill(...)`. +- Flush frame with `self.driver.n.write()`. +- If a pattern needs to clear, use black `(0, 0, 0)`. + +## Step semantics + +- `self.driver.step` is shared pattern state managed by `Presets.select(...)` and patterns. +- Patterns that use step-based progression should update `self.driver.step` themselves. +- `select(..., step=...)` may set an explicit starting step. + +## Error handling + +- Let unexpected errors raise inside the generator. +- `Presets.tick()` catches exceptions, logs, and stops the active generator. +- Pattern code should not swallow broad exceptions unless there is a clear recovery path. + +## Built-ins + +- `off` and `on` are built-in methods on `Presets`, not loaded from this folder. +- `__init__.py` is ignored by dynamic loader. diff --git a/lib/microdot/__init__.py b/lib/microdot/__init__.py new file mode 100644 index 0000000..68cb381 --- /dev/null +++ b/lib/microdot/__init__.py @@ -0,0 +1,2 @@ +from microdot.microdot import Microdot, Request, Response, abort, redirect, \ + send_file # noqa: F401 \ No newline at end of file diff --git a/lib/microdot/helpers.py b/lib/microdot/helpers.py new file mode 100644 index 0000000..664e58c --- /dev/null +++ b/lib/microdot/helpers.py @@ -0,0 +1,8 @@ +try: + from functools import wraps +except ImportError: # pragma: no cover + # MicroPython does not currently implement functools.wraps + def wraps(wrapped): + def _(wrapper): + return wrapper + return _ diff --git a/lib/microdot/microdot.py b/lib/microdot/microdot.py new file mode 100644 index 0000000..0513f21 --- /dev/null +++ b/lib/microdot/microdot.py @@ -0,0 +1,1450 @@ +""" +microdot +-------- + +The ``microdot`` module defines a few classes that help implement HTTP-based +servers for MicroPython and standard Python. +""" +import asyncio +import io +import json +import time + +try: + from inspect import iscoroutinefunction, iscoroutine + from functools import partial + + async def invoke_handler(handler, *args, **kwargs): + """Invoke a handler and return the result. + + This method runs sync handlers in a thread pool executor. + """ + if iscoroutinefunction(handler): + ret = await handler(*args, **kwargs) + else: + ret = await asyncio.get_running_loop().run_in_executor( + None, partial(handler, *args, **kwargs)) + return ret +except ImportError: # pragma: no cover + def iscoroutine(coro): + return hasattr(coro, 'send') and hasattr(coro, 'throw') + + async def invoke_handler(handler, *args, **kwargs): + """Invoke a handler and return the result. + + This method runs sync handlers in the asyncio thread, which can + potentially cause blocking and performance issues. + """ + ret = handler(*args, **kwargs) + if iscoroutine(ret): + ret = await ret + return ret + +try: + from sys import print_exception +except ImportError: # pragma: no cover + import traceback + + def print_exception(exc): + traceback.print_exc() + +MUTED_SOCKET_ERRORS = [ + 32, # Broken pipe + 54, # Connection reset by peer + 104, # Connection reset by peer + 128, # Operation on closed socket +] + + +def urldecode_str(s): + s = s.replace('+', ' ') + parts = s.split('%') + if len(parts) == 1: + return s + result = [parts[0]] + for item in parts[1:]: + if item == '': + result.append('%') + else: + code = item[:2] + result.append(chr(int(code, 16))) + result.append(item[2:]) + return ''.join(result) + + +def urldecode_bytes(s): + s = s.replace(b'+', b' ') + parts = s.split(b'%') + if len(parts) == 1: + return s.decode() + result = [parts[0]] + for item in parts[1:]: + if item == b'': + result.append(b'%') + else: + code = item[:2] + result.append(bytes([int(code, 16)])) + result.append(item[2:]) + return b''.join(result).decode() + + +def urlencode(s): + return s.replace('+', '%2B').replace(' ', '+').replace( + '%', '%25').replace('?', '%3F').replace('#', '%23').replace( + '&', '%26').replace('=', '%3D') + + +class NoCaseDict(dict): + """A subclass of dictionary that holds case-insensitive keys. + + :param initial_dict: an initial dictionary of key/value pairs to + initialize this object with. + + Example:: + + >>> d = NoCaseDict() + >>> d['Content-Type'] = 'text/html' + >>> print(d['Content-Type']) + text/html + >>> print(d['content-type']) + text/html + >>> print(d['CONTENT-TYPE']) + text/html + >>> del d['cOnTeNt-TyPe'] + >>> print(d) + {} + """ + def __init__(self, initial_dict=None): + super().__init__(initial_dict or {}) + self.keymap = {k.lower(): k for k in self.keys() if k.lower() != k} + + def __setitem__(self, key, value): + kl = key.lower() + key = self.keymap.get(kl, key) + if kl != key: + self.keymap[kl] = key + super().__setitem__(key, value) + + def __getitem__(self, key): + kl = key.lower() + return super().__getitem__(self.keymap.get(kl, kl)) + + def __delitem__(self, key): + kl = key.lower() + super().__delitem__(self.keymap.get(kl, kl)) + + def __contains__(self, key): + kl = key.lower() + return self.keymap.get(kl, kl) in self.keys() + + def get(self, key, default=None): + kl = key.lower() + return super().get(self.keymap.get(kl, kl), default) + + def update(self, other_dict): + for key, value in other_dict.items(): + self[key] = value + + +def mro(cls): # pragma: no cover + """Return the method resolution order of a class. + + This is a helper function that returns the method resolution order of a + class. It is used by Microdot to find the best error handler to invoke for + the raised exception. + + In CPython, this function returns the ``__mro__`` attribute of the class. + In MicroPython, this function implements a recursive depth-first scanning + of the class hierarchy. + """ + if hasattr(cls, 'mro'): + return cls.__mro__ + + def _mro(cls): + m = [cls] + for base in cls.__bases__: + m += _mro(base) + return m + + mro_list = _mro(cls) + + # If a class appears multiple times (due to multiple inheritance) remove + # all but the last occurence. This matches the method resolution order + # of MicroPython, but not CPython. + mro_pruned = [] + for i in range(len(mro_list)): + base = mro_list.pop(0) + if base not in mro_list: + mro_pruned.append(base) + return mro_pruned + + +class MultiDict(dict): + """A subclass of dictionary that can hold multiple values for the same + key. It is used to hold key/value pairs decoded from query strings and + form submissions. + + :param initial_dict: an initial dictionary of key/value pairs to + initialize this object with. + + Example:: + + >>> d = MultiDict() + >>> d['sort'] = 'name' + >>> d['sort'] = 'email' + >>> print(d['sort']) + 'name' + >>> print(d.getlist('sort')) + ['name', 'email'] + """ + def __init__(self, initial_dict=None): + super().__init__() + if initial_dict: + for key, value in initial_dict.items(): + self[key] = value + + def __setitem__(self, key, value): + if key not in self: + super().__setitem__(key, []) + super().__getitem__(key).append(value) + + def __getitem__(self, key): + return super().__getitem__(key)[0] + + def get(self, key, default=None, type=None): + """Return the value for a given key. + + :param key: The key to retrieve. + :param default: A default value to use if the key does not exist. + :param type: A type conversion callable to apply to the value. + + If the multidict contains more than one value for the requested key, + this method returns the first value only. + + Example:: + + >>> d = MultiDict() + >>> d['age'] = '42' + >>> d.get('age') + '42' + >>> d.get('age', type=int) + 42 + >>> d.get('name', default='noname') + 'noname' + """ + if key not in self: + return default + value = self[key] + if type is not None: + value = type(value) + return value + + def getlist(self, key, type=None): + """Return all the values for a given key. + + :param key: The key to retrieve. + :param type: A type conversion callable to apply to the values. + + If the requested key does not exist in the dictionary, this method + returns an empty list. + + Example:: + + >>> d = MultiDict() + >>> d.getlist('items') + [] + >>> d['items'] = '3' + >>> d.getlist('items') + ['3'] + >>> d['items'] = '56' + >>> d.getlist('items') + ['3', '56'] + >>> d.getlist('items', type=int) + [3, 56] + """ + if key not in self: + return [] + values = super().__getitem__(key) + if type is not None: + values = [type(value) for value in values] + return values + + +class AsyncBytesIO: + """An async wrapper for BytesIO.""" + def __init__(self, data): + self.stream = io.BytesIO(data) + + async def read(self, n=-1): + return self.stream.read(n) + + async def readline(self): # pragma: no cover + return self.stream.readline() + + async def readexactly(self, n): # pragma: no cover + return self.stream.read(n) + + async def readuntil(self, separator=b'\n'): # pragma: no cover + return self.stream.readuntil(separator=separator) + + async def awrite(self, data): # pragma: no cover + return self.stream.write(data) + + async def aclose(self): # pragma: no cover + pass + + +class Request: + """An HTTP request.""" + #: Specify the maximum payload size that is accepted. Requests with larger + #: payloads will be rejected with a 413 status code. Applications can + #: change this maximum as necessary. + #: + #: Example:: + #: + #: Request.max_content_length = 1 * 1024 * 1024 # 1MB requests allowed + max_content_length = 16 * 1024 + + #: Specify the maximum payload size that can be stored in ``body``. + #: Requests with payloads that are larger than this size and up to + #: ``max_content_length`` bytes will be accepted, but the application will + #: only be able to access the body of the request by reading from + #: ``stream``. Set to 0 if you always access the body as a stream. + #: + #: Example:: + #: + #: Request.max_body_length = 4 * 1024 # up to 4KB bodies read + max_body_length = 16 * 1024 + + #: Specify the maximum length allowed for a line in the request. Requests + #: with longer lines will not be correctly interpreted. Applications can + #: change this maximum as necessary. + #: + #: Example:: + #: + #: Request.max_readline = 16 * 1024 # 16KB lines allowed + max_readline = 2 * 1024 + + class G: + pass + + def __init__(self, app, client_addr, method, url, http_version, headers, + body=None, stream=None, sock=None): + #: The application instance to which this request belongs. + self.app = app + #: The address of the client, as a tuple (host, port). + self.client_addr = client_addr + #: The HTTP method of the request. + self.method = method + #: The request URL, including the path and query string. + self.url = url + #: The path portion of the URL. + self.path = url + #: The query string portion of the URL. + self.query_string = None + #: The parsed query string, as a + #: :class:`MultiDict ` object. + self.args = {} + #: A dictionary with the headers included in the request. + self.headers = headers + #: A dictionary with the cookies included in the request. + self.cookies = {} + #: The parsed ``Content-Length`` header. + self.content_length = 0 + #: The parsed ``Content-Type`` header. + self.content_type = None + #: A general purpose container for applications to store data during + #: the life of the request. + self.g = Request.G() + + self.http_version = http_version + if '?' in self.path: + self.path, self.query_string = self.path.split('?', 1) + self.args = self._parse_urlencoded(self.query_string) + + if 'Content-Length' in self.headers: + self.content_length = int(self.headers['Content-Length']) + if 'Content-Type' in self.headers: + self.content_type = self.headers['Content-Type'] + if 'Cookie' in self.headers: + for cookie in self.headers['Cookie'].split(';'): + name, value = cookie.strip().split('=', 1) + self.cookies[name] = value + + self._body = body + self.body_used = False + self._stream = stream + self.sock = sock + self._json = None + self._form = None + self.after_request_handlers = [] + + @staticmethod + async def create(app, client_reader, client_writer, client_addr): + """Create a request object. + + :param app: The Microdot application instance. + :param client_reader: An input stream from where the request data can + be read. + :param client_writer: An output stream where the response data can be + written. + :param client_addr: The address of the client, as a tuple. + + This method is a coroutine. It returns a newly created ``Request`` + object. + """ + # request line + line = (await Request._safe_readline(client_reader)).strip().decode() + if not line: # pragma: no cover + return None + method, url, http_version = line.split() + http_version = http_version.split('/', 1)[1] + + # headers + headers = NoCaseDict() + content_length = 0 + while True: + line = (await Request._safe_readline( + client_reader)).strip().decode() + if line == '': + break + header, value = line.split(':', 1) + value = value.strip() + headers[header] = value + if header.lower() == 'content-length': + content_length = int(value) + + # body + body = b'' + if content_length and content_length <= Request.max_body_length: + body = await client_reader.readexactly(content_length) + stream = None + else: + body = b'' + stream = client_reader + + return Request(app, client_addr, method, url, http_version, headers, + body=body, stream=stream, + sock=(client_reader, client_writer)) + + def _parse_urlencoded(self, urlencoded): + data = MultiDict() + if len(urlencoded) > 0: # pragma: no branch + if isinstance(urlencoded, str): + for kv in [pair.split('=', 1) + for pair in urlencoded.split('&') if pair]: + data[urldecode_str(kv[0])] = urldecode_str(kv[1]) \ + if len(kv) > 1 else '' + elif isinstance(urlencoded, bytes): # pragma: no branch + for kv in [pair.split(b'=', 1) + for pair in urlencoded.split(b'&') if pair]: + data[urldecode_bytes(kv[0])] = urldecode_bytes(kv[1]) \ + if len(kv) > 1 else b'' + return data + + @property + def body(self): + """The body of the request, as bytes.""" + return self._body + + @property + def stream(self): + """The body of the request, as a bytes stream.""" + if self._stream is None: + self._stream = AsyncBytesIO(self._body) + return self._stream + + @property + def json(self): + """The parsed JSON body, or ``None`` if the request does not have a + JSON body.""" + if self._json is None: + if self.content_type is None: + return None + mime_type = self.content_type.split(';')[0] + if mime_type != 'application/json': + return None + self._json = json.loads(self.body.decode()) + return self._json + + @property + def form(self): + """The parsed form submission body, as a + :class:`MultiDict ` object, or ``None`` if the + request does not have a form submission.""" + if self._form is None: + if self.content_type is None: + return None + mime_type = self.content_type.split(';')[0] + if mime_type != 'application/x-www-form-urlencoded': + return None + self._form = self._parse_urlencoded(self.body) + return self._form + + def after_request(self, f): + """Register a request-specific function to run after the request is + handled. Request-specific after request handlers run at the very end, + after the application's own after request handlers. The function must + take two arguments, the request and response objects. The return value + of the function must be the updated response object. + + Example:: + + @app.route('/') + def index(request): + # register a request-specific after request handler + @req.after_request + def func(request, response): + # ... + return response + + return 'Hello, World!' + + Note that the function is not called if the request handler raises an + exception and an error response is returned instead. + """ + self.after_request_handlers.append(f) + return f + + @staticmethod + async def _safe_readline(stream): + line = (await stream.readline()) + if len(line) > Request.max_readline: + raise ValueError('line too long') + return line + + +class Response: + """An HTTP response class. + + :param body: The body of the response. If a dictionary or list is given, + a JSON formatter is used to generate the body. If a file-like + object or an async generator is given, a streaming response is + used. If a string is given, it is encoded from UTF-8. Else, + the body should be a byte sequence. + :param status_code: The numeric HTTP status code of the response. The + default is 200. + :param headers: A dictionary of headers to include in the response. + :param reason: A custom reason phrase to add after the status code. The + default is "OK" for responses with a 200 status code and + "N/A" for any other status codes. + """ + types_map = { + 'css': 'text/css', + 'gif': 'image/gif', + 'html': 'text/html', + 'jpg': 'image/jpeg', + 'js': 'application/javascript', + 'json': 'application/json', + 'png': 'image/png', + 'txt': 'text/plain', + } + + send_file_buffer_size = 1024 + + #: The content type to use for responses that do not explicitly define a + #: ``Content-Type`` header. + default_content_type = 'text/plain' + + #: The default cache control max age used by :meth:`send_file`. A value + #: of ``None`` means that no ``Cache-Control`` header is added. + default_send_file_max_age = None + + #: Special response used to signal that a response does not need to be + #: written to the client. Used to exit WebSocket connections cleanly. + already_handled = None + + def __init__(self, body='', status_code=200, headers=None, reason=None): + if body is None and status_code == 200: + body = '' + status_code = 204 + self.status_code = status_code + self.headers = NoCaseDict(headers or {}) + self.reason = reason + if isinstance(body, (dict, list)): + self.body = json.dumps(body).encode() + self.headers['Content-Type'] = 'application/json; charset=UTF-8' + elif isinstance(body, str): + self.body = body.encode() + else: + # this applies to bytes, file-like objects or generators + self.body = body + self.is_head = False + + def set_cookie(self, cookie, value, path=None, domain=None, expires=None, + max_age=None, secure=False, http_only=False, + partitioned=False): + """Add a cookie to the response. + + :param cookie: The cookie's name. + :param value: The cookie's value. + :param path: The cookie's path. + :param domain: The cookie's domain. + :param expires: The cookie expiration time, as a ``datetime`` object + or a correctly formatted string. + :param max_age: The cookie's ``Max-Age`` value. + :param secure: The cookie's ``secure`` flag. + :param http_only: The cookie's ``HttpOnly`` flag. + :param partitioned: Whether the cookie is partitioned. + """ + http_cookie = '{cookie}={value}'.format(cookie=cookie, value=value) + if path: + http_cookie += '; Path=' + path + if domain: + http_cookie += '; Domain=' + domain + if expires: + if isinstance(expires, str): + http_cookie += '; Expires=' + expires + else: # pragma: no cover + http_cookie += '; Expires=' + time.strftime( + '%a, %d %b %Y %H:%M:%S GMT', expires.timetuple()) + if max_age is not None: + http_cookie += '; Max-Age=' + str(max_age) + if secure: + http_cookie += '; Secure' + if http_only: + http_cookie += '; HttpOnly' + if partitioned: + http_cookie += '; Partitioned' + if 'Set-Cookie' in self.headers: + self.headers['Set-Cookie'].append(http_cookie) + else: + self.headers['Set-Cookie'] = [http_cookie] + + def delete_cookie(self, cookie, **kwargs): + """Delete a cookie. + + :param cookie: The cookie's name. + :param kwargs: Any cookie opens and flags supported by + ``set_cookie()`` except ``expires`` and ``max_age``. + """ + self.set_cookie(cookie, '', expires='Thu, 01 Jan 1970 00:00:01 GMT', + max_age=0, **kwargs) + + def complete(self): + if isinstance(self.body, bytes) and \ + 'Content-Length' not in self.headers: + self.headers['Content-Length'] = str(len(self.body)) + if 'Content-Type' not in self.headers: + self.headers['Content-Type'] = self.default_content_type + if 'charset=' not in self.headers['Content-Type']: + self.headers['Content-Type'] += '; charset=UTF-8' + + async def write(self, stream): + self.complete() + + try: + # status code + reason = self.reason if self.reason is not None else \ + ('OK' if self.status_code == 200 else 'N/A') + await stream.awrite('HTTP/1.0 {status_code} {reason}\r\n'.format( + status_code=self.status_code, reason=reason).encode()) + + # headers + for header, value in self.headers.items(): + values = value if isinstance(value, list) else [value] + for value in values: + await stream.awrite('{header}: {value}\r\n'.format( + header=header, value=value).encode()) + await stream.awrite(b'\r\n') + + # body + if not self.is_head: + iter = self.body_iter() + async for body in iter: + if isinstance(body, str): # pragma: no cover + body = body.encode() + try: + await stream.awrite(body) + except OSError as exc: # pragma: no cover + if exc.errno in MUTED_SOCKET_ERRORS or \ + exc.args[0] == 'Connection lost': + if hasattr(iter, 'aclose'): + await iter.aclose() + raise + if hasattr(iter, 'aclose'): # pragma: no branch + await iter.aclose() + + except OSError as exc: # pragma: no cover + if exc.errno in MUTED_SOCKET_ERRORS or \ + exc.args[0] == 'Connection lost': + pass + else: + raise + + def body_iter(self): + if hasattr(self.body, '__anext__'): + # response body is an async generator + return self.body + + response = self + + class iter: + ITER_UNKNOWN = 0 + ITER_SYNC_GEN = 1 + ITER_FILE_OBJ = 2 + ITER_NO_BODY = -1 + + def __aiter__(self): + if response.body: + self.i = self.ITER_UNKNOWN # need to determine type + else: + self.i = self.ITER_NO_BODY + return self + + async def __anext__(self): + if self.i == self.ITER_NO_BODY: + await self.aclose() + raise StopAsyncIteration + if self.i == self.ITER_UNKNOWN: + if hasattr(response.body, 'read'): + self.i = self.ITER_FILE_OBJ + elif hasattr(response.body, '__next__'): + self.i = self.ITER_SYNC_GEN + return next(response.body) + else: + self.i = self.ITER_NO_BODY + return response.body + elif self.i == self.ITER_SYNC_GEN: + try: + return next(response.body) + except StopIteration: + await self.aclose() + raise StopAsyncIteration + buf = response.body.read(response.send_file_buffer_size) + if iscoroutine(buf): # pragma: no cover + buf = await buf + if len(buf) < response.send_file_buffer_size: + self.i = self.ITER_NO_BODY + return buf + + async def aclose(self): + if hasattr(response.body, 'close'): + result = response.body.close() + if iscoroutine(result): # pragma: no cover + await result + + return iter() + + @classmethod + def redirect(cls, location, status_code=302): + """Return a redirect response. + + :param location: The URL to redirect to. + :param status_code: The 3xx status code to use for the redirect. The + default is 302. + """ + if '\x0d' in location or '\x0a' in location: + raise ValueError('invalid redirect URL') + return cls(status_code=status_code, headers={'Location': location}) + + @classmethod + def send_file(cls, filename, status_code=200, content_type=None, + stream=None, max_age=None, compressed=False, + file_extension=''): + """Send file contents in a response. + + :param filename: The filename of the file. + :param status_code: The 3xx status code to use for the redirect. The + default is 302. + :param content_type: The ``Content-Type`` header to use in the + response. If omitted, it is generated + automatically from the file extension of the + ``filename`` parameter. + :param stream: A file-like object to read the file contents from. If + a stream is given, the ``filename`` parameter is only + used when generating the ``Content-Type`` header. + :param max_age: The ``Cache-Control`` header's ``max-age`` value in + seconds. If omitted, the value of the + :attr:`Response.default_send_file_max_age` attribute is + used. + :param compressed: Whether the file is compressed. If ``True``, the + ``Content-Encoding`` header is set to ``gzip``. A + string with the header value can also be passed. + Note that when using this option the file must have + been compressed beforehand. This option only sets + the header. + :param file_extension: A file extension to append to the ``filename`` + parameter when opening the file, including the + dot. The extension given here is not considered + when generating the ``Content-Type`` header. + + Security note: The filename is assumed to be trusted. Never pass + filenames provided by the user without validating and sanitizing them + first. + """ + if content_type is None: + if compressed and filename.endswith('.gz'): + ext = filename[:-3].split('.')[-1] + else: + ext = filename.split('.')[-1] + if ext in Response.types_map: + content_type = Response.types_map[ext] + else: + content_type = 'application/octet-stream' + headers = {'Content-Type': content_type} + + if max_age is None: + max_age = cls.default_send_file_max_age + if max_age is not None: + headers['Cache-Control'] = 'max-age={}'.format(max_age) + + if compressed: + headers['Content-Encoding'] = compressed \ + if isinstance(compressed, str) else 'gzip' + + f = stream or open(filename + file_extension, 'rb') + return cls(body=f, status_code=status_code, headers=headers) + + +class URLPattern(): + def __init__(self, url_pattern): + self.url_pattern = url_pattern + self.segments = [] + self.regex = None + pattern = '' + use_regex = False + for segment in url_pattern.lstrip('/').split('/'): + if segment and segment[0] == '<': + if segment[-1] != '>': + raise ValueError('invalid URL pattern') + segment = segment[1:-1] + if ':' in segment: + type_, name = segment.rsplit(':', 1) + else: + type_ = 'string' + name = segment + parser = None + if type_ == 'string': + parser = self._string_segment + pattern += '/([^/]+)' + elif type_ == 'int': + parser = self._int_segment + pattern += '/(-?\\d+)' + elif type_ == 'path': + use_regex = True + pattern += '/(.+)' + elif type_.startswith('re:'): + use_regex = True + pattern += '/({pattern})'.format(pattern=type_[3:]) + else: + raise ValueError('invalid URL segment type') + self.segments.append({'parser': parser, 'name': name, + 'type': type_}) + else: + pattern += '/' + segment + self.segments.append({'parser': self._static_segment(segment)}) + if use_regex: + import re + self.regex = re.compile('^' + pattern + '$') + + def match(self, path): + args = {} + if self.regex: + g = self.regex.match(path) + if not g: + return + i = 1 + for segment in self.segments: + if 'name' not in segment: + continue + value = g.group(i) + if segment['type'] == 'int': + value = int(value) + args[segment['name']] = value + i += 1 + else: + if len(path) == 0 or path[0] != '/': + return + path = path[1:] + args = {} + for segment in self.segments: + if path is None: + return + arg, path = segment['parser'](path) + if arg is None: + return + if 'name' in segment: + args[segment['name']] = arg + if path is not None: + return + return args + + def _static_segment(self, segment): + def _static(value): + s = value.split('/', 1) + if s[0] == segment: + return '', s[1] if len(s) > 1 else None + return None, None + return _static + + def _string_segment(self, value): + s = value.split('/', 1) + if len(s[0]) == 0: + return None, None + return s[0], s[1] if len(s) > 1 else None + + def _int_segment(self, value): + s = value.split('/', 1) + try: + return int(s[0]), s[1] if len(s) > 1 else None + except ValueError: + return None, None + + +class HTTPException(Exception): + def __init__(self, status_code, reason=None): + self.status_code = status_code + self.reason = reason or str(status_code) + ' error' + + def __repr__(self): # pragma: no cover + return 'HTTPException: {}'.format(self.status_code) + + +class Microdot: + """An HTTP application class. + + This class implements an HTTP application instance and is heavily + influenced by the ``Flask`` class of the Flask framework. It is typically + declared near the start of the main application script. + + Example:: + + from microdot import Microdot + + app = Microdot() + """ + + def __init__(self): + self.url_map = [] + self.before_request_handlers = [] + self.after_request_handlers = [] + self.after_error_request_handlers = [] + self.error_handlers = {} + self.shutdown_requested = False + self.options_handler = self.default_options_handler + self.debug = False + self.server = None + + def route(self, url_pattern, methods=None): + """Decorator that is used to register a function as a request handler + for a given URL. + + :param url_pattern: The URL pattern that will be compared against + incoming requests. + :param methods: The list of HTTP methods to be handled by the + decorated function. If omitted, only ``GET`` requests + are handled. + + The URL pattern can be a static path (for example, ``/users`` or + ``/api/invoices/search``) or a path with dynamic components enclosed + in ``<`` and ``>`` (for example, ``/users/`` or + ``/invoices//products``). Dynamic path components can also + include a type prefix, separated from the name with a colon (for + example, ``/users/``). The type can be ``string`` (the + default), ``int``, ``path`` or ``re:[regular-expression]``. + + The first argument of the decorated function must be + the request object. Any path arguments that are specified in the URL + pattern are passed as keyword arguments. The return value of the + function must be a :class:`Response` instance, or the arguments to + be passed to this class. + + Example:: + + @app.route('/') + def index(request): + return 'Hello, world!' + """ + def decorated(f): + self.url_map.append( + ([m.upper() for m in (methods or ['GET'])], + URLPattern(url_pattern), f)) + return f + return decorated + + def get(self, url_pattern): + """Decorator that is used to register a function as a ``GET`` request + handler for a given URL. + + :param url_pattern: The URL pattern that will be compared against + incoming requests. + + This decorator can be used as an alias to the ``route`` decorator with + ``methods=['GET']``. + + Example:: + + @app.get('/users/') + def get_user(request, id): + # ... + """ + return self.route(url_pattern, methods=['GET']) + + def post(self, url_pattern): + """Decorator that is used to register a function as a ``POST`` request + handler for a given URL. + + :param url_pattern: The URL pattern that will be compared against + incoming requests. + + This decorator can be used as an alias to the``route`` decorator with + ``methods=['POST']``. + + Example:: + + @app.post('/users') + def create_user(request): + # ... + """ + return self.route(url_pattern, methods=['POST']) + + def put(self, url_pattern): + """Decorator that is used to register a function as a ``PUT`` request + handler for a given URL. + + :param url_pattern: The URL pattern that will be compared against + incoming requests. + + This decorator can be used as an alias to the ``route`` decorator with + ``methods=['PUT']``. + + Example:: + + @app.put('/users/') + def edit_user(request, id): + # ... + """ + return self.route(url_pattern, methods=['PUT']) + + def patch(self, url_pattern): + """Decorator that is used to register a function as a ``PATCH`` request + handler for a given URL. + + :param url_pattern: The URL pattern that will be compared against + incoming requests. + + This decorator can be used as an alias to the ``route`` decorator with + ``methods=['PATCH']``. + + Example:: + + @app.patch('/users/') + def edit_user(request, id): + # ... + """ + return self.route(url_pattern, methods=['PATCH']) + + def delete(self, url_pattern): + """Decorator that is used to register a function as a ``DELETE`` + request handler for a given URL. + + :param url_pattern: The URL pattern that will be compared against + incoming requests. + + This decorator can be used as an alias to the ``route`` decorator with + ``methods=['DELETE']``. + + Example:: + + @app.delete('/users/') + def delete_user(request, id): + # ... + """ + return self.route(url_pattern, methods=['DELETE']) + + def before_request(self, f): + """Decorator to register a function to run before each request is + handled. The decorated function must take a single argument, the + request object. + + Example:: + + @app.before_request + def func(request): + # ... + """ + self.before_request_handlers.append(f) + return f + + def after_request(self, f): + """Decorator to register a function to run after each request is + handled. The decorated function must take two arguments, the request + and response objects. The return value of the function must be an + updated response object. + + Example:: + + @app.after_request + def func(request, response): + # ... + return response + """ + self.after_request_handlers.append(f) + return f + + def after_error_request(self, f): + """Decorator to register a function to run after an error response is + generated. The decorated function must take two arguments, the request + and response objects. The return value of the function must be an + updated response object. The handler is invoked for error responses + generated by Microdot, as well as those returned by application-defined + error handlers. + + Example:: + + @app.after_error_request + def func(request, response): + # ... + return response + """ + self.after_error_request_handlers.append(f) + return f + + def errorhandler(self, status_code_or_exception_class): + """Decorator to register a function as an error handler. Error handler + functions for numeric HTTP status codes must accept a single argument, + the request object. Error handler functions for Python exceptions + must accept two arguments, the request object and the exception + object. + + :param status_code_or_exception_class: The numeric HTTP status code or + Python exception class to + handle. + + Examples:: + + @app.errorhandler(404) + def not_found(request): + return 'Not found' + + @app.errorhandler(RuntimeError) + def runtime_error(request, exception): + return 'Runtime error' + """ + def decorated(f): + self.error_handlers[status_code_or_exception_class] = f + return f + return decorated + + def mount(self, subapp, url_prefix=''): + """Mount a sub-application, optionally under the given URL prefix. + + :param subapp: The sub-application to mount. + :param url_prefix: The URL prefix to mount the application under. + """ + for methods, pattern, handler in subapp.url_map: + self.url_map.append( + (methods, URLPattern(url_prefix + pattern.url_pattern), + handler)) + for handler in subapp.before_request_handlers: + self.before_request_handlers.append(handler) + for handler in subapp.after_request_handlers: + self.after_request_handlers.append(handler) + for handler in subapp.after_error_request_handlers: + self.after_error_request_handlers.append(handler) + for status_code, handler in subapp.error_handlers.items(): + self.error_handlers[status_code] = handler + + @staticmethod + def abort(status_code, reason=None): + """Abort the current request and return an error response with the + given status code. + + :param status_code: The numeric status code of the response. + :param reason: The reason for the response, which is included in the + response body. + + Example:: + + from microdot import abort + + @app.route('/users/') + def get_user(id): + user = get_user_by_id(id) + if user is None: + abort(404) + return user.to_dict() + """ + raise HTTPException(status_code, reason) + + async def start_server(self, host='0.0.0.0', port=5000, debug=False, + ssl=None): + """Start the Microdot web server as a coroutine. This coroutine does + not normally return, as the server enters an endless listening loop. + The :func:`shutdown` function provides a method for terminating the + server gracefully. + + :param host: The hostname or IP address of the network interface that + will be listening for requests. A value of ``'0.0.0.0'`` + (the default) indicates that the server should listen for + requests on all the available interfaces, and a value of + ``127.0.0.1`` indicates that the server should listen + for requests only on the internal networking interface of + the host. + :param port: The port number to listen for requests. The default is + port 5000. + :param debug: If ``True``, the server logs debugging information. The + default is ``False``. + :param ssl: An ``SSLContext`` instance or ``None`` if the server should + not use TLS. The default is ``None``. + + This method is a coroutine. + + Example:: + + import asyncio + from microdot import Microdot + + app = Microdot() + + @app.route('/') + async def index(request): + return 'Hello, world!' + + async def main(): + await app.start_server(debug=True) + + asyncio.run(main()) + """ + self.debug = debug + + async def serve(reader, writer): + if not hasattr(writer, 'awrite'): # pragma: no cover + # CPython provides the awrite and aclose methods in 3.8+ + async def awrite(self, data): + self.write(data) + await self.drain() + + async def aclose(self): + self.close() + await self.wait_closed() + + from types import MethodType + writer.awrite = MethodType(awrite, writer) + writer.aclose = MethodType(aclose, writer) + + await self.handle_request(reader, writer) + + if self.debug: # pragma: no cover + print('Starting async server on {host}:{port}...'.format( + host=host, port=port)) + + try: + self.server = await asyncio.start_server(serve, host, port, + ssl=ssl) + except TypeError: # pragma: no cover + self.server = await asyncio.start_server(serve, host, port) + + while True: + try: + if hasattr(self.server, 'serve_forever'): # pragma: no cover + try: + await self.server.serve_forever() + except asyncio.CancelledError: + pass + await self.server.wait_closed() + break + except AttributeError: # pragma: no cover + # the task hasn't been initialized in the server object yet + # wait a bit and try again + await asyncio.sleep(0.1) + + def run(self, host='0.0.0.0', port=5000, debug=False, ssl=None): + """Start the web server. This function does not normally return, as + the server enters an endless listening loop. The :func:`shutdown` + function provides a method for terminating the server gracefully. + + :param host: The hostname or IP address of the network interface that + will be listening for requests. A value of ``'0.0.0.0'`` + (the default) indicates that the server should listen for + requests on all the available interfaces, and a value of + ``127.0.0.1`` indicates that the server should listen + for requests only on the internal networking interface of + the host. + :param port: The port number to listen for requests. The default is + port 5000. + :param debug: If ``True``, the server logs debugging information. The + default is ``False``. + :param ssl: An ``SSLContext`` instance or ``None`` if the server should + not use TLS. The default is ``None``. + + Example:: + + from microdot import Microdot + + app = Microdot() + + @app.route('/') + async def index(request): + return 'Hello, world!' + + app.run(debug=True) + """ + asyncio.run(self.start_server(host=host, port=port, debug=debug, + ssl=ssl)) # pragma: no cover + + def shutdown(self): + """Request a server shutdown. The server will then exit its request + listening loop and the :func:`run` function will return. This function + can be safely called from a route handler, as it only schedules the + server to terminate as soon as the request completes. + + Example:: + + @app.route('/shutdown') + def shutdown(request): + request.app.shutdown() + return 'The server is shutting down...' + """ + self.server.close() + + def find_route(self, req): + method = req.method.upper() + if method == 'OPTIONS' and self.options_handler: + return self.options_handler(req) + if method == 'HEAD': + method = 'GET' + f = 404 + for route_methods, route_pattern, route_handler in self.url_map: + req.url_args = route_pattern.match(req.path) + if req.url_args is not None: + if method in route_methods: + f = route_handler + break + else: + f = 405 + return f + + def default_options_handler(self, req): + allow = [] + for route_methods, route_pattern, route_handler in self.url_map: + if route_pattern.match(req.path) is not None: + allow.extend(route_methods) + if 'GET' in allow: + allow.append('HEAD') + allow.append('OPTIONS') + return {'Allow': ', '.join(allow)} + + async def handle_request(self, reader, writer): + req = None + try: + req = await Request.create(self, reader, writer, + writer.get_extra_info('peername')) + except Exception as exc: # pragma: no cover + print_exception(exc) + + res = await self.dispatch_request(req) + if res != Response.already_handled: # pragma: no branch + await res.write(writer) + try: + await writer.aclose() + except OSError as exc: # pragma: no cover + if exc.errno in MUTED_SOCKET_ERRORS: + pass + else: + raise + if self.debug and req: # pragma: no cover + print('{method} {path} {status_code}'.format( + method=req.method, path=req.path, + status_code=res.status_code)) + + async def dispatch_request(self, req): + after_request_handled = False + if req: + if req.content_length > req.max_content_length: + if 413 in self.error_handlers: + res = await invoke_handler(self.error_handlers[413], req) + else: + res = 'Payload too large', 413 + else: + f = self.find_route(req) + try: + res = None + if callable(f): + for handler in self.before_request_handlers: + res = await invoke_handler(handler, req) + if res: + break + if res is None: + res = await invoke_handler( + f, req, **req.url_args) + if isinstance(res, int): + res = '', res + if isinstance(res, tuple): + if isinstance(res[0], int): + res = ('', res[0], + res[1] if len(res) > 1 else {}) + body = res[0] + if isinstance(res[1], int): + status_code = res[1] + headers = res[2] if len(res) > 2 else {} + else: + status_code = 200 + headers = res[1] + res = Response(body, status_code, headers) + elif not isinstance(res, Response): + res = Response(res) + for handler in self.after_request_handlers: + res = await invoke_handler( + handler, req, res) or res + for handler in req.after_request_handlers: + res = await invoke_handler( + handler, req, res) or res + after_request_handled = True + elif isinstance(f, dict): + res = Response(headers=f) + elif f in self.error_handlers: + res = await invoke_handler(self.error_handlers[f], req) + else: + res = 'Not found', f + except HTTPException as exc: + if exc.status_code in self.error_handlers: + res = self.error_handlers[exc.status_code](req) + else: + res = exc.reason, exc.status_code + except Exception as exc: + print_exception(exc) + exc_class = None + res = None + if exc.__class__ in self.error_handlers: + exc_class = exc.__class__ + else: + for c in mro(exc.__class__)[1:]: + if c in self.error_handlers: + exc_class = c + break + if exc_class: + try: + res = await invoke_handler( + self.error_handlers[exc_class], req, exc) + except Exception as exc2: # pragma: no cover + print_exception(exc2) + if res is None: + if 500 in self.error_handlers: + res = await invoke_handler( + self.error_handlers[500], req) + else: + res = 'Internal server error', 500 + else: + if 400 in self.error_handlers: + res = await invoke_handler(self.error_handlers[400], req) + else: + res = 'Bad request', 400 + if isinstance(res, tuple): + res = Response(*res) + elif not isinstance(res, Response): + res = Response(res) + if not after_request_handled: + for handler in self.after_error_request_handlers: + res = await invoke_handler( + handler, req, res) or res + res.is_head = (req and req.method == 'HEAD') + return res + + +Response.already_handled = Response() + +abort = Microdot.abort +redirect = Response.redirect +send_file = Response.send_file \ No newline at end of file diff --git a/lib/microdot/session.py b/lib/microdot/session.py new file mode 100644 index 0000000..78ce2e6 --- /dev/null +++ b/lib/microdot/session.py @@ -0,0 +1,225 @@ +try: + import jwt + HAS_JWT = True +except ImportError: + HAS_JWT = False + try: + import ubinascii + except ImportError: + import binascii as ubinascii + try: + import uhashlib as hashlib + except ImportError: + import hashlib + try: + import uhmac as hmac + except ImportError: + try: + import hmac + except ImportError: + hmac = None + import json + +from microdot.microdot import invoke_handler +from microdot.helpers import wraps + + +class SessionDict(dict): + """A session dictionary. + + The session dictionary is a standard Python dictionary that has been + extended with convenience ``save()`` and ``delete()`` methods. + """ + def __init__(self, request, session_dict): + super().__init__(session_dict) + self.request = request + + def save(self): + """Update the session cookie.""" + self.request.app._session.update(self.request, self) + + def delete(self): + """Delete the session cookie.""" + self.request.app._session.delete(self.request) + + +class Session: + """Session handling + + :param app: The application instance. + :param secret_key: The secret key, as a string or bytes object. + :param cookie_options: A dictionary with cookie options to pass as + arguments to :meth:`Response.set_cookie() + `. + """ + secret_key = None + + def __init__(self, app=None, secret_key=None, cookie_options=None): + self.secret_key = secret_key + self.cookie_options = cookie_options or {} + if app is not None: + self.initialize(app) + + def initialize(self, app, secret_key=None, cookie_options=None): + if secret_key is not None: + self.secret_key = secret_key + if cookie_options is not None: + self.cookie_options = cookie_options + if 'path' not in self.cookie_options: + self.cookie_options['path'] = '/' + if 'http_only' not in self.cookie_options: + self.cookie_options['http_only'] = True + app._session = self + + def get(self, request): + """Retrieve the user session. + + :param request: The client request. + + The return value is a session dictionary with the data stored in the + user's session, or ``{}`` if the session data is not available or + invalid. + """ + if not self.secret_key: + raise ValueError('The session secret key is not configured') + if hasattr(request.g, '_session'): + return request.g._session + session = request.cookies.get('session') + if session is None: + request.g._session = SessionDict(request, {}) + return request.g._session + request.g._session = SessionDict(request, self.decode(session)) + return request.g._session + + def update(self, request, session): + """Update the user session. + + :param request: The client request. + :param session: A dictionary with the update session data for the user. + + Applications would normally not call this method directly, instead they + would use the :meth:`SessionDict.save` method on the session + dictionary, which calls this method. For example:: + + @app.route('/') + @with_session + def index(request, session): + session['foo'] = 'bar' + session.save() + return 'Hello, World!' + + Calling this method adds a cookie with the updated session to the + request currently being processed. + """ + if not self.secret_key: + raise ValueError('The session secret key is not configured') + + encoded_session = self.encode(session) + + @request.after_request + def _update_session(request, response): + response.set_cookie('session', encoded_session, + **self.cookie_options) + return response + + def delete(self, request): + """Remove the user session. + + :param request: The client request. + + Applications would normally not call this method directly, instead they + would use the :meth:`SessionDict.delete` method on the session + dictionary, which calls this method. For example:: + + @app.route('/') + @with_session + def index(request, session): + session.delete() + return 'Hello, World!' + + Calling this method adds a cookie removal header to the request + currently being processed. + """ + @request.after_request + def _delete_session(request, response): + response.delete_cookie('session', **self.cookie_options) + return response + + def encode(self, payload, secret_key=None): + """Encode session data using JWT if available, otherwise use simple HMAC.""" + if HAS_JWT: + return jwt.encode(payload, secret_key or self.secret_key, + algorithm='HS256') + else: + # Simple encoding for MicroPython: base64(json) + HMAC signature + key = (secret_key or self.secret_key).encode() if isinstance(secret_key or self.secret_key, str) else (secret_key or self.secret_key) + payload_json = json.dumps(payload) + payload_b64 = ubinascii.b2a_base64(payload_json.encode()).decode().strip() + + # Create HMAC signature + if hmac: + # Use hmac module if available + h = hmac.new(key, payload_json.encode(), hashlib.sha256) + else: + # Fallback: simple SHA256(key + message) + h = hashlib.sha256(key + payload_json.encode()) + signature = ubinascii.b2a_base64(h.digest()).decode().strip() + + return f"{payload_b64}.{signature}" + + def decode(self, session, secret_key=None): + """Decode session data using JWT if available, otherwise use simple HMAC.""" + if HAS_JWT: + try: + payload = jwt.decode(session, secret_key or self.secret_key, + algorithms=['HS256']) + except jwt.exceptions.PyJWTError: # pragma: no cover + return {} + return payload + else: + try: + # Simple decoding for MicroPython + if '.' not in session: + return {} + + payload_b64, signature = session.rsplit('.', 1) + payload_json = ubinascii.a2b_base64(payload_b64).decode() + + # Verify HMAC signature + key = (secret_key or self.secret_key).encode() if isinstance(secret_key or self.secret_key, str) else (secret_key or self.secret_key) + if hmac: + # Use hmac module if available + h = hmac.new(key, payload_json.encode(), hashlib.sha256) + else: + # Fallback: simple SHA256(key + message) + h = hashlib.sha256(key + payload_json.encode()) + expected_signature = ubinascii.b2a_base64(h.digest()).decode().strip() + + if signature != expected_signature: + return {} + + return json.loads(payload_json) + except Exception: + return {} + + +def with_session(f): + """Decorator that passes the user session to the route handler. + + The session dictionary is passed to the decorated function as an argument + after the request object. Example:: + + @app.route('/') + @with_session + def index(request, session): + return 'Hello, World!' + + Note that the decorator does not save the session. To update the session, + call the :func:`session.save() ` method. + """ + @wraps(f) + async def wrapper(request, *args, **kwargs): + return await invoke_handler( + f, request, request.app._session.get(request), *args, **kwargs) + + return wrapper diff --git a/lib/microdot/utemplate.py b/lib/microdot/utemplate.py new file mode 100644 index 0000000..16d0398 --- /dev/null +++ b/lib/microdot/utemplate.py @@ -0,0 +1,70 @@ +from utemplate import recompile + +_loader = None + + +class Template: + """A template object. + + :param template: The filename of the template to render, relative to the + configured template directory. + """ + @classmethod + def initialize(cls, template_dir='templates', + loader_class=recompile.Loader): + """Initialize the templating subsystem. + + :param template_dir: the directory where templates are stored. This + argument is optional. The default is to load + templates from a *templates* subdirectory. + :param loader_class: the ``utemplate.Loader`` class to use when loading + templates. This argument is optional. The default + is the ``recompile.Loader`` class, which + automatically recompiles templates when they + change. + """ + global _loader + _loader = loader_class(None, template_dir) + + def __init__(self, template): + if _loader is None: # pragma: no cover + self.initialize() + #: The name of the template + self.name = template + self.template = _loader.load(template) + + def generate(self, *args, **kwargs): + """Return a generator that renders the template in chunks, with the + given arguments.""" + return self.template(*args, **kwargs) + + def render(self, *args, **kwargs): + """Render the template with the given arguments and return it as a + string.""" + return ''.join(self.generate(*args, **kwargs)) + + def generate_async(self, *args, **kwargs): + """Return an asynchronous generator that renders the template in + chunks, using the given arguments.""" + class sync_to_async_iter(): + def __init__(self, iter): + self.iter = iter + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.iter) + except StopIteration: + raise StopAsyncIteration + + return sync_to_async_iter(self.generate(*args, **kwargs)) + + async def render_async(self, *args, **kwargs): + """Render the template with the given arguments asynchronously and + return it as a string.""" + response = '' + async for chunk in self.generate_async(*args, **kwargs): + response += chunk + return response diff --git a/lib/microdot/websocket.py b/lib/microdot/websocket.py new file mode 100644 index 0000000..0fb6f7c --- /dev/null +++ b/lib/microdot/websocket.py @@ -0,0 +1,231 @@ +import binascii +import hashlib +from microdot import Request, Response +from microdot.microdot import MUTED_SOCKET_ERRORS, print_exception +from microdot.helpers import wraps + + +class WebSocketError(Exception): + """Exception raised when an error occurs in a WebSocket connection.""" + pass + + +class WebSocket: + """A WebSocket connection object. + + An instance of this class is sent to handler functions to manage the + WebSocket connection. + """ + CONT = 0 + TEXT = 1 + BINARY = 2 + CLOSE = 8 + PING = 9 + PONG = 10 + + #: Specify the maximum message size that can be received when calling the + #: ``receive()`` method. Messages with payloads that are larger than this + #: size will be rejected and the connection closed. Set to 0 to disable + #: the size check (be aware of potential security issues if you do this), + #: or to -1 to use the value set in + #: ``Request.max_body_length``. The default is -1. + #: + #: Example:: + #: + #: WebSocket.max_message_length = 4 * 1024 # up to 4KB messages + max_message_length = -1 + + def __init__(self, request): + self.request = request + self.closed = False + + async def handshake(self): + response = self._handshake_response() + await self.request.sock[1].awrite( + b'HTTP/1.1 101 Switching Protocols\r\n') + await self.request.sock[1].awrite(b'Upgrade: websocket\r\n') + await self.request.sock[1].awrite(b'Connection: Upgrade\r\n') + await self.request.sock[1].awrite( + b'Sec-WebSocket-Accept: ' + response + b'\r\n\r\n') + + async def receive(self): + """Receive a message from the client.""" + while True: + opcode, payload = await self._read_frame() + send_opcode, data = self._process_websocket_frame(opcode, payload) + if send_opcode: # pragma: no cover + await self.send(data, send_opcode) + elif data: # pragma: no branch + return data + + async def send(self, data, opcode=None): + """Send a message to the client. + + :param data: the data to send, given as a string or bytes. + :param opcode: a custom frame opcode to use. If not given, the opcode + is ``TEXT`` or ``BINARY`` depending on the type of the + data. + """ + frame = self._encode_websocket_frame( + opcode or (self.TEXT if isinstance(data, str) else self.BINARY), + data) + await self.request.sock[1].awrite(frame) + + async def close(self): + """Close the websocket connection.""" + if not self.closed: # pragma: no cover + self.closed = True + await self.send(b'', self.CLOSE) + + def _handshake_response(self): + connection = False + upgrade = False + websocket_key = None + for header, value in self.request.headers.items(): + h = header.lower() + if h == 'connection': + connection = True + if 'upgrade' not in value.lower(): + return self.request.app.abort(400) + elif h == 'upgrade': + upgrade = True + if not value.lower() == 'websocket': + return self.request.app.abort(400) + elif h == 'sec-websocket-key': + websocket_key = value + if not connection or not upgrade or not websocket_key: + return self.request.app.abort(400) + d = hashlib.sha1(websocket_key.encode()) + d.update(b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11') + return binascii.b2a_base64(d.digest())[:-1] + + @classmethod + def _parse_frame_header(cls, header): + fin = header[0] & 0x80 + opcode = header[0] & 0x0f + if fin == 0 or opcode == cls.CONT: # pragma: no cover + raise WebSocketError('Continuation frames not supported') + has_mask = header[1] & 0x80 + length = header[1] & 0x7f + if length == 126: + length = -2 + elif length == 127: + length = -8 + return fin, opcode, has_mask, length + + def _process_websocket_frame(self, opcode, payload): + if opcode == self.TEXT: + payload = payload.decode() + elif opcode == self.BINARY: + pass + elif opcode == self.CLOSE: + raise WebSocketError('Websocket connection closed') + elif opcode == self.PING: + return self.PONG, payload + elif opcode == self.PONG: # pragma: no branch + return None, None + return None, payload + + @classmethod + def _encode_websocket_frame(cls, opcode, payload): + frame = bytearray() + frame.append(0x80 | opcode) + if opcode == cls.TEXT: + payload = payload.encode() + if len(payload) < 126: + frame.append(len(payload)) + elif len(payload) < (1 << 16): + frame.append(126) + frame.extend(len(payload).to_bytes(2, 'big')) + else: + frame.append(127) + frame.extend(len(payload).to_bytes(8, 'big')) + frame.extend(payload) + return frame + + async def _read_frame(self): + header = await self.request.sock[0].read(2) + if len(header) != 2: # pragma: no cover + raise WebSocketError('Websocket connection closed') + fin, opcode, has_mask, length = self._parse_frame_header(header) + if length == -2: + length = await self.request.sock[0].read(2) + length = int.from_bytes(length, 'big') + elif length == -8: + length = await self.request.sock[0].read(8) + length = int.from_bytes(length, 'big') + max_allowed_length = Request.max_body_length \ + if self.max_message_length == -1 else self.max_message_length + if length > max_allowed_length: + raise WebSocketError('Message too large') + if has_mask: # pragma: no cover + mask = await self.request.sock[0].read(4) + payload = await self.request.sock[0].read(length) + if has_mask: # pragma: no cover + payload = bytes(x ^ mask[i % 4] for i, x in enumerate(payload)) + return opcode, payload + + +async def websocket_upgrade(request): + """Upgrade a request handler to a websocket connection. + + This function can be called directly inside a route function to process a + WebSocket upgrade handshake, for example after the user's credentials are + verified. The function returns the websocket object:: + + @app.route('/echo') + async def echo(request): + if not authenticate_user(request): + abort(401) + ws = await websocket_upgrade(request) + while True: + message = await ws.receive() + await ws.send(message) + """ + ws = WebSocket(request) + await ws.handshake() + + @request.after_request + async def after_request(request, response): + return Response.already_handled + + return ws + + +def websocket_wrapper(f, upgrade_function): + @wraps(f) + async def wrapper(request, *args, **kwargs): + ws = await upgrade_function(request) + try: + await f(request, ws, *args, **kwargs) + except OSError as exc: + if exc.errno not in MUTED_SOCKET_ERRORS: # pragma: no cover + raise + except WebSocketError: + pass + except Exception as exc: + print_exception(exc) + finally: # pragma: no cover + try: + await ws.close() + except Exception: + pass + return Response.already_handled + return wrapper + + +def with_websocket(f): + """Decorator to make a route a WebSocket endpoint. + + This decorator is used to define a route that accepts websocket + connections. The route then receives a websocket object as a second + argument that it can use to send and receive messages:: + + @app.route('/echo') + @with_websocket + async def echo(request, ws): + while True: + message = await ws.receive() + await ws.send(message) + """ + return websocket_wrapper(f, websocket_upgrade) diff --git a/lib/utemplate/__init__.py b/lib/utemplate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lib/utemplate/compiled.py b/lib/utemplate/compiled.py new file mode 100644 index 0000000..006e6f5 --- /dev/null +++ b/lib/utemplate/compiled.py @@ -0,0 +1,14 @@ +class Loader: + + def __init__(self, pkg, dir): + if dir == ".": + dir = "" + else: + dir = dir.replace("/", ".") + "." + if pkg and pkg != "__main__": + dir = pkg + "." + dir + self.p = dir + + def load(self, name): + name = name.replace(".", "_") + return __import__(self.p + name, None, None, (name,)).render \ No newline at end of file diff --git a/lib/utemplate/recompile.py b/lib/utemplate/recompile.py new file mode 100644 index 0000000..b9bae4e --- /dev/null +++ b/lib/utemplate/recompile.py @@ -0,0 +1,21 @@ +# (c) 2014-2020 Paul Sokolovsky. MIT license. +try: + from uos import stat, remove +except: + from os import stat, remove +from . import source + + +class Loader(source.Loader): + + def load(self, name): + o_path = self.pkg_path + self.compiled_path(name) + i_path = self.pkg_path + self.dir + "/" + name + try: + o_stat = stat(o_path) + i_stat = stat(i_path) + if i_stat[8] > o_stat[8]: + # input file is newer, remove output to force recompile + remove(o_path) + finally: + return super().load(name) \ No newline at end of file diff --git a/lib/utemplate/source.py b/lib/utemplate/source.py new file mode 100644 index 0000000..0ff4651 --- /dev/null +++ b/lib/utemplate/source.py @@ -0,0 +1,188 @@ +# (c) 2014-2019 Paul Sokolovsky. MIT license. +from . import compiled + + +class Compiler: + + START_CHAR = "{" + STMNT = "%" + STMNT_END = "%}" + EXPR = "{" + EXPR_END = "}}" + + def __init__(self, file_in, file_out, indent=0, seq=0, loader=None): + self.file_in = file_in + self.file_out = file_out + self.loader = loader + self.seq = seq + self._indent = indent + self.stack = [] + self.in_literal = False + self.flushed_header = False + self.args = "*a, **d" + + def indent(self, adjust=0): + if not self.flushed_header: + self.flushed_header = True + self.indent() + self.file_out.write("def render%s(%s):\n" % (str(self.seq) if self.seq else "", self.args)) + self.stack.append("def") + self.file_out.write(" " * (len(self.stack) + self._indent + adjust)) + + def literal(self, s): + if not s: + return + if not self.in_literal: + self.indent() + self.file_out.write('yield """') + self.in_literal = True + self.file_out.write(s.replace('"', '\\"')) + + def close_literal(self): + if self.in_literal: + self.file_out.write('"""\n') + self.in_literal = False + + def render_expr(self, e): + self.indent() + self.file_out.write('yield str(' + e + ')\n') + + def parse_statement(self, stmt): + tokens = stmt.split(None, 1) + if tokens[0] == "args": + if len(tokens) > 1: + self.args = tokens[1] + else: + self.args = "" + elif tokens[0] == "set": + self.indent() + self.file_out.write(stmt[3:].strip() + "\n") + elif tokens[0] == "include": + if not self.flushed_header: + # If there was no other output, we still need a header now + self.indent() + tokens = tokens[1].split(None, 1) + args = "" + if len(tokens) > 1: + args = tokens[1] + if tokens[0][0] == "{": + self.indent() + # "1" as fromlist param is uPy hack + self.file_out.write('_ = __import__(%s.replace(".", "_"), None, None, 1)\n' % tokens[0][2:-2]) + self.indent() + self.file_out.write("yield from _.render(%s)\n" % args) + return + + with self.loader.input_open(tokens[0][1:-1]) as inc: + self.seq += 1 + c = Compiler(inc, self.file_out, len(self.stack) + self._indent, self.seq) + inc_id = self.seq + self.seq = c.compile() + self.indent() + self.file_out.write("yield from render%d(%s)\n" % (inc_id, args)) + elif len(tokens) > 1: + if tokens[0] == "elif": + assert self.stack[-1] == "if" + self.indent(-1) + self.file_out.write(stmt + ":\n") + else: + self.indent() + self.file_out.write(stmt + ":\n") + self.stack.append(tokens[0]) + else: + if stmt.startswith("end"): + assert self.stack[-1] == stmt[3:] + self.stack.pop(-1) + elif stmt == "else": + assert self.stack[-1] == "if" + self.indent(-1) + self.file_out.write("else:\n") + else: + assert False + + def parse_line(self, l): + while l: + start = l.find(self.START_CHAR) + if start == -1: + self.literal(l) + return + self.literal(l[:start]) + self.close_literal() + sel = l[start + 1] + #print("*%s=%s=" % (sel, EXPR)) + if sel == self.STMNT: + end = l.find(self.STMNT_END) + assert end > 0 + stmt = l[start + len(self.START_CHAR + self.STMNT):end].strip() + self.parse_statement(stmt) + end += len(self.STMNT_END) + l = l[end:] + if not self.in_literal and l == "\n": + break + elif sel == self.EXPR: + # print("EXPR") + end = l.find(self.EXPR_END) + assert end > 0 + expr = l[start + len(self.START_CHAR + self.EXPR):end].strip() + self.render_expr(expr) + end += len(self.EXPR_END) + l = l[end:] + else: + self.literal(l[start]) + l = l[start + 1:] + + def header(self): + self.file_out.write("# Autogenerated file\n") + + def compile(self): + self.header() + for l in self.file_in: + self.parse_line(l) + self.close_literal() + return self.seq + + +class Loader(compiled.Loader): + + def __init__(self, pkg, dir): + super().__init__(pkg, dir) + self.dir = dir + if pkg == "__main__": + # if pkg isn't really a package, don't bother to use it + # it means we're running from "filesystem directory", not + # from a package. + pkg = None + + self.pkg_path = "" + if pkg: + p = __import__(pkg) + if isinstance(p.__path__, str): + # uPy + self.pkg_path = p.__path__ + else: + # CPy + self.pkg_path = p.__path__[0] + self.pkg_path += "/" + + def input_open(self, template): + path = self.pkg_path + self.dir + "/" + template + return open(path) + + def compiled_path(self, template): + return self.dir + "/" + template.replace(".", "_") + ".py" + + def load(self, name): + try: + return super().load(name) + except (OSError, ImportError): + pass + + compiled_path = self.pkg_path + self.compiled_path(name) + + f_in = self.input_open(name) + f_out = open(compiled_path, "w") + c = Compiler(f_in, f_out, loader=self) + c.compile() + f_in.close() + f_out.close() + return super().load(name) \ No newline at end of file diff --git a/src/controller_messages.py b/src/controller_messages.py new file mode 100644 index 0000000..9187c0c --- /dev/null +++ b/src/controller_messages.py @@ -0,0 +1,218 @@ +"""Parse controller JSON (v1) and apply brightness, presets, OTA patterns, etc.""" + +import json +import socket + +from utils import convert_and_reorder_colors + +try: + import uos as os +except ImportError: + import os + + +def process_data(payload, settings, presets, controller_ip=None): + """Read one controller message; json.loads (bytes or str), then apply fields.""" + try: + data = json.loads(payload) + print(payload) + if data.get("v", "") != "1": + return + except (ValueError, TypeError): + return + if "b" in data: + apply_brightness(data, settings, presets) + if "presets" in data: + apply_presets(data, settings, presets) + if "select" in data: + apply_select(data, settings, presets) + if "default" in data: + apply_default(data, settings, presets) + if "manifest" in data: + apply_patterns_ota(data, presets, controller_ip=controller_ip) + if "save" in data and ("presets" in data or "default" in data): + presets.save() + + +def apply_brightness(data, settings, presets): + try: + presets.b = max(0, min(255, int(data["b"]))) + settings["brightness"] = presets.b + except (TypeError, ValueError): + pass + + +def apply_presets(data, settings, presets): + presets_map = data["presets"] + for id, preset_data in presets_map.items(): + if not preset_data: + continue + color_key = "c" if "c" in preset_data else ("colors" if "colors" in preset_data else None) + if color_key is not None: + try: + preset_data[color_key] = convert_and_reorder_colors( + preset_data[color_key], settings + ) + except (TypeError, ValueError, KeyError): + continue + presets.edit(id, preset_data) + print(f"Edited preset {id}: {preset_data.get('name', '')}") + + +def apply_select(data, settings, presets): + select_map = data["select"] + device_name = settings["name"] + select_list = select_map.get(device_name, []) + if not select_list: + return + preset_name = select_list[0] + step = select_list[1] if len(select_list) > 1 else None + presets.select(preset_name, step=step) + + +def apply_default(data, settings, presets): + targets = data.get("targets") or [] + default_name = data["default"] + if ( + settings["name"] in targets + and isinstance(default_name, str) + and default_name in presets.presets + ): + settings["default"] = default_name + + +def _parse_http_url(url): + """Parse http://host[:port]/path into (host, port, path).""" + if not isinstance(url, str): + raise ValueError("url must be a string") + if not url.startswith("http://"): + raise ValueError("only http:// URLs are supported") + remainder = url[7:] + slash_idx = remainder.find("/") + if slash_idx == -1: + host_port = remainder + path = "/" + else: + host_port = remainder[:slash_idx] + path = remainder[slash_idx:] + if ":" in host_port: + host, port_s = host_port.rsplit(":", 1) + port = int(port_s) + else: + host = host_port + port = 80 + if not host: + raise ValueError("missing host") + return host, port, path + + +def _http_get_raw(url, timeout_s=10.0): + host, port, path = _parse_http_url(url) + req = ( + "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n" % (path, host) + ).encode("utf-8") + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + sock.settimeout(timeout_s) + sock.connect((host, int(port))) + sock.send(req) + data = b"" + while True: + chunk = sock.recv(1024) + if not chunk: + break + data += chunk + finally: + try: + sock.close() + except Exception: + pass + sep = b"\r\n\r\n" + if sep not in data: + raise OSError("invalid HTTP response") + head, body = data.split(sep, 1) + status_line = head.split(b"\r\n", 1)[0] + if b" 200 " not in status_line: + raise OSError("HTTP status not OK: %s" % status_line.decode("utf-8")) + return body + + +def _http_get_json(url, timeout_s=10.0): + body = _http_get_raw(url, timeout_s=timeout_s) + return json.loads(body.decode("utf-8")) + + +def _http_get_text(url, timeout_s=10.0, controller_ip=None): + # Support relative URLs from controller messages. + if isinstance(url, str) and url.startswith("/"): + if not controller_ip: + raise OSError("controller IP unavailable for relative URL") + url = "http://%s%s" % (controller_ip, url) + try: + body = _http_get_raw(url, timeout_s=timeout_s) + return body.decode("utf-8") + except Exception: + # Fallback for mDNS/unresolvable host: retry against current controller IP. + if not controller_ip or not isinstance(url, str) or not url.startswith("http://"): + raise + _host, _port, path = _parse_http_url(url) + fallback = "http://%s:%d%s" % (controller_ip, _port, path) + body = _http_get_raw(fallback, timeout_s=timeout_s) + return body.decode("utf-8") + + +def _safe_pattern_filename(name): + if not isinstance(name, str): + return False + if not name.endswith(".py"): + return False + if "/" in name or "\\" in name or ".." in name: + return False + return True + + +def apply_patterns_ota(data, presets, controller_ip=None): + manifest_payload = data.get("manifest") + if not manifest_payload: + return + try: + if isinstance(manifest_payload, dict): + manifest = manifest_payload + elif isinstance(manifest_payload, str): + manifest = _http_get_json(manifest_payload, timeout_s=20.0) + else: + print("patterns_ota: invalid manifest payload type") + return + files = manifest.get("files", []) + if not isinstance(files, list) or not files: + print("patterns_ota: no files in manifest") + return + try: + os.mkdir("patterns") + except OSError: + pass + updated = 0 + for item in files: + if not isinstance(item, dict): + continue + name = item.get("name") + url = item.get("url") + inline_code = item.get("code") + if not _safe_pattern_filename(name): + continue + if isinstance(inline_code, str): + code = inline_code + elif isinstance(url, str): + code = _http_get_text(url, timeout_s=20.0, controller_ip=controller_ip) + else: + continue + with open("patterns/" + name, "w") as f: + f.write(code) + updated += 1 + if updated > 0: + presets.reload_patterns() + print("patterns_ota: updated", updated, "pattern file(s)") + else: + print("patterns_ota: no valid files downloaded") + except Exception as e: + print("patterns_ota failed:", e) diff --git a/src/hello.py b/src/hello.py index 1c6e16e..155cec2 100644 --- a/src/hello.py +++ b/src/hello.py @@ -1,4 +1,8 @@ -"""LED hello payload and UDP broadcast discovery (controller IP via echo on port 8766). +"""LED hello JSON line and UDP broadcast on port 8766. + +Used so led-controller can register the device (name, MAC, IP) when ``wait_reply`` is +false; the controller may then connect to the device's WebSocket. With +``wait_reply`` true, blocks for an echo and returns the controller IP (legacy discovery). Wi-Fi must already be connected; this module does not use Settings or call connect(). """ @@ -40,7 +44,13 @@ def ipv4_broadcast(ip, netmask): im = [int(x) for x in netmask.split(".")] if len(ia) != 4 or len(im) != 4: return None - return ".".join(str(ia[i] | (255 - im[i])) for i in range(4)) + # STA often reports 255.255.255.255; "broadcast" would equal the host IP — useless for LAN. + if netmask == "255.255.255.255": + return None + bcast = ".".join(str(ia[i] | (255 - im[i])) for i in range(4)) + if bcast == ip: + return None + return bcast def udp_discovery_targets(ip, mask): @@ -52,6 +62,14 @@ def udp_discovery_targets(ip, mask): return out +def _udp_discovery_targets_single(ip, mask): + """One destination: subnet broadcast if known, else limited broadcast.""" + b = ipv4_broadcast(ip, mask) + if b: + return [(b, DISCOVERY_UDP_PORT)] + return [("255.255.255.255", DISCOVERY_UDP_PORT)] + + def broadcast_hello_udp( sta, device_name="", @@ -59,11 +77,17 @@ def broadcast_hello_udp( wait_reply=True, recv_timeout_s=DEFAULT_RECV_TIMEOUT_S, wdt=None, + dual_destinations=True, ): """ - Send pack_hello_line via directed then 255.255.255.255 on DISCOVERY_UDP_PORT. + Send pack_hello_line on DISCOVERY_UDP_PORT. STA must already be connected with a valid IPv4 (caller brings up Wi-Fi). + If dual_destinations (default), send subnet broadcast then 255.255.255.255 so + discovery works on awkward APs — the controller may receive two packets. + If dual_destinations is False, send only one (subnet broadcast or limited), + e.g. after TCP connect so the Pi does not run duplicate resync handlers. + If wait_reply, wait for first UDP echo. Returns controller IP string or None. """ ip, mask, _gw, _dns = sta.ifconfig() @@ -89,7 +113,12 @@ def broadcast_hello_udp( pass discovered = None - for dest_ip, dest_port in udp_discovery_targets(ip, mask): + targets = ( + udp_discovery_targets(ip, mask) + if dual_destinations + else _udp_discovery_targets_single(ip, mask) + ) + for dest_ip, dest_port in targets: if wdt is not None: wdt.feed() label = "%s:%s" % (dest_ip, dest_port) diff --git a/src/http_poll.py b/src/http_poll.py deleted file mode 100644 index 035f1f1..0000000 --- a/src/http_poll.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Minimal HTTP/1.1 POST JSON client for driver long-poll (MicroPython).""" - -import json -import socket - - -def _send_all(sock, data): - n = 0 - while n < len(data): - m = sock.send(data[n:]) - if m <= 0: - raise OSError("socket send failed") - n += m - - -def _read_http_json_body(sock, max_headers=8192): - buf = b"" - while b"\r\n\r\n" not in buf: - chunk = sock.recv(256) - if not chunk: - break - buf += chunk - if len(buf) > max_headers: - raise OSError("response headers too large") - if b"\r\n\r\n" not in buf: - raise OSError("incomplete response headers") - head, rest = buf.split(b"\r\n\r\n", 1) - cl = None - for line in head.split(b"\r\n"): - if line.lower().startswith(b"content-length:"): - try: - cl = int(line.split(b":", 1)[1].strip()) - except (ValueError, IndexError): - cl = None - if cl is None: - body = rest - else: - body = rest - while len(body) < cl: - chunk = sock.recv(min(2048, cl - len(body))) - if not chunk: - break - body += chunk - return json.loads(body.decode("utf-8")) - - -def http_driver_poll(host, port, payload_dict, timeout_s=40.0): - """ - POST ``/driver/v1/poll`` with JSON body; return parsed JSON (expects ``{"lines": [...]}``). - """ - path = "/driver/v1/poll" - body_bytes = json.dumps(payload_dict).encode("utf-8") - host_s = str(host) - req_head = ( - "POST %s HTTP/1.1\r\nHost: %s\r\nContent-Type: application/json\r\nContent-Length: %d\r\nConnection: close\r\n\r\n" - % (path, host_s, len(body_bytes)) - ).encode("utf-8") - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - try: - sock.settimeout(timeout_s) - sock.connect((host_s, int(port))) - _send_all(sock, req_head + body_bytes) - return _read_http_json_body(sock) - finally: - try: - sock.close() - except Exception: - pass diff --git a/src/main.py b/src/main.py index 1cde978..fa02cc3 100644 --- a/src/main.py +++ b/src/main.py @@ -1,23 +1,13 @@ from settings import Settings from machine import WDT -import utime import network +import utime +import asyncio +from microdot import Microdot +from microdot.websocket import WebSocketError, with_websocket from presets import Presets -from utils import convert_and_reorder_colors -import json -import time -import select -import socket -import ubinascii -from hello import discover_controller_udp -try: - import uos as os -except ImportError: - import os - -BROADCAST_MAC = b"\xff\xff\xff\xff\xff\xff" -CONTROLLER_TCP_PORT = 8765 -controller_ip = None +from controller_messages import process_data +from hello import broadcast_hello_udp settings = Settings() print(settings) @@ -27,364 +17,74 @@ presets.load(settings) presets.b = settings.get("brightness", 255) default_preset = settings.get("default", "") if default_preset and default_preset in presets.presets: - presets.select(default_preset) - print(f"Selected startup preset: {default_preset}") + if presets.select(default_preset): + print(f"Selected startup preset: {default_preset}") + else: + print("Startup preset failed (invalid pattern?):", default_preset) wdt = WDT(timeout=10000) wdt.feed() - -# --- Controller JSON (bytes or str): parse v1, then apply ------------------------- - - -def process_data(payload): - """Read one controller message; json.loads (bytes or str), then apply fields.""" - try: - data = json.loads(payload) - print(payload) - if data.get("v", "") != "1": - return - except (ValueError, TypeError): - return - if "b" in data: - apply_brightness(data) - if "presets" in data: - apply_presets(data) - if "select" in data: - apply_select(data) - if "default" in data: - apply_default(data) - if "manifest" in data: - apply_patterns_ota(data) - if "save" in data and ("presets" in data or "default" in data): - presets.save() - - -def apply_brightness(data): - try: - presets.b = max(0, min(255, int(data["b"]))) - settings["brightness"] = presets.b - except (TypeError, ValueError): - pass - - -def apply_presets(data): - presets_map = data["presets"] - for id, preset_data in presets_map.items(): - if not preset_data: - continue - color_key = "c" if "c" in preset_data else ("colors" if "colors" in preset_data else None) - if color_key is not None: - try: - preset_data[color_key] = convert_and_reorder_colors( - preset_data[color_key], settings - ) - except (TypeError, ValueError, KeyError): - continue - presets.edit(id, preset_data) - print(f"Edited preset {id}: {preset_data.get('name', '')}") - - -def apply_select(data): - select_map = data["select"] - device_name = settings["name"] - select_list = select_map.get(device_name, []) - if not select_list: - return - preset_name = select_list[0] - step = select_list[1] if len(select_list) > 1 else None - presets.select(preset_name, step=step) - - -def apply_default(data): - targets = data.get("targets") or [] - default_name = data["default"] - if ( - settings["name"] in targets - and isinstance(default_name, str) - and default_name in presets.presets - ): - settings["default"] = default_name - - -def _parse_http_url(url): - """Parse http://host[:port]/path into (host, port, path).""" - if not isinstance(url, str): - raise ValueError("url must be a string") - if not url.startswith("http://"): - raise ValueError("only http:// URLs are supported") - remainder = url[7:] - slash_idx = remainder.find("/") - if slash_idx == -1: - host_port = remainder - path = "/" - else: - host_port = remainder[:slash_idx] - path = remainder[slash_idx:] - if ":" in host_port: - host, port_s = host_port.rsplit(":", 1) - port = int(port_s) - else: - host = host_port - port = 80 - if not host: - raise ValueError("missing host") - return host, port, path - - -def _http_get_raw(url, timeout_s=10.0): - host, port, path = _parse_http_url(url) - req = ( - "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n" % (path, host) - ).encode("utf-8") - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - try: - sock.settimeout(timeout_s) - sock.connect((host, int(port))) - sock.send(req) - data = b"" - while True: - chunk = sock.recv(1024) - if not chunk: - break - data += chunk - finally: - try: - sock.close() - except Exception: - pass - sep = b"\r\n\r\n" - if sep not in data: - raise OSError("invalid HTTP response") - head, body = data.split(sep, 1) - status_line = head.split(b"\r\n", 1)[0] - if b" 200 " not in status_line: - raise OSError("HTTP status not OK: %s" % status_line.decode("utf-8")) - return body - - -def _http_get_json(url, timeout_s=10.0): - body = _http_get_raw(url, timeout_s=timeout_s) - return json.loads(body.decode("utf-8")) - - -def _http_get_text(url, timeout_s=10.0): - global controller_ip - # Support relative URLs from controller messages. - if isinstance(url, str) and url.startswith("/"): - if not controller_ip: - raise OSError("controller IP unavailable for relative URL") - url = "http://%s%s" % (controller_ip, url) - try: - body = _http_get_raw(url, timeout_s=timeout_s) - return body.decode("utf-8") - except Exception: - # Fallback for mDNS/unresolvable host: retry against current controller IP. - if not controller_ip or not isinstance(url, str) or not url.startswith("http://"): - raise - _host, _port, path = _parse_http_url(url) - fallback = "http://%s:%d%s" % (controller_ip, _port, path) - body = _http_get_raw(fallback, timeout_s=timeout_s) - return body.decode("utf-8") - - -def _safe_pattern_filename(name): - if not isinstance(name, str): - return False - if not name.endswith(".py"): - return False - if "/" in name or "\\" in name or ".." in name: - return False - return True - - -def apply_patterns_ota(data): - manifest_payload = data.get("manifest") - if not manifest_payload: - return - try: - if isinstance(manifest_payload, dict): - manifest = manifest_payload - elif isinstance(manifest_payload, str): - manifest = _http_get_json(manifest_payload, timeout_s=20.0) - else: - print("patterns_ota: invalid manifest payload type") - return - files = manifest.get("files", []) - if not isinstance(files, list) or not files: - print("patterns_ota: no files in manifest") - return - try: - os.mkdir("patterns") - except OSError: - pass - updated = 0 - for item in files: - if not isinstance(item, dict): - continue - name = item.get("name") - url = item.get("url") - inline_code = item.get("code") - if not _safe_pattern_filename(name): - continue - if isinstance(inline_code, str): - code = inline_code - elif isinstance(url, str): - code = _http_get_text(url, timeout_s=20.0) - else: - continue - with open("patterns/" + name, "w") as f: - f.write(code) - updated += 1 - if updated > 0: - presets.reload_patterns() - print("patterns_ota: updated", updated, "pattern file(s)") - else: - print("patterns_ota: no valid files downloaded") - except Exception as e: - print("patterns_ota failed:", e) - - -# --- TCP framing (bytes) → process_data ------------------------------------------- - - -def tcp_append_and_drain_lines(buf, chunk): - """Return (new_buf, list of non-empty stripped line byte strings).""" - buf += chunk - lines = [] - while b"\n" in buf: - line, buf = buf.split(b"\n", 1) - line = line.strip() - if line: - lines.append(line) - return buf, lines - - -# --- Network + hello -------------------------------------------------------------- - sta_if = network.WLAN(network.STA_IF) sta_if.active(True) sta_if.config(pm=network.WLAN.PM_NONE) +sta_if.connect(settings["ssid"], settings["password"]) +while not sta_if.isconnected(): + utime.sleep(1) + wdt.feed() -mac = sta_if.config("mac") -hello_payload = { - "v": "1", - "device_name": settings.get("name", ""), - "mac": ubinascii.hexlify(mac).decode().lower(), - "type": "led", -} -hello_bytes = json.dumps(hello_payload).encode("utf-8") +print(sta_if.ifconfig()) -if settings["transport_type"] == "espnow": - from espnow import ESPNow # import only in this branch (avoids load when using Wi-Fi) +app = Microdot() - sta_if.disconnect() - sta_if.config(channel=settings.get("wifi_channel", 1)) - e = ESPNow() - e.active(True) - e.add_peer(BROADCAST_MAC) - e.add_peer(mac) - e.send(BROADCAST_MAC, hello_bytes) + +@app.route("/ws") +@with_websocket +async def ws_handler(request, ws): + print("WS client connected") + try: + while True: + data = await ws.receive() + if not data: + print("WS client disconnected (closed)") + break + print(data) + process_data(data, settings, presets) + except WebSocketError as e: + print("WS client disconnected:", e) + except OSError as e: + print("WS client dropped (OSError):", e) + + +async def presets_loop(): while True: - if e.any(): - _peer, msg = e.recv() - if msg: - process_data(msg) - presets.tick() + await presets.tick() wdt.feed() + # tick() does not await; yield so UDP hello and HTTP/WebSocket can run. + await asyncio.sleep(0) -elif settings["transport_type"] == "wifi": - sta_if.connect(settings["ssid"], settings["password"]) - while not sta_if.isconnected(): - time.sleep(1) - print(f"WiFi connected {sta_if.ifconfig()[0]}") - controller_ip = discover_controller_udp( - device_name=settings.get("name", ""), - wdt=wdt, - ) - if not controller_ip: - raise SystemExit("No controller IP discovered for Wi-Fi transport") - def pick_controller_ip(current): - ip = discover_controller_udp( - device_name=settings.get("name", ""), +async def _udp_hello_after_http_ready(): + """Hello must run after the HTTP server binds, or discovery clients time out on /ws.""" + await asyncio.sleep(1) + print("UDP hello: broadcasting…") + try: + broadcast_hello_udp( + sta_if, + settings.get("name", ""), + wait_reply=False, wdt=wdt, + dual_destinations=True, ) - if ip and ip != current: - print("Controller IP updated to", ip) - return ip if ip else current + except Exception as ex: + print("UDP hello broadcast failed:", ex) - reconnect_ms = 1000 - next_connect_at = 0 - client = None - poller = None - buf = b"" - while True: - now = utime.ticks_ms() +async def main(port=80): + asyncio.create_task(presets_loop()) + asyncio.create_task(_udp_hello_after_http_ready()) + await app.start_server(host="0.0.0.0", port=port) - if client is None and utime.ticks_diff(now, next_connect_at) >= 0: - c = None - try: - c = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - c.connect((controller_ip, CONTROLLER_TCP_PORT)) - c.setblocking(False) - p = select.poll() - p.register(c, select.POLLIN) - client = c - poller = p - buf = b"" - print("TCP connected") - except Exception: - if c is not None: - try: - c.close() - except Exception: - pass - controller_ip = pick_controller_ip(controller_ip) - next_connect_at = utime.ticks_add(now, reconnect_ms) - if client is not None and poller is not None: - try: - events = poller.poll(0) - except Exception: - events = [] - - reconnect_needed = False - for fd, event in events: - if (event & select.POLLHUP) or (event & select.POLLERR): - reconnect_needed = True - break - if event & select.POLLIN: - try: - chunk = client.recv(512) - except OSError: - reconnect_needed = True - break - - if not chunk: - reconnect_needed = True - break - - buf, lines = tcp_append_and_drain_lines(buf, chunk) - for raw_line in lines: - process_data(raw_line) - - if reconnect_needed: - print("TCP disconnected, reconnecting...") - try: - poller.unregister(client) - except Exception: - pass - try: - client.close() - except Exception: - pass - client = None - poller = None - buf = b"" - controller_ip = pick_controller_ip(controller_ip) - next_connect_at = utime.ticks_add(now, reconnect_ms) - - presets.tick() - wdt.feed() +if __name__ == "__main__": + asyncio.run(main(port=80)) diff --git a/src/patterns/main.py b/src/patterns/main.py new file mode 100644 index 0000000..3dba628 --- /dev/null +++ b/src/patterns/main.py @@ -0,0 +1,136 @@ +import os +import sys + +from settings import Settings +from machine import WDT +import network +import utime +import asyncio +from microdot import Microdot +from microdot.websocket import WebSocketError, with_websocket +from presets import Presets +from controller_messages import process_data +from hello import broadcast_hello_udp + +settings = Settings() +print(settings) + +presets = Presets(settings["led_pin"], settings["num_leds"]) +presets.load(settings) +presets.b = settings.get("brightness", 255) +default_preset = settings.get("default", "") +if default_preset and default_preset in presets.presets: + if presets.select(default_preset): + print(f"Selected startup preset: {default_preset}") + else: + print("Startup preset failed (invalid pattern?):", default_preset) + +wdt = WDT(timeout=10000) +wdt.feed() + +sta_if = network.WLAN(network.STA_IF) +sta_if.active(True) +sta_if.config(pm=network.WLAN.PM_NONE) +sta_if.connect(settings["ssid"], settings["password"]) +while not sta_if.isconnected(): + utime.sleep(1) + wdt.feed() + +app = Microdot() + + +def _simulator_register_microdot_app(): + """led-simulator sets LED_SIM_ROOT so Stop can shutdown() the same Microdot instance.""" + root = os.environ.get("LED_SIM_ROOT") + if not root: + return + sys.path.insert(0, root) + try: + import led_driver_sim_hook as _sim_hook + except ImportError: + return + _sim_hook.register_app(app) + + +_simulator_register_microdot_app() + + +@app.route("/ws") +@with_websocket +async def ws_handler(request, ws): + print("WS client connected") + try: + while True: + data = await ws.receive() + if not data: + print("WS client disconnected (closed)") + break + print(data) + process_data(data, settings, presets) + except WebSocketError as e: + print("WS client disconnected:", e) + except OSError as e: + print("WS client dropped (OSError):", e) + + +async def presets_loop(): + while True: + await presets.tick() + wdt.feed() + # tick() does not await; yield so UDP hello and HTTP/WebSocket can run. + await asyncio.sleep(0) + + +async def _udp_hello_after_http_ready(): + """Hello must run after the HTTP server binds, or discovery clients time out on /ws.""" + await asyncio.sleep(1) + print("UDP hello: broadcasting…") + try: + broadcast_hello_udp( + sta_if, + settings.get("name", ""), + wait_reply=False, + wdt=wdt, + dual_destinations=True, + ) + except Exception as ex: + print("UDP hello broadcast failed:", ex) + + +async def main(port=80): + t_presets = asyncio.create_task(presets_loop()) + t_hello = asyncio.create_task(_udp_hello_after_http_ready()) + try: + await app.start_server(host="0.0.0.0", port=port) + finally: + for t in (t_presets, t_hello): + t.cancel() + try: + await t + except asyncio.CancelledError: + pass + + +def _simulator_apply_pattern_from_env(): + """led-simulator sets LED_SIM_PATTERN to a patterns/ module name (no .py).""" + mod = os.environ.get("LED_SIM_PATTERN", "").strip() + if not mod: + return + presets.reload_patterns() + presets.edit( + "_sim", + { + "p": mod, + "d": 200, + "b": 255, + "c": [(255, 0, 0), (0, 255, 0), (0, 0, 255)], + }, + ) + if not presets.select("_sim"): + print("LED_SIM_PATTERN: could not select pattern:", mod) + + +if __name__ == "__main__": + _simulator_apply_pattern_from_env() + _port = int(os.environ.get("LED_SIM_PORT", "80")) + asyncio.run(main(port=_port)) \ No newline at end of file diff --git a/src/settings.py b/src/settings.py index b2f15ea..40a3886 100644 --- a/src/settings.py +++ b/src/settings.py @@ -21,10 +21,9 @@ class Settings(dict): self["debug"] = False self["default"] = "on" self["brightness"] = 32 - self["transport_type"] = "wifi" + self["transport_type"] = "espnow" self["wifi_channel"] = 1 - # Wi-Fi + TCP to controller: set ssid and password. Use transport_type "espnow" - # for ESP-NOW (requires espnow firmware). + # ESP-NOW transport (requires espnow firmware; uses wifi_channel). self["ssid"] = "" self["password"] = "" diff --git a/test/all.py b/tests/all.py similarity index 99% rename from test/all.py rename to tests/all.py index f6d6040..a637f3e 100644 --- a/test/all.py +++ b/tests/all.py @@ -7,7 +7,7 @@ import utime from machine import WDT from settings import Settings -from presets import Presets +from presets import Presets, run_tick from utils import convert_and_reorder_colors @@ -23,7 +23,7 @@ class _TestContext: start = utime.ticks_ms() while utime.ticks_diff(utime.ticks_ms(), start) < duration_ms: self.wdt.feed() - self.presets.tick() + run_tick(self.presets) utime.sleep_ms(sleep_ms) diff --git a/test/patterns/auto_manual.py b/tests/patterns/auto_manual.py similarity index 97% rename from test/patterns/auto_manual.py rename to tests/patterns/auto_manual.py index a5ec964..e779b79 100644 --- a/test/patterns/auto_manual.py +++ b/tests/patterns/auto_manual.py @@ -2,7 +2,7 @@ import utime from machine import WDT from settings import Settings -from presets import Presets +from presets import Presets, run_tick def run_for(p, wdt, duration_ms): @@ -10,7 +10,7 @@ def run_for(p, wdt, duration_ms): start = utime.ticks_ms() while utime.ticks_diff(utime.ticks_ms(), start) < duration_ms: wdt.feed() - p.tick() + run_tick(p) utime.sleep_ms(10) @@ -52,7 +52,7 @@ def main(): p.select("rainbow_manual") print("Calling tick() 5 times (should advance 5 steps)...") for i in range(5): - p.tick() + run_tick(p) utime.sleep_ms(100) # Small delay to see changes print(f" Tick {i+1}: generator={'active' if p.generator is not None else 'stopped'}") @@ -96,7 +96,7 @@ def main(): tick_count = 0 max_ticks = 200 # Safety limit while p.generator is not None and tick_count < max_ticks: - p.tick() + run_tick(p) tick_count += 1 utime.sleep_ms(10) @@ -133,7 +133,7 @@ def main(): tick_count = 0 max_ticks = 200 while p.generator is not None and tick_count < max_ticks: - p.tick() + run_tick(p) tick_count += 1 utime.sleep_ms(10) @@ -162,7 +162,7 @@ def main(): print("Calling tick() 3 times in manual mode...") for i in range(3): - p.tick() + run_tick(p) utime.sleep_ms(100) print(f" Tick {i+1}: generator={'active' if p.generator is not None else 'stopped'}") @@ -178,7 +178,7 @@ def main(): print("\nCleaning up...") p.edit("cleanup_off", {"p": "off"}) p.select("cleanup_off") - p.tick() + run_tick(p) utime.sleep_ms(100) print("\n" + "=" * 50) diff --git a/test/patterns/blink.py b/tests/patterns/blink.py similarity index 92% rename from test/patterns/blink.py rename to tests/patterns/blink.py index 7291dde..124de8f 100644 --- a/test/patterns/blink.py +++ b/tests/patterns/blink.py @@ -2,7 +2,7 @@ import utime from machine import WDT from settings import Settings -from presets import Presets +from presets import Presets, run_tick def main(): @@ -25,7 +25,7 @@ def main(): start = utime.ticks_ms() while utime.ticks_diff(utime.ticks_ms(), start) < 1500: wdt.feed() - p.tick() + run_tick(p) utime.sleep_ms(10) diff --git a/test/patterns/chase.py b/tests/patterns/chase.py similarity index 97% rename from test/patterns/chase.py rename to tests/patterns/chase.py index 39c2618..2e79c49 100644 --- a/test/patterns/chase.py +++ b/tests/patterns/chase.py @@ -2,7 +2,7 @@ import utime from machine import WDT from settings import Settings -from presets import Presets +from presets import Presets, run_tick def run_for(p, wdt, ms): @@ -10,7 +10,7 @@ def run_for(p, wdt, ms): start = utime.ticks_ms() while utime.ticks_diff(utime.ticks_ms(), start) < ms: wdt.feed() - p.tick() + run_tick(p) utime.sleep_ms(10) @@ -123,7 +123,7 @@ def main(): print(" Advancing pattern with 10 beats (select + tick)...") for i in range(10): p.select("chase_manual") # Simulate beat - restarts generator - p.tick() # Advance one step + run_tick(p) # Advance one step utime.sleep_ms(500) # Pause to see the pattern wdt.feed() print(f" Beat {i+1}: step={p.step}") @@ -141,7 +141,7 @@ def main(): p.step = 0 initial_step = p.step p.select("chase_manual2") - p.tick() + run_tick(p) final_step = p.step print(f" Step updated from {initial_step} to {final_step} (expected: 1)") if final_step == 1: diff --git a/test/patterns/circle.py b/tests/patterns/circle.py similarity index 98% rename from test/patterns/circle.py rename to tests/patterns/circle.py index 2de9d8d..d84456f 100644 --- a/test/patterns/circle.py +++ b/tests/patterns/circle.py @@ -2,7 +2,7 @@ import utime from machine import WDT from settings import Settings -from presets import Presets +from presets import Presets, run_tick def run_for(p, wdt, ms): @@ -10,7 +10,7 @@ def run_for(p, wdt, ms): start = utime.ticks_ms() while utime.ticks_diff(utime.ticks_ms(), start) < ms: wdt.feed() - p.tick() + run_tick(p) utime.sleep_ms(10) diff --git a/test/patterns/off.py b/tests/patterns/off.py similarity index 90% rename from test/patterns/off.py rename to tests/patterns/off.py index e85a701..bf7574b 100644 --- a/test/patterns/off.py +++ b/tests/patterns/off.py @@ -2,7 +2,7 @@ import utime from machine import WDT from settings import Settings -from presets import Presets +from presets import Presets, run_tick def main(): @@ -20,7 +20,7 @@ def main(): start = utime.ticks_ms() while utime.ticks_diff(utime.ticks_ms(), start) < 200: wdt.feed() - p.tick() + run_tick(p) utime.sleep_ms(10) diff --git a/test/patterns/on.py b/tests/patterns/on.py similarity index 92% rename from test/patterns/on.py rename to tests/patterns/on.py index 44c82c1..1a3aea1 100644 --- a/test/patterns/on.py +++ b/tests/patterns/on.py @@ -2,7 +2,7 @@ import utime from machine import WDT from settings import Settings -from presets import Presets +from presets import Presets, run_tick def main(): @@ -29,7 +29,7 @@ def main(): start = utime.ticks_ms() while utime.ticks_diff(utime.ticks_ms(), start) < 800: wdt.feed() - p.tick() + run_tick(p) utime.sleep_ms(10) # OFF phase @@ -37,7 +37,7 @@ def main(): start = utime.ticks_ms() while utime.ticks_diff(utime.ticks_ms(), start) < 100: wdt.feed() - p.tick() + run_tick(p) utime.sleep_ms(10) diff --git a/test/patterns/pulse.py b/tests/patterns/pulse.py similarity index 97% rename from test/patterns/pulse.py rename to tests/patterns/pulse.py index 708b112..f1793dc 100644 --- a/test/patterns/pulse.py +++ b/tests/patterns/pulse.py @@ -2,7 +2,7 @@ import utime from machine import WDT from settings import Settings -from presets import Presets +from presets import Presets, run_tick def run_for(p, wdt, ms): @@ -10,7 +10,7 @@ def run_for(p, wdt, ms): start = utime.ticks_ms() while utime.ticks_diff(utime.ticks_ms(), start) < ms: wdt.feed() - p.tick() + run_tick(p) utime.sleep_ms(10) diff --git a/test/patterns/rainbow.py b/tests/patterns/rainbow.py similarity index 97% rename from test/patterns/rainbow.py rename to tests/patterns/rainbow.py index 7773371..e0d0c8a 100644 --- a/test/patterns/rainbow.py +++ b/tests/patterns/rainbow.py @@ -2,7 +2,7 @@ import utime from machine import WDT from settings import Settings -from presets import Presets +from presets import Presets, run_tick def run_for(p, wdt, ms): @@ -10,7 +10,7 @@ def run_for(p, wdt, ms): start = utime.ticks_ms() while utime.ticks_diff(utime.ticks_ms(), start) < ms: wdt.feed() - p.tick() + run_tick(p) utime.sleep_ms(10) @@ -81,7 +81,7 @@ def main(): for i in range(10): p.select("rainbow5") # One tick advances the generator one frame when auto=False - p.tick() + run_tick(p) utime.sleep_ms(100) wdt.feed() @@ -94,7 +94,7 @@ def main(): }) initial_step = p.step p.select("rainbow6") - p.tick() + run_tick(p) final_step = p.step print(f"Step updated from {initial_step} to {final_step} (expected increment: 1)") @@ -130,7 +130,7 @@ def main(): p.step = 0 initial_step = p.step p.select("rainbow9") - p.tick() + run_tick(p) final_step = p.step expected_step = (initial_step + 5) % 256 print(f"Step updated from {initial_step} to {final_step} (expected: {expected_step})") diff --git a/test/patterns/transition.py b/tests/patterns/transition.py similarity index 97% rename from test/patterns/transition.py rename to tests/patterns/transition.py index 00149c0..8ff981e 100644 --- a/test/patterns/transition.py +++ b/tests/patterns/transition.py @@ -2,7 +2,7 @@ import utime from machine import WDT from settings import Settings -from presets import Presets +from presets import Presets, run_tick def run_for(p, wdt, ms): @@ -10,7 +10,7 @@ def run_for(p, wdt, ms): start = utime.ticks_ms() while utime.ticks_diff(utime.ticks_ms(), start) < ms: wdt.feed() - p.tick() + run_tick(p) utime.sleep_ms(10) diff --git a/test/test_espnow_receive.py b/tests/test_espnow_receive.py similarity index 99% rename from test/test_espnow_receive.py rename to tests/test_espnow_receive.py index 10af390..4c3cb75 100644 --- a/test/test_espnow_receive.py +++ b/tests/test_espnow_receive.py @@ -4,7 +4,7 @@ import json import os import utime from settings import Settings -from presets import Presets +from presets import Presets, run_tick from utils import convert_and_reorder_colors @@ -54,7 +54,7 @@ def run_main_loop_iterations(espnow, patterns, settings, wdt, max_iterations=10) while iterations < max_iterations: wdt.feed() - patterns.tick() + run_tick(patterns) if espnow.any(): host, msg = espnow.recv() @@ -363,7 +363,7 @@ def test_switch_presets(): start = utime.ticks_ms() while utime.ticks_diff(utime.ticks_ms(), start) < 2000: wdt.feed() - patterns.tick() + run_tick(patterns) utime.sleep_ms(10) # Switch to second preset and run for 2 seconds @@ -381,7 +381,7 @@ def test_switch_presets(): start = utime.ticks_ms() while utime.ticks_diff(utime.ticks_ms(), start) < 2000: wdt.feed() - patterns.tick() + run_tick(patterns) utime.sleep_ms(10) # Switch to third preset and run for 2 seconds @@ -399,7 +399,7 @@ def test_switch_presets(): start = utime.ticks_ms() while utime.ticks_diff(utime.ticks_ms(), start) < 2000: wdt.feed() - patterns.tick() + run_tick(patterns) utime.sleep_ms(10) # Switch back to first preset and run for 2 seconds @@ -417,7 +417,7 @@ def test_switch_presets(): start = utime.ticks_ms() while utime.ticks_diff(utime.ticks_ms(), start) < 2000: wdt.feed() - patterns.tick() + run_tick(patterns) utime.sleep_ms(10) print(" ✓ Preset switching works correctly") @@ -577,7 +577,7 @@ def test_select_with_step(): mock_espnow.send_message(b"\xbb\xbb\xbb\xbb\xbb\xbb", msg2) run_main_loop_iterations(mock_espnow, patterns, settings, wdt, max_iterations=2) # Ensure tick() is called after select() to advance the step - patterns.tick() + run_tick(patterns) assert patterns.selected == "step_preset", "Should select step_preset" # Step is set to 10, then tick() advances it, so it should be 11 @@ -596,7 +596,7 @@ def test_select_with_step(): initial_step = patterns.step # Should be 11 run_main_loop_iterations(mock_espnow, patterns, settings, wdt, max_iterations=2) # Ensure tick() is called after select() to advance the step - patterns.tick() + run_tick(patterns) # Since it's the same preset, step should not be reset, but tick() will advance it # So step should be initial_step + 1 (one tick call) assert patterns.step == initial_step + 1, f"Step should advance from {initial_step} to {initial_step + 1} (not reset), got {patterns.step}" @@ -614,7 +614,7 @@ def test_select_with_step(): mock_espnow.send_message(b"\xdd\xdd\xdd\xdd\xdd\xdd", msg4) run_main_loop_iterations(mock_espnow, patterns, settings, wdt, max_iterations=2) # Ensure tick() is called after select() to advance the step - patterns.tick() + run_tick(patterns) assert patterns.selected == "other_preset", "Should select other_preset" # Step is set to 5, then tick() advances it, so it should be 6 diff --git a/tests/test_mdns.py b/tests/test_mdns.py index 701138c..1083d7c 100644 --- a/tests/test_mdns.py +++ b/tests/test_mdns.py @@ -15,7 +15,7 @@ Deploy src to the device (including utils.py with mdns_hostname), then from the mpremote connect PORT run tests/test_mdns.py -If ImportError: copy utils.py from src/ to the device, or rely on the built-in fallback below. +Copy ``utils.py`` from ``src/`` onto the device if imports fail. Or with cwd led-driver: @@ -30,26 +30,7 @@ import utime from machine import WDT from settings import Settings - -try: - from utils import mdns_hostname -except ImportError: - - def mdns_hostname(settings): - """Same as utils.mdns_hostname (fallback if device utils.py is older than host repo).""" - raw = settings.get("name") or "led" - suffix = [] - for c in str(raw).lower(): - o = ord(c) - if (48 <= o <= 57) or (97 <= o <= 122): - suffix.append(c) - s = "".join(suffix) - if not s: - s = "device" - h = "led" + s - if len(h) > 32: - h = h[:32] - return h +from utils import mdns_hostname CONNECT_TIMEOUT_S = 45 # ESP32 MicroPython WDT timeout is capped (typically 10000 ms). Longer blocking work @@ -213,16 +194,6 @@ def main(): "Set SELF_LOCAL_GETADDRINFO = True to attempt (may hang)." ) - # Optional: built-in mdns module (not present on all ESP32 builds) - _dbg(t0, "checking for optional 'mdns' module") - try: - import mdns # noqa: F401 - - print("Note: 'mdns' module is present; check your port's docs for Server/API.") - except ImportError: - print("No top-level 'mdns' module; relying on stack mDNS from hostname.") - _dbg(t0, "mdns import check done") - if HOLD_S != 0: forever = HOLD_S < 0 _dbg( diff --git a/tests/test_wifi.py b/tests/test_wifi.py new file mode 100644 index 0000000..4e742f0 --- /dev/null +++ b/tests/test_wifi.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +"""Wi-Fi connection smoke test for MicroPython on ESP32. + +Runs on-device via mpremote and uses /settings.json credentials. + +Usage: + mpremote connect /dev/ttyACM0 run tests/test_wifi.py +""" + +import time +import utime +import network +from machine import WDT + +from settings import Settings + +CONNECT_TIMEOUT_S = 30 +RETRY_DELAY_S = 2 +WDT_TIMEOUT_MS = 10000 + + +def _wifi_status_label(code): + names = { + getattr(network, "STAT_IDLE", 0): "idle", + getattr(network, "STAT_CONNECTING", 1): "connecting", + getattr(network, "STAT_WRONG_PASSWORD", -3): "wrong_password", + getattr(network, "STAT_NO_AP_FOUND", -2): "no_ap_found", + getattr(network, "STAT_CONNECT_FAIL", -1): "connect_fail", + getattr(network, "STAT_GOT_IP", 3): "got_ip", + } + return names.get(code, str(code)) + + +def connect_wifi_with_wdt(sta, ssid, password, wdt): + attempt = 0 + while not sta.isconnected(): + attempt += 1 + print("[wifi-test] attempt", attempt, "ssid=", repr(ssid)) + try: + sta.disconnect() + except Exception: + pass + sta.connect(ssid, password) + + start = utime.time() + last_status = None + while not sta.isconnected(): + status = sta.status() + if status != last_status: + print("[wifi-test] status:", status, _wifi_status_label(status)) + last_status = status + if status in ( + getattr(network, "STAT_WRONG_PASSWORD", -3), + getattr(network, "STAT_NO_AP_FOUND", -2), + getattr(network, "STAT_CONNECT_FAIL", -1), + ): + break + if utime.time() - start >= CONNECT_TIMEOUT_S: + print("[wifi-test] timeout after", CONNECT_TIMEOUT_S, "seconds") + break + time.sleep(1) + wdt.feed() + + if sta.isconnected(): + return True + + print("[wifi-test] retry in", RETRY_DELAY_S, "seconds") + for _ in range(RETRY_DELAY_S): + time.sleep(1) + wdt.feed() + return True + + +def main(): + settings = Settings() + ssid = settings.get("ssid") or "" + password = settings.get("password") or "" + + if not ssid: + print("[wifi-test] skipped: settings.ssid is empty") + raise SystemExit(0) + + wdt = WDT(timeout=WDT_TIMEOUT_MS) + wdt.feed() + + sta = network.WLAN(network.STA_IF) + sta.active(True) + try: + sta.config(pm=network.WLAN.PM_NONE) + except (AttributeError, ValueError, TypeError): + pass + + ok = connect_wifi_with_wdt(sta, ssid, password, wdt) + if not ok or not sta.isconnected(): + print("[wifi-test] FAILED: not connected") + raise SystemExit(1) + + print("[wifi-test] OK:", sta.ifconfig()) + + +if __name__ == "__main__": + main()