global session to local session

This commit is contained in:
knst-kotov 2021-08-05 16:49:08 +00:00
parent a0e99e5ba9
commit 93c6f88435
6 changed files with 88 additions and 79 deletions

View File

@ -1,7 +1,7 @@
from auth.password import Password from auth.password import Password
from exceptions import InvalidPassword, ObjectNotExist from exceptions import InvalidPassword, ObjectNotExist
from orm import User as OrmUser from orm import User as OrmUser
from orm.base import global_session from orm.base import local_session
from auth.validations import User from auth.validations import User
from sqlalchemy import or_ from sqlalchemy import or_
@ -10,7 +10,8 @@ from sqlalchemy import or_
class Identity: class Identity:
@staticmethod @staticmethod
def identity(user_id: int, password: str) -> User: def identity(user_id: int, password: str) -> User:
user = global_session.query(OrmUser).filter_by(id=user_id).first() with local_session() as session:
user = session.query(OrmUser).filter_by(id=user_id).first()
if not user: if not user:
raise ObjectNotExist("User does not exist") raise ObjectNotExist("User does not exist")
user = User(**user.dict()) user = User(**user.dict())
@ -22,14 +23,15 @@ class Identity:
@staticmethod @staticmethod
def identity_oauth(input) -> User: def identity_oauth(input) -> User:
user = global_session.query(OrmUser).filter( with local_session() as session:
user = session.query(OrmUser).filter(
or_(OrmUser.oauth_id == input["oauth_id"], OrmUser.email == input["email"]) or_(OrmUser.oauth_id == input["oauth_id"], OrmUser.email == input["email"])
).first() ).first()
if not user: if not user:
user = OrmUser.create(**input) user = OrmUser.create(**input)
if not user.oauth_id: if not user.oauth_id:
user.oauth_id = input["oauth_id"] user.oauth_id = input["oauth_id"]
global_session.commit() session.commit()
user = User(**user.dict()) user = User(**user.dict())
return user return user

View File

@ -2,7 +2,7 @@ from typing import TypeVar, Any, Dict, Generic, Callable
from sqlalchemy import create_engine, Column, Integer from sqlalchemy import create_engine, Column, Integer
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import Session
from sqlalchemy.sql.schema import Table from sqlalchemy.sql.schema import Table
from settings import DB_URL from settings import DB_URL
@ -17,16 +17,16 @@ engine = create_engine(DB_URL,
#pool_recycle=300, #pool_recycle=300,
pool_pre_ping=True, pool_pre_ping=True,
#pool_use_lifo=True #pool_use_lifo=True
future=True
) )
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine, query_cls=RetryingQuery)
#Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
global_session = Session()
T = TypeVar("T") T = TypeVar("T")
REGISTRY: Dict[str, type] = {} REGISTRY: Dict[str, type] = {}
def local_session():
return Session(bind=engine, expire_on_commit=False)
class Base(declarative_base()): class Base(declarative_base()):
__table__: Table __table__: Table
@ -37,7 +37,6 @@ class Base(declarative_base()):
__abstract__: bool = True __abstract__: bool = True
__table_args__ = {"extend_existing": True} __table_args__ = {"extend_existing": True}
id: int = Column(Integer, primary_key=True) id: int = Column(Integer, primary_key=True)
session = global_session
def __init_subclass__(cls, **kwargs): def __init_subclass__(cls, **kwargs):
REGISTRY[cls.__name__] = cls REGISTRY[cls.__name__] = cls
@ -48,8 +47,9 @@ class Base(declarative_base()):
return instance.save() return instance.save()
def save(self) -> Generic[T]: def save(self) -> Generic[T]:
self.session.add(self) with local_session() as session:
self.session.commit() session.add(self)
session.commit()
return self return self
def dict(self) -> Dict[str, Any]: def dict(self) -> Dict[str, Any]:

View File

@ -4,7 +4,7 @@ from typing import Type
from sqlalchemy import String, Column, ForeignKey, types, UniqueConstraint from sqlalchemy import String, Column, ForeignKey, types, UniqueConstraint
from orm.base import Base, REGISTRY, engine, global_session from orm.base import Base, REGISTRY, engine
class ClassType(types.TypeDecorator): class ClassType(types.TypeDecorator):

View File

