From 1170532c9349235f46e7fe6870e42db9dc1927d6 Mon Sep 17 00:00:00 2001 From: ksieuk Date: Fri, 13 Oct 2023 17:04:23 +0300 Subject: [PATCH] fix: [#47] to async --- src/assistant/lib/tts/services.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/assistant/lib/tts/services.py b/src/assistant/lib/tts/services.py index 1d6a337..13fd340 100644 --- a/src/assistant/lib/tts/services.py +++ b/src/assistant/lib/tts/services.py @@ -1,35 +1,33 @@ -import lib.app.settings as app_settings -import lib.models as models +import lib.models as _models import lib.tts.models as tts_models class TTSService: def __init__( self, - settings: app_settings.Settings, - repositories: dict[models.VoiceModelProvidersEnum, tts_models.TTSRepositoryProtocol], + repositories: dict[_models.VoiceModelProvidersEnum, tts_models.TTSRepositoryProtocol], ): - self.settings = settings self.repositories = repositories - def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel: + async def get_audio_as_bytes(self, request: _models.TTSCreateRequestModel) -> _models.TTSCreateResponseModel: model = request.voice_model repository = self.repositories[model.provider] - audio_response = repository.get_audio_as_bytes(request) + audio_response = await repository.get_audio_as_bytes(request) return audio_response - def get_voice_model_by_name(self, voice_model_name: str) -> models.BaseVoiceModel | None: + async def get_voice_model_by_name(self, voice_model_name: str) -> _models.BaseVoiceModel | None: for repository in self.repositories.values(): - voice_model = repository.get_voice_model_by_name(voice_model_name) + voice_model = await repository.get_voice_model_by_name(voice_model_name) if voice_model: return voice_model + raise ValueError("Voice model not found") - def get_list_voice_models_by_fields( - self, fields: models.TTSSearchVoiceRequestModel - ) -> list[models.AVAILABLE_MODELS_TYPE]: - response_models: list[models.AVAILABLE_MODELS_TYPE] = [] + async def get_list_voice_models_by_fields( + self, fields: _models.TTSSearchVoiceRequestModel + ) -> list[_models.AVAILABLE_MODELS_TYPE]: + response_models: list[_models.AVAILABLE_MODELS_TYPE] = [] for repository in self.repositories.values(): - voice_models = repository.get_voice_models_by_fields(fields) + voice_models = await repository.get_voice_models_by_fields(fields) if voice_models.models: response_models.extend(voice_models.models) return response_models