Merge pull request #9 from matnad/v2.1

fixes and performance
This commit is contained in:
Matthias Nadler 2019-02-20 16:17:34 +01:00 committed by GitHub
commit ae553f2f36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 139 additions and 102 deletions

View File

@ -29,6 +29,7 @@ class Config:
f'If you would like to add a trailing whitespace to the prefix, use `{pre}prefix {pre}\w`.' f'If you would like to add a trailing whitespace to the prefix, use `{pre}prefix {pre}\w`.'
await self.bot.db.config.update_one({'_id': str(server.id)}, {'$set': {'prefix': str(pre)}}, upsert=True) await self.bot.db.config.update_one({'_id': str(server.id)}, {'$set': {'prefix': str(pre)}}, upsert=True)
self.bot.pre[str(server.id)] = str(pre)
await self.bot.say(msg) await self.bot.say(msg)
@commands.command(pass_context=True) @commands.command(pass_context=True)

View File

@ -20,6 +20,7 @@ class DiscordBotsOrgAPI:
while True: while True:
logger.info('attempting to post server count') logger.info('attempting to post server count')
try: try:
if SETTINGS.mode == 'production':
await self.dblpy.post_server_count() await self.dblpy.post_server_count()
logger.info('posted server count ({})'.format(len(self.bot.servers))) logger.info('posted server count ({})'.format(len(self.bot.servers)))
except Exception as e: except Exception as e:

View File

@ -1,4 +1,5 @@
import argparse import argparse
import asyncio
import copy import copy
import datetime import datetime
import json import json
@ -24,6 +25,11 @@ class PollControls:
self.bot = bot self.bot = bot
# General Methods # General Methods
def get_lock(self, server_id):
if not self.bot.locks.get(str(server_id)):
self.bot.locks[server_id] = asyncio.Lock()
return self.bot.locks.get(str(server_id))
async def is_admin_or_creator(self, ctx, server, owner_id, error_msg=None): async def is_admin_or_creator(self, ctx, server, owner_id, error_msg=None):
member = server.get_member(ctx.message.author.id) member = server.get_member(ctx.message.author.id)
if member.id == owner_id: if member.id == owner_id:
@ -288,11 +294,23 @@ class PollControls:
return return
try: try:
args = parser.parse_args(cmds) args, unknown_args = parser.parse_known_args(cmds)
except SystemExit: except SystemExit:
await self.say_error(ctx, error_text=helpstring) await self.say_error(ctx, error_text=helpstring)
return return
if unknown_args:
error_text = f'**There was an error reading the command line options!**.\n' \
f'Most likely this is because you didn\'t surround the arguments with double quotes like this: ' \
f'`{pre}cmd -q "question of the poll" -o "yes, no, maybe"`' \
f'\n\nHere are the arguments I could not understand:\n'
error_text += '`'+'\n'.join(unknown_args)+'`'
error_text += f'\n\nHere are the arguments which are ok:\n'
error_text += '`' + '\n'.join([f'{k}: {v}' for k, v in vars(args).items()]) + '`'
await self.say_error(ctx, error_text=error_text, footer_text=f'type `{pre}cmd help` for details.')
return
# pass arguments to the wizard # pass arguments to the wizard
async def route(poll): async def route(poll):
await poll.set_name(force=args.question) await poll.set_name(force=args.question)
@ -306,7 +324,7 @@ class PollControls:
poll = await self.wizard(ctx, route, server) poll = await self.wizard(ctx, route, server)
if poll: if poll:
await poll.post_embed() await poll.post_embed(destination=poll.channel)
@commands.command(pass_context=True) @commands.command(pass_context=True)
@ -328,7 +346,7 @@ class PollControls:
poll = await self.wizard(ctx, route, server) poll = await self.wizard(ctx, route, server)
if poll: if poll:
await poll.post_embed() await poll.post_embed(destination=poll.channel)
@commands.command(pass_context=True) @commands.command(pass_context=True)
async def prepare(self, ctx, *, cmd=None): async def prepare(self, ctx, *, cmd=None):
@ -371,7 +389,7 @@ class PollControls:
poll = await self.wizard(ctx, route, server) poll = await self.wizard(ctx, route, server)
if poll: if poll:
await poll.post_embed() await poll.post_embed(destination=poll.channel)
# The Wizard! # The Wizard!
async def wizard(self, ctx, route, server): async def wizard(self, ctx, route, server):
@ -460,6 +478,9 @@ class PollControls:
user_msg = copy.deepcopy(message) user_msg = copy.deepcopy(message)
user_msg.author = user user_msg.author = user
server = await ask_for_server(self.bot, user_msg, label) server = await ask_for_server(self.bot, user_msg, label)
# this is exclusive
async with self.get_lock(server.id):
p = await Poll.load_from_db(self.bot, server.id, label) p = await Poll.load_from_db(self.bot, server.id, label)
if not isinstance(p, Poll): if not isinstance(p, Poll):
return return
@ -510,6 +531,10 @@ class PollControls:
user_msg = copy.deepcopy(message) user_msg = copy.deepcopy(message)
user_msg.author = user user_msg.author = user
server = await ask_for_server(self.bot, user_msg, label) server = await ask_for_server(self.bot, user_msg, label)
# this is exclusive to keep database access sequential
# hopefully it will scale well enough or I need a different solution
async with self.get_lock(server.id):
p = await Poll.load_from_db(self.bot, server.id, label) p = await Poll.load_from_db(self.bot, server.id, label)
if not isinstance(p, Poll): if not isinstance(p, Poll):
return return

