Compare commits

..

3 Commits

Author SHA1 Message Date
9f16ee022b cors-applevel-fix
Some checks failed
Deploy on push / deploy (push) Failing after 40s
2025-06-02 23:59:09 +03:00
7bbb847eb1 tests, maintainance fixes
All checks were successful
Deploy on push / deploy (push) Successful in 1m36s
2025-05-16 09:22:53 +03:00
8a60bec73a tests upgrade 2025-05-16 09:11:39 +03:00
33 changed files with 655 additions and 1449 deletions

View File

@ -29,16 +29,7 @@ jobs:
if: github.ref == 'refs/heads/dev' if: github.ref == 'refs/heads/dev'
uses: dokku/github-action@master uses: dokku/github-action@master
with: with:
branch: 'main' branch: 'dev'
force: true force: true
git_remote_url: 'ssh://dokku@v2.discours.io:22/core' git_remote_url: 'ssh://dokku@v2.discours.io:22/core'
ssh_private_key: ${{ secrets.SSH_PRIVATE_KEY }} ssh_private_key: ${{ secrets.SSH_PRIVATE_KEY }}
- name: Push to dokku for staging branch
if: github.ref == 'refs/heads/staging'
uses: dokku/github-action@master
with:
branch: 'dev'
git_remote_url: 'ssh://dokku@staging.discours.io:22/core'
ssh_private_key: ${{ secrets.SSH_PRIVATE_KEY }}
git_push_flags: '--force'

6
.gitignore vendored
View File

@ -128,9 +128,6 @@ dmypy.json
.idea .idea
temp.* temp.*
# Debug
DEBUG.log
discours.key discours.key
discours.crt discours.crt
discours.pem discours.pem
@ -165,4 +162,5 @@ views.json
*.crt *.crt
*cache.json *cache.json
.cursor .cursor
.devcontainer/
node_modules/

View File

@ -3,7 +3,6 @@ FROM python:slim
RUN apt-get update && apt-get install -y \ RUN apt-get update && apt-get install -y \
postgresql-client \ postgresql-client \
curl \ curl \
build-essential \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
WORKDIR /app WORKDIR /app

View File

@ -1 +0,0 @@

View File

