1
0
mirror of https://github.com/ijaric/voice_assistant.git synced 2025-05-24 14:33:26 +00:00
This commit is contained in:
Aleksandr Sukharev 2023-10-15 04:29:48 +00:00 committed by GitHub
commit 3f9d5a526b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 90 additions and 5 deletions

View File

@ -1,7 +1,9 @@
from .health import basic_router from .health import basic_router
from .tts import *
from .voice_responce_handler import VoiceResponseHandler from .voice_responce_handler import VoiceResponseHandler
__all__ = [ __all__ = [
"TTSHandler",
"VoiceResponseHandler", "VoiceResponseHandler",
"basic_router", "basic_router",
] ]

View File

@ -0,0 +1,66 @@
import http
import fastapi
import lib.models as models
import lib.tts.services as tts_service
class TTSHandler:
def __init__(
self,
tts: tts_service.TTSService,
):
self.tts = tts
self.router = fastapi.APIRouter()
self.router.add_api_route(
"/fields",
self.get_by_fields,
methods=["POST"],
summary="Получение моделей по полю",
description="Возвращает список моделей с указанными полями",
)
self.router.add_api_route(
"/name",
self.get_by_name,
methods=["POST"],
summary="Получение модели по имени",
description="Позволяет получить одну модель по её имени",
)
self.router.add_api_route(
"/",
self.get_all,
methods=["GET"],
summary="Получение всех доступных моделей",
description="Возвращает список всех доступных моделей",
)
self.router.add_api_route(
"/languages",
self.get_languages,
methods=["GET"],
summary="Получение всех доступных языков",
description="Возвращает список всех доступных языков",
)
async def get_by_fields(
self,
data: models.TTSSearchVoiceRequestModel,
) -> list[models.AVAILABLE_MODELS_TYPE]:
response = await self.tts.get_list_voice_models_by_fields(data)
return response
async def get_by_name(
self,
model_name: str,
) -> models.BaseVoiceModel:
response = await self.tts.get_voice_model_by_name(model_name)
if not response:
raise fastapi.HTTPException(status_code=http.HTTPStatus.BAD_REQUEST, detail="Model not found")
return response
async def get_all(self) -> list[models.AVAILABLE_MODELS_TYPE]:
return await self.tts.get_all_models()
@classmethod
async def get_languages(cls) -> dict[str, str]:
return {language.name: language.value for language in models.BaseLanguageCodesEnum}

View File

@ -107,15 +107,15 @@ class Application:
# Services # Services
logger.info("Initializing services") logger.info("Initializing services")
stt_service: stt.SpeechService = stt.SpeechService(repository=stt_repository) # type: ignore stt_service: stt.SpeechService = stt.SpeechService(repository=stt_repository)
tts_service: tts.TTSService = tts.TTSService( # type: ignore tts_service: tts.TTSService = tts.TTSService(
repositories={ repositories={
models.VoiceModelProvidersEnum.YANDEX: tts_yandex_repository, models.VoiceModelProvidersEnum.YANDEX: tts_yandex_repository,
models.VoiceModelProvidersEnum.ELEVEN_LABS: tts_eleven_labs_repository, models.VoiceModelProvidersEnum.ELEVEN_LABS: tts_eleven_labs_repository,
}, },
) )
# Handlers # Handlers
logger.info("Initializing handlers") logger.info("Initializing handlers")
@ -127,6 +127,8 @@ class Application:
# tts=tts_service, # TODO # tts=tts_service, # TODO
).router ).router
tts_handler = api_v1_handlers.TTSHandler(tts=tts_service).router
logger.info("Creating application") logger.info("Creating application")
fastapi_app = fastapi.FastAPI( fastapi_app = fastapi.FastAPI(
@ -140,6 +142,7 @@ class Application:
# Routes # Routes
fastapi_app.include_router(liveness_probe_handler, prefix="/api/v1/health", tags=["health"]) fastapi_app.include_router(liveness_probe_handler, prefix="/api/v1/health", tags=["health"])
fastapi_app.include_router(voice_response_handler, prefix="/api/v1/voice", tags=["voice"]) fastapi_app.include_router(voice_response_handler, prefix="/api/v1/voice", tags=["voice"])
fastapi_app.include_router(tts_handler, prefix="/api/v1/tts", tags=["tts"])
application = Application( application = Application(
settings=settings, settings=settings,

View File

@ -7,7 +7,6 @@ import lib.models.tts.voice.languages as models_tts_languages
class ElevenLabsVoiceModel(models_tts_base.BaseVoiceModel): class ElevenLabsVoiceModel(models_tts_base.BaseVoiceModel):
model_config = pydantic.ConfigDict(use_enum_values=True)
voice_id: str voice_id: str
voice_name: str | None = None voice_name: str | None = None
languages: list[models_tts_languages.LANGUAGE_CODES_ENUM_TYPE] languages: list[models_tts_languages.LANGUAGE_CODES_ENUM_TYPE]

View File

@ -4,6 +4,10 @@ import lib.models as models
class TTSRepositoryProtocol(typing.Protocol): class TTSRepositoryProtocol(typing.Protocol):
@property
async def voice_models(self) -> models.LIST_VOICE_MODELS_TYPE:
raise NotImplementedError
async def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel: async def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel:
... ...

View File

@ -49,7 +49,7 @@ class TTSBaseRepository(abc.ABC):
continue continue
break break
voice_model_dump = voice_model.model_dump() voice_model_dump = voice_model.model_dump()
if voice_model_dump[field] != field_value.name: if voice_model_dump[field] != field_value:
break break
else: else:
voice_models_response.append(voice_model) voice_models_response.append(voice_model)

View File

@ -31,3 +31,14 @@ class TTSService:
if voice_models.models: if voice_models.models:
response_models.extend(voice_models.models) response_models.extend(voice_models.models)
return response_models return response_models
async def get_all_models(self) -> list[_models.AVAILABLE_MODELS_TYPE]:
response_models: list[_models.AVAILABLE_MODELS_TYPE] = []
for repository in self.repositories.values():
response = await repository.voice_models
for model in response.models:
model.languages = [ # type: ignore
_models.BaseLanguageCodesEnum[language.name] for language in model.languages
]
response_models.append(model)
return response_models