diff --git a/base/redis.py b/base/redis.py index d5d4babd..0bab8c12 100644 --- a/base/redis.py +++ b/base/redis.py @@ -1,45 +1,61 @@ -from asyncio import sleep - -from aioredis import from_url +import redis.asyncio as aredis from settings import REDIS_URL +import logging + +logger = logging.getLogger("[services.redis] ") +logger.setLevel(logging.DEBUG) class RedisCache: def __init__(self, uri=REDIS_URL): self._uri: str = uri - self._instance = None + self.pubsub_channels = [] + self._client = None async def connect(self): - if self._instance is not None: - return - self._instance = await from_url(self._uri, encoding="utf-8") - # print(self._instance) + self._client = aredis.Redis.from_url(self._uri, decode_responses=True) async def disconnect(self): - if self._instance is None: - return - await self._instance.close() - # await self._instance.wait_closed() # deprecated - self._instance = None + if self._client: + await self._client.close() async def execute(self, command, *args, **kwargs): - while not self._instance: - await sleep(1) - try: - # print("[redis] " + command + ' ' + ' '.join(args)) - return await self._instance.execute_command(command, *args, **kwargs) - except Exception: - pass + if self._client: + try: + logger.debug(command + " " + " ".join(args)) + r = await self._client.execute_command(command, *args, **kwargs) + logger.debug(type(r)) + logger.debug(r) + return r + except Exception as e: + logger.error(e) + + async def subscribe(self, *channels): + if self._client: + async with self._client.pubsub() as pubsub: + for channel in channels: + await pubsub.subscribe(channel) + self.pubsub_channels.append(channel) + + async def unsubscribe(self, *channels): + if not self._client: + return + async with self._client.pubsub() as pubsub: + for channel in channels: + await pubsub.unsubscribe(channel) + self.pubsub_channels.remove(channel) + + async def publish(self, channel, data): + if not self._client: + return + await self._client.publish(channel, data) + + async def mget(self, *keys): + return await self.execute('MGET', *keys) async def lrange(self, key, start, stop): - # print(f"[redis] LRANGE {key} {start} {stop}") - return await self._instance.lrange(key, start, stop) - - async def mget(self, key, *keys): - # print(f"[redis] MGET {key} {keys}") - return await self._instance.mget(key, *keys) - + return await self.execute('LRANGE', key, start, stop) redis = RedisCache() diff --git a/requirements.txt b/requirements.txt index d25fc230..d480ee9e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ aiohttp -aioredis~=2.0.1 alembic==1.11.3 ariadne>=0.17.0 asyncio~=3.4.3 @@ -34,3 +33,5 @@ sse-starlette==1.6.5 starlette~=0.23.1 transliterate~=1.10.2 uvicorn>=0.18.3 + +redis