@ -1,119 +0,0 @@
import time
from sqlalchemy import (
JSON,
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
func,
)
from sqlalchemy.orm import relationship
from services.db import Base
class Permission(Base):
__tablename__ = "permission"
id = Column(String, primary_key=True, unique=True, nullable=False, default=None)
resource = Column(String, nullable=False)
operation = Column(String, nullable=False)
class Role(Base):
__tablename__ = "role"
id = Column(String, primary_key=True, unique=True, nullable=False, default=None)
name = Column(String, nullable=False)
permissions = relationship(Permission)
class AuthorizerUser(Base):
__tablename__ = "authorizer_users"
id = Column(String, primary_key=True, unique=True, nullable=False, default=None)
key = Column(String)
email = Column(String, unique=True)
email_verified_at = Column(Integer)
family_name = Column(String)
gender = Column(String)
given_name = Column(String)
is_multi_factor_auth_enabled = Column(Boolean)
middle_name = Column(String)
nickname = Column(String)
password = Column(String)
phone_number = Column(String, unique=True)
phone_number_verified_at = Column(Integer)
# preferred_username = Column(String, nullable=False)
picture = Column(String)
revoked_timestamp = Column(Integer)
roles = Column(String, default="author,reader")
signup_methods = Column(String, default="magic_link_login")
created_at = Column(Integer, default=lambda: int(time.time()))
updated_at = Column(Integer, default=lambda: int(time.time()))
class UserRating(Base):
__tablename__ = "user_rating"
id = None
rater: Column = Column(ForeignKey("user.id"), primary_key=True, index=True)
user: Column = Column(ForeignKey("user.id"), primary_key=True, index=True)
value: Column = Column(Integer)
@staticmethod
def init_table():
pass
class UserRole(Base):
__tablename__ = "user_role"
id = None
user = Column(ForeignKey("user.id"), primary_key=True, index=True)
role = Column(ForeignKey("role.id"), primary_key=True, index=True)
class User(Base):
__tablename__ = "user"
default_user = None
email = Column(String, unique=True, nullable=False, comment="Email")
username = Column(String, nullable=False, comment="Login")
password = Column(String, nullable=True, comment="Password")
bio = Column(String, nullable=True, comment="Bio") # status description
about = Column(String, nullable=True, comment="About") # long and formatted
userpic = Column(String, nullable=True, comment="Userpic")
name = Column(String, nullable=True, comment="Display name")
slug = Column(String, unique=True, comment="User's slug")
links = Column(JSON, nullable=True, comment="Links")
oauth = Column(String, nullable=True)
oid = Column(String, nullable=True)
muted = Column(Boolean, default=False)
confirmed = Column(Boolean, default=False)
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Created at")
updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Updated at")
last_seen = Column(DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Was online at")
deleted_at = Column(DateTime(timezone=True), nullable=True, comment="Deleted at")
ratings = relationship(UserRating, foreign_keys=UserRating.user)
roles = relationship(lambda: Role, secondary=UserRole.__tablename__)
def get_permission(self):
scope = {}
for role in self.roles:
for p in role.permissions:
if p.resource not in scope:
scope[p.resource] = set()
scope[p.resource].add(p.operation)
print(scope)
return scope
# if __name__ == "__main__":
# print(User.get_permission(user_id=1))

12
cache/precache.py vendored
View File

@ -77,15 +77,11 @@ async def precache_topics_followers(topic_id: int, session):
async def precache_data(): async def precache_data():
logger.info("precaching...") logger.info("precaching...")
logger.debug("Entering precache_data")
try: try:
key = "authorizer_env" key = "authorizer_env"
logger.debug(f"Fetching existing hash for key '{key}' from Redis")
# cache reset # cache reset
value = await redis.execute("HGETALL", key) value = await redis.execute("HGETALL", key)
logger.debug(f"Fetched value for '{key}': {value}")
await redis.execute("FLUSHDB") await redis.execute("FLUSHDB")
logger.debug("Redis database flushed")
logger.info("redis: FLUSHDB") logger.info("redis: FLUSHDB")
# Преобразуем словарь в список аргументов для HSET # Преобразуем словарь в список аргументов для HSET
@ -101,27 +97,21 @@ async def precache_data():
await redis.execute("HSET", key, *value) await redis.execute("HSET", key, *value)
logger.info(f"redis hash '{key}' was restored") logger.info(f"redis hash '{key}' was restored")
logger.info("Beginning topic precache phase")
with local_session() as session: with local_session() as session:
# topics # topics
q = select(Topic).where(Topic.community == 1) q = select(Topic).where(Topic.community == 1)
topics = get_with_stat(q) topics = get_with_stat(q)
logger.info(f"Found {len(topics)} topics to precache")
for topic in topics: for topic in topics:
topic_dict = topic.dict() if hasattr(topic, "dict") else topic topic_dict = topic.dict() if hasattr(topic, "dict") else topic
logger.debug(f"Precaching topic id={topic_dict.get('id')}")
await cache_topic(topic_dict) await cache_topic(topic_dict)
logger.debug(f"Cached topic id={topic_dict.get('id')}")
await asyncio.gather( await asyncio.gather(
precache_topics_followers(topic_dict["id"], session), precache_topics_followers(topic_dict["id"], session),
precache_topics_authors(topic_dict["id"], session), precache_topics_authors(topic_dict["id"], session),
) )
logger.debug(f"Finished precaching followers and authors for topic id={topic_dict.get('id')}")
logger.info(f"{len(topics)} topics and their followings precached") logger.info(f"{len(topics)} topics and their followings precached")
# authors # authors
authors = get_with_stat(select(Author).where(Author.user.is_not(None))) authors = get_with_stat(select(Author).where(Author.user.is_not(None)))
logger.info(f"Found {len(authors)} authors to precache")
logger.info(f"{len(authors)} authors found in database") logger.info(f"{len(authors)} authors found in database")
for author in authors: for author in authors:
if isinstance(author, Author): if isinstance(author, Author):
@ -129,12 +119,10 @@ async def precache_data():
author_id = profile.get("id") author_id = profile.get("id")
user_id = profile.get("user", "").strip() user_id = profile.get("user", "").strip()
if author_id and user_id: if author_id and user_id:
logger.debug(f"Precaching author id={author_id}")
await cache_author(profile) await cache_author(profile)
await asyncio.gather( await asyncio.gather(
precache_authors_followers(author_id, session), precache_authors_follows(author_id, session) precache_authors_followers(author_id, session), precache_authors_follows(author_id, session)
) )
logger.debug(f"Finished precaching followers and follows for author id={author_id}")
else: else:
logger.error(f"fail caching {author}") logger.error(f"fail caching {author}")
logger.info(f"{len(authors)} authors and their followings precached") logger.info(f"{len(authors)} authors and their followings precached")

24
cache/triggers.py vendored
View File

@ -88,7 +88,11 @@ def after_reaction_handler(mapper, connection, target):
with local_session() as session: with local_session() as session:
shout = ( shout = (
session.query(Shout) session.query(Shout)
.filter(Shout.id == shout_id, Shout.published_at.is_not(None), Shout.deleted_at.is_(None)) .filter(
Shout.id == shout_id,
Shout.published_at.is_not(None),
Shout.deleted_at.is_(None),
)
.first() .first()
) )
@ -108,15 +112,27 @@ def events_register():
event.listen(AuthorFollower, "after_insert", after_follower_handler) event.listen(AuthorFollower, "after_insert", after_follower_handler)
event.listen(AuthorFollower, "after_update", after_follower_handler) event.listen(AuthorFollower, "after_update", after_follower_handler)
event.listen(AuthorFollower, "after_delete", lambda *args: after_follower_handler(*args, is_delete=True)) event.listen(
AuthorFollower,
"after_delete",
lambda *args: after_follower_handler(*args, is_delete=True),
)
event.listen(TopicFollower, "after_insert", after_follower_handler) event.listen(TopicFollower, "after_insert", after_follower_handler)
event.listen(TopicFollower, "after_update", after_follower_handler) event.listen(TopicFollower, "after_update", after_follower_handler)
event.listen(TopicFollower, "after_delete", lambda *args: after_follower_handler(*args, is_delete=True)) event.listen(
TopicFollower,
"after_delete",
lambda *args: after_follower_handler(*args, is_delete=True),
)
event.listen(ShoutReactionsFollower, "after_insert", after_follower_handler) event.listen(ShoutReactionsFollower, "after_insert", after_follower_handler)
event.listen(ShoutReactionsFollower, "after_update", after_follower_handler) event.listen(ShoutReactionsFollower, "after_update", after_follower_handler)
event.listen(ShoutReactionsFollower, "after_delete", lambda *args: after_follower_handler(*args, is_delete=True)) event.listen(
ShoutReactionsFollower,
"after_delete",
lambda *args: after_follower_handler(*args, is_delete=True),
)
event.listen(Reaction, "after_update", mark_for_revalidation) event.listen(Reaction, "after_update", mark_for_revalidation)
event.listen(Author, "after_update", mark_for_revalidation) event.listen(Author, "after_update", mark_for_revalidation)

81
main.py
View File

@ -7,6 +7,7 @@ from os.path import exists
from ariadne import load_schema_from_path, make_executable_schema from ariadne import load_schema_from_path, make_executable_schema
from ariadne.asgi import GraphQL from ariadne.asgi import GraphQL
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
@ -17,7 +18,7 @@ from cache.revalidator import revalidation_manager
from services.exception import ExceptionHandlerMiddleware from services.exception import ExceptionHandlerMiddleware
from services.redis import redis from services.redis import redis
from services.schema import create_all_tables, resolvers from services.schema import create_all_tables, resolvers
from services.search import search_service, initialize_search_index from services.search import search_service
from services.viewed import ViewedStorage from services.viewed import ViewedStorage
from services.webhook import WebhookEndpoint, create_webhook_endpoint from services.webhook import WebhookEndpoint, create_webhook_endpoint
from settings import DEV_SERVER_PID_FILE_NAME, MODE from settings import DEV_SERVER_PID_FILE_NAME, MODE
@ -34,79 +35,24 @@ async def start():
f.write(str(os.getpid())) f.write(str(os.getpid()))
print(f"[main] process started in {MODE} mode") print(f"[main] process started in {MODE} mode")
async def check_search_service():
"""Check if search service is available and log result"""
info = await search_service.info()
if info.get("status") in ["error", "unavailable"]:
print(f"[WARNING] Search service unavailable: {info.get('message', 'unknown reason')}")
else:
print(f"[INFO] Search service is available: {info}")
# Helper to run precache with timeout and catch errors
async def precache_with_timeout():
try:
await asyncio.wait_for(precache_data(), timeout=60)
except asyncio.TimeoutError:
print("[precache] Precache timed out after 60 seconds")
except Exception as e:
print(f"[precache] Error during precache: {e}")
# indexing DB data
# async def indexing():
# from services.db import fetch_all_shouts
# all_shouts = await fetch_all_shouts()
# await initialize_search_index(all_shouts)
async def lifespan(_app): async def lifespan(_app):
try: try:
print("[lifespan] Starting application initialization")
create_all_tables() create_all_tables()
# schedule precaching in background with timeout and error handling
asyncio.create_task(precache_with_timeout())
await asyncio.gather( await asyncio.gather(
redis.connect(), redis.connect(),
precache_data(),
ViewedStorage.init(), ViewedStorage.init(),
create_webhook_endpoint(), create_webhook_endpoint(),
check_search_service(), search_service.info(),
start(), start(),
revalidation_manager.start(), revalidation_manager.start(),
) )
print("[lifespan] Basic initialization complete")
# Add a delay before starting the intensive search indexing
print("[lifespan] Waiting for system stabilization before search indexing...")
await asyncio.sleep(10) # 10-second delay to let the system stabilize
# Start search indexing as a background task with lower priority
asyncio.create_task(initialize_search_index_background())
yield yield
finally: finally:
print("[lifespan] Shutting down application services")
tasks = [redis.disconnect(), ViewedStorage.stop(), revalidation_manager.stop()] tasks = [redis.disconnect(), ViewedStorage.stop(), revalidation_manager.stop()]
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
print("[lifespan] Shutdown complete")
# Initialize search index in the background
async def initialize_search_index_background():
"""Run search indexing as a background task with low priority"""
try:
print("[search] Starting background search indexing process")
from services.db import fetch_all_shouts
# Get total count first (optional)
all_shouts = await fetch_all_shouts()
total_count = len(all_shouts) if all_shouts else 0
print(f"[search] Fetched {total_count} shouts for background indexing")
# Start the indexing process with the fetched shouts
print("[search] Beginning background search index initialization...")
await initialize_search_index(all_shouts)
print("[search] Background search index initialization complete")
except Exception as e:
print(f"[search] Error in background search indexing: {str(e)}")
# Создаем экземпляр GraphQL # Создаем экземпляр GraphQL
graphql_app = GraphQL(schema, debug=True) graphql_app = GraphQL(schema, debug=True)
@ -128,6 +74,24 @@ async def graphql_handler(request: Request):
print(f"GraphQL error: {str(e)}") print(f"GraphQL error: {str(e)}")
return JSONResponse({"error": str(e)}, status_code=500) return JSONResponse({"error": str(e)}, status_code=500)
middleware = [
# Начинаем с обработки ошибок
Middleware(ExceptionHandlerMiddleware),
# CORS должен быть перед другими middleware для корректной обработки preflight-запросов
Middleware(
CORSMiddleware,
allow_origins=[
"https://localhost:3000",
"https://testing.discours.io",
"https://testing3.discours.io",
"https://discours.io",
"https://new.discours.io"
],
allow_methods=["GET", "POST", "OPTIONS"], # Явно указываем OPTIONS
allow_headers=["*"],
allow_credentials=True,
),
]
# Обновляем маршрут в Starlette # Обновляем маршрут в Starlette
app = Starlette( app = Starlette(
@ -135,6 +99,7 @@ app = Starlette(
Route("/", graphql_handler, methods=["GET", "POST"]), Route("/", graphql_handler, methods=["GET", "POST"]),
Route("/new-author", WebhookEndpoint), Route("/new-author", WebhookEndpoint),
], ],
middleware=middleware,
lifespan=lifespan, lifespan=lifespan,
debug=True, debug=True,
) )

