1
0
mirror of https://github.com/ijaric/voice_assistant.git synced 2025-05-24 06:23:28 +00:00

fix: [#47] to async

This commit is contained in:
ksieuk 2023-10-13 17:04:23 +03:00
parent 6ed4928ced
commit 1170532c93

View File

@ -1,35 +1,33 @@
import lib.app.settings as app_settings
import lib.models as models
import lib.models as _models
import lib.tts.models as tts_models
class TTSService:
def __init__(
self,
settings: app_settings.Settings,
repositories: dict[models.VoiceModelProvidersEnum, tts_models.TTSRepositoryProtocol],
repositories: dict[_models.VoiceModelProvidersEnum, tts_models.TTSRepositoryProtocol],
):
self.settings = settings
self.repositories = repositories
def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel:
async def get_audio_as_bytes(self, request: _models.TTSCreateRequestModel) -> _models.TTSCreateResponseModel:
model = request.voice_model
repository = self.repositories[model.provider]
audio_response = repository.get_audio_as_bytes(request)
audio_response = await repository.get_audio_as_bytes(request)
return audio_response
def get_voice_model_by_name(self, voice_model_name: str) -> models.BaseVoiceModel | None:
async def get_voice_model_by_name(self, voice_model_name: str) -> _models.BaseVoiceModel | None:
for repository in self.repositories.values():
voice_model = repository.get_voice_model_by_name(voice_model_name)
voice_model = await repository.get_voice_model_by_name(voice_model_name)
if voice_model:
return voice_model
raise ValueError("Voice model not found")
def get_list_voice_models_by_fields(
self, fields: models.TTSSearchVoiceRequestModel
) -> list[models.AVAILABLE_MODELS_TYPE]:
response_models: list[models.AVAILABLE_MODELS_TYPE] = []
async def get_list_voice_models_by_fields(
self, fields: _models.TTSSearchVoiceRequestModel
) -> list[_models.AVAILABLE_MODELS_TYPE]:
response_models: list[_models.AVAILABLE_MODELS_TYPE] = []
for repository in self.repositories.values():
voice_models = repository.get_voice_models_by_fields(fields)
voice_models = await repository.get_voice_models_by_fields(fields)
if voice_models.models:
response_models.extend(voice_models.models)
return response_models