View File

@ -11,7 +11,7 @@ async def get_pre(bot, message):
if str(message.channel.type) == 'private': if str(message.channel.type) == 'private':
shared_server_list = await get_servers(bot, message) shared_server_list = await get_servers(bot, message)
if shared_server_list.__len__() == 0: if shared_server_list.__len__() == 0:
return 'pm!!' return 'pm!'
elif shared_server_list.__len__() == 1: elif shared_server_list.__len__() == 1:
return await get_server_pre(bot, shared_server_list[0]) return await get_server_pre(bot, shared_server_list[0])
else: else:
@ -24,12 +24,13 @@ async def get_pre(bot, message):
async def get_server_pre(bot, server): async def get_server_pre(bot, server):
'''Gets the prefix for a server.''' '''Gets the prefix for a server.'''
try: try:
result = await bot.db.config.find_one({'_id': str(server.id)}) #result = await bot.db.config.find_one({'_id': str(server.id)})
result = bot.pre[str(server.id)]
except AttributeError: except AttributeError:
return 'pm!' return 'pm!'
if not result or not result.get('prefix'): if not result: #or not result.get('prefix'):
return 'pm!!' return 'pm!'
return result.get('prefix') return result #result.get('prefix')
async def get_servers(bot, message, short=None): async def get_servers(bot, message, short=None):

View File

@ -22,6 +22,7 @@ class Settings:
self.dbl_token = SECRETS.dbl_token self.dbl_token = SECRETS.dbl_token
self.mongo_db = SECRETS.mongo_db self.mongo_db = SECRETS.mongo_db
self.bot_token = SECRETS.bot_token self.bot_token = SECRETS.bot_token
self.mode = SECRETS.mode
SETTINGS = Settings() SETTINGS = Settings()

View File

@ -1,3 +1,4 @@
import asyncio
import traceback import traceback
import logging import logging
import aiohttp import aiohttp
@ -74,6 +75,10 @@ async def on_ready():
except: except:
print("Problem verifying servers.") print("Problem verifying servers.")
# cache prefixes
bot.pre = {entry['_id']:entry['prefix'] async for entry in bot.db.config.find({}, {'_id', 'prefix'})}
bot.locks = {sid: asyncio.Lock() for sid in [s.id for s in bot.servers]}
print("Servers verified. Bot running.") print("Servers verified. Bot running.")
@ -97,6 +102,8 @@ async def on_command_error(e, ctx):
# log error # log error
logger.error(f'{type(e).__name__}: {e}\n{"".join(traceback.format_tb(e.__traceback__))}') logger.error(f'{type(e).__name__}: {e}\n{"".join(traceback.format_tb(e.__traceback__))}')
if SETTINGS.mode == 'development':
raise e
if SETTINGS.msg_errors: if SETTINGS.msg_errors:
# send discord message for unexpected errors # send discord message for unexpected errors
@ -119,6 +126,7 @@ async def on_server_join(server):
{'$set': {'prefix': 'pm!', 'admin_role': 'polladmin', 'user_role': 'polluser'}}, {'$set': {'prefix': 'pm!', 'admin_role': 'polladmin', 'user_role': 'polluser'}},
upsert=True upsert=True
) )
bot.pre[str(server.id)] = 'pm!'
bot.run(SETTINGS.bot_token, reconnect=True) bot.run(SETTINGS.bot_token, reconnect=True)