Merge pull request #11 from matnad/v2.1.5

a lot of work on scaling the bot
This commit is contained in:
Matthias Nadler 2019-02-21 19:57:07 +01:00 committed by GitHub
commit c7431d6b95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 163 additions and 71 deletions

View File

@ -43,6 +43,8 @@ class Poll:
if channel is None:
channel = ctx.message.channel
self.id = None
self.author = ctx.message.author
self.server = server
@ -840,6 +842,7 @@ class Poll:
return None
async def from_dict(self, d):
self.id = d['_id']
self.server = self.bot.get_server(str(d['server_id']))
self.channel = self.bot.get_channel(str(d['channel_id']))
self.author = await self.bot.get_user_info(str(d['author']))
@ -1104,7 +1107,7 @@ class Poll:
else:
return sum([1 for c in [u for u in self.votes] if option in self.votes[c]['choices']])
async def vote(self, user, option, message):
async def vote(self, user, option, message, lock):
if not await self.is_open():
# refresh to show closed poll
await self.bot.edit_message(message, embed=await self.generate_embed())
@ -1114,7 +1117,7 @@ class Poll:
return
choice = 'invalid'
refresh_poll = True
# refresh_poll = True
# get weight
weight = 1
@ -1129,55 +1132,69 @@ class Poll:
else:
self.votes[user.id]['weight'] = weight
if self.reaction:
if self.options_reaction_default:
if option in self.options_reaction:
choice = self.options_reaction.index(option)
else:
if option in AZ_EMOJIS:
choice = AZ_EMOJIS.index(option)
if choice != 'invalid':
if self.multiple_choice != 1: # more than 1 choice (0 = no limit)
if choice in self.votes[user.id]['choices']:
if self.anonymous:
# anonymous multiple choice -> can't unreact so we toggle with react
await self.unvote(user, option, message)
return
refresh_poll = False
else:
if self.multiple_choice > 0 and self.votes[user.id]['choices'].__len__() >= self.multiple_choice:
say_text = f'You have reached the **maximum choices of {self.multiple_choice}** for this poll. ' \
f'Before you can vote again, you need to unvote one of your choices.'
embed = discord.Embed(title='', description=say_text, colour=SETTINGS.color)
embed.set_author(name='Pollmaster', icon_url=SETTINGS.author_icon)
await self.bot.send_message(user, embed=embed)
refresh_poll = False
else:
self.votes[user.id]['choices'].append(choice)
self.votes[user.id]['choices'] = list(set(self.votes[user.id]['choices']))
else:
if [choice] == self.votes[user.id]['choices']:
refresh_poll = False
if self.anonymous:
# undo anonymous vote
await self.unvote(user, option, message)
return
else:
self.votes[user.id]['choices'] = [choice]
if self.options_reaction_default:
if option in self.options_reaction:
choice = self.options_reaction.index(option)
else:
pass
if option in AZ_EMOJIS:
choice = AZ_EMOJIS.index(option)
if choice != 'invalid':
# if self.multiple_choice != 1: # more than 1 choice (0 = no limit)
if choice in self.votes[user.id]['choices']:
if self.anonymous:
# anonymous multiple choice -> can't unreact so we toggle with react
await self.unvote(user, option, message, lock)
# refresh_poll = False
else:
if self.multiple_choice > 0 and self.votes[user.id]['choices'].__len__() >= self.multiple_choice:
say_text = f'You have reached the **maximum choices of {self.multiple_choice}** for this poll. ' \
f'Before you can vote again, you need to unvote one of your choices.'
embed = discord.Embed(title='', description=say_text, colour=SETTINGS.color)
embed.set_author(name='Pollmaster', icon_url=SETTINGS.author_icon)
await self.bot.send_message(user, embed=embed)
# refresh_poll = False
else:
self.votes[user.id]['choices'].append(choice)
self.votes[user.id]['choices'] = list(set(self.votes[user.id]['choices']))
# else:
# if [choice] == self.votes[user.id]['choices']:
# # refresh_poll = False
# # if self.anonymous:
# # undo anonymous vote
# await self.unvote(user, option, message, lock)
# return
# else:
# self.votes[user.id]['choices'] = [choice]
else:
# unknow emoji
return
# commit
await self.save_to_db()
if lock._waiters.__len__() == 0:
# updating DB, clearing cache and refresh if necessary
await self.save_to_db()
await self.bot.poll_refresh_q.put_unique_id(
{'id': self.id, 'msg': message, 'sid': self.server.id, 'label': self.short, 'lock': lock})
if self.bot.poll_cache.get(str(self.server.id) + self.short):
del self.bot.poll_cache[str(self.server.id) + self.short]
# refresh
# if refresh_poll:
# edit message if there is a real change
# await self.bot.edit_message(message, embed=await self.generate_embed())
# self.bot.poll_refresh_q.append(str(self.id))
else:
# cache the poll until the queue is empty
self.bot.poll_cache[str(self.server.id)+self.short] = self
# refresh
if refresh_poll:
# edit message if there is a real change
await self.bot.edit_message(message, embed=await self.generate_embed())
# if refresh_poll:
# await self.bot.poll_refresh_q.put_unique_id({'id': self.id, 'msg': message, 'sid': self.server.id, 'label': self.short, 'lock': lock})
async def unvote(self, user, option, message):
async def unvote(self, user, option, message, lock):
if not await self.is_open():
# refresh to show closed poll
await self.bot.edit_message(message, embed=await self.generate_embed())
@ -1189,6 +1206,7 @@ class Poll:
if str(user.id) not in self.votes: return
choice = 'invalid'
if self.options_reaction_default:
if option in self.options_reaction:
choice = self.options_reaction.index(option)
@ -1199,8 +1217,17 @@ class Poll:
if choice != 'invalid' and choice in self.votes[user.id]['choices']:
try:
self.votes[user.id]['choices'].remove(choice)
await self.save_to_db()
await self.bot.edit_message(message, embed=await self.generate_embed())
if lock._waiters.__len__() == 0:
# updating DB, clearing cache and refreshing message
await self.save_to_db()
await self.bot.poll_refresh_q.put_unique_id(
{'id': self.id, 'msg': message, 'sid': self.server.id, 'label': self.short, 'lock': lock})
if self.bot.poll_cache.get(str(self.server.id) + self.short):
del self.bot.poll_cache[str(self.server.id) + self.short]
# await self.bot.edit_message(message, embed=await self.generate_embed())
else:
# cache the poll until the queue is empty
self.bot.poll_cache[str(self.server.id) + self.short] = self
except ValueError:
pass

