From 0a61b4ada42f6285303771989f4d5042be1d39c6 Mon Sep 17 00:00:00 2001 From: jsdio Date: Sat, 7 Oct 2023 05:53:49 +0300 Subject: [PATCH] feat: review fixes --- src/assistant/.env.example | 1 + src/assistant/lib/app/app.py | 3 ++- src/assistant/lib/app/split_settings/openai.py | 1 + src/assistant/lib/models/__init__.py | 7 +++++-- src/assistant/lib/stt/__init__.py | 11 +++++++---- .../lib/{models/stt_voice.py => stt/models.py} | 2 +- .../stt/{openai_speech.py => repositories.py} | 18 +++++++++--------- src/assistant/lib/stt/services.py | 14 ++++++++++++++ src/assistant/lib/stt/stt_protocol.py | 6 ------ src/assistant/poetry.lock | 6 ++++++ 10 files changed, 46 insertions(+), 23 deletions(-) rename src/assistant/lib/{models/stt_voice.py => stt/models.py} (96%) rename src/assistant/lib/stt/{openai_speech.py => repositories.py} (70%) create mode 100644 src/assistant/lib/stt/services.py delete mode 100644 src/assistant/lib/stt/stt_protocol.py diff --git a/src/assistant/.env.example b/src/assistant/.env.example index 8e82b73..9a66582 100644 --- a/src/assistant/.env.example +++ b/src/assistant/.env.example @@ -19,3 +19,4 @@ VOICE_MAX_INPUT_SIZE=5120 # 5MB VOICE_MAX_INPUT_SECONDS=30 OPENAI_API_KEY=sk-1234567890 +OPENAI_STT_MODEL=whisper-1 diff --git a/src/assistant/lib/app/app.py b/src/assistant/lib/app/app.py index 5d5cff6..4953002 100644 --- a/src/assistant/lib/app/app.py +++ b/src/assistant/lib/app/app.py @@ -61,6 +61,7 @@ class Application: # Repositories logger.info("Initializing repositories") + stt_repository: stt.STTProtocol = stt.OpenaiSpeechRepository(settings=settings) # Caches @@ -69,7 +70,7 @@ class Application: # Services logger.info("Initializing services") - stt_service: stt.STTProtocol = stt.OpenaiSpeech(settings=settings) # type: ignore + stt_service: stt.SpeechService = stt.SpeechService(repository=stt_repository) # type: ignore # Handlers diff --git a/src/assistant/lib/app/split_settings/openai.py b/src/assistant/lib/app/split_settings/openai.py index be3af92..235b940 100644 --- a/src/assistant/lib/app/split_settings/openai.py +++ b/src/assistant/lib/app/split_settings/openai.py @@ -15,3 +15,4 @@ class OpenaiSettings(pydantic_settings.BaseSettings): api_key: pydantic.SecretStr = pydantic.Field( default=..., validation_alias=pydantic.AliasChoices("api_key", "openai_api_key") ) + stt_model: str = "whisper-1" diff --git a/src/assistant/lib/models/__init__.py b/src/assistant/lib/models/__init__.py index 122a9e5..54bd848 100644 --- a/src/assistant/lib/models/__init__.py +++ b/src/assistant/lib/models/__init__.py @@ -1,5 +1,8 @@ from .orm import Base, IdCreatedUpdatedBaseMixin -from .stt_voice import SttVoice from .token import Token -__all__ = ["Base", "IdCreatedUpdatedBaseMixin", "SttVoice", "Token"] +__all__ = [ + "Base", + "IdCreatedUpdatedBaseMixin", + "Token"] + diff --git a/src/assistant/lib/stt/__init__.py b/src/assistant/lib/stt/__init__.py index 8e82d5e..c5f2b9e 100644 --- a/src/assistant/lib/stt/__init__.py +++ b/src/assistant/lib/stt/__init__.py @@ -1,7 +1,10 @@ -from .openai_speech import * -from .stt_protocol import * +from .models import * +from .repositories import * +from .services import * __all__ = [ - "OpenaiSpeech", - "stt_protocol", + "OpenaiSpeechRepository", + "STTProtocol", + "SpeechService", + "SttVoice", ] diff --git a/src/assistant/lib/models/stt_voice.py b/src/assistant/lib/stt/models.py similarity index 96% rename from src/assistant/lib/models/stt_voice.py rename to src/assistant/lib/stt/models.py index a892a04..ca0604e 100644 --- a/src/assistant/lib/models/stt_voice.py +++ b/src/assistant/lib/stt/models.py @@ -8,7 +8,7 @@ import lib.app.split_settings as app_split_settings class SttVoice(pydantic.BaseModel): audio_size: int audio_format: str - audio_name: str = "123" + audio_name: str = "voice" audio_data: bytes voice_settings: app_split_settings.VoiceSettings diff --git a/src/assistant/lib/stt/openai_speech.py b/src/assistant/lib/stt/repositories.py similarity index 70% rename from src/assistant/lib/stt/openai_speech.py rename to src/assistant/lib/stt/repositories.py index bef5530..61006bf 100644 --- a/src/assistant/lib/stt/openai_speech.py +++ b/src/assistant/lib/stt/repositories.py @@ -5,10 +5,10 @@ import magic import openai import lib.app.settings as app_settings -import lib.models as models +import lib.stt as stt -class OpenaiSpeech: +class OpenaiSpeechRepository: def __init__(self, settings: app_settings.Settings): self.settings = settings openai.api_key = self.settings.openai.api_key.get_secret_value() @@ -22,13 +22,13 @@ class OpenaiSpeech: extension = extension.replace(".", "") return extension - async def recognize(self, audio: bytes) -> str: + async def speech_to_text(self, audio: bytes) -> str: file_extension: str | None = self.__get_file_extension_from_bytes(audio) if not file_extension: raise ValueError("File extension is not supported") - voice: models.SttVoice = models.SttVoice( - audio_size=int(len(audio) / 1024), + voice: stt.models.SttVoice = stt.models.SttVoice( + audio_size=len(audio) // 1024, # audio size in MB, audio_format=file_extension, audio_data=audio, voice_settings=self.settings.voice, @@ -38,10 +38,10 @@ class OpenaiSpeech: with tempfile.NamedTemporaryFile(suffix=f".{file_extension}") as temp_file: temp_file.write(voice.audio_data) temp_file.seek(0) - transcript = openai.Audio.transcribe("whisper-1", temp_file) # type: ignore - except openai.error.InvalidRequestError as e: # type: ignore + transcript = openai.Audio.transcribe(self.settings.openai.stt_model, temp_file) # type: ignore + except openai.error.InvalidRequestError as e: # type: ignore[reportGeneralTypeIssues] raise ValueError(f"OpenAI API error: {e}") - except openai.error.OpenAIError as e: # type: ignore + except openai.error.OpenAIError as e: # type: ignore[reportGeneralTypeIssues] raise ValueError(f"OpenAI API error: {e}") - return transcript.text # type: ignore + return transcript.text # type: ignore[reportUnknownVariableType] diff --git a/src/assistant/lib/stt/services.py b/src/assistant/lib/stt/services.py new file mode 100644 index 0000000..415de57 --- /dev/null +++ b/src/assistant/lib/stt/services.py @@ -0,0 +1,14 @@ +from typing import Protocol + + +class STTProtocol(Protocol): + async def speech_to_text(self, audio: bytes) -> str: + ... + + +class SpeechService: + def __init__(self, repository: STTProtocol): + self.repository = repository + + async def recognize(self, audio: bytes) -> str: + return await self.repository.speech_to_text(audio) diff --git a/src/assistant/lib/stt/stt_protocol.py b/src/assistant/lib/stt/stt_protocol.py deleted file mode 100644 index 87c6d75..0000000 --- a/src/assistant/lib/stt/stt_protocol.py +++ /dev/null @@ -1,6 +0,0 @@ -from typing import Protocol - - -class STTProtocol(Protocol): - async def recognize(self, audio: bytes) -> str: - ... diff --git a/src/assistant/poetry.lock b/src/assistant/poetry.lock index a3c3eea..fb1baf8 100644 --- a/src/assistant/poetry.lock +++ b/src/assistant/poetry.lock @@ -1713,6 +1713,12 @@ files = [ {file = "SQLAlchemy-2.0.21-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b69f1f754d92eb1cc6b50938359dead36b96a1dcf11a8670bff65fd9b21a4b09"}, {file = "SQLAlchemy-2.0.21-cp311-cp311-win32.whl", hash = "sha256:af520a730d523eab77d754f5cf44cc7dd7ad2d54907adeb3233177eeb22f271b"}, {file = "SQLAlchemy-2.0.21-cp311-cp311-win_amd64.whl", hash = "sha256:141675dae56522126986fa4ca713739d00ed3a6f08f3c2eb92c39c6dfec463ce"}, + {file = "SQLAlchemy-2.0.21-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:56628ca27aa17b5890391ded4e385bf0480209726f198799b7e980c6bd473bd7"}, + {file = "SQLAlchemy-2.0.21-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:db726be58837fe5ac39859e0fa40baafe54c6d54c02aba1d47d25536170b690f"}, + {file = "SQLAlchemy-2.0.21-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:632784f7a6f12cfa0e84bf2a5003b07660addccf5563c132cd23b7cc1d7371a9"}, + {file = "SQLAlchemy-2.0.21-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2a1f7ffac934bc0ea717fa1596f938483fb8c402233f9b26679b4f7b38d6ab6e"}, + {file = "SQLAlchemy-2.0.21-cp312-cp312-win32.whl", hash = "sha256:bfece2f7cec502ec5f759bbc09ce711445372deeac3628f6fa1c16b7fb45b682"}, + {file = "SQLAlchemy-2.0.21-cp312-cp312-win_amd64.whl", hash = "sha256:526b869a0f4f000d8d8ee3409d0becca30ae73f494cbb48801da0129601f72c6"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7614f1eab4336df7dd6bee05bc974f2b02c38d3d0c78060c5faa4cd1ca2af3b8"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d59cb9e20d79686aa473e0302e4a82882d7118744d30bb1dfb62d3c47141b3ec"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a95aa0672e3065d43c8aa80080cdd5cc40fe92dc873749e6c1cf23914c4b83af"},