1
0
mirror of https://github.com/ijaric/voice_assistant.git synced 2025-12-17 11:46:20 +00:00

Merge branch 'main' into tasks/#38_client_http

This commit is contained in:
Artem Litvinov
2023-10-10 20:09:02 +01:00
committed by GitHub
14 changed files with 845 additions and 90 deletions

View File

@@ -11,6 +11,7 @@ import lib.app.errors as app_errors
import lib.app.settings as app_settings
import lib.app.split_settings as app_split_settings
import lib.clients as clients
import lib.stt as stt
logger = logging.getLogger(__name__)
@@ -71,6 +72,7 @@ class Application:
# Repositories
logger.info("Initializing repositories")
stt_repository: stt.STTProtocol = stt.OpenaiSpeechRepository(settings=settings)
# Caches
@@ -79,6 +81,7 @@ class Application:
# Services
logger.info("Initializing services")
stt_service: stt.SpeechService = stt.SpeechService(repository=stt_repository) # type: ignore
# Handlers

View File

@@ -13,7 +13,12 @@ class Settings(pydantic_settings.BaseSettings):
logger: app_split_settings.LoggingSettings = pydantic.Field(
default_factory=lambda: app_split_settings.LoggingSettings()
)
openai: app_split_settings.OpenaiSettings = pydantic.Field(
default_factory=lambda: app_split_settings.OpenaiSettings()
)
project: app_split_settings.ProjectSettings = pydantic.Field(
default_factory=lambda: app_split_settings.ProjectSettings()
)
proxy: app_split_settings.ProxySettings = pydantic.Field(default_factory=lambda: app_split_settings.ProxySettings())
voice: app_split_settings.VoiceSettings = pydantic.Field(default_factory=lambda: app_split_settings.VoiceSettings())

View File

@@ -1,16 +1,21 @@
from .api import *
from .app import *
from .logger import *
from .openai import *
from .postgres import *
from .project import *
from .proxy import *
from .voice import *
__all__ = [
"ApiSettings",
"AppSettings",
"LoggingSettings",
"OpenaiSettings",
"PostgresSettings",
"ProjectSettings",
"ProxySettings",
"VoiceSettings",
"get_logging_config",
]

View File

@@ -5,7 +5,9 @@ import lib.app.split_settings.utils as app_split_settings_utils
class LoggingSettings(pydantic_settings.BaseSettings):
model_config = pydantic_settings.SettingsConfigDict(
env_file=app_split_settings_utils.ENV_PATH, env_file_encoding="utf-8", extra="ignore"
env_file=app_split_settings_utils.ENV_PATH,
env_file_encoding="utf-8",
extra="ignore",
)
log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"

View File

@@ -0,0 +1,18 @@
import pydantic
import pydantic_settings
import lib.app.split_settings.utils as app_split_settings_utils
class OpenaiSettings(pydantic_settings.BaseSettings):
model_config = pydantic_settings.SettingsConfigDict(
env_file=app_split_settings_utils.ENV_PATH,
env_prefix="OPENAI_",
env_file_encoding="utf-8",
extra="ignore",
)
api_key: pydantic.SecretStr = pydantic.Field(
default=..., validation_alias=pydantic.AliasChoices("api_key", "openai_api_key")
)
stt_model: str = "whisper-1"

View File

@@ -0,0 +1,21 @@
import pydantic
import pydantic_settings
import lib.app.split_settings.utils as app_split_settings_utils
class VoiceSettings(pydantic_settings.BaseSettings):
model_config = pydantic_settings.SettingsConfigDict(
env_file=app_split_settings_utils.ENV_PATH,
env_prefix="VOICE_",
env_file_encoding="utf-8",
extra="ignore",
)
max_input_seconds: int = 30
max_input_size: int = 5120 # 5MB
available_formats: str = "wav,mp3,ogg"
@pydantic.field_validator("available_formats")
def validate_available_formats(cls, v: str) -> list[str]:
return v.split(",")

View File

@@ -0,0 +1,10 @@
from .models import *
from .repositories import *
from .services import *
__all__ = [
"OpenaiSpeechRepository",
"STTProtocol",
"SpeechService",
"SttVoice",
]

View File

@@ -0,0 +1,25 @@
import typing
import pydantic
import lib.app.split_settings as app_split_settings
class SttVoice(pydantic.BaseModel):
audio_size: int
audio_format: str
audio_name: str = "voice"
audio_data: bytes
voice_settings: app_split_settings.VoiceSettings
@pydantic.model_validator(mode="before")
@classmethod
def validate_audio(cls, v: dict[str, typing.Any]) -> dict[str, typing.Any]:
settings: app_split_settings.VoiceSettings = v["voice_settings"]
if v["audio_size"] > settings.max_input_size:
raise ValueError(f"Audio size is too big: {v['audio_size']}")
if v["audio_format"] not in settings.available_formats:
raise ValueError(f"Audio format is not supported: {v['audio_format']}")
if "audio_name" not in v or not v["audio_name"]:
v["audio_name"] = f"audio.{v['audio_format']}"
return v

View File

@@ -0,0 +1,47 @@
import mimetypes
import tempfile
import magic
import openai
import lib.app.settings as app_settings
import lib.stt as stt
class OpenaiSpeechRepository:
def __init__(self, settings: app_settings.Settings):
self.settings = settings
openai.api_key = self.settings.openai.api_key.get_secret_value()
@staticmethod
def __get_file_extension_from_bytes(audio: bytes) -> str | None:
mime: magic.Magic = magic.Magic(mime=True)
mime_type: str = mime.from_buffer(audio)
extension: str | None = mimetypes.guess_extension(mime_type)
if extension:
extension = extension.replace(".", "")
return extension
async def speech_to_text(self, audio: bytes) -> str:
file_extension = self.__get_file_extension_from_bytes(audio)
if not file_extension:
raise ValueError("File extension is not supported")
voice: stt.models.SttVoice = stt.models.SttVoice(
audio_size=len(audio) // 1024, # audio size in MB,
audio_format=file_extension,
audio_data=audio,
voice_settings=self.settings.voice,
)
try:
with tempfile.NamedTemporaryFile(suffix=f".{file_extension}") as temp_file:
temp_file.write(voice.audio_data)
temp_file.seek(0)
transcript = openai.Audio.transcribe(self.settings.openai.stt_model, temp_file) # type: ignore
except openai.error.InvalidRequestError as e: # type: ignore[reportGeneralTypeIssues]
raise ValueError(f"OpenAI API error: {e}")
except openai.error.OpenAIError as e: # type: ignore[reportGeneralTypeIssues]
raise ValueError(f"OpenAI API error: {e}")
return transcript.text # type: ignore[reportUnknownVariableType]

View File

@@ -0,0 +1,14 @@
import typing
class STTProtocol(typing.Protocol):
async def speech_to_text(self, audio: bytes) -> str:
...
class SpeechService:
def __init__(self, repository: STTProtocol):
self.repository = repository
async def recognize(self, audio: bytes) -> str:
return await self.repository.speech_to_text(audio)