diff --git a/app/auth.py b/app/auth.py new file mode 100644 index 0000000..71036ee --- /dev/null +++ b/app/auth.py @@ -0,0 +1,81 @@ +from os import getenv +import json +from datetime import timedelta, datetime +from typing import Optional +from fastapi.security import OAuth2PasswordBearer +from fastapi import Depends, HTTPException, status +from jose import JWTError, jwt +from passlib.context import CryptContext +from models import User, UserInDB, TokenData + + +# to get a string like this run: +# openssl rand -hex 32 +SECRET_KEY = getenv("SECRET_KEY") +ALGORITHM = getenv("ALGORITHM") +ACCESS_TOKEN_EXPIRE_MINUTES = int(getenv("ACCESS_TOKEN_EXPIRE_MINUTES")) + +fake_users_db = json.load(open("app/users.json")) + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + +def verify_password(plain_password, hashed_password): + return pwd_context.verify(plain_password, hashed_password) + + +def get_password_hash(password): + return pwd_context.hash(password) + + +def get_user(db, username: str): + if username in db: + user_dict = db[username] + return UserInDB(**user_dict) + + +def authenticate_user(fake_db, username: str, password: str): + user = get_user(fake_db, username) + if not user: + return False + if not verify_password(password, user.hashed_password): + return False + return user + + +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=15) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + +async def get_current_user(token: str = Depends(oauth2_scheme)): + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + raise credentials_exception + token_data = TokenData(username=username) + except JWTError: + raise credentials_exception + user = get_user(fake_users_db, username=token_data.username) + if user is None: + raise credentials_exception + return user + + +async def get_current_active_user(current_user: User = Depends(get_current_user)): + if current_user.disabled: + raise HTTPException(status_code=400, detail="Inactive user") + return current_user \ No newline at end of file diff --git a/app/main.py b/app/main.py index e37f914..1a45fd9 100755 --- a/app/main.py +++ b/app/main.py @@ -1,109 +1,14 @@ -from datetime import datetime, timedelta -from typing import Optional +from datetime import timedelta + from fastapi import Depends, FastAPI, HTTPException, status -from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -from jose import JWTError, jwt -from passlib.context import CryptContext -from pydantic import BaseModel -from os import getenv -import json +from fastapi.security import OAuth2PasswordRequestForm -# to get a string like this run: -# openssl rand -hex 32 -SECRET_KEY = getenv("SECRET_KEY") -ALGORITHM = getenv("ALGORITHM") -ACCESS_TOKEN_EXPIRE_MINUTES = int(getenv("ACCESS_TOKEN_EXPIRE_MINUTES")) - -fake_users_db = json.load(open("app/users.json")) - -class Token(BaseModel): - access_token: str - token_type: str - - -class TokenData(BaseModel): - username: Optional[str] = None - - -class User(BaseModel): - username: str - email: Optional[str] = None - full_name: Optional[str] = None - disabled: Optional[bool] = None - - -class UserInDB(User): - hashed_password: str - - -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +from models import Token, User +from auth import get_current_active_user, create_access_token, authenticate_user, fake_users_db, ACCESS_TOKEN_EXPIRE_MINUTES app = FastAPI() - -def verify_password(plain_password, hashed_password): - return pwd_context.verify(plain_password, hashed_password) - - -def get_password_hash(password): - return pwd_context.hash(password) - - -def get_user(db, username: str): - if username in db: - user_dict = db[username] - return UserInDB(**user_dict) - - -def authenticate_user(fake_db, username: str, password: str): - user = get_user(fake_db, username) - if not user: - return False - if not verify_password(password, user.hashed_password): - return False - return user - - -def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): - to_encode = data.copy() - if expires_delta: - expire = datetime.utcnow() + expires_delta - else: - expire = datetime.utcnow() + timedelta(minutes=15) - to_encode.update({"exp": expire}) - encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) - return encoded_jwt - - -async def get_current_user(token: str = Depends(oauth2_scheme)): - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - username: str = payload.get("sub") - if username is None: - raise credentials_exception - token_data = TokenData(username=username) - except JWTError: - raise credentials_exception - user = get_user(fake_users_db, username=token_data.username) - if user is None: - raise credentials_exception - return user - - -async def get_current_active_user(current_user: User = Depends(get_current_user)): - if current_user.disabled: - raise HTTPException(status_code=400, detail="Inactive user") - return current_user - - @app.post("/token", response_model=Token) async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()): user = authenticate_user(fake_users_db, form_data.username, form_data.password) diff --git a/app/models.py b/app/models.py new file mode 100644 index 0000000..a98b8e0 --- /dev/null +++ b/app/models.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel +from typing import Optional + +class Token(BaseModel): + access_token: str + token_type: str + + +class TokenData(BaseModel): + username: Optional[str] = None + + +class User(BaseModel): + username: str + email: Optional[str] = None + full_name: Optional[str] = None + disabled: Optional[bool] = None + +class UserInDB(User): + hashed_password: str \ No newline at end of file