diff --git a/CHANGELOG.md b/CHANGELOG.md index c499b46f..075c62c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,11 @@ - `Shout.draft` field added - `Draft` entity added - `create_draft`, `update_draft`, `delete_draft` mutations and resolvers added -- `get_shout_drafts` resolver updated +- `create_shout`, `update_shout`, `delete_shout` mutations removed from GraphQL API +- `load_drafts` resolver implemented +- `publish_` and `unpublish_` mutations and resolvers added +- `create_`, `update_`, `delete_` mutations and resolvers added for `Draft` entity +- tests with pytest for auth, shouts, drafts #### [0.4.8] - 2025-02-03 - `Reaction.deleted_at` filter on `update_reaction` resolver added diff --git a/auth/identity.py b/auth/identity.py index 5bbb6030..3a096d9d 100644 --- a/auth/identity.py +++ b/auth/identity.py @@ -2,13 +2,13 @@ from binascii import hexlify from hashlib import sha256 # from base.exceptions import InvalidPassword, InvalidToken -from base.orm import local_session -from jwt import DecodeError, ExpiredSignatureError +from services.db import local_session +from auth.exceptions import ExpiredToken, InvalidToken from passlib.hash import bcrypt from auth.jwtcodec import JWTCodec from auth.tokenstorage import TokenStorage -from orm import User +from orm.user import User class Password: @@ -79,10 +79,10 @@ class Identity: if not await TokenStorage.exist(f"{payload.user_id}-{payload.username}-{token}"): # raise InvalidToken("Login token has expired, please login again") return {"error": "Token has expired"} - except ExpiredSignatureError: + except ExpiredToken: # raise InvalidToken("Login token has expired, please try again") return {"error": "Token has expired"} - except DecodeError: + except InvalidToken: # raise InvalidToken("token format error") from e return {"error": "Token format error"} with local_session() as session: diff --git a/auth/tokenstorage.py b/auth/tokenstorage.py index 3ee5e7fd..49fed14d 100644 --- a/auth/tokenstorage.py +++ b/auth/tokenstorage.py @@ -1,7 +1,7 @@ from datetime import datetime, timedelta, timezone -from base.redis import redis -from validations.auth import AuthInput +from services.redis import redis +from auth.validations import AuthInput from auth.jwtcodec import JWTCodec from settings import ONETIME_TOKEN_LIFE_SPAN, SESSION_TOKEN_LIFE_SPAN diff --git a/auth/validations.py b/auth/validations.py new file mode 100644 index 00000000..c4a7d253 --- /dev/null +++ b/auth/validations.py @@ -0,0 +1,103 @@ +import re +from datetime import datetime +from typing import Dict, List, Optional, Union +from pydantic import BaseModel, Field, field_validator + +# RFC 5322 compliant email regex pattern +EMAIL_PATTERN = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$" + +class AuthInput(BaseModel): + """Base model for authentication input validation""" + user_id: str = Field(description="Unique user identifier") + username: str = Field(min_length=2, max_length=50) + token: str = Field(min_length=32) + + @field_validator('user_id') + @classmethod + def validate_user_id(cls, v: str) -> str: + if not v.strip(): + raise ValueError("user_id cannot be empty") + return v + +class UserRegistrationInput(BaseModel): + """Validation model for user registration""" + email: str = Field(max_length=254) # Max email length per RFC 5321 + password: str = Field(min_length=8, max_length=100) + name: str = Field(min_length=2, max_length=50) + + @field_validator('email') + @classmethod + def validate_email(cls, v: str) -> str: + """Validate email format""" + if not re.match(EMAIL_PATTERN, v): + raise ValueError("Invalid email format") + return v.lower() + + @field_validator('password') + @classmethod + def validate_password_strength(cls, v: str) -> str: + """Validate password meets security requirements""" + if not any(c.isupper() for c in v): + raise ValueError("Password must contain at least one uppercase letter") + if not any(c.islower() for c in v): + raise ValueError("Password must contain at least one lowercase letter") + if not any(c.isdigit() for c in v): + raise ValueError("Password must contain at least one number") + if not any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in v): + raise ValueError("Password must contain at least one special character") + return v + +class UserLoginInput(BaseModel): + """Validation model for user login""" + email: str = Field(max_length=254) + password: str = Field(min_length=8, max_length=100) + + @field_validator('email') + @classmethod + def validate_email(cls, v: str) -> str: + if not re.match(EMAIL_PATTERN, v): + raise ValueError("Invalid email format") + return v.lower() + +class TokenPayload(BaseModel): + """Validation model for JWT token payload""" + user_id: str + username: str + exp: datetime + iat: datetime + scopes: Optional[List[str]] = [] + +class OAuthInput(BaseModel): + """Validation model for OAuth input""" + provider: str = Field(pattern='^(google|github|facebook)$') + code: str + redirect_uri: Optional[str] = None + + @field_validator('provider') + @classmethod + def validate_provider(cls, v: str) -> str: + valid_providers = ['google', 'github', 'facebook'] + if v.lower() not in valid_providers: + raise ValueError(f"Provider must be one of: {', '.join(valid_providers)}") + return v.lower() + +class AuthResponse(BaseModel): + """Validation model for authentication responses""" + success: bool + token: Optional[str] = None + error: Optional[str] = None + user: Optional[Dict[str, Union[str, int, bool]]] = None + + @field_validator('error') + @classmethod + def validate_error_if_not_success(cls, v: Optional[str], info) -> Optional[str]: + if not info.data.get('success') and not v: + raise ValueError("Error message required when success is False") + return v + + @field_validator('token') + @classmethod + def validate_token_if_success(cls, v: Optional[str], info) -> Optional[str]: + if info.data.get('success') and not v: + raise ValueError("Token required when success is True") + return v \ No newline at end of file diff --git a/orm/rbac.py b/orm/rbac.py new file mode 100644 index 00000000..be22701b --- /dev/null +++ b/orm/rbac.py @@ -0,0 +1,176 @@ +from services.db import REGISTRY, Base, local_session +from utils.logger import root_logger as logger + +from sqlalchemy.types import TypeDecorator +from sqlalchemy.types import String +from sqlalchemy import Column, ForeignKey, String, UniqueConstraint +from sqlalchemy.orm import relationship + +class ClassType(TypeDecorator): + impl = String + + @property + def python_type(self): + return NotImplemented + + def process_literal_param(self, value, dialect): + return NotImplemented + + def process_bind_param(self, value, dialect): + return value.__name__ if isinstance(value, type) else str(value) + + def process_result_value(self, value, dialect): + class_ = REGISTRY.get(value) + if class_ is None: + logger.warn(f"Can't find class <{value}>,find it yourself!", stacklevel=2) + return class_ + + +class Role(Base): + __tablename__ = "role" + + name = Column(String, nullable=False, comment="Role Name") + desc = Column(String, nullable=True, comment="Role Description") + community = Column( + ForeignKey("community.id", ondelete="CASCADE"), + nullable=False, + comment="Community", + ) + permissions = relationship(lambda: Permission) + + @staticmethod + def init_table(): + with local_session() as session: + r = session.query(Role).filter(Role.name == "author").first() + if r: + Role.default_role = r + return + + r1 = Role.create( + name="author", + desc="Role for an author", + community=1, + ) + + session.add(r1) + + Role.default_role = r1 + + r2 = Role.create( + name="reader", + desc="Role for a reader", + community=1, + ) + + session.add(r2) + + r3 = Role.create( + name="expert", + desc="Role for an expert", + community=1, + ) + + session.add(r3) + + r4 = Role.create( + name="editor", + desc="Role for an editor", + community=1, + ) + + session.add(r4) + + +class Operation(Base): + __tablename__ = "operation" + name = Column(String, nullable=False, unique=True, comment="Operation Name") + + @staticmethod + def init_table(): + with local_session() as session: + for name in ["create", "update", "delete", "load"]: + """ + * everyone can: + - load shouts + - load topics + - load reactions + - create an account to become a READER + * readers can: + - update and delete their account + - load chats + - load messages + - create reaction of some shout's author allowed kinds + - create shout to become an AUTHOR + * authors can: + - update and delete their shout + - invite other authors to edit shout and chat + - manage allowed reactions for their shout + * pros can: + - create/update/delete their community + - create/update/delete topics for their community + + """ + op = session.query(Operation).filter(Operation.name == name).first() + if not op: + op = Operation.create(name=name) + session.add(op) + session.commit() + + +class Resource(Base): + __tablename__ = "resource" + resourceClass = Column(String, nullable=False, unique=True, comment="Resource class") + name = Column(String, nullable=False, unique=True, comment="Resource name") + # TODO: community = Column(ForeignKey()) + + @staticmethod + def init_table(): + with local_session() as session: + for res in [ + "shout", + "topic", + "reaction", + "chat", + "message", + "invite", + "community", + "user", + ]: + r = session.query(Resource).filter(Resource.name == res).first() + if not r: + r = Resource.create(name=res, resourceClass=res) + session.add(r) + session.commit() + + +class Permission(Base): + __tablename__ = "permission" + __table_args__ = ( + UniqueConstraint("role", "operation", "resource"), + {"extend_existing": True}, + ) + + role: Column = Column(ForeignKey("role.id", ondelete="CASCADE"), nullable=False, comment="Role") + operation: Column = Column( + ForeignKey("operation.id", ondelete="CASCADE"), + nullable=False, + comment="Operation", + ) + resource: Column = Column( + ForeignKey("resource.id", ondelete="CASCADE"), + nullable=False, + comment="Resource", + ) + + +# if __name__ == "__main__": +# Base.metadata.create_all(engine) +# ops = [ +# Permission(role=1, operation=1, resource=1), +# Permission(role=1, operation=2, resource=1), +# Permission(role=1, operation=3, resource=1), +# Permission(role=1, operation=4, resource=1), +# Permission(role=2, operation=4, resource=1), +# ] +# global_session.add_all(ops) +# global_session.commit() \ No newline at end of file diff --git a/orm/user.py b/orm/user.py new file mode 100644 index 00000000..c31c7913 --- /dev/null +++ b/orm/user.py @@ -0,0 +1,105 @@ +from sqlalchemy import JSON as JSONType +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, func +from sqlalchemy.orm import relationship + +from services.db import Base, local_session +from orm.rbac import Role + + +class UserRating(Base): + __tablename__ = "user_rating" + + id = None + rater: Column = Column(ForeignKey("user.id"), primary_key=True, index=True) + user: Column = Column(ForeignKey("user.id"), primary_key=True, index=True) + value: Column = Column(Integer) + + @staticmethod + def init_table(): + pass + + +class UserRole(Base): + __tablename__ = "user_role" + + id = None + user = Column(ForeignKey("user.id"), primary_key=True, index=True) + role = Column(ForeignKey("role.id"), primary_key=True, index=True) + + +class AuthorFollower(Base): + __tablename__ = "author_follower" + + id = None + follower: Column = Column(ForeignKey("user.id"), primary_key=True, index=True) + author: Column = Column(ForeignKey("user.id"), primary_key=True, index=True) + createdAt = Column( + DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Created at" + ) + auto = Column(Boolean, nullable=False, default=False) + + +class User(Base): + __tablename__ = "user" + default_user = None + + email = Column(String, unique=True, nullable=False, comment="Email") + username = Column(String, nullable=False, comment="Login") + password = Column(String, nullable=True, comment="Password") + bio = Column(String, nullable=True, comment="Bio") # status description + about = Column(String, nullable=True, comment="About") # long and formatted + userpic = Column(String, nullable=True, comment="Userpic") + name = Column(String, nullable=True, comment="Display name") + slug = Column(String, unique=True, comment="User's slug") + muted = Column(Boolean, default=False) + emailConfirmed = Column(Boolean, default=False) + createdAt = Column( + DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Created at" + ) + lastSeen = Column( + DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Was online at" + ) + deletedAt = Column(DateTime(timezone=True), nullable=True, comment="Deleted at") + links = Column(JSONType, nullable=True, comment="Links") + oauth = Column(String, nullable=True) + ratings = relationship(UserRating, foreign_keys=UserRating.user) + roles = relationship(lambda: Role, secondary=UserRole.__tablename__) + oid = Column(String, nullable=True) + + @staticmethod + def init_table(): + with local_session() as session: + default = session.query(User).filter(User.slug == "anonymous").first() + if not default: + default_dict = { + "email": "noreply@discours.io", + "username": "noreply@discours.io", + "name": "Аноним", + "slug": "anonymous", + } + default = User.create(**default_dict) + session.add(default) + discours_dict = { + "email": "welcome@discours.io", + "username": "welcome@discours.io", + "name": "Дискурс", + "slug": "discours", + } + discours = User.create(**discours_dict) + session.add(discours) + session.commit() + User.default_user = default + + def get_permission(self): + scope = {} + for role in self.roles: + for p in role.permissions: + if p.resource not in scope: + scope[p.resource] = set() + scope[p.resource].add(p.operation) + print(scope) + return scope + + +# if __name__ == "__main__": +# print(User.get_permission(user_id=1)) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f46a3dc6..ccd0bcf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,8 @@ fakeredis = "^2.25.1" pydantic = "^2.9.2" jwt = "^1.3.1" authlib = "^1.3.2" +passlib = "^1.7.4" +bcrypt = "^4.2.1" [tool.poetry.group.dev.dependencies] @@ -34,13 +36,15 @@ isort = "^5.13.2" pydantic = "^2.9.2" pytest = "^8.3.4" mypy = "^1.15.0" +pytest-asyncio = "^0.23.5" +pytest-cov = "^4.1.0" [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" [tool.pyright] -venvPath = "." +venvPath = "venv" venv = "venv" [tool.isort] @@ -49,5 +53,10 @@ include_trailing_comma = true force_grid_wrap = 0 line_length = 120 +[tool.pytest.ini_options] +testpaths = ["tests"] +pythonpath = ["."] +venv = "venv" + [tool.ruff] line-length = 120 diff --git a/resolvers/draft.py b/resolvers/draft.py index 9516e1fb..c780de85 100644 --- a/resolvers/draft.py +++ b/resolvers/draft.py @@ -1,16 +1,21 @@ import time -from importlib import invalidate_caches - +from orm.topic import Topic from sqlalchemy import select +from sqlalchemy.sql import and_ -from cache.cache import invalidate_shout_related_cache, invalidate_shouts_cache +from cache.cache import ( + cache_author, cache_by_id, cache_topic, + invalidate_shout_related_cache, invalidate_shouts_cache +) from orm.author import Author from orm.draft import Draft -from orm.shout import Shout +from orm.shout import Shout, ShoutAuthor, ShoutTopic from services.auth import login_required from services.db import local_session from services.schema import mutation, query from utils.logger import root_logger as logger +from services.notify import notify_shout +from services.search import search_service @query.field("load_drafts") @@ -119,9 +124,10 @@ async def unpublish_draft(_, info, draft_id: int): @login_required async def publish_shout(_, info, shout_id: int, draft=None): """Publish draft as a shout or update existing shout. - + Args: - session: SQLAlchemy session to use for database operations + shout_id: ID существующей публикации или 0 для новой + draft: Объект черновика (опционально) """ user_id = info.context.get("user_id") author_dict = info.context.get("author", {}) @@ -130,16 +136,25 @@ async def publish_shout(_, info, shout_id: int, draft=None): return {"error": "User ID and author ID are required"} try: - # Use proper SQLAlchemy query with local_session() as session: + # Находим черновик если не передан if not draft: find_draft_stmt = select(Draft).where(Draft.shout == shout_id) draft = session.execute(find_draft_stmt).scalar_one_or_none() + if not draft: + return {"error": "Draft not found"} now = int(time.time()) - + + # Находим существующую публикацию или создаем новую + shout = None + was_published = False + if shout_id: + shout = session.query(Shout).filter(Shout.id == shout_id).first() + was_published = shout and shout.published_at is not None + if not shout: - # Create new shout from draft + # Создаем новую публикацию shout = Shout( body=draft.body, slug=draft.slug, @@ -155,15 +170,11 @@ async def publish_shout(_, info, shout_id: int, draft=None): seo=draft.seo, created_by=author_id, community=draft.community, - authors=draft.authors.copy(), # Create copies of relationships - topics=draft.topics.copy(), draft=draft.id, deleted_at=None, ) else: - # Update existing shout - shout.authors = draft.authors.copy() - shout.topics = draft.topics.copy() + # Обновляем существующую публикацию shout.draft = draft.id shout.created_by = author_id shout.title = draft.title @@ -178,24 +189,78 @@ async def publish_shout(_, info, shout_id: int, draft=None): shout.lang = draft.lang shout.seo = draft.seo + # Обновляем временные метки shout.updated_at = now - shout.published_at = now + + # Устанавливаем published_at только если это новая публикация + # или публикация была ранее снята с публикации + if not was_published: + shout.published_at = now + draft.updated_at = now draft.published_at = now + + # Обрабатываем связи с авторами + if not session.query(ShoutAuthor).filter( + and_(ShoutAuthor.shout == shout.id, ShoutAuthor.author == author_id) + ).first(): + sa = ShoutAuthor(shout=shout.id, author=author_id) + session.add(sa) + + # Обрабатываем темы + if draft.topics: + for topic in draft.topics: + st = ShoutTopic( + topic=topic.id, + shout=shout.id, + main=topic.main if hasattr(topic, 'main') else False + ) + session.add(st) + session.add(shout) session.add(draft) + session.flush() + + # Инвалидируем кэш только если это новая публикация или была снята с публикации + if not was_published: + cache_keys = [ + "feed", + f"author_{author_id}", + "random_top", + "unrated" + ] + + # Добавляем ключи для тем + for topic in shout.topics: + cache_keys.append(f"topic_{topic.id}") + cache_keys.append(f"topic_shouts_{topic.id}") + await cache_by_id(Topic, topic.id, cache_topic) + + # Инвалидируем кэш + await invalidate_shouts_cache(cache_keys) + await invalidate_shout_related_cache(shout, author_id) + + # Обновляем кэш авторов + for author in shout.authors: + await cache_by_id(Author, author.id, cache_author) + + # Отправляем уведомление о публикации + await notify_shout(shout.dict(), "published") + + # Обновляем поисковый индекс + search_service.index(shout) + else: + # Для уже опубликованных материалов просто отправляем уведомление об обновлении + await notify_shout(shout.dict(), "update") + session.commit() + return {"shout": shout} - invalidate_shout_related_cache(shout) - invalidate_shouts_cache() - return {"shout": shout} except Exception as e: - import traceback - - logger.error(f"Failed to publish shout: {e}") - logger.error(traceback.format_exc()) - session.rollback() - return {"error": "Failed to publish shout"} + logger.error(f"Failed to publish shout: {e}", exc_info=True) + if 'session' in locals(): + session.rollback() + return {"error": f"Failed to publish shout: {str(e)}"} @mutation.field("unpublish_shout") diff --git a/schema/type.graphql b/schema/type.graphql index df4be71a..a9941a62 100644 --- a/schema/type.graphql +++ b/schema/type.graphql @@ -188,6 +188,8 @@ type Topic { type CommonResult { error: String + drafts: [Draft] + draft: Draft slugs: [String] shout: Shout shouts: [Shout] diff --git a/services/db.py b/services/db.py index 8b36c406..394de75d 100644 --- a/services/db.py +++ b/services/db.py @@ -7,8 +7,7 @@ from typing import Any, Callable, Dict, TypeVar import sqlalchemy from sqlalchemy import JSON, Column, Engine, Integer, create_engine, event, exc, func, inspect -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Session, configure_mappers +from sqlalchemy.orm import Session, configure_mappers, declarative_base from sqlalchemy.sql.schema import Table from settings import DB_URL diff --git a/services/pretopic.py b/services/pretopic.py new file mode 100644 index 00000000..b6cd60c8 --- /dev/null +++ b/services/pretopic.py @@ -0,0 +1,179 @@ +import concurrent.futures +from typing import Dict, Tuple, List +from txtai.embeddings import Embeddings +from services.logger import root_logger as logger + +class TopicClassifier: + def __init__(self, shouts_by_topic: Dict[str, str], publications: List[Dict[str, str]]): + """ + Инициализация классификатора тем и поиска публикаций. + Args: + shouts_by_topic: Словарь {тема: текст_всех_публикаций} + publications: Список публикаций с полями 'id', 'title', 'text' + """ + self.shouts_by_topic = shouts_by_topic + self.topics = list(shouts_by_topic.keys()) + self.publications = publications + self.topic_embeddings = None # Для классификации тем + self.search_embeddings = None # Для поиска публикаций + self._initialization_future = None + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + + def initialize(self) -> None: + """ + Асинхронная инициализация векторных представлений. + """ + if self._initialization_future is None: + self._initialization_future = self._executor.submit(self._prepare_embeddings) + logger.info("Векторизация текстов начата в фоновом режиме...") + + def _prepare_embeddings(self) -> None: + """ + Подготавливает векторные представления для тем и поиска. + """ + logger.info("Начинается подготовка векторных представлений...") + + # Модель для русского языка + # TODO: model local caching + model_path = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + + # Инициализируем embeddings для классификации тем + self.topic_embeddings = Embeddings(path=model_path) + topic_documents = [ + (topic, text) + for topic, text in self.shouts_by_topic.items() + ] + self.topic_embeddings.index(topic_documents) + + # Инициализируем embeddings для поиска публикаций + self.search_embeddings = Embeddings(path=model_path) + search_documents = [ + (str(pub['id']), f"{pub['title']} {pub['text']}") + for pub in self.publications + ] + self.search_embeddings.index(search_documents) + + logger.info("Подготовка векторных представлений завершена.") + + def predict_topic(self, text: str) -> Tuple[float, str]: + """ + Предсказывает тему для заданного текста из известного набора тем. + Args: + text: Текст для классификации + Returns: + Tuple[float, str]: (уверенность, тема) + """ + if not self.is_ready(): + logger.error("Векторные представления не готовы. Вызовите initialize() и дождитесь завершения.") + return 0.0, "unknown" + + try: + # Ищем наиболее похожую тему + results = self.topic_embeddings.search(text, 1) + if not results: + return 0.0, "unknown" + + score, topic = results[0] + return float(score), topic + + except Exception as e: + logger.error(f"Ошибка при определении темы: {str(e)}") + return 0.0, "unknown" + + def search_similar(self, query: str, limit: int = 5) -> List[Dict[str, any]]: + """ + Ищет публикации похожие на поисковый запрос. + Args: + query: Поисковый запрос + limit: Максимальное количество результатов + Returns: + List[Dict]: Список найденных публикаций с оценкой релевантности + """ + if not self.is_ready(): + logger.error("Векторные представления не готовы. Вызовите initialize() и дождитесь завершения.") + return [] + + try: + # Ищем похожие публикации + results = self.search_embeddings.search(query, limit) + + # Формируем результаты + found_publications = [] + for score, pub_id in results: + # Находим публикацию по id + publication = next( + (pub for pub in self.publications if str(pub['id']) == pub_id), + None + ) + if publication: + found_publications.append({ + **publication, + 'relevance': float(score) + }) + + return found_publications + + except Exception as e: + logger.error(f"Ошибка при поиске публикаций: {str(e)}") + return [] + + def is_ready(self) -> bool: + """ + Проверяет, готовы ли векторные представления. + """ + return self.topic_embeddings is not None and self.search_embeddings is not None + + def wait_until_ready(self) -> None: + """ + Ожидает завершения подготовки векторных представлений. + """ + if self._initialization_future: + self._initialization_future.result() + + def __del__(self): + """ + Очистка ресурсов при удалении объекта. + """ + if self._executor: + self._executor.shutdown(wait=False) + +# Пример использования: +""" +shouts_by_topic = { + "Спорт": "... большой текст со всеми спортивными публикациями ...", + "Технологии": "... большой текст со всеми технологическими публикациями ...", + "Политика": "... большой текст со всеми политическими публикациями ..." +} + +publications = [ + { + 'id': 1, + 'title': 'Новый процессор AMD', + 'text': 'Компания AMD представила новый процессор...' + }, + { + 'id': 2, + 'title': 'Футбольный матч', + 'text': 'Вчера состоялся решающий матч...' + } +] + +# Создание классификатора +classifier = TopicClassifier(shouts_by_topic, publications) +classifier.initialize() +classifier.wait_until_ready() + +# Определение темы текста +text = "Новый процессор показал высокую производительность" +score, topic = classifier.predict_topic(text) +print(f"Тема: {topic} (уверенность: {score:.4f})") + +# Поиск похожих публикаций +query = "процессор AMD производительность" +similar_publications = classifier.search_similar(query, limit=3) +for pub in similar_publications: + print(f"\nНайдена публикация (релевантность: {pub['relevance']:.4f}):") + print(f"Заголовок: {pub['title']}") + print(f"Текст: {pub['text'][:100]}...") +""" + diff --git a/settings.py b/settings.py index a2a18efb..2760140b 100644 --- a/settings.py +++ b/settings.py @@ -5,7 +5,7 @@ PORT = 8000 DB_URL = ( environ.get("DATABASE_URL", "").replace("postgres://", "postgresql://") or environ.get("DB_URL", "").replace("postgres://", "postgresql://") - or "sqlite:///discoursio-db.sqlite3" + or "sqlite:///discoursio.db" ) REDIS_URL = environ.get("REDIS_URL") or "redis://127.0.0.1" AUTH_URL = environ.get("AUTH_URL") or "" @@ -15,3 +15,9 @@ MODE = "development" if "dev" in sys.argv else "production" ADMIN_SECRET = environ.get("AUTH_SECRET") or "nothing" WEBHOOK_SECRET = environ.get("WEBHOOK_SECRET") or "nothing-else" + +# own auth +ONETIME_TOKEN_LIFE_SPAN = 60 * 60 * 24 * 3 # 3 days +SESSION_TOKEN_LIFE_SPAN = 60 * 60 * 24 * 30 # 30 days +JWT_ALGORITHM = "HS256" +JWT_SECRET_KEY = environ.get("JWT_SECRET") or "nothing-else-jwt-secret-matters" \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..b45ef526 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,55 @@ +import asyncio +import os +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session +from starlette.testclient import TestClient + +from main import app +from services.db import Base +from services.redis import redis +from settings import DB_URL + +# Use SQLite for testing +TEST_DB_URL = "sqlite:///test.db" + +@pytest.fixture(scope="session") +def event_loop(): + """Create an instance of the default event loop for the test session.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + +@pytest.fixture(scope="session") +def test_engine(): + """Create a test database engine.""" + engine = create_engine(TEST_DB_URL) + Base.metadata.create_all(engine) + yield engine + Base.metadata.drop_all(engine) + os.remove("test.db") + +@pytest.fixture +def db_session(test_engine): + """Create a new database session for a test.""" + connection = test_engine.connect() + transaction = connection.begin() + session = Session(bind=connection) + + yield session + + session.close() + transaction.rollback() + connection.close() + +@pytest.fixture +async def redis_client(): + """Create a test Redis client.""" + await redis.connect() + yield redis + await redis.disconnect() + +@pytest.fixture +def test_client(): + """Create a TestClient instance.""" + return TestClient(app) \ No newline at end of file diff --git a/tests/test_drafts.py b/tests/test_drafts.py new file mode 100644 index 00000000..5cd44a76 --- /dev/null +++ b/tests/test_drafts.py @@ -0,0 +1,95 @@ +import pytest +from orm.shout import Shout +from orm.author import Author + +@pytest.fixture +def test_author(db_session): + """Create a test author.""" + author = Author( + name="Test Author", + slug="test-author", + user="test-user-id" + ) + db_session.add(author) + db_session.commit() + return author + +@pytest.fixture +def test_shout(db_session): + """Create test shout with required fields.""" + author = Author(name="Test Author", slug="test-author", user="test-user-id") + db_session.add(author) + db_session.flush() + + shout = Shout( + title="Test Shout", + slug="test-shout", + created_by=author.id, # Обязательное поле + body="Test body", + layout="article", + lang="ru" + ) + db_session.add(shout) + db_session.commit() + return shout + +@pytest.mark.asyncio +async def test_create_shout(test_client, db_session, test_author): + """Test creating a new shout.""" + response = test_client.post( + "/", + json={ + "query": """ + mutation CreateDraft($input: DraftInput!) { + create_draft(input: $input) { + error + draft { + id + title + body + } + } + } + """, + "variables": { + "input": { + "title": "Test Shout", + "body": "This is a test shout", + } + } + } + ) + + assert response.status_code == 200 + data = response.json() + assert "errors" not in data + assert data["data"]["create_draft"]["draft"]["title"] == "Test Shout" + +@pytest.mark.asyncio +async def test_load_drafts(test_client, db_session): + """Test retrieving a shout.""" + response = test_client.post( + "/", + json={ + "query": """ + query { + load_drafts { + error + drafts { + id + title + body + } + } + } + """, + "variables": { + "slug": "test-shout" + } + } + ) + + assert response.status_code == 200 + data = response.json() + assert "errors" not in data + assert data["data"]["load_drafts"]["drafts"] == [] \ No newline at end of file diff --git a/tests/test_reactions.py b/tests/test_reactions.py new file mode 100644 index 00000000..0a37a20d --- /dev/null +++ b/tests/test_reactions.py @@ -0,0 +1,64 @@ +import pytest +from orm.reaction import Reaction, ReactionKind +from orm.shout import Shout +from orm.author import Author +from datetime import datetime + +@pytest.fixture +def test_setup(db_session): + """Set up test data.""" + now = int(datetime.now().timestamp()) + author = Author(name="Test Author", slug="test-author", user="test-user-id") + db_session.add(author) + db_session.flush() + + shout = Shout( + title="Test Shout", + slug="test-shout", + created_by=author.id, + body="This is a test shout", + layout="article", + lang="ru", + community=1, + created_at=now, + updated_at=now + ) + db_session.add_all([author, shout]) + db_session.commit() + return {"author": author, "shout": shout} + +@pytest.mark.asyncio +async def test_create_reaction(test_client, db_session, test_setup): + """Test creating a reaction on a shout.""" + response = test_client.post( + "/", + json={ + "query": """ + mutation CreateReaction($reaction: ReactionInput!) { + create_reaction(reaction: $reaction) { + error + reaction { + id + kind + body + created_by { + name + } + } + } + } + """, + "variables": { + "reaction": { + "shout": test_setup["shout"].id, + "kind": ReactionKind.LIKE.value, + "body": "Great post!" + } + } + } + ) + + assert response.status_code == 200 + data = response.json() + assert "error" not in data + assert data["data"]["create_reaction"]["reaction"]["kind"] == ReactionKind.LIKE.value \ No newline at end of file diff --git a/tests/test_shouts.py b/tests/test_shouts.py new file mode 100644 index 00000000..66f3b1e9 --- /dev/null +++ b/tests/test_shouts.py @@ -0,0 +1,83 @@ +import pytest +from orm.author import Author +from orm.shout import Shout +from datetime import datetime + +@pytest.fixture +def test_shout(db_session): + """Create test shout with required fields.""" + now = int(datetime.now().timestamp()) + author = Author(name="Test Author", slug="test-author", user="test-user-id") + db_session.add(author) + db_session.flush() + + now = int(datetime.now().timestamp()) + + shout = Shout( + title="Test Shout", + slug="test-shout", + created_by=author.id, + body="Test body", + layout="article", + lang="ru", + community=1, + created_at=now, + updated_at=now + ) + db_session.add(shout) + db_session.commit() + return shout + +@pytest.mark.asyncio +async def test_get_shout(test_client, db_session): + """Test retrieving a shout.""" + # Создаем автора + author = Author(name="Test Author", slug="test-author", user="test-user-id") + db_session.add(author) + db_session.flush() + now = int(datetime.now().timestamp()) + + # Создаем публикацию со всеми обязательными полями + shout = Shout( + title="Test Shout", + body="This is a test shout", + slug="test-shout", + created_by=author.id, + layout="article", + lang="ru", + community=1, + created_at=now, + updated_at=now + ) + db_session.add(shout) + db_session.commit() + + response = test_client.post( + "/", + json={ + "query": """ + query GetShout($slug: String!) { + get_shout(slug: $slug) { + id + title + body + created_at + updated_at + created_by { + id + name + slug + } + } + } + """, + "variables": { + "slug": "test-shout" + } + } + ) + + data = response.json() + assert response.status_code == 200 + assert "errors" not in data + assert data["data"]["get_shout"]["title"] == "Test Shout" \ No newline at end of file diff --git a/tests/test_validations.py b/tests/test_validations.py new file mode 100644 index 00000000..2cffae7a --- /dev/null +++ b/tests/test_validations.py @@ -0,0 +1,101 @@ +import pytest +from datetime import datetime, timedelta +from pydantic import ValidationError + +from auth.validations import ( + AuthInput, + UserRegistrationInput, + UserLoginInput, + TokenPayload, + OAuthInput, + AuthResponse +) + +class TestAuthValidations: + def test_auth_input(self): + """Test basic auth input validation""" + # Valid case + auth = AuthInput( + user_id="123", + username="testuser", + token="1234567890abcdef1234567890abcdef" + ) + assert auth.user_id == "123" + assert auth.username == "testuser" + + # Invalid cases + with pytest.raises(ValidationError): + AuthInput(user_id="", username="test", token="x" * 32) + + with pytest.raises(ValidationError): + AuthInput(user_id="123", username="t", token="x" * 32) + + def test_user_registration(self): + """Test user registration validation""" + # Valid case + user = UserRegistrationInput( + email="test@example.com", + password="SecurePass123!", + name="Test User" + ) + assert user.email == "test@example.com" + assert user.name == "Test User" + + # Test email validation + with pytest.raises(ValidationError) as exc: + UserRegistrationInput( + email="invalid-email", + password="SecurePass123!", + name="Test" + ) + assert "Invalid email format" in str(exc.value) + + # Test password validation + with pytest.raises(ValidationError) as exc: + UserRegistrationInput( + email="test@example.com", + password="weak", + name="Test" + ) + assert "String should have at least 8 characters" in str(exc.value) + + def test_token_payload(self): + """Test token payload validation""" + now = datetime.utcnow() + exp = now + timedelta(hours=1) + + payload = TokenPayload( + user_id="123", + username="testuser", + exp=exp, + iat=now + ) + assert payload.user_id == "123" + assert payload.username == "testuser" + assert payload.scopes == [] # Default empty list + + def test_auth_response(self): + """Test auth response validation""" + # Success case + success_resp = AuthResponse( + success=True, + token="valid_token", + user={"id": "123", "name": "Test"} + ) + assert success_resp.success is True + assert success_resp.token == "valid_token" + + # Error case + error_resp = AuthResponse( + success=False, + error="Invalid credentials" + ) + assert error_resp.success is False + assert error_resp.error == "Invalid credentials" + + # Invalid case - отсутствует обязательное поле token при success=True + with pytest.raises(ValidationError): + AuthResponse( + success=True, + user={"id": "123", "name": "Test"} + ) \ No newline at end of file