mirror of
https://github.com/ijaric/voice_assistant.git
synced 2025-05-24 14:33:26 +00:00
fix: start error & linting
This commit is contained in:
parent
288af57ab5
commit
2344b576d8
|
@ -1,5 +1,5 @@
|
|||
from .services import AgentService
|
||||
# from .services import AgentService
|
||||
|
||||
__all__ = [
|
||||
"AgentService",
|
||||
]
|
||||
# __all__ = [
|
||||
# "AgentService",
|
||||
# ]
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import logging
|
||||
import typing
|
||||
|
||||
import langchain.chat_models
|
||||
import openai
|
||||
import openai.error
|
||||
|
||||
|
@ -40,25 +39,3 @@ class EmbeddingRepository:
|
|||
|
||||
except openai.error.OpenAIError:
|
||||
self.logger.exception("Failed to get async embedding for: %s", text)
|
||||
|
||||
|
||||
class LlmRepository:
|
||||
"""A service for getting embeddings from OpenAI."""
|
||||
|
||||
def __init__(self, settings: app_settings.Settings) -> None:
|
||||
"""Initialize the service with an OpenAI API key."""
|
||||
self.llm = langchain.chat_models.ChatOpenAI(
|
||||
temperature=0.7,
|
||||
openai_api_key=self.settings.openai.api_key.get_secret_value()
|
||||
)
|
||||
|
||||
async def get_chat_response(self, request: str, prompt: str) -> str:
|
||||
"""Get the embedding for a given text."""
|
||||
prompt = langchain.prompts.ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", prompt),
|
||||
]
|
||||
)
|
||||
chain = prompt | self.llm
|
||||
response = await chain.ainvoke({"input": request})
|
||||
return response.content
|
||||
|
|
|
@ -5,38 +5,33 @@ import langchain.agents
|
|||
import langchain.agents.format_scratchpad
|
||||
import langchain.agents.output_parsers
|
||||
import langchain.chat_models
|
||||
import langchain.memory
|
||||
import langchain.memory.chat_memory
|
||||
import langchain.prompts
|
||||
import langchain.schema
|
||||
import langchain.tools.render
|
||||
import langchain.memory
|
||||
import langchain.memory.chat_memory
|
||||
|
||||
import lib.models as models
|
||||
import lib.agent.repositories as lib_agent_repositories
|
||||
import lib.agent.repositories.chat_repository as chat_repositories
|
||||
import lib.app.settings as app_settings
|
||||
import lib.agent.repositories.chat_repository as _chat_repository
|
||||
import lib.models as models
|
||||
|
||||
|
||||
class AgentService:
|
||||
def __init__(
|
||||
self,
|
||||
settings: app_settings.Settings,
|
||||
chat_repository: _chat_repository.ChatHistoryRepository,
|
||||
tools: lib_agent_repositories.OpenAIFunctions | None = None,
|
||||
chat_repository: chat_repositories.ChatHistoryRepository,
|
||||
tools: lib_agent_repositories.OpenAIFunctions,
|
||||
) -> None:
|
||||
self.settings = settings
|
||||
self.tools = tools
|
||||
self.chat_repository = chat_repository
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def process_request(self, request: models.AgentCreateRequestModel) -> models.AgentCreateResponseModel:
|
||||
# Get session ID
|
||||
session_request = models.RequestLastSessionId(
|
||||
channel=request.channel,
|
||||
user_id=request.user_id,
|
||||
minutes_ago=3
|
||||
)
|
||||
session_request = models.RequestLastSessionId(channel=request.channel, user_id=request.user_id, minutes_ago=3)
|
||||
session_id = await self.chat_repository.get_last_session_id(session_request)
|
||||
if not session_id:
|
||||
session_id = uuid.uuid4()
|
||||
|
@ -47,23 +42,26 @@ class AgentService:
|
|||
name="GetMovieByDescription",
|
||||
func=self.tools.get_movie_by_description,
|
||||
coroutine=self.tools.get_movie_by_description,
|
||||
description="Get a movie by description"
|
||||
description="Get a movie by description",
|
||||
),
|
||||
]
|
||||
|
||||
llm = langchain.chat_models.ChatOpenAI(temperature=self.settings.openai.agent_temperature, openai_api_key=self.settings.openai.api_key.get_secret_value())
|
||||
llm = langchain.chat_models.ChatOpenAI(
|
||||
temperature=self.settings.openai.agent_temperature,
|
||||
openai_api_key=self.settings.openai.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
chat_history = []
|
||||
chat_history_name = f"{chat_history=}".partition("=")[0]
|
||||
|
||||
request_chat_history = models.RequestChatHistory(session_id=session_id)
|
||||
chat_history_source = await self.chat_repository.get_messages_by_sid(request_chat_history)
|
||||
for entry in chat_history_source:
|
||||
if entry.role == "user":
|
||||
chat_history.append(langchain.schema.HumanMessage(content=entry.content))
|
||||
elif entry.role == "agent":
|
||||
chat_history.append(langchain.schema.AIMessage(content=entry.content))
|
||||
|
||||
if not chat_history_source:
|
||||
for entry in chat_history_source:
|
||||
if entry.role == "user":
|
||||
chat_history.append(langchain.schema.HumanMessage(content=entry.content))
|
||||
elif entry.role == "agent":
|
||||
chat_history.append(langchain.schema.AIMessage(content=entry.content))
|
||||
|
||||
prompt = langchain.prompts.ChatPromptTemplate.from_messages(
|
||||
[
|
||||
|
@ -94,7 +92,6 @@ class AgentService:
|
|||
| langchain.agents.output_parsers.OpenAIFunctionsAgentOutputParser()
|
||||
)
|
||||
|
||||
|
||||
agent_executor = langchain.agents.AgentExecutor(agent=agent, tools=tools, verbose=False)
|
||||
chat_history = [] # temporary disable chat_history
|
||||
response = await agent_executor.ainvoke({"input": request.text, "chat_history": chat_history})
|
||||
|
@ -103,16 +100,18 @@ class AgentService:
|
|||
session_id=session_id,
|
||||
user_id=request.user_id,
|
||||
channel=request.channel,
|
||||
message={"role": "user", "content": request.text}
|
||||
message={"role": "user", "content": request.text},
|
||||
)
|
||||
ai_response = models.RequestChatMessage(
|
||||
session_id=session_id,
|
||||
user_id=request.user_id,
|
||||
channel=request.channel,
|
||||
message={"role": "assistant", "content": response["output"]}
|
||||
message={"role": "assistant", "content": response["output"]},
|
||||
)
|
||||
|
||||
await self.chat_repository.add_message(user_request)
|
||||
await self.chat_repository.add_message(ai_response)
|
||||
|
||||
return response
|
||||
print("RES:", response)
|
||||
|
||||
return models.AgentCreateResponseModel(text="response")
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
from .agent import AgentHandler
|
||||
from .health import basic_router
|
||||
from .voice_responce_handler import VoiceResponseHandler
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentHandler",
|
||||
"VoiceResponseHandler",
|
||||
"basic_router",
|
||||
]
|
||||
|
|
|
@ -1,59 +0,0 @@
|
|||
import uuid
|
||||
|
||||
import fastapi
|
||||
|
||||
import lib.agent.repositories as agent_repositories
|
||||
import lib.models as models
|
||||
|
||||
|
||||
class AgentHandler:
|
||||
def __init__(self, chat_history_repository: agent_repositories.ChatHistoryRepository):
|
||||
self.chat_history_repository = chat_history_repository
|
||||
self.router = fastapi.APIRouter()
|
||||
self.router.add_api_route(
|
||||
"/",
|
||||
self.get_agent,
|
||||
methods=["GET"],
|
||||
summary="Статус работоспособности",
|
||||
description="Проверяет доступность сервиса FastAPI.",
|
||||
)
|
||||
self.router.add_api_route(
|
||||
"/add",
|
||||
self.add_message,
|
||||
methods=["GET"],
|
||||
summary="Статус работоспособности",
|
||||
description="Проверяет доступность сервиса FastAPI.",
|
||||
)
|
||||
self.router.add_api_route(
|
||||
"/messages",
|
||||
self.get_messages,
|
||||
methods=["GET"],
|
||||
summary="Статус работоспособности",
|
||||
description="Проверяет доступность сервиса FastAPI.",
|
||||
)
|
||||
|
||||
async def get_agent(self):
|
||||
request = models.RequestLastSessionId(channel="test", user_id="user_id_1", minutes_ago=3)
|
||||
response = await self.chat_history_repository.get_last_session_id(request=request)
|
||||
print("RESPONSE: ", response)
|
||||
return {"response": response}
|
||||
|
||||
async def add_message(self):
|
||||
sid: uuid.UUID = uuid.UUID("0cd3c882-affd-4929-aff1-e1724f5b54f2")
|
||||
import faker
|
||||
|
||||
fake = faker.Faker()
|
||||
|
||||
message = models.RequestChatMessage(
|
||||
session_id=sid, user_id="user_id_1", channel="test", message={"role": "system", "content": fake.sentence()}
|
||||
)
|
||||
await self.chat_history_repository.add_message(request=message)
|
||||
return {"response": "ok"}
|
||||
|
||||
async def get_messages(self):
|
||||
sid: uuid.UUID = uuid.UUID("0cd3c882-affd-4929-aff1-e1724f5b54f2")
|
||||
|
||||
request = models.RequestChatHistory(session_id=sid)
|
||||
response = await self.chat_history_repository.get_messages_by_sid(request=request)
|
||||
print("RESPONSE: ", response)
|
||||
return {"response": response}
|
|
@ -3,10 +3,10 @@ import io
|
|||
|
||||
import fastapi
|
||||
|
||||
import lib.stt.services as stt_services
|
||||
import lib.agent.services as agent_service
|
||||
import lib.tts.services as tts_service
|
||||
import lib.models as models
|
||||
import lib.stt.services as stt_services
|
||||
import lib.tts.services as tts_service
|
||||
|
||||
|
||||
class VoiceResponseHandler:
|
||||
|
@ -14,11 +14,11 @@ class VoiceResponseHandler:
|
|||
self,
|
||||
stt: stt_services.SpeechService,
|
||||
tts: tts_service.TTSService,
|
||||
# agent: agent_service.AgentService,
|
||||
agent: agent_service.AgentService,
|
||||
):
|
||||
self.stt = stt
|
||||
self.tts = tts
|
||||
# self.agent = agent
|
||||
self.agent = agent
|
||||
self.router = fastapi.APIRouter()
|
||||
self.router.add_api_route(
|
||||
"/",
|
||||
|
@ -38,12 +38,12 @@ class VoiceResponseHandler:
|
|||
if voice_text == "":
|
||||
raise fastapi.HTTPException(status_code=http.HTTPStatus.BAD_REQUEST, detail="Speech recognition failed")
|
||||
|
||||
# agent_request = models.AgentCreateRequestModel(channel=channel, user_id=user_id, text=voice_text)
|
||||
# reply_text = await self.agent.process_request(agent_request)
|
||||
reply_text = "hi there"
|
||||
agent_request = models.AgentCreateRequestModel(channel=channel, user_id=user_id, text=voice_text)
|
||||
reply_text = await self.agent.process_request(agent_request)
|
||||
|
||||
response = await self.tts.get_audio_as_bytes(
|
||||
models.TTSCreateRequestModel(
|
||||
text=reply_text,
|
||||
text=reply_text.text,
|
||||
)
|
||||
)
|
||||
return fastapi.responses.StreamingResponse(io.BytesIO(response.audio_content), media_type="audio/ogg")
|
||||
|
|
|
@ -6,6 +6,9 @@ import typing
|
|||
import fastapi
|
||||
import uvicorn
|
||||
|
||||
import lib.agent.repositories as agent_repositories
|
||||
import lib.agent.repositories.openai_functions as agent_functions
|
||||
import lib.agent.services as agent_services
|
||||
import lib.api.v1.handlers as api_v1_handlers
|
||||
import lib.app.errors as app_errors
|
||||
import lib.app.settings as app_settings
|
||||
|
@ -14,9 +17,6 @@ import lib.clients as clients
|
|||
import lib.models as models
|
||||
import lib.stt as stt
|
||||
import lib.tts as tts
|
||||
import lib.agent.repositories as agent_repositories
|
||||
import lib.agent.repositories.openai_functions as agent_functions
|
||||
import lib.agent.services as agent_services
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -92,9 +92,14 @@ class Application:
|
|||
|
||||
logger.info("Initializing repositories")
|
||||
stt_repository: stt.STTProtocol = stt.OpenaiSpeechRepository(settings=settings)
|
||||
chat_history_repository = agent_repositories.ChatHistoryRepository(pg_async_session=postgres_client.get_async_session())
|
||||
embedding_repository = agent_repositories.EmbeddingRepository(settings)
|
||||
|
||||
chat_history_repository = agent_repositories.ChatHistoryRepository(
|
||||
pg_async_session=postgres_client.get_async_session()
|
||||
)
|
||||
embedding_repository = agent_repositories.EmbeddingRepository(settings=settings)
|
||||
agent_tools = agent_functions.OpenAIFunctions(
|
||||
repository=embedding_repository, pg_async_session=postgres_client.get_async_session()
|
||||
)
|
||||
agent_tools = None
|
||||
tts_yandex_repository = tts.TTSYandexRepository(
|
||||
tts_settings=app_split_settings.TTSYandexSettings(),
|
||||
client=http_yandex_tts_client,
|
||||
|
@ -105,43 +110,40 @@ class Application:
|
|||
is_models_from_api=True,
|
||||
)
|
||||
|
||||
|
||||
# Caches
|
||||
|
||||
logger.info("Initializing caches")
|
||||
|
||||
# Tools
|
||||
|
||||
agent_tools = agent_functions.OpenAIFunctions(repository=embedding_repository, pg_async_session=postgres_client.get_async_session())
|
||||
|
||||
# Services
|
||||
|
||||
logger.info("Initializing services")
|
||||
|
||||
stt_service: stt.SpeechService = stt.SpeechService(repository=stt_repository)
|
||||
|
||||
stt_service: stt.SpeechService = stt.SpeechService(repository=stt_repository) # type: ignore
|
||||
|
||||
tts_service: tts.TTSService = tts.TTSService( # type: ignore
|
||||
tts_service: tts.TTSService = tts.TTSService(
|
||||
repositories={
|
||||
models.VoiceModelProvidersEnum.YANDEX: tts_yandex_repository,
|
||||
models.VoiceModelProvidersEnum.ELEVEN_LABS: tts_eleven_labs_repository,
|
||||
},
|
||||
)
|
||||
|
||||
agent_service: agent_services.AgentService(settings=settings, chat_repository=chat_history_repository)
|
||||
# agent_service: agent_services.AgentService(settings=settings, chat_repository=chat_history_repository, tools=agent_tools)
|
||||
agent_service = agent_services.AgentService(
|
||||
settings=settings, chat_repository=chat_history_repository, tools=agent_tools
|
||||
)
|
||||
|
||||
|
||||
# Handlers
|
||||
|
||||
logger.info("Initializing handlers")
|
||||
liveness_probe_handler = api_v1_handlers.basic_router
|
||||
agent_handler = api_v1_handlers.AgentHandler(chat_history_repository=chat_history_repository).router
|
||||
|
||||
# TODO: объявить сервисы tts и openai и добавить их в voice_response_handler
|
||||
voice_response_handler = api_v1_handlers.VoiceResponseHandler(
|
||||
stt=stt_service,
|
||||
tts=tts_service,
|
||||
agent=agent_services,
|
||||
agent=agent_service,
|
||||
).router
|
||||
|
||||
logger.info("Creating application")
|
||||
|
@ -156,7 +158,6 @@ class Application:
|
|||
|
||||
# Routes
|
||||
fastapi_app.include_router(liveness_probe_handler, prefix="/api/v1/health", tags=["health"])
|
||||
fastapi_app.include_router(agent_handler, prefix="/api/v1/agent", tags=["testing"])
|
||||
fastapi_app.include_router(voice_response_handler, prefix="/api/v1/voice", tags=["voice"])
|
||||
|
||||
application = Application(
|
||||
|
|
|
@ -1,21 +1,19 @@
|
|||
from .agent import *
|
||||
from .chat_history import Message, RequestChatHistory, RequestChatMessage, RequestLastSessionId
|
||||
from .embedding import Embedding
|
||||
from .movies import Movie
|
||||
from .token import Token
|
||||
from .tts import *
|
||||
from .agent import *
|
||||
|
||||
|
||||
# __all__ = ["Embedding", "Message", "Movie", "RequestChatHistory", "RequestChatMessage", "RequestLastSessionId", "Token"]
|
||||
__all__ = [
|
||||
"AVAILABLE_MODELS_TYPE",
|
||||
# "Base",
|
||||
"AgentCreateRequestModel",
|
||||
"BaseLanguageCodesEnum",
|
||||
"BaseVoiceModel",
|
||||
"ElevenLabsLanguageCodesEnum",
|
||||
"ElevenLabsListVoiceModelsModel",
|
||||
"ElevenLabsVoiceModel",
|
||||
# "IdCreatedUpdatedBaseMixin",
|
||||
"LANGUAGE_CODES_ENUM_TYPE",
|
||||
"LIST_VOICE_MODELS_TYPE",
|
||||
"TTSCreateRequestModel",
|
||||
|
@ -26,5 +24,4 @@ __all__ = [
|
|||
"YandexLanguageCodesEnum",
|
||||
"YandexListVoiceModelsModel",
|
||||
"YandexVoiceModel",
|
||||
"AgentCreateRequestModel",
|
||||
]
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import uuid
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
|
|
|
@ -25,21 +25,21 @@ asyncpg = "^0.28.0"
|
|||
fastapi = "0.103.1"
|
||||
greenlet = "^2.0.2"
|
||||
httpx = "^0.25.0"
|
||||
langchain = "^0.0.314"
|
||||
multidict = "^6.0.4"
|
||||
openai = "^0.28.1"
|
||||
orjson = "^3.9.7"
|
||||
pgvector = "^0.2.3"
|
||||
psycopg2-binary = "^2.9.9"
|
||||
pydantic = {extras = ["email"], version = "^2.3.0"}
|
||||
pydantic-settings = "^2.0.3"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
python = "^3.11"
|
||||
python-jose = "^3.3.0"
|
||||
python-magic = "^0.4.27"
|
||||
python-multipart = "^0.0.6"
|
||||
sqlalchemy = "^2.0.20"
|
||||
uvicorn = "^0.23.2"
|
||||
pgvector = "^0.2.3"
|
||||
python-magic = "^0.4.27"
|
||||
openai = "^0.28.1"
|
||||
python-multipart = "^0.0.6"
|
||||
langchain = "^0.0.314"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
black = "^23.7.0"
|
||||
|
|
Loading…
Reference in New Issue
Block a user