diff --git a/src/assistant/lib/agent/openai_functions.py b/src/assistant/lib/agent/openai_functions.py index 1237148..811fd6b 100644 --- a/src/assistant/lib/agent/openai_functions.py +++ b/src/assistant/lib/agent/openai_functions.py @@ -23,6 +23,28 @@ class OpenAIFunctions: self.pg_async_session = pg_async_session self.repository = repository + @langchain.agents.tool + async def artem_get_movie_by_description(self, description: str) -> list[models.Movie] | None: + """Provide a movie data by description.""" + + self.logger.info("Request to get movie by description: %s", description) + embedded_description = await self.repository.aget_embedding(description) + try: + async with self.pg_async_session() as session: + result: list[models.Movie] = [] + stmt = ( + sa.select(orm_models.FilmWork) + .order_by(orm_models.FilmWork.embeddings.cosine_distance(embedded_description.root)) + .limit(5) + ) + response = await session.execute(stmt) + neighbours = response.scalars() + for neighbour in neighbours: + result.append(models.Movie(**neighbour.__dict__)) + return result + except sqlalchemy.exc.SQLAlchemyError as error: + self.logger.exception("Error: %s", error) + async def get_movie_by_description(self, description: str) -> list[models.Movie] | None: """Provide a movie data by description.""" diff --git a/src/assistant/lib/agent/services.py b/src/assistant/lib/agent/services.py index 52aa3b3..557a43b 100644 --- a/src/assistant/lib/agent/services.py +++ b/src/assistant/lib/agent/services.py @@ -11,17 +11,22 @@ import langchain.chat_models import langchain.prompts import langchain.schema import langchain.tools.render +import langchain.memory +import langchain.memory.chat_memory import lib.models as models import lib.agent.openai_functions as openai_functions import lib.app.settings as app_settings - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +import lib.agent.chat_repository as _chat_repository class AgentService: - def __init__(self, settings: app_settings.Settings, tools: openai_functions.OpenAIFunctions) -> None: + def __init__( + self, + settings: app_settings.Settings, + tools: openai_functions.OpenAIFunctions, + chat_repository: _chat_repository.ChatHistoryRepository, + ) -> None: self.settings = settings self.tools = tools self.llm = langchain.chat_models.ChatOpenAI(