mirror of
https://github.com/ijaric/voice_assistant.git
synced 2025-05-24 14:33:26 +00:00
feat: review fixes
This commit is contained in:
parent
b5191e7601
commit
0a61b4ada4
|
@ -19,3 +19,4 @@ VOICE_MAX_INPUT_SIZE=5120 # 5MB
|
||||||
VOICE_MAX_INPUT_SECONDS=30
|
VOICE_MAX_INPUT_SECONDS=30
|
||||||
|
|
||||||
OPENAI_API_KEY=sk-1234567890
|
OPENAI_API_KEY=sk-1234567890
|
||||||
|
OPENAI_STT_MODEL=whisper-1
|
||||||
|
|
|
@ -61,6 +61,7 @@ class Application:
|
||||||
# Repositories
|
# Repositories
|
||||||
|
|
||||||
logger.info("Initializing repositories")
|
logger.info("Initializing repositories")
|
||||||
|
stt_repository: stt.STTProtocol = stt.OpenaiSpeechRepository(settings=settings)
|
||||||
|
|
||||||
# Caches
|
# Caches
|
||||||
|
|
||||||
|
@ -69,7 +70,7 @@ class Application:
|
||||||
# Services
|
# Services
|
||||||
|
|
||||||
logger.info("Initializing services")
|
logger.info("Initializing services")
|
||||||
stt_service: stt.STTProtocol = stt.OpenaiSpeech(settings=settings) # type: ignore
|
stt_service: stt.SpeechService = stt.SpeechService(repository=stt_repository) # type: ignore
|
||||||
|
|
||||||
# Handlers
|
# Handlers
|
||||||
|
|
||||||
|
|
|
@ -15,3 +15,4 @@ class OpenaiSettings(pydantic_settings.BaseSettings):
|
||||||
api_key: pydantic.SecretStr = pydantic.Field(
|
api_key: pydantic.SecretStr = pydantic.Field(
|
||||||
default=..., validation_alias=pydantic.AliasChoices("api_key", "openai_api_key")
|
default=..., validation_alias=pydantic.AliasChoices("api_key", "openai_api_key")
|
||||||
)
|
)
|
||||||
|
stt_model: str = "whisper-1"
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
from .orm import Base, IdCreatedUpdatedBaseMixin
|
from .orm import Base, IdCreatedUpdatedBaseMixin
|
||||||
from .stt_voice import SttVoice
|
|
||||||
from .token import Token
|
from .token import Token
|
||||||
|
|
||||||
__all__ = ["Base", "IdCreatedUpdatedBaseMixin", "SttVoice", "Token"]
|
__all__ = [
|
||||||
|
"Base",
|
||||||
|
"IdCreatedUpdatedBaseMixin",
|
||||||
|
"Token"]
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
from .openai_speech import *
|
from .models import *
|
||||||
from .stt_protocol import *
|
from .repositories import *
|
||||||
|
from .services import *
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"OpenaiSpeech",
|
"OpenaiSpeechRepository",
|
||||||
"stt_protocol",
|
"STTProtocol",
|
||||||
|
"SpeechService",
|
||||||
|
"SttVoice",
|
||||||
]
|
]
|
||||||
|
|
|
@ -8,7 +8,7 @@ import lib.app.split_settings as app_split_settings
|
||||||
class SttVoice(pydantic.BaseModel):
|
class SttVoice(pydantic.BaseModel):
|
||||||
audio_size: int
|
audio_size: int
|
||||||
audio_format: str
|
audio_format: str
|
||||||
audio_name: str = "123"
|
audio_name: str = "voice"
|
||||||
audio_data: bytes
|
audio_data: bytes
|
||||||
voice_settings: app_split_settings.VoiceSettings
|
voice_settings: app_split_settings.VoiceSettings
|
||||||
|
|
|
@ -5,10 +5,10 @@ import magic
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
import lib.app.settings as app_settings
|
import lib.app.settings as app_settings
|
||||||
import lib.models as models
|
import lib.stt as stt
|
||||||
|
|
||||||
|
|
||||||
class OpenaiSpeech:
|
class OpenaiSpeechRepository:
|
||||||
def __init__(self, settings: app_settings.Settings):
|
def __init__(self, settings: app_settings.Settings):
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
openai.api_key = self.settings.openai.api_key.get_secret_value()
|
openai.api_key = self.settings.openai.api_key.get_secret_value()
|
||||||
|
@ -22,13 +22,13 @@ class OpenaiSpeech:
|
||||||
extension = extension.replace(".", "")
|
extension = extension.replace(".", "")
|
||||||
return extension
|
return extension
|
||||||
|
|
||||||
async def recognize(self, audio: bytes) -> str:
|
async def speech_to_text(self, audio: bytes) -> str:
|
||||||
file_extension: str | None = self.__get_file_extension_from_bytes(audio)
|
file_extension: str | None = self.__get_file_extension_from_bytes(audio)
|
||||||
if not file_extension:
|
if not file_extension:
|
||||||
raise ValueError("File extension is not supported")
|
raise ValueError("File extension is not supported")
|
||||||
|
|
||||||
voice: models.SttVoice = models.SttVoice(
|
voice: stt.models.SttVoice = stt.models.SttVoice(
|
||||||
audio_size=int(len(audio) / 1024),
|
audio_size=len(audio) // 1024, # audio size in MB,
|
||||||
audio_format=file_extension,
|
audio_format=file_extension,
|
||||||
audio_data=audio,
|
audio_data=audio,
|
||||||
voice_settings=self.settings.voice,
|
voice_settings=self.settings.voice,
|
||||||
|
@ -38,10 +38,10 @@ class OpenaiSpeech:
|
||||||
with tempfile.NamedTemporaryFile(suffix=f".{file_extension}") as temp_file:
|
with tempfile.NamedTemporaryFile(suffix=f".{file_extension}") as temp_file:
|
||||||
temp_file.write(voice.audio_data)
|
temp_file.write(voice.audio_data)
|
||||||
temp_file.seek(0)
|
temp_file.seek(0)
|
||||||
transcript = openai.Audio.transcribe("whisper-1", temp_file) # type: ignore
|
transcript = openai.Audio.transcribe(self.settings.openai.stt_model, temp_file) # type: ignore
|
||||||
except openai.error.InvalidRequestError as e: # type: ignore
|
except openai.error.InvalidRequestError as e: # type: ignore[reportGeneralTypeIssues]
|
||||||
raise ValueError(f"OpenAI API error: {e}")
|
raise ValueError(f"OpenAI API error: {e}")
|
||||||
except openai.error.OpenAIError as e: # type: ignore
|
except openai.error.OpenAIError as e: # type: ignore[reportGeneralTypeIssues]
|
||||||
raise ValueError(f"OpenAI API error: {e}")
|
raise ValueError(f"OpenAI API error: {e}")
|
||||||
|
|
||||||
return transcript.text # type: ignore
|
return transcript.text # type: ignore[reportUnknownVariableType]
|
14
src/assistant/lib/stt/services.py
Normal file
14
src/assistant/lib/stt/services.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
|
||||||
|
class STTProtocol(Protocol):
|
||||||
|
async def speech_to_text(self, audio: bytes) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechService:
|
||||||
|
def __init__(self, repository: STTProtocol):
|
||||||
|
self.repository = repository
|
||||||
|
|
||||||
|
async def recognize(self, audio: bytes) -> str:
|
||||||
|
return await self.repository.speech_to_text(audio)
|
|
@ -1,6 +0,0 @@
|
||||||
from typing import Protocol
|
|
||||||
|
|
||||||
|
|
||||||
class STTProtocol(Protocol):
|
|
||||||
async def recognize(self, audio: bytes) -> str:
|
|
||||||
...
|
|
6
src/assistant/poetry.lock
generated
6
src/assistant/poetry.lock
generated
|
@ -1713,6 +1713,12 @@ files = [
|
||||||
{file = "SQLAlchemy-2.0.21-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b69f1f754d92eb1cc6b50938359dead36b96a1dcf11a8670bff65fd9b21a4b09"},
|
{file = "SQLAlchemy-2.0.21-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b69f1f754d92eb1cc6b50938359dead36b96a1dcf11a8670bff65fd9b21a4b09"},
|
||||||
{file = "SQLAlchemy-2.0.21-cp311-cp311-win32.whl", hash = "sha256:af520a730d523eab77d754f5cf44cc7dd7ad2d54907adeb3233177eeb22f271b"},
|
{file = "SQLAlchemy-2.0.21-cp311-cp311-win32.whl", hash = "sha256:af520a730d523eab77d754f5cf44cc7dd7ad2d54907adeb3233177eeb22f271b"},
|
||||||
{file = "SQLAlchemy-2.0.21-cp311-cp311-win_amd64.whl", hash = "sha256:141675dae56522126986fa4ca713739d00ed3a6f08f3c2eb92c39c6dfec463ce"},
|
{file = "SQLAlchemy-2.0.21-cp311-cp311-win_amd64.whl", hash = "sha256:141675dae56522126986fa4ca713739d00ed3a6f08f3c2eb92c39c6dfec463ce"},
|
||||||
|
{file = "SQLAlchemy-2.0.21-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:56628ca27aa17b5890391ded4e385bf0480209726f198799b7e980c6bd473bd7"},
|
||||||
|
{file = "SQLAlchemy-2.0.21-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:db726be58837fe5ac39859e0fa40baafe54c6d54c02aba1d47d25536170b690f"},
|
||||||
|
{file = "SQLAlchemy-2.0.21-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:632784f7a6f12cfa0e84bf2a5003b07660addccf5563c132cd23b7cc1d7371a9"},
|
||||||
|
{file = "SQLAlchemy-2.0.21-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2a1f7ffac934bc0ea717fa1596f938483fb8c402233f9b26679b4f7b38d6ab6e"},
|
||||||
|
{file = "SQLAlchemy-2.0.21-cp312-cp312-win32.whl", hash = "sha256:bfece2f7cec502ec5f759bbc09ce711445372deeac3628f6fa1c16b7fb45b682"},
|
||||||
|
{file = "SQLAlchemy-2.0.21-cp312-cp312-win_amd64.whl", hash = "sha256:526b869a0f4f000d8d8ee3409d0becca30ae73f494cbb48801da0129601f72c6"},
|
||||||
{file = "SQLAlchemy-2.0.21-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7614f1eab4336df7dd6bee05bc974f2b02c38d3d0c78060c5faa4cd1ca2af3b8"},
|
{file = "SQLAlchemy-2.0.21-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7614f1eab4336df7dd6bee05bc974f2b02c38d3d0c78060c5faa4cd1ca2af3b8"},
|
||||||
{file = "SQLAlchemy-2.0.21-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d59cb9e20d79686aa473e0302e4a82882d7118744d30bb1dfb62d3c47141b3ec"},
|
{file = "SQLAlchemy-2.0.21-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d59cb9e20d79686aa473e0302e4a82882d7118744d30bb1dfb62d3c47141b3ec"},
|
||||||
{file = "SQLAlchemy-2.0.21-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a95aa0672e3065d43c8aa80080cdd5cc40fe92dc873749e6c1cf23914c4b83af"},
|
{file = "SQLAlchemy-2.0.21-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a95aa0672e3065d43c8aa80080cdd5cc40fe92dc873749e6c1cf23914c4b83af"},
|
||||||
|
|
Loading…
Reference in New Issue
Block a user