From 6f84918506cdd5454186e1a40236fe63b36a560c Mon Sep 17 00:00:00 2001 From: jsdio Date: Sun, 15 Oct 2023 07:59:29 +0300 Subject: [PATCH] Changes by ijaric --- src/assistant/lib/agent/openai_functions.py | 22 +++++++++++++++++++++ src/assistant/lib/agent/services.py | 13 ++++++++---- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/assistant/lib/agent/openai_functions.py b/src/assistant/lib/agent/openai_functions.py index 1237148..811fd6b 100644 --- a/src/assistant/lib/agent/openai_functions.py +++ b/src/assistant/lib/agent/openai_functions.py @@ -23,6 +23,28 @@ class OpenAIFunctions: self.pg_async_session = pg_async_session self.repository = repository + @langchain.agents.tool + async def artem_get_movie_by_description(self, description: str) -> list[models.Movie] | None: + """Provide a movie data by description.""" + + self.logger.info("Request to get movie by description: %s", description) + embedded_description = await self.repository.aget_embedding(description) + try: + async with self.pg_async_session() as session: + result: list[models.Movie] = [] + stmt = ( + sa.select(orm_models.FilmWork) + .order_by(orm_models.FilmWork.embeddings.cosine_distance(embedded_description.root)) + .limit(5) + ) + response = await session.execute(stmt) + neighbours = response.scalars() + for neighbour in neighbours: + result.append(models.Movie(**neighbour.__dict__)) + return result + except sqlalchemy.exc.SQLAlchemyError as error: + self.logger.exception("Error: %s", error) + async def get_movie_by_description(self, description: str) -> list[models.Movie] | None: """Provide a movie data by description.""" diff --git a/src/assistant/lib/agent/services.py b/src/assistant/lib/agent/services.py index 52aa3b3..557a43b 100644 --- a/src/assistant/lib/agent/services.py +++ b/src/assistant/lib/agent/services.py @@ -11,17 +11,22 @@ import langchain.chat_models import langchain.prompts import langchain.schema import langchain.tools.render +import langchain.memory +import langchain.memory.chat_memory import lib.models as models import lib.agent.openai_functions as openai_functions import lib.app.settings as app_settings - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +import lib.agent.chat_repository as _chat_repository class AgentService: - def __init__(self, settings: app_settings.Settings, tools: openai_functions.OpenAIFunctions) -> None: + def __init__( + self, + settings: app_settings.Settings, + tools: openai_functions.OpenAIFunctions, + chat_repository: _chat_repository.ChatHistoryRepository, + ) -> None: self.settings = settings self.tools = tools self.llm = langchain.chat_models.ChatOpenAI(