diff --git a/bot.py b/bot.py index 97933f0..8560783 100644 --- a/bot.py +++ b/bot.py @@ -5,6 +5,7 @@ from loguru import logger import filters import handlers +import middlewares from data import config bot = Bot(token=config.BOT_TOKEN, parse_mode=types.ParseMode.HTML) @@ -15,6 +16,7 @@ dp = Dispatcher(bot, storage=storage) # noinspection PyUnusedLocal async def on_startup(web_app: web.Application): filters.setup(dp) + middlewares.setup(dp) handlers.user.setup(dp) await dp.bot.delete_webhook() await dp.bot.set_webhook(config.WEBHOOK_URL) diff --git a/handlers/user/__init__.py b/handlers/user/__init__.py index 7e7c090..037e614 100644 --- a/handlers/user/__init__.py +++ b/handlers/user/__init__.py @@ -1,8 +1,10 @@ from aiogram import Dispatcher -from aiogram.dispatcher.filters import CommandStart +from aiogram.dispatcher.filters import CommandStart, CommandHelp +from .help import bot_help from .start import bot_start def setup(dp: Dispatcher): - dp.register_message_handler(start, CommandStart()) + dp.register_message_handler(bot_start, CommandStart()) + dp.register_message_handler(bot_help, CommandHelp()) diff --git a/handlers/user/help.py b/handlers/user/help.py new file mode 100644 index 0000000..105fb38 --- /dev/null +++ b/handlers/user/help.py @@ -0,0 +1,13 @@ +from aiogram import types + +from utils.misc import rate_limit + + +@rate_limit(5, 'help') +async def bot_help(msg: types.Message): + text = [ + 'Список команд: ', + '/start - Начать диалог', + '/help - Получить справку' + ] + await msg.answer('\n'.join(text)) diff --git a/middlewares/__init__.py b/middlewares/__init__.py new file mode 100644 index 0000000..1ea7546 --- /dev/null +++ b/middlewares/__init__.py @@ -0,0 +1,7 @@ +from aiogram import Dispatcher + +from .throttling import ThrottlingMiddleware + + +def setup(dp: Dispatcher): + dp.middleware.setup(ThrottlingMiddleware) diff --git a/middlewares/throttling.py b/middlewares/throttling.py new file mode 100644 index 0000000..3202354 --- /dev/null +++ b/middlewares/throttling.py @@ -0,0 +1,49 @@ +import asyncio + +from aiogram import types, Dispatcher +from aiogram.dispatcher import DEFAULT_RATE_LIMIT +from aiogram.dispatcher.handler import CancelHandler, current_handler +from aiogram.dispatcher.middlewares import BaseMiddleware +from aiogram.utils.exceptions import Throttled + + +class ThrottlingMiddleware(BaseMiddleware): + """ + Simple middleware + """ + + def __init__(self, limit=DEFAULT_RATE_LIMIT, key_prefix='antiflood_'): + self.rate_limit = limit + self.prefix = key_prefix + super(ThrottlingMiddleware, self).__init__() + + # noinspection PyUnusedLocal + async def on_process_message(self, message: types.Message, data: dict): + handler = current_handler.get() + dispatcher = Dispatcher.get_current() + if handler: + limit = getattr(handler, 'throttling_rate_limit', self.rate_limit) + key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}") + else: + limit = self.rate_limit + key = f"{self.prefix}_message" + try: + await dispatcher.throttle(key, rate=limit) + except Throttled as t: + await self.message_throttled(message, t) + raise CancelHandler() + + async def message_throttled(self, message: types.Message, throttled: Throttled): + handler = current_handler.get() + dispatcher = Dispatcher.get_current() + if handler: + key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}") + else: + key = f"{self.prefix}_message" + delta = throttled.rate - throttled.delta + if throttled.exceeded_count <= 2: + await message.reply('Too many requests! ') + await asyncio.sleep(delta) + thr = await dispatcher.check_key(key) + if thr.exceeded_count == throttled.exceeded_count: + await message.reply('Unlocked.') diff --git a/states/__init__.py b/states/__init__.py index e69de29..f9b61db 100644 --- a/states/__init__.py +++ b/states/__init__.py @@ -0,0 +1 @@ +from . import user diff --git a/utils/__init__.py b/utils/__init__.py index e69de29..976134d 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -0,0 +1,3 @@ +from . import db_api +from . import misc +from . import redis diff --git a/utils/misc/__init__.py b/utils/misc/__init__.py new file mode 100644 index 0000000..05dba85 --- /dev/null +++ b/utils/misc/__init__.py @@ -0,0 +1 @@ +from .throttling import rate_limit diff --git a/utils/misc/throttling.py b/utils/misc/throttling.py new file mode 100644 index 0000000..c881c9e --- /dev/null +++ b/utils/misc/throttling.py @@ -0,0 +1,16 @@ +def rate_limit(limit: int, key=None): + """ + Decorator for configuring rate limit and key in different functions. + + :param limit: + :param key: + :return: + """ + + def decorator(func): + setattr(func, 'throttling_rate_limit', limit) + if key: + setattr(func, 'throttling_key', key) + return func + + return decorator