1
0
mirror of https://github.com/ijaric/voice_assistant.git synced 2025-12-19 14:16:18 +00:00

feat: review fixes

This commit is contained in:
2023-10-07 05:53:49 +03:00
parent b5191e7601
commit 0a61b4ada4
10 changed files with 46 additions and 23 deletions

View File

@@ -1,7 +1,10 @@
from .openai_speech import *
from .stt_protocol import *
from .models import *
from .repositories import *
from .services import *
__all__ = [
"OpenaiSpeech",
"stt_protocol",
"OpenaiSpeechRepository",
"STTProtocol",
"SpeechService",
"SttVoice",
]

View File

@@ -0,0 +1,25 @@
import typing
import pydantic
import lib.app.split_settings as app_split_settings
class SttVoice(pydantic.BaseModel):
audio_size: int
audio_format: str
audio_name: str = "voice"
audio_data: bytes
voice_settings: app_split_settings.VoiceSettings
@pydantic.model_validator(mode="before")
@classmethod
def validate_audio(cls, v: dict[str, typing.Any]) -> dict[str, typing.Any]:
settings: app_split_settings.VoiceSettings = v["voice_settings"]
if v["audio_size"] > settings.max_input_size:
raise ValueError(f"Audio size is too big: {v['audio_size']}")
if v["audio_format"] not in settings.available_formats:
raise ValueError(f"Audio format is not supported: {v['audio_format']}")
if "audio_name" not in v or not v["audio_name"]:
v["audio_name"] = f"audio.{v['audio_format']}"
return v

View File

@@ -5,10 +5,10 @@ import magic
import openai
import lib.app.settings as app_settings
import lib.models as models
import lib.stt as stt
class OpenaiSpeech:
class OpenaiSpeechRepository:
def __init__(self, settings: app_settings.Settings):
self.settings = settings
openai.api_key = self.settings.openai.api_key.get_secret_value()
@@ -22,13 +22,13 @@ class OpenaiSpeech:
extension = extension.replace(".", "")
return extension
async def recognize(self, audio: bytes) -> str:
async def speech_to_text(self, audio: bytes) -> str:
file_extension: str | None = self.__get_file_extension_from_bytes(audio)
if not file_extension:
raise ValueError("File extension is not supported")
voice: models.SttVoice = models.SttVoice(
audio_size=int(len(audio) / 1024),
voice: stt.models.SttVoice = stt.models.SttVoice(
audio_size=len(audio) // 1024, # audio size in MB,
audio_format=file_extension,
audio_data=audio,
voice_settings=self.settings.voice,
@@ -38,10 +38,10 @@ class OpenaiSpeech:
with tempfile.NamedTemporaryFile(suffix=f".{file_extension}") as temp_file:
temp_file.write(voice.audio_data)
temp_file.seek(0)
transcript = openai.Audio.transcribe("whisper-1", temp_file) # type: ignore
except openai.error.InvalidRequestError as e: # type: ignore
transcript = openai.Audio.transcribe(self.settings.openai.stt_model, temp_file) # type: ignore
except openai.error.InvalidRequestError as e: # type: ignore[reportGeneralTypeIssues]
raise ValueError(f"OpenAI API error: {e}")
except openai.error.OpenAIError as e: # type: ignore
except openai.error.OpenAIError as e: # type: ignore[reportGeneralTypeIssues]
raise ValueError(f"OpenAI API error: {e}")
return transcript.text # type: ignore
return transcript.text # type: ignore[reportUnknownVariableType]

View File

@@ -0,0 +1,14 @@
from typing import Protocol
class STTProtocol(Protocol):
async def speech_to_text(self, audio: bytes) -> str:
...
class SpeechService:
def __init__(self, repository: STTProtocol):
self.repository = repository
async def recognize(self, audio: bytes) -> str:
return await self.repository.speech_to_text(audio)

View File

@@ -1,6 +0,0 @@
from typing import Protocol
class STTProtocol(Protocol):
async def recognize(self, audio: bytes) -> str:
...