mirror of
https://github.com/ijaric/voice_assistant.git
synced 2025-05-24 14:33:26 +00:00
Changes by ijaric
This commit is contained in:
parent
b3e480886e
commit
288af57ab5
|
@ -1,9 +1,6 @@
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import typing
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import fastapi
|
|
||||||
import langchain.agents
|
import langchain.agents
|
||||||
import langchain.agents.format_scratchpad
|
import langchain.agents.format_scratchpad
|
||||||
import langchain.agents.output_parsers
|
import langchain.agents.output_parsers
|
||||||
|
@ -25,23 +22,15 @@ class AgentService:
|
||||||
self,
|
self,
|
||||||
settings: app_settings.Settings,
|
settings: app_settings.Settings,
|
||||||
chat_repository: _chat_repository.ChatHistoryRepository,
|
chat_repository: _chat_repository.ChatHistoryRepository,
|
||||||
|
tools: lib_agent_repositories.OpenAIFunctions | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
self.llm = langchain.chat_models.ChatOpenAI(
|
|
||||||
temperature=self.settings.openai.agent_temperature,
|
|
||||||
openai_api_key=self.settings.openai.api_key.get_secret_value()
|
|
||||||
)
|
|
||||||
self.chat_repository = chat_repository
|
self.chat_repository = chat_repository
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
async def get_chat_session_id(self, request: models.RequestLastSessionId) -> uuid.UUID:
|
|
||||||
session_id = self.chat_repository.get_last_session_id(request)
|
|
||||||
if not session_id:
|
|
||||||
session_id = uuid.uuid4()
|
|
||||||
return session_id
|
|
||||||
|
|
||||||
async def artem_process_request(self, request: models.AgentCreateRequestModel) -> models.AgentCreateResponseModel:
|
async def process_request(self, request: models.AgentCreateRequestModel) -> models.AgentCreateResponseModel:
|
||||||
# Get session ID
|
# Get session ID
|
||||||
session_request = models.RequestLastSessionId(
|
session_request = models.RequestLastSessionId(
|
||||||
channel=request.channel,
|
channel=request.channel,
|
||||||
|
@ -50,11 +39,9 @@ class AgentService:
|
||||||
)
|
)
|
||||||
session_id = await self.chat_repository.get_last_session_id(session_request)
|
session_id = await self.chat_repository.get_last_session_id(session_request)
|
||||||
if not session_id:
|
if not session_id:
|
||||||
print("NO PREVIOUS CHATS")
|
|
||||||
session_id = uuid.uuid4()
|
session_id = uuid.uuid4()
|
||||||
print("FOUND CHAT:", )
|
|
||||||
print("SID:", session_id)
|
|
||||||
|
|
||||||
|
# Declare tools (OpenAI functions)
|
||||||
tools = [
|
tools = [
|
||||||
langchain.tools.Tool(
|
langchain.tools.Tool(
|
||||||
name="GetMovieByDescription",
|
name="GetMovieByDescription",
|
||||||
|
@ -66,22 +53,17 @@ class AgentService:
|
||||||
|
|
||||||
llm = langchain.chat_models.ChatOpenAI(temperature=self.settings.openai.agent_temperature, openai_api_key=self.settings.openai.api_key.get_secret_value())
|
llm = langchain.chat_models.ChatOpenAI(temperature=self.settings.openai.agent_temperature, openai_api_key=self.settings.openai.api_key.get_secret_value())
|
||||||
|
|
||||||
# chat_history = langchain.memory.ChatMessageHistory()
|
|
||||||
chat_history = []
|
chat_history = []
|
||||||
chat_history_name = f"{chat_history=}".partition("=")[0]
|
chat_history_name = f"{chat_history=}".partition("=")[0]
|
||||||
|
|
||||||
request_chat_history = models.RequestChatHistory(session_id=session_id)
|
request_chat_history = models.RequestChatHistory(session_id=session_id)
|
||||||
chat_history_source = await self.chat_repository.get_messages_by_sid(request_chat_history)
|
chat_history_source = await self.chat_repository.get_messages_by_sid(request_chat_history)
|
||||||
for entry in chat_history_source:
|
for entry in chat_history_source:
|
||||||
print("ENTRY: ", entry)
|
|
||||||
if entry.role == "user":
|
if entry.role == "user":
|
||||||
chat_history.append(langchain.schema.HumanMessage(content=entry.content))
|
chat_history.append(langchain.schema.HumanMessage(content=entry.content))
|
||||||
elif entry.role == "agent":
|
elif entry.role == "agent":
|
||||||
chat_history.append(langchain.schema.AIMessage(content=entry.content))
|
chat_history.append(langchain.schema.AIMessage(content=entry.content))
|
||||||
|
|
||||||
# memory = langchain.memory.ConversationBufferMemory(memory_key=chat_history_name,chat_memory=chat_history)
|
|
||||||
|
|
||||||
print("CHAT HISTORY:", chat_history)
|
|
||||||
|
|
||||||
prompt = langchain.prompts.ChatPromptTemplate.from_messages(
|
prompt = langchain.prompts.ChatPromptTemplate.from_messages(
|
||||||
[
|
[
|
||||||
|
@ -112,13 +94,10 @@ class AgentService:
|
||||||
| langchain.agents.output_parsers.OpenAIFunctionsAgentOutputParser()
|
| langchain.agents.output_parsers.OpenAIFunctionsAgentOutputParser()
|
||||||
)
|
)
|
||||||
|
|
||||||
print("AGENT:", agent)
|
|
||||||
|
|
||||||
agent_executor = langchain.agents.AgentExecutor(agent=agent, tools=tools, verbose=True)
|
agent_executor = langchain.agents.AgentExecutor(agent=agent, tools=tools, verbose=False)
|
||||||
print("CH:", type(chat_history), chat_history)
|
|
||||||
chat_history = [] # temporary disable chat_history
|
chat_history = [] # temporary disable chat_history
|
||||||
response = await agent_executor.ainvoke({"input": request.text, "chat_history": chat_history})
|
response = await agent_executor.ainvoke({"input": request.text, "chat_history": chat_history})
|
||||||
print("AI RESPONSE:", response)
|
|
||||||
|
|
||||||
user_request = models.RequestChatMessage(
|
user_request = models.RequestChatMessage(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
@ -137,108 +116,3 @@ class AgentService:
|
||||||
await self.chat_repository.add_message(ai_response)
|
await self.chat_repository.add_message(ai_response)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
# TODO: Добавить запрос для процессинга запроса с памятью+
|
|
||||||
# TODO: Улучшить промпт+
|
|
||||||
# TODO: Возможно, надо добавить Chain на перевод
|
|
||||||
|
|
||||||
|
|
||||||
async def process_request(self, request: models.AgentCreateRequestModel) -> models.AgentCreateResponseModel:
|
|
||||||
|
|
||||||
result = await self.tools.get_movie_by_description(request.text)
|
|
||||||
|
|
||||||
if len(result) == 0:
|
|
||||||
raise fastapi.HTTPException(status_code=404, detail="Movies not found")
|
|
||||||
|
|
||||||
# llm = langchain.chat_models.ChatOpenAI(
|
|
||||||
# temperature=self.settings.openai.agent_temperature,
|
|
||||||
# openai_api_key=self.settings.openai.api_key.get_secret_value()
|
|
||||||
# )
|
|
||||||
|
|
||||||
content_films = "\n".join(film.get_movie_info_line() for film in result)
|
|
||||||
|
|
||||||
system_prompt = (
|
|
||||||
"You are a cinema expert. "
|
|
||||||
f"Here are the movies I found for you: {content_films}"
|
|
||||||
"Listen to the question and answer it based on the information above."
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = langchain.prompts.ChatPromptTemplate.from_messages(
|
|
||||||
[
|
|
||||||
("system", system_prompt),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
chain = prompt | self.llm
|
|
||||||
response = await chain.ainvoke({"input": request.text})
|
|
||||||
response_model = models.AgentCreateResponseModel(text=response.content)
|
|
||||||
return response_model
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
import lib.agent.repositories as agent_repositories
|
|
||||||
import lib.clients as clients
|
|
||||||
|
|
||||||
postgres_client = clients.AsyncPostgresClient(app_settings.Settings())
|
|
||||||
embedding_repository = agent_repositories.EmbeddingRepository(app_settings.Settings())
|
|
||||||
chat_repository = _chat_repository.ChatHistoryRepository(postgres_client.get_async_session())
|
|
||||||
|
|
||||||
agent_service = AgentService(
|
|
||||||
settings=app_settings.Settings(),
|
|
||||||
tools=openai_functions.OpenAIFunctions(
|
|
||||||
repository=embedding_repository,
|
|
||||||
pg_async_session=postgres_client.get_async_session(),
|
|
||||||
),
|
|
||||||
chat_repository=chat_repository
|
|
||||||
)
|
|
||||||
|
|
||||||
# question = "What is the movie about a famous country singer meet a talented singer and songwriter who works as a waitress?"
|
|
||||||
request_1 = models.AgentCreateRequestModel(
|
|
||||||
channel="telegram",
|
|
||||||
user_id="123",
|
|
||||||
text="What is the movie about a famous country singer meet a talented singer and songwriter who works as a waitress?"
|
|
||||||
)
|
|
||||||
request_2 = models.AgentCreateRequestModel(
|
|
||||||
channel="telegram",
|
|
||||||
user_id="123",
|
|
||||||
text="So what is the rating of the movie? Do you recommend it?"
|
|
||||||
)
|
|
||||||
request_3 = models.AgentCreateRequestModel(
|
|
||||||
channel="telegram",
|
|
||||||
user_id="123",
|
|
||||||
text="What are the similar movies?"
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await agent_service.artem_process_request(request_1)
|
|
||||||
response = await agent_service.artem_process_request(request_2)
|
|
||||||
response = await agent_service.artem_process_request(request_3)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# response = await agent_service.artem_process_request(question)
|
|
||||||
# question = "Highly Rated Titanic Movies"
|
|
||||||
# request = models.AgentCreateRequestModel(text=question)
|
|
||||||
# film_results = await agent_service.process_request(request=request)
|
|
||||||
|
|
||||||
# result = [agent_service.tools.get_movie_by_id(id=film.id) for film in film_results]
|
|
||||||
|
|
||||||
# agent_executor = langchain.agents.AgentExecutor(agent=agent, tools=tools, verbose=True)
|
|
||||||
#
|
|
||||||
# # first_question = "What is the movie where halfling bring the ring to the volcano?"
|
|
||||||
# first_question = (
|
|
||||||
# "What is the movie about a famous country singer meet a talented singer and songwriter who works as a waitress?"
|
|
||||||
# )
|
|
||||||
# second_question = "So what is the rating of the movie? Do you recommend it?"
|
|
||||||
# third_question = "What are the similar movies?"
|
|
||||||
# first_result = await agent_executor.ainvoke({"input": first_question, "chat_history": chat_history})
|
|
||||||
# chat_history.append(langchain.schema.messages.HumanMessage(content=first_question))
|
|
||||||
# chat_history.append(langchain.schema.messages.AIMessage(content=first_result["output"]))
|
|
||||||
# second_result = await agent_executor.ainvoke({"input": second_question, "chat_history": chat_history})
|
|
||||||
# chat_history.append(langchain.schema.messages.HumanMessage(content=second_question))
|
|
||||||
# chat_history.append(langchain.schema.messages.AIMessage(content=second_result["output"]))
|
|
||||||
# final_result = await agent_executor.ainvoke({"input": third_question, "chat_history": chat_history})
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import uuid
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
|
|
||||||
import lib.agent as agent
|
import lib.agent.repositories as agent_repositories
|
||||||
import lib.models as models
|
import lib.models as models
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ import io
|
||||||
import fastapi
|
import fastapi
|
||||||
|
|
||||||
import lib.stt.services as stt_services
|
import lib.stt.services as stt_services
|
||||||
|
import lib.agent.services as agent_service
|
||||||
import lib.tts.services as tts_service
|
import lib.tts.services as tts_service
|
||||||
import lib.models as models
|
import lib.models as models
|
||||||
|
|
||||||
|
@ -14,9 +14,11 @@ class VoiceResponseHandler:
|
||||||
self,
|
self,
|
||||||
stt: stt_services.SpeechService,
|
stt: stt_services.SpeechService,
|
||||||
tts: tts_service.TTSService,
|
tts: tts_service.TTSService,
|
||||||
|
# agent: agent_service.AgentService,
|
||||||
):
|
):
|
||||||
self.stt = stt
|
self.stt = stt
|
||||||
self.tts = tts
|
self.tts = tts
|
||||||
|
# self.agent = agent
|
||||||
self.router = fastapi.APIRouter()
|
self.router = fastapi.APIRouter()
|
||||||
self.router.add_api_route(
|
self.router.add_api_route(
|
||||||
"/",
|
"/",
|
||||||
|
@ -35,12 +37,13 @@ class VoiceResponseHandler:
|
||||||
voice_text: str = await self.stt.recognize(voice)
|
voice_text: str = await self.stt.recognize(voice)
|
||||||
if voice_text == "":
|
if voice_text == "":
|
||||||
raise fastapi.HTTPException(status_code=http.HTTPStatus.BAD_REQUEST, detail="Speech recognition failed")
|
raise fastapi.HTTPException(status_code=http.HTTPStatus.BAD_REQUEST, detail="Speech recognition failed")
|
||||||
# TODO: Добавить обработку текста через клиента openai
|
|
||||||
# TODO: Добавить синтез речи через клиента tts
|
# agent_request = models.AgentCreateRequestModel(channel=channel, user_id=user_id, text=voice_text)
|
||||||
# TODO: Заменить заглушку на реальный ответ
|
# reply_text = await self.agent.process_request(agent_request)
|
||||||
|
reply_text = "hi there"
|
||||||
response = await self.tts.get_audio_as_bytes(
|
response = await self.tts.get_audio_as_bytes(
|
||||||
models.TTSCreateRequestModel(
|
models.TTSCreateRequestModel(
|
||||||
text=voice_text,
|
text=reply_text,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return fastapi.responses.StreamingResponse(io.BytesIO(response.audio_content), media_type="audio/ogg")
|
return fastapi.responses.StreamingResponse(io.BytesIO(response.audio_content), media_type="audio/ogg")
|
||||||
|
|
|
@ -14,6 +14,9 @@ import lib.clients as clients
|
||||||
import lib.models as models
|
import lib.models as models
|
||||||
import lib.stt as stt
|
import lib.stt as stt
|
||||||
import lib.tts as tts
|
import lib.tts as tts
|
||||||
|
import lib.agent.repositories as agent_repositories
|
||||||
|
import lib.agent.repositories.openai_functions as agent_functions
|
||||||
|
import lib.agent.services as agent_services
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -89,7 +92,8 @@ class Application:
|
||||||
|
|
||||||
logger.info("Initializing repositories")
|
logger.info("Initializing repositories")
|
||||||
stt_repository: stt.STTProtocol = stt.OpenaiSpeechRepository(settings=settings)
|
stt_repository: stt.STTProtocol = stt.OpenaiSpeechRepository(settings=settings)
|
||||||
chat_history_repository = agent.ChatHistoryRepository(pg_async_session=postgres_client.get_async_session())
|
chat_history_repository = agent_repositories.ChatHistoryRepository(pg_async_session=postgres_client.get_async_session())
|
||||||
|
embedding_repository = agent_repositories.EmbeddingRepository(settings)
|
||||||
|
|
||||||
tts_yandex_repository = tts.TTSYandexRepository(
|
tts_yandex_repository = tts.TTSYandexRepository(
|
||||||
tts_settings=app_split_settings.TTSYandexSettings(),
|
tts_settings=app_split_settings.TTSYandexSettings(),
|
||||||
|
@ -101,13 +105,20 @@ class Application:
|
||||||
is_models_from_api=True,
|
is_models_from_api=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Caches
|
# Caches
|
||||||
|
|
||||||
logger.info("Initializing caches")
|
logger.info("Initializing caches")
|
||||||
|
|
||||||
|
# Tools
|
||||||
|
|
||||||
|
agent_tools = agent_functions.OpenAIFunctions(repository=embedding_repository, pg_async_session=postgres_client.get_async_session())
|
||||||
|
|
||||||
# Services
|
# Services
|
||||||
|
|
||||||
logger.info("Initializing services")
|
logger.info("Initializing services")
|
||||||
|
|
||||||
|
|
||||||
stt_service: stt.SpeechService = stt.SpeechService(repository=stt_repository) # type: ignore
|
stt_service: stt.SpeechService = stt.SpeechService(repository=stt_repository) # type: ignore
|
||||||
|
|
||||||
tts_service: tts.TTSService = tts.TTSService( # type: ignore
|
tts_service: tts.TTSService = tts.TTSService( # type: ignore
|
||||||
|
@ -117,6 +128,9 @@ class Application:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
agent_service: agent_services.AgentService(settings=settings, chat_repository=chat_history_repository)
|
||||||
|
# agent_service: agent_services.AgentService(settings=settings, chat_repository=chat_history_repository, tools=agent_tools)
|
||||||
|
|
||||||
# Handlers
|
# Handlers
|
||||||
|
|
||||||
logger.info("Initializing handlers")
|
logger.info("Initializing handlers")
|
||||||
|
@ -127,6 +141,7 @@ class Application:
|
||||||
voice_response_handler = api_v1_handlers.VoiceResponseHandler(
|
voice_response_handler = api_v1_handlers.VoiceResponseHandler(
|
||||||
stt=stt_service,
|
stt=stt_service,
|
||||||
tts=tts_service,
|
tts=tts_service,
|
||||||
|
agent=agent_services,
|
||||||
).router
|
).router
|
||||||
|
|
||||||
logger.info("Creating application")
|
logger.info("Creating application")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user