mirror of
https://github.com/ijaric/voice_assistant.git
synced 2025-05-24 14:33:26 +00:00
84 lines
3.4 KiB
Python
84 lines
3.4 KiB
Python
import logging
|
|
import uuid
|
|
|
|
import sqlalchemy as sa
|
|
import sqlalchemy.exc
|
|
import sqlalchemy.ext.asyncio as sa_asyncio
|
|
|
|
import lib.models as models
|
|
import lib.orm_models as orm_models
|
|
|
|
|
|
class ChatHistoryRepository:
|
|
def __init__(self, pg_async_session: sa_asyncio.async_sessionmaker[sa_asyncio.AsyncSession]) -> None:
|
|
self.pg_async_session = pg_async_session
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
async def get_last_session_id(self, request: models.RequestLastSessionId) -> uuid.UUID | None:
|
|
"""Get a current session ID if exists."""
|
|
|
|
try:
|
|
async with self.pg_async_session() as session:
|
|
statement = (
|
|
sa.select(orm_models.ChatHistory)
|
|
.filter_by(channel=request.channel, user_id=request.user_id)
|
|
.filter(
|
|
(sa.text("NOW()") - sa.func.extract("epoch", orm_models.ChatHistory.created)) / 60
|
|
<= request.minutes_ago
|
|
)
|
|
.order_by(orm_models.ChatHistory.created.desc())
|
|
.limit(1)
|
|
)
|
|
result = await session.execute(statement)
|
|
|
|
chat_session = result.scalars().first()
|
|
if chat_session:
|
|
return chat_session.id
|
|
except sqlalchemy.exc.SQLAlchemyError as error:
|
|
self.logger.exception("Error: %s", error)
|
|
|
|
async def get_messages_by_sid(self, request: models.RequestChatHistory):
|
|
"""Get all messages of a chat by session ID."""
|
|
|
|
try:
|
|
async with self.pg_async_session() as session:
|
|
statement = (
|
|
sa.select(orm_models.ChatHistory)
|
|
.filter_by(id=request.session_id)
|
|
.order_by(orm_models.ChatHistory.created.desc())
|
|
)
|
|
result = await session.execute(statement)
|
|
for row in result.scalars().all():
|
|
print("Row: ", row)
|
|
except sqlalchemy.exc.SQLAlchemyError as error:
|
|
self.logger.exception("Error: %s", error)
|
|
|
|
# async def get_all_by_session_id(self, request: models.RequestChatHistory) -> list[models.ChatHistory]:
|
|
# try:
|
|
# async with self.pg_async_session() as session:
|
|
# statement = (
|
|
# sa.select(orm_models.ChatHistory)
|
|
# .filter_by(id=request.session_id)
|
|
# .order_by(orm_models.ChatHistory.created.desc())
|
|
# )
|
|
# result = await session.execute(statement)
|
|
|
|
# return [models.ChatHistory.from_orm(chat_history) for chat_history in result.scalars().all()]
|
|
|
|
async def add_message(self, request: models.ChatMessage) -> None:
|
|
"""Add a message to the chat history."""
|
|
try:
|
|
async with self.pg_async_session() as session:
|
|
chat_history = orm_models.ChatHistory(
|
|
id=uuid.uuid4(),
|
|
session_id=request.session_id,
|
|
user_id=request.user_id,
|
|
channel=request.channel,
|
|
content=request.message,
|
|
)
|
|
session.add(chat_history)
|
|
await session.commit()
|
|
# TODO: Add refresh to session and return added object
|
|
except sqlalchemy.exc.SQLAlchemyError as error:
|
|
self.logger.exception("Error: %s", error)
|