1
0
mirror of https://github.com/ijaric/voice_assistant.git synced 2025-12-18 05:26:18 +00:00

feat: [#45] llm_agent

This commit is contained in:
2023-10-15 07:49:19 +03:00
parent fe7ebcc11b
commit 36089de163
15 changed files with 344 additions and 368 deletions

View File

@@ -2,13 +2,13 @@ import logging
import uuid
import langchain.agents
import orm_models
import sqlalchemy as sa
import sqlalchemy.exc
import sqlalchemy.ext.asyncio as sa_asyncio
import lib.agent.repositories as repositories
import lib.models as models
import lib.orm_models as orm_models
class OpenAIFunctions:
@@ -23,7 +23,6 @@ class OpenAIFunctions:
self.pg_async_session = pg_async_session
self.repository = repository
@langchain.agents.tool
async def get_movie_by_description(self, description: str) -> list[models.Movie] | None:
"""Provide a movie data by description."""
@@ -34,7 +33,7 @@ class OpenAIFunctions:
result: list[models.Movie] = []
stmt = (
sa.select(orm_models.FilmWork)
.order_by(orm_models.FilmWork.embedding.cosine_distance(embedded_description))
.order_by(orm_models.FilmWork.embeddings.cosine_distance(embedded_description.root))
.limit(5)
)
neighbours = session.scalars(stmt)

View File

@@ -1,6 +1,7 @@
import logging
import typing
import langchain.chat_models
import openai
import openai.error
@@ -13,7 +14,7 @@ class EmbeddingRepository:
def __init__(self, settings: app_settings.Settings) -> None:
"""Initialize the service with an OpenAI API key."""
self.llm = openai.api_key = settings.openai.api_key
self.llm = openai.api_key = settings.openai.api_key.get_secret_value()
self.logger = logging.getLogger(__name__)
def get_embedding(self, text: str, model: str = "text-embedding-ada-002") -> models.Embedding | None:
@@ -28,13 +29,36 @@ class EmbeddingRepository:
self.logger.exception("Failed to get async embedding for: %s", text)
async def aget_embedding(self, text: str, model: str = "text-embedding-ada-002") -> models.Embedding | None:
"""Get the embedding for a given text."""
"""Get the embedding for a given text.[Async]"""
try:
response: dict[str, typing.Any] = await openai.Embedding.acreate(
input=text,
model=model,
) # type: ignore[reportGeneralTypeIssues]
return models.Embedding(**response["data"][0]["embedding"])
# print(response["data"][0]["embedding"])
return models.Embedding(root=response["data"][0]["embedding"])
except openai.error.OpenAIError:
self.logger.exception("Failed to get async embedding for: %s", text)
class LlmRepository:
"""A service for getting embeddings from OpenAI."""
def __init__(self, settings: app_settings.Settings) -> None:
"""Initialize the service with an OpenAI API key."""
self.llm = langchain.chat_models.ChatOpenAI(
temperature=0.7,
openai_api_key=self.settings.openai.api_key.get_secret_value()
)
async def get_chat_response(self, request: str, prompt: str) -> str:
"""Get the embedding for a given text."""
prompt = langchain.prompts.ChatPromptTemplate.from_messages(
[
("system", prompt),
]
)
chain = prompt | self.llm
response = await chain.ainvoke({"input": request})
return response.content

View File

@@ -2,13 +2,13 @@ import asyncio
import logging
import uuid
import fastapi
import langchain.agents
import langchain.agents.format_scratchpad
import langchain.agents.output_parsers
import langchain.chat_models
import langchain.prompts
import langchain.schema.agent
import langchain.schema.messages
import langchain.schema
import langchain.tools.render
import assistant.lib.models.movies as movies
@@ -24,44 +24,36 @@ class AgentService:
self.settings = settings
self.tools = tools
async def process_request(self, request: str, chat_history: list[langchain.schema.messages.Message]) -> str:
llm = langchain.chat_models.ChatOpenAI(temperature=0.7, openai_api_key=self.settings.openai.api_key)
tools = [self.tools.get_movie_by_description, self.tools.get_movie_by_id, self.tools.get_similar_movies]
async def process_request(self, request: models.AgentCreateRequestModel) -> models.AgentCreateResponseModel:
result = await self.tools.get_movie_by_description(request.text)
if len(result) == 0:
raise fastapi.HTTPException(status_code=404, detail="Movies not found")
# llm = langchain.chat_models.ChatOpenAI(
# temperature=self.settings.openai.agent_temperature,
# openai_api_key=self.settings.openai.api_key.get_secret_value()
# )
content_films = "\n".join(film.get_movie_info_line() for film in result)
system_prompt = (
"You are a cinema expert. "
f"Here are the movies I found for you: {content_films}"
"Listen to the question and answer it based on the information above."
)
chat_history = []
chat_history_name = f"{chat_history=}".partition("=")[0]
prompt = langchain.prompts.ChatPromptTemplate.from_messages(
[
(
"system",
"You are very powerful assistant. If you are asked about movies you will you provided functions.",
),
langchain.prompts.MessagesPlaceholder(variable_name=chat_history_name),
("user", "{input}"),
langchain.prompts.MessagesPlaceholder(variable_name="agent_scratchpad"),
("system", system_prompt),
]
)
chain = prompt | self.llm
response = await chain.ainvoke({"input": request.text})
response_model = models.AgentCreateResponseModel(text=response.content)
return response_model
llm_with_tools = llm.bind(
functions=[langchain.tools.render.format_tool_to_openai_function(tool) for tool in tools]
)
chat_history = []
agent = (
{
"input": lambda _: _["input"],
"agent_scratchpad": lambda _: langchain.agents.format_scratchpad.format_to_openai_functions(
_["intermediate_steps"]
),
"chat_history": lambda _: _["chat_history"],
}
| prompt
| llm_with_tools
| langchain.agents.output_parsers.OpenAIFunctionsAgentOutputParser()
)
agent_executor = langchain.agents.AgentExecutor(agent=agent, tools=tools, verbose=True)
return await agent_executor.ainvoke({"input": first_question, "chat_history": chat_history})

View File

@@ -37,7 +37,7 @@ class PostgresSettings(pydantic_settings.BaseSettings):
@property
def dsn(self) -> str:
password = self.password.get_secret_value()
return f"{self.driver}://{self.user}:{password}@{self.host}:{self.port}"
return f"{self.driver}://{self.user}:{password}@{self.host}:{self.port}/{self.db_name}"
@property
def dsn_as_safe_url(self) -> str:

View File

@@ -8,7 +8,30 @@ class Movie(pydantic.BaseModel):
id: uuid.UUID
title: str
description: str | None = None
rating: float
type: str
rating: float | None = None
type: str | None = None
created: datetime.datetime
modified: datetime.datetime
creation_date: datetime.datetime | None = None
runtime: int | None = None
budget: int | None = None
imdb_id: str | None = None
@pydantic.computed_field
@property
def imdb_url(self) -> str:
return f"https://www.imdb.com/title/{self.imdb_id}"
def get_movie_info_line(self):
not_provided_value = "not provided"
content_film_info = {
"Title": self.title,
"Description": self.description or not_provided_value,
"Rating": self.rating or not_provided_value,
"Imdb_id": self.imdb_url or not_provided_value,
"Creation_date": self.creation_date or not_provided_value,
"Runtime": self.runtime or not_provided_value,
"Budget": self.budget or not_provided_value,
}
content_film_info_line = ", ".join(f"{k}: {v}" for k, v in content_film_info.items())
return content_film_info_line

View File

@@ -43,8 +43,10 @@ class FilmWork(base_models.Base):
title: sa_orm.Mapped[str] = sa_orm.mapped_column()
description: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
creation_date: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(nullable=True)
file_path: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
rating: sa_orm.Mapped[float] = sa_orm.mapped_column(nullable=True)
runtime: sa_orm.Mapped[int] = sa_orm.mapped_column(nullable=False)
budget: sa_orm.Mapped[int] = sa_orm.mapped_column(default=0)
imdb_id: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=False)
type: sa_orm.Mapped[str] = sa_orm.mapped_column()
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
@@ -52,7 +54,7 @@ class FilmWork(base_models.Base):
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
)
embedding: sa_orm.Mapped[list[float]] = sa_orm.mapped_column(pgvector.sqlalchemy.Vector(1536))
embeddings: sa_orm.Mapped[list[float]] = sa_orm.mapped_column(pgvector.sqlalchemy.Vector(1536))
genres: sa_orm.Mapped[list[Genre]] = sa_orm.relationship(secondary="genre_film_work")