diff --git a/src/assistant/lib/tts/repositories/base.py b/src/assistant/lib/tts/repositories/base.py index 6d1ce15..5978061 100644 --- a/src/assistant/lib/tts/repositories/base.py +++ b/src/assistant/lib/tts/repositories/base.py @@ -11,24 +11,25 @@ class TTSBaseRepository(abc.ABC): @property @abc.abstractmethod - def voice_models(self) -> models.LIST_VOICE_MODELS_TYPE: + async def voice_models(self) -> models.LIST_VOICE_MODELS_TYPE: raise NotImplementedError @abc.abstractmethod - def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel: + async def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel: raise NotImplementedError - def get_voice_model_by_name(self, voice_model_name: str) -> models.BaseVoiceModel | None: + async def get_voice_model_by_name(self, voice_model_name: str) -> models.BaseVoiceModel | None: """ Search voice model by name :param voice_model_name: String name :return: Voice model that match the name """ - for voice_model in self.voice_models.models: + voice_models = await self.voice_models + for voice_model in voice_models.models: if voice_model.voice_name == voice_model_name: return voice_model - def get_list_voice_models_by_fields( + async def get_list_voice_models_by_fields( self, fields: models.TTSSearchVoiceRequestModel ) -> list[models.AVAILABLE_MODELS_TYPE]: """ @@ -38,7 +39,8 @@ class TTSBaseRepository(abc.ABC): """ fields_dump = fields.model_dump(exclude_none=True) voice_models_response = [] - for voice_model in self.voice_models.models: + voice_models = await self.voice_models + for voice_model in voice_models.models: for field, field_value in fields_dump.items(): if field == "languages": # language is a list language_names: set[str] = {item.name for item in field_value}