From 73f19d46c4eda4cddad3ff46188dce6059e945cd Mon Sep 17 00:00:00 2001 From: Artem Litvinov Date: Thu, 12 Oct 2023 03:00:40 +0100 Subject: [PATCH] build: ready chat history repo --- src/assistant/lib/agent/chat_repository.py | 36 +++++++++++----------- src/assistant/lib/api/v1/handlers/agent.py | 20 ++++++++++-- src/assistant/lib/models/__init__.py | 4 +-- src/assistant/lib/models/chat_history.py | 9 +++++- src/assistant/pyproject.toml | 2 +- 5 files changed, 47 insertions(+), 24 deletions(-) diff --git a/src/assistant/lib/agent/chat_repository.py b/src/assistant/lib/agent/chat_repository.py index 1bcc5dc..d632d4b 100644 --- a/src/assistant/lib/agent/chat_repository.py +++ b/src/assistant/lib/agent/chat_repository.py @@ -17,13 +17,18 @@ class ChatHistoryRepository: async def get_last_session_id(self, request: models.RequestLastSessionId) -> uuid.UUID | None: """Get a current session ID if exists.""" + self.logger.debug("get_last_session_id: %s", request) try: async with self.pg_async_session() as session: statement = ( sa.select(orm_models.ChatHistory) .filter_by(channel=request.channel, user_id=request.user_id) .filter( - (sa.func.extract("epoch", sa.text("NOW()")) - sa.func.extract("epoch", orm_models.ChatHistory.created)) / 60 + ( + sa.func.extract("epoch", sa.text("NOW()")) + - sa.func.extract("epoch", orm_models.ChatHistory.created) + ) + / 60 <= request.minutes_ago ) .order_by(orm_models.ChatHistory.created.desc()) @@ -37,36 +42,31 @@ class ChatHistoryRepository: except sqlalchemy.exc.SQLAlchemyError as error: self.logger.exception("Error: %s", error) - async def get_messages_by_sid(self, request: models.RequestChatHistory): + async def get_messages_by_sid(self, request: models.RequestChatHistory) -> list[models.Message] | None: """Get all messages of a chat by session ID.""" + self.logger.debug("get_messages_by_sid: %s", request) try: async with self.pg_async_session() as session: + messages: list[models.Message] = [] statement = ( sa.select(orm_models.ChatHistory) - .filter_by(id=request.session_id) - .order_by(orm_models.ChatHistory.created.desc()) + .filter_by(session_id=request.session_id) + .order_by(orm_models.ChatHistory.created.asc()) ) + print("get_messages_by_sid:", statement) result = await session.execute(statement) for row in result.scalars().all(): - print("Row: ", row) + # TODO: Было бы интересно понять почему pyright ругается ниже и как правильно вызывать компоненты + messages.append(models.Message(role=row.content["role"], content=row.content["content"])) # type: ignore[reportGeneralTypeIssues] + return messages except sqlalchemy.exc.SQLAlchemyError as error: self.logger.exception("Error: %s", error) - # async def get_all_by_session_id(self, request: models.RequestChatHistory) -> list[models.ChatHistory]: - # try: - # async with self.pg_async_session() as session: - # statement = ( - # sa.select(orm_models.ChatHistory) - # .filter_by(id=request.session_id) - # .order_by(orm_models.ChatHistory.created.desc()) - # ) - # result = await session.execute(statement) - - # return [models.ChatHistory.from_orm(chat_history) for chat_history in result.scalars().all()] - - async def add_message(self, request: models.ChatMessage) -> None: + async def add_message(self, request: models.RequestChatMessage) -> None: """Add a message to the chat history.""" + + self.logger.debug("add_message: %s", request) try: async with self.pg_async_session() as session: chat_history = orm_models.ChatHistory( diff --git a/src/assistant/lib/api/v1/handlers/agent.py b/src/assistant/lib/api/v1/handlers/agent.py index 0fce3a9..bd584e0 100644 --- a/src/assistant/lib/api/v1/handlers/agent.py +++ b/src/assistant/lib/api/v1/handlers/agent.py @@ -24,6 +24,13 @@ class AgentHandler: summary="Статус работоспособности", description="Проверяет доступность сервиса FastAPI.", ) + self.router.add_api_route( + "/messages", + self.get_messages, + methods=["GET"], + summary="Статус работоспособности", + description="Проверяет доступность сервиса FastAPI.", + ) async def get_agent(self): request = models.RequestLastSessionId(channel="test", user_id="user_id_1", minutes_ago=3) @@ -34,10 +41,19 @@ class AgentHandler: async def add_message(self): sid: uuid.UUID = uuid.UUID("0cd3c882-affd-4929-aff1-e1724f5b54f2") import faker + fake = faker.Faker() - - message = models.ChatMessage( + + message = models.RequestChatMessage( session_id=sid, user_id="user_id_1", channel="test", message={"role": "system", "content": fake.sentence()} ) await self.chat_history_repository.add_message(request=message) return {"response": "ok"} + + async def get_messages(self): + sid: uuid.UUID = uuid.UUID("0cd3c882-affd-4929-aff1-e1724f5b54f2") + + request = models.RequestChatHistory(session_id=sid) + response = await self.chat_history_repository.get_messages_by_sid(request=request) + print("RESPONSE: ", response) + return {"response": response} diff --git a/src/assistant/lib/models/__init__.py b/src/assistant/lib/models/__init__.py index 51dd416..a514cb1 100644 --- a/src/assistant/lib/models/__init__.py +++ b/src/assistant/lib/models/__init__.py @@ -1,6 +1,6 @@ -from .chat_history import ChatMessage, RequestChatHistory, RequestLastSessionId +from .chat_history import Message, RequestChatHistory, RequestChatMessage, RequestLastSessionId from .embedding import Embedding from .movies import Movie from .token import Token -__all__ = ["ChatMessage", "Embedding", "Movie", "RequestChatHistory", "RequestLastSessionId", "Token"] +__all__ = ["Embedding", "Message", "Movie", "RequestChatHistory", "RequestChatMessage", "RequestLastSessionId", "Token"] diff --git a/src/assistant/lib/models/chat_history.py b/src/assistant/lib/models/chat_history.py index 20c60e2..41a3dbe 100644 --- a/src/assistant/lib/models/chat_history.py +++ b/src/assistant/lib/models/chat_history.py @@ -11,7 +11,7 @@ class RequestLastSessionId(pydantic.BaseModel): minutes_ago: int -class ChatMessage(pydantic.BaseModel): +class RequestChatMessage(pydantic.BaseModel): """A chat message.""" session_id: uuid.UUID @@ -24,3 +24,10 @@ class RequestChatHistory(pydantic.BaseModel): """Request for chat history.""" session_id: uuid.UUID + + +class Message(pydantic.BaseModel): + """A chat message.""" + + role: str + content: str diff --git a/src/assistant/pyproject.toml b/src/assistant/pyproject.toml index 9d4a180..9f38127 100644 --- a/src/assistant/pyproject.toml +++ b/src/assistant/pyproject.toml @@ -23,6 +23,7 @@ version = "0.1.0" alembic = "^1.12.0" asyncpg = "^0.28.0" dill = "^0.3.7" +faker = "^19.10.0" fastapi = "0.103.1" greenlet = "^2.0.2" httpx = "^0.25.0" @@ -39,7 +40,6 @@ python-magic = "^0.4.27" sqlalchemy = "^2.0.20" uvicorn = "^0.23.2" wrapt = "^1.15.0" -faker = "^19.10.0" [tool.poetry.dev-dependencies] black = "^23.7.0"