You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

93 lines
2.7 KiB

"""
安全模块 - JWT认证和密码处理
"""
from datetime import datetime, timedelta
from typing import Optional
from jose import JWTError, jwt
import bcrypt
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from app.config import settings
from app.db.session import get_db
from app.models.user import User
security = HTTPBearer()
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证密码"""
return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8'))
def get_password_hash(password: str) -> str:
"""获取密码哈希"""
return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""创建JWT访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(hours=settings.ACCESS_TOKEN_EXPIRE_HOURS)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
def decode_token(token: str) -> Optional[dict]:
"""解码JWT令牌"""
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
return payload
except JWTError:
return None
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: Session = Depends(get_db)
) -> User:
"""获取当前用户"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
token = credentials.credentials
payload = decode_token(token)
if payload is None:
raise credentials_exception
username: str = payload.get("sub")
if username is None:
raise credentials_exception
user = db.query(User).filter(User.username == username).first()
if user is None:
raise credentials_exception
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="用户已被禁用"
)
return user
async def get_current_active_superuser(current_user: User = Depends(get_current_user)) -> User:
"""获取当前超级用户"""
if not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要超级用户权限"
)
return current_user