diff --git a/src/assistant/alembic/env.py b/src/assistant/alembic/env.py index b506519..9d6b168 100644 --- a/src/assistant/alembic/env.py +++ b/src/assistant/alembic/env.py @@ -6,7 +6,6 @@ from sqlalchemy.engine import Connection from sqlalchemy.ext.asyncio import async_engine_from_config import lib.app.settings as app_settings -import lib.models as models import lib.orm_models as orm_models from alembic import context @@ -24,7 +23,7 @@ print("BASE: ", orm_models.Base.metadata.schema) for t in orm_models.Base.metadata.sorted_tables: print(t.name) -target_metadata = models.Base.metadata +target_metadata = orm_models.Base.metadata def run_migrations_offline() -> None: diff --git a/src/assistant/alembic/versions/2023-10-12_3d448c6327cd_init_commit.py b/src/assistant/alembic/versions/2023-10-12_3d448c6327cd_init_commit.py index 71b2cac..d40648e 100644 --- a/src/assistant/alembic/versions/2023-10-12_3d448c6327cd_init_commit.py +++ b/src/assistant/alembic/versions/2023-10-12_3d448c6327cd_init_commit.py @@ -24,7 +24,7 @@ def upgrade() -> None: op.create_table( "chat_history", sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("session_id", sa.String(), nullable=False), + sa.Column("session_id", sa.Uuid(), nullable=False), sa.Column("channel", sa.String(), nullable=False), sa.Column("user_id", sa.String(), nullable=False), sa.Column("content", sa.JSON(), nullable=False), @@ -33,39 +33,6 @@ def upgrade() -> None: sa.PrimaryKeyConstraint("id"), schema="content", ) - op.drop_table("auth_group") - op.drop_table("auth_user_groups") - op.drop_table("auth_group_permissions") - op.drop_table("auth_user_user_permissions") - op.drop_table("auth_user") - op.drop_table("django_content_type") - op.drop_table("auth_permission") - op.drop_table("django_session") - op.drop_table("django_admin_log") - op.drop_table("django_migrations") - op.alter_column("film_work", "title", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=False) - op.alter_column("film_work", "description", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=True) - op.alter_column("film_work", "creation_date", existing_type=sa.DATE(), type_=sa.DateTime(), existing_nullable=True) - op.alter_column("film_work", "file_path", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=True) - op.alter_column("film_work", "type", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=False) - op.alter_column("film_work", "created", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False) - op.alter_column("film_work", "modified", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False) - op.alter_column("genre", "name", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=False) - op.alter_column("genre", "description", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=True) - op.alter_column("genre", "created", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False) - op.alter_column("genre", "modified", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False) - op.create_foreign_key(None, "genre_film_work", "genre", ["genre_id"], ["id"], referent_schema="content") - op.create_foreign_key(None, "genre_film_work", "film_work", ["film_work_id"], ["id"], referent_schema="content") - op.alter_column("person", "full_name", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=False) - op.alter_column("person", "created", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False) - op.alter_column("person", "modified", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False) - op.alter_column( - "person_film_work", "role", existing_type=sa.TEXT(), type_=sa.String(length=50), existing_nullable=False - ) - op.create_foreign_key(None, "person_film_work", "film_work", ["film_work_id"], ["id"], referent_schema="content") - op.create_foreign_key(None, "person_film_work", "person", ["person_id"], ["id"], referent_schema="content") - op.drop_column("person_film_work", "id") - # ### end Alembic commands ### def downgrade() -> None: diff --git a/src/assistant/lib/agent/chat_repository.py b/src/assistant/lib/agent/chat_repository.py index 435b128..959a1f6 100644 --- a/src/assistant/lib/agent/chat_repository.py +++ b/src/assistant/lib/agent/chat_repository.py @@ -15,7 +15,7 @@ class ChatHistoryRepository: self.logger = logging.getLogger(__name__) async def get_last_session_id(self, request: models.RequestLastSessionId) -> uuid.UUID | None: - """Get a new session ID.""" + """Get a current session ID if exists.""" try: async with self.pg_async_session() as session: @@ -23,10 +23,7 @@ class ChatHistoryRepository: sa.select(orm_models.ChatHistory) .filter_by(channel=request.channel, user_id=request.user_id) .filter( - ( - sa.func.extract("epoch", orm_models.ChatHistory.created) - - sa.func.extract("epoch", orm_models.ChatHistory.modified) / 60 - ) + (sa.text("NOW()") - sa.func.extract("epoch", orm_models.ChatHistory.created)) / 60 <= request.minutes_ago ) .order_by(orm_models.ChatHistory.created.desc()) @@ -39,3 +36,48 @@ class ChatHistoryRepository: return chat_session.id except sqlalchemy.exc.SQLAlchemyError as error: self.logger.exception("Error: %s", error) + + async def get_messages_by_sid(self, request: models.RequestChatHistory): + """Get all messages of a chat by session ID.""" + + try: + async with self.pg_async_session() as session: + statement = ( + sa.select(orm_models.ChatHistory) + .filter_by(id=request.session_id) + .order_by(orm_models.ChatHistory.created.desc()) + ) + result = await session.execute(statement) + for row in result.scalars().all(): + print("Row: ", row) + except sqlalchemy.exc.SQLAlchemyError as error: + self.logger.exception("Error: %s", error) + + # async def get_all_by_session_id(self, request: models.RequestChatHistory) -> list[models.ChatHistory]: + # try: + # async with self.pg_async_session() as session: + # statement = ( + # sa.select(orm_models.ChatHistory) + # .filter_by(id=request.session_id) + # .order_by(orm_models.ChatHistory.created.desc()) + # ) + # result = await session.execute(statement) + + # return [models.ChatHistory.from_orm(chat_history) for chat_history in result.scalars().all()] + + async def add_message(self, request: models.ChatMessage) -> None: + """Add a message to the chat history.""" + try: + async with self.pg_async_session() as session: + chat_history = orm_models.ChatHistory( + id=uuid.uuid4(), + session_id=request.session_id, + user_id=request.user_id, + channel=request.channel, + content=request.message, + ) + session.add(chat_history) + await session.commit() + # TODO: Add refresh to session and return added object + except sqlalchemy.exc.SQLAlchemyError as error: + self.logger.exception("Error: %s", error) diff --git a/src/assistant/lib/api/v1/handlers/agent.py b/src/assistant/lib/api/v1/handlers/agent.py index 40d90be..0fce3a9 100644 --- a/src/assistant/lib/api/v1/handlers/agent.py +++ b/src/assistant/lib/api/v1/handlers/agent.py @@ -1,3 +1,5 @@ +import uuid + import fastapi import lib.agent as agent @@ -15,9 +17,27 @@ class AgentHandler: summary="Статус работоспособности", description="Проверяет доступность сервиса FastAPI.", ) + self.router.add_api_route( + "/add", + self.add_message, + methods=["GET"], + summary="Статус работоспособности", + description="Проверяет доступность сервиса FastAPI.", + ) async def get_agent(self): - request = models.RequestLastSessionId(channel="test", user_id="test", minutes_ago=3) + request = models.RequestLastSessionId(channel="test", user_id="user_id_1", minutes_ago=3) response = await self.chat_history_repository.get_last_session_id(request=request) print("RESPONSE: ", response) return {"response": response} + + async def add_message(self): + sid: uuid.UUID = uuid.UUID("0cd3c882-affd-4929-aff1-e1724f5b54f2") + import faker + fake = faker.Faker() + + message = models.ChatMessage( + session_id=sid, user_id="user_id_1", channel="test", message={"role": "system", "content": fake.sentence()} + ) + await self.chat_history_repository.add_message(request=message) + return {"response": "ok"} diff --git a/src/assistant/lib/models/__init__.py b/src/assistant/lib/models/__init__.py index 1995cc2..51dd416 100644 --- a/src/assistant/lib/models/__init__.py +++ b/src/assistant/lib/models/__init__.py @@ -1,6 +1,6 @@ -from .chat_history import RequestLastSessionId +from .chat_history import ChatMessage, RequestChatHistory, RequestLastSessionId from .embedding import Embedding from .movies import Movie from .token import Token -__all__ = ["Embedding", "Movie", "RequestLastSessionId", "Token"] +__all__ = ["ChatMessage", "Embedding", "Movie", "RequestChatHistory", "RequestLastSessionId", "Token"] diff --git a/src/assistant/lib/models/chat_history.py b/src/assistant/lib/models/chat_history.py index 264eaa5..20c60e2 100644 --- a/src/assistant/lib/models/chat_history.py +++ b/src/assistant/lib/models/chat_history.py @@ -1,3 +1,5 @@ +import uuid + import pydantic @@ -7,3 +9,18 @@ class RequestLastSessionId(pydantic.BaseModel): channel: str user_id: str minutes_ago: int + + +class ChatMessage(pydantic.BaseModel): + """A chat message.""" + + session_id: uuid.UUID + user_id: str + channel: str + message: dict[str, str] + + +class RequestChatHistory(pydantic.BaseModel): + """Request for chat history.""" + + session_id: uuid.UUID diff --git a/src/assistant/lib/orm_models/__init__.py b/src/assistant/lib/orm_models/__init__.py index afb551e..869102e 100644 --- a/src/assistant/lib/orm_models/__init__.py +++ b/src/assistant/lib/orm_models/__init__.py @@ -1,5 +1,6 @@ from .base import Base, IdCreatedUpdatedBaseMixin -from .movies import ChatHistory, FilmWork, Genre, GenreFilmWork, Person, PersonFilmWork +from .chat_history import ChatHistory +from .movies import FilmWork, Genre, GenreFilmWork, Person, PersonFilmWork __all__ = [ "Base", diff --git a/src/assistant/lib/orm_models/chat_history.py b/src/assistant/lib/orm_models/chat_history.py new file mode 100644 index 0000000..04742e9 --- /dev/null +++ b/src/assistant/lib/orm_models/chat_history.py @@ -0,0 +1,24 @@ +import datetime +import uuid + +import sqlalchemy as sa +import sqlalchemy.orm as sa_orm +import sqlalchemy.sql as sa_sql + +import lib.orm_models.base as base_models + + +class ChatHistory(base_models.Base): + __tablename__: str = "chat_history" # type: ignore[reportIncompatibleVariableOverride] + + id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4) + session_id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(nullable=False, unique=True) + channel: sa_orm.Mapped[str] = sa_orm.mapped_column() + user_id: sa_orm.Mapped[str] = sa_orm.mapped_column() + content: sa_orm.Mapped[sa.JSON] = sa_orm.mapped_column(sa.JSON) + created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column( + sa.DateTime(timezone=True), server_default=sa_sql.func.now() + ) + modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column( + sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now() + ) diff --git a/src/assistant/lib/orm_models/movies.py b/src/assistant/lib/orm_models/movies.py index b9f8167..88082e2 100644 --- a/src/assistant/lib/orm_models/movies.py +++ b/src/assistant/lib/orm_models/movies.py @@ -74,19 +74,3 @@ PersonFilmWork = sa.Table( sa.Column("role", sa.String(50), nullable=False), sa.Column("created", sa.DateTime(timezone=True), server_default=sa_sql.func.now()), ) - - -class ChatHistory(base_models.Base): - __tablename__: str = "chat_history" # type: ignore[reportIncompatibleVariableOverride] - - id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4) - session_id: sa_orm.Mapped[str] = sa_orm.mapped_column() - channel: sa_orm.Mapped[str] = sa_orm.mapped_column() - user_id: sa_orm.Mapped[str] = sa_orm.mapped_column() - content: sa_orm.Mapped[sa.JSON] = sa_orm.mapped_column(sa.JSON) - created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column( - sa.DateTime(timezone=True), server_default=sa_sql.func.now() - ) - modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column( - sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now() - ) diff --git a/src/assistant/poetry.lock b/src/assistant/poetry.lock index d5810e4..bcf762a 100644 --- a/src/assistant/poetry.lock +++ b/src/assistant/poetry.lock @@ -529,6 +529,20 @@ files = [ dnspython = ">=2.0.0" idna = ">=2.0.0" +[[package]] +name = "faker" +version = "19.10.0" +description = "Faker is a Python package that generates fake data for you." +optional = false +python-versions = ">=3.8" +files = [ + {file = "Faker-19.10.0-py3-none-any.whl", hash = "sha256:f321e657ed61616fbfe14dbb9ccc6b2e8282652bbcfcb503c1bd0231ff834df6"}, + {file = "Faker-19.10.0.tar.gz", hash = "sha256:63da90512d0cb3acdb71bd833bb3071cb8a196020d08b8567a01d232954f1820"}, +] + +[package.dependencies] +python-dateutil = ">=2.4" + [[package]] name = "fastapi" version = "0.103.1" @@ -1622,6 +1636,20 @@ pluggy = ">=0.12,<2.0" [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "python-dateutil" +version = "2.8.2" +description = "Extensions to the standard Python datetime module" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +files = [ + {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, + {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, +] + +[package.dependencies] +six = ">=1.5" + [[package]] name = "python-dotenv" version = "1.0.0" @@ -2244,4 +2272,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "2d865a52e2e48b9700ca3f14caa3cbbc05c3ad3965866ab397cdd11e74958829" +content-hash = "5212b83adf6d4f20bc25f36c0f79039516153c470fce1975ed8a30596746a113" diff --git a/src/assistant/pyproject.toml b/src/assistant/pyproject.toml index dfdfe2e..9d4a180 100644 --- a/src/assistant/pyproject.toml +++ b/src/assistant/pyproject.toml @@ -39,6 +39,7 @@ python-magic = "^0.4.27" sqlalchemy = "^2.0.20" uvicorn = "^0.23.2" wrapt = "^1.15.0" +faker = "^19.10.0" [tool.poetry.dev-dependencies] black = "^23.7.0"