View File

@ -1,5 +1,5 @@
log_format custom '$remote_addr - $remote_user [$time_local] "$request" ' log_format custom '$remote_addr - $remote_user [$time_local] "$request" '
'origin=$http_origin allow_origin=$allow_origin status=$status ' 'origin=$http_origin status=$status '
'"$http_referer" "$http_user_agent"'; '"$http_referer" "$http_user_agent"';
{{ $proxy_settings := "proxy_http_version 1.1; proxy_set_header Upgrade $http_upgrade; proxy_set_header Connection $http_connection; proxy_set_header Host $http_host; proxy_set_header X-Request-Start $msec;" }} {{ $proxy_settings := "proxy_http_version 1.1; proxy_set_header Upgrade $http_upgrade; proxy_set_header Connection $http_connection; proxy_set_header Host $http_host; proxy_set_header X-Request-Start $msec;" }}
@ -49,34 +49,6 @@ server {
{{ $proxy_settings }} {{ $proxy_settings }}
{{ $gzip_settings }} {{ $gzip_settings }}
# Handle CORS for OPTIONS method
if ($request_method = 'OPTIONS') {
add_header 'Access-Control-Allow-Origin' $allow_origin always;
add_header 'Access-Control-Allow-Methods' 'POST, GET, OPTIONS';
add_header 'Access-Control-Allow-Headers' 'Content-Type, Authorization' always;
add_header 'Access-Control-Allow-Credentials' 'true' always;
add_header 'Access-Control-Max-Age' 1728000;
add_header 'Content-Type' 'text/plain; charset=utf-8';
add_header 'Content-Length' 0;
return 204;
}
# Handle CORS for POST method
if ($request_method = 'POST') {
add_header 'Access-Control-Allow-Origin' $allow_origin always;
add_header 'Access-Control-Allow-Methods' 'POST, GET, OPTIONS' always;
add_header 'Access-Control-Allow-Headers' 'Content-Type, Authorization' always;
add_header 'Access-Control-Allow-Credentials' 'true' always;
}
# Handle CORS for GET method
if ($request_method = 'GET') {
add_header 'Access-Control-Allow-Origin' $allow_origin always;
add_header 'Access-Control-Allow-Methods' 'POST, GET, OPTIONS' always;
add_header 'Access-Control-Allow-Headers' 'Content-Type, Authorization' always;
add_header 'Access-Control-Allow-Credentials' 'true' always;
}
proxy_cache my_cache; proxy_cache my_cache;
proxy_cache_revalidate on; proxy_cache_revalidate on;
proxy_cache_min_uses 2; proxy_cache_min_uses 2;
@ -96,13 +68,6 @@ server {
location ~* \.(mp3|wav|ogg|flac|aac|aif|webm)$ { location ~* \.(mp3|wav|ogg|flac|aac|aif|webm)$ {
proxy_pass http://{{ $.APP }}-{{ $upstream_port }}; proxy_pass http://{{ $.APP }}-{{ $upstream_port }};
if ($request_method = 'GET') {
add_header 'Access-Control-Allow-Origin' $allow_origin always;
add_header 'Access-Control-Allow-Methods' 'GET, POST, OPTIONS' always;
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization' always;
add_header 'Access-Control-Expose-Headers' 'Content-Length,Content-Range' always;
add_header 'Access-Control-Allow-Credentials' 'true' always;
}
} }

View File

@ -90,7 +90,6 @@ class Author(Base):
Модель автора в системе. Модель автора в системе.
Attributes: Attributes:
user (str): Идентификатор пользователя в системе авторизации
name (str): Отображаемое имя name (str): Отображаемое имя
slug (str): Уникальный строковый идентификатор slug (str): Уникальный строковый идентификатор
bio (str): Краткая биография/статус bio (str): Краткая биография/статус
@ -105,8 +104,6 @@ class Author(Base):
__tablename__ = "author" __tablename__ = "author"
user = Column(String) # unbounded link with authorizer's User type
name = Column(String, nullable=True, comment="Display name") name = Column(String, nullable=True, comment="Display name")
slug = Column(String, unique=True, comment="Author's slug") slug = Column(String, unique=True, comment="Author's slug")
bio = Column(String, nullable=True, comment="Bio") # status description bio = Column(String, nullable=True, comment="Bio") # status description
@ -124,12 +121,14 @@ class Author(Base):
# Определяем индексы # Определяем индексы
__table_args__ = ( __table_args__ = (
# Индекс для быстрого поиска по имени
Index("idx_author_name", "name"),
# Индекс для быстрого поиска по slug # Индекс для быстрого поиска по slug
Index("idx_author_slug", "slug"), Index("idx_author_slug", "slug"),
# Индекс для быстрого поиска по идентификатору пользователя
Index("idx_author_user", "user"),
# Индекс для фильтрации неудаленных авторов # Индекс для фильтрации неудаленных авторов
Index("idx_author_deleted_at", "deleted_at", postgresql_where=deleted_at.is_(None)), Index(
"idx_author_deleted_at", "deleted_at", postgresql_where=deleted_at.is_(None)
),
# Индекс для сортировки по времени создания (для новых авторов) # Индекс для сортировки по времени создания (для новых авторов)
Index("idx_author_created_at", "created_at"), Index("idx_author_created_at", "created_at"),
# Индекс для сортировки по времени последнего посещения # Индекс для сортировки по времени последнего посещения

View File

@ -6,7 +6,6 @@ from sqlalchemy.orm import relationship
from orm.author import Author from orm.author import Author
from orm.topic import Topic from orm.topic import Topic
from services.db import Base from services.db import Base
from orm.shout import Shout
class DraftTopic(Base): class DraftTopic(Base):

View File

@ -71,34 +71,6 @@ class ShoutAuthor(Base):
class Shout(Base): class Shout(Base):
""" """
Публикация в системе. Публикация в системе.
Attributes:
body (str)
slug (str)
cover (str) : "Cover image url"
cover_caption (str) : "Cover image alt caption"
lead (str)
title (str)
subtitle (str)
layout (str)
media (dict)
authors (list[Author])
topics (list[Topic])
reactions (list[Reaction])
lang (str)
version_of (int)
oid (str)
seo (str) : JSON
draft (int)
created_at (int)
updated_at (int)
published_at (int)
featured_at (int)
deleted_at (int)
created_by (int)
updated_by (int)
deleted_by (int)
community (int)
""" """
__tablename__ = "shout" __tablename__ = "shout"

2
pyproject.toml Normal file
View File

@ -0,0 +1,2 @@
[tool.ruff]
line-length = 108

View File

@ -13,10 +13,6 @@ starlette
gql gql
ariadne ariadne
granian granian
# NLP and search
httpx
orjson orjson
pydantic pydantic
trafilatura trafilatura

View File

@ -8,7 +8,6 @@ from resolvers.author import ( # search_authors,
get_author_id, get_author_id,
get_authors_all, get_authors_all,
load_authors_by, load_authors_by,
load_authors_search,
update_author, update_author,
) )
from resolvers.community import get_communities_all, get_community from resolvers.community import get_communities_all, get_community
@ -74,7 +73,6 @@ __all__ = [
"get_author_follows_authors", "get_author_follows_authors",
"get_authors_all", "get_authors_all",
"load_authors_by", "load_authors_by",
"load_authors_search",
"update_author", "update_author",
## "search_authors", ## "search_authors",
# community # community

View File

@ -20,7 +20,6 @@ from services.auth import login_required
from services.db import local_session from services.db import local_session
from services.redis import redis from services.redis import redis
from services.schema import mutation, query from services.schema import mutation, query
from services.search import search_service
from utils.logger import root_logger as logger from utils.logger import root_logger as logger
DEFAULT_COMMUNITIES = [1] DEFAULT_COMMUNITIES = [1]
@ -302,46 +301,6 @@ async def load_authors_by(_, _info, by, limit, offset):
return await get_authors_with_stats(limit, offset, by) return await get_authors_with_stats(limit, offset, by)
@query.field("load_authors_search")
async def load_authors_search(_, info, text: str, limit: int = 10, offset: int = 0):
"""
Resolver for searching authors by text. Works with txt-ai search endpony.
Args:
text: Search text
limit: Maximum number of authors to return
offset: Offset for pagination
Returns:
list: List of authors matching the search criteria
"""
# Get author IDs from search engine (already sorted by relevance)
search_results = await search_service.search_authors(text, limit, offset)
if not search_results:
return []
author_ids = [result.get("id") for result in search_results if result.get("id")]
if not author_ids:
return []
# Fetch full author objects from DB
with local_session() as session:
# Simple query to get authors by IDs - no need for stats here
authors_query = select(Author).filter(Author.id.in_(author_ids))
db_authors = session.execute(authors_query).scalars().all()
if not db_authors:
return []
# Create a dictionary for quick lookup
authors_dict = {str(author.id): author for author in db_authors}
# Keep the order from search results (maintains the relevance sorting)
ordered_authors = [authors_dict[author_id] for author_id in author_ids if author_id in authors_dict]
return ordered_authors
def get_author_id_from(slug="", user=None, author_id=None): def get_author_id_from(slug="", user=None, author_id=None):
try: try:
author_id = None author_id = None

View File

@ -1,25 +0,0 @@
{
"include": [
"."
],
"exclude": [
"**/node_modules",
"**/__pycache__",
"**/.*"
],
"defineConstant": {
"DEBUG": true
},
"venvPath": ".",
"venv": ".venv",
"pythonVersion": "3.11",
"typeCheckingMode": "strict",
"reportMissingImports": true,
"reportMissingTypeStubs": false,
"reportUnknownMemberType": false,
"reportUnknownParameterType": false,
"reportUnknownVariableType": false,
"reportUnknownArgumentType": false,
"reportPrivateUsage": false,
"reportUntypedFunctionDecorator": false
}

View File

@ -10,7 +10,7 @@ from orm.shout import Shout, ShoutAuthor, ShoutTopic
from orm.topic import Topic from orm.topic import Topic
from services.db import json_array_builder, json_builder, local_session from services.db import json_array_builder, json_builder, local_session
from services.schema import query from services.schema import query
from services.search import search_text, get_search_count from services.search import search_text
from services.viewed import ViewedStorage from services.viewed import ViewedStorage
from utils.logger import root_logger as logger from utils.logger import root_logger as logger
@ -187,10 +187,12 @@ def get_shouts_with_links(info, q, limit=20, offset=0):
""" """
shouts = [] shouts = []
try: try:
# logger.info(f"Starting get_shouts_with_links with limit={limit}, offset={offset}")
q = q.limit(limit).offset(offset) q = q.limit(limit).offset(offset)
with local_session() as session: with local_session() as session:
shouts_result = session.execute(q).all() shouts_result = session.execute(q).all()
# logger.info(f"Got {len(shouts_result) if shouts_result else 0} shouts from query")
if not shouts_result: if not shouts_result:
logger.warning("No shouts found in query result") logger.warning("No shouts found in query result")
@ -201,6 +203,7 @@ def get_shouts_with_links(info, q, limit=20, offset=0):
shout = None shout = None
if hasattr(row, "Shout"): if hasattr(row, "Shout"):
shout = row.Shout shout = row.Shout
# logger.debug(f"Processing shout#{shout.id} at index {idx}")
if shout: if shout:
shout_id = int(f"{shout.id}") shout_id = int(f"{shout.id}")
shout_dict = shout.dict() shout_dict = shout.dict()
@ -228,16 +231,20 @@ def get_shouts_with_links(info, q, limit=20, offset=0):
topics = None topics = None
if has_field(info, "topics") and hasattr(row, "topics"): if has_field(info, "topics") and hasattr(row, "topics"):
topics = orjson.loads(row.topics) if isinstance(row.topics, str) else row.topics topics = orjson.loads(row.topics) if isinstance(row.topics, str) else row.topics
# logger.debug(f"Shout#{shout_id} topics: {topics}")
shout_dict["topics"] = topics shout_dict["topics"] = topics
if has_field(info, "main_topic"): if has_field(info, "main_topic"):
main_topic = None main_topic = None
if hasattr(row, "main_topic"): if hasattr(row, "main_topic"):
# logger.debug(f"Raw main_topic for shout#{shout_id}: {row.main_topic}")
main_topic = ( main_topic = (
orjson.loads(row.main_topic) if isinstance(row.main_topic, str) else row.main_topic orjson.loads(row.main_topic) if isinstance(row.main_topic, str) else row.main_topic
) )
# logger.debug(f"Parsed main_topic for shout#{shout_id}: {main_topic}")
if not main_topic and topics and len(topics) > 0: if not main_topic and topics and len(topics) > 0:
# logger.info(f"No main_topic found for shout#{shout_id}, using first topic from list")
main_topic = { main_topic = {
"id": topics[0]["id"], "id": topics[0]["id"],
"title": topics[0]["title"], "title": topics[0]["title"],
@ -245,8 +252,10 @@ def get_shouts_with_links(info, q, limit=20, offset=0):
"is_main": True, "is_main": True,
} }
elif not main_topic: elif not main_topic:
logger.warning(f"No main_topic and no topics found for shout#{shout_id}")
main_topic = {"id": 0, "title": "no topic", "slug": "notopic", "is_main": True} main_topic = {"id": 0, "title": "no topic", "slug": "notopic", "is_main": True}
shout_dict["main_topic"] = main_topic shout_dict["main_topic"] = main_topic
# logger.debug(f"Final main_topic for shout#{shout_id}: {main_topic}")
if has_field(info, "authors") and hasattr(row, "authors"): if has_field(info, "authors") and hasattr(row, "authors"):
shout_dict["authors"] = ( shout_dict["authors"] = (
@ -273,6 +282,7 @@ def get_shouts_with_links(info, q, limit=20, offset=0):
logger.error(f"Fatal error in get_shouts_with_links: {e}", exc_info=True) logger.error(f"Fatal error in get_shouts_with_links: {e}", exc_info=True)
raise raise
finally: finally:
logger.info(f"Returning {len(shouts)} shouts from get_shouts_with_links")
return shouts return shouts
@ -391,49 +401,33 @@ async def load_shouts_search(_, info, text, options):
""" """
limit = options.get("limit", 10) limit = options.get("limit", 10)
offset = options.get("offset", 0) offset = options.get("offset", 0)
if isinstance(text, str) and len(text) > 2: if isinstance(text, str) and len(text) > 2:
# Get search results with pagination
results = await search_text(text, limit, offset) results = await search_text(text, limit, offset)
scores = {}
hits_ids = []
for sr in results:
shout_id = sr.get("id")
if shout_id:
shout_id = str(shout_id)
scores[shout_id] = sr.get("score")
hits_ids.append(shout_id)
if not results: q = (
logger.info(f"No search results found for '{text}'") query_with_stat(info)
return [] if has_field(info, "stat")
else select(Shout).filter(and_(Shout.published_at.is_not(None), Shout.deleted_at.is_(None)))
# Extract IDs in the order from the search engine )
hits_ids = [str(sr.get("id")) for sr in results if sr.get("id")]
# Query DB for only the IDs in the current page
q = query_with_stat(info)
q = q.filter(Shout.id.in_(hits_ids)) q = q.filter(Shout.id.in_(hits_ids))
q = apply_filters(q, options.get("filters", {})) q = apply_filters(q, options)
q = apply_sorting(q, options)
shouts = get_shouts_with_links(info, q, len(hits_ids), 0) shouts = get_shouts_with_links(info, q, limit, offset)
for shout in shouts:
# Reorder shouts to match the order from hits_ids shout.score = scores[f"{shout.id}"]
shouts_dict = {str(shout['id']): shout for shout in shouts} shouts.sort(key=lambda x: x.score, reverse=True)
ordered_shouts = [shouts_dict[shout_id] for shout_id in hits_ids if shout_id in shouts_dict] return shouts
return ordered_shouts
return [] return []
@query.field("get_search_results_count")
async def get_search_results_count(_, info, text):
"""
Returns the total count of search results for a search query.
:param _: Root query object (unused)
:param info: GraphQL context information
:param text: Search query text
:return: Total count of results
"""
if isinstance(text, str) and len(text) > 2:
count = await get_search_count(text)
return {"count": count}
return {"count": 0}
@query.field("load_shouts_unrated") @query.field("load_shouts_unrated")
async def load_shouts_unrated(_, info, options): async def load_shouts_unrated(_, info, options):
""" """

View File

@ -4,7 +4,7 @@ type Query {
get_author_id(user: String!): Author get_author_id(user: String!): Author
get_authors_all: [Author] get_authors_all: [Author]
load_authors_by(by: AuthorsBy!, limit: Int, offset: Int): [Author] load_authors_by(by: AuthorsBy!, limit: Int, offset: Int): [Author]
load_authors_search(text: String!, limit: Int, offset: Int): [Author!] # Search for authors by name or bio # search_authors(what: String!): [Author]
# community # community
get_community: Community get_community: Community
@ -33,7 +33,6 @@ type Query {
get_shout(slug: String, shout_id: Int): Shout get_shout(slug: String, shout_id: Int): Shout
load_shouts_by(options: LoadShoutsOptions): [Shout] load_shouts_by(options: LoadShoutsOptions): [Shout]
load_shouts_search(text: String!, options: LoadShoutsOptions): [SearchResult] load_shouts_search(text: String!, options: LoadShoutsOptions): [SearchResult]
get_search_results_count(text: String!): CountResult!
load_shouts_bookmarked(options: LoadShoutsOptions): [Shout] load_shouts_bookmarked(options: LoadShoutsOptions): [Shout]
# rating # rating

View File

@ -213,7 +213,6 @@ type CommonResult {
} }
type SearchResult { type SearchResult {
id: Int!
slug: String! slug: String!
title: String! title: String!
cover: String cover: String
@ -281,7 +280,3 @@ type MyRateComment {
my_rate: ReactionKind my_rate: ReactionKind
} }
type CountResult {
count: Int!
}

1
services/__init__.py Normal file
View File

@ -0,0 +1 @@
# This file makes services a Python package

View File

@ -19,7 +19,7 @@ from sqlalchemy import (
inspect, inspect,
text, text,
) )
from sqlalchemy.orm import Session, configure_mappers, declarative_base, joinedload from sqlalchemy.orm import Session, configure_mappers, declarative_base
from sqlalchemy.sql.schema import Table from sqlalchemy.sql.schema import Table
from settings import DB_URL from settings import DB_URL
@ -259,32 +259,3 @@ def get_json_builder():
# Используем их в коде # Используем их в коде
json_builder, json_array_builder, json_cast = get_json_builder() json_builder, json_array_builder, json_cast = get_json_builder()
# Fetch all shouts, with authors preloaded
# This function is used for search indexing
async def fetch_all_shouts(session=None):
"""Fetch all published shouts for search indexing with authors preloaded"""
from orm.shout import Shout
close_session = False
if session is None:
session = local_session()
close_session = True
try:
# Fetch only published and non-deleted shouts with authors preloaded
query = session.query(Shout).options(
joinedload(Shout.authors)
).filter(
Shout.published_at.is_not(None),
Shout.deleted_at.is_(None)
)
shouts = query.all()
return shouts
except Exception as e:
logger.error(f"Error fetching shouts for search indexing: {e}")
return []
finally:
if close_session:
session.close()

View File

@ -88,7 +88,7 @@ async def notify_draft(draft_data, action: str = "publish"):
"subtitle": getattr(draft_data, "subtitle", None), "subtitle": getattr(draft_data, "subtitle", None),
"media": getattr(draft_data, "media", None), "media": getattr(draft_data, "media", None),
"created_at": getattr(draft_data, "created_at", None), "created_at": getattr(draft_data, "created_at", None),
"updated_at": getattr(draft_data, "updated_at", None) "updated_at": getattr(draft_data, "updated_at", None),
} }
# Если переданы связанные атрибуты, добавим их # Если переданы связанные атрибуты, добавим их
@ -100,7 +100,12 @@ async def notify_draft(draft_data, action: str = "publish"):
if hasattr(draft_data, "authors") and draft_data.authors is not None: if hasattr(draft_data, "authors") and draft_data.authors is not None:
draft_payload["authors"] = [ draft_payload["authors"] = [
{"id": a.id, "name": a.name, "slug": a.slug, "pic": getattr(a, "pic", None)} {
"id": a.id,
"name": a.name,
"slug": a.slug,
"pic": getattr(a, "pic", None),
}
for a in draft_data.authors for a in draft_data.authors
] ]

View File

@ -0,0 +1,19 @@
{
"include": ["."],
"exclude": ["**/node_modules", "**/__pycache__", "**/.*"],
"defineConstant": {
"DEBUG": true
},
"venvPath": ".",
"venv": ".venv",
"pythonVersion": "3.11",
"typeCheckingMode": "strict",
"reportMissingImports": true,
"reportMissingTypeStubs": false,
"reportUnknownMemberType": false,
"reportUnknownParameterType": false,
"reportUnknownVariableType": false,
"reportUnknownArgumentType": false,
"reportPrivateUsage": false,
"reportUntypedFunctionDecorator": false
}

View File

@ -29,19 +29,12 @@ async def request_graphql_data(gql, url=AUTH_URL, headers=None):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post(url, json=gql, headers=headers) response = await client.post(url, json=gql, headers=headers)
if response.status_code == 200: if response.status_code == 200:
# Check if the response has content before parsing
if response.content and len(response.content.strip()) > 0:
try:
data = response.json() data = response.json()
errors = data.get("errors") errors = data.get("errors")
if errors: if errors:
logger.error(f"{url} response: {data}") logger.error(f"{url} response: {data}")
else: else:
return data return data
except Exception as json_err:
logger.error(f"JSON decode error: {json_err}, Response content: {response.text[:100]}")
else:
logger.error(f"{url}: Response is empty")
else: else:
logger.error(f"{url}: {response.status_code} {response.text}") logger.error(f"{url}: {response.status_code} {response.text}")
except Exception as _e: except Exception as _e:

File diff suppressed because it is too large Load Diff

25
tests/auth/conftest.py Normal file
View File

@ -0,0 +1,25 @@
import pytest
from typing import Dict
@pytest.fixture
def oauth_settings() -> Dict[str, Dict[str, str]]:
"""Тестовые настройки OAuth"""
return {
"GOOGLE": {"id": "test_google_id", "key": "test_google_secret"},
"GITHUB": {"id": "test_github_id", "key": "test_github_secret"},
"FACEBOOK": {"id": "test_facebook_id", "key": "test_facebook_secret"},
}
@pytest.fixture
def frontend_url() -> str:
"""URL фронтенда для тестов"""
return "https://localhost:3000"
@pytest.fixture(autouse=True)
def mock_settings(monkeypatch, oauth_settings, frontend_url):
"""Подменяем настройки для тестов"""
monkeypatch.setattr("auth.oauth.OAUTH_CLIENTS", oauth_settings)
monkeypatch.setattr("auth.oauth.FRONTEND_URL", frontend_url)

222
tests/auth/test_oauth.py Normal file
View File

@ -0,0 +1,222 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from starlette.responses import JSONResponse, RedirectResponse
from auth.oauth import get_user_profile, oauth_login, oauth_callback
# Подменяем настройки для тестов
with (
patch("auth.oauth.FRONTEND_URL", "https://localhost:3000"),
patch(
"auth.oauth.OAUTH_CLIENTS",
{
"GOOGLE": {"id": "test_google_id", "key": "test_google_secret"},
"GITHUB": {"id": "test_github_id", "key": "test_github_secret"},
"FACEBOOK": {"id": "test_facebook_id", "key": "test_facebook_secret"},
},
),
):
@pytest.fixture
def mock_request():
"""Фикстура для мока запроса"""
request = MagicMock()
request.session = {}
request.path_params = {}
request.query_params = {}
return request
@pytest.fixture
def mock_oauth_client():
"""Фикстура для мока OAuth клиента"""
client = AsyncMock()
client.authorize_redirect = AsyncMock()
client.authorize_access_token = AsyncMock()
client.get = AsyncMock()
return client
@pytest.mark.asyncio
async def test_get_user_profile_google():
"""Тест получения профиля из Google"""
client = AsyncMock()
token = {
"userinfo": {
"sub": "123",
"email": "test@gmail.com",
"name": "Test User",
"picture": "https://lh3.googleusercontent.com/photo=s96",
}
}
profile = await get_user_profile("google", client, token)
assert profile["id"] == "123"
assert profile["email"] == "test@gmail.com"
assert profile["name"] == "Test User"
assert profile["picture"] == "https://lh3.googleusercontent.com/photo=s600"
@pytest.mark.asyncio
async def test_get_user_profile_github():
"""Тест получения профиля из GitHub"""
client = AsyncMock()
client.get.side_effect = [
MagicMock(
json=lambda: {
"id": 456,
"login": "testuser",
"name": "Test User",
"avatar_url": "https://github.com/avatar.jpg",
}
),
MagicMock(
json=lambda: [
{"email": "other@github.com", "primary": False},
{"email": "test@github.com", "primary": True},
]
),
]
profile = await get_user_profile("github", client, {})
assert profile["id"] == "456"
assert profile["email"] == "test@github.com"
assert profile["name"] == "Test User"
assert profile["picture"] == "https://github.com/avatar.jpg"
@pytest.mark.asyncio
async def test_get_user_profile_facebook():
"""Тест получения профиля из Facebook"""
client = AsyncMock()
client.get.return_value = MagicMock(
json=lambda: {
"id": "789",
"name": "Test User",
"email": "test@facebook.com",
"picture": {"data": {"url": "https://facebook.com/photo.jpg"}},
}
)
profile = await get_user_profile("facebook", client, {})
assert profile["id"] == "789"
assert profile["email"] == "test@facebook.com"
assert profile["name"] == "Test User"
assert profile["picture"] == "https://facebook.com/photo.jpg"
@pytest.mark.asyncio
async def test_oauth_login_success(mock_request, mock_oauth_client):
"""Тест успешного начала OAuth авторизации"""
mock_request.path_params["provider"] = "google"
# Настраиваем мок для authorize_redirect
redirect_response = RedirectResponse(url="http://example.com")
mock_oauth_client.authorize_redirect.return_value = redirect_response
with patch("auth.oauth.oauth.create_client", return_value=mock_oauth_client):
response = await oauth_login(mock_request)
assert isinstance(response, RedirectResponse)
assert mock_request.session["provider"] == "google"
assert "code_verifier" in mock_request.session
assert "state" in mock_request.session
mock_oauth_client.authorize_redirect.assert_called_once()
@pytest.mark.asyncio
async def test_oauth_login_invalid_provider(mock_request):
"""Тест с неправильным провайдером"""
mock_request.path_params["provider"] = "invalid"
response = await oauth_login(mock_request)
assert isinstance(response, JSONResponse)
assert response.status_code == 400
assert "Invalid provider" in response.body.decode()
@pytest.mark.asyncio
async def test_oauth_callback_success(mock_request, mock_oauth_client):
"""Тест успешного OAuth callback"""
mock_request.session = {
"provider": "google",
"code_verifier": "test_verifier",
"state": "test_state",
}
mock_request.query_params["state"] = "test_state"
mock_oauth_client.authorize_access_token.return_value = {
"userinfo": {"sub": "123", "email": "test@gmail.com", "name": "Test User"}
}
with (
patch("auth.oauth.oauth.create_client", return_value=mock_oauth_client),
patch("auth.oauth.local_session") as mock_session,
patch("auth.oauth.TokenStorage.create_session", return_value="test_token"),
):
# Мокаем сессию базы данных
session = MagicMock()
session.query.return_value.filter.return_value.first.return_value = None
mock_session.return_value.__enter__.return_value = session
response = await oauth_callback(mock_request)
assert isinstance(response, RedirectResponse)
assert response.status_code == 307
assert "auth/success" in response.headers["location"]
# Проверяем cookie
cookies = response.headers.getlist("set-cookie")
assert any("session_token=test_token" in cookie for cookie in cookies)
assert any("httponly" in cookie.lower() for cookie in cookies)
assert any("secure" in cookie.lower() for cookie in cookies)
# Проверяем очистку сессии
assert "code_verifier" not in mock_request.session
assert "provider" not in mock_request.session
assert "state" not in mock_request.session
@pytest.mark.asyncio
async def test_oauth_callback_invalid_state(mock_request):
"""Тест с неправильным state параметром"""
mock_request.session = {"provider": "google", "state": "correct_state"}
mock_request.query_params["state"] = "wrong_state"
response = await oauth_callback(mock_request)
assert isinstance(response, JSONResponse)
assert response.status_code == 400
assert "Invalid state" in response.body.decode()
@pytest.mark.asyncio
async def test_oauth_callback_existing_user(mock_request, mock_oauth_client):
"""Тест OAuth callback с существующим пользователем"""
mock_request.session = {
"provider": "google",
"code_verifier": "test_verifier",
"state": "test_state",
}
mock_request.query_params["state"] = "test_state"
mock_oauth_client.authorize_access_token.return_value = {
"userinfo": {"sub": "123", "email": "test@gmail.com", "name": "Test User"}
}
with (
patch("auth.oauth.oauth.create_client", return_value=mock_oauth_client),
patch("auth.oauth.local_session") as mock_session,
patch("auth.oauth.TokenStorage.create_session", return_value="test_token"),
):
# Мокаем существующего пользователя
existing_user = MagicMock()
session = MagicMock()
session.query.return_value.filter.return_value.first.return_value = existing_user
mock_session.return_value.__enter__.return_value = session
response = await oauth_callback(mock_request)
assert isinstance(response, RedirectResponse)
assert response.status_code == 307
# Проверяем обновление существующего пользователя
assert existing_user.name == "Test User"
assert existing_user.oauth == "google:123"
assert existing_user.email_verified is True

View File

@ -0,0 +1,9 @@
"""Тестовые настройки для OAuth"""
FRONTEND_URL = "https://localhost:3000"
OAUTH_CLIENTS = {
"GOOGLE": {"id": "test_google_id", "key": "test_google_secret"},
"GITHUB": {"id": "test_github_id", "key": "test_github_secret"},
"FACEBOOK": {"id": "test_facebook_id", "key": "test_facebook_secret"},
}

View File

@ -1,17 +1,7 @@
import asyncio import asyncio
import os
import pytest import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from starlette.testclient import TestClient
from main import app
from services.db import Base
from services.redis import redis from services.redis import redis
from tests.test_config import get_test_client
# Use SQLite for testing
TEST_DB_URL = "sqlite:///test.db"
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -23,38 +13,36 @@ def event_loop():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def test_engine(): def test_app():
"""Create a test database engine.""" """Create a test client and session factory."""
engine = create_engine(TEST_DB_URL) client, SessionLocal = get_test_client()
Base.metadata.create_all(engine) return client, SessionLocal
yield engine
Base.metadata.drop_all(engine)
os.remove("test.db")
@pytest.fixture @pytest.fixture
def db_session(test_engine): def db_session(test_app):
"""Create a new database session for a test.""" """Create a new database session for a test."""
connection = test_engine.connect() _, SessionLocal = test_app
transaction = connection.begin() session = SessionLocal()
session = Session(bind=connection)
yield session yield session
session.rollback()
session.close() session.close()
transaction.rollback()
connection.close()
@pytest.fixture
def test_client(test_app):
"""Get the test client."""
client, _ = test_app
return client
@pytest.fixture @pytest.fixture
async def redis_client(): async def redis_client():
"""Create a test Redis client.""" """Create a test Redis client."""
await redis.connect() await redis.connect()
await redis.flushall() # Очищаем Redis перед каждым тестом
yield redis yield redis
await redis.flushall() # Очищаем после теста
await redis.disconnect() await redis.disconnect()
@pytest.fixture
def test_client():
"""Create a TestClient instance."""
return TestClient(app)

67
tests/test_config.py Normal file
View File

@ -0,0 +1,67 @@
"""
Конфигурация для тестов
"""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.testclient import TestClient
# Используем in-memory SQLite для тестов
TEST_DB_URL = "sqlite:///:memory:"
class DatabaseMiddleware(BaseHTTPMiddleware):
"""Middleware для внедрения сессии БД"""
def __init__(self, app, session_maker):
super().__init__(app)
self.session_maker = session_maker
async def dispatch(self, request, call_next):
session = self.session_maker()
request.state.db = session
try:
response = await call_next(request)
finally:
session.close()
return response
def create_test_app():
"""Create a test Starlette application."""
from services.db import Base
# Создаем движок и таблицы
engine = create_engine(
TEST_DB_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
echo=False,
)
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
# Создаем фабрику сессий
SessionLocal = sessionmaker(bind=engine)
# Создаем middleware для сессий
middleware = [Middleware(DatabaseMiddleware, session_maker=SessionLocal)]
# Создаем тестовое приложение
app = Starlette(
debug=True,
middleware=middleware,
routes=[], # Здесь можно добавить тестовые маршруты если нужно
)
return app, SessionLocal
def get_test_client():
"""Get a test client with initialized database."""
app, SessionLocal = create_test_app()
return TestClient(app), SessionLocal

View File

@ -53,7 +53,11 @@ async def test_create_reaction(test_client, db_session, test_setup):
} }
""", """,
"variables": { "variables": {
"reaction": {"shout": test_setup["shout"].id, "kind": ReactionKind.LIKE.value, "body": "Great post!"} "reaction": {
"shout": test_setup["shout"].id,
"kind": ReactionKind.LIKE.value,
"body": "Great post!",
}
}, },
}, },
) )

