From 8b39b47714fa85ecf1c0c604ab8628d59bc013f2 Mon Sep 17 00:00:00 2001 From: Untone Date: Sat, 17 Feb 2024 02:56:15 +0300 Subject: [PATCH] fixed-fmt-linted --- .gitignore | 1 + Dockerfile | 9 +- main.py | 12 +- orm/author.py | 29 ++-- orm/notification.py | 35 +++-- pyproject.toml | 122 ++++++++--------- resolvers/listener.py | 22 +-- resolvers/load.py | 295 ++++++++++++++++++++--------------------- resolvers/model.py | 5 +- resolvers/schema.py | 3 +- resolvers/seen.py | 40 +++--- server.py | 71 +++++----- services/auth.py | 51 +++---- services/core.py | 23 ++-- services/db.py | 13 +- services/rediscache.py | 27 ++-- settings.py | 17 +-- 17 files changed, 384 insertions(+), 391 deletions(-) diff --git a/.gitignore b/.gitignore index ba134b3..55ddd88 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ __pycache__ poetry.lock .venv .ruff_cache +.pytest_cache diff --git a/Dockerfile b/Dockerfile index 558db76..3e36460 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,20 +1,19 @@ # Use an official Python runtime as a parent image -FROM python:3.12-slim +FROM python:3.12-alpine # Set the working directory in the container to /app WORKDIR /app -# Add metadata to the image to describe that the container is listening on port 80 +# Add metadata to the image to describe that the container is listening on port 8000 EXPOSE 8000 # Copy the current directory contents into the container at /app COPY . /app # Install any needed packages specified in pyproject.toml -RUN apt-get update && apt-get install -y gcc curl && \ +RUN apk update && apk add --no-cache gcc curl && \ curl -sSL https://install.python-poetry.org | python - && \ - echo "export PATH=$PATH:/root/.local/bin" >> ~/.bashrc && \ - . ~/.bashrc && \ + export PATH=$PATH:/root/.local/bin && \ poetry config virtualenvs.create false && \ poetry install --no-dev diff --git a/main.py b/main.py index 9b7f841..440c7d3 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import asyncio +import logging import os from os.path import exists @@ -13,11 +14,10 @@ from resolvers.listener import notifications_worker from resolvers.schema import schema from services.rediscache import redis from settings import DEV_SERVER_PID_FILE_NAME, MODE, SENTRY_DSN -import logging logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger("\t[main]\t") +logger = logging.getLogger('\t[main]\t') logger.setLevel(logging.DEBUG) @@ -27,9 +27,9 @@ async def start_up(): task = asyncio.create_task(notifications_worker()) logger.info(task) - if MODE == "dev": + if MODE == 'dev': if exists(DEV_SERVER_PID_FILE_NAME): - with open(DEV_SERVER_PID_FILE_NAME, "w", encoding="utf-8") as f: + with open(DEV_SERVER_PID_FILE_NAME, 'w', encoding='utf-8') as f: f.write(str(os.getpid())) else: try: @@ -46,7 +46,7 @@ async def start_up(): ], ) except Exception as e: - logger.error("sentry init error", e) + logger.error('sentry init error', e) async def shutdown(): @@ -54,4 +54,4 @@ async def shutdown(): app = Starlette(debug=True, on_startup=[start_up], on_shutdown=[shutdown]) -app.mount("/", GraphQL(schema, debug=True)) +app.mount('/', GraphQL(schema, debug=True)) diff --git a/orm/author.py b/orm/author.py index ba9fcc5..49bf6d1 100644 --- a/orm/author.py +++ b/orm/author.py @@ -1,46 +1,45 @@ import time -from sqlalchemy import JSON as JSONType -from sqlalchemy import Boolean, Column, ForeignKey, Integer, String +from sqlalchemy import JSON, Boolean, Column, ForeignKey, Integer, String from sqlalchemy.orm import relationship from services.db import Base class AuthorRating(Base): - __tablename__ = "author_rating" + __tablename__ = 'author_rating' id = None # type: ignore - rater = Column(ForeignKey("author.id"), primary_key=True, index=True) - author = Column(ForeignKey("author.id"), primary_key=True, index=True) + rater = Column(ForeignKey('author.id'), primary_key=True, index=True) + author = Column(ForeignKey('author.id'), primary_key=True, index=True) plus = Column(Boolean) class AuthorFollower(Base): - __tablename__ = "author_follower" + __tablename__ = 'author_follower' id = None # type: ignore - follower = Column(ForeignKey("author.id"), primary_key=True, index=True) - author = Column(ForeignKey("author.id"), primary_key=True, index=True) + follower = Column(ForeignKey('author.id'), primary_key=True, index=True) + author = Column(ForeignKey('author.id'), primary_key=True, index=True) created_at = Column(Integer, nullable=False, default=lambda: int(time.time())) auto = Column(Boolean, nullable=False, default=False) class Author(Base): - __tablename__ = "author" + __tablename__ = 'author' user = Column(String, unique=True) # unbounded link with authorizer's User type - name = Column(String, nullable=True, comment="Display name") + name = Column(String, nullable=True, comment='Display name') slug = Column(String, unique=True, comment="Author's slug") - bio = Column(String, nullable=True, comment="Bio") # status description - about = Column(String, nullable=True, comment="About") # long and formatted - pic = Column(String, nullable=True, comment="Picture") - links = Column(JSONType, nullable=True, comment="Links") + bio = Column(String, nullable=True, comment='Bio') # status description + about = Column(String, nullable=True, comment='About') # long and formatted + pic = Column(String, nullable=True, comment='Picture') + links = Column(JSON, nullable=True, comment='Links') ratings = relationship(AuthorRating, foreign_keys=AuthorRating.author) created_at = Column(Integer, nullable=False, default=lambda: int(time.time())) last_seen = Column(Integer, nullable=False, default=lambda: int(time.time())) updated_at = Column(Integer, nullable=False, default=lambda: int(time.time())) - deleted_at = Column(Integer, nullable=True, comment="Deleted at") + deleted_at = Column(Integer, nullable=True, comment='Deleted at') diff --git a/orm/notification.py b/orm/notification.py index 6fc08df..2b09ea1 100644 --- a/orm/notification.py +++ b/orm/notification.py @@ -1,42 +1,41 @@ +import time from enum import Enum as Enumeration -from sqlalchemy import JSON as JSONType -from sqlalchemy import Column, ForeignKey, Integer, String +from sqlalchemy import JSON, Column, ForeignKey, Integer, String from sqlalchemy.orm import relationship from orm.author import Author from services.db import Base -import time class NotificationEntity(Enumeration): - REACTION = "reaction" - SHOUT = "shout" - FOLLOWER = "follower" + REACTION = 'reaction' + SHOUT = 'shout' + FOLLOWER = 'follower' class NotificationAction(Enumeration): - CREATE = "create" - UPDATE = "update" - DELETE = "delete" - SEEN = "seen" - FOLLOW = "follow" - UNFOLLOW = "unfollow" + CREATE = 'create' + UPDATE = 'update' + DELETE = 'delete' + SEEN = 'seen' + FOLLOW = 'follow' + UNFOLLOW = 'unfollow' class NotificationSeen(Base): - __tablename__ = "notification_seen" + __tablename__ = 'notification_seen' - viewer = Column(ForeignKey("author.id")) - notification = Column(ForeignKey("notification.id")) + viewer = Column(ForeignKey('author.id')) + notification = Column(ForeignKey('notification.id')) class Notification(Base): - __tablename__ = "notification" + __tablename__ = 'notification' created_at = Column(Integer, server_default=str(int(time.time()))) entity = Column(String, nullable=False) action = Column(String, nullable=False) - payload = Column(JSONType, nullable=True) + payload = Column(JSON, nullable=True) - seen = relationship(lambda: Author, secondary="notification_seen") + seen = relationship(lambda: Author, secondary='notification_seen') diff --git a/pyproject.toml b/pyproject.toml index 191176b..4daf5fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,8 @@ -[build-system] -requires = ["poetry-core>=1.0.0"] -build-backend = "poetry.core.masonry.api" - [tool.poetry] name = "discoursio-notifier" -version = "0.2.19" +version = "0.3.0" description = "notifier server for discours.io" -authors = ["discours.io devteam"] +authors = ["Tony Rewin "] [tool.poetry.dependencies] python = "^3.12" @@ -21,48 +17,68 @@ granian = "^1.0.2" [tool.poetry.group.dev.dependencies] setuptools = "^69.0.2" -pytest = "^7.4.2" -black = { version = "^23.12.0", python = ">=3.12" } -ruff = { version = "^0.1.15", python = ">=3.12" } -mypy = { version = "^1.7", python = ">=3.12" } isort = "^5.13.2" pyright = "^1.1.341" -pre-commit = "^3.6.0" -pytest-asyncio = "^0.23.4" -pytest-cov = "^4.1.0" +mypy = "^1.7.1" +ruff = "^0.1.15" +black = "^23.12.0" +pytest = "^7.4.3" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" + +[tool.ruff] +line-length = 120 +extend-select = [ + # E and F are enabled by default + 'B', # flake8-bugbear + 'C4', # flake8-comprehensions + 'C90', # mccabe + 'I', # isort + 'N', # pep8-naming + 'Q', # flake8-quotes + 'RUF100', # ruff (unused noqa) + 'S', # flake8-bandit + 'W', # pycodestyle +] +extend-ignore = [ + 'B008', # function calls in args defaults are fine + 'B009', # getattr with constants is fine + 'B034', # re.split won't confuse us + 'B904', # rising without from is fine + 'E501', # leave line length to black + 'N818', # leave to us exceptions naming + 'S101', # assert is fine + 'RUF100', # black's noqa +] +flake8-quotes = { inline-quotes = 'single', multiline-quotes = 'double' } +mccabe = { max-complexity = 13 } +target-version = "py312" + +[tool.ruff.format] +quote-style = 'single' [tool.black] -line-length = 120 -target-version = ['py312'] -include = '\.pyi?$' -exclude = ''' +skip-string-normalization = true -( - /( - \.eggs # exclude a few common directories in the - | \.git # root of the project - | \.hg - | \.mypy_cache - | \.tox - | \.venv - | _build - | buck-out - | build - | dist - )/ - | foo.py # also separately exclude a file named foo.py in - # the root of the project -) -''' +[tool.ruff.isort] +combine-as-imports = true +lines-after-imports = 2 +known-first-party = ['resolvers', 'services', 'orm', 'tests'] -[tool.isort] -multi_line_output = 3 -include_trailing_comma = true -force_grid_wrap = 0 -use_parentheses = true -ensure_newline_before_comments = true -line_length = 120 +[tool.ruff.per-file-ignores] +'tests/**' = ['B018', 'S110', 'S501'] +[tool.mypy] +python_version = "3.12" +warn_return_any = true +warn_unused_configs = true +ignore_missing_imports = true +exclude = ["nb"] + +[tool.pytest.ini_options] +asyncio_mode = 'auto' [tool.pyright] venvPath = "." @@ -90,27 +106,3 @@ logLevel = "Information" pluginSearchPaths = [] typings = {} mergeTypeStubPackages = false - -[tool.mypy] -python_version = "3.12" -warn_unused_configs = true -plugins = ["mypy_sqlalchemy.plugin", "strawberry.ext.mypy_plugin"] - -[tool.ruff] -# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. -# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or -# McCabe complexity (`C901`) by default. -select = ["E4", "E7", "E9", "F"] -ignore = [] -line-length = 120 -target-version = "py312" - - - -[tool.pytest.ini_options] -pythonpath = [ - "." -] - -[tool.pytest] -python_files = "*_test.py" diff --git a/resolvers/listener.py b/resolvers/listener.py index a7b70bc..95b762c 100644 --- a/resolvers/listener.py +++ b/resolvers/listener.py @@ -1,11 +1,13 @@ -from orm.notification import Notification -from resolvers.model import NotificationReaction, NotificationAuthor, NotificationShout -from services.db import local_session -from services.rediscache import redis import asyncio import logging -logger = logging.getLogger("[listener.listen_task] ") +from orm.notification import Notification +from resolvers.model import NotificationAuthor, NotificationReaction, NotificationShout +from services.db import local_session +from services.rediscache import redis + + +logger = logging.getLogger('[listener.listen_task] ') logger.setLevel(logging.DEBUG) @@ -19,8 +21,8 @@ async def handle_notification(n: ServiceMessage, channel: str): """создаеёт новое хранимое уведомление""" with local_session() as session: try: - if channel.startswith("follower:"): - author_id = int(channel.split(":")[1]) + if channel.startswith('follower:'): + author_id = int(channel.split(':')[1]) if isinstance(n.payload, NotificationAuthor): n.payload.following_id = author_id n = Notification(action=n.action, entity=n.entity, payload=n.payload) @@ -28,7 +30,7 @@ async def handle_notification(n: ServiceMessage, channel: str): session.commit() except Exception as e: session.rollback() - logger.error(f"[listener.handle_reaction] error: {str(e)}") + logger.error(f'[listener.handle_reaction] error: {str(e)}') async def listen_task(pattern): @@ -38,9 +40,9 @@ async def listen_task(pattern): notification_message = ServiceMessage(**message_data) await handle_notification(notification_message, str(channel)) except Exception as e: - logger.error(f"Error processing notification: {str(e)}") + logger.error(f'Error processing notification: {str(e)}') async def notifications_worker(): # Use asyncio.gather to run tasks concurrently - await asyncio.gather(listen_task("follower:*"), listen_task("reaction"), listen_task("shout")) + await asyncio.gather(listen_task('follower:*'), listen_task('reaction'), listen_task('shout')) diff --git a/resolvers/load.py b/resolvers/load.py index f7ade53..04de03b 100644 --- a/resolvers/load.py +++ b/resolvers/load.py @@ -1,53 +1,31 @@ +import json +import logging +import time +from typing import Dict, List, Tuple, Union + +import strawberry +from sqlalchemy import and_, select +from sqlalchemy.orm import aliased from sqlalchemy.sql import not_ -from services.db import local_session + +from orm.notification import Notification, NotificationAction, NotificationEntity, NotificationSeen from resolvers.model import ( - NotificationReaction, - NotificationGroup, - NotificationShout, NotificationAuthor, + NotificationGroup, + NotificationReaction, + NotificationShout, NotificationsResult, ) -from orm.notification import NotificationAction, NotificationEntity, NotificationSeen, Notification -from typing import Dict, List -import time -import json -import strawberry -from sqlalchemy.orm import aliased -from sqlalchemy import select, and_ -import logging +from services.db import local_session -logger = logging.getLogger("[resolvers.schema] ") + +logger = logging.getLogger('[resolvers.schema]') logger.setLevel(logging.DEBUG) -async def get_notifications_grouped(author_id: int, after: int = 0, limit: int = 10, offset: int = 0): - """ - Retrieves notifications for a given author. - - Args: - author_id (int): The ID of the author for whom notifications are retrieved. - after (int, optional): If provided, selects only notifications created after this timestamp will be considered. - limit (int, optional): The maximum number of groupa to retrieve. - offset (int, optional): Offset for pagination - - Returns: - Dict[str, NotificationGroup], int, int: A dictionary where keys are thread IDs and values are NotificationGroup objects, unread and total amounts. - - This function queries the database to retrieve notifications for the specified author, considering optional filters. - The result is a dictionary where each key is a thread ID, and the corresponding value is a NotificationGroup - containing information about the notifications within that thread. - - NotificationGroup structure: - { - entity: str, # Type of entity (e.g., 'reaction', 'shout', 'follower'). - updated_at: int, # Timestamp of the latest update in the thread. - shout: Optional[NotificationShout] - reactions: List[int], # List of reaction ids within the thread. - authors: List[NotificationAuthor], # List of authors involved in the thread. - } - """ - NotificationSeenAlias = aliased(NotificationSeen) - query = select(Notification, NotificationSeenAlias.viewer.label("seen")).outerjoin( +def query_notifications(author_id: int, after: int = 0) -> Tuple[int, int, List[Tuple[Notification, bool]]]: + notification_seen_alias = aliased(NotificationSeen) + query = select(Notification, notification_seen_alias.viewer.label('seen')).outerjoin( NotificationSeen, and_(NotificationSeen.viewer == author_id, NotificationSeen.notification == Notification.id), ) @@ -55,17 +33,13 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int = query = query.filter(Notification.created_at > after) query = query.group_by(NotificationSeen.notification, Notification.created_at) - groups_amount = 0 - unread = 0 - total = 0 - notifications_by_thread: Dict[str, List[Notification]] = {} - groups_by_thread: Dict[str, NotificationGroup] = {} with local_session() as session: total = ( session.query(Notification) .filter(and_(Notification.action == NotificationAction.CREATE.value, Notification.created_at > after)) .count() ) + unread = ( session.query(Notification) .filter( @@ -77,123 +51,140 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int = ) .count() ) + notifications_result = session.execute(query) + notifications = [] for n, seen in notifications_result: - thread_id = "" - payload = json.loads(n.payload) - logger.debug(f"[resolvers.schema] {n.action} {n.entity}: {payload}") - if n.entity == "shout" and n.action == "create": - shout: NotificationShout = payload - thread_id += f"{shout.id}" - logger.debug(f"create shout: {shout}") - group = groups_by_thread.get(thread_id) or NotificationGroup( - id=thread_id, - entity=n.entity, - shout=shout, - authors=shout.authors, - updated_at=shout.created_at, - reactions=[], - action="create", - seen=author_id in n.seen, - ) - # store group in result - groups_by_thread[thread_id] = group - notifications = notifications_by_thread.get(thread_id, []) - if n not in notifications: - notifications.append(n) - notifications_by_thread[thread_id] = notifications - groups_amount += 1 - elif n.entity == NotificationEntity.REACTION.value and n.action == NotificationAction.CREATE.value: - reaction: NotificationReaction = payload - shout: NotificationShout = reaction.shout - thread_id += f"{reaction.shout}" - if reaction.kind == "LIKE" or reaction.kind == "DISLIKE": - # TODO: making published reaction vote announce - pass - elif reaction.kind == "COMMENT": - if reaction.reply_to: - thread_id += f"{'::' + str(reaction.reply_to)}" - group: NotificationGroup | None = groups_by_thread.get(thread_id) - notifications: List[Notification] = notifications_by_thread.get(thread_id, []) - if group and notifications: - group.seen = False # any not seen notification make it false - group.shout = shout - group.authors.append(reaction.created_by) - if not group.reactions: - group.reactions = [] - group.reactions.append(reaction.id) - # store group in result - groups_by_thread[thread_id] = group - notifications = notifications_by_thread.get(thread_id, []) - if n not in notifications: - notifications.append(n) - notifications_by_thread[thread_id] = notifications - groups_amount += 1 - else: - groups_amount += 1 - if groups_amount > limit: - break - else: - # init notification group - reactions = [] - reactions.append(reaction.id) - group = NotificationGroup( - id=thread_id, - action=n.action, - entity=n.entity, - updated_at=reaction.created_at, - reactions=reactions, - shout=shout, - authors=[ - reaction.created_by, - ], - seen=author_id in n.seen, - ) - # store group in result - groups_by_thread[thread_id] = group - notifications = notifications_by_thread.get(thread_id, []) - if n not in notifications: - notifications.append(n) - notifications_by_thread[thread_id] = notifications + notifications.append((n, seen)) - elif n.entity == "follower": - thread_id = "followers" - follower: NotificationAuthor = payload - group = groups_by_thread.get(thread_id) or NotificationGroup( - id=thread_id, - authors=[follower], - updated_at=int(time.time()), - shout=None, - reactions=[], - entity="follower", - action="follow", - seen=author_id in n.seen, - ) - group.authors = [ - follower, - ] - group.updated_at = int(time.time()) - # store group in result + return total, unread, notifications + + +def process_shout_notification( + notification: Notification, seen: bool +) -> Union[Tuple[str, NotificationGroup], None] | None: + if not isinstance(notification.payload, str) or not isinstance(notification.entity, str): + return + payload = json.loads(notification.payload) + shout: NotificationShout = payload + thread_id = str(shout.id) + group = NotificationGroup( + id=thread_id, + entity=notification.entity, + shout=shout, + authors=shout.authors, + updated_at=shout.created_at, + reactions=[], + action='create', + seen=seen, + ) + return thread_id, group + + +def process_reaction_notification( + notification: Notification, seen: bool +) -> Union[Tuple[str, NotificationGroup], None] | None: + if ( + not isinstance(notification, Notification) + or not isinstance(notification.payload, str) + or not isinstance(notification.entity, str) + ): + return + payload = json.loads(notification.payload) + reaction: NotificationReaction = payload + shout: NotificationShout = reaction.shout + thread_id = str(reaction.shout) + if reaction.kind == 'COMMENT' and reaction.reply_to: + thread_id += f'::{reaction.reply_to}' + group = NotificationGroup( + id=thread_id, + action=str(notification.action), + entity=notification.entity, + updated_at=reaction.created_at, + reactions=[reaction.id], + shout=shout, + authors=[reaction.created_by], + seen=seen, + ) + return thread_id, group + + +def process_follower_notification( + notification: Notification, seen: bool +) -> Union[Tuple[str, NotificationGroup], None] | None: + if not isinstance(notification.payload, str): + return + payload = json.loads(notification.payload) + follower: NotificationAuthor = payload + thread_id = 'followers' + group = NotificationGroup( + id=thread_id, + authors=[follower], + updated_at=int(time.time()), + shout=None, + reactions=[], + entity='follower', + action='follow', + seen=seen, + ) + return thread_id, group + + +async def get_notifications_grouped( + author_id: int, after: int = 0, limit: int = 10 +) -> Tuple[Dict[str, NotificationGroup], int, int]: + total, unread, notifications = query_notifications(author_id, after) + groups_by_thread: Dict[str, NotificationGroup] = {} + groups_amount = 0 + + for notification, seen in notifications: + if groups_amount >= limit: + break + + if str(notification.entity) == 'shout' and str(notification.action) == 'create': + result = process_shout_notification(notification, seen) + if result: + thread_id, group = result groups_by_thread[thread_id] = group - notifications = notifications_by_thread.get(thread_id, []) - if n not in notifications: - notifications.append(n) - notifications_by_thread[thread_id] = notifications groups_amount += 1 - if groups_amount > limit: - break + elif ( + str(notification.entity) == NotificationEntity.REACTION.value + and str(notification.action) == NotificationAction.CREATE.value + ): + result = process_reaction_notification(notification, seen) + if result: + thread_id, group = result + existing_group = groups_by_thread.get(thread_id) + if existing_group: + existing_group.seen = False + existing_group.shout = group.shout + existing_group.authors.append(group.authors[0]) + if not existing_group.reactions: + existing_group.reactions = [] + existing_group.reactions.extend(group.reactions or []) + groups_by_thread[thread_id] = existing_group + else: + groups_by_thread[thread_id] = group + groups_amount += 1 - return groups_by_thread, notifications_by_thread, unread, total + elif str(notification.entity) == 'follower': + result = process_follower_notification(notification, seen) + if result: + thread_id, group = result + groups_by_thread[thread_id] = group + groups_amount += 1 + + return groups_by_thread, unread, total @strawberry.type class Query: @strawberry.field async def load_notifications(self, info, after: int, limit: int = 50, offset: int = 0) -> NotificationsResult: - author_id = info.context.get("author_id") - groups: Dict[str, NotificationGroup] = {} + author_id = info.context.get('author_id') if author_id: - groups, notifications, total, unread = await get_notifications_grouped(author_id, after, limit, offset) - notifications = sorted(groups.values(), key=lambda group: group.updated_at, reverse=True) - return NotificationsResult(notifications=notifications, total=0, unread=0, error=None) + groups, unread, total = await get_notifications_grouped(author_id, after, limit) + notifications = sorted(groups.values(), key=lambda group: group.updated_at, reverse=True) + return NotificationsResult(notifications=notifications, total=total, unread=unread, error=None) + return NotificationsResult(notifications=[], total=0, unread=0, error=None) diff --git a/resolvers/model.py b/resolvers/model.py index 357f2bf..0ccc9b4 100644 --- a/resolvers/model.py +++ b/resolvers/model.py @@ -1,8 +1,11 @@ -import strawberry from typing import List, Optional + +import strawberry from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper + from orm.notification import Notification as NotificationMessage + strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper() diff --git a/resolvers/schema.py b/resolvers/schema.py index da838d9..da45ad1 100644 --- a/resolvers/schema.py +++ b/resolvers/schema.py @@ -1,11 +1,12 @@ import strawberry from strawberry.schema.config import StrawberryConfig -from services.auth import LoginRequiredMiddleware from resolvers.load import Query from resolvers.seen import Mutation +from services.auth import LoginRequiredMiddleware from services.db import Base, engine + schema = strawberry.Schema( query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False), extensions=[LoginRequiredMiddleware] ) diff --git a/resolvers/seen.py b/resolvers/seen.py index 0629a38..d1cce6f 100644 --- a/resolvers/seen.py +++ b/resolvers/seen.py @@ -1,14 +1,14 @@ -from sqlalchemy import and_ -from orm.notification import NotificationSeen -from services.db import local_session -from resolvers.model import Notification, NotificationSeenResult, NotificationReaction +import json +import logging import strawberry -import logging -import json - +from sqlalchemy import and_ from sqlalchemy.exc import SQLAlchemyError +from orm.notification import NotificationSeen +from resolvers.model import Notification, NotificationReaction, NotificationSeenResult +from services.db import local_session + logger = logging.getLogger(__name__) @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) class Mutation: @strawberry.mutation async def mark_seen(self, info, notification_id: int) -> NotificationSeenResult: - author_id = info.context.get("author_id") + author_id = info.context.get('author_id') if author_id: with local_session() as session: try: @@ -27,9 +27,9 @@ class Mutation: except SQLAlchemyError as e: session.rollback() logger.error( - f"[mark_notification_as_read] Ошибка при обновлении статуса прочтения уведомления: {str(e)}" + f'[mark_notification_as_read] Ошибка при обновлении статуса прочтения уведомления: {str(e)}' ) - return NotificationSeenResult(error="cant mark as read") + return NotificationSeenResult(error='cant mark as read') return NotificationSeenResult(error=None) @strawberry.mutation @@ -37,7 +37,7 @@ class Mutation: # TODO: use latest loaded notification_id as input offset parameter error = None try: - author_id = info.context.get("author_id") + author_id = info.context.get('author_id') if author_id: with local_session() as session: nnn = session.query(Notification).filter(and_(Notification.created_at > after)).all() @@ -50,22 +50,22 @@ class Mutation: session.rollback() except Exception as e: print(e) - error = "cant mark as read" + error = 'cant mark as read' return NotificationSeenResult(error=error) @strawberry.mutation async def mark_seen_thread(self, info, thread: str, after: int) -> NotificationSeenResult: error = None - author_id = info.context.get("author_id") + author_id = info.context.get('author_id') if author_id: - [shout_id, reply_to_id] = thread.split("::") + [shout_id, reply_to_id] = thread.split('::') with local_session() as session: # TODO: handle new follower and new shout notifications new_reaction_notifications = ( session.query(Notification) .filter( - Notification.action == "create", - Notification.entity == "reaction", + Notification.action == 'create', + Notification.entity == 'reaction', Notification.created_at > after, ) .all() @@ -73,13 +73,13 @@ class Mutation: removed_reaction_notifications = ( session.query(Notification) .filter( - Notification.action == "delete", - Notification.entity == "reaction", + Notification.action == 'delete', + Notification.entity == 'reaction', Notification.created_at > after, ) .all() ) - exclude = set([]) + exclude = set() for nr in removed_reaction_notifications: reaction: NotificationReaction = json.loads(nr.payload) exclude.add(reaction.id) @@ -97,5 +97,5 @@ class Mutation: except Exception: session.rollback() else: - error = "You are not logged in" + error = 'You are not logged in' return NotificationSeenResult(error=error) diff --git a/server.py b/server.py index 0304a0c..9591541 100644 --- a/server.py +++ b/server.py @@ -1,53 +1,54 @@ -import sys import logging +import sys from settings import PORT + log_settings = { - "version": 1, - "disable_existing_loggers": True, - "formatters": { - "default": { - "()": "uvicorn.logging.DefaultFormatter", - "fmt": "%(levelprefix)s %(message)s", - "use_colors": None, + 'version': 1, + 'disable_existing_loggers': True, + 'formatters': { + 'default': { + '()': 'uvicorn.logging.DefaultFormatter', + 'fmt': '%(levelprefix)s %(message)s', + 'use_colors': None, }, - "access": { - "()": "uvicorn.logging.AccessFormatter", - "fmt": '%(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s', + 'access': { + '()': 'uvicorn.logging.AccessFormatter', + 'fmt': '%(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s', }, }, - "handlers": { - "default": { - "formatter": "default", - "class": "logging.StreamHandler", - "stream": "ext://sys.stderr", + 'handlers': { + 'default': { + 'formatter': 'default', + 'class': 'logging.StreamHandler', + 'stream': 'ext://sys.stderr', }, - "access": { - "formatter": "access", - "class": "logging.StreamHandler", - "stream": "ext://sys.stdout", + 'access': { + 'formatter': 'access', + 'class': 'logging.StreamHandler', + 'stream': 'ext://sys.stdout', }, }, - "loggers": { - "uvicorn": {"handlers": ["default"], "level": "INFO"}, - "uvicorn.error": {"level": "INFO", "handlers": ["default"], "propagate": True}, - "uvicorn.access": {"handlers": ["access"], "level": "INFO", "propagate": False}, + 'loggers': { + 'uvicorn': {'handlers': ['default'], 'level': 'INFO'}, + 'uvicorn.error': {'level': 'INFO', 'handlers': ['default'], 'propagate': True}, + 'uvicorn.access': {'handlers': ['access'], 'level': 'INFO', 'propagate': False}, }, } local_headers = [ - ("Access-Control-Allow-Methods", "GET, POST, OPTIONS, HEAD"), - ("Access-Control-Allow-Origin", "https://localhost:3000"), + ('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, HEAD'), + ('Access-Control-Allow-Origin', 'https://localhost:3000'), ( - "Access-Control-Allow-Headers", - "DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization", + 'Access-Control-Allow-Headers', + 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization', ), - ("Access-Control-Expose-Headers", "Content-Length,Content-Range"), - ("Access-Control-Allow-Credentials", "true"), + ('Access-Control-Expose-Headers', 'Content-Length,Content-Range'), + ('Access-Control-Allow-Credentials', 'true'), ] -logger = logging.getLogger("[server] ") +logger = logging.getLogger('[server] ') logger.setLevel(logging.DEBUG) @@ -55,16 +56,16 @@ def exception_handler(_et, exc, _tb): logger.error(..., exc_info=(type(exc), exc, exc.__traceback__)) -if __name__ == "__main__": +if __name__ == '__main__': sys.excepthook = exception_handler from granian.constants import Interfaces from granian.server import Granian - print("[server] started") + print('[server] started') granian_instance = Granian( - "main:app", - address="0.0.0.0", # noqa S104 + 'main:app', + address='0.0.0.0', # noqa S104 port=PORT, workers=2, threads=2, diff --git a/services/auth.py b/services/auth.py index a1851a6..b84a9bd 100644 --- a/services/auth.py +++ b/services/auth.py @@ -1,59 +1,60 @@ +import logging + from aiohttp import ClientSession from strawberry.extensions import Extension -from settings import AUTH_URL -from services.db import local_session from orm.author import Author +from services.db import local_session +from settings import AUTH_URL -import logging -logger = logging.getLogger("\t[services.auth]\t") +logger = logging.getLogger('\t[services.auth]\t') logger.setLevel(logging.DEBUG) async def check_auth(req) -> str | None: - token = req.headers.get("Authorization") - user_id = "" + token = req.headers.get('Authorization') + user_id = '' if token: - query_name = "validate_jwt_token" - operation = "ValidateToken" + query_name = 'validate_jwt_token' + operation = 'ValidateToken' headers = { - "Content-Type": "application/json", + 'Content-Type': 'application/json', } variables = { - "params": { - "token_type": "access_token", - "token": token, + 'params': { + 'token_type': 'access_token', + 'token': token, } } gql = { - "query": f"query {operation}($params: ValidateJWTTokenInput!) {{ {query_name}(params: $params) {{ is_valid claims }} }}", - "variables": variables, - "operationName": operation, + 'query': f'query {operation}($params: ValidateJWTTokenInput!) {{ {query_name}(params: $params) {{ is_valid claims }} }}', + 'variables': variables, + 'operationName': operation, } try: # Asynchronous HTTP request to the authentication server async with ClientSession() as session: async with session.post(AUTH_URL, json=gql, headers=headers) as response: - print(f"[services.auth] HTTP Response {response.status} {await response.text()}") + print(f'[services.auth] HTTP Response {response.status} {await response.text()}') if response.status == 200: data = await response.json() - errors = data.get("errors") + errors = data.get('errors') if errors: - print(f"[services.auth] errors: {errors}") + print(f'[services.auth] errors: {errors}') else: - user_id = data.get("data", {}).get(query_name, {}).get("claims", {}).get("sub") + user_id = data.get('data', {}).get(query_name, {}).get('claims', {}).get('sub') if user_id: - print(f"[services.auth] got user_id: {user_id}") + print(f'[services.auth] got user_id: {user_id}') return user_id except Exception as e: import traceback traceback.print_exc() # Handling and logging exceptions during authentication check - print(f"[services.auth] Error {e}") + print(f'[services.auth] Error {e}') return None @@ -61,14 +62,14 @@ async def check_auth(req) -> str | None: class LoginRequiredMiddleware(Extension): async def on_request_start(self): context = self.execution_context.context - req = context.get("request") + req = context.get('request') user_id = await check_auth(req) if user_id: - context["user_id"] = user_id.strip() + context['user_id'] = user_id.strip() with local_session() as session: author = session.query(Author).filter(Author.user == user_id).first() if author: - context["author_id"] = author.id - context["user_id"] = user_id or None + context['author_id'] = author.id + context['user_id'] = user_id or None self.execution_context.context = context diff --git a/services/core.py b/services/core.py index c0eb25b..148b393 100644 --- a/services/core.py +++ b/services/core.py @@ -4,7 +4,8 @@ import aiohttp from settings import API_BASE -headers = {"Content-Type": "application/json"} + +headers = {'Content-Type': 'application/json'} # TODO: rewrite to orm usage? @@ -13,39 +14,39 @@ headers = {"Content-Type": "application/json"} async def _request_endpoint(query_name, body) -> Any: async with aiohttp.ClientSession() as session: async with session.post(API_BASE, headers=headers, json=body) as response: - print(f"[services.core] {query_name} HTTP Response {response.status} {await response.text()}") + print(f'[services.core] {query_name} HTTP Response {response.status} {await response.text()}') if response.status == 200: r = await response.json() if r: - return r.get("data", {}).get(query_name, {}) + return r.get('data', {}).get(query_name, {}) return [] async def get_followed_shouts(author_id: int): - query_name = "load_shouts_followed" - operation = "GetFollowedShouts" + query_name = 'load_shouts_followed' + operation = 'GetFollowedShouts' query = f"""query {operation}($author_id: Int!, limit: Int, offset: Int) {{ {query_name}(author_id: $author_id, limit: $limit, offset: $offset) {{ id slug title }} }}""" gql = { - "query": query, - "operationName": operation, - "variables": {"author_id": author_id, "limit": 1000, "offset": 0}, # FIXME: too big limit + 'query': query, + 'operationName': operation, + 'variables': {'author_id': author_id, 'limit': 1000, 'offset': 0}, # FIXME: too big limit } return await _request_endpoint(query_name, gql) async def get_shout(shout_id): - query_name = "get_shout" - operation = "GetShout" + query_name = 'get_shout' + operation = 'GetShout' query = f"""query {operation}($slug: String, $shout_id: Int) {{ {query_name}(slug: $slug, shout_id: $shout_id) {{ id slug title authors {{ id slug name pic }} }} }}""" - gql = {"query": query, "operationName": operation, "variables": {"slug": None, "shout_id": shout_id}} + gql = {'query': query, 'operationName': operation, 'variables': {'slug': None, 'shout_id': shout_id}} return await _request_endpoint(query_name, gql) diff --git a/services/db.py b/services/db.py index 427254c..9475907 100644 --- a/services/db.py +++ b/services/db.py @@ -8,15 +8,16 @@ from sqlalchemy.sql.schema import Table from settings import DB_URL + engine = create_engine(DB_URL, echo=False, pool_size=10, max_overflow=20) -T = TypeVar("T") +T = TypeVar('T') REGISTRY: Dict[str, type] = {} # @contextmanager -def local_session(src=""): +def local_session(src=''): return Session(bind=engine, expire_on_commit=False) # try: @@ -44,7 +45,7 @@ class Base(declarative_base()): __init__: Callable __allow_unmapped__ = True __abstract__ = True - __table_args__ = {"extend_existing": True} + __table_args__ = {'extend_existing': True} id = Column(Integer, primary_key=True) @@ -53,12 +54,12 @@ class Base(declarative_base()): def dict(self) -> Dict[str, Any]: column_names = self.__table__.columns.keys() - if "_sa_instance_state" in column_names: - column_names.remove("_sa_instance_state") + if '_sa_instance_state' in column_names: + column_names.remove('_sa_instance_state') try: return {c: getattr(self, c) for c in column_names} except Exception as e: - print(f"[services.db] Error dict: {e}") + print(f'[services.db] Error dict: {e}') return {} def update(self, values: Dict[str, Any]) -> None: diff --git a/services/rediscache.py b/services/rediscache.py index 2cd3037..15a89b3 100644 --- a/services/rediscache.py +++ b/services/rediscache.py @@ -1,12 +1,13 @@ -import json - -import redis.asyncio as aredis import asyncio -from settings import REDIS_URL - +import json import logging -logger = logging.getLogger("\t[services.redis]\t") +import redis.asyncio as aredis + +from settings import REDIS_URL + + +logger = logging.getLogger('\t[services.redis]\t') logger.setLevel(logging.DEBUG) @@ -26,11 +27,11 @@ class RedisCache: async def execute(self, command, *args, **kwargs): if self._client: try: - logger.debug(command + " " + " ".join(args)) + logger.debug(command + ' ' + ' '.join(args)) r = await self._client.execute_command(command, *args, **kwargs) return r except Exception as e: - logger.error(f"{e}") + logger.error(f'{e}') return None async def subscribe(self, *channels): @@ -60,15 +61,15 @@ class RedisCache: while True: message = await pubsub.get_message() - if message and isinstance(message["data"], (str, bytes, bytearray)): - logger.debug("pubsub got msg") + if message and isinstance(message['data'], (str, bytes, bytearray)): + logger.debug('pubsub got msg') try: - yield json.loads(message["data"]), message.get("channel") + yield json.loads(message['data']), message.get('channel') except Exception as e: - logger.error(f"{e}") + logger.error(f'{e}') await asyncio.sleep(1) redis = RedisCache() -__all__ = ["redis"] +__all__ = ['redis'] diff --git a/settings.py b/settings.py index 58633a4..a2fde0d 100644 --- a/settings.py +++ b/settings.py @@ -1,13 +1,14 @@ from os import environ + PORT = 8000 DB_URL = ( - environ.get("DATABASE_URL", environ.get("DB_URL", "")).replace("postgres://", "postgresql://") - or "postgresql://postgres@localhost:5432/discoursio" + environ.get('DATABASE_URL', environ.get('DB_URL', '')).replace('postgres://', 'postgresql://') + or 'postgresql://postgres@localhost:5432/discoursio' ) -REDIS_URL = environ.get("REDIS_URL") or "redis://127.0.0.1" -API_BASE = environ.get("API_BASE") or "https://core.discours.io" -AUTH_URL = environ.get("AUTH_URL") or "https://auth.discours.io" -MODE = environ.get("MODE") or "production" -SENTRY_DSN = environ.get("SENTRY_DSN") -DEV_SERVER_PID_FILE_NAME = "dev-server.pid" +REDIS_URL = environ.get('REDIS_URL') or 'redis://127.0.0.1' +API_BASE = environ.get('API_BASE') or 'https://core.discours.io' +AUTH_URL = environ.get('AUTH_URL') or 'https://auth.discours.io' +MODE = environ.get('MODE') or 'production' +SENTRY_DSN = environ.get('SENTRY_DSN') +DEV_SERVER_PID_FILE_NAME = 'dev-server.pid'