41 lines
1.4 KiB
Python
41 lines
1.4 KiB
Python
from collections.abc import AsyncGenerator
|
|
from fastapi import Depends
|
|
from fastapi_users.db import SQLAlchemyUserDatabase
|
|
from fastapi_users.jwt import SecretType
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
from src.auth.user_manager import UserManager
|
|
from src.database.user import User, Base
|
|
|
|
|
|
class Database:
|
|
def __init__(
|
|
self,
|
|
db_user: str,
|
|
db_pass: str,
|
|
db_host: str,
|
|
db_port: int,
|
|
db_name: str,
|
|
secret: SecretType
|
|
):
|
|
self.DATABASE_URL = f'postgresql+asyncpg://{db_user}:{db_pass}@{db_host}:{db_port}/{db_name}'
|
|
self.engine = create_async_engine(self.DATABASE_URL)
|
|
self.async_session_maker = async_sessionmaker(self.engine, expire_on_commit=False)
|
|
self.secret = secret
|
|
|
|
async def create_db_and_tables(self):
|
|
async with self.engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
|
|
async def get_async_session(self) -> AsyncGenerator[AsyncSession, None]:
|
|
async with self.async_session_maker() as session:
|
|
yield session
|
|
|
|
|
|
async def get_user_db(self, session: AsyncSession = Depends(get_async_session)):
|
|
yield SQLAlchemyUserDatabase(session, User)
|
|
|
|
async def get_user_manager(self, user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
|
|
yield UserManager(self.secret, user_db)
|