mirror of
https://github.com/ijaric/voice_assistant.git
synced 2025-05-24 14:33:26 +00:00
build: ready chat history repo
This commit is contained in:
parent
3306025640
commit
73f19d46c4
|
@ -17,13 +17,18 @@ class ChatHistoryRepository:
|
||||||
async def get_last_session_id(self, request: models.RequestLastSessionId) -> uuid.UUID | None:
|
async def get_last_session_id(self, request: models.RequestLastSessionId) -> uuid.UUID | None:
|
||||||
"""Get a current session ID if exists."""
|
"""Get a current session ID if exists."""
|
||||||
|
|
||||||
|
self.logger.debug("get_last_session_id: %s", request)
|
||||||
try:
|
try:
|
||||||
async with self.pg_async_session() as session:
|
async with self.pg_async_session() as session:
|
||||||
statement = (
|
statement = (
|
||||||
sa.select(orm_models.ChatHistory)
|
sa.select(orm_models.ChatHistory)
|
||||||
.filter_by(channel=request.channel, user_id=request.user_id)
|
.filter_by(channel=request.channel, user_id=request.user_id)
|
||||||
.filter(
|
.filter(
|
||||||
(sa.func.extract("epoch", sa.text("NOW()")) - sa.func.extract("epoch", orm_models.ChatHistory.created)) / 60
|
(
|
||||||
|
sa.func.extract("epoch", sa.text("NOW()"))
|
||||||
|
- sa.func.extract("epoch", orm_models.ChatHistory.created)
|
||||||
|
)
|
||||||
|
/ 60
|
||||||
<= request.minutes_ago
|
<= request.minutes_ago
|
||||||
)
|
)
|
||||||
.order_by(orm_models.ChatHistory.created.desc())
|
.order_by(orm_models.ChatHistory.created.desc())
|
||||||
|
@ -37,36 +42,31 @@ class ChatHistoryRepository:
|
||||||
except sqlalchemy.exc.SQLAlchemyError as error:
|
except sqlalchemy.exc.SQLAlchemyError as error:
|
||||||
self.logger.exception("Error: %s", error)
|
self.logger.exception("Error: %s", error)
|
||||||
|
|
||||||
async def get_messages_by_sid(self, request: models.RequestChatHistory):
|
async def get_messages_by_sid(self, request: models.RequestChatHistory) -> list[models.Message] | None:
|
||||||
"""Get all messages of a chat by session ID."""
|
"""Get all messages of a chat by session ID."""
|
||||||
|
|
||||||
|
self.logger.debug("get_messages_by_sid: %s", request)
|
||||||
try:
|
try:
|
||||||
async with self.pg_async_session() as session:
|
async with self.pg_async_session() as session:
|
||||||
|
messages: list[models.Message] = []
|
||||||
statement = (
|
statement = (
|
||||||
sa.select(orm_models.ChatHistory)
|
sa.select(orm_models.ChatHistory)
|
||||||
.filter_by(id=request.session_id)
|
.filter_by(session_id=request.session_id)
|
||||||
.order_by(orm_models.ChatHistory.created.desc())
|
.order_by(orm_models.ChatHistory.created.asc())
|
||||||
)
|
)
|
||||||
|
print("get_messages_by_sid:", statement)
|
||||||
result = await session.execute(statement)
|
result = await session.execute(statement)
|
||||||
for row in result.scalars().all():
|
for row in result.scalars().all():
|
||||||
print("Row: ", row)
|
# TODO: Было бы интересно понять почему pyright ругается ниже и как правильно вызывать компоненты
|
||||||
|
messages.append(models.Message(role=row.content["role"], content=row.content["content"])) # type: ignore[reportGeneralTypeIssues]
|
||||||
|
return messages
|
||||||
except sqlalchemy.exc.SQLAlchemyError as error:
|
except sqlalchemy.exc.SQLAlchemyError as error:
|
||||||
self.logger.exception("Error: %s", error)
|
self.logger.exception("Error: %s", error)
|
||||||
|
|
||||||
# async def get_all_by_session_id(self, request: models.RequestChatHistory) -> list[models.ChatHistory]:
|
async def add_message(self, request: models.RequestChatMessage) -> None:
|
||||||
# try:
|
|
||||||
# async with self.pg_async_session() as session:
|
|
||||||
# statement = (
|
|
||||||
# sa.select(orm_models.ChatHistory)
|
|
||||||
# .filter_by(id=request.session_id)
|
|
||||||
# .order_by(orm_models.ChatHistory.created.desc())
|
|
||||||
# )
|
|
||||||
# result = await session.execute(statement)
|
|
||||||
|
|
||||||
# return [models.ChatHistory.from_orm(chat_history) for chat_history in result.scalars().all()]
|
|
||||||
|
|
||||||
async def add_message(self, request: models.ChatMessage) -> None:
|
|
||||||
"""Add a message to the chat history."""
|
"""Add a message to the chat history."""
|
||||||
|
|
||||||
|
self.logger.debug("add_message: %s", request)
|
||||||
try:
|
try:
|
||||||
async with self.pg_async_session() as session:
|
async with self.pg_async_session() as session:
|
||||||
chat_history = orm_models.ChatHistory(
|
chat_history = orm_models.ChatHistory(
|
||||||
|
|
|
@ -24,6 +24,13 @@ class AgentHandler:
|
||||||
summary="Статус работоспособности",
|
summary="Статус работоспособности",
|
||||||
description="Проверяет доступность сервиса FastAPI.",
|
description="Проверяет доступность сервиса FastAPI.",
|
||||||
)
|
)
|
||||||
|
self.router.add_api_route(
|
||||||
|
"/messages",
|
||||||
|
self.get_messages,
|
||||||
|
methods=["GET"],
|
||||||
|
summary="Статус работоспособности",
|
||||||
|
description="Проверяет доступность сервиса FastAPI.",
|
||||||
|
)
|
||||||
|
|
||||||
async def get_agent(self):
|
async def get_agent(self):
|
||||||
request = models.RequestLastSessionId(channel="test", user_id="user_id_1", minutes_ago=3)
|
request = models.RequestLastSessionId(channel="test", user_id="user_id_1", minutes_ago=3)
|
||||||
|
@ -34,10 +41,19 @@ class AgentHandler:
|
||||||
async def add_message(self):
|
async def add_message(self):
|
||||||
sid: uuid.UUID = uuid.UUID("0cd3c882-affd-4929-aff1-e1724f5b54f2")
|
sid: uuid.UUID = uuid.UUID("0cd3c882-affd-4929-aff1-e1724f5b54f2")
|
||||||
import faker
|
import faker
|
||||||
|
|
||||||
fake = faker.Faker()
|
fake = faker.Faker()
|
||||||
|
|
||||||
message = models.ChatMessage(
|
message = models.RequestChatMessage(
|
||||||
session_id=sid, user_id="user_id_1", channel="test", message={"role": "system", "content": fake.sentence()}
|
session_id=sid, user_id="user_id_1", channel="test", message={"role": "system", "content": fake.sentence()}
|
||||||
)
|
)
|
||||||
await self.chat_history_repository.add_message(request=message)
|
await self.chat_history_repository.add_message(request=message)
|
||||||
return {"response": "ok"}
|
return {"response": "ok"}
|
||||||
|
|
||||||
|
async def get_messages(self):
|
||||||
|
sid: uuid.UUID = uuid.UUID("0cd3c882-affd-4929-aff1-e1724f5b54f2")
|
||||||
|
|
||||||
|
request = models.RequestChatHistory(session_id=sid)
|
||||||
|
response = await self.chat_history_repository.get_messages_by_sid(request=request)
|
||||||
|
print("RESPONSE: ", response)
|
||||||
|
return {"response": response}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from .chat_history import ChatMessage, RequestChatHistory, RequestLastSessionId
|
from .chat_history import Message, RequestChatHistory, RequestChatMessage, RequestLastSessionId
|
||||||
from .embedding import Embedding
|
from .embedding import Embedding
|
||||||
from .movies import Movie
|
from .movies import Movie
|
||||||
from .token import Token
|
from .token import Token
|
||||||
|
|
||||||
__all__ = ["ChatMessage", "Embedding", "Movie", "RequestChatHistory", "RequestLastSessionId", "Token"]
|
__all__ = ["Embedding", "Message", "Movie", "RequestChatHistory", "RequestChatMessage", "RequestLastSessionId", "Token"]
|
||||||
|
|
|
@ -11,7 +11,7 @@ class RequestLastSessionId(pydantic.BaseModel):
|
||||||
minutes_ago: int
|
minutes_ago: int
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(pydantic.BaseModel):
|
class RequestChatMessage(pydantic.BaseModel):
|
||||||
"""A chat message."""
|
"""A chat message."""
|
||||||
|
|
||||||
session_id: uuid.UUID
|
session_id: uuid.UUID
|
||||||
|
@ -24,3 +24,10 @@ class RequestChatHistory(pydantic.BaseModel):
|
||||||
"""Request for chat history."""
|
"""Request for chat history."""
|
||||||
|
|
||||||
session_id: uuid.UUID
|
session_id: uuid.UUID
|
||||||
|
|
||||||
|
|
||||||
|
class Message(pydantic.BaseModel):
|
||||||
|
"""A chat message."""
|
||||||
|
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
|
@ -23,6 +23,7 @@ version = "0.1.0"
|
||||||
alembic = "^1.12.0"
|
alembic = "^1.12.0"
|
||||||
asyncpg = "^0.28.0"
|
asyncpg = "^0.28.0"
|
||||||
dill = "^0.3.7"
|
dill = "^0.3.7"
|
||||||
|
faker = "^19.10.0"
|
||||||
fastapi = "0.103.1"
|
fastapi = "0.103.1"
|
||||||
greenlet = "^2.0.2"
|
greenlet = "^2.0.2"
|
||||||
httpx = "^0.25.0"
|
httpx = "^0.25.0"
|
||||||
|
@ -39,7 +40,6 @@ python-magic = "^0.4.27"
|
||||||
sqlalchemy = "^2.0.20"
|
sqlalchemy = "^2.0.20"
|
||||||
uvicorn = "^0.23.2"
|
uvicorn = "^0.23.2"
|
||||||
wrapt = "^1.15.0"
|
wrapt = "^1.15.0"
|
||||||
faker = "^19.10.0"
|
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
black = "^23.7.0"
|
black = "^23.7.0"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user