mirror of
https://github.com/ijaric/voice_assistant.git
synced 2025-05-24 22:43:26 +00:00
Merge pull request #48 from ijaric/tasks/#47_tts_handlers_repositories
[#47] TTS: Repositories
This commit is contained in:
commit
ad9d4fdb4d
5
.github/workflows/check-pr.yaml
vendored
5
.github/workflows/check-pr.yaml
vendored
|
@ -111,6 +111,11 @@ jobs:
|
||||||
API_HOST: ${{ vars.API_HOST }}
|
API_HOST: ${{ vars.API_HOST }}
|
||||||
API_PORT: ${{ vars.API_PORT }}
|
API_PORT: ${{ vars.API_PORT }}
|
||||||
APP_RELOAD: ${{ vars.APP_RELOAD }}
|
APP_RELOAD: ${{ vars.APP_RELOAD }}
|
||||||
|
TTS_YANDEX_API_KEY: ${{ secrets.TTS_YANDEX_API_KEY }}
|
||||||
|
TTS_ELEVEN_LABS_API_KEY: ${{ secrets.TTS_ELEVEN_LABS_API_KEY }}
|
||||||
|
TTS_YANDEX_AUDIO_FORMAT: ${{ vars.TTS_YANDEX_AUDIO_FORMAT }}
|
||||||
|
TTS_YANDEX_SAMPLE_RATE_HERTZ: ${{ vars.TTS_YANDEX_SAMPLE_RATE_HERTZ }}
|
||||||
|
TTS_ELEVEN_LABS_DEFAULT_VOICE_ID: ${{ vars.TTS_ELEVEN_LABS_DEFAULT_VOICE_ID }}
|
||||||
working-directory: src/${{ matrix.package }}
|
working-directory: src/${{ matrix.package }}
|
||||||
run: |
|
run: |
|
||||||
make ci-test
|
make ci-test
|
||||||
|
|
|
@ -30,3 +30,12 @@ VOICE_MAX_INPUT_SECONDS=30
|
||||||
|
|
||||||
OPENAI_API_KEY=sk-1234567890
|
OPENAI_API_KEY=sk-1234567890
|
||||||
OPENAI_STT_MODEL=whisper-1
|
OPENAI_STT_MODEL=whisper-1
|
||||||
|
|
||||||
|
TTS_YANDEX_API_KEY=
|
||||||
|
TTS_YANDEX_AUDIO_FORMAT=oggopus
|
||||||
|
TTS_YANDEX_SAMPLE_RATE_HERTZ=48000
|
||||||
|
TTS_YANDEX_TIMEOUT_SECONDS=30
|
||||||
|
|
||||||
|
TTS_ELEVEN_LABS_API_KEY=
|
||||||
|
TTS_ELEVEN_LABS_DEFAULT_VOICE_ID=EXAVITQu4vr4xnSDxMaL
|
||||||
|
TTS_ELEVEN_LABS_TIMEOUT_SECONDS=30
|
||||||
|
|
|
@ -11,7 +11,9 @@ import lib.app.errors as app_errors
|
||||||
import lib.app.settings as app_settings
|
import lib.app.settings as app_settings
|
||||||
import lib.app.split_settings as app_split_settings
|
import lib.app.split_settings as app_split_settings
|
||||||
import lib.clients as clients
|
import lib.clients as clients
|
||||||
|
import lib.models as models
|
||||||
import lib.stt as stt
|
import lib.stt as stt
|
||||||
|
import lib.tts as tts
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -59,21 +61,45 @@ class Application:
|
||||||
logger.info("Initializing clients")
|
logger.info("Initializing clients")
|
||||||
|
|
||||||
http_yandex_tts_client = clients.AsyncHttpClient(
|
http_yandex_tts_client = clients.AsyncHttpClient(
|
||||||
base_url="yandex", # todo add yandex api url from settings
|
|
||||||
proxy_settings=settings.proxy,
|
proxy_settings=settings.proxy,
|
||||||
|
base_url=settings.tts_yandex.base_url,
|
||||||
|
headers=settings.tts_yandex.base_headers,
|
||||||
|
timeout=settings.tts_yandex.timeout_seconds,
|
||||||
)
|
)
|
||||||
|
http_eleven_labs_tts_client = clients.AsyncHttpClient(
|
||||||
|
base_url=settings.tts_eleven_labs.base_url,
|
||||||
|
headers=settings.tts_eleven_labs.base_headers,
|
||||||
|
timeout=settings.tts_eleven_labs.timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
disposable_resources.append(
|
disposable_resources.append(
|
||||||
DisposableResource(
|
DisposableResource(
|
||||||
name="http_client yandex",
|
name="http_client yandex",
|
||||||
dispose_callback=http_yandex_tts_client.close(),
|
dispose_callback=http_yandex_tts_client.close(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
disposable_resources.append(
|
||||||
|
DisposableResource(
|
||||||
|
name="http_client eleven labs",
|
||||||
|
dispose_callback=http_eleven_labs_tts_client.close(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Repositories
|
# Repositories
|
||||||
|
|
||||||
logger.info("Initializing repositories")
|
logger.info("Initializing repositories")
|
||||||
stt_repository: stt.STTProtocol = stt.OpenaiSpeechRepository(settings=settings)
|
stt_repository: stt.STTProtocol = stt.OpenaiSpeechRepository(settings=settings)
|
||||||
|
|
||||||
|
tts_yandex_repository = tts.TTSYandexRepository(
|
||||||
|
tts_settings=app_split_settings.TTSYandexSettings(),
|
||||||
|
client=http_yandex_tts_client,
|
||||||
|
)
|
||||||
|
tts_eleven_labs_repository = tts.TTSElevenLabsRepository(
|
||||||
|
tts_settings=app_split_settings.TTSElevenLabsSettings(),
|
||||||
|
client=http_eleven_labs_tts_client,
|
||||||
|
is_models_from_api=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Caches
|
# Caches
|
||||||
|
|
||||||
logger.info("Initializing caches")
|
logger.info("Initializing caches")
|
||||||
|
@ -81,7 +107,15 @@ class Application:
|
||||||
# Services
|
# Services
|
||||||
|
|
||||||
logger.info("Initializing services")
|
logger.info("Initializing services")
|
||||||
stt_service: stt.SpeechService = stt.SpeechService(repository=stt_repository)
|
stt_service: stt.SpeechService = stt.SpeechService(repository=stt_repository) # type: ignore
|
||||||
|
|
||||||
|
tts_service: tts.TTSService = tts.TTSService( # type: ignore
|
||||||
|
repositories={
|
||||||
|
models.VoiceModelProvidersEnum.YANDEX: tts_yandex_repository,
|
||||||
|
models.VoiceModelProvidersEnum.ELEVEN_LABS: tts_eleven_labs_repository,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Handlers
|
# Handlers
|
||||||
|
|
||||||
logger.info("Initializing handlers")
|
logger.info("Initializing handlers")
|
||||||
|
|
|
@ -12,3 +12,5 @@ class Settings(pydantic_settings.BaseSettings):
|
||||||
project: app_split_settings.ProjectSettings = app_split_settings.ProjectSettings()
|
project: app_split_settings.ProjectSettings = app_split_settings.ProjectSettings()
|
||||||
proxy: app_split_settings.ProxySettings = app_split_settings.ProxySettings()
|
proxy: app_split_settings.ProxySettings = app_split_settings.ProxySettings()
|
||||||
voice: app_split_settings.VoiceSettings = app_split_settings.VoiceSettings()
|
voice: app_split_settings.VoiceSettings = app_split_settings.VoiceSettings()
|
||||||
|
tts_yandex: app_split_settings.TTSYandexSettings = app_split_settings.TTSYandexSettings()
|
||||||
|
tts_eleven_labs: app_split_settings.TTSElevenLabsSettings = app_split_settings.TTSElevenLabsSettings()
|
||||||
|
|
|
@ -5,6 +5,7 @@ from .openai import *
|
||||||
from .postgres import *
|
from .postgres import *
|
||||||
from .project import *
|
from .project import *
|
||||||
from .proxy import *
|
from .proxy import *
|
||||||
|
from .tts import *
|
||||||
from .voice import *
|
from .voice import *
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -15,6 +16,8 @@ __all__ = [
|
||||||
"PostgresSettings",
|
"PostgresSettings",
|
||||||
"ProjectSettings",
|
"ProjectSettings",
|
||||||
"ProxySettings",
|
"ProxySettings",
|
||||||
|
"TTSElevenLabsSettings",
|
||||||
|
"TTSYandexSettings",
|
||||||
"VoiceSettings",
|
"VoiceSettings",
|
||||||
"get_logging_config",
|
"get_logging_config",
|
||||||
]
|
]
|
||||||
|
|
7
src/assistant/lib/app/split_settings/tts/__init__.py
Normal file
7
src/assistant/lib/app/split_settings/tts/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
from .eleven_labs import *
|
||||||
|
from .yandex import *
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TTSElevenLabsSettings",
|
||||||
|
"TTSYandexSettings",
|
||||||
|
]
|
26
src/assistant/lib/app/split_settings/tts/eleven_labs.py
Normal file
26
src/assistant/lib/app/split_settings/tts/eleven_labs.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
import pydantic
|
||||||
|
import pydantic_settings
|
||||||
|
|
||||||
|
import lib.app.split_settings.utils as app_split_settings_utils
|
||||||
|
|
||||||
|
|
||||||
|
class TTSElevenLabsSettings(pydantic_settings.BaseSettings):
|
||||||
|
model_config = pydantic_settings.SettingsConfigDict(
|
||||||
|
env_file=app_split_settings_utils.ENV_PATH,
|
||||||
|
env_prefix="TTS_ELEVEN_LABS_",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
extra="ignore",
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key: pydantic.SecretStr = pydantic.Field(default=...)
|
||||||
|
default_voice_id: str = "EXAVITQu4vr4xnSDxMaL"
|
||||||
|
base_url: str = "https://api.elevenlabs.io/v1/"
|
||||||
|
timeout_seconds: int = 30
|
||||||
|
|
||||||
|
@property
|
||||||
|
def base_headers(self) -> dict[str, str]:
|
||||||
|
return {
|
||||||
|
"Accept": "audio/mpeg",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"xi-api-key": self.api_key.get_secret_value(),
|
||||||
|
}
|
28
src/assistant/lib/app/split_settings/tts/yandex.py
Normal file
28
src/assistant/lib/app/split_settings/tts/yandex.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
import pydantic_settings
|
||||||
|
|
||||||
|
import lib.app.split_settings.utils as app_split_settings_utils
|
||||||
|
|
||||||
|
|
||||||
|
class TTSYandexSettings(pydantic_settings.BaseSettings):
|
||||||
|
model_config = pydantic_settings.SettingsConfigDict(
|
||||||
|
env_file=app_split_settings_utils.ENV_PATH,
|
||||||
|
env_prefix="TTS_YANDEX_",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
extra="ignore",
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_format: typing.Literal["oggopus", "mp3", "lpcm"] = "oggopus"
|
||||||
|
sample_rate_hertz: int = 48000
|
||||||
|
api_key: pydantic.SecretStr = pydantic.Field(default=...)
|
||||||
|
base_url: str = "https://tts.api.cloud.yandex.net/speech/v1/"
|
||||||
|
timeout_seconds: int = 30
|
||||||
|
|
||||||
|
@property
|
||||||
|
def base_headers(self) -> dict[str, str]:
|
||||||
|
return {
|
||||||
|
"Authorization": f"Api-Key {self.api_key.get_secret_value()}",
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
|
}
|
|
@ -8,7 +8,7 @@ import lib.app.split_settings as app_split_settings
|
||||||
class AsyncHttpClient(httpx.AsyncClient):
|
class AsyncHttpClient(httpx.AsyncClient):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
proxy_settings: app_split_settings.ProxySettings,
|
proxy_settings: app_split_settings.ProxySettings | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
**client_params: typing.Any,
|
**client_params: typing.Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -20,7 +20,7 @@ class AsyncHttpClient(httpx.AsyncClient):
|
||||||
super().__init__(base_url=self.base_url, proxies=self.proxies, **client_params) # type: ignore[reportGeneralTypeIssues]
|
super().__init__(base_url=self.base_url, proxies=self.proxies, **client_params) # type: ignore[reportGeneralTypeIssues]
|
||||||
|
|
||||||
def __get_proxies_from_settings(self) -> dict[str, str] | None:
|
def __get_proxies_from_settings(self) -> dict[str, str] | None:
|
||||||
if not self.proxy_settings.enable:
|
if not self.proxy_settings or not self.proxy_settings.enable:
|
||||||
return None
|
return None
|
||||||
proxies = {"all://": self.proxy_settings.dsn}
|
proxies = {"all://": self.proxy_settings.dsn}
|
||||||
return proxies
|
return proxies
|
||||||
|
|
|
@ -8,6 +8,8 @@ __all__ = [
|
||||||
"BaseLanguageCodesEnum",
|
"BaseLanguageCodesEnum",
|
||||||
"BaseVoiceModel",
|
"BaseVoiceModel",
|
||||||
"ElevenLabsLanguageCodesEnum",
|
"ElevenLabsLanguageCodesEnum",
|
||||||
|
"ElevenLabsListVoiceModelsModel",
|
||||||
|
"ElevenLabsVoiceModel",
|
||||||
"IdCreatedUpdatedBaseMixin",
|
"IdCreatedUpdatedBaseMixin",
|
||||||
"LANGUAGE_CODES_ENUM_TYPE",
|
"LANGUAGE_CODES_ENUM_TYPE",
|
||||||
"LIST_VOICE_MODELS_TYPE",
|
"LIST_VOICE_MODELS_TYPE",
|
||||||
|
@ -17,4 +19,6 @@ __all__ = [
|
||||||
"Token",
|
"Token",
|
||||||
"VoiceModelProvidersEnum",
|
"VoiceModelProvidersEnum",
|
||||||
"YandexLanguageCodesEnum",
|
"YandexLanguageCodesEnum",
|
||||||
|
"YandexListVoiceModelsModel",
|
||||||
|
"YandexVoiceModel",
|
||||||
]
|
]
|
||||||
|
|
|
@ -6,6 +6,8 @@ __all__ = [
|
||||||
"BaseLanguageCodesEnum",
|
"BaseLanguageCodesEnum",
|
||||||
"BaseVoiceModel",
|
"BaseVoiceModel",
|
||||||
"ElevenLabsLanguageCodesEnum",
|
"ElevenLabsLanguageCodesEnum",
|
||||||
|
"ElevenLabsListVoiceModelsModel",
|
||||||
|
"ElevenLabsVoiceModel",
|
||||||
"LANGUAGE_CODES_ENUM_TYPE",
|
"LANGUAGE_CODES_ENUM_TYPE",
|
||||||
"LIST_VOICE_MODELS_TYPE",
|
"LIST_VOICE_MODELS_TYPE",
|
||||||
"TTSCreateRequestModel",
|
"TTSCreateRequestModel",
|
||||||
|
@ -13,4 +15,6 @@ __all__ = [
|
||||||
"TTSSearchVoiceRequestModel",
|
"TTSSearchVoiceRequestModel",
|
||||||
"VoiceModelProvidersEnum",
|
"VoiceModelProvidersEnum",
|
||||||
"YandexLanguageCodesEnum",
|
"YandexLanguageCodesEnum",
|
||||||
|
"YandexListVoiceModelsModel",
|
||||||
|
"YandexVoiceModel",
|
||||||
]
|
]
|
||||||
|
|
|
@ -5,12 +5,45 @@ import lib.models.tts.voice.languages as models_tts_languages
|
||||||
|
|
||||||
AVAILABLE_MODELS_TYPE = models_tts_voice.YandexVoiceModel | models_tts_voice.ElevenLabsVoiceModel
|
AVAILABLE_MODELS_TYPE = models_tts_voice.YandexVoiceModel | models_tts_voice.ElevenLabsVoiceModel
|
||||||
LIST_VOICE_MODELS_TYPE = models_tts_voice.YandexListVoiceModelsModel | models_tts_voice.ElevenLabsListVoiceModelsModel
|
LIST_VOICE_MODELS_TYPE = models_tts_voice.YandexListVoiceModelsModel | models_tts_voice.ElevenLabsListVoiceModelsModel
|
||||||
|
DEFAULT_MODEL = models_tts_voice.ElevenLabsVoiceModel(
|
||||||
|
voice_id="eleven_multilingual_v2",
|
||||||
|
languages=[
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.ENGLISH,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.JAPANESE,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.CHINESE,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.GERMAN,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.HINDI,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.FRENCH,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.KOREAN,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.PORTUGUESE,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.ITALIAN,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.SPANISH,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.INDONESIAN,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.DUTCH,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.TURKISH,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.FILIPINO,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.POLISH,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.SWEDISH,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.BULGARIAN,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.ROMANIAN,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.ARABIC,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.CZECH,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.GREEK,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.FINNISH,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.CROATIAN,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.MALAY,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.SLOVAK,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.DANISH,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.TAMIL,
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum.UKRAINIAN,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TTSCreateRequestModel(pydantic.BaseModel):
|
class TTSCreateRequestModel(pydantic.BaseModel):
|
||||||
model_config = pydantic.ConfigDict(use_enum_values=True)
|
model_config = pydantic.ConfigDict(use_enum_values=True)
|
||||||
|
|
||||||
voice_model: AVAILABLE_MODELS_TYPE
|
voice_model: AVAILABLE_MODELS_TYPE = DEFAULT_MODEL
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,8 @@ class BaseVoiceModel(pydantic.BaseModel):
|
||||||
@pydantic.model_validator(mode="before")
|
@pydantic.model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_voice_name_exists(cls, data: typing.Any) -> typing.Any:
|
def check_voice_name_exists(cls, data: typing.Any) -> typing.Any:
|
||||||
|
if not data:
|
||||||
|
return data
|
||||||
voice_id = data.get("voice_id")
|
voice_id = data.get("voice_id")
|
||||||
voice_name = data.get("voice_name")
|
voice_name = data.get("voice_name")
|
||||||
if not voice_name and voice_id:
|
if not voice_name and voice_id:
|
||||||
|
|
|
@ -71,5 +71,13 @@ class ElevenLabsListVoiceModelsModel(pydantic.BaseModel):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_api(cls, voice_models_from_api: list[dict[str, typing.Any]]) -> typing.Self:
|
def from_api(cls, voice_models_from_api: list[dict[str, typing.Any]]) -> typing.Self:
|
||||||
voice_models = [ElevenLabsVoiceModel.model_validate(voice_model) for voice_model in voice_models_from_api]
|
voice_models = []
|
||||||
|
for voice_model in voice_models_from_api:
|
||||||
|
voice_model["voice_id"] = voice_model.pop("model_id")
|
||||||
|
voice_model["voice_name"] = voice_model.pop("name")
|
||||||
|
voice_model["languages"] = [
|
||||||
|
models_tts_languages.ElevenLabsLanguageCodesEnum(item.get("language_id"))
|
||||||
|
for item in voice_model.pop("languages")
|
||||||
|
]
|
||||||
|
voice_models.append(ElevenLabsVoiceModel.model_validate(voice_model))
|
||||||
return ElevenLabsListVoiceModelsModel(models=voice_models)
|
return ElevenLabsListVoiceModelsModel(models=voice_models)
|
||||||
|
|
|
@ -16,6 +16,8 @@ class YandexVoiceModel(models_tts_base.BaseVoiceModel):
|
||||||
@pydantic.model_validator(mode="before")
|
@pydantic.model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_voice_name_exists(cls, data: typing.Any) -> typing.Any:
|
def check_voice_name_exists(cls, data: typing.Any) -> typing.Any:
|
||||||
|
if not data:
|
||||||
|
return data
|
||||||
voice_id = data.get("voice_id")
|
voice_id = data.get("voice_id")
|
||||||
voice_name = data.get("voice_name")
|
voice_name = data.get("voice_name")
|
||||||
role = data.get("role")
|
role = data.get("role")
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
from .services import TTSService
|
from .repositories import *
|
||||||
|
from .services import *
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"TTSBaseRepository",
|
||||||
|
"TTSElevenLabsRepository",
|
||||||
"TTSService",
|
"TTSService",
|
||||||
|
"TTSYandexRepository",
|
||||||
]
|
]
|
||||||
|
|
|
@ -4,11 +4,13 @@ import lib.models as models
|
||||||
|
|
||||||
|
|
||||||
class TTSRepositoryProtocol(typing.Protocol):
|
class TTSRepositoryProtocol(typing.Protocol):
|
||||||
def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel:
|
async def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel:
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_voice_model_by_name(self, voice_model_name: str) -> models.BaseVoiceModel | None:
|
async def get_voice_model_by_name(self, voice_model_name: str) -> models.BaseVoiceModel | None:
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_voice_models_by_fields(self, fields: models.TTSSearchVoiceRequestModel) -> models.LIST_VOICE_MODELS_TYPE:
|
async def get_voice_models_by_fields(
|
||||||
|
self, fields: models.TTSSearchVoiceRequestModel
|
||||||
|
) -> models.LIST_VOICE_MODELS_TYPE:
|
||||||
...
|
...
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
from .base import *
|
from .base import *
|
||||||
|
from .eleven_labs import *
|
||||||
|
from .yandex import *
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TTSBaseRepository",
|
"TTSBaseRepository",
|
||||||
|
"TTSElevenLabsRepository",
|
||||||
|
"TTSYandexRepository",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,37 +1,35 @@
|
||||||
import abc
|
import abc
|
||||||
|
|
||||||
|
import lib.clients as clients
|
||||||
import lib.models as models
|
import lib.models as models
|
||||||
|
|
||||||
|
|
||||||
class HttpClient: # Mocked class todo remove and use real http client from lib.clients.http_client
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class TTSBaseRepository(abc.ABC):
|
class TTSBaseRepository(abc.ABC):
|
||||||
def __init__(self, client: HttpClient, is_models_from_api: bool = False):
|
def __init__(self, client: clients.AsyncHttpClient, is_models_from_api: bool = False):
|
||||||
self.http_client = client
|
self.http_client = client
|
||||||
self.is_models_from_api = is_models_from_api
|
self.is_models_from_api = is_models_from_api
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def voice_models(self) -> models.LIST_VOICE_MODELS_TYPE:
|
async def voice_models(self) -> models.LIST_VOICE_MODELS_TYPE:
|
||||||
...
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel:
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_voice_model_by_name(self, voice_model_name: str) -> models.BaseVoiceModel | None:
|
@abc.abstractmethod
|
||||||
|
async def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_voice_model_by_name(self, voice_model_name: str) -> models.BaseVoiceModel | None:
|
||||||
"""
|
"""
|
||||||
Search voice model by name
|
Search voice model by name
|
||||||
:param voice_model_name: String name
|
:param voice_model_name: String name
|
||||||
:return: Voice model that match the name
|
:return: Voice model that match the name
|
||||||
"""
|
"""
|
||||||
for voice_model in self.voice_models.models:
|
voice_models = await self.voice_models
|
||||||
|
for voice_model in voice_models.models:
|
||||||
if voice_model.voice_name == voice_model_name:
|
if voice_model.voice_name == voice_model_name:
|
||||||
return voice_model
|
return voice_model
|
||||||
|
|
||||||
def get_list_voice_models_by_fields(
|
async def get_list_voice_models_by_fields(
|
||||||
self, fields: models.TTSSearchVoiceRequestModel
|
self, fields: models.TTSSearchVoiceRequestModel
|
||||||
) -> list[models.AVAILABLE_MODELS_TYPE]:
|
) -> list[models.AVAILABLE_MODELS_TYPE]:
|
||||||
"""
|
"""
|
||||||
|
@ -41,7 +39,8 @@ class TTSBaseRepository(abc.ABC):
|
||||||
"""
|
"""
|
||||||
fields_dump = fields.model_dump(exclude_none=True)
|
fields_dump = fields.model_dump(exclude_none=True)
|
||||||
voice_models_response = []
|
voice_models_response = []
|
||||||
for voice_model in self.voice_models.models:
|
voice_models = await self.voice_models
|
||||||
|
for voice_model in voice_models.models:
|
||||||
for field, field_value in fields_dump.items():
|
for field, field_value in fields_dump.items():
|
||||||
if field == "languages": # language is a list
|
if field == "languages": # language is a list
|
||||||
language_names: set[str] = {item.name for item in field_value}
|
language_names: set[str] = {item.name for item in field_value}
|
||||||
|
|
42
src/assistant/lib/tts/repositories/eleven_labs.py
Normal file
42
src/assistant/lib/tts/repositories/eleven_labs.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import lib.app.split_settings as app_split_settings
|
||||||
|
import lib.clients as clients
|
||||||
|
import lib.models as models
|
||||||
|
import lib.tts.repositories.base as tts_repositories_base
|
||||||
|
|
||||||
|
|
||||||
|
class TTSElevenLabsRepository(tts_repositories_base.TTSBaseRepository):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tts_settings: app_split_settings.TTSElevenLabsSettings,
|
||||||
|
client: clients.AsyncHttpClient,
|
||||||
|
is_models_from_api: bool = False,
|
||||||
|
):
|
||||||
|
self.tts_settings = tts_settings
|
||||||
|
super().__init__(client, is_models_from_api)
|
||||||
|
|
||||||
|
@property
|
||||||
|
async def voice_models(self) -> models.ElevenLabsListVoiceModelsModel:
|
||||||
|
if self.is_models_from_api:
|
||||||
|
return models.ElevenLabsListVoiceModelsModel.from_api(await self.get_all_models_dict_from_api())
|
||||||
|
return models.ElevenLabsListVoiceModelsModel()
|
||||||
|
|
||||||
|
async def get_all_models_dict_from_api(self) -> list[dict[str, typing.Any]]:
|
||||||
|
response = await self.http_client.get("/models")
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel:
|
||||||
|
if not isinstance(request.voice_model, models.ElevenLabsVoiceModel):
|
||||||
|
raise ValueError("ElevenLabs TTS support only ElevenLabsVoiceModel")
|
||||||
|
response = await self.http_client.post(
|
||||||
|
f"/text-to-speech/{self.tts_settings.default_voice_id}",
|
||||||
|
json={"text": request.text, "model_id": request.voice_model.voice_id},
|
||||||
|
)
|
||||||
|
return models.TTSCreateResponseModel(audio_content=response.content)
|
||||||
|
|
||||||
|
async def get_voice_models_by_fields(
|
||||||
|
self, fields: models.TTSSearchVoiceRequestModel
|
||||||
|
) -> models.ElevenLabsListVoiceModelsModel:
|
||||||
|
list_voice_models = await self.get_list_voice_models_by_fields(fields)
|
||||||
|
return models.ElevenLabsListVoiceModelsModel(models=list_voice_models) # type: ignore
|
48
src/assistant/lib/tts/repositories/yandex.py
Normal file
48
src/assistant/lib/tts/repositories/yandex.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import lib.app.split_settings as app_split_settings
|
||||||
|
import lib.clients as clients
|
||||||
|
import lib.models as models
|
||||||
|
import lib.tts.repositories.base as tts_repositories_base
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TTSYandexRepository(tts_repositories_base.TTSBaseRepository):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tts_settings: app_split_settings.TTSYandexSettings,
|
||||||
|
client: clients.AsyncHttpClient,
|
||||||
|
is_models_from_api: bool = False,
|
||||||
|
):
|
||||||
|
self.tts_settings = tts_settings
|
||||||
|
if is_models_from_api:
|
||||||
|
logger.warning("Yandex TTS doesn't support getting models from API")
|
||||||
|
super().__init__(client, is_models_from_api=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
async def voice_models(self) -> models.YandexListVoiceModelsModel:
|
||||||
|
return models.YandexListVoiceModelsModel()
|
||||||
|
|
||||||
|
async def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel:
|
||||||
|
if not isinstance(request.voice_model, models.YandexVoiceModel):
|
||||||
|
raise ValueError("Yandex TTS support only YandexVoiceModel")
|
||||||
|
data = {
|
||||||
|
"text": request.text,
|
||||||
|
"lang": request.voice_model.languages[0].value,
|
||||||
|
"voice": request.voice_model.voice_id,
|
||||||
|
"emotion": request.voice_model.role,
|
||||||
|
"format": self.tts_settings.audio_format,
|
||||||
|
"sampleRateHertz": self.tts_settings.sample_rate_hertz,
|
||||||
|
}
|
||||||
|
response = await self.http_client.post(
|
||||||
|
"/tts:synthesize",
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
return models.TTSCreateResponseModel(audio_content=response.content)
|
||||||
|
|
||||||
|
async def get_voice_models_by_fields(
|
||||||
|
self, fields: models.TTSSearchVoiceRequestModel
|
||||||
|
) -> models.YandexListVoiceModelsModel:
|
||||||
|
list_voice_models = await self.get_list_voice_models_by_fields(fields)
|
||||||
|
return models.YandexListVoiceModelsModel(models=list_voice_models) # type: ignore
|
|
@ -1,35 +1,33 @@
|
||||||
import lib.app.settings as app_settings
|
import lib.models as _models
|
||||||
import lib.models as models
|
|
||||||
import lib.tts.models as tts_models
|
import lib.tts.models as tts_models
|
||||||
|
|
||||||
|
|
||||||
class TTSService:
|
class TTSService:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
settings: app_settings.Settings,
|
repositories: dict[_models.VoiceModelProvidersEnum, tts_models.TTSRepositoryProtocol],
|
||||||
repositories: dict[models.VoiceModelProvidersEnum, tts_models.TTSRepositoryProtocol],
|
|
||||||
):
|
):
|
||||||
self.settings = settings
|
|
||||||
self.repositories = repositories
|
self.repositories = repositories
|
||||||
|
|
||||||
def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel:
|
async def get_audio_as_bytes(self, request: _models.TTSCreateRequestModel) -> _models.TTSCreateResponseModel:
|
||||||
model = request.voice_model
|
model = request.voice_model
|
||||||
repository = self.repositories[model.provider]
|
repository = self.repositories[model.provider]
|
||||||
audio_response = repository.get_audio_as_bytes(request)
|
audio_response = await repository.get_audio_as_bytes(request)
|
||||||
return audio_response
|
return audio_response
|
||||||
|
|
||||||
def get_voice_model_by_name(self, voice_model_name: str) -> models.BaseVoiceModel | None:
|
async def get_voice_model_by_name(self, voice_model_name: str) -> _models.BaseVoiceModel | None:
|
||||||
for repository in self.repositories.values():
|
for repository in self.repositories.values():
|
||||||
voice_model = repository.get_voice_model_by_name(voice_model_name)
|
voice_model = await repository.get_voice_model_by_name(voice_model_name)
|
||||||
if voice_model:
|
if voice_model:
|
||||||
return voice_model
|
return voice_model
|
||||||
|
raise ValueError("Voice model not found")
|
||||||
|
|
||||||
def get_list_voice_models_by_fields(
|
async def get_list_voice_models_by_fields(
|
||||||
self, fields: models.TTSSearchVoiceRequestModel
|
self, fields: _models.TTSSearchVoiceRequestModel
|
||||||
) -> list[models.AVAILABLE_MODELS_TYPE]:
|
) -> list[_models.AVAILABLE_MODELS_TYPE]:
|
||||||
response_models: list[models.AVAILABLE_MODELS_TYPE] = []
|
response_models: list[_models.AVAILABLE_MODELS_TYPE] = []
|
||||||
for repository in self.repositories.values():
|
for repository in self.repositories.values():
|
||||||
voice_models = repository.get_voice_models_by_fields(fields)
|
voice_models = await repository.get_voice_models_by_fields(fields)
|
||||||
if voice_models.models:
|
if voice_models.models:
|
||||||
response_models.extend(voice_models.models)
|
response_models.extend(voice_models.models)
|
||||||
return response_models
|
return response_models
|
||||||
|
|
Loading…
Reference in New Issue
Block a user