From 741a470d8921b9a98ed1301da815ac044164a018 Mon Sep 17 00:00:00 2001 From: ksieuk Date: Tue, 10 Oct 2023 18:33:49 +0300 Subject: [PATCH] feat: [#28] add base repository --- src/assistant/lib/tts/repositories/base.py | 59 ++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 src/assistant/lib/tts/repositories/base.py diff --git a/src/assistant/lib/tts/repositories/base.py b/src/assistant/lib/tts/repositories/base.py new file mode 100644 index 0000000..238b873 --- /dev/null +++ b/src/assistant/lib/tts/repositories/base.py @@ -0,0 +1,59 @@ +import abc + +import lib.models as models +import lib.tts.models as tts_models + + +class HttpClient: # Mocked class todo remove and use real http client from lib.clients.http_client + ... + + +class TTSBaseRepository(abc.ABC): + def __init__(self, client: HttpClient, is_models_from_api: bool = False): + self.http_client = client + self.is_models_from_api = is_models_from_api + + @property + @abc.abstractmethod + def voice_models(self) -> tts_models.LIST_VOICE_MODELS_TYPE: + ... + + @abc.abstractmethod + def get_audio_as_bytes_from_text(self, text: str) -> models.TTSCreateResponseModel: + raise NotImplementedError + + def get_voice_model_by_name(self, voice_model_name: str) -> models.BaseVoiceModel | None: + """ + Search voice model by name + :param voice_model_name: String name + :return: Voice model that match the name + """ + for voice_model in self.voice_models.models: + if voice_model.voice_name == voice_model_name: + return voice_model + + def get_list_voice_models_by_fields( + self, fields: models.TTSSearchVoiceRequestModel + ) -> list[tts_models.VOICE_MODELS_TYPE]: + """ + Search voice model by fields + :param fields: Any fields from TTSSearchVoiceRequestModel + :return: All voice models that match the fields + """ + fields_dump = fields.model_dump(exclude_none=True) + voice_models_response = [] + for voice_model in self.voice_models.models: + voice_model: tts_models.VOICE_MODELS_TYPE + for field, field_value in fields_dump.items(): + if field == "languages": # language is a list + language_names: set[str] = {item.name for item in field_value} + voice_model_language_names: set[str] = {item.name for item in voice_model.languages} + if language_names.issubset(voice_model_language_names): + continue + break + voice_model_dump = voice_model.model_dump() + if voice_model_dump[field] != field_value.name: + break + else: + voice_models_response.append(voice_model) + return voice_models_response # type: ignore[reportUnknownVariableType]