1
0
mirror of https://github.com/ijaric/voice_assistant.git synced 2025-05-24 22:43:26 +00:00

Исправлена логика отправки сообщений

This commit is contained in:
grucshetskyaleksei 2023-10-03 04:46:52 +03:00
parent ffd66aec70
commit 42dbbf72d3
3 changed files with 40 additions and 31 deletions

View File

@ -19,7 +19,3 @@ class BasePublisher(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
async def publish_message(self, message_body: api_schemas.broker_message.BrokerMessage, routing_key: str) -> None: async def publish_message(self, message_body: api_schemas.broker_message.BrokerMessage, routing_key: str) -> None:
pass pass
@abc.abstractmethod
async def get_connection(self) -> contextlib.AbstractAsyncContextManager[T]:
pass

View File

@ -16,8 +16,3 @@ class BrokerPublisher:
async def publish_message(self, message_body: api_schemas.BrokerMessage, routing_key: str): async def publish_message(self, message_body: api_schemas.BrokerMessage, routing_key: str):
await self.broker.publish_message(message_body, routing_key) await self.broker.publish_message(message_body, routing_key)
@contextlib.asynccontextmanager
async def get_connection(self) -> typing.AsyncGenerator:
async with self.broker.get_connection() as conn:
yield conn

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
import contextlib import contextlib
import json import json
import logging
import aio_pika import aio_pika
@ -13,37 +14,54 @@ class RabbitMQPublisher(db_brokers.base_broker.BasePublisher):
def __init__(self, settings: app_split_settings.RabbitMQSettings()): def __init__(self, settings: app_split_settings.RabbitMQSettings()):
self.settings = settings self.settings = settings
self.connection = None self.connection = None
self.channel = None self.logger = logging.getLogger(__name__)
self.pool = asyncio.Queue() self.pool = asyncio.Queue(maxsize=settings.max_pool_size)
self.pool_size = settings.max_pool_size
async def connect(self): async def connect(self):
self.connection = await aio_pika.connect(self.settings.amqp_url) try:
self.channel = await self.connection.channel() self.connection = await aio_pika.connect(self.settings.amqp_url)
await self.channel.set_qos(prefetch_count=1)
exchange = await self.channel.declare_exchange(self.settings.exchange, aio_pika.ExchangeType.DIRECT)
for attr_name, attr_value in vars(self.settings.Queues).items(): for _ in range(self.settings.max_pool_size):
if not attr_name.startswith("__"): channel = await self.connection.channel()
queue = await self.channel.declare_queue(attr_value.value) await channel.set_qos(prefetch_count=1)
await queue.bind(exchange, attr_value.value) await self.pool.put(channel)
except Exception as e:
self.logger.error(f"Failed to dispose resources: {e}")
raise
async def dispose(self): async def dispose(self):
await self.channel.close() try:
await self.connection.close() while not self.pool.empty():
channel = await self.pool.get()
await channel.close()
async def publish_message(self, message_body: api_schemas.BrokerMessage, routing_key: str): if self.connection:
message = aio_pika.Message(content_type="application/json", body=json.dumps(message_body).encode()) await self.connection.close()
await self.channel.default_exchange.publish(message, routing_key=routing_key) except Exception as e:
self.logger.error(f"Failed to dispose resources: {e}")
raise
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def get_connection(self): async def __get_channel(self):
if self.pool.empty() and self.pool.qsize() < self.pool_size: if not self.connection:
await self.connect() await self.connect()
await self.pool.put(self)
conn = await self.pool.get() if self.pool.empty():
channel = await self.connection.channel()
await self.pool.put(channel)
channel = await self.pool.get()
try: try:
yield conn yield channel
finally: finally:
await self.pool.put(conn) await self.pool.put(channel)
async def publish_message(self, message_body: api_schemas.BrokerMessage, routing_key: str):
try:
async with self.__get_channel() as channel:
message = aio_pika.Message(content_type="application/json", body=json.dumps(message_body).encode())
await channel.default_exchange.publish(message, routing_key=routing_key)
except Exception as e:
logging.error(f"Failed to publish message: {e}")
raise