From ac1f3051efe6802e32860f58886cd3e7bec917b3 Mon Sep 17 00:00:00 2001 From: matnad Date: Wed, 20 Feb 2019 16:12:50 +0100 Subject: [PATCH] fix crash, improve !cmd help and errors, add database locks for high traffic servers, cache prefixes, add production/development setting --- cogs/config.py | 1 + cogs/db_api.py | 3 +- cogs/poll_controls.py | 217 +++++++++++++++++++++---------------- essentials/multi_server.py | 11 +- essentials/settings.py | 1 + pollmaster.py | 8 ++ 6 files changed, 139 insertions(+), 102 deletions(-) diff --git a/cogs/config.py b/cogs/config.py index 8cef470..835fa61 100644 --- a/cogs/config.py +++ b/cogs/config.py @@ -29,6 +29,7 @@ class Config: f'If you would like to add a trailing whitespace to the prefix, use `{pre}prefix {pre}\w`.' await self.bot.db.config.update_one({'_id': str(server.id)}, {'$set': {'prefix': str(pre)}}, upsert=True) + self.bot.pre[str(server.id)] = str(pre) await self.bot.say(msg) @commands.command(pass_context=True) diff --git a/cogs/db_api.py b/cogs/db_api.py index 51fe1fc..33ec75f 100644 --- a/cogs/db_api.py +++ b/cogs/db_api.py @@ -20,7 +20,8 @@ class DiscordBotsOrgAPI: while True: logger.info('attempting to post server count') try: - await self.dblpy.post_server_count() + if SETTINGS.mode == 'production': + await self.dblpy.post_server_count() logger.info('posted server count ({})'.format(len(self.bot.servers))) except Exception as e: logger.exception('Failed to post server count\n{}: {}'.format(type(e).__name__, e)) diff --git a/cogs/poll_controls.py b/cogs/poll_controls.py index 9e7a542..cb3d999 100644 --- a/cogs/poll_controls.py +++ b/cogs/poll_controls.py @@ -1,4 +1,5 @@ import argparse +import asyncio import copy import datetime import json @@ -24,6 +25,11 @@ class PollControls: self.bot = bot # General Methods + def get_lock(self, server_id): + if not self.bot.locks.get(str(server_id)): + self.bot.locks[server_id] = asyncio.Lock() + return self.bot.locks.get(str(server_id)) + async def is_admin_or_creator(self, ctx, server, owner_id, error_msg=None): member = server.get_member(ctx.message.author.id) if member.id == owner_id: @@ -288,11 +294,23 @@ class PollControls: return try: - args = parser.parse_args(cmds) + args, unknown_args = parser.parse_known_args(cmds) except SystemExit: await self.say_error(ctx, error_text=helpstring) return + if unknown_args: + error_text = f'**There was an error reading the command line options!**.\n' \ + f'Most likely this is because you didn\'t surround the arguments with double quotes like this: ' \ + f'`{pre}cmd -q "question of the poll" -o "yes, no, maybe"`' \ + f'\n\nHere are the arguments I could not understand:\n' + error_text += '`'+'\n'.join(unknown_args)+'`' + error_text += f'\n\nHere are the arguments which are ok:\n' + error_text += '`' + '\n'.join([f'{k}: {v}' for k, v in vars(args).items()]) + '`' + + await self.say_error(ctx, error_text=error_text, footer_text=f'type `{pre}cmd help` for details.') + return + # pass arguments to the wizard async def route(poll): await poll.set_name(force=args.question) @@ -306,7 +324,7 @@ class PollControls: poll = await self.wizard(ctx, route, server) if poll: - await poll.post_embed() + await poll.post_embed(destination=poll.channel) @commands.command(pass_context=True) @@ -328,7 +346,7 @@ class PollControls: poll = await self.wizard(ctx, route, server) if poll: - await poll.post_embed() + await poll.post_embed(destination=poll.channel) @commands.command(pass_context=True) async def prepare(self, ctx, *, cmd=None): @@ -371,7 +389,7 @@ class PollControls: poll = await self.wizard(ctx, route, server) if poll: - await poll.post_embed() + await poll.post_embed(destination=poll.channel) # The Wizard! async def wizard(self, ctx, route, server): @@ -460,14 +478,17 @@ class PollControls: user_msg = copy.deepcopy(message) user_msg.author = user server = await ask_for_server(self.bot, user_msg, label) - p = await Poll.load_from_db(self.bot, server.id, label) - if not isinstance(p, Poll): - return - if not p.anonymous: - # for anonymous polls we can't unvote because we need to hide reactions - member = server.get_member(user_id) - await p.unvote(member, emoji, message) + # this is exclusive + async with self.get_lock(server.id): + p = await Poll.load_from_db(self.bot, server.id, label) + if not isinstance(p, Poll): + return + + if not p.anonymous: + # for anonymous polls we can't unvote because we need to hide reactions + member = server.get_member(user_id) + await p.unvote(member, emoji, message) async def do_on_reaction_add(self, data): @@ -510,103 +531,107 @@ class PollControls: user_msg = copy.deepcopy(message) user_msg.author = user server = await ask_for_server(self.bot, user_msg, label) - p = await Poll.load_from_db(self.bot, server.id, label) - if not isinstance(p, Poll): - return - # export - if emoji == '📎': - # sending file - file = await p.export() - if file is not None: - await self.bot.send_file( - user, - file, - content='Sending you the requested export of "{}".'.format(p.short) - ) - return + # this is exclusive to keep database access sequential + # hopefully it will scale well enough or I need a different solution + async with self.get_lock(server.id): + p = await Poll.load_from_db(self.bot, server.id, label) + if not isinstance(p, Poll): + return - # info - member = server.get_member(user_id) - if emoji == '❔': - is_open = await p.is_open() - embed = discord.Embed(title=f"Info for the {'CLOSED ' if not is_open else ''}poll \"{p.name}\"", - description='', color=SETTINGS.color) - embed.set_author(name=f" >> {p.short}", icon_url=SETTINGS.author_icon) + # export + if emoji == '📎': + # sending file + file = await p.export() + if file is not None: + await self.bot.send_file( + user, + file, + content='Sending you the requested export of "{}".'.format(p.short) + ) + return - # vote rights - vote_rights = await p.has_required_role(member) - embed.add_field(name=f'{"Can you vote?" if is_open else "Could you vote?"}', - value=f'{"✅" if vote_rights else "❎"}', inline=False) + # info + member = server.get_member(user_id) + if emoji == '❔': + is_open = await p.is_open() + embed = discord.Embed(title=f"Info for the {'CLOSED ' if not is_open else ''}poll \"{p.name}\"", + description='', color=SETTINGS.color) + embed.set_author(name=f" >> {p.short}", icon_url=SETTINGS.author_icon) - # edit rights - edit_rights = False - if str(member.id) == str(p.author): - edit_rights = True - elif member.server_permissions.manage_server: - edit_rights = True - else: - result = await self.bot.db.config.find_one({'_id': str(server.id)}) - if result and result.get('admin_role') in [r.name for r in member.roles]: + # vote rights + vote_rights = await p.has_required_role(member) + embed.add_field(name=f'{"Can you vote?" if is_open else "Could you vote?"}', + value=f'{"✅" if vote_rights else "❎"}', inline=False) + + # edit rights + edit_rights = False + if str(member.id) == str(p.author): edit_rights = True - embed.add_field(name='Can you manage the poll?', value=f'{"✅" if edit_rights else "❎"}', inline=False) + elif member.server_permissions.manage_server: + edit_rights = True + else: + result = await self.bot.db.config.find_one({'_id': str(server.id)}) + if result and result.get('admin_role') in [r.name for r in member.roles]: + edit_rights = True + embed.add_field(name='Can you manage the poll?', value=f'{"✅" if edit_rights else "❎"}', inline=False) - # choices - choices = 'You have not voted yet.' if vote_rights else 'You can\'t vote in this poll.' - if user.id in p.votes: - if p.votes[user.id]['choices'].__len__() > 0: - choices = ', '.join([p.options_reaction[c] for c in p.votes[user.id]['choices']]) - embed.add_field(name=f'{"Your current votes (can be changed as long as the poll is open):" if is_open else "Your final votes:"}', - value=choices, inline=False) + # choices + choices = 'You have not voted yet.' if vote_rights else 'You can\'t vote in this poll.' + if user.id in p.votes: + if p.votes[user.id]['choices'].__len__() > 0: + choices = ', '.join([p.options_reaction[c] for c in p.votes[user.id]['choices']]) + embed.add_field(name=f'{"Your current votes (can be changed as long as the poll is open):" if is_open else "Your final votes:"}', + value=choices, inline=False) - # weight - if vote_rights: - weight = 1 - if p.weights_roles.__len__() > 0: - valid_weights = [p.weights_numbers[p.weights_roles.index(r)] for r in - list(set([n.name for n in member.roles]).intersection(set(p.weights_roles)))] - if valid_weights.__len__() > 0: - weight = max(valid_weights) - else: - weight = 'You can\'t vote in this poll.' - embed.add_field(name='Weight of your votes:', value=weight, inline=False) + # weight + if vote_rights: + weight = 1 + if p.weights_roles.__len__() > 0: + valid_weights = [p.weights_numbers[p.weights_roles.index(r)] for r in + list(set([n.name for n in member.roles]).intersection(set(p.weights_roles)))] + if valid_weights.__len__() > 0: + weight = max(valid_weights) + else: + weight = 'You can\'t vote in this poll.' + embed.add_field(name='Weight of your votes:', value=weight, inline=False) - # time left - deadline = p.get_duration_with_tz() - if not is_open: - time_left = 'This poll is closed.' - elif deadline == 0: - time_left = 'Until manually closed.' - else: - time_left = str(deadline-datetime.datetime.utcnow().replace(tzinfo=pytz.utc)).split('.', 2)[0] + # time left + deadline = p.get_duration_with_tz() + if not is_open: + time_left = 'This poll is closed.' + elif deadline == 0: + time_left = 'Until manually closed.' + else: + time_left = str(deadline-datetime.datetime.utcnow().replace(tzinfo=pytz.utc)).split('.', 2)[0] - embed.add_field(name='Time left in the poll:', value=time_left, inline=False) + embed.add_field(name='Time left in the poll:', value=time_left, inline=False) - await self.bot.send_message(user, embed=embed) - return + await self.bot.send_message(user, embed=embed) + return - # Assume: User wants to vote with reaction - # no rights, terminate function - if not await p.has_required_role(member): - await self.bot.remove_reaction(message, emoji, user) - await self.bot.send_message(user, f'You are not allowed to vote in this poll. Only users with ' - f'at least one of these roles can vote:\n{", ".join(p.roles)}') - return - - # order here is crucial since we can't determine if a reaction was removed by the bot or user - # update database with vote - await p.vote(member, emoji, message) - # - # check if we need to remove reactions (this will trigger on_reaction_remove) - if str(channel.type) != 'private': - if p.anonymous: - # immediately remove reaction and to be safe, remove all reactions + # Assume: User wants to vote with reaction + # no rights, terminate function + if not await p.has_required_role(member): await self.bot.remove_reaction(message, emoji, user) - elif p.multiple_choice == 1: - # remove all other reactions - for r in message.reactions: - if r.emoji and r.emoji != emoji: - await self.bot.remove_reaction(message, r.emoji, user) + await self.bot.send_message(user, f'You are not allowed to vote in this poll. Only users with ' + f'at least one of these roles can vote:\n{", ".join(p.roles)}') + return + + # order here is crucial since we can't determine if a reaction was removed by the bot or user + # update database with vote + await p.vote(member, emoji, message) + # + # check if we need to remove reactions (this will trigger on_reaction_remove) + if str(channel.type) != 'private': + if p.anonymous: + # immediately remove reaction and to be safe, remove all reactions + await self.bot.remove_reaction(message, emoji, user) + elif p.multiple_choice == 1: + # remove all other reactions + for r in message.reactions: + if r.emoji and r.emoji != emoji: + await self.bot.remove_reaction(message, r.emoji, user) def setup(bot): global logger diff --git a/essentials/multi_server.py b/essentials/multi_server.py index 8ff5a78..b345cfc 100644 --- a/essentials/multi_server.py +++ b/essentials/multi_server.py @@ -11,7 +11,7 @@ async def get_pre(bot, message): if str(message.channel.type) == 'private': shared_server_list = await get_servers(bot, message) if shared_server_list.__len__() == 0: - return 'pm!!' + return 'pm!' elif shared_server_list.__len__() == 1: return await get_server_pre(bot, shared_server_list[0]) else: @@ -24,12 +24,13 @@ async def get_pre(bot, message): async def get_server_pre(bot, server): '''Gets the prefix for a server.''' try: - result = await bot.db.config.find_one({'_id': str(server.id)}) + #result = await bot.db.config.find_one({'_id': str(server.id)}) + result = bot.pre[str(server.id)] except AttributeError: return 'pm!' - if not result or not result.get('prefix'): - return 'pm!!' - return result.get('prefix') + if not result: #or not result.get('prefix'): + return 'pm!' + return result #result.get('prefix') async def get_servers(bot, message, short=None): diff --git a/essentials/settings.py b/essentials/settings.py index 0bca921..424dc04 100644 --- a/essentials/settings.py +++ b/essentials/settings.py @@ -22,6 +22,7 @@ class Settings: self.dbl_token = SECRETS.dbl_token self.mongo_db = SECRETS.mongo_db self.bot_token = SECRETS.bot_token + self.mode = SECRETS.mode SETTINGS = Settings() diff --git a/pollmaster.py b/pollmaster.py index 562ba6c..6e9182a 100644 --- a/pollmaster.py +++ b/pollmaster.py @@ -1,3 +1,4 @@ +import asyncio import traceback import logging import aiohttp @@ -74,6 +75,10 @@ async def on_ready(): except: print("Problem verifying servers.") + # cache prefixes + bot.pre = {entry['_id']:entry['prefix'] async for entry in bot.db.config.find({}, {'_id', 'prefix'})} + bot.locks = {sid: asyncio.Lock() for sid in [s.id for s in bot.servers]} + print("Servers verified. Bot running.") @@ -97,6 +102,8 @@ async def on_command_error(e, ctx): # log error logger.error(f'{type(e).__name__}: {e}\n{"".join(traceback.format_tb(e.__traceback__))}') + if SETTINGS.mode == 'development': + raise e if SETTINGS.msg_errors: # send discord message for unexpected errors @@ -119,6 +126,7 @@ async def on_server_join(server): {'$set': {'prefix': 'pm!', 'admin_role': 'polladmin', 'user_role': 'polluser'}}, upsert=True ) + bot.pre[str(server.id)] = 'pm!' bot.run(SETTINGS.bot_token, reconnect=True) \ No newline at end of file