You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

44 lines
1012 B

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""数据库连接管理"""
import os
from typing import Generator
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
# 数据库配置
# 优先使用环境变量 DATABASE_URL
# 格式mysql+pymysql://user:password@host:port/database
DATABASE_URL = os.getenv(
"DATABASE_URL",
"mysql+pymysql://root:1qazse42W3@localhost:3306/marketdata"
)
# 创建数据库引擎
engine = create_engine(
DATABASE_URL,
pool_pre_ping=True,
pool_size=10,
max_overflow=20,
)
# 会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 声明基类
Base = declarative_base()
def get_db() -> Generator[Session, None, None]:
"""获取数据库会话用于FastAPI依赖注入"""
db = SessionLocal()
try:
yield db
finally:
db.close()
def init_db():
"""初始化数据库(创建所有表)"""
Base.metadata.create_all(bind=engine)