precommit

This commit is contained in:
Untone 2024-02-04 07:58:44 +03:00
parent 537b89dbaf
commit b98da839ed
10 changed files with 64 additions and 31 deletions

20
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,20 @@
fail_fast: true
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-added-large-files
- id: detect-private-key
- id: check-ast
- id: check-merge-conflict
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.13
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

View File

@ -1,14 +1,14 @@
from enum import Enum as Enumeration from enum import Enum as Enumeration
from sqlalchemy import JSON as JSONType, func, cast from sqlalchemy import JSON as JSONType
from sqlalchemy import Column, Enum, ForeignKey, Integer, String from sqlalchemy import Column, ForeignKey, Integer, String
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from sqlalchemy.orm.session import engine
from orm.author import Author from orm.author import Author
from services.db import Base from services.db import Base
import time import time
class NotificationEntity(Enumeration): class NotificationEntity(Enumeration):
REACTION = "reaction" REACTION = "reaction"
SHOUT = "shout" SHOUT = "shout"

View File

@ -13,7 +13,7 @@ python = "^3.12"
SQLAlchemy = "^2.0.22" SQLAlchemy = "^2.0.22"
psycopg2-binary = "^2.9.9" psycopg2-binary = "^2.9.9"
redis = {extras = ["hiredis"], version = "^5.0.1"} redis = {extras = ["hiredis"], version = "^5.0.1"}
uvicorn = "^0.24.0" uvicorn = "^0.27.0"
strawberry-graphql = {extras = ["asgi", "debug-server"], version = "^0.216.1" } strawberry-graphql = {extras = ["asgi", "debug-server"], version = "^0.216.1" }
strawberry-sqlalchemy-mapper = "^0.4.0" strawberry-sqlalchemy-mapper = "^0.4.0"
sentry-sdk = "^1.37.1" sentry-sdk = "^1.37.1"
@ -23,10 +23,11 @@ aiohttp = "^3.9.1"
setuptools = "^69.0.2" setuptools = "^69.0.2"
pytest = "^7.4.2" pytest = "^7.4.2"
black = { version = "^23.12.0", python = ">=3.12" } black = { version = "^23.12.0", python = ">=3.12" }
ruff = { version = "^0.1.8", python = ">=3.12" } ruff = { version = "^0.1.15", python = ">=3.12" }
mypy = { version = "^1.7", python = ">=3.12" } mypy = { version = "^1.7", python = ">=3.12" }
isort = "^5.13.2" isort = "^5.13.2"
pyright = "^1.1.341" pyright = "^1.1.341"
pre-commit = "^3.6.0"
[tool.black] [tool.black]
line-length = 120 line-length = 120

View File

@ -1,11 +1,11 @@
from orm.notification import Notification, NotificationAction, NotificationEntity from orm.notification import Notification
from resolvers.model import NotificationReaction, NotificationAuthor, NotificationShout from resolvers.model import NotificationReaction, NotificationAuthor, NotificationShout
from services.db import local_session from services.db import local_session
from services.rediscache import redis from services.rediscache import redis
import asyncio import asyncio
import logging import logging
logger = logging.getLogger(f"[listener.listen_task] ") logger = logging.getLogger("[listener.listen_task] ")
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)

View File

