diff --git a/orm/__init__.py b/orm/__init__.py index d2e94d21..a15a6561 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -16,5 +16,5 @@ Operation.init_table() Resource.init_table() with local_session() as session: - rating_storage = ShoutRatingStorage(session) + ShoutRatingStorage.init(session) ShoutViewStorage.init(session) diff --git a/orm/shout.py b/orm/shout.py index ce26bceb..56b03da7 100644 --- a/orm/shout.py +++ b/orm/shout.py @@ -43,21 +43,30 @@ class ShoutRating(Base): class ShoutRatingStorage: - def __init__(self, session): - self.ratings = session.query(ShoutRating).all() + ratings = [] - def get_rating(self, shout_id): - shout_ratings = list(filter(lambda x: x.shout_id == shout_id, self.ratings)) + lock = asyncio.Lock() + + @staticmethod + def init(session): + ShoutRatingStorage.ratings = session.query(ShoutRating).all() + + @staticmethod + async def get_rating(shout_id): + async with ShoutRatingStorage.lock: + shout_ratings = list(filter(lambda x: x.shout_id == shout_id, ShoutRatingStorage.ratings)) return reduce((lambda x, y: x + y.value), shout_ratings, 0) - def update_rating(self, new_rating): - rating = next((x for x in self.ratings \ - if x.rater_id == new_rating.rater_id and x.shout_id == new_rating.shout_id), None) - if rating: - rating.value = new_rating.value - rating.ts = new_rating.ts - else: - self.ratings.append(new_rating) + @staticmethod + async def update_rating(new_rating): + async with ShoutRatingStorage.lock: + rating = next((x for x in ShoutRatingStorage.ratings \ + if x.rater_id == new_rating.rater_id and x.shout_id == new_rating.shout_id), None) + if rating: + rating.value = new_rating.value + rating.ts = new_rating.ts + else: + ShoutRatingStorage.ratings.append(new_rating) class ShoutViewByDay(Base): diff --git a/resolvers/zine.py b/resolvers/zine.py index 7703bd24..0e9fa89e 100644 --- a/resolvers/zine.py +++ b/resolvers/zine.py @@ -1,5 +1,5 @@ from orm import Shout, ShoutAuthor, ShoutTopic, ShoutRating, ShoutViewByDay, User, Community, Resource,\ - rating_storage, ShoutViewStorage + ShoutRatingStorage, ShoutViewStorage from orm.base import local_session from resolvers.base import mutation, query @@ -280,7 +280,7 @@ async def rate_shout(_, info, shout_id, value): value = value ) - rating_storage.update_rating(rating) + await ShoutRatingStorage.update_rating(rating) return {"error" : ""} @@ -299,6 +299,6 @@ async def get_shout_by_slug(_, info, slug): shout = session.query(Shout).\ options(select_options).\ filter(Shout.slug == slug).first() - shout.rating = rating_storage.get_rating(shout.id) + shout.rating = await ShoutRatingStorage.get_rating(shout.id) shout.views = await ShoutViewStorage.get_view(shout.id) return shout