View File

@ -5,6 +5,7 @@ import datetime
import json
import logging
import shlex
import traceback
import discord
import pytz
@ -23,8 +24,30 @@ from essentials.exceptions import StopWizard
class PollControls:
def __init__(self, bot):
self.bot = bot
self.bot.loop.create_task(self.refresh_polls())
self.ignore_next_removed_reaction = {}
# General Methods
async def refresh_polls(self):
"""This function runs every 5 seconds to refresh poll messages when needed"""
while True:
try:
for i in range(self.bot.poll_refresh_q.qsize()):
values = await self.bot.poll_refresh_q.get()
if values.get('lock') and not values.get('lock')._waiters:
p = await Poll.load_from_db(self.bot, str(values.get('sid')), values.get('label'))
if p:
await self.bot.edit_message(values.get('msg'), embed=await p.generate_embed())
self.bot.poll_refresh_q.task_done()
else:
await self.bot.poll_refresh_q.put_unique_id(values)
except AttributeError:
pass
await asyncio.sleep(5)
def get_lock(self, server_id):
if not self.bot.locks.get(str(server_id)):
self.bot.locks[server_id] = asyncio.Lock()
@ -455,10 +478,17 @@ class PollControls:
if not emoji:
return
# check if we can find a poll label
# check if removed by the bot.. this is a bit hacky but discord doesn't provide the correct info...
message_id = data.get('message_id')
channel_id = data.get('channel_id')
user_id = data.get('user_id')
if self.ignore_next_removed_reaction.get(str(message_id)+str(emoji)) == user_id:
del self.ignore_next_removed_reaction[str(message_id)+str(emoji)]
return
# check if we can find a poll label
channel_id = data.get('channel_id')
channel = self.bot.get_channel(channel_id)
user = await self.bot.get_user_info(user_id) # only do this once
if not channel:
@ -485,15 +515,20 @@ class PollControls:
server = await ask_for_server(self.bot, user_msg, label)
# this is exclusive
async with self.get_lock(server.id):
p = await Poll.load_from_db(self.bot, server.id, label)
lock = self.get_lock(server.id)
async with lock:
# try to load poll form cache
p = self.bot.poll_cache.get(str(server.id) + label)
if not p:
p = await Poll.load_from_db(self.bot, server.id, label)
if not isinstance(p, Poll):
return
if not p.anonymous:
# for anonymous polls we can't unvote because we need to hide reactions
member = server.get_member(user_id)
await p.unvote(member, emoji, message)
await p.unvote(member, emoji, message, lock)
async def do_on_reaction_add(self, data):
@ -539,8 +574,11 @@ class PollControls:
# this is exclusive to keep database access sequential
# hopefully it will scale well enough or I need a different solution
async with self.get_lock(server.id):
p = await Poll.load_from_db(self.bot, server.id, label)
lock = self.get_lock(server.id)
async with lock:
p = self.bot.poll_cache.get(str(server.id)+label)
if not p:
p = await Poll.load_from_db(self.bot, server.id, label)
if not isinstance(p, Poll):
return
@ -623,20 +661,26 @@ class PollControls:
f'at least one of these roles can vote:\n{", ".join(p.roles)}')
return
# check if we need to remove reactions (this will trigger on_reaction_remove)
if str(channel.type) != 'private' and p.anonymous:
# immediately remove reaction and to be safe, remove all reactions
self.ignore_next_removed_reaction[str(message.id)+str(emoji)] = user_id
await self.bot.remove_reaction(message, emoji, user)
# order here is crucial since we can't determine if a reaction was removed by the bot or user
# update database with vote
await p.vote(member, emoji, message)
#
# check if we need to remove reactions (this will trigger on_reaction_remove)
if str(channel.type) != 'private':
if p.anonymous:
# immediately remove reaction and to be safe, remove all reactions
await self.bot.remove_reaction(message, emoji, user)
elif p.multiple_choice == 1:
# remove all other reactions
for r in message.reactions:
if r.emoji and r.emoji != emoji:
await self.bot.remove_reaction(message, r.emoji, user)
await p.vote(member, emoji, message, lock)
# cant do this until we figure out how to see who removed the reaction?
# for now MC 1 is like MC x
# if str(channel.type) != 'private' and p.multiple_choice == 1:
# # remove all other reactions
# # if lock._waiters.__len__() == 0:
# for r in message.reactions:
# if r.emoji and r.emoji != emoji:
# await self.bot.remove_reaction(message, r.emoji, user)
# pass
def setup(bot):
global logger

