mirror of
				https://github.com/ijaric/voice_assistant.git
				synced 2025-10-31 01:33:25 +00:00 
			
		
		
		
	fix: work in progess: models & migrations & repos
This commit is contained in:
		
							parent
							
								
									b75af6034b
								
							
						
					
					
						commit
						2db4a87bf4
					
				|  | @ -6,7 +6,6 @@ from sqlalchemy.engine import Connection | ||||||
| from sqlalchemy.ext.asyncio import async_engine_from_config | from sqlalchemy.ext.asyncio import async_engine_from_config | ||||||
| 
 | 
 | ||||||
| import lib.app.settings as app_settings | import lib.app.settings as app_settings | ||||||
| import lib.models as models |  | ||||||
| import lib.orm_models as orm_models | import lib.orm_models as orm_models | ||||||
| from alembic import context | from alembic import context | ||||||
| 
 | 
 | ||||||
|  | @ -24,7 +23,7 @@ print("BASE: ", orm_models.Base.metadata.schema) | ||||||
| for t in orm_models.Base.metadata.sorted_tables: | for t in orm_models.Base.metadata.sorted_tables: | ||||||
|     print(t.name) |     print(t.name) | ||||||
| 
 | 
 | ||||||
| target_metadata = models.Base.metadata | target_metadata = orm_models.Base.metadata | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def run_migrations_offline() -> None: | def run_migrations_offline() -> None: | ||||||
|  |  | ||||||
|  | @ -24,7 +24,7 @@ def upgrade() -> None: | ||||||
|     op.create_table( |     op.create_table( | ||||||
|         "chat_history", |         "chat_history", | ||||||
|         sa.Column("id", sa.Uuid(), nullable=False), |         sa.Column("id", sa.Uuid(), nullable=False), | ||||||
|         sa.Column("session_id", sa.String(), nullable=False), |         sa.Column("session_id", sa.Uuid(), nullable=False), | ||||||
|         sa.Column("channel", sa.String(), nullable=False), |         sa.Column("channel", sa.String(), nullable=False), | ||||||
|         sa.Column("user_id", sa.String(), nullable=False), |         sa.Column("user_id", sa.String(), nullable=False), | ||||||
|         sa.Column("content", sa.JSON(), nullable=False), |         sa.Column("content", sa.JSON(), nullable=False), | ||||||
|  | @ -33,39 +33,6 @@ def upgrade() -> None: | ||||||
|         sa.PrimaryKeyConstraint("id"), |         sa.PrimaryKeyConstraint("id"), | ||||||
|         schema="content", |         schema="content", | ||||||
|     ) |     ) | ||||||
|     op.drop_table("auth_group") |  | ||||||
|     op.drop_table("auth_user_groups") |  | ||||||
|     op.drop_table("auth_group_permissions") |  | ||||||
|     op.drop_table("auth_user_user_permissions") |  | ||||||
|     op.drop_table("auth_user") |  | ||||||
|     op.drop_table("django_content_type") |  | ||||||
|     op.drop_table("auth_permission") |  | ||||||
|     op.drop_table("django_session") |  | ||||||
|     op.drop_table("django_admin_log") |  | ||||||
|     op.drop_table("django_migrations") |  | ||||||
|     op.alter_column("film_work", "title", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=False) |  | ||||||
|     op.alter_column("film_work", "description", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=True) |  | ||||||
|     op.alter_column("film_work", "creation_date", existing_type=sa.DATE(), type_=sa.DateTime(), existing_nullable=True) |  | ||||||
|     op.alter_column("film_work", "file_path", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=True) |  | ||||||
|     op.alter_column("film_work", "type", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=False) |  | ||||||
|     op.alter_column("film_work", "created", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False) |  | ||||||
|     op.alter_column("film_work", "modified", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False) |  | ||||||
|     op.alter_column("genre", "name", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=False) |  | ||||||
|     op.alter_column("genre", "description", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=True) |  | ||||||
|     op.alter_column("genre", "created", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False) |  | ||||||
|     op.alter_column("genre", "modified", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False) |  | ||||||
|     op.create_foreign_key(None, "genre_film_work", "genre", ["genre_id"], ["id"], referent_schema="content") |  | ||||||
|     op.create_foreign_key(None, "genre_film_work", "film_work", ["film_work_id"], ["id"], referent_schema="content") |  | ||||||
|     op.alter_column("person", "full_name", existing_type=sa.TEXT(), type_=sa.String(), existing_nullable=False) |  | ||||||
|     op.alter_column("person", "created", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False) |  | ||||||
|     op.alter_column("person", "modified", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False) |  | ||||||
|     op.alter_column( |  | ||||||
|         "person_film_work", "role", existing_type=sa.TEXT(), type_=sa.String(length=50), existing_nullable=False |  | ||||||
|     ) |  | ||||||
|     op.create_foreign_key(None, "person_film_work", "film_work", ["film_work_id"], ["id"], referent_schema="content") |  | ||||||
|     op.create_foreign_key(None, "person_film_work", "person", ["person_id"], ["id"], referent_schema="content") |  | ||||||
|     op.drop_column("person_film_work", "id") |  | ||||||
|     # ### end Alembic commands ### |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def downgrade() -> None: | def downgrade() -> None: | ||||||
|  |  | ||||||
|  | @ -15,7 +15,7 @@ class ChatHistoryRepository: | ||||||
|         self.logger = logging.getLogger(__name__) |         self.logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
|     async def get_last_session_id(self, request: models.RequestLastSessionId) -> uuid.UUID | None: |     async def get_last_session_id(self, request: models.RequestLastSessionId) -> uuid.UUID | None: | ||||||
|         """Get a new session ID.""" |         """Get a current session ID if exists.""" | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             async with self.pg_async_session() as session: |             async with self.pg_async_session() as session: | ||||||
|  | @ -23,10 +23,7 @@ class ChatHistoryRepository: | ||||||
|                     sa.select(orm_models.ChatHistory) |                     sa.select(orm_models.ChatHistory) | ||||||
|                     .filter_by(channel=request.channel, user_id=request.user_id) |                     .filter_by(channel=request.channel, user_id=request.user_id) | ||||||
|                     .filter( |                     .filter( | ||||||
|                         ( |                         (sa.text("NOW()") - sa.func.extract("epoch", orm_models.ChatHistory.created)) / 60 | ||||||
|                             sa.func.extract("epoch", orm_models.ChatHistory.created) |  | ||||||
|                             - sa.func.extract("epoch", orm_models.ChatHistory.modified) / 60 |  | ||||||
|                         ) |  | ||||||
|                         <= request.minutes_ago |                         <= request.minutes_ago | ||||||
|                     ) |                     ) | ||||||
|                     .order_by(orm_models.ChatHistory.created.desc()) |                     .order_by(orm_models.ChatHistory.created.desc()) | ||||||
|  | @ -39,3 +36,48 @@ class ChatHistoryRepository: | ||||||
|                     return chat_session.id |                     return chat_session.id | ||||||
|         except sqlalchemy.exc.SQLAlchemyError as error: |         except sqlalchemy.exc.SQLAlchemyError as error: | ||||||
|             self.logger.exception("Error: %s", error) |             self.logger.exception("Error: %s", error) | ||||||
|  | 
 | ||||||