@ -3,7 +3,7 @@ from typing import List
from sqlalchemy import Column, Integer, String, ForeignKey #, relationship from sqlalchemy import Column, Integer, String, ForeignKey #, relationship
from orm import Permission from orm import Permission
from orm.base import Base from orm.base import Base, local_session
class User(Base): class User(Base):
@ -19,7 +19,8 @@ class User(Base):
@classmethod @classmethod
def get_permission(cls, user_id): def get_permission(cls, user_id):
perms: List[Permission] = cls.session.query(Permission).join(User, User.role_id == Permission.role_id).filter( with local_session() as session:
perms: List[Permission] = session.query(Permission).join(User, User.role_id == Permission.role_id).filter(
User.id == user_id).all() User.id == user_id).all()
return {f"{p.operation_id}-{p.resource_id}" for p in perms} return {f"{p.operation_id}-{p.resource_id}" for p in perms}

View File

@ -6,7 +6,7 @@ from auth.identity import Identity
from auth.password import Password from auth.password import Password
from auth.validations import CreateUser from auth.validations import CreateUser
from orm import User from orm import User
from orm.base import global_session from orm.base import local_session
from resolvers.base import mutation, query from resolvers.base import mutation, query
from exceptions import InvalidPassword from exceptions import InvalidPassword
@ -44,7 +44,8 @@ async def register(*_, email: str, password: str = ""):
@query.field("signIn") @query.field("signIn")
async def sign_in(_, info: GraphQLResolveInfo, email: str, password: str): async def sign_in(_, info: GraphQLResolveInfo, email: str, password: str):
orm_user = global_session.query(User).filter(User.email == email).first() with local_session() as session:
orm_user = session.query(User).filter(User.email == email).first()
if orm_user is None: if orm_user is None:
return {"error" : "invalid email"} return {"error" : "invalid email"}
@ -75,10 +76,12 @@ async def sign_out(_, info: GraphQLResolveInfo):
async def get_user(_, info): async def get_user(_, info):
auth = info.context["request"].auth auth = info.context["request"].auth
user_id = auth.user_id user_id = auth.user_id
user = global_session.query(User).filter(User.id == user_id).first() with local_session() as session:
user = session.query(User).filter(User.id == user_id).first()
return { "user": user } return { "user": user }
@query.field("isEmailFree") @query.field("isEmailFree")
async def is_email_free(_, info, email): async def is_email_free(_, info, email):
user = global_session.query(User).filter(User.email == email).first() with local_session() as session:
user = session.query(User).filter(User.email == email).first()
return user is None return user is None

View File

@ -1,5 +1,5 @@
from orm import Message, User from orm import Message, User
from orm.base import global_session from orm.base import local_session
from resolvers.base import mutation, query, subscription from resolvers.base import mutation, query, subscription
@ -37,12 +37,13 @@ async def get_messages(_, info, count, page):
auth = info.context["request"].auth auth = info.context["request"].auth
user_id = auth.user_id user_id = auth.user_id
messages = global_session.query(Message).filter(Message.author == user_id) with local_session() as session:
messages = session.query(Message).filter(Message.author == user_id)
return messages return messages
def check_and_get_message(message_id, user_id) : def check_and_get_message(message_id, user_id, session) :
message = global_session.query(Message).filter(Message.id == message_id).first() message = session.query(Message).filter(Message.id == message_id).first()
if not message : if not message :
raise Exception("invalid id") raise Exception("invalid id")
@ -58,13 +59,14 @@ async def update_message(_, info, id, body):
auth = info.context["request"].auth auth = info.context["request"].auth
user_id = auth.user_id user_id = auth.user_id
with local_session() as session:
try: try:
message = check_and_get_message(id, user_id) message = check_and_get_message(id, user_id, session)
except Exception as err: except Exception as err:
return {"error" : err} return {"error" : err}
message.body = body message.body = body
global_session.commit() session.commit()
MessageQueue.updated_message.put_nowait(message) MessageQueue.updated_message.put_nowait(message)
@ -76,13 +78,14 @@ async def delete_message(_, info, id):
auth = info.context["request"].auth auth = info.context["request"].auth
user_id = auth.user_id user_id = auth.user_id
with local_session() as session:
try: try:
message = check_and_get_message(id, user_id) message = check_and_get_message(id, user_id, session)
except Exception as err: except Exception as err:
return {"error" : err} return {"error" : err}
global_session.delete(message) session.delete(message)
global_session.commit() session.commit()
MessageQueue.deleted_message.put_nowait(message) MessageQueue.deleted_message.put_nowait(message)