diff --git a/src/fastapi_app/lib/db/brokers/base_broker.py b/src/fastapi_app/lib/db/brokers/base_broker.py index bf2a5b5..bf1a503 100644 --- a/src/fastapi_app/lib/db/brokers/base_broker.py +++ b/src/fastapi_app/lib/db/brokers/base_broker.py @@ -19,7 +19,3 @@ class BasePublisher(abc.ABC): @abc.abstractmethod async def publish_message(self, message_body: api_schemas.broker_message.BrokerMessage, routing_key: str) -> None: pass - - @abc.abstractmethod - async def get_connection(self) -> contextlib.AbstractAsyncContextManager[T]: - pass diff --git a/src/fastapi_app/lib/db/brokers/broker.py b/src/fastapi_app/lib/db/brokers/broker.py index 396f4cf..68f8d39 100644 --- a/src/fastapi_app/lib/db/brokers/broker.py +++ b/src/fastapi_app/lib/db/brokers/broker.py @@ -16,8 +16,3 @@ class BrokerPublisher: async def publish_message(self, message_body: api_schemas.BrokerMessage, routing_key: str): await self.broker.publish_message(message_body, routing_key) - - @contextlib.asynccontextmanager - async def get_connection(self) -> typing.AsyncGenerator: - async with self.broker.get_connection() as conn: - yield conn diff --git a/src/fastapi_app/lib/db/brokers/rabbitmq.py b/src/fastapi_app/lib/db/brokers/rabbitmq.py index 3a2bbeb..251b476 100644 --- a/src/fastapi_app/lib/db/brokers/rabbitmq.py +++ b/src/fastapi_app/lib/db/brokers/rabbitmq.py @@ -1,6 +1,7 @@ import asyncio import contextlib import json +import logging import aio_pika @@ -13,37 +14,54 @@ class RabbitMQPublisher(db_brokers.base_broker.BasePublisher): def __init__(self, settings: app_split_settings.RabbitMQSettings()): self.settings = settings self.connection = None - self.channel = None - self.pool = asyncio.Queue() - self.pool_size = settings.max_pool_size + self.logger = logging.getLogger(__name__) + self.pool = asyncio.Queue(maxsize=settings.max_pool_size) async def connect(self): - self.connection = await aio_pika.connect(self.settings.amqp_url) - self.channel = await self.connection.channel() - await self.channel.set_qos(prefetch_count=1) - exchange = await self.channel.declare_exchange(self.settings.exchange, aio_pika.ExchangeType.DIRECT) + try: + self.connection = await aio_pika.connect(self.settings.amqp_url) - for attr_name, attr_value in vars(self.settings.Queues).items(): - if not attr_name.startswith("__"): - queue = await self.channel.declare_queue(attr_value.value) - await queue.bind(exchange, attr_value.value) + for _ in range(self.settings.max_pool_size): + channel = await self.connection.channel() + await channel.set_qos(prefetch_count=1) + await self.pool.put(channel) + + except Exception as e: + self.logger.error(f"Failed to dispose resources: {e}") + raise async def dispose(self): - await self.channel.close() - await self.connection.close() + try: + while not self.pool.empty(): + channel = await self.pool.get() + await channel.close() - async def publish_message(self, message_body: api_schemas.BrokerMessage, routing_key: str): - message = aio_pika.Message(content_type="application/json", body=json.dumps(message_body).encode()) - await self.channel.default_exchange.publish(message, routing_key=routing_key) + if self.connection: + await self.connection.close() + except Exception as e: + self.logger.error(f"Failed to dispose resources: {e}") + raise @contextlib.asynccontextmanager - async def get_connection(self): - if self.pool.empty() and self.pool.qsize() < self.pool_size: + async def __get_channel(self): + if not self.connection: await self.connect() - await self.pool.put(self) - conn = await self.pool.get() + if self.pool.empty(): + channel = await self.connection.channel() + await self.pool.put(channel) + + channel = await self.pool.get() try: - yield conn + yield channel finally: - await self.pool.put(conn) + await self.pool.put(channel) + + async def publish_message(self, message_body: api_schemas.BrokerMessage, routing_key: str): + try: + async with self.__get_channel() as channel: + message = aio_pika.Message(content_type="application/json", body=json.dumps(message_body).encode()) + await channel.default_exchange.publish(message, routing_key=routing_key) + except Exception as e: + logging.error(f"Failed to publish message: {e}") + raise