From 6189aaf39a14d7c3fc57d26bb86b5aed47c9c94c Mon Sep 17 00:00:00 2001 From: matnad Date: Thu, 21 Feb 2019 19:56:03 +0100 Subject: [PATCH] a lot of work on scaling the bot --- cogs/poll.py | 121 +++++++++++++++++++++------------- cogs/poll_controls.py | 82 +++++++++++++++++------ pollmaster.py | 20 ++++-- utils/asyncio_unique_queue.py | 11 ++++ 4 files changed, 163 insertions(+), 71 deletions(-) create mode 100644 utils/asyncio_unique_queue.py diff --git a/cogs/poll.py b/cogs/poll.py index 461baed..7fd3d94 100644 --- a/cogs/poll.py +++ b/cogs/poll.py @@ -43,6 +43,8 @@ class Poll: if channel is None: channel = ctx.message.channel + self.id = None + self.author = ctx.message.author self.server = server @@ -840,6 +842,7 @@ class Poll: return None async def from_dict(self, d): + self.id = d['_id'] self.server = self.bot.get_server(str(d['server_id'])) self.channel = self.bot.get_channel(str(d['channel_id'])) self.author = await self.bot.get_user_info(str(d['author'])) @@ -1104,7 +1107,7 @@ class Poll: else: return sum([1 for c in [u for u in self.votes] if option in self.votes[c]['choices']]) - async def vote(self, user, option, message): + async def vote(self, user, option, message, lock): if not await self.is_open(): # refresh to show closed poll await self.bot.edit_message(message, embed=await self.generate_embed()) @@ -1114,7 +1117,7 @@ class Poll: return choice = 'invalid' - refresh_poll = True + # refresh_poll = True # get weight weight = 1 @@ -1129,55 +1132,69 @@ class Poll: else: self.votes[user.id]['weight'] = weight - if self.reaction: - if self.options_reaction_default: - if option in self.options_reaction: - choice = self.options_reaction.index(option) - else: - if option in AZ_EMOJIS: - choice = AZ_EMOJIS.index(option) - - if choice != 'invalid': - if self.multiple_choice != 1: # more than 1 choice (0 = no limit) - if choice in self.votes[user.id]['choices']: - if self.anonymous: - # anonymous multiple choice -> can't unreact so we toggle with react - await self.unvote(user, option, message) - return - refresh_poll = False - else: - if self.multiple_choice > 0 and self.votes[user.id]['choices'].__len__() >= self.multiple_choice: - say_text = f'You have reached the **maximum choices of {self.multiple_choice}** for this poll. ' \ - f'Before you can vote again, you need to unvote one of your choices.' - embed = discord.Embed(title='', description=say_text, colour=SETTINGS.color) - embed.set_author(name='Pollmaster', icon_url=SETTINGS.author_icon) - await self.bot.send_message(user, embed=embed) - refresh_poll = False - else: - self.votes[user.id]['choices'].append(choice) - self.votes[user.id]['choices'] = list(set(self.votes[user.id]['choices'])) - else: - if [choice] == self.votes[user.id]['choices']: - refresh_poll = False - if self.anonymous: - # undo anonymous vote - await self.unvote(user, option, message) - return - else: - self.votes[user.id]['choices'] = [choice] + if self.options_reaction_default: + if option in self.options_reaction: + choice = self.options_reaction.index(option) else: - pass + if option in AZ_EMOJIS: + choice = AZ_EMOJIS.index(option) + + if choice != 'invalid': + # if self.multiple_choice != 1: # more than 1 choice (0 = no limit) + if choice in self.votes[user.id]['choices']: + if self.anonymous: + # anonymous multiple choice -> can't unreact so we toggle with react + await self.unvote(user, option, message, lock) + # refresh_poll = False + else: + if self.multiple_choice > 0 and self.votes[user.id]['choices'].__len__() >= self.multiple_choice: + say_text = f'You have reached the **maximum choices of {self.multiple_choice}** for this poll. ' \ + f'Before you can vote again, you need to unvote one of your choices.' + embed = discord.Embed(title='', description=say_text, colour=SETTINGS.color) + embed.set_author(name='Pollmaster', icon_url=SETTINGS.author_icon) + await self.bot.send_message(user, embed=embed) + # refresh_poll = False + else: + self.votes[user.id]['choices'].append(choice) + self.votes[user.id]['choices'] = list(set(self.votes[user.id]['choices'])) + # else: + # if [choice] == self.votes[user.id]['choices']: + # # refresh_poll = False + # # if self.anonymous: + # # undo anonymous vote + # await self.unvote(user, option, message, lock) + # return + # else: + # self.votes[user.id]['choices'] = [choice] + + else: + # unknow emoji + return # commit - await self.save_to_db() + if lock._waiters.__len__() == 0: + # updating DB, clearing cache and refresh if necessary + await self.save_to_db() + await self.bot.poll_refresh_q.put_unique_id( + {'id': self.id, 'msg': message, 'sid': self.server.id, 'label': self.short, 'lock': lock}) + if self.bot.poll_cache.get(str(self.server.id) + self.short): + del self.bot.poll_cache[str(self.server.id) + self.short] + # refresh + # if refresh_poll: + # edit message if there is a real change + # await self.bot.edit_message(message, embed=await self.generate_embed()) + # self.bot.poll_refresh_q.append(str(self.id)) + else: + # cache the poll until the queue is empty + self.bot.poll_cache[str(self.server.id)+self.short] = self - # refresh - if refresh_poll: - # edit message if there is a real change - await self.bot.edit_message(message, embed=await self.generate_embed()) + # if refresh_poll: + # await self.bot.poll_refresh_q.put_unique_id({'id': self.id, 'msg': message, 'sid': self.server.id, 'label': self.short, 'lock': lock}) - async def unvote(self, user, option, message): + + + async def unvote(self, user, option, message, lock): if not await self.is_open(): # refresh to show closed poll await self.bot.edit_message(message, embed=await self.generate_embed()) @@ -1189,6 +1206,7 @@ class Poll: if str(user.id) not in self.votes: return choice = 'invalid' + if self.options_reaction_default: if option in self.options_reaction: choice = self.options_reaction.index(option) @@ -1199,8 +1217,17 @@ class Poll: if choice != 'invalid' and choice in self.votes[user.id]['choices']: try: self.votes[user.id]['choices'].remove(choice) - await self.save_to_db() - await self.bot.edit_message(message, embed=await self.generate_embed()) + if lock._waiters.__len__() == 0: + # updating DB, clearing cache and refreshing message + await self.save_to_db() + await self.bot.poll_refresh_q.put_unique_id( + {'id': self.id, 'msg': message, 'sid': self.server.id, 'label': self.short, 'lock': lock}) + if self.bot.poll_cache.get(str(self.server.id) + self.short): + del self.bot.poll_cache[str(self.server.id) + self.short] + # await self.bot.edit_message(message, embed=await self.generate_embed()) + else: + # cache the poll until the queue is empty + self.bot.poll_cache[str(self.server.id) + self.short] = self except ValueError: pass diff --git a/cogs/poll_controls.py b/cogs/poll_controls.py index 93e3e22..80ff1e5 100644 --- a/cogs/poll_controls.py +++ b/cogs/poll_controls.py @@ -5,6 +5,7 @@ import datetime import json import logging import shlex +import traceback import discord import pytz @@ -23,8 +24,30 @@ from essentials.exceptions import StopWizard class PollControls: def __init__(self, bot): self.bot = bot + self.bot.loop.create_task(self.refresh_polls()) + self.ignore_next_removed_reaction = {} + + # General Methods + async def refresh_polls(self): + """This function runs every 5 seconds to refresh poll messages when needed""" + while True: + try: + for i in range(self.bot.poll_refresh_q.qsize()): + values = await self.bot.poll_refresh_q.get() + if values.get('lock') and not values.get('lock')._waiters: + p = await Poll.load_from_db(self.bot, str(values.get('sid')), values.get('label')) + if p: + await self.bot.edit_message(values.get('msg'), embed=await p.generate_embed()) + + self.bot.poll_refresh_q.task_done() + else: + await self.bot.poll_refresh_q.put_unique_id(values) + except AttributeError: + pass + await asyncio.sleep(5) + def get_lock(self, server_id): if not self.bot.locks.get(str(server_id)): self.bot.locks[server_id] = asyncio.Lock() @@ -455,10 +478,17 @@ class PollControls: if not emoji: return - # check if we can find a poll label + # check if removed by the bot.. this is a bit hacky but discord doesn't provide the correct info... message_id = data.get('message_id') - channel_id = data.get('channel_id') user_id = data.get('user_id') + if self.ignore_next_removed_reaction.get(str(message_id)+str(emoji)) == user_id: + del self.ignore_next_removed_reaction[str(message_id)+str(emoji)] + return + + + # check if we can find a poll label + channel_id = data.get('channel_id') + channel = self.bot.get_channel(channel_id) user = await self.bot.get_user_info(user_id) # only do this once if not channel: @@ -485,15 +515,20 @@ class PollControls: server = await ask_for_server(self.bot, user_msg, label) # this is exclusive - async with self.get_lock(server.id): - p = await Poll.load_from_db(self.bot, server.id, label) + + lock = self.get_lock(server.id) + async with lock: + # try to load poll form cache + p = self.bot.poll_cache.get(str(server.id) + label) + if not p: + p = await Poll.load_from_db(self.bot, server.id, label) if not isinstance(p, Poll): return if not p.anonymous: # for anonymous polls we can't unvote because we need to hide reactions member = server.get_member(user_id) - await p.unvote(member, emoji, message) + await p.unvote(member, emoji, message, lock) async def do_on_reaction_add(self, data): @@ -539,8 +574,11 @@ class PollControls: # this is exclusive to keep database access sequential # hopefully it will scale well enough or I need a different solution - async with self.get_lock(server.id): - p = await Poll.load_from_db(self.bot, server.id, label) + lock = self.get_lock(server.id) + async with lock: + p = self.bot.poll_cache.get(str(server.id)+label) + if not p: + p = await Poll.load_from_db(self.bot, server.id, label) if not isinstance(p, Poll): return @@ -623,20 +661,26 @@ class PollControls: f'at least one of these roles can vote:\n{", ".join(p.roles)}') return + # check if we need to remove reactions (this will trigger on_reaction_remove) + if str(channel.type) != 'private' and p.anonymous: + # immediately remove reaction and to be safe, remove all reactions + self.ignore_next_removed_reaction[str(message.id)+str(emoji)] = user_id + await self.bot.remove_reaction(message, emoji, user) + # order here is crucial since we can't determine if a reaction was removed by the bot or user # update database with vote - await p.vote(member, emoji, message) - # - # check if we need to remove reactions (this will trigger on_reaction_remove) - if str(channel.type) != 'private': - if p.anonymous: - # immediately remove reaction and to be safe, remove all reactions - await self.bot.remove_reaction(message, emoji, user) - elif p.multiple_choice == 1: - # remove all other reactions - for r in message.reactions: - if r.emoji and r.emoji != emoji: - await self.bot.remove_reaction(message, r.emoji, user) + await p.vote(member, emoji, message, lock) + + # cant do this until we figure out how to see who removed the reaction? + # for now MC 1 is like MC x + # if str(channel.type) != 'private' and p.multiple_choice == 1: + # # remove all other reactions + # # if lock._waiters.__len__() == 0: + # for r in message.reactions: + # if r.emoji and r.emoji != emoji: + # await self.bot.remove_reaction(message, r.emoji, user) + # pass + def setup(bot): global logger diff --git a/pollmaster.py b/pollmaster.py index 8a02aba..fcd1600 100644 --- a/pollmaster.py +++ b/pollmaster.py @@ -1,4 +1,5 @@ import asyncio +import sys import traceback import logging import aiohttp @@ -9,6 +10,7 @@ from motor.motor_asyncio import AsyncIOMotorClient from essentials.multi_server import get_pre from essentials.settings import SETTINGS +from utils.asyncio_unique_queue import UniqueQueue from utils.import_old_database import import_old_database bot_config = { @@ -76,12 +78,19 @@ async def on_ready(): print("Problem verifying servers.") # cache prefixes - bot.pre = {entry['_id']:entry['prefix'] async for entry in bot.db.config.find({}, {'_id', 'prefix'})} - bot.locks = {sid: asyncio.Lock() for sid in [s.id for s in bot.servers]} + bot.pre = {entry['_id']: entry['prefix'] async for entry in bot.db.config.find({}, {'_id', 'prefix'})} + + # global locks and caches for performance when voting rapidly + bot.locks = {} + bot.poll_cache = {} + # bot.poll_refresh_q = {} + bot.poll_refresh_q = UniqueQueue() print("Servers verified. Bot running.") + + @bot.event async def on_command_error(e, ctx): if SETTINGS.log_errors: @@ -102,8 +111,6 @@ async def on_command_error(e, ctx): # log error logger.error(f'{type(e).__name__}: {e}\n{"".join(traceback.format_tb(e.__traceback__))}') - if SETTINGS.mode == 'development': - raise e traceback.print_exception(type(e), e, e.__traceback__, file=sys.stderr) if SETTINGS.msg_errors: @@ -117,6 +124,9 @@ async def on_command_error(e, ctx): ) await bot.send_message(bot.owner, embed=e) + # if SETTINGS.mode == 'development': + raise e + @bot.event async def on_server_join(server): @@ -130,4 +140,4 @@ async def on_server_join(server): bot.pre[str(server.id)] = 'pm!' -bot.run(SETTINGS.bot_token, reconnect=True) \ No newline at end of file +bot.run(SETTINGS.bot_token, reconnect=True) diff --git a/utils/asyncio_unique_queue.py b/utils/asyncio_unique_queue.py new file mode 100644 index 0000000..ee8aa36 --- /dev/null +++ b/utils/asyncio_unique_queue.py @@ -0,0 +1,11 @@ +import asyncio + + +class UniqueQueue(asyncio.Queue): + + async def put_unique_id(self, item): + if not item.get('id'): + return + + if item.get('id') not in [v.get('id') for v in self._queue]: + await self.put(item) \ No newline at end of file