fixed-fmt-linted
Some checks failed
deploy to v2 / test (push) Failing after 49s
deploy to v2 / deploy (push) Has been skipped

This commit is contained in:
Untone 2024-02-17 02:56:15 +03:00
parent 2163e85b16
commit 8b39b47714
17 changed files with 384 additions and 391 deletions

1
.gitignore vendored
View File

@ -6,3 +6,4 @@ __pycache__
poetry.lock
.venv
.ruff_cache
.pytest_cache

View File

@ -1,20 +1,19 @@
# Use an official Python runtime as a parent image
FROM python:3.12-slim
FROM python:3.12-alpine
# Set the working directory in the container to /app
WORKDIR /app
# Add metadata to the image to describe that the container is listening on port 80
# Add metadata to the image to describe that the container is listening on port 8000
EXPOSE 8000
# Copy the current directory contents into the container at /app
COPY . /app
# Install any needed packages specified in pyproject.toml
RUN apt-get update && apt-get install -y gcc curl && \
RUN apk update && apk add --no-cache gcc curl && \
curl -sSL https://install.python-poetry.org | python - && \
echo "export PATH=$PATH:/root/.local/bin" >> ~/.bashrc && \
. ~/.bashrc && \
export PATH=$PATH:/root/.local/bin && \
poetry config virtualenvs.create false && \
poetry install --no-dev

12
main.py
View File

@ -1,4 +1,5 @@
import asyncio
import logging
import os
from os.path import exists
@ -13,11 +14,10 @@ from resolvers.listener import notifications_worker
from resolvers.schema import schema
from services.rediscache import redis
from settings import DEV_SERVER_PID_FILE_NAME, MODE, SENTRY_DSN
import logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("\t[main]\t")
logger = logging.getLogger('\t[main]\t')
logger.setLevel(logging.DEBUG)
@ -27,9 +27,9 @@ async def start_up():
task = asyncio.create_task(notifications_worker())
logger.info(task)
if MODE == "dev":
if MODE == 'dev':
if exists(DEV_SERVER_PID_FILE_NAME):
with open(DEV_SERVER_PID_FILE_NAME, "w", encoding="utf-8") as f:
with open(DEV_SERVER_PID_FILE_NAME, 'w', encoding='utf-8') as f:
f.write(str(os.getpid()))
else:
try:
@ -46,7 +46,7 @@ async def start_up():
],
)
except Exception as e:
logger.error("sentry init error", e)
logger.error('sentry init error', e)
async def shutdown():
@ -54,4 +54,4 @@ async def shutdown():
app = Starlette(debug=True, on_startup=[start_up], on_shutdown=[shutdown])
app.mount("/", GraphQL(schema, debug=True))
app.mount('/', GraphQL(schema, debug=True))

View File

@ -1,46 +1,45 @@
import time
from sqlalchemy import JSON as JSONType
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String
from sqlalchemy import JSON, Boolean, Column, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from services.db import Base
class AuthorRating(Base):
__tablename__ = "author_rating"
__tablename__ = 'author_rating'
id = None # type: ignore
rater = Column(ForeignKey("author.id"), primary_key=True, index=True)
author = Column(ForeignKey("author.id"), primary_key=True, index=True)
rater = Column(ForeignKey('author.id'), primary_key=True, index=True)
author = Column(ForeignKey('author.id'), primary_key=True, index=True)
plus = Column(Boolean)
class AuthorFollower(Base):
__tablename__ = "author_follower"
__tablename__ = 'author_follower'
id = None # type: ignore
follower = Column(ForeignKey("author.id"), primary_key=True, index=True)
author = Column(ForeignKey("author.id"), primary_key=True, index=True)
follower = Column(ForeignKey('author.id'), primary_key=True, index=True)
author = Column(ForeignKey('author.id'), primary_key=True, index=True)
created_at = Column(Integer, nullable=False, default=lambda: int(time.time()))
auto = Column(Boolean, nullable=False, default=False)
class Author(Base):
__tablename__ = "author"
__tablename__ = 'author'
user = Column(String, unique=True) # unbounded link with authorizer's User type
name = Column(String, nullable=True, comment="Display name")
name = Column(String, nullable=True, comment='Display name')
slug = Column(String, unique=True, comment="Author's slug")
bio = Column(String, nullable=True, comment="Bio") # status description
about = Column(String, nullable=True, comment="About") # long and formatted
pic = Column(String, nullable=True, comment="Picture")
links = Column(JSONType, nullable=True, comment="Links")
bio = Column(String, nullable=True, comment='Bio') # status description
about = Column(String, nullable=True, comment='About') # long and formatted
pic = Column(String, nullable=True, comment='Picture')
links = Column(JSON, nullable=True, comment='Links')
ratings = relationship(AuthorRating, foreign_keys=AuthorRating.author)
created_at = Column(Integer, nullable=False, default=lambda: int(time.time()))
last_seen = Column(Integer, nullable=False, default=lambda: int(time.time()))
updated_at = Column(Integer, nullable=False, default=lambda: int(time.time()))
deleted_at = Column(Integer, nullable=True, comment="Deleted at")
deleted_at = Column(Integer, nullable=True, comment='Deleted at')