View File

@ -1,70 +0,0 @@
from datetime import datetime, timedelta
import pytest
from pydantic import ValidationError
from auth.validations import (
AuthInput,
AuthResponse,
TokenPayload,
UserRegistrationInput,
)
class TestAuthValidations:
def test_auth_input(self):
"""Test basic auth input validation"""
# Valid case
auth = AuthInput(user_id="123", username="testuser", token="1234567890abcdef1234567890abcdef")
assert auth.user_id == "123"
assert auth.username == "testuser"
# Invalid cases
with pytest.raises(ValidationError):
AuthInput(user_id="", username="test", token="x" * 32)
with pytest.raises(ValidationError):
AuthInput(user_id="123", username="t", token="x" * 32)
def test_user_registration(self):
"""Test user registration validation"""
# Valid case
user = UserRegistrationInput(email="test@example.com", password="SecurePass123!", name="Test User")
assert user.email == "test@example.com"
assert user.name == "Test User"
# Test email validation
with pytest.raises(ValidationError) as exc:
UserRegistrationInput(email="invalid-email", password="SecurePass123!", name="Test")
assert "Invalid email format" in str(exc.value)
# Test password validation
with pytest.raises(ValidationError) as exc:
UserRegistrationInput(email="test@example.com", password="weak", name="Test")
assert "String should have at least 8 characters" in str(exc.value)
def test_token_payload(self):
"""Test token payload validation"""
now = datetime.utcnow()
exp = now + timedelta(hours=1)
payload = TokenPayload(user_id="123", username="testuser", exp=exp, iat=now)
assert payload.user_id == "123"
assert payload.username == "testuser"
assert payload.scopes == [] # Default empty list
def test_auth_response(self):
"""Test auth response validation"""
# Success case
success_resp = AuthResponse(success=True, token="valid_token", user={"id": "123", "name": "Test"})
assert success_resp.success is True
assert success_resp.token == "valid_token"
# Error case
error_resp = AuthResponse(success=False, error="Invalid credentials")
assert error_resp.success is False
assert error_resp.error == "Invalid credentials"
# Invalid case - отсутствует обязательное поле token при success=True
with pytest.raises(ValidationError):
AuthResponse(success=True, user={"id": "123", "name": "Test"})