mirror of
https://github.com/ijaric/voice_assistant.git
synced 2025-05-24 14:33:26 +00:00
fix: work in progess: models & migrations & repos
This commit is contained in:
parent
b75af6034b
commit
2db4a87bf4
|
@ -6,7 +6,6 @@ from sqlalchemy.engine import Connection
|
||||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||||
|
|
||||||
import lib.app.settings as app_settings
|
import lib.app.settings as app_settings
|
||||||
import lib.models as models
|
|
||||||
import lib.orm_models as orm_models
|
import lib.orm_models as orm_models
|
||||||
from alembic import context
|
from alembic import context
|
||||||
|
|
||||||
|
@ -24,7 +23,7 @@ print("BASE: ", orm_models.Base.metadata.schema)
|
||||||
for t in orm_models.Base.metadata.sorted_tables:
|
for t in orm_models.Base.metadata.sorted_tables:
|
||||||
print(t.name)
|
print(t.name)
|
||||||
|
|
||||||
target_metadata = models.Base.metadata
|
target_metadata = orm_models.Base.metadata
|
||||||
|
|
||||||
|
|
||||||
def run_migrations_offline() -> None:
|
def run_migrations_offline() -> None:
|
||||||
|
|
|
@ -24,7 +24,7 @@ def upgrade() -> None:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"chat_history",
|
"chat_history",
|
||||||
sa.Column("id", sa.Uuid(), nullable=False),
|
sa.Column("id", sa.Uuid(), nullable=False),
|
||||||
sa.Column("session_id", sa.String(), nullable=False),
|
sa.Column("session_id", sa.Uuid(), nullable=False),
|
||||||
sa.Column("channel", sa.String(), nullable=False),
|
sa.Column("channel", sa.String(), nullable=False),
|
||||||
sa.Column("user_id", sa.String(), nullable=False),
|
sa.Column("user_id", sa.String(), nullable=False),
|
||||||
sa.Column("content", sa.JSON(), nullable=False),
|
sa.Column("content", sa.JSON(), nullable=False),
|
||||||
|
@ -33,39 +33,6 @@ def upgrade() -> None:
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint("id"),
|
||||||
schema="content",
|
schema="content",
|
||||||
)
|
)
|
||||||
op.drop_table("auth_group")
|
|
||||||
op.drop_table("auth_user_groups")
|
|
||||||
op.drop_table("auth_group_permissions")
|
|
||||||
op.drop_table("auth_user_user_permissions")
|
|
||||||
op.drop_table("auth_user")
|
|
||||||
op.drop_table("django_content_type")
|
|
||||||
op.drop_table("auth_permission")
|
|
||||||
op.drop_table("django_session")
|
|
||||||
op.drop_table("django_admin_log")
|
|
||||||
op.drop_table("django_migrations")
|
|
||||||
op.alter_column("film_work", "title", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=False)
|
|
||||||
op.alter_column("film_work", "description", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=True)
|
|
||||||
op.alter_column("film_work", "creation_date", existing_type=sa.DATE(), type_=sa.DateTime(), existing_nullable=True)
|
|
||||||
op.alter_column("film_work", "file_path", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=True)
|
|
||||||
op.alter_column("film_work", "type", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=False)
|
|
||||||
op.alter_column("film_work", "created", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False)
|
|
||||||
op.alter_column("film_work", "modified", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False)
|
|
||||||
op.alter_column("genre", "name", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=False)
|
|
||||||
op.alter_column("genre", "description", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=True)
|
|
||||||
op.alter_column("genre", "created", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False)
|
|
||||||
op.alter_column("genre", "modified", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False)
|
|
||||||
op.create_foreign_key(None, "genre_film_work", "genre", ["genre_id"], ["id"], referent_schema="content")
|
|
||||||
op.create_foreign_key(None, "genre_film_work", "film_work", ["film_work_id"], ["id"], referent_schema="content")
|
|
||||||
op.alter_column("person", "full_name", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=False)
|
|
||||||
op.alter_column("person", "created", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False)
|
|
||||||
op.alter_column("person", "modified", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False)
|
|
||||||
op.alter_column(
|
|
||||||
"person_film_work", "role", existing_type=sa.TEXT(), type_=sa.String(length=50), existing_nullable=False
|
|
||||||
)
|
|
||||||
op.create_foreign_key(None, "person_film_work", "film_work", ["film_work_id"], ["id"], referent_schema="content")
|
|
||||||
op.create_foreign_key(None, "person_film_work", "person", ["person_id"], ["id"], referent_schema="content")
|
|
||||||
op.drop_column("person_film_work", "id")
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
|
|
|
@ -15,7 +15,7 @@ class ChatHistoryRepository:
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
async def get_last_session_id(self, request: models.RequestLastSessionId) -> uuid.UUID | None:
|
async def get_last_session_id(self, request: models.RequestLastSessionId) -> uuid.UUID | None:
|
||||||
"""Get a new session ID."""
|
"""Get a current session ID if exists."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self.pg_async_session() as session:
|
async with self.pg_async_session() as session:
|
||||||
|
@ -23,10 +23,7 @@ class ChatHistoryRepository:
|
||||||
sa.select(orm_models.ChatHistory)
|
sa.select(orm_models.ChatHistory)
|
||||||
.filter_by(channel=request.channel, user_id=request.user_id)
|
.filter_by(channel=request.channel, user_id=request.user_id)
|
||||||
.filter(
|
.filter(
|
||||||
(
|
(sa.text("NOW()") - sa.func.extract("epoch", orm_models.ChatHistory.created)) / 60
|
||||||
sa.func.extract("epoch", orm_models.ChatHistory.created)
|
|
||||||
- sa.func.extract("epoch", orm_models.ChatHistory.modified) / 60
|
|
||||||
)
|
|
||||||
<= request.minutes_ago
|
<= request.minutes_ago
|
||||||
)
|
)
|
||||||
.order_by(orm_models.ChatHistory.created.desc())
|
.order_by(orm_models.ChatHistory.created.desc())
|
||||||
|
@ -39,3 +36,48 @@ class ChatHistoryRepository:
|
||||||
return chat_session.id
|
return chat_session.id
|
||||||
except sqlalchemy.exc.SQLAlchemyError as error:
|
except sqlalchemy.exc.SQLAlchemyError as error:
|
||||||
self.logger.exception("Error: %s", error)
|
self.logger.exception("Error: %s", error)
|
||||||
|
|
||||||
|
async def get_messages_by_sid(self, request: models.RequestChatHistory):
|
||||||
|
"""Get all messages of a chat by session ID."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self.pg_async_session() as session:
|
||||||
|
statement = (
|
||||||
|
sa.select(orm_models.ChatHistory)
|
||||||
|
.filter_by(id=request.session_id)
|
||||||
|
.order_by(orm_models.ChatHistory.created.desc())
|
||||||
|
)
|
||||||
|
result = await session.execute(statement)
|
||||||
|
for row in result.scalars().all():
|
||||||
|
print("Row: ", row)
|
||||||
|
except sqlalchemy.exc.SQLAlchemyError as error:
|
||||||
|
self.logger.exception("Error: %s", error)
|
||||||
|
|
||||||
|
# async def get_all_by_session_id(self, request: models.RequestChatHistory) -> list[models.ChatHistory]:
|
||||||
|
# try:
|
||||||
|
# async with self.pg_async_session() as session:
|
||||||
|
# statement = (
|
||||||
|
# sa.select(orm_models.ChatHistory)
|
||||||
|
# .filter_by(id=request.session_id)
|
||||||
|
# .order_by(orm_models.ChatHistory.created.desc())
|
||||||
|
# )
|
||||||
|
# result = await session.execute(statement)
|
||||||
|
|
||||||
|
# return [models.ChatHistory.from_orm(chat_history) for chat_history in result.scalars().all()]
|
||||||
|
|
||||||
|
async def add_message(self, request: models.ChatMessage) -> None:
|
||||||
|
"""Add a message to the chat history."""
|
||||||
|
try:
|
||||||
|
async with self.pg_async_session() as session:
|
||||||
|
chat_history = orm_models.ChatHistory(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
session_id=request.session_id,
|
||||||
|
user_id=request.user_id,
|
||||||
|
channel=request.channel,
|
||||||
|
content=request.message,
|
||||||
|
)
|
||||||
|
session.add(chat_history)
|
||||||
|
await session.commit()
|
||||||
|
# TODO: Add refresh to session and return added object
|
||||||
|
except sqlalchemy.exc.SQLAlchemyError as error:
|
||||||
|
self.logger.exception("Error: %s", error)
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import uuid
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
|
|
||||||
import lib.agent as agent
|
import lib.agent as agent
|
||||||
|
@ -15,9 +17,27 @@ class AgentHandler:
|
||||||
summary="Статус работоспособности",
|
summary="Статус работоспособности",
|
||||||
description="Проверяет доступность сервиса FastAPI.",
|
description="Проверяет доступность сервиса FastAPI.",
|
||||||
)
|
)
|
||||||
|
self.router.add_api_route(
|
||||||
|
"/add",
|
||||||
|
self.add_message,
|
||||||
|
methods=["GET"],
|
||||||
|
summary="Статус работоспособности",
|
||||||
|
description="Проверяет доступность сервиса FastAPI.",
|
||||||
|
)
|
||||||
|
|
||||||
async def get_agent(self):
|
async def get_agent(self):
|
||||||
request = models.RequestLastSessionId(channel="test", user_id="test", minutes_ago=3)
|
request = models.RequestLastSessionId(channel="test", user_id="user_id_1", minutes_ago=3)
|
||||||
response = await self.chat_history_repository.get_last_session_id(request=request)
|
response = await self.chat_history_repository.get_last_session_id(request=request)
|
||||||
print("RESPONSE: ", response)
|
print("RESPONSE: ", response)
|
||||||
return {"response": response}
|
return {"response": response}
|
||||||
|
|
||||||
|
async def add_message(self):
|
||||||
|
sid: uuid.UUID = uuid.UUID("0cd3c882-affd-4929-aff1-e1724f5b54f2")
|
||||||
|
import faker
|
||||||
|
fake = faker.Faker()
|
||||||
|
|
||||||
|
message = models.ChatMessage(
|
||||||
|
session_id=sid, user_id="user_id_1", channel="test", message={"role": "system", "content": fake.sentence()}
|
||||||
|
)
|
||||||
|
await self.chat_history_repository.add_message(request=message)
|
||||||
|
return {"response": "ok"}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from .chat_history import RequestLastSessionId
|
from .chat_history import ChatMessage, RequestChatHistory, RequestLastSessionId
|
||||||
from .embedding import Embedding
|
from .embedding import Embedding
|
||||||
from .movies import Movie
|
from .movies import Movie
|
||||||
from .token import Token
|
from .token import Token
|
||||||
|
|
||||||
__all__ = ["Embedding", "Movie", "RequestLastSessionId", "Token"]
|
__all__ = ["ChatMessage", "Embedding", "Movie", "RequestChatHistory", "RequestLastSessionId", "Token"]
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,3 +9,18 @@ class RequestLastSessionId(pydantic.BaseModel):
|
||||||
channel: str
|
channel: str
|
||||||
user_id: str
|
user_id: str
|
||||||
minutes_ago: int
|
minutes_ago: int
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(pydantic.BaseModel):
|
||||||
|
"""A chat message."""
|
||||||
|
|
||||||
|
session_id: uuid.UUID
|
||||||
|
user_id: str
|
||||||
|
channel: str
|
||||||
|
message: dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
class RequestChatHistory(pydantic.BaseModel):
|
||||||
|
"""Request for chat history."""
|
||||||
|
|
||||||
|
session_id: uuid.UUID
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from .base import Base, IdCreatedUpdatedBaseMixin
|
from .base import Base, IdCreatedUpdatedBaseMixin
|
||||||
from .movies import ChatHistory, FilmWork, Genre, GenreFilmWork, Person, PersonFilmWork
|
from .chat_history import ChatHistory
|
||||||
|
from .movies import FilmWork, Genre, GenreFilmWork, Person, PersonFilmWork
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Base",
|
"Base",
|
||||||
|
|
24
src/assistant/lib/orm_models/chat_history.py
Normal file
24
src/assistant/lib/orm_models/chat_history.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import sqlalchemy.orm as sa_orm
|
||||||
|
import sqlalchemy.sql as sa_sql
|
||||||
|
|
||||||
|
import lib.orm_models.base as base_models
|
||||||
|
|
||||||
|
|
||||||
|
class ChatHistory(base_models.Base):
|
||||||
|
__tablename__: str = "chat_history" # type: ignore[reportIncompatibleVariableOverride]
|
||||||
|
|
||||||
|
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
||||||
|
session_id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(nullable=False, unique=True)
|
||||||
|
channel: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
||||||
|
user_id: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
||||||
|
content: sa_orm.Mapped[sa.JSON] = sa_orm.mapped_column(sa.JSON)
|
||||||
|
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
||||||
|
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
|
||||||
|
)
|
||||||
|
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
||||||
|
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
|
||||||
|
)
|
|
@ -74,19 +74,3 @@ PersonFilmWork = sa.Table(
|
||||||
sa.Column("role", sa.String(50), nullable=False),
|
sa.Column("role", sa.String(50), nullable=False),
|
||||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa_sql.func.now()),
|
sa.Column("created", sa.DateTime(timezone=True), server_default=sa_sql.func.now()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatHistory(base_models.Base):
|
|
||||||
__tablename__: str = "chat_history" # type: ignore[reportIncompatibleVariableOverride]
|
|
||||||
|
|
||||||
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
|
||||||
session_id: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
|
||||||
channel: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
|
||||||
user_id: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
|
||||||
content: sa_orm.Mapped[sa.JSON] = sa_orm.mapped_column(sa.JSON)
|
|
||||||
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
|
||||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
|
|
||||||
)
|
|
||||||
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
|
||||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
|
|
||||||
)
|
|
||||||
|
|
30
src/assistant/poetry.lock
generated
30
src/assistant/poetry.lock
generated
|
@ -529,6 +529,20 @@ files = [
|
||||||
dnspython = ">=2.0.0"
|
dnspython = ">=2.0.0"
|
||||||
idna = ">=2.0.0"
|
idna = ">=2.0.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "faker"
|
||||||
|
version = "19.10.0"
|
||||||
|
description = "Faker is a Python package that generates fake data for you."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "Faker-19.10.0-py3-none-any.whl", hash = "sha256:f321e657ed61616fbfe14dbb9ccc6b2e8282652bbcfcb503c1bd0231ff834df6"},
|
||||||
|
{file = "Faker-19.10.0.tar.gz", hash = "sha256:63da90512d0cb3acdb71bd833bb3071cb8a196020d08b8567a01d232954f1820"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
python-dateutil = ">=2.4"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastapi"
|
name = "fastapi"
|
||||||
version = "0.103.1"
|
version = "0.103.1"
|
||||||
|
@ -1622,6 +1636,20 @@ pluggy = ">=0.12,<2.0"
|
||||||
[package.extras]
|
[package.extras]
|
||||||
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "python-dateutil"
|
||||||
|
version = "2.8.2"
|
||||||
|
description = "Extensions to the standard Python datetime module"
|
||||||
|
optional = false
|
||||||
|
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
|
||||||
|
files = [
|
||||||
|
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
|
||||||
|
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
six = ">=1.5"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "python-dotenv"
|
name = "python-dotenv"
|
||||||
version = "1.0.0"
|
version = "1.0.0"
|
||||||
|
@ -2244,4 +2272,4 @@ multidict = ">=4.0"
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "2d865a52e2e48b9700ca3f14caa3cbbc05c3ad3965866ab397cdd11e74958829"
|
content-hash = "5212b83adf6d4f20bc25f36c0f79039516153c470fce1975ed8a30596746a113"
|
||||||
|
|
|
@ -39,6 +39,7 @@ python-magic = "^0.4.27"
|
||||||
sqlalchemy = "^2.0.20"
|
sqlalchemy = "^2.0.20"
|
||||||
uvicorn = "^0.23.2"
|
uvicorn = "^0.23.2"
|
||||||
wrapt = "^1.15.0"
|
wrapt = "^1.15.0"
|
||||||
|
faker = "^19.10.0"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
black = "^23.7.0"
|
black = "^23.7.0"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user