mirror of
https://github.com/ijaric/voice_assistant.git
synced 2025-05-24 14:33:26 +00:00
Changes by ijaric
This commit is contained in:
parent
b3e480886e
commit
288af57ab5
|
@ -1,9 +1,6 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
import uuid
|
||||
|
||||
import fastapi
|
||||
import langchain.agents
|
||||
import langchain.agents.format_scratchpad
|
||||
import langchain.agents.output_parsers
|
||||
|
@ -25,23 +22,15 @@ class AgentService:
|
|||
self,
|
||||
settings: app_settings.Settings,
|
||||
chat_repository: _chat_repository.ChatHistoryRepository,
|
||||
tools: lib_agent_repositories.OpenAIFunctions | None = None,
|
||||
) -> None:
|
||||
self.settings = settings
|
||||
self.tools = tools
|
||||
self.llm = langchain.chat_models.ChatOpenAI(
|
||||
temperature=self.settings.openai.agent_temperature,
|
||||
openai_api_key=self.settings.openai.api_key.get_secret_value()
|
||||
)
|
||||
self.chat_repository = chat_repository
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
async def get_chat_session_id(self, request: models.RequestLastSessionId) -> uuid.UUID:
|
||||
session_id = self.chat_repository.get_last_session_id(request)
|
||||
if not session_id:
|
||||
session_id = uuid.uuid4()
|
||||
return session_id
|
||||
|
||||
async def artem_process_request(self, request: models.AgentCreateRequestModel) -> models.AgentCreateResponseModel:
|
||||
async def process_request(self, request: models.AgentCreateRequestModel) -> models.AgentCreateResponseModel:
|
||||
# Get session ID
|
||||
session_request = models.RequestLastSessionId(
|
||||
channel=request.channel,
|
||||
|
@ -50,11 +39,9 @@ class AgentService:
|
|||
)
|
||||
session_id = await self.chat_repository.get_last_session_id(session_request)
|
||||
if not session_id:
|
||||
print("NO PREVIOUS CHATS")
|
||||
session_id = uuid.uuid4()
|
||||
print("FOUND CHAT:", )
|
||||
print("SID:", session_id)
|
||||
|
||||
# Declare tools (OpenAI functions)
|
||||
tools = [
|
||||
langchain.tools.Tool(
|
||||
name="GetMovieByDescription",
|
||||
|
@ -66,22 +53,17 @@ class AgentService:
|
|||
|
||||
llm = langchain.chat_models.ChatOpenAI(temperature=self.settings.openai.agent_temperature, openai_api_key=self.settings.openai.api_key.get_secret_value())
|
||||
|
||||
# chat_history = langchain.memory.ChatMessageHistory()
|
||||
chat_history = []
|
||||
chat_history_name = f"{chat_history=}".partition("=")[0]
|
||||
|
||||
request_chat_history = models.RequestChatHistory(session_id=session_id)
|
||||
chat_history_source = await self.chat_repository.get_messages_by_sid(request_chat_history)
|
||||
for entry in chat_history_source:
|
||||
print("ENTRY: ", entry)
|
||||
if entry.role == "user":
|
||||
chat_history.append(langchain.schema.HumanMessage(content=entry.content))
|
||||
elif entry.role == "agent":
|
||||
chat_history.append(langchain.schema.AIMessage(content=entry.content))
|
||||
|
||||
# memory = langchain.memory.ConversationBufferMemory(memory_key=chat_history_name,chat_memory=chat_history)
|
||||
|
||||
print("CHAT HISTORY:", chat_history)
|
||||
|
||||
prompt = langchain.prompts.ChatPromptTemplate.from_messages(
|
||||
[
|
||||
|
@ -112,13 +94,10 @@ class AgentService:
|
|||
| langchain.agents.output_parsers.OpenAIFunctionsAgentOutputParser()
|
||||
)
|
||||
|
||||
print("AGENT:", agent)
|
||||
|
||||
agent_executor = langchain.agents.AgentExecutor(agent=agent, tools=tools, verbose=True)
|
||||
print("CH:", type(chat_history), chat_history)
|
||||
agent_executor = langchain.agents.AgentExecutor(agent=agent, tools=tools, verbose=False)
|
||||
chat_history = [] # temporary disable chat_history
|
||||
response = await agent_executor.ainvoke({"input": request.text, "chat_history": chat_history})
|
||||
print("AI RESPONSE:", response)
|
||||
|
||||
user_request = models.RequestChatMessage(
|
||||
session_id=session_id,
|
||||
|
@ -137,108 +116,3 @@ class AgentService:
|
|||
await self.chat_repository.add_message(ai_response)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# TODO: Добавить запрос для процессинга запроса с памятью+
|
||||
# TODO: Улучшить промпт+
|
||||
# TODO: Возможно, надо добавить Chain на перевод
|
||||
|
||||
|
||||
async def process_request(self, request: models.AgentCreateRequestModel) -> models.AgentCreateResponseModel:
|
||||
|
||||
result = await self.tools.get_movie_by_description(request.text)
|
||||
|
||||
if len(result) == 0:
|
||||
raise fastapi.HTTPException(status_code=404, detail="Movies not found")
|
||||
|
||||
# llm = langchain.chat_models.ChatOpenAI(
|
||||
# temperature=self.settings.openai.agent_temperature,
|
||||
# openai_api_key=self.settings.openai.api_key.get_secret_value()
|
||||
# )
|
||||
|
||||
content_films = "\n".join(film.get_movie_info_line() for film in result)
|
||||
|
||||
system_prompt = (
|
||||
"You are a cinema expert. "
|
||||
f"Here are the movies I found for you: {content_films}"
|
||||
"Listen to the question and answer it based on the information above."
|
||||
)
|
||||
|
||||
prompt = langchain.prompts.ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_prompt),
|
||||
]
|
||||
)
|
||||
chain = prompt | self.llm
|
||||
response = await chain.ainvoke({"input": request.text})
|
||||
response_model = models.AgentCreateResponseModel(text=response.content)
|
||||
return response_model
|
||||
|
||||
|
||||
async def main():
|
||||
import lib.agent.repositories as agent_repositories
|
||||
import lib.clients as clients
|
||||
|
||||
postgres_client = clients.AsyncPostgresClient(app_settings.Settings())
|
||||
embedding_repository = agent_repositories.EmbeddingRepository(app_settings.Settings())
|
||||
chat_repository = _chat_repository.ChatHistoryRepository(postgres_client.get_async_session())
|
||||
|
||||
agent_service = AgentService(
|
||||
settings=app_settings.Settings(),
|
||||
tools=openai_functions.OpenAIFunctions(
|
||||
repository=embedding_repository,
|
||||
pg_async_session=postgres_client.get_async_session(),
|
||||
),
|
||||
chat_repository=chat_repository
|
||||
)
|
||||
|
||||
# question = "What is the movie about a famous country singer meet a talented singer and songwriter who works as a waitress?"
|
||||
request_1 = models.AgentCreateRequestModel(
|
||||
channel="telegram",
|
||||
user_id="123",
|
||||
text="What is the movie about a famous country singer meet a talented singer and songwriter who works as a waitress?"
|
||||
)
|
||||
request_2 = models.AgentCreateRequestModel(
|
||||
channel="telegram",
|
||||
user_id="123",
|
||||
text="So what is the rating of the movie? Do you recommend it?"
|
||||
)
|
||||
request_3 = models.AgentCreateRequestModel(
|
||||
channel="telegram",
|
||||
user_id="123",
|
||||
text="What are the similar movies?"
|
||||
)
|
||||
|
||||
response = await agent_service.artem_process_request(request_1)
|
||||
response = await agent_service.artem_process_request(request_2)
|
||||
response = await agent_service.artem_process_request(request_3)
|
||||
|
||||
|
||||
|
||||
|
||||
# response = await agent_service.artem_process_request(question)
|
||||
# question = "Highly Rated Titanic Movies"
|
||||
# request = models.AgentCreateRequestModel(text=question)
|
||||
# film_results = await agent_service.process_request(request=request)
|
||||
|
||||
# result = [agent_service.tools.get_movie_by_id(id=film.id) for film in film_results]
|
||||
|
||||
# agent_executor = langchain.agents.AgentExecutor(agent=agent, tools=tools, verbose=True)
|
||||
#
|
||||
# # first_question = "What is the movie where halfling bring the ring to the volcano?"
|
||||
# first_question = (
|
||||
# "What is the movie about a famous country singer meet a talented singer and songwriter who works as a waitress?"
|
||||
# )
|
||||
# second_question = "So what is the rating of the movie? Do you recommend it?"
|
||||
# third_question = "What are the similar movies?"
|
||||
# first_result = await agent_executor.ainvoke({"input": first_question, "chat_history": chat_history})
|
||||
# chat_history.append(langchain.schema.messages.HumanMessage(content=first_question))
|
||||
# chat_history.append(langchain.schema.messages.AIMessage(content=first_result["output"]))
|
||||
# second_result = await agent_executor.ainvoke({"input": second_question, "chat_history": chat_history})
|
||||
# chat_history.append(langchain.schema.messages.HumanMessage(content=second_question))
|
||||
# chat_history.append(langchain.schema.messages.AIMessage(content=second_result["output"]))
|
||||
# final_result = await agent_executor.ainvoke({"input": third_question, "chat_history": chat_history})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
@ -2,7 +2,7 @@ import uuid
|
|||
|
||||
import fastapi
|
||||
|
||||
import lib.agent as agent
|
||||
import lib.agent.repositories as agent_repositories
|
||||
import lib.models as models
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import io
|
|||
import fastapi
|
||||
|
||||
import lib.stt.services as stt_services
|
||||
|
||||
import lib.agent.services as agent_service
|
||||
import lib.tts.services as tts_service
|
||||
import lib.models as models
|
||||
|
||||
|
@ -14,9 +14,11 @@ class VoiceResponseHandler:
|
|||
self,
|
||||
stt: stt_services.SpeechService,
|
||||
tts: tts_service.TTSService,
|
||||
# agent: agent_service.AgentService,
|
||||
):
|
||||
self.stt = stt
|
||||
self.tts = tts
|
||||
# self.agent = agent
|
||||
self.router = fastapi.APIRouter()
|
||||
self.router.add_api_route(
|
||||
"/",
|
||||
|
@ -35,12 +37,13 @@ class VoiceResponseHandler:
|
|||
voice_text: str = await self.stt.recognize(voice)
|
||||
if voice_text == "":
|
||||
raise fastapi.HTTPException(status_code=http.HTTPStatus.BAD_REQUEST, detail="Speech recognition failed")
|
||||
# TODO: Добавить обработку текста через клиента openai
|
||||
# TODO: Добавить синтез речи через клиента tts
|
||||
# TODO: Заменить заглушку на реальный ответ
|
||||
|
||||
# agent_request = models.AgentCreateRequestModel(channel=channel, user_id=user_id, text=voice_text)
|
||||
# reply_text = await self.agent.process_request(agent_request)
|
||||
reply_text = "hi there"
|
||||
response = await self.tts.get_audio_as_bytes(
|
||||
models.TTSCreateRequestModel(
|
||||
text=voice_text,
|
||||
text=reply_text,
|
||||
)
|
||||
)
|
||||
return fastapi.responses.StreamingResponse(io.BytesIO(response.audio_content), media_type="audio/ogg")
|
||||
|
|
|
@ -14,6 +14,9 @@ import lib.clients as clients
|
|||
import lib.models as models
|
||||
import lib.stt as stt
|
||||
import lib.tts as tts
|
||||
import lib.agent.repositories as agent_repositories
|
||||
import lib.agent.repositories.openai_functions as agent_functions
|
||||
import lib.agent.services as agent_services
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -89,7 +92,8 @@ class Application:
|
|||
|
||||
logger.info("Initializing repositories")
|
||||
stt_repository: stt.STTProtocol = stt.OpenaiSpeechRepository(settings=settings)
|
||||
chat_history_repository = agent.ChatHistoryRepository(pg_async_session=postgres_client.get_async_session())
|
||||
chat_history_repository = agent_repositories.ChatHistoryRepository(pg_async_session=postgres_client.get_async_session())
|
||||
embedding_repository = agent_repositories.EmbeddingRepository(settings)
|
||||
|
||||
tts_yandex_repository = tts.TTSYandexRepository(
|
||||
tts_settings=app_split_settings.TTSYandexSettings(),
|
||||
|
@ -101,14 +105,21 @@ class Application:
|
|||
is_models_from_api=True,
|
||||
)
|
||||
|
||||
|
||||
# Caches
|
||||
|
||||
logger.info("Initializing caches")
|
||||
|
||||
# Tools
|
||||
|
||||
agent_tools = agent_functions.OpenAIFunctions(repository=embedding_repository, pg_async_session=postgres_client.get_async_session())
|
||||
|
||||
# Services
|
||||
|
||||
logger.info("Initializing services")
|
||||
stt_service: stt.SpeechService = stt.SpeechService(repository=stt_repository) # type: ignore
|
||||
|
||||
|
||||
stt_service: stt.SpeechService = stt.SpeechService(repository=stt_repository) # type: ignore
|
||||
|
||||
tts_service: tts.TTSService = tts.TTSService( # type: ignore
|
||||
repositories={
|
||||
|
@ -117,6 +128,9 @@ class Application:
|
|||
},
|
||||
)
|
||||
|
||||
agent_service: agent_services.AgentService(settings=settings, chat_repository=chat_history_repository)
|
||||
# agent_service: agent_services.AgentService(settings=settings, chat_repository=chat_history_repository, tools=agent_tools)
|
||||
|
||||
# Handlers
|
||||
|
||||
logger.info("Initializing handlers")
|
||||
|
@ -127,6 +141,7 @@ class Application:
|
|||
voice_response_handler = api_v1_handlers.VoiceResponseHandler(
|
||||
stt=stt_service,
|
||||
tts=tts_service,
|
||||
agent=agent_services,
|
||||
).router
|
||||
|
||||
logger.info("Creating application")
|
||||
|
|
Loading…
Reference in New Issue
Block a user