1
0
mirror of https://github.com/ijaric/voice_assistant.git synced 2025-05-24 14:33:26 +00:00

Merge pull request #48 from ijaric/tasks/#47_tts_handlers_repositories

[#47] TTS: Repositories
This commit is contained in:
Aleksandr Sukharev 2023-10-14 22:51:35 +03:00 committed by GitHub
commit ad9d4fdb4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 302 additions and 38 deletions

View File

@ -111,6 +111,11 @@ jobs:
API_HOST: ${{ vars.API_HOST }} API_HOST: ${{ vars.API_HOST }}
API_PORT: ${{ vars.API_PORT }} API_PORT: ${{ vars.API_PORT }}
APP_RELOAD: ${{ vars.APP_RELOAD }} APP_RELOAD: ${{ vars.APP_RELOAD }}
TTS_YANDEX_API_KEY: ${{ secrets.TTS_YANDEX_API_KEY }}
TTS_ELEVEN_LABS_API_KEY: ${{ secrets.TTS_ELEVEN_LABS_API_KEY }}
TTS_YANDEX_AUDIO_FORMAT: ${{ vars.TTS_YANDEX_AUDIO_FORMAT }}
TTS_YANDEX_SAMPLE_RATE_HERTZ: ${{ vars.TTS_YANDEX_SAMPLE_RATE_HERTZ }}
TTS_ELEVEN_LABS_DEFAULT_VOICE_ID: ${{ vars.TTS_ELEVEN_LABS_DEFAULT_VOICE_ID }}
working-directory: src/${{ matrix.package }} working-directory: src/${{ matrix.package }}
run: | run: |
make ci-test make ci-test

View File

@ -30,3 +30,12 @@ VOICE_MAX_INPUT_SECONDS=30
OPENAI_API_KEY=sk-1234567890 OPENAI_API_KEY=sk-1234567890
OPENAI_STT_MODEL=whisper-1 OPENAI_STT_MODEL=whisper-1
TTS_YANDEX_API_KEY=
TTS_YANDEX_AUDIO_FORMAT=oggopus
TTS_YANDEX_SAMPLE_RATE_HERTZ=48000
TTS_YANDEX_TIMEOUT_SECONDS=30
TTS_ELEVEN_LABS_API_KEY=
TTS_ELEVEN_LABS_DEFAULT_VOICE_ID=EXAVITQu4vr4xnSDxMaL
TTS_ELEVEN_LABS_TIMEOUT_SECONDS=30

View File

@ -11,7 +11,9 @@ import lib.app.errors as app_errors
import lib.app.settings as app_settings import lib.app.settings as app_settings
import lib.app.split_settings as app_split_settings import lib.app.split_settings as app_split_settings
import lib.clients as clients import lib.clients as clients
import lib.models as models
import lib.stt as stt import lib.stt as stt
import lib.tts as tts
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -59,21 +61,45 @@ class Application:
logger.info("Initializing clients") logger.info("Initializing clients")
http_yandex_tts_client = clients.AsyncHttpClient( http_yandex_tts_client = clients.AsyncHttpClient(
base_url="yandex", # todo add yandex api url from settings
proxy_settings=settings.proxy, proxy_settings=settings.proxy,
base_url=settings.tts_yandex.base_url,
headers=settings.tts_yandex.base_headers,
timeout=settings.tts_yandex.timeout_seconds,
) )
http_eleven_labs_tts_client = clients.AsyncHttpClient(
base_url=settings.tts_eleven_labs.base_url,
headers=settings.tts_eleven_labs.base_headers,
timeout=settings.tts_eleven_labs.timeout_seconds,
)
disposable_resources.append( disposable_resources.append(
DisposableResource( DisposableResource(
name="http_client yandex", name="http_client yandex",
dispose_callback=http_yandex_tts_client.close(), dispose_callback=http_yandex_tts_client.close(),
) )
) )
disposable_resources.append(
DisposableResource(
name="http_client eleven labs",
dispose_callback=http_eleven_labs_tts_client.close(),
)
)
# Repositories # Repositories
logger.info("Initializing repositories") logger.info("Initializing repositories")
stt_repository: stt.STTProtocol = stt.OpenaiSpeechRepository(settings=settings) stt_repository: stt.STTProtocol = stt.OpenaiSpeechRepository(settings=settings)
tts_yandex_repository = tts.TTSYandexRepository(
tts_settings=app_split_settings.TTSYandexSettings(),
client=http_yandex_tts_client,
)
tts_eleven_labs_repository = tts.TTSElevenLabsRepository(
tts_settings=app_split_settings.TTSElevenLabsSettings(),
client=http_eleven_labs_tts_client,
is_models_from_api=True,
)
# Caches # Caches
logger.info("Initializing caches") logger.info("Initializing caches")
@ -81,7 +107,15 @@ class Application:
# Services # Services
logger.info("Initializing services") logger.info("Initializing services")
stt_service: stt.SpeechService = stt.SpeechService(repository=stt_repository) stt_service: stt.SpeechService = stt.SpeechService(repository=stt_repository) # type: ignore
tts_service: tts.TTSService = tts.TTSService( # type: ignore
repositories={
models.VoiceModelProvidersEnum.YANDEX: tts_yandex_repository,
models.VoiceModelProvidersEnum.ELEVEN_LABS: tts_eleven_labs_repository,
},
)
# Handlers # Handlers
logger.info("Initializing handlers") logger.info("Initializing handlers")

View File

@ -12,3 +12,5 @@ class Settings(pydantic_settings.BaseSettings):
project: app_split_settings.ProjectSettings = app_split_settings.ProjectSettings() project: app_split_settings.ProjectSettings = app_split_settings.ProjectSettings()
proxy: app_split_settings.ProxySettings = app_split_settings.ProxySettings() proxy: app_split_settings.ProxySettings = app_split_settings.ProxySettings()
voice: app_split_settings.VoiceSettings = app_split_settings.VoiceSettings() voice: app_split_settings.VoiceSettings = app_split_settings.VoiceSettings()
tts_yandex: app_split_settings.TTSYandexSettings = app_split_settings.TTSYandexSettings()
tts_eleven_labs: app_split_settings.TTSElevenLabsSettings = app_split_settings.TTSElevenLabsSettings()

View File

@ -5,6 +5,7 @@ from .openai import *
from .postgres import * from .postgres import *
from .project import * from .project import *
from .proxy import * from .proxy import *
from .tts import *
from .voice import * from .voice import *
__all__ = [ __all__ = [
@ -15,6 +16,8 @@ __all__ = [
"PostgresSettings", "PostgresSettings",
"ProjectSettings", "ProjectSettings",
"ProxySettings", "ProxySettings",
"TTSElevenLabsSettings",
"TTSYandexSettings",
"VoiceSettings", "VoiceSettings",
"get_logging_config", "get_logging_config",
] ]

View File

@ -0,0 +1,7 @@
from .eleven_labs import *
from .yandex import *
__all__ = [
"TTSElevenLabsSettings",
"TTSYandexSettings",
]

View File

@ -0,0 +1,26 @@
import pydantic
import pydantic_settings
import lib.app.split_settings.utils as app_split_settings_utils
class TTSElevenLabsSettings(pydantic_settings.BaseSettings):
model_config = pydantic_settings.SettingsConfigDict(
env_file=app_split_settings_utils.ENV_PATH,
env_prefix="TTS_ELEVEN_LABS_",
env_file_encoding="utf-8",
extra="ignore",
)
api_key: pydantic.SecretStr = pydantic.Field(default=...)
default_voice_id: str = "EXAVITQu4vr4xnSDxMaL"
base_url: str = "https://api.elevenlabs.io/v1/"
timeout_seconds: int = 30
@property
def base_headers(self) -> dict[str, str]:
return {
"Accept": "audio/mpeg",
"Content-Type": "application/json",
"xi-api-key": self.api_key.get_secret_value(),
}

View File

@ -0,0 +1,28 @@
import typing
import pydantic
import pydantic_settings
import lib.app.split_settings.utils as app_split_settings_utils
class TTSYandexSettings(pydantic_settings.BaseSettings):
model_config = pydantic_settings.SettingsConfigDict(
env_file=app_split_settings_utils.ENV_PATH,
env_prefix="TTS_YANDEX_",
env_file_encoding="utf-8",
extra="ignore",
)
audio_format: typing.Literal["oggopus", "mp3", "lpcm"] = "oggopus"
sample_rate_hertz: int = 48000
api_key: pydantic.SecretStr = pydantic.Field(default=...)
base_url: str = "https://tts.api.cloud.yandex.net/speech/v1/"
timeout_seconds: int = 30
@property
def base_headers(self) -> dict[str, str]:
return {
"Authorization": f"Api-Key {self.api_key.get_secret_value()}",
"Content-Type": "application/x-www-form-urlencoded",
}

View File

@ -8,7 +8,7 @@ import lib.app.split_settings as app_split_settings
class AsyncHttpClient(httpx.AsyncClient): class AsyncHttpClient(httpx.AsyncClient):
def __init__( def __init__(
self, self,
proxy_settings: app_split_settings.ProxySettings, proxy_settings: app_split_settings.ProxySettings | None = None,
base_url: str | None = None, base_url: str | None = None,
**client_params: typing.Any, **client_params: typing.Any,
) -> None: ) -> None:
@ -20,7 +20,7 @@ class AsyncHttpClient(httpx.AsyncClient):
super().__init__(base_url=self.base_url, proxies=self.proxies, **client_params) # type: ignore[reportGeneralTypeIssues] super().__init__(base_url=self.base_url, proxies=self.proxies, **client_params) # type: ignore[reportGeneralTypeIssues]
def __get_proxies_from_settings(self) -> dict[str, str] | None: def __get_proxies_from_settings(self) -> dict[str, str] | None:
if not self.proxy_settings.enable: if not self.proxy_settings or not self.proxy_settings.enable:
return None return None
proxies = {"all://": self.proxy_settings.dsn} proxies = {"all://": self.proxy_settings.dsn}
return proxies return proxies

View File

@ -8,6 +8,8 @@ __all__ = [
"BaseLanguageCodesEnum", "BaseLanguageCodesEnum",
"BaseVoiceModel", "BaseVoiceModel",
"ElevenLabsLanguageCodesEnum", "ElevenLabsLanguageCodesEnum",
"ElevenLabsListVoiceModelsModel",
"ElevenLabsVoiceModel",
"IdCreatedUpdatedBaseMixin", "IdCreatedUpdatedBaseMixin",
"LANGUAGE_CODES_ENUM_TYPE", "LANGUAGE_CODES_ENUM_TYPE",
"LIST_VOICE_MODELS_TYPE", "LIST_VOICE_MODELS_TYPE",
@ -17,4 +19,6 @@ __all__ = [
"Token", "Token",
"VoiceModelProvidersEnum", "VoiceModelProvidersEnum",
"YandexLanguageCodesEnum", "YandexLanguageCodesEnum",
"YandexListVoiceModelsModel",
"YandexVoiceModel",
] ]

View File

@ -6,6 +6,8 @@ __all__ = [
"BaseLanguageCodesEnum", "BaseLanguageCodesEnum",
"BaseVoiceModel", "BaseVoiceModel",
"ElevenLabsLanguageCodesEnum", "ElevenLabsLanguageCodesEnum",
"ElevenLabsListVoiceModelsModel",
"ElevenLabsVoiceModel",
"LANGUAGE_CODES_ENUM_TYPE", "LANGUAGE_CODES_ENUM_TYPE",
"LIST_VOICE_MODELS_TYPE", "LIST_VOICE_MODELS_TYPE",
"TTSCreateRequestModel", "TTSCreateRequestModel",
@ -13,4 +15,6 @@ __all__ = [
"TTSSearchVoiceRequestModel", "TTSSearchVoiceRequestModel",
"VoiceModelProvidersEnum", "VoiceModelProvidersEnum",
"YandexLanguageCodesEnum", "YandexLanguageCodesEnum",
"YandexListVoiceModelsModel",
"YandexVoiceModel",
] ]

View File

@ -5,12 +5,45 @@ import lib.models.tts.voice.languages as models_tts_languages
AVAILABLE_MODELS_TYPE = models_tts_voice.YandexVoiceModel | models_tts_voice.ElevenLabsVoiceModel AVAILABLE_MODELS_TYPE = models_tts_voice.YandexVoiceModel | models_tts_voice.ElevenLabsVoiceModel
LIST_VOICE_MODELS_TYPE = models_tts_voice.YandexListVoiceModelsModel | models_tts_voice.ElevenLabsListVoiceModelsModel LIST_VOICE_MODELS_TYPE = models_tts_voice.YandexListVoiceModelsModel | models_tts_voice.ElevenLabsListVoiceModelsModel
DEFAULT_MODEL = models_tts_voice.ElevenLabsVoiceModel(
voice_id="eleven_multilingual_v2",
languages=[
models_tts_languages.ElevenLabsLanguageCodesEnum.ENGLISH,
models_tts_languages.ElevenLabsLanguageCodesEnum.JAPANESE,
models_tts_languages.ElevenLabsLanguageCodesEnum.CHINESE,
models_tts_languages.ElevenLabsLanguageCodesEnum.GERMAN,
models_tts_languages.ElevenLabsLanguageCodesEnum.HINDI,
models_tts_languages.ElevenLabsLanguageCodesEnum.FRENCH,
models_tts_languages.ElevenLabsLanguageCodesEnum.KOREAN,
models_tts_languages.ElevenLabsLanguageCodesEnum.PORTUGUESE,
models_tts_languages.ElevenLabsLanguageCodesEnum.ITALIAN,
models_tts_languages.ElevenLabsLanguageCodesEnum.SPANISH,
models_tts_languages.ElevenLabsLanguageCodesEnum.INDONESIAN,
models_tts_languages.ElevenLabsLanguageCodesEnum.DUTCH,
models_tts_languages.ElevenLabsLanguageCodesEnum.TURKISH,
models_tts_languages.ElevenLabsLanguageCodesEnum.FILIPINO,
models_tts_languages.ElevenLabsLanguageCodesEnum.POLISH,
models_tts_languages.ElevenLabsLanguageCodesEnum.SWEDISH,
models_tts_languages.ElevenLabsLanguageCodesEnum.BULGARIAN,
models_tts_languages.ElevenLabsLanguageCodesEnum.ROMANIAN,
models_tts_languages.ElevenLabsLanguageCodesEnum.ARABIC,
models_tts_languages.ElevenLabsLanguageCodesEnum.CZECH,
models_tts_languages.ElevenLabsLanguageCodesEnum.GREEK,
models_tts_languages.ElevenLabsLanguageCodesEnum.FINNISH,
models_tts_languages.ElevenLabsLanguageCodesEnum.CROATIAN,
models_tts_languages.ElevenLabsLanguageCodesEnum.MALAY,
models_tts_languages.ElevenLabsLanguageCodesEnum.SLOVAK,
models_tts_languages.ElevenLabsLanguageCodesEnum.DANISH,
models_tts_languages.ElevenLabsLanguageCodesEnum.TAMIL,
models_tts_languages.ElevenLabsLanguageCodesEnum.UKRAINIAN,
],
)
class TTSCreateRequestModel(pydantic.BaseModel): class TTSCreateRequestModel(pydantic.BaseModel):
model_config = pydantic.ConfigDict(use_enum_values=True) model_config = pydantic.ConfigDict(use_enum_values=True)
voice_model: AVAILABLE_MODELS_TYPE voice_model: AVAILABLE_MODELS_TYPE = DEFAULT_MODEL
text: str text: str

View File

@ -20,6 +20,8 @@ class BaseVoiceModel(pydantic.BaseModel):
@pydantic.model_validator(mode="before") @pydantic.model_validator(mode="before")
@classmethod @classmethod
def check_voice_name_exists(cls, data: typing.Any) -> typing.Any: def check_voice_name_exists(cls, data: typing.Any) -> typing.Any:
if not data:
return data
voice_id = data.get("voice_id") voice_id = data.get("voice_id")
voice_name = data.get("voice_name") voice_name = data.get("voice_name")
if not voice_name and voice_id: if not voice_name and voice_id:

View File

@ -71,5 +71,13 @@ class ElevenLabsListVoiceModelsModel(pydantic.BaseModel):
@classmethod @classmethod
def from_api(cls, voice_models_from_api: list[dict[str, typing.Any]]) -> typing.Self: def from_api(cls, voice_models_from_api: list[dict[str, typing.Any]]) -> typing.Self:
voice_models = [ElevenLabsVoiceModel.model_validate(voice_model) for voice_model in voice_models_from_api] voice_models = []
for voice_model in voice_models_from_api:
voice_model["voice_id"] = voice_model.pop("model_id")
voice_model["voice_name"] = voice_model.pop("name")
voice_model["languages"] = [
models_tts_languages.ElevenLabsLanguageCodesEnum(item.get("language_id"))
for item in voice_model.pop("languages")
]
voice_models.append(ElevenLabsVoiceModel.model_validate(voice_model))
return ElevenLabsListVoiceModelsModel(models=voice_models) return ElevenLabsListVoiceModelsModel(models=voice_models)

View File

@ -16,6 +16,8 @@ class YandexVoiceModel(models_tts_base.BaseVoiceModel):
@pydantic.model_validator(mode="before") @pydantic.model_validator(mode="before")
@classmethod @classmethod
def check_voice_name_exists(cls, data: typing.Any) -> typing.Any: def check_voice_name_exists(cls, data: typing.Any) -> typing.Any:
if not data:
return data
voice_id = data.get("voice_id") voice_id = data.get("voice_id")
voice_name = data.get("voice_name") voice_name = data.get("voice_name")
role = data.get("role") role = data.get("role")

View File

@ -1,5 +1,9 @@
from .services import TTSService from .repositories import *
from .services import *
__all__ = [ __all__ = [
"TTSBaseRepository",
"TTSElevenLabsRepository",
"TTSService", "TTSService",
"TTSYandexRepository",
] ]

View File

@ -4,11 +4,13 @@ import lib.models as models
class TTSRepositoryProtocol(typing.Protocol): class TTSRepositoryProtocol(typing.Protocol):
def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel: async def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel:
... ...
def get_voice_model_by_name(self, voice_model_name: str) -> models.BaseVoiceModel | None: async def get_voice_model_by_name(self, voice_model_name: str) -> models.BaseVoiceModel | None:
... ...
def get_voice_models_by_fields(self, fields: models.TTSSearchVoiceRequestModel) -> models.LIST_VOICE_MODELS_TYPE: async def get_voice_models_by_fields(
self, fields: models.TTSSearchVoiceRequestModel
) -> models.LIST_VOICE_MODELS_TYPE:
... ...

View File

@ -1,5 +1,9 @@
from .base import * from .base import *
from .eleven_labs import *
from .yandex import *
__all__ = [ __all__ = [
"TTSBaseRepository", "TTSBaseRepository",
"TTSElevenLabsRepository",
"TTSYandexRepository",
] ]

View File

@ -1,37 +1,35 @@
import abc import abc
import lib.clients as clients
import lib.models as models import lib.models as models
class HttpClient: # Mocked class todo remove and use real http client from lib.clients.http_client
...
class TTSBaseRepository(abc.ABC): class TTSBaseRepository(abc.ABC):
def __init__(self, client: HttpClient, is_models_from_api: bool = False): def __init__(self, client: clients.AsyncHttpClient, is_models_from_api: bool = False):
self.http_client = client self.http_client = client
self.is_models_from_api = is_models_from_api self.is_models_from_api = is_models_from_api
@property @property
@abc.abstractmethod @abc.abstractmethod
def voice_models(self) -> models.LIST_VOICE_MODELS_TYPE: async def voice_models(self) -> models.LIST_VOICE_MODELS_TYPE:
...
@abc.abstractmethod
def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel:
raise NotImplementedError raise NotImplementedError
def get_voice_model_by_name(self, voice_model_name: str) -> models.BaseVoiceModel | None: @abc.abstractmethod
async def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel:
raise NotImplementedError
async def get_voice_model_by_name(self, voice_model_name: str) -> models.BaseVoiceModel | None:
""" """
Search voice model by name Search voice model by name
:param voice_model_name: String name :param voice_model_name: String name
:return: Voice model that match the name :return: Voice model that match the name
""" """
for voice_model in self.voice_models.models: voice_models = await self.voice_models
for voice_model in voice_models.models:
if voice_model.voice_name == voice_model_name: if voice_model.voice_name == voice_model_name:
return voice_model return voice_model
def get_list_voice_models_by_fields( async def get_list_voice_models_by_fields(
self, fields: models.TTSSearchVoiceRequestModel self, fields: models.TTSSearchVoiceRequestModel
) -> list[models.AVAILABLE_MODELS_TYPE]: ) -> list[models.AVAILABLE_MODELS_TYPE]:
""" """
@ -41,7 +39,8 @@ class TTSBaseRepository(abc.ABC):
""" """
fields_dump = fields.model_dump(exclude_none=True) fields_dump = fields.model_dump(exclude_none=True)
voice_models_response = [] voice_models_response = []
for voice_model in self.voice_models.models: voice_models = await self.voice_models
for voice_model in voice_models.models:
for field, field_value in fields_dump.items(): for field, field_value in fields_dump.items():
if field == "languages": # language is a list if field == "languages": # language is a list
language_names: set[str] = {item.name for item in field_value} language_names: set[str] = {item.name for item in field_value}

View File

@ -0,0 +1,42 @@
import typing
import lib.app.split_settings as app_split_settings
import lib.clients as clients
import lib.models as models
import lib.tts.repositories.base as tts_repositories_base
class TTSElevenLabsRepository(tts_repositories_base.TTSBaseRepository):
def __init__(
self,
tts_settings: app_split_settings.TTSElevenLabsSettings,
client: clients.AsyncHttpClient,
is_models_from_api: bool = False,
):
self.tts_settings = tts_settings
super().__init__(client, is_models_from_api)
@property
async def voice_models(self) -> models.ElevenLabsListVoiceModelsModel:
if self.is_models_from_api:
return models.ElevenLabsListVoiceModelsModel.from_api(await self.get_all_models_dict_from_api())
return models.ElevenLabsListVoiceModelsModel()
async def get_all_models_dict_from_api(self) -> list[dict[str, typing.Any]]:
response = await self.http_client.get("/models")
return response.json()
async def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel:
if not isinstance(request.voice_model, models.ElevenLabsVoiceModel):
raise ValueError("ElevenLabs TTS support only ElevenLabsVoiceModel")
response = await self.http_client.post(
f"/text-to-speech/{self.tts_settings.default_voice_id}",
json={"text": request.text, "model_id": request.voice_model.voice_id},
)
return models.TTSCreateResponseModel(audio_content=response.content)
async def get_voice_models_by_fields(
self, fields: models.TTSSearchVoiceRequestModel
) -> models.ElevenLabsListVoiceModelsModel:
list_voice_models = await self.get_list_voice_models_by_fields(fields)
return models.ElevenLabsListVoiceModelsModel(models=list_voice_models) # type: ignore

View File

@ -0,0 +1,48 @@
import logging
import lib.app.split_settings as app_split_settings
import lib.clients as clients
import lib.models as models
import lib.tts.repositories.base as tts_repositories_base
logger = logging.getLogger(__name__)
class TTSYandexRepository(tts_repositories_base.TTSBaseRepository):
def __init__(
self,
tts_settings: app_split_settings.TTSYandexSettings,
client: clients.AsyncHttpClient,
is_models_from_api: bool = False,
):
self.tts_settings = tts_settings
if is_models_from_api:
logger.warning("Yandex TTS doesn't support getting models from API")
super().__init__(client, is_models_from_api=False)
@property
async def voice_models(self) -> models.YandexListVoiceModelsModel:
return models.YandexListVoiceModelsModel()
async def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel:
if not isinstance(request.voice_model, models.YandexVoiceModel):
raise ValueError("Yandex TTS support only YandexVoiceModel")
data = {
"text": request.text,
"lang": request.voice_model.languages[0].value,
"voice": request.voice_model.voice_id,
"emotion": request.voice_model.role,
"format": self.tts_settings.audio_format,
"sampleRateHertz": self.tts_settings.sample_rate_hertz,
}
response = await self.http_client.post(
"/tts:synthesize",
data=data,
)
return models.TTSCreateResponseModel(audio_content=response.content)
async def get_voice_models_by_fields(
self, fields: models.TTSSearchVoiceRequestModel
) -> models.YandexListVoiceModelsModel:
list_voice_models = await self.get_list_voice_models_by_fields(fields)
return models.YandexListVoiceModelsModel(models=list_voice_models) # type: ignore

View File

@ -1,35 +1,33 @@
import lib.app.settings as app_settings import lib.models as _models
import lib.models as models
import lib.tts.models as tts_models import lib.tts.models as tts_models
class TTSService: class TTSService:
def __init__( def __init__(
self, self,
settings: app_settings.Settings, repositories: dict[_models.VoiceModelProvidersEnum, tts_models.TTSRepositoryProtocol],
repositories: dict[models.VoiceModelProvidersEnum, tts_models.TTSRepositoryProtocol],
): ):
self.settings = settings
self.repositories = repositories self.repositories = repositories
def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel: async def get_audio_as_bytes(self, request: _models.TTSCreateRequestModel) -> _models.TTSCreateResponseModel:
model = request.voice_model model = request.voice_model
repository = self.repositories[model.provider] repository = self.repositories[model.provider]
audio_response = repository.get_audio_as_bytes(request) audio_response = await repository.get_audio_as_bytes(request)
return audio_response return audio_response
def get_voice_model_by_name(self, voice_model_name: str) -> models.BaseVoiceModel | None: async def get_voice_model_by_name(self, voice_model_name: str) -> _models.BaseVoiceModel | None:
for repository in self.repositories.values(): for repository in self.repositories.values():
voice_model = repository.get_voice_model_by_name(voice_model_name) voice_model = await repository.get_voice_model_by_name(voice_model_name)
if voice_model: if voice_model:
return voice_model return voice_model
raise ValueError("Voice model not found")
def get_list_voice_models_by_fields( async def get_list_voice_models_by_fields(
self, fields: models.TTSSearchVoiceRequestModel self, fields: _models.TTSSearchVoiceRequestModel
) -> list[models.AVAILABLE_MODELS_TYPE]: ) -> list[_models.AVAILABLE_MODELS_TYPE]:
response_models: list[models.AVAILABLE_MODELS_TYPE] = [] response_models: list[_models.AVAILABLE_MODELS_TYPE] = []
for repository in self.repositories.values(): for repository in self.repositories.values():
voice_models = repository.get_voice_models_by_fields(fields) voice_models = await repository.get_voice_models_by_fields(fields)
if voice_models.models: if voice_models.models:
response_models.extend(voice_models.models) response_models.extend(voice_models.models)
return response_models return response_models