mirror of
https://github.com/ijaric/voice_assistant.git
synced 2025-12-16 16:26:16 +00:00
build: in progress agent
This commit is contained in:
3
src/assistant/lib/agent/__init__.py
Normal file
3
src/assistant/lib/agent/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .chat_repository import ChatHistoryRepository
|
||||
|
||||
__all__ = ["ChatHistoryRepository"]
|
||||
41
src/assistant/lib/agent/chat_repository.py
Normal file
41
src/assistant/lib/agent/chat_repository.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.exc
|
||||
import sqlalchemy.ext.asyncio as sa_asyncio
|
||||
|
||||
import lib.models as models
|
||||
import lib.orm_models as orm_models
|
||||
|
||||
|
||||
class ChatHistoryRepository:
|
||||
def __init__(self, pg_async_session: sa_asyncio.async_sessionmaker[sa_asyncio.AsyncSession]) -> None:
|
||||
self.pg_async_session = pg_async_session
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
async def get_last_session_id(self, request: models.RequestLastSessionId) -> uuid.UUID | None:
|
||||
"""Get a new session ID."""
|
||||
|
||||
try:
|
||||
async with self.pg_async_session() as session:
|
||||
statement = (
|
||||
sa.select(orm_models.ChatHistory)
|
||||
.filter_by(channel=request.channel, user_id=request.user_id)
|
||||
.filter(
|
||||
(
|
||||
sa.func.extract("epoch", orm_models.ChatHistory.created)
|
||||
- sa.func.extract("epoch", orm_models.ChatHistory.modified) / 60
|
||||
)
|
||||
<= request.minutes_ago
|
||||
)
|
||||
.order_by(orm_models.ChatHistory.created.desc())
|
||||
.limit(1)
|
||||
)
|
||||
result = await session.execute(statement)
|
||||
|
||||
chat_session = result.scalars().first()
|
||||
if chat_session:
|
||||
return chat_session.id
|
||||
except sqlalchemy.exc.SQLAlchemyError as error:
|
||||
self.logger.exception("Error: %s", error)
|
||||
18
src/assistant/lib/agent/models.py
Normal file
18
src/assistant/lib/agent/models.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class Movie(pydantic.BaseModel):
|
||||
id: uuid.UUID
|
||||
title: str
|
||||
description: str | None = None
|
||||
rating: float
|
||||
type: str
|
||||
created: datetime.datetime
|
||||
modified: datetime.datetime
|
||||
|
||||
|
||||
class Embedding(pydantic.RootModel[list[float]]):
|
||||
root: list[float]
|
||||
58
src/assistant/lib/agent/openai_functions.py
Normal file
58
src/assistant/lib/agent/openai_functions.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import langchain.agents
|
||||
import models
|
||||
import orm_models
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.exc
|
||||
import sqlalchemy.ext.asyncio as sa_asyncio
|
||||
|
||||
import lib.agent.repositories as repositories
|
||||
|
||||
|
||||
class OpenAIFunctions:
|
||||
"""OpenAI Functions for langchain agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repository: repositories.EmbeddingRepository,
|
||||
pg_async_session: sa_asyncio.async_sessionmaker[sa_asyncio.AsyncSession],
|
||||
) -> None:
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.pg_async_session = pg_async_session
|
||||
self.repository = repository
|
||||
|
||||
@langchain.agents.tool
|
||||
async def get_movie_by_description(self, description: str) -> list[models.Movie] | None:
|
||||
"""Returns a movie data by description."""
|
||||
|
||||
self.logger.info("Request to get movie by description: %s", description)
|
||||
embedded_description = await self.repository.aget_embedding(description)
|
||||
try:
|
||||
async with self.pg_async_session() as session:
|
||||
result: list[models.Movie] = []
|
||||
stmt = (
|
||||
sa.select(orm_models.FilmWork)
|
||||
.order_by(orm_models.FilmWork.embedding.cosine_distance(embedded_description))
|
||||
.limit(5)
|
||||
)
|
||||
neighbours = session.scalars(stmt)
|
||||
for neighbour in await neighbours:
|
||||
print(neighbour.title)
|
||||
result.append(models.Movie(**neighbour.__dict__))
|
||||
return result
|
||||
except sqlalchemy.exc.SQLAlchemyError as error:
|
||||
self.logger.exception("Error: %s", error)
|
||||
|
||||
@langchain.agents.tool
|
||||
def get_movie_by_id(self, id: uuid.UUID) -> models.Movie | None:
|
||||
"""Returns a movie data by movie id."""
|
||||
self.logger.info("Request to get movie by id: %s", id)
|
||||
return None
|
||||
|
||||
@langchain.agents.tool
|
||||
def get_similar_movies(self, id: uuid.UUID) -> list[models.Movie] | None:
|
||||
"""Returns a similar movies to movie with movie id."""
|
||||
self.logger.info("Request to get movie by id: %s", id)
|
||||
return None
|
||||
92
src/assistant/lib/agent/orm_models.py
Normal file
92
src/assistant/lib/agent/orm_models.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
import pgvector.sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.orm as sa_orm
|
||||
import sqlalchemy.sql as sa_sql
|
||||
|
||||
import lib.models.orm as base_models
|
||||
|
||||
|
||||
class Genre(base_models.Base):
|
||||
__tablename__: str = "genre"
|
||||
|
||||
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
||||
name: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
||||
description: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
|
||||
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
|
||||
)
|
||||
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
|
||||
)
|
||||
|
||||
|
||||
class Person(base_models.Base):
|
||||
__tablename__: str = "person"
|
||||
|
||||
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
||||
full_name: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
||||
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
|
||||
)
|
||||
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
|
||||
)
|
||||
|
||||
|
||||
class FilmWork(base_models.Base):
|
||||
__tablename__: str = "film_work"
|
||||
|
||||
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
||||
title: sa_orm.Mapped[str]
|
||||
description: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
|
||||
creation_date: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(nullable=True)
|
||||
file_path: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
|
||||
rating: sa_orm.Mapped[float] = sa_orm.mapped_column(nullable=True)
|
||||
type: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
||||
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
|
||||
)
|
||||
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
|
||||
)
|
||||
embedding: sa_orm.Mapped[list[float]] = sa_orm.mapped_column(pgvector.sqlalchemy.Vector(1536))
|
||||
genres: sa_orm.Mapped[list[Genre]] = sa_orm.relationship(secondary="genre_film_work")
|
||||
|
||||
|
||||
GenreFilmWork = sa.Table(
|
||||
"genre_film_work",
|
||||
base_models.Base.metadata,
|
||||
sa.Column("id", sa.UUID, primary_key=True), # type: ignore[reportUnknownVariableType]
|
||||
sa.Column("genre_id", sa.ForeignKey(Genre.id), primary_key=True), # type: ignore[reportUnknownVariableType]
|
||||
sa.Column("film_work_id", sa.ForeignKey(FilmWork.id), primary_key=True), # type: ignore[reportUnknownVariableType]
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa_sql.func.now()),
|
||||
)
|
||||
|
||||
|
||||
PersonFilmWork = sa.Table(
|
||||
"person_film_work",
|
||||
base_models.Base.metadata,
|
||||
sa.Column("person_id", sa.ForeignKey(Person.id), primary_key=True), # type: ignore[reportUnknownVariableType]
|
||||
sa.Column("film_work_id", sa.ForeignKey(FilmWork.id), primary_key=True), # type: ignore[reportUnknownVariableType]
|
||||
sa.Column("role", sa.String(50), nullable=False),
|
||||
sa.Column("created", sa.DateTime(timezone=True), server_default=sa_sql.func.now()),
|
||||
)
|
||||
|
||||
|
||||
class ChatHistory(base_models.Base):
|
||||
__tablename__: str = "chat_history"
|
||||
|
||||
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
|
||||
session_id: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
||||
channel: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
||||
user_id: sa_orm.Mapped[str] = sa_orm.mapped_column()
|
||||
content: sa_orm.Mapped[sa.JSON] = sa_orm.mapped_column(sa.JSON)
|
||||
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
|
||||
)
|
||||
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
|
||||
)
|
||||
40
src/assistant/lib/agent/repositories.py
Normal file
40
src/assistant/lib/agent/repositories.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
|
||||
import lib.agent.models as models
|
||||
import lib.app.settings as app_settings
|
||||
|
||||
|
||||
class EmbeddingRepository:
|
||||
"""A service for getting embeddings from OpenAI."""
|
||||
|
||||
def __init__(self, settings: app_settings.Settings) -> None:
|
||||
"""Initialize the service with an OpenAI API key."""
|
||||
self.llm = openai.api_key = settings.openai.api_key
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def get_embedding(self, text: str, model: str = "text-embedding-ada-002") -> models.Embedding | None:
|
||||
"""Get the embedding for a given text."""
|
||||
try:
|
||||
response: dict[str, typing.Any] = openai.Embedding.create(
|
||||
input=text,
|
||||
model=model,
|
||||
) # type: ignore[reportGeneralTypeIssues]
|
||||
return response["data"][0]["embedding"]
|
||||
except openai.error.OpenAIError:
|
||||
self.logger.exception("Failed to get async embedding for: %s", text)
|
||||
|
||||
async def aget_embedding(self, text: str, model: str = "text-embedding-ada-002") -> models.Embedding | None:
|
||||
"""Get the embedding for a given text."""
|
||||
try:
|
||||
response: dict[str, typing.Any] = await openai.Embedding.acreate(
|
||||
input=text,
|
||||
model=model,
|
||||
) # type: ignore[reportGeneralTypeIssues]
|
||||
return models.Embedding(**response["data"][0]["embedding"])
|
||||
|
||||
except openai.error.OpenAIError:
|
||||
self.logger.exception("Failed to get async embedding for: %s", text)
|
||||
89
src/assistant/lib/agent/services.py
Normal file
89
src/assistant/lib/agent/services.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import langchain.agents
|
||||
import langchain.agents.format_scratchpad
|
||||
import langchain.agents.output_parsers
|
||||
import langchain.chat_models
|
||||
import langchain.prompts
|
||||
import langchain.schema.agent
|
||||
import langchain.schema.messages
|
||||
import langchain.tools.render
|
||||
import models
|
||||
import varname
|
||||
|
||||
import lib.agent.openai_functions as openai_functions
|
||||
import lib.app.settings as app_settings
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentService:
|
||||
def __init__(self, settings: app_settings.Settings, tools: openai_functions.OpenAIFunctions) -> None:
|
||||
self.settings = settings
|
||||
self.tools = tools
|
||||
|
||||
async def process_request(self, request: str, chat_history: list[langchain.schema.messages.Message]) -> str:
|
||||
llm = langchain.chat_models.ChatOpenAI(temperature=0.7, openai_api_key=self.settings.openai.api_key)
|
||||
tools = [self.tools.get_movie_by_description, self.tools.get_movie_by_id, self.tools.get_similar_movies]
|
||||
|
||||
chat_history = []
|
||||
chat_history_name = f"{chat_history=}".partition("=")[0]
|
||||
prompt = langchain.prompts.ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are very powerful assistant. If you are asked about movies you will you provided functions.",
|
||||
),
|
||||
langchain.prompts.MessagesPlaceholder(variable_name=chat_history_name),
|
||||
("user", "{input}"),
|
||||
langchain.prompts.MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
)
|
||||
|
||||
llm_with_tools = llm.bind(
|
||||
functions=[langchain.tools.render.format_tool_to_openai_function(tool) for tool in tools]
|
||||
)
|
||||
|
||||
chat_history = []
|
||||
|
||||
agent = (
|
||||
{
|
||||
"input": lambda _: _["input"],
|
||||
"agent_scratchpad": lambda _: langchain.agents.format_scratchpad.format_to_openai_functions(
|
||||
_["intermediate_steps"]
|
||||
),
|
||||
"chat_history": lambda _: _["chat_history"],
|
||||
}
|
||||
| prompt
|
||||
| llm_with_tools
|
||||
| langchain.agents.output_parsers.OpenAIFunctionsAgentOutputParser()
|
||||
)
|
||||
|
||||
agent_executor = langchain.agents.AgentExecutor(agent=agent, tools=tools, verbose=True)
|
||||
|
||||
return await agent_executor.ainvoke({"input": first_question, "chat_history": chat_history})
|
||||
|
||||
|
||||
# async def main():
|
||||
# agent_executor = langchain.agents.AgentExecutor(agent=agent, tools=tools, verbose=True)
|
||||
|
||||
# # first_question = "What is the movie where halfling bring the ring to the volcano?"
|
||||
# first_question = (
|
||||
# "What is the movie about a famous country singer meet a talented singer and songwriter who works as a waitress?"
|
||||
# )
|
||||
# second_question = "So what is the rating of the movie? Do you recommend it?"
|
||||
# third_question = "What are the similar movies?"
|
||||
# first_result = await agent_executor.ainvoke({"input": first_question, "chat_history": chat_history})
|
||||
# chat_history.append(langchain.schema.messages.HumanMessage(content=first_question))
|
||||
# chat_history.append(langchain.schema.messages.AIMessage(content=first_result["output"]))
|
||||
# second_result = await agent_executor.ainvoke({"input": second_question, "chat_history": chat_history})
|
||||
# chat_history.append(langchain.schema.messages.HumanMessage(content=second_question))
|
||||
# chat_history.append(langchain.schema.messages.AIMessage(content=second_result["output"]))
|
||||
# final_result = await agent_executor.ainvoke({"input": third_question, "chat_history": chat_history})
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# asyncio.run(main())
|
||||
Reference in New Issue
Block a user