38 lines
1.0 KiB
Python
38 lines
1.0 KiB
Python
from contextlib import (AbstractContextManager, asynccontextmanager,
|
|
contextmanager)
|
|
from typing import Callable
|
|
|
|
from sqlalchemy.ext.asyncio import (AsyncSession, async_sessionmaker,
|
|
create_async_engine)
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import Session
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
class Database:
|
|
def __init__(self, db_url: str) -> None:
|
|
self._engine = create_async_engine(db_url, echo=True)
|
|
self._session_factory = async_sessionmaker(
|
|
self._engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
)
|
|
|
|
@property
|
|
def session(self):
|
|
return self._session_factory()
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, *args):
|
|
await self.session.rollback()
|
|
await self.session.close()
|
|
|
|
async def commit(self):
|
|
await self.session.commit()
|
|
|
|
async def rollback(self):
|
|
await self.session.rollback()
|