mirror of
https://github.com/ijaric/voice_assistant.git
synced 2025-12-18 05:26:18 +00:00
fix: work in progess: models & migrations & repos
This commit is contained in:
@@ -15,7 +15,7 @@ class ChatHistoryRepository:
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
async def get_last_session_id(self, request: models.RequestLastSessionId) -> uuid.UUID | None:
|
||||
"""Get a new session ID."""
|
||||
"""Get a current session ID if exists."""
|
||||
|
||||
try:
|
||||
async with self.pg_async_session() as session:
|
||||
@@ -23,10 +23,7 @@ class ChatHistoryRepository:
|
||||
sa.select(orm_models.ChatHistory)
|
||||
.filter_by(channel=request.channel, user_id=request.user_id)
|
||||
.filter(
|
||||
(
|
||||
sa.func.extract("epoch", orm_models.ChatHistory.created)
|
||||
- sa.func.extract("epoch", orm_models.ChatHistory.modified) / 60
|
||||
)
|
||||
(sa.text("NOW()") - sa.func.extract("epoch", orm_models.ChatHistory.created)) / 60
|
||||
<= request.minutes_ago
|
||||
)
|
||||
.order_by(orm_models.ChatHistory.created.desc())
|
||||
@@ -39,3 +36,48 @@ class ChatHistoryRepository:
|
||||
return chat_session.id
|
||||
except sqlalchemy.exc.SQLAlchemyError as error:
|
||||
self.logger.exception("Error: %s", error)
|
||||
|
||||
async def get_messages_by_sid(self, request: models.RequestChatHistory):
|
||||
"""Get all messages of a chat by session ID."""
|
||||
|
||||
try:
|
||||
async with self.pg_async_session() as session:
|
||||
statement = (
|
||||
sa.select(orm_models.ChatHistory)
|
||||
.filter_by(id=request.session_id)
|
||||
.order_by(orm_models.ChatHistory.created.desc())
|
||||
)
|
||||
result = await session.execute(statement)
|
||||
for row in result.scalars().all():
|
||||
print("Row: ", row)
|
||||
except sqlalchemy.exc.SQLAlchemyError as error:
|
||||
self.logger.exception("Error: %s", error)
|
||||
|
||||
# async def get_all_by_session_id(self, request: models.RequestChatHistory) -> list[models.ChatHistory]:
|
||||
# try:
|
||||
# async with self.pg_async_session() as session:
|
||||
# statement = (
|
||||
# sa.select(orm_models.ChatHistory)
|
||||
# .filter_by(id=request.session_id)
|
||||
# .order_by(orm_models.ChatHistory.created.desc())
|
||||
# )
|
||||
# result = await session.execute(statement)
|
||||
|
||||
# return [models.ChatHistory.from_orm(chat_history) for chat_history in result.scalars().all()]
|
||||
|
||||
async def add_message(self, request: models.ChatMessage) -> None:
|
||||
"""Add a message to the chat history."""
|
||||
try:
|
||||
async with self.pg_async_session() as session:
|
||||
chat_history = orm_models.ChatHistory(
|
||||
id=uuid.uuid4(),
|
||||
session_id=request.session_id,
|
||||
user_id=request.user_id,
|
||||
channel=request.channel,
|
||||
content=request.message,
|
||||
)
|
||||
session.add(chat_history)
|
||||
await session.commit()
|
||||
# TODO: Add refresh to session and return added object
|
||||
except sqlalchemy.exc.SQLAlchemyError as error:
|
||||
self.logger.exception("Error: %s", error)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import uuid
|
||||
|
||||
import fastapi
|
||||
|
||||
import lib.agent as agent
|
||||
@@ -15,9 +17,27 @@ class AgentHandler:
|
||||
summary="Статус работоспособности",
|
||||
description="Проверяет доступность сервиса FastAPI.",
|
||||
)
|
||||
self.router.add_api_route(
|
||||
"/add",
|
||||
self.add_message,
|
||||
methods=["GET"],
|
||||
summary="Статус работоспособности",
|
||||
description="Проверяет доступность сервиса FastAPI.",
|
||||
)
|
||||
|
||||
async def get_agent(self):
|
||||
request = models.RequestLastSessionId(channel="test", user_id="test", minutes_ago=3)
|
||||
request = models.RequestLastSessionId(channel="test", user_id="user_id_1", minutes_ago=3)
|
||||
response = await self.chat_history_repository.get_last_session_id(request=request)
|
||||
print("RESPONSE: ", response)
|
||||
return {"response": response}
|
||||
|
||||
async def add_message(self):
|
||||
sid: uuid.UUID = uuid.UUID("0cd3c882-affd-4929-aff1-e1724f5b54f2")
|
||||
import faker
|
||||
fake = faker.Faker()
|
||||
|
||||
message = models.ChatMessage(
|
||||
session_id=sid, user_id="user_id_1", channel="test", message={"role": "system", "content": fake.sentence()}
|
||||
)
|
||||
await self.chat_history_repository.add_message(request=message)
|
||||
return {"response": "ok"}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from .chat_history import RequestLastSessionId
|
||||
from .chat_history import ChatMessage, RequestChatHistory, RequestLastSessionId
|
||||
from .embedding import Embedding
|
||||
from .movies import Movie
|
||||
from .token import Token
|
||||
|
||||
__all__ = ["Embedding", "Movie", "RequestLastSessionId", "Token"]
|
||||
__all__ = ["ChatMessage", "Embedding", "Movie", "RequestChatHistory", "RequestLastSessionId", "Token"]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import uuid
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
@@ -7,3 +9,18 @@ class RequestLastSessionId(pydantic.BaseModel):
|
||||
channel: str
|
||||
user_id: str
|
||||
minutes_ago: int
|
||||
|
||||
|
||||
class ChatMessage(pydantic.BaseModel):
|
||||
"""A chat message."""
|
||||
|
||||
session_id: uuid.UUID
|
||||
user_id: str
|
||||
channel: str
|
||||
message: dict[str, str]
|
||||
|
||||
|
||||
class RequestChatHistory(pydantic.BaseModel):
|
||||
"""Request for chat history."""
|
||||
|
||||
session_id: uuid.UUID
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from .base import Base, IdCreatedUpdatedBaseMixin
|
||||
from .movies import ChatHistory, FilmWork, Genre, GenreFilmWork, Person, PersonFilmWork
|
||||
from .chat_history import ChatHistory
|
||||
from .movies import FilmWork, Genre, GenreFilmWork, Person, PersonFilmWork
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
|
||||
24
src/assistant/lib/orm_models/chat_history.py
Normal file
24
src/assistant/lib/orm_models/chat_history.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.orm as sa_orm
|
||||
import sqlalchemy.sql as sa_sql
|
||||
|
||||
import lib.orm_models.base as base_models
|
||||
|
||||
|
||||
class ChatHistory(base_models.Base):
|
||||
__tablename__: str = "chat_history" # type: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
||||
session_id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(nullable=False, unique=True)
|
||||
channel: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
||||
user_id: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
||||
content: sa_orm.Mapped[sa.JSON] = sa_orm.mapped_column(sa.JSON)
|
||||
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
|
||||
)
|
||||
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
|
||||
)
|
||||
@@ -74,19 +74,3 @@ PersonFilmWork = sa.Table(
|
||||
sa.Column("role", sa.String(50), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa_sql.func.now()),
|
||||
)
|
||||
|
||||
|
||||
class ChatHistory(base_models.Base):
|
||||
__tablename__: str = "chat_history" # type: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
||||
session_id: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
||||
channel: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
||||
user_id: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
||||
content: sa_orm.Mapped[sa.JSON] = sa_orm.mapped_column(sa.JSON)
|
||||
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
|
||||
)
|
||||
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user