a lot of work on scaling the bot
This commit is contained in:
parent
654383b116
commit
6189aaf39a
71
cogs/poll.py
71
cogs/poll.py
@ -43,6 +43,8 @@ class Poll:
|
|||||||
if channel is None:
|
if channel is None:
|
||||||
channel = ctx.message.channel
|
channel = ctx.message.channel
|
||||||
|
|
||||||
|
self.id = None
|
||||||
|
|
||||||
self.author = ctx.message.author
|
self.author = ctx.message.author
|
||||||
|
|
||||||
self.server = server
|
self.server = server
|
||||||
@ -840,6 +842,7 @@ class Poll:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def from_dict(self, d):
|
async def from_dict(self, d):
|
||||||
|
self.id = d['_id']
|
||||||
self.server = self.bot.get_server(str(d['server_id']))
|
self.server = self.bot.get_server(str(d['server_id']))
|
||||||
self.channel = self.bot.get_channel(str(d['channel_id']))
|
self.channel = self.bot.get_channel(str(d['channel_id']))
|
||||||
self.author = await self.bot.get_user_info(str(d['author']))
|
self.author = await self.bot.get_user_info(str(d['author']))
|
||||||
@ -1104,7 +1107,7 @@ class Poll:
|
|||||||
else:
|
else:
|
||||||
return sum([1 for c in [u for u in self.votes] if option in self.votes[c]['choices']])
|
return sum([1 for c in [u for u in self.votes] if option in self.votes[c]['choices']])
|
||||||
|
|
||||||
async def vote(self, user, option, message):
|
async def vote(self, user, option, message, lock):
|
||||||
if not await self.is_open():
|
if not await self.is_open():
|
||||||
# refresh to show closed poll
|
# refresh to show closed poll
|
||||||
await self.bot.edit_message(message, embed=await self.generate_embed())
|
await self.bot.edit_message(message, embed=await self.generate_embed())
|
||||||
@ -1114,7 +1117,7 @@ class Poll:
|
|||||||
return
|
return
|
||||||
|
|
||||||
choice = 'invalid'
|
choice = 'invalid'
|
||||||
refresh_poll = True
|
# refresh_poll = True
|
||||||
|
|
||||||
# get weight
|
# get weight
|
||||||
weight = 1
|
weight = 1
|
||||||
@ -1129,7 +1132,6 @@ class Poll:
|
|||||||
else:
|
else:
|
||||||
self.votes[user.id]['weight'] = weight
|
self.votes[user.id]['weight'] = weight
|
||||||
|
|
||||||
if self.reaction:
|
|
||||||
if self.options_reaction_default:
|
if self.options_reaction_default:
|
||||||
if option in self.options_reaction:
|
if option in self.options_reaction:
|
||||||
choice = self.options_reaction.index(option)
|
choice = self.options_reaction.index(option)
|
||||||
@ -1138,13 +1140,12 @@ class Poll:
|
|||||||
choice = AZ_EMOJIS.index(option)
|
choice = AZ_EMOJIS.index(option)
|
||||||
|
|
||||||
if choice != 'invalid':
|
if choice != 'invalid':
|
||||||
if self.multiple_choice != 1: # more than 1 choice (0 = no limit)
|
# if self.multiple_choice != 1: # more than 1 choice (0 = no limit)
|
||||||
if choice in self.votes[user.id]['choices']:
|
if choice in self.votes[user.id]['choices']:
|
||||||
if self.anonymous:
|
if self.anonymous:
|
||||||
# anonymous multiple choice -> can't unreact so we toggle with react
|
# anonymous multiple choice -> can't unreact so we toggle with react
|
||||||
await self.unvote(user, option, message)
|
await self.unvote(user, option, message, lock)
|
||||||
return
|
# refresh_poll = False
|
||||||
refresh_poll = False
|
|
||||||
else:
|
else:
|
||||||
if self.multiple_choice > 0 and self.votes[user.id]['choices'].__len__() >= self.multiple_choice:
|
if self.multiple_choice > 0 and self.votes[user.id]['choices'].__len__() >= self.multiple_choice:
|
||||||
say_text = f'You have reached the **maximum choices of {self.multiple_choice}** for this poll. ' \
|
say_text = f'You have reached the **maximum choices of {self.multiple_choice}** for this poll. ' \
|
||||||
@ -1152,32 +1153,48 @@ class Poll:
|
|||||||
embed = discord.Embed(title='', description=say_text, colour=SETTINGS.color)
|
embed = discord.Embed(title='', description=say_text, colour=SETTINGS.color)
|
||||||
embed.set_author(name='Pollmaster', icon_url=SETTINGS.author_icon)
|
embed.set_author(name='Pollmaster', icon_url=SETTINGS.author_icon)
|
||||||
await self.bot.send_message(user, embed=embed)
|
await self.bot.send_message(user, embed=embed)
|
||||||
refresh_poll = False
|
# refresh_poll = False
|
||||||
else:
|
else:
|
||||||
self.votes[user.id]['choices'].append(choice)
|
self.votes[user.id]['choices'].append(choice)
|
||||||
self.votes[user.id]['choices'] = list(set(self.votes[user.id]['choices']))
|
self.votes[user.id]['choices'] = list(set(self.votes[user.id]['choices']))
|
||||||
|
# else:
|
||||||
|
# if [choice] == self.votes[user.id]['choices']:
|
||||||
|
# # refresh_poll = False
|
||||||
|
# # if self.anonymous:
|
||||||
|
# # undo anonymous vote
|
||||||
|
# await self.unvote(user, option, message, lock)
|
||||||
|
# return
|
||||||
|
# else:
|
||||||
|
# self.votes[user.id]['choices'] = [choice]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if [choice] == self.votes[user.id]['choices']:
|
# unknow emoji
|
||||||
refresh_poll = False
|
|
||||||
if self.anonymous:
|
|
||||||
# undo anonymous vote
|
|
||||||
await self.unvote(user, option, message)
|
|
||||||
return
|
return
|
||||||
else:
|
|
||||||
self.votes[user.id]['choices'] = [choice]
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# commit
|
# commit
|
||||||
|
if lock._waiters.__len__() == 0:
|
||||||
|
# updating DB, clearing cache and refresh if necessary
|
||||||
await self.save_to_db()
|
await self.save_to_db()
|
||||||
|
await self.bot.poll_refresh_q.put_unique_id(
|
||||||
|
{'id': self.id, 'msg': message, 'sid': self.server.id, 'label': self.short, 'lock': lock})
|
||||||
|
if self.bot.poll_cache.get(str(self.server.id) + self.short):
|
||||||
|
del self.bot.poll_cache[str(self.server.id) + self.short]
|
||||||
# refresh
|
# refresh
|
||||||
if refresh_poll:
|
# if refresh_poll:
|
||||||
# edit message if there is a real change
|
# edit message if there is a real change
|
||||||
await self.bot.edit_message(message, embed=await self.generate_embed())
|
# await self.bot.edit_message(message, embed=await self.generate_embed())
|
||||||
|
# self.bot.poll_refresh_q.append(str(self.id))
|
||||||
|
else:
|
||||||
|
# cache the poll until the queue is empty
|
||||||
|
self.bot.poll_cache[str(self.server.id)+self.short] = self
|
||||||
|
|
||||||
|
# if refresh_poll:
|
||||||
|
# await self.bot.poll_refresh_q.put_unique_id({'id': self.id, 'msg': message, 'sid': self.server.id, 'label': self.short, 'lock': lock})
|
||||||
|
|
||||||
|
|
||||||
async def unvote(self, user, option, message):
|
|
||||||
|
|
||||||
|
async def unvote(self, user, option, message, lock):
|
||||||
if not await self.is_open():
|
if not await self.is_open():
|
||||||
# refresh to show closed poll
|
# refresh to show closed poll
|
||||||
await self.bot.edit_message(message, embed=await self.generate_embed())
|
await self.bot.edit_message(message, embed=await self.generate_embed())
|
||||||
@ -1189,6 +1206,7 @@ class Poll:
|
|||||||
if str(user.id) not in self.votes: return
|
if str(user.id) not in self.votes: return
|
||||||
|
|
||||||
choice = 'invalid'
|
choice = 'invalid'
|
||||||
|
|
||||||
if self.options_reaction_default:
|
if self.options_reaction_default:
|
||||||
if option in self.options_reaction:
|
if option in self.options_reaction:
|
||||||
choice = self.options_reaction.index(option)
|
choice = self.options_reaction.index(option)
|
||||||
@ -1199,8 +1217,17 @@ class Poll:
|
|||||||
if choice != 'invalid' and choice in self.votes[user.id]['choices']:
|
if choice != 'invalid' and choice in self.votes[user.id]['choices']:
|
||||||
try:
|
try:
|
||||||
self.votes[user.id]['choices'].remove(choice)
|
self.votes[user.id]['choices'].remove(choice)
|
||||||
|
if lock._waiters.__len__() == 0:
|
||||||
|
# updating DB, clearing cache and refreshing message
|
||||||
await self.save_to_db()
|
await self.save_to_db()
|
||||||
await self.bot.edit_message(message, embed=await self.generate_embed())
|
await self.bot.poll_refresh_q.put_unique_id(
|
||||||
|
{'id': self.id, 'msg': message, 'sid': self.server.id, 'label': self.short, 'lock': lock})
|
||||||
|
if self.bot.poll_cache.get(str(self.server.id) + self.short):
|
||||||
|
del self.bot.poll_cache[str(self.server.id) + self.short]
|
||||||
|
# await self.bot.edit_message(message, embed=await self.generate_embed())
|
||||||
|
else:
|
||||||
|
# cache the poll until the queue is empty
|
||||||
|
self.bot.poll_cache[str(self.server.id) + self.short] = self
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import datetime
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import shlex
|
import shlex
|
||||||
|
import traceback
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
import pytz
|
import pytz
|
||||||
@ -23,8 +24,30 @@ from essentials.exceptions import StopWizard
|
|||||||
class PollControls:
|
class PollControls:
|
||||||
def __init__(self, bot):
|
def __init__(self, bot):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
self.bot.loop.create_task(self.refresh_polls())
|
||||||
|
self.ignore_next_removed_reaction = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# General Methods
|
# General Methods
|
||||||
|
async def refresh_polls(self):
|
||||||
|
"""This function runs every 5 seconds to refresh poll messages when needed"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
for i in range(self.bot.poll_refresh_q.qsize()):
|
||||||
|
values = await self.bot.poll_refresh_q.get()
|
||||||
|
if values.get('lock') and not values.get('lock')._waiters:
|
||||||
|
p = await Poll.load_from_db(self.bot, str(values.get('sid')), values.get('label'))
|
||||||
|
if p:
|
||||||
|
await self.bot.edit_message(values.get('msg'), embed=await p.generate_embed())
|
||||||
|
|
||||||
|
self.bot.poll_refresh_q.task_done()
|
||||||
|
else:
|
||||||
|
await self.bot.poll_refresh_q.put_unique_id(values)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
def get_lock(self, server_id):
|
def get_lock(self, server_id):
|
||||||
if not self.bot.locks.get(str(server_id)):
|
if not self.bot.locks.get(str(server_id)):
|
||||||
self.bot.locks[server_id] = asyncio.Lock()
|
self.bot.locks[server_id] = asyncio.Lock()
|
||||||
@ -455,10 +478,17 @@ class PollControls:
|
|||||||
if not emoji:
|
if not emoji:
|
||||||
return
|
return
|
||||||
|
|
||||||
# check if we can find a poll label
|
# check if removed by the bot.. this is a bit hacky but discord doesn't provide the correct info...
|
||||||
message_id = data.get('message_id')
|
message_id = data.get('message_id')
|
||||||
channel_id = data.get('channel_id')
|
|
||||||
user_id = data.get('user_id')
|
user_id = data.get('user_id')
|
||||||
|
if self.ignore_next_removed_reaction.get(str(message_id)+str(emoji)) == user_id:
|
||||||
|
del self.ignore_next_removed_reaction[str(message_id)+str(emoji)]
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# check if we can find a poll label
|
||||||
|
channel_id = data.get('channel_id')
|
||||||
|
|
||||||
channel = self.bot.get_channel(channel_id)
|
channel = self.bot.get_channel(channel_id)
|
||||||
user = await self.bot.get_user_info(user_id) # only do this once
|
user = await self.bot.get_user_info(user_id) # only do this once
|
||||||
if not channel:
|
if not channel:
|
||||||
@ -485,7 +515,12 @@ class PollControls:
|
|||||||
server = await ask_for_server(self.bot, user_msg, label)
|
server = await ask_for_server(self.bot, user_msg, label)
|
||||||
|
|
||||||
# this is exclusive
|
# this is exclusive
|
||||||
async with self.get_lock(server.id):
|
|
||||||
|
lock = self.get_lock(server.id)
|
||||||
|
async with lock:
|
||||||
|
# try to load poll form cache
|
||||||
|
p = self.bot.poll_cache.get(str(server.id) + label)
|
||||||
|
if not p:
|
||||||
p = await Poll.load_from_db(self.bot, server.id, label)
|
p = await Poll.load_from_db(self.bot, server.id, label)
|
||||||
if not isinstance(p, Poll):
|
if not isinstance(p, Poll):
|
||||||
return
|
return
|
||||||
@ -493,7 +528,7 @@ class PollControls:
|
|||||||
if not p.anonymous:
|
if not p.anonymous:
|
||||||
# for anonymous polls we can't unvote because we need to hide reactions
|
# for anonymous polls we can't unvote because we need to hide reactions
|
||||||
member = server.get_member(user_id)
|
member = server.get_member(user_id)
|
||||||
await p.unvote(member, emoji, message)
|
await p.unvote(member, emoji, message, lock)
|
||||||
|
|
||||||
|
|
||||||
async def do_on_reaction_add(self, data):
|
async def do_on_reaction_add(self, data):
|
||||||
@ -539,7 +574,10 @@ class PollControls:
|
|||||||
|
|
||||||
# this is exclusive to keep database access sequential
|
# this is exclusive to keep database access sequential
|
||||||
# hopefully it will scale well enough or I need a different solution
|
# hopefully it will scale well enough or I need a different solution
|
||||||
async with self.get_lock(server.id):
|
lock = self.get_lock(server.id)
|
||||||
|
async with lock:
|
||||||
|
p = self.bot.poll_cache.get(str(server.id)+label)
|
||||||
|
if not p:
|
||||||
p = await Poll.load_from_db(self.bot, server.id, label)
|
p = await Poll.load_from_db(self.bot, server.id, label)
|
||||||
if not isinstance(p, Poll):
|
if not isinstance(p, Poll):
|
||||||
return
|
return
|
||||||
@ -623,20 +661,26 @@ class PollControls:
|
|||||||
f'at least one of these roles can vote:\n{", ".join(p.roles)}')
|
f'at least one of these roles can vote:\n{", ".join(p.roles)}')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# check if we need to remove reactions (this will trigger on_reaction_remove)
|
||||||
|
if str(channel.type) != 'private' and p.anonymous:
|
||||||
|
# immediately remove reaction and to be safe, remove all reactions
|
||||||
|
self.ignore_next_removed_reaction[str(message.id)+str(emoji)] = user_id
|
||||||
|
await self.bot.remove_reaction(message, emoji, user)
|
||||||
|
|
||||||
# order here is crucial since we can't determine if a reaction was removed by the bot or user
|
# order here is crucial since we can't determine if a reaction was removed by the bot or user
|
||||||
# update database with vote
|
# update database with vote
|
||||||
await p.vote(member, emoji, message)
|
await p.vote(member, emoji, message, lock)
|
||||||
#
|
|
||||||
# check if we need to remove reactions (this will trigger on_reaction_remove)
|
# cant do this until we figure out how to see who removed the reaction?
|
||||||
if str(channel.type) != 'private':
|
# for now MC 1 is like MC x
|
||||||
if p.anonymous:
|
# if str(channel.type) != 'private' and p.multiple_choice == 1:
|
||||||
# immediately remove reaction and to be safe, remove all reactions
|
# # remove all other reactions
|
||||||
await self.bot.remove_reaction(message, emoji, user)
|
# # if lock._waiters.__len__() == 0:
|
||||||
elif p.multiple_choice == 1:
|
# for r in message.reactions:
|
||||||
# remove all other reactions
|
# if r.emoji and r.emoji != emoji:
|
||||||
for r in message.reactions:
|
# await self.bot.remove_reaction(message, r.emoji, user)
|
||||||
if r.emoji and r.emoji != emoji:
|
# pass
|
||||||
await self.bot.remove_reaction(message, r.emoji, user)
|
|
||||||
|
|
||||||
def setup(bot):
|
def setup(bot):
|
||||||
global logger
|
global logger
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@ -9,6 +10,7 @@ from motor.motor_asyncio import AsyncIOMotorClient
|
|||||||
|
|
||||||
from essentials.multi_server import get_pre
|
from essentials.multi_server import get_pre
|
||||||
from essentials.settings import SETTINGS
|
from essentials.settings import SETTINGS
|
||||||
|
from utils.asyncio_unique_queue import UniqueQueue
|
||||||
from utils.import_old_database import import_old_database
|
from utils.import_old_database import import_old_database
|
||||||
|
|
||||||
bot_config = {
|
bot_config = {
|
||||||
@ -77,11 +79,18 @@ async def on_ready():
|
|||||||
|
|
||||||
# cache prefixes
|
# cache prefixes
|
||||||
bot.pre = {entry['_id']: entry['prefix'] async for entry in bot.db.config.find({}, {'_id', 'prefix'})}
|
bot.pre = {entry['_id']: entry['prefix'] async for entry in bot.db.config.find({}, {'_id', 'prefix'})}
|
||||||
bot.locks = {sid: asyncio.Lock() for sid in [s.id for s in bot.servers]}
|
|
||||||
|
# global locks and caches for performance when voting rapidly
|
||||||
|
bot.locks = {}
|
||||||
|
bot.poll_cache = {}
|
||||||
|
# bot.poll_refresh_q = {}
|
||||||
|
bot.poll_refresh_q = UniqueQueue()
|
||||||
|
|
||||||
print("Servers verified. Bot running.")
|
print("Servers verified. Bot running.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@bot.event
|
@bot.event
|
||||||
async def on_command_error(e, ctx):
|
async def on_command_error(e, ctx):
|
||||||
if SETTINGS.log_errors:
|
if SETTINGS.log_errors:
|
||||||
@ -102,8 +111,6 @@ async def on_command_error(e, ctx):
|
|||||||
|
|
||||||
# log error
|
# log error
|
||||||
logger.error(f'{type(e).__name__}: {e}\n{"".join(traceback.format_tb(e.__traceback__))}')
|
logger.error(f'{type(e).__name__}: {e}\n{"".join(traceback.format_tb(e.__traceback__))}')
|
||||||
if SETTINGS.mode == 'development':
|
|
||||||
raise e
|
|
||||||
traceback.print_exception(type(e), e, e.__traceback__, file=sys.stderr)
|
traceback.print_exception(type(e), e, e.__traceback__, file=sys.stderr)
|
||||||
|
|
||||||
if SETTINGS.msg_errors:
|
if SETTINGS.msg_errors:
|
||||||
@ -117,6 +124,9 @@ async def on_command_error(e, ctx):
|
|||||||
)
|
)
|
||||||
await bot.send_message(bot.owner, embed=e)
|
await bot.send_message(bot.owner, embed=e)
|
||||||
|
|
||||||
|
# if SETTINGS.mode == 'development':
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
@bot.event
|
@bot.event
|
||||||
async def on_server_join(server):
|
async def on_server_join(server):
|
||||||
|
|||||||
11
utils/asyncio_unique_queue.py
Normal file
11
utils/asyncio_unique_queue.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class UniqueQueue(asyncio.Queue):
|
||||||
|
|
||||||
|
async def put_unique_id(self, item):
|
||||||
|
if not item.get('id'):
|
||||||
|
return
|
||||||
|
|
||||||
|
if item.get('id') not in [v.get('id') for v in self._queue]:
|
||||||
|
await self.put(item)
|
||||||
Loading…
Reference in New Issue
Block a user