View File

@ -1,4 +1,5 @@
import asyncio
import sys
import traceback
import logging
import aiohttp
@ -9,6 +10,7 @@ from motor.motor_asyncio import AsyncIOMotorClient
from essentials.multi_server import get_pre
from essentials.settings import SETTINGS
from utils.asyncio_unique_queue import UniqueQueue
from utils.import_old_database import import_old_database
bot_config = {
@ -76,12 +78,19 @@ async def on_ready():
print("Problem verifying servers.")
# cache prefixes
bot.pre = {entry['_id']:entry['prefix'] async for entry in bot.db.config.find({}, {'_id', 'prefix'})}
bot.locks = {sid: asyncio.Lock() for sid in [s.id for s in bot.servers]}
bot.pre = {entry['_id']: entry['prefix'] async for entry in bot.db.config.find({}, {'_id', 'prefix'})}
# global locks and caches for performance when voting rapidly
bot.locks = {}
bot.poll_cache = {}
# bot.poll_refresh_q = {}
bot.poll_refresh_q = UniqueQueue()
print("Servers verified. Bot running.")
@bot.event
async def on_command_error(e, ctx):
if SETTINGS.log_errors:
@ -102,8 +111,6 @@ async def on_command_error(e, ctx):
# log error
logger.error(f'{type(e).__name__}: {e}\n{"".join(traceback.format_tb(e.__traceback__))}')
if SETTINGS.mode == 'development':
raise e
traceback.print_exception(type(e), e, e.__traceback__, file=sys.stderr)
if SETTINGS.msg_errors:
@ -117,6 +124,9 @@ async def on_command_error(e, ctx):
)
await bot.send_message(bot.owner, embed=e)
# if SETTINGS.mode == 'development':
raise e
@bot.event
async def on_server_join(server):

View File

@ -0,0 +1,11 @@
import asyncio
class UniqueQueue(asyncio.Queue):
async def put_unique_id(self, item):
if not item.get('id'):
return
if item.get('id') not in [v.get('id') for v in self._queue]:
await self.put(item)