|  |     async def get_messages_by_sid(self, request: models.RequestChatHistory): | ||||||
|  |         """Get all messages of a chat by session ID.""" | ||||||
|  | 
 | ||||||
|  |         try: | ||||||
|  |             async with self.pg_async_session() as session: | ||||||
|  |                 statement = ( | ||||||
|  |                     sa.select(orm_models.ChatHistory) | ||||||
|  |                     .filter_by(id=request.session_id) | ||||||
|  |                     .order_by(orm_models.ChatHistory.created.desc()) | ||||||
|  |                 ) | ||||||
|  |                 result = await session.execute(statement) | ||||||
|  |                 for row in result.scalars().all(): | ||||||
|  |                     print("Row: ", row) | ||||||
|  |         except sqlalchemy.exc.SQLAlchemyError as error: | ||||||
|  |             self.logger.exception("Error: %s", error) | ||||||
|  | 
 | ||||||
|  |     # async def get_all_by_session_id(self, request: models.RequestChatHistory) -> list[models.ChatHistory]: | ||||||
|  |     #     try: | ||||||
|  |     #         async with self.pg_async_session() as session: | ||||||
|  |     #             statement = ( | ||||||
|  |     #                 sa.select(orm_models.ChatHistory) | ||||||
|  |     #                 .filter_by(id=request.session_id) | ||||||
|  |     #                 .order_by(orm_models.ChatHistory.created.desc()) | ||||||
|  |     #             ) | ||||||
|  |     #             result = await session.execute(statement) | ||||||
|  | 
 | ||||||
