Files
core/services/db.py
Untone b60a314ddd
Some checks failed
Deploy on push / deploy (push) Failing after 5s
tested-auth-refactoring
2025-07-25 01:04:15 +03:00

263 lines
9.0 KiB
Python

import logging
import math
import time
import traceback
import warnings
from io import TextIOWrapper
from typing import Any, TypeVar
import sqlalchemy
from sqlalchemy import create_engine, event, exc, func, inspect
from sqlalchemy.dialects.sqlite import insert
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.orm import Session, configure_mappers, joinedload
from sqlalchemy.pool import StaticPool
from orm.base import BaseModel
from settings import DB_URL
from utils.logger import root_logger as logger
# Global variables
logger = logging.getLogger(__name__)
# Database configuration
engine = create_engine(DB_URL, echo=False, poolclass=StaticPool if "sqlite" in DB_URL else None)
ENGINE = engine # Backward compatibility alias
inspector = inspect(engine)
# Session = sessionmaker(engine)
configure_mappers()
T = TypeVar("T")
FILTERED_FIELDS = ["_sa_instance_state", "search_vector"]
# make_searchable(Base.metadata)
# Base.metadata.create_all(bind=engine)
# Функция для вывода полного трейсбека при предупреждениях
def warning_with_traceback(
message: Warning | str,
category: type[Warning],
filename: str,
lineno: int,
file: TextIOWrapper | None = None,
line: str | None = None,
) -> None:
tb = traceback.format_stack()
tb_str = "".join(tb)
print(f"{message} ({filename}, {lineno}): {category.__name__}\n{tb_str}")
# Установка функции вывода трейсбека для предупреждений SQLAlchemy
warnings.showwarning = warning_with_traceback # type: ignore[assignment]
warnings.simplefilter("always", exc.SAWarning)
# Функция для извлечения SQL-запроса из контекста
def get_statement_from_context(context: Connection) -> str | None:
query = ""
compiled = getattr(context, "compiled", None)
if compiled:
compiled_statement = getattr(compiled, "string", None)
compiled_parameters = getattr(compiled, "params", None)
if compiled_statement:
if compiled_parameters:
try:
# Безопасное форматирование параметров
query = compiled_statement % compiled_parameters
except Exception as e:
logger.exception(f"Error formatting query: {e}")
else:
query = compiled_statement
if query:
query = query.replace("\n", " ").replace(" ", " ").replace(" ", " ").strip()
return query
# Обработчик события перед выполнением запроса
@event.listens_for(Engine, "before_cursor_execute")
def before_cursor_execute(
conn: Connection,
cursor: Any,
statement: str,
parameters: dict[str, Any] | None,
context: Connection,
executemany: bool,
) -> None:
conn.query_start_time = time.time() # type: ignore[attr-defined]
conn.cursor_id = id(cursor) # type: ignore[attr-defined]
# Обработчик события после выполнения запроса
@event.listens_for(Engine, "after_cursor_execute")
def after_cursor_execute(
conn: Connection,
cursor: Any,
statement: str,
parameters: dict[str, Any] | None,
context: Connection,
executemany: bool,
) -> None:
if hasattr(conn, "cursor_id") and conn.cursor_id == id(cursor):
query = get_statement_from_context(context)
if query:
elapsed = time.time() - getattr(conn, "query_start_time", time.time())
if elapsed > 1:
query_end = query[-16:]
query = query.split(query_end)[0] + query_end
logger.debug(query)
elapsed_n = math.floor(elapsed)
logger.debug("*" * (elapsed_n))
logger.debug(f"{elapsed:.3f} s")
if hasattr(conn, "cursor_id"):
delattr(conn, "cursor_id") # Удаление идентификатора курсора после выполнения
def get_json_builder() -> tuple[Any, Any, Any]:
"""
Возвращает подходящие функции для построения JSON объектов в зависимости от драйвера БД
"""
dialect = engine.dialect.name
json_cast = lambda x: x # noqa: E731
if dialect.startswith("postgres"):
json_cast = lambda x: func.cast(x, sqlalchemy.Text) # noqa: E731
return func.json_build_object, func.json_agg, json_cast
if dialect.startswith(("sqlite", "mysql")):
return func.json_object, func.json_group_array, json_cast
msg = f"JSON builder not implemented for dialect {dialect}"
raise NotImplementedError(msg)
# Используем их в коде
json_builder, json_array_builder, json_cast = get_json_builder()
# Fetch all shouts, with authors preloaded
# This function is used for search indexing
def fetch_all_shouts(session: Session | None = None) -> list[Any]:
"""Fetch all published shouts for search indexing with authors preloaded"""
from orm.shout import Shout
close_session = False
if session is None:
session = local_session()
close_session = True
try:
# Fetch only published and non-deleted shouts with authors preloaded
query = (
session.query(Shout)
.options(joinedload(Shout.authors))
.filter(Shout.published_at is not None, Shout.deleted_at is None)
)
return query.all()
except Exception as e:
logger.exception(f"Error fetching shouts for search indexing: {e}")
return []
finally:
if close_session:
# Подавляем SQLAlchemy deprecated warning для синхронной сессии
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
session.close()
def get_column_names_without_virtual(model_cls: type[BaseModel]) -> list[str]:
"""Получает имена колонок модели без виртуальных полей"""
try:
column_names: list[str] = [
col.name for col in model_cls.__table__.columns if not getattr(col, "_is_virtual", False)
]
return column_names
except AttributeError:
return []
def get_primary_key_columns(model_cls: type[BaseModel]) -> list[str]:
"""Получает имена первичных ключей модели"""
try:
return [col.name for col in model_cls.__table__.primary_key.columns]
except AttributeError:
return ["id"]
def create_table_if_not_exists(engine: Engine, model_cls: type[BaseModel]) -> None:
"""Creates table for the given model if it doesn't exist"""
if hasattr(model_cls, "__tablename__"):
inspector = inspect(engine)
if not inspector.has_table(model_cls.__tablename__):
model_cls.__table__.create(engine)
logger.info(f"Created table: {model_cls.__tablename__}")
def format_sql_warning(
message: str | Warning,
category: type[Warning],
filename: str,
lineno: int,
file: TextIOWrapper | None = None,
line: str | None = None,
) -> str:
"""Custom warning formatter for SQL warnings"""
return f"SQL Warning: {message}\n"
# Apply the custom warning formatter
def _set_warning_formatter() -> None:
"""Set custom warning formatter"""
import warnings
original_formatwarning = warnings.formatwarning
def custom_formatwarning(
message: Warning | str,
category: type[Warning],
filename: str,
lineno: int,
file: TextIOWrapper | None = None,
line: str | None = None,
) -> str:
return format_sql_warning(message, category, filename, lineno, file, line)
warnings.formatwarning = custom_formatwarning # type: ignore[assignment]
_set_warning_formatter()
def upsert_on_duplicate(table: sqlalchemy.Table, **values: Any) -> sqlalchemy.sql.Insert:
"""
Performs an upsert operation (insert or update on conflict)
"""
if engine.dialect.name == "sqlite":
return insert(table).values(**values).on_conflict_do_update(index_elements=["id"], set_=values)
# For other databases, implement appropriate upsert logic
return table.insert().values(**values)
def get_sql_functions() -> dict[str, Any]:
"""Returns database-specific SQL functions"""
if engine.dialect.name == "sqlite":
return {
"now": sqlalchemy.func.datetime("now"),
"extract_epoch": lambda x: sqlalchemy.func.strftime("%s", x),
"coalesce": sqlalchemy.func.coalesce,
}
return {
"now": sqlalchemy.func.now(),
"extract_epoch": sqlalchemy.func.extract("epoch", sqlalchemy.text("?")),
"coalesce": sqlalchemy.func.coalesce,
}
# noinspection PyUnusedLocal
def local_session(src: str = "") -> Session:
"""Create a new database session"""
return Session(bind=engine, expire_on_commit=False)
# Also export the type for type hints
__all__ = ["engine", "local_session"]