View File

@ -1,42 +1,41 @@
import time
from enum import Enum as Enumeration
from sqlalchemy import JSON as JSONType
from sqlalchemy import Column, ForeignKey, Integer, String
from sqlalchemy import JSON, Column, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from orm.author import Author
from services.db import Base
import time
class NotificationEntity(Enumeration):
REACTION = "reaction"
SHOUT = "shout"
FOLLOWER = "follower"
REACTION = 'reaction'
SHOUT = 'shout'
FOLLOWER = 'follower'
class NotificationAction(Enumeration):
CREATE = "create"
UPDATE = "update"
DELETE = "delete"
SEEN = "seen"
FOLLOW = "follow"
UNFOLLOW = "unfollow"
CREATE = 'create'
UPDATE = 'update'
DELETE = 'delete'
SEEN = 'seen'
FOLLOW = 'follow'
UNFOLLOW = 'unfollow'
class NotificationSeen(Base):
__tablename__ = "notification_seen"
__tablename__ = 'notification_seen'
viewer = Column(ForeignKey("author.id"))
notification = Column(ForeignKey("notification.id"))
viewer = Column(ForeignKey('author.id'))
notification = Column(ForeignKey('notification.id'))
class Notification(Base):
__tablename__ = "notification"
__tablename__ = 'notification'
created_at = Column(Integer, server_default=str(int(time.time())))
entity = Column(String, nullable=False)
action = Column(String, nullable=False)
payload = Column(JSONType, nullable=True)
payload = Column(JSON, nullable=True)
seen = relationship(lambda: Author, secondary="notification_seen")
seen = relationship(lambda: Author, secondary='notification_seen')

View File

@ -1,12 +1,8 @@
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "discoursio-notifier"
version = "0.2.19"
version = "0.3.0"
description = "notifier server for discours.io"
authors = ["discours.io devteam"]
authors = ["Tony Rewin <anton.rewin@gmail.com>"]
[tool.poetry.dependencies]
python = "^3.12"
@ -21,48 +17,68 @@ granian = "^1.0.2"
[tool.poetry.group.dev.dependencies]
setuptools = "^69.0.2"
pytest = "^7.4.2"
black = { version = "^23.12.0", python = ">=3.12" }
ruff = { version = "^0.1.15", python = ">=3.12" }
mypy = { version = "^1.7", python = ">=3.12" }
isort = "^5.13.2"
pyright = "^1.1.341"
pre-commit = "^3.6.0"
pytest-asyncio = "^0.23.4"
pytest-cov = "^4.1.0"
mypy = "^1.7.1"
ruff = "^0.1.15"
black = "^23.12.0"
pytest = "^7.4.3"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.ruff]
line-length = 120
extend-select = [
# E and F are enabled by default
'B', # flake8-bugbear
'C4', # flake8-comprehensions
'C90', # mccabe
'I', # isort
'N', # pep8-naming
'Q', # flake8-quotes
'RUF100', # ruff (unused noqa)
'S', # flake8-bandit
'W', # pycodestyle
]
extend-ignore = [
'B008', # function calls in args defaults are fine
'B009', # getattr with constants is fine
'B034', # re.split won't confuse us
'B904', # rising without from is fine
'E501', # leave line length to black
'N818', # leave to us exceptions naming
'S101', # assert is fine
'RUF100', # black's noqa
]
flake8-quotes = { inline-quotes = 'single', multiline-quotes = 'double' }
mccabe = { max-complexity = 13 }
target-version = "py312"
[tool.ruff.format]
quote-style = 'single'
[tool.black]
line-length = 120
target-version = ['py312']
include = '\.pyi?$'
exclude = '''
skip-string-normalization = true
(
/(
\.eggs # exclude a few common directories in the
| \.git # root of the project
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
| foo.py # also separately exclude a file named foo.py in
# the root of the project
)
'''
[tool.ruff.isort]
combine-as-imports = true
lines-after-imports = 2
known-first-party = ['resolvers', 'services', 'orm', 'tests']
[tool.isort]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 120
[tool.ruff.per-file-ignores]
'tests/**' = ['B018', 'S110', 'S501']
[tool.mypy]
python_version = "3.12"
warn_return_any = true
warn_unused_configs = true
ignore_missing_imports = true
exclude = ["nb"]
[tool.pytest.ini_options]
asyncio_mode = 'auto'
[tool.pyright]
venvPath = "."
@ -90,27 +106,3 @@ logLevel = "Information"
pluginSearchPaths = []
typings = {}
mergeTypeStubPackages = false
[tool.mypy]
python_version = "3.12"
warn_unused_configs = true
plugins = ["mypy_sqlalchemy.plugin", "strawberry.ext.mypy_plugin"]
[tool.ruff]
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
# McCabe complexity (`C901`) by default.
select = ["E4", "E7", "E9", "F"]
ignore = []
line-length = 120
target-version = "py312"
[tool.pytest.ini_options]
pythonpath = [
"."
]
[tool.pytest]
python_files = "*_test.py"

