1
0
mirror of https://github.com/ijaric/voice_assistant.git synced 2025-05-24 14:33:26 +00:00

fix: fixing types partially

This commit is contained in:
Artem Litvinov 2023-10-12 02:02:39 +01:00
parent 98bb7cbb9c
commit b75af6034b
10 changed files with 25 additions and 114 deletions

View File

@ -31,9 +31,9 @@ class ChatHistoryRepository:
) )
.order_by(orm_models.ChatHistory.created.desc()) .order_by(orm_models.ChatHistory.created.desc())
.limit(1) .limit(1)
) )
result = await session.execute(statement) result = await session.execute(statement)
chat_session = result.scalars().first() chat_session = result.scalars().first()
if chat_session: if chat_session:
return chat_session.id return chat_session.id

View File

@ -2,13 +2,13 @@ import logging
import uuid import uuid
import langchain.agents import langchain.agents
import models
import orm_models import orm_models
import sqlalchemy as sa import sqlalchemy as sa
import sqlalchemy.exc import sqlalchemy.exc
import sqlalchemy.ext.asyncio as sa_asyncio import sqlalchemy.ext.asyncio as sa_asyncio
import lib.agent.repositories as repositories import lib.agent.repositories as repositories
import lib.models as models
class OpenAIFunctions: class OpenAIFunctions:
@ -25,7 +25,7 @@ class OpenAIFunctions:
@langchain.agents.tool @langchain.agents.tool
async def get_movie_by_description(self, description: str) -> list[models.Movie] | None: async def get_movie_by_description(self, description: str) -> list[models.Movie] | None:
"""Returns a movie data by description.""" """Provide a movie data by description."""
self.logger.info("Request to get movie by description: %s", description) self.logger.info("Request to get movie by description: %s", description)
embedded_description = await self.repository.aget_embedding(description) embedded_description = await self.repository.aget_embedding(description)
@ -39,7 +39,6 @@ class OpenAIFunctions:
) )
neighbours = session.scalars(stmt) neighbours = session.scalars(stmt)
for neighbour in await neighbours: for neighbour in await neighbours:
print(neighbour.title)
result.append(models.Movie(**neighbour.__dict__)) result.append(models.Movie(**neighbour.__dict__))
return result return result
except sqlalchemy.exc.SQLAlchemyError as error: except sqlalchemy.exc.SQLAlchemyError as error:
@ -47,12 +46,13 @@ class OpenAIFunctions:
@langchain.agents.tool @langchain.agents.tool
def get_movie_by_id(self, id: uuid.UUID) -> models.Movie | None: def get_movie_by_id(self, id: uuid.UUID) -> models.Movie | None:
"""Returns a movie data by movie id.""" """Provide a movie data by movie id."""
self.logger.info("Request to get movie by id: %s", id) self.logger.info("Request to get movie by id: %s", id)
return None return None
@langchain.agents.tool @langchain.agents.tool
def get_similar_movies(self, id: uuid.UUID) -> list[models.Movie] | None: def get_similar_movies(self, id: uuid.UUID) -> list[models.Movie] | None:
"""Returns a similar movies to movie with movie id.""" """Provide similar movies for the given movie ID."""
self.logger.info("Request to get movie by id: %s", id) self.logger.info("Request to get movie by id: %s", id)
return None return None

View File

@ -1,92 +0,0 @@
import datetime
import uuid
import pgvector.sqlalchemy
import sqlalchemy as sa
import sqlalchemy.orm as sa_orm
import sqlalchemy.sql as sa_sql
import lib.models.orm as base_models
class Genre(base_models.Base):
__tablename__: str = "genre"
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
name: sa_orm.Mapped[str] = sa_orm.mapped_column()
description: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
)
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
)
class Person(base_models.Base):
__tablename__: str = "person"
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
full_name: sa_orm.Mapped[str] = sa_orm.mapped_column()
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
)
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
)
class FilmWork(base_models.Base):
__tablename__: str = "film_work"
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
title: sa_orm.Mapped[str]
description: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
creation_date: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(nullable=True)
file_path: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
rating: sa_orm.Mapped[float] = sa_orm.mapped_column(nullable=True)
type: sa_orm.Mapped[str] = sa_orm.mapped_column()
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
)
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
)
embedding: sa_orm.Mapped[list[float]] = sa_orm.mapped_column(pgvector.sqlalchemy.Vector(1536))
genres: sa_orm.Mapped[list[Genre]] = sa_orm.relationship(secondary="genre_film_work")
GenreFilmWork = sa.Table(
"genre_film_work",
base_models.Base.metadata,
sa.Column("id", sa.UUID, primary_key=True), # type: ignore[reportUnknownVariableType]
sa.Column("genre_id", sa.ForeignKey(Genre.id), primary_key=True), # type: ignore[reportUnknownVariableType]
sa.Column("film_work_id", sa.ForeignKey(FilmWork.id), primary_key=True), # type: ignore[reportUnknownVariableType]
sa.Column("created", sa.DateTime(timezone=True), server_default=sa_sql.func.now()),
)
PersonFilmWork = sa.Table(
"person_film_work",
base_models.Base.metadata,
sa.Column("person_id", sa.ForeignKey(Person.id), primary_key=True), # type: ignore[reportUnknownVariableType]
sa.Column("film_work_id", sa.ForeignKey(FilmWork.id), primary_key=True), # type: ignore[reportUnknownVariableType]
sa.Column("role", sa.String(50), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa_sql.func.now()),
)
class ChatHistory(base_models.Base):
__tablename__: str = "chat_history"
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
session_id: sa_orm.Mapped[str] = sa_orm.mapped_column()
channel: sa_orm.Mapped[str] = sa_orm.mapped_column()
user_id: sa_orm.Mapped[str] = sa_orm.mapped_column()
content: sa_orm.Mapped[sa.JSON] = sa_orm.mapped_column(sa.JSON)
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
)
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
)

