From 5d870358853e1830d03e26657ad5e0baa6443b53 Mon Sep 17 00:00:00 2001 From: Untone Date: Tue, 11 Feb 2025 12:00:35 +0300 Subject: [PATCH] 0.4.10-a --- CHANGELOG.md | 2 + auth/identity.py | 7 +- auth/tokenstorage.py | 5 +- auth/usermodel.py | 11 ++- auth/validations.py | 41 ++++++--- orm/draft.py | 24 +++--- orm/rbac.py | 176 -------------------------------------- orm/user.py | 105 ----------------------- resolvers/draft.py | 64 +++++++------- resolvers/editor.py | 13 ++- resolvers/feed.py | 7 +- resolvers/stat.py | 10 +-- resolvers/topic.py | 10 +-- schema/input.graphql | 52 ++++++++--- schema/mutation.graphql | 12 +-- schema/type.graphql | 31 ++++--- services/db.py | 12 ++- services/pretopic.py | 45 ++++------ services/schema.py | 10 +-- services/viewed.py | 7 +- settings.py | 2 +- tests/conftest.py | 12 ++- tests/test_drafts.py | 35 ++++---- tests/test_reactions.py | 28 +++--- tests/test_shouts.py | 26 +++--- tests/test_validations.py | 63 ++++---------- utils/logger.py | 25 ++++-- 27 files changed, 299 insertions(+), 536 deletions(-) delete mode 100644 orm/rbac.py delete mode 100644 orm/user.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 929a7f95..60e38b29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ #### [0.4.10] - 2025-02-10 - `add_author_stat_columns` fixed +- `Draft` orm and schema tuning and fixes +- `create_draft` and `update_draft` mutations and resolvers fixed #### [0.4.9] - 2025-02-09 diff --git a/auth/identity.py b/auth/identity.py index 3a096d9d..63dc4bc9 100644 --- a/auth/identity.py +++ b/auth/identity.py @@ -1,15 +1,16 @@ from binascii import hexlify from hashlib import sha256 -# from base.exceptions import InvalidPassword, InvalidToken -from services.db import local_session -from auth.exceptions import ExpiredToken, InvalidToken from passlib.hash import bcrypt +from auth.exceptions import ExpiredToken, InvalidToken from auth.jwtcodec import JWTCodec from auth.tokenstorage import TokenStorage from orm.user import User +# from base.exceptions import InvalidPassword, InvalidToken +from services.db import local_session + class Password: @staticmethod diff --git a/auth/tokenstorage.py b/auth/tokenstorage.py index 49fed14d..7e9fcaf8 100644 --- a/auth/tokenstorage.py +++ b/auth/tokenstorage.py @@ -1,9 +1,8 @@ from datetime import datetime, timedelta, timezone -from services.redis import redis -from auth.validations import AuthInput - from auth.jwtcodec import JWTCodec +from auth.validations import AuthInput +from services.redis import redis from settings import ONETIME_TOKEN_LIFE_SPAN, SESSION_TOKEN_LIFE_SPAN diff --git a/auth/usermodel.py b/auth/usermodel.py index 804e479c..8032543a 100644 --- a/auth/usermodel.py +++ b/auth/usermodel.py @@ -1,6 +1,15 @@ import time -from sqlalchemy import JSON, Boolean, Column, DateTime, ForeignKey, Integer, String, func +from sqlalchemy import ( + JSON, + Boolean, + Column, + DateTime, + ForeignKey, + Integer, + String, + func, +) from sqlalchemy.orm import relationship from services.db import Base diff --git a/auth/validations.py b/auth/validations.py index c4a7d253..f1b2a6a4 100644 --- a/auth/validations.py +++ b/auth/validations.py @@ -1,39 +1,44 @@ import re from datetime import datetime from typing import Dict, List, Optional, Union + from pydantic import BaseModel, Field, field_validator # RFC 5322 compliant email regex pattern EMAIL_PATTERN = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$" + class AuthInput(BaseModel): """Base model for authentication input validation""" + user_id: str = Field(description="Unique user identifier") username: str = Field(min_length=2, max_length=50) token: str = Field(min_length=32) - @field_validator('user_id') + @field_validator("user_id") @classmethod def validate_user_id(cls, v: str) -> str: if not v.strip(): raise ValueError("user_id cannot be empty") return v + class UserRegistrationInput(BaseModel): """Validation model for user registration""" + email: str = Field(max_length=254) # Max email length per RFC 5321 password: str = Field(min_length=8, max_length=100) name: str = Field(min_length=2, max_length=50) - - @field_validator('email') + + @field_validator("email") @classmethod def validate_email(cls, v: str) -> str: """Validate email format""" if not re.match(EMAIL_PATTERN, v): raise ValueError("Invalid email format") return v.lower() - - @field_validator('password') + + @field_validator("password") @classmethod def validate_password_strength(cls, v: str) -> str: """Validate password meets security requirements""" @@ -47,57 +52,65 @@ class UserRegistrationInput(BaseModel): raise ValueError("Password must contain at least one special character") return v + class UserLoginInput(BaseModel): """Validation model for user login""" + email: str = Field(max_length=254) password: str = Field(min_length=8, max_length=100) - @field_validator('email') + @field_validator("email") @classmethod def validate_email(cls, v: str) -> str: if not re.match(EMAIL_PATTERN, v): raise ValueError("Invalid email format") return v.lower() + class TokenPayload(BaseModel): """Validation model for JWT token payload""" + user_id: str username: str exp: datetime iat: datetime scopes: Optional[List[str]] = [] + class OAuthInput(BaseModel): """Validation model for OAuth input""" - provider: str = Field(pattern='^(google|github|facebook)$') + + provider: str = Field(pattern="^(google|github|facebook)$") code: str redirect_uri: Optional[str] = None - @field_validator('provider') + @field_validator("provider") @classmethod def validate_provider(cls, v: str) -> str: - valid_providers = ['google', 'github', 'facebook'] + valid_providers = ["google", "github", "facebook"] if v.lower() not in valid_providers: raise ValueError(f"Provider must be one of: {', '.join(valid_providers)}") return v.lower() + class AuthResponse(BaseModel): """Validation model for authentication responses""" + success: bool token: Optional[str] = None error: Optional[str] = None user: Optional[Dict[str, Union[str, int, bool]]] = None - @field_validator('error') + @field_validator("error") @classmethod def validate_error_if_not_success(cls, v: Optional[str], info) -> Optional[str]: - if not info.data.get('success') and not v: + if not info.data.get("success") and not v: raise ValueError("Error message required when success is False") return v - @field_validator('token') + @field_validator("token") @classmethod def validate_token_if_success(cls, v: Optional[str], info) -> Optional[str]: - if info.data.get('success') and not v: + if info.data.get("success") and not v: raise ValueError("Token required when success is True") - return v \ No newline at end of file + return v diff --git a/orm/draft.py b/orm/draft.py index 491c9319..c29794c5 100644 --- a/orm/draft.py +++ b/orm/draft.py @@ -28,27 +28,27 @@ class DraftAuthor(Base): class Draft(Base): __tablename__ = "draft" - + # required created_at: int = Column(Integer, nullable=False, default=lambda: int(time.time())) - updated_at: int | None = Column(Integer, nullable=True, index=True) - deleted_at: int | None = Column(Integer, nullable=True, index=True) + created_by: int = Column(ForeignKey("author.id"), nullable=False) - body: str = Column(String, nullable=False, comment="Body") + # optional + layout: str = Column(String, nullable=True, default="article") slug: str = Column(String, unique=True) - cover: str | None = Column(String, nullable=True, comment="Cover image url") - cover_caption: str | None = Column(String, nullable=True, comment="Cover image alt caption") + title: str = Column(String, nullable=True) + subtitle: str | None = Column(String, nullable=True) lead: str | None = Column(String, nullable=True) description: str | None = Column(String, nullable=True) - title: str = Column(String, nullable=False) - subtitle: str | None = Column(String, nullable=True) - layout: str = Column(String, nullable=False, default="article") + body: str = Column(String, nullable=False, comment="Body") media: dict | None = Column(JSON, nullable=True) - + cover: str | None = Column(String, nullable=True, comment="Cover image url") + cover_caption: str | None = Column(String, nullable=True, comment="Cover image alt caption") lang: str = Column(String, nullable=False, default="ru", comment="Language") - oid: str | None = Column(String, nullable=True) seo: str | None = Column(String, nullable=True) # JSON - created_by: int = Column(ForeignKey("author.id"), nullable=False) + # auto + updated_at: int | None = Column(Integer, nullable=True, index=True) + deleted_at: int | None = Column(Integer, nullable=True, index=True) updated_by: int | None = Column(ForeignKey("author.id"), nullable=True) deleted_by: int | None = Column(ForeignKey("author.id"), nullable=True) authors = relationship(Author, secondary="draft_author") diff --git a/orm/rbac.py b/orm/rbac.py deleted file mode 100644 index be22701b..00000000 --- a/orm/rbac.py +++ /dev/null @@ -1,176 +0,0 @@ -from services.db import REGISTRY, Base, local_session -from utils.logger import root_logger as logger - -from sqlalchemy.types import TypeDecorator -from sqlalchemy.types import String -from sqlalchemy import Column, ForeignKey, String, UniqueConstraint -from sqlalchemy.orm import relationship - -class ClassType(TypeDecorator): - impl = String - - @property - def python_type(self): - return NotImplemented - - def process_literal_param(self, value, dialect): - return NotImplemented - - def process_bind_param(self, value, dialect): - return value.__name__ if isinstance(value, type) else str(value) - - def process_result_value(self, value, dialect): - class_ = REGISTRY.get(value) - if class_ is None: - logger.warn(f"Can't find class <{value}>,find it yourself!", stacklevel=2) - return class_ - - -class Role(Base): - __tablename__ = "role" - - name = Column(String, nullable=False, comment="Role Name") - desc = Column(String, nullable=True, comment="Role Description") - community = Column( - ForeignKey("community.id", ondelete="CASCADE"), - nullable=False, - comment="Community", - ) - permissions = relationship(lambda: Permission) - - @staticmethod - def init_table(): - with local_session() as session: - r = session.query(Role).filter(Role.name == "author").first() - if r: - Role.default_role = r - return - - r1 = Role.create( - name="author", - desc="Role for an author", - community=1, - ) - - session.add(r1) - - Role.default_role = r1 - - r2 = Role.create( - name="reader", - desc="Role for a reader", - community=1, - ) - - session.add(r2) - - r3 = Role.create( - name="expert", - desc="Role for an expert", - community=1, - ) - - session.add(r3) - - r4 = Role.create( - name="editor", - desc="Role for an editor", - community=1, - ) - - session.add(r4) - - -class Operation(Base): - __tablename__ = "operation" - name = Column(String, nullable=False, unique=True, comment="Operation Name") - - @staticmethod - def init_table(): - with local_session() as session: - for name in ["create", "update", "delete", "load"]: - """ - * everyone can: - - load shouts - - load topics - - load reactions - - create an account to become a READER - * readers can: - - update and delete their account - - load chats - - load messages - - create reaction of some shout's author allowed kinds - - create shout to become an AUTHOR - * authors can: - - update and delete their shout - - invite other authors to edit shout and chat - - manage allowed reactions for their shout - * pros can: - - create/update/delete their community - - create/update/delete topics for their community - - """ - op = session.query(Operation).filter(Operation.name == name).first() - if not op: - op = Operation.create(name=name) - session.add(op) - session.commit() - - -class Resource(Base): - __tablename__ = "resource" - resourceClass = Column(String, nullable=False, unique=True, comment="Resource class") - name = Column(String, nullable=False, unique=True, comment="Resource name") - # TODO: community = Column(ForeignKey()) - - @staticmethod - def init_table(): - with local_session() as session: - for res in [ - "shout", - "topic", - "reaction", - "chat", - "message", - "invite", - "community", - "user", - ]: - r = session.query(Resource).filter(Resource.name == res).first() - if not r: - r = Resource.create(name=res, resourceClass=res) - session.add(r) - session.commit() - - -class Permission(Base): - __tablename__ = "permission" - __table_args__ = ( - UniqueConstraint("role", "operation", "resource"), - {"extend_existing": True}, - ) - - role: Column = Column(ForeignKey("role.id", ondelete="CASCADE"), nullable=False, comment="Role") - operation: Column = Column( - ForeignKey("operation.id", ondelete="CASCADE"), - nullable=False, - comment="Operation", - ) - resource: Column = Column( - ForeignKey("resource.id", ondelete="CASCADE"), - nullable=False, - comment="Resource", - ) - - -# if __name__ == "__main__": -# Base.metadata.create_all(engine) -# ops = [ -# Permission(role=1, operation=1, resource=1), -# Permission(role=1, operation=2, resource=1), -# Permission(role=1, operation=3, resource=1), -# Permission(role=1, operation=4, resource=1), -# Permission(role=2, operation=4, resource=1), -# ] -# global_session.add_all(ops) -# global_session.commit() \ No newline at end of file diff --git a/orm/user.py b/orm/user.py deleted file mode 100644 index c31c7913..00000000 --- a/orm/user.py +++ /dev/null @@ -1,105 +0,0 @@ -from sqlalchemy import JSON as JSONType -from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, func -from sqlalchemy.orm import relationship - -from services.db import Base, local_session -from orm.rbac import Role - - -class UserRating(Base): - __tablename__ = "user_rating" - - id = None - rater: Column = Column(ForeignKey("user.id"), primary_key=True, index=True) - user: Column = Column(ForeignKey("user.id"), primary_key=True, index=True) - value: Column = Column(Integer) - - @staticmethod - def init_table(): - pass - - -class UserRole(Base): - __tablename__ = "user_role" - - id = None - user = Column(ForeignKey("user.id"), primary_key=True, index=True) - role = Column(ForeignKey("role.id"), primary_key=True, index=True) - - -class AuthorFollower(Base): - __tablename__ = "author_follower" - - id = None - follower: Column = Column(ForeignKey("user.id"), primary_key=True, index=True) - author: Column = Column(ForeignKey("user.id"), primary_key=True, index=True) - createdAt = Column( - DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Created at" - ) - auto = Column(Boolean, nullable=False, default=False) - - -class User(Base): - __tablename__ = "user" - default_user = None - - email = Column(String, unique=True, nullable=False, comment="Email") - username = Column(String, nullable=False, comment="Login") - password = Column(String, nullable=True, comment="Password") - bio = Column(String, nullable=True, comment="Bio") # status description - about = Column(String, nullable=True, comment="About") # long and formatted - userpic = Column(String, nullable=True, comment="Userpic") - name = Column(String, nullable=True, comment="Display name") - slug = Column(String, unique=True, comment="User's slug") - muted = Column(Boolean, default=False) - emailConfirmed = Column(Boolean, default=False) - createdAt = Column( - DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Created at" - ) - lastSeen = Column( - DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Was online at" - ) - deletedAt = Column(DateTime(timezone=True), nullable=True, comment="Deleted at") - links = Column(JSONType, nullable=True, comment="Links") - oauth = Column(String, nullable=True) - ratings = relationship(UserRating, foreign_keys=UserRating.user) - roles = relationship(lambda: Role, secondary=UserRole.__tablename__) - oid = Column(String, nullable=True) - - @staticmethod - def init_table(): - with local_session() as session: - default = session.query(User).filter(User.slug == "anonymous").first() - if not default: - default_dict = { - "email": "noreply@discours.io", - "username": "noreply@discours.io", - "name": "Аноним", - "slug": "anonymous", - } - default = User.create(**default_dict) - session.add(default) - discours_dict = { - "email": "welcome@discours.io", - "username": "welcome@discours.io", - "name": "Дискурс", - "slug": "discours", - } - discours = User.create(**discours_dict) - session.add(discours) - session.commit() - User.default_user = default - - def get_permission(self): - scope = {} - for role in self.roles: - for p in role.permissions: - if p.resource not in scope: - scope[p.resource] = set() - scope[p.resource].add(p.operation) - print(scope) - return scope - - -# if __name__ == "__main__": -# print(User.get_permission(user_id=1)) \ No newline at end of file diff --git a/resolvers/draft.py b/resolvers/draft.py index 1de987b3..523de3d3 100644 --- a/resolvers/draft.py +++ b/resolvers/draft.py @@ -1,22 +1,25 @@ import time -from orm.topic import Topic + from sqlalchemy import select from sqlalchemy.sql import and_ from cache.cache import ( - cache_author, cache_by_id, cache_topic, - invalidate_shout_related_cache, invalidate_shouts_cache + cache_author, + cache_by_id, + cache_topic, + invalidate_shout_related_cache, + invalidate_shouts_cache, ) from orm.author import Author from orm.draft import Draft from orm.shout import Shout, ShoutAuthor, ShoutTopic +from orm.topic import Topic from services.auth import login_required from services.db import local_session -from services.schema import mutation, query -from utils.logger import root_logger as logger from services.notify import notify_shout +from services.schema import mutation, query from services.search import search_service - +from utils.logger import root_logger as logger def create_shout_from_draft(session, draft, author_id): @@ -59,16 +62,19 @@ async def load_drafts(_, info): @mutation.field("create_draft") @login_required -async def create_draft(_, info, shout_id: int = 0): +async def create_draft(_, info, draft_input): user_id = info.context.get("user_id") author_dict = info.context.get("author", {}) author_id = author_dict.get("id") + draft_id = draft_input.get("id") + if not draft_id: + return {"error": "Draft ID is required"} if not user_id or not author_id: - return {"error": "User ID and author ID are required"} + return {"error": "Author ID are required"} with local_session() as session: - draft = Draft(created_by=author_id) + draft = Draft(created_by=author_id, **draft_input) session.add(draft) session.commit() return {"draft": draft} @@ -81,11 +87,14 @@ async def update_draft(_, info, draft_input): author_dict = info.context.get("author", {}) author_id = author_dict.get("id") draft_id = draft_input.get("id") + if not draft_id: + return {"error": "Draft ID is required"} if not user_id or not author_id: - return {"error": "User ID and author ID are required"} + return {"error": "Author ID are required"} with local_session() as session: draft = session.query(Draft).filter(Draft.id == draft_id).first() + del draft_input["id"] Draft.update(draft, {**draft_input}) if not draft: return {"error": "Draft not found"} @@ -129,7 +138,7 @@ async def publish_draft(_, info, draft_id: int): shout = create_shout_from_draft(session, draft, author_id) session.add(shout) session.commit() - return {"shout": shout} + return {"shout": shout, "draft": draft} @mutation.field("unpublish_draft") @@ -149,15 +158,15 @@ async def unpublish_draft(_, info, draft_id: int): if shout: shout.published_at = None session.commit() - return {"shout": shout} + return {"shout": shout, "draft": draft} return {"error": "Failed to unpublish draft"} @mutation.field("publish_shout") @login_required -async def publish_shout(_, info, shout_id: int, draft=None): +async def publish_shout(_, info, shout_id: int): """Publish draft as a shout or update existing shout. - + Args: shout_id: ID существующей публикации или 0 для новой draft: Объект черновика (опционально) @@ -205,11 +214,13 @@ async def publish_shout(_, info, shout_id: int, draft=None): # или публикация была ранее снята с публикации if not was_published: shout.published_at = now - + # Обрабатываем связи с авторами - if not session.query(ShoutAuthor).filter( - and_(ShoutAuthor.shout == shout.id, ShoutAuthor.author == author_id) - ).first(): + if ( + not session.query(ShoutAuthor) + .filter(and_(ShoutAuthor.shout == shout.id, ShoutAuthor.author == author_id)) + .first() + ): sa = ShoutAuthor(shout=shout.id, author=author_id) session.add(sa) @@ -217,9 +228,7 @@ async def publish_shout(_, info, shout_id: int, draft=None): if draft.topics: for topic in draft.topics: st = ShoutTopic( - topic=topic.id, - shout=shout.id, - main=topic.main if hasattr(topic, 'main') else False + topic=topic.id, shout=shout.id, main=topic.main if hasattr(topic, "main") else False ) session.add(st) @@ -229,13 +238,8 @@ async def publish_shout(_, info, shout_id: int, draft=None): # Инвалидируем кэш только если это новая публикация или была снята с публикации if not was_published: - cache_keys = [ - "feed", - f"author_{author_id}", - "random_top", - "unrated" - ] - + cache_keys = ["feed", f"author_{author_id}", "random_top", "unrated"] + # Добавляем ключи для тем for topic in shout.topics: cache_keys.append(f"topic_{topic.id}") @@ -264,7 +268,7 @@ async def publish_shout(_, info, shout_id: int, draft=None): except Exception as e: logger.error(f"Failed to publish shout: {e}", exc_info=True) - if 'session' in locals(): + if "session" in locals(): session.rollback() return {"error": f"Failed to publish shout: {str(e)}"} @@ -299,5 +303,3 @@ async def unpublish_shout(_, info, shout_id: int): return {"error": "Failed to unpublish shout"} return {"shout": shout} - - diff --git a/resolvers/editor.py b/resolvers/editor.py index 17774372..1175dbe2 100644 --- a/resolvers/editor.py +++ b/resolvers/editor.py @@ -5,7 +5,12 @@ from sqlalchemy import and_, desc, select from sqlalchemy.orm import joinedload from sqlalchemy.sql.functions import coalesce -from cache.cache import cache_author, cache_topic, invalidate_shout_related_cache, invalidate_shouts_cache +from cache.cache import ( + cache_author, + cache_topic, + invalidate_shout_related_cache, + invalidate_shouts_cache, +) from orm.author import Author from orm.draft import Draft from orm.shout import Shout, ShoutAuthor, ShoutTopic @@ -114,11 +119,11 @@ async def get_my_shout(_, info, shout_id: int): logger.debug(f"got {len(shout.authors)} shout authors, created by {shout.created_by}") is_editor = "editor" in roles - logger.debug(f'viewer is{'' if is_editor else ' not'} editor') + logger.debug(f"viewer is{'' if is_editor else ' not'} editor") is_creator = author_id == shout.created_by - logger.debug(f'viewer is{'' if is_creator else ' not'} creator') + logger.debug(f"viewer is{'' if is_creator else ' not'} creator") is_author = bool(list(filter(lambda x: x.id == int(author_id), [x for x in shout.authors]))) - logger.debug(f'viewer is{'' if is_creator else ' not'} author') + logger.debug(f"viewer is{'' if is_creator else ' not'} author") can_edit = is_editor or is_author or is_creator if not can_edit: diff --git a/resolvers/feed.py b/resolvers/feed.py index 831411f8..b745038f 100644 --- a/resolvers/feed.py +++ b/resolvers/feed.py @@ -5,7 +5,12 @@ from sqlalchemy import and_, select from orm.author import Author, AuthorFollower from orm.shout import Shout, ShoutAuthor, ShoutReactionsFollower, ShoutTopic from orm.topic import Topic, TopicFollower -from resolvers.reader import apply_options, get_shouts_with_links, has_field, query_with_stat +from resolvers.reader import ( + apply_options, + get_shouts_with_links, + has_field, + query_with_stat, +) from services.auth import login_required from services.db import local_session from services.schema import query diff --git a/resolvers/stat.py b/resolvers/stat.py index ad9132c0..85ad69b3 100644 --- a/resolvers/stat.py +++ b/resolvers/stat.py @@ -67,10 +67,7 @@ def add_author_stat_columns(q): shouts_subq = ( select(func.count(distinct(Shout.id))) .select_from(ShoutAuthor) - .join(Shout, and_( - Shout.id == ShoutAuthor.shout, - Shout.deleted_at.is_(None) - )) + .join(Shout, and_(Shout.id == ShoutAuthor.shout, Shout.deleted_at.is_(None))) .where(ShoutAuthor.author == Author.id) .scalar_subquery() ) @@ -85,10 +82,7 @@ def add_author_stat_columns(q): # Основной запрос q = ( q.select_from(Author) - .add_columns( - shouts_subq.label("shouts_stat"), - followers_subq.label("followers_stat") - ) + .add_columns(shouts_subq.label("shouts_stat"), followers_subq.label("followers_stat")) .group_by(Author.id) ) diff --git a/resolvers/topic.py b/resolvers/topic.py index 8a3b9036..d7460c36 100644 --- a/resolvers/topic.py +++ b/resolvers/topic.py @@ -66,11 +66,11 @@ async def get_topic(_, _info, slug: str): # Мутация для создания новой темы @mutation.field("create_topic") @login_required -async def create_topic(_, _info, inp): +async def create_topic(_, _info, topic_input): with local_session() as session: # TODO: проверить права пользователя на создание темы для конкретного сообщества # и разрешение на создание - new_topic = Topic(**inp) + new_topic = Topic(**topic_input) session.add(new_topic) session.commit() @@ -80,14 +80,14 @@ async def create_topic(_, _info, inp): # Мутация для обновления темы @mutation.field("update_topic") @login_required -async def update_topic(_, _info, inp): - slug = inp["slug"] +async def update_topic(_, _info, topic_input): + slug = topic_input["slug"] with local_session() as session: topic = session.query(Topic).filter(Topic.slug == slug).first() if not topic: return {"error": "topic not found"} else: - Topic.update(topic, inp) + Topic.update(topic, topic_input) session.add(topic) session.commit() diff --git a/schema/input.graphql b/schema/input.graphql index 07367fc2..ff3fa4dd 100644 --- a/schema/input.graphql +++ b/schema/input.graphql @@ -1,15 +1,47 @@ -input DraftInput { - slug: String +input MediaItemInput { + url: String title: String body: String + source: String + pic: String + date: String + genre: String + artist: String + lyrics: String +} + +input AuthorInput { + id: Int! + slug: String +} + +input TopicInput { + id: Int + slug: String! + title: String + body: String + pic: String +} + +input DraftInput { + id: Int + # no created_at, updated_at, deleted_at, updated_by, deleted_by + layout: String + shout_id: Int # Changed from shout: Shout + author_ids: [Int!] # Changed from authors: [Author] + topic_ids: [Int!] # Changed from topics: [Topic] + main_topic_id: Int # Changed from main_topic: Topic + media: [MediaItemInput] # Changed to use MediaItemInput lead: String description: String - layout: String - media: String - topics: [TopicInput] - community: Int subtitle: String + lang: String + seo: String + body: String + title: String + slug: String cover: String + cover_caption: String } input ProfileInput { @@ -21,14 +53,6 @@ input ProfileInput { about: String } -input TopicInput { - id: Int - slug: String! - title: String - body: String - pic: String -} - input ReactionInput { id: Int kind: ReactionKind! diff --git a/schema/mutation.graphql b/schema/mutation.graphql index df2074a5..c5f48fde 100644 --- a/schema/mutation.graphql +++ b/schema/mutation.graphql @@ -4,8 +4,8 @@ type Mutation { update_author(profile: ProfileInput!): CommonResult! # draft - create_draft(input: DraftInput!): CommonResult! - update_draft(draft_id: Int!, input: DraftInput!): CommonResult! + create_draft(draft_input: DraftInput!): CommonResult! + update_draft(draft_id: Int!, draft_input: DraftInput!): CommonResult! delete_draft(draft_id: Int!): CommonResult! # publication publish_shout(shout_id: Int!): CommonResult! @@ -18,8 +18,8 @@ type Mutation { unfollow(what: FollowingEntity!, slug: String!): AuthorFollowsResult! # topic - create_topic(input: TopicInput!): CommonResult! - update_topic(input: TopicInput!): CommonResult! + create_topic(topic_input: TopicInput!): CommonResult! + update_topic(topic_input: TopicInput!): CommonResult! delete_topic(slug: String!): CommonResult! # reaction @@ -45,7 +45,7 @@ type Mutation { # community join_community(slug: String!): CommonResult! leave_community(slug: String!): CommonResult! - create_community(input: CommunityInput!): CommonResult! - update_community(input: CommunityInput!): CommonResult! + create_community(community_input: CommunityInput!): CommonResult! + update_community(community_input: CommunityInput!): CommonResult! delete_community(slug: String!): CommonResult! } diff --git a/schema/type.graphql b/schema/type.graphql index a9941a62..7a8344da 100644 --- a/schema/type.graphql +++ b/schema/type.graphql @@ -108,27 +108,30 @@ type Shout { type Draft { id: Int! - shout: Shout created_at: Int! + created_by: Author! + + layout: String + slug: String + title: String + subtitle: String + lead: String + description: String + body: String + media: [MediaItem] + cover: String + cover_caption: String + lang: String + seo: String + + # auto updated_at: Int deleted_at: Int - created_by: Author! updated_by: Author deleted_by: Author authors: [Author] topics: [Topic] - media: [MediaItem] - lead: String - description: String - subtitle: String - layout: String - lang: String - seo: String - body: String - title: String - slug: String - cover: String - cover_caption: String + } type Stat { diff --git a/services/db.py b/services/db.py index 394de75d..bd3072e4 100644 --- a/services/db.py +++ b/services/db.py @@ -6,7 +6,17 @@ import warnings from typing import Any, Callable, Dict, TypeVar import sqlalchemy -from sqlalchemy import JSON, Column, Engine, Integer, create_engine, event, exc, func, inspect +from sqlalchemy import ( + JSON, + Column, + Engine, + Integer, + create_engine, + event, + exc, + func, + inspect, +) from sqlalchemy.orm import Session, configure_mappers, declarative_base from sqlalchemy.sql.schema import Table diff --git a/services/pretopic.py b/services/pretopic.py index b6cd60c8..87e10c2d 100644 --- a/services/pretopic.py +++ b/services/pretopic.py @@ -1,8 +1,11 @@ import concurrent.futures -from typing import Dict, Tuple, List +from typing import Dict, List, Tuple + from txtai.embeddings import Embeddings + from services.logger import root_logger as logger + class TopicClassifier: def __init__(self, shouts_by_topic: Dict[str, str], publications: List[Dict[str, str]]): """ @@ -32,27 +35,21 @@ class TopicClassifier: Подготавливает векторные представления для тем и поиска. """ logger.info("Начинается подготовка векторных представлений...") - + # Модель для русского языка # TODO: model local caching model_path = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" - + # Инициализируем embeddings для классификации тем self.topic_embeddings = Embeddings(path=model_path) - topic_documents = [ - (topic, text) - for topic, text in self.shouts_by_topic.items() - ] + topic_documents = [(topic, text) for topic, text in self.shouts_by_topic.items()] self.topic_embeddings.index(topic_documents) - + # Инициализируем embeddings для поиска публикаций self.search_embeddings = Embeddings(path=model_path) - search_documents = [ - (str(pub['id']), f"{pub['title']} {pub['text']}") - for pub in self.publications - ] + search_documents = [(str(pub["id"]), f"{pub['title']} {pub['text']}") for pub in self.publications] self.search_embeddings.index(search_documents) - + logger.info("Подготовка векторных представлений завершена.") def predict_topic(self, text: str) -> Tuple[float, str]: @@ -66,13 +63,13 @@ class TopicClassifier: if not self.is_ready(): logger.error("Векторные представления не готовы. Вызовите initialize() и дождитесь завершения.") return 0.0, "unknown" - + try: # Ищем наиболее похожую тему results = self.topic_embeddings.search(text, 1) if not results: return 0.0, "unknown" - + score, topic = results[0] return float(score), topic @@ -92,25 +89,19 @@ class TopicClassifier: if not self.is_ready(): logger.error("Векторные представления не готовы. Вызовите initialize() и дождитесь завершения.") return [] - + try: # Ищем похожие публикации results = self.search_embeddings.search(query, limit) - + # Формируем результаты found_publications = [] for score, pub_id in results: # Находим публикацию по id - publication = next( - (pub for pub in self.publications if str(pub['id']) == pub_id), - None - ) + publication = next((pub for pub in self.publications if str(pub["id"]) == pub_id), None) if publication: - found_publications.append({ - **publication, - 'relevance': float(score) - }) - + found_publications.append({**publication, "relevance": float(score)}) + return found_publications except Exception as e: @@ -137,6 +128,7 @@ class TopicClassifier: if self._executor: self._executor.shutdown(wait=False) + # Пример использования: """ shouts_by_topic = { @@ -176,4 +168,3 @@ for pub in similar_publications: print(f"Заголовок: {pub['title']}") print(f"Текст: {pub['text'][:100]}...") """ - diff --git a/services/schema.py b/services/schema.py index dca8a467..06a30261 100644 --- a/services/schema.py +++ b/services/schema.py @@ -43,7 +43,6 @@ async def request_graphql_data(gql, url=AUTH_URL, headers=None): return None - def create_all_tables(): """Create all database tables in the correct order.""" from orm import author, community, draft, notification, reaction, shout, topic, user @@ -54,26 +53,21 @@ def create_all_tables(): author.Author, # Базовая таблица community.Community, # Базовая таблица topic.Topic, # Базовая таблица - # Связи для базовых таблиц author.AuthorFollower, # Зависит от Author community.CommunityFollower, # Зависит от Community topic.TopicFollower, # Зависит от Topic - # Черновики (теперь без зависимости от Shout) draft.Draft, # Зависит только от Author draft.DraftAuthor, # Зависит от Draft и Author draft.DraftTopic, # Зависит от Draft и Topic - # Основные таблицы контента shout.Shout, # Зависит от Author и Draft shout.ShoutAuthor, # Зависит от Shout и Author shout.ShoutTopic, # Зависит от Shout и Topic - # Реакции reaction.Reaction, # Зависит от Author и Shout shout.ShoutReactionsFollower, # Зависит от Shout и Reaction - # Дополнительные таблицы author.AuthorRating, # Зависит от Author notification.Notification, # Зависит от Author @@ -87,7 +81,7 @@ def create_all_tables(): for model in models_in_order: try: create_table_if_not_exists(session.get_bind(), model) - logger.info(f"Created or verified table: {model.__tablename__}") + # logger.info(f"Created or verified table: {model.__tablename__}") except Exception as e: logger.error(f"Error creating table {model.__tablename__}: {e}") - raise \ No newline at end of file + raise diff --git a/services/viewed.py b/services/viewed.py index 424bbb04..f1942de0 100644 --- a/services/viewed.py +++ b/services/viewed.py @@ -7,7 +7,12 @@ from typing import Dict # ga from google.analytics.data_v1beta import BetaAnalyticsDataClient -from google.analytics.data_v1beta.types import DateRange, Dimension, Metric, RunReportRequest +from google.analytics.data_v1beta.types import ( + DateRange, + Dimension, + Metric, + RunReportRequest, +) from google.analytics.data_v1beta.types import Filter as GAFilter from orm.author import Author diff --git a/settings.py b/settings.py index 2760140b..5567e60e 100644 --- a/settings.py +++ b/settings.py @@ -20,4 +20,4 @@ WEBHOOK_SECRET = environ.get("WEBHOOK_SECRET") or "nothing-else" ONETIME_TOKEN_LIFE_SPAN = 60 * 60 * 24 * 3 # 3 days SESSION_TOKEN_LIFE_SPAN = 60 * 60 * 24 * 30 # 30 days JWT_ALGORITHM = "HS256" -JWT_SECRET_KEY = environ.get("JWT_SECRET") or "nothing-else-jwt-secret-matters" \ No newline at end of file +JWT_SECRET_KEY = environ.get("JWT_SECRET") or "nothing-else-jwt-secret-matters" diff --git a/tests/conftest.py b/tests/conftest.py index b45ef526..7bd7f135 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import asyncio import os + import pytest from sqlalchemy import create_engine from sqlalchemy.orm import Session @@ -13,6 +14,7 @@ from settings import DB_URL # Use SQLite for testing TEST_DB_URL = "sqlite:///test.db" + @pytest.fixture(scope="session") def event_loop(): """Create an instance of the default event loop for the test session.""" @@ -20,6 +22,7 @@ def event_loop(): yield loop loop.close() + @pytest.fixture(scope="session") def test_engine(): """Create a test database engine.""" @@ -29,19 +32,21 @@ def test_engine(): Base.metadata.drop_all(engine) os.remove("test.db") + @pytest.fixture def db_session(test_engine): """Create a new database session for a test.""" connection = test_engine.connect() transaction = connection.begin() session = Session(bind=connection) - + yield session - + session.close() transaction.rollback() connection.close() + @pytest.fixture async def redis_client(): """Create a test Redis client.""" @@ -49,7 +54,8 @@ async def redis_client(): yield redis await redis.disconnect() + @pytest.fixture def test_client(): """Create a TestClient instance.""" - return TestClient(app) \ No newline at end of file + return TestClient(app) diff --git a/tests/test_drafts.py b/tests/test_drafts.py index 5cd44a76..a5f2b75e 100644 --- a/tests/test_drafts.py +++ b/tests/test_drafts.py @@ -1,19 +1,18 @@ import pytest -from orm.shout import Shout + from orm.author import Author +from orm.shout import Shout + @pytest.fixture def test_author(db_session): """Create a test author.""" - author = Author( - name="Test Author", - slug="test-author", - user="test-user-id" - ) + author = Author(name="Test Author", slug="test-author", user="test-user-id") db_session.add(author) db_session.commit() return author + @pytest.fixture def test_shout(db_session): """Create test shout with required fields.""" @@ -27,12 +26,13 @@ def test_shout(db_session): created_by=author.id, # Обязательное поле body="Test body", layout="article", - lang="ru" + lang="ru", ) db_session.add(shout) db_session.commit() return shout + @pytest.mark.asyncio async def test_create_shout(test_client, db_session, test_author): """Test creating a new shout.""" @@ -40,8 +40,8 @@ async def test_create_shout(test_client, db_session, test_author): "/", json={ "query": """ - mutation CreateDraft($input: DraftInput!) { - create_draft(input: $input) { + mutation CreateDraft($draft_input: DraftInput!) { + create_draft(draft_input: $draft_input) { error draft { id @@ -56,15 +56,16 @@ async def test_create_shout(test_client, db_session, test_author): "title": "Test Shout", "body": "This is a test shout", } - } - } + }, + }, ) - + assert response.status_code == 200 data = response.json() assert "errors" not in data assert data["data"]["create_draft"]["draft"]["title"] == "Test Shout" + @pytest.mark.asyncio async def test_load_drafts(test_client, db_session): """Test retrieving a shout.""" @@ -83,13 +84,11 @@ async def test_load_drafts(test_client, db_session): } } """, - "variables": { - "slug": "test-shout" - } - } + "variables": {"slug": "test-shout"}, + }, ) - + assert response.status_code == 200 data = response.json() assert "errors" not in data - assert data["data"]["load_drafts"]["drafts"] == [] \ No newline at end of file + assert data["data"]["load_drafts"]["drafts"] == [] diff --git a/tests/test_reactions.py b/tests/test_reactions.py index 0a37a20d..9b73e001 100644 --- a/tests/test_reactions.py +++ b/tests/test_reactions.py @@ -1,8 +1,11 @@ +from datetime import datetime + import pytest + +from orm.author import Author from orm.reaction import Reaction, ReactionKind from orm.shout import Shout -from orm.author import Author -from datetime import datetime + @pytest.fixture def test_setup(db_session): @@ -11,9 +14,9 @@ def test_setup(db_session): author = Author(name="Test Author", slug="test-author", user="test-user-id") db_session.add(author) db_session.flush() - + shout = Shout( - title="Test Shout", + title="Test Shout", slug="test-shout", created_by=author.id, body="This is a test shout", @@ -21,12 +24,13 @@ def test_setup(db_session): lang="ru", community=1, created_at=now, - updated_at=now + updated_at=now, ) db_session.add_all([author, shout]) db_session.commit() return {"author": author, "shout": shout} + @pytest.mark.asyncio async def test_create_reaction(test_client, db_session, test_setup): """Test creating a reaction on a shout.""" @@ -49,16 +53,12 @@ async def test_create_reaction(test_client, db_session, test_setup): } """, "variables": { - "reaction": { - "shout": test_setup["shout"].id, - "kind": ReactionKind.LIKE.value, - "body": "Great post!" - } - } - } + "reaction": {"shout": test_setup["shout"].id, "kind": ReactionKind.LIKE.value, "body": "Great post!"} + }, + }, ) - + assert response.status_code == 200 data = response.json() assert "error" not in data - assert data["data"]["create_reaction"]["reaction"]["kind"] == ReactionKind.LIKE.value \ No newline at end of file + assert data["data"]["create_reaction"]["reaction"]["kind"] == ReactionKind.LIKE.value diff --git a/tests/test_shouts.py b/tests/test_shouts.py index 66f3b1e9..8544b4d2 100644 --- a/tests/test_shouts.py +++ b/tests/test_shouts.py @@ -1,7 +1,10 @@ +from datetime import datetime + import pytest + from orm.author import Author from orm.shout import Shout -from datetime import datetime + @pytest.fixture def test_shout(db_session): @@ -10,7 +13,7 @@ def test_shout(db_session): author = Author(name="Test Author", slug="test-author", user="test-user-id") db_session.add(author) db_session.flush() - + now = int(datetime.now().timestamp()) shout = Shout( @@ -22,12 +25,13 @@ def test_shout(db_session): lang="ru", community=1, created_at=now, - updated_at=now + updated_at=now, ) db_session.add(shout) db_session.commit() return shout + @pytest.mark.asyncio async def test_get_shout(test_client, db_session): """Test retrieving a shout.""" @@ -36,7 +40,7 @@ async def test_get_shout(test_client, db_session): db_session.add(author) db_session.flush() now = int(datetime.now().timestamp()) - + # Создаем публикацию со всеми обязательными полями shout = Shout( title="Test Shout", @@ -47,11 +51,11 @@ async def test_get_shout(test_client, db_session): lang="ru", community=1, created_at=now, - updated_at=now + updated_at=now, ) db_session.add(shout) db_session.commit() - + response = test_client.post( "/", json={ @@ -71,13 +75,11 @@ async def test_get_shout(test_client, db_session): } } """, - "variables": { - "slug": "test-shout" - } - } + "variables": {"slug": "test-shout"}, + }, ) - + data = response.json() assert response.status_code == 200 assert "errors" not in data - assert data["data"]["get_shout"]["title"] == "Test Shout" \ No newline at end of file + assert data["data"]["get_shout"]["title"] == "Test Shout" diff --git a/tests/test_validations.py b/tests/test_validations.py index 2cffae7a..39fa7e24 100644 --- a/tests/test_validations.py +++ b/tests/test_validations.py @@ -1,75 +1,56 @@ -import pytest from datetime import datetime, timedelta + +import pytest from pydantic import ValidationError from auth.validations import ( AuthInput, - UserRegistrationInput, - UserLoginInput, - TokenPayload, + AuthResponse, OAuthInput, - AuthResponse + TokenPayload, + UserLoginInput, + UserRegistrationInput, ) + class TestAuthValidations: def test_auth_input(self): """Test basic auth input validation""" # Valid case - auth = AuthInput( - user_id="123", - username="testuser", - token="1234567890abcdef1234567890abcdef" - ) + auth = AuthInput(user_id="123", username="testuser", token="1234567890abcdef1234567890abcdef") assert auth.user_id == "123" assert auth.username == "testuser" # Invalid cases with pytest.raises(ValidationError): AuthInput(user_id="", username="test", token="x" * 32) - + with pytest.raises(ValidationError): AuthInput(user_id="123", username="t", token="x" * 32) def test_user_registration(self): """Test user registration validation""" # Valid case - user = UserRegistrationInput( - email="test@example.com", - password="SecurePass123!", - name="Test User" - ) + user = UserRegistrationInput(email="test@example.com", password="SecurePass123!", name="Test User") assert user.email == "test@example.com" assert user.name == "Test User" # Test email validation with pytest.raises(ValidationError) as exc: - UserRegistrationInput( - email="invalid-email", - password="SecurePass123!", - name="Test" - ) + UserRegistrationInput(email="invalid-email", password="SecurePass123!", name="Test") assert "Invalid email format" in str(exc.value) # Test password validation with pytest.raises(ValidationError) as exc: - UserRegistrationInput( - email="test@example.com", - password="weak", - name="Test" - ) + UserRegistrationInput(email="test@example.com", password="weak", name="Test") assert "String should have at least 8 characters" in str(exc.value) def test_token_payload(self): """Test token payload validation""" now = datetime.utcnow() exp = now + timedelta(hours=1) - - payload = TokenPayload( - user_id="123", - username="testuser", - exp=exp, - iat=now - ) + + payload = TokenPayload(user_id="123", username="testuser", exp=exp, iat=now) assert payload.user_id == "123" assert payload.username == "testuser" assert payload.scopes == [] # Default empty list @@ -77,25 +58,15 @@ class TestAuthValidations: def test_auth_response(self): """Test auth response validation""" # Success case - success_resp = AuthResponse( - success=True, - token="valid_token", - user={"id": "123", "name": "Test"} - ) + success_resp = AuthResponse(success=True, token="valid_token", user={"id": "123", "name": "Test"}) assert success_resp.success is True assert success_resp.token == "valid_token" # Error case - error_resp = AuthResponse( - success=False, - error="Invalid credentials" - ) + error_resp = AuthResponse(success=False, error="Invalid credentials") assert error_resp.success is False assert error_resp.error == "Invalid credentials" # Invalid case - отсутствует обязательное поле token при success=True with pytest.raises(ValidationError): - AuthResponse( - success=True, - user={"id": "123", "name": "Test"} - ) \ No newline at end of file + AuthResponse(success=True, user={"id": "123", "name": "Test"}) diff --git a/utils/logger.py b/utils/logger.py index 20d4e10f..3607e84d 100644 --- a/utils/logger.py +++ b/utils/logger.py @@ -6,17 +6,26 @@ import colorlog _lib_path = Path(__file__).parents[1] _leng_path = len(_lib_path.as_posix()) + def filter(record: logging.LogRecord): # Define `package` attribute with the relative path. - record.package = record.pathname[_leng_path+1:].replace(".py", "") - record.emoji = "🔍" if record.levelno == logging.DEBUG \ - else "🖊️" if record.levelno == logging.INFO \ - else "🚧" if record.levelno == logging.WARNING \ - else "❌" if record.levelno == logging.ERROR \ - else "🧨" if record.levelno == logging.CRITICAL \ + record.package = record.pathname[_leng_path + 1 :].replace(".py", "") + record.emoji = ( + "🔍" + if record.levelno == logging.DEBUG + else "🖊️" + if record.levelno == logging.INFO + else "🚧" + if record.levelno == logging.WARNING + else "❌" + if record.levelno == logging.ERROR + else "🧨" + if record.levelno == logging.CRITICAL else "" + ) return record + # Define the color scheme color_scheme = { "DEBUG": "light_black", @@ -55,9 +64,9 @@ class MultilineColoredFormatter(colorlog.ColoredFormatter): def format(self, record): # Add default emoji if not present - if not hasattr(record, 'emoji'): + if not hasattr(record, "emoji"): record = filter(record) - + message = record.getMessage() if "\n" in message: lines = message.split("\n")