1
0
mirror of https://github.com/ijaric/voice_assistant.git synced 2025-12-18 05:26:18 +00:00

build: in progress agent

This commit is contained in:
Artem Litvinov
2023-10-12 01:50:28 +01:00
parent a59a1dba2e
commit 98bb7cbb9c
22 changed files with 906 additions and 55 deletions

View File

@@ -0,0 +1,3 @@
from .chat_repository import ChatHistoryRepository
__all__ = ["ChatHistoryRepository"]

View File

@@ -0,0 +1,41 @@
import logging
import uuid
import sqlalchemy as sa
import sqlalchemy.exc
import sqlalchemy.ext.asyncio as sa_asyncio
import lib.models as models
import lib.orm_models as orm_models
class ChatHistoryRepository:
def __init__(self, pg_async_session: sa_asyncio.async_sessionmaker[sa_asyncio.AsyncSession]) -> None:
self.pg_async_session = pg_async_session
self.logger = logging.getLogger(__name__)
async def get_last_session_id(self, request: models.RequestLastSessionId) -> uuid.UUID | None:
"""Get a new session ID."""
try:
async with self.pg_async_session() as session:
statement = (
sa.select(orm_models.ChatHistory)
.filter_by(channel=request.channel, user_id=request.user_id)
.filter(
(
sa.func.extract("epoch", orm_models.ChatHistory.created)
- sa.func.extract("epoch", orm_models.ChatHistory.modified) / 60
)
<= request.minutes_ago
)
.order_by(orm_models.ChatHistory.created.desc())
.limit(1)
)
result = await session.execute(statement)
chat_session = result.scalars().first()
if chat_session:
return chat_session.id
except sqlalchemy.exc.SQLAlchemyError as error:
self.logger.exception("Error: %s", error)

View File

@@ -0,0 +1,18 @@
import datetime
import uuid
import pydantic
class Movie(pydantic.BaseModel):
id: uuid.UUID
title: str
description: str | None = None
rating: float
type: str
created: datetime.datetime
modified: datetime.datetime
class Embedding(pydantic.RootModel[list[float]]):
root: list[float]

View File

@@ -0,0 +1,58 @@
import logging
import uuid
import langchain.agents
import models
import orm_models
import sqlalchemy as sa
import sqlalchemy.exc
import sqlalchemy.ext.asyncio as sa_asyncio
import lib.agent.repositories as repositories
class OpenAIFunctions:
"""OpenAI Functions for langchain agents."""
def __init__(
self,
repository: repositories.EmbeddingRepository,
pg_async_session: sa_asyncio.async_sessionmaker[sa_asyncio.AsyncSession],
) -> None:
self.logger = logging.getLogger(__name__)
self.pg_async_session = pg_async_session
self.repository = repository
@langchain.agents.tool
async def get_movie_by_description(self, description: str) -> list[models.Movie] | None:
"""Returns a movie data by description."""
self.logger.info("Request to get movie by description: %s", description)
embedded_description = await self.repository.aget_embedding(description)
try:
async with self.pg_async_session() as session:
result: list[models.Movie] = []
stmt = (
sa.select(orm_models.FilmWork)
.order_by(orm_models.FilmWork.embedding.cosine_distance(embedded_description))
.limit(5)
)
neighbours = session.scalars(stmt)
for neighbour in await neighbours:
print(neighbour.title)
result.append(models.Movie(**neighbour.__dict__))
return result
except sqlalchemy.exc.SQLAlchemyError as error:
self.logger.exception("Error: %s", error)
@langchain.agents.tool
def get_movie_by_id(self, id: uuid.UUID) -> models.Movie | None:
"""Returns a movie data by movie id."""
self.logger.info("Request to get movie by id: %s", id)
return None
@langchain.agents.tool
def get_similar_movies(self, id: uuid.UUID) -> list[models.Movie] | None:
"""Returns a similar movies to movie with movie id."""
self.logger.info("Request to get movie by id: %s", id)
return None

View File

