215 lines
6.5 KiB
Python
215 lines
6.5 KiB
Python
from __future__ import annotations
|
|
|
|
import datetime as dt
|
|
import os
|
|
import secrets
|
|
import smtplib
|
|
from email.message import EmailMessage
|
|
from typing import TYPE_CHECKING
|
|
|
|
import bcrypt
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from editor_app.db.models import AuthSession, InviteToken, User
|
|
from editor_app.services import user_workspace
|
|
|
|
if TYPE_CHECKING:
|
|
pass
|
|
|
|
SESSION_COOKIE_NAME = "editor_session"
|
|
SESSION_DAYS_DEFAULT = 14
|
|
|
|
|
|
def auth_enabled() -> bool:
|
|
return os.environ.get("AUTH_ENABLED", "false").strip().lower() in ("1", "true", "yes", "on")
|
|
|
|
|
|
def register_open() -> bool:
|
|
return os.environ.get("AUTH_REGISTER_OPEN", "true").strip().lower() in ("1", "true", "yes", "on")
|
|
|
|
|
|
def session_ttl_days() -> int:
|
|
try:
|
|
return max(1, min(365, int(os.environ.get("AUTH_SESSION_DAYS", str(SESSION_DAYS_DEFAULT)))))
|
|
except ValueError:
|
|
return SESSION_DAYS_DEFAULT
|
|
|
|
|
|
def hash_password(password: str) -> str:
|
|
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("ascii")
|
|
|
|
|
|
def verify_password(plain: str, hashed: str) -> bool:
|
|
try:
|
|
return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("ascii"))
|
|
except (ValueError, TypeError):
|
|
return False
|
|
|
|
|
|
def get_user_by_username(db: Session, username: str) -> User | None:
|
|
return db.scalars(select(User).where(User.username == username)).one_or_none()
|
|
|
|
|
|
def get_user_by_id(db: Session, user_id: int) -> User | None:
|
|
return db.scalars(select(User).where(User.id == user_id)).one_or_none()
|
|
|
|
|
|
def count_users(db: Session) -> int:
|
|
return int(db.scalar(select(func.count()).select_from(User)) or 0)
|
|
|
|
|
|
def create_user(db: Session, username: str, password: str, *, is_superuser: bool = False) -> User:
|
|
user = User(
|
|
username=username,
|
|
password_hash=hash_password(password),
|
|
is_superuser=is_superuser,
|
|
)
|
|
db.add(user)
|
|
db.commit()
|
|
db.refresh(user)
|
|
if auth_enabled():
|
|
user_workspace.ensure_default_code_main(
|
|
user_workspace.user_workspace_root(user.id, user.username)
|
|
)
|
|
return user
|
|
|
|
|
|
def register_user(db: Session, username: str, password: str) -> User:
|
|
if get_user_by_username(db, username):
|
|
raise ValueError("Username already taken")
|
|
first = count_users(db) == 0
|
|
return create_user(db, username, password, is_superuser=first)
|
|
|
|
|
|
def register_user_with_invite(db: Session, username: str, password: str, invite_token: str) -> User:
|
|
invite = get_valid_invite(db, invite_token)
|
|
if invite is None:
|
|
raise ValueError("Invite is invalid or expired")
|
|
if get_user_by_username(db, username):
|
|
raise ValueError("Username already taken")
|
|
user = create_user(db, username, password, is_superuser=False)
|
|
invite.used_at = _utc_naive()
|
|
invite.consumed_by_user_id = user.id
|
|
db.add(invite)
|
|
db.commit()
|
|
return user
|
|
|
|
|
|
def authenticate(db: Session, username: str, password: str) -> User | None:
|
|
user = get_user_by_username(db, username.strip())
|
|
if not user or not verify_password(password, user.password_hash):
|
|
return None
|
|
return user
|
|
|
|
|
|
def _utc_naive() -> dt.datetime:
|
|
return dt.datetime.now(dt.UTC).replace(tzinfo=None)
|
|
|
|
|
|
def create_session(db: Session, user: User) -> AuthSession:
|
|
token = secrets.token_urlsafe(48)
|
|
expires = _utc_naive() + dt.timedelta(days=session_ttl_days())
|
|
row = AuthSession(user_id=user.id, token=token, expires_at=expires)
|
|
db.add(row)
|
|
db.commit()
|
|
db.refresh(row)
|
|
return row
|
|
|
|
|
|
def get_session_user(db: Session, token: str | None) -> User | None:
|
|
if not token:
|
|
return None
|
|
now = _utc_naive()
|
|
row = db.scalars(select(AuthSession).where(AuthSession.token == token)).one_or_none()
|
|
if not row or row.expires_at < now:
|
|
return None
|
|
return get_user_by_id(db, row.user_id)
|
|
|
|
|
|
def delete_session(db: Session, token: str | None) -> None:
|
|
if not token:
|
|
return
|
|
row = db.scalars(select(AuthSession).where(AuthSession.token == token)).one_or_none()
|
|
if row:
|
|
db.delete(row)
|
|
db.commit()
|
|
|
|
|
|
def list_users(db: Session) -> list[User]:
|
|
return list(db.scalars(select(User).order_by(User.username)).all())
|
|
|
|
|
|
def delete_user(db: Session, user_id: int) -> bool:
|
|
user = get_user_by_id(db, user_id)
|
|
if not user:
|
|
return False
|
|
db.delete(user)
|
|
db.commit()
|
|
return True
|
|
|
|
|
|
def invite_required() -> bool:
|
|
return os.environ.get("AUTH_INVITE_ONLY", "true").strip().lower() in ("1", "true", "yes", "on")
|
|
|
|
|
|
def create_invite(db: Session, email: str, invited_by_user_id: int | None = None, expires_days: int = 7) -> InviteToken:
|
|
token = secrets.token_urlsafe(36)
|
|
expires = _utc_naive() + dt.timedelta(days=max(1, min(30, int(expires_days))))
|
|
row = InviteToken(
|
|
email=email.strip().lower(),
|
|
token=token,
|
|
expires_at=expires,
|
|
invited_by_user_id=invited_by_user_id,
|
|
)
|
|
db.add(row)
|
|
db.commit()
|
|
db.refresh(row)
|
|
return row
|
|
|
|
|
|
def get_valid_invite(db: Session, token: str | None) -> InviteToken | None:
|
|
if not token:
|
|
return None
|
|
row = db.scalars(select(InviteToken).where(InviteToken.token == token.strip())).one_or_none()
|
|
if row is None:
|
|
return None
|
|
if row.used_at is not None:
|
|
return None
|
|
if row.expires_at < _utc_naive():
|
|
return None
|
|
return row
|
|
|
|
|
|
def build_invite_url(token: str) -> str:
|
|
base = (os.environ.get("PUBLIC_BASE_URL") or "http://127.0.0.1:8080").rstrip("/")
|
|
return f"{base}/register?invite={token}"
|
|
|
|
|
|
def send_invite_email(email: str, invite_url: str) -> bool:
|
|
host = (os.environ.get("SMTP_HOST") or "").strip()
|
|
if not host:
|
|
return False
|
|
port = int((os.environ.get("SMTP_PORT") or "587").strip())
|
|
user = (os.environ.get("SMTP_USER") or "").strip()
|
|
password = os.environ.get("SMTP_PASSWORD") or ""
|
|
sender = (os.environ.get("SMTP_FROM") or user or "noreply@python-editor.local").strip()
|
|
use_tls = (os.environ.get("SMTP_TLS", "true").strip().lower() in ("1", "true", "yes", "on"))
|
|
|
|
msg = EmailMessage()
|
|
msg["Subject"] = "Your Python Editor invite"
|
|
msg["From"] = sender
|
|
msg["To"] = email
|
|
msg.set_content(
|
|
"You have been invited to Python Editor.\n\n"
|
|
f"Use this link to sign up:\n{invite_url}\n\n"
|
|
"If you did not expect this invite, you can ignore this message.\n"
|
|
)
|
|
with smtplib.SMTP(host, port, timeout=10) as smtp:
|
|
if use_tls:
|
|
smtp.starttls()
|
|
if user:
|
|
smtp.login(user, password)
|
|
smtp.send_message(msg)
|
|
return True
|