mirror of
				https://github.com/ijaric/voice_assistant.git
				synced 2025-10-31 01:43:25 +00:00 
			
		
		
		
	Merge 16803a4bb6 into ad9d4fdb4d
				
					
				
			This commit is contained in:
		
						commit
						3f9d5a526b
					
				|  | @ -1,7 +1,9 @@ | ||||||
| from .health import basic_router | from .health import basic_router | ||||||
|  | from .tts import * | ||||||
| from .voice_responce_handler import VoiceResponseHandler | from .voice_responce_handler import VoiceResponseHandler | ||||||
| 
 | 
 | ||||||
| __all__ = [ | __all__ = [ | ||||||
|  |     "TTSHandler", | ||||||
|     "VoiceResponseHandler", |     "VoiceResponseHandler", | ||||||
|     "basic_router", |     "basic_router", | ||||||
| ] | ] | ||||||
|  |  | ||||||
							
								
								
									
										66
									
								
								src/assistant/lib/api/v1/handlers/tts.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								src/assistant/lib/api/v1/handlers/tts.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,66 @@ | ||||||
|  | import http | ||||||
|  | 
 | ||||||
|  | import fastapi | ||||||
|  | 
 | ||||||
|  | import lib.models as models | ||||||
|  | import lib.tts.services as tts_service | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class TTSHandler: | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         tts: tts_service.TTSService, | ||||||
|  |     ): | ||||||
|  |         self.tts = tts | ||||||
|  |         self.router = fastapi.APIRouter() | ||||||
|  |         self.router.add_api_route( | ||||||
|  |             "/fields", | ||||||
|  |             self.get_by_fields, | ||||||
|  |             methods=["POST"], | ||||||
|  |             summary="Получение моделей по полю", | ||||||
|  |             description="Возвращает список моделей с указанными полями", | ||||||
|  |         ) | ||||||
|  |         self.router.add_api_route( | ||||||
|  |             "/name", | ||||||
|  |             self.get_by_name, | ||||||
|  |             methods=["POST"], | ||||||
|  |             summary="Получение модели по имени", | ||||||
|  |             description="Позволяет получить одну модель по её имени", | ||||||
|  |         ) | ||||||
|  |         self.router.add_api_route( | ||||||
|  |             "/", | ||||||
|  |             self.get_all, | ||||||
|  |             methods=["GET"], | ||||||
|  |             summary="Получение всех доступных моделей", | ||||||
|  |             description="Возвращает список всех доступных моделей", | ||||||
|  |         ) | ||||||
|  |         self.router.add_api_route( | ||||||
|  |             "/languages", | ||||||
|  |             self.get_languages, | ||||||
|  |             methods=["GET"], | ||||||
|  |             summary="Получение всех доступных языков", | ||||||
|  |             description="Возвращает список всех доступных языков", | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     async def get_by_fields( | ||||||
|  |         self, | ||||||
|  |         data: models.TTSSearchVoiceRequestModel, | ||||||
|  |     ) -> list[models.AVAILABLE_MODELS_TYPE]: | ||||||
|  |         response = await self.tts.get_list_voice_models_by_fields(data) | ||||||
|  |         return response | ||||||
|  | 
 | ||||||
|  |     async def get_by_name( | ||||||
|  |         self, | ||||||
|  |         model_name: str, | ||||||
|  |     ) -> models.BaseVoiceModel: | ||||||
|  |         response = await self.tts.get_voice_model_by_name(model_name) | ||||||
|  |         if not response: | ||||||
|  |             raise fastapi.HTTPException(status_code=http.HTTPStatus.BAD_REQUEST, detail="Model not found") | ||||||
|  |         return response | ||||||
|  | 
 | ||||||
|  |     async def get_all(self) -> list[models.AVAILABLE_MODELS_TYPE]: | ||||||
|  |         return await self.tts.get_all_models() | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     async def get_languages(cls) -> dict[str, str]: | ||||||
|  |         return {language.name: language.value for language in models.BaseLanguageCodesEnum} | ||||||
|  | @ -107,9 +107,9 @@ class Application: | ||||||
|         # Services |         # Services | ||||||
| 
 | 
 | ||||||
|         logger.info("Initializing services") |         logger.info("Initializing services") | ||||||
|         stt_service: stt.SpeechService = stt.SpeechService(repository=stt_repository)  # type: ignore |         stt_service: stt.SpeechService = stt.SpeechService(repository=stt_repository) | ||||||
| 
 | 
 | ||||||
|         tts_service: tts.TTSService = tts.TTSService(  # type: ignore |         tts_service: tts.TTSService = tts.TTSService( | ||||||
|             repositories={ |             repositories={ | ||||||
|                 models.VoiceModelProvidersEnum.YANDEX: tts_yandex_repository, |                 models.VoiceModelProvidersEnum.YANDEX: tts_yandex_repository, | ||||||
|                 models.VoiceModelProvidersEnum.ELEVEN_LABS: tts_eleven_labs_repository, |                 models.VoiceModelProvidersEnum.ELEVEN_LABS: tts_eleven_labs_repository, | ||||||
|  | @ -127,6 +127,8 @@ class Application: | ||||||
|             # tts=tts_service,  # TODO |             # tts=tts_service,  # TODO | ||||||
|         ).router |         ).router | ||||||
| 
 | 
 | ||||||
|  |         tts_handler = api_v1_handlers.TTSHandler(tts=tts_service).router | ||||||
|  | 
 | ||||||
|         logger.info("Creating application") |         logger.info("Creating application") | ||||||
| 
 | 
 | ||||||
|         fastapi_app = fastapi.FastAPI( |         fastapi_app = fastapi.FastAPI( | ||||||
|  | @ -140,6 +142,7 @@ class Application: | ||||||
|         # Routes |         # Routes | ||||||
|         fastapi_app.include_router(liveness_probe_handler, prefix="/api/v1/health", tags=["health"]) |         fastapi_app.include_router(liveness_probe_handler, prefix="/api/v1/health", tags=["health"]) | ||||||
|         fastapi_app.include_router(voice_response_handler, prefix="/api/v1/voice", tags=["voice"]) |         fastapi_app.include_router(voice_response_handler, prefix="/api/v1/voice", tags=["voice"]) | ||||||
|  |         fastapi_app.include_router(tts_handler, prefix="/api/v1/tts", tags=["tts"]) | ||||||
| 
 | 
 | ||||||
|         application = Application( |         application = Application( | ||||||
|             settings=settings, |             settings=settings, | ||||||
|  |  | ||||||
|  | @ -7,7 +7,6 @@ import lib.models.tts.voice.languages as models_tts_languages | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ElevenLabsVoiceModel(models_tts_base.BaseVoiceModel): | class ElevenLabsVoiceModel(models_tts_base.BaseVoiceModel): | ||||||
|     model_config = pydantic.ConfigDict(use_enum_values=True) |  | ||||||
|     voice_id: str |     voice_id: str | ||||||
|     voice_name: str | None = None |     voice_name: str | None = None | ||||||
|     languages: list[models_tts_languages.LANGUAGE_CODES_ENUM_TYPE] |     languages: list[models_tts_languages.LANGUAGE_CODES_ENUM_TYPE] | ||||||
|  |  | ||||||
|  | @ -4,6 +4,10 @@ import lib.models as models | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TTSRepositoryProtocol(typing.Protocol): | class TTSRepositoryProtocol(typing.Protocol): | ||||||
|  |     @property | ||||||
|  |     async def voice_models(self) -> models.LIST_VOICE_MODELS_TYPE: | ||||||
|  |         raise NotImplementedError | ||||||
|  | 
 | ||||||
|     async def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel: |     async def get_audio_as_bytes(self, request: models.TTSCreateRequestModel) -> models.TTSCreateResponseModel: | ||||||
|         ... |         ... | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -49,7 +49,7 @@ class TTSBaseRepository(abc.ABC): | ||||||
|                         continue |                         continue | ||||||
|                     break |                     break | ||||||
|                 voice_model_dump = voice_model.model_dump() |                 voice_model_dump = voice_model.model_dump() | ||||||
|                 if voice_model_dump[field] != field_value.name: |                 if voice_model_dump[field] != field_value: | ||||||
|                     break |                     break | ||||||
|             else: |             else: | ||||||
|                 voice_models_response.append(voice_model) |                 voice_models_response.append(voice_model) | ||||||
|  |  | ||||||
|  | @ -31,3 +31,14 @@ class TTSService: | ||||||
|             if voice_models.models: |             if voice_models.models: | ||||||
|                 response_models.extend(voice_models.models) |                 response_models.extend(voice_models.models) | ||||||
|         return response_models |         return response_models | ||||||
|  | 
 | ||||||
|  |     async def get_all_models(self) -> list[_models.AVAILABLE_MODELS_TYPE]: | ||||||
|  |         response_models: list[_models.AVAILABLE_MODELS_TYPE] = [] | ||||||
|  |         for repository in self.repositories.values(): | ||||||
|  |             response = await repository.voice_models | ||||||
|  |             for model in response.models: | ||||||
|  |                 model.languages = [  # type: ignore | ||||||
|  |                     _models.BaseLanguageCodesEnum[language.name] for language in model.languages | ||||||
|  |                 ] | ||||||
|  |                 response_models.append(model) | ||||||
|  |         return response_models | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user