diff --git a/auth/identity.py b/auth/identity.py index cfb4093a..6c446a79 100644 --- a/auth/identity.py +++ b/auth/identity.py @@ -1,7 +1,7 @@ from auth.password import Password from exceptions import InvalidPassword, ObjectNotExist from orm import User as OrmUser -from orm.base import global_session +from orm.base import local_session from auth.validations import User from sqlalchemy import or_ @@ -10,7 +10,8 @@ from sqlalchemy import or_ class Identity: @staticmethod def identity(user_id: int, password: str) -> User: - user = global_session.query(OrmUser).filter_by(id=user_id).first() + with local_session() as session: + user = session.query(OrmUser).filter_by(id=user_id).first() if not user: raise ObjectNotExist("User does not exist") user = User(**user.dict()) @@ -22,14 +23,15 @@ class Identity: @staticmethod def identity_oauth(input) -> User: - user = global_session.query(OrmUser).filter( - or_(OrmUser.oauth_id == input["oauth_id"], OrmUser.email == input["email"]) - ).first() - if not user: - user = OrmUser.create(**input) - if not user.oauth_id: - user.oauth_id = input["oauth_id"] - global_session.commit() + with local_session() as session: + user = session.query(OrmUser).filter( + or_(OrmUser.oauth_id == input["oauth_id"], OrmUser.email == input["email"]) + ).first() + if not user: + user = OrmUser.create(**input) + if not user.oauth_id: + user.oauth_id = input["oauth_id"] + session.commit() user = User(**user.dict()) return user diff --git a/orm/base.py b/orm/base.py index 1bda41b2..b8a8d8de 100644 --- a/orm/base.py +++ b/orm/base.py @@ -2,7 +2,7 @@ from typing import TypeVar, Any, Dict, Generic, Callable from sqlalchemy import create_engine, Column, Integer from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from sqlalchemy.sql.schema import Table from settings import DB_URL @@ -10,48 +10,48 @@ from orm._retry import RetryingQuery # engine = create_engine(DB_URL, convert_unicode=True, echo=False) engine = create_engine(DB_URL, - convert_unicode=True, - echo=False, - #pool_size=10, - #max_overflow=2, - #pool_recycle=300, - pool_pre_ping=True, - #pool_use_lifo=True - ) - -Session = sessionmaker(autocommit=False, autoflush=False, bind=engine, query_cls=RetryingQuery) -#Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) -global_session = Session() + convert_unicode=True, + echo=False, + #pool_size=10, + #max_overflow=2, + #pool_recycle=300, + pool_pre_ping=True, + #pool_use_lifo=True + future=True + ) T = TypeVar("T") REGISTRY: Dict[str, type] = {} +def local_session(): + return Session(bind=engine, expire_on_commit=False) + class Base(declarative_base()): - __table__: Table - __tablename__: str - __new__: Callable - __init__: Callable + __table__: Table + __tablename__: str + __new__: Callable + __init__: Callable - __abstract__: bool = True - __table_args__ = {"extend_existing": True} - id: int = Column(Integer, primary_key=True) - session = global_session + __abstract__: bool = True + __table_args__ = {"extend_existing": True} + id: int = Column(Integer, primary_key=True) - def __init_subclass__(cls, **kwargs): - REGISTRY[cls.__name__] = cls + def __init_subclass__(cls, **kwargs): + REGISTRY[cls.__name__] = cls - @classmethod - def create(cls: Generic[T], **kwargs) -> Generic[T]: - instance = cls(**kwargs) - return instance.save() + @classmethod + def create(cls: Generic[T], **kwargs) -> Generic[T]: + instance = cls(**kwargs) + return instance.save() - def save(self) -> Generic[T]: - self.session.add(self) - self.session.commit() - return self + def save(self) -> Generic[T]: + with local_session() as session: + session.add(self) + session.commit() + return self - def dict(self) -> Dict[str, Any]: - column_names = self.__table__.columns.keys() - return {c: getattr(self, c) for c in column_names} + def dict(self) -> Dict[str, Any]: + column_names = self.__table__.columns.keys() + return {c: getattr(self, c) for c in column_names} diff --git a/orm/rbac.py b/orm/rbac.py index c8924fa4..bf3659a5 100644 --- a/orm/rbac.py +++ b/orm/rbac.py @@ -4,7 +4,7 @@ from typing import Type from sqlalchemy import String, Column, ForeignKey, types, UniqueConstraint -from orm.base import Base, REGISTRY, engine, global_session +from orm.base import Base, REGISTRY, engine class ClassType(types.TypeDecorator): diff --git a/orm/user.py b/orm/user.py index 5ef9fb0b..217565fa 100644 --- a/orm/user.py +++ b/orm/user.py @@ -3,26 +3,27 @@ from typing import List from sqlalchemy import Column, Integer, String, ForeignKey #, relationship from orm import Permission -from orm.base import Base +from orm.base import Base, local_session class User(Base): - __tablename__ = 'user' + __tablename__ = 'user' - email: str = Column(String, nullable=False) - username: str = Column(String, nullable=False, comment="Name") - password: str = Column(String, nullable=True, comment="Password") + email: str = Column(String, nullable=False) + username: str = Column(String, nullable=False, comment="Name") + password: str = Column(String, nullable=True, comment="Password") - role_id: list = Column(ForeignKey("role.id"), nullable=True, comment="Role") - # roles = relationship("Role") TODO: one to many, see schema.graphql - oauth_id: str = Column(String, nullable=True) + role_id: list = Column(ForeignKey("role.id"), nullable=True, comment="Role") + # roles = relationship("Role") TODO: one to many, see schema.graphql + oauth_id: str = Column(String, nullable=True) - @classmethod - def get_permission(cls, user_id): - perms: List[Permission] = cls.session.query(Permission).join(User, User.role_id == Permission.role_id).filter( - User.id == user_id).all() - return {f"{p.operation_id}-{p.resource_id}" for p in perms} + @classmethod + def get_permission(cls, user_id): + with local_session() as session: + perms: List[Permission] = session.query(Permission).join(User, User.role_id == Permission.role_id).filter( + User.id == user_id).all() + return {f"{p.operation_id}-{p.resource_id}" for p in perms} if __name__ == '__main__': - print(User.get_permission(user_id=1)) + print(User.get_permission(user_id=1)) diff --git a/resolvers/auth.py b/resolvers/auth.py index 50aa3cbc..18b444a7 100644 --- a/resolvers/auth.py +++ b/resolvers/auth.py @@ -6,7 +6,7 @@ from auth.identity import Identity from auth.password import Password from auth.validations import CreateUser from orm import User -from orm.base import global_session +from orm.base import local_session from resolvers.base import mutation, query from exceptions import InvalidPassword @@ -44,7 +44,8 @@ async def register(*_, email: str, password: str = ""): @query.field("signIn") async def sign_in(_, info: GraphQLResolveInfo, email: str, password: str): - orm_user = global_session.query(User).filter(User.email == email).first() + with local_session() as session: + orm_user = session.query(User).filter(User.email == email).first() if orm_user is None: return {"error" : "invalid email"} @@ -75,10 +76,12 @@ async def sign_out(_, info: GraphQLResolveInfo): async def get_user(_, info): auth = info.context["request"].auth user_id = auth.user_id - user = global_session.query(User).filter(User.id == user_id).first() + with local_session() as session: + user = session.query(User).filter(User.id == user_id).first() return { "user": user } @query.field("isEmailFree") async def is_email_free(_, info, email): - user = global_session.query(User).filter(User.email == email).first() + with local_session() as session: + user = session.query(User).filter(User.email == email).first() return user is None diff --git a/resolvers/inbox.py b/resolvers/inbox.py index 9104ac24..a6f710cb 100644 --- a/resolvers/inbox.py +++ b/resolvers/inbox.py @@ -1,5 +1,5 @@ from orm import Message, User -from orm.base import global_session +from orm.base import local_session from resolvers.base import mutation, query, subscription @@ -37,12 +37,13 @@ async def get_messages(_, info, count, page): auth = info.context["request"].auth user_id = auth.user_id - messages = global_session.query(Message).filter(Message.author == user_id) + with local_session() as session: + messages = session.query(Message).filter(Message.author == user_id) return messages -def check_and_get_message(message_id, user_id) : - message = global_session.query(Message).filter(Message.id == message_id).first() +def check_and_get_message(message_id, user_id, session) : + message = session.query(Message).filter(Message.id == message_id).first() if not message : raise Exception("invalid id") @@ -58,13 +59,14 @@ async def update_message(_, info, id, body): auth = info.context["request"].auth user_id = auth.user_id - try: - message = check_and_get_message(id, user_id) - except Exception as err: - return {"error" : err} + with local_session() as session: + try: + message = check_and_get_message(id, user_id, session) + except Exception as err: + return {"error" : err} - message.body = body - global_session.commit() + message.body = body + session.commit() MessageQueue.updated_message.put_nowait(message) @@ -76,13 +78,14 @@ async def delete_message(_, info, id): auth = info.context["request"].auth user_id = auth.user_id - try: - message = check_and_get_message(id, user_id) - except Exception as err: - return {"error" : err} + with local_session() as session: + try: + message = check_and_get_message(id, user_id, session) + except Exception as err: + return {"error" : err} - global_session.delete(message) - global_session.commit() + session.delete(message) + session.commit() MessageQueue.deleted_message.put_nowait(message)