|  |     #             return [models.ChatHistory.from_orm(chat_history) for chat_history in result.scalars().all()] | ||||||
|  | 
 | ||||||
|  |     async def add_message(self, request: models.ChatMessage) -> None: | ||||||
|  |         """Add a message to the chat history.""" | ||||||
|  |         try: | ||||||
|  |             async with self.pg_async_session() as session: | ||||||
|  |                 chat_history = orm_models.ChatHistory( | ||||||
|  |                     id=uuid.uuid4(), | ||||||
|  |                     session_id=request.session_id, | ||||||
|  |                     user_id=request.user_id, | ||||||
|  |                     channel=request.channel, | ||||||
|  |                     content=request.message, | ||||||
|  |                 ) | ||||||
|  |                 session.add(chat_history) | ||||||
|  |                 await session.commit() | ||||||
|  |                 # TODO: Add refresh to session and return added object | ||||||
|  |         except sqlalchemy.exc.SQLAlchemyError as error: | ||||||
|  |             self.logger.exception("Error: %s", error) | ||||||
|  |  | ||||||
|  | @ -1,3 +1,5 @@ | ||||||
|  | import uuid | ||||||
|  | 
 | ||||||
| import fastapi | import fastapi | ||||||
| 
 | 
 | ||||||
| import lib.agent as agent | import lib.agent as agent | ||||||
|  | @ -15,9 +17,27 @@ class AgentHandler: | ||||||
|             summary="Статус работоспособности", |             summary="Статус работоспособности", | ||||||
|             description="Проверяет доступность сервиса FastAPI.", |             description="Проверяет доступность сервиса FastAPI.", | ||||||
|         ) |         ) | ||||||
|  |         self.router.add_api_route( | ||||||
|  |             "/add", | ||||||
|  |             self.add_message, | ||||||
|  |             methods=["GET"], | ||||||
|  |             summary="Статус работоспособности", | ||||||
|  |             description="Проверяет доступность сервиса FastAPI.", | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     async def get_agent(self): |     async def get_agent(self): | ||||||
|         request = models.RequestLastSessionId(channel="test", user_id="test", minutes_ago=3) |         request = models.RequestLastSessionId(channel="test", user_id="user_id_1", minutes_ago=3) | ||||||
|         response = await self.chat_history_repository.get_last_session_id(request=request) |         response = await self.chat_history_repository.get_last_session_id(request=request) | ||||||
|         print("RESPONSE: ", response) |         print("RESPONSE: ", response) | ||||||
|         return {"response": response} |         return {"response": response} | ||||||
|  | 
 | ||||||
|  |     async def add_message(self): | ||||||
|  |         sid: uuid.UUID = uuid.UUID("0cd3c882-affd-4929-aff1-e1724f5b54f2") | ||||||
|  |         import faker | ||||||
|  |         fake = faker.Faker() | ||||||
|  |          | ||||||
|  |         message = models.ChatMessage( | ||||||
|  |             session_id=sid, user_id="user_id_1", channel="test", message={"role": "system", "content": fake.sentence()} | ||||||
|  |         ) | ||||||
|  |         await self.chat_history_repository.add_message(request=message) | ||||||
|  |         return {"response": "ok"} | ||||||
|  |  | ||||||
|  | @ -1,6 +1,6 @@ | ||||||
| from .chat_history import RequestLastSessionId | from .chat_history import ChatMessage, RequestChatHistory, RequestLastSessionId | ||||||
| from .embedding import Embedding | from .embedding import Embedding | ||||||
| from .movies import Movie | from .movies import Movie | ||||||
| from .token import Token | from .token import Token | ||||||
| 
 | 
 | ||||||
| __all__ = ["Embedding", "Movie", "RequestLastSessionId", "Token"] | __all__ = ["ChatMessage", "Embedding", "Movie", "RequestChatHistory", "RequestLastSessionId", "Token"] | ||||||
|  |  | ||||||
|  | @ -1,3 +1,5 @@ | ||||||
|  | import uuid | ||||||
|  | 
 | ||||||
| import pydantic | import pydantic | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -7,3 +9,18 @@ class RequestLastSessionId(pydantic.BaseModel): | ||||||
|     channel: str |     channel: str | ||||||
|     user_id: str |     user_id: str | ||||||
|     minutes_ago: int |     minutes_ago: int | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ChatMessage(pydantic.BaseModel): | ||||||
|  |     """A chat message.""" | ||||||
|  | 
 | ||||||
|  |     session_id: uuid.UUID | ||||||
|  |     user_id: str | ||||||
|  |     channel: str | ||||||
|  |     message: dict[str, str] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class RequestChatHistory(pydantic.BaseModel): | ||||||
|  |     """Request for chat history.""" | ||||||
|  | 
 | ||||||
|  |     session_id: uuid.UUID | ||||||
|  |  | ||||||
|  | @ -1,5 +1,6 @@ | ||||||
| from .base import Base, IdCreatedUpdatedBaseMixin | from .base import Base, IdCreatedUpdatedBaseMixin | ||||||
| from .movies import ChatHistory, FilmWork, Genre, GenreFilmWork, Person, PersonFilmWork | from .chat_history import ChatHistory | ||||||
|  | from .movies import FilmWork, Genre, GenreFilmWork, Person, PersonFilmWork | ||||||
| 
 | 
 | ||||||
| __all__ = [ | __all__ = [ | ||||||
|     "Base", |     "Base", | ||||||
|  |  | ||||||
							
								
								
									
										24
									
								
								src/assistant/lib/orm_models/chat_history.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								src/assistant/lib/orm_models/chat_history.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,24 @@ | ||||||
|  | import datetime | ||||||
|  | import uuid | ||||||
|  | 
 | ||||||
|  | import sqlalchemy as sa | ||||||
|  | import sqlalchemy.orm as sa_orm | ||||||
|  | import sqlalchemy.sql as sa_sql | ||||||
|  | 
 | ||||||
|  | import lib.orm_models.base as base_models | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ChatHistory(base_models.Base): | ||||||
|  |     __tablename__: str = "chat_history"  # type: ignore[reportIncompatibleVariableOverride] | ||||||
|  | 
 | ||||||
|  |     id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4) | ||||||
|  |     session_id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(nullable=False, unique=True) | ||||||
|  |     channel: sa_orm.Mapped[str] = sa_orm.mapped_column() | ||||||
|  |     user_id: sa_orm.Mapped[str] = sa_orm.mapped_column() | ||||||
|  |     content: sa_orm.Mapped[sa.JSON] = sa_orm.mapped_column(sa.JSON) | ||||||
|  |     created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column( | ||||||
|  |         sa.DateTime(timezone=True), server_default=sa_sql.func.now() | ||||||
|  |     ) | ||||||
|  |     modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column( | ||||||
|  |         sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now() | ||||||
|  |     ) | ||||||
|  | @ -74,19 +74,3 @@ PersonFilmWork = sa.Table( | ||||||
|     sa.Column("role", sa.String(50), nullable=False), |     sa.Column("role", sa.String(50), nullable=False), | ||||||
|     sa.Column("created", sa.DateTime(timezone=True), server_default=sa_sql.func.now()), |     sa.Column("created", sa.DateTime(timezone=True), server_default=sa_sql.func.now()), | ||||||
| ) | ) | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class ChatHistory(base_models.Base): |  | ||||||
|     __tablename__: str = "chat_history"  # type: ignore[reportIncompatibleVariableOverride] |  | ||||||
| 
 |  | ||||||
|     id: sa_orm.Mapped[uuid.UUID] = sa_orm.mapped_column(primary_key=True, default=uuid.uuid4) |  | ||||||
|     session_id: sa_orm.Mapped[str] = sa_orm.mapped_column() |  | ||||||
|     channel: sa_orm.Mapped[str] = sa_orm.mapped_column() |  | ||||||
|     user_id: sa_orm.Mapped[str] = sa_orm.mapped_column() |  | ||||||
|     content: sa_orm.Mapped[sa.JSON] = sa_orm.mapped_column(sa.JSON) |  | ||||||
|     created: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column( |  | ||||||
|         sa.DateTime(timezone=True), server_default=sa_sql.func.now() |  | ||||||
|     ) |  | ||||||
|     modified: sa_orm.Mapped[datetime.datetime] = sa_orm.mapped_column( |  | ||||||
|         sa.DateTime(timezone=True), server_default=sa_sql.func.now(), onupdate=sa_sql.func.now() |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
							
								
								
									
										30
									
								
								src/assistant/poetry.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										30
									
								
								src/assistant/poetry.lock
									
									
									
										generated
									
									
									
								
							|  | @ -529,6 +529,20 @@ files = [ | ||||||
| dnspython = ">=2.0.0" | dnspython = ">=2.0.0" | ||||||
| idna = ">=2.0.0" | idna = ">=2.0.0" | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "faker" | ||||||
|  | version = "19.10.0" | ||||||
|  | description = "Faker is a Python package that generates fake data for you." | ||||||
|  | optional = false | ||||||
|  | python-versions = ">=3.8" | ||||||
|  | files = [ | ||||||
|  |     {file = "Faker-19.10.0-py3-none-any.whl", hash = "sha256:f321e657ed61616fbfe14dbb9ccc6b2e8282652bbcfcb503c1bd0231ff834df6"}, | ||||||
|  |     {file = "Faker-19.10.0.tar.gz", hash = "sha256:63da90512d0cb3acdb71bd833bb3071cb8a196020d08b8567a01d232954f1820"}, | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | [package.dependencies] | ||||||
|  | python-dateutil = ">=2.4" | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "fastapi" | name = "fastapi" | ||||||
| version = "0.103.1" | version = "0.103.1" | ||||||
|  | @ -1622,6 +1636,20 @@ pluggy = ">=0.12,<2.0" | ||||||
| [package.extras] | [package.extras] | ||||||
| testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] | testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "python-dateutil" | ||||||
|  | version = "2.8.2" | ||||||
|  | description = "Extensions to the standard Python datetime module" | ||||||
|  | optional = false | ||||||
|  | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" | ||||||
|  | files = [ | ||||||
|  |     {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, | ||||||
|  |     {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | [package.dependencies] | ||||||
|  | six = ">=1.5" | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "python-dotenv" | name = "python-dotenv" | ||||||
| version = "1.0.0" | version = "1.0.0" | ||||||
|  | @ -2244,4 +2272,4 @@ multidict = ">=4.0" | ||||||
| [metadata] | [metadata] | ||||||
| lock-version = "2.0" | lock-version = "2.0" | ||||||
| python-versions = "^3.11" | python-versions = "^3.11" | ||||||
| content-hash = "2d865a52e2e48b9700ca3f14caa3cbbc05c3ad3965866ab397cdd11e74958829" | content-hash = "5212b83adf6d4f20bc25f36c0f79039516153c470fce1975ed8a30596746a113" | ||||||
|  |  | ||||||
|  | @ -39,6 +39,7 @@ python-magic = "^0.4.27" | ||||||
| sqlalchemy = "^2.0.20" | sqlalchemy = "^2.0.20" | ||||||
| uvicorn = "^0.23.2" | uvicorn = "^0.23.2" | ||||||
| wrapt = "^1.15.0" | wrapt = "^1.15.0" | ||||||
|  | faker = "^19.10.0" | ||||||
| 
 | 
 | ||||||
| [tool.poetry.dev-dependencies] | [tool.poetry.dev-dependencies] | ||||||
| black = "^23.7.0" | black = "^23.7.0" | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user