@@ -0,0 +1,92 @@
import datetime
import uuid
import pgvector.sqlalchemy
import sqlalchemy as sa
import sqlalchemy.orm as sa_orm
import sqlalchemy.sql as sa_sql
import lib.models.orm as base_models
class Genre(base_models.Base):
__tablename__: str = "genre"
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
name: sa_orm.Mapped[str] = sa_orm.mapped_column()
description: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
)
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
)
class Person(base_models.Base):
__tablename__: str = "person"
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
full_name: sa_orm.Mapped[str] = sa_orm.mapped_column()
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
)
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
)
class FilmWork(base_models.Base):
__tablename__: str = "film_work"
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
title: sa_orm.Mapped[str]
description: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
creation_date: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(nullable=True)
file_path: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
rating: sa_orm.Mapped[float] = sa_orm.mapped_column(nullable=True)
type: sa_orm.Mapped[str] = sa_orm.mapped_column()
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
)
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
)
embedding: sa_orm.Mapped[list[float]] = sa_orm.mapped_column(pgvector.sqlalchemy.Vector(1536))
genres: sa_orm.Mapped[list[Genre]] = sa_orm.relationship(secondary="genre_film_work")
GenreFilmWork = sa.Table(
"genre_film_work",
base_models.Base.metadata,
sa.Column("id", sa.UUID, primary_key=True), # type: ignore[reportUnknownVariableType]
sa.Column("genre_id", sa.ForeignKey(Genre.id), primary_key=True), # type: ignore[reportUnknownVariableType]
sa.Column("film_work_id", sa.ForeignKey(FilmWork.id), primary_key=True), # type: ignore[reportUnknownVariableType]
sa.Column("created", sa.DateTime(timezone=True), server_default=sa_sql.func.now()),
)
PersonFilmWork = sa.Table(
"person_film_work",
base_models.Base.metadata,
sa.Column("person_id", sa.ForeignKey(Person.id), primary_key=True), # type: ignore[reportUnknownVariableType]
sa.Column("film_work_id", sa.ForeignKey(FilmWork.id), primary_key=True), # type: ignore[reportUnknownVariableType]
sa.Column("role", sa.String(50), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa_sql.func.now()),
)
class ChatHistory(base_models.Base):
__tablename__: str = "chat_history"
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
session_id: sa_orm.Mapped[str] = sa_orm.mapped_column()
channel: sa_orm.Mapped[str] = sa_orm.mapped_column()
user_id: sa_orm.Mapped[str] = sa_orm.mapped_column()
content: sa_orm.Mapped[sa.JSON] = sa_orm.mapped_column(sa.JSON)
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
)
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
)

View File

@@ -0,0 +1,40 @@
import logging
import typing
import openai
import openai.error
import lib.agent.models as models
import lib.app.settings as app_settings
class EmbeddingRepository:
"""A service for getting embeddings from OpenAI."""
def __init__(self, settings: app_settings.Settings) -> None:
"""Initialize the service with an OpenAI API key."""
self.llm = openai.api_key = settings.openai.api_key
self.logger = logging.getLogger(__name__)
def get_embedding(self, text: str, model: str = "text-embedding-ada-002") -> models.Embedding | None:
"""Get the embedding for a given text."""
try:
response: dict[str, typing.Any] = openai.Embedding.create(
input=text,
model=model,
) # type: ignore[reportGeneralTypeIssues]
return response["data"][0]["embedding"]
except openai.error.OpenAIError:
self.logger.exception("Failed to get async embedding for: %s", text)
async def aget_embedding(self, text: str, model: str = "text-embedding-ada-002") -> models.Embedding | None:
"""Get the embedding for a given text."""
try:
response: dict[str, typing.Any] = await openai.Embedding.acreate(
input=text,
model=model,
) # type: ignore[reportGeneralTypeIssues]
return models.Embedding(**response["data"][0]["embedding"])
except openai.error.OpenAIError:
self.logger.exception("Failed to get async embedding for: %s", text)

View File

