mirror of
https://github.com/ijaric/voice_assistant.git
synced 2025-12-17 11:46:20 +00:00
feat: review fixes
This commit is contained in:
@@ -61,6 +61,7 @@ class Application:
|
||||
# Repositories
|
||||
|
||||
logger.info("Initializing repositories")
|
||||
stt_repository: stt.STTProtocol = stt.OpenaiSpeechRepository(settings=settings)
|
||||
|
||||
# Caches
|
||||
|
||||
@@ -69,7 +70,7 @@ class Application:
|
||||
# Services
|
||||
|
||||
logger.info("Initializing services")
|
||||
stt_service: stt.STTProtocol = stt.OpenaiSpeech(settings=settings) # type: ignore
|
||||
stt_service: stt.SpeechService = stt.SpeechService(repository=stt_repository) # type: ignore
|
||||
|
||||
# Handlers
|
||||
|
||||
|
||||
@@ -15,3 +15,4 @@ class OpenaiSettings(pydantic_settings.BaseSettings):
|
||||
api_key: pydantic.SecretStr = pydantic.Field(
|
||||
default=..., validation_alias=pydantic.AliasChoices("api_key", "openai_api_key")
|
||||
)
|
||||
stt_model: str = "whisper-1"
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from .orm import Base, IdCreatedUpdatedBaseMixin
|
||||
from .stt_voice import SttVoice
|
||||
from .token import Token
|
||||
|
||||
__all__ = ["Base", "IdCreatedUpdatedBaseMixin", "SttVoice", "Token"]
|
||||
__all__ = [
|
||||
"Base",
|
||||
"IdCreatedUpdatedBaseMixin",
|
||||
"Token"]
|
||||
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from .openai_speech import *
|
||||
from .stt_protocol import *
|
||||
from .models import *
|
||||
from .repositories import *
|
||||
from .services import *
|
||||
|
||||
__all__ = [
|
||||
"OpenaiSpeech",
|
||||
"stt_protocol",
|
||||
"OpenaiSpeechRepository",
|
||||
"STTProtocol",
|
||||
"SpeechService",
|
||||
"SttVoice",
|
||||
]
|
||||
|
||||
@@ -8,7 +8,7 @@ import lib.app.split_settings as app_split_settings
|
||||
class SttVoice(pydantic.BaseModel):
|
||||
audio_size: int
|
||||
audio_format: str
|
||||
audio_name: str = "123"
|
||||
audio_name: str = "voice"
|
||||
audio_data: bytes
|
||||
voice_settings: app_split_settings.VoiceSettings
|
||||
|
||||
@@ -5,10 +5,10 @@ import magic
|
||||
import openai
|
||||
|
||||
import lib.app.settings as app_settings
|
||||
import lib.models as models
|
||||
import lib.stt as stt
|
||||
|
||||
|
||||
class OpenaiSpeech:
|
||||
class OpenaiSpeechRepository:
|
||||
def __init__(self, settings: app_settings.Settings):
|
||||
self.settings = settings
|
||||
openai.api_key = self.settings.openai.api_key.get_secret_value()
|
||||
@@ -22,13 +22,13 @@ class OpenaiSpeech:
|
||||
extension = extension.replace(".", "")
|
||||
return extension
|
||||
|
||||
async def recognize(self, audio: bytes) -> str:
|
||||
async def speech_to_text(self, audio: bytes) -> str:
|
||||
file_extension: str | None = self.__get_file_extension_from_bytes(audio)
|
||||
if not file_extension:
|
||||
raise ValueError("File extension is not supported")
|
||||
|
||||
voice: models.SttVoice = models.SttVoice(
|
||||
audio_size=int(len(audio) / 1024),
|
||||
voice: stt.models.SttVoice = stt.models.SttVoice(
|
||||
audio_size=len(audio) // 1024, # audio size in MB,
|
||||
audio_format=file_extension,
|
||||
audio_data=audio,
|
||||
voice_settings=self.settings.voice,
|
||||
@@ -38,10 +38,10 @@ class OpenaiSpeech:
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{file_extension}") as temp_file:
|
||||
temp_file.write(voice.audio_data)
|
||||
temp_file.seek(0)
|
||||
transcript = openai.Audio.transcribe("whisper-1", temp_file) # type: ignore
|
||||
except openai.error.InvalidRequestError as e: # type: ignore
|
||||
transcript = openai.Audio.transcribe(self.settings.openai.stt_model, temp_file) # type: ignore
|
||||
except openai.error.InvalidRequestError as e: # type: ignore[reportGeneralTypeIssues]
|
||||
raise ValueError(f"OpenAI API error: {e}")
|
||||
except openai.error.OpenAIError as e: # type: ignore
|
||||
except openai.error.OpenAIError as e: # type: ignore[reportGeneralTypeIssues]
|
||||
raise ValueError(f"OpenAI API error: {e}")
|
||||
|
||||
return transcript.text # type: ignore
|
||||
return transcript.text # type: ignore[reportUnknownVariableType]
|
||||
14
src/assistant/lib/stt/services.py
Normal file
14
src/assistant/lib/stt/services.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class STTProtocol(Protocol):
|
||||
async def speech_to_text(self, audio: bytes) -> str:
|
||||
...
|
||||
|
||||
|
||||
class SpeechService:
|
||||
def __init__(self, repository: STTProtocol):
|
||||
self.repository = repository
|
||||
|
||||
async def recognize(self, audio: bytes) -> str:
|
||||
return await self.repository.speech_to_text(audio)
|
||||
@@ -1,6 +0,0 @@
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class STTProtocol(Protocol):
|
||||
async def recognize(self, audio: bytes) -> str:
|
||||
...
|
||||
Reference in New Issue
Block a user