From 9618802c6bde5c84413c78c9a258489fedcb3360 Mon Sep 17 00:00:00 2001 From: knst-kotov Date: Mon, 16 Aug 2021 15:09:04 +0300 Subject: [PATCH 1/3] fix redis client --- redis/client.py | 99 ++++++++++++++++++++++++------------------------- 1 file changed, 49 insertions(+), 50 deletions(-) diff --git a/redis/client.py b/redis/client.py index b11475c4..ae5372f7 100644 --- a/redis/client.py +++ b/redis/client.py @@ -1,50 +1,49 @@ -from typing import Optional - -import aioredis -# from aioredis import ConnectionsPool - -from settings import REDIS_URL - - -class Redis: - def __init__(self, uri=REDIS_URL): - self._uri: str = uri - self._instance = None - - async def connect(self): - if self._instance is not None: - return - self._instance = await aioredis.from_url(self._uri)# .create_pool(self._uri) - - async def disconnect(self): - if self._instance is None: - return - self._instance.close() - await self._instance.wait_closed() - self._instance = None - - async def execute(self, command, *args, **kwargs): - return await self._instance.execute(command, *args, **kwargs, encoding="UTF-8") - - -async def test(): - redis = Redis() - from datetime import datetime - - await redis.connect() - await redis.execute("SET", "1-KEY1", 1) - await redis.execute("SET", "1-KEY2", 1) - await redis.execute("SET", "1-KEY3", 1) - await redis.execute("SET", "1-KEY4", 1) - await redis.execute("EXPIREAT", "1-KEY4", int(datetime.utcnow().timestamp())) - v = await redis.execute("KEYS", "1-*") - print(v) - await redis.execute("DEL", *v) - v = await redis.execute("KEYS", "1-*") - print(v) - - -if __name__ == '__main__': - import asyncio - - asyncio.run(test()) +from typing import Optional + +import aioredis + +from settings import REDIS_URL + + +class Redis: + def __init__(self, uri=REDIS_URL): + self._uri: str = uri + self._instance = None + + async def connect(self): + if self._instance is not None: + return + self._instance = aioredis.from_url(self._uri, encoding="utf-8") + + async def disconnect(self): + if self._instance is None: + return + self._instance.close() + await self._instance.wait_closed() + self._instance = None + + async def execute(self, command, *args, **kwargs): + return await self._instance.execute_command(command, *args, **kwargs) + + +async def test(): + redis = Redis() + from datetime import datetime + + await redis.connect() + await redis.execute("SET", "1-KEY1", 1) + await redis.execute("SET", "1-KEY2", 1) + await redis.execute("SET", "1-KEY3", 1) + await redis.execute("SET", "1-KEY4", 1) + await redis.execute("EXPIREAT", "1-KEY4", int(datetime.utcnow().timestamp())) + v = await redis.execute("KEYS", "1-*") + print(v) + await redis.execute("DEL", *v) + v = await redis.execute("KEYS", "1-*") + print(v) + + +if __name__ == '__main__': + import asyncio + + asyncio.run(test()) From b8b7854c4cf31298515d87cb4bf9fea65617efe6 Mon Sep 17 00:00:00 2001 From: knst-kotov Date: Tue, 17 Aug 2021 12:14:26 +0300 Subject: [PATCH 2/3] improve rbac --- auth/authenticate.py | 1 + auth/credentials.py | 2 +- create_crt.sh | 20 +++++----- orm/__init__.py | 4 +- orm/rbac.py | 94 +++++++++++++++++++++++++++----------------- orm/shout.py | 2 +- orm/user.py | 23 ++++++++--- resolvers/zine.py | 41 +++++++++++++++++-- schema.graphql | 7 ++-- 9 files changed, 132 insertions(+), 62 deletions(-) mode change 100755 => 100644 create_crt.sh diff --git a/auth/authenticate.py b/auth/authenticate.py index 15b92e74..354fc37a 100644 --- a/auth/authenticate.py +++ b/auth/authenticate.py @@ -63,6 +63,7 @@ class JWTAuthenticate(AuthenticationBackend): return AuthCredentials(scopes=[], error_message=str(exc)), AuthUser(user_id=None) scopes = User.get_permission(user_id=payload.user_id) + print(scopes) return AuthCredentials(user_id=payload.user_id, scopes=scopes, logged_in=True), AuthUser(user_id=payload.user_id) diff --git a/auth/credentials.py b/auth/credentials.py index 47300b34..097b5936 100644 --- a/auth/credentials.py +++ b/auth/credentials.py @@ -9,7 +9,7 @@ class Permission(BaseModel): class AuthCredentials(BaseModel): user_id: Optional[int] = None - scopes: Optional[set] = {} + scopes: Optional[dict] = {} logged_in: bool = False error_message: str = "" diff --git a/create_crt.sh b/create_crt.sh old mode 100755 new mode 100644 index d36eec87..3867257a --- a/create_crt.sh +++ b/create_crt.sh @@ -1,10 +1,10 @@ -#!/bin/bash - -openssl req -newkey rsa:4096 \ - -x509 \ - -sha256 \ - -days 3650 \ - -nodes \ - -out discours.crt \ - -keyout discours.key \ - -subj "/C=RU/ST=Moscow/L=Moscow/O=Discours/OU=Site/CN=test-api.discours.io" +#!/bin/bash + +openssl req -newkey rsa:4096 \ + -x509 \ + -sha256 \ + -days 3650 \ + -nodes \ + -out discours.crt \ + -keyout discours.key \ + -subj "/C=RU/ST=Moscow/L=Moscow/O=Discours/OU=Site/CN=test-api.discours.io" diff --git a/orm/__init__.py b/orm/__init__.py index 09586bf8..864b432a 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -1,4 +1,4 @@ -from orm.rbac import Operation, Permission, Role +from orm.rbac import Organization, Operation, Resource, Permission, Role from orm.user import User from orm.message import Message from orm.shout import Shout @@ -7,3 +7,5 @@ from orm.base import Base, engine __all__ = ["User", "Role", "Operation", "Permission", "Message", "Shout"] Base.metadata.create_all(engine) +Operation.init_table() +Resource.init_table() diff --git a/orm/rbac.py b/orm/rbac.py index bf3659a5..5a19438d 100644 --- a/orm/rbac.py +++ b/orm/rbac.py @@ -3,63 +3,85 @@ import warnings from typing import Type from sqlalchemy import String, Column, ForeignKey, types, UniqueConstraint +from sqlalchemy.orm import relationship -from orm.base import Base, REGISTRY, engine +from orm.base import Base, REGISTRY, engine, local_session class ClassType(types.TypeDecorator): - impl = types.String + impl = types.String - @property - def python_type(self): - return NotImplemented + @property + def python_type(self): + return NotImplemented - def process_literal_param(self, value, dialect): - return NotImplemented + def process_literal_param(self, value, dialect): + return NotImplemented - def process_bind_param(self, value, dialect): - return value.__name__ if isinstance(value, type) else str(value) + def process_bind_param(self, value, dialect): + return value.__name__ if isinstance(value, type) else str(value) - def process_result_value(self, value, dialect): - class_ = REGISTRY.get(value) - if class_ is None: - warnings.warn(f"Can't find class <{value}>,find it yourself 😊", stacklevel=2) - return class_ + def process_result_value(self, value, dialect): + class_ = REGISTRY.get(value) + if class_ is None: + warnings.warn(f"Can't find class <{value}>,find it yourself 😊", stacklevel=2) + return class_ +class Organization(Base): + __tablename__ = 'organization' + name: str = Column(String, nullable=False, unique=True, comment="Organization Name") class Role(Base): - __tablename__ = 'role' - name: str = Column(String, nullable=False, unique=True, comment="Role Name") + __tablename__ = 'role' + name: str = Column(String, nullable=False, unique=True, comment="Role Name") + org_id: int = Column(ForeignKey("organization.id", ondelete="CASCADE"), nullable=False, comment="Organization") + permissions = relationship("Permission") class Operation(Base): - __tablename__ = 'operation' - name: str = Column(String, nullable=False, unique=True, comment="Operation Name") + __tablename__ = 'operation' + name: str = Column(String, nullable=False, unique=True, comment="Operation Name") + + @staticmethod + def init_table(): + with local_session() as session: + edit_op = session.query(Operation).filter(Operation.name == "edit").first() + if not edit_op: + edit_op = Operation.create(name = "edit") + Operation.edit_id = edit_op.id class Resource(Base): - __tablename__ = "resource" - resource_class: Type[Base] = Column(ClassType, nullable=False, unique=True, comment="Resource class") - name: str = Column(String, nullable=False, unique=True, comment="Resource name") + __tablename__ = "resource" + resource_class: Type[Base] = Column(ClassType, nullable=False, unique=True, comment="Resource class") + name: str = Column(String, nullable=False, unique=True, comment="Resource name") + + @staticmethod + def init_table(): + with local_session() as session: + shout_res = session.query(Resource).filter(Resource.name == "shout").first() + if not shout_res: + shout_res = Resource.create(name = "shout", resource_class = "shout") + Resource.shout_id = shout_res.id class Permission(Base): - __tablename__ = "permission" - __table_args__ = (UniqueConstraint("role_id", "operation_id", "resource_id"), {"extend_existing": True}) + __tablename__ = "permission" + __table_args__ = (UniqueConstraint("role_id", "operation_id", "resource_id"), {"extend_existing": True}) - role_id: int = Column(ForeignKey("role.id", ondelete="CASCADE"), nullable=False, comment="Role") - operation_id: int = Column(ForeignKey("operation.id", ondelete="CASCADE"), nullable=False, comment="Operation") - resource_id: int = Column(ForeignKey("operation.id", ondelete="CASCADE"), nullable=False, comment="Resource") + role_id: int = Column(ForeignKey("role.id", ondelete="CASCADE"), nullable=False, comment="Role") + operation_id: int = Column(ForeignKey("operation.id", ondelete="CASCADE"), nullable=False, comment="Operation") + resource_id: int = Column(ForeignKey("operation.id", ondelete="CASCADE"), nullable=False, comment="Resource") if __name__ == '__main__': - Base.metadata.create_all(engine) - ops = [ - Permission(role_id=1, operation_id=1, resource_id=1), - Permission(role_id=1, operation_id=2, resource_id=1), - Permission(role_id=1, operation_id=3, resource_id=1), - Permission(role_id=1, operation_id=4, resource_id=1), - Permission(role_id=2, operation_id=4, resource_id=1) - ] - global_session.add_all(ops) - global_session.commit() + Base.metadata.create_all(engine) + ops = [ + Permission(role_id=1, operation_id=1, resource_id=1), + Permission(role_id=1, operation_id=2, resource_id=1), + Permission(role_id=1, operation_id=3, resource_id=1), + Permission(role_id=1, operation_id=4, resource_id=1), + Permission(role_id=2, operation_id=4, resource_id=1) + ] + global_session.add_all(ops) + global_session.commit() diff --git a/orm/shout.py b/orm/shout.py index aecef27b..dd43f5ec 100644 --- a/orm/shout.py +++ b/orm/shout.py @@ -12,7 +12,7 @@ class Shout(Base): id = None slug: str = Column(String, primary_key=True) - org: str = Column(String, nullable=False) + org_id: str = Column(ForeignKey("organization.id"), nullable=False) author_id: str = Column(ForeignKey("user.id"), nullable=False, comment="Author") body: str = Column(String, nullable=False, comment="Body") createdAt: str = Column(DateTime, nullable=False, default = datetime.now, comment="Created at") diff --git a/orm/user.py b/orm/user.py index 217565fa..7fec296f 100644 --- a/orm/user.py +++ b/orm/user.py @@ -1,11 +1,19 @@ from typing import List -from sqlalchemy import Column, Integer, String, ForeignKey #, relationship +from sqlalchemy import Column, Integer, String, ForeignKey +from sqlalchemy.orm import relationship from orm import Permission from orm.base import Base, local_session +class UserRole(Base): + __tablename__ = 'user_role' + + id = None + user_id: int = Column(ForeignKey("user.id"), primary_key = True) + role_id: int = Column(ForeignKey("role.id"), primary_key = True) + class User(Base): __tablename__ = 'user' @@ -13,16 +21,19 @@ class User(Base): username: str = Column(String, nullable=False, comment="Name") password: str = Column(String, nullable=True, comment="Password") - role_id: list = Column(ForeignKey("role.id"), nullable=True, comment="Role") - # roles = relationship("Role") TODO: one to many, see schema.graphql oauth_id: str = Column(String, nullable=True) + roles = relationship("Role", secondary=UserRole.__table__) + @classmethod def get_permission(cls, user_id): + scope = {} with local_session() as session: - perms: List[Permission] = session.query(Permission).join(User, User.role_id == Permission.role_id).filter( - User.id == user_id).all() - return {f"{p.operation_id}-{p.resource_id}" for p in perms} + user = session.query(User).filter(User.id == user_id).first() + for role in user.roles: + for p in role.permissions: + scope[p.resource_id] = p.operation_id + return scope if __name__ == '__main__': diff --git a/resolvers/zine.py b/resolvers/zine.py index aa3bcdbd..97d311b3 100644 --- a/resolvers/zine.py +++ b/resolvers/zine.py @@ -1,4 +1,4 @@ -from orm import Shout, User +from orm import Shout, User, Organization from orm.base import local_session from resolvers.base import mutation, query @@ -15,10 +15,10 @@ class GitTask: queue = asyncio.Queue() - def __init__(self, input, username, user_email, comment): + def __init__(self, input, org, username, user_email, comment): self.slug = input["slug"]; - self.org = input["org"]; self.shout_body = input["body"]; + self.org = org; self.username = username; self.user_email = user_email; self.comment = comment; @@ -84,12 +84,19 @@ async def create_shout(_, info, input): auth = info.context["request"].auth user_id = auth.user_id + org_id = org = input["org_id"] with local_session() as session: user = session.query(User).filter(User.id == user_id).first() + org = session.query(Organization).filter(Organization.id == org_id).first() + + if not org: + return { + "error" : "invalid organization" + } new_shout = Shout.create( slug = input["slug"], - org = input["org"], + org_id = org_id, author_id = user_id, body = input["body"], replyTo = input.get("replyTo"), @@ -100,6 +107,7 @@ async def create_shout(_, info, input): task = GitTask( input, + org.name, user.username, user.email, "new shout %s" % (new_shout.slug) @@ -109,5 +117,30 @@ async def create_shout(_, info, input): "shout" : new_shout } +@mutation.field("updateShout") +@login_required +async def update_shout(_, info, shout_id, input): + auth = info.context["request"].auth + user_id = auth.user_id + + with local_session() as session: + user = session.query(User).filter(User.id == user_id).first() + shout = session.query(Shout).filter(Shout.id == shout_id).first() + + if not shout: + return { + "error" : "shout not found" + } + + if shout.author_id != user_id: + scope = info.context["request"].scope + if not Resource.shout_id in scope: + return { + "error" : "access denied" + } + + return { + "shout" : shout + } # TODO: paginate, get, update, delete diff --git a/schema.graphql b/schema.graphql index 2cb59bf7..7945ec42 100644 --- a/schema.graphql +++ b/schema.graphql @@ -23,7 +23,7 @@ type MessageResult { } input ShoutInput { - org: String! + org_id: Int! slug: String! body: String! replyTo: String # another shout @@ -61,10 +61,11 @@ type Mutation { # invalidateTokenById(id: Int!): Boolean! # requestEmailConfirmation: User! # requestPasswordReset(email: String!): Boolean! - registerUser(email: String!, password: String!): AuthResult! + registerUser(email: String!, password: String!): AuthResult! # shout createShout(input: ShoutInput!): ShoutResult! + updateShout(input: ShoutInput!): ShoutResult! deleteShout(slug: String!): Result! rateShout(slug: String!, value: Int!): Result! @@ -151,7 +152,7 @@ type Message { # is publication type Shout { - org: String! + org_id: Int! slug: String! author: Int! body: String! From 1ce88a351589a28d4261d6a5e4227ba4117dbfe7 Mon Sep 17 00:00:00 2001 From: knst-kotov Date: Wed, 18 Aug 2021 19:53:55 +0300 Subject: [PATCH 3/3] updateShout --- auth/authenticate.py | 104 ++++++++++++++++++++++--------------------- orm/user.py | 6 ++- resolvers/zine.py | 31 ++++++++++--- 3 files changed, 83 insertions(+), 58 deletions(-) diff --git a/auth/authenticate.py b/auth/authenticate.py index 354fc37a..5c4764c8 100644 --- a/auth/authenticate.py +++ b/auth/authenticate.py @@ -15,63 +15,65 @@ from settings import JWT_AUTH_HEADER class _Authenticate: - @classmethod - async def verify(cls, token: str): - """ - Rules for a token to be valid. - 1. token format is legal && - token exists in redis database && - token is not expired - 2. token format is legal && - token exists in redis database && - token is expired && - token is of specified type - """ - try: - payload = Token.decode(token) - except ExpiredSignatureError: - payload = Token.decode(token, verify_exp=False) - if not await cls.exists(payload.user_id, token): - raise InvalidToken("Login expired, please login again") - if payload.device == "mobile": # noqa - "we cat set mobile token to be valid forever" - return payload - except DecodeError as e: - raise InvalidToken("token format error") from e - else: - if not await cls.exists(payload.user_id, token): - raise InvalidToken("Login expired, please login again") - return payload + @classmethod + async def verify(cls, token: str): + """ + Rules for a token to be valid. + 1. token format is legal && + token exists in redis database && + token is not expired + 2. token format is legal && + token exists in redis database && + token is expired && + token is of specified type + """ + try: + payload = Token.decode(token) + except ExpiredSignatureError: + payload = Token.decode(token, verify_exp=False) + if not await cls.exists(payload.user_id, token): + raise InvalidToken("Login expired, please login again") + if payload.device == "mobile": # noqa + "we cat set mobile token to be valid forever" + return payload + except DecodeError as e: + raise InvalidToken("token format error") from e + else: + if not await cls.exists(payload.user_id, token): + raise InvalidToken("Login expired, please login again") + return payload - @classmethod - async def exists(cls, user_id, token): - token = await redis.execute("GET", f"{user_id}-{token}") - return token is not None + @classmethod + async def exists(cls, user_id, token): + token = await redis.execute("GET", f"{user_id}-{token}") + return token is not None class JWTAuthenticate(AuthenticationBackend): - async def authenticate( - self, request: HTTPConnection - ) -> Optional[Tuple[AuthCredentials, AuthUser]]: - if JWT_AUTH_HEADER not in request.headers: - return AuthCredentials(scopes=[]), AuthUser(user_id=None) + async def authenticate( + self, request: HTTPConnection + ) -> Optional[Tuple[AuthCredentials, AuthUser]]: + if JWT_AUTH_HEADER not in request.headers: + return AuthCredentials(scopes=[]), AuthUser(user_id=None) - token = request.headers[JWT_AUTH_HEADER] - try: - payload = await _Authenticate.verify(token) - except Exception as exc: - return AuthCredentials(scopes=[], error_message=str(exc)), AuthUser(user_id=None) + token = request.headers[JWT_AUTH_HEADER] + try: + payload = await _Authenticate.verify(token) + except Exception as exc: + return AuthCredentials(scopes=[], error_message=str(exc)), AuthUser(user_id=None) + + if payload is None: + return AuthCredentials(scopes=[]), AuthUser(user_id=None) - scopes = User.get_permission(user_id=payload.user_id) - print(scopes) - return AuthCredentials(user_id=payload.user_id, scopes=scopes, logged_in=True), AuthUser(user_id=payload.user_id) + scopes = User.get_permission(user_id=payload.user_id) + return AuthCredentials(user_id=payload.user_id, scopes=scopes, logged_in=True), AuthUser(user_id=payload.user_id) def login_required(func): - @wraps(func) - async def wrap(parent, info: GraphQLResolveInfo, *args, **kwargs): - auth: AuthCredentials = info.context["request"].auth - if not auth.logged_in: - return {"error" : auth.error_message or "Please login"} - return await func(parent, info, *args, **kwargs) - return wrap + @wraps(func) + async def wrap(parent, info: GraphQLResolveInfo, *args, **kwargs): + auth: AuthCredentials = info.context["request"].auth + if not auth.logged_in: + return {"error" : auth.error_message or "Please login"} + return await func(parent, info, *args, **kwargs) + return wrap diff --git a/orm/user.py b/orm/user.py index 7fec296f..804c7b2b 100644 --- a/orm/user.py +++ b/orm/user.py @@ -17,7 +17,7 @@ class UserRole(Base): class User(Base): __tablename__ = 'user' - email: str = Column(String, nullable=False) + email: str = Column(String, unique=True, nullable=False) username: str = Column(String, nullable=False, comment="Name") password: str = Column(String, nullable=True, comment="Password") @@ -32,7 +32,9 @@ class User(Base): user = session.query(User).filter(User.id == user_id).first() for role in user.roles: for p in role.permissions: - scope[p.resource_id] = p.operation_id + if not p.resource_id in scope: + scope[p.resource_id] = set() + scope[p.resource_id].add(p.operation_id) return scope diff --git a/resolvers/zine.py b/resolvers/zine.py index 97d311b3..581f820b 100644 --- a/resolvers/zine.py +++ b/resolvers/zine.py @@ -1,4 +1,4 @@ -from orm import Shout, User, Organization +from orm import Shout, User, Organization, Resource from orm.base import local_session from resolvers.base import mutation, query @@ -119,13 +119,16 @@ async def create_shout(_, info, input): @mutation.field("updateShout") @login_required -async def update_shout(_, info, shout_id, input): +async def update_shout(_, info, input): auth = info.context["request"].auth user_id = auth.user_id + slug = input["slug"] + org_id = org = input["org_id"] with local_session() as session: user = session.query(User).filter(User.id == user_id).first() - shout = session.query(Shout).filter(Shout.id == shout_id).first() + shout = session.query(Shout).filter(Shout.slug == slug).first() + org = session.query(Organization).filter(Organization.id == org_id).first() if not shout: return { @@ -133,12 +136,30 @@ async def update_shout(_, info, shout_id, input): } if shout.author_id != user_id: - scope = info.context["request"].scope - if not Resource.shout_id in scope: + scopes = auth.scopes + print(scopes) + if not Resource.shout_id in scopes: return { "error" : "access denied" } + shout.body = input["body"], + shout.replyTo = input.get("replyTo"), + shout.versionOf = input.get("versionOf"), + shout.tags = input.get("tags"), + shout.topics = input.get("topics") + + with local_session() as session: + session.commit() + + task = GitTask( + input, + org.name, + user.username, + user.email, + "update shout %s" % (shout.slug) + ) + return { "shout" : shout }