@ -1,4 +1,3 @@
from sqlalchemy.sql import not_ from sqlalchemy.sql import not_
from services.db import local_session from services.db import local_session
from resolvers.model import ( from resolvers.model import (
@ -10,10 +9,10 @@ from resolvers.model import (
) )
from orm.notification import NotificationAction, NotificationEntity, NotificationSeen, Notification from orm.notification import NotificationAction, NotificationEntity, NotificationSeen, Notification
from typing import Dict, List from typing import Dict, List
import time, json import time
import json
import strawberry import strawberry
from sqlalchemy.orm import aliased from sqlalchemy.orm import aliased
from sqlalchemy.sql.expression import or_
from sqlalchemy import select, and_ from sqlalchemy import select, and_
import logging import logging
@ -62,14 +61,22 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
notifications_by_thread: Dict[str, List[Notification]] = {} notifications_by_thread: Dict[str, List[Notification]] = {}
groups_by_thread: Dict[str, NotificationGroup] = {} groups_by_thread: Dict[str, NotificationGroup] = {}
with local_session() as session: with local_session() as session:
total = session.query(Notification).filter(and_(Notification.action == NotificationAction.CREATE.value, Notification.created_at > after)).count() total = (
unread = session.query(Notification).filter( session.query(Notification)
.filter(and_(Notification.action == NotificationAction.CREATE.value, Notification.created_at > after))
.count()
)
unread = (
session.query(Notification)
.filter(
and_( and_(
Notification.action == NotificationAction.CREATE.value, Notification.action == NotificationAction.CREATE.value,
Notification.created_at > after, Notification.created_at > after,
not_(Notification.seen) not_(Notification.seen),
)
)
.count()
) )
).count()
notifications_result = session.execute(query) notifications_result = session.execute(query)
for n, seen in notifications_result: for n, seen in notifications_result:
thread_id = "" thread_id = ""
@ -87,7 +94,7 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
updated_at=shout.created_at, updated_at=shout.created_at,
reactions=[], reactions=[],
action="create", action="create",
seen=author_id in n.seen seen=author_id in n.seen,
) )
# store group in result # store group in result
groups_by_thread[thread_id] = group groups_by_thread[thread_id] = group
@ -140,7 +147,7 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
authors=[ authors=[
reaction.created_by, reaction.created_by,
], ],
seen=author_id in n.seen seen=author_id in n.seen,
) )
# store group in result # store group in result
groups_by_thread[thread_id] = group groups_by_thread[thread_id] = group
@ -160,9 +167,11 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
reactions=[], reactions=[],
entity="follower", entity="follower",
action="follow", action="follow",
seen=author_id in n.seen seen=author_id in n.seen,
) )
group.authors = [follower, ] group.authors = [
follower,
]
group.updated_at = int(time.time()) group.updated_at = int(time.time())
# store group in result # store group in result
groups_by_thread[thread_id] = group groups_by_thread[thread_id] = group

View File

@ -1,4 +1,3 @@
import strawberry import strawberry
from strawberry.schema.config import StrawberryConfig from strawberry.schema.config import StrawberryConfig

View File

@ -46,7 +46,7 @@ class Mutation:
ns = NotificationSeen(notification=n.id, viewer=author_id) ns = NotificationSeen(notification=n.id, viewer=author_id)
session.add(ns) session.add(ns)
session.commit() session.commit()
except SQLAlchemyError as e: except SQLAlchemyError:
session.rollback() session.rollback()
except Exception as e: except Exception as e:
print(e) print(e)

View File

@ -10,6 +10,7 @@ import logging
logger = logging.getLogger("\t[services.auth]\t") logger = logging.getLogger("\t[services.auth]\t")
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
async def check_auth(req) -> str | None: async def check_auth(req) -> str | None:
token = req.headers.get("Authorization") token = req.headers.get("Authorization")
user_id = "" user_id = ""
@ -49,6 +50,7 @@ async def check_auth(req) -> str | None:
return user_id return user_id
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
# Handling and logging exceptions during authentication check # Handling and logging exceptions during authentication check
print(f"[services.auth] Error {e}") print(f"[services.auth] Error {e}")

View File

@ -9,6 +9,7 @@ headers = {"Content-Type": "application/json"}
# TODO: rewrite to orm usage? # TODO: rewrite to orm usage?
async def _request_endpoint(query_name, body) -> Any: async def _request_endpoint(query_name, body) -> Any:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(API_BASE, headers=headers, json=body) as response: async with session.post(API_BASE, headers=headers, json=body) as response:

View File

@ -9,6 +9,7 @@ import logging
logger = logging.getLogger("\t[services.redis]\t") logger = logging.getLogger("\t[services.redis]\t")
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
class RedisCache: class RedisCache:
def __init__(self, uri=REDIS_URL): def __init__(self, uri=REDIS_URL):
self._uri: str = uri self._uri: str = uri