mirror of
https://github.com/ijaric/voice_assistant.git
synced 2025-05-24 14:33:26 +00:00
feat: [#47] add tts repositories
This commit is contained in:
parent
c9a9abb077
commit
0d5a2c8bae
|
@ -0,0 +1,9 @@
|
||||||
|
from .repositories import *
|
||||||
|
from .services import *
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TTSBaseRepository",
|
||||||
|
"TTSElevenLabsRepository",
|
||||||
|
"TTSService",
|
||||||
|
"TTSYandexRepository",
|
||||||
|
]
|
|
@ -1,5 +1,9 @@
|
||||||
from .base import *
|
from .base import *
|
||||||
|
from .eleven_labs import *
|
||||||
|
from .yandex import *
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TTSBaseRepository",
|
"TTSBaseRepository",
|
||||||
|
"TTSElevenLabsRepository",
|
||||||
|
"TTSYandexRepository",
|
||||||
]
|
]
|
||||||
|
|
43
src/assistant/lib/tts/repositories/eleven_labs.py
Normal file
43
src/assistant/lib/tts/repositories/eleven_labs.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import lib.app.split_settings as app_split_settings
|
||||||
|
import lib.clients as clients
|
||||||
|
import lib.models as models
|
||||||
|
import lib.tts.repositories.base as tts_repositories_base
|
||||||
|
|
||||||
|
|
||||||
|
class TTSElevenLabsRepository(tts_repositories_base.TTSBaseRepository):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tts_settings: app_split_settings.TTSElevenLabsSettings,
|
||||||
|
client: clients.AsyncHttpClient,
|
||||||
|
is_models_from_api: bool = False,
|
||||||
|
):
|
||||||
|
self.tts_settings = tts_settings
|
||||||
|
super().__init__(client, is_models_from_api)
|
||||||
|
|
||||||
|
@property
|
||||||
|
async def voice_models(self) -> models.ElevenLabsListVoiceModelsModel:
|
||||||
|
if self.is_models_from_api:
|
||||||
|
return models.ElevenLabsListVoiceModelsModel.from_api(await self.get_all_models_dict_from_api())
|
||||||
|
return models.ElevenLabsListVoiceModelsModel()
|
||||||
|
|
||||||
|
async def get_all_models_dict_from_api(self) -> list[dict[str, typing.Any]]:
|
||||||
|
response = await self.http_client.get("/models")
|
||||||
|
print(response)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel:
|
||||||
|
if not isinstance(request.voice_model, models.ElevenLabsVoiceModel):
|
||||||
|
raise ValueError("ElevenLabs TTS support only ElevenLabsVoiceModel")
|
||||||
|
response = await self.http_client.post(
|
||||||
|
f"/text-to-speech/{self.tts_settings.default_voice_id}",
|
||||||
|
json={"text": request.text, "model_id": request.voice_model.voice_id},
|
||||||
|
)
|
||||||
|
return models.TTSCreateResponseModel(audio_content=response.content)
|
||||||
|
|
||||||
|
async def get_voice_models_by_fields(
|
||||||
|
self, fields: models.TTSSearchVoiceRequestModel
|
||||||
|
) -> models.ElevenLabsListVoiceModelsModel:
|
||||||
|
list_voice_models = await self.get_list_voice_models_by_fields(fields)
|
||||||
|
return models.ElevenLabsListVoiceModelsModel(models=list_voice_models) # type: ignore
|
48
src/assistant/lib/tts/repositories/yandex.py
Normal file
48
src/assistant/lib/tts/repositories/yandex.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import lib.app.split_settings as app_split_settings
|
||||||
|
import lib.clients as clients
|
||||||
|
import lib.models as models
|
||||||
|
import lib.tts.repositories.base as tts_repositories_base
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TTSYandexRepository(tts_repositories_base.TTSBaseRepository):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tts_settings: app_split_settings.TTSYandexSettings,
|
||||||
|
client: clients.AsyncHttpClient,
|
||||||
|
is_models_from_api: bool = False,
|
||||||
|
):
|
||||||
|
self.tts_settings = tts_settings
|
||||||
|
if is_models_from_api:
|
||||||
|
logger.warning("Yandex TTS doesn't support getting models from API")
|
||||||
|
super().__init__(client, is_models_from_api=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
async def voice_models(self) -> models.YandexListVoiceModelsModel:
|
||||||
|
return models.YandexListVoiceModelsModel()
|
||||||
|
|
||||||
|
async def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel:
|
||||||
|
if not isinstance(request.voice_model, models.YandexVoiceModel):
|
||||||
|
raise ValueError("Yandex TTS support only YandexVoiceModel")
|
||||||
|
data = {
|
||||||
|
"text": request.text,
|
||||||
|
"lang": request.voice_model.languages[0].value,
|
||||||
|
"voice": request.voice_model.voice_id,
|
||||||
|
"emotion": request.voice_model.role,
|
||||||
|
"format": self.tts_settings.audio_format,
|
||||||
|
"sampleRateHertz": self.tts_settings.sample_rate_hertz,
|
||||||
|
}
|
||||||
|
response = await self.http_client.post(
|
||||||
|
"/tts:synthesize",
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
return models.TTSCreateResponseModel(audio_content=response.content)
|
||||||
|
|
||||||
|
async def get_voice_models_by_fields(
|
||||||
|
self, fields: models.TTSSearchVoiceRequestModel
|
||||||
|
) -> models.YandexListVoiceModelsModel:
|
||||||
|
list_voice_models = await self.get_list_voice_models_by_fields(fields)
|
||||||
|
return models.YandexListVoiceModelsModel(models=list_voice_models) # type: ignore
|
Loading…
Reference in New Issue
Block a user