View File

@ -4,8 +4,8 @@ import typing
import openai import openai
import openai.error import openai.error
import lib.agent.models as models
import lib.app.settings as app_settings import lib.app.settings as app_settings
import lib.models as models
class EmbeddingRepository: class EmbeddingRepository:
@ -23,7 +23,7 @@ class EmbeddingRepository:
input=text, input=text,
model=model, model=model,
) # type: ignore[reportGeneralTypeIssues] ) # type: ignore[reportGeneralTypeIssues]
return response["data"][0]["embedding"] return models.Embedding(**response["data"][0]["embedding"])
except openai.error.OpenAIError: except openai.error.OpenAIError:
self.logger.exception("Failed to get async embedding for: %s", text) self.logger.exception("Failed to get async embedding for: %s", text)

View File

@ -10,9 +10,8 @@ import langchain.prompts
import langchain.schema.agent import langchain.schema.agent
import langchain.schema.messages import langchain.schema.messages
import langchain.tools.render import langchain.tools.render
import models
import varname
import assistant.lib.models.movies as movies
import lib.agent.openai_functions as openai_functions import lib.agent.openai_functions as openai_functions
import lib.app.settings as app_settings import lib.app.settings as app_settings

View File

@ -1,4 +1,6 @@
from .chat_history import RequestLastSessionId from .chat_history import RequestLastSessionId
from .embedding import Embedding
from .movies import Movie
from .token import Token from .token import Token
__all__ = ["RequestLastSessionId", "Token"] __all__ = ["Embedding", "Movie", "RequestLastSessionId", "Token"]

View File

@ -0,0 +1,5 @@
import pydantic
class Embedding(pydantic.RootModel[list[float]]):
root: list[float]

View File

@ -12,7 +12,3 @@ class Movie(pydantic.BaseModel):
type: str type: str
created: datetime.datetime created: datetime.datetime
modified: datetime.datetime modified: datetime.datetime
class Embedding(pydantic.RootModel[list[float]]):
root: list[float]

View File

@ -10,7 +10,7 @@ import lib.orm_models.base as base_models
class Genre(base_models.Base): class Genre(base_models.Base):
__tablename__: str = "genre" __tablename__: str = "genre" # type: ignore[reportIncompatibleVariableOverride]
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4) id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
name: sa_orm.Mapped[str] = sa_orm.mapped_column() name: sa_orm.Mapped[str] = sa_orm.mapped_column()
@ -24,7 +24,7 @@ class Genre(base_models.Base):
class Person(base_models.Base): class Person(base_models.Base):
__tablename__: str = "person" __tablename__: str = "person" # type: ignore[reportIncompatibleVariableOverride]
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4) id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
full_name: sa_orm.Mapped[str] = sa_orm.mapped_column() full_name: sa_orm.Mapped[str] = sa_orm.mapped_column()
@ -37,10 +37,10 @@ class Person(base_models.Base):
class FilmWork(base_models.Base): class FilmWork(base_models.Base):
__tablename__: str = "film_work" __tablename__: str = "film_work" # type: ignore[reportIncompatibleVariableOverride]
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4) id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
title: sa_orm.Mapped[str] title: sa_orm.Mapped[str] = sa_orm.mapped_column()
description: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True) description: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
creation_date: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(nullable=True) creation_date: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(nullable=True)
file_path: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True) file_path: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
@ -77,7 +77,7 @@ PersonFilmWork = sa.Table(
class ChatHistory(base_models.Base): class ChatHistory(base_models.Base):
__tablename__: str = "chat_history" __tablename__: str = "chat_history" # type: ignore[reportIncompatibleVariableOverride]
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4) id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
session_id: sa_orm.Mapped[str] = sa_orm.mapped_column() session_id: sa_orm.Mapped[str] = sa_orm.mapped_column()

View File

@ -92,7 +92,8 @@ variable-rgx = "^_{0,2}[a-z][a-z0-9_]*$"
[tool.pyright] [tool.pyright]
exclude = [ exclude = [
".venv" ".venv",
"alembic"
] ]
pythonPlatform = "All" pythonPlatform = "All"
pythonVersion = "3.11" pythonVersion = "3.11"