From b75af6034b47877ad6e639e4aedf9d413d1bace6 Mon Sep 17 00:00:00 2001 From: Artem Litvinov Date: Thu, 12 Oct 2023 02:02:39 +0100 Subject: [PATCH] fix: fixing types partially --- src/assistant/lib/agent/chat_repository.py | 4 +- src/assistant/lib/agent/openai_functions.py | 10 +- src/assistant/lib/agent/orm_models.py | 92 ------------------- src/assistant/lib/agent/repositories.py | 4 +- src/assistant/lib/agent/services.py | 3 +- src/assistant/lib/models/__init__.py | 4 +- src/assistant/lib/models/embedding.py | 5 + .../lib/{agent/models.py => models/movies.py} | 4 - src/assistant/lib/orm_models/movies.py | 10 +- src/assistant/pyproject.toml | 3 +- 10 files changed, 25 insertions(+), 114 deletions(-) delete mode 100644 src/assistant/lib/agent/orm_models.py create mode 100644 src/assistant/lib/models/embedding.py rename src/assistant/lib/{agent/models.py => models/movies.py} (76%) diff --git a/src/assistant/lib/agent/chat_repository.py b/src/assistant/lib/agent/chat_repository.py index 4fd311b..435b128 100644 --- a/src/assistant/lib/agent/chat_repository.py +++ b/src/assistant/lib/agent/chat_repository.py @@ -31,9 +31,9 @@ class ChatHistoryRepository: ) .order_by(orm_models.ChatHistory.created.desc()) .limit(1) - ) + ) result = await session.execute(statement) - + chat_session = result.scalars().first() if chat_session: return chat_session.id diff --git a/src/assistant/lib/agent/openai_functions.py b/src/assistant/lib/agent/openai_functions.py index 4da4362..673a1ba 100644 --- a/src/assistant/lib/agent/openai_functions.py +++ b/src/assistant/lib/agent/openai_functions.py @@ -2,13 +2,13 @@ import logging import uuid import langchain.agents -import models import orm_models import sqlalchemy as sa import sqlalchemy.exc import sqlalchemy.ext.asyncio as sa_asyncio import lib.agent.repositories as repositories +import lib.models as models class OpenAIFunctions: @@ -25,7 +25,7 @@ class OpenAIFunctions: @langchain.agents.tool async def get_movie_by_description(self, description: str) -> list[models.Movie] | None: - """Returns a movie data by description.""" + """Provide a movie data by description.""" self.logger.info("Request to get movie by description: %s", description) embedded_description = await self.repository.aget_embedding(description) @@ -39,7 +39,6 @@ class OpenAIFunctions: ) neighbours = session.scalars(stmt) for neighbour in await neighbours: - print(neighbour.title) result.append(models.Movie(**neighbour.__dict__)) return result except sqlalchemy.exc.SQLAlchemyError as error: @@ -47,12 +46,13 @@ class OpenAIFunctions: @langchain.agents.tool def get_movie_by_id(self, id: uuid.UUID) -> models.Movie | None: - """Returns a movie data by movie id.""" + """Provide a movie data by movie id.""" self.logger.info("Request to get movie by id: %s", id) return None @langchain.agents.tool def get_similar_movies(self, id: uuid.UUID) -> list[models.Movie] | None: - """Returns a similar movies to movie with movie id.""" + """Provide similar movies for the given movie ID.""" + self.logger.info("Request to get movie by id: %s", id) return None diff --git a/src/assistant/lib/agent/orm_models.py b/src/assistant/lib/agent/orm_models.py deleted file mode 100644 index fe61e56..0000000 --- a/src/assistant/lib/agent/orm_models.py +++ /dev/null @@ -1,92 +0,0 @@ -import datetime -import uuid - -import pgvector.sqlalchemy -import sqlalchemy as sa -import sqlalchemy.orm as sa_orm -import sqlalchemy.sql as sa_sql - -import lib.models.orm as base_models - - -class Genre(base_models.Base): - __tablename__: str = "genre" - - id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4) - name: sa_orm.Mapped[str] = sa_orm.mapped_column() - description: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True) - created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column( - sa.DateTime(timezone=True), server_default=sa_sql.func.now() - ) - modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column( - sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now() - ) - - -class Person(base_models.Base): - __tablename__: str = "person" - - id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4) - full_name: sa_orm.Mapped[str] = sa_orm.mapped_column() - created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column( - sa.DateTime(timezone=True), server_default=sa_sql.func.now() - ) - modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column( - sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now() - ) - - -class FilmWork(base_models.Base): - __tablename__: str = "film_work" - - id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4) - title: sa_orm.Mapped[str] - description: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True) - creation_date: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(nullable=True) - file_path: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True) - rating: sa_orm.Mapped[float] = sa_orm.mapped_column(nullable=True) - type: sa_orm.Mapped[str] = sa_orm.mapped_column() - created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column( - sa.DateTime(timezone=True), server_default=sa_sql.func.now() - ) - modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column( - sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now() - ) - embedding: sa_orm.Mapped[list[float]] = sa_orm.mapped_column(pgvector.sqlalchemy.Vector(1536)) - genres: sa_orm.Mapped[list[Genre]] = sa_orm.relationship(secondary="genre_film_work") - - -GenreFilmWork = sa.Table( - "genre_film_work", - base_models.Base.metadata, - sa.Column("id", sa.UUID, primary_key=True), # type: ignore[reportUnknownVariableType] - sa.Column("genre_id", sa.ForeignKey(Genre.id), primary_key=True), # type: ignore[reportUnknownVariableType] - sa.Column("film_work_id", sa.ForeignKey(FilmWork.id), primary_key=True), # type: ignore[reportUnknownVariableType] - sa.Column("created", sa.DateTime(timezone=True), server_default=sa_sql.func.now()), -) - - -PersonFilmWork = sa.Table( - "person_film_work", - base_models.Base.metadata, - sa.Column("person_id", sa.ForeignKey(Person.id), primary_key=True), # type: ignore[reportUnknownVariableType] - sa.Column("film_work_id", sa.ForeignKey(FilmWork.id), primary_key=True), # type: ignore[reportUnknownVariableType] - sa.Column("role", sa.String(50), nullable=False), - sa.Column("created", sa.DateTime(timezone=True), server_default=sa_sql.func.now()), -) - - -class ChatHistory(base_models.Base): - __tablename__: str = "chat_history" - - id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4) - session_id: sa_orm.Mapped[str] = sa_orm.mapped_column() - channel: sa_orm.Mapped[str] = sa_orm.mapped_column() - user_id: sa_orm.Mapped[str] = sa_orm.mapped_column() - content: sa_orm.Mapped[sa.JSON] = sa_orm.mapped_column(sa.JSON) - created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column( - sa.DateTime(timezone=True), server_default=sa_sql.func.now() - ) - modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column( - sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now() - ) diff --git a/src/assistant/lib/agent/repositories.py b/src/assistant/lib/agent/repositories.py index af25bbe..f8fe6f1 100644 --- a/src/assistant/lib/agent/repositories.py +++ b/src/assistant/lib/agent/repositories.py @@ -4,8 +4,8 @@ import typing import openai import openai.error -import lib.agent.models as models import lib.app.settings as app_settings +import lib.models as models class EmbeddingRepository: @@ -23,7 +23,7 @@ class EmbeddingRepository: input=text, model=model, ) # type: ignore[reportGeneralTypeIssues] - return response["data"][0]["embedding"] + return models.Embedding(**response["data"][0]["embedding"]) except openai.error.OpenAIError: self.logger.exception("Failed to get async embedding for: %s", text) diff --git a/src/assistant/lib/agent/services.py b/src/assistant/lib/agent/services.py index f86248a..83f89a4 100644 --- a/src/assistant/lib/agent/services.py +++ b/src/assistant/lib/agent/services.py @@ -10,9 +10,8 @@ import langchain.prompts import langchain.schema.agent import langchain.schema.messages import langchain.tools.render -import models -import varname +import assistant.lib.models.movies as movies import lib.agent.openai_functions as openai_functions import lib.app.settings as app_settings diff --git a/src/assistant/lib/models/__init__.py b/src/assistant/lib/models/__init__.py index a7cb462..1995cc2 100644 --- a/src/assistant/lib/models/__init__.py +++ b/src/assistant/lib/models/__init__.py @@ -1,4 +1,6 @@ from .chat_history import RequestLastSessionId +from .embedding import Embedding +from .movies import Movie from .token import Token -__all__ = ["RequestLastSessionId", "Token"] +__all__ = ["Embedding", "Movie", "RequestLastSessionId", "Token"] diff --git a/src/assistant/lib/models/embedding.py b/src/assistant/lib/models/embedding.py new file mode 100644 index 0000000..3978dc9 --- /dev/null +++ b/src/assistant/lib/models/embedding.py @@ -0,0 +1,5 @@ +import pydantic + + +class Embedding(pydantic.RootModel[list[float]]): + root: list[float] diff --git a/src/assistant/lib/agent/models.py b/src/assistant/lib/models/movies.py similarity index 76% rename from src/assistant/lib/agent/models.py rename to src/assistant/lib/models/movies.py index a1fc33c..e432111 100644 --- a/src/assistant/lib/agent/models.py +++ b/src/assistant/lib/models/movies.py @@ -12,7 +12,3 @@ class Movie(pydantic.BaseModel): type: str created: datetime.datetime modified: datetime.datetime - - -class Embedding(pydantic.RootModel[list[float]]): - root: list[float] diff --git a/src/assistant/lib/orm_models/movies.py b/src/assistant/lib/orm_models/movies.py index ee26811..b9f8167 100644 --- a/src/assistant/lib/orm_models/movies.py +++ b/src/assistant/lib/orm_models/movies.py @@ -10,7 +10,7 @@ import lib.orm_models.base as base_models class Genre(base_models.Base): - __tablename__: str = "genre" + __tablename__: str = "genre" # type: ignore[reportIncompatibleVariableOverride] id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4) name: sa_orm.Mapped[str] = sa_orm.mapped_column() @@ -24,7 +24,7 @@ class Genre(base_models.Base): class Person(base_models.Base): - __tablename__: str = "person" + __tablename__: str = "person" # type: ignore[reportIncompatibleVariableOverride] id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4) full_name: sa_orm.Mapped[str] = sa_orm.mapped_column() @@ -37,10 +37,10 @@ class Person(base_models.Base): class FilmWork(base_models.Base): - __tablename__: str = "film_work" + __tablename__: str = "film_work" # type: ignore[reportIncompatibleVariableOverride] id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4) - title: sa_orm.Mapped[str] + title: sa_orm.Mapped[str] = sa_orm.mapped_column() description: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True) creation_date: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(nullable=True) file_path: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True) @@ -77,7 +77,7 @@ PersonFilmWork = sa.Table( class ChatHistory(base_models.Base): - __tablename__: str = "chat_history" + __tablename__: str = "chat_history" # type: ignore[reportIncompatibleVariableOverride] id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4) session_id: sa_orm.Mapped[str] = sa_orm.mapped_column() diff --git a/src/assistant/pyproject.toml b/src/assistant/pyproject.toml index 66ace5b..dfdfe2e 100644 --- a/src/assistant/pyproject.toml +++ b/src/assistant/pyproject.toml @@ -92,7 +92,8 @@ variable-rgx = "^_{0,2}[a-z][a-z0-9_]*$" [tool.pyright] exclude = [ - ".venv" + ".venv", + "alembic" ] pythonPlatform = "All" pythonVersion = "3.11"