diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..43fa826 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,21 @@ +fail_fast: true + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-added-large-files + - id: detect-private-key + - id: double-quote-string-fixer + - id: check-ast + - id: check-merge-conflict + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.13 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/main.py b/main.py index 9b7f841..440c7d3 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import asyncio +import logging import os from os.path import exists @@ -13,11 +14,10 @@ from resolvers.listener import notifications_worker from resolvers.schema import schema from services.rediscache import redis from settings import DEV_SERVER_PID_FILE_NAME, MODE, SENTRY_DSN -import logging logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger("\t[main]\t") +logger = logging.getLogger('\t[main]\t') logger.setLevel(logging.DEBUG) @@ -27,9 +27,9 @@ async def start_up(): task = asyncio.create_task(notifications_worker()) logger.info(task) - if MODE == "dev": + if MODE == 'dev': if exists(DEV_SERVER_PID_FILE_NAME): - with open(DEV_SERVER_PID_FILE_NAME, "w", encoding="utf-8") as f: + with open(DEV_SERVER_PID_FILE_NAME, 'w', encoding='utf-8') as f: f.write(str(os.getpid())) else: try: @@ -46,7 +46,7 @@ async def start_up(): ], ) except Exception as e: - logger.error("sentry init error", e) + logger.error('sentry init error', e) async def shutdown(): @@ -54,4 +54,4 @@ async def shutdown(): app = Starlette(debug=True, on_startup=[start_up], on_shutdown=[shutdown]) -app.mount("/", GraphQL(schema, debug=True)) +app.mount('/', GraphQL(schema, debug=True)) diff --git a/orm/author.py b/orm/author.py index ba9fcc5..49bf6d1 100644 --- a/orm/author.py +++ b/orm/author.py @@ -1,46 +1,45 @@ import time -from sqlalchemy import JSON as JSONType -from sqlalchemy import Boolean, Column, ForeignKey, Integer, String +from sqlalchemy import JSON, Boolean, Column, ForeignKey, Integer, String from sqlalchemy.orm import relationship from services.db import Base class AuthorRating(Base): - __tablename__ = "author_rating" + __tablename__ = 'author_rating' id = None # type: ignore - rater = Column(ForeignKey("author.id"), primary_key=True, index=True) - author = Column(ForeignKey("author.id"), primary_key=True, index=True) + rater = Column(ForeignKey('author.id'), primary_key=True, index=True) + author = Column(ForeignKey('author.id'), primary_key=True, index=True) plus = Column(Boolean) class AuthorFollower(Base): - __tablename__ = "author_follower" + __tablename__ = 'author_follower' id = None # type: ignore - follower = Column(ForeignKey("author.id"), primary_key=True, index=True) - author = Column(ForeignKey("author.id"), primary_key=True, index=True) + follower = Column(ForeignKey('author.id'), primary_key=True, index=True) + author = Column(ForeignKey('author.id'), primary_key=True, index=True) created_at = Column(Integer, nullable=False, default=lambda: int(time.time())) auto = Column(Boolean, nullable=False, default=False) class Author(Base): - __tablename__ = "author" + __tablename__ = 'author' user = Column(String, unique=True) # unbounded link with authorizer's User type - name = Column(String, nullable=True, comment="Display name") + name = Column(String, nullable=True, comment='Display name') slug = Column(String, unique=True, comment="Author's slug") - bio = Column(String, nullable=True, comment="Bio") # status description - about = Column(String, nullable=True, comment="About") # long and formatted - pic = Column(String, nullable=True, comment="Picture") - links = Column(JSONType, nullable=True, comment="Links") + bio = Column(String, nullable=True, comment='Bio') # status description + about = Column(String, nullable=True, comment='About') # long and formatted + pic = Column(String, nullable=True, comment='Picture') + links = Column(JSON, nullable=True, comment='Links') ratings = relationship(AuthorRating, foreign_keys=AuthorRating.author) created_at = Column(Integer, nullable=False, default=lambda: int(time.time())) last_seen = Column(Integer, nullable=False, default=lambda: int(time.time())) updated_at = Column(Integer, nullable=False, default=lambda: int(time.time())) - deleted_at = Column(Integer, nullable=True, comment="Deleted at") + deleted_at = Column(Integer, nullable=True, comment='Deleted at') diff --git a/orm/notification.py b/orm/notification.py index 0622ad4..2b09ea1 100644 --- a/orm/notification.py +++ b/orm/notification.py @@ -1,42 +1,41 @@ +import time from enum import Enum as Enumeration -from sqlalchemy import JSON as JSONType, func, cast -from sqlalchemy import Column, Enum, ForeignKey, Integer, String +from sqlalchemy import JSON, Column, ForeignKey, Integer, String from sqlalchemy.orm import relationship -from sqlalchemy.orm.session import engine from orm.author import Author from services.db import Base -import time + class NotificationEntity(Enumeration): - REACTION = "reaction" - SHOUT = "shout" - FOLLOWER = "follower" + REACTION = 'reaction' + SHOUT = 'shout' + FOLLOWER = 'follower' class NotificationAction(Enumeration): - CREATE = "create" - UPDATE = "update" - DELETE = "delete" - SEEN = "seen" - FOLLOW = "follow" - UNFOLLOW = "unfollow" + CREATE = 'create' + UPDATE = 'update' + DELETE = 'delete' + SEEN = 'seen' + FOLLOW = 'follow' + UNFOLLOW = 'unfollow' class NotificationSeen(Base): - __tablename__ = "notification_seen" + __tablename__ = 'notification_seen' - viewer = Column(ForeignKey("author.id")) - notification = Column(ForeignKey("notification.id")) + viewer = Column(ForeignKey('author.id')) + notification = Column(ForeignKey('notification.id')) class Notification(Base): - __tablename__ = "notification" + __tablename__ = 'notification' created_at = Column(Integer, server_default=str(int(time.time()))) entity = Column(String, nullable=False) action = Column(String, nullable=False) - payload = Column(JSONType, nullable=True) + payload = Column(JSON, nullable=True) - seen = relationship(lambda: Author, secondary="notification_seen") + seen = relationship(lambda: Author, secondary='notification_seen') diff --git a/resolvers/listener.py b/resolvers/listener.py index 39d3ee3..95b762c 100644 --- a/resolvers/listener.py +++ b/resolvers/listener.py @@ -1,11 +1,13 @@ -from orm.notification import Notification, NotificationAction, NotificationEntity -from resolvers.model import NotificationReaction, NotificationAuthor, NotificationShout -from services.db import local_session -from services.rediscache import redis import asyncio import logging -logger = logging.getLogger(f"[listener.listen_task] ") +from orm.notification import Notification +from resolvers.model import NotificationAuthor, NotificationReaction, NotificationShout +from services.db import local_session +from services.rediscache import redis + + +logger = logging.getLogger('[listener.listen_task] ') logger.setLevel(logging.DEBUG) @@ -19,8 +21,8 @@ async def handle_notification(n: ServiceMessage, channel: str): """создаеёт новое хранимое уведомление""" with local_session() as session: try: - if channel.startswith("follower:"): - author_id = int(channel.split(":")[1]) + if channel.startswith('follower:'): + author_id = int(channel.split(':')[1]) if isinstance(n.payload, NotificationAuthor): n.payload.following_id = author_id n = Notification(action=n.action, entity=n.entity, payload=n.payload) @@ -28,7 +30,7 @@ async def handle_notification(n: ServiceMessage, channel: str): session.commit() except Exception as e: session.rollback() - logger.error(f"[listener.handle_reaction] error: {str(e)}") + logger.error(f'[listener.handle_reaction] error: {str(e)}') async def listen_task(pattern): @@ -38,9 +40,9 @@ async def listen_task(pattern): notification_message = ServiceMessage(**message_data) await handle_notification(notification_message, str(channel)) except Exception as e: - logger.error(f"Error processing notification: {str(e)}") + logger.error(f'Error processing notification: {str(e)}') async def notifications_worker(): # Use asyncio.gather to run tasks concurrently - await asyncio.gather(listen_task("follower:*"), listen_task("reaction"), listen_task("shout")) + await asyncio.gather(listen_task('follower:*'), listen_task('reaction'), listen_task('shout')) diff --git a/resolvers/load.py b/resolvers/load.py index 65f3b1d..fce5671 100644 --- a/resolvers/load.py +++ b/resolvers/load.py @@ -1,27 +1,36 @@ +import json +import logging +import time +from typing import Dict, List +import strawberry +from sqlalchemy import and_, select +from sqlalchemy.orm import aliased from sqlalchemy.sql import not_ -from services.db import local_session + +from orm.notification import ( + Notification, + NotificationAction, + NotificationEntity, + NotificationSeen, +) from resolvers.model import ( - NotificationReaction, - NotificationGroup, - NotificationShout, NotificationAuthor, + NotificationGroup, + NotificationReaction, + NotificationShout, NotificationsResult, ) -from orm.notification import NotificationAction, NotificationEntity, NotificationSeen, Notification -from typing import Dict, List -import time, json -import strawberry -from sqlalchemy.orm import aliased -from sqlalchemy.sql.expression import or_ -from sqlalchemy import select, and_ -import logging +from services.db import local_session -logger = logging.getLogger("[resolvers.schema] ") + +logger = logging.getLogger('[resolvers.schema] ') logger.setLevel(logging.DEBUG) -async def get_notifications_grouped(author_id: int, after: int = 0, limit: int = 10, offset: int = 0): +async def get_notifications_grouped( # noqa: C901 + author_id: int, after: int = 0, limit: int = 10, offset: int = 0 +): """ Retrieves notifications for a given author. @@ -47,10 +56,13 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int = authors: List[NotificationAuthor], # List of authors involved in the thread. } """ - NotificationSeenAlias = aliased(NotificationSeen) - query = select(Notification, NotificationSeenAlias.viewer.label("seen")).outerjoin( + seen_alias = aliased(NotificationSeen) + query = select(Notification, seen_alias.viewer.label('seen')).outerjoin( NotificationSeen, - and_(NotificationSeen.viewer == author_id, NotificationSeen.notification == Notification.id), + and_( + NotificationSeen.viewer == author_id, + NotificationSeen.notification == Notification.id, + ), ) if after: query = query.filter(Notification.created_at > after) @@ -62,23 +74,36 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int = notifications_by_thread: Dict[str, List[Notification]] = {} groups_by_thread: Dict[str, NotificationGroup] = {} with local_session() as session: - total = session.query(Notification).filter(and_(Notification.action == NotificationAction.CREATE.value, Notification.created_at > after)).count() - unread = session.query(Notification).filter( - and_( - Notification.action == NotificationAction.CREATE.value, - Notification.created_at > after, - not_(Notification.seen) + total = ( + session.query(Notification) + .filter( + and_( + Notification.action == NotificationAction.CREATE.value, + Notification.created_at > after, + ) ) - ).count() + .count() + ) + unread = ( + session.query(Notification) + .filter( + and_( + Notification.action == NotificationAction.CREATE.value, + Notification.created_at > after, + not_(Notification.seen), + ) + ) + .count() + ) notifications_result = session.execute(query) - for n, seen in notifications_result: - thread_id = "" + for n, _seen in notifications_result: + thread_id = '' payload = json.loads(n.payload) - logger.debug(f"[resolvers.schema] {n.action} {n.entity}: {payload}") - if n.entity == "shout" and n.action == "create": + logger.debug(f'[resolvers.schema] {n.action} {n.entity}: {payload}') + if n.entity == 'shout' and n.action == 'create': shout: NotificationShout = payload - thread_id += f"{shout.id}" - logger.debug(f"create shout: {shout}") + thread_id += f'{shout.id}' + logger.debug(f'create shout: {shout}') group = groups_by_thread.get(thread_id) or NotificationGroup( id=thread_id, entity=n.entity, @@ -86,8 +111,8 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int = authors=shout.authors, updated_at=shout.created_at, reactions=[], - action="create", - seen=author_id in n.seen + action='create', + seen=author_id in n.seen, ) # store group in result groups_by_thread[thread_id] = group @@ -99,11 +124,11 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int = elif n.entity == NotificationEntity.REACTION.value and n.action == NotificationAction.CREATE.value: reaction: NotificationReaction = payload shout: NotificationShout = reaction.shout - thread_id += f"{reaction.shout}" - if reaction.kind == "LIKE" or reaction.kind == "DISLIKE": + thread_id += f'{reaction.shout}' + if not bool(reaction.reply_to) and (reaction.kind == 'LIKE' or reaction.kind == 'DISLIKE'): # TODO: making published reaction vote announce pass - elif reaction.kind == "COMMENT": + elif reaction.kind == 'COMMENT': if reaction.reply_to: thread_id += f"{'::' + str(reaction.reply_to)}" group: NotificationGroup | None = groups_by_thread.get(thread_id) @@ -128,8 +153,9 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int = break else: # init notification group - reactions = [] - reactions.append(reaction.id) + reactions = [ + reaction.id, + ] group = NotificationGroup( id=thread_id, action=n.action, @@ -140,7 +166,7 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int = authors=[ reaction.created_by, ], - seen=author_id in n.seen + seen=author_id in n.seen, ) # store group in result groups_by_thread[thread_id] = group @@ -149,20 +175,22 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int = notifications.append(n) notifications_by_thread[thread_id] = notifications - elif n.entity == "follower": - thread_id = "followers" + elif n.entity == 'follower': + thread_id = 'followers' follower: NotificationAuthor = payload group = groups_by_thread.get(thread_id) or NotificationGroup( - id=thread_id, - authors=[follower], - updated_at=int(time.time()), - shout=None, - reactions=[], - entity="follower", - action="follow", - seen=author_id in n.seen - ) - group.authors = [follower, ] + id=thread_id, + authors=[follower], + updated_at=int(time.time()), + shout=None, + reactions=[], + entity='follower', + action='follow', + seen=author_id in n.seen, + ) + group.authors = [ + follower, + ] group.updated_at = int(time.time()) # store group in result groups_by_thread[thread_id] = group @@ -182,7 +210,7 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int = class Query: @strawberry.field async def load_notifications(self, info, after: int, limit: int = 50, offset: int = 0) -> NotificationsResult: - author_id = info.context.get("author_id") + author_id = info.context.get('author_id') groups: Dict[str, NotificationGroup] = {} if author_id: groups, notifications, total, unread = await get_notifications_grouped(author_id, after, limit, offset) diff --git a/resolvers/model.py b/resolvers/model.py index 357f2bf..0ccc9b4 100644 --- a/resolvers/model.py +++ b/resolvers/model.py @@ -1,8 +1,11 @@ -import strawberry from typing import List, Optional + +import strawberry from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper + from orm.notification import Notification as NotificationMessage + strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper() diff --git a/resolvers/schema.py b/resolvers/schema.py index 7e87145..da45ad1 100644 --- a/resolvers/schema.py +++ b/resolvers/schema.py @@ -1,12 +1,12 @@ - import strawberry from strawberry.schema.config import StrawberryConfig -from services.auth import LoginRequiredMiddleware from resolvers.load import Query from resolvers.seen import Mutation +from services.auth import LoginRequiredMiddleware from services.db import Base, engine + schema = strawberry.Schema( query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False), extensions=[LoginRequiredMiddleware] ) diff --git a/resolvers/seen.py b/resolvers/seen.py index bfbc260..d1cce6f 100644 --- a/resolvers/seen.py +++ b/resolvers/seen.py @@ -1,14 +1,14 @@ -from sqlalchemy import and_ -from orm.notification import NotificationSeen -from services.db import local_session -from resolvers.model import Notification, NotificationSeenResult, NotificationReaction +import json +import logging import strawberry -import logging -import json - +from sqlalchemy import and_ from sqlalchemy.exc import SQLAlchemyError +from orm.notification import NotificationSeen +from resolvers.model import Notification, NotificationReaction, NotificationSeenResult +from services.db import local_session + logger = logging.getLogger(__name__) @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) class Mutation: @strawberry.mutation async def mark_seen(self, info, notification_id: int) -> NotificationSeenResult: - author_id = info.context.get("author_id") + author_id = info.context.get('author_id') if author_id: with local_session() as session: try: @@ -27,9 +27,9 @@ class Mutation: except SQLAlchemyError as e: session.rollback() logger.error( - f"[mark_notification_as_read] Ошибка при обновлении статуса прочтения уведомления: {str(e)}" + f'[mark_notification_as_read] Ошибка при обновлении статуса прочтения уведомления: {str(e)}' ) - return NotificationSeenResult(error="cant mark as read") + return NotificationSeenResult(error='cant mark as read') return NotificationSeenResult(error=None) @strawberry.mutation @@ -37,7 +37,7 @@ class Mutation: # TODO: use latest loaded notification_id as input offset parameter error = None try: - author_id = info.context.get("author_id") + author_id = info.context.get('author_id') if author_id: with local_session() as session: nnn = session.query(Notification).filter(and_(Notification.created_at > after)).all() @@ -46,26 +46,26 @@ class Mutation: ns = NotificationSeen(notification=n.id, viewer=author_id) session.add(ns) session.commit() - except SQLAlchemyError as e: + except SQLAlchemyError: session.rollback() except Exception as e: print(e) - error = "cant mark as read" + error = 'cant mark as read' return NotificationSeenResult(error=error) @strawberry.mutation async def mark_seen_thread(self, info, thread: str, after: int) -> NotificationSeenResult: error = None - author_id = info.context.get("author_id") + author_id = info.context.get('author_id') if author_id: - [shout_id, reply_to_id] = thread.split("::") + [shout_id, reply_to_id] = thread.split('::') with local_session() as session: # TODO: handle new follower and new shout notifications new_reaction_notifications = ( session.query(Notification) .filter( - Notification.action == "create", - Notification.entity == "reaction", + Notification.action == 'create', + Notification.entity == 'reaction', Notification.created_at > after, ) .all() @@ -73,13 +73,13 @@ class Mutation: removed_reaction_notifications = ( session.query(Notification) .filter( - Notification.action == "delete", - Notification.entity == "reaction", + Notification.action == 'delete', + Notification.entity == 'reaction', Notification.created_at > after, ) .all() ) - exclude = set([]) + exclude = set() for nr in removed_reaction_notifications: reaction: NotificationReaction = json.loads(nr.payload) exclude.add(reaction.id) @@ -97,5 +97,5 @@ class Mutation: except Exception: session.rollback() else: - error = "You are not logged in" + error = 'You are not logged in' return NotificationSeenResult(error=error) diff --git a/services/auth.py b/services/auth.py index 6912382..b84a9bd 100644 --- a/services/auth.py +++ b/services/auth.py @@ -1,57 +1,60 @@ +import logging + from aiohttp import ClientSession from strawberry.extensions import Extension -from settings import AUTH_URL -from services.db import local_session from orm.author import Author +from services.db import local_session +from settings import AUTH_URL -import logging -logger = logging.getLogger("\t[services.auth]\t") +logger = logging.getLogger('\t[services.auth]\t') logger.setLevel(logging.DEBUG) + async def check_auth(req) -> str | None: - token = req.headers.get("Authorization") - user_id = "" + token = req.headers.get('Authorization') + user_id = '' if token: - query_name = "validate_jwt_token" - operation = "ValidateToken" + query_name = 'validate_jwt_token' + operation = 'ValidateToken' headers = { - "Content-Type": "application/json", + 'Content-Type': 'application/json', } variables = { - "params": { - "token_type": "access_token", - "token": token, + 'params': { + 'token_type': 'access_token', + 'token': token, } } gql = { - "query": f"query {operation}($params: ValidateJWTTokenInput!) {{ {query_name}(params: $params) {{ is_valid claims }} }}", - "variables": variables, - "operationName": operation, + 'query': f'query {operation}($params: ValidateJWTTokenInput!) {{ {query_name}(params: $params) {{ is_valid claims }} }}', + 'variables': variables, + 'operationName': operation, } try: # Asynchronous HTTP request to the authentication server async with ClientSession() as session: async with session.post(AUTH_URL, json=gql, headers=headers) as response: - print(f"[services.auth] HTTP Response {response.status} {await response.text()}") + print(f'[services.auth] HTTP Response {response.status} {await response.text()}') if response.status == 200: data = await response.json() - errors = data.get("errors") + errors = data.get('errors') if errors: - print(f"[services.auth] errors: {errors}") + print(f'[services.auth] errors: {errors}') else: - user_id = data.get("data", {}).get(query_name, {}).get("claims", {}).get("sub") + user_id = data.get('data', {}).get(query_name, {}).get('claims', {}).get('sub') if user_id: - print(f"[services.auth] got user_id: {user_id}") + print(f'[services.auth] got user_id: {user_id}') return user_id except Exception as e: import traceback + traceback.print_exc() # Handling and logging exceptions during authentication check - print(f"[services.auth] Error {e}") + print(f'[services.auth] Error {e}') return None @@ -59,14 +62,14 @@ async def check_auth(req) -> str | None: class LoginRequiredMiddleware(Extension): async def on_request_start(self): context = self.execution_context.context - req = context.get("request") + req = context.get('request') user_id = await check_auth(req) if user_id: - context["user_id"] = user_id.strip() + context['user_id'] = user_id.strip() with local_session() as session: author = session.query(Author).filter(Author.user == user_id).first() if author: - context["author_id"] = author.id - context["user_id"] = user_id or None + context['author_id'] = author.id + context['user_id'] = user_id or None self.execution_context.context = context diff --git a/services/core.py b/services/core.py index 8957f02..148b393 100644 --- a/services/core.py +++ b/services/core.py @@ -4,47 +4,49 @@ import aiohttp from settings import API_BASE -headers = {"Content-Type": "application/json"} + +headers = {'Content-Type': 'application/json'} # TODO: rewrite to orm usage? + async def _request_endpoint(query_name, body) -> Any: async with aiohttp.ClientSession() as session: async with session.post(API_BASE, headers=headers, json=body) as response: - print(f"[services.core] {query_name} HTTP Response {response.status} {await response.text()}") + print(f'[services.core] {query_name} HTTP Response {response.status} {await response.text()}') if response.status == 200: r = await response.json() if r: - return r.get("data", {}).get(query_name, {}) + return r.get('data', {}).get(query_name, {}) return [] async def get_followed_shouts(author_id: int): - query_name = "load_shouts_followed" - operation = "GetFollowedShouts" + query_name = 'load_shouts_followed' + operation = 'GetFollowedShouts' query = f"""query {operation}($author_id: Int!, limit: Int, offset: Int) {{ {query_name}(author_id: $author_id, limit: $limit, offset: $offset) {{ id slug title }} }}""" gql = { - "query": query, - "operationName": operation, - "variables": {"author_id": author_id, "limit": 1000, "offset": 0}, # FIXME: too big limit + 'query': query, + 'operationName': operation, + 'variables': {'author_id': author_id, 'limit': 1000, 'offset': 0}, # FIXME: too big limit } return await _request_endpoint(query_name, gql) async def get_shout(shout_id): - query_name = "get_shout" - operation = "GetShout" + query_name = 'get_shout' + operation = 'GetShout' query = f"""query {operation}($slug: String, $shout_id: Int) {{ {query_name}(slug: $slug, shout_id: $shout_id) {{ id slug title authors {{ id slug name pic }} }} }}""" - gql = {"query": query, "operationName": operation, "variables": {"slug": None, "shout_id": shout_id}} + gql = {'query': query, 'operationName': operation, 'variables': {'slug': None, 'shout_id': shout_id}} return await _request_endpoint(query_name, gql) diff --git a/services/db.py b/services/db.py index e9b15d6..c1e3bd6 100644 --- a/services/db.py +++ b/services/db.py @@ -9,15 +9,16 @@ from sqlalchemy.sql.schema import Table from settings import DB_URL + engine = create_engine(DB_URL, echo=False, pool_size=10, max_overflow=20) -T = TypeVar("T") +T = TypeVar('T') REGISTRY: Dict[str, type] = {} # @contextmanager -def local_session(src=""): +def local_session(src=''): return Session(bind=engine, expire_on_commit=False) # try: @@ -45,7 +46,7 @@ class Base(declarative_base()): __init__: Callable __allow_unmapped__ = True __abstract__ = True - __table_args__ = {"extend_existing": True} + __table_args__ = {'extend_existing': True} id = Column(Integer, primary_key=True) @@ -54,12 +55,12 @@ class Base(declarative_base()): def dict(self) -> Dict[str, Any]: column_names = self.__table__.columns.keys() - if "_sa_instance_state" in column_names: - column_names.remove("_sa_instance_state") + if '_sa_instance_state' in column_names: + column_names.remove('_sa_instance_state') try: return {c: getattr(self, c) for c in column_names} except Exception as e: - print(f"[services.db] Error dict: {e}") + print(f'[services.db] Error dict: {e}') return {} def update(self, values: Dict[str, Any]) -> None: diff --git a/services/rediscache.py b/services/rediscache.py index 3609a99..15a89b3 100644 --- a/services/rediscache.py +++ b/services/rediscache.py @@ -1,14 +1,16 @@ -import json - -import redis.asyncio as aredis import asyncio -from settings import REDIS_URL - +import json import logging -logger = logging.getLogger("\t[services.redis]\t") +import redis.asyncio as aredis + +from settings import REDIS_URL + + +logger = logging.getLogger('\t[services.redis]\t') logger.setLevel(logging.DEBUG) + class RedisCache: def __init__(self, uri=REDIS_URL): self._uri: str = uri @@ -25,11 +27,11 @@ class RedisCache: async def execute(self, command, *args, **kwargs): if self._client: try: - logger.debug(command + " " + " ".join(args)) + logger.debug(command + ' ' + ' '.join(args)) r = await self._client.execute_command(command, *args, **kwargs) return r except Exception as e: - logger.error(f"{e}") + logger.error(f'{e}') return None async def subscribe(self, *channels): @@ -59,15 +61,15 @@ class RedisCache: while True: message = await pubsub.get_message() - if message and isinstance(message["data"], (str, bytes, bytearray)): - logger.debug("pubsub got msg") + if message and isinstance(message['data'], (str, bytes, bytearray)): + logger.debug('pubsub got msg') try: - yield json.loads(message["data"]), message.get("channel") + yield json.loads(message['data']), message.get('channel') except Exception as e: - logger.error(f"{e}") + logger.error(f'{e}') await asyncio.sleep(1) redis = RedisCache() -__all__ = ["redis"] +__all__ = ['redis'] diff --git a/settings.py b/settings.py index eaae333..3eae789 100644 --- a/settings.py +++ b/settings.py @@ -1,13 +1,14 @@ from os import environ + PORT = 80 DB_URL = ( - environ.get("DATABASE_URL", environ.get("DB_URL", "")).replace("postgres://", "postgresql://") - or "postgresql://postgres@localhost:5432/discoursio" + environ.get('DATABASE_URL', environ.get('DB_URL', '')).replace('postgres://', 'postgresql://') + or 'postgresql://postgres@localhost:5432/discoursio' ) -REDIS_URL = environ.get("REDIS_URL") or "redis://127.0.0.1" -API_BASE = environ.get("API_BASE") or "https://core.discours.io" -AUTH_URL = environ.get("AUTH_URL") or "https://auth.discours.io" -MODE = environ.get("MODE") or "production" -SENTRY_DSN = environ.get("SENTRY_DSN") -DEV_SERVER_PID_FILE_NAME = "dev-server.pid" +REDIS_URL = environ.get('REDIS_URL') or 'redis://127.0.0.1' +API_BASE = environ.get('API_BASE') or 'https://core.discours.io' +AUTH_URL = environ.get('AUTH_URL') or 'https://auth.discours.io' +MODE = environ.get('MODE') or 'production' +SENTRY_DSN = environ.get('SENTRY_DSN') +DEV_SERVER_PID_FILE_NAME = 'dev-server.pid'