1
0
mirror of https://github.com/ijaric/voice_assistant.git synced 2025-05-24 14:33:26 +00:00

feat: [#45] llm_agent

This commit is contained in:
Григорич 2023-10-15 11:35:53 +03:00
parent 2344b576d8
commit fbc0b4411b
5 changed files with 30 additions and 11 deletions

View File

@ -5,6 +5,7 @@ import langchain.agents
import langchain.agents.format_scratchpad
import langchain.agents.output_parsers
import langchain.chat_models
import langchain.chains
import langchain.memory
import langchain.memory.chat_memory
import langchain.prompts
@ -29,12 +30,28 @@ class AgentService:
self.chat_repository = chat_repository
self.logger = logging.getLogger(__name__)
async def send_message_request(self, request: str, system_prompt: str):
prompt = langchain.prompts.ChatPromptTemplate.from_messages([("system", system_prompt),])
llm = langchain.chat_models.ChatOpenAI(
temperature=self.settings.openai.agent_temperature,
openai_api_key=self.settings.openai.api_key.get_secret_value(),
)
chain = langchain.chains.LLMChain(llm=llm, prompt=prompt)
result = await chain.ainvoke({"input": request})
return result["text"]
async def process_request(self, request: models.AgentCreateRequestModel) -> models.AgentCreateResponseModel:
# Get session ID
request_text = request.text
translate_text = await self.send_message_request(request=request_text, system_prompt="Translation into English")
session_request = models.RequestLastSessionId(channel=request.channel, user_id=request.user_id, minutes_ago=3)
session_id = await self.chat_repository.get_last_session_id(session_request)
if not session_id:
session_id = uuid.uuid4()
await self.send_message_request(request='test', system_prompt="test")
# Declare tools (OpenAI functions)
tools = [
@ -94,7 +111,7 @@ class AgentService:
agent_executor = langchain.agents.AgentExecutor(agent=agent, tools=tools, verbose=False)
chat_history = [] # temporary disable chat_history
response = await agent_executor.ainvoke({"input": request.text, "chat_history": chat_history})
response = await agent_executor.ainvoke({"input": translate_text, "chat_history": chat_history})
user_request = models.RequestChatMessage(
session_id=session_id,
@ -112,6 +129,9 @@ class AgentService:
await self.chat_repository.add_message(user_request)
await self.chat_repository.add_message(ai_response)
print("RES:", response)
return models.AgentCreateResponseModel(text="response")
response_translate = await self.send_message_request(
request=f"Original text: {request_text}. Answer: {response['output']}",
system_prompt="Translate the answer into the language of the original text",
)
print(response_translate)
return models.AgentCreateResponseModel(text=response_translate)

View File

@ -30,8 +30,8 @@ class VoiceResponseHandler:
async def voice_response(
self,
channel: str,
user_id: str,
channel: str="tg",
user_id: str="1234",
voice: bytes = fastapi.File(...),
) -> fastapi.responses.StreamingResponse:
voice_text: str = await self.stt.recognize(voice)

View File

@ -99,7 +99,6 @@ class Application:
agent_tools = agent_functions.OpenAIFunctions(
repository=embedding_repository, pg_async_session=postgres_client.get_async_session()
)
agent_tools = None
tts_yandex_repository = tts.TTSYandexRepository(
tts_settings=app_split_settings.TTSYandexSettings(),
client=http_yandex_tts_client,

View File

@ -15,7 +15,7 @@ class TTSElevenLabsSettings(pydantic_settings.BaseSettings):
api_key: pydantic.SecretStr = pydantic.Field(default=...)
default_voice_id: str = "EXAVITQu4vr4xnSDxMaL"
base_url: str = "https://api.elevenlabs.io/v1/"
timeout_seconds: int = 30
timeout_seconds: int = 120
@property
def base_headers(self) -> dict[str, str]:

View File

@ -18,7 +18,7 @@ class TTSYandexSettings(pydantic_settings.BaseSettings):
sample_rate_hertz: int = 48000
api_key: pydantic.SecretStr = pydantic.Field(default=...)
base_url: str = "https://tts.api.cloud.yandex.net/speech/v1/"
timeout_seconds: int = 30
timeout_seconds: int = 120
@property
def base_headers(self) -> dict[str, str]: