tests-passed

This commit is contained in:
2025-07-31 18:55:59 +03:00
parent b7abb8d8a1
commit e7230ba63c
126 changed files with 8326 additions and 3207 deletions

View File

@@ -1,25 +1,20 @@
import logging
import math
import time
import traceback
import warnings
from io import TextIOWrapper
from typing import Any, TypeVar
from typing import Any, Type, TypeVar
import sqlalchemy
from sqlalchemy import create_engine, event, exc, func, inspect
from sqlalchemy.dialects.sqlite import insert
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.orm import Session, configure_mappers, joinedload
from sqlalchemy.orm import DeclarativeBase, Session, configure_mappers
from sqlalchemy.pool import StaticPool
from orm.base import BaseModel
from settings import DB_URL
from utils.logger import root_logger as logger
# Global variables
logger = logging.getLogger(__name__)
# Database configuration
engine = create_engine(DB_URL, echo=False, poolclass=StaticPool if "sqlite" in DB_URL else None)
ENGINE = engine # Backward compatibility alias
@@ -64,8 +59,8 @@ def get_statement_from_context(context: Connection) -> str | None:
try:
# Безопасное форматирование параметров
query = compiled_statement % compiled_parameters
except Exception as e:
logger.exception(f"Error formatting query: {e}")
except Exception:
logger.exception("Error formatting query")
else:
query = compiled_statement
if query:
@@ -130,41 +125,28 @@ def get_json_builder() -> tuple[Any, Any, Any]:
# Используем их в коде
json_builder, json_array_builder, json_cast = get_json_builder()
# Fetch all shouts, with authors preloaded
# This function is used for search indexing
def fetch_all_shouts(session: Session | None = None) -> list[Any]:
"""Fetch all published shouts for search indexing with authors preloaded"""
from orm.shout import Shout
close_session = False
if session is None:
session = local_session()
close_session = True
def create_table_if_not_exists(connection_or_engine: Connection | Engine, model_cls: Type[DeclarativeBase]) -> None:
"""Creates table for the given model if it doesn't exist"""
# If an Engine is passed, get a connection from it
connection = connection_or_engine.connect() if isinstance(connection_or_engine, Engine) else connection_or_engine
try:
# Fetch only published and non-deleted shouts with authors preloaded
query = (
session.query(Shout)
.options(joinedload(Shout.authors))
.filter(Shout.published_at is not None, Shout.deleted_at is None)
)
return query.all()
except Exception as e:
logger.exception(f"Error fetching shouts for search indexing: {e}")
return []
inspector = inspect(connection)
if not inspector.has_table(model_cls.__tablename__):
# Use SQLAlchemy's built-in table creation instead of manual SQL generation
from sqlalchemy.schema import CreateTable
create_stmt = CreateTable(model_cls.__table__) # type: ignore[arg-type]
connection.execute(create_stmt)
logger.info(f"Created table: {model_cls.__tablename__}")
finally:
if close_session:
# Подавляем SQLAlchemy deprecated warning для синхронной сессии
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
session.close()
# If we created a connection from an Engine, close it
if isinstance(connection_or_engine, Engine):
connection.close()
def get_column_names_without_virtual(model_cls: type[BaseModel]) -> list[str]:
def get_column_names_without_virtual(model_cls: Type[DeclarativeBase]) -> list[str]:
"""Получает имена колонок модели без виртуальных полей"""
try:
column_names: list[str] = [
@@ -175,23 +157,6 @@ def get_column_names_without_virtual(model_cls: type[BaseModel]) -> list[str]:
return []
def get_primary_key_columns(model_cls: type[BaseModel]) -> list[str]:
"""Получает имена первичных ключей модели"""
try:
return [col.name for col in model_cls.__table__.primary_key.columns]
except AttributeError:
return ["id"]
def create_table_if_not_exists(engine: Engine, model_cls: type[BaseModel]) -> None:
"""Creates table for the given model if it doesn't exist"""
if hasattr(model_cls, "__tablename__"):
inspector = inspect(engine)
if not inspector.has_table(model_cls.__tablename__):
model_cls.__table__.create(engine)
logger.info(f"Created table: {model_cls.__tablename__}")
def format_sql_warning(
message: str | Warning,
category: type[Warning],
@@ -207,19 +172,11 @@ def format_sql_warning(
# Apply the custom warning formatter
def _set_warning_formatter() -> None:
"""Set custom warning formatter"""
import warnings
original_formatwarning = warnings.formatwarning
def custom_formatwarning(
message: Warning | str,
category: type[Warning],
filename: str,
lineno: int,
file: TextIOWrapper | None = None,
line: str | None = None,
message: str, category: type[Warning], filename: str, lineno: int, line: str | None = None
) -> str:
return format_sql_warning(message, category, filename, lineno, file, line)
return f"{category.__name__}: {message}\n"
warnings.formatwarning = custom_formatwarning # type: ignore[assignment]