@@ -0,0 +1,89 @@
import asyncio
import logging
import uuid
import langchain.agents
import langchain.agents.format_scratchpad
import langchain.agents.output_parsers
import langchain.chat_models
import langchain.prompts
import langchain.schema.agent
import langchain.schema.messages
import langchain.tools.render
import models
import varname
import lib.agent.openai_functions as openai_functions
import lib.app.settings as app_settings
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AgentService:
def __init__(self, settings: app_settings.Settings, tools: openai_functions.OpenAIFunctions) -> None:
self.settings = settings
self.tools = tools
async def process_request(self, request: str, chat_history: list[langchain.schema.messages.Message]) -> str:
llm = langchain.chat_models.ChatOpenAI(temperature=0.7, openai_api_key=self.settings.openai.api_key)
tools = [self.tools.get_movie_by_description, self.tools.get_movie_by_id, self.tools.get_similar_movies]
chat_history = []
chat_history_name = f"{chat_history=}".partition("=")[0]
prompt = langchain.prompts.ChatPromptTemplate.from_messages(
[
(
"system",
"You are very powerful assistant. If you are asked about movies you will you provided functions.",
),
langchain.prompts.MessagesPlaceholder(variable_name=chat_history_name),
("user", "{input}"),
langchain.prompts.MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
llm_with_tools = llm.bind(
functions=[langchain.tools.render.format_tool_to_openai_function(tool) for tool in tools]
)
chat_history = []
agent = (
{
"input": lambda _: _["input"],
"agent_scratchpad": lambda _: langchain.agents.format_scratchpad.format_to_openai_functions(
_["intermediate_steps"]
),
"chat_history": lambda _: _["chat_history"],
}
| prompt
| llm_with_tools
| langchain.agents.output_parsers.OpenAIFunctionsAgentOutputParser()
)
agent_executor = langchain.agents.AgentExecutor(agent=agent, tools=tools, verbose=True)
return await agent_executor.ainvoke({"input": first_question, "chat_history": chat_history})
# async def main():
# agent_executor = langchain.agents.AgentExecutor(agent=agent, tools=tools, verbose=True)
# # first_question = "What is the movie where halfling bring the ring to the volcano?"
# first_question = (
# "What is the movie about a famous country singer meet a talented singer and songwriter who works as a waitress?"
# )
# second_question = "So what is the rating of the movie? Do you recommend it?"
# third_question = "What are the similar movies?"
# first_result = await agent_executor.ainvoke({"input": first_question, "chat_history": chat_history})
# chat_history.append(langchain.schema.messages.HumanMessage(content=first_question))
# chat_history.append(langchain.schema.messages.AIMessage(content=first_result["output"]))
# second_result = await agent_executor.ainvoke({"input": second_question, "chat_history": chat_history})
# chat_history.append(langchain.schema.messages.HumanMessage(content=second_question))
# chat_history.append(langchain.schema.messages.AIMessage(content=second_result["output"]))
# final_result = await agent_executor.ainvoke({"input": third_question, "chat_history": chat_history})
# if __name__ == "__main__":
# asyncio.run(main())

View File

@@ -1,3 +1,4 @@
from .agent import AgentHandler
from .health import basic_router
__all__ = ["basic_router"]
__all__ = ["AgentHandler", "basic_router"]

View File

@@ -0,0 +1,23 @@
import fastapi
import lib.agent as agent
import lib.models as models
class AgentHandler:
def __init__(self, chat_history_repository: agent.ChatHistoryRepository):
self.chat_history_repository = chat_history_repository
self.router = fastapi.APIRouter()
self.router.add_api_route(
"/",
self.get_agent,
methods=["GET"],
summary="Статус работоспособности",
description="Проверяет доступность сервиса FastAPI.",
)
async def get_agent(self):
request = models.RequestLastSessionId(channel="test", user_id="test", minutes_ago=3)
response = await self.chat_history_repository.get_last_session_id(request=request)
print("RESPONSE: ", response)
return {"response": response}

View File

@@ -6,6 +6,7 @@ import typing
import fastapi
import uvicorn
import lib.agent as agent
import lib.api.v1.handlers as api_v1_handlers
import lib.app.errors as app_errors
import lib.app.settings as app_settings
@@ -73,6 +74,7 @@ class Application:
logger.info("Initializing repositories")
stt_repository: stt.STTProtocol = stt.OpenaiSpeechRepository(settings=settings)
chat_history_repository = agent.ChatHistoryRepository(pg_async_session=postgres_client.get_async_session())
# Caches
@@ -87,6 +89,7 @@ class Application:
logger.info("Initializing handlers")
liveness_probe_handler = api_v1_handlers.basic_router
agent_handler = api_v1_handlers.AgentHandler(chat_history_repository=chat_history_repository).router
logger.info("Creating application")
@@ -100,6 +103,7 @@ class Application:
# Routes
fastapi_app.include_router(liveness_probe_handler, prefix="/api/v1/health", tags=["health"])
fastapi_app.include_router(agent_handler, prefix="/api/v1/agent", tags=["testing"])
application = Application(
settings=settings,

View File

@@ -7,7 +7,6 @@ from .project import *
from .proxy import *
from .voice import *
__all__ = [
"ApiSettings",
"AppSettings",

View File

@@ -1,4 +1,4 @@
from .orm import Base, IdCreatedUpdatedBaseMixin
from .chat_history import RequestLastSessionId
from .token import Token
__all__ = ["Base", "IdCreatedUpdatedBaseMixin", "Token"]
__all__ = ["RequestLastSessionId", "Token"]

View File

@@ -0,0 +1,9 @@
import pydantic
class RequestLastSessionId(pydantic.BaseModel):
"""Request for a new session ID."""
channel: str
user_id: str
minutes_ago: int

View File

@@ -1,3 +0,0 @@
from .base import Base, IdCreatedUpdatedBaseMixin
__all__ = ["Base", "IdCreatedUpdatedBaseMixin"]

View File

@@ -0,0 +1,13 @@
from .base import Base, IdCreatedUpdatedBaseMixin
from .movies import ChatHistory, FilmWork, Genre, GenreFilmWork, Person, PersonFilmWork
__all__ = [
"Base",
"ChatHistory",
"FilmWork",
"Genre",
"GenreFilmWork",
"IdCreatedUpdatedBaseMixin",
"Person",
"PersonFilmWork",
]

View File

@@ -16,20 +16,12 @@ class Base(sa_orm.DeclarativeBase):
return cls.__name__.lower()
__mapper_args__ = {"eager_defaults": True}
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
__table_args__ = {"schema": "content"}
class IdCreatedUpdatedBaseMixin:
# id: sa_orm.Mapped[int] = sa_orm.mapped_column(primary_key=True)
# id_field: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(name="uuid", primary_key=True, unique=True, default=uuid.uuid4, nullable=False)
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(server_default=sa_sql.func.now())
updated: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
)
# __mapper_args__ = {"eager_defaults": True}
# @sqlalchemy.ext.declarative.declared_attr.directive
# def __tablename__(cls) -> str:
# return cls.__name__.lower()

View File

@@ -0,0 +1,92 @@
import datetime
import uuid
import pgvector.sqlalchemy
import sqlalchemy as sa
import sqlalchemy.orm as sa_orm
import sqlalchemy.sql as sa_sql
import lib.orm_models.base as base_models
class Genre(base_models.Base):
__tablename__: str = "genre"
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
name: sa_orm.Mapped[str] = sa_orm.mapped_column()
description: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
)
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
)
class Person(base_models.Base):
__tablename__: str = "person"
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
full_name: sa_orm.Mapped[str] = sa_orm.mapped_column()
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
)
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
)
class FilmWork(base_models.Base):
__tablename__: str = "film_work"
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
title: sa_orm.Mapped[str]
description: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
creation_date: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(nullable=True)
file_path: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
rating: sa_orm.Mapped[float] = sa_orm.mapped_column(nullable=True)
type: sa_orm.Mapped[str] = sa_orm.mapped_column()
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
)
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
)
embedding: sa_orm.Mapped[list[float]] = sa_orm.mapped_column(pgvector.sqlalchemy.Vector(1536))
genres: sa_orm.Mapped[list[Genre]] = sa_orm.relationship(secondary="genre_film_work")
GenreFilmWork = sa.Table(
"genre_film_work",
base_models.Base.metadata,
sa.Column("id", sa.UUID, primary_key=True), # type: ignore[reportUnknownVariableType]
sa.Column("genre_id", sa.ForeignKey(Genre.id), primary_key=True), # type: ignore[reportUnknownVariableType]
sa.Column("film_work_id", sa.ForeignKey(FilmWork.id), primary_key=True), # type: ignore[reportUnknownVariableType]
sa.Column("created", sa.DateTime(timezone=True), server_default=sa_sql.func.now()),
)
PersonFilmWork = sa.Table(
"person_film_work",
base_models.Base.metadata,
sa.Column("person_id", sa.ForeignKey(Person.id), primary_key=True), # type: ignore[reportUnknownVariableType]
sa.Column("film_work_id", sa.ForeignKey(FilmWork.id), primary_key=True), # type: ignore[reportUnknownVariableType]
sa.Column("role", sa.String(50), nullable=False),
sa.Column("created", sa.DateTime(timezone=True), server_default=sa_sql.func.now()),
)
class ChatHistory(base_models.Base):
__tablename__: str = "chat_history"
id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4)
session_id: sa_orm.Mapped[str] = sa_orm.mapped_column()
channel: sa_orm.Mapped[str] = sa_orm.mapped_column()
user_id: sa_orm.Mapped[str] = sa_orm.mapped_column()
content: sa_orm.Mapped[sa.JSON] = sa_orm.mapped_column(sa.JSON)
created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now()
)
modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column(
sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now()
)