precommit

This commit is contained in:
Untone 2024-02-04 07:58:44 +03:00
parent 537b89dbaf
commit b98da839ed
10 changed files with 64 additions and 31 deletions

20
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,20 @@
fail_fast: true
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-added-large-files
- id: detect-private-key
- id: check-ast
- id: check-merge-conflict
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.13
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

View File

@ -1,14 +1,14 @@
from enum import Enum as Enumeration
from sqlalchemy import JSON as JSONType, func, cast
from sqlalchemy import Column, Enum, ForeignKey, Integer, String
from sqlalchemy import JSON as JSONType
from sqlalchemy import Column, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from sqlalchemy.orm.session import engine
from orm.author import Author
from services.db import Base
import time
class NotificationEntity(Enumeration):
REACTION = "reaction"
SHOUT = "shout"

View File

@ -13,7 +13,7 @@ python = "^3.12"
SQLAlchemy = "^2.0.22"
psycopg2-binary = "^2.9.9"
redis = {extras = ["hiredis"], version = "^5.0.1"}
uvicorn = "^0.24.0"
uvicorn = "^0.27.0"
strawberry-graphql = {extras = ["asgi", "debug-server"], version = "^0.216.1" }
strawberry-sqlalchemy-mapper = "^0.4.0"
sentry-sdk = "^1.37.1"
@ -23,10 +23,11 @@ aiohttp = "^3.9.1"
setuptools = "^69.0.2"
pytest = "^7.4.2"
black = { version = "^23.12.0", python = ">=3.12" }
ruff = { version = "^0.1.8", python = ">=3.12" }
ruff = { version = "^0.1.15", python = ">=3.12" }
mypy = { version = "^1.7", python = ">=3.12" }
isort = "^5.13.2"
pyright = "^1.1.341"
pre-commit = "^3.6.0"
[tool.black]
line-length = 120

View File

@ -1,11 +1,11 @@
from orm.notification import Notification, NotificationAction, NotificationEntity
from orm.notification import Notification
from resolvers.model import NotificationReaction, NotificationAuthor, NotificationShout
from services.db import local_session
from services.rediscache import redis
import asyncio
import logging
logger = logging.getLogger(f"[listener.listen_task] ")
logger = logging.getLogger("[listener.listen_task] ")
logger.setLevel(logging.DEBUG)

View File

@ -1,4 +1,3 @@
from sqlalchemy.sql import not_
from services.db import local_session
from resolvers.model import (
@ -10,10 +9,10 @@ from resolvers.model import (
)
from orm.notification import NotificationAction, NotificationEntity, NotificationSeen, Notification
from typing import Dict, List
import time, json
import time
import json
import strawberry
from sqlalchemy.orm import aliased
from sqlalchemy.sql.expression import or_
from sqlalchemy import select, and_
import logging
@ -62,14 +61,22 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
notifications_by_thread: Dict[str, List[Notification]] = {}
groups_by_thread: Dict[str, NotificationGroup] = {}
with local_session() as session:
total = session.query(Notification).filter(and_(Notification.action == NotificationAction.CREATE.value, Notification.created_at > after)).count()
unread = session.query(Notification).filter(
and_(
Notification.action == NotificationAction.CREATE.value,
Notification.created_at > after,
not_(Notification.seen)
total = (
session.query(Notification)
.filter(and_(Notification.action == NotificationAction.CREATE.value, Notification.created_at > after))
.count()
)
unread = (
session.query(Notification)
.filter(
and_(
Notification.action == NotificationAction.CREATE.value,
Notification.created_at > after,
not_(Notification.seen),
)
)
).count()
.count()
)
notifications_result = session.execute(query)
for n, seen in notifications_result:
thread_id = ""
@ -87,7 +94,7 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
updated_at=shout.created_at,
reactions=[],
action="create",
seen=author_id in n.seen
seen=author_id in n.seen,
)
# store group in result
groups_by_thread[thread_id] = group
@ -140,7 +147,7 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
authors=[
reaction.created_by,
],
seen=author_id in n.seen
seen=author_id in n.seen,
)
# store group in result
groups_by_thread[thread_id] = group
@ -153,16 +160,18 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
thread_id = "followers"
follower: NotificationAuthor = payload
group = groups_by_thread.get(thread_id) or NotificationGroup(
id=thread_id,
authors=[follower],
updated_at=int(time.time()),
shout=None,
reactions=[],
entity="follower",
action="follow",
seen=author_id in n.seen
)
group.authors = [follower, ]
id=thread_id,
authors=[follower],
updated_at=int(time.time()),
shout=None,
reactions=[],
entity="follower",
action="follow",
seen=author_id in n.seen,
)
group.authors = [
follower,
]
group.updated_at = int(time.time())
# store group in result
groups_by_thread[thread_id] = group

View File

@ -1,4 +1,3 @@
import strawberry
from strawberry.schema.config import StrawberryConfig

View File

@ -46,7 +46,7 @@ class Mutation:
ns = NotificationSeen(notification=n.id, viewer=author_id)
session.add(ns)
session.commit()
except SQLAlchemyError as e:
except SQLAlchemyError:
session.rollback()
except Exception as e:
print(e)

View File

@ -10,6 +10,7 @@ import logging
logger = logging.getLogger("\t[services.auth]\t")
logger.setLevel(logging.DEBUG)
async def check_auth(req) -> str | None:
token = req.headers.get("Authorization")
user_id = ""
@ -49,6 +50,7 @@ async def check_auth(req) -> str | None:
return user_id
except Exception as e:
import traceback
traceback.print_exc()
# Handling and logging exceptions during authentication check
print(f"[services.auth] Error {e}")

View File

@ -9,6 +9,7 @@ headers = {"Content-Type": "application/json"}
# TODO: rewrite to orm usage?
async def _request_endpoint(query_name, body) -> Any:
async with aiohttp.ClientSession() as session:
async with session.post(API_BASE, headers=headers, json=body) as response:

View File

@ -9,6 +9,7 @@ import logging
logger = logging.getLogger("\t[services.redis]\t")
logger.setLevel(logging.DEBUG)
class RedisCache:
def __init__(self, uri=REDIS_URL):
self._uri: str = uri