View File

@ -1,11 +1,13 @@
from orm.notification import Notification
from resolvers.model import NotificationReaction, NotificationAuthor, NotificationShout
from services.db import local_session
from services.rediscache import redis
import asyncio
import logging
logger = logging.getLogger("[listener.listen_task] ")
from orm.notification import Notification
from resolvers.model import NotificationAuthor, NotificationReaction, NotificationShout
from services.db import local_session
from services.rediscache import redis
logger = logging.getLogger('[listener.listen_task] ')
logger.setLevel(logging.DEBUG)
@ -19,8 +21,8 @@ async def handle_notification(n: ServiceMessage, channel: str):
"""создаеёт новое хранимое уведомление"""
with local_session() as session:
try:
if channel.startswith("follower:"):
author_id = int(channel.split(":")[1])
if channel.startswith('follower:'):
author_id = int(channel.split(':')[1])
if isinstance(n.payload, NotificationAuthor):
n.payload.following_id = author_id
n = Notification(action=n.action, entity=n.entity, payload=n.payload)
@ -28,7 +30,7 @@ async def handle_notification(n: ServiceMessage, channel: str):
session.commit()
except Exception as e:
session.rollback()
logger.error(f"[listener.handle_reaction] error: {str(e)}")
logger.error(f'[listener.handle_reaction] error: {str(e)}')
async def listen_task(pattern):
@ -38,9 +40,9 @@ async def listen_task(pattern):
notification_message = ServiceMessage(**message_data)
await handle_notification(notification_message, str(channel))
except Exception as e:
logger.error(f"Error processing notification: {str(e)}")
logger.error(f'Error processing notification: {str(e)}')
async def notifications_worker():
# Use asyncio.gather to run tasks concurrently
await asyncio.gather(listen_task("follower:*"), listen_task("reaction"), listen_task("shout"))
await asyncio.gather(listen_task('follower:*'), listen_task('reaction'), listen_task('shout'))

View File

