mirror of
https://github.com/ijaric/voice_assistant.git
synced 2025-05-24 14:33:26 +00:00
fix: fixing types partially
This commit is contained in:
parent
98bb7cbb9c
commit
b75af6034b
|
@ -31,9 +31,9 @@ class ChatHistoryRepository:
|
||||||
)
|
)
|
||||||
.order_by(orm_models.ChatHistory.created.desc())
|
.order_by(orm_models.ChatHistory.created.desc())
|
||||||
.limit(1)
|
.limit(1)
|
||||||
)
|
)
|
||||||
result = await session.execute(statement)
|
result = await session.execute(statement)
|
||||||
|
|
||||||
chat_session = result.scalars().first()
|
chat_session = result.scalars().first()
|
||||||
if chat_session:
|
if chat_session:
|
||||||
return chat_session.id
|
return chat_session.id
|
||||||
|
|
|
@ -2,13 +2,13 @@ import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import langchain.agents
|
import langchain.agents
|
||||||
import models
|
|
||||||
import orm_models
|
import orm_models
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
import sqlalchemy.exc
|
import sqlalchemy.exc
|
||||||
import sqlalchemy.ext.asyncio as sa_asyncio
|
import sqlalchemy.ext.asyncio as sa_asyncio
|
||||||
|
|
||||||
import lib.agent.repositories as repositories
|
import lib.agent.repositories as repositories
|
||||||
|
import lib.models as models
|
||||||
|
|
||||||
|
|
||||||
class OpenAIFunctions:
|
class OpenAIFunctions:
|
||||||
|
@ -25,7 +25,7 @@ class OpenAIFunctions:
|
||||||
|
|
||||||
@langchain.agents.tool
|
@langchain.agents.tool
|
||||||
async def get_movie_by_description(self, description: str) -> list[models.Movie] | None:
|
async def get_movie_by_description(self, description: str) -> list[models.Movie] | None:
|
||||||
"""Returns a movie data by description."""
|
"""Provide a movie data by description."""
|
||||||
|
|
||||||
self.logger.info("Request to get movie by description: %s", description)
|
self.logger.info("Request to get movie by description: %s", description)
|
||||||
embedded_description = await self.repository.aget_embedding(description)
|
embedded_description = await self.repository.aget_embedding(description)
|
||||||
|
@ -39,7 +39,6 @@ class OpenAIFunctions:
|
||||||
)
|
)
|
||||||
neighbours = session.scalars(stmt)
|
neighbours = session.scalars(stmt)
|
||||||
for neighbour in await neighbours:
|
for neighbour in await neighbours:
|
||||||
print(neighbour.title)
|
|
||||||
result.append(models.Movie(**neighbour.__dict__))
|
result.append(models.Movie(**neighbour.__dict__))
|
||||||
return result
|
return result
|
||||||
except sqlalchemy.exc.SQLAlchemyError as error:
|
except sqlalchemy.exc.SQLAlchemyError as error:
|
||||||
|
@ -47,12 +46,13 @@ class OpenAIFunctions:
|
||||||
|
|
||||||
@langchain.agents.tool
|
@langchain.agents.tool
|
||||||
def get_movie_by_id(self, id: uuid.UUID) -> models.Movie | None:
|
def get_movie_by_id(self, id: uuid.UUID) -> models.Movie | None:
|
||||||
"""Returns a movie data by movie id."""
|
"""Provide a movie data by movie id."""
|
||||||
self.logger.info("Request to get movie by id: %s", id)
|
self.logger.info("Request to get movie by id: %s", id)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@langchain.agents.tool
|
@langchain.agents.tool
|
||||||
def get_similar_movies(self, id: uuid.UUID) -> list[models.Movie] | None:
|
def get_similar_movies(self, id: uuid.UUID) -> list[models.Movie] | None:
|
||||||
"""Returns a similar movies to movie with movie id."""
|
"""Provide similar movies for the given movie ID."""
|
||||||
|
|
||||||
self.logger.info("Request to get movie by id: %s", id)
|
self.logger.info("Request to get movie by id: %s", id)
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -1,92 +0,0 @@
|
||||||
import datetime
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
import pgvector.sqlalchemy
|
|
||||||
import sqlalchemy as sa
|
|
||||||
import sqlalchemy.orm as sa_orm
|
|
||||||
import sqlalchemy.sql as sa_sql
|
|
||||||
|
|
||||||
import lib.models.orm as base_models
|
|
||||||
|
|
||||||
|
|
||||||
class Genre(base_models.Base):
|
|
||||||
__tablename__: str = "genre"
|
|
||||||
|
|
||||||
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
|
||||||
name: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
|
||||||
description: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
|
|
||||||
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
|
||||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
|
|
||||||
)
|
|
||||||
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
|
||||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Person(base_models.Base):
|
|
||||||
__tablename__: str = "person"
|
|
||||||
|
|
||||||
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
|
||||||
full_name: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
|
||||||
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
|
||||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
|
|
||||||
)
|
|
||||||
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
|
||||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FilmWork(base_models.Base):
|
|
||||||
__tablename__: str = "film_work"
|
|
||||||
|
|
||||||
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
|
||||||
title: sa_orm.Mapped[str]
|
|
||||||
description: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
|
|
||||||
creation_date: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(nullable=True)
|
|
||||||
file_path: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
|
|
||||||
rating: sa_orm.Mapped[float] = sa_orm.mapped_column(nullable=True)
|
|
||||||
type: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
|
||||||
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
|
||||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
|
|
||||||
)
|
|
||||||
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
|
||||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
|
|
||||||
)
|
|
||||||
embedding: sa_orm.Mapped[list[float]] = sa_orm.mapped_column(pgvector.sqlalchemy.Vector(1536))
|
|
||||||
genres: sa_orm.Mapped[list[Genre]] = sa_orm.relationship(secondary="genre_film_work")
|
|
||||||
|
|
||||||
|
|
||||||
GenreFilmWork = sa.Table(
|
|
||||||
"genre_film_work",
|
|
||||||
base_models.Base.metadata,
|
|
||||||
sa.Column("id", sa.UUID, primary_key=True), # type: ignore[reportUnknownVariableType]
|
|
||||||
sa.Column("genre_id", sa.ForeignKey(Genre.id), primary_key=True), # type: ignore[reportUnknownVariableType]
|
|
||||||
sa.Column("film_work_id", sa.ForeignKey(FilmWork.id), primary_key=True), # type: ignore[reportUnknownVariableType]
|
|
||||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa_sql.func.now()),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
PersonFilmWork = sa.Table(
|
|
||||||
"person_film_work",
|
|
||||||
base_models.Base.metadata,
|
|
||||||
sa.Column("person_id", sa.ForeignKey(Person.id), primary_key=True), # type: ignore[reportUnknownVariableType]
|
|
||||||
sa.Column("film_work_id", sa.ForeignKey(FilmWork.id), primary_key=True), # type: ignore[reportUnknownVariableType]
|
|
||||||
sa.Column("role", sa.String(50), nullable=False),
|
|
||||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa_sql.func.now()),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatHistory(base_models.Base):
|
|
||||||
__tablename__: str = "chat_history"
|
|
||||||
|
|
||||||
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
|
||||||
session_id: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
|
||||||
channel: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
|
||||||
user_id: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
|
||||||
content: sa_orm.Mapped[sa.JSON] = sa_orm.mapped_column(sa.JSON)
|
|
||||||
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
|
||||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
|
|
||||||
)
|
|
||||||
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
|
||||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
|
|
||||||
)
|
|
|
@ -4,8 +4,8 @@ import typing
|
||||||
import openai
|
import openai
|
||||||
import openai.error
|
import openai.error
|
||||||
|
|
||||||
import lib.agent.models as models
|
|
||||||
import lib.app.settings as app_settings
|
import lib.app.settings as app_settings
|
||||||
|
import lib.models as models
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingRepository:
|
class EmbeddingRepository:
|
||||||
|
@ -23,7 +23,7 @@ class EmbeddingRepository:
|
||||||
input=text,
|
input=text,
|
||||||
model=model,
|
model=model,
|
||||||
) # type: ignore[reportGeneralTypeIssues]
|
) # type: ignore[reportGeneralTypeIssues]
|
||||||
return response["data"][0]["embedding"]
|
return models.Embedding(**response["data"][0]["embedding"])
|
||||||
except openai.error.OpenAIError:
|
except openai.error.OpenAIError:
|
||||||
self.logger.exception("Failed to get async embedding for: %s", text)
|
self.logger.exception("Failed to get async embedding for: %s", text)
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,8 @@ import langchain.prompts
|
||||||
import langchain.schema.agent
|
import langchain.schema.agent
|
||||||
import langchain.schema.messages
|
import langchain.schema.messages
|
||||||
import langchain.tools.render
|
import langchain.tools.render
|
||||||
import models
|
|
||||||
import varname
|
|
||||||
|
|
||||||
|
import assistant.lib.models.movies as movies
|
||||||
import lib.agent.openai_functions as openai_functions
|
import lib.agent.openai_functions as openai_functions
|
||||||
import lib.app.settings as app_settings
|
import lib.app.settings as app_settings
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
from .chat_history import RequestLastSessionId
|
from .chat_history import RequestLastSessionId
|
||||||
|
from .embedding import Embedding
|
||||||
|
from .movies import Movie
|
||||||
from .token import Token
|
from .token import Token
|
||||||
|
|
||||||
__all__ = ["RequestLastSessionId", "Token"]
|
__all__ = ["Embedding", "Movie", "RequestLastSessionId", "Token"]
|
||||||
|
|
5
src/assistant/lib/models/embedding.py
Normal file
5
src/assistant/lib/models/embedding.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
import pydantic
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding(pydantic.RootModel[list[float]]):
|
||||||
|
root: list[float]
|
|
@ -12,7 +12,3 @@ class Movie(pydantic.BaseModel):
|
||||||
type: str
|
type: str
|
||||||
created: datetime.datetime
|
created: datetime.datetime
|
||||||
modified: datetime.datetime
|
modified: datetime.datetime
|
||||||
|
|
||||||
|
|
||||||
class Embedding(pydantic.RootModel[list[float]]):
|
|
||||||
root: list[float]
|
|
|
@ -10,7 +10,7 @@ import lib.orm_models.base as base_models
|
||||||
|
|
||||||
|
|
||||||
class Genre(base_models.Base):
|
class Genre(base_models.Base):
|
||||||
__tablename__: str = "genre"
|
__tablename__: str = "genre" # type: ignore[reportIncompatibleVariableOverride]
|
||||||
|
|
||||||
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
||||||
name: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
name: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
||||||
|
@ -24,7 +24,7 @@ class Genre(base_models.Base):
|
||||||
|
|
||||||
|
|
||||||
class Person(base_models.Base):
|
class Person(base_models.Base):
|
||||||
__tablename__: str = "person"
|
__tablename__: str = "person" # type: ignore[reportIncompatibleVariableOverride]
|
||||||
|
|
||||||
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
||||||
full_name: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
full_name: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
||||||
|
@ -37,10 +37,10 @@ class Person(base_models.Base):
|
||||||
|
|
||||||
|
|
||||||
class FilmWork(base_models.Base):
|
class FilmWork(base_models.Base):
|
||||||
__tablename__: str = "film_work"
|
__tablename__: str = "film_work" # type: ignore[reportIncompatibleVariableOverride]
|
||||||
|
|
||||||
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
||||||
title: sa_orm.Mapped[str]
|
title: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
||||||
description: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
|
description: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
|
||||||
creation_date: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(nullable=True)
|
creation_date: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(nullable=True)
|
||||||
file_path: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
|
file_path: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
|
||||||
|
@ -77,7 +77,7 @@ PersonFilmWork = sa.Table(
|
||||||
|
|
||||||
|
|
||||||
class ChatHistory(base_models.Base):
|
class ChatHistory(base_models.Base):
|
||||||
__tablename__: str = "chat_history"
|
__tablename__: str = "chat_history" # type: ignore[reportIncompatibleVariableOverride]
|
||||||
|
|
||||||
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
||||||
session_id: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
session_id: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
||||||
|
|
|
@ -92,7 +92,8 @@ variable-rgx = "^_{0,2}[a-z][a-z0-9_]*$"
|
||||||
|
|
||||||
[tool.pyright]
|
[tool.pyright]
|
||||||
exclude = [
|
exclude = [
|
||||||
".venv"
|
".venv",
|
||||||
|
"alembic"
|
||||||
]
|
]
|
||||||
pythonPlatform = "All"
|
pythonPlatform = "All"
|
||||||
pythonVersion = "3.11"
|
pythonVersion = "3.11"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user