commit
ae553f2f36
@ -29,6 +29,7 @@ class Config:
|
|||||||
f'If you would like to add a trailing whitespace to the prefix, use `{pre}prefix {pre}\w`.'
|
f'If you would like to add a trailing whitespace to the prefix, use `{pre}prefix {pre}\w`.'
|
||||||
|
|
||||||
await self.bot.db.config.update_one({'_id': str(server.id)}, {'$set': {'prefix': str(pre)}}, upsert=True)
|
await self.bot.db.config.update_one({'_id': str(server.id)}, {'$set': {'prefix': str(pre)}}, upsert=True)
|
||||||
|
self.bot.pre[str(server.id)] = str(pre)
|
||||||
await self.bot.say(msg)
|
await self.bot.say(msg)
|
||||||
|
|
||||||
@commands.command(pass_context=True)
|
@commands.command(pass_context=True)
|
||||||
|
|||||||
@ -20,6 +20,7 @@ class DiscordBotsOrgAPI:
|
|||||||
while True:
|
while True:
|
||||||
logger.info('attempting to post server count')
|
logger.info('attempting to post server count')
|
||||||
try:
|
try:
|
||||||
|
if SETTINGS.mode == 'production':
|
||||||
await self.dblpy.post_server_count()
|
await self.dblpy.post_server_count()
|
||||||
logger.info('posted server count ({})'.format(len(self.bot.servers)))
|
logger.info('posted server count ({})'.format(len(self.bot.servers)))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
@ -24,6 +25,11 @@ class PollControls:
|
|||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
|
||||||
# General Methods
|
# General Methods
|
||||||
|
def get_lock(self, server_id):
|
||||||
|
if not self.bot.locks.get(str(server_id)):
|
||||||
|
self.bot.locks[server_id] = asyncio.Lock()
|
||||||
|
return self.bot.locks.get(str(server_id))
|
||||||
|
|
||||||
async def is_admin_or_creator(self, ctx, server, owner_id, error_msg=None):
|
async def is_admin_or_creator(self, ctx, server, owner_id, error_msg=None):
|
||||||
member = server.get_member(ctx.message.author.id)
|
member = server.get_member(ctx.message.author.id)
|
||||||
if member.id == owner_id:
|
if member.id == owner_id:
|
||||||
@ -288,11 +294,23 @@ class PollControls:
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
args = parser.parse_args(cmds)
|
args, unknown_args = parser.parse_known_args(cmds)
|
||||||
except SystemExit:
|
except SystemExit:
|
||||||
await self.say_error(ctx, error_text=helpstring)
|
await self.say_error(ctx, error_text=helpstring)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if unknown_args:
|
||||||
|
error_text = f'**There was an error reading the command line options!**.\n' \
|
||||||
|
f'Most likely this is because you didn\'t surround the arguments with double quotes like this: ' \
|
||||||
|
f'`{pre}cmd -q "question of the poll" -o "yes, no, maybe"`' \
|
||||||
|
f'\n\nHere are the arguments I could not understand:\n'
|
||||||
|
error_text += '`'+'\n'.join(unknown_args)+'`'
|
||||||
|
error_text += f'\n\nHere are the arguments which are ok:\n'
|
||||||
|
error_text += '`' + '\n'.join([f'{k}: {v}' for k, v in vars(args).items()]) + '`'
|
||||||
|
|
||||||
|
await self.say_error(ctx, error_text=error_text, footer_text=f'type `{pre}cmd help` for details.')
|
||||||
|
return
|
||||||
|
|
||||||
# pass arguments to the wizard
|
# pass arguments to the wizard
|
||||||
async def route(poll):
|
async def route(poll):
|
||||||
await poll.set_name(force=args.question)
|
await poll.set_name(force=args.question)
|
||||||
@ -306,7 +324,7 @@ class PollControls:
|
|||||||
|
|
||||||
poll = await self.wizard(ctx, route, server)
|
poll = await self.wizard(ctx, route, server)
|
||||||
if poll:
|
if poll:
|
||||||
await poll.post_embed()
|
await poll.post_embed(destination=poll.channel)
|
||||||
|
|
||||||
|
|
||||||
@commands.command(pass_context=True)
|
@commands.command(pass_context=True)
|
||||||
@ -328,7 +346,7 @@ class PollControls:
|
|||||||
|
|
||||||
poll = await self.wizard(ctx, route, server)
|
poll = await self.wizard(ctx, route, server)
|
||||||
if poll:
|
if poll:
|
||||||
await poll.post_embed()
|
await poll.post_embed(destination=poll.channel)
|
||||||
|
|
||||||
@commands.command(pass_context=True)
|
@commands.command(pass_context=True)
|
||||||
async def prepare(self, ctx, *, cmd=None):
|
async def prepare(self, ctx, *, cmd=None):
|
||||||
@ -371,7 +389,7 @@ class PollControls:
|
|||||||
|
|
||||||
poll = await self.wizard(ctx, route, server)
|
poll = await self.wizard(ctx, route, server)
|
||||||
if poll:
|
if poll:
|
||||||
await poll.post_embed()
|
await poll.post_embed(destination=poll.channel)
|
||||||
|
|
||||||
# The Wizard!
|
# The Wizard!
|
||||||
async def wizard(self, ctx, route, server):
|
async def wizard(self, ctx, route, server):
|
||||||
@ -460,6 +478,9 @@ class PollControls:
|
|||||||
user_msg = copy.deepcopy(message)
|
user_msg = copy.deepcopy(message)
|
||||||
user_msg.author = user
|
user_msg.author = user
|
||||||
server = await ask_for_server(self.bot, user_msg, label)
|
server = await ask_for_server(self.bot, user_msg, label)
|
||||||
|
|
||||||
|
# this is exclusive
|
||||||
|
async with self.get_lock(server.id):
|
||||||
p = await Poll.load_from_db(self.bot, server.id, label)
|
p = await Poll.load_from_db(self.bot, server.id, label)
|
||||||
if not isinstance(p, Poll):
|
if not isinstance(p, Poll):
|
||||||
return
|
return
|
||||||
@ -510,6 +531,10 @@ class PollControls:
|
|||||||
user_msg = copy.deepcopy(message)
|
user_msg = copy.deepcopy(message)
|
||||||
user_msg.author = user
|
user_msg.author = user
|
||||||
server = await ask_for_server(self.bot, user_msg, label)
|
server = await ask_for_server(self.bot, user_msg, label)
|
||||||
|
|
||||||
|
# this is exclusive to keep database access sequential
|
||||||
|
# hopefully it will scale well enough or I need a different solution
|
||||||
|
async with self.get_lock(server.id):
|
||||||
p = await Poll.load_from_db(self.bot, server.id, label)
|
p = await Poll.load_from_db(self.bot, server.id, label)
|
||||||
if not isinstance(p, Poll):
|
if not isinstance(p, Poll):
|
||||||
return
|
return
|
||||||
|
|||||||
@ -11,7 +11,7 @@ async def get_pre(bot, message):
|
|||||||
if str(message.channel.type) == 'private':
|
if str(message.channel.type) == 'private':
|
||||||
shared_server_list = await get_servers(bot, message)
|
shared_server_list = await get_servers(bot, message)
|
||||||
if shared_server_list.__len__() == 0:
|
if shared_server_list.__len__() == 0:
|
||||||
return 'pm!!'
|
return 'pm!'
|
||||||
elif shared_server_list.__len__() == 1:
|
elif shared_server_list.__len__() == 1:
|
||||||
return await get_server_pre(bot, shared_server_list[0])
|
return await get_server_pre(bot, shared_server_list[0])
|
||||||
else:
|
else:
|
||||||
@ -24,12 +24,13 @@ async def get_pre(bot, message):
|
|||||||
async def get_server_pre(bot, server):
|
async def get_server_pre(bot, server):
|
||||||
'''Gets the prefix for a server.'''
|
'''Gets the prefix for a server.'''
|
||||||
try:
|
try:
|
||||||
result = await bot.db.config.find_one({'_id': str(server.id)})
|
#result = await bot.db.config.find_one({'_id': str(server.id)})
|
||||||
|
result = bot.pre[str(server.id)]
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return 'pm!'
|
return 'pm!'
|
||||||
if not result or not result.get('prefix'):
|
if not result: #or not result.get('prefix'):
|
||||||
return 'pm!!'
|
return 'pm!'
|
||||||
return result.get('prefix')
|
return result #result.get('prefix')
|
||||||
|
|
||||||
|
|
||||||
async def get_servers(bot, message, short=None):
|
async def get_servers(bot, message, short=None):
|
||||||
|
|||||||
@ -22,6 +22,7 @@ class Settings:
|
|||||||
self.dbl_token = SECRETS.dbl_token
|
self.dbl_token = SECRETS.dbl_token
|
||||||
self.mongo_db = SECRETS.mongo_db
|
self.mongo_db = SECRETS.mongo_db
|
||||||
self.bot_token = SECRETS.bot_token
|
self.bot_token = SECRETS.bot_token
|
||||||
|
self.mode = SECRETS.mode
|
||||||
|
|
||||||
|
|
||||||
SETTINGS = Settings()
|
SETTINGS = Settings()
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@ -74,6 +75,10 @@ async def on_ready():
|
|||||||
except:
|
except:
|
||||||
print("Problem verifying servers.")
|
print("Problem verifying servers.")
|
||||||
|
|
||||||
|
# cache prefixes
|
||||||
|
bot.pre = {entry['_id']:entry['prefix'] async for entry in bot.db.config.find({}, {'_id', 'prefix'})}
|
||||||
|
bot.locks = {sid: asyncio.Lock() for sid in [s.id for s in bot.servers]}
|
||||||
|
|
||||||
print("Servers verified. Bot running.")
|
print("Servers verified. Bot running.")
|
||||||
|
|
||||||
|
|
||||||
@ -97,6 +102,8 @@ async def on_command_error(e, ctx):
|
|||||||
|
|
||||||
# log error
|
# log error
|
||||||
logger.error(f'{type(e).__name__}: {e}\n{"".join(traceback.format_tb(e.__traceback__))}')
|
logger.error(f'{type(e).__name__}: {e}\n{"".join(traceback.format_tb(e.__traceback__))}')
|
||||||
|
if SETTINGS.mode == 'development':
|
||||||
|
raise e
|
||||||
|
|
||||||
if SETTINGS.msg_errors:
|
if SETTINGS.msg_errors:
|
||||||
# send discord message for unexpected errors
|
# send discord message for unexpected errors
|
||||||
@ -119,6 +126,7 @@ async def on_server_join(server):
|
|||||||
{'$set': {'prefix': 'pm!', 'admin_role': 'polladmin', 'user_role': 'polluser'}},
|
{'$set': {'prefix': 'pm!', 'admin_role': 'polladmin', 'user_role': 'polluser'}},
|
||||||
upsert=True
|
upsert=True
|
||||||
)
|
)
|
||||||
|
bot.pre[str(server.id)] = 'pm!'
|
||||||
|
|
||||||
|
|
||||||
bot.run(SETTINGS.bot_token, reconnect=True)
|
bot.run(SETTINGS.bot_token, reconnect=True)
|
||||||
Loading…
Reference in New Issue
Block a user