@ -1,53 +1,31 @@
import json
import logging
import time
from typing import Dict, List, Tuple, Union
import strawberry
from sqlalchemy import and_, select
from sqlalchemy.orm import aliased
from sqlalchemy.sql import not_
from services.db import local_session
from orm.notification import Notification, NotificationAction, NotificationEntity, NotificationSeen
from resolvers.model import (
NotificationReaction,
NotificationGroup,
NotificationShout,
NotificationAuthor,
NotificationGroup,
NotificationReaction,
NotificationShout,
NotificationsResult,
)
from orm.notification import NotificationAction, NotificationEntity, NotificationSeen, Notification
from typing import Dict, List
import time
import json
import strawberry
from sqlalchemy.orm import aliased
from sqlalchemy import select, and_
import logging
from services.db import local_session
logger = logging.getLogger("[resolvers.schema] ")
logger = logging.getLogger('[resolvers.schema]')
logger.setLevel(logging.DEBUG)
async def get_notifications_grouped(author_id: int, after: int = 0, limit: int = 10, offset: int = 0):
"""
Retrieves notifications for a given author.
Args:
author_id (int): The ID of the author for whom notifications are retrieved.
after (int, optional): If provided, selects only notifications created after this timestamp will be considered.
limit (int, optional): The maximum number of groupa to retrieve.
offset (int, optional): Offset for pagination
Returns:
Dict[str, NotificationGroup], int, int: A dictionary where keys are thread IDs and values are NotificationGroup objects, unread and total amounts.
This function queries the database to retrieve notifications for the specified author, considering optional filters.
The result is a dictionary where each key is a thread ID, and the corresponding value is a NotificationGroup
containing information about the notifications within that thread.
NotificationGroup structure:
{
entity: str, # Type of entity (e.g., 'reaction', 'shout', 'follower').
updated_at: int, # Timestamp of the latest update in the thread.
shout: Optional[NotificationShout]
reactions: List[int], # List of reaction ids within the thread.
authors: List[NotificationAuthor], # List of authors involved in the thread.
}
"""
NotificationSeenAlias = aliased(NotificationSeen)
query = select(Notification, NotificationSeenAlias.viewer.label("seen")).outerjoin(
def query_notifications(author_id: int, after: int = 0) -> Tuple[int, int, List[Tuple[Notification, bool]]]:
notification_seen_alias = aliased(NotificationSeen)
query = select(Notification, notification_seen_alias.viewer.label('seen')).outerjoin(
NotificationSeen,
and_(NotificationSeen.viewer == author_id, NotificationSeen.notification == Notification.id),
)
@ -55,17 +33,13 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
query = query.filter(Notification.created_at > after)
query = query.group_by(NotificationSeen.notification, Notification.created_at)
groups_amount = 0
unread = 0
total = 0
notifications_by_thread: Dict[str, List[Notification]] = {}
groups_by_thread: Dict[str, NotificationGroup] = {}
with local_session() as session:
total = (
session.query(Notification)
.filter(and_(Notification.action == NotificationAction.CREATE.value, Notification.created_at > after))
.count()
)
unread = (
session.query(Notification)
.filter(
@ -77,123 +51,140 @@ async def get_notifications_grouped(author_id: int, after: int = 0, limit: int =
)
.count()
)
notifications_result = session.execute(query)
notifications = []
for n, seen in notifications_result:
thread_id = ""
payload = json.loads(n.payload)
logger.debug(f"[resolvers.schema] {n.action} {n.entity}: {payload}")
if n.entity == "shout" and n.action == "create":
notifications.append((n, seen))
return total, unread, notifications
def process_shout_notification(
notification: Notification, seen: bool
) -> Union[Tuple[str, NotificationGroup], None] | None:
if not isinstance(notification.payload, str) or not isinstance(notification.entity, str):
return
payload = json.loads(notification.payload)
shout: NotificationShout = payload
thread_id += f"{shout.id}"
logger.debug(f"create shout: {shout}")
group = groups_by_thread.get(thread_id) or NotificationGroup(
thread_id = str(shout.id)
group = NotificationGroup(
id=thread_id,
entity=n.entity,
entity=notification.entity,
shout=shout,
authors=shout.authors,
updated_at=shout.created_at,
reactions=[],
action="create",
seen=author_id in n.seen,
action='create',
seen=seen,
)
# store group in result
groups_by_thread[thread_id] = group
notifications = notifications_by_thread.get(thread_id, [])
if n not in notifications:
notifications.append(n)
notifications_by_thread[thread_id] = notifications
groups_amount += 1
elif n.entity == NotificationEntity.REACTION.value and n.action == NotificationAction.CREATE.value:
return thread_id, group
def process_reaction_notification(
notification: Notification, seen: bool
) -> Union[Tuple[str, NotificationGroup], None] | None:
if (
not isinstance(notification, Notification)
or not isinstance(notification.payload, str)
or not isinstance(notification.entity, str)
):
return
payload = json.loads(notification.payload)
reaction: NotificationReaction = payload
shout: NotificationShout = reaction.shout
thread_id += f"{reaction.shout}"
if reaction.kind == "LIKE" or reaction.kind == "DISLIKE":
# TODO: making published reaction vote announce
pass
elif reaction.kind == "COMMENT":
if reaction.reply_to:
thread_id += f"{'::' + str(reaction.reply_to)}"
group: NotificationGroup | None = groups_by_thread.get(thread_id)
notifications: List[Notification] = notifications_by_thread.get(thread_id, [])
if group and notifications:
group.seen = False # any not seen notification make it false
group.shout = shout
group.authors.append(reaction.created_by)
if not group.reactions:
group.reactions = []
group.reactions.append(reaction.id)
# store group in result
groups_by_thread[thread_id] = group
notifications = notifications_by_thread.get(thread_id, [])
if n not in notifications:
notifications.append(n)
notifications_by_thread[thread_id] = notifications
groups_amount += 1
else:
groups_amount += 1
if groups_amount > limit:
break
else:
# init notification group
reactions = []
reactions.append(reaction.id)
thread_id = str(reaction.shout)
if reaction.kind == 'COMMENT' and reaction.reply_to:
thread_id += f'::{reaction.reply_to}'
group = NotificationGroup(
id=thread_id,
action=n.action,
entity=n.entity,
action=str(notification.action),
entity=notification.entity,
updated_at=reaction.created_at,
reactions=reactions,
reactions=[reaction.id],
shout=shout,
authors=[
reaction.created_by,
],
seen=author_id in n.seen,
authors=[reaction.created_by],
seen=seen,
)
# store group in result
groups_by_thread[thread_id] = group
notifications = notifications_by_thread.get(thread_id, [])
if n not in notifications:
notifications.append(n)
notifications_by_thread[thread_id] = notifications
return thread_id, group
elif n.entity == "follower":
thread_id = "followers"
def process_follower_notification(
notification: Notification, seen: bool
) -> Union[Tuple[str, NotificationGroup], None] | None:
if not isinstance(notification.payload, str):
return
payload = json.loads(notification.payload)
follower: NotificationAuthor = payload
group = groups_by_thread.get(thread_id) or NotificationGroup(
thread_id = 'followers'
group = NotificationGroup(
id=thread_id,
authors=[follower],
updated_at=int(time.time()),
shout=None,
reactions=[],
entity="follower",
action="follow",
seen=author_id in n.seen,
entity='follower',
action='follow',
seen=seen,
)
group.authors = [
follower,
]
group.updated_at = int(time.time())
# store group in result
groups_by_thread[thread_id] = group
notifications = notifications_by_thread.get(thread_id, [])
if n not in notifications:
notifications.append(n)
notifications_by_thread[thread_id] = notifications
groups_amount += 1
return thread_id, group
if groups_amount > limit:
async def get_notifications_grouped(
author_id: int, after: int = 0, limit: int = 10
) -> Tuple[Dict[str, NotificationGroup], int, int]:
total, unread, notifications = query_notifications(author_id, after)
groups_by_thread: Dict[str, NotificationGroup] = {}
groups_amount = 0
for notification, seen in notifications:
if groups_amount >= limit:
break
return groups_by_thread, notifications_by_thread, unread, total
if str(notification.entity) == 'shout' and str(notification.action) == 'create':
result = process_shout_notification(notification, seen)
if result:
thread_id, group = result
groups_by_thread[thread_id] = group
groups_amount += 1
elif (
str(notification.entity) == NotificationEntity.REACTION.value
and str(notification.action) == NotificationAction.CREATE.value
):
result = process_reaction_notification(notification, seen)
if result:
thread_id, group = result
existing_group = groups_by_thread.get(thread_id)
if existing_group:
existing_group.seen = False
existing_group.shout = group.shout
existing_group.authors.append(group.authors[0])
if not existing_group.reactions:
existing_group.reactions = []
existing_group.reactions.extend(group.reactions or [])
groups_by_thread[thread_id] = existing_group
else:
groups_by_thread[thread_id] = group
groups_amount += 1
elif str(notification.entity) == 'follower':
result = process_follower_notification(notification, seen)
if result:
thread_id, group = result
groups_by_thread[thread_id] = group
groups_amount += 1
return groups_by_thread, unread, total
@strawberry.type
class Query:
@strawberry.field
async def load_notifications(self, info, after: int, limit: int = 50, offset: int = 0) -> NotificationsResult:
author_id = info.context.get("author_id")
groups: Dict[str, NotificationGroup] = {}
author_id = info.context.get('author_id')
if author_id:
groups, notifications, total, unread = await get_notifications_grouped(author_id, after, limit, offset)
groups, unread, total = await get_notifications_grouped(author_id, after, limit)
notifications = sorted(groups.values(), key=lambda group: group.updated_at, reverse=True)
return NotificationsResult(notifications=notifications, total=0, unread=0, error=None)
return NotificationsResult(notifications=notifications, total=total, unread=unread, error=None)
return NotificationsResult(notifications=[], total=0, unread=0, error=None)

View File

@ -1,8 +1,11 @@
import strawberry
from typing import List, Optional
import strawberry
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
from orm.notification import Notification as NotificationMessage
strawberry_sqlalchemy_mapper = StrawberrySQLAlchemyMapper()

View File

@ -1,11 +1,12 @@
import strawberry
from strawberry.schema.config import StrawberryConfig
from services.auth import LoginRequiredMiddleware
from resolvers.load import Query
from resolvers.seen import Mutation
from services.auth import LoginRequiredMiddleware
from services.db import Base, engine
schema = strawberry.Schema(
query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=False), extensions=[LoginRequiredMiddleware]
)

View File

@ -1,14 +1,14 @@
from sqlalchemy import and_
from orm.notification import NotificationSeen
from services.db import local_session
from resolvers.model import Notification, NotificationSeenResult, NotificationReaction
import json
import logging
import strawberry
import logging
import json
from sqlalchemy import and_
from sqlalchemy.exc import SQLAlchemyError
from orm.notification import NotificationSeen
from resolvers.model import Notification, NotificationReaction, NotificationSeenResult
from services.db import local_session
logger = logging.getLogger(__name__)
@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
class Mutation:
@strawberry.mutation
async def mark_seen(self, info, notification_id: int) -> NotificationSeenResult:
author_id = info.context.get("author_id")
author_id = info.context.get('author_id')
if author_id:
with local_session() as session:
try:
@ -27,9 +27,9 @@ class Mutation:
except SQLAlchemyError as e:
session.rollback()
logger.error(
f"[mark_notification_as_read] Ошибка при обновлении статуса прочтения уведомления: {str(e)}"
f'[mark_notification_as_read] Ошибка при обновлении статуса прочтения уведомления: {str(e)}'
)
return NotificationSeenResult(error="cant mark as read")
return NotificationSeenResult(error='cant mark as read')
return NotificationSeenResult(error=None)
@strawberry.mutation
@ -37,7 +37,7 @@ class Mutation:
# TODO: use latest loaded notification_id as input offset parameter
error = None
try:
author_id = info.context.get("author_id")
author_id = info.context.get('author_id')
if author_id:
with local_session() as session:
nnn = session.query(Notification).filter(and_(Notification.created_at > after)).all()
@ -50,22 +50,22 @@ class Mutation:
session.rollback()
except Exception as e:
print(e)
error = "cant mark as read"
error = 'cant mark as read'
return NotificationSeenResult(error=error)
@strawberry.mutation
async def mark_seen_thread(self, info, thread: str, after: int) -> NotificationSeenResult:
error = None
author_id = info.context.get("author_id")
author_id = info.context.get('author_id')
if author_id:
[shout_id, reply_to_id] = thread.split("::")
[shout_id, reply_to_id] = thread.split('::')
with local_session() as session:
# TODO: handle new follower and new shout notifications
new_reaction_notifications = (
session.query(Notification)
.filter(
Notification.action == "create",
Notification.entity == "reaction",
Notification.action == 'create',
Notification.entity == 'reaction',
Notification.created_at > after,
)
.all()
@ -73,13 +73,13 @@ class Mutation:
removed_reaction_notifications = (
session.query(Notification)
.filter(
Notification.action == "delete",
Notification.entity == "reaction",
Notification.action == 'delete',
Notification.entity == 'reaction',
Notification.created_at > after,
)
.all()
)
exclude = set([])
exclude = set()
for nr in removed_reaction_notifications:
reaction: NotificationReaction = json.loads(nr.payload)
exclude.add(reaction.id)
@ -97,5 +97,5 @@ class Mutation:
except Exception:
session.rollback()
else:
error = "You are not logged in"
error = 'You are not logged in'
return NotificationSeenResult(error=error)

View File

@ -1,53 +1,54 @@
import sys
import logging
import sys
from settings import PORT
log_settings = {
"version": 1,
"disable_existing_loggers": True,
"formatters": {
"default": {
"()": "uvicorn.logging.DefaultFormatter",
"fmt": "%(levelprefix)s %(message)s",
"use_colors": None,
'version': 1,
'disable_existing_loggers': True,
'formatters': {
'default': {
'()': 'uvicorn.logging.DefaultFormatter',
'fmt': '%(levelprefix)s %(message)s',
'use_colors': None,
},
"access": {
"()": "uvicorn.logging.AccessFormatter",
"fmt": '%(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s',
'access': {
'()': 'uvicorn.logging.AccessFormatter',
'fmt': '%(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s',
},
},
"handlers": {
"default": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
'handlers': {
'default': {
'formatter': 'default',
'class': 'logging.StreamHandler',
'stream': 'ext://sys.stderr',
},
"access": {
"formatter": "access",
"class": "logging.StreamHandler",
"stream": "ext://sys.stdout",
'access': {
'formatter': 'access',
'class': 'logging.StreamHandler',
'stream': 'ext://sys.stdout',
},
},
"loggers": {
"uvicorn": {"handlers": ["default"], "level": "INFO"},
"uvicorn.error": {"level": "INFO", "handlers": ["default"], "propagate": True},
"uvicorn.access": {"handlers": ["access"], "level": "INFO", "propagate": False},
'loggers': {
'uvicorn': {'handlers': ['default'], 'level': 'INFO'},
'uvicorn.error': {'level': 'INFO', 'handlers': ['default'], 'propagate': True},
'uvicorn.access': {'handlers': ['access'], 'level': 'INFO', 'propagate': False},
},
}
local_headers = [
("Access-Control-Allow-Methods", "GET, POST, OPTIONS, HEAD"),
("Access-Control-Allow-Origin", "https://localhost:3000"),
('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, HEAD'),
('Access-Control-Allow-Origin', 'https://localhost:3000'),
(
"Access-Control-Allow-Headers",
"DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization",
'Access-Control-Allow-Headers',
'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization',
),
("Access-Control-Expose-Headers", "Content-Length,Content-Range"),
("Access-Control-Allow-Credentials", "true"),
('Access-Control-Expose-Headers', 'Content-Length,Content-Range'),
('Access-Control-Allow-Credentials', 'true'),
]
logger = logging.getLogger("[server] ")
logger = logging.getLogger('[server] ')
logger.setLevel(logging.DEBUG)
@ -55,16 +56,16 @@ def exception_handler(_et, exc, _tb):
logger.error(..., exc_info=(type(exc), exc, exc.__traceback__))
if __name__ == "__main__":
if __name__ == '__main__':
sys.excepthook = exception_handler
from granian.constants import Interfaces
from granian.server import Granian
print("[server] started")
print('[server] started')
granian_instance = Granian(
"main:app",
address="0.0.0.0", # noqa S104
'main:app',
address='0.0.0.0', # noqa S104
port=PORT,
workers=2,
threads=2,

View File

@ -1,59 +1,60 @@
import logging
from aiohttp import ClientSession
from strawberry.extensions import Extension
from settings import AUTH_URL
from services.db import local_session
from orm.author import Author
from services.db import local_session
from settings import AUTH_URL
import logging
logger = logging.getLogger("\t[services.auth]\t")
logger = logging.getLogger('\t[services.auth]\t')
logger.setLevel(logging.DEBUG)
async def check_auth(req) -> str | None:
token = req.headers.get("Authorization")
user_id = ""
token = req.headers.get('Authorization')
user_id = ''
if token:
query_name = "validate_jwt_token"
operation = "ValidateToken"
query_name = 'validate_jwt_token'
operation = 'ValidateToken'
headers = {
"Content-Type": "application/json",
'Content-Type': 'application/json',
}
variables = {
"params": {
"token_type": "access_token",
"token": token,
'params': {
'token_type': 'access_token',
'token': token,
}
}
gql = {
"query": f"query {operation}($params: ValidateJWTTokenInput!) {{ {query_name}(params: $params) {{ is_valid claims }} }}",
"variables": variables,
"operationName": operation,
'query': f'query {operation}($params: ValidateJWTTokenInput!) {{ {query_name}(params: $params) {{ is_valid claims }} }}',
'variables': variables,
'operationName': operation,
}
try:
# Asynchronous HTTP request to the authentication server
async with ClientSession() as session:
async with session.post(AUTH_URL, json=gql, headers=headers) as response:
print(f"[services.auth] HTTP Response {response.status} {await response.text()}")
print(f'[services.auth] HTTP Response {response.status} {await response.text()}')
if response.status == 200:
data = await response.json()
errors = data.get("errors")
errors = data.get('errors')
if errors:
print(f"[services.auth] errors: {errors}")
print(f'[services.auth] errors: {errors}')
else:
user_id = data.get("data", {}).get(query_name, {}).get("claims", {}).get("sub")
user_id = data.get('data', {}).get(query_name, {}).get('claims', {}).get('sub')
if user_id:
print(f"[services.auth] got user_id: {user_id}")
print(f'[services.auth] got user_id: {user_id}')
return user_id
except Exception as e:
import traceback
traceback.print_exc()
# Handling and logging exceptions during authentication check
print(f"[services.auth] Error {e}")
print(f'[services.auth] Error {e}')
return None
@ -61,14 +62,14 @@ async def check_auth(req) -> str | None:
class LoginRequiredMiddleware(Extension):
async def on_request_start(self):
context = self.execution_context.context
req = context.get("request")
req = context.get('request')
user_id = await check_auth(req)
if user_id:
context["user_id"] = user_id.strip()
context['user_id'] = user_id.strip()
with local_session() as session:
author = session.query(Author).filter(Author.user == user_id).first()
if author:
context["author_id"] = author.id
context["user_id"] = user_id or None
context['author_id'] = author.id
context['user_id'] = user_id or None
self.execution_context.context = context

View File

@ -4,7 +4,8 @@ import aiohttp
from settings import API_BASE
headers = {"Content-Type": "application/json"}
headers = {'Content-Type': 'application/json'}
# TODO: rewrite to orm usage?
@ -13,39 +14,39 @@ headers = {"Content-Type": "application/json"}
async def _request_endpoint(query_name, body) -> Any:
async with aiohttp.ClientSession() as session:
async with session.post(API_BASE, headers=headers, json=body) as response:
print(f"[services.core] {query_name} HTTP Response {response.status} {await response.text()}")
print(f'[services.core] {query_name} HTTP Response {response.status} {await response.text()}')
if response.status == 200:
r = await response.json()
if r:
return r.get("data", {}).get(query_name, {})
return r.get('data', {}).get(query_name, {})
return []
async def get_followed_shouts(author_id: int):
query_name = "load_shouts_followed"
operation = "GetFollowedShouts"
query_name = 'load_shouts_followed'
operation = 'GetFollowedShouts'
query = f"""query {operation}($author_id: Int!, limit: Int, offset: Int) {{
{query_name}(author_id: $author_id, limit: $limit, offset: $offset) {{ id slug title }}
}}"""
gql = {
"query": query,
"operationName": operation,
"variables": {"author_id": author_id, "limit": 1000, "offset": 0}, # FIXME: too big limit
'query': query,
'operationName': operation,
'variables': {'author_id': author_id, 'limit': 1000, 'offset': 0}, # FIXME: too big limit
}
return await _request_endpoint(query_name, gql)
async def get_shout(shout_id):
query_name = "get_shout"
operation = "GetShout"
query_name = 'get_shout'
operation = 'GetShout'
query = f"""query {operation}($slug: String, $shout_id: Int) {{
{query_name}(slug: $slug, shout_id: $shout_id) {{ id slug title authors {{ id slug name pic }} }}
}}"""
gql = {"query": query, "operationName": operation, "variables": {"slug": None, "shout_id": shout_id}}
gql = {'query': query, 'operationName': operation, 'variables': {'slug': None, 'shout_id': shout_id}}
return await _request_endpoint(query_name, gql)

View File

@ -8,15 +8,16 @@ from sqlalchemy.sql.schema import Table
from settings import DB_URL
engine = create_engine(DB_URL, echo=False, pool_size=10, max_overflow=20)
T = TypeVar("T")
T = TypeVar('T')
REGISTRY: Dict[str, type] = {}
# @contextmanager
def local_session(src=""):
def local_session(src=''):
return Session(bind=engine, expire_on_commit=False)
# try:
@ -44,7 +45,7 @@ class Base(declarative_base()):
__init__: Callable
__allow_unmapped__ = True
__abstract__ = True
__table_args__ = {"extend_existing": True}
__table_args__ = {'extend_existing': True}
id = Column(Integer, primary_key=True)
@ -53,12 +54,12 @@ class Base(declarative_base()):
def dict(self) -> Dict[str, Any]:
column_names = self.__table__.columns.keys()
if "_sa_instance_state" in column_names:
column_names.remove("_sa_instance_state")
if '_sa_instance_state' in column_names:
column_names.remove('_sa_instance_state')
try:
return {c: getattr(self, c) for c in column_names}
except Exception as e:
print(f"[services.db] Error dict: {e}")
print(f'[services.db] Error dict: {e}')
return {}
def update(self, values: Dict[str, Any]) -> None:

View File

@ -1,12 +1,13 @@
import json
import redis.asyncio as aredis
import asyncio
from settings import REDIS_URL
import json
import logging
logger = logging.getLogger("\t[services.redis]\t")
import redis.asyncio as aredis
from settings import REDIS_URL
logger = logging.getLogger('\t[services.redis]\t')
logger.setLevel(logging.DEBUG)
@ -26,11 +27,11 @@ class RedisCache:
async def execute(self, command, *args, **kwargs):
if self._client:
try:
logger.debug(command + " " + " ".join(args))
logger.debug(command + ' ' + ' '.join(args))
r = await self._client.execute_command(command, *args, **kwargs)
return r
except Exception as e:
logger.error(f"{e}")
logger.error(f'{e}')
return None
async def subscribe(self, *channels):
@ -60,15 +61,15 @@ class RedisCache:
while True:
message = await pubsub.get_message()
if message and isinstance(message["data"], (str, bytes, bytearray)):
logger.debug("pubsub got msg")
if message and isinstance(message['data'], (str, bytes, bytearray)):
logger.debug('pubsub got msg')
try:
yield json.loads(message["data"]), message.get("channel")
yield json.loads(message['data']), message.get('channel')
except Exception as e:
logger.error(f"{e}")
logger.error(f'{e}')
await asyncio.sleep(1)
redis = RedisCache()
__all__ = ["redis"]
__all__ = ['redis']

View File

@ -1,13 +1,14 @@
from os import environ
PORT = 8000
DB_URL = (
environ.get("DATABASE_URL", environ.get("DB_URL", "")).replace("postgres://", "postgresql://")
or "postgresql://postgres@localhost:5432/discoursio"
environ.get('DATABASE_URL', environ.get('DB_URL', '')).replace('postgres://', 'postgresql://')
or 'postgresql://postgres@localhost:5432/discoursio'
)
REDIS_URL = environ.get("REDIS_URL") or "redis://127.0.0.1"
API_BASE = environ.get("API_BASE") or "https://core.discours.io"
AUTH_URL = environ.get("AUTH_URL") or "https://auth.discours.io"
MODE = environ.get("MODE") or "production"
SENTRY_DSN = environ.get("SENTRY_DSN")
DEV_SERVER_PID_FILE_NAME = "dev-server.pid"
REDIS_URL = environ.get('REDIS_URL') or 'redis://127.0.0.1'
API_BASE = environ.get('API_BASE') or 'https://core.discours.io'
AUTH_URL = environ.get('AUTH_URL') or 'https://auth.discours.io'
MODE = environ.get('MODE') or 'production'
SENTRY_DSN = environ.get('SENTRY_DSN')
DEV_SERVER_PID_FILE_NAME = 'dev-server.pid'