Add in bot, personify for RBXLegacy l8r

Oh also it's pre installed with all requirements 😃
This commit is contained in:
Quacky 2017-07-22 22:54:59 -05:00
parent 4344ad3582
commit 8fda0bc62e
1417 changed files with 341934 additions and 0 deletions

8
RBXLegacyDiscordBot/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
*.json
*.pyc
__pycache__
data
!data/trivia/*
!data/audio/playlists/*
*.exe
*.dll

View File

@ -0,0 +1,12 @@
language: python
python:
- "3.5.2"
install:
- pip install -r requirements.txt
script:
- python -m compileall ./red.py
- python -m compileall ./cogs
- python ./red.py --no-prompt --no-cogs --dry-run
cache: pip
notifications:
email: false

View File

@ -0,0 +1,191 @@
from discord.ext import commands
from .utils.chat_formatting import box
from .utils.dataIO import dataIO
from .utils import checks
from __main__ import user_allowed, send_cmd_help
from copy import deepcopy
import os
import discord
class Alias:
def __init__(self, bot):
self.bot = bot
self.file_path = "data/alias/aliases.json"
self.aliases = dataIO.load_json(self.file_path)
self.remove_old()
@commands.group(pass_context=True, no_pm=True)
async def alias(self, ctx):
"""Manage per-server aliases for commands"""
if ctx.invoked_subcommand is None:
await send_cmd_help(ctx)
@alias.command(name="add", pass_context=True, no_pm=True)
@checks.mod_or_permissions(manage_server=True)
async def _add_alias(self, ctx, command, *, to_execute):
"""Add an alias for a command
Example: !alias add test flip @Twentysix"""
server = ctx.message.server
command = command.lower()
if len(command.split(" ")) != 1:
await self.bot.say("I can't safely do multi-word aliases because"
" of the fact that I allow arguments to"
" aliases. It sucks, I know, deal with it.")
return
if self.part_of_existing_command(command, server.id):
await self.bot.say('I can\'t safely add an alias that starts with '
'an existing command or alias. Sry <3')
return
prefix = self.get_prefix(server, to_execute)
if prefix is not None:
to_execute = to_execute[len(prefix):]
if server.id not in self.aliases:
self.aliases[server.id] = {}
if command not in self.bot.commands:
self.aliases[server.id][command] = to_execute
dataIO.save_json(self.file_path, self.aliases)
await self.bot.say("Alias '{}' added.".format(command))
else:
await self.bot.say("Cannot add '{}' because it's a real bot "
"command.".format(command))
@alias.command(name="help", pass_context=True, no_pm=True)
async def _help_alias(self, ctx, command):
"""Tries to execute help for the base command of the alias"""
server = ctx.message.server
if server.id in self.aliases:
server_aliases = self.aliases[server.id]
if command in server_aliases:
help_cmd = server_aliases[command].split(" ")[0]
new_content = self.bot.settings.get_prefixes(server)[0]
new_content += "help "
new_content += help_cmd[len(self.get_prefix(server,
help_cmd)):]
message = ctx.message
message.content = new_content
await self.bot.process_commands(message)
else:
await self.bot.say("That alias doesn't exist.")
@alias.command(name="show", pass_context=True, no_pm=True)
async def _show_alias(self, ctx, command):
"""Shows what command the alias executes."""
server = ctx.message.server
if server.id in self.aliases:
server_aliases = self.aliases[server.id]
if command in server_aliases:
await self.bot.say(box(server_aliases[command]))
else:
await self.bot.say("That alias doesn't exist.")
@alias.command(name="del", pass_context=True, no_pm=True)
@checks.mod_or_permissions(manage_server=True)
async def _del_alias(self, ctx, command):
"""Deletes an alias"""
command = command.lower()
server = ctx.message.server
if server.id in self.aliases:
self.aliases[server.id].pop(command, None)
dataIO.save_json(self.file_path, self.aliases)
await self.bot.say("Alias '{}' deleted.".format(command))
@alias.command(name="list", pass_context=True, no_pm=True)
async def _alias_list(self, ctx):
"""Lists aliases available on this server
Responds in DM"""
server = ctx.message.server
if server.id in self.aliases:
message = "```Alias list:\n"
for alias in sorted(self.aliases[server.id]):
if len(message) + len(alias) + 3 > 2000:
await self.bot.whisper(message)
message = "```\n"
message += "\t{}\n".format(alias)
if message != "```Alias list:\n":
message += "```"
await self.bot.whisper(message)
else:
await self.bot.say("There are no aliases on this server.")
async def on_message(self, message):
if len(message.content) < 2 or message.channel.is_private:
return
msg = message.content
server = message.server
prefix = self.get_prefix(server, msg)
if not prefix:
return
if server.id in self.aliases and user_allowed(message):
alias = self.first_word(msg[len(prefix):]).lower()
if alias in self.aliases[server.id]:
new_command = self.aliases[server.id][alias]
args = message.content[len(prefix + alias):]
new_message = deepcopy(message)
new_message.content = prefix + new_command + args
await self.bot.process_commands(new_message)
def part_of_existing_command(self, alias, server):
'''Command or alias'''
for command in self.bot.commands:
if alias.lower() == command.lower():
return True
return False
def remove_old(self):
for sid in self.aliases:
to_delete = []
to_add = []
for aliasname, alias in self.aliases[sid].items():
lower = aliasname.lower()
if aliasname != lower:
to_delete.append(aliasname)
to_add.append((lower, alias))
if aliasname != self.first_word(aliasname):
to_delete.append(aliasname)
continue
server = discord.Object(id=sid)
prefix = self.get_prefix(server, alias)
if prefix is not None:
self.aliases[sid][aliasname] = alias[len(prefix):]
for alias in to_delete: # Fixes caps and bad prefixes
del self.aliases[sid][alias]
for alias, command in to_add: # For fixing caps
self.aliases[sid][alias] = command
dataIO.save_json(self.file_path, self.aliases)
def first_word(self, msg):
return msg.split(" ")[0]
def get_prefix(self, server, msg):
prefixes = self.bot.settings.get_prefixes(server)
for p in prefixes:
if msg.startswith(p):
return p
return None
def check_folder():
if not os.path.exists("data/alias"):
print("Creating data/alias folder...")
os.makedirs("data/alias")
def check_file():
aliases = {}
f = "data/alias/aliases.json"
if not dataIO.is_valid_json(f):
print("Creating default alias's aliases.json...")
dataIO.save_json(f, aliases)
def setup(bot):
check_folder()
check_file()
bot.add_cog(Alias(bot))

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,200 @@
from discord.ext import commands
from .utils.dataIO import dataIO
from .utils import checks
from .utils.chat_formatting import pagify, box
import os
import re
class CustomCommands:
"""Custom commands
Creates commands used to display text"""
def __init__(self, bot):
self.bot = bot
self.file_path = "data/customcom/commands.json"
self.c_commands = dataIO.load_json(self.file_path)
@commands.group(aliases=["cc"], pass_context=True, no_pm=True)
async def customcom(self, ctx):
"""Custom commands management"""
if ctx.invoked_subcommand is None:
await self.bot.send_cmd_help(ctx)
@customcom.command(name="add", pass_context=True)
@checks.mod_or_permissions(administrator=True)
async def cc_add(self, ctx, command : str, *, text):
"""Adds a custom command
Example:
[p]customcom add yourcommand Text you want
CCs can be enhanced with arguments:
https://twentysix26.github.io/Red-Docs/red_guide_command_args/
"""
server = ctx.message.server
command = command.lower()
if command in self.bot.commands:
await self.bot.say("That command is already a standard command.")
return
if server.id not in self.c_commands:
self.c_commands[server.id] = {}
cmdlist = self.c_commands[server.id]
if command not in cmdlist:
cmdlist[command] = text
self.c_commands[server.id] = cmdlist
dataIO.save_json(self.file_path, self.c_commands)
await self.bot.say("Custom command successfully added.")
else:
await self.bot.say("This command already exists. Use "
"`{}customcom edit` to edit it."
"".format(ctx.prefix))
@customcom.command(name="edit", pass_context=True)
@checks.mod_or_permissions(administrator=True)
async def cc_edit(self, ctx, command : str, *, text):
"""Edits a custom command
Example:
[p]customcom edit yourcommand Text you want
"""
server = ctx.message.server
command = command.lower()
if server.id in self.c_commands:
cmdlist = self.c_commands[server.id]
if command in cmdlist:
cmdlist[command] = text
self.c_commands[server.id] = cmdlist
dataIO.save_json(self.file_path, self.c_commands)
await self.bot.say("Custom command successfully edited.")
else:
await self.bot.say("That command doesn't exist. Use "
"`{}customcom add` to add it."
"".format(ctx.prefix))
else:
await self.bot.say("There are no custom commands in this server."
" Use `{}customcom add` to start adding some."
"".format(ctx.prefix))
@customcom.command(name="delete", pass_context=True)
@checks.mod_or_permissions(administrator=True)
async def cc_delete(self, ctx, command : str):
"""Deletes a custom command
Example:
[p]customcom delete yourcommand"""
server = ctx.message.server
command = command.lower()
if server.id in self.c_commands:
cmdlist = self.c_commands[server.id]
if command in cmdlist:
cmdlist.pop(command, None)
self.c_commands[server.id] = cmdlist
dataIO.save_json(self.file_path, self.c_commands)
await self.bot.say("Custom command successfully deleted.")
else:
await self.bot.say("That command doesn't exist.")
else:
await self.bot.say("There are no custom commands in this server."
" Use `{}customcom add` to start adding some."
"".format(ctx.prefix))
@customcom.command(name="list", pass_context=True)
async def cc_list(self, ctx):
"""Shows custom commands list"""
server = ctx.message.server
commands = self.c_commands.get(server.id, {})
if not commands:
await self.bot.say("There are no custom commands in this server."
" Use `{}customcom add` to start adding some."
"".format(ctx.prefix))
return
commands = ", ".join([ctx.prefix + c for c in sorted(commands)])
commands = "Custom commands:\n\n" + commands
if len(commands) < 1500:
await self.bot.say(box(commands))
else:
for page in pagify(commands, delims=[" ", "\n"]):
await self.bot.whisper(box(page))
async def on_message(self, message):
if len(message.content) < 2 or message.channel.is_private:
return
server = message.server
prefix = self.get_prefix(message)
if not prefix:
return
if server.id in self.c_commands and self.bot.user_allowed(message):
cmdlist = self.c_commands[server.id]
cmd = message.content[len(prefix):]
if cmd in cmdlist:
cmd = cmdlist[cmd]
cmd = self.format_cc(cmd, message)
await self.bot.send_message(message.channel, cmd)
elif cmd.lower() in cmdlist:
cmd = cmdlist[cmd.lower()]
cmd = self.format_cc(cmd, message)
await self.bot.send_message(message.channel, cmd)
def get_prefix(self, message):
for p in self.bot.settings.get_prefixes(message.server):
if message.content.startswith(p):
return p
return False
def format_cc(self, command, message):
results = re.findall("\{([^}]+)\}", command)
for result in results:
param = self.transform_parameter(result, message)
command = command.replace("{" + result + "}", param)
return command
def transform_parameter(self, result, message):
"""
For security reasons only specific objects are allowed
Internals are ignored
"""
raw_result = "{" + result + "}"
objects = {
"message" : message,
"author" : message.author,
"channel" : message.channel,
"server" : message.server
}
if result in objects:
return str(objects[result])
try:
first, second = result.split(".")
except ValueError:
return raw_result
if first in objects and not second.startswith("_"):
first = objects[first]
else:
return raw_result
return str(getattr(first, second, raw_result))
def check_folders():
if not os.path.exists("data/customcom"):
print("Creating data/customcom folder...")
os.makedirs("data/customcom")
def check_files():
f = "data/customcom/commands.json"
if not dataIO.is_valid_json(f):
print("Creating empty commands.json...")
dataIO.save_json(f, {})
def setup(bot):
check_folders()
check_files()
bot.add_cog(CustomCommands(bot))

View File

@ -0,0 +1,693 @@
from discord.ext import commands
from cogs.utils.dataIO import dataIO
from cogs.utils import checks
from cogs.utils.chat_formatting import pagify, box
from __main__ import send_cmd_help, set_cog
import os
from subprocess import run as sp_run, PIPE
import shutil
from asyncio import as_completed
from setuptools import distutils
import discord
from functools import partial
from concurrent.futures import ThreadPoolExecutor
from time import time
from importlib.util import find_spec
from copy import deepcopy
NUM_THREADS = 4
REPO_NONEX = 0x1
REPO_CLONE = 0x2
REPO_SAME = 0x4
REPOS_LIST = "https://twentysix26.github.io/Red-Docs/red_cog_approved_repos/"
DISCLAIMER = ("You're about to add a 3rd party repository. The creator of Red"
" and its community have no responsibility for any potential "
"damage that the content of 3rd party repositories might cause."
"\nBy typing 'I agree' you declare to have read and understand "
"the above message. This message won't be shown again until the"
" next reboot.")
class UpdateError(Exception):
pass
class CloningError(UpdateError):
pass
class RequirementFail(UpdateError):
pass
class Downloader:
"""Cog downloader/installer."""
def __init__(self, bot):
self.bot = bot
self.disclaimer_accepted = False
self.path = os.path.join("data", "downloader")
self.file_path = os.path.join(self.path, "repos.json")
# {name:{url,cog1:{installed},cog1:{installed}}}
self.repos = dataIO.load_json(self.file_path)
self.executor = ThreadPoolExecutor(NUM_THREADS)
self._do_first_run()
def save_repos(self):
dataIO.save_json(self.file_path, self.repos)
@commands.group(pass_context=True)
@checks.is_owner()
async def cog(self, ctx):
"""Additional cogs management"""
if ctx.invoked_subcommand is None:
await send_cmd_help(ctx)
@cog.group(pass_context=True)
async def repo(self, ctx):
"""Repo management commands"""
if ctx.invoked_subcommand is None or \
isinstance(ctx.invoked_subcommand, commands.Group):
await send_cmd_help(ctx)
return
@repo.command(name="add", pass_context=True)
async def _repo_add(self, ctx, repo_name: str, repo_url: str):
"""Adds repo to available repo lists
Warning: Adding 3RD Party Repositories is at your own
Risk."""
if not self.disclaimer_accepted:
await self.bot.say(DISCLAIMER)
answer = await self.bot.wait_for_message(timeout=30,
author=ctx.message.author)
if answer is None:
await self.bot.say('Not adding repo.')
return
elif "i agree" not in answer.content.lower():
await self.bot.say('Not adding repo.')
return
else:
self.disclaimer_accepted = True
self.repos[repo_name] = {}
self.repos[repo_name]['url'] = repo_url
try:
self.update_repo(repo_name)
except CloningError:
await self.bot.say("That repository link doesn't seem to be "
"valid.")
del self.repos[repo_name]
return
self.populate_list(repo_name)
self.save_repos()
data = self.get_info_data(repo_name)
if data:
msg = data.get("INSTALL_MSG")
if msg:
await self.bot.say(msg[:2000])
await self.bot.say("Repo '{}' added.".format(repo_name))
@repo.command(name="remove")
async def _repo_del(self, repo_name: str):
"""Removes repo from repo list. COGS ARE NOT REMOVED."""
def remove_readonly(func, path, excinfo):
os.chmod(path, 0o755)
func(path)
if repo_name not in self.repos:
await self.bot.say("That repo doesn't exist.")
return
del self.repos[repo_name]
try:
shutil.rmtree(os.path.join(self.path, repo_name), onerror=remove_readonly)
except FileNotFoundError:
pass
self.save_repos()
await self.bot.say("Repo '{}' removed.".format(repo_name))
@cog.command(name="list")
async def _send_list(self, repo_name=None):
"""Lists installable cogs
Repositories list:
https://twentysix26.github.io/Red-Docs/red_cog_approved_repos/"""
retlist = []
if repo_name and repo_name in self.repos:
msg = "Available cogs:\n"
for cog in sorted(self.repos[repo_name].keys()):
if 'url' == cog:
continue
data = self.get_info_data(repo_name, cog)
if data and data.get("HIDDEN") is True:
continue
if data:
retlist.append([cog, data.get("SHORT", "")])
else:
retlist.append([cog, ''])
else:
if self.repos:
msg = "Available repos:\n"
for repo_name in sorted(self.repos.keys()):
data = self.get_info_data(repo_name)
if data:
retlist.append([repo_name, data.get("SHORT", "")])
else:
retlist.append([repo_name, ""])
else:
await self.bot.say("You haven't added a repository yet.\n"
"Start now! {}".format(REPOS_LIST))
return
col_width = max(len(row[0]) for row in retlist) + 2
for row in retlist:
msg += "\t" + "".join(word.ljust(col_width) for word in row) + "\n"
msg += "\nRepositories list: {}".format(REPOS_LIST)
for page in pagify(msg, delims=['\n'], shorten_by=8):
await self.bot.say(box(page))
@cog.command()
async def info(self, repo_name: str, cog: str=None):
"""Shows info about the specified cog"""
if cog is not None:
cogs = self.list_cogs(repo_name)
if cog in cogs:
data = self.get_info_data(repo_name, cog)
if data:
msg = "{} by {}\n\n".format(cog, data["AUTHOR"])
msg += data["NAME"] + "\n\n" + data["DESCRIPTION"]
await self.bot.say(box(msg))
else:
await self.bot.say("The specified cog has no info file.")
else:
await self.bot.say("That cog doesn't exist."
" Use cog list to see the full list.")
else:
data = self.get_info_data(repo_name)
if data is None:
await self.bot.say("That repo does not exist or the"
" information file is missing for that repo"
".")
return
name = data.get("NAME", None)
name = repo_name if name is None else name
author = data.get("AUTHOR", "Unknown")
desc = data.get("DESCRIPTION", "")
msg = ("```{} by {}```\n\n{}".format(name, author, desc))
await self.bot.say(msg)
@cog.command(hidden=True)
async def search(self, *terms: str):
"""Search installable cogs"""
pass # TO DO
@cog.command(pass_context=True)
async def update(self, ctx):
"""Updates cogs"""
tasknum = 0
num_repos = len(self.repos)
min_dt = 0.5
burst_inc = 0.1/(NUM_THREADS)
touch_n = tasknum
touch_t = time()
def regulate(touch_t, touch_n):
dt = time() - touch_t
if dt + burst_inc*(touch_n) > min_dt:
touch_n = 0
touch_t = time()
return True, touch_t, touch_n
return False, touch_t, touch_n + 1
tasks = []
for r in self.repos:
task = partial(self.update_repo, r)
task = self.bot.loop.run_in_executor(self.executor, task)
tasks.append(task)
base_msg = "Downloading updated cogs, please wait... "
status = ' %d/%d repos updated' % (tasknum, num_repos)
msg = await self.bot.say(base_msg + status)
updated_cogs = []
new_cogs = []
deleted_cogs = []
failed_cogs = []
error_repos = {}
installed_updated_cogs = []
for f in as_completed(tasks):
tasknum += 1
try:
name, updates, oldhash = await f
if updates:
if type(updates) is dict:
for k, l in updates.items():
tl = [(name, c, oldhash) for c in l]
if k == 'A':
new_cogs.extend(tl)
elif k == 'D':
deleted_cogs.extend(tl)
elif k == 'M':
updated_cogs.extend(tl)
except UpdateError as e:
name, what = e.args
error_repos[name] = what
edit, touch_t, touch_n = regulate(touch_t, touch_n)
if edit:
status = ' %d/%d repos updated' % (tasknum, num_repos)
msg = await self._robust_edit(msg, base_msg + status)
status = 'done. '
for t in updated_cogs:
repo, cog, _ = t
if self.repos[repo][cog]['INSTALLED']:
try:
await self.install(repo, cog,
no_install_on_reqs_fail=False)
except RequirementFail:
failed_cogs.append(t)
else:
installed_updated_cogs.append(t)
for t in updated_cogs.copy():
if t in failed_cogs:
updated_cogs.remove(t)
if not any(self.repos[repo][cog]['INSTALLED'] for
repo, cog, _ in updated_cogs):
status += ' No updates to apply. '
if new_cogs:
status += '\nNew cogs: ' \
+ ', '.join('%s/%s' % c[:2] for c in new_cogs) + '.'
if deleted_cogs:
status += '\nDeleted cogs: ' \
+ ', '.join('%s/%s' % c[:2] for c in deleted_cogs) + '.'
if updated_cogs:
status += '\nUpdated cogs: ' \
+ ', '.join('%s/%s' % c[:2] for c in updated_cogs) + '.'
if failed_cogs:
status += '\nCogs that got new requirements which have ' + \
'failed to install: ' + \
', '.join('%s/%s' % c[:2] for c in failed_cogs) + '.'
if error_repos:
status += '\nThe following repos failed to update: '
for n, what in error_repos.items():
status += '\n%s: %s' % (n, what)
msg = await self._robust_edit(msg, base_msg + status)
if not installed_updated_cogs:
return
patchnote_lang = 'Prolog'
shorten_by = 8 + len(patchnote_lang)
for note in self.patch_notes_handler(installed_updated_cogs):
if note is None:
continue
for page in pagify(note, delims=['\n'], shorten_by=shorten_by):
await self.bot.say(box(page, patchnote_lang))
await self.bot.say("Cogs updated. Reload updated cogs? (yes/no)")
answer = await self.bot.wait_for_message(timeout=15,
author=ctx.message.author)
if answer is None:
await self.bot.say("Ok then, you can reload cogs with"
" `{}reload <cog_name>`".format(ctx.prefix))
elif answer.content.lower().strip() == "yes":
registry = dataIO.load_json(os.path.join("data", "red", "cogs.json"))
update_list = []
fail_list = []
for repo, cog, _ in installed_updated_cogs:
if not registry.get('cogs.' + cog, False):
continue
try:
self.bot.unload_extension("cogs." + cog)
self.bot.load_extension("cogs." + cog)
update_list.append(cog)
except:
fail_list.append(cog)
msg = 'Done.'
if update_list:
msg += " The following cogs were reloaded: "\
+ ', '.join(update_list) + "\n"
if fail_list:
msg += " The following cogs failed to reload: "\
+ ', '.join(fail_list)
await self.bot.say(msg)
else:
await self.bot.say("Ok then, you can reload cogs with"
" `{}reload <cog_name>`".format(ctx.prefix))
def patch_notes_handler(self, repo_cog_hash_pairs):
for repo, cog, oldhash in repo_cog_hash_pairs:
repo_path = os.path.join('data', 'downloader', repo)
cogfile = os.path.join(cog, cog + ".py")
cmd = ["git", "-C", repo_path, "log", "--relative-date",
"--reverse", oldhash + '..', cogfile
]
try:
log = sp_run(cmd, stdout=PIPE).stdout.decode().strip()
yield self.format_patch(repo, cog, log)
except:
pass
@cog.command(pass_context=True)
async def uninstall(self, ctx, repo_name, cog):
"""Uninstalls a cog"""
if repo_name not in self.repos:
await self.bot.say("That repo doesn't exist.")
return
if cog not in self.repos[repo_name]:
await self.bot.say("That cog isn't available from that repo.")
return
set_cog("cogs." + cog, False)
self.repos[repo_name][cog]['INSTALLED'] = False
self.save_repos()
os.remove(os.path.join("cogs", cog + ".py"))
owner = self.bot.get_cog('Owner')
await owner.unload.callback(owner, cog_name=cog)
await self.bot.say("Cog successfully uninstalled.")
@cog.command(name="install", pass_context=True)
async def _install(self, ctx, repo_name: str, cog: str):
"""Installs specified cog"""
if repo_name not in self.repos:
await self.bot.say("That repo doesn't exist.")
return
if cog not in self.repos[repo_name]:
await self.bot.say("That cog isn't available from that repo.")
return
data = self.get_info_data(repo_name, cog)
try:
install_cog = await self.install(repo_name, cog, notify_reqs=True)
except RequirementFail:
await self.bot.say("That cog has requirements that I could not "
"install. Check the console for more "
"informations.")
return
if data is not None:
install_msg = data.get("INSTALL_MSG", None)
if install_msg:
await self.bot.say(install_msg[:2000])
if install_cog:
await self.bot.say("Installation completed. Load it now? (yes/no)")
answer = await self.bot.wait_for_message(timeout=15,
author=ctx.message.author)
if answer is None:
await self.bot.say("Ok then, you can load it with"
" `{}load {}`".format(ctx.prefix, cog))
elif answer.content.lower().strip() == "yes":
set_cog("cogs." + cog, True)
owner = self.bot.get_cog('Owner')
await owner.load.callback(owner, cog_name=cog)
else:
await self.bot.say("Ok then, you can load it with"
" `{}load {}`".format(ctx.prefix, cog))
elif install_cog is False:
await self.bot.say("Invalid cog. Installation aborted.")
else:
await self.bot.say("That cog doesn't exist. Use cog list to see"
" the full list.")
async def install(self, repo_name, cog, *, notify_reqs=False,
no_install_on_reqs_fail=True):
# 'no_install_on_reqs_fail' will make the cog get installed anyway
# on requirements installation fail. This is necessary because due to
# how 'cog update' works right now, the user would have no way to
# reupdate the cog if the update fails, since 'cog update' only
# updates the cogs that get a new commit.
# This is not a great way to deal with the problem and a cog update
# rework would probably be the best course of action.
reqs_failed = False
if cog.endswith('.py'):
cog = cog[:-3]
path = self.repos[repo_name][cog]['file']
cog_folder_path = self.repos[repo_name][cog]['folder']
cog_data_path = os.path.join(cog_folder_path, 'data')
data = self.get_info_data(repo_name, cog)
if data is not None:
requirements = data.get("REQUIREMENTS", [])
requirements = [r for r in requirements
if not self.is_lib_installed(r)]
if requirements and notify_reqs:
await self.bot.say("Installing cog's requirements...")
for requirement in requirements:
if not self.is_lib_installed(requirement):
success = await self.bot.pip_install(requirement)
if not success:
if no_install_on_reqs_fail:
raise RequirementFail()
else:
reqs_failed = True
to_path = os.path.join("cogs", cog + ".py")
print("Copying {}...".format(cog))
shutil.copy(path, to_path)
if os.path.exists(cog_data_path):
print("Copying {}'s data folder...".format(cog))
distutils.dir_util.copy_tree(cog_data_path,
os.path.join('data', cog))
self.repos[repo_name][cog]['INSTALLED'] = True
self.save_repos()
if not reqs_failed:
return True
else:
raise RequirementFail()
def get_info_data(self, repo_name, cog=None):
if cog is not None:
cogs = self.list_cogs(repo_name)
if cog in cogs:
info_file = os.path.join(cogs[cog].get('folder'), "info.json")
if os.path.isfile(info_file):
try:
data = dataIO.load_json(info_file)
except:
return None
return data
else:
repo_info = os.path.join(self.path, repo_name, 'info.json')
if os.path.isfile(repo_info):
try:
data = dataIO.load_json(repo_info)
return data
except:
return None
return None
def list_cogs(self, repo_name):
valid_cogs = {}
repo_path = os.path.join(self.path, repo_name)
folders = [f for f in os.listdir(repo_path)
if os.path.isdir(os.path.join(repo_path, f))]
legacy_path = os.path.join(repo_path, "cogs")
legacy_folders = []
if os.path.exists(legacy_path):
for f in os.listdir(legacy_path):
if os.path.isdir(os.path.join(legacy_path, f)):
legacy_folders.append(os.path.join("cogs", f))
folders = folders + legacy_folders
for f in folders:
cog_folder_path = os.path.join(self.path, repo_name, f)
cog_folder = os.path.basename(cog_folder_path)
for cog in os.listdir(cog_folder_path):
cog_path = os.path.join(cog_folder_path, cog)
if os.path.isfile(cog_path) and cog_folder == cog[:-3]:
valid_cogs[cog[:-3]] = {'folder': cog_folder_path,
'file': cog_path}
return valid_cogs
def get_dir_name(self, url):
splitted = url.split("/")
git_name = splitted[-1]
return git_name[:-4]
def is_lib_installed(self, name):
return bool(find_spec(name))
def _do_first_run(self):
save = False
repos_copy = deepcopy(self.repos)
# Issue 725
for repo in repos_copy:
for cog in repos_copy[repo]:
cog_data = repos_copy[repo][cog]
if isinstance(cog_data, str): # ... url field
continue
for k, v in cog_data.items():
if k in ("file", "folder"):
repos_copy[repo][cog][k] = os.path.normpath(cog_data[k])
if self.repos != repos_copy:
self.repos = repos_copy
save = True
invalid = []
for repo in self.repos:
broken = 'url' in self.repos[repo] and len(self.repos[repo]) == 1
if broken:
save = True
try:
self.update_repo(repo)
self.populate_list(repo)
except CloningError:
invalid.append(repo)
continue
except Exception as e:
print(e) # TODO: Proper logging
continue
for repo in invalid:
del self.repos[repo]
if save:
self.save_repos()
def populate_list(self, name):
valid_cogs = self.list_cogs(name)
new = set(valid_cogs.keys())
old = set(self.repos[name].keys())
for cog in new - old:
self.repos[name][cog] = valid_cogs.get(cog, {})
self.repos[name][cog]['INSTALLED'] = False
for cog in new & old:
self.repos[name][cog].update(valid_cogs[cog])
for cog in old - new:
if cog != 'url':
del self.repos[name][cog]
def update_repo(self, name):
def run(*args, **kwargs):
env = os.environ.copy()
env['GIT_TERMINAL_PROMPT'] = '0'
kwargs['env'] = env
return sp_run(*args, **kwargs)
try:
dd = self.path
if name not in self.repos:
raise UpdateError("Repo does not exist in data, wtf")
folder = os.path.join(dd, name)
# Make sure we don't git reset the Red folder on accident
if not os.path.exists(os.path.join(folder, '.git')):
#if os.path.exists(folder):
#shutil.rmtree(folder)
url = self.repos[name].get('url')
if not url:
raise UpdateError("Need to clone but no URL set")
branch = None
if "@" in url: # Specific branch
url, branch = url.rsplit("@", maxsplit=1)
if branch is None:
p = run(["git", "clone", url, folder])
else:
p = run(["git", "clone", "-b", branch, url, folder])
if p.returncode != 0:
raise CloningError()
self.populate_list(name)
return name, REPO_CLONE, None
else:
rpbcmd = ["git", "-C", folder, "rev-parse", "--abbrev-ref", "HEAD"]
p = run(rpbcmd, stdout=PIPE)
branch = p.stdout.decode().strip()
rpcmd = ["git", "-C", folder, "rev-parse", branch]
p = run(["git", "-C", folder, "reset", "--hard",
"origin/%s" % branch, "-q"])
if p.returncode != 0:
raise UpdateError("Error resetting to origin/%s" % branch)
p = run(rpcmd, stdout=PIPE)
if p.returncode != 0:
raise UpdateError("Unable to determine old commit hash")
oldhash = p.stdout.decode().strip()
p = run(["git", "-C", folder, "pull", "-q", "--ff-only"])
if p.returncode != 0:
raise UpdateError("Error pulling updates")
p = run(rpcmd, stdout=PIPE)
if p.returncode != 0:
raise UpdateError("Unable to determine new commit hash")
newhash = p.stdout.decode().strip()
if oldhash == newhash:
return name, REPO_SAME, None
else:
self.populate_list(name)
self.save_repos()
ret = {}
cmd = ['git', '-C', folder, 'diff', '--no-commit-id',
'--name-status', oldhash, newhash]
p = run(cmd, stdout=PIPE)
if p.returncode != 0:
raise UpdateError("Error in git diff")
changed = p.stdout.strip().decode().split('\n')
for f in changed:
if not f.endswith('.py'):
continue
status, _, cogpath = f.partition('\t')
cogname = os.path.split(cogpath)[-1][:-3] # strip .py
if status not in ret:
ret[status] = []
ret[status].append(cogname)
return name, ret, oldhash
except CloningError as e:
raise CloningError(name, *e.args) from None
except UpdateError as e:
raise UpdateError(name, *e.args) from None
async def _robust_edit(self, msg, text):
try:
msg = await self.bot.edit_message(msg, text)
except discord.errors.NotFound:
msg = await self.bot.send_message(msg.channel, text)
except:
raise
return msg
@staticmethod
def format_patch(repo, cog, log):
header = "Patch Notes for %s/%s" % (repo, cog)
line = "=" * len(header)
if log:
return '\n'.join((header, line, log))
def check_folders():
if not os.path.exists(os.path.join("data", "downloader")):
print('Making repo downloads folder...')
os.mkdir(os.path.join("data", "downloader"))
def check_files():
f = os.path.join("data", "downloader", "repos.json")
if not dataIO.is_valid_json(f):
print("Creating default data/downloader/repos.json")
dataIO.save_json(f, {})
def setup(bot):
check_folders()
check_files()
n = Downloader(bot)
bot.add_cog(n)

View File

@ -0,0 +1,736 @@
import discord
from discord.ext import commands
from cogs.utils.dataIO import dataIO
from collections import namedtuple, defaultdict, deque
from datetime import datetime
from copy import deepcopy
from .utils import checks
from cogs.utils.chat_formatting import pagify, box
from enum import Enum
from __main__ import send_cmd_help
import os
import time
import logging
import random
default_settings = {"PAYDAY_TIME": 300, "PAYDAY_CREDITS": 120,
"SLOT_MIN": 5, "SLOT_MAX": 100, "SLOT_TIME": 0,
"REGISTER_CREDITS": 0}
class EconomyError(Exception):
pass
class OnCooldown(EconomyError):
pass
class InvalidBid(EconomyError):
pass
class BankError(Exception):
pass
class AccountAlreadyExists(BankError):
pass
class NoAccount(BankError):
pass
class InsufficientBalance(BankError):
pass
class NegativeValue(BankError):
pass
class SameSenderAndReceiver(BankError):
pass
NUM_ENC = "\N{COMBINING ENCLOSING KEYCAP}"
class SMReel(Enum):
cherries = "\N{CHERRIES}"
cookie = "\N{COOKIE}"
two = "\N{DIGIT TWO}" + NUM_ENC
flc = "\N{FOUR LEAF CLOVER}"
cyclone = "\N{CYCLONE}"
sunflower = "\N{SUNFLOWER}"
six = "\N{DIGIT SIX}" + NUM_ENC
mushroom = "\N{MUSHROOM}"
heart = "\N{HEAVY BLACK HEART}"
snowflake = "\N{SNOWFLAKE}"
PAYOUTS = {
(SMReel.two, SMReel.two, SMReel.six) : {
"payout" : lambda x: x * 2500 + x,
"phrase" : "JACKPOT! 226! Your bid has been multiplied * 2500!"
},
(SMReel.flc, SMReel.flc, SMReel.flc) : {
"payout" : lambda x: x + 1000,
"phrase" : "4LC! +1000!"
},
(SMReel.cherries, SMReel.cherries, SMReel.cherries) : {
"payout" : lambda x: x + 800,
"phrase" : "Three cherries! +800!"
},
(SMReel.two, SMReel.six) : {
"payout" : lambda x: x * 4 + x,
"phrase" : "2 6! Your bid has been multiplied * 4!"
},
(SMReel.cherries, SMReel.cherries) : {
"payout" : lambda x: x * 3 + x,
"phrase" : "Two cherries! Your bid has been multiplied * 3!"
},
"3 symbols" : {
"payout" : lambda x: x + 500,
"phrase" : "Three symbols! +500!"
},
"2 symbols" : {
"payout" : lambda x: x * 2 + x,
"phrase" : "Two consecutive symbols! Your bid has been multiplied * 2!"
},
}
SLOT_PAYOUTS_MSG = ("Slot machine payouts:\n"
"{two.value} {two.value} {six.value} Bet * 2500\n"
"{flc.value} {flc.value} {flc.value} +1000\n"
"{cherries.value} {cherries.value} {cherries.value} +800\n"
"{two.value} {six.value} Bet * 4\n"
"{cherries.value} {cherries.value} Bet * 3\n\n"
"Three symbols: +500\n"
"Two symbols: Bet * 2".format(**SMReel.__dict__))
class Bank:
def __init__(self, bot, file_path):
self.accounts = dataIO.load_json(file_path)
self.bot = bot
def create_account(self, user, *, initial_balance=0):
server = user.server
if not self.account_exists(user):
if server.id not in self.accounts:
self.accounts[server.id] = {}
if user.id in self.accounts: # Legacy account
balance = self.accounts[user.id]["balance"]
else:
balance = initial_balance
timestamp = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
account = {"name": user.name,
"balance": balance,
"created_at": timestamp
}
self.accounts[server.id][user.id] = account
self._save_bank()
return self.get_account(user)
else:
raise AccountAlreadyExists()
def account_exists(self, user):
try:
self._get_account(user)
except NoAccount:
return False
return True
def withdraw_credits(self, user, amount):
server = user.server
if amount < 0:
raise NegativeValue()
account = self._get_account(user)
if account["balance"] >= amount:
account["balance"] -= amount
self.accounts[server.id][user.id] = account
self._save_bank()
else:
raise InsufficientBalance()
def deposit_credits(self, user, amount):
server = user.server
if amount < 0:
raise NegativeValue()
account = self._get_account(user)
account["balance"] += amount
self.accounts[server.id][user.id] = account
self._save_bank()
def set_credits(self, user, amount):
server = user.server
if amount < 0:
raise NegativeValue()
account = self._get_account(user)
account["balance"] = amount
self.accounts[server.id][user.id] = account
self._save_bank()
def transfer_credits(self, sender, receiver, amount):
if amount < 0:
raise NegativeValue()
if sender is receiver:
raise SameSenderAndReceiver()
if self.account_exists(sender) and self.account_exists(receiver):
sender_acc = self._get_account(sender)
if sender_acc["balance"] < amount:
raise InsufficientBalance()
self.withdraw_credits(sender, amount)
self.deposit_credits(receiver, amount)
else:
raise NoAccount()
def can_spend(self, user, amount):
account = self._get_account(user)
if account["balance"] >= amount:
return True
else:
return False
def wipe_bank(self, server):
self.accounts[server.id] = {}
self._save_bank()
def get_server_accounts(self, server):
if server.id in self.accounts:
raw_server_accounts = deepcopy(self.accounts[server.id])
accounts = []
for k, v in raw_server_accounts.items():
v["id"] = k
v["server"] = server
acc = self._create_account_obj(v)
accounts.append(acc)
return accounts
else:
return []
def get_all_accounts(self):
accounts = []
for server_id, v in self.accounts.items():
server = self.bot.get_server(server_id)
if server is None:
# Servers that have since been left will be ignored
# Same for users_id from the old bank format
continue
raw_server_accounts = deepcopy(self.accounts[server.id])
for k, v in raw_server_accounts.items():
v["id"] = k
v["server"] = server
acc = self._create_account_obj(v)
accounts.append(acc)
return accounts
def get_balance(self, user):
account = self._get_account(user)
return account["balance"]
def get_account(self, user):
acc = self._get_account(user)
acc["id"] = user.id
acc["server"] = user.server
return self._create_account_obj(acc)
def _create_account_obj(self, account):
account["member"] = account["server"].get_member(account["id"])
account["created_at"] = datetime.strptime(account["created_at"],
"%Y-%m-%d %H:%M:%S")
Account = namedtuple("Account", "id name balance "
"created_at server member")
return Account(**account)
def _save_bank(self):
dataIO.save_json("data/economy/bank.json", self.accounts)
def _get_account(self, user):
server = user.server
try:
return deepcopy(self.accounts[server.id][user.id])
except KeyError:
raise NoAccount()
class SetParser:
def __init__(self, argument):
allowed = ("+", "-")
if argument and argument[0] in allowed:
try:
self.sum = int(argument)
except:
raise
if self.sum < 0:
self.operation = "withdraw"
elif self.sum > 0:
self.operation = "deposit"
else:
raise
self.sum = abs(self.sum)
elif argument.isdigit():
self.sum = int(argument)
self.operation = "set"
else:
raise
class Economy:
"""Economy
Get rich and have fun with imaginary currency!"""
def __init__(self, bot):
global default_settings
self.bot = bot
self.bank = Bank(bot, "data/economy/bank.json")
self.file_path = "data/economy/settings.json"
self.settings = dataIO.load_json(self.file_path)
if "PAYDAY_TIME" in self.settings: # old format
default_settings = self.settings
self.settings = {}
self.settings = defaultdict(lambda: default_settings, self.settings)
self.payday_register = defaultdict(dict)
self.slot_register = defaultdict(dict)
@commands.group(name="bank", pass_context=True)
async def _bank(self, ctx):
"""Bank operations"""
if ctx.invoked_subcommand is None:
await send_cmd_help(ctx)
@_bank.command(pass_context=True, no_pm=True)
async def register(self, ctx):
"""Registers an account at the Twentysix bank"""
settings = self.settings[ctx.message.server.id]
author = ctx.message.author
credits = 0
if ctx.message.server.id in self.settings:
credits = settings.get("REGISTER_CREDITS", 0)
try:
account = self.bank.create_account(author, initial_balance=credits)
await self.bot.say("{} Account opened. Current balance: {}"
"".format(author.mention, account.balance))
except AccountAlreadyExists:
await self.bot.say("{} You already have an account at the"
" Twentysix bank.".format(author.mention))
@_bank.command(pass_context=True)
async def balance(self, ctx, user: discord.Member=None):
"""Shows balance of user.
Defaults to yours."""
if not user:
user = ctx.message.author
try:
await self.bot.say("{} Your balance is: {}".format(
user.mention, self.bank.get_balance(user)))
except NoAccount:
await self.bot.say("{} You don't have an account at the"
" Twentysix bank. Type `{}bank register`"
" to open one.".format(user.mention,
ctx.prefix))
else:
try:
await self.bot.say("{}'s balance is {}".format(
user.name, self.bank.get_balance(user)))
except NoAccount:
await self.bot.say("That user has no bank account.")
@_bank.command(pass_context=True)
async def transfer(self, ctx, user: discord.Member, sum: int):
"""Transfer credits to other users"""
author = ctx.message.author
try:
self.bank.transfer_credits(author, user, sum)
logger.info("{}({}) transferred {} credits to {}({})".format(
author.name, author.id, sum, user.name, user.id))
await self.bot.say("{} credits have been transferred to {}'s"
" account.".format(sum, user.name))
except NegativeValue:
await self.bot.say("You need to transfer at least 1 credit.")
except SameSenderAndReceiver:
await self.bot.say("You can't transfer credits to yourself.")
except InsufficientBalance:
await self.bot.say("You don't have that sum in your bank account.")
except NoAccount:
await self.bot.say("That user has no bank account.")
@_bank.command(name="set", pass_context=True)
@checks.admin_or_permissions(manage_server=True)
async def _set(self, ctx, user: discord.Member, credits: SetParser):
"""Sets credits of user's bank account. See help for more operations
Passing positive and negative values will add/remove credits instead
Examples:
bank set @Twentysix 26 - Sets 26 credits
bank set @Twentysix +2 - Adds 2 credits
bank set @Twentysix -6 - Removes 6 credits"""
author = ctx.message.author
try:
if credits.operation == "deposit":
self.bank.deposit_credits(user, credits.sum)
logger.info("{}({}) added {} credits to {} ({})".format(
author.name, author.id, credits.sum, user.name, user.id))
await self.bot.say("{} credits have been added to {}"
"".format(credits.sum, user.name))
elif credits.operation == "withdraw":
self.bank.withdraw_credits(user, credits.sum)
logger.info("{}({}) removed {} credits to {} ({})".format(
author.name, author.id, credits.sum, user.name, user.id))
await self.bot.say("{} credits have been withdrawn from {}"
"".format(credits.sum, user.name))
elif credits.operation == "set":
self.bank.set_credits(user, credits.sum)
logger.info("{}({}) set {} credits to {} ({})"
"".format(author.name, author.id, credits.sum,
user.name, user.id))
await self.bot.say("{}'s credits have been set to {}".format(
user.name, credits.sum))
except InsufficientBalance:
await self.bot.say("User doesn't have enough credits.")
except NoAccount:
await self.bot.say("User has no bank account.")
@_bank.command(pass_context=True, no_pm=True)
@checks.serverowner_or_permissions(administrator=True)
async def reset(self, ctx, confirmation: bool=False):
"""Deletes all server's bank accounts"""
if confirmation is False:
await self.bot.say("This will delete all bank accounts on "
"this server.\nIf you're sure, type "
"{}bank reset yes".format(ctx.prefix))
else:
self.bank.wipe_bank(ctx.message.server)
await self.bot.say("All bank accounts of this server have been "
"deleted.")
@commands.command(pass_context=True, no_pm=True)
async def payday(self, ctx): # TODO
"""Get some free credits"""
author = ctx.message.author
server = author.server
id = author.id
if self.bank.account_exists(author):
if id in self.payday_register[server.id]:
seconds = abs(self.payday_register[server.id][
id] - int(time.perf_counter()))
if seconds >= self.settings[server.id]["PAYDAY_TIME"]:
self.bank.deposit_credits(author, self.settings[
server.id]["PAYDAY_CREDITS"])
self.payday_register[server.id][
id] = int(time.perf_counter())
await self.bot.say(
"{} Here, take some credits. Enjoy! (+{}"
" credits!)".format(
author.mention,
str(self.settings[server.id]["PAYDAY_CREDITS"])))
else:
dtime = self.display_time(
self.settings[server.id]["PAYDAY_TIME"] - seconds)
await self.bot.say(
"{} Too soon. For your next payday you have to"
" wait {}.".format(author.mention, dtime))
else:
self.payday_register[server.id][id] = int(time.perf_counter())
self.bank.deposit_credits(author, self.settings[
server.id]["PAYDAY_CREDITS"])
await self.bot.say(
"{} Here, take some credits. Enjoy! (+{} credits!)".format(
author.mention,
str(self.settings[server.id]["PAYDAY_CREDITS"])))
else:
await self.bot.say("{} You need an account to receive credits."
" Type `{}bank register` to open one.".format(
author.mention, ctx.prefix))
@commands.group(pass_context=True)
async def leaderboard(self, ctx):
"""Server / global leaderboard
Defaults to server"""
if ctx.invoked_subcommand is None:
await ctx.invoke(self._server_leaderboard)
@leaderboard.command(name="server", pass_context=True)
async def _server_leaderboard(self, ctx, top: int=10):
"""Prints out the server's leaderboard
Defaults to top 10"""
# Originally coded by Airenkun - edited by irdumb
server = ctx.message.server
if top < 1:
top = 10
bank_sorted = sorted(self.bank.get_server_accounts(server),
key=lambda x: x.balance, reverse=True)
bank_sorted = [a for a in bank_sorted if a.member] # exclude users who left
if len(bank_sorted) < top:
top = len(bank_sorted)
topten = bank_sorted[:top]
highscore = ""
place = 1
for acc in topten:
highscore += str(place).ljust(len(str(top)) + 1)
highscore += (str(acc.member.display_name) + " ").ljust(23 - len(str(acc.balance)))
highscore += str(acc.balance) + "\n"
place += 1
if highscore != "":
for page in pagify(highscore, shorten_by=12):
await self.bot.say(box(page, lang="py"))
else:
await self.bot.say("There are no accounts in the bank.")
@leaderboard.command(name="global")
async def _global_leaderboard(self, top: int=10):
"""Prints out the global leaderboard
Defaults to top 10"""
if top < 1:
top = 10
bank_sorted = sorted(self.bank.get_all_accounts(),
key=lambda x: x.balance, reverse=True)
bank_sorted = [a for a in bank_sorted if a.member] # exclude users who left
unique_accounts = []
for acc in bank_sorted:
if not self.already_in_list(unique_accounts, acc):
unique_accounts.append(acc)
if len(unique_accounts) < top:
top = len(unique_accounts)
topten = unique_accounts[:top]
highscore = ""
place = 1
for acc in topten:
highscore += str(place).ljust(len(str(top)) + 1)
highscore += ("{} |{}| ".format(acc.member, acc.server)
).ljust(23 - len(str(acc.balance)))
highscore += str(acc.balance) + "\n"
place += 1
if highscore != "":
for page in pagify(highscore, shorten_by=12):
await self.bot.say(box(page, lang="py"))
else:
await self.bot.say("There are no accounts in the bank.")
def already_in_list(self, accounts, user):
for acc in accounts:
if user.id == acc.id:
return True
return False
@commands.command()
async def payouts(self):
"""Shows slot machine payouts"""
await self.bot.whisper(SLOT_PAYOUTS_MSG)
@commands.command(pass_context=True, no_pm=True)
async def slot(self, ctx, bid: int):
"""Play the slot machine"""
author = ctx.message.author
server = author.server
settings = self.settings[server.id]
valid_bid = settings["SLOT_MIN"] <= bid and bid <= settings["SLOT_MAX"]
slot_time = settings["SLOT_TIME"]
last_slot = self.slot_register.get(author.id)
now = datetime.utcnow()
try:
if last_slot:
if (now - last_slot).seconds < slot_time:
raise OnCooldown()
if not valid_bid:
raise InvalidBid()
if not self.bank.can_spend(author, bid):
raise InsufficientBalance
await self.slot_machine(author, bid)
except NoAccount:
await self.bot.say("{} You need an account to use the slot "
"machine. Type `{}bank register` to open one."
"".format(author.mention, ctx.prefix))
except InsufficientBalance:
await self.bot.say("{} You need an account with enough funds to "
"play the slot machine.".format(author.mention))
except OnCooldown:
await self.bot.say("Slot machine is still cooling off! Wait {} "
"seconds between each pull".format(slot_time))
except InvalidBid:
await self.bot.say("Bid must be between {} and {}."
"".format(settings["SLOT_MIN"],
settings["SLOT_MAX"]))
async def slot_machine(self, author, bid):
default_reel = deque(SMReel)
reels = []
self.slot_register[author.id] = datetime.utcnow()
for i in range(3):
default_reel.rotate(random.randint(-999, 999)) # weeeeee
new_reel = deque(default_reel, maxlen=3) # we need only 3 symbols
reels.append(new_reel) # for each reel
rows = ((reels[0][0], reels[1][0], reels[2][0]),
(reels[0][1], reels[1][1], reels[2][1]),
(reels[0][2], reels[1][2], reels[2][2]))
slot = "~~\n~~" # Mobile friendly
for i, row in enumerate(rows): # Let's build the slot to show
sign = " "
if i == 1:
sign = ">"
slot += "{}{} {} {}\n".format(sign, *[c.value for c in row])
payout = PAYOUTS.get(rows[1])
if not payout:
# Checks for two-consecutive-symbols special rewards
payout = PAYOUTS.get((rows[1][0], rows[1][1]),
PAYOUTS.get((rows[1][1], rows[1][2]))
)
if not payout:
# Still nothing. Let's check for 3 generic same symbols
# or 2 consecutive symbols
has_three = rows[1][0] == rows[1][1] == rows[1][2]
has_two = (rows[1][0] == rows[1][1]) or (rows[1][1] == rows[1][2])
if has_three:
payout = PAYOUTS["3 symbols"]
elif has_two:
payout = PAYOUTS["2 symbols"]
if payout:
then = self.bank.get_balance(author)
pay = payout["payout"](bid)
now = then - bid + pay
self.bank.set_credits(author, now)
await self.bot.say("{}\n{} {}\n\nYour bid: {}\n{}{}!"
"".format(slot, author.mention,
payout["phrase"], bid, then, now))
else:
then = self.bank.get_balance(author)
self.bank.withdraw_credits(author, bid)
now = then - bid
await self.bot.say("{}\n{} Nothing!\nYour bid: {}\n{}{}!"
"".format(slot, author.mention, bid, then, now))
@commands.group(pass_context=True, no_pm=True)
@checks.admin_or_permissions(manage_server=True)
async def economyset(self, ctx):
"""Changes economy module settings"""
server = ctx.message.server
settings = self.settings[server.id]
if ctx.invoked_subcommand is None:
msg = "```"
for k, v in settings.items():
msg += "{}: {}\n".format(k, v)
msg += "```"
await send_cmd_help(ctx)
await self.bot.say(msg)
@economyset.command(pass_context=True)
async def slotmin(self, ctx, bid: int):
"""Minimum slot machine bid"""
server = ctx.message.server
self.settings[server.id]["SLOT_MIN"] = bid
await self.bot.say("Minimum bid is now {} credits.".format(bid))
dataIO.save_json(self.file_path, self.settings)
@economyset.command(pass_context=True)
async def slotmax(self, ctx, bid: int):
"""Maximum slot machine bid"""
server = ctx.message.server
self.settings[server.id]["SLOT_MAX"] = bid
await self.bot.say("Maximum bid is now {} credits.".format(bid))
dataIO.save_json(self.file_path, self.settings)
@economyset.command(pass_context=True)
async def slottime(self, ctx, seconds: int):
"""Seconds between each slots use"""
server = ctx.message.server
self.settings[server.id]["SLOT_TIME"] = seconds
await self.bot.say("Cooldown is now {} seconds.".format(seconds))
dataIO.save_json(self.file_path, self.settings)
@economyset.command(pass_context=True)
async def paydaytime(self, ctx, seconds: int):
"""Seconds between each payday"""
server = ctx.message.server
self.settings[server.id]["PAYDAY_TIME"] = seconds
await self.bot.say("Value modified. At least {} seconds must pass "
"between each payday.".format(seconds))
dataIO.save_json(self.file_path, self.settings)
@economyset.command(pass_context=True)
async def paydaycredits(self, ctx, credits: int):
"""Credits earned each payday"""
server = ctx.message.server
self.settings[server.id]["PAYDAY_CREDITS"] = credits
await self.bot.say("Every payday will now give {} credits."
"".format(credits))
dataIO.save_json(self.file_path, self.settings)
@economyset.command(pass_context=True)
async def registercredits(self, ctx, credits: int):
"""Credits given on registering an account"""
server = ctx.message.server
if credits < 0:
credits = 0
self.settings[server.id]["REGISTER_CREDITS"] = credits
await self.bot.say("Registering an account will now give {} credits."
"".format(credits))
dataIO.save_json(self.file_path, self.settings)
# What would I ever do without stackoverflow?
def display_time(self, seconds, granularity=2):
intervals = ( # Source: http://stackoverflow.com/a/24542445
('weeks', 604800), # 60 * 60 * 24 * 7
('days', 86400), # 60 * 60 * 24
('hours', 3600), # 60 * 60
('minutes', 60),
('seconds', 1),
)
result = []
for name, count in intervals:
value = seconds // count
if value:
seconds -= value * count
if value == 1:
name = name.rstrip('s')
result.append("{} {}".format(value, name))
return ', '.join(result[:granularity])
def check_folders():
if not os.path.exists("data/economy"):
print("Creating data/economy folder...")
os.makedirs("data/economy")
def check_files():
f = "data/economy/settings.json"
if not dataIO.is_valid_json(f):
print("Creating default economy's settings.json...")
dataIO.save_json(f, {})
f = "data/economy/bank.json"
if not dataIO.is_valid_json(f):
print("Creating empty bank.json...")
dataIO.save_json(f, {})
def setup(bot):
global logger
check_folders()
check_files()
logger = logging.getLogger("red.economy")
if logger.level == 0:
# Prevents the logger from being loaded again in case of module reload
logger.setLevel(logging.INFO)
handler = logging.FileHandler(
filename='data/economy/economy.log', encoding='utf-8', mode='a')
handler.setFormatter(logging.Formatter(
'%(asctime)s %(message)s', datefmt="[%d/%m/%Y %H:%M]"))
logger.addHandler(handler)
bot.add_cog(Economy(bot))

View File

@ -0,0 +1,433 @@
import discord
from discord.ext import commands
from .utils.chat_formatting import escape_mass_mentions, italics, pagify
from random import randint
from random import choice
from enum import Enum
from urllib.parse import quote_plus
import datetime
import time
import aiohttp
import asyncio
settings = {"POLL_DURATION" : 60}
class RPS(Enum):
rock = "\N{MOYAI}"
paper = "\N{PAGE FACING UP}"
scissors = "\N{BLACK SCISSORS}"
class RPSParser:
def __init__(self, argument):
argument = argument.lower()
if argument == "rock":
self.choice = RPS.rock
elif argument == "paper":
self.choice = RPS.paper
elif argument == "scissors":
self.choice = RPS.scissors
else:
raise
class General:
"""General commands."""
def __init__(self, bot):
self.bot = bot
self.stopwatches = {}
self.ball = ["As I see it, yes", "It is certain", "It is decidedly so", "Most likely", "Outlook good",
"Signs point to yes", "Without a doubt", "Yes", "Yes definitely", "You may rely on it", "Reply hazy, try again",
"Ask again later", "Better not tell you now", "Cannot predict now", "Concentrate and ask again",
"Don't count on it", "My reply is no", "My sources say no", "Outlook not so good", "Very doubtful"]
self.poll_sessions = []
@commands.command(hidden=True)
async def ping(self):
"""Pong."""
await self.bot.say("Pong.")
@commands.command()
async def choose(self, *choices):
"""Chooses between multiple choices.
To denote multiple choices, you should use double quotes.
"""
choices = [escape_mass_mentions(c) for c in choices]
if len(choices) < 2:
await self.bot.say('Not enough choices to pick from.')
else:
await self.bot.say(choice(choices))
@commands.command(pass_context=True)
async def roll(self, ctx, number : int = 100):
"""Rolls random number (between 1 and user choice)
Defaults to 100.
"""
author = ctx.message.author
if number > 1:
n = randint(1, number)
await self.bot.say("{} :game_die: {} :game_die:".format(author.mention, n))
else:
await self.bot.say("{} Maybe higher than 1? ;P".format(author.mention))
@commands.command(pass_context=True)
async def flip(self, ctx, user : discord.Member=None):
"""Flips a coin... or a user.
Defaults to coin.
"""
if user != None:
msg = ""
if user.id == self.bot.user.id:
user = ctx.message.author
msg = "Nice try. You think this is funny? How about *this* instead:\n\n"
char = "abcdefghijklmnopqrstuvwxyz"
tran = "ɐqɔpǝɟƃɥᴉɾʞlɯuodbɹsʇnʌʍxʎz"
table = str.maketrans(char, tran)
name = user.display_name.translate(table)
char = char.upper()
tran = "∀qƆpƎℲפHIſʞ˥WNOԀQᴚS┴∩ΛMX⅄Z"
table = str.maketrans(char, tran)
name = name.translate(table)
await self.bot.say(msg + "(╯°□°)╯︵ " + name[::-1])
else:
await self.bot.say("*flips a coin and... " + choice(["HEADS!*", "TAILS!*"]))
@commands.command(pass_context=True)
async def rps(self, ctx, your_choice : RPSParser):
"""Play rock paper scissors"""
author = ctx.message.author
player_choice = your_choice.choice
red_choice = choice((RPS.rock, RPS.paper, RPS.scissors))
cond = {
(RPS.rock, RPS.paper) : False,
(RPS.rock, RPS.scissors) : True,
(RPS.paper, RPS.rock) : True,
(RPS.paper, RPS.scissors) : False,
(RPS.scissors, RPS.rock) : False,
(RPS.scissors, RPS.paper) : True
}
if red_choice == player_choice:
outcome = None # Tie
else:
outcome = cond[(player_choice, red_choice)]
if outcome is True:
await self.bot.say("{} You win {}!"
"".format(red_choice.value, author.mention))
elif outcome is False:
await self.bot.say("{} You lose {}!"
"".format(red_choice.value, author.mention))
else:
await self.bot.say("{} We're square {}!"
"".format(red_choice.value, author.mention))
@commands.command(name="8", aliases=["8ball"])
async def _8ball(self, *, question : str):
"""Ask 8 ball a question
Question must end with a question mark.
"""
if question.endswith("?") and question != "?":
await self.bot.say("`" + choice(self.ball) + "`")
else:
await self.bot.say("That doesn't look like a question.")
@commands.command(aliases=["sw"], pass_context=True)
async def stopwatch(self, ctx):
"""Starts/stops stopwatch"""
author = ctx.message.author
if not author.id in self.stopwatches:
self.stopwatches[author.id] = int(time.perf_counter())
await self.bot.say(author.mention + " Stopwatch started!")
else:
tmp = abs(self.stopwatches[author.id] - int(time.perf_counter()))
tmp = str(datetime.timedelta(seconds=tmp))
await self.bot.say(author.mention + " Stopwatch stopped! Time: **" + tmp + "**")
self.stopwatches.pop(author.id, None)
@commands.command()
async def lmgtfy(self, *, search_terms : str):
"""Creates a lmgtfy link"""
search_terms = escape_mass_mentions(search_terms.replace(" ", "+"))
await self.bot.say("https://lmgtfy.com/?q={}".format(search_terms))
@commands.command(no_pm=True, hidden=True)
async def hug(self, user : discord.Member, intensity : int=1):
"""Because everyone likes hugs
Up to 10 intensity levels."""
name = italics(user.display_name)
if intensity <= 0:
msg = "(っ˘̩╭╮˘̩)っ" + name
elif intensity <= 3:
msg = "(っ´▽`)っ" + name
elif intensity <= 6:
msg = "╰(*´︶`*)╯" + name
elif intensity <= 9:
msg = "(つ≧▽≦)つ" + name
elif intensity >= 10:
msg = "(づ ̄ ³ ̄)づ{} ⊂(´・ω・`⊂)".format(name)
await self.bot.say(msg)
@commands.command(pass_context=True, no_pm=True)
async def userinfo(self, ctx, *, user: discord.Member=None):
"""Shows users's informations"""
author = ctx.message.author
server = ctx.message.server
if not user:
user = author
roles = [x.name for x in user.roles if x.name != "@everyone"]
joined_at = self.fetch_joined_at(user, server)
since_created = (ctx.message.timestamp - user.created_at).days
since_joined = (ctx.message.timestamp - joined_at).days
user_joined = joined_at.strftime("%d %b %Y %H:%M")
user_created = user.created_at.strftime("%d %b %Y %H:%M")
member_number = sorted(server.members,
key=lambda m: m.joined_at).index(user) + 1
created_on = "{}\n({} days ago)".format(user_created, since_created)
joined_on = "{}\n({} days ago)".format(user_joined, since_joined)
game = "Chilling in {} status".format(user.status)
if user.game is None:
pass
elif user.game.url is None:
game = "Playing {}".format(user.game)
else:
game = "Streaming: [{}]({})".format(user.game, user.game.url)
if roles:
roles = sorted(roles, key=[x.name for x in server.role_hierarchy
if x.name != "@everyone"].index)
roles = ", ".join(roles)
else:
roles = "None"
data = discord.Embed(description=game, colour=user.colour)
data.add_field(name="Joined Discord on", value=created_on)
data.add_field(name="Joined this server on", value=joined_on)
data.add_field(name="Roles", value=roles, inline=False)
data.set_footer(text="Member #{} | User ID:{}"
"".format(member_number, user.id))
name = str(user)
name = " ~ ".join((name, user.nick)) if user.nick else name
if user.avatar_url:
data.set_author(name=name, url=user.avatar_url)
data.set_thumbnail(url=user.avatar_url)
else:
data.set_author(name=name)
try:
await self.bot.say(embed=data)
except discord.HTTPException:
await self.bot.say("I need the `Embed links` permission "
"to send this")
@commands.command(pass_context=True, no_pm=True)
async def serverinfo(self, ctx):
"""Shows server's informations"""
server = ctx.message.server
online = len([m.status for m in server.members
if m.status == discord.Status.online or
m.status == discord.Status.idle])
total_users = len(server.members)
text_channels = len([x for x in server.channels
if x.type == discord.ChannelType.text])
voice_channels = len(server.channels) - text_channels
passed = (ctx.message.timestamp - server.created_at).days
created_at = ("Since {}. That's over {} days ago!"
"".format(server.created_at.strftime("%d %b %Y %H:%M"),
passed))
colour = ''.join([choice('0123456789ABCDEF') for x in range(6)])
colour = int(colour, 16)
data = discord.Embed(
description=created_at,
colour=discord.Colour(value=colour))
data.add_field(name="Region", value=str(server.region))
data.add_field(name="Users", value="{}/{}".format(online, total_users))
data.add_field(name="Text Channels", value=text_channels)
data.add_field(name="Voice Channels", value=voice_channels)
data.add_field(name="Roles", value=len(server.roles))
data.add_field(name="Owner", value=str(server.owner))
data.set_footer(text="Server ID: " + server.id)
if server.icon_url:
data.set_author(name=server.name, url=server.icon_url)
data.set_thumbnail(url=server.icon_url)
else:
data.set_author(name=server.name)
try:
await self.bot.say(embed=data)
except discord.HTTPException:
await self.bot.say("I need the `Embed links` permission "
"to send this")
@commands.command()
async def urban(self, *, search_terms : str, definition_number : int=1):
"""Urban Dictionary search
Definition number must be between 1 and 10"""
def encode(s):
return quote_plus(s, encoding='utf-8', errors='replace')
# definition_number is just there to show up in the help
# all this mess is to avoid forcing double quotes on the user
search_terms = search_terms.split(" ")
try:
if len(search_terms) > 1:
pos = int(search_terms[-1]) - 1
search_terms = search_terms[:-1]
else:
pos = 0
if pos not in range(0, 11): # API only provides the
pos = 0 # top 10 definitions
except ValueError:
pos = 0
search_terms = "+".join([encode(s) for s in search_terms])
url = "http://api.urbandictionary.com/v0/define?term=" + search_terms
try:
async with aiohttp.get(url) as r:
result = await r.json()
if result["list"]:
definition = result['list'][pos]['definition']
example = result['list'][pos]['example']
defs = len(result['list'])
msg = ("**Definition #{} out of {}:\n**{}\n\n"
"**Example:\n**{}".format(pos+1, defs, definition,
example))
msg = pagify(msg, ["\n"])
for page in msg:
await self.bot.say(page)
else:
await self.bot.say("Your search terms gave no results.")
except IndexError:
await self.bot.say("There is no definition #{}".format(pos+1))
except:
await self.bot.say("Error.")
@commands.command(pass_context=True, no_pm=True)
async def poll(self, ctx, *text):
"""Starts/stops a poll
Usage example:
poll Is this a poll?;Yes;No;Maybe
poll stop"""
message = ctx.message
if len(text) == 1:
if text[0].lower() == "stop":
await self.endpoll(message)
return
if not self.getPollByChannel(message):
check = " ".join(text).lower()
if "@everyone" in check or "@here" in check:
await self.bot.say("Nice try.")
return
p = NewPoll(message, " ".join(text), self)
if p.valid:
self.poll_sessions.append(p)
await p.start()
else:
await self.bot.say("poll question;option1;option2 (...)")
else:
await self.bot.say("A poll is already ongoing in this channel.")
async def endpoll(self, message):
if self.getPollByChannel(message):
p = self.getPollByChannel(message)
if p.author == message.author.id: # or isMemberAdmin(message)
await self.getPollByChannel(message).endPoll()
else:
await self.bot.say("Only admins and the author can stop the poll.")
else:
await self.bot.say("There's no poll ongoing in this channel.")
def getPollByChannel(self, message):
for poll in self.poll_sessions:
if poll.channel == message.channel:
return poll
return False
async def check_poll_votes(self, message):
if message.author.id != self.bot.user.id:
if self.getPollByChannel(message):
self.getPollByChannel(message).checkAnswer(message)
def fetch_joined_at(self, user, server):
"""Just a special case for someone special :^)"""
if user.id == "96130341705637888" and server.id == "133049272517001216":
return datetime.datetime(2016, 1, 10, 6, 8, 4, 443000)
else:
return user.joined_at
class NewPoll():
def __init__(self, message, text, main):
self.channel = message.channel
self.author = message.author.id
self.client = main.bot
self.poll_sessions = main.poll_sessions
msg = [ans.strip() for ans in text.split(";")]
if len(msg) < 2: # Needs at least one question and 2 choices
self.valid = False
return None
else:
self.valid = True
self.already_voted = []
self.question = msg[0]
msg.remove(self.question)
self.answers = {}
i = 1
for answer in msg: # {id : {answer, votes}}
self.answers[i] = {"ANSWER" : answer, "VOTES" : 0}
i += 1
async def start(self):
msg = "**POLL STARTED!**\n\n{}\n\n".format(self.question)
for id, data in self.answers.items():
msg += "{}. *{}*\n".format(id, data["ANSWER"])
msg += "\nType the number to vote!"
await self.client.send_message(self.channel, msg)
await asyncio.sleep(settings["POLL_DURATION"])
if self.valid:
await self.endPoll()
async def endPoll(self):
self.valid = False
msg = "**POLL ENDED!**\n\n{}\n\n".format(self.question)
for data in self.answers.values():
msg += "*{}* - {} votes\n".format(data["ANSWER"], str(data["VOTES"]))
await self.client.send_message(self.channel, msg)
self.poll_sessions.remove(self)
def checkAnswer(self, message):
try:
i = int(message.content)
if i in self.answers.keys():
if message.author.id not in self.already_voted:
data = self.answers[i]
data["VOTES"] += 1
self.answers[i] = data
self.already_voted.append(message.author.id)
except ValueError:
pass
def setup(bot):
n = General(bot)
bot.add_listener(n.check_poll_votes, "on_message")
bot.add_cog(n)

View File

@ -0,0 +1,168 @@
from discord.ext import commands
from random import choice, shuffle
import aiohttp
import functools
import asyncio
try:
from imgurpython import ImgurClient
except:
ImgurClient = False
CLIENT_ID = "1fd3ef04daf8cab"
CLIENT_SECRET = "f963e574e8e3c17993c933af4f0522e1dc01e230"
GIPHY_API_KEY = "dc6zaTOxFJmzC"
class Image:
"""Image related commands."""
def __init__(self, bot):
self.bot = bot
self.imgur = ImgurClient(CLIENT_ID, CLIENT_SECRET)
@commands.group(name="imgur", no_pm=True, pass_context=True)
async def _imgur(self, ctx):
"""Retrieves pictures from imgur"""
if ctx.invoked_subcommand is None:
await self.bot.send_cmd_help(ctx)
@_imgur.command(pass_context=True, name="random")
async def imgur_random(self, ctx, *, term: str=None):
"""Retrieves a random image from Imgur
Search terms can be specified"""
if term is None:
task = functools.partial(self.imgur.gallery_random, page=0)
else:
task = functools.partial(self.imgur.gallery_search, term,
advanced=None, sort='time',
window='all', page=0)
task = self.bot.loop.run_in_executor(None, task)
try:
results = await asyncio.wait_for(task, timeout=10)
except asyncio.TimeoutError:
await self.bot.say("Error: request timed out")
else:
if results:
item = choice(results)
link = item.gifv if hasattr(item, "gifv") else item.link
await self.bot.say(link)
else:
await self.bot.say("Your search terms gave no results.")
@_imgur.command(pass_context=True, name="search")
async def imgur_search(self, ctx, *, term: str):
"""Searches Imgur for the specified term and returns up to 3 results"""
task = functools.partial(self.imgur.gallery_search, term,
advanced=None, sort='time',
window='all', page=0)
task = self.bot.loop.run_in_executor(None, task)
try:
results = await asyncio.wait_for(task, timeout=10)
except asyncio.TimeoutError:
await self.bot.say("Error: request timed out")
else:
if results:
shuffle(results)
msg = "Search results...\n"
for r in results[:3]:
msg += r.gifv if hasattr(r, "gifv") else r.link
msg += "\n"
await self.bot.say(msg)
else:
await self.bot.say("Your search terms gave no results.")
@_imgur.command(pass_context=True, name="subreddit")
async def imgur_subreddit(self, ctx, subreddit: str, sort_type: str="top", window: str="day"):
"""Gets images from the specified subreddit section
Sort types: new, top
Time windows: day, week, month, year, all"""
sort_type = sort_type.lower()
if sort_type not in ("new", "top"):
await self.bot.say("Only 'new' and 'top' are a valid sort type.")
return
elif window not in ("day", "week", "month", "year", "all"):
await self.bot.send_cmd_help(ctx)
return
if sort_type == "new":
sort = "time"
elif sort_type == "top":
sort = "top"
links = []
task = functools.partial(self.imgur.subreddit_gallery, subreddit,
sort=sort, window=window, page=0)
task = self.bot.loop.run_in_executor(None, task)
try:
items = await asyncio.wait_for(task, timeout=10)
except asyncio.TimeoutError:
await self.bot.say("Error: request timed out")
return
for item in items[:3]:
link = item.gifv if hasattr(item, "gifv") else item.link
links.append("{}\n{}".format(item.title, link))
if links:
await self.bot.say("\n".join(links))
else:
await self.bot.say("No results found.")
@commands.command(pass_context=True, no_pm=True)
async def gif(self, ctx, *keywords):
"""Retrieves first search result from giphy"""
if keywords:
keywords = "+".join(keywords)
else:
await self.bot.send_cmd_help(ctx)
return
url = ("http://api.giphy.com/v1/gifs/search?&api_key={}&q={}"
"".format(GIPHY_API_KEY, keywords))
async with aiohttp.get(url) as r:
result = await r.json()
if r.status == 200:
if result["data"]:
await self.bot.say(result["data"][0]["url"])
else:
await self.bot.say("No results found.")
else:
await self.bot.say("Error contacting the API")
@commands.command(pass_context=True, no_pm=True)
async def gifr(self, ctx, *keywords):
"""Retrieves a random gif from a giphy search"""
if keywords:
keywords = "+".join(keywords)
else:
await self.bot.send_cmd_help(ctx)
return
url = ("http://api.giphy.com/v1/gifs/random?&api_key={}&tag={}"
"".format(GIPHY_API_KEY, keywords))
async with aiohttp.get(url) as r:
result = await r.json()
if r.status == 200:
if result["data"]:
await self.bot.say(result["data"]["url"])
else:
await self.bot.say("No results found.")
else:
await self.bot.say("Error contacting the API")
def setup(bot):
if ImgurClient is False:
raise RuntimeError("You need the imgurpython module to use this.\n"
"pip3 install imgurpython")
bot.add_cog(Image(bot))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,689 @@
from discord.ext import commands
from .utils.dataIO import dataIO
from .utils.chat_formatting import escape_mass_mentions
from .utils import checks
from collections import defaultdict
from string import ascii_letters
from random import choice
import discord
import os
import re
import aiohttp
import asyncio
import logging
import json
class StreamsError(Exception):
pass
class StreamNotFound(StreamsError):
pass
class APIError(StreamsError):
pass
class InvalidCredentials(StreamsError):
pass
class OfflineStream(StreamsError):
pass
class Streams:
"""Streams
Alerts for a variety of streaming services"""
def __init__(self, bot):
self.bot = bot
self.twitch_streams = dataIO.load_json("data/streams/twitch.json")
self.hitbox_streams = dataIO.load_json("data/streams/hitbox.json")
self.mixer_streams = dataIO.load_json("data/streams/beam.json")
self.picarto_streams = dataIO.load_json("data/streams/picarto.json")
settings = dataIO.load_json("data/streams/settings.json")
self.settings = defaultdict(dict, settings)
self.messages_cache = defaultdict(list)
@commands.command()
async def hitbox(self, stream: str):
"""Checks if hitbox stream is online"""
stream = escape_mass_mentions(stream)
regex = r'^(https?\:\/\/)?(www\.)?(hitbox\.tv\/)'
stream = re.sub(regex, '', stream)
try:
embed = await self.hitbox_online(stream)
except OfflineStream:
await self.bot.say(stream + " is offline.")
except StreamNotFound:
await self.bot.say("That stream doesn't exist.")
except APIError:
await self.bot.say("Error contacting the API.")
else:
await self.bot.say(embed=embed)
@commands.command(pass_context=True)
async def twitch(self, ctx, stream: str):
"""Checks if twitch stream is online"""
stream = escape_mass_mentions(stream)
regex = r'^(https?\:\/\/)?(www\.)?(twitch\.tv\/)'
stream = re.sub(regex, '', stream)
try:
data = await self.fetch_twitch_ids(stream, raise_if_none=True)
embed = await self.twitch_online(data[0]["_id"])
except OfflineStream:
await self.bot.say(stream + " is offline.")
except StreamNotFound:
await self.bot.say("That stream doesn't exist.")
except APIError:
await self.bot.say("Error contacting the API.")
except InvalidCredentials:
await self.bot.say("Owner: Client-ID is invalid or not set. "
"See `{}streamset twitchtoken`"
"".format(ctx.prefix))
else:
await self.bot.say(embed=embed)
@commands.command()
async def mixer(self, stream: str):
"""Checks if mixer stream is online"""
stream = escape_mass_mentions(stream)
regex = r'^(https?\:\/\/)?(www\.)?(mixer\.com\/)'
stream = re.sub(regex, '', stream)
try:
embed = await self.mixer_online(stream)
except OfflineStream:
await self.bot.say(stream + " is offline.")
except StreamNotFound:
await self.bot.say("That stream doesn't exist.")
except APIError:
await self.bot.say("Error contacting the API.")
else:
await self.bot.say(embed=embed)
@commands.command()
async def picarto(self, stream: str):
"""Checks if picarto stream is online"""
stream = escape_mass_mentions(stream)
regex = r'^(https?\:\/\/)?(www\.)?(picarto\.tv\/)'
stream = re.sub(regex, '', stream)
try:
embed = await self.picarto_online(stream)
except OfflineStream:
await self.bot.say(stream + " is offline.")
except StreamNotFound:
await self.bot.say("That stream doesn't exist.")
except APIError:
await self.bot.say("Error contacting the API.")
else:
await self.bot.say(embed=embed)
@commands.group(pass_context=True, no_pm=True)
@checks.mod_or_permissions(manage_server=True)
async def streamalert(self, ctx):
"""Adds/removes stream alerts from the current channel"""
if ctx.invoked_subcommand is None:
await self.bot.send_cmd_help(ctx)
@streamalert.command(name="twitch", pass_context=True)
async def twitch_alert(self, ctx, stream: str):
"""Adds/removes twitch alerts from the current channel"""
stream = escape_mass_mentions(stream)
regex = r'^(https?\:\/\/)?(www\.)?(twitch\.tv\/)'
stream = re.sub(regex, '', stream)
channel = ctx.message.channel
try:
data = await self.fetch_twitch_ids(stream, raise_if_none=True)
except StreamNotFound:
await self.bot.say("That stream doesn't exist.")
return
except APIError:
await self.bot.say("Error contacting the API.")
return
except InvalidCredentials:
await self.bot.say("Owner: Client-ID is invalid or not set. "
"See `{}streamset twitchtoken`"
"".format(ctx.prefix))
return
enabled = self.enable_or_disable_if_active(self.twitch_streams,
stream,
channel,
_id=data[0]["_id"])
if enabled:
await self.bot.say("Alert activated. I will notify this channel "
"when {} is live.".format(stream))
else:
await self.bot.say("Alert has been removed from this channel.")
dataIO.save_json("data/streams/twitch.json", self.twitch_streams)
@streamalert.command(name="hitbox", pass_context=True)
async def hitbox_alert(self, ctx, stream: str):
"""Adds/removes hitbox alerts from the current channel"""
stream = escape_mass_mentions(stream)
regex = r'^(https?\:\/\/)?(www\.)?(hitbox\.tv\/)'
stream = re.sub(regex, '', stream)
channel = ctx.message.channel
try:
await self.hitbox_online(stream)
except StreamNotFound:
await self.bot.say("That stream doesn't exist.")
return
except APIError:
await self.bot.say("Error contacting the API.")
return
except OfflineStream:
pass
enabled = self.enable_or_disable_if_active(self.hitbox_streams,
stream,
channel)
if enabled:
await self.bot.say("Alert activated. I will notify this channel "
"when {} is live.".format(stream))
else:
await self.bot.say("Alert has been removed from this channel.")
dataIO.save_json("data/streams/hitbox.json", self.hitbox_streams)
@streamalert.command(name="mixer", pass_context=True)
async def mixer_alert(self, ctx, stream: str):
"""Adds/removes mixer alerts from the current channel"""
stream = escape_mass_mentions(stream)
regex = r'^(https?\:\/\/)?(www\.)?(mixer\.com\/)'
stream = re.sub(regex, '', stream)
channel = ctx.message.channel
try:
await self.mixer_online(stream)
except StreamNotFound:
await self.bot.say("That stream doesn't exist.")
return
except APIError:
await self.bot.say("Error contacting the API.")
return
except OfflineStream:
pass
enabled = self.enable_or_disable_if_active(self.mixer_streams,
stream,
channel)
if enabled:
await self.bot.say("Alert activated. I will notify this channel "
"when {} is live.".format(stream))
else:
await self.bot.say("Alert has been removed from this channel.")
dataIO.save_json("data/streams/beam.json", self.mixer_streams)
@streamalert.command(name="picarto", pass_context=True)
async def picarto_alert(self, ctx, stream: str):
"""Adds/removes picarto alerts from the current channel"""
stream = escape_mass_mentions(stream)
regex = r'^(https?\:\/\/)?(www\.)?(picarto\.tv\/)'
stream = re.sub(regex, '', stream)
channel = ctx.message.channel
try:
await self.picarto_online(stream)
except StreamNotFound:
await self.bot.say("That stream doesn't exist.")
return
except APIError:
await self.bot.say("Error contacting the API.")
return
except OfflineStream:
pass
enabled = self.enable_or_disable_if_active(self.picarto_streams,
stream,
channel)
if enabled:
await self.bot.say("Alert activated. I will notify this channel "
"when {} is live.".format(stream))
else:
await self.bot.say("Alert has been removed from this channel.")
dataIO.save_json("data/streams/picarto.json", self.picarto_streams)
@streamalert.command(name="stop", pass_context=True)
async def stop_alert(self, ctx):
"""Stops all streams alerts in the current channel"""
channel = ctx.message.channel
streams = (
self.hitbox_streams,
self.twitch_streams,
self.mixer_streams,
self.picarto_streams
)
for stream_type in streams:
to_delete = []
for s in stream_type:
if channel.id in s["CHANNELS"]:
s["CHANNELS"].remove(channel.id)
if not s["CHANNELS"]:
to_delete.append(s)
for s in to_delete:
stream_type.remove(s)
dataIO.save_json("data/streams/twitch.json", self.twitch_streams)
dataIO.save_json("data/streams/hitbox.json", self.hitbox_streams)
dataIO.save_json("data/streams/beam.json", self.mixer_streams)
dataIO.save_json("data/streams/picarto.json", self.picarto_streams)
await self.bot.say("There will be no more stream alerts in this "
"channel.")
@commands.group(pass_context=True)
async def streamset(self, ctx):
"""Stream settings"""
if ctx.invoked_subcommand is None:
await self.bot.send_cmd_help(ctx)
@streamset.command()
@checks.is_owner()
async def twitchtoken(self, token : str):
"""Sets the Client-ID for Twitch
https://blog.twitch.tv/client-id-required-for-kraken-api-calls-afbb8e95f843"""
self.settings["TWITCH_TOKEN"] = token
dataIO.save_json("data/streams/settings.json", self.settings)
await self.bot.say('Twitch Client-ID set.')
@streamset.command(pass_context=True, no_pm=True)
@checks.admin()
async def mention(self, ctx, *, mention_type : str):
"""Sets mentions for stream alerts
Types: everyone, here, none"""
server = ctx.message.server
mention_type = mention_type.lower()
if mention_type in ("everyone", "here"):
self.settings[server.id]["MENTION"] = "@" + mention_type
await self.bot.say("When a stream is online @\u200b{} will be "
"mentioned.".format(mention_type))
elif mention_type == "none":
self.settings[server.id]["MENTION"] = ""
await self.bot.say("Mentions disabled.")
else:
await self.bot.send_cmd_help(ctx)
dataIO.save_json("data/streams/settings.json", self.settings)
@streamset.command(pass_context=True, no_pm=True)
@checks.admin()
async def autodelete(self, ctx):
"""Toggles automatic notification deletion for streams that go offline"""
server = ctx.message.server
settings = self.settings[server.id]
current = settings.get("AUTODELETE", True)
settings["AUTODELETE"] = not current
if settings["AUTODELETE"]:
await self.bot.say("Notifications will be automatically deleted "
"once the stream goes offline.")
else:
await self.bot.say("Notifications won't be deleted anymore.")
dataIO.save_json("data/streams/settings.json", self.settings)
async def hitbox_online(self, stream):
url = "https://api.hitbox.tv/media/live/" + stream
async with aiohttp.get(url) as r:
data = await r.json(encoding='utf-8')
if "livestream" not in data:
raise StreamNotFound()
elif data["livestream"][0]["media_is_live"] == "0":
raise OfflineStream()
elif data["livestream"][0]["media_is_live"] == "1":
return self.hitbox_embed(data)
raise APIError()
async def twitch_online(self, stream):
session = aiohttp.ClientSession()
url = "https://api.twitch.tv/kraken/streams/" + stream
header = {
'Client-ID': self.settings.get("TWITCH_TOKEN", ""),
'Accept': 'application/vnd.twitchtv.v5+json'
}
async with session.get(url, headers=header) as r:
data = await r.json(encoding='utf-8')
await session.close()
if r.status == 200:
if data["stream"] is None:
raise OfflineStream()
return self.twitch_embed(data)
elif r.status == 400:
raise InvalidCredentials()
elif r.status == 404:
raise StreamNotFound()
else:
raise APIError()
async def mixer_online(self, stream):
url = "https://mixer.com/api/v1/channels/" + stream
async with aiohttp.get(url) as r:
data = await r.json(encoding='utf-8')
if r.status == 200:
if data["online"] is True:
return self.mixer_embed(data)
else:
raise OfflineStream()
elif r.status == 404:
raise StreamNotFound()
else:
raise APIError()
async def picarto_online(self, stream):
url = "https://api.picarto.tv/v1/channel/name/" + stream
async with aiohttp.get(url) as r:
data = await r.text(encoding='utf-8')
if r.status == 200:
data = json.loads(data)
if data["online"] is True:
return self.picarto_embed(data)
else:
raise OfflineStream()
elif r.status == 404:
raise StreamNotFound()
else:
raise APIError()
async def fetch_twitch_ids(self, *streams, raise_if_none=False):
def chunks(l):
for i in range(0, len(l), 100):
yield l[i:i + 100]
base_url = "https://api.twitch.tv/kraken/users?login="
header = {
'Client-ID': self.settings.get("TWITCH_TOKEN", ""),
'Accept': 'application/vnd.twitchtv.v5+json'
}
results = []
for streams_list in chunks(streams):
session = aiohttp.ClientSession()
url = base_url + ",".join(streams_list)
async with session.get(url, headers=header) as r:
data = await r.json()
if r.status == 200:
results.extend(data["users"])
elif r.status == 400:
raise InvalidCredentials()
else:
raise APIError()
await session.close()
if not results and raise_if_none:
raise StreamNotFound()
return results
def twitch_embed(self, data):
channel = data["stream"]["channel"]
url = channel["url"]
logo = channel["logo"]
if logo is None:
logo = "https://static-cdn.jtvnw.net/jtv_user_pictures/xarth/404_user_70x70.png"
status = channel["status"]
if not status:
status = "Untitled broadcast"
embed = discord.Embed(title=status, url=url)
embed.set_author(name=channel["display_name"])
embed.add_field(name="Followers", value=channel["followers"])
embed.add_field(name="Total views", value=channel["views"])
embed.set_thumbnail(url=logo)
if data["stream"]["preview"]["medium"]:
embed.set_image(url=data["stream"]["preview"]["medium"] + self.rnd_attr())
if channel["game"]:
embed.set_footer(text="Playing: " + channel["game"])
embed.color = 0x6441A4
return embed
def hitbox_embed(self, data):
base_url = "https://edge.sf.hitbox.tv"
livestream = data["livestream"][0]
channel = livestream["channel"]
url = channel["channel_link"]
embed = discord.Embed(title=livestream["media_status"], url=url)
embed.set_author(name=livestream["media_name"])
embed.add_field(name="Followers", value=channel["followers"])
#embed.add_field(name="Views", value=channel["views"])
embed.set_thumbnail(url=base_url + channel["user_logo"])
if livestream["media_thumbnail"]:
embed.set_image(url=base_url + livestream["media_thumbnail"] + self.rnd_attr())
embed.set_footer(text="Playing: " + livestream["category_name"])
embed.color = 0x98CB00
return embed
def mixer_embed(self, data):
default_avatar = ("https://mixer.com/_latest/assets/images/main/"
"avatars/default.jpg")
user = data["user"]
url = "https://mixer.com/" + data["token"]
embed = discord.Embed(title=data["name"], url=url)
embed.set_author(name=user["username"])
embed.add_field(name="Followers", value=data["numFollowers"])
embed.add_field(name="Total views", value=data["viewersTotal"])
if user["avatarUrl"]:
embed.set_thumbnail(url=user["avatarUrl"])
else:
embed.set_thumbnail(url=default_avatar)
if data["thumbnail"]:
embed.set_image(url=data["thumbnail"]["url"] + self.rnd_attr())
embed.color = 0x4C90F3
if data["type"] is not None:
embed.set_footer(text="Playing: " + data["type"]["name"])
return embed
def picarto_embed(self, data):
avatar = ("https://picarto.tv/user_data/usrimg/{}/dsdefault.jpg{}"
"".format(data["name"].lower(), self.rnd_attr()))
url = "https://picarto.tv/" + data["name"]
thumbnail = ("https://thumb.picarto.tv/thumbnail/{}.jpg"
"".format(data["name"]))
embed = discord.Embed(title=data["title"], url=url)
embed.set_author(name=data["name"])
embed.set_image(url=thumbnail + self.rnd_attr())
embed.add_field(name="Followers", value=data["followers"])
embed.add_field(name="Total views", value=data["viewers_total"])
embed.set_thumbnail(url=avatar)
embed.color = 0x132332
data["tags"] = ", ".join(data["tags"])
if not data["tags"]:
data["tags"] = "None"
if data["adult"]:
data["adult"] = "NSFW | "
else:
data["adult"] = ""
embed.color = 0x4C90F3
embed.set_footer(text="{adult}Category: {category} | Tags: {tags}"
"".format(**data))
return embed
def enable_or_disable_if_active(self, streams, stream, channel, _id=None):
"""Returns True if enabled or False if disabled"""
for i, s in enumerate(streams):
if s["NAME"] != stream:
continue
if channel.id in s["CHANNELS"]:
streams[i]["CHANNELS"].remove(channel.id)
if not s["CHANNELS"]:
streams.remove(s)
return False
else:
streams[i]["CHANNELS"].append(channel.id)
return True
data = {"CHANNELS": [channel.id],
"NAME": stream,
"ALREADY_ONLINE": False}
if _id:
data["ID"] = _id
streams.append(data)
return True
async def stream_checker(self):
CHECK_DELAY = 60
try:
await self._migration_twitch_v5()
except InvalidCredentials:
print("Error during convertion of twitch usernames to IDs: "
"invalid token")
except Exception as e:
print("Error during convertion of twitch usernames to IDs: "
"{}".format(e))
while self == self.bot.get_cog("Streams"):
save = False
streams = ((self.twitch_streams, self.twitch_online),
(self.hitbox_streams, self.hitbox_online),
(self.mixer_streams, self.mixer_online),
(self.picarto_streams, self.picarto_online))
for streams_list, parser in streams:
if parser == self.twitch_online:
_type = "ID"
else:
_type = "NAME"
for stream in streams_list:
if _type not in stream:
continue
key = (parser, stream[_type])
try:
embed = await parser(stream[_type])
except OfflineStream:
if stream["ALREADY_ONLINE"]:
stream["ALREADY_ONLINE"] = False
save = True
await self.delete_old_notifications(key)
except: # We don't want our task to die
continue
else:
if stream["ALREADY_ONLINE"]:
continue
save = True
stream["ALREADY_ONLINE"] = True
messages_sent = []
for channel_id in stream["CHANNELS"]:
channel = self.bot.get_channel(channel_id)
if channel is None:
continue
mention = self.settings.get(channel.server.id, {}).get("MENTION", "")
can_speak = channel.permissions_for(channel.server.me).send_messages
message = mention + " {} is live!".format(stream["NAME"])
if channel and can_speak:
m = await self.bot.send_message(channel, message, embed=embed)
messages_sent.append(m)
self.messages_cache[key] = messages_sent
await asyncio.sleep(0.5)
if save:
dataIO.save_json("data/streams/twitch.json", self.twitch_streams)
dataIO.save_json("data/streams/hitbox.json", self.hitbox_streams)
dataIO.save_json("data/streams/beam.json", self.mixer_streams)
dataIO.save_json("data/streams/picarto.json", self.picarto_streams)
await asyncio.sleep(CHECK_DELAY)
async def delete_old_notifications(self, key):
for message in self.messages_cache[key]:
server = message.server
settings = self.settings.get(server.id, {})
is_enabled = settings.get("AUTODELETE", True)
try:
if is_enabled:
await self.bot.delete_message(message)
except:
pass
del self.messages_cache[key]
def rnd_attr(self):
"""Avoids Discord's caching"""
return "?rnd=" + "".join([choice(ascii_letters) for i in range(6)])
async def _migration_twitch_v5(self):
# Migration of old twitch streams to API v5
to_convert = []
for stream in self.twitch_streams:
if "ID" not in stream:
to_convert.append(stream["NAME"])
if not to_convert:
return
results = await self.fetch_twitch_ids(*to_convert)
for stream in self.twitch_streams:
for result in results:
if stream["NAME"].lower() == result["name"].lower():
stream["ID"] = result["_id"]
# We might as well delete the invalid / renamed ones
self.twitch_streams = [s for s in self.twitch_streams if "ID" in s]
dataIO.save_json("data/streams/twitch.json", self.twitch_streams)
def check_folders():
if not os.path.exists("data/streams"):
print("Creating data/streams folder...")
os.makedirs("data/streams")
def check_files():
stream_files = (
"twitch.json",
"hitbox.json",
"beam.json",
"picarto.json"
)
for filename in stream_files:
if not dataIO.is_valid_json("data/streams/" + filename):
print("Creating empty {}...".format(filename))
dataIO.save_json("data/streams/" + filename, [])
f = "data/streams/settings.json"
if not dataIO.is_valid_json(f):
print("Creating empty settings.json...")
dataIO.save_json(f, {})
def setup(bot):
logger = logging.getLogger('aiohttp.client')
logger.setLevel(50) # Stops warning spam
check_folders()
check_files()
n = Streams(bot)
loop = asyncio.get_event_loop()
loop.create_task(n.stream_checker())
bot.add_cog(n)

View File

@ -0,0 +1,332 @@
from discord.ext import commands
from random import choice
from .utils.dataIO import dataIO
from .utils import checks
from .utils.chat_formatting import box
from collections import Counter, defaultdict, namedtuple
import discord
import time
import os
import asyncio
import chardet
DEFAULTS = {"MAX_SCORE" : 10,
"TIMEOUT" : 120,
"DELAY" : 15,
"BOT_PLAYS" : False,
"REVEAL_ANSWER": True}
TriviaLine = namedtuple("TriviaLine", "question answers")
class Trivia:
"""General commands."""
def __init__(self, bot):
self.bot = bot
self.trivia_sessions = []
self.file_path = "data/trivia/settings.json"
settings = dataIO.load_json(self.file_path)
self.settings = defaultdict(lambda: DEFAULTS.copy(), settings)
@commands.group(pass_context=True, no_pm=True)
@checks.mod_or_permissions(administrator=True)
async def triviaset(self, ctx):
"""Change trivia settings"""
server = ctx.message.server
if ctx.invoked_subcommand is None:
settings = self.settings[server.id]
msg = box("Red gains points: {BOT_PLAYS}\n"
"Seconds to answer: {DELAY}\n"
"Points to win: {MAX_SCORE}\n"
"Reveal answer on timeout: {REVEAL_ANSWER}\n"
"".format(**settings))
msg += "\nSee {}help triviaset to edit the settings".format(ctx.prefix)
await self.bot.say(msg)
@triviaset.command(pass_context=True)
async def maxscore(self, ctx, score : int):
"""Points required to win"""
server = ctx.message.server
if score > 0:
self.settings[server.id]["MAX_SCORE"] = score
self.save_settings()
await self.bot.say("Points required to win set to {}".format(score))
else:
await self.bot.say("Score must be superior to 0.")
@triviaset.command(pass_context=True)
async def timelimit(self, ctx, seconds : int):
"""Maximum seconds to answer"""
server = ctx.message.server
if seconds > 4:
self.settings[server.id]["DELAY"] = seconds
self.save_settings()
await self.bot.say("Maximum seconds to answer set to {}".format(seconds))
else:
await self.bot.say("Seconds must be at least 5.")
@triviaset.command(pass_context=True)
async def botplays(self, ctx):
"""Red gains points"""
server = ctx.message.server
if self.settings[server.id]["BOT_PLAYS"]:
self.settings[server.id]["BOT_PLAYS"] = False
await self.bot.say("Alright, I won't embarass you at trivia anymore.")
else:
self.settings[server.id]["BOT_PLAYS"] = True
await self.bot.say("I'll gain a point everytime you don't answer in time.")
self.save_settings()
@triviaset.command(pass_context=True)
async def revealanswer(self, ctx):
"""Reveals answer to the question on timeout"""
server = ctx.message.server
if self.settings[server.id]["REVEAL_ANSWER"]:
self.settings[server.id]["REVEAL_ANSWER"] = False
await self.bot.say("I won't reveal the answer to the questions anymore.")
else:
self.settings[server.id]["REVEAL_ANSWER"] = True
await self.bot.say("I'll reveal the answer if no one knows it.")
self.save_settings()
@commands.group(pass_context=True, invoke_without_command=True, no_pm=True)
async def trivia(self, ctx, list_name: str):
"""Start a trivia session with the specified list"""
message = ctx.message
server = message.server
session = self.get_trivia_by_channel(message.channel)
if not session:
try:
trivia_list = self.parse_trivia_list(list_name)
except FileNotFoundError:
await self.bot.say("That trivia list doesn't exist.")
except Exception as e:
print(e)
await self.bot.say("Error loading the trivia list.")
else:
settings = self.settings[server.id]
t = TriviaSession(self.bot, trivia_list, message, settings)
self.trivia_sessions.append(t)
await t.new_question()
else:
await self.bot.say("A trivia session is already ongoing in this channel.")
@trivia.group(name="stop", pass_context=True, no_pm=True)
async def trivia_stop(self, ctx):
"""Stops an ongoing trivia session"""
author = ctx.message.author
server = author.server
admin_role = self.bot.settings.get_server_admin(server)
mod_role = self.bot.settings.get_server_mod(server)
is_admin = discord.utils.get(author.roles, name=admin_role)
is_mod = discord.utils.get(author.roles, name=mod_role)
is_owner = author.id == self.bot.settings.owner
is_server_owner = author == server.owner
is_authorized = is_admin or is_mod or is_owner or is_server_owner
session = self.get_trivia_by_channel(ctx.message.channel)
if session:
if author == session.starter or is_authorized:
await session.end_game()
await self.bot.say("Trivia stopped.")
else:
await self.bot.say("You are not allowed to do that.")
else:
await self.bot.say("There's no trivia session ongoing in this channel.")
@trivia.group(name="list")
async def trivia_list(self):
"""Shows available trivia lists"""
lists = os.listdir("data/trivia/")
lists = [l for l in lists if l.endswith(".txt") and " " not in l]
lists = [l.replace(".txt", "") for l in lists]
if lists:
msg = "+ Available trivia lists\n\n" + ", ".join(sorted(lists))
msg = box(msg, lang="diff")
if len(lists) < 100:
await self.bot.say(msg)
else:
await self.bot.whisper(msg)
else:
await self.bot.say("There are no trivia lists available.")
def parse_trivia_list(self, filename):
path = "data/trivia/{}.txt".format(filename)
parsed_list = []
with open(path, "rb") as f:
try:
encoding = chardet.detect(f.read())["encoding"]
except:
encoding = "ISO-8859-1"
with open(path, "r", encoding=encoding) as f:
trivia_list = f.readlines()
for line in trivia_list:
if "`" not in line:
continue
line = line.replace("\n", "")
line = line.split("`")
question = line[0]
answers = []
for l in line[1:]:
answers.append(l.strip())
if len(line) >= 2 and question and answers:
line = TriviaLine(question=question, answers=answers)
parsed_list.append(line)
if not parsed_list:
raise ValueError("Empty trivia list")
return parsed_list
def get_trivia_by_channel(self, channel):
for t in self.trivia_sessions:
if t.channel == channel:
return t
return None
async def on_message(self, message):
if message.author != self.bot.user:
session = self.get_trivia_by_channel(message.channel)
if session:
await session.check_answer(message)
async def on_trivia_end(self, instance):
if instance in self.trivia_sessions:
self.trivia_sessions.remove(instance)
def save_settings(self):
dataIO.save_json(self.file_path, self.settings)
class TriviaSession():
def __init__(self, bot, trivia_list, message, settings):
self.bot = bot
self.reveal_messages = ("I know this one! {}!",
"Easy: {}.",
"Oh really? It's {} of course.")
self.fail_messages = ("To the next one I guess...",
"Moving on...",
"I'm sure you'll know the answer of the next one.",
"\N{PENSIVE FACE} Next one.")
self.current_line = None # {"QUESTION" : "String", "ANSWERS" : []}
self.question_list = trivia_list
self.channel = message.channel
self.starter = message.author
self.scores = Counter()
self.status = "new question"
self.timer = None
self.timeout = time.perf_counter()
self.count = 0
self.settings = settings
async def stop_trivia(self):
self.status = "stop"
self.bot.dispatch("trivia_end", self)
async def end_game(self):
self.status = "stop"
if self.scores:
await self.send_table()
self.bot.dispatch("trivia_end", self)
async def new_question(self):
for score in self.scores.values():
if score == self.settings["MAX_SCORE"]:
await self.end_game()
return True
if self.question_list == []:
await self.end_game()
return True
self.current_line = choice(self.question_list)
self.question_list.remove(self.current_line)
self.status = "waiting for answer"
self.count += 1
self.timer = int(time.perf_counter())
msg = "**Question number {}!**\n\n{}".format(self.count, self.current_line.question)
await self.bot.say(msg)
while self.status != "correct answer" and abs(self.timer - int(time.perf_counter())) <= self.settings["DELAY"]:
if abs(self.timeout - int(time.perf_counter())) >= self.settings["TIMEOUT"]:
await self.bot.say("Guys...? Well, I guess I'll stop then.")
await self.stop_trivia()
return True
await asyncio.sleep(1) #Waiting for an answer or for the time limit
if self.status == "correct answer":
self.status = "new question"
await asyncio.sleep(3)
if not self.status == "stop":
await self.new_question()
elif self.status == "stop":
return True
else:
if self.settings["REVEAL_ANSWER"]:
msg = choice(self.reveal_messages).format(self.current_line.answers[0])
else:
msg = choice(self.fail_messages)
if self.settings["BOT_PLAYS"]:
msg += " **+1** for me!"
self.scores[self.bot.user] += 1
self.current_line = None
await self.bot.say(msg)
await self.bot.type()
await asyncio.sleep(3)
if not self.status == "stop":
await self.new_question()
async def send_table(self):
t = "+ Results: \n\n"
for user, score in self.scores.most_common():
t += "+ {}\t{}\n".format(user, score)
await self.bot.say(box(t, lang="diff"))
async def check_answer(self, message):
if message.author == self.bot.user:
return
elif self.current_line is None:
return
self.timeout = time.perf_counter()
has_guessed = False
for answer in self.current_line.answers:
answer = answer.lower()
guess = message.content.lower()
if " " not in answer: # Exact matching, issue #331
guess = guess.split(" ")
for word in guess:
if word == answer:
has_guessed = True
else: # The answer has spaces, we can't be as strict
if answer in guess:
has_guessed = True
if has_guessed:
self.current_line = None
self.status = "correct answer"
self.scores[message.author] += 1
msg = "You got it {}! **+1** to you!".format(message.author.name)
await self.bot.send_message(message.channel, msg)
def check_folders():
folders = ("data", "data/trivia/")
for folder in folders:
if not os.path.exists(folder):
print("Creating " + folder + " folder...")
os.makedirs(folder)
def check_files():
if not os.path.isfile("data/trivia/settings.json"):
print("Creating empty settings.json...")
dataIO.save_json("data/trivia/settings.json", {})
def setup(bot):
check_folders()
check_files()
bot.add_cog(Trivia(bot))

View File

@ -0,0 +1,80 @@
def error(text):
return "\N{NO ENTRY SIGN} {}".format(text)
def warning(text):
return "\N{WARNING SIGN} {}".format(text)
def info(text):
return "\N{INFORMATION SOURCE} {}".format(text)
def question(text):
return "\N{BLACK QUESTION MARK ORNAMENT} {}".format(text)
def bold(text):
return "**{}**".format(text)
def box(text, lang=""):
ret = "```{}\n{}\n```".format(lang, text)
return ret
def inline(text):
return "`{}`".format(text)
def italics(text):
return "*{}*".format(text)
def pagify(text, delims=["\n"], *, escape=True, shorten_by=8,
page_length=2000):
"""DOES NOT RESPECT MARKDOWN BOXES OR INLINE CODE"""
in_text = text
if escape:
num_mentions = text.count("@here") + text.count("@everyone")
shorten_by += num_mentions
page_length -= shorten_by
while len(in_text) > page_length:
closest_delim = max([in_text.rfind(d, 0, page_length)
for d in delims])
closest_delim = closest_delim if closest_delim != -1 else page_length
if escape:
to_send = escape_mass_mentions(in_text[:closest_delim])
else:
to_send = in_text[:closest_delim]
yield to_send
in_text = in_text[closest_delim:]
if escape:
yield escape_mass_mentions(in_text)
else:
yield in_text
def strikethrough(text):
return "~~{}~~".format(text)
def underline(text):
return "__{}__".format(text)
def escape(text, *, mass_mentions=False, formatting=False):
if mass_mentions:
text = text.replace("@everyone", "@\u200beveryone")
text = text.replace("@here", "@\u200bhere")
if formatting:
text = (text.replace("`", "\\`")
.replace("*", "\\*")
.replace("_", "\\_")
.replace("~", "\\~"))
return text
def escape_mass_mentions(text):
return escape(text, mass_mentions=True)

View File

@ -0,0 +1,88 @@
from discord.ext import commands
import discord.utils
from __main__ import settings
#
# This is a modified version of checks.py, originally made by Rapptz
#
# https://github.com/Rapptz
# https://github.com/Rapptz/RoboDanny/tree/async
#
def is_owner_check(ctx):
return ctx.message.author.id == settings.owner
def is_owner():
return commands.check(is_owner_check)
# The permission system of the bot is based on a "just works" basis
# You have permissions and the bot has permissions. If you meet the permissions
# required to execute the command (and the bot does as well) then it goes through
# and you can execute the command.
# If these checks fail, then there are two fallbacks.
# A role with the name of Bot Mod and a role with the name of Bot Admin.
# Having these roles provides you access to certain commands without actually having
# the permissions required for them.
# Of course, the owner will always be able to execute commands.
def check_permissions(ctx, perms):
if is_owner_check(ctx):
return True
elif not perms:
return False
ch = ctx.message.channel
author = ctx.message.author
resolved = ch.permissions_for(author)
return all(getattr(resolved, name, None) == value for name, value in perms.items())
def role_or_permissions(ctx, check, **perms):
if check_permissions(ctx, perms):
return True
ch = ctx.message.channel
author = ctx.message.author
if ch.is_private:
return False # can't have roles in PMs
role = discord.utils.find(check, author.roles)
return role is not None
def mod_or_permissions(**perms):
def predicate(ctx):
server = ctx.message.server
mod_role = settings.get_server_mod(server).lower()
admin_role = settings.get_server_admin(server).lower()
return role_or_permissions(ctx, lambda r: r.name.lower() in (mod_role,admin_role), **perms)
return commands.check(predicate)
def admin_or_permissions(**perms):
def predicate(ctx):
server = ctx.message.server
admin_role = settings.get_server_admin(server)
return role_or_permissions(ctx, lambda r: r.name.lower() == admin_role.lower(), **perms)
return commands.check(predicate)
def serverowner_or_permissions(**perms):
def predicate(ctx):
if ctx.message.server is None:
return False
server = ctx.message.server
owner = server.owner
if ctx.message.author.id == owner.id:
return True
return check_permissions(ctx,perms)
return commands.check(predicate)
def serverowner():
return serverowner_or_permissions()
def admin():
return admin_or_permissions()
def mod():
return mod_or_permissions()

View File

@ -0,0 +1,79 @@
import json
import os
import logging
from random import randint
class InvalidFileIO(Exception):
pass
class DataIO():
def __init__(self):
self.logger = logging.getLogger("red")
def save_json(self, filename, data):
"""Atomically saves json file"""
rnd = randint(1000, 9999)
path, ext = os.path.splitext(filename)
tmp_file = "{}-{}.tmp".format(path, rnd)
self._save_json(tmp_file, data)
try:
self._read_json(tmp_file)
except json.decoder.JSONDecodeError:
self.logger.exception("Attempted to write file {} but JSON "
"integrity check on tmp file has failed. "
"The original file is unaltered."
"".format(filename))
return False
os.replace(tmp_file, filename)
return True
def load_json(self, filename):
"""Loads json file"""
return self._read_json(filename)
def is_valid_json(self, filename):
"""Verifies if json file exists / is readable"""
try:
self._read_json(filename)
return True
except FileNotFoundError:
return False
except json.decoder.JSONDecodeError:
return False
def _read_json(self, filename):
with open(filename, encoding='utf-8', mode="r") as f:
data = json.load(f)
return data
def _save_json(self, filename, data):
with open(filename, encoding='utf-8', mode="w") as f:
json.dump(data, f, indent=4,sort_keys=True,
separators=(',',' : '))
return data
def _legacy_fileio(self, filename, IO, data=None):
"""Old fileIO provided for backwards compatibility"""
if IO == "save" and data != None:
return self.save_json(filename, data)
elif IO == "load" and data == None:
return self.load_json(filename)
elif IO == "check" and data == None:
return self.is_valid_json(filename)
else:
raise InvalidFileIO("FileIO was called with invalid"
" parameters")
def get_value(filename, key):
with open(filename, encoding='utf-8', mode="r") as f:
data = json.load(f)
return data[key]
def set_value(filename, key, value):
data = fileIO(filename, "load")
data[key] = value
fileIO(filename, "save", data)
return True
dataIO = DataIO()
fileIO = dataIO._legacy_fileio # backwards compatibility

View File

@ -0,0 +1,291 @@
from .dataIO import dataIO
from copy import deepcopy
import discord
import os
import argparse
default_path = "data/red/settings.json"
class Settings:
def __init__(self, path=default_path, parse_args=True):
self.path = path
self.check_folders()
self.default_settings = {
"TOKEN": None,
"EMAIL": None,
"PASSWORD": None,
"OWNER": None,
"PREFIXES": [],
"default": {"ADMIN_ROLE": "Transistor",
"MOD_ROLE": "Process",
"PREFIXES": []}
}
self._memory_only = False
if not dataIO.is_valid_json(self.path):
self.bot_settings = deepcopy(self.default_settings)
self.save_settings()
else:
current = dataIO.load_json(self.path)
if current.keys() != self.default_settings.keys():
for key in self.default_settings.keys():
if key not in current.keys():
current[key] = self.default_settings[key]
print("Adding " + str(key) +
" field to red settings.json")
dataIO.save_json(self.path, current)
self.bot_settings = dataIO.load_json(self.path)
if "default" not in self.bot_settings:
self.update_old_settings_v1()
if "LOGIN_TYPE" in self.bot_settings:
self.update_old_settings_v2()
if parse_args:
self.parse_cmd_arguments()
def parse_cmd_arguments(self):
parser = argparse.ArgumentParser(description="Red - Discord Bot")
parser.add_argument("--owner", help="ID of the owner. Only who hosts "
"Red should be owner, this has "
"security implications")
parser.add_argument("--prefix", "-p", action="append",
help="Global prefix. Can be multiple")
parser.add_argument("--admin-role", help="Role seen as admin role by "
"Red")
parser.add_argument("--mod-role", help="Role seen as mod role by Red")
parser.add_argument("--no-prompt",
action="store_true",
help="Disables console inputs. Features requiring "
"console interaction could be disabled as a "
"result")
parser.add_argument("--no-cogs",
action="store_true",
help="Starts Red with no cogs loaded, only core")
parser.add_argument("--self-bot",
action='store_true',
help="Specifies if Red should log in as selfbot")
parser.add_argument("--memory-only",
action="store_true",
help="Arguments passed and future edits to the "
"settings will not be saved to disk")
parser.add_argument("--dry-run",
action="store_true",
help="Makes Red quit with code 0 just before the "
"login. This is useful for testing the boot "
"process.")
parser.add_argument("--debug",
action="store_true",
help="Enables debug mode")
args = parser.parse_args()
if args.owner:
self.owner = args.owner
if args.prefix:
self.prefixes = sorted(args.prefix, reverse=True)
if args.admin_role:
self.default_admin = args.admin_role
if args.mod_role:
self.default_mod = args.mod_role
self.no_prompt = args.no_prompt
self.self_bot = args.self_bot
self._memory_only = args.memory_only
self._no_cogs = args.no_cogs
self.debug = args.debug
self._dry_run = args.dry_run
self.save_settings()
def check_folders(self):
folders = ("data", os.path.dirname(self.path), "cogs", "cogs/utils")
for folder in folders:
if not os.path.exists(folder):
print("Creating " + folder + " folder...")
os.makedirs(folder)
def save_settings(self):
if not self._memory_only:
dataIO.save_json(self.path, self.bot_settings)
def update_old_settings_v1(self):
# This converts the old settings format
mod = self.bot_settings["MOD_ROLE"]
admin = self.bot_settings["ADMIN_ROLE"]
del self.bot_settings["MOD_ROLE"]
del self.bot_settings["ADMIN_ROLE"]
self.bot_settings["default"] = {"MOD_ROLE": mod,
"ADMIN_ROLE": admin,
"PREFIXES": []
}
self.save_settings()
def update_old_settings_v2(self):
# The joys of backwards compatibility
settings = self.bot_settings
if settings["EMAIL"] == "EmailHere":
settings["EMAIL"] = None
if settings["PASSWORD"] == "":
settings["PASSWORD"] = None
if settings["LOGIN_TYPE"] == "token":
settings["TOKEN"] = settings["EMAIL"]
settings["EMAIL"] = None
settings["PASSWORD"] = None
else:
settings["TOKEN"] = None
del settings["LOGIN_TYPE"]
self.save_settings()
@property
def owner(self):
return self.bot_settings["OWNER"]
@owner.setter
def owner(self, value):
self.bot_settings["OWNER"] = value
@property
def token(self):
return os.environ.get("RED_TOKEN", self.bot_settings["TOKEN"])
@token.setter
def token(self, value):
self.bot_settings["TOKEN"] = value
self.bot_settings["EMAIL"] = None
self.bot_settings["PASSWORD"] = None
@property
def email(self):
return os.environ.get("RED_EMAIL", self.bot_settings["EMAIL"])
@email.setter
def email(self, value):
self.bot_settings["EMAIL"] = value
self.bot_settings["TOKEN"] = None
@property
def password(self):
return os.environ.get("RED_PASSWORD", self.bot_settings["PASSWORD"])
@password.setter
def password(self, value):
self.bot_settings["PASSWORD"] = value
@property
def login_credentials(self):
if self.token:
return (self.token,)
elif self.email and self.password:
return (self.email, self.password)
else:
return tuple()
@property
def prefixes(self):
return self.bot_settings["PREFIXES"]
@prefixes.setter
def prefixes(self, value):
assert isinstance(value, list)
self.bot_settings["PREFIXES"] = value
@property
def default_admin(self):
if "default" not in self.bot_settings:
self.update_old_settings()
return self.bot_settings["default"].get("ADMIN_ROLE", "")
@default_admin.setter
def default_admin(self, value):
if "default" not in self.bot_settings:
self.update_old_settings()
self.bot_settings["default"]["ADMIN_ROLE"] = value
@property
def default_mod(self):
if "default" not in self.bot_settings:
self.update_old_settings_v1()
return self.bot_settings["default"].get("MOD_ROLE", "")
@default_mod.setter
def default_mod(self, value):
if "default" not in self.bot_settings:
self.update_old_settings_v1()
self.bot_settings["default"]["MOD_ROLE"] = value
@property
def servers(self):
ret = {}
server_ids = list(
filter(lambda x: str(x).isdigit(), self.bot_settings))
for server in server_ids:
ret.update({server: self.bot_settings[server]})
return ret
def get_server(self, server):
if server is None:
return self.bot_settings["default"].copy()
assert isinstance(server, discord.Server)
return self.bot_settings.get(server.id,
self.bot_settings["default"]).copy()
def get_server_admin(self, server):
if server is None:
return self.default_admin
assert isinstance(server, discord.Server)
if server.id not in self.bot_settings:
return self.default_admin
return self.bot_settings[server.id].get("ADMIN_ROLE", "")
def set_server_admin(self, server, value):
if server is None:
return
assert isinstance(server, discord.Server)
if server.id not in self.bot_settings:
self.add_server(server.id)
self.bot_settings[server.id]["ADMIN_ROLE"] = value
self.save_settings()
def get_server_mod(self, server):
if server is None:
return self.default_mod
assert isinstance(server, discord.Server)
if server.id not in self.bot_settings:
return self.default_mod
return self.bot_settings[server.id].get("MOD_ROLE", "")
def set_server_mod(self, server, value):
if server is None:
return
assert isinstance(server, discord.Server)
if server.id not in self.bot_settings:
self.add_server(server.id)
self.bot_settings[server.id]["MOD_ROLE"] = value
self.save_settings()
def get_server_prefixes(self, server):
if server is None or server.id not in self.bot_settings:
return self.prefixes
return self.bot_settings[server.id].get("PREFIXES", [])
def set_server_prefixes(self, server, prefixes):
if server is None:
return
assert isinstance(server, discord.Server)
if server.id not in self.bot_settings:
self.add_server(server.id)
self.bot_settings[server.id]["PREFIXES"] = prefixes
self.save_settings()
def get_prefixes(self, server):
"""Returns server's prefixes if set, otherwise global ones"""
p = self.get_server_prefixes(server)
return p if p else self.prefixes
def add_server(self, sid):
self.bot_settings[sid] = self.bot_settings["default"].copy()
self.save_settings()

View File

@ -0,0 +1,573 @@
from __future__ import print_function
import os
import sys
import subprocess
try: # Older Pythons lack this
import urllib.request # We'll let them reach the Python
from importlib.util import find_spec # check anyway
except ImportError:
pass
import platform
import webbrowser
import hashlib
import argparse
import shutil
import stat
import time
try:
import pip
except ImportError:
pip = None
REQS_DIR = "lib"
sys.path.insert(0, REQS_DIR)
REQS_TXT = "requirements.txt"
REQS_NO_AUDIO_TXT = "requirements_no_audio.txt"
FFMPEG_BUILDS_URL = "https://ffmpeg.zeranoe.com/builds/"
INTRO = ("==========================\n"
"Red Discord Bot - Launcher\n"
"==========================\n")
IS_WINDOWS = os.name == "nt"
IS_MAC = sys.platform == "darwin"
IS_64BIT = platform.machine().endswith("64")
INTERACTIVE_MODE = not len(sys.argv) > 1 # CLI flags = non-interactive
PYTHON_OK = sys.version_info >= (3, 5)
FFMPEG_FILES = {
"ffmpeg.exe" : "e0d60f7c0d27ad9d7472ddf13e78dc89",
"ffplay.exe" : "d100abe8281cbcc3e6aebe550c675e09",
"ffprobe.exe" : "0e84b782c0346a98434ed476e937764f"
}
def parse_cli_arguments():
parser = argparse.ArgumentParser(description="Red - Discord Bot's launcher")
parser.add_argument("--start", "-s",
help="Starts Red",
action="store_true")
parser.add_argument("--auto-restart",
help="Autorestarts Red in case of issues",
action="store_true")
parser.add_argument("--update-red",
help="Updates Red (git)",
action="store_true")
parser.add_argument("--update-reqs",
help="Updates requirements (w/ audio)",
action="store_true")
parser.add_argument("--update-reqs-no-audio",
help="Updates requirements (w/o audio)",
action="store_true")
parser.add_argument("--repair",
help="Issues a git reset --hard",
action="store_true")
return parser.parse_args()
def install_reqs(audio):
remove_reqs_readonly()
interpreter = sys.executable
if interpreter is None:
print("Python interpreter not found.")
return
txt = REQS_TXT if audio else REQS_NO_AUDIO_TXT
args = [
interpreter, "-m",
"pip", "install",
"--upgrade",
"--target", REQS_DIR,
"-r", txt
]
if IS_MAC: # --target is a problem on Homebrew. See PR #552
args.remove("--target")
args.remove(REQS_DIR)
code = subprocess.call(args)
if code == 0:
print("\nRequirements setup completed.")
else:
print("\nAn error occurred and the requirements setup might "
"not be completed. Consult the docs.\n")
def update_pip():
interpreter = sys.executable
if interpreter is None:
print("Python interpreter not found.")
return
args = [
interpreter, "-m",
"pip", "install",
"--upgrade", "pip"
]
code = subprocess.call(args)
if code == 0:
print("\nPip has been updated.")
else:
print("\nAn error occurred and pip might not have been updated.")
def update_red():
try:
code = subprocess.call(("git", "pull", "--ff-only"))
except FileNotFoundError:
print("\nError: Git not found. It's either not installed or not in "
"the PATH environment variable like requested in the guide.")
return
if code == 0:
print("\nRed has been updated")
else:
print("\nRed could not update properly. If this is caused by edits "
"you have made to the code you can try the repair option from "
"the Maintenance submenu")
def reset_red(reqs=False, data=False, cogs=False, git_reset=False):
if reqs:
try:
shutil.rmtree(REQS_DIR, onerror=remove_readonly)
print("Installed local packages have been wiped.")
except FileNotFoundError:
pass
except Exception as e:
print("An error occurred when trying to remove installed "
"requirements: {}".format(e))
if data:
try:
shutil.rmtree("data", onerror=remove_readonly)
print("'data' folder has been wiped.")
except FileNotFoundError:
pass
except Exception as e:
print("An error occurred when trying to remove the 'data' folder: "
"{}".format(e))
if cogs:
try:
shutil.rmtree("cogs", onerror=remove_readonly)
print("'cogs' folder has been wiped.")
except FileNotFoundError:
pass
except Exception as e:
print("An error occurred when trying to remove the 'cogs' folder: "
"{}".format(e))
if git_reset:
code = subprocess.call(("git", "reset", "--hard"))
if code == 0:
print("Red has been restored to the last local commit.")
else:
print("The repair has failed.")
def download_ffmpeg(bitness):
clear_screen()
repo = "https://github.com/Twentysix26/Red-DiscordBot/raw/master/"
verified = []
if bitness == "32bit":
print("Please download 'ffmpeg 32bit static' from the page that "
"is about to open.\nOnce done, open the 'bin' folder located "
"inside the zip.\nThere should be 3 files: ffmpeg.exe, "
"ffplay.exe, ffprobe.exe.\nPut all three of them into the "
"bot's main folder.")
time.sleep(4)
webbrowser.open(FFMPEG_BUILDS_URL)
return
for filename in FFMPEG_FILES:
if os.path.isfile(filename):
print("{} already present. Verifying integrity... "
"".format(filename), end="")
_hash = calculate_md5(filename)
if _hash == FFMPEG_FILES[filename]:
verified.append(filename)
print("Ok")
continue
else:
print("Hash mismatch. Redownloading.")
print("Downloading {}... Please wait.".format(filename))
with urllib.request.urlopen(repo + filename) as data:
with open(filename, "wb") as f:
f.write(data.read())
print("Download completed.")
for filename, _hash in FFMPEG_FILES.items():
if filename in verified:
continue
print("Verifying {}... ".format(filename), end="")
if not calculate_md5(filename) != _hash:
print("Passed.")
else:
print("Hash mismatch. Please redownload.")
print("\nAll files have been downloaded.")
def verify_requirements():
sys.path_importer_cache = {} # I don't know if the cache reset has any
basic = find_spec("discord") # side effect. Without it, the lib folder
audio = find_spec("nacl") # wouldn't be seen if it didn't exist
if not basic: # when the launcher was started
return None
elif not audio:
return False
else:
return True
def is_git_installed():
try:
subprocess.call(["git", "--version"], stdout=subprocess.DEVNULL,
stdin =subprocess.DEVNULL,
stderr=subprocess.DEVNULL)
except FileNotFoundError:
return False
else:
return True
def requirements_menu():
clear_screen()
while True:
print(INTRO)
print("Main requirements:\n")
print("1. Install basic + audio requirements (recommended)")
print("2. Install basic requirements")
if IS_WINDOWS:
print("\nffmpeg (required for audio):")
print("3. Install ffmpeg 32bit")
if IS_64BIT:
print("4. Install ffmpeg 64bit (recommended on Windows 64bit)")
print("\n0. Go back")
choice = user_choice()
if choice == "1":
install_reqs(audio=True)
wait()
elif choice == "2":
install_reqs(audio=False)
wait()
elif choice == "3" and IS_WINDOWS:
download_ffmpeg(bitness="32bit")
wait()
elif choice == "4" and (IS_WINDOWS and IS_64BIT):
download_ffmpeg(bitness="64bit")
wait()
elif choice == "0":
break
clear_screen()
def update_menu():
clear_screen()
while True:
print(INTRO)
reqs = verify_requirements()
if reqs is None:
status = "No requirements installed"
elif reqs is False:
status = "Basic requirements installed (no audio)"
else:
status = "Basic + audio requirements installed"
print("Status: " + status + "\n")
print("Update:\n")
print("Red:")
print("1. Update Red + requirements (recommended)")
print("2. Update Red")
print("3. Update requirements")
print("\nOthers:")
print("4. Update pip (might require admin privileges)")
print("\n0. Go back")
choice = user_choice()
if choice == "1":
update_red()
print("Updating requirements...")
reqs = verify_requirements()
if reqs is not None:
install_reqs(audio=reqs)
else:
print("The requirements haven't been installed yet.")
wait()
elif choice == "2":
update_red()
wait()
elif choice == "3":
reqs = verify_requirements()
if reqs is not None:
install_reqs(audio=reqs)
else:
print("The requirements haven't been installed yet.")
wait()
elif choice == "4":
update_pip()
wait()
elif choice == "0":
break
clear_screen()
def maintenance_menu():
clear_screen()
while True:
print(INTRO)
print("Maintenance:\n")
print("1. Repair Red (discards code changes, keeps data intact)")
print("2. Wipe 'data' folder (all settings, cogs' data...)")
print("3. Wipe 'lib' folder (all local requirements / local installed"
" python packages)")
print("4. Factory reset")
print("\n0. Go back")
choice = user_choice()
if choice == "1":
print("Any code modification you have made will be lost. Data/"
"non-default cogs will be left intact. Are you sure?")
if user_pick_yes_no():
reset_red(git_reset=True)
wait()
elif choice == "2":
print("Are you sure? This will wipe the 'data' folder, which "
"contains all your settings and cogs' data.\nThe 'cogs' "
"folder, however, will be left intact.")
if user_pick_yes_no():
reset_red(data=True)
wait()
elif choice == "3":
reset_red(reqs=True)
wait()
elif choice == "4":
print("Are you sure? This will wipe ALL your Red's installation "
"data.\nYou'll lose all your settings, cogs and any "
"modification you have made.\nThere is no going back.")
if user_pick_yes_no():
reset_red(reqs=True, data=True, cogs=True, git_reset=True)
wait()
elif choice == "0":
break
clear_screen()
def run_red(autorestart):
interpreter = sys.executable
if interpreter is None: # This should never happen
raise RuntimeError("Couldn't find Python's interpreter")
if verify_requirements() is None:
print("You don't have the requirements to start Red. "
"Install them from the launcher.")
if not INTERACTIVE_MODE:
exit(1)
cmd = (interpreter, "red.py")
while True:
try:
code = subprocess.call(cmd)
except KeyboardInterrupt:
code = 0
break
else:
if code == 0:
break
elif code == 26:
print("Restarting Red...")
continue
else:
if not autorestart:
break
print("Red has been terminated. Exit code: %d" % code)
if INTERACTIVE_MODE:
wait()
def clear_screen():
if IS_WINDOWS:
os.system("cls")
else:
os.system("clear")
def wait():
if INTERACTIVE_MODE:
input("Press enter to continue.")
def user_choice():
return input("> ").lower().strip()
def user_pick_yes_no():
choice = None
yes = ("yes", "y")
no = ("no", "n")
while choice not in yes and choice not in no:
choice = input("Yes/No > ").lower().strip()
return choice in yes
def remove_readonly(func, path, excinfo):
os.chmod(path, 0o755)
func(path)
def remove_reqs_readonly():
"""Workaround for issue #569"""
if not os.path.isdir(REQS_DIR):
return
os.chmod(REQS_DIR, 0o755)
for root, dirs, files in os.walk(REQS_DIR):
for d in dirs:
os.chmod(os.path.join(root, d), 0o755)
for f in files:
os.chmod(os.path.join(root, f), 0o755)
def calculate_md5(filename):
hash_md5 = hashlib.md5()
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def create_fast_start_scripts():
"""Creates scripts for fast boot of Red without going
through the launcher"""
interpreter = sys.executable
if not interpreter:
return
call = "\"{}\" launcher.py".format(interpreter)
start_red = "{} --start".format(call)
start_red_autorestart = "{} --start --auto-restart".format(call)
modified = False
if IS_WINDOWS:
ccd = "pushd %~dp0\n"
pause = "\npause"
ext = ".bat"
else:
ccd = 'cd "$(dirname "$0")"\n'
pause = "\nread -rsp $'Press enter to continue...\\n'"
if not IS_MAC:
ext = ".sh"
else:
ext = ".command"
start_red = ccd + start_red + pause
start_red_autorestart = ccd + start_red_autorestart + pause
files = {
"start_red" + ext : start_red,
"start_red_autorestart" + ext : start_red_autorestart
}
if not IS_WINDOWS:
files["start_launcher" + ext] = ccd + call
for filename, content in files.items():
if not os.path.isfile(filename):
print("Creating {}... (fast start scripts)".format(filename))
modified = True
with open(filename, "w") as f:
f.write(content)
if not IS_WINDOWS and modified: # Let's make them executable on Unix
for script in files:
st = os.stat(script)
os.chmod(script, st.st_mode | stat.S_IEXEC)
def main():
print("Verifying git installation...")
has_git = is_git_installed()
is_git_installation = os.path.isdir(".git")
if IS_WINDOWS:
os.system("TITLE Red Discord Bot - Launcher")
clear_screen()
try:
create_fast_start_scripts()
except Exception as e:
print("Failed making fast start scripts: {}\n".format(e))
while True:
print(INTRO)
if not is_git_installation:
print("WARNING: It doesn't look like Red has been "
"installed with git.\nThis means that you won't "
"be able to update and some features won't be working.\n"
"A reinstallation is recommended. Follow the guide "
"properly this time:\n"
"https://twentysix26.github.io/Red-Docs/\n")
if not has_git:
print("WARNING: Git not found. This means that it's either not "
"installed or not in the PATH environment variable like "
"requested in the guide.\n")
print("1. Run Red /w autorestart in case of issues")
print("2. Run Red")
print("3. Update")
print("4. Install requirements")
print("5. Maintenance (repair, reset...)")
print("\n0. Quit")
choice = user_choice()
if choice == "1":
run_red(autorestart=True)
elif choice == "2":
run_red(autorestart=False)
elif choice == "3":
update_menu()
elif choice == "4":
requirements_menu()
elif choice == "5":
maintenance_menu()
elif choice == "0":
break
clear_screen()
args = parse_cli_arguments()
if __name__ == '__main__':
abspath = os.path.abspath(__file__)
dirname = os.path.dirname(abspath)
# Sets current directory to the script's
os.chdir(dirname)
if not PYTHON_OK:
print("Red needs Python 3.5 or superior. Install the required "
"version.\nPress enter to continue.")
if INTERACTIVE_MODE:
wait()
exit(1)
if pip is None:
print("Red cannot work without the pip module. Please make sure to "
"install Python without unchecking any option during the setup")
wait()
exit(1)
if args.repair:
reset_red(git_reset=True)
if args.update_red:
update_red()
if args.update_reqs:
install_reqs(audio=True)
elif args.update_reqs_no_audio:
install_reqs(audio=False)
if INTERACTIVE_MODE:
main()
elif args.start:
print("Starting Red...")
run_red(autorestart=args.auto_restart)

View File

@ -0,0 +1,91 @@
PyNaCl
======
.. image:: https://pypip.in/version/PyNaCl/badge.svg?style=flat
:target: https://pypi.python.org/pypi/PyNaCl/
:alt: Latest Version
.. image:: https://travis-ci.org/pyca/pynacl.svg?branch=master
:target: https://travis-ci.org/pyca/pynacl
.. image:: https://coveralls.io/repos/pyca/pynacl/badge.svg?branch=master
:target: https://coveralls.io/r/pyca/pynacl?branch=master
PyNaCl is a Python binding to the `Networking and Cryptography library`_,
a crypto library with the stated goal of improving usability, security and
speed.
.. _Networking and Cryptography library: https://nacl.cr.yp.to/
Installation
------------
Linux
~~~~~
PyNaCl relies on libsodium_, a portable C library. A copy is bundled
with PyNaCl so to install you can run:
.. code-block:: console
$ pip install pynacl
If you'd prefer to use one provided by your distribution you can disable
the bundled copy during install by running:
.. code-block:: console
$ SODIUM_INSTALL=system pip install pynacl
.. _libsodium: https://github.com/jedisct1/libsodium
Mac OS X & Windows
~~~~~~~~~~~~~~~~~~
PyNaCl ships as a binary wheel on OS X and Windows so all dependencies
are included. Make sure you have an up-to-date pip and run:
.. code-block:: console
$ pip install pynacl
Features
--------
* Digital signatures
* Secret-key encryption
* Public-key encryption
Changes
-------
* 1.0.1:
* Fix an issue with absolute paths that prevented the creation of wheels.
* 1.0:
* PyNaCl has been ported to use the new APIs available in cffi 1.0+.
Due to this change we no longer support PyPy releases older than 2.6.
* Python 3.2 support has been dropped.
* Functions to convert between Ed25519 and Curve25519 keys have been added.
* 0.3.0:
* The low-level API (`nacl.c.*`) has been changed to match the
upstream NaCl C/C++ conventions (as well as those of other NaCl bindings).
The order of arguments and return values has changed significantly. To
avoid silent failures, `nacl.c` has been removed, and replaced with
`nacl.bindings` (with the new argument ordering). If you have code which
calls these functions (e.g. `nacl.c.crypto_box_keypair()`), you must review
the new docstrings and update your code/imports to match the new
conventions.

View File

@ -0,0 +1 @@
pip

View File

@ -0,0 +1,114 @@
Metadata-Version: 2.0
Name: PyNaCl
Version: 1.0.1
Summary: Python binding to the Networking and Cryptography (NaCl) library
Home-page: https://github.com/pyca/pynacl/
Author: The PyNaCl developers
Author-email: cryptography-dev@python.org
License: Apache License 2.0
Platform: UNKNOWN
Classifier: Programming Language :: Python :: Implementation :: CPython
Classifier: Programming Language :: Python :: Implementation :: PyPy
Classifier: Programming Language :: Python :: 2
Classifier: Programming Language :: Python :: 2.6
Classifier: Programming Language :: Python :: 2.7
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.3
Classifier: Programming Language :: Python :: 3.4
Classifier: Programming Language :: Python :: 3.5
Requires-Dist: cffi (>=1.1.0)
Requires-Dist: six
Provides-Extra: tests
Requires-Dist: pytest; extra == 'tests'
PyNaCl
======
.. image:: https://pypip.in/version/PyNaCl/badge.svg?style=flat
:target: https://pypi.python.org/pypi/PyNaCl/
:alt: Latest Version
.. image:: https://travis-ci.org/pyca/pynacl.svg?branch=master
:target: https://travis-ci.org/pyca/pynacl
.. image:: https://coveralls.io/repos/pyca/pynacl/badge.svg?branch=master
:target: https://coveralls.io/r/pyca/pynacl?branch=master
PyNaCl is a Python binding to the `Networking and Cryptography library`_,
a crypto library with the stated goal of improving usability, security and
speed.
.. _Networking and Cryptography library: https://nacl.cr.yp.to/
Installation
------------
Linux
~~~~~
PyNaCl relies on libsodium_, a portable C library. A copy is bundled
with PyNaCl so to install you can run:
.. code-block:: console
$ pip install pynacl
If you'd prefer to use one provided by your distribution you can disable
the bundled copy during install by running:
.. code-block:: console
$ SODIUM_INSTALL=system pip install pynacl
.. _libsodium: https://github.com/jedisct1/libsodium
Mac OS X & Windows
~~~~~~~~~~~~~~~~~~
PyNaCl ships as a binary wheel on OS X and Windows so all dependencies
are included. Make sure you have an up-to-date pip and run:
.. code-block:: console
$ pip install pynacl
Features
--------
* Digital signatures
* Secret-key encryption
* Public-key encryption
Changes
-------
* 1.0.1:
* Fix an issue with absolute paths that prevented the creation of wheels.
* 1.0:
* PyNaCl has been ported to use the new APIs available in cffi 1.0+.
Due to this change we no longer support PyPy releases older than 2.6.
* Python 3.2 support has been dropped.
* Functions to convert between Ed25519 and Curve25519 keys have been added.
* 0.3.0:
* The low-level API (`nacl.c.*`) has been changed to match the
upstream NaCl C/C++ conventions (as well as those of other NaCl bindings).
The order of arguments and return values has changed significantly. To
avoid silent failures, `nacl.c` has been removed, and replaced with
`nacl.bindings` (with the new argument ordering). If you have code which
calls these functions (e.g. `nacl.c.crypto_box_keypair()`), you must review
the new docstrings and update your code/imports to match the new
conventions.

View File

@ -0,0 +1,40 @@
PyNaCl-1.0.1.dist-info/DESCRIPTION.rst,sha256=3qnL8XwzGV1BtOJoieda8zEbuLvAY0wWIroa08MdjXY,2310
PyNaCl-1.0.1.dist-info/METADATA,sha256=05N1teI88uA_YyplqlutrzCcpWNADo_66_DDqH483Eo,3194
PyNaCl-1.0.1.dist-info/RECORD,,
PyNaCl-1.0.1.dist-info/WHEEL,sha256=xiHTm3JxoVljPSD6nSGhq3B4VY9iUqMNXwYQ259n1PI,102
PyNaCl-1.0.1.dist-info/metadata.json,sha256=a5yJUJb5gjQCsCaQGG4I-ZbQFD9z8LMYAstuJ8z8XZ0,1064
PyNaCl-1.0.1.dist-info/top_level.txt,sha256=wfdEOI_G2RIzmzsMyhpqP17HUh6Jcqi99to9aHLEslo,13
nacl/__init__.py,sha256=avK2K7KSLic4oYjD2vCcAYqKG5Ed277wMfZuwNwr2oc,1170
nacl/_sodium.cp36-win32.pyd,sha256=XdBkQdVqquHCCaULVcrij_XAtlGq3faEgYwe96huYT4,183296
nacl/encoding.py,sha256=tOiyIQVVpGU6A4Lzr0tMuqomhc_Aj0V_c1t56a-ZtPw,1928
nacl/exceptions.py,sha256=cY0MvWUHpa443Qi9ZjikX2bg2zWC4ko0vChO90JEaf4,881
nacl/hash.py,sha256=SP9wJIcs5bOg2l52JCpqe_p1BjwOA_NYdWsuHuJ9cFs,962
nacl/public.py,sha256=BVD0UMu26mYhMrjxuKBC2xzP7YwILO0eLAKN5ZMTMK0,7519
nacl/secret.py,sha256=704VLB1VR0FO8vuAILG2O3Idh8KYf2S3jAu_fidf5cU,4777
nacl/signing.py,sha256=X8I0AUhA5jZH0pZi9c4uRqg9LQvCXZ1uBmGHJ1QTdmY,6661
nacl/utils.py,sha256=E8TKyHN6g_xCOn7eB9KrmO7JIbayuLXbcpv4kWBu_rQ,1601
nacl/bindings/__init__.py,sha256=GklZvnvt_q9Mlo9XOIG490c-y-G8CVD6gaWqyGncXtQ,3164
nacl/bindings/crypto_box.py,sha256=k2hnwFH5nSyUBqaIYYWXpGJOlYvssC46XmIqxqH811k,5603
nacl/bindings/crypto_hash.py,sha256=kBA-JVoRt9WkkFcF--prKz5BnWyHzLmif-PLt7FM4MU,1942
nacl/bindings/crypto_scalarmult.py,sha256=VHTiWhkhbXxUiok9SyFxHzgEhBBYMZ7kuPpNZtW9nGs,1579
nacl/bindings/crypto_secretbox.py,sha256=Q_E3fpCfhyvaKkb7ndwRfaHrMZk0aTDWUGWpguEUXeA,2641
nacl/bindings/crypto_sign.py,sha256=fD_346rtF3CCtlXkjy30A-HGkXY1Z8GOSRv_0kUTsFg,4857
nacl/bindings/randombytes.py,sha256=eThts6s-9xBXOl3GNzT57fV1dZUhzPjjAmAVIUHfcrc,988
nacl/bindings/sodium_core.py,sha256=8B6CPFXlkmzPCRJ_Asvc-KFS5yaqe6OCYqsL5xOsvgE,950
PyNaCl-1.0.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
nacl/bindings/__pycache__/crypto_box.cpython-36.pyc,,
nacl/bindings/__pycache__/crypto_hash.cpython-36.pyc,,
nacl/bindings/__pycache__/crypto_scalarmult.cpython-36.pyc,,
nacl/bindings/__pycache__/crypto_secretbox.cpython-36.pyc,,
nacl/bindings/__pycache__/crypto_sign.cpython-36.pyc,,
nacl/bindings/__pycache__/randombytes.cpython-36.pyc,,
nacl/bindings/__pycache__/sodium_core.cpython-36.pyc,,
nacl/bindings/__pycache__/__init__.cpython-36.pyc,,
nacl/__pycache__/encoding.cpython-36.pyc,,
nacl/__pycache__/exceptions.cpython-36.pyc,,
nacl/__pycache__/hash.cpython-36.pyc,,
nacl/__pycache__/public.cpython-36.pyc,,
nacl/__pycache__/secret.cpython-36.pyc,,
nacl/__pycache__/signing.cpython-36.pyc,,
nacl/__pycache__/utils.cpython-36.pyc,,
nacl/__pycache__/__init__.cpython-36.pyc,,

View File

@ -0,0 +1,5 @@
Wheel-Version: 1.0
Generator: bdist_wheel (0.29.0)
Root-Is-Purelib: false
Tag: cp36-cp36m-win32

View File

@ -0,0 +1,2 @@
_sodium
nacl

Binary file not shown.

View File

@ -0,0 +1,352 @@
Metadata-Version: 1.1
Name: aiohttp
Version: 1.0.5
Summary: http client/server for asyncio
Home-page: https://github.com/KeepSafe/aiohttp/
Author: Andrew Svetlov
Author-email: andrew.svetlov@gmail.com
License: Apache 2
Description: http client/server for asyncio
==============================
.. image:: https://raw.github.com/KeepSafe/aiohttp/master/docs/_static/aiohttp-icon-128x128.png
:height: 64px
:width: 64px
:alt: aiohttp logo
.. image:: https://travis-ci.org/KeepSafe/aiohttp.svg?branch=master
:target: https://travis-ci.org/KeepSafe/aiohttp
:align: right
.. image:: https://codecov.io/gh/KeepSafe/aiohttp/branch/master/graph/badge.svg
:target: https://codecov.io/gh/KeepSafe/aiohttp
.. image:: https://badge.fury.io/py/aiohttp.svg
:target: https://badge.fury.io/py/aiohttp
Features
--------
- Supports both client and server side of HTTP protocol.
- Supports both client and server Web-Sockets out-of-the-box.
- Web-server has middlewares and pluggable routing.
Getting started
---------------
Client
^^^^^^
To retrieve something from the web:
.. code-block:: python
import aiohttp
import asyncio
async def fetch(session, url):
with aiohttp.Timeout(10, loop=session.loop):
async with session.get(url) as response:
return await response.text()
async def main(loop):
async with aiohttp.ClientSession(loop=loop) as session:
html = await fetch(session, 'http://python.org')
print(html)
if __name__ == '__main__':
loop = asyncio.get_event_loop()
loop.run_until_complete(main(loop))
Server
^^^^^^
This is simple usage example:
.. code-block:: python
from aiohttp import web
async def handle(request):
name = request.match_info.get('name', "Anonymous")
text = "Hello, " + name
return web.Response(text=text)
async def wshandler(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
async for msg in ws:
if msg.type == web.MsgType.text:
ws.send_str("Hello, {}".format(msg.data))
elif msg.type == web.MsgType.binary:
ws.send_bytes(msg.data)
elif msg.type == web.MsgType.close:
break
return ws
app = web.Application()
app.router.add_get('/echo', wshandler)
app.router.add_get('/', handle)
app.router.add_get('/{name}', handle)
web.run_app(app)
Note: examples are written for Python 3.5+ and utilize PEP-492 aka
async/await. If you are using Python 3.4 please replace ``await`` with
``yield from`` and ``async def`` with ``@coroutine`` e.g.::
async def coro(...):
ret = await f()
should be replaced by::
@asyncio.coroutine
def coro(...):
ret = yield from f()
Documentation
-------------
https://aiohttp.readthedocs.io/
Discussion list
---------------
*aio-libs* google group: https://groups.google.com/forum/#!forum/aio-libs
Requirements
------------
- Python >= 3.4.2
- chardet_
- multidict_
Optionally you may install the cChardet_ and aiodns_ libraries (highly
recommended for sake of speed).
.. _chardet: https://pypi.python.org/pypi/chardet
.. _aiodns: https://pypi.python.org/pypi/aiodns
.. _multidict: https://pypi.python.org/pypi/multidict
.. _cChardet: https://pypi.python.org/pypi/cchardet
License
-------
``aiohttp`` is offered under the Apache 2 license.
Source code
------------
The latest developer version is available in a github repository:
https://github.com/KeepSafe/aiohttp
Benchmarks
----------
If you are interested in by efficiency, AsyncIO community maintains a
list of benchmarks on the official wiki:
https://github.com/python/asyncio/wiki/Benchmarks
CHANGES
=======
1.0.5 (2016-10-11)
------------------
- Fix StreamReader._read_nowait to return all available
data up to the requested amount #1297
1.0.4 (2016-09-22)
------------------
- Fix FlowControlStreamReader.read_nowait so that it checks
whether the transport is paused #1206
1.0.2 (2016-09-22)
------------------
- Make CookieJar compatible with 32-bit systems #1188
- Add missing `WSMsgType` to `web_ws.__all__`, see #1200
- Fix `CookieJar` ctor when called with `loop=None` #1203
- Fix broken upper-casing in wsgi support #1197
1.0.1 (2016-09-16)
------------------
- Restore `aiohttp.web.MsgType` alias for `aiohttp.WSMsgType` for sake
of backward compatibility #1178
- Tune alabaster schema.
- Use `text/html` content type for displaying index pages by static
file handler.
- Fix `AssertionError` in static file handling #1177
- Fix access log formats `%O` and `%b` for static file handling
- Remove `debug` setting of GunicornWorker, use `app.debug`
to control its debug-mode instead
1.0.0 (2016-09-16)
-------------------
- Change default size for client session's connection pool from
unlimited to 20 #977
- Add IE support for cookie deletion. #994
- Remove deprecated `WebSocketResponse.wait_closed` method (BACKWARD
INCOMPATIBLE)
- Remove deprecated `force` parameter for `ClientResponse.close`
method (BACKWARD INCOMPATIBLE)
- Avoid using of mutable CIMultiDict kw param in make_mocked_request
#997
- Make WebSocketResponse.close a little bit faster by avoiding new
task creating just for timeout measurement
- Add `proxy` and `proxy_auth` params to `client.get()` and family,
deprecate `ProxyConnector` #998
- Add support for websocket send_json and receive_json, synchronize
server and client API for websockets #984
- Implement router shourtcuts for most useful HTTP methods, use
`app.router.add_get()`, `app.router.add_post()` etc. instead of
`app.router.add_route()` #986
- Support SSL connections for gunicorn worker #1003
- Move obsolete examples to legacy folder
- Switch to multidict 2.0 and title-cased strings #1015
- `{FOO}e` logger format is case-sensitive now
- Fix logger report for unix socket 8e8469b
- Rename aiohttp.websocket to aiohttp._ws_impl
- Rename aiohttp.MsgType tp aiohttp.WSMsgType
- Introduce aiohttp.WSMessage officially
- Rename Message -> WSMessage
- Remove deprecated decode param from resp.read(decode=True)
- Use 5min default client timeout #1028
- Relax HTTP method validation in UrlDispatcher #1037
- Pin minimal supported asyncio version to 3.4.2+ (`loop.is_close()`
should be present)
- Remove aiohttp.websocket module (BACKWARD INCOMPATIBLE)
Please use high-level client and server approaches
- Link header for 451 status code is mandatory
- Fix test_client fixture to allow multiple clients per test #1072
- make_mocked_request now accepts dict as headers #1073
- Add Python 3.5.2/3.6+ compatibility patch for async generator
protocol change #1082
- Improvement test_client can accept instance object #1083
- Simplify ServerHttpProtocol implementation #1060
- Add a flag for optional showing directory index for static file
handling #921
- Define `web.Application.on_startup()` signal handler #1103
- Drop ChunkedParser and LinesParser #1111
- Call `Application.startup` in GunicornWebWorker #1105
- Fix client handling hostnames with 63 bytes when a port is given in
the url #1044
- Implement proxy support for ClientSession.ws_connect #1025
- Return named tuple from WebSocketResponse.can_prepare #1016
- Fix access_log_format in `GunicornWebWorker` #1117
- Setup Content-Type to application/octet-stream by default #1124
- Deprecate debug parameter from app.make_handler(), use
`Application(debug=True)` instead #1121
- Remove fragment string in request path #846
- Use aiodns.DNSResolver.gethostbyname() if available #1136
- Fix static file sending on uvloop when sendfile is available #1093
- Make prettier urls if query is empty dict #1143
- Fix redirects for HEAD requests #1147
- Default value for `StreamReader.read_nowait` is -1 from now #1150
- `aiohttp.StreamReader` is not inherited from `asyncio.StreamReader` from now
(BACKWARD INCOMPATIBLE) #1150
- Streams documentation added #1150
- Add `multipart` coroutine method for web Request object #1067
- Publish ClientSession.loop property #1149
- Fix static file with spaces #1140
- Fix piling up asyncio loop by cookie expiration callbacks #1061
- Drop `Timeout` class for sake of `async_timeout` external library.
`aiohttp.Timeout` is an alias for `async_timeout.timeout`
- `use_dns_cache` parameter of `aiohttp.TCPConnector` is `True` by
default (BACKWARD INCOMPATIBLE) #1152
- `aiohttp.TCPConnector` uses asynchronous DNS resolver if available by
default (BACKWARD INCOMPATIBLE) #1152
- Conform to RFC3986 - do not include url fragments in client requests #1174
- Drop `ClientSession.cookies` (BACKWARD INCOMPATIBLE) #1173
- Refactor `AbstractCookieJar` public API (BACKWARD INCOMPATIBLE) #1173
- Fix clashing cookies with have the same name but belong to different
domains (BACKWARD INCOMPATIBLE) #1125
- Support binary Content-Transfer-Encoding #1169
Platform: UNKNOWN
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Intended Audience :: Developers
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.4
Classifier: Programming Language :: Python :: 3.5
Classifier: Topic :: Internet :: WWW/HTTP

View File

@ -0,0 +1,162 @@
CHANGES.rst
CONTRIBUTORS.txt
LICENSE.txt
MANIFEST.in
Makefile
README.rst
setup.cfg
setup.py
aiohttp/__init__.py
aiohttp/_websocket.c
aiohttp/_websocket.pyx
aiohttp/_ws_impl.py
aiohttp/abc.py
aiohttp/client.py
aiohttp/client_reqrep.py
aiohttp/client_ws.py
aiohttp/connector.py
aiohttp/cookiejar.py
aiohttp/errors.py
aiohttp/file_sender.py
aiohttp/hdrs.py
aiohttp/helpers.py
aiohttp/log.py
aiohttp/multipart.py
aiohttp/parsers.py
aiohttp/protocol.py
aiohttp/pytest_plugin.py
aiohttp/resolver.py
aiohttp/server.py
aiohttp/signals.py
aiohttp/streams.py
aiohttp/test_utils.py
aiohttp/web.py
aiohttp/web_exceptions.py
aiohttp/web_reqrep.py
aiohttp/web_urldispatcher.py
aiohttp/web_ws.py
aiohttp/worker.py
aiohttp/wsgi.py
aiohttp.egg-info/PKG-INFO
aiohttp.egg-info/SOURCES.txt
aiohttp.egg-info/dependency_links.txt
aiohttp.egg-info/requires.txt
aiohttp.egg-info/top_level.txt
docs/Makefile
docs/abc.rst
docs/aiohttp-icon.ico
docs/aiohttp-icon.svg
docs/api.rst
docs/changes.rst
docs/client.rst
docs/client_reference.rst
docs/conf.py
docs/contributing.rst
docs/faq.rst
docs/glossary.rst
docs/gunicorn.rst
docs/index.rst
docs/logging.rst
docs/make.bat
docs/multipart.rst
docs/new_router.rst
docs/server.rst
docs/spelling_wordlist.txt
docs/streams.rst
docs/testing.rst
docs/tutorial.rst
docs/web.rst
docs/web_reference.rst
docs/_static/aiohttp-icon-128x128.png
docs/_static/aiohttp-icon-32x32.png
docs/_static/aiohttp-icon-64x64.png
docs/_static/aiohttp-icon-96x96.png
examples/background_tasks.py
examples/basic_srv.py
examples/cli_app.py
examples/client_auth.py
examples/client_json.py
examples/client_ws.py
examples/curl.py
examples/fake_server.py
examples/server.crt
examples/server.csr
examples/server.key
examples/static_files.py
examples/web_classview1.py
examples/web_cookies.py
examples/web_rewrite_headers_middleware.py
examples/web_srv.py
examples/web_ws.py
examples/websocket.html
examples/legacy/crawl.py
examples/legacy/srv.py
examples/legacy/tcp_protocol_parser.py
tests/conftest.py
tests/data.unknown_mime_type
tests/hello.txt.gz
tests/sample.crt
tests/sample.crt.der
tests/sample.key
tests/software_development_in_picture.jpg
tests/test_classbasedview.py
tests/test_client_connection.py
tests/test_client_functional.py
tests/test_client_functional_oldstyle.py
tests/test_client_request.py
tests/test_client_response.py
tests/test_client_session.py
tests/test_client_ws.py
tests/test_client_ws_functional.py
tests/test_connector.py
tests/test_cookiejar.py
tests/test_errors.py
tests/test_flowcontrol_streams.py
tests/test_helpers.py
tests/test_http_parser.py
tests/test_multipart.py
tests/test_parser_buffer.py
tests/test_protocol.py
tests/test_proxy.py
tests/test_pytest_plugin.py
tests/test_resolver.py
tests/test_run_app.py
tests/test_server.py
tests/test_signals.py
tests/test_stream_parser.py
tests/test_stream_protocol.py
tests/test_stream_writer.py
tests/test_streams.py
tests/test_test_utils.py
tests/test_urldispatch.py
tests/test_web_application.py
tests/test_web_cli.py
tests/test_web_exceptions.py
tests/test_web_functional.py
tests/test_web_middleware.py
tests/test_web_request.py
tests/test_web_request_handler.py
tests/test_web_response.py
tests/test_web_sendfile.py
tests/test_web_sendfile_functional.py
tests/test_web_urldispatcher.py
tests/test_web_websocket.py
tests/test_web_websocket_functional.py
tests/test_web_websocket_functional_oldstyle.py
tests/test_websocket_handshake.py
tests/test_websocket_parser.py
tests/test_websocket_writer.py
tests/test_worker.py
tests/test_wsgi.py
tests/autobahn/client.py
tests/autobahn/fuzzingclient.json
tests/autobahn/fuzzingserver.json
tests/autobahn/server.py
tests/test_py35/test_cbv35.py
tests/test_py35/test_client.py
tests/test_py35/test_client_websocket_35.py
tests/test_py35/test_multipart_35.py
tests/test_py35/test_resp.py
tests/test_py35/test_streams_35.py
tests/test_py35/test_test_utils_35.py
tests/test_py35/test_web_websocket_35.py

View File

@ -0,0 +1,66 @@
..\aiohttp\abc.py
..\aiohttp\client.py
..\aiohttp\client_reqrep.py
..\aiohttp\client_ws.py
..\aiohttp\connector.py
..\aiohttp\cookiejar.py
..\aiohttp\errors.py
..\aiohttp\file_sender.py
..\aiohttp\hdrs.py
..\aiohttp\helpers.py
..\aiohttp\log.py
..\aiohttp\multipart.py
..\aiohttp\parsers.py
..\aiohttp\protocol.py
..\aiohttp\pytest_plugin.py
..\aiohttp\resolver.py
..\aiohttp\server.py
..\aiohttp\signals.py
..\aiohttp\streams.py
..\aiohttp\test_utils.py
..\aiohttp\web.py
..\aiohttp\web_exceptions.py
..\aiohttp\web_reqrep.py
..\aiohttp\web_urldispatcher.py
..\aiohttp\web_ws.py
..\aiohttp\worker.py
..\aiohttp\wsgi.py
..\aiohttp\_ws_impl.py
..\aiohttp\__init__.py
..\aiohttp\_websocket.c
..\aiohttp\_websocket.pyx
..\aiohttp\__pycache__\abc.cpython-36.pyc
..\aiohttp\__pycache__\client.cpython-36.pyc
..\aiohttp\__pycache__\client_reqrep.cpython-36.pyc
..\aiohttp\__pycache__\client_ws.cpython-36.pyc
..\aiohttp\__pycache__\connector.cpython-36.pyc
..\aiohttp\__pycache__\cookiejar.cpython-36.pyc
..\aiohttp\__pycache__\errors.cpython-36.pyc
..\aiohttp\__pycache__\file_sender.cpython-36.pyc
..\aiohttp\__pycache__\hdrs.cpython-36.pyc
..\aiohttp\__pycache__\helpers.cpython-36.pyc
..\aiohttp\__pycache__\log.cpython-36.pyc
..\aiohttp\__pycache__\multipart.cpython-36.pyc
..\aiohttp\__pycache__\parsers.cpython-36.pyc
..\aiohttp\__pycache__\protocol.cpython-36.pyc
..\aiohttp\__pycache__\pytest_plugin.cpython-36.pyc
..\aiohttp\__pycache__\resolver.cpython-36.pyc
..\aiohttp\__pycache__\server.cpython-36.pyc
..\aiohttp\__pycache__\signals.cpython-36.pyc
..\aiohttp\__pycache__\streams.cpython-36.pyc
..\aiohttp\__pycache__\test_utils.cpython-36.pyc
..\aiohttp\__pycache__\web.cpython-36.pyc
..\aiohttp\__pycache__\web_exceptions.cpython-36.pyc
..\aiohttp\__pycache__\web_reqrep.cpython-36.pyc
..\aiohttp\__pycache__\web_urldispatcher.cpython-36.pyc
..\aiohttp\__pycache__\web_ws.cpython-36.pyc
..\aiohttp\__pycache__\worker.cpython-36.pyc
..\aiohttp\__pycache__\wsgi.cpython-36.pyc
..\aiohttp\__pycache__\_ws_impl.cpython-36.pyc
..\aiohttp\__pycache__\__init__.cpython-36.pyc
..\aiohttp\_websocket.cp36-win32.pyd
dependency_links.txt
PKG-INFO
requires.txt
SOURCES.txt
top_level.txt

View File

@ -0,0 +1,3 @@
chardet
multidict>=2.0
async_timeout

View File

@ -0,0 +1 @@
aiohttp

View File

@ -0,0 +1,41 @@
__version__ = '1.0.5'
# Deprecated, keep it here for a while for backward compatibility.
import multidict # noqa
# This relies on each of the submodules having an __all__ variable.
from multidict import * # noqa
from . import hdrs # noqa
from .protocol import * # noqa
from .connector import * # noqa
from .client import * # noqa
from .client_reqrep import * # noqa
from .errors import * # noqa
from .helpers import * # noqa
from .parsers import * # noqa
from .streams import * # noqa
from .multipart import * # noqa
from .client_ws import ClientWebSocketResponse # noqa
from ._ws_impl import WSMsgType, WSCloseCode, WSMessage, WebSocketError # noqa
from .file_sender import FileSender # noqa
from .cookiejar import CookieJar # noqa
from .resolver import * # noqa
MsgType = WSMsgType # backward compatibility
__all__ = (client.__all__ + # noqa
client_reqrep.__all__ + # noqa
errors.__all__ + # noqa
helpers.__all__ + # noqa
parsers.__all__ + # noqa
protocol.__all__ + # noqa
connector.__all__ + # noqa
streams.__all__ + # noqa
multidict.__all__ + # noqa
multipart.__all__ + # noqa
('hdrs', 'FileSender', 'WSMsgType', 'MsgType', 'WSCloseCode',
'WebSocketError', 'WSMessage',
'ClientWebSocketResponse', 'CookieJar'))

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,48 @@
from cpython cimport PyBytes_AsString
#from cpython cimport PyByteArray_AsString # cython still not exports that
cdef extern from "Python.h":
char* PyByteArray_AsString(bytearray ba) except NULL
from libc.stdint cimport uint32_t, uint64_t, uintmax_t
def _websocket_mask_cython(bytes mask, bytearray data):
"""Note, this function mutates it's `data` argument
"""
cdef:
Py_ssize_t data_len, i
# bit operations on signed integers are implementation-specific
unsigned char * in_buf
const unsigned char * mask_buf
uint32_t uint32_msk
uint64_t uint64_msk
assert len(mask) == 4
data_len = len(data)
in_buf = <unsigned char*>PyByteArray_AsString(data)
mask_buf = <const unsigned char*>PyBytes_AsString(mask)
uint32_msk = (<uint32_t*>mask_buf)[0]
# TODO: align in_data ptr to achieve even faster speeds
# does it need in python ?! malloc() always aligns to sizeof(long) bytes
if sizeof(size_t) >= 8:
uint64_msk = uint32_msk
uint64_msk = (uint64_msk << 32) | uint32_msk
while data_len >= 8:
(<uint64_t*>in_buf)[0] ^= uint64_msk
in_buf += 8
data_len -= 8
while data_len >= 4:
(<uint32_t*>in_buf)[0] ^= uint32_msk
in_buf += 4
data_len -= 4
for i in range(0, data_len):
in_buf[i] ^= mask_buf[i]
return data

View File

@ -0,0 +1,438 @@
"""WebSocket protocol versions 13 and 8."""
import base64
import binascii
import collections
import hashlib
import json
import os
import random
import sys
from enum import IntEnum
from struct import Struct
from aiohttp import errors, hdrs
from aiohttp.log import ws_logger
__all__ = ('WebSocketParser', 'WebSocketWriter', 'do_handshake',
'WSMessage', 'WebSocketError', 'WSMsgType', 'WSCloseCode')
class WSCloseCode(IntEnum):
OK = 1000
GOING_AWAY = 1001
PROTOCOL_ERROR = 1002
UNSUPPORTED_DATA = 1003
INVALID_TEXT = 1007
POLICY_VIOLATION = 1008
MESSAGE_TOO_BIG = 1009
MANDATORY_EXTENSION = 1010
INTERNAL_ERROR = 1011
SERVICE_RESTART = 1012
TRY_AGAIN_LATER = 1013
ALLOWED_CLOSE_CODES = {int(i) for i in WSCloseCode}
class WSMsgType(IntEnum):
CONTINUATION = 0x0
TEXT = 0x1
BINARY = 0x2
PING = 0x9
PONG = 0xa
CLOSE = 0x8
CLOSED = 0x101
ERROR = 0x102
text = TEXT
binary = BINARY
ping = PING
pong = PONG
close = CLOSE
closed = CLOSED
error = ERROR
WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
UNPACK_LEN2 = Struct('!H').unpack_from
UNPACK_LEN3 = Struct('!Q').unpack_from
UNPACK_CLOSE_CODE = Struct('!H').unpack
PACK_LEN1 = Struct('!BB').pack
PACK_LEN2 = Struct('!BBH').pack
PACK_LEN3 = Struct('!BBQ').pack
PACK_CLOSE_CODE = Struct('!H').pack
MSG_SIZE = 2 ** 14
_WSMessageBase = collections.namedtuple('_WSMessageBase',
['type', 'data', 'extra'])
class WSMessage(_WSMessageBase):
def json(self, *, loads=json.loads):
"""Return parsed JSON data.
.. versionadded:: 0.22
"""
return loads(self.data)
@property
def tp(self):
return self.type
CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None)
class WebSocketError(Exception):
"""WebSocket protocol parser error."""
def __init__(self, code, message):
self.code = code
super().__init__(message)
def WebSocketParser(out, buf):
while True:
fin, opcode, payload = yield from parse_frame(buf)
if opcode == WSMsgType.CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Invalid close code: {}'.format(close_code))
try:
close_message = payload[2:].decode('utf-8')
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT,
'Invalid UTF-8 text message') from exc
msg = WSMessage(WSMsgType.CLOSE, close_code, close_message)
elif payload:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Invalid close frame: {} {} {!r}'.format(
fin, opcode, payload))
else:
msg = WSMessage(WSMsgType.CLOSE, 0, '')
out.feed_data(msg, 0)
elif opcode == WSMsgType.PING:
out.feed_data(WSMessage(WSMsgType.PING, payload, ''), len(payload))
elif opcode == WSMsgType.PONG:
out.feed_data(WSMessage(WSMsgType.PONG, payload, ''), len(payload))
elif opcode not in (WSMsgType.TEXT, WSMsgType.BINARY):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Unexpected opcode={!r}".format(opcode))
else:
# load text/binary
data = [payload]
while not fin:
fin, _opcode, payload = yield from parse_frame(buf, True)
# We can receive ping/close in the middle of
# text message, Case 5.*
if _opcode == WSMsgType.PING:
out.feed_data(
WSMessage(WSMsgType.PING, payload, ''), len(payload))
fin, _opcode, payload = yield from parse_frame(buf, True)
elif _opcode == WSMsgType.CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if (close_code not in ALLOWED_CLOSE_CODES and
close_code < 3000):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Invalid close code: {}'.format(close_code))
try:
close_message = payload[2:].decode('utf-8')
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT,
'Invalid UTF-8 text message') from exc
msg = WSMessage(WSMsgType.CLOSE, close_code,
close_message)
elif payload:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Invalid close frame: {} {} {!r}'.format(
fin, opcode, payload))
else:
msg = WSMessage(WSMsgType.CLOSE, 0, '')
out.feed_data(msg, 0)
fin, _opcode, payload = yield from parse_frame(buf, True)
if _opcode != WSMsgType.CONTINUATION:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'The opcode in non-fin frame is expected '
'to be zero, got {!r}'.format(_opcode))
else:
data.append(payload)
if opcode == WSMsgType.TEXT:
try:
text = b''.join(data).decode('utf-8')
out.feed_data(WSMessage(WSMsgType.TEXT, text, ''),
len(text))
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT,
'Invalid UTF-8 text message') from exc
else:
data = b''.join(data)
out.feed_data(
WSMessage(WSMsgType.BINARY, data, ''), len(data))
native_byteorder = sys.byteorder
def _websocket_mask_python(mask, data):
"""Websocket masking function.
`mask` is a `bytes` object of length 4; `data` is a `bytes` object
of any length. Returns a `bytes` object of the same length as
`data` with the mask applied as specified in section 5.3 of RFC
6455.
This pure-python implementation may be replaced by an optimized
version when available.
"""
assert isinstance(data, bytearray), data
assert len(mask) == 4, mask
datalen = len(data)
if datalen == 0:
# everything work without this, but may be changed later in Python.
return bytearray()
data = int.from_bytes(data, native_byteorder)
mask = int.from_bytes(mask * (datalen // 4) + mask[: datalen % 4],
native_byteorder)
return (data ^ mask).to_bytes(datalen, native_byteorder)
if bool(os.environ.get('AIOHTTP_NO_EXTENSIONS')):
_websocket_mask = _websocket_mask_python
else:
try:
from ._websocket import _websocket_mask_cython
_websocket_mask = _websocket_mask_cython
except ImportError: # pragma: no cover
_websocket_mask = _websocket_mask_python
def parse_frame(buf, continuation=False):
"""Return the next frame from the socket."""
# read header
data = yield from buf.read(2)
first_byte, second_byte = data
fin = (first_byte >> 7) & 1
rsv1 = (first_byte >> 6) & 1
rsv2 = (first_byte >> 5) & 1
rsv3 = (first_byte >> 4) & 1
opcode = first_byte & 0xf
# frame-fin = %x0 ; more frames of this message follow
# / %x1 ; final frame of this message
# frame-rsv1 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv2 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv3 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
if rsv1 or rsv2 or rsv3:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Received frame with non-zero reserved bits')
if opcode > 0x7 and fin == 0:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Received fragmented control frame')
if fin == 0 and opcode == WSMsgType.CONTINUATION and not continuation:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
'Received new fragment frame with non-zero '
'opcode {!r}'.format(opcode))
has_mask = (second_byte >> 7) & 1
length = (second_byte) & 0x7f
# Control frames MUST have a payload length of 125 bytes or less
if opcode > 0x7 and length > 125:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Control frame payload cannot be larger than 125 bytes")
# read payload
if length == 126:
data = yield from buf.read(2)
length = UNPACK_LEN2(data)[0]
elif length > 126:
data = yield from buf.read(8)
length = UNPACK_LEN3(data)[0]
if has_mask:
mask = yield from buf.read(4)
if length:
payload = yield from buf.read(length)
else:
payload = bytearray()
if has_mask:
payload = _websocket_mask(bytes(mask), payload)
return fin, opcode, payload
class WebSocketWriter:
def __init__(self, writer, *, use_mask=False, random=random.Random()):
self.writer = writer
self.use_mask = use_mask
self.randrange = random.randrange
def _send_frame(self, message, opcode):
"""Send a frame over the websocket with message as its payload."""
msg_length = len(message)
use_mask = self.use_mask
if use_mask:
mask_bit = 0x80
else:
mask_bit = 0
if msg_length < 126:
header = PACK_LEN1(0x80 | opcode, msg_length | mask_bit)
elif msg_length < (1 << 16):
header = PACK_LEN2(0x80 | opcode, 126 | mask_bit, msg_length)
else:
header = PACK_LEN3(0x80 | opcode, 127 | mask_bit, msg_length)
if use_mask:
mask = self.randrange(0, 0xffffffff)
mask = mask.to_bytes(4, 'big')
message = _websocket_mask(mask, bytearray(message))
self.writer.write(header + mask + message)
else:
if len(message) > MSG_SIZE:
self.writer.write(header)
self.writer.write(message)
else:
self.writer.write(header + message)
def pong(self, message=b''):
"""Send pong message."""
if isinstance(message, str):
message = message.encode('utf-8')
self._send_frame(message, WSMsgType.PONG)
def ping(self, message=b''):
"""Send ping message."""
if isinstance(message, str):
message = message.encode('utf-8')
self._send_frame(message, WSMsgType.PING)
def send(self, message, binary=False):
"""Send a frame over the websocket with message as its payload."""
if isinstance(message, str):
message = message.encode('utf-8')
if binary:
self._send_frame(message, WSMsgType.BINARY)
else:
self._send_frame(message, WSMsgType.TEXT)
def close(self, code=1000, message=b''):
"""Close the websocket, sending the specified code and message."""
if isinstance(message, str):
message = message.encode('utf-8')
self._send_frame(
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE)
def do_handshake(method, headers, transport, protocols=()):
"""Prepare WebSocket handshake.
It return HTTP response code, response headers, websocket parser,
websocket writer. It does not perform any IO.
`protocols` is a sequence of known protocols. On successful handshake,
the returned response headers contain the first protocol in this list
which the server also knows.
"""
# WebSocket accepts only GET
if method.upper() != hdrs.METH_GET:
raise errors.HttpProcessingError(
code=405, headers=((hdrs.ALLOW, hdrs.METH_GET),))
if 'websocket' != headers.get(hdrs.UPGRADE, '').lower().strip():
raise errors.HttpBadRequest(
message='No WebSocket UPGRADE hdr: {}\n Can '
'"Upgrade" only to "WebSocket".'.format(headers.get(hdrs.UPGRADE)))
if 'upgrade' not in headers.get(hdrs.CONNECTION, '').lower():
raise errors.HttpBadRequest(
message='No CONNECTION upgrade hdr: {}'.format(
headers.get(hdrs.CONNECTION)))
# find common sub-protocol between client and server
protocol = None
if hdrs.SEC_WEBSOCKET_PROTOCOL in headers:
req_protocols = [str(proto.strip()) for proto in
headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
for proto in req_protocols:
if proto in protocols:
protocol = proto
break
else:
# No overlap found: Return no protocol as per spec
ws_logger.warning(
'Client protocols %r dont overlap server-known ones %r',
req_protocols, protocols)
# check supported version
version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, '')
if version not in ('13', '8', '7'):
raise errors.HttpBadRequest(
message='Unsupported version: {}'.format(version),
headers=((hdrs.SEC_WEBSOCKET_VERSION, '13'),))
# check client handshake for validity
key = headers.get(hdrs.SEC_WEBSOCKET_KEY)
try:
if not key or len(base64.b64decode(key)) != 16:
raise errors.HttpBadRequest(
message='Handshake error: {!r}'.format(key))
except binascii.Error:
raise errors.HttpBadRequest(
message='Handshake error: {!r}'.format(key)) from None
response_headers = [
(hdrs.UPGRADE, 'websocket'),
(hdrs.CONNECTION, 'upgrade'),
(hdrs.TRANSFER_ENCODING, 'chunked'),
(hdrs.SEC_WEBSOCKET_ACCEPT, base64.b64encode(
hashlib.sha1(key.encode() + WS_KEY).digest()).decode())]
if protocol:
response_headers.append((hdrs.SEC_WEBSOCKET_PROTOCOL, protocol))
# response code, headers, parser, writer, protocol
return (101,
response_headers,
WebSocketParser,
WebSocketWriter(transport),
protocol)

View File

@ -0,0 +1,88 @@
import asyncio
import sys
from abc import ABC, abstractmethod
from collections.abc import Iterable, Sized
PY_35 = sys.version_info >= (3, 5)
class AbstractRouter(ABC):
@asyncio.coroutine # pragma: no branch
@abstractmethod
def resolve(self, request):
"""Return MATCH_INFO for given request"""
class AbstractMatchInfo(ABC):
@asyncio.coroutine # pragma: no branch
@abstractmethod
def handler(self, request):
"""Execute matched request handler"""
@asyncio.coroutine # pragma: no branch
@abstractmethod
def expect_handler(self, request):
"""Expect handler for 100-continue processing"""
@property # pragma: no branch
@abstractmethod
def http_exception(self):
"""HTTPException instance raised on router's resolving, or None"""
@abstractmethod # pragma: no branch
def get_info(self):
"""Return a dict with additional info useful for introspection"""
class AbstractView(ABC):
def __init__(self, request):
self._request = request
@property
def request(self):
return self._request
@asyncio.coroutine # pragma: no branch
@abstractmethod
def __iter__(self):
while False: # pragma: no cover
yield None
if PY_35: # pragma: no branch
@abstractmethod
def __await__(self):
return # pragma: no cover
class AbstractResolver(ABC):
@asyncio.coroutine # pragma: no branch
@abstractmethod
def resolve(self, hostname):
"""Return IP address for given hostname"""
@asyncio.coroutine # pragma: no branch
@abstractmethod
def close(self):
"""Release resolver"""
class AbstractCookieJar(Sized, Iterable):
def __init__(self, *, loop=None):
self._loop = loop or asyncio.get_event_loop()
@abstractmethod
def clear(self):
"""Clear all cookies."""
@abstractmethod
def update_cookies(self, cookies, response_url=None):
"""Update cookies."""
@abstractmethod
def filter_cookies(self, request_url):
"""Return the jar's cookies filtered by their attributes."""

View File

@ -0,0 +1,786 @@
"""HTTP Client for asyncio."""
import asyncio
import base64
import hashlib
import os
import sys
import traceback
import urllib.parse
import warnings
from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr
import aiohttp
from . import hdrs, helpers
from ._ws_impl import WS_KEY, WebSocketParser, WebSocketWriter
from .client_reqrep import ClientRequest, ClientResponse
from .client_ws import ClientWebSocketResponse
from .cookiejar import CookieJar
from .errors import WSServerHandshakeError
from .helpers import Timeout
__all__ = ('ClientSession', 'request', 'get', 'options', 'head',
'delete', 'post', 'put', 'patch', 'ws_connect')
PY_35 = sys.version_info >= (3, 5)
class ClientSession:
"""First-class interface for making HTTP requests."""
_source_traceback = None
_connector = None
def __init__(self, *, connector=None, loop=None, cookies=None,
headers=None, skip_auto_headers=None,
auth=None, request_class=ClientRequest,
response_class=ClientResponse,
ws_response_class=ClientWebSocketResponse,
version=aiohttp.HttpVersion11,
cookie_jar=None):
if connector is None:
connector = aiohttp.TCPConnector(loop=loop)
loop = connector._loop # never None
else:
if loop is None:
loop = connector._loop # never None
elif connector._loop is not loop:
raise ValueError("loop argument must agree with connector")
self._loop = loop
if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))
if cookie_jar is None:
cookie_jar = CookieJar(loop=loop)
self._cookie_jar = cookie_jar
if cookies is not None:
self._cookie_jar.update_cookies(cookies)
self._connector = connector
self._default_auth = auth
self._version = version
# Convert to list of tuples
if headers:
headers = CIMultiDict(headers)
else:
headers = CIMultiDict()
self._default_headers = headers
if skip_auto_headers is not None:
self._skip_auto_headers = frozenset([istr(i)
for i in skip_auto_headers])
else:
self._skip_auto_headers = frozenset()
self._request_class = request_class
self._response_class = response_class
self._ws_response_class = ws_response_class
def __del__(self, _warnings=warnings):
if not self.closed:
self.close()
_warnings.warn("Unclosed client session {!r}".format(self),
ResourceWarning)
context = {'client_session': self,
'message': 'Unclosed client session'}
if self._source_traceback is not None:
context['source_traceback'] = self._source_traceback
self._loop.call_exception_handler(context)
def request(self, method, url, *,
params=None,
data=None,
headers=None,
skip_auto_headers=None,
auth=None,
allow_redirects=True,
max_redirects=10,
encoding='utf-8',
version=None,
compress=None,
chunked=None,
expect100=False,
read_until_eof=True,
proxy=None,
proxy_auth=None,
timeout=5*60):
"""Perform HTTP request."""
return _RequestContextManager(
self._request(
method,
url,
params=params,
data=data,
headers=headers,
skip_auto_headers=skip_auto_headers,
auth=auth,
allow_redirects=allow_redirects,
max_redirects=max_redirects,
encoding=encoding,
version=version,
compress=compress,
chunked=chunked,
expect100=expect100,
read_until_eof=read_until_eof,
proxy=proxy,
proxy_auth=proxy_auth,
timeout=timeout))
@asyncio.coroutine
def _request(self, method, url, *,
params=None,
data=None,
headers=None,
skip_auto_headers=None,
auth=None,
allow_redirects=True,
max_redirects=10,
encoding='utf-8',
version=None,
compress=None,
chunked=None,
expect100=False,
read_until_eof=True,
proxy=None,
proxy_auth=None,
timeout=5*60):
if version is not None:
warnings.warn("HTTP version should be specified "
"by ClientSession constructor", DeprecationWarning)
else:
version = self._version
if self.closed:
raise RuntimeError('Session is closed')
redirects = 0
history = []
# Merge with default headers and transform to CIMultiDict
headers = self._prepare_headers(headers)
if auth is None:
auth = self._default_auth
# It would be confusing if we support explicit Authorization header
# with `auth` argument
if (headers is not None and
auth is not None and
hdrs.AUTHORIZATION in headers):
raise ValueError("Can't combine `Authorization` header with "
"`auth` argument")
skip_headers = set(self._skip_auto_headers)
if skip_auto_headers is not None:
for i in skip_auto_headers:
skip_headers.add(istr(i))
while True:
url, _ = urllib.parse.urldefrag(url)
cookies = self._cookie_jar.filter_cookies(url)
req = self._request_class(
method, url, params=params, headers=headers,
skip_auto_headers=skip_headers, data=data,
cookies=cookies, encoding=encoding,
auth=auth, version=version, compress=compress, chunked=chunked,
expect100=expect100,
loop=self._loop, response_class=self._response_class,
proxy=proxy, proxy_auth=proxy_auth, timeout=timeout)
with Timeout(timeout, loop=self._loop):
conn = yield from self._connector.connect(req)
try:
resp = req.send(conn.writer, conn.reader)
try:
yield from resp.start(conn, read_until_eof)
except:
resp.close()
conn.close()
raise
except (aiohttp.HttpProcessingError,
aiohttp.ServerDisconnectedError) as exc:
raise aiohttp.ClientResponseError() from exc
except OSError as exc:
raise aiohttp.ClientOSError(*exc.args) from exc
self._cookie_jar.update_cookies(resp.cookies, resp.url)
# redirects
if resp.status in (301, 302, 303, 307) and allow_redirects:
redirects += 1
history.append(resp)
if max_redirects and redirects >= max_redirects:
resp.close()
break
else:
# TODO: close the connection if BODY is large enough
# Redirect with big BODY is forbidden by HTTP protocol
# but malformed server may send illegal response.
# Small BODIES with text like "Not Found" are still
# perfectly fine and should be accepted.
yield from resp.release()
# For 301 and 302, mimic IE behaviour, now changed in RFC.
# Details: https://github.com/kennethreitz/requests/pull/269
if (resp.status == 303 and resp.method != hdrs.METH_HEAD) \
or (resp.status in (301, 302) and
resp.method == hdrs.METH_POST):
method = hdrs.METH_GET
data = None
if headers.get(hdrs.CONTENT_LENGTH):
headers.pop(hdrs.CONTENT_LENGTH)
r_url = (resp.headers.get(hdrs.LOCATION) or
resp.headers.get(hdrs.URI))
scheme = urllib.parse.urlsplit(r_url)[0]
if scheme not in ('http', 'https', ''):
resp.close()
raise ValueError('Can redirect only to http or https')
elif not scheme:
r_url = urllib.parse.urljoin(url, r_url)
url = r_url
params = None
yield from resp.release()
continue
break
resp._history = tuple(history)
return resp
def ws_connect(self, url, *,
protocols=(),
timeout=10.0,
autoclose=True,
autoping=True,
auth=None,
origin=None,
headers=None,
proxy=None,
proxy_auth=None):
"""Initiate websocket connection."""
return _WSRequestContextManager(
self._ws_connect(url,
protocols=protocols,
timeout=timeout,
autoclose=autoclose,
autoping=autoping,
auth=auth,
origin=origin,
headers=headers,
proxy=proxy,
proxy_auth=proxy_auth))
@asyncio.coroutine
def _ws_connect(self, url, *,
protocols=(),
timeout=10.0,
autoclose=True,
autoping=True,
auth=None,
origin=None,
headers=None,
proxy=None,
proxy_auth=None):
sec_key = base64.b64encode(os.urandom(16))
if headers is None:
headers = CIMultiDict()
default_headers = {
hdrs.UPGRADE: hdrs.WEBSOCKET,
hdrs.CONNECTION: hdrs.UPGRADE,
hdrs.SEC_WEBSOCKET_VERSION: '13',
hdrs.SEC_WEBSOCKET_KEY: sec_key.decode(),
}
for key, value in default_headers.items():
if key not in headers:
headers[key] = value
if protocols:
headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols)
if origin is not None:
headers[hdrs.ORIGIN] = origin
# send request
resp = yield from self.get(url, headers=headers,
read_until_eof=False,
auth=auth,
proxy=proxy,
proxy_auth=proxy_auth)
try:
# check handshake
if resp.status != 101:
raise WSServerHandshakeError(
message='Invalid response status',
code=resp.status,
headers=resp.headers)
if resp.headers.get(hdrs.UPGRADE, '').lower() != 'websocket':
raise WSServerHandshakeError(
message='Invalid upgrade header',
code=resp.status,
headers=resp.headers)
if resp.headers.get(hdrs.CONNECTION, '').lower() != 'upgrade':
raise WSServerHandshakeError(
message='Invalid connection header',
code=resp.status,
headers=resp.headers)
# key calculation
key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '')
match = base64.b64encode(
hashlib.sha1(sec_key + WS_KEY).digest()).decode()
if key != match:
raise WSServerHandshakeError(
message='Invalid challenge response',
code=resp.status,
headers=resp.headers)
# websocket protocol
protocol = None
if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers:
resp_protocols = [
proto.strip() for proto in
resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
for proto in resp_protocols:
if proto in protocols:
protocol = proto
break
reader = resp.connection.reader.set_parser(WebSocketParser)
resp.connection.writer.set_tcp_nodelay(True)
writer = WebSocketWriter(resp.connection.writer, use_mask=True)
except Exception:
resp.close()
raise
else:
return self._ws_response_class(reader,
writer,
protocol,
resp,
timeout,
autoclose,
autoping,
self._loop)
def _prepare_headers(self, headers):
""" Add default headers and transform it to CIMultiDict
"""
# Convert headers to MultiDict
result = CIMultiDict(self._default_headers)
if headers:
if not isinstance(headers, (MultiDictProxy, MultiDict)):
headers = CIMultiDict(headers)
added_names = set()
for key, value in headers.items():
if key in added_names:
result.add(key, value)
else:
result[key] = value
added_names.add(key)
return result
def get(self, url, *, allow_redirects=True, **kwargs):
"""Perform HTTP GET request."""
return _RequestContextManager(
self._request(hdrs.METH_GET, url,
allow_redirects=allow_redirects,
**kwargs))
def options(self, url, *, allow_redirects=True, **kwargs):
"""Perform HTTP OPTIONS request."""
return _RequestContextManager(
self._request(hdrs.METH_OPTIONS, url,
allow_redirects=allow_redirects,
**kwargs))
def head(self, url, *, allow_redirects=False, **kwargs):
"""Perform HTTP HEAD request."""
return _RequestContextManager(
self._request(hdrs.METH_HEAD, url,
allow_redirects=allow_redirects,
**kwargs))
def post(self, url, *, data=None, **kwargs):
"""Perform HTTP POST request."""
return _RequestContextManager(
self._request(hdrs.METH_POST, url,
data=data,
**kwargs))
def put(self, url, *, data=None, **kwargs):
"""Perform HTTP PUT request."""
return _RequestContextManager(
self._request(hdrs.METH_PUT, url,
data=data,
**kwargs))
def patch(self, url, *, data=None, **kwargs):
"""Perform HTTP PATCH request."""
return _RequestContextManager(
self._request(hdrs.METH_PATCH, url,
data=data,
**kwargs))
def delete(self, url, **kwargs):
"""Perform HTTP DELETE request."""
return _RequestContextManager(
self._request(hdrs.METH_DELETE, url,
**kwargs))
def close(self):
"""Close underlying connector.
Release all acquired resources.
"""
if not self.closed:
self._connector.close()
self._connector = None
ret = helpers.create_future(self._loop)
ret.set_result(None)
return ret
@property
def closed(self):
"""Is client session closed.
A readonly property.
"""
return self._connector is None or self._connector.closed
@property
def connector(self):
"""Connector instance used for the session."""
return self._connector
@property
def cookie_jar(self):
"""The session cookies."""
return self._cookie_jar
@property
def version(self):
"""The session HTTP protocol version."""
return self._version
@property
def loop(self):
"""Session's loop."""
return self._loop
def detach(self):
"""Detach connector from session without closing the former.
Session is switched to closed state anyway.
"""
self._connector = None
def __enter__(self):
warnings.warn("Use async with instead", DeprecationWarning)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
if PY_35:
@asyncio.coroutine
def __aenter__(self):
return self
@asyncio.coroutine
def __aexit__(self, exc_type, exc_val, exc_tb):
yield from self.close()
if PY_35:
from collections.abc import Coroutine
base = Coroutine
else:
base = object
class _BaseRequestContextManager(base):
__slots__ = ('_coro', '_resp')
def __init__(self, coro):
self._coro = coro
self._resp = None
def send(self, value):
return self._coro.send(value)
def throw(self, typ, val=None, tb=None):
if val is None:
return self._coro.throw(typ)
elif tb is None:
return self._coro.throw(typ, val)
else:
return self._coro.throw(typ, val, tb)
def close(self):
return self._coro.close()
@property
def gi_frame(self):
return self._coro.gi_frame
@property
def gi_running(self):
return self._coro.gi_running
@property
def gi_code(self):
return self._coro.gi_code
def __next__(self):
return self.send(None)
@asyncio.coroutine
def __iter__(self):
resp = yield from self._coro
return resp
if PY_35:
def __await__(self):
resp = yield from self._coro
return resp
@asyncio.coroutine
def __aenter__(self):
self._resp = yield from self._coro
return self._resp
if not PY_35:
try:
from asyncio import coroutines
coroutines._COROUTINE_TYPES += (_BaseRequestContextManager,)
except: # pragma: no cover
pass # Python 3.4.2 and 3.4.3 has no coroutines._COROUTINE_TYPES
class _RequestContextManager(_BaseRequestContextManager):
if PY_35:
@asyncio.coroutine
def __aexit__(self, exc_type, exc, tb):
if exc_type is not None:
self._resp.close()
else:
yield from self._resp.release()
class _WSRequestContextManager(_BaseRequestContextManager):
if PY_35:
@asyncio.coroutine
def __aexit__(self, exc_type, exc, tb):
yield from self._resp.close()
class _DetachedRequestContextManager(_RequestContextManager):
__slots__ = _RequestContextManager.__slots__ + ('_session', )
def __init__(self, coro, session):
super().__init__(coro)
self._session = session
@asyncio.coroutine
def __iter__(self):
try:
return (yield from self._coro)
except:
yield from self._session.close()
raise
if PY_35:
def __await__(self):
try:
return (yield from self._coro)
except:
yield from self._session.close()
raise
def __del__(self):
self._session.detach()
class _DetachedWSRequestContextManager(_WSRequestContextManager):
__slots__ = _WSRequestContextManager.__slots__ + ('_session', )
def __init__(self, coro, session):
super().__init__(coro)
self._session = session
def __del__(self):
self._session.detach()
def request(method, url, *,
params=None,
data=None,
headers=None,
skip_auto_headers=None,
cookies=None,
auth=None,
allow_redirects=True,
max_redirects=10,
encoding='utf-8',
version=None,
compress=None,
chunked=None,
expect100=False,
connector=None,
loop=None,
read_until_eof=True,
request_class=None,
response_class=None,
proxy=None,
proxy_auth=None):
"""Constructs and sends a request. Returns response object.
method - HTTP method
url - request url
params - (optional) Dictionary or bytes to be sent in the query
string of the new request
data - (optional) Dictionary, bytes, or file-like object to
send in the body of the request
headers - (optional) Dictionary of HTTP Headers to send with
the request
cookies - (optional) Dict object to send with the request
auth - (optional) BasicAuth named tuple represent HTTP Basic Auth
auth - aiohttp.helpers.BasicAuth
allow_redirects - (optional) If set to False, do not follow
redirects
version - Request HTTP version.
compress - Set to True if request has to be compressed
with deflate encoding.
chunked - Set to chunk size for chunked transfer encoding.
expect100 - Expect 100-continue response from server.
connector - BaseConnector sub-class instance to support
connection pooling.
read_until_eof - Read response until eof if response
does not have Content-Length header.
request_class - (optional) Custom Request class implementation.
response_class - (optional) Custom Response class implementation.
loop - Optional event loop.
Usage::
>>> import aiohttp
>>> resp = yield from aiohttp.request('GET', 'http://python.org/')
>>> resp
<ClientResponse(python.org/) [200]>
>>> data = yield from resp.read()
"""
warnings.warn("Use ClientSession().request() instead", DeprecationWarning)
if connector is None:
connector = aiohttp.TCPConnector(loop=loop, force_close=True)
kwargs = {}
if request_class is not None:
kwargs['request_class'] = request_class
if response_class is not None:
kwargs['response_class'] = response_class
session = ClientSession(loop=loop,
cookies=cookies,
connector=connector,
**kwargs)
return _DetachedRequestContextManager(
session._request(method, url,
params=params,
data=data,
headers=headers,
skip_auto_headers=skip_auto_headers,
auth=auth,
allow_redirects=allow_redirects,
max_redirects=max_redirects,
encoding=encoding,
version=version,
compress=compress,
chunked=chunked,
expect100=expect100,
read_until_eof=read_until_eof,
proxy=proxy,
proxy_auth=proxy_auth,),
session=session)
def get(url, **kwargs):
warnings.warn("Use ClientSession().get() instead", DeprecationWarning)
return request(hdrs.METH_GET, url, **kwargs)
def options(url, **kwargs):
warnings.warn("Use ClientSession().options() instead", DeprecationWarning)
return request(hdrs.METH_OPTIONS, url, **kwargs)
def head(url, **kwargs):
warnings.warn("Use ClientSession().head() instead", DeprecationWarning)
return request(hdrs.METH_HEAD, url, **kwargs)
def post(url, **kwargs):
warnings.warn("Use ClientSession().post() instead", DeprecationWarning)
return request(hdrs.METH_POST, url, **kwargs)
def put(url, **kwargs):
warnings.warn("Use ClientSession().put() instead", DeprecationWarning)
return request(hdrs.METH_PUT, url, **kwargs)
def patch(url, **kwargs):
warnings.warn("Use ClientSession().patch() instead", DeprecationWarning)
return request(hdrs.METH_PATCH, url, **kwargs)
def delete(url, **kwargs):
warnings.warn("Use ClientSession().delete() instead", DeprecationWarning)
return request(hdrs.METH_DELETE, url, **kwargs)
def ws_connect(url, *, protocols=(), timeout=10.0, connector=None, auth=None,
ws_response_class=ClientWebSocketResponse, autoclose=True,
autoping=True, loop=None, origin=None, headers=None):
warnings.warn("Use ClientSession().ws_connect() instead",
DeprecationWarning)
if loop is None:
loop = asyncio.get_event_loop()
if connector is None:
connector = aiohttp.TCPConnector(loop=loop, force_close=True)
session = aiohttp.ClientSession(loop=loop, connector=connector, auth=auth,
ws_response_class=ws_response_class,
headers=headers)
return _DetachedWSRequestContextManager(
session._ws_connect(url,
protocols=protocols,
timeout=timeout,
autoclose=autoclose,
autoping=autoping,
origin=origin),
session=session)

View File

@ -0,0 +1,801 @@
import asyncio
import collections
import http.cookies
import io
import json
import mimetypes
import os
import sys
import traceback
import urllib.parse
import warnings
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
import aiohttp
from . import hdrs, helpers, streams
from .helpers import Timeout
from .log import client_logger
from .multipart import MultipartWriter
from .protocol import HttpMessage
from .streams import EOF_MARKER, FlowControlStreamReader
try:
import cchardet as chardet
except ImportError:
import chardet
__all__ = ('ClientRequest', 'ClientResponse')
PY_35 = sys.version_info >= (3, 5)
HTTP_PORT = 80
HTTPS_PORT = 443
class ClientRequest:
GET_METHODS = {hdrs.METH_GET, hdrs.METH_HEAD, hdrs.METH_OPTIONS}
POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
ALL_METHODS = GET_METHODS.union(POST_METHODS).union(
{hdrs.METH_DELETE, hdrs.METH_TRACE})
DEFAULT_HEADERS = {
hdrs.ACCEPT: '*/*',
hdrs.ACCEPT_ENCODING: 'gzip, deflate',
}
SERVER_SOFTWARE = HttpMessage.SERVER_SOFTWARE
body = b''
auth = None
response = None
response_class = None
_writer = None # async task for streaming data
_continue = None # waiter future for '100 Continue' response
# N.B.
# Adding __del__ method with self._writer closing doesn't make sense
# because _writer is instance method, thus it keeps a reference to self.
# Until writer has finished finalizer will not be called.
def __init__(self, method, url, *,
params=None, headers=None, skip_auto_headers=frozenset(),
data=None, cookies=None,
auth=None, encoding='utf-8',
version=aiohttp.HttpVersion11, compress=None,
chunked=None, expect100=False,
loop=None, response_class=None,
proxy=None, proxy_auth=None,
timeout=5*60):
if loop is None:
loop = asyncio.get_event_loop()
self.url = url
self.method = method.upper()
self.encoding = encoding
self.chunked = chunked
self.compress = compress
self.loop = loop
self.response_class = response_class or ClientResponse
self._timeout = timeout
if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))
self.update_version(version)
self.update_host(url)
self.update_path(params)
self.update_headers(headers)
self.update_auto_headers(skip_auto_headers)
self.update_cookies(cookies)
self.update_content_encoding(data)
self.update_auth(auth)
self.update_proxy(proxy, proxy_auth)
self.update_body_from_data(data, skip_auto_headers)
self.update_transfer_encoding()
self.update_expect_continue(expect100)
def update_host(self, url):
"""Update destination host, port and connection type (ssl)."""
url_parsed = urllib.parse.urlsplit(url)
# check for network location part
netloc = url_parsed.netloc
if not netloc:
raise ValueError('Host could not be detected.')
# get host/port
host = url_parsed.hostname
if not host:
raise ValueError('Host could not be detected.')
try:
port = url_parsed.port
except ValueError:
raise ValueError(
'Port number could not be converted.') from None
# check domain idna encoding
try:
host = host.encode('idna').decode('utf-8')
netloc = self.make_netloc(host, url_parsed.port)
except UnicodeError:
raise ValueError('URL has an invalid label.')
# basic auth info
username, password = url_parsed.username, url_parsed.password
if username:
self.auth = helpers.BasicAuth(username, password or '')
# Record entire netloc for usage in host header
self.netloc = netloc
scheme = url_parsed.scheme
self.ssl = scheme in ('https', 'wss')
# set port number if it isn't already set
if not port:
if self.ssl:
port = HTTPS_PORT
else:
port = HTTP_PORT
self.host, self.port, self.scheme = host, port, scheme
def make_netloc(self, host, port):
ret = host
if port:
ret = ret + ':' + str(port)
return ret
def update_version(self, version):
"""Convert request version to two elements tuple.
parser HTTP version '1.1' => (1, 1)
"""
if isinstance(version, str):
v = [l.strip() for l in version.split('.', 1)]
try:
version = int(v[0]), int(v[1])
except ValueError:
raise ValueError(
'Can not parse http version number: {}'
.format(version)) from None
self.version = version
def update_path(self, params):
"""Build path."""
# extract path
scheme, netloc, path, query, fragment = urllib.parse.urlsplit(self.url)
if not path:
path = '/'
if isinstance(params, collections.Mapping):
params = list(params.items())
if params:
if not isinstance(params, str):
params = urllib.parse.urlencode(params)
if query:
query = '%s&%s' % (query, params)
else:
query = params
self.path = urllib.parse.urlunsplit(('', '', helpers.requote_uri(path),
query, ''))
self.url = urllib.parse.urlunsplit(
(scheme, netloc, self.path, '', fragment))
def update_headers(self, headers):
"""Update request headers."""
self.headers = CIMultiDict()
if headers:
if isinstance(headers, dict):
headers = headers.items()
elif isinstance(headers, (MultiDictProxy, MultiDict)):
headers = headers.items()
for key, value in headers:
self.headers.add(key, value)
def update_auto_headers(self, skip_auto_headers):
self.skip_auto_headers = skip_auto_headers
used_headers = set(self.headers) | skip_auto_headers
for hdr, val in self.DEFAULT_HEADERS.items():
if hdr not in used_headers:
self.headers.add(hdr, val)
# add host
if hdrs.HOST not in used_headers:
self.headers[hdrs.HOST] = self.netloc
if hdrs.USER_AGENT not in used_headers:
self.headers[hdrs.USER_AGENT] = self.SERVER_SOFTWARE
def update_cookies(self, cookies):
"""Update request cookies header."""
if not cookies:
return
c = http.cookies.SimpleCookie()
if hdrs.COOKIE in self.headers:
c.load(self.headers.get(hdrs.COOKIE, ''))
del self.headers[hdrs.COOKIE]
if isinstance(cookies, dict):
cookies = cookies.items()
for name, value in cookies:
if isinstance(value, http.cookies.Morsel):
c[value.key] = value.value
else:
c[name] = value
self.headers[hdrs.COOKIE] = c.output(header='', sep=';').strip()
def update_content_encoding(self, data):
"""Set request content encoding."""
if not data:
return
enc = self.headers.get(hdrs.CONTENT_ENCODING, '').lower()
if enc:
if self.compress is not False:
self.compress = enc
# enable chunked, no need to deal with length
self.chunked = True
elif self.compress:
if not isinstance(self.compress, str):
self.compress = 'deflate'
self.headers[hdrs.CONTENT_ENCODING] = self.compress
self.chunked = True # enable chunked, no need to deal with length
def update_auth(self, auth):
"""Set basic auth."""
if auth is None:
auth = self.auth
if auth is None:
return
if not isinstance(auth, helpers.BasicAuth):
raise TypeError('BasicAuth() tuple is required instead')
self.headers[hdrs.AUTHORIZATION] = auth.encode()
def update_body_from_data(self, data, skip_auto_headers):
if not data:
return
if isinstance(data, str):
data = data.encode(self.encoding)
if isinstance(data, (bytes, bytearray)):
self.body = data
if (hdrs.CONTENT_TYPE not in self.headers and
hdrs.CONTENT_TYPE not in skip_auto_headers):
self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'
if hdrs.CONTENT_LENGTH not in self.headers and not self.chunked:
self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))
elif isinstance(data, (asyncio.StreamReader, streams.StreamReader,
streams.DataQueue)):
self.body = data
elif asyncio.iscoroutine(data):
self.body = data
if (hdrs.CONTENT_LENGTH not in self.headers and
self.chunked is None):
self.chunked = True
elif isinstance(data, io.IOBase):
assert not isinstance(data, io.StringIO), \
'attempt to send text data instead of binary'
self.body = data
if not self.chunked and isinstance(data, io.BytesIO):
# Not chunking if content-length can be determined
size = len(data.getbuffer())
self.headers[hdrs.CONTENT_LENGTH] = str(size)
self.chunked = False
elif not self.chunked and isinstance(data, io.BufferedReader):
# Not chunking if content-length can be determined
try:
size = os.fstat(data.fileno()).st_size - data.tell()
self.headers[hdrs.CONTENT_LENGTH] = str(size)
self.chunked = False
except OSError:
# data.fileno() is not supported, e.g.
# io.BufferedReader(io.BytesIO(b'data'))
self.chunked = True
else:
self.chunked = True
if hasattr(data, 'mode'):
if data.mode == 'r':
raise ValueError('file {!r} should be open in binary mode'
''.format(data))
if (hdrs.CONTENT_TYPE not in self.headers and
hdrs.CONTENT_TYPE not in skip_auto_headers and
hasattr(data, 'name')):
mime = mimetypes.guess_type(data.name)[0]
mime = 'application/octet-stream' if mime is None else mime
self.headers[hdrs.CONTENT_TYPE] = mime
elif isinstance(data, MultipartWriter):
self.body = data.serialize()
self.headers.update(data.headers)
self.chunked = self.chunked or 8192
else:
if not isinstance(data, helpers.FormData):
data = helpers.FormData(data)
self.body = data(self.encoding)
if (hdrs.CONTENT_TYPE not in self.headers and
hdrs.CONTENT_TYPE not in skip_auto_headers):
self.headers[hdrs.CONTENT_TYPE] = data.content_type
if data.is_multipart:
self.chunked = self.chunked or 8192
else:
if (hdrs.CONTENT_LENGTH not in self.headers and
not self.chunked):
self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))
def update_transfer_encoding(self):
"""Analyze transfer-encoding header."""
te = self.headers.get(hdrs.TRANSFER_ENCODING, '').lower()
if self.chunked:
if hdrs.CONTENT_LENGTH in self.headers:
del self.headers[hdrs.CONTENT_LENGTH]
if 'chunked' not in te:
self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'
self.chunked = self.chunked if type(self.chunked) is int else 8192
else:
if 'chunked' in te:
self.chunked = 8192
else:
self.chunked = None
if hdrs.CONTENT_LENGTH not in self.headers:
self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))
def update_expect_continue(self, expect=False):
if expect:
self.headers[hdrs.EXPECT] = '100-continue'
elif self.headers.get(hdrs.EXPECT, '').lower() == '100-continue':
expect = True
if expect:
self._continue = helpers.create_future(self.loop)
def update_proxy(self, proxy, proxy_auth):
if proxy and not proxy.startswith('http://'):
raise ValueError("Only http proxies are supported")
if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
raise ValueError("proxy_auth must be None or BasicAuth() tuple")
self.proxy = proxy
self.proxy_auth = proxy_auth
@asyncio.coroutine
def write_bytes(self, request, reader):
"""Support coroutines that yields bytes objects."""
# 100 response
if self._continue is not None:
yield from self._continue
try:
if asyncio.iscoroutine(self.body):
request.transport.set_tcp_nodelay(True)
exc = None
value = None
stream = self.body
while True:
try:
if exc is not None:
result = stream.throw(exc)
else:
result = stream.send(value)
except StopIteration as exc:
if isinstance(exc.value, bytes):
yield from request.write(exc.value, drain=True)
break
except:
self.response.close()
raise
if isinstance(result, asyncio.Future):
exc = None
value = None
try:
value = yield result
except Exception as err:
exc = err
elif isinstance(result, (bytes, bytearray)):
yield from request.write(result, drain=True)
value = None
else:
raise ValueError(
'Bytes object is expected, got: %s.' %
type(result))
elif isinstance(self.body, (asyncio.StreamReader,
streams.StreamReader)):
request.transport.set_tcp_nodelay(True)
chunk = yield from self.body.read(streams.DEFAULT_LIMIT)
while chunk:
yield from request.write(chunk, drain=True)
chunk = yield from self.body.read(streams.DEFAULT_LIMIT)
elif isinstance(self.body, streams.DataQueue):
request.transport.set_tcp_nodelay(True)
while True:
try:
chunk = yield from self.body.read()
if chunk is EOF_MARKER:
break
yield from request.write(chunk, drain=True)
except streams.EofStream:
break
elif isinstance(self.body, io.IOBase):
chunk = self.body.read(self.chunked)
while chunk:
request.write(chunk)
chunk = self.body.read(self.chunked)
request.transport.set_tcp_nodelay(True)
else:
if isinstance(self.body, (bytes, bytearray)):
self.body = (self.body,)
for chunk in self.body:
request.write(chunk)
request.transport.set_tcp_nodelay(True)
except Exception as exc:
new_exc = aiohttp.ClientRequestError(
'Can not write request body for %s' % self.url)
new_exc.__context__ = exc
new_exc.__cause__ = exc
reader.set_exception(new_exc)
else:
assert request.transport.tcp_nodelay
try:
ret = request.write_eof()
# NB: in asyncio 3.4.1+ StreamWriter.drain() is coroutine
# see bug #170
if (asyncio.iscoroutine(ret) or
isinstance(ret, asyncio.Future)):
yield from ret
except Exception as exc:
new_exc = aiohttp.ClientRequestError(
'Can not write request body for %s' % self.url)
new_exc.__context__ = exc
new_exc.__cause__ = exc
reader.set_exception(new_exc)
self._writer = None
def send(self, writer, reader):
writer.set_tcp_cork(True)
request = aiohttp.Request(writer, self.method, self.path, self.version)
if self.compress:
request.add_compression_filter(self.compress)
if self.chunked is not None:
request.enable_chunked_encoding()
request.add_chunking_filter(self.chunked)
# set default content-type
if (self.method in self.POST_METHODS and
hdrs.CONTENT_TYPE not in self.skip_auto_headers and
hdrs.CONTENT_TYPE not in self.headers):
self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'
for k, value in self.headers.items():
request.add_header(k, value)
request.send_headers()
self._writer = helpers.ensure_future(
self.write_bytes(request, reader), loop=self.loop)
self.response = self.response_class(
self.method, self.url, self.host,
writer=self._writer, continue100=self._continue,
timeout=self._timeout)
self.response._post_init(self.loop)
return self.response
@asyncio.coroutine
def close(self):
if self._writer is not None:
try:
yield from self._writer
finally:
self._writer = None
def terminate(self):
if self._writer is not None:
if not self.loop.is_closed():
self._writer.cancel()
self._writer = None
class ClientResponse:
# from the Status-Line of the response
version = None # HTTP-Version
status = None # Status-Code
reason = None # Reason-Phrase
cookies = None # Response cookies (Set-Cookie)
content = None # Payload stream
headers = None # Response headers, CIMultiDictProxy
raw_headers = None # Response raw headers, a sequence of pairs
_connection = None # current connection
flow_control_class = FlowControlStreamReader # reader flow control
_reader = None # input stream
_response_parser = aiohttp.HttpResponseParser()
_source_traceback = None
# setted up by ClientRequest after ClientResponse object creation
# post-init stage allows to not change ctor signature
_loop = None
_closed = True # to allow __del__ for non-initialized properly response
def __init__(self, method, url, host='', *, writer=None, continue100=None,
timeout=5*60):
super().__init__()
self.method = method
self.url = url
self.host = host
self._content = None
self._writer = writer
self._continue = continue100
self._closed = False
self._should_close = True # override by message.should_close later
self._history = ()
self._timeout = timeout
def _post_init(self, loop):
self._loop = loop
if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))
def __del__(self, _warnings=warnings):
if self._loop is None:
return # not started
if self._closed:
return
self.close()
_warnings.warn("Unclosed response {!r}".format(self),
ResourceWarning)
context = {'client_response': self,
'message': 'Unclosed response'}
if self._source_traceback:
context['source_traceback'] = self._source_traceback
self._loop.call_exception_handler(context)
def __repr__(self):
out = io.StringIO()
ascii_encodable_url = self.url.encode('ascii', 'backslashreplace') \
.decode('ascii')
if self.reason:
ascii_encodable_reason = self.reason.encode('ascii',
'backslashreplace') \
.decode('ascii')
else:
ascii_encodable_reason = self.reason
print('<ClientResponse({}) [{} {}]>'.format(
ascii_encodable_url, self.status, ascii_encodable_reason),
file=out)
print(self.headers, file=out)
return out.getvalue()
@property
def connection(self):
return self._connection
@property
def history(self):
"""A sequence of of responses, if redirects occured."""
return self._history
def waiting_for_continue(self):
return self._continue is not None
def _setup_connection(self, connection):
self._reader = connection.reader
self._connection = connection
self.content = self.flow_control_class(
connection.reader, loop=connection.loop, timeout=self._timeout)
def _need_parse_response_body(self):
return (self.method.lower() != 'head' and
self.status not in [204, 304])
@asyncio.coroutine
def start(self, connection, read_until_eof=False):
"""Start response processing."""
self._setup_connection(connection)
while True:
httpstream = self._reader.set_parser(self._response_parser)
# read response
with Timeout(self._timeout, loop=self._loop):
message = yield from httpstream.read()
if message.code != 100:
break
if self._continue is not None and not self._continue.done():
self._continue.set_result(True)
self._continue = None
# response status
self.version = message.version
self.status = message.code
self.reason = message.reason
self._should_close = message.should_close
# headers
self.headers = CIMultiDictProxy(message.headers)
self.raw_headers = tuple(message.raw_headers)
# payload
rwb = self._need_parse_response_body()
self._reader.set_parser(
aiohttp.HttpPayloadParser(message,
readall=read_until_eof,
response_with_body=rwb),
self.content)
# cookies
self.cookies = http.cookies.SimpleCookie()
if hdrs.SET_COOKIE in self.headers:
for hdr in self.headers.getall(hdrs.SET_COOKIE):
try:
self.cookies.load(hdr)
except http.cookies.CookieError as exc:
client_logger.warning(
'Can not load response cookies: %s', exc)
return self
def close(self):
if self._closed:
return
self._closed = True
if self._loop is None or self._loop.is_closed():
return
if self._connection is not None:
self._connection.close()
self._connection = None
self._cleanup_writer()
@asyncio.coroutine
def release(self):
if self._closed:
return
try:
content = self.content
if content is not None and not content.at_eof():
chunk = yield from content.readany()
while chunk is not EOF_MARKER or chunk:
chunk = yield from content.readany()
except Exception:
self._connection.close()
self._connection = None
raise
finally:
self._closed = True
if self._connection is not None:
self._connection.release()
if self._reader is not None:
self._reader.unset_parser()
self._connection = None
self._cleanup_writer()
def raise_for_status(self):
if 400 <= self.status:
raise aiohttp.HttpProcessingError(
code=self.status,
message=self.reason)
def _cleanup_writer(self):
if self._writer is not None and not self._writer.done():
self._writer.cancel()
self._writer = None
@asyncio.coroutine
def wait_for_close(self):
if self._writer is not None:
try:
yield from self._writer
finally:
self._writer = None
yield from self.release()
@asyncio.coroutine
def read(self):
"""Read response payload."""
if self._content is None:
try:
self._content = yield from self.content.read()
except:
self.close()
raise
else:
yield from self.release()
return self._content
def _get_encoding(self):
ctype = self.headers.get(hdrs.CONTENT_TYPE, '').lower()
mtype, stype, _, params = helpers.parse_mimetype(ctype)
encoding = params.get('charset')
if not encoding:
encoding = chardet.detect(self._content)['encoding']
if not encoding:
encoding = 'utf-8'
return encoding
@asyncio.coroutine
def text(self, encoding=None):
"""Read response payload and decode."""
if self._content is None:
yield from self.read()
if encoding is None:
encoding = self._get_encoding()
return self._content.decode(encoding)
@asyncio.coroutine
def json(self, *, encoding=None, loads=json.loads):
"""Read and decodes JSON response."""
if self._content is None:
yield from self.read()
ctype = self.headers.get(hdrs.CONTENT_TYPE, '').lower()
if 'json' not in ctype:
client_logger.warning(
'Attempt to decode JSON with unexpected mimetype: %s', ctype)
stripped = self._content.strip()
if not stripped:
return None
if encoding is None:
encoding = self._get_encoding()
return loads(stripped.decode(encoding))
if PY_35:
@asyncio.coroutine
def __aenter__(self):
return self
@asyncio.coroutine
def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
yield from self.release()
else:
self.close()

View File

@ -0,0 +1,193 @@
"""WebSocket client for asyncio."""
import asyncio
import json
import sys
from ._ws_impl import CLOSED_MESSAGE, WebSocketError, WSMessage, WSMsgType
PY_35 = sys.version_info >= (3, 5)
PY_352 = sys.version_info >= (3, 5, 2)
class ClientWebSocketResponse:
def __init__(self, reader, writer, protocol,
response, timeout, autoclose, autoping, loop):
self._response = response
self._conn = response.connection
self._writer = writer
self._reader = reader
self._protocol = protocol
self._closed = False
self._closing = False
self._close_code = None
self._timeout = timeout
self._autoclose = autoclose
self._autoping = autoping
self._loop = loop
self._waiting = False
self._exception = None
@property
def closed(self):
return self._closed
@property
def close_code(self):
return self._close_code
@property
def protocol(self):
return self._protocol
def exception(self):
return self._exception
def ping(self, message='b'):
if self._closed:
raise RuntimeError('websocket connection is closed')
self._writer.ping(message)
def pong(self, message='b'):
if self._closed:
raise RuntimeError('websocket connection is closed')
self._writer.pong(message)
def send_str(self, data):
if self._closed:
raise RuntimeError('websocket connection is closed')
if not isinstance(data, str):
raise TypeError('data argument must be str (%r)' % type(data))
self._writer.send(data, binary=False)
def send_bytes(self, data):
if self._closed:
raise RuntimeError('websocket connection is closed')
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError('data argument must be byte-ish (%r)' %
type(data))
self._writer.send(data, binary=True)
def send_json(self, data, *, dumps=json.dumps):
self.send_str(dumps(data))
@asyncio.coroutine
def close(self, *, code=1000, message=b''):
if not self._closed:
self._closed = True
try:
self._writer.close(code, message)
except asyncio.CancelledError:
self._close_code = 1006
self._response.close()
raise
except Exception as exc:
self._close_code = 1006
self._exception = exc
self._response.close()
return True
if self._closing:
self._response.close()
return True
while True:
try:
msg = yield from asyncio.wait_for(
self._reader.read(), self._timeout, loop=self._loop)
except asyncio.CancelledError:
self._close_code = 1006
self._response.close()
raise
except Exception as exc:
self._close_code = 1006
self._exception = exc
self._response.close()
return True
if msg.type == WSMsgType.CLOSE:
self._close_code = msg.data
self._response.close()
return True
else:
return False
@asyncio.coroutine
def receive(self):
if self._waiting:
raise RuntimeError('Concurrent call to receive() is not allowed')
self._waiting = True
try:
while True:
if self._closed:
return CLOSED_MESSAGE
try:
msg = yield from self._reader.read()
except (asyncio.CancelledError, asyncio.TimeoutError):
raise
except WebSocketError as exc:
self._close_code = exc.code
yield from self.close(code=exc.code)
return WSMessage(WSMsgType.ERROR, exc, None)
except Exception as exc:
self._exception = exc
self._closing = True
self._close_code = 1006
yield from self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
if msg.type == WSMsgType.CLOSE:
self._closing = True
self._close_code = msg.data
if not self._closed and self._autoclose:
yield from self.close()
return msg
if msg.type == WSMsgType.PING and self._autoping:
self.pong(msg.data)
elif msg.type == WSMsgType.PONG and self._autoping:
continue
else:
return msg
finally:
self._waiting = False
@asyncio.coroutine
def receive_str(self):
msg = yield from self.receive()
if msg.type != WSMsgType.TEXT:
raise TypeError(
"Received message {}:{!r} is not str".format(msg.type,
msg.data))
return msg.data
@asyncio.coroutine
def receive_bytes(self):
msg = yield from self.receive()
if msg.type != WSMsgType.BINARY:
raise TypeError(
"Received message {}:{!r} is not bytes".format(msg.type,
msg.data))
return msg.data
@asyncio.coroutine
def receive_json(self, *, loads=json.loads):
data = yield from self.receive_str()
return loads(data)
if PY_35:
def __aiter__(self):
return self
if not PY_352: # pragma: no cover
__aiter__ = asyncio.coroutine(__aiter__)
@asyncio.coroutine
def __anext__(self):
msg = yield from self.receive()
if msg.type == WSMsgType.CLOSE:
raise StopAsyncIteration # NOQA
return msg

View File

@ -0,0 +1,783 @@
import asyncio
import functools
import http.cookies
import ssl
import sys
import traceback
import warnings
from collections import defaultdict
from hashlib import md5, sha1, sha256
from itertools import chain
from math import ceil
from types import MappingProxyType
import aiohttp
from . import hdrs, helpers
from .client import ClientRequest
from .errors import (ClientOSError, ClientTimeoutError, FingerprintMismatch,
HttpProxyError, ProxyConnectionError,
ServerDisconnectedError)
from .helpers import is_ip_address, sentinel
from .resolver import DefaultResolver
__all__ = ('BaseConnector', 'TCPConnector', 'ProxyConnector', 'UnixConnector')
PY_343 = sys.version_info >= (3, 4, 3)
HASHFUNC_BY_DIGESTLEN = {
16: md5,
20: sha1,
32: sha256,
}
class Connection:
_source_traceback = None
_transport = None
def __init__(self, connector, key, request, transport, protocol, loop):
self._key = key
self._connector = connector
self._request = request
self._transport = transport
self._protocol = protocol
self._loop = loop
self.reader = protocol.reader
self.writer = protocol.writer
if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))
def __repr__(self):
return 'Connection<{}>'.format(self._key)
def __del__(self, _warnings=warnings):
if self._transport is not None:
_warnings.warn('Unclosed connection {!r}'.format(self),
ResourceWarning)
if self._loop.is_closed():
return
self._connector._release(
self._key, self._request, self._transport, self._protocol,
should_close=True)
context = {'client_connection': self,
'message': 'Unclosed connection'}
if self._source_traceback is not None:
context['source_traceback'] = self._source_traceback
self._loop.call_exception_handler(context)
@property
def loop(self):
return self._loop
def close(self):
if self._transport is not None:
self._connector._release(
self._key, self._request, self._transport, self._protocol,
should_close=True)
self._transport = None
def release(self):
if self._transport is not None:
self._connector._release(
self._key, self._request, self._transport, self._protocol,
should_close=False)
self._transport = None
def detach(self):
self._transport = None
@property
def closed(self):
return self._transport is None
class BaseConnector(object):
"""Base connector class.
conn_timeout - (optional) Connect timeout.
keepalive_timeout - (optional) Keep-alive timeout.
force_close - Set to True to force close and do reconnect
after each request (and between redirects).
limit - The limit of simultaneous connections to the same endpoint.
loop - Optional event loop.
"""
_closed = True # prevent AttributeError in __del__ if ctor was failed
_source_traceback = None
def __init__(self, *, conn_timeout=None, keepalive_timeout=sentinel,
force_close=False, limit=20,
loop=None):
if force_close:
if keepalive_timeout is not None and \
keepalive_timeout is not sentinel:
raise ValueError('keepalive_timeout cannot '
'be set if force_close is True')
else:
if keepalive_timeout is sentinel:
keepalive_timeout = 30
if loop is None:
loop = asyncio.get_event_loop()
self._closed = False
if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))
self._conns = {}
self._acquired = defaultdict(set)
self._conn_timeout = conn_timeout
self._keepalive_timeout = keepalive_timeout
self._cleanup_handle = None
self._force_close = force_close
self._limit = limit
self._waiters = defaultdict(list)
self._loop = loop
self._factory = functools.partial(
aiohttp.StreamProtocol, loop=loop,
disconnect_error=ServerDisconnectedError)
self.cookies = http.cookies.SimpleCookie()
def __del__(self, _warnings=warnings):
if self._closed:
return
if not self._conns:
return
conns = [repr(c) for c in self._conns.values()]
self.close()
_warnings.warn("Unclosed connector {!r}".format(self),
ResourceWarning)
context = {'connector': self,
'connections': conns,
'message': 'Unclosed connector'}
if self._source_traceback is not None:
context['source_traceback'] = self._source_traceback
self._loop.call_exception_handler(context)
def __enter__(self):
return self
def __exit__(self, *exc):
self.close()
@property
def force_close(self):
"""Ultimately close connection on releasing if True."""
return self._force_close
@property
def limit(self):
"""The limit for simultaneous connections to the same endpoint.
Endpoints are the same if they are have equal
(host, port, is_ssl) triple.
If limit is None the connector has no limit.
The default limit size is 20.
"""
return self._limit
def _cleanup(self):
"""Cleanup unused transports."""
if self._cleanup_handle:
self._cleanup_handle.cancel()
self._cleanup_handle = None
now = self._loop.time()
connections = {}
timeout = self._keepalive_timeout
for key, conns in self._conns.items():
alive = []
for transport, proto, t0 in conns:
if transport is not None:
if proto and not proto.is_connected():
transport = None
else:
delta = t0 + self._keepalive_timeout - now
if delta < 0:
transport.close()
transport = None
elif delta < timeout:
timeout = delta
if transport is not None:
alive.append((transport, proto, t0))
if alive:
connections[key] = alive
if connections:
self._cleanup_handle = self._loop.call_at(
ceil(now + timeout), self._cleanup)
self._conns = connections
def _start_cleanup_task(self):
if self._cleanup_handle is None:
now = self._loop.time()
self._cleanup_handle = self._loop.call_at(
ceil(now + self._keepalive_timeout), self._cleanup)
def close(self):
"""Close all opened transports."""
ret = helpers.create_future(self._loop)
ret.set_result(None)
if self._closed:
return ret
self._closed = True
try:
if self._loop.is_closed():
return ret
for key, data in self._conns.items():
for transport, proto, t0 in data:
transport.close()
for transport in chain(*self._acquired.values()):
transport.close()
if self._cleanup_handle:
self._cleanup_handle.cancel()
finally:
self._conns.clear()
self._acquired.clear()
self._cleanup_handle = None
return ret
@property
def closed(self):
"""Is connector closed.
A readonly property.
"""
return self._closed
@asyncio.coroutine
def connect(self, req):
"""Get from pool or create new connection."""
key = (req.host, req.port, req.ssl)
limit = self._limit
if limit is not None:
fut = helpers.create_future(self._loop)
waiters = self._waiters[key]
# The limit defines the maximum number of concurrent connections
# for a key. Waiters must be counted against the limit, even before
# the underlying connection is created.
available = limit - len(waiters) - len(self._acquired[key])
# Don't wait if there are connections available.
if available > 0:
fut.set_result(None)
# This connection will now count towards the limit.
waiters.append(fut)
try:
if limit is not None:
yield from fut
transport, proto = self._get(key)
if transport is None:
try:
if self._conn_timeout:
transport, proto = yield from asyncio.wait_for(
self._create_connection(req),
self._conn_timeout, loop=self._loop)
else:
transport, proto = \
yield from self._create_connection(req)
except asyncio.TimeoutError as exc:
raise ClientTimeoutError(
'Connection timeout to host {0[0]}:{0[1]} ssl:{0[2]}'
.format(key)) from exc
except OSError as exc:
raise ClientOSError(
exc.errno,
'Cannot connect to host {0[0]}:{0[1]} ssl:{0[2]} [{1}]'
.format(key, exc.strerror)) from exc
except:
self._release_waiter(key)
raise
self._acquired[key].add(transport)
conn = Connection(self, key, req, transport, proto, self._loop)
return conn
def _get(self, key):
try:
conns = self._conns[key]
except KeyError:
return None, None
t1 = self._loop.time()
while conns:
transport, proto, t0 = conns.pop()
if transport is not None and proto.is_connected():
if t1 - t0 > self._keepalive_timeout:
transport.close()
transport = None
else:
if not conns:
# The very last connection was reclaimed: drop the key
del self._conns[key]
return transport, proto
# No more connections: drop the key
del self._conns[key]
return None, None
def _release_waiter(self, key):
waiters = self._waiters[key]
while waiters:
waiter = waiters.pop(0)
if not waiter.done():
waiter.set_result(None)
break
def _release(self, key, req, transport, protocol, *, should_close=False):
if self._closed:
# acquired connection is already released on connector closing
return
acquired = self._acquired[key]
try:
acquired.remove(transport)
except KeyError: # pragma: no cover
# this may be result of undetermenistic order of objects
# finalization due garbage collection.
pass
else:
if self._limit is not None and len(acquired) < self._limit:
self._release_waiter(key)
resp = req.response
if not should_close:
if self._force_close:
should_close = True
elif resp is not None:
should_close = resp._should_close
reader = protocol.reader
if should_close or (reader.output and not reader.output.at_eof()):
transport.close()
else:
conns = self._conns.get(key)
if conns is None:
conns = self._conns[key] = []
conns.append((transport, protocol, self._loop.time()))
reader.unset_parser()
self._start_cleanup_task()
@asyncio.coroutine
def _create_connection(self, req):
raise NotImplementedError()
_SSL_OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0)
class TCPConnector(BaseConnector):
"""TCP connector.
verify_ssl - Set to True to check ssl certifications.
fingerprint - Pass the binary md5, sha1, or sha256
digest of the expected certificate in DER format to verify
that the certificate the server presents matches. See also
https://en.wikipedia.org/wiki/Transport_Layer_Security#Certificate_pinning
resolve - (Deprecated) Set to True to do DNS lookup for
host name.
resolver - Enable DNS lookups and use this
resolver
use_dns_cache - Use memory cache for DNS lookups.
family - socket address family
local_addr - local tuple of (host, port) to bind socket to
conn_timeout - (optional) Connect timeout.
keepalive_timeout - (optional) Keep-alive timeout.
force_close - Set to True to force close and do reconnect
after each request (and between redirects).
limit - The limit of simultaneous connections to the same endpoint.
loop - Optional event loop.
"""
def __init__(self, *, verify_ssl=True, fingerprint=None,
resolve=sentinel, use_dns_cache=sentinel,
family=0, ssl_context=None, local_addr=None, resolver=None,
conn_timeout=None, keepalive_timeout=sentinel,
force_close=False, limit=20,
loop=None):
super().__init__(conn_timeout=conn_timeout,
keepalive_timeout=keepalive_timeout,
force_close=force_close, limit=limit, loop=loop)
if not verify_ssl and ssl_context is not None:
raise ValueError(
"Either disable ssl certificate validation by "
"verify_ssl=False or specify ssl_context, not both.")
self._verify_ssl = verify_ssl
if fingerprint:
digestlen = len(fingerprint)
hashfunc = HASHFUNC_BY_DIGESTLEN.get(digestlen)
if not hashfunc:
raise ValueError('fingerprint has invalid length')
self._hashfunc = hashfunc
self._fingerprint = fingerprint
if resolve is not sentinel:
warnings.warn(("resolve parameter is deprecated, "
"use use_dns_cache instead"),
DeprecationWarning, stacklevel=2)
if use_dns_cache is not sentinel and resolve is not sentinel:
if use_dns_cache != resolve:
raise ValueError("use_dns_cache must agree with resolve")
_use_dns_cache = use_dns_cache
elif use_dns_cache is not sentinel:
_use_dns_cache = use_dns_cache
elif resolve is not sentinel:
_use_dns_cache = resolve
else:
_use_dns_cache = True
if resolver is None:
resolver = DefaultResolver(loop=self._loop)
self._resolver = resolver
self._use_dns_cache = _use_dns_cache
self._cached_hosts = {}
self._ssl_context = ssl_context
self._family = family
self._local_addr = local_addr
@property
def verify_ssl(self):
"""Do check for ssl certifications?"""
return self._verify_ssl
@property
def fingerprint(self):
"""Expected ssl certificate fingerprint."""
return self._fingerprint
@property
def ssl_context(self):
"""SSLContext instance for https requests.
Lazy property, creates context on demand.
"""
if self._ssl_context is None:
if not self._verify_ssl:
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.options |= ssl.OP_NO_SSLv3
sslcontext.options |= _SSL_OP_NO_COMPRESSION
sslcontext.set_default_verify_paths()
else:
sslcontext = ssl.create_default_context()
self._ssl_context = sslcontext
return self._ssl_context
@property
def family(self):
"""Socket family like AF_INET."""
return self._family
@property
def use_dns_cache(self):
"""True if local DNS caching is enabled."""
return self._use_dns_cache
@property
def cached_hosts(self):
"""Read-only dict of cached DNS record."""
return MappingProxyType(self._cached_hosts)
def clear_dns_cache(self, host=None, port=None):
"""Remove specified host/port or clear all dns local cache."""
if host is not None and port is not None:
self._cached_hosts.pop((host, port), None)
elif host is not None or port is not None:
raise ValueError("either both host and port "
"or none of them are allowed")
else:
self._cached_hosts.clear()
@property
def resolve(self):
"""Do DNS lookup for host name?"""
warnings.warn((".resolve property is deprecated, "
"use .dns_cache instead"),
DeprecationWarning, stacklevel=2)
return self.use_dns_cache
@property
def resolved_hosts(self):
"""The dict of (host, port) -> (ipaddr, port) pairs."""
warnings.warn((".resolved_hosts property is deprecated, "
"use .cached_hosts instead"),
DeprecationWarning, stacklevel=2)
return self.cached_hosts
def clear_resolved_hosts(self, host=None, port=None):
"""Remove specified host/port or clear all resolve cache."""
warnings.warn((".clear_resolved_hosts() is deprecated, "
"use .clear_dns_cache() instead"),
DeprecationWarning, stacklevel=2)
if host is not None and port is not None:
self.clear_dns_cache(host, port)
else:
self.clear_dns_cache()
@asyncio.coroutine
def _resolve_host(self, host, port):
if is_ip_address(host):
return [{'hostname': host, 'host': host, 'port': port,
'family': self._family, 'proto': 0, 'flags': 0}]
if self._use_dns_cache:
key = (host, port)
if key not in self._cached_hosts:
self._cached_hosts[key] = yield from \
self._resolver.resolve(host, port, family=self._family)
return self._cached_hosts[key]
else:
res = yield from self._resolver.resolve(
host, port, family=self._family)
return res
@asyncio.coroutine
def _create_connection(self, req):
"""Create connection.
Has same keyword arguments as BaseEventLoop.create_connection.
"""
if req.proxy:
transport, proto = yield from self._create_proxy_connection(req)
else:
transport, proto = yield from self._create_direct_connection(req)
return transport, proto
@asyncio.coroutine
def _create_direct_connection(self, req):
if req.ssl:
sslcontext = self.ssl_context
else:
sslcontext = None
hosts = yield from self._resolve_host(req.host, req.port)
exc = None
for hinfo in hosts:
try:
host = hinfo['host']
port = hinfo['port']
transp, proto = yield from self._loop.create_connection(
self._factory, host, port,
ssl=sslcontext, family=hinfo['family'],
proto=hinfo['proto'], flags=hinfo['flags'],
server_hostname=hinfo['hostname'] if sslcontext else None,
local_addr=self._local_addr)
has_cert = transp.get_extra_info('sslcontext')
if has_cert and self._fingerprint:
sock = transp.get_extra_info('socket')
if not hasattr(sock, 'getpeercert'):
# Workaround for asyncio 3.5.0
# Starting from 3.5.1 version
# there is 'ssl_object' extra info in transport
sock = transp._ssl_protocol._sslpipe.ssl_object
# gives DER-encoded cert as a sequence of bytes (or None)
cert = sock.getpeercert(binary_form=True)
assert cert
got = self._hashfunc(cert).digest()
expected = self._fingerprint
if got != expected:
transp.close()
raise FingerprintMismatch(expected, got, host, port)
return transp, proto
except OSError as e:
exc = e
else:
raise ClientOSError(exc.errno,
'Can not connect to %s:%s [%s]' %
(req.host, req.port, exc.strerror)) from exc
@asyncio.coroutine
def _create_proxy_connection(self, req):
proxy_req = ClientRequest(
hdrs.METH_GET, req.proxy,
headers={hdrs.HOST: req.host},
auth=req.proxy_auth,
loop=self._loop)
try:
# create connection to proxy server
transport, proto = yield from self._create_direct_connection(
proxy_req)
except OSError as exc:
raise ProxyConnectionError(*exc.args) from exc
if not req.ssl:
req.path = '{scheme}://{host}{path}'.format(scheme=req.scheme,
host=req.netloc,
path=req.path)
if hdrs.AUTHORIZATION in proxy_req.headers:
auth = proxy_req.headers[hdrs.AUTHORIZATION]
del proxy_req.headers[hdrs.AUTHORIZATION]
if not req.ssl:
req.headers[hdrs.PROXY_AUTHORIZATION] = auth
else:
proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth
if req.ssl:
# For HTTPS requests over HTTP proxy
# we must notify proxy to tunnel connection
# so we send CONNECT command:
# CONNECT www.python.org:443 HTTP/1.1
# Host: www.python.org
#
# next we must do TLS handshake and so on
# to do this we must wrap raw socket into secure one
# asyncio handles this perfectly
proxy_req.method = hdrs.METH_CONNECT
proxy_req.path = '{}:{}'.format(req.host, req.port)
key = (req.host, req.port, req.ssl)
conn = Connection(self, key, proxy_req,
transport, proto, self._loop)
self._acquired[key].add(conn._transport)
proxy_resp = proxy_req.send(conn.writer, conn.reader)
try:
resp = yield from proxy_resp.start(conn, True)
except:
proxy_resp.close()
conn.close()
raise
else:
conn.detach()
if resp.status != 200:
raise HttpProxyError(code=resp.status, message=resp.reason)
rawsock = transport.get_extra_info('socket', default=None)
if rawsock is None:
raise RuntimeError(
"Transport does not expose socket instance")
transport.pause_reading()
transport, proto = yield from self._loop.create_connection(
self._factory, ssl=self.ssl_context, sock=rawsock,
server_hostname=req.host)
finally:
proxy_resp.close()
return transport, proto
class ProxyConnector(TCPConnector):
"""Http Proxy connector.
Deprecated, use ClientSession.request with proxy parameters.
Is still here for backward compatibility.
proxy - Proxy URL address. Only HTTP proxy supported.
proxy_auth - (optional) Proxy HTTP Basic Auth
proxy_auth - aiohttp.helpers.BasicAuth
conn_timeout - (optional) Connect timeout.
keepalive_timeout - (optional) Keep-alive timeout.
force_close - Set to True to force close and do reconnect
after each request (and between redirects).
limit - The limit of simultaneous connections to the same endpoint.
loop - Optional event loop.
Usage:
>>> conn = ProxyConnector(proxy="http://some.proxy.com")
>>> session = ClientSession(connector=conn)
>>> resp = yield from session.get('http://python.org')
"""
def __init__(self, proxy, *, proxy_auth=None, force_close=True,
conn_timeout=None, keepalive_timeout=sentinel,
limit=20, loop=None):
warnings.warn("ProxyConnector is deprecated, use "
"client.get(url, proxy=proxy_url) instead",
DeprecationWarning)
super().__init__(force_close=force_close,
conn_timeout=conn_timeout,
keepalive_timeout=keepalive_timeout,
limit=limit, loop=loop)
self._proxy = proxy
self._proxy_auth = proxy_auth
@property
def proxy(self):
return self._proxy
@property
def proxy_auth(self):
return self._proxy_auth
@asyncio.coroutine
def _create_connection(self, req):
"""
Use TCPConnector _create_connection, to emulate old ProxyConnector.
"""
req.update_proxy(self._proxy, self._proxy_auth)
transport, proto = yield from super()._create_connection(req)
return transport, proto
class UnixConnector(BaseConnector):
"""Unix socket connector.
path - Unix socket path.
conn_timeout - (optional) Connect timeout.
keepalive_timeout - (optional) Keep-alive timeout.
force_close - Set to True to force close and do reconnect
after each request (and between redirects).
limit - The limit of simultaneous connections to the same endpoint.
loop - Optional event loop.
Usage:
>>> conn = UnixConnector(path='/path/to/socket')
>>> session = ClientSession(connector=conn)
>>> resp = yield from session.get('http://python.org')
"""
def __init__(self, path, force_close=False, conn_timeout=None,
keepalive_timeout=sentinel, limit=20, loop=None):
super().__init__(force_close=force_close,
conn_timeout=conn_timeout,
keepalive_timeout=keepalive_timeout,
limit=limit, loop=loop)
self._path = path
@property
def path(self):
"""Path to unix socket."""
return self._path
@asyncio.coroutine
def _create_connection(self, req):
return (yield from self._loop.create_unix_connection(
self._factory, self._path))

View File

@ -0,0 +1,290 @@
import datetime
import re
from collections import defaultdict
from collections.abc import Mapping
from http.cookies import Morsel, SimpleCookie
from math import ceil
from urllib.parse import urlsplit
from .abc import AbstractCookieJar
from .helpers import is_ip_address
class CookieJar(AbstractCookieJar):
"""Implements cookie storage adhering to RFC 6265."""
DATE_TOKENS_RE = re.compile(
"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)")
DATE_HMS_TIME_RE = re.compile("(\d{1,2}):(\d{1,2}):(\d{1,2})")
DATE_DAY_OF_MONTH_RE = re.compile("(\d{1,2})")
DATE_MONTH_RE = re.compile("(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|"
"(aug)|(sep)|(oct)|(nov)|(dec)", re.I)
DATE_YEAR_RE = re.compile("(\d{2,4})")
MAX_TIME = 2051215261.0 # so far in future (2035-01-01)
def __init__(self, *, unsafe=False, loop=None):
super().__init__(loop=loop)
self._cookies = defaultdict(SimpleCookie)
self._host_only_cookies = set()
self._unsafe = unsafe
self._next_expiration = ceil(self._loop.time())
self._expirations = {}
def clear(self):
self._cookies.clear()
self._host_only_cookies.clear()
self._next_expiration = ceil(self._loop.time())
self._expirations.clear()
def __iter__(self):
self._do_expiration()
for val in self._cookies.values():
yield from val.values()
def __len__(self):
return sum(1 for i in self)
def _do_expiration(self):
now = self._loop.time()
if self._next_expiration > now:
return
if not self._expirations:
return
next_expiration = self.MAX_TIME
to_del = []
cookies = self._cookies
expirations = self._expirations
for (domain, name), when in expirations.items():
if when < now:
cookies[domain].pop(name, None)
to_del.append((domain, name))
self._host_only_cookies.discard((domain, name))
else:
next_expiration = min(next_expiration, when)
for key in to_del:
del expirations[key]
self._next_expiration = ceil(next_expiration)
def _expire_cookie(self, when, domain, name):
self._next_expiration = min(self._next_expiration, when)
self._expirations[(domain, name)] = when
def update_cookies(self, cookies, response_url=None):
"""Update cookies."""
url_parsed = urlsplit(response_url or "")
hostname = url_parsed.hostname
if not self._unsafe and is_ip_address(hostname):
# Don't accept cookies from IPs
return
if isinstance(cookies, Mapping):
cookies = cookies.items()
for name, cookie in cookies:
if not isinstance(cookie, Morsel):
tmp = SimpleCookie()
tmp[name] = cookie
cookie = tmp[name]
domain = cookie["domain"]
# ignore domains with trailing dots
if domain.endswith('.'):
domain = ""
del cookie["domain"]
if not domain and hostname is not None:
# Set the cookie's domain to the response hostname
# and set its host-only-flag
self._host_only_cookies.add((hostname, name))
domain = cookie["domain"] = hostname
if domain.startswith("."):
# Remove leading dot
domain = domain[1:]
cookie["domain"] = domain
if hostname and not self._is_domain_match(domain, hostname):
# Setting cookies for different domains is not allowed
continue
path = cookie["path"]
if not path or not path.startswith("/"):
# Set the cookie's path to the response path
path = url_parsed.path
if not path.startswith("/"):
path = "/"
else:
# Cut everything from the last slash to the end
path = "/" + path[1:path.rfind("/")]
cookie["path"] = path
max_age = cookie["max-age"]
if max_age:
try:
delta_seconds = int(max_age)
self._expire_cookie(self._loop.time() + delta_seconds,
domain, name)
except ValueError:
cookie["max-age"] = ""
else:
expires = cookie["expires"]
if expires:
expire_time = self._parse_date(expires)
if expire_time:
self._expire_cookie(expire_time.timestamp(),
domain, name)
else:
cookie["expires"] = ""
# use dict method because SimpleCookie class modifies value
# before Python 3.4.3
dict.__setitem__(self._cookies[domain], name, cookie)
self._do_expiration()
def filter_cookies(self, request_url):
"""Returns this jar's cookies filtered by their attributes."""
self._do_expiration()
url_parsed = urlsplit(request_url)
filtered = SimpleCookie()
hostname = url_parsed.hostname or ""
is_not_secure = url_parsed.scheme not in ("https", "wss")
for cookie in self:
name = cookie.key
domain = cookie["domain"]
# Send shared cookies
if not domain:
filtered[name] = cookie.value
continue
if not self._unsafe and is_ip_address(hostname):
continue
if (domain, name) in self._host_only_cookies:
if domain != hostname:
continue
elif not self._is_domain_match(domain, hostname):
continue
if not self._is_path_match(url_parsed.path, cookie["path"]):
continue
if is_not_secure and cookie["secure"]:
continue
filtered[name] = cookie.value
return filtered
@staticmethod
def _is_domain_match(domain, hostname):
"""Implements domain matching adhering to RFC 6265."""
if hostname == domain:
return True
if not hostname.endswith(domain):
return False
non_matching = hostname[:-len(domain)]
if not non_matching.endswith("."):
return False
return not is_ip_address(hostname)
@staticmethod
def _is_path_match(req_path, cookie_path):
"""Implements path matching adhering to RFC 6265."""
if not req_path.startswith("/"):
req_path = "/"
if req_path == cookie_path:
return True
if not req_path.startswith(cookie_path):
return False
if cookie_path.endswith("/"):
return True
non_matching = req_path[len(cookie_path):]
return non_matching.startswith("/")
@classmethod
def _parse_date(cls, date_str):
"""Implements date string parsing adhering to RFC 6265."""
if not date_str:
return
found_time = False
found_day = False
found_month = False
found_year = False
hour = minute = second = 0
day = 0
month = 0
year = 0
for token_match in cls.DATE_TOKENS_RE.finditer(date_str):
token = token_match.group("token")
if not found_time:
time_match = cls.DATE_HMS_TIME_RE.match(token)
if time_match:
found_time = True
hour, minute, second = [
int(s) for s in time_match.groups()]
continue
if not found_day:
day_match = cls.DATE_DAY_OF_MONTH_RE.match(token)
if day_match:
found_day = True
day = int(day_match.group())
continue
if not found_month:
month_match = cls.DATE_MONTH_RE.match(token)
if month_match:
found_month = True
month = month_match.lastindex
continue
if not found_year:
year_match = cls.DATE_YEAR_RE.match(token)
if year_match:
found_year = True
year = int(year_match.group())
if 70 <= year <= 99:
year += 1900
elif 0 <= year <= 69:
year += 2000
if False in (found_day, found_month, found_year, found_time):
return
if not 1 <= day <= 31:
return
if year < 1601 or hour > 23 or minute > 59 or second > 59:
return
return datetime.datetime(year, month, day,
hour, minute, second,
tzinfo=datetime.timezone.utc)

View File

@ -0,0 +1,186 @@
"""HTTP related errors."""
from asyncio import TimeoutError
__all__ = (
'DisconnectedError', 'ClientDisconnectedError', 'ServerDisconnectedError',
'HttpProcessingError', 'BadHttpMessage',
'HttpMethodNotAllowed', 'HttpBadRequest', 'HttpProxyError',
'BadStatusLine', 'LineTooLong', 'InvalidHeader',
'ClientError', 'ClientHttpProcessingError', 'ClientConnectionError',
'ClientOSError', 'ClientTimeoutError', 'ProxyConnectionError',
'ClientRequestError', 'ClientResponseError',
'FingerprintMismatch',
'WSServerHandshakeError', 'WSClientDisconnectedError')
class DisconnectedError(Exception):
"""Disconnected."""
class ClientDisconnectedError(DisconnectedError):
"""Client disconnected."""
class ServerDisconnectedError(DisconnectedError):
"""Server disconnected."""
class WSClientDisconnectedError(ClientDisconnectedError):
"""Deprecated."""
class ClientError(Exception):
"""Base class for client connection errors."""
class ClientHttpProcessingError(ClientError):
"""Base class for client HTTP processing errors."""
class ClientRequestError(ClientHttpProcessingError):
"""Connection error during sending request."""
class ClientResponseError(ClientHttpProcessingError):
"""Connection error during reading response."""
class ClientConnectionError(ClientError):
"""Base class for client socket errors."""
class ClientOSError(ClientConnectionError, OSError):
"""OSError error."""
class ClientTimeoutError(ClientConnectionError, TimeoutError):
"""Client connection timeout error."""
class ProxyConnectionError(ClientConnectionError):
"""Proxy connection error.
Raised in :class:`aiohttp.connector.ProxyConnector` if
connection to proxy can not be established.
"""
class HttpProcessingError(Exception):
"""HTTP error.
Shortcut for raising HTTP errors with custom code, message and headers.
:param int code: HTTP Error code.
:param str message: (optional) Error message.
:param list of [tuple] headers: (optional) Headers to be sent in response.
"""
code = 0
message = ''
headers = None
def __init__(self, *, code=None, message='', headers=None):
if code is not None:
self.code = code
self.headers = headers
self.message = message
super().__init__("%s, message='%s'" % (self.code, message))
class WSServerHandshakeError(HttpProcessingError):
"""websocket server handshake error."""
class HttpProxyError(HttpProcessingError):
"""HTTP proxy error.
Raised in :class:`aiohttp.connector.ProxyConnector` if
proxy responds with status other than ``200 OK``
on ``CONNECT`` request.
"""
class BadHttpMessage(HttpProcessingError):
code = 400
message = 'Bad Request'
def __init__(self, message, *, headers=None):
super().__init__(message=message, headers=headers)
class HttpMethodNotAllowed(HttpProcessingError):
code = 405
message = 'Method Not Allowed'
class HttpBadRequest(BadHttpMessage):
code = 400
message = 'Bad Request'
class ContentEncodingError(BadHttpMessage):
"""Content encoding error."""
class TransferEncodingError(BadHttpMessage):
"""transfer encoding error."""
class LineTooLong(BadHttpMessage):
def __init__(self, line, limit='Unknown'):
super().__init__(
"got more than %s bytes when reading %s" % (limit, line))
class InvalidHeader(BadHttpMessage):
def __init__(self, hdr):
if isinstance(hdr, bytes):
hdr = hdr.decode('utf-8', 'surrogateescape')
super().__init__('Invalid HTTP Header: {}'.format(hdr))
self.hdr = hdr
class BadStatusLine(BadHttpMessage):
def __init__(self, line=''):
if not line:
line = repr(line)
self.args = line,
self.line = line
class LineLimitExceededParserError(HttpBadRequest):
"""Line is too long."""
def __init__(self, msg, limit):
super().__init__(msg)
self.limit = limit
class FingerprintMismatch(ClientConnectionError):
"""SSL certificate does not match expected fingerprint."""
def __init__(self, expected, got, host, port):
self.expected = expected
self.got = got
self.host = host
self.port = port
def __repr__(self):
return '<{} expected={} got={} host={} port={}>'.format(
self.__class__.__name__, self.expected, self.got,
self.host, self.port)
class InvalidURL(Exception):
"""Invalid URL."""

View File

@ -0,0 +1,168 @@
import asyncio
import mimetypes
import os
from . import hdrs
from .helpers import create_future
from .web_reqrep import StreamResponse
class FileSender:
""""A helper that can be used to send files.
"""
def __init__(self, *, resp_factory=StreamResponse, chunk_size=256*1024):
self._response_factory = resp_factory
self._chunk_size = chunk_size
if bool(os.environ.get("AIOHTTP_NOSENDFILE")):
self._sendfile = self._sendfile_fallback
def _sendfile_cb(self, fut, out_fd, in_fd, offset,
count, loop, registered):
if registered:
loop.remove_writer(out_fd)
if fut.cancelled():
return
try:
n = os.sendfile(out_fd, in_fd, offset, count)
if n == 0: # EOF reached
n = count
except (BlockingIOError, InterruptedError):
n = 0
except Exception as exc:
fut.set_exception(exc)
return
if n < count:
loop.add_writer(out_fd, self._sendfile_cb, fut, out_fd, in_fd,
offset + n, count - n, loop, True)
else:
fut.set_result(None)
@asyncio.coroutine
def _sendfile_system(self, request, resp, fobj, count):
# Write count bytes of fobj to resp using
# the os.sendfile system call.
#
# request should be a aiohttp.web.Request instance.
#
# resp should be a aiohttp.web.StreamResponse instance.
#
# fobj should be an open file object.
#
# count should be an integer > 0.
transport = request.transport
if transport.get_extra_info("sslcontext"):
yield from self._sendfile_fallback(request, resp, fobj, count)
return
def _send_headers(resp_impl):
# Durty hack required for
# https://github.com/KeepSafe/aiohttp/issues/1093
# don't send headers in sendfile mode
pass
resp._send_headers = _send_headers
@asyncio.coroutine
def write_eof():
# Durty hack required for
# https://github.com/KeepSafe/aiohttp/issues/1177
# do nothing in write_eof
pass
resp.write_eof = write_eof
resp_impl = yield from resp.prepare(request)
loop = request.app.loop
# See https://github.com/KeepSafe/aiohttp/issues/958 for details
# send headers
headers = ['HTTP/{0.major}.{0.minor} 200 OK\r\n'.format(
request.version)]
for hdr, val in resp.headers.items():
headers.append('{}: {}\r\n'.format(hdr, val))
headers.append('\r\n')
out_socket = transport.get_extra_info("socket").dup()
out_socket.setblocking(False)
out_fd = out_socket.fileno()
in_fd = fobj.fileno()
bheaders = ''.join(headers).encode('utf-8')
headers_length = len(bheaders)
resp_impl.headers_length = headers_length
resp_impl.output_length = headers_length + count
try:
yield from loop.sock_sendall(out_socket, bheaders)
fut = create_future(loop)
self._sendfile_cb(fut, out_fd, in_fd, 0, count, loop, False)
yield from fut
finally:
out_socket.close()
@asyncio.coroutine
def _sendfile_fallback(self, request, resp, fobj, count):
# Mimic the _sendfile_system() method, but without using the
# os.sendfile() system call. This should be used on systems
# that don't support the os.sendfile().
# To avoid blocking the event loop & to keep memory usage low,
# fobj is transferred in chunks controlled by the
# constructor's chunk_size argument.
yield from resp.prepare(request)
chunk_size = self._chunk_size
chunk = fobj.read(chunk_size)
while True:
resp.write(chunk)
yield from resp.drain()
count = count - chunk_size
if count <= 0:
break
chunk = fobj.read(count)
if hasattr(os, "sendfile"): # pragma: no cover
_sendfile = _sendfile_system
else: # pragma: no cover
_sendfile = _sendfile_fallback
@asyncio.coroutine
def send(self, request, filepath):
"""Send filepath to client using request."""
st = filepath.stat()
modsince = request.if_modified_since
if modsince is not None and st.st_mtime <= modsince.timestamp():
from .web_exceptions import HTTPNotModified
raise HTTPNotModified()
ct, encoding = mimetypes.guess_type(str(filepath))
if not ct:
ct = 'application/octet-stream'
resp = self._response_factory()
resp.content_type = ct
if encoding:
resp.headers[hdrs.CONTENT_ENCODING] = encoding
resp.last_modified = st.st_mtime
file_size = st.st_size
resp.content_length = file_size
resp.set_tcp_cork(True)
try:
with filepath.open('rb') as f:
yield from self._sendfile(request, resp, f, file_size)
finally:
resp.set_tcp_nodelay(True)
return resp

View File

@ -0,0 +1,91 @@
"""HTTP Headers constants."""
from multidict import istr
METH_ANY = '*'
METH_CONNECT = 'CONNECT'
METH_HEAD = 'HEAD'
METH_GET = 'GET'
METH_DELETE = 'DELETE'
METH_OPTIONS = 'OPTIONS'
METH_PATCH = 'PATCH'
METH_POST = 'POST'
METH_PUT = 'PUT'
METH_TRACE = 'TRACE'
METH_ALL = {METH_CONNECT, METH_HEAD, METH_GET, METH_DELETE,
METH_OPTIONS, METH_PATCH, METH_POST, METH_PUT, METH_TRACE}
ACCEPT = istr('ACCEPT')
ACCEPT_CHARSET = istr('ACCEPT-CHARSET')
ACCEPT_ENCODING = istr('ACCEPT-ENCODING')
ACCEPT_LANGUAGE = istr('ACCEPT-LANGUAGE')
ACCEPT_RANGES = istr('ACCEPT-RANGES')
ACCESS_CONTROL_MAX_AGE = istr('ACCESS-CONTROL-MAX-AGE')
ACCESS_CONTROL_ALLOW_CREDENTIALS = istr('ACCESS-CONTROL-ALLOW-CREDENTIALS')
ACCESS_CONTROL_ALLOW_HEADERS = istr('ACCESS-CONTROL-ALLOW-HEADERS')
ACCESS_CONTROL_ALLOW_METHODS = istr('ACCESS-CONTROL-ALLOW-METHODS')
ACCESS_CONTROL_ALLOW_ORIGIN = istr('ACCESS-CONTROL-ALLOW-ORIGIN')
ACCESS_CONTROL_EXPOSE_HEADERS = istr('ACCESS-CONTROL-EXPOSE-HEADERS')
ACCESS_CONTROL_REQUEST_HEADERS = istr('ACCESS-CONTROL-REQUEST-HEADERS')
ACCESS_CONTROL_REQUEST_METHOD = istr('ACCESS-CONTROL-REQUEST-METHOD')
AGE = istr('AGE')
ALLOW = istr('ALLOW')
AUTHORIZATION = istr('AUTHORIZATION')
CACHE_CONTROL = istr('CACHE-CONTROL')
CONNECTION = istr('CONNECTION')
CONTENT_DISPOSITION = istr('CONTENT-DISPOSITION')
CONTENT_ENCODING = istr('CONTENT-ENCODING')
CONTENT_LANGUAGE = istr('CONTENT-LANGUAGE')
CONTENT_LENGTH = istr('CONTENT-LENGTH')
CONTENT_LOCATION = istr('CONTENT-LOCATION')
CONTENT_MD5 = istr('CONTENT-MD5')
CONTENT_RANGE = istr('CONTENT-RANGE')
CONTENT_TRANSFER_ENCODING = istr('CONTENT-TRANSFER-ENCODING')
CONTENT_TYPE = istr('CONTENT-TYPE')
COOKIE = istr('COOKIE')
DATE = istr('DATE')
DESTINATION = istr('DESTINATION')
DIGEST = istr('DIGEST')
ETAG = istr('ETAG')
EXPECT = istr('EXPECT')
EXPIRES = istr('EXPIRES')
FROM = istr('FROM')
HOST = istr('HOST')
IF_MATCH = istr('IF-MATCH')
IF_MODIFIED_SINCE = istr('IF-MODIFIED-SINCE')
IF_NONE_MATCH = istr('IF-NONE-MATCH')
IF_RANGE = istr('IF-RANGE')
IF_UNMODIFIED_SINCE = istr('IF-UNMODIFIED-SINCE')
KEEP_ALIVE = istr('KEEP-ALIVE')
LAST_EVENT_ID = istr('LAST-EVENT-ID')
LAST_MODIFIED = istr('LAST-MODIFIED')
LINK = istr('LINK')
LOCATION = istr('LOCATION')
MAX_FORWARDS = istr('MAX-FORWARDS')
ORIGIN = istr('ORIGIN')
PRAGMA = istr('PRAGMA')
PROXY_AUTHENTICATE = istr('PROXY_AUTHENTICATE')
PROXY_AUTHORIZATION = istr('PROXY-AUTHORIZATION')
RANGE = istr('RANGE')
REFERER = istr('REFERER')
RETRY_AFTER = istr('RETRY-AFTER')
SEC_WEBSOCKET_ACCEPT = istr('SEC-WEBSOCKET-ACCEPT')
SEC_WEBSOCKET_VERSION = istr('SEC-WEBSOCKET-VERSION')
SEC_WEBSOCKET_PROTOCOL = istr('SEC-WEBSOCKET-PROTOCOL')
SEC_WEBSOCKET_KEY = istr('SEC-WEBSOCKET-KEY')
SEC_WEBSOCKET_KEY1 = istr('SEC-WEBSOCKET-KEY1')
SERVER = istr('SERVER')
SET_COOKIE = istr('SET-COOKIE')
TE = istr('TE')
TRAILER = istr('TRAILER')
TRANSFER_ENCODING = istr('TRANSFER-ENCODING')
UPGRADE = istr('UPGRADE')
WEBSOCKET = istr('WEBSOCKET')
URI = istr('URI')
USER_AGENT = istr('USER-AGENT')
VARY = istr('VARY')
VIA = istr('VIA')
WANT_DIGEST = istr('WANT-DIGEST')
WARNING = istr('WARNING')
WWW_AUTHENTICATE = istr('WWW-AUTHENTICATE')

View File

@ -0,0 +1,534 @@
"""Various helper functions"""
import asyncio
import base64
import binascii
import datetime
import functools
import io
import os
import re
import warnings
from collections import namedtuple
from pathlib import Path
from urllib.parse import quote, urlencode
from async_timeout import timeout
from multidict import MultiDict, MultiDictProxy
from . import hdrs
from .errors import InvalidURL
try:
from asyncio import ensure_future
except ImportError:
ensure_future = asyncio.async
__all__ = ('BasicAuth', 'create_future', 'FormData', 'parse_mimetype',
'Timeout', 'ensure_future')
sentinel = object()
Timeout = timeout
class BasicAuth(namedtuple('BasicAuth', ['login', 'password', 'encoding'])):
"""Http basic authentication helper.
:param str login: Login
:param str password: Password
:param str encoding: (optional) encoding ('latin1' by default)
"""
def __new__(cls, login, password='', encoding='latin1'):
if login is None:
raise ValueError('None is not allowed as login value')
if password is None:
raise ValueError('None is not allowed as password value')
return super().__new__(cls, login, password, encoding)
@classmethod
def decode(cls, auth_header, encoding='latin1'):
"""Create a :class:`BasicAuth` object from an ``Authorization`` HTTP
header."""
split = auth_header.strip().split(' ')
if len(split) == 2:
if split[0].strip().lower() != 'basic':
raise ValueError('Unknown authorization method %s' % split[0])
to_decode = split[1]
else:
raise ValueError('Could not parse authorization header.')
try:
username, _, password = base64.b64decode(
to_decode.encode('ascii')
).decode(encoding).partition(':')
except binascii.Error:
raise ValueError('Invalid base64 encoding.')
return cls(username, password, encoding=encoding)
def encode(self):
"""Encode credentials."""
creds = ('%s:%s' % (self.login, self.password)).encode(self.encoding)
return 'Basic %s' % base64.b64encode(creds).decode(self.encoding)
def create_future(loop):
"""Compatibility wrapper for the loop.create_future() call introduced in
3.5.2."""
if hasattr(loop, 'create_future'):
return loop.create_future()
else:
return asyncio.Future(loop=loop)
class FormData:
"""Helper class for multipart/form-data and
application/x-www-form-urlencoded body generation."""
def __init__(self, fields=()):
from . import multipart
self._writer = multipart.MultipartWriter('form-data')
self._fields = []
self._is_multipart = False
if isinstance(fields, dict):
fields = list(fields.items())
elif not isinstance(fields, (list, tuple)):
fields = (fields,)
self.add_fields(*fields)
@property
def is_multipart(self):
return self._is_multipart
@property
def content_type(self):
if self._is_multipart:
return self._writer.headers[hdrs.CONTENT_TYPE]
else:
return 'application/x-www-form-urlencoded'
def add_field(self, name, value, *, content_type=None, filename=None,
content_transfer_encoding=None):
if isinstance(value, io.IOBase):
self._is_multipart = True
elif isinstance(value, (bytes, bytearray, memoryview)):
if filename is None and content_transfer_encoding is None:
filename = name
type_options = MultiDict({'name': name})
if filename is not None and not isinstance(filename, str):
raise TypeError('filename must be an instance of str. '
'Got: %s' % filename)
if filename is None and isinstance(value, io.IOBase):
filename = guess_filename(value, name)
if filename is not None:
type_options['filename'] = filename
self._is_multipart = True
headers = {}
if content_type is not None:
if not isinstance(content_type, str):
raise TypeError('content_type must be an instance of str. '
'Got: %s' % content_type)
headers[hdrs.CONTENT_TYPE] = content_type
self._is_multipart = True
if content_transfer_encoding is not None:
if not isinstance(content_transfer_encoding, str):
raise TypeError('content_transfer_encoding must be an instance'
' of str. Got: %s' % content_transfer_encoding)
headers[hdrs.CONTENT_TRANSFER_ENCODING] = content_transfer_encoding
self._is_multipart = True
self._fields.append((type_options, headers, value))
def add_fields(self, *fields):
to_add = list(fields)
while to_add:
rec = to_add.pop(0)
if isinstance(rec, io.IOBase):
k = guess_filename(rec, 'unknown')
self.add_field(k, rec)
elif isinstance(rec, (MultiDictProxy, MultiDict)):
to_add.extend(rec.items())
elif isinstance(rec, (list, tuple)) and len(rec) == 2:
k, fp = rec
self.add_field(k, fp)
else:
raise TypeError('Only io.IOBase, multidict and (name, file) '
'pairs allowed, use .add_field() for passing '
'more complex parameters')
def _gen_form_urlencoded(self, encoding):
# form data (x-www-form-urlencoded)
data = []
for type_options, _, value in self._fields:
data.append((type_options['name'], value))
data = urlencode(data, doseq=True)
return data.encode(encoding)
def _gen_form_data(self, *args, **kwargs):
"""Encode a list of fields using the multipart/form-data MIME format"""
for dispparams, headers, value in self._fields:
part = self._writer.append(value, headers)
if dispparams:
part.set_content_disposition('form-data', **dispparams)
# FIXME cgi.FieldStorage doesn't likes body parts with
# Content-Length which were sent via chunked transfer encoding
part.headers.pop(hdrs.CONTENT_LENGTH, None)
yield from self._writer.serialize()
def __call__(self, encoding):
if self._is_multipart:
return self._gen_form_data(encoding)
else:
return self._gen_form_urlencoded(encoding)
def parse_mimetype(mimetype):
"""Parses a MIME type into its components.
:param str mimetype: MIME type
:returns: 4 element tuple for MIME type, subtype, suffix and parameters
:rtype: tuple
Example:
>>> parse_mimetype('text/html; charset=utf-8')
('text', 'html', '', {'charset': 'utf-8'})
"""
if not mimetype:
return '', '', '', {}
parts = mimetype.split(';')
params = []
for item in parts[1:]:
if not item:
continue
key, value = item.split('=', 1) if '=' in item else (item, '')
params.append((key.lower().strip(), value.strip(' "')))
params = dict(params)
fulltype = parts[0].strip().lower()
if fulltype == '*':
fulltype = '*/*'
mtype, stype = fulltype.split('/', 1) \
if '/' in fulltype else (fulltype, '')
stype, suffix = stype.split('+', 1) if '+' in stype else (stype, '')
return mtype, stype, suffix, params
def guess_filename(obj, default=None):
name = getattr(obj, 'name', None)
if name and name[0] != '<' and name[-1] != '>':
return Path(name).name
return default
class AccessLogger:
"""Helper object to log access.
Usage:
log = logging.getLogger("spam")
log_format = "%a %{User-Agent}i"
access_logger = AccessLogger(log, log_format)
access_logger.log(message, environ, response, transport, time)
Format:
%% The percent sign
%a Remote IP-address (IP-address of proxy if using reverse proxy)
%t Time when the request was started to process
%P The process ID of the child that serviced the request
%r First line of request
%s Response status code
%b Size of response in bytes, excluding HTTP headers
%O Bytes sent, including headers
%T Time taken to serve the request, in seconds
%Tf Time taken to serve the request, in seconds with floating fraction
in .06f format
%D Time taken to serve the request, in microseconds
%{FOO}i request.headers['FOO']
%{FOO}o response.headers['FOO']
%{FOO}e os.environ['FOO']
"""
LOG_FORMAT = '%a %l %u %t "%r" %s %b "%{Referrer}i" "%{User-Agent}i"'
FORMAT_RE = re.compile(r'%(\{([A-Za-z\-]+)\}([ioe])|[atPrsbOD]|Tf?)')
CLEANUP_RE = re.compile(r'(%[^s])')
_FORMAT_CACHE = {}
def __init__(self, logger, log_format=LOG_FORMAT):
"""Initialise the logger.
:param logger: logger object to be used for logging
:param log_format: apache compatible log format
"""
self.logger = logger
_compiled_format = AccessLogger._FORMAT_CACHE.get(log_format)
if not _compiled_format:
_compiled_format = self.compile_format(log_format)
AccessLogger._FORMAT_CACHE[log_format] = _compiled_format
self._log_format, self._methods = _compiled_format
def compile_format(self, log_format):
"""Translate log_format into form usable by modulo formatting
All known atoms will be replaced with %s
Also methods for formatting of those atoms will be added to
_methods in apropriate order
For example we have log_format = "%a %t"
This format will be translated to "%s %s"
Also contents of _methods will be
[self._format_a, self._format_t]
These method will be called and results will be passed
to translated string format.
Each _format_* method receive 'args' which is list of arguments
given to self.log
Exceptions are _format_e, _format_i and _format_o methods which
also receive key name (by functools.partial)
"""
log_format = log_format.replace("%l", "-")
log_format = log_format.replace("%u", "-")
methods = []
for atom in self.FORMAT_RE.findall(log_format):
if atom[1] == '':
methods.append(getattr(AccessLogger, '_format_%s' % atom[0]))
else:
m = getattr(AccessLogger, '_format_%s' % atom[2])
methods.append(functools.partial(m, atom[1]))
log_format = self.FORMAT_RE.sub(r'%s', log_format)
log_format = self.CLEANUP_RE.sub(r'%\1', log_format)
return log_format, methods
@staticmethod
def _format_e(key, args):
return (args[1] or {}).get(key, '-')
@staticmethod
def _format_i(key, args):
if not args[0]:
return '(no headers)'
# suboptimal, make istr(key) once
return args[0].headers.get(key, '-')
@staticmethod
def _format_o(key, args):
# suboptimal, make istr(key) once
return args[2].headers.get(key, '-')
@staticmethod
def _format_a(args):
if args[3] is None:
return '-'
peername = args[3].get_extra_info('peername')
if isinstance(peername, (list, tuple)):
return peername[0]
else:
return peername
@staticmethod
def _format_t(args):
return datetime.datetime.utcnow().strftime('[%d/%b/%Y:%H:%M:%S +0000]')
@staticmethod
def _format_P(args):
return "<%s>" % os.getpid()
@staticmethod
def _format_r(args):
msg = args[0]
if not msg:
return '-'
return '%s %s HTTP/%s.%s' % tuple((msg.method,
msg.path) + msg.version)
@staticmethod
def _format_s(args):
return args[2].status
@staticmethod
def _format_b(args):
return args[2].body_length
@staticmethod
def _format_O(args):
return args[2].output_length
@staticmethod
def _format_T(args):
return round(args[4])
@staticmethod
def _format_Tf(args):
return '%06f' % args[4]
@staticmethod
def _format_D(args):
return round(args[4] * 1000000)
def _format_line(self, args):
return tuple(m(args) for m in self._methods)
def log(self, message, environ, response, transport, time):
"""Log access.
:param message: Request object. May be None.
:param environ: Environment dict. May be None.
:param response: Response object.
:param transport: Tansport object. May be None
:param float time: Time taken to serve the request.
"""
try:
self.logger.info(self._log_format % self._format_line(
[message, environ, response, transport, time]))
except Exception:
self.logger.exception("Error in logging")
class reify:
"""Use as a class method decorator. It operates almost exactly like
the Python `@property` decorator, but it puts the result of the
method it decorates into the instance dict after the first call,
effectively replacing the function it decorates with an instance
variable. It is, in Python parlance, a data descriptor.
"""
def __init__(self, wrapped):
self.wrapped = wrapped
try:
self.__doc__ = wrapped.__doc__
except: # pragma: no cover
self.__doc__ = ""
self.name = wrapped.__name__
def __get__(self, inst, owner, _sentinel=sentinel):
if inst is None:
return self
val = inst.__dict__.get(self.name, _sentinel)
if val is not _sentinel:
return val
val = self.wrapped(inst)
inst.__dict__[self.name] = val
return val
def __set__(self, inst, value):
raise AttributeError("reified property is read-only")
# The unreserved URI characters (RFC 3986)
UNRESERVED_SET = frozenset(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +
"0123456789-._~")
def unquote_unreserved(uri):
"""Un-escape any percent-escape sequences in a URI that are unreserved
characters. This leaves all reserved, illegal and non-ASCII bytes encoded.
"""
parts = uri.split('%')
for i in range(1, len(parts)):
h = parts[i][0:2]
if len(h) == 2 and h.isalnum():
try:
c = chr(int(h, 16))
except ValueError:
raise InvalidURL("Invalid percent-escape sequence: '%s'" % h)
if c in UNRESERVED_SET:
parts[i] = c + parts[i][2:]
else:
parts[i] = '%' + parts[i]
else:
parts[i] = '%' + parts[i]
return ''.join(parts)
def requote_uri(uri):
"""Re-quote the given URI.
This function passes the given URI through an unquote/quote cycle to
ensure that it is fully and consistently quoted.
"""
safe_with_percent = "!#$%&'()*+,/:;=?@[]~"
safe_without_percent = "!#$&'()*+,/:;=?@[]~"
try:
# Unquote only the unreserved characters
# Then quote only illegal characters (do not quote reserved,
# unreserved, or '%')
return quote(unquote_unreserved(uri), safe=safe_with_percent)
except InvalidURL:
# We couldn't unquote the given URI, so let's try quoting it, but
# there may be unquoted '%'s in the URI. We need to make sure they're
# properly quoted so they do not cause issues elsewhere.
return quote(uri, safe=safe_without_percent)
_ipv4_pattern = ('^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}'
'(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$')
_ipv6_pattern = (
'^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}'
'(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)'
'((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})'
'(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}'
'(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}'
'[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)'
'(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}'
':|:(:[A-F0-9]{1,4}){7})$')
_ipv4_regex = re.compile(_ipv4_pattern)
_ipv6_regex = re.compile(_ipv6_pattern, flags=re.IGNORECASE)
_ipv4_regexb = re.compile(_ipv4_pattern.encode('ascii'))
_ipv6_regexb = re.compile(_ipv6_pattern.encode('ascii'), flags=re.IGNORECASE)
def is_ip_address(host):
if host is None:
return False
if isinstance(host, str):
if _ipv4_regex.match(host) or _ipv6_regex.match(host):
return True
else:
return False
elif isinstance(host, (bytes, bytearray, memoryview)):
if _ipv4_regexb.match(host) or _ipv6_regexb.match(host):
return True
else:
return False
else:
raise TypeError("{} [{}] is not a str or bytes"
.format(host, type(host)))
def _get_kwarg(kwargs, old, new, value):
val = kwargs.pop(old, sentinel)
if val is not sentinel:
warnings.warn("{} is deprecated, use {} instead".format(old, new),
DeprecationWarning,
stacklevel=3)
return val
else:
return value

View File

@ -0,0 +1,8 @@
import logging
access_logger = logging.getLogger('aiohttp.access')
client_logger = logging.getLogger('aiohttp.client')
internal_logger = logging.getLogger('aiohttp.internal')
server_logger = logging.getLogger('aiohttp.server')
web_logger = logging.getLogger('aiohttp.web')
ws_logger = logging.getLogger('aiohttp.websocket')

View File

@ -0,0 +1,973 @@
import asyncio
import base64
import binascii
import io
import json
import mimetypes
import os
import re
import sys
import uuid
import warnings
import zlib
from collections import Mapping, Sequence, deque
from pathlib import Path
from urllib.parse import parse_qsl, quote, unquote, urlencode
from multidict import CIMultiDict
from .hdrs import (CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_LENGTH,
CONTENT_TRANSFER_ENCODING, CONTENT_TYPE)
from .helpers import parse_mimetype
from .protocol import HttpParser
__all__ = ('MultipartReader', 'MultipartWriter',
'BodyPartReader', 'BodyPartWriter',
'BadContentDispositionHeader', 'BadContentDispositionParam',
'parse_content_disposition', 'content_disposition_filename')
CHAR = set(chr(i) for i in range(0, 128))
CTL = set(chr(i) for i in range(0, 32)) | {chr(127), }
SEPARATORS = {'(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']',
'?', '=', '{', '}', ' ', chr(9)}
TOKEN = CHAR ^ CTL ^ SEPARATORS
PY_35 = sys.version_info >= (3, 5)
PY_352 = sys.version_info >= (3, 5, 2)
class BadContentDispositionHeader(RuntimeWarning):
pass
class BadContentDispositionParam(RuntimeWarning):
pass
def parse_content_disposition(header):
def is_token(string):
return string and TOKEN >= set(string)
def is_quoted(string):
return string[0] == string[-1] == '"'
def is_rfc5987(string):
return is_token(string) and string.count("'") == 2
def is_extended_param(string):
return string.endswith('*')
def is_continuous_param(string):
pos = string.find('*') + 1
if not pos:
return False
substring = string[pos:-1] if string.endswith('*') else string[pos:]
return substring.isdigit()
def unescape(text, *, chars=''.join(map(re.escape, CHAR))):
return re.sub('\\\\([{}])'.format(chars), '\\1', text)
if not header:
return None, {}
disptype, *parts = header.split(';')
if not is_token(disptype):
warnings.warn(BadContentDispositionHeader(header))
return None, {}
params = {}
for item in parts:
if '=' not in item:
warnings.warn(BadContentDispositionHeader(header))
return None, {}
key, value = item.split('=', 1)
key = key.lower().strip()
value = value.lstrip()
if key in params:
warnings.warn(BadContentDispositionHeader(header))
return None, {}
if not is_token(key):
warnings.warn(BadContentDispositionParam(item))
continue
elif is_continuous_param(key):
if is_quoted(value):
value = unescape(value[1:-1])
elif not is_token(value):
warnings.warn(BadContentDispositionParam(item))
continue
elif is_extended_param(key):
if is_rfc5987(value):
encoding, _, value = value.split("'", 2)
encoding = encoding or 'utf-8'
else:
warnings.warn(BadContentDispositionParam(item))
continue
try:
value = unquote(value, encoding, 'strict')
except UnicodeDecodeError: # pragma: nocover
warnings.warn(BadContentDispositionParam(item))
continue
else:
if is_quoted(value):
value = unescape(value[1:-1].lstrip('\\/'))
elif not is_token(value):
warnings.warn(BadContentDispositionHeader(header))
return None, {}
params[key] = value
return disptype.lower(), params
def content_disposition_filename(params):
if not params:
return None
elif 'filename*' in params:
return params['filename*']
elif 'filename' in params:
return params['filename']
else:
parts = []
fnparams = sorted((key, value)
for key, value in params.items()
if key.startswith('filename*'))
for num, (key, value) in enumerate(fnparams):
_, tail = key.split('*', 1)
if tail.endswith('*'):
tail = tail[:-1]
if tail == str(num):
parts.append(value)
else:
break
if not parts:
return None
value = ''.join(parts)
if "'" in value:
encoding, _, value = value.split("'", 2)
encoding = encoding or 'utf-8'
return unquote(value, encoding, 'strict')
return value
class MultipartResponseWrapper(object):
"""Wrapper around the :class:`MultipartBodyReader` to take care about
underlying connection and close it when it needs in."""
def __init__(self, resp, stream):
self.resp = resp
self.stream = stream
if PY_35:
def __aiter__(self):
return self
if not PY_352: # pragma: no cover
__aiter__ = asyncio.coroutine(__aiter__)
@asyncio.coroutine
def __anext__(self):
part = yield from self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part
def at_eof(self):
"""Returns ``True`` when all response data had been read.
:rtype: bool
"""
return self.resp.content.at_eof()
@asyncio.coroutine
def next(self):
"""Emits next multipart reader object."""
item = yield from self.stream.next()
if self.stream.at_eof():
yield from self.release()
return item
@asyncio.coroutine
def release(self):
"""Releases the connection gracefully, reading all the content
to the void."""
yield from self.resp.release()
class BodyPartReader(object):
"""Multipart reader for single body part."""
chunk_size = 8192
def __init__(self, boundary, headers, content):
self.headers = headers
self._boundary = boundary
self._content = content
self._at_eof = False
length = self.headers.get(CONTENT_LENGTH, None)
self._length = int(length) if length is not None else None
self._read_bytes = 0
self._unread = deque()
self._prev_chunk = None
self._content_eof = 0
if PY_35:
def __aiter__(self):
return self
if not PY_352: # pragma: no cover
__aiter__ = asyncio.coroutine(__aiter__)
@asyncio.coroutine
def __anext__(self):
part = yield from self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part
@asyncio.coroutine
def next(self):
item = yield from self.read()
if not item:
return None
return item
@asyncio.coroutine
def read(self, *, decode=False):
"""Reads body part data.
:param bool decode: Decodes data following by encoding
method from `Content-Encoding` header. If it missed
data remains untouched
:rtype: bytearray
"""
if self._at_eof:
return b''
data = bytearray()
if self._length is None:
while not self._at_eof:
data.extend((yield from self.readline()))
else:
while not self._at_eof:
data.extend((yield from self.read_chunk(self.chunk_size)))
if decode:
return self.decode(data)
return data
@asyncio.coroutine
def read_chunk(self, size=chunk_size):
"""Reads body part content chunk of the specified size.
:param int size: chunk size
:rtype: bytearray
"""
if self._at_eof:
return b''
if self._length:
chunk = yield from self._read_chunk_from_length(size)
else:
chunk = yield from self._read_chunk_from_stream(size)
self._read_bytes += len(chunk)
if self._read_bytes == self._length:
self._at_eof = True
if self._at_eof:
assert b'\r\n' == (yield from self._content.readline()), \
'reader did not read all the data or it is malformed'
return chunk
@asyncio.coroutine
def _read_chunk_from_length(self, size):
"""Reads body part content chunk of the specified size.
The body part must has `Content-Length` header with proper value.
:param int size: chunk size
:rtype: bytearray
"""
assert self._length is not None, \
'Content-Length required for chunked read'
chunk_size = min(size, self._length - self._read_bytes)
chunk = yield from self._content.read(chunk_size)
return chunk
@asyncio.coroutine
def _read_chunk_from_stream(self, size):
"""Reads content chunk of body part with unknown length.
The `Content-Length` header for body part is not necessary.
:param int size: chunk size
:rtype: bytearray
"""
assert size >= len(self._boundary) + 2, \
'Chunk size must be greater or equal than boundary length + 2'
first_chunk = self._prev_chunk is None
if first_chunk:
self._prev_chunk = yield from self._content.read(size)
chunk = yield from self._content.read(size)
self._content_eof += int(self._content.at_eof())
assert self._content_eof < 3, "Reading after EOF"
window = self._prev_chunk + chunk
sub = b'\r\n' + self._boundary
if first_chunk:
idx = window.find(sub)
else:
idx = window.find(sub, max(0, len(self._prev_chunk) - len(sub)))
if idx >= 0:
# pushing boundary back to content
self._content.unread_data(window[idx:])
if size > idx:
self._prev_chunk = self._prev_chunk[:idx]
chunk = window[len(self._prev_chunk):idx]
if not chunk:
self._at_eof = True
if 0 < len(chunk) < len(sub) and not self._content_eof:
self._prev_chunk += chunk
self._at_eof = False
return b''
result = self._prev_chunk
self._prev_chunk = chunk
return result
@asyncio.coroutine
def readline(self):
"""Reads body part by line by line.
:rtype: bytearray
"""
if self._at_eof:
return b''
if self._unread:
line = self._unread.popleft()
else:
line = yield from self._content.readline()
if line.startswith(self._boundary):
# the very last boundary may not come with \r\n,
# so set single rules for everyone
sline = line.rstrip(b'\r\n')
boundary = self._boundary
last_boundary = self._boundary + b'--'
# ensure that we read exactly the boundary, not something alike
if sline == boundary or sline == last_boundary:
self._at_eof = True
self._unread.append(line)
return b''
else:
next_line = yield from self._content.readline()
if next_line.startswith(self._boundary):
line = line[:-2] # strip CRLF but only once
self._unread.append(next_line)
return line
@asyncio.coroutine
def release(self):
"""Like :meth:`read`, but reads all the data to the void.
:rtype: None
"""
if self._at_eof:
return
if self._length is None:
while not self._at_eof:
yield from self.readline()
else:
while not self._at_eof:
yield from self.read_chunk(self.chunk_size)
@asyncio.coroutine
def text(self, *, encoding=None):
"""Like :meth:`read`, but assumes that body part contains text data.
:param str encoding: Custom text encoding. Overrides specified
in charset param of `Content-Type` header
:rtype: str
"""
data = yield from self.read(decode=True)
encoding = encoding or self.get_charset(default='latin1')
return data.decode(encoding)
@asyncio.coroutine
def json(self, *, encoding=None):
"""Like :meth:`read`, but assumes that body parts contains JSON data.
:param str encoding: Custom JSON encoding. Overrides specified
in charset param of `Content-Type` header
"""
data = yield from self.read(decode=True)
if not data:
return None
encoding = encoding or self.get_charset(default='utf-8')
return json.loads(data.decode(encoding))
@asyncio.coroutine
def form(self, *, encoding=None):
"""Like :meth:`read`, but assumes that body parts contains form
urlencoded data.
:param str encoding: Custom form encoding. Overrides specified
in charset param of `Content-Type` header
"""
data = yield from self.read(decode=True)
if not data:
return None
encoding = encoding or self.get_charset(default='utf-8')
return parse_qsl(data.rstrip().decode(encoding), encoding=encoding)
def at_eof(self):
"""Returns ``True`` if the boundary was reached or
``False`` otherwise.
:rtype: bool
"""
return self._at_eof
def decode(self, data):
"""Decodes data according the specified `Content-Encoding`
or `Content-Transfer-Encoding` headers value.
Supports ``gzip``, ``deflate`` and ``identity`` encodings for
`Content-Encoding` header.
Supports ``base64``, ``quoted-printable``, ``binary`` encodings for
`Content-Transfer-Encoding` header.
:param bytearray data: Data to decode.
:raises: :exc:`RuntimeError` - if encoding is unknown.
:rtype: bytes
"""
if CONTENT_TRANSFER_ENCODING in self.headers:
data = self._decode_content_transfer(data)
if CONTENT_ENCODING in self.headers:
return self._decode_content(data)
return data
def _decode_content(self, data):
encoding = self.headers[CONTENT_ENCODING].lower()
if encoding == 'deflate':
return zlib.decompress(data, -zlib.MAX_WBITS)
elif encoding == 'gzip':
return zlib.decompress(data, 16 + zlib.MAX_WBITS)
elif encoding == 'identity':
return data
else:
raise RuntimeError('unknown content encoding: {}'.format(encoding))
def _decode_content_transfer(self, data):
encoding = self.headers[CONTENT_TRANSFER_ENCODING].lower()
if encoding == 'base64':
return base64.b64decode(data)
elif encoding == 'quoted-printable':
return binascii.a2b_qp(data)
elif encoding == 'binary':
return data
else:
raise RuntimeError('unknown content transfer encoding: {}'
''.format(encoding))
def get_charset(self, default=None):
"""Returns charset parameter from ``Content-Type`` header or default.
"""
ctype = self.headers.get(CONTENT_TYPE, '')
*_, params = parse_mimetype(ctype)
return params.get('charset', default)
@property
def filename(self):
"""Returns filename specified in Content-Disposition header or ``None``
if missed or header is malformed."""
_, params = parse_content_disposition(
self.headers.get(CONTENT_DISPOSITION))
return content_disposition_filename(params)
class MultipartReader(object):
"""Multipart body reader."""
#: Response wrapper, used when multipart readers constructs from response.
response_wrapper_cls = MultipartResponseWrapper
#: Multipart reader class, used to handle multipart/* body parts.
#: None points to type(self)
multipart_reader_cls = None
#: Body part reader class for non multipart/* content types.
part_reader_cls = BodyPartReader
def __init__(self, headers, content):
self.headers = headers
self._boundary = ('--' + self._get_boundary()).encode()
self._content = content
self._last_part = None
self._at_eof = False
self._at_bof = True
self._unread = []
if PY_35:
def __aiter__(self):
return self
if not PY_352: # pragma: no cover
__aiter__ = asyncio.coroutine(__aiter__)
@asyncio.coroutine
def __anext__(self):
part = yield from self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part
@classmethod
def from_response(cls, response):
"""Constructs reader instance from HTTP response.
:param response: :class:`~aiohttp.client.ClientResponse` instance
"""
obj = cls.response_wrapper_cls(response, cls(response.headers,
response.content))
return obj
def at_eof(self):
"""Returns ``True`` if the final boundary was reached or
``False`` otherwise.
:rtype: bool
"""
return self._at_eof
@asyncio.coroutine
def next(self):
"""Emits the next multipart body part."""
# So, if we're at BOF, we need to skip till the boundary.
if self._at_eof:
return
yield from self._maybe_release_last_part()
if self._at_bof:
yield from self._read_until_first_boundary()
self._at_bof = False
else:
yield from self._read_boundary()
if self._at_eof: # we just read the last boundary, nothing to do there
return
self._last_part = yield from self.fetch_next_part()
return self._last_part
@asyncio.coroutine
def release(self):
"""Reads all the body parts to the void till the final boundary."""
while not self._at_eof:
item = yield from self.next()
if item is None:
break
yield from item.release()
@asyncio.coroutine
def fetch_next_part(self):
"""Returns the next body part reader."""
headers = yield from self._read_headers()
return self._get_part_reader(headers)
def _get_part_reader(self, headers):
"""Dispatches the response by the `Content-Type` header, returning
suitable reader instance.
:param dict headers: Response headers
"""
ctype = headers.get(CONTENT_TYPE, '')
mtype, *_ = parse_mimetype(ctype)
if mtype == 'multipart':
if self.multipart_reader_cls is None:
return type(self)(headers, self._content)
return self.multipart_reader_cls(headers, self._content)
else:
return self.part_reader_cls(self._boundary, headers, self._content)
def _get_boundary(self):
mtype, *_, params = parse_mimetype(self.headers[CONTENT_TYPE])
assert mtype == 'multipart', 'multipart/* content type expected'
if 'boundary' not in params:
raise ValueError('boundary missed for Content-Type: %s'
% self.headers[CONTENT_TYPE])
boundary = params['boundary']
if len(boundary) > 70:
raise ValueError('boundary %r is too long (70 chars max)'
% boundary)
return boundary
@asyncio.coroutine
def _readline(self):
if self._unread:
return self._unread.pop()
return (yield from self._content.readline())
@asyncio.coroutine
def _read_until_first_boundary(self):
while True:
chunk = yield from self._readline()
if chunk == b'':
raise ValueError("Could not find starting boundary %r"
% (self._boundary))
chunk = chunk.rstrip()
if chunk == self._boundary:
return
elif chunk == self._boundary + b'--':
self._at_eof = True
return
@asyncio.coroutine
def _read_boundary(self):
chunk = (yield from self._readline()).rstrip()
if chunk == self._boundary:
pass
elif chunk == self._boundary + b'--':
self._at_eof = True
else:
raise ValueError('Invalid boundary %r, expected %r'
% (chunk, self._boundary))
@asyncio.coroutine
def _read_headers(self):
lines = [b'']
while True:
chunk = yield from self._content.readline()
chunk = chunk.strip()
lines.append(chunk)
if not chunk:
break
parser = HttpParser()
headers, *_ = parser.parse_headers(lines)
return headers
@asyncio.coroutine
def _maybe_release_last_part(self):
"""Ensures that the last read body part is read completely."""
if self._last_part is not None:
if not self._last_part.at_eof():
yield from self._last_part.release()
self._unread.extend(self._last_part._unread)
self._last_part = None
class BodyPartWriter(object):
"""Multipart writer for single body part."""
def __init__(self, obj, headers=None, *, chunk_size=8192):
if headers is None:
headers = CIMultiDict()
elif not isinstance(headers, CIMultiDict):
headers = CIMultiDict(headers)
self.obj = obj
self.headers = headers
self._chunk_size = chunk_size
self._fill_headers_with_defaults()
self._serialize_map = {
bytes: self._serialize_bytes,
str: self._serialize_str,
io.IOBase: self._serialize_io,
MultipartWriter: self._serialize_multipart,
('application', 'json'): self._serialize_json,
('application', 'x-www-form-urlencoded'): self._serialize_form
}
def _fill_headers_with_defaults(self):
if CONTENT_TYPE not in self.headers:
content_type = self._guess_content_type(self.obj)
if content_type is not None:
self.headers[CONTENT_TYPE] = content_type
if CONTENT_LENGTH not in self.headers:
content_length = self._guess_content_length(self.obj)
if content_length is not None:
self.headers[CONTENT_LENGTH] = str(content_length)
if CONTENT_DISPOSITION not in self.headers:
filename = self._guess_filename(self.obj)
if filename is not None:
self.set_content_disposition('attachment', filename=filename)
def _guess_content_length(self, obj):
if isinstance(obj, bytes):
return len(obj)
elif isinstance(obj, str):
*_, params = parse_mimetype(self.headers.get(CONTENT_TYPE))
charset = params.get('charset', 'us-ascii')
return len(obj.encode(charset))
elif isinstance(obj, io.StringIO):
*_, params = parse_mimetype(self.headers.get(CONTENT_TYPE))
charset = params.get('charset', 'us-ascii')
return len(obj.getvalue().encode(charset)) - obj.tell()
elif isinstance(obj, io.BytesIO):
return len(obj.getvalue()) - obj.tell()
elif isinstance(obj, io.IOBase):
try:
return os.fstat(obj.fileno()).st_size - obj.tell()
except (AttributeError, OSError):
return None
else:
return None
def _guess_content_type(self, obj, default='application/octet-stream'):
if hasattr(obj, 'name'):
name = getattr(obj, 'name')
return mimetypes.guess_type(name)[0]
elif isinstance(obj, (str, io.StringIO)):
return 'text/plain; charset=utf-8'
else:
return default
def _guess_filename(self, obj):
if isinstance(obj, io.IOBase):
name = getattr(obj, 'name', None)
if name is not None:
return Path(name).name
def serialize(self):
"""Yields byte chunks for body part."""
has_encoding = (
CONTENT_ENCODING in self.headers and
self.headers[CONTENT_ENCODING] != 'identity' or
CONTENT_TRANSFER_ENCODING in self.headers
)
if has_encoding:
# since we're following streaming approach which doesn't assumes
# any intermediate buffers, we cannot calculate real content length
# with the specified content encoding scheme. So, instead of lying
# about content length and cause reading issues, we have to strip
# this information.
self.headers.pop(CONTENT_LENGTH, None)
if self.headers:
yield b'\r\n'.join(
b': '.join(map(lambda i: i.encode('latin1'), item))
for item in self.headers.items()
)
yield b'\r\n\r\n'
yield from self._maybe_encode_stream(self._serialize_obj())
yield b'\r\n'
def _serialize_obj(self):
obj = self.obj
mtype, stype, *_ = parse_mimetype(self.headers.get(CONTENT_TYPE))
serializer = self._serialize_map.get((mtype, stype))
if serializer is not None:
return serializer(obj)
for key in self._serialize_map:
if not isinstance(key, tuple) and isinstance(obj, key):
return self._serialize_map[key](obj)
return self._serialize_default(obj)
def _serialize_bytes(self, obj):
yield obj
def _serialize_str(self, obj):
*_, params = parse_mimetype(self.headers.get(CONTENT_TYPE))
yield obj.encode(params.get('charset', 'us-ascii'))
def _serialize_io(self, obj):
while True:
chunk = obj.read(self._chunk_size)
if not chunk:
break
if isinstance(chunk, str):
yield from self._serialize_str(chunk)
else:
yield from self._serialize_bytes(chunk)
def _serialize_multipart(self, obj):
yield from obj.serialize()
def _serialize_json(self, obj):
*_, params = parse_mimetype(self.headers.get(CONTENT_TYPE))
yield json.dumps(obj).encode(params.get('charset', 'utf-8'))
def _serialize_form(self, obj):
if isinstance(obj, Mapping):
obj = list(obj.items())
return self._serialize_str(urlencode(obj, doseq=True))
def _serialize_default(self, obj):
raise TypeError('unknown body part type %r' % type(obj))
def _maybe_encode_stream(self, stream):
if CONTENT_ENCODING in self.headers:
stream = self._apply_content_encoding(stream)
if CONTENT_TRANSFER_ENCODING in self.headers:
stream = self._apply_content_transfer_encoding(stream)
yield from stream
def _apply_content_encoding(self, stream):
encoding = self.headers[CONTENT_ENCODING].lower()
if encoding == 'identity':
yield from stream
elif encoding in ('deflate', 'gzip'):
if encoding == 'gzip':
zlib_mode = 16 + zlib.MAX_WBITS
else:
zlib_mode = -zlib.MAX_WBITS
zcomp = zlib.compressobj(wbits=zlib_mode)
for chunk in stream:
yield zcomp.compress(chunk)
else:
yield zcomp.flush()
else:
raise RuntimeError('unknown content encoding: {}'
''.format(encoding))
def _apply_content_transfer_encoding(self, stream):
encoding = self.headers[CONTENT_TRANSFER_ENCODING].lower()
if encoding == 'base64':
buffer = bytearray()
while True:
if buffer:
div, mod = divmod(len(buffer), 3)
chunk, buffer = buffer[:div * 3], buffer[div * 3:]
if chunk:
yield base64.b64encode(chunk)
chunk = next(stream, None)
if not chunk:
if buffer:
yield base64.b64encode(buffer[:])
return
buffer.extend(chunk)
elif encoding == 'quoted-printable':
for chunk in stream:
yield binascii.b2a_qp(chunk)
elif encoding == 'binary':
yield from stream
else:
raise RuntimeError('unknown content transfer encoding: {}'
''.format(encoding))
def set_content_disposition(self, disptype, **params):
"""Sets ``Content-Disposition`` header.
:param str disptype: Disposition type: inline, attachment, form-data.
Should be valid extension token (see RFC 2183)
:param dict params: Disposition params
"""
if not disptype or not (TOKEN > set(disptype)):
raise ValueError('bad content disposition type {!r}'
''.format(disptype))
value = disptype
if params:
lparams = []
for key, val in params.items():
if not key or not (TOKEN > set(key)):
raise ValueError('bad content disposition parameter'
' {!r}={!r}'.format(key, val))
qval = quote(val, '')
lparams.append((key, '"%s"' % qval))
if key == 'filename':
lparams.append(('filename*', "utf-8''" + qval))
sparams = '; '.join('='.join(pair) for pair in lparams)
value = '; '.join((value, sparams))
self.headers[CONTENT_DISPOSITION] = value
@property
def filename(self):
"""Returns filename specified in Content-Disposition header or ``None``
if missed."""
_, params = parse_content_disposition(
self.headers.get(CONTENT_DISPOSITION))
return content_disposition_filename(params)
class MultipartWriter(object):
"""Multipart body writer."""
#: Body part reader class for non multipart/* content types.
part_writer_cls = BodyPartWriter
def __init__(self, subtype='mixed', boundary=None):
boundary = boundary if boundary is not None else uuid.uuid4().hex
try:
boundary.encode('us-ascii')
except UnicodeEncodeError:
raise ValueError('boundary should contains ASCII only chars')
self.headers = CIMultiDict()
self.headers[CONTENT_TYPE] = 'multipart/{}; boundary="{}"'.format(
subtype, boundary
)
self.parts = []
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def __iter__(self):
return iter(self.parts)
def __len__(self):
return len(self.parts)
@property
def boundary(self):
*_, params = parse_mimetype(self.headers.get(CONTENT_TYPE))
return params['boundary'].encode('us-ascii')
def append(self, obj, headers=None):
"""Adds a new body part to multipart writer."""
if isinstance(obj, self.part_writer_cls):
if headers:
obj.headers.update(headers)
self.parts.append(obj)
else:
if not headers:
headers = CIMultiDict()
self.parts.append(self.part_writer_cls(obj, headers))
return self.parts[-1]
def append_json(self, obj, headers=None):
"""Helper to append JSON part."""
if not headers:
headers = CIMultiDict()
headers[CONTENT_TYPE] = 'application/json'
return self.append(obj, headers)
def append_form(self, obj, headers=None):
"""Helper to append form urlencoded part."""
if not headers:
headers = CIMultiDict()
headers[CONTENT_TYPE] = 'application/x-www-form-urlencoded'
assert isinstance(obj, (Sequence, Mapping))
return self.append(obj, headers)
def serialize(self):
"""Yields multipart byte chunks."""
if not self.parts:
yield b''
return
for part in self.parts:
yield b'--' + self.boundary + b'\r\n'
yield from part.serialize()
else:
yield b'--' + self.boundary + b'--\r\n'
yield b''

View File

@ -0,0 +1,495 @@
"""Parser is a generator function (NOT coroutine).
Parser receives data with generator's send() method and sends data to
destination DataQueue. Parser receives ParserBuffer and DataQueue objects
as a parameters of the parser call, all subsequent send() calls should
send bytes objects. Parser sends parsed `term` to destination buffer with
DataQueue.feed_data() method. DataQueue object should implement two methods.
feed_data() - parser uses this method to send parsed protocol data.
feed_eof() - parser uses this method for indication of end of parsing stream.
To indicate end of incoming data stream EofStream exception should be sent
into parser. Parser could throw exceptions.
There are three stages:
* Data flow chain:
1. Application creates StreamParser object for storing incoming data.
2. StreamParser creates ParserBuffer as internal data buffer.
3. Application create parser and set it into stream buffer:
parser = HttpRequestParser()
data_queue = stream.set_parser(parser)
3. At this stage StreamParser creates DataQueue object and passes it
and internal buffer into parser as an arguments.
def set_parser(self, parser):
output = DataQueue()
self.p = parser(output, self._input)
return output
4. Application waits data on output.read()
while True:
msg = yield from output.read()
...
* Data flow:
1. asyncio's transport reads data from socket and sends data to protocol
with data_received() call.
2. Protocol sends data to StreamParser with feed_data() call.
3. StreamParser sends data into parser with generator's send() method.
4. Parser processes incoming data and sends parsed data
to DataQueue with feed_data()
5. Application received parsed data from DataQueue.read()
* Eof:
1. StreamParser receives eof with feed_eof() call.
2. StreamParser throws EofStream exception into parser.
3. Then it unsets parser.
_SocketSocketTransport ->
-> "protocol" -> StreamParser -> "parser" -> DataQueue <- "application"
"""
import asyncio
import asyncio.streams
import inspect
import socket
from . import errors
from .streams import EofStream, FlowControlDataQueue
__all__ = ('EofStream', 'StreamParser', 'StreamProtocol',
'ParserBuffer', 'StreamWriter')
DEFAULT_LIMIT = 2 ** 16
if hasattr(socket, 'TCP_CORK'): # pragma: no cover
CORK = socket.TCP_CORK
elif hasattr(socket, 'TCP_NOPUSH'): # pragma: no cover
CORK = socket.TCP_NOPUSH
else: # pragma: no cover
CORK = None
class StreamParser:
"""StreamParser manages incoming bytes stream and protocol parsers.
StreamParser uses ParserBuffer as internal buffer.
set_parser() sets current parser, it creates DataQueue object
and sends ParserBuffer and DataQueue into parser generator.
unset_parser() sends EofStream into parser and then removes it.
"""
def __init__(self, *, loop=None, buf=None,
limit=DEFAULT_LIMIT, eof_exc_class=RuntimeError, **kwargs):
self._loop = loop
self._eof = False
self._exception = None
self._parser = None
self._output = None
self._limit = limit
self._eof_exc_class = eof_exc_class
self._buffer = buf if buf is not None else ParserBuffer()
self.paused = False
self.transport = None
@property
def output(self):
return self._output
def set_transport(self, transport):
assert transport is None or self.transport is None, \
'Transport already set'
self.transport = transport
def at_eof(self):
return self._eof
def exception(self):
return self._exception
def set_exception(self, exc):
if isinstance(exc, ConnectionError):
exc, old_exc = self._eof_exc_class(), exc
exc.__cause__ = old_exc
exc.__context__ = old_exc
self._exception = exc
if self._output is not None:
self._output.set_exception(exc)
self._output = None
self._parser = None
def feed_data(self, data):
"""send data to current parser or store in buffer."""
if data is None:
return
if self._parser:
try:
self._parser.send(data)
except StopIteration:
self._output.feed_eof()
self._output = None
self._parser = None
except Exception as exc:
self._output.set_exception(exc)
self._output = None
self._parser = None
else:
self._buffer.feed_data(data)
def feed_eof(self):
"""send eof to all parsers, recursively."""
if self._parser:
try:
if self._buffer:
self._parser.send(b'')
self._parser.throw(EofStream())
except StopIteration:
self._output.feed_eof()
except EofStream:
self._output.set_exception(self._eof_exc_class())
except Exception as exc:
self._output.set_exception(exc)
self._parser = None
self._output = None
self._eof = True
def set_parser(self, parser, output=None):
"""set parser to stream. return parser's DataQueue."""
if self._parser:
self.unset_parser()
if output is None:
output = FlowControlDataQueue(
self, limit=self._limit, loop=self._loop)
if self._exception:
output.set_exception(self._exception)
return output
# init parser
p = parser(output, self._buffer)
assert inspect.isgenerator(p), 'Generator is required'
try:
# initialize parser with data and parser buffers
next(p)
except StopIteration:
pass
except Exception as exc:
output.set_exception(exc)
else:
# parser still require more data
self._parser = p
self._output = output
if self._eof:
self.unset_parser()
return output
def unset_parser(self):
"""unset parser, send eof to the parser and then remove it."""
if self._parser is None:
return
# TODO: write test
if self._loop.is_closed():
# TODO: log something
return
try:
self._parser.throw(EofStream())
except StopIteration:
self._output.feed_eof()
except EofStream:
self._output.set_exception(self._eof_exc_class())
except Exception as exc:
self._output.set_exception(exc)
finally:
self._output = None
self._parser = None
class StreamWriter(asyncio.streams.StreamWriter):
def __init__(self, transport, protocol, reader, loop):
self._transport = transport
self._protocol = protocol
self._reader = reader
self._loop = loop
self._tcp_nodelay = False
self._tcp_cork = False
self._socket = transport.get_extra_info('socket')
@property
def tcp_nodelay(self):
return self._tcp_nodelay
def set_tcp_nodelay(self, value):
value = bool(value)
if self._tcp_nodelay == value:
return
self._tcp_nodelay = value
if self._socket is None:
return
if self._socket.family not in (socket.AF_INET, socket.AF_INET6):
return
if self._tcp_cork:
self._tcp_cork = False
if CORK is not None: # pragma: no branch
self._socket.setsockopt(socket.IPPROTO_TCP, CORK, False)
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, value)
@property
def tcp_cork(self):
return self._tcp_cork
def set_tcp_cork(self, value):
value = bool(value)
if self._tcp_cork == value:
return
self._tcp_cork = value
if self._socket is None:
return
if self._socket.family not in (socket.AF_INET, socket.AF_INET6):
return
if self._tcp_nodelay:
self._socket.setsockopt(socket.IPPROTO_TCP,
socket.TCP_NODELAY,
False)
self._tcp_nodelay = False
if CORK is not None: # pragma: no branch
self._socket.setsockopt(socket.IPPROTO_TCP, CORK, value)
class StreamProtocol(asyncio.streams.FlowControlMixin, asyncio.Protocol):
"""Helper class to adapt between Protocol and StreamReader."""
def __init__(self, *, loop=None, disconnect_error=RuntimeError, **kwargs):
super().__init__(loop=loop)
self.transport = None
self.writer = None
self.reader = StreamParser(
loop=loop, eof_exc_class=disconnect_error, **kwargs)
def is_connected(self):
return self.transport is not None
def connection_made(self, transport):
self.transport = transport
self.reader.set_transport(transport)
self.writer = StreamWriter(transport, self, self.reader, self._loop)
def connection_lost(self, exc):
self.transport = self.writer = None
self.reader.set_transport(None)
if exc is None:
self.reader.feed_eof()
else:
self.reader.set_exception(exc)
super().connection_lost(exc)
def data_received(self, data):
self.reader.feed_data(data)
def eof_received(self):
self.reader.feed_eof()
class _ParserBufferHelper:
__slots__ = ('exception', 'data')
def __init__(self, exception, data):
self.exception = exception
self.data = data
class ParserBuffer:
"""ParserBuffer is NOT a bytearray extension anymore.
ParserBuffer provides helper methods for parsers.
"""
__slots__ = ('_helper', '_writer', '_data')
def __init__(self, *args):
self._data = bytearray(*args)
self._helper = _ParserBufferHelper(None, self._data)
self._writer = self._feed_data(self._helper)
next(self._writer)
def exception(self):
return self._helper.exception
def set_exception(self, exc):
self._helper.exception = exc
@staticmethod
def _feed_data(helper):
while True:
chunk = yield
if chunk:
helper.data.extend(chunk)
if helper.exception:
raise helper.exception
def feed_data(self, data):
if not self._helper.exception:
self._writer.send(data)
def read(self, size):
"""read() reads specified amount of bytes."""
while True:
if self._helper.exception:
raise self._helper.exception
if len(self._data) >= size:
data = self._data[:size]
del self._data[:size]
return data
self._writer.send((yield))
def readsome(self, size=None):
"""reads size of less amount of bytes."""
while True:
if self._helper.exception:
raise self._helper.exception
length = len(self._data)
if length > 0:
if size is None or length < size:
size = length
data = self._data[:size]
del self._data[:size]
return data
self._writer.send((yield))
def readuntil(self, stop, limit=None):
assert isinstance(stop, bytes) and stop, \
'bytes is required: {!r}'.format(stop)
stop_len = len(stop)
while True:
if self._helper.exception:
raise self._helper.exception
pos = self._data.find(stop)
if pos >= 0:
end = pos + stop_len
size = end
if limit is not None and size > limit:
raise errors.LineLimitExceededParserError(
'Line is too long.', limit)
data = self._data[:size]
del self._data[:size]
return data
else:
if limit is not None and len(self._data) > limit:
raise errors.LineLimitExceededParserError(
'Line is too long.', limit)
self._writer.send((yield))
def wait(self, size):
"""wait() waits for specified amount of bytes
then returns data without changing internal buffer."""
while True:
if self._helper.exception:
raise self._helper.exception
if len(self._data) >= size:
return self._data[:size]
self._writer.send((yield))
def waituntil(self, stop, limit=None):
"""waituntil() reads until `stop` bytes sequence."""
assert isinstance(stop, bytes) and stop, \
'bytes is required: {!r}'.format(stop)
stop_len = len(stop)
while True:
if self._helper.exception:
raise self._helper.exception
pos = self._data.find(stop)
if pos >= 0:
size = pos + stop_len
if limit is not None and size > limit:
raise errors.LineLimitExceededParserError(
'Line is too long. %s' % bytes(self._data), limit)
return self._data[:size]
else:
if limit is not None and len(self._data) > limit:
raise errors.LineLimitExceededParserError(
'Line is too long. %s' % bytes(self._data), limit)
self._writer.send((yield))
def skip(self, size):
"""skip() skips specified amount of bytes."""
while len(self._data) < size:
if self._helper.exception:
raise self._helper.exception
self._writer.send((yield))
del self._data[:size]
def skipuntil(self, stop):
"""skipuntil() reads until `stop` bytes sequence."""
assert isinstance(stop, bytes) and stop, \
'bytes is required: {!r}'.format(stop)
stop_len = len(stop)
while True:
if self._helper.exception:
raise self._helper.exception
stop_line = self._data.find(stop)
if stop_line >= 0:
size = stop_line + stop_len
del self._data[:size]
return
self._writer.send((yield))
def extend(self, data):
self._data.extend(data)
def __len__(self):
return len(self._data)
def __bytes__(self):
return bytes(self._data)

View File

@ -0,0 +1,916 @@
"""Http related parsers and protocol."""
import collections
import functools
import http.server
import re
import string
import sys
import zlib
from abc import ABC, abstractmethod
from wsgiref.handlers import format_date_time
from multidict import CIMultiDict, istr
import aiohttp
from . import errors, hdrs
from .helpers import reify
from .log import internal_logger
__all__ = ('HttpMessage', 'Request', 'Response',
'HttpVersion', 'HttpVersion10', 'HttpVersion11',
'RawRequestMessage', 'RawResponseMessage',
'HttpPrefixParser', 'HttpRequestParser', 'HttpResponseParser',
'HttpPayloadParser')
ASCIISET = set(string.printable)
METHRE = re.compile('[A-Z0-9$-_.]+')
VERSRE = re.compile('HTTP/(\d+).(\d+)')
HDRRE = re.compile(b'[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]')
EOF_MARKER = object()
EOL_MARKER = object()
STATUS_LINE_READY = object()
RESPONSES = http.server.BaseHTTPRequestHandler.responses
HttpVersion = collections.namedtuple(
'HttpVersion', ['major', 'minor'])
HttpVersion10 = HttpVersion(1, 0)
HttpVersion11 = HttpVersion(1, 1)
RawStatusLineMessage = collections.namedtuple(
'RawStatusLineMessage', ['method', 'path', 'version'])
RawRequestMessage = collections.namedtuple(
'RawRequestMessage',
['method', 'path', 'version', 'headers', 'raw_headers',
'should_close', 'compression'])
RawResponseMessage = collections.namedtuple(
'RawResponseMessage',
['version', 'code', 'reason', 'headers', 'raw_headers',
'should_close', 'compression'])
class HttpParser:
def __init__(self, max_line_size=8190, max_headers=32768,
max_field_size=8190):
self.max_line_size = max_line_size
self.max_headers = max_headers
self.max_field_size = max_field_size
def parse_headers(self, lines):
"""Parses RFC 5322 headers from a stream.
Line continuations are supported. Returns list of header name
and value pairs. Header name is in upper case.
"""
close_conn = None
encoding = None
headers = CIMultiDict()
raw_headers = []
lines_idx = 1
line = lines[1]
while line:
header_length = len(line)
# Parse initial header name : value pair.
try:
bname, bvalue = line.split(b':', 1)
except ValueError:
raise errors.InvalidHeader(line) from None
bname = bname.strip(b' \t').upper()
if HDRRE.search(bname):
raise errors.InvalidHeader(bname)
# next line
lines_idx += 1
line = lines[lines_idx]
# consume continuation lines
continuation = line and line[0] in (32, 9) # (' ', '\t')
if continuation:
bvalue = [bvalue]
while continuation:
header_length += len(line)
if header_length > self.max_field_size:
raise errors.LineTooLong(
'limit request headers fields size')
bvalue.append(line)
# next line
lines_idx += 1
line = lines[lines_idx]
continuation = line[0] in (32, 9) # (' ', '\t')
bvalue = b'\r\n'.join(bvalue)
else:
if header_length > self.max_field_size:
raise errors.LineTooLong(
'limit request headers fields size')
bvalue = bvalue.strip()
name = istr(bname.decode('utf-8', 'surrogateescape'))
value = bvalue.decode('utf-8', 'surrogateescape')
# keep-alive and encoding
if name == hdrs.CONNECTION:
v = value.lower()
if v == 'close':
close_conn = True
elif v == 'keep-alive':
close_conn = False
elif name == hdrs.CONTENT_ENCODING:
enc = value.lower()
if enc in ('gzip', 'deflate'):
encoding = enc
headers.add(name, value)
raw_headers.append((bname, bvalue))
return headers, raw_headers, close_conn, encoding
class HttpPrefixParser:
"""Waits for 'HTTP' prefix (non destructive)"""
def __init__(self, allowed_methods=()):
self.allowed_methods = [m.upper() for m in allowed_methods]
def __call__(self, out, buf):
raw_data = yield from buf.waituntil(b' ', 12)
method = raw_data.decode('ascii', 'surrogateescape').strip()
# method
method = method.upper()
if not METHRE.match(method):
raise errors.BadStatusLine(method)
# allowed method
if self.allowed_methods and method not in self.allowed_methods:
raise errors.HttpMethodNotAllowed(message=method)
out.feed_data(method, len(method))
out.feed_eof()
class HttpRequestParser(HttpParser):
"""Read request status line. Exception errors.BadStatusLine
could be raised in case of any errors in status line.
Returns RawRequestMessage.
"""
def __call__(self, out, buf):
# read HTTP message (request line + headers)
try:
raw_data = yield from buf.readuntil(
b'\r\n\r\n', self.max_headers)
except errors.LineLimitExceededParserError as exc:
raise errors.LineTooLong(exc.limit) from None
lines = raw_data.split(b'\r\n')
# request line
line = lines[0].decode('utf-8', 'surrogateescape')
try:
method, path, version = line.split(None, 2)
except ValueError:
raise errors.BadStatusLine(line) from None
# method
method = method.upper()
if not METHRE.match(method):
raise errors.BadStatusLine(method)
# version
try:
if version.startswith('HTTP/'):
n1, n2 = version[5:].split('.', 1)
version = HttpVersion(int(n1), int(n2))
else:
raise errors.BadStatusLine(version)
except:
raise errors.BadStatusLine(version)
# read headers
headers, raw_headers, close, compression = self.parse_headers(lines)
if close is None: # then the headers weren't set in the request
if version <= HttpVersion10: # HTTP 1.0 must asks to not close
close = True
else: # HTTP 1.1 must ask to close.
close = False
out.feed_data(
RawRequestMessage(
method, path, version, headers, raw_headers,
close, compression),
len(raw_data))
out.feed_eof()
class HttpResponseParser(HttpParser):
"""Read response status line and headers.
BadStatusLine could be raised in case of any errors in status line.
Returns RawResponseMessage"""
def __call__(self, out, buf):
# read HTTP message (response line + headers)
try:
raw_data = yield from buf.readuntil(
b'\r\n\r\n', self.max_line_size + self.max_headers)
except errors.LineLimitExceededParserError as exc:
raise errors.LineTooLong(exc.limit) from None
lines = raw_data.split(b'\r\n')
line = lines[0].decode('utf-8', 'surrogateescape')
try:
version, status = line.split(None, 1)
except ValueError:
raise errors.BadStatusLine(line) from None
else:
try:
status, reason = status.split(None, 1)
except ValueError:
reason = ''
# version
match = VERSRE.match(version)
if match is None:
raise errors.BadStatusLine(line)
version = HttpVersion(int(match.group(1)), int(match.group(2)))
# The status code is a three-digit number
try:
status = int(status)
except ValueError:
raise errors.BadStatusLine(line) from None
if status < 100 or status > 999:
raise errors.BadStatusLine(line)
# read headers
headers, raw_headers, close, compression = self.parse_headers(lines)
if close is None:
close = version <= HttpVersion10
out.feed_data(
RawResponseMessage(
version, status, reason.strip(),
headers, raw_headers, close, compression),
len(raw_data))
out.feed_eof()
class HttpPayloadParser:
def __init__(self, message, length=None, compression=True,
readall=False, response_with_body=True):
self.message = message
self.length = length
self.compression = compression
self.readall = readall
self.response_with_body = response_with_body
def __call__(self, out, buf):
# payload params
length = self.message.headers.get(hdrs.CONTENT_LENGTH, self.length)
if hdrs.SEC_WEBSOCKET_KEY1 in self.message.headers:
length = 8
# payload decompression wrapper
if (self.response_with_body and
self.compression and self.message.compression):
out = DeflateBuffer(out, self.message.compression)
# payload parser
if not self.response_with_body:
# don't parse payload if it's not expected to be received
pass
elif 'chunked' in self.message.headers.get(
hdrs.TRANSFER_ENCODING, ''):
yield from self.parse_chunked_payload(out, buf)
elif length is not None:
try:
length = int(length)
except ValueError:
raise errors.InvalidHeader(hdrs.CONTENT_LENGTH) from None
if length < 0:
raise errors.InvalidHeader(hdrs.CONTENT_LENGTH)
elif length > 0:
yield from self.parse_length_payload(out, buf, length)
else:
if self.readall and getattr(self.message, 'code', 0) != 204:
yield from self.parse_eof_payload(out, buf)
elif getattr(self.message, 'method', None) in ('PUT', 'POST'):
internal_logger.warning( # pragma: no cover
'Content-Length or Transfer-Encoding header is required')
out.feed_eof()
def parse_chunked_payload(self, out, buf):
"""Chunked transfer encoding parser."""
while True:
# read next chunk size
line = yield from buf.readuntil(b'\r\n', 8192)
i = line.find(b';')
if i >= 0:
line = line[:i] # strip chunk-extensions
else:
line = line.strip()
try:
size = int(line, 16)
except ValueError:
raise errors.TransferEncodingError(line) from None
if size == 0: # eof marker
break
# read chunk and feed buffer
while size:
chunk = yield from buf.readsome(size)
out.feed_data(chunk, len(chunk))
size = size - len(chunk)
# toss the CRLF at the end of the chunk
yield from buf.skip(2)
# read and discard trailer up to the CRLF terminator
yield from buf.skipuntil(b'\r\n')
def parse_length_payload(self, out, buf, length=0):
"""Read specified amount of bytes."""
required = length
while required:
chunk = yield from buf.readsome(required)
out.feed_data(chunk, len(chunk))
required -= len(chunk)
def parse_eof_payload(self, out, buf):
"""Read all bytes until eof."""
try:
while True:
chunk = yield from buf.readsome()
out.feed_data(chunk, len(chunk))
except aiohttp.EofStream:
pass
class DeflateBuffer:
"""DeflateStream decompress stream and feed data into specified stream."""
def __init__(self, out, encoding):
self.out = out
zlib_mode = (16 + zlib.MAX_WBITS
if encoding == 'gzip' else -zlib.MAX_WBITS)
self.zlib = zlib.decompressobj(wbits=zlib_mode)
def feed_data(self, chunk, size):
try:
chunk = self.zlib.decompress(chunk)
except Exception:
raise errors.ContentEncodingError('deflate')
if chunk:
self.out.feed_data(chunk, len(chunk))
def feed_eof(self):
chunk = self.zlib.flush()
self.out.feed_data(chunk, len(chunk))
if not self.zlib.eof:
raise errors.ContentEncodingError('deflate')
self.out.feed_eof()
def wrap_payload_filter(func):
"""Wraps payload filter and piped filters.
Filter is a generator that accepts arbitrary chunks of data,
modify data and emit new stream of data.
For example we have stream of chunks: ['1', '2', '3', '4', '5'],
we can apply chunking filter to this stream:
['1', '2', '3', '4', '5']
|
response.add_chunking_filter(2)
|
['12', '34', '5']
It is possible to use different filters at the same time.
For a example to compress incoming stream with 'deflate' encoding
and then split data and emit chunks of 8192 bytes size chunks:
>>> response.add_compression_filter('deflate')
>>> response.add_chunking_filter(8192)
Filters do not alter transfer encoding.
Filter can receive types types of data, bytes object or EOF_MARKER.
1. If filter receives bytes object, it should process data
and yield processed data then yield EOL_MARKER object.
2. If Filter received EOF_MARKER, it should yield remaining
data (buffered) and then yield EOF_MARKER.
"""
@functools.wraps(func)
def wrapper(self, *args, **kw):
new_filter = func(self, *args, **kw)
filter = self.filter
if filter is not None:
next(new_filter)
self.filter = filter_pipe(filter, new_filter)
else:
self.filter = new_filter
next(self.filter)
return wrapper
def filter_pipe(filter, filter2, *,
EOF_MARKER=EOF_MARKER, EOL_MARKER=EOL_MARKER):
"""Creates pipe between two filters.
filter_pipe() feeds first filter with incoming data and then
send yielded from first filter data into filter2, results of
filter2 are being emitted.
1. If filter_pipe receives bytes object, it sends it to the first filter.
2. Reads yielded values from the first filter until it receives
EOF_MARKER or EOL_MARKER.
3. Each of this values is being send to second filter.
4. Reads yielded values from second filter until it receives EOF_MARKER
or EOL_MARKER. Each of this values yields to writer.
"""
chunk = yield
while True:
eof = chunk is EOF_MARKER
chunk = filter.send(chunk)
while chunk is not EOL_MARKER:
chunk = filter2.send(chunk)
while chunk not in (EOF_MARKER, EOL_MARKER):
yield chunk
chunk = next(filter2)
if chunk is not EOF_MARKER:
if eof:
chunk = EOF_MARKER
else:
chunk = next(filter)
else:
break
chunk = yield EOL_MARKER
class HttpMessage(ABC):
"""HttpMessage allows to write headers and payload to a stream.
For example, lets say we want to read file then compress it with deflate
compression and then send it with chunked transfer encoding, code may look
like this:
>>> response = aiohttp.Response(transport, 200)
We have to use deflate compression first:
>>> response.add_compression_filter('deflate')
Then we want to split output stream into chunks of 1024 bytes size:
>>> response.add_chunking_filter(1024)
We can add headers to response with add_headers() method. add_headers()
does not send data to transport, send_headers() sends request/response
line and then sends headers:
>>> response.add_headers(
... ('Content-Disposition', 'attachment; filename="..."'))
>>> response.send_headers()
Now we can use chunked writer to write stream to a network stream.
First call to write() method sends response status line and headers,
add_header() and add_headers() method unavailable at this stage:
>>> with open('...', 'rb') as f:
... chunk = fp.read(8192)
... while chunk:
... response.write(chunk)
... chunk = fp.read(8192)
>>> response.write_eof()
"""
writer = None
# 'filter' is being used for altering write() behaviour,
# add_chunking_filter adds deflate/gzip compression and
# add_compression_filter splits incoming data into a chunks.
filter = None
HOP_HEADERS = None # Must be set by subclass.
SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} aiohttp/{1}'.format(
sys.version_info, aiohttp.__version__)
upgrade = False # Connection: UPGRADE
websocket = False # Upgrade: WEBSOCKET
has_chunked_hdr = False # Transfer-encoding: chunked
# subclass can enable auto sending headers with write() call,
# this is useful for wsgi's start_response implementation.
_send_headers = False
def __init__(self, transport, version, close):
self.transport = transport
self._version = version
self.closing = close
self.keepalive = None
self.chunked = False
self.length = None
self.headers = CIMultiDict()
self.headers_sent = False
self.output_length = 0
self.headers_length = 0
self._output_size = 0
@property
@abstractmethod
def status_line(self):
return b''
@abstractmethod
def autochunked(self):
return False
@property
def version(self):
return self._version
@property
def body_length(self):
return self.output_length - self.headers_length
def force_close(self):
self.closing = True
self.keepalive = False
def enable_chunked_encoding(self):
self.chunked = True
def keep_alive(self):
if self.keepalive is None:
if self.version < HttpVersion10:
# keep alive not supported at all
return False
if self.version == HttpVersion10:
if self.headers.get(hdrs.CONNECTION) == 'keep-alive':
return True
else: # no headers means we close for Http 1.0
return False
else:
return not self.closing
else:
return self.keepalive
def is_headers_sent(self):
return self.headers_sent
def add_header(self, name, value):
"""Analyze headers. Calculate content length,
removes hop headers, etc."""
assert not self.headers_sent, 'headers have been sent already'
assert isinstance(name, str), \
'Header name should be a string, got {!r}'.format(name)
assert set(name).issubset(ASCIISET), \
'Header name should contain ASCII chars, got {!r}'.format(name)
assert isinstance(value, str), \
'Header {!r} should have string value, got {!r}'.format(
name, value)
name = istr(name)
value = value.strip()
if name == hdrs.CONTENT_LENGTH:
self.length = int(value)
if name == hdrs.TRANSFER_ENCODING:
self.has_chunked_hdr = value.lower().strip() == 'chunked'
if name == hdrs.CONNECTION:
val = value.lower()
# handle websocket
if 'upgrade' in val:
self.upgrade = True
# connection keep-alive
elif 'close' in val:
self.keepalive = False
elif 'keep-alive' in val:
self.keepalive = True
elif name == hdrs.UPGRADE:
if 'websocket' in value.lower():
self.websocket = True
self.headers[name] = value
elif name not in self.HOP_HEADERS:
# ignore hop-by-hop headers
self.headers.add(name, value)
def add_headers(self, *headers):
"""Adds headers to a HTTP message."""
for name, value in headers:
self.add_header(name, value)
def send_headers(self, _sep=': ', _end='\r\n'):
"""Writes headers to a stream. Constructs payload writer."""
# Chunked response is only for HTTP/1.1 clients or newer
# and there is no Content-Length header is set.
# Do not use chunked responses when the response is guaranteed to
# not have a response body (304, 204).
assert not self.headers_sent, 'headers have been sent already'
self.headers_sent = True
if self.chunked or self.autochunked():
self.writer = self._write_chunked_payload()
self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'
elif self.length is not None:
self.writer = self._write_length_payload(self.length)
else:
self.writer = self._write_eof_payload()
next(self.writer)
self._add_default_headers()
# status + headers
headers = self.status_line + ''.join(
[k + _sep + v + _end for k, v in self.headers.items()])
headers = headers.encode('utf-8') + b'\r\n'
self.output_length += len(headers)
self.headers_length = len(headers)
self.transport.write(headers)
def _add_default_headers(self):
# set the connection header
connection = None
if self.upgrade:
connection = 'upgrade'
elif not self.closing if self.keepalive is None else self.keepalive:
if self.version == HttpVersion10:
connection = 'keep-alive'
else:
if self.version == HttpVersion11:
connection = 'close'
if connection is not None:
self.headers[hdrs.CONNECTION] = connection
def write(self, chunk, *,
drain=False, EOF_MARKER=EOF_MARKER, EOL_MARKER=EOL_MARKER):
"""Writes chunk of data to a stream by using different writers.
writer uses filter to modify chunk of data.
write_eof() indicates end of stream.
writer can't be used after write_eof() method being called.
write() return drain future.
"""
assert (isinstance(chunk, (bytes, bytearray)) or
chunk is EOF_MARKER), chunk
size = self.output_length
if self._send_headers and not self.headers_sent:
self.send_headers()
assert self.writer is not None, 'send_headers() is not called.'
if self.filter:
chunk = self.filter.send(chunk)
while chunk not in (EOF_MARKER, EOL_MARKER):
if chunk:
self.writer.send(chunk)
chunk = next(self.filter)
else:
if chunk is not EOF_MARKER:
self.writer.send(chunk)
self._output_size += self.output_length - size
if self._output_size > 64 * 1024:
if drain:
self._output_size = 0
return self.transport.drain()
return ()
def write_eof(self):
self.write(EOF_MARKER)
try:
self.writer.throw(aiohttp.EofStream())
except StopIteration:
pass
return self.transport.drain()
def _write_chunked_payload(self):
"""Write data in chunked transfer encoding."""
while True:
try:
chunk = yield
except aiohttp.EofStream:
self.transport.write(b'0\r\n\r\n')
self.output_length += 5
break
chunk = bytes(chunk)
chunk_len = '{:x}\r\n'.format(len(chunk)).encode('ascii')
self.transport.write(chunk_len + chunk + b'\r\n')
self.output_length += len(chunk_len) + len(chunk) + 2
def _write_length_payload(self, length):
"""Write specified number of bytes to a stream."""
while True:
try:
chunk = yield
except aiohttp.EofStream:
break
if length:
l = len(chunk)
if length >= l:
self.transport.write(chunk)
self.output_length += l
length = length-l
else:
self.transport.write(chunk[:length])
self.output_length += length
length = 0
def _write_eof_payload(self):
while True:
try:
chunk = yield
except aiohttp.EofStream:
break
self.transport.write(chunk)
self.output_length += len(chunk)
@wrap_payload_filter
def add_chunking_filter(self, chunk_size=16*1024, *,
EOF_MARKER=EOF_MARKER, EOL_MARKER=EOL_MARKER):
"""Split incoming stream into chunks."""
buf = bytearray()
chunk = yield
while True:
if chunk is EOF_MARKER:
if buf:
yield buf
yield EOF_MARKER
else:
buf.extend(chunk)
while len(buf) >= chunk_size:
chunk = bytes(buf[:chunk_size])
del buf[:chunk_size]
yield chunk
chunk = yield EOL_MARKER
@wrap_payload_filter
def add_compression_filter(self, encoding='deflate', *,
EOF_MARKER=EOF_MARKER, EOL_MARKER=EOL_MARKER):
"""Compress incoming stream with deflate or gzip encoding."""
zlib_mode = (16 + zlib.MAX_WBITS
if encoding == 'gzip' else -zlib.MAX_WBITS)
zcomp = zlib.compressobj(wbits=zlib_mode)
chunk = yield
while True:
if chunk is EOF_MARKER:
yield zcomp.flush()
chunk = yield EOF_MARKER
else:
yield zcomp.compress(chunk)
chunk = yield EOL_MARKER
class Response(HttpMessage):
"""Create HTTP response message.
Transport is a socket stream transport. status is a response status code,
status has to be integer value. http_version is a tuple that represents
HTTP version, (1, 0) stands for HTTP/1.0 and (1, 1) is for HTTP/1.1
"""
HOP_HEADERS = ()
@staticmethod
def calc_reason(status, *, _RESPONSES=RESPONSES):
record = _RESPONSES.get(status)
if record is not None:
reason = record[0]
else:
reason = str(status)
return reason
def __init__(self, transport, status,
http_version=HttpVersion11, close=False, reason=None):
super().__init__(transport, http_version, close)
self._status = status
if reason is None:
reason = self.calc_reason(status)
self._reason = reason
@property
def status(self):
return self._status
@property
def reason(self):
return self._reason
@reify
def status_line(self):
version = self.version
return 'HTTP/{}.{} {} {}\r\n'.format(
version[0], version[1], self.status, self.reason)
def autochunked(self):
return (self.length is None and
self.version >= HttpVersion11)
def _add_default_headers(self):
super()._add_default_headers()
if hdrs.DATE not in self.headers:
# format_date_time(None) is quite expensive
self.headers.setdefault(hdrs.DATE, format_date_time(None))
self.headers.setdefault(hdrs.SERVER, self.SERVER_SOFTWARE)
class Request(HttpMessage):
HOP_HEADERS = ()
def __init__(self, transport, method, path,
http_version=HttpVersion11, close=False):
# set the default for HTTP 0.9 to be different
# will only be overwritten with keep-alive header
if http_version < HttpVersion10:
close = True
super().__init__(transport, http_version, close)
self._method = method
self._path = path
@property
def method(self):
return self._method
@property
def path(self):
return self._path
@reify
def status_line(self):
return '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format(
self.method, self.path, self.version)
def autochunked(self):
return (self.length is None and
self.version >= HttpVersion11 and
self.status not in (304, 204))

View File

@ -0,0 +1,113 @@
import asyncio
import contextlib
import pytest
from aiohttp.web import Application
from .test_utils import unused_port as _unused_port
from .test_utils import (TestClient, TestServer, loop_context, setup_test_loop,
teardown_test_loop)
@contextlib.contextmanager
def _passthrough_loop_context(loop):
if loop:
# loop already exists, pass it straight through
yield loop
else:
# this shadows loop_context's standard behavior
loop = setup_test_loop()
yield loop
teardown_test_loop(loop)
def pytest_pycollect_makeitem(collector, name, obj):
"""
Fix pytest collecting for coroutines.
"""
if collector.funcnamefilter(name) and asyncio.iscoroutinefunction(obj):
return list(collector._genfunctions(name, obj))
def pytest_pyfunc_call(pyfuncitem):
"""
Run coroutines in an event loop instead of a normal function call.
"""
if asyncio.iscoroutinefunction(pyfuncitem.function):
existing_loop = pyfuncitem.funcargs.get('loop', None)
with _passthrough_loop_context(existing_loop) as _loop:
testargs = {arg: pyfuncitem.funcargs[arg]
for arg in pyfuncitem._fixtureinfo.argnames}
task = _loop.create_task(pyfuncitem.obj(**testargs))
_loop.run_until_complete(task)
return True
@pytest.yield_fixture
def loop():
with loop_context() as _loop:
yield _loop
@pytest.fixture
def unused_port():
return _unused_port
@pytest.yield_fixture
def test_server(loop):
servers = []
@asyncio.coroutine
def go(app, **kwargs):
assert app.loop is loop, \
"Application is attached to other event loop"
server = TestServer(app)
yield from server.start_server(**kwargs)
servers.append(server)
return server
yield go
@asyncio.coroutine
def finalize():
while servers:
yield from servers.pop().close()
loop.run_until_complete(finalize())
@pytest.yield_fixture
def test_client(loop):
clients = []
@asyncio.coroutine
def go(__param, *args, **kwargs):
if isinstance(__param, Application):
assert not args, "args should be empty"
assert not kwargs, "kwargs should be empty"
assert __param.loop is loop, \
"Application is attached to other event loop"
elif isinstance(__param, TestServer):
assert __param.app.loop is loop, \
"TestServer is attached to other event loop"
else:
__param = __param(loop, *args, **kwargs)
client = TestClient(__param)
yield from client.start_server()
clients.append(client)
return client
yield go
@asyncio.coroutine
def finalize():
while clients:
yield from clients.pop().close()
loop.run_until_complete(finalize())

View File

@ -0,0 +1,100 @@
import asyncio
import socket
from .abc import AbstractResolver
__all__ = ('ThreadedResolver', 'AsyncResolver', 'DefaultResolver')
try:
import aiodns
aiodns_default = hasattr(aiodns.DNSResolver, 'gethostbyname')
except ImportError: # pragma: no cover
aiodns = None
aiodns_default = False
class ThreadedResolver(AbstractResolver):
"""Use Executor for synchronous getaddrinfo() calls, which defaults to
concurrent.futures.ThreadPoolExecutor.
"""
def __init__(self, loop=None):
if loop is None:
loop = asyncio.get_event_loop()
self._loop = loop
@asyncio.coroutine
def resolve(self, host, port=0, family=socket.AF_INET):
infos = yield from self._loop.getaddrinfo(
host, port, type=socket.SOCK_STREAM, family=family)
hosts = []
for family, _, proto, _, address in infos:
hosts.append(
{'hostname': host,
'host': address[0], 'port': address[1],
'family': family, 'proto': proto,
'flags': socket.AI_NUMERICHOST})
return hosts
@asyncio.coroutine
def close(self):
pass
class AsyncResolver(AbstractResolver):
"""Use the `aiodns` package to make asynchronous DNS lookups"""
def __init__(self, loop=None, *args, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()
if aiodns is None:
raise RuntimeError("Resolver requires aiodns library")
self._loop = loop
self._resolver = aiodns.DNSResolver(*args, loop=loop, **kwargs)
if not hasattr(self._resolver, 'gethostbyname'):
# aiodns 1.1 is not available, fallback to DNSResolver.query
self.resolve = self.resolve_with_query
@asyncio.coroutine
def resolve(self, host, port=0, family=socket.AF_INET):
hosts = []
resp = yield from self._resolver.gethostbyname(host, family)
for address in resp.addresses:
hosts.append(
{'hostname': host,
'host': address, 'port': port,
'family': family, 'proto': 0,
'flags': socket.AI_NUMERICHOST})
return hosts
@asyncio.coroutine
def resolve_with_query(self, host, port=0, family=socket.AF_INET):
if family == socket.AF_INET6:
qtype = 'AAAA'
else:
qtype = 'A'
hosts = []
resp = yield from self._resolver.query(host, qtype)
for rr in resp:
hosts.append(
{'hostname': host,
'host': rr.host, 'port': port,
'family': family, 'proto': 0,
'flags': socket.AI_NUMERICHOST})
return hosts
@asyncio.coroutine
def close(self):
return self._resolver.cancel()
DefaultResolver = AsyncResolver if aiodns_default else ThreadedResolver

View File

@ -0,0 +1,376 @@
"""simple HTTP server."""
import asyncio
import http.server
import socket
import traceback
import warnings
from contextlib import suppress
from html import escape as html_escape
import aiohttp
from aiohttp import errors, hdrs, helpers, streams
from aiohttp.helpers import Timeout, _get_kwarg, ensure_future
from aiohttp.log import access_logger, server_logger
__all__ = ('ServerHttpProtocol',)
RESPONSES = http.server.BaseHTTPRequestHandler.responses
DEFAULT_ERROR_MESSAGE = """
<html>
<head>
<title>{status} {reason}</title>
</head>
<body>
<h1>{status} {reason}</h1>
{message}
</body>
</html>"""
if hasattr(socket, 'SO_KEEPALIVE'):
def tcp_keepalive(server, transport):
sock = transport.get_extra_info('socket')
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
else:
def tcp_keepalive(server, transport): # pragma: no cover
pass
EMPTY_PAYLOAD = streams.EmptyStreamReader()
class ServerHttpProtocol(aiohttp.StreamProtocol):
"""Simple HTTP protocol implementation.
ServerHttpProtocol handles incoming HTTP request. It reads request line,
request headers and request payload and calls handle_request() method.
By default it always returns with 404 response.
ServerHttpProtocol handles errors in incoming request, like bad
status line, bad headers or incomplete payload. If any error occurs,
connection gets closed.
:param keepalive_timeout: number of seconds before closing
keep-alive connection
:type keepalive_timeout: int or None
:param bool tcp_keepalive: TCP keep-alive is on, default is on
:param int slow_request_timeout: slow request timeout
:param bool debug: enable debug mode
:param logger: custom logger object
:type logger: aiohttp.log.server_logger
:param access_log: custom logging object
:type access_log: aiohttp.log.server_logger
:param str access_log_format: access log format string
:param loop: Optional event loop
:param int max_line_size: Optional maximum header line size
:param int max_field_size: Optional maximum header field size
:param int max_headers: Optional maximum header size
"""
_request_count = 0
_request_handler = None
_reading_request = False
_keepalive = False # keep transport open
def __init__(self, *, loop=None,
keepalive_timeout=75, # NGINX default value is 75 secs
tcp_keepalive=True,
slow_request_timeout=0,
logger=server_logger,
access_log=access_logger,
access_log_format=helpers.AccessLogger.LOG_FORMAT,
debug=False,
max_line_size=8190,
max_headers=32768,
max_field_size=8190,
**kwargs):
# process deprecated params
logger = _get_kwarg(kwargs, 'log', 'logger', logger)
tcp_keepalive = _get_kwarg(kwargs, 'keep_alive_on',
'tcp_keepalive', tcp_keepalive)
keepalive_timeout = _get_kwarg(kwargs, 'keep_alive',
'keepalive_timeout', keepalive_timeout)
slow_request_timeout = _get_kwarg(kwargs, 'timeout',
'slow_request_timeout',
slow_request_timeout)
super().__init__(
loop=loop,
disconnect_error=errors.ClientDisconnectedError, **kwargs)
self._tcp_keepalive = tcp_keepalive
self._keepalive_timeout = keepalive_timeout
self._slow_request_timeout = slow_request_timeout
self._loop = loop if loop is not None else asyncio.get_event_loop()
self._request_prefix = aiohttp.HttpPrefixParser()
self._request_parser = aiohttp.HttpRequestParser(
max_line_size=max_line_size,
max_field_size=max_field_size,
max_headers=max_headers)
self.logger = logger
self.debug = debug
self.access_log = access_log
if access_log:
self.access_logger = helpers.AccessLogger(access_log,
access_log_format)
else:
self.access_logger = None
self._closing = False
@property
def keep_alive_timeout(self):
warnings.warn("Use keepalive_timeout property instead",
DeprecationWarning,
stacklevel=2)
return self._keepalive_timeout
@property
def keepalive_timeout(self):
return self._keepalive_timeout
@asyncio.coroutine
def shutdown(self, timeout=15.0):
"""Worker process is about to exit, we need cleanup everything and
stop accepting requests. It is especially important for keep-alive
connections."""
if self._request_handler is None:
return
self._closing = True
if timeout:
canceller = self._loop.call_later(timeout,
self._request_handler.cancel)
with suppress(asyncio.CancelledError):
yield from self._request_handler
canceller.cancel()
else:
self._request_handler.cancel()
def connection_made(self, transport):
super().connection_made(transport)
self._request_handler = ensure_future(self.start(), loop=self._loop)
if self._tcp_keepalive:
tcp_keepalive(self, transport)
def connection_lost(self, exc):
super().connection_lost(exc)
self._closing = True
if self._request_handler is not None:
self._request_handler.cancel()
def data_received(self, data):
super().data_received(data)
# reading request
if not self._reading_request:
self._reading_request = True
def keep_alive(self, val):
"""Set keep-alive connection mode.
:param bool val: new state.
"""
self._keepalive = val
def log_access(self, message, environ, response, time):
if self.access_logger:
self.access_logger.log(message, environ, response,
self.transport, time)
def log_debug(self, *args, **kw):
if self.debug:
self.logger.debug(*args, **kw)
def log_exception(self, *args, **kw):
self.logger.exception(*args, **kw)
@asyncio.coroutine
def start(self):
"""Start processing of incoming requests.
It reads request line, request headers and request payload, then
calls handle_request() method. Subclass has to override
handle_request(). start() handles various exceptions in request
or response handling. Connection is being closed always unless
keep_alive(True) specified.
"""
reader = self.reader
try:
while not self._closing:
message = None
self._keepalive = False
self._request_count += 1
self._reading_request = False
payload = None
with Timeout(max(self._slow_request_timeout,
self._keepalive_timeout),
loop=self._loop):
# read HTTP request method
prefix = reader.set_parser(self._request_prefix)
yield from prefix.read()
# start reading request
self._reading_request = True
# start slow request timer
# read request headers
httpstream = reader.set_parser(self._request_parser)
message = yield from httpstream.read()
# request may not have payload
try:
content_length = int(
message.headers.get(hdrs.CONTENT_LENGTH, 0))
except ValueError:
raise errors.InvalidHeader(hdrs.CONTENT_LENGTH) from None
if (content_length > 0 or
message.method == 'CONNECT' or
hdrs.SEC_WEBSOCKET_KEY1 in message.headers or
'chunked' in message.headers.get(
hdrs.TRANSFER_ENCODING, '')):
payload = streams.FlowControlStreamReader(
reader, loop=self._loop)
reader.set_parser(
aiohttp.HttpPayloadParser(message), payload)
else:
payload = EMPTY_PAYLOAD
yield from self.handle_request(message, payload)
if payload and not payload.is_eof():
self.log_debug('Uncompleted request.')
self._closing = True
else:
reader.unset_parser()
if not self._keepalive or not self._keepalive_timeout:
self._closing = True
except asyncio.CancelledError:
self.log_debug(
'Request handler cancelled.')
return
except asyncio.TimeoutError:
self.log_debug(
'Request handler timed out.')
return
except errors.ClientDisconnectedError:
self.log_debug(
'Ignored premature client disconnection #1.')
return
except errors.HttpProcessingError as exc:
yield from self.handle_error(exc.code, message,
None, exc, exc.headers,
exc.message)
except Exception as exc:
yield from self.handle_error(500, message, None, exc)
finally:
self._request_handler = None
if self.transport is None:
self.log_debug(
'Ignored premature client disconnection #2.')
else:
self.transport.close()
def handle_error(self, status=500, message=None,
payload=None, exc=None, headers=None, reason=None):
"""Handle errors.
Returns HTTP response with specific status code. Logs additional
information. It always closes current connection."""
now = self._loop.time()
try:
if self.transport is None:
# client has been disconnected during writing.
return ()
if status == 500:
self.log_exception("Error handling request")
try:
if reason is None or reason == '':
reason, msg = RESPONSES[status]
else:
msg = reason
except KeyError:
status = 500
reason, msg = '???', ''
if self.debug and exc is not None:
try:
tb = traceback.format_exc()
tb = html_escape(tb)
msg += '<br><h2>Traceback:</h2>\n<pre>{}</pre>'.format(tb)
except:
pass
html = DEFAULT_ERROR_MESSAGE.format(
status=status, reason=reason, message=msg).encode('utf-8')
response = aiohttp.Response(self.writer, status, close=True)
response.add_header(hdrs.CONTENT_TYPE, 'text/html; charset=utf-8')
response.add_header(hdrs.CONTENT_LENGTH, str(len(html)))
if headers is not None:
for name, value in headers:
response.add_header(name, value)
response.send_headers()
response.write(html)
# disable CORK, enable NODELAY if needed
self.writer.set_tcp_nodelay(True)
drain = response.write_eof()
self.log_access(message, None, response, self._loop.time() - now)
return drain
finally:
self.keep_alive(False)
def handle_request(self, message, payload):
"""Handle a single HTTP request.
Subclass should override this method. By default it always
returns 404 response.
:param message: Request headers
:type message: aiohttp.protocol.HttpRequestParser
:param payload: Request payload
:type payload: aiohttp.streams.FlowControlStreamReader
"""
now = self._loop.time()
response = aiohttp.Response(
self.writer, 404, http_version=message.version, close=True)
body = b'Page Not Found!'
response.add_header(hdrs.CONTENT_TYPE, 'text/plain')
response.add_header(hdrs.CONTENT_LENGTH, str(len(body)))
response.send_headers()
response.write(body)
drain = response.write_eof()
self.keep_alive(False)
self.log_access(message, None, response, self._loop.time() - now)
return drain

View File

@ -0,0 +1,71 @@
import asyncio
from itertools import count
class BaseSignal(list):
@asyncio.coroutine
def _send(self, *args, **kwargs):
for receiver in self:
res = receiver(*args, **kwargs)
if asyncio.iscoroutine(res) or isinstance(res, asyncio.Future):
yield from res
def copy(self):
raise NotImplementedError("copy() is forbidden")
def sort(self):
raise NotImplementedError("sort() is forbidden")
class Signal(BaseSignal):
"""Coroutine-based signal implementation.
To connect a callback to a signal, use any list method.
Signals are fired using the :meth:`send` coroutine, which takes named
arguments.
"""
def __init__(self, app):
super().__init__()
self._app = app
klass = self.__class__
self._name = klass.__module__ + ':' + klass.__qualname__
self._pre = app.on_pre_signal
self._post = app.on_post_signal
@asyncio.coroutine
def send(self, *args, **kwargs):
"""
Sends data to all registered receivers.
"""
ordinal = None
debug = self._app._debug
if debug:
ordinal = self._pre.ordinal()
yield from self._pre.send(ordinal, self._name, *args, **kwargs)
yield from self._send(*args, **kwargs)
if debug:
yield from self._post.send(ordinal, self._name, *args, **kwargs)
class DebugSignal(BaseSignal):
@asyncio.coroutine
def send(self, ordinal, name, *args, **kwargs):
yield from self._send(ordinal, name, *args, **kwargs)
class PreSignal(DebugSignal):
def __init__(self):
super().__init__()
self._counter = count(1)
def ordinal(self):
return next(self._counter)
class PostSignal(DebugSignal):
pass

View File

@ -0,0 +1,672 @@
import asyncio
import collections
import functools
import sys
import traceback
from . import helpers
from .log import internal_logger
__all__ = (
'EofStream', 'StreamReader', 'DataQueue', 'ChunksQueue',
'FlowControlStreamReader',
'FlowControlDataQueue', 'FlowControlChunksQueue')
PY_35 = sys.version_info >= (3, 5)
PY_352 = sys.version_info >= (3, 5, 2)
EOF_MARKER = b''
DEFAULT_LIMIT = 2 ** 16
class EofStream(Exception):
"""eof stream indication."""
if PY_35:
class AsyncStreamIterator:
def __init__(self, read_func):
self.read_func = read_func
def __aiter__(self):
return self
if not PY_352: # pragma: no cover
__aiter__ = asyncio.coroutine(__aiter__)
@asyncio.coroutine
def __anext__(self):
try:
rv = yield from self.read_func()
except EofStream:
raise StopAsyncIteration # NOQA
if rv == EOF_MARKER:
raise StopAsyncIteration # NOQA
return rv
class AsyncStreamReaderMixin:
if PY_35:
def __aiter__(self):
return AsyncStreamIterator(self.readline)
if not PY_352: # pragma: no cover
__aiter__ = asyncio.coroutine(__aiter__)
def iter_chunked(self, n):
"""Returns an asynchronous iterator that yields chunks of size n.
Python-3.5 available for Python 3.5+ only
"""
return AsyncStreamIterator(lambda: self.read(n))
def iter_any(self):
"""Returns an asynchronous iterator that yields slices of data
as they come.
Python-3.5 available for Python 3.5+ only
"""
return AsyncStreamIterator(self.readany)
class StreamReader(AsyncStreamReaderMixin):
"""An enhancement of asyncio.StreamReader.
Supports asynchronous iteration by line, chunk or as available::
async for line in reader:
...
async for chunk in reader.iter_chunked(1024):
...
async for slice in reader.iter_any():
...
"""
total_bytes = 0
def __init__(self, limit=DEFAULT_LIMIT, timeout=None, loop=None):
self._limit = limit
if loop is None:
loop = asyncio.get_event_loop()
self._loop = loop
self._buffer = collections.deque()
self._buffer_size = 0
self._buffer_offset = 0
self._eof = False
self._waiter = None
self._canceller = None
self._eof_waiter = None
self._exception = None
self._timeout = timeout
def __repr__(self):
info = ['StreamReader']
if self._buffer_size:
info.append('%d bytes' % self._buffer_size)
if self._eof:
info.append('eof')
if self._limit != DEFAULT_LIMIT:
info.append('l=%d' % self._limit)
if self._waiter:
info.append('w=%r' % self._waiter)
if self._exception:
info.append('e=%r' % self._exception)
return '<%s>' % ' '.join(info)
def exception(self):
return self._exception
def set_exception(self, exc):
self._exception = exc
waiter = self._waiter
if waiter is not None:
self._waiter = None
if not waiter.cancelled():
waiter.set_exception(exc)
canceller = self._canceller
if canceller is not None:
self._canceller = None
canceller.cancel()
def feed_eof(self):
self._eof = True
waiter = self._waiter
if waiter is not None:
self._waiter = None
if not waiter.cancelled():
waiter.set_result(True)
canceller = self._canceller
if canceller is not None:
self._canceller = None
canceller.cancel()
waiter = self._eof_waiter
if waiter is not None:
self._eof_waiter = None
if not waiter.cancelled():
waiter.set_result(True)
def is_eof(self):
"""Return True if 'feed_eof' was called."""
return self._eof
def at_eof(self):
"""Return True if the buffer is empty and 'feed_eof' was called."""
return self._eof and not self._buffer
@asyncio.coroutine
def wait_eof(self):
if self._eof:
return
assert self._eof_waiter is None
self._eof_waiter = helpers.create_future(self._loop)
try:
yield from self._eof_waiter
finally:
self._eof_waiter = None
def unread_data(self, data):
""" rollback reading some data from stream, inserting it to buffer head.
"""
if not data:
return
if self._buffer_offset:
self._buffer[0] = self._buffer[0][self._buffer_offset:]
self._buffer_offset = 0
self._buffer.appendleft(data)
self._buffer_size += len(data)
def feed_data(self, data):
assert not self._eof, 'feed_data after feed_eof'
if not data:
return
self._buffer.append(data)
self._buffer_size += len(data)
self.total_bytes += len(data)
waiter = self._waiter
if waiter is not None:
self._waiter = None
if not waiter.cancelled():
waiter.set_result(False)
canceller = self._canceller
if canceller is not None:
self._canceller = None
canceller.cancel()
@asyncio.coroutine
def _wait(self, func_name):
# StreamReader uses a future to link the protocol feed_data() method
# to a read coroutine. Running two read coroutines at the same time
# would have an unexpected behaviour. It would not possible to know
# which coroutine would get the next data.
if self._waiter is not None:
raise RuntimeError('%s() called while another coroutine is '
'already waiting for incoming data' % func_name)
waiter = self._waiter = helpers.create_future(self._loop)
if self._timeout:
self._canceller = self._loop.call_later(self._timeout,
self.set_exception,
asyncio.TimeoutError())
try:
yield from waiter
finally:
self._waiter = None
if self._canceller is not None:
self._canceller.cancel()
self._canceller = None
@asyncio.coroutine
def readline(self):
if self._exception is not None:
raise self._exception
line = []
line_size = 0
not_enough = True
while not_enough:
while self._buffer and not_enough:
offset = self._buffer_offset
ichar = self._buffer[0].find(b'\n', offset) + 1
# Read from current offset to found b'\n' or to the end.
data = self._read_nowait_chunk(ichar - offset if ichar else -1)
line.append(data)
line_size += len(data)
if ichar:
not_enough = False
if line_size > self._limit:
raise ValueError('Line is too long')
if self._eof:
break
if not_enough:
yield from self._wait('readline')
return b''.join(line)
@asyncio.coroutine
def read(self, n=-1):
if self._exception is not None:
raise self._exception
# migration problem; with DataQueue you have to catch
# EofStream exception, so common way is to run payload.read() inside
# infinite loop. what can cause real infinite loop with StreamReader
# lets keep this code one major release.
if __debug__:
if self._eof and not self._buffer:
self._eof_counter = getattr(self, '_eof_counter', 0) + 1
if self._eof_counter > 5:
stack = traceback.format_stack()
internal_logger.warning(
'Multiple access to StreamReader in eof state, '
'might be infinite loop: \n%s', stack)
if not n:
return EOF_MARKER
if n < 0:
# This used to just loop creating a new waiter hoping to
# collect everything in self._buffer, but that would
# deadlock if the subprocess sends more than self.limit
# bytes. So just call self.readany() until EOF.
blocks = []
while True:
block = yield from self.readany()
if not block:
break
blocks.append(block)
return b''.join(blocks)
if not self._buffer and not self._eof:
yield from self._wait('read')
return self._read_nowait(n)
@asyncio.coroutine
def readany(self):
if self._exception is not None:
raise self._exception
if not self._buffer and not self._eof:
yield from self._wait('readany')
return self._read_nowait(-1)
@asyncio.coroutine
def readexactly(self, n):
if self._exception is not None:
raise self._exception
blocks = []
while n > 0:
block = yield from self.read(n)
if not block:
partial = b''.join(blocks)
raise asyncio.streams.IncompleteReadError(
partial, len(partial) + n)
blocks.append(block)
n -= len(block)
return b''.join(blocks)
def read_nowait(self, n=-1):
# default was changed to be consistent with .read(-1)
#
# I believe the most users don't know about the method and
# they are not affected.
assert n is not None, "n should be -1"
if self._exception is not None:
raise self._exception
if self._waiter and not self._waiter.done():
raise RuntimeError(
'Called while some coroutine is waiting for incoming data.')
return self._read_nowait(n)
def _read_nowait_chunk(self, n):
first_buffer = self._buffer[0]
offset = self._buffer_offset
if n != -1 and len(first_buffer) - offset > n:
data = first_buffer[offset:offset + n]
self._buffer_offset += n
elif offset:
self._buffer.popleft()
data = first_buffer[offset:]
self._buffer_offset = 0
else:
data = self._buffer.popleft()
self._buffer_size -= len(data)
return data
def _read_nowait(self, n):
chunks = []
while self._buffer:
chunk = self._read_nowait_chunk(n)
chunks.append(chunk)
if n != -1:
n -= len(chunk)
if n == 0:
break
return b''.join(chunks) if chunks else EOF_MARKER
class EmptyStreamReader(AsyncStreamReaderMixin):
def exception(self):
return None
def set_exception(self, exc):
pass
def feed_eof(self):
pass
def is_eof(self):
return True
def at_eof(self):
return True
@asyncio.coroutine
def wait_eof(self):
return
def feed_data(self, data):
pass
@asyncio.coroutine
def readline(self):
return EOF_MARKER
@asyncio.coroutine
def read(self, n=-1):
return EOF_MARKER
@asyncio.coroutine
def readany(self):
return EOF_MARKER
@asyncio.coroutine
def readexactly(self, n):
raise asyncio.streams.IncompleteReadError(b'', n)
def read_nowait(self):
return EOF_MARKER
class DataQueue:
"""DataQueue is a general-purpose blocking queue with one reader."""
def __init__(self, *, loop=None):
self._loop = loop
self._eof = False
self._waiter = None
self._exception = None
self._size = 0
self._buffer = collections.deque()
def is_eof(self):
return self._eof
def at_eof(self):
return self._eof and not self._buffer
def exception(self):
return self._exception
def set_exception(self, exc):
self._exception = exc
waiter = self._waiter
if waiter is not None:
self._waiter = None
if not waiter.done():
waiter.set_exception(exc)
def feed_data(self, data, size=0):
self._size += size
self._buffer.append((data, size))
waiter = self._waiter
if waiter is not None:
self._waiter = None
if not waiter.cancelled():
waiter.set_result(True)
def feed_eof(self):
self._eof = True
waiter = self._waiter
if waiter is not None:
self._waiter = None
if not waiter.cancelled():
waiter.set_result(False)
@asyncio.coroutine
def read(self):
if not self._buffer and not self._eof:
if self._exception is not None:
raise self._exception
assert not self._waiter
self._waiter = helpers.create_future(self._loop)
try:
yield from self._waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._waiter = None
raise
if self._buffer:
data, size = self._buffer.popleft()
self._size -= size
return data
else:
if self._exception is not None:
raise self._exception
else:
raise EofStream
if PY_35:
def __aiter__(self):
return AsyncStreamIterator(self.read)
if not PY_352: # pragma: no cover
__aiter__ = asyncio.coroutine(__aiter__)
class ChunksQueue(DataQueue):
"""Like a :class:`DataQueue`, but for binary chunked data transfer."""
@asyncio.coroutine
def read(self):
try:
return (yield from super().read())
except EofStream:
return EOF_MARKER
readany = read
def maybe_resume(func):
if asyncio.iscoroutinefunction(func):
@asyncio.coroutine
@functools.wraps(func)
def wrapper(self, *args, **kw):
result = yield from func(self, *args, **kw)
self._check_buffer_size()
return result
else:
@functools.wraps(func)
def wrapper(self, *args, **kw):
result = func(self, *args, **kw)
self._check_buffer_size()
return result
return wrapper
class FlowControlStreamReader(StreamReader):
def __init__(self, stream, limit=DEFAULT_LIMIT, *args, **kwargs):
super().__init__(*args, **kwargs)
self._stream = stream
self._b_limit = limit * 2
# resume transport reading
if stream.paused:
try:
self._stream.transport.resume_reading()
except (AttributeError, NotImplementedError):
pass
else:
self._stream.paused = False
def _check_buffer_size(self):
if self._stream.paused:
if self._buffer_size < self._b_limit:
try:
self._stream.transport.resume_reading()
except (AttributeError, NotImplementedError):
pass
else:
self._stream.paused = False
else:
if self._buffer_size > self._b_limit:
try:
self._stream.transport.pause_reading()
except (AttributeError, NotImplementedError):
pass
else:
self._stream.paused = True
def feed_data(self, data, size=0):
has_waiter = self._waiter is not None and not self._waiter.cancelled()
super().feed_data(data)
if (not self._stream.paused and
not has_waiter and self._buffer_size > self._b_limit):
try:
self._stream.transport.pause_reading()
except (AttributeError, NotImplementedError):
pass
else:
self._stream.paused = True
@maybe_resume
@asyncio.coroutine
def read(self, n=-1):
return (yield from super().read(n))
@maybe_resume
@asyncio.coroutine
def readline(self):
return (yield from super().readline())
@maybe_resume
@asyncio.coroutine
def readany(self):
return (yield from super().readany())
@maybe_resume
@asyncio.coroutine
def readexactly(self, n):
return (yield from super().readexactly(n))
@maybe_resume
def read_nowait(self, n=-1):
return super().read_nowait(n)
class FlowControlDataQueue(DataQueue):
"""FlowControlDataQueue resumes and pauses an underlying stream.
It is a destination for parsed data."""
def __init__(self, stream, *, limit=DEFAULT_LIMIT, loop=None):
super().__init__(loop=loop)
self._stream = stream
self._limit = limit * 2
# resume transport reading
if stream.paused:
try:
self._stream.transport.resume_reading()
except (AttributeError, NotImplementedError):
pass
else:
self._stream.paused = False
def feed_data(self, data, size):
has_waiter = self._waiter is not None and not self._waiter.cancelled()
super().feed_data(data, size)
if (not self._stream.paused and
not has_waiter and self._size > self._limit):
try:
self._stream.transport.pause_reading()
except (AttributeError, NotImplementedError):
pass
else:
self._stream.paused = True
@asyncio.coroutine
def read(self):
result = yield from super().read()
if self._stream.paused:
if self._size < self._limit:
try:
self._stream.transport.resume_reading()
except (AttributeError, NotImplementedError):
pass
else:
self._stream.paused = False
else:
if self._size > self._limit:
try:
self._stream.transport.pause_reading()
except (AttributeError, NotImplementedError):
pass
else:
self._stream.paused = True
return result
class FlowControlChunksQueue(FlowControlDataQueue):
@asyncio.coroutine
def read(self):
try:
return (yield from super().read())
except EofStream:
return EOF_MARKER
readany = read

View File

@ -0,0 +1,485 @@
"""Utilities shared by tests."""
import asyncio
import contextlib
import functools
import gc
import socket
import sys
import unittest
from unittest import mock
from multidict import CIMultiDict
import aiohttp
from . import ClientSession, hdrs
from .helpers import sentinel
from .protocol import HttpVersion, RawRequestMessage
from .signals import Signal
from .web import Application, Request
PY_35 = sys.version_info >= (3, 5)
def run_briefly(loop):
@asyncio.coroutine
def once():
pass
t = asyncio.Task(once(), loop=loop)
loop.run_until_complete(t)
def unused_port():
"""Return a port that is unused on the current host."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 0))
return s.getsockname()[1]
class TestServer:
def __init__(self, app, *, scheme="http", host='127.0.0.1'):
self.app = app
self._loop = app.loop
self.port = None
self.server = None
self.handler = None
self._root = None
self.host = host
self.scheme = scheme
self._closed = False
@asyncio.coroutine
def start_server(self, **kwargs):
if self.server:
return
self.port = unused_port()
self._root = '{}://{}:{}'.format(self.scheme, self.host, self.port)
self.handler = self.app.make_handler(**kwargs)
self.server = yield from self._loop.create_server(self.handler,
self.host,
self.port)
def make_url(self, path):
return self._root + path
@asyncio.coroutine
def close(self):
"""Close all fixtures created by the test client.
After that point, the TestClient is no longer usable.
This is an idempotent function: running close multiple times
will not have any additional effects.
close is also run when the object is garbage collected, and on
exit when used as a context manager.
"""
if self.server is not None and not self._closed:
self.server.close()
yield from self.server.wait_closed()
yield from self.app.shutdown()
yield from self.handler.finish_connections()
yield from self.app.cleanup()
self._root = None
self.port = None
self._closed = True
def __enter__(self):
self._loop.run_until_complete(self.start_server())
return self
def __exit__(self, exc_type, exc_value, traceback):
self._loop.run_until_complete(self.close())
if PY_35:
@asyncio.coroutine
def __aenter__(self):
yield from self.start_server()
return self
@asyncio.coroutine
def __aexit__(self, exc_type, exc_value, traceback):
yield from self.close()
class TestClient:
"""
A test client implementation, for a aiohttp.web.Application.
:param app: the aiohttp.web application passed to create_test_server
:type app: aiohttp.web.Application
:param protocol: http or https
:type protocol: str
TestClient can also be used as a contextmanager, returning
the instance of itself instantiated.
"""
def __init__(self, app_or_server, *, scheme=sentinel, host=sentinel):
if isinstance(app_or_server, TestServer):
if scheme is not sentinel or host is not sentinel:
raise ValueError("scheme and host are mutable exclusive "
"with TestServer parameter")
self._server = app_or_server
elif isinstance(app_or_server, Application):
scheme = "http" if scheme is sentinel else scheme
host = '127.0.0.1' if host is sentinel else host
self._server = TestServer(app_or_server,
scheme=scheme, host=host)
else:
raise TypeError("app_or_server should be either web.Application "
"or TestServer instance")
self._loop = self._server.app.loop
self._session = ClientSession(
loop=self._loop,
cookie_jar=aiohttp.CookieJar(unsafe=True,
loop=self._loop))
self._closed = False
self._responses = []
@asyncio.coroutine
def start_server(self):
yield from self._server.start_server()
@property
def app(self):
return self._server.app
@property
def host(self):
return self._server.host
@property
def port(self):
return self._server.port
@property
def handler(self):
return self._server.handler
@property
def server(self):
return self._server.server
@property
def session(self):
"""A raw handler to the aiohttp.ClientSession.
Unlike the methods on the TestClient, client session requests
do not automatically include the host in the url queried, and
will require an absolute path to the resource.
"""
return self._session
def make_url(self, path):
return self._server.make_url(path)
@asyncio.coroutine
def request(self, method, path, *args, **kwargs):
"""Routes a request to the http server.
The interface is identical to asyncio.ClientSession.request,
except the loop kwarg is overridden by the instance used by the
application.
"""
resp = yield from self._session.request(
method, self.make_url(path), *args, **kwargs
)
# save it to close later
self._responses.append(resp)
return resp
def get(self, path, *args, **kwargs):
"""Perform an HTTP GET request."""
return self.request(hdrs.METH_GET, path, *args, **kwargs)
def post(self, path, *args, **kwargs):
"""Perform an HTTP POST request."""
return self.request(hdrs.METH_POST, path, *args, **kwargs)
def options(self, path, *args, **kwargs):
"""Perform an HTTP OPTIONS request."""
return self.request(hdrs.METH_OPTIONS, path, *args, **kwargs)
def head(self, path, *args, **kwargs):
"""Perform an HTTP HEAD request."""
return self.request(hdrs.METH_HEAD, path, *args, **kwargs)
def put(self, path, *args, **kwargs):
"""Perform an HTTP PUT request."""
return self.request(hdrs.METH_PUT, path, *args, **kwargs)
def patch(self, path, *args, **kwargs):
"""Perform an HTTP PATCH request."""
return self.request(hdrs.METH_PATCH, path, *args, **kwargs)
def delete(self, path, *args, **kwargs):
"""Perform an HTTP PATCH request."""
return self.request(hdrs.METH_DELETE, path, *args, **kwargs)
def ws_connect(self, path, *args, **kwargs):
"""Initiate websocket connection.
The api is identical to aiohttp.ClientSession.ws_connect.
"""
return self._session.ws_connect(
self.make_url(path), *args, **kwargs
)
@asyncio.coroutine
def close(self):
"""Close all fixtures created by the test client.
After that point, the TestClient is no longer usable.
This is an idempotent function: running close multiple times
will not have any additional effects.
close is also run on exit when used as a(n) (asynchronous)
context manager.
"""
if not self._closed:
for resp in self._responses:
resp.close()
yield from self._session.close()
yield from self._server.close()
self._closed = True
def __enter__(self):
self._loop.run_until_complete(self.start_server())
return self
def __exit__(self, exc_type, exc_value, traceback):
self._loop.run_until_complete(self.close())
if PY_35:
@asyncio.coroutine
def __aenter__(self):
yield from self.start_server()
return self
@asyncio.coroutine
def __aexit__(self, exc_type, exc_value, traceback):
yield from self.close()
class AioHTTPTestCase(unittest.TestCase):
"""A base class to allow for unittest web applications using
aiohttp.
Provides the following:
* self.client (aiohttp.test_utils.TestClient): an aiohttp test client.
* self.loop (asyncio.BaseEventLoop): the event loop in which the
application and server are running.
* self.app (aiohttp.web.Application): the application returned by
self.get_app()
Note that the TestClient's methods are asynchronous: you have to
execute function on the test client using asynchronous methods.
"""
def get_app(self, loop):
"""
This method should be overridden
to return the aiohttp.web.Application
object to test.
:param loop: the event_loop to use
:type loop: asyncio.BaseEventLoop
"""
pass # pragma: no cover
def setUp(self):
self.loop = setup_test_loop()
self.app = self.get_app(self.loop)
self.client = TestClient(self.app)
self.loop.run_until_complete(self.client.start_server())
def tearDown(self):
self.loop.run_until_complete(self.client.close())
teardown_test_loop(self.loop)
def unittest_run_loop(func):
"""A decorator dedicated to use with asynchronous methods of an
AioHTTPTestCase.
Handles executing an asynchronous function, using
the self.loop of the AioHTTPTestCase.
"""
@functools.wraps(func)
def new_func(self):
return self.loop.run_until_complete(func(self))
return new_func
@contextlib.contextmanager
def loop_context(loop_factory=asyncio.new_event_loop):
"""A contextmanager that creates an event_loop, for test purposes.
Handles the creation and cleanup of a test loop.
"""
loop = setup_test_loop(loop_factory)
yield loop
teardown_test_loop(loop)
def setup_test_loop(loop_factory=asyncio.new_event_loop):
"""Create and return an asyncio.BaseEventLoop
instance.
The caller should also call teardown_test_loop,
once they are done with the loop.
"""
loop = loop_factory()
asyncio.set_event_loop(None)
return loop
def teardown_test_loop(loop):
"""Teardown and cleanup an event_loop created
by setup_test_loop.
:param loop: the loop to teardown
:type loop: asyncio.BaseEventLoop
"""
closed = loop.is_closed()
if not closed:
loop.call_soon(loop.stop)
loop.run_forever()
loop.close()
gc.collect()
asyncio.set_event_loop(None)
def _create_app_mock():
app = mock.Mock()
app._debug = False
app.on_response_prepare = Signal(app)
return app
def _create_transport(sslcontext=None):
transport = mock.Mock()
def get_extra_info(key):
if key == 'sslcontext':
return sslcontext
else:
return None
transport.get_extra_info.side_effect = get_extra_info
return transport
def make_mocked_request(method, path, headers=None, *,
version=HttpVersion(1, 1), closing=False,
app=None,
reader=sentinel,
writer=sentinel,
transport=sentinel,
payload=sentinel,
sslcontext=None,
secure_proxy_ssl_header=None):
"""Creates mocked web.Request testing purposes.
Useful in unit tests, when spinning full web server is overkill or
specific conditions and errors are hard to trigger.
:param method: str, that represents HTTP method, like; GET, POST.
:type method: str
:param path: str, The URL including *PATH INFO* without the host or scheme
:type path: str
:param headers: mapping containing the headers. Can be anything accepted
by the multidict.CIMultiDict constructor.
:type headers: dict, multidict.CIMultiDict, list of pairs
:param version: namedtuple with encoded HTTP version
:type version: aiohttp.protocol.HttpVersion
:param closing: flag indicates that connection should be closed after
response.
:type closing: bool
:param app: the aiohttp.web application attached for fake request
:type app: aiohttp.web.Application
:param reader: object for storing and managing incoming data
:type reader: aiohttp.parsers.StreamParser
:param writer: object for managing outcoming data
:type wirter: aiohttp.parsers.StreamWriter
:param transport: asyncio transport instance
:type transport: asyncio.transports.Transport
:param payload: raw payload reader object
:type payload: aiohttp.streams.FlowControlStreamReader
:param sslcontext: ssl.SSLContext object, for HTTPS connection
:type sslcontext: ssl.SSLContext
:param secure_proxy_ssl_header: A tuple representing a HTTP header/value
combination that signifies a request is secure.
:type secure_proxy_ssl_header: tuple
"""
if version < HttpVersion(1, 1):
closing = True
if headers:
hdrs = CIMultiDict(headers)
raw_hdrs = [
(k.encode('utf-8'), v.encode('utf-8')) for k, v in headers.items()]
else:
hdrs = CIMultiDict()
raw_hdrs = []
message = RawRequestMessage(method, path, version, hdrs,
raw_hdrs, closing, False)
if app is None:
app = _create_app_mock()
if reader is sentinel:
reader = mock.Mock()
if writer is sentinel:
writer = mock.Mock()
if transport is sentinel:
transport = _create_transport(sslcontext)
if payload is sentinel:
payload = mock.Mock()
req = Request(app, message, payload,
transport, reader, writer,
secure_proxy_ssl_header=secure_proxy_ssl_header)
return req
def make_mocked_coro(return_value=sentinel, raise_exception=sentinel):
"""Creates a coroutine mock."""
@asyncio.coroutine
def mock_coro(*args, **kwargs):
if raise_exception is not sentinel:
raise raise_exception
return return_value
return mock.Mock(wraps=mock_coro)

View File

@ -0,0 +1,376 @@
import asyncio
import sys
import warnings
from argparse import ArgumentParser
from importlib import import_module
from . import hdrs, web_exceptions, web_reqrep, web_urldispatcher, web_ws
from .abc import AbstractMatchInfo, AbstractRouter
from .helpers import sentinel
from .log import web_logger
from .protocol import HttpVersion # noqa
from .server import ServerHttpProtocol
from .signals import PostSignal, PreSignal, Signal
from .web_exceptions import * # noqa
from .web_reqrep import * # noqa
from .web_urldispatcher import * # noqa
from .web_ws import * # noqa
__all__ = (web_reqrep.__all__ +
web_exceptions.__all__ +
web_urldispatcher.__all__ +
web_ws.__all__ +
('Application', 'RequestHandler',
'RequestHandlerFactory', 'HttpVersion',
'MsgType'))
class RequestHandler(ServerHttpProtocol):
_meth = 'none'
_path = 'none'
def __init__(self, manager, app, router, *,
secure_proxy_ssl_header=None, **kwargs):
super().__init__(**kwargs)
self._manager = manager
self._app = app
self._router = router
self._middlewares = app.middlewares
self._secure_proxy_ssl_header = secure_proxy_ssl_header
def __repr__(self):
return "<{} {}:{} {}>".format(
self.__class__.__name__, self._meth, self._path,
'connected' if self.transport is not None else 'disconnected')
def connection_made(self, transport):
super().connection_made(transport)
self._manager.connection_made(self, transport)
def connection_lost(self, exc):
self._manager.connection_lost(self, exc)
super().connection_lost(exc)
@asyncio.coroutine
def handle_request(self, message, payload):
self._manager._requests_count += 1
if self.access_log:
now = self._loop.time()
app = self._app
request = web_reqrep.Request(
app, message, payload,
self.transport, self.reader, self.writer,
secure_proxy_ssl_header=self._secure_proxy_ssl_header)
self._meth = request.method
self._path = request.path
try:
match_info = yield from self._router.resolve(request)
assert isinstance(match_info, AbstractMatchInfo), match_info
resp = None
request._match_info = match_info
expect = request.headers.get(hdrs.EXPECT)
if expect:
resp = (
yield from match_info.expect_handler(request))
if resp is None:
handler = match_info.handler
for factory in reversed(self._middlewares):
handler = yield from factory(app, handler)
resp = yield from handler(request)
assert isinstance(resp, web_reqrep.StreamResponse), \
("Handler {!r} should return response instance, "
"got {!r} [middlewares {!r}]").format(
match_info.handler, type(resp), self._middlewares)
except web_exceptions.HTTPException as exc:
resp = exc
resp_msg = yield from resp.prepare(request)
yield from resp.write_eof()
# notify server about keep-alive
self.keep_alive(resp.keep_alive)
# log access
if self.access_log:
self.log_access(message, None, resp_msg, self._loop.time() - now)
# for repr
self._meth = 'none'
self._path = 'none'
class RequestHandlerFactory:
def __init__(self, app, router, *,
handler=RequestHandler, loop=None,
secure_proxy_ssl_header=None, **kwargs):
self._app = app
self._router = router
self._handler = handler
self._loop = loop
self._connections = {}
self._secure_proxy_ssl_header = secure_proxy_ssl_header
self._kwargs = kwargs
self._kwargs.setdefault('logger', app.logger)
self._requests_count = 0
@property
def requests_count(self):
"""Number of processed requests."""
return self._requests_count
@property
def secure_proxy_ssl_header(self):
return self._secure_proxy_ssl_header
@property
def connections(self):
return list(self._connections.keys())
def connection_made(self, handler, transport):
self._connections[handler] = transport
def connection_lost(self, handler, exc=None):
if handler in self._connections:
del self._connections[handler]
@asyncio.coroutine
def finish_connections(self, timeout=None):
coros = [conn.shutdown(timeout) for conn in self._connections]
yield from asyncio.gather(*coros, loop=self._loop)
self._connections.clear()
def __call__(self):
return self._handler(
self, self._app, self._router, loop=self._loop,
secure_proxy_ssl_header=self._secure_proxy_ssl_header,
**self._kwargs)
class Application(dict):
def __init__(self, *, logger=web_logger, loop=None,
router=None, handler_factory=RequestHandlerFactory,
middlewares=(), debug=False):
if loop is None:
loop = asyncio.get_event_loop()
if router is None:
router = web_urldispatcher.UrlDispatcher()
assert isinstance(router, AbstractRouter), router
self._debug = debug
self._router = router
self._handler_factory = handler_factory
self._loop = loop
self.logger = logger
self._middlewares = list(middlewares)
self._on_pre_signal = PreSignal()
self._on_post_signal = PostSignal()
self._on_response_prepare = Signal(self)
self._on_startup = Signal(self)
self._on_shutdown = Signal(self)
self._on_cleanup = Signal(self)
@property
def debug(self):
return self._debug
@property
def on_response_prepare(self):
return self._on_response_prepare
@property
def on_pre_signal(self):
return self._on_pre_signal
@property
def on_post_signal(self):
return self._on_post_signal
@property
def on_startup(self):
return self._on_startup
@property
def on_shutdown(self):
return self._on_shutdown
@property
def on_cleanup(self):
return self._on_cleanup
@property
def router(self):
return self._router
@property
def loop(self):
return self._loop
@property
def middlewares(self):
return self._middlewares
def make_handler(self, **kwargs):
debug = kwargs.pop('debug', sentinel)
if debug is not sentinel:
warnings.warn(
"`debug` parameter is deprecated. "
"Use Application's debug mode instead", DeprecationWarning)
if debug != self.debug:
raise ValueError(
"The value of `debug` parameter conflicts with the debug "
"settings of the `Application` instance. The "
"application's debug mode setting should be used instead "
"as a single point to setup a debug mode. For more "
"information please check "
"http://aiohttp.readthedocs.io/en/stable/"
"web_reference.html#aiohttp.web.Application"
)
return self._handler_factory(self, self.router, debug=self.debug,
loop=self.loop, **kwargs)
@asyncio.coroutine
def startup(self):
"""Causes on_startup signal
Should be called in the event loop along with the request handler.
"""
yield from self.on_startup.send(self)
@asyncio.coroutine
def shutdown(self):
"""Causes on_shutdown signal
Should be called before cleanup()
"""
yield from self.on_shutdown.send(self)
@asyncio.coroutine
def cleanup(self):
"""Causes on_cleanup signal
Should be called after shutdown()
"""
yield from self.on_cleanup.send(self)
@asyncio.coroutine
def finish(self):
"""Finalize an application.
Deprecated alias for .cleanup()
"""
warnings.warn("Use .cleanup() instead", DeprecationWarning)
yield from self.cleanup()
def register_on_finish(self, func, *args, **kwargs):
warnings.warn("Use .on_cleanup.append() instead", DeprecationWarning)
self.on_cleanup.append(lambda app: func(app, *args, **kwargs))
def copy(self):
raise NotImplementedError
def __call__(self):
"""gunicorn compatibility"""
return self
def __repr__(self):
return "<Application>"
def run_app(app, *, host='0.0.0.0', port=None,
shutdown_timeout=60.0, ssl_context=None,
print=print, backlog=128):
"""Run an app locally"""
if port is None:
if not ssl_context:
port = 8080
else:
port = 8443
loop = app.loop
handler = app.make_handler()
server = loop.create_server(handler, host, port, ssl=ssl_context,
backlog=backlog)
srv, startup_res = loop.run_until_complete(asyncio.gather(server,
app.startup(),
loop=loop))
scheme = 'https' if ssl_context else 'http'
print("======== Running on {scheme}://{host}:{port}/ ========\n"
"(Press CTRL+C to quit)".format(
scheme=scheme, host=host, port=port))
try:
loop.run_forever()
except KeyboardInterrupt: # pragma: no cover
pass
finally:
srv.close()
loop.run_until_complete(srv.wait_closed())
loop.run_until_complete(app.shutdown())
loop.run_until_complete(handler.finish_connections(shutdown_timeout))
loop.run_until_complete(app.cleanup())
loop.close()
def main(argv):
arg_parser = ArgumentParser(
description="aiohttp.web Application server",
prog="aiohttp.web"
)
arg_parser.add_argument(
"entry_func",
help=("Callable returning the `aiohttp.web.Application` instance to "
"run. Should be specified in the 'module:function' syntax."),
metavar="entry-func"
)
arg_parser.add_argument(
"-H", "--hostname",
help="TCP/IP hostname to serve on (default: %(default)r)",
default="localhost"
)
arg_parser.add_argument(
"-P", "--port",
help="TCP/IP port to serve on (default: %(default)r)",
type=int,
default="8080"
)
args, extra_argv = arg_parser.parse_known_args(argv)
# Import logic
mod_str, _, func_str = args.entry_func.partition(":")
if not func_str or not mod_str:
arg_parser.error(
"'entry-func' not in 'module:function' syntax"
)
if mod_str.startswith("."):
arg_parser.error("relative module names not supported")
try:
module = import_module(mod_str)
except ImportError:
arg_parser.error("module %r not found" % mod_str)
try:
func = getattr(module, func_str)
except AttributeError:
arg_parser.error("module %r has no attribute %r" % (mod_str, func_str))
app = func(extra_argv)
run_app(app, host=args.hostname, port=args.port)
arg_parser.exit(message="Stopped\n")
if __name__ == "__main__": # pragma: no branch
main(sys.argv[1:]) # pragma: no cover

View File

@ -0,0 +1,349 @@
from .web_reqrep import Response
__all__ = (
'HTTPException',
'HTTPError',
'HTTPRedirection',
'HTTPSuccessful',
'HTTPOk',
'HTTPCreated',
'HTTPAccepted',
'HTTPNonAuthoritativeInformation',
'HTTPNoContent',
'HTTPResetContent',
'HTTPPartialContent',
'HTTPMultipleChoices',
'HTTPMovedPermanently',
'HTTPFound',
'HTTPSeeOther',
'HTTPNotModified',
'HTTPUseProxy',
'HTTPTemporaryRedirect',
'HTTPPermanentRedirect',
'HTTPClientError',
'HTTPBadRequest',
'HTTPUnauthorized',
'HTTPPaymentRequired',
'HTTPForbidden',
'HTTPNotFound',
'HTTPMethodNotAllowed',
'HTTPNotAcceptable',
'HTTPProxyAuthenticationRequired',
'HTTPRequestTimeout',
'HTTPConflict',
'HTTPGone',
'HTTPLengthRequired',
'HTTPPreconditionFailed',
'HTTPRequestEntityTooLarge',
'HTTPRequestURITooLong',
'HTTPUnsupportedMediaType',
'HTTPRequestRangeNotSatisfiable',
'HTTPExpectationFailed',
'HTTPMisdirectedRequest',
'HTTPUpgradeRequired',
'HTTPPreconditionRequired',
'HTTPTooManyRequests',
'HTTPRequestHeaderFieldsTooLarge',
'HTTPUnavailableForLegalReasons',
'HTTPServerError',
'HTTPInternalServerError',
'HTTPNotImplemented',
'HTTPBadGateway',
'HTTPServiceUnavailable',
'HTTPGatewayTimeout',
'HTTPVersionNotSupported',
'HTTPVariantAlsoNegotiates',
'HTTPNotExtended',
'HTTPNetworkAuthenticationRequired',
)
############################################################
# HTTP Exceptions
############################################################
class HTTPException(Response, Exception):
# You should set in subclasses:
# status = 200
status_code = None
empty_body = False
def __init__(self, *, headers=None, reason=None,
body=None, text=None, content_type=None):
Response.__init__(self, status=self.status_code,
headers=headers, reason=reason,
body=body, text=text, content_type=content_type)
Exception.__init__(self, self.reason)
if self.body is None and not self.empty_body:
self.text = "{}: {}".format(self.status, self.reason)
class HTTPError(HTTPException):
"""Base class for exceptions with status codes in the 400s and 500s."""
class HTTPRedirection(HTTPException):
"""Base class for exceptions with status codes in the 300s."""
class HTTPSuccessful(HTTPException):
"""Base class for exceptions with status codes in the 200s."""
class HTTPOk(HTTPSuccessful):
status_code = 200
class HTTPCreated(HTTPSuccessful):
status_code = 201
class HTTPAccepted(HTTPSuccessful):
status_code = 202
class HTTPNonAuthoritativeInformation(HTTPSuccessful):
status_code = 203
class HTTPNoContent(HTTPSuccessful):
status_code = 204
empty_body = True
class HTTPResetContent(HTTPSuccessful):
status_code = 205
empty_body = True
class HTTPPartialContent(HTTPSuccessful):
status_code = 206
############################################################
# 3xx redirection
############################################################
class _HTTPMove(HTTPRedirection):
def __init__(self, location, *, headers=None, reason=None,
body=None, text=None, content_type=None):
if not location:
raise ValueError("HTTP redirects need a location to redirect to.")
super().__init__(headers=headers, reason=reason,
body=body, text=text, content_type=content_type)
self.headers['Location'] = location
self.location = location
class HTTPMultipleChoices(_HTTPMove):
status_code = 300
class HTTPMovedPermanently(_HTTPMove):
status_code = 301
class HTTPFound(_HTTPMove):
status_code = 302
# This one is safe after a POST (the redirected location will be
# retrieved with GET):
class HTTPSeeOther(_HTTPMove):
status_code = 303
class HTTPNotModified(HTTPRedirection):
# FIXME: this should include a date or etag header
status_code = 304
empty_body = True
class HTTPUseProxy(_HTTPMove):
# Not a move, but looks a little like one
status_code = 305
class HTTPTemporaryRedirect(_HTTPMove):
status_code = 307
class HTTPPermanentRedirect(_HTTPMove):
status_code = 308
############################################################
# 4xx client error
############################################################
class HTTPClientError(HTTPError):
pass
class HTTPBadRequest(HTTPClientError):
status_code = 400
class HTTPUnauthorized(HTTPClientError):
status_code = 401
class HTTPPaymentRequired(HTTPClientError):
status_code = 402
class HTTPForbidden(HTTPClientError):
status_code = 403
class HTTPNotFound(HTTPClientError):
status_code = 404
class HTTPMethodNotAllowed(HTTPClientError):
status_code = 405
def __init__(self, method, allowed_methods, *, headers=None, reason=None,
body=None, text=None, content_type=None):
allow = ','.join(sorted(allowed_methods))
super().__init__(headers=headers, reason=reason,
body=body, text=text, content_type=content_type)
self.headers['Allow'] = allow
self.allowed_methods = allowed_methods
self.method = method.upper()
class HTTPNotAcceptable(HTTPClientError):
status_code = 406
class HTTPProxyAuthenticationRequired(HTTPClientError):
status_code = 407
class HTTPRequestTimeout(HTTPClientError):
status_code = 408
class HTTPConflict(HTTPClientError):
status_code = 409
class HTTPGone(HTTPClientError):
status_code = 410
class HTTPLengthRequired(HTTPClientError):
status_code = 411
class HTTPPreconditionFailed(HTTPClientError):
status_code = 412
class HTTPRequestEntityTooLarge(HTTPClientError):
status_code = 413
class HTTPRequestURITooLong(HTTPClientError):
status_code = 414
class HTTPUnsupportedMediaType(HTTPClientError):
status_code = 415
class HTTPRequestRangeNotSatisfiable(HTTPClientError):
status_code = 416
class HTTPExpectationFailed(HTTPClientError):
status_code = 417
class HTTPMisdirectedRequest(HTTPClientError):
status_code = 421
class HTTPUpgradeRequired(HTTPClientError):
status_code = 426
class HTTPPreconditionRequired(HTTPClientError):
status_code = 428
class HTTPTooManyRequests(HTTPClientError):
status_code = 429
class HTTPRequestHeaderFieldsTooLarge(HTTPClientError):
status_code = 431
class HTTPUnavailableForLegalReasons(HTTPClientError):
status_code = 451
def __init__(self, link, *, headers=None, reason=None,
body=None, text=None, content_type=None):
super().__init__(headers=headers, reason=reason,
body=body, text=text, content_type=content_type)
self.headers['Link'] = '<%s>; rel="blocked-by"' % link
self.link = link
############################################################
# 5xx Server Error
############################################################
# Response status codes beginning with the digit "5" indicate cases in
# which the server is aware that it has erred or is incapable of
# performing the request. Except when responding to a HEAD request, the
# server SHOULD include an entity containing an explanation of the error
# situation, and whether it is a temporary or permanent condition. User
# agents SHOULD display any included entity to the user. These response
# codes are applicable to any request method.
class HTTPServerError(HTTPError):
pass
class HTTPInternalServerError(HTTPServerError):
status_code = 500
class HTTPNotImplemented(HTTPServerError):
status_code = 501
class HTTPBadGateway(HTTPServerError):
status_code = 502
class HTTPServiceUnavailable(HTTPServerError):
status_code = 503
class HTTPGatewayTimeout(HTTPServerError):
status_code = 504
class HTTPVersionNotSupported(HTTPServerError):
status_code = 505
class HTTPVariantAlsoNegotiates(HTTPServerError):
status_code = 506
class HTTPNotExtended(HTTPServerError):
status_code = 510
class HTTPNetworkAuthenticationRequired(HTTPServerError):
status_code = 511

View File

@ -0,0 +1,895 @@
import asyncio
import binascii
import cgi
import collections
import datetime
import enum
import http.cookies
import io
import json
import math
import time
import warnings
from email.utils import parsedate
from types import MappingProxyType
from urllib.parse import parse_qsl, unquote, urlsplit
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
from . import hdrs, multipart
from .helpers import reify, sentinel
from .protocol import Response as ResponseImpl
from .protocol import HttpVersion10, HttpVersion11
from .streams import EOF_MARKER
__all__ = (
'ContentCoding', 'Request', 'StreamResponse', 'Response',
'json_response'
)
class HeadersMixin:
_content_type = None
_content_dict = None
_stored_content_type = sentinel
def _parse_content_type(self, raw):
self._stored_content_type = raw
if raw is None:
# default value according to RFC 2616
self._content_type = 'application/octet-stream'
self._content_dict = {}
else:
self._content_type, self._content_dict = cgi.parse_header(raw)
@property
def content_type(self, _CONTENT_TYPE=hdrs.CONTENT_TYPE):
"""The value of content part for Content-Type HTTP header."""
raw = self.headers.get(_CONTENT_TYPE)
if self._stored_content_type != raw:
self._parse_content_type(raw)
return self._content_type
@property
def charset(self, _CONTENT_TYPE=hdrs.CONTENT_TYPE):
"""The value of charset part for Content-Type HTTP header."""
raw = self.headers.get(_CONTENT_TYPE)
if self._stored_content_type != raw:
self._parse_content_type(raw)
return self._content_dict.get('charset')
@property
def content_length(self, _CONTENT_LENGTH=hdrs.CONTENT_LENGTH):
"""The value of Content-Length HTTP header."""
l = self.headers.get(_CONTENT_LENGTH)
if l is None:
return None
else:
return int(l)
FileField = collections.namedtuple('Field', 'name filename file content_type')
class ContentCoding(enum.Enum):
# The content codings that we have support for.
#
# Additional registered codings are listed at:
# https://www.iana.org/assignments/http-parameters/http-parameters.xhtml#content-coding
deflate = 'deflate'
gzip = 'gzip'
identity = 'identity'
############################################################
# HTTP Request
############################################################
class Request(dict, HeadersMixin):
POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT,
hdrs.METH_TRACE, hdrs.METH_DELETE}
def __init__(self, app, message, payload, transport, reader, writer, *,
secure_proxy_ssl_header=None):
self._app = app
self._message = message
self._transport = transport
self._reader = reader
self._writer = writer
self._post = None
self._post_files_cache = None
# matchdict, route_name, handler
# or information about traversal lookup
self._match_info = None # initialized after route resolving
self._payload = payload
self._read_bytes = None
self._has_body = not payload.at_eof()
self._secure_proxy_ssl_header = secure_proxy_ssl_header
@reify
def scheme(self):
"""A string representing the scheme of the request.
'http' or 'https'.
"""
if self._transport.get_extra_info('sslcontext'):
return 'https'
secure_proxy_ssl_header = self._secure_proxy_ssl_header
if secure_proxy_ssl_header is not None:
header, value = secure_proxy_ssl_header
if self.headers.get(header) == value:
return 'https'
return 'http'
@reify
def method(self):
"""Read only property for getting HTTP method.
The value is upper-cased str like 'GET', 'POST', 'PUT' etc.
"""
return self._message.method
@reify
def version(self):
"""Read only property for getting HTTP version of request.
Returns aiohttp.protocol.HttpVersion instance.
"""
return self._message.version
@reify
def host(self):
"""Read only property for getting *HOST* header of request.
Returns str or None if HTTP request has no HOST header.
"""
return self._message.headers.get(hdrs.HOST)
@reify
def path_qs(self):
"""The URL including PATH_INFO and the query string.
E.g, /app/blog?id=10
"""
return self._message.path
@reify
def _splitted_path(self):
url = '{}://{}{}'.format(self.scheme, self.host, self.path_qs)
return urlsplit(url)
@reify
def raw_path(self):
""" The URL including raw *PATH INFO* without the host or scheme.
Warning, the path is unquoted and may contains non valid URL characters
E.g., ``/my%2Fpath%7Cwith%21some%25strange%24characters``
"""
return self._splitted_path.path
@reify
def path(self):
"""The URL including *PATH INFO* without the host or scheme.
E.g., ``/app/blog``
"""
return unquote(self.raw_path)
@reify
def query_string(self):
"""The query string in the URL.
E.g., id=10
"""
return self._splitted_path.query
@reify
def GET(self):
"""A multidict with all the variables in the query string.
Lazy property.
"""
return MultiDictProxy(MultiDict(parse_qsl(self.query_string,
keep_blank_values=True)))
@reify
def POST(self):
"""A multidict with all the variables in the POST parameters.
post() methods has to be called before using this attribute.
"""
if self._post is None:
raise RuntimeError("POST is not available before post()")
return self._post
@reify
def headers(self):
"""A case-insensitive multidict proxy with all headers."""
return CIMultiDictProxy(self._message.headers)
@reify
def raw_headers(self):
"""A sequence of pars for all headers."""
return tuple(self._message.raw_headers)
@reify
def if_modified_since(self, _IF_MODIFIED_SINCE=hdrs.IF_MODIFIED_SINCE):
"""The value of If-Modified-Since HTTP header, or None.
This header is represented as a `datetime` object.
"""
httpdate = self.headers.get(_IF_MODIFIED_SINCE)
if httpdate is not None:
timetuple = parsedate(httpdate)
if timetuple is not None:
return datetime.datetime(*timetuple[:6],
tzinfo=datetime.timezone.utc)
return None
@reify
def keep_alive(self):
"""Is keepalive enabled by client?"""
if self.version < HttpVersion10:
return False
else:
return not self._message.should_close
@property
def match_info(self):
"""Result of route resolving."""
return self._match_info
@property
def app(self):
"""Application instance."""
return self._app
@property
def transport(self):
"""Transport used for request processing."""
return self._transport
@reify
def cookies(self):
"""Return request cookies.
A read-only dictionary-like object.
"""
raw = self.headers.get(hdrs.COOKIE, '')
parsed = http.cookies.SimpleCookie(raw)
return MappingProxyType(
{key: val.value for key, val in parsed.items()})
@property
def content(self):
"""Return raw payload stream."""
return self._payload
@property
def has_body(self):
"""Return True if request has HTTP BODY, False otherwise."""
return self._has_body
@asyncio.coroutine
def release(self):
"""Release request.
Eat unread part of HTTP BODY if present.
"""
chunk = yield from self._payload.readany()
while chunk is not EOF_MARKER or chunk:
chunk = yield from self._payload.readany()
@asyncio.coroutine
def read(self):
"""Read request body if present.
Returns bytes object with full request content.
"""
if self._read_bytes is None:
body = bytearray()
while True:
chunk = yield from self._payload.readany()
body.extend(chunk)
if chunk is EOF_MARKER:
break
self._read_bytes = bytes(body)
return self._read_bytes
@asyncio.coroutine
def text(self):
"""Return BODY as text using encoding from .charset."""
bytes_body = yield from self.read()
encoding = self.charset or 'utf-8'
return bytes_body.decode(encoding)
@asyncio.coroutine
def json(self, *, loads=json.loads, loader=None):
"""Return BODY as JSON."""
if loader is not None:
warnings.warn(
"Using loader argument is deprecated, use loads instead",
DeprecationWarning)
loads = loader
body = yield from self.text()
return loads(body)
@asyncio.coroutine
def multipart(self, *, reader=multipart.MultipartReader):
"""Return async iterator to process BODY as multipart."""
return reader(self.headers, self.content)
@asyncio.coroutine
def post(self):
"""Return POST parameters."""
if self._post is not None:
return self._post
if self.method not in self.POST_METHODS:
self._post = MultiDictProxy(MultiDict())
return self._post
content_type = self.content_type
if (content_type not in ('',
'application/x-www-form-urlencoded',
'multipart/form-data')):
self._post = MultiDictProxy(MultiDict())
return self._post
if self.content_type.startswith('multipart/'):
warnings.warn('To process multipart requests use .multipart'
' coroutine instead.', DeprecationWarning)
body = yield from self.read()
content_charset = self.charset or 'utf-8'
environ = {'REQUEST_METHOD': self.method,
'CONTENT_LENGTH': str(len(body)),
'QUERY_STRING': '',
'CONTENT_TYPE': self.headers.get(hdrs.CONTENT_TYPE)}
fs = cgi.FieldStorage(fp=io.BytesIO(body),
environ=environ,
keep_blank_values=True,
encoding=content_charset)
supported_transfer_encoding = {
'base64': binascii.a2b_base64,
'quoted-printable': binascii.a2b_qp
}
out = MultiDict()
_count = 1
for field in fs.list or ():
transfer_encoding = field.headers.get(
hdrs.CONTENT_TRANSFER_ENCODING, None)
if field.filename:
ff = FileField(field.name,
field.filename,
field.file, # N.B. file closed error
field.type)
if self._post_files_cache is None:
self._post_files_cache = {}
self._post_files_cache[field.name+str(_count)] = field
_count += 1
out.add(field.name, ff)
else:
value = field.value
if transfer_encoding in supported_transfer_encoding:
# binascii accepts bytes
value = value.encode('utf-8')
value = supported_transfer_encoding[
transfer_encoding](value)
out.add(field.name, value)
self._post = MultiDictProxy(out)
return self._post
def copy(self):
raise NotImplementedError
def __repr__(self):
ascii_encodable_path = self.path.encode('ascii', 'backslashreplace') \
.decode('ascii')
return "<{} {} {} >".format(self.__class__.__name__,
self.method, ascii_encodable_path)
############################################################
# HTTP Response classes
############################################################
class StreamResponse(HeadersMixin):
def __init__(self, *, status=200, reason=None, headers=None):
self._body = None
self._keep_alive = None
self._chunked = False
self._chunk_size = None
self._compression = False
self._compression_force = False
self._headers = CIMultiDict()
self._cookies = http.cookies.SimpleCookie()
self.set_status(status, reason)
self._req = None
self._resp_impl = None
self._eof_sent = False
self._tcp_nodelay = True
self._tcp_cork = False
if headers is not None:
self._headers.extend(headers)
self._parse_content_type(self._headers.get(hdrs.CONTENT_TYPE))
self._generate_content_type_header()
def _copy_cookies(self):
for cookie in self._cookies.values():
value = cookie.output(header='')[1:]
self.headers.add(hdrs.SET_COOKIE, value)
@property
def prepared(self):
return self._resp_impl is not None
@property
def started(self):
warnings.warn('use Response.prepared instead', DeprecationWarning)
return self.prepared
@property
def status(self):
return self._status
@property
def chunked(self):
return self._chunked
@property
def compression(self):
return self._compression
@property
def reason(self):
return self._reason
def set_status(self, status, reason=None):
self._status = int(status)
if reason is None:
reason = ResponseImpl.calc_reason(status)
self._reason = reason
@property
def keep_alive(self):
return self._keep_alive
def force_close(self):
self._keep_alive = False
def enable_chunked_encoding(self, chunk_size=None):
"""Enables automatic chunked transfer encoding."""
self._chunked = True
self._chunk_size = chunk_size
def enable_compression(self, force=None):
"""Enables response compression encoding."""
# Backwards compatibility for when force was a bool <0.17.
if type(force) == bool:
force = ContentCoding.deflate if force else ContentCoding.identity
elif force is not None:
assert isinstance(force, ContentCoding), ("force should one of "
"None, bool or "
"ContentEncoding")
self._compression = True
self._compression_force = force
@property
def headers(self):
return self._headers
@property
def cookies(self):
return self._cookies
def set_cookie(self, name, value, *, expires=None,
domain=None, max_age=None, path='/',
secure=None, httponly=None, version=None):
"""Set or update response cookie.
Sets new cookie or updates existent with new value.
Also updates only those params which are not None.
"""
old = self._cookies.get(name)
if old is not None and old.coded_value == '':
# deleted cookie
self._cookies.pop(name, None)
self._cookies[name] = value
c = self._cookies[name]
if expires is not None:
c['expires'] = expires
elif c.get('expires') == 'Thu, 01 Jan 1970 00:00:00 GMT':
del c['expires']
if domain is not None:
c['domain'] = domain
if max_age is not None:
c['max-age'] = max_age
elif 'max-age' in c:
del c['max-age']
c['path'] = path
if secure is not None:
c['secure'] = secure
if httponly is not None:
c['httponly'] = httponly
if version is not None:
c['version'] = version
def del_cookie(self, name, *, domain=None, path='/'):
"""Delete cookie.
Creates new empty expired cookie.
"""
# TODO: do we need domain/path here?
self._cookies.pop(name, None)
self.set_cookie(name, '', max_age=0,
expires="Thu, 01 Jan 1970 00:00:00 GMT",
domain=domain, path=path)
@property
def content_length(self):
# Just a placeholder for adding setter
return super().content_length
@content_length.setter
def content_length(self, value):
if value is not None:
value = int(value)
# TODO: raise error if chunked enabled
self.headers[hdrs.CONTENT_LENGTH] = str(value)
else:
self.headers.pop(hdrs.CONTENT_LENGTH, None)
@property
def content_type(self):
# Just a placeholder for adding setter
return super().content_type
@content_type.setter
def content_type(self, value):
self.content_type # read header values if needed
self._content_type = str(value)
self._generate_content_type_header()
@property
def charset(self):
# Just a placeholder for adding setter
return super().charset
@charset.setter
def charset(self, value):
ctype = self.content_type # read header values if needed
if ctype == 'application/octet-stream':
raise RuntimeError("Setting charset for application/octet-stream "
"doesn't make sense, setup content_type first")
if value is None:
self._content_dict.pop('charset', None)
else:
self._content_dict['charset'] = str(value).lower()
self._generate_content_type_header()
@property
def last_modified(self, _LAST_MODIFIED=hdrs.LAST_MODIFIED):
"""The value of Last-Modified HTTP header, or None.
This header is represented as a `datetime` object.
"""
httpdate = self.headers.get(_LAST_MODIFIED)
if httpdate is not None:
timetuple = parsedate(httpdate)
if timetuple is not None:
return datetime.datetime(*timetuple[:6],
tzinfo=datetime.timezone.utc)
return None
@last_modified.setter
def last_modified(self, value):
if value is None:
self.headers.pop(hdrs.LAST_MODIFIED, None)
elif isinstance(value, (int, float)):
self.headers[hdrs.LAST_MODIFIED] = time.strftime(
"%a, %d %b %Y %H:%M:%S GMT", time.gmtime(math.ceil(value)))
elif isinstance(value, datetime.datetime):
self.headers[hdrs.LAST_MODIFIED] = time.strftime(
"%a, %d %b %Y %H:%M:%S GMT", value.utctimetuple())
elif isinstance(value, str):
self.headers[hdrs.LAST_MODIFIED] = value
@property
def tcp_nodelay(self):
return self._tcp_nodelay
def set_tcp_nodelay(self, value):
value = bool(value)
self._tcp_nodelay = value
if value:
self._tcp_cork = False
if self._resp_impl is None:
return
if value:
self._resp_impl.transport.set_tcp_cork(False)
self._resp_impl.transport.set_tcp_nodelay(value)
@property
def tcp_cork(self):
return self._tcp_cork
def set_tcp_cork(self, value):
value = bool(value)
self._tcp_cork = value
if value:
self._tcp_nodelay = False
if self._resp_impl is None:
return
if value:
self._resp_impl.transport.set_tcp_nodelay(False)
self._resp_impl.transport.set_tcp_cork(value)
def _generate_content_type_header(self, CONTENT_TYPE=hdrs.CONTENT_TYPE):
params = '; '.join("%s=%s" % i for i in self._content_dict.items())
if params:
ctype = self._content_type + '; ' + params
else:
ctype = self._content_type
self.headers[CONTENT_TYPE] = ctype
def _start_pre_check(self, request):
if self._resp_impl is not None:
if self._req is not request:
raise RuntimeError(
"Response has been started with different request.")
else:
return self._resp_impl
else:
return None
def _do_start_compression(self, coding):
if coding != ContentCoding.identity:
self.headers[hdrs.CONTENT_ENCODING] = coding.value
self._resp_impl.add_compression_filter(coding.value)
self.content_length = None
def _start_compression(self, request):
if self._compression_force:
self._do_start_compression(self._compression_force)
else:
accept_encoding = request.headers.get(
hdrs.ACCEPT_ENCODING, '').lower()
for coding in ContentCoding:
if coding.value in accept_encoding:
self._do_start_compression(coding)
return
def start(self, request):
warnings.warn('use .prepare(request) instead', DeprecationWarning)
resp_impl = self._start_pre_check(request)
if resp_impl is not None:
return resp_impl
return self._start(request)
@asyncio.coroutine
def prepare(self, request):
resp_impl = self._start_pre_check(request)
if resp_impl is not None:
return resp_impl
yield from request.app.on_response_prepare.send(request, self)
return self._start(request)
def _start(self, request):
self._req = request
keep_alive = self._keep_alive
if keep_alive is None:
keep_alive = request.keep_alive
self._keep_alive = keep_alive
resp_impl = self._resp_impl = ResponseImpl(
request._writer,
self._status,
request.version,
not keep_alive,
self._reason)
self._copy_cookies()
if self._compression:
self._start_compression(request)
if self._chunked:
if request.version != HttpVersion11:
raise RuntimeError("Using chunked encoding is forbidden "
"for HTTP/{0.major}.{0.minor}".format(
request.version))
resp_impl.enable_chunked_encoding()
if self._chunk_size:
resp_impl.add_chunking_filter(self._chunk_size)
headers = self.headers.items()
for key, val in headers:
resp_impl.add_header(key, val)
resp_impl.transport.set_tcp_nodelay(self._tcp_nodelay)
resp_impl.transport.set_tcp_cork(self._tcp_cork)
self._send_headers(resp_impl)
return resp_impl
def _send_headers(self, resp_impl):
# Durty hack required for
# https://github.com/KeepSafe/aiohttp/issues/1093
# File sender may override it
resp_impl.send_headers()
def write(self, data):
assert isinstance(data, (bytes, bytearray, memoryview)), \
"data argument must be byte-ish (%r)" % type(data)
if self._eof_sent:
raise RuntimeError("Cannot call write() after write_eof()")
if self._resp_impl is None:
raise RuntimeError("Cannot call write() before start()")
if data:
return self._resp_impl.write(data)
else:
return ()
@asyncio.coroutine
def drain(self):
if self._resp_impl is None:
raise RuntimeError("Response has not been started")
yield from self._resp_impl.transport.drain()
@asyncio.coroutine
def write_eof(self):
if self._eof_sent:
return
if self._resp_impl is None:
raise RuntimeError("Response has not been started")
yield from self._resp_impl.write_eof()
self._eof_sent = True
def __repr__(self):
if self.started:
info = "{} {} ".format(self._req.method, self._req.path)
else:
info = "not started"
return "<{} {} {}>".format(self.__class__.__name__,
self.reason, info)
class Response(StreamResponse):
def __init__(self, *, body=None, status=200,
reason=None, text=None, headers=None, content_type=None,
charset=None):
if body is not None and text is not None:
raise ValueError("body and text are not allowed together")
if headers is None:
headers = CIMultiDict()
elif not isinstance(headers, (CIMultiDict, CIMultiDictProxy)):
headers = CIMultiDict(headers)
if content_type is not None and ";" in content_type:
raise ValueError("charset must not be in content_type "
"argument")
if text is not None:
if hdrs.CONTENT_TYPE in headers:
if content_type or charset:
raise ValueError("passing both Content-Type header and "
"content_type or charset params "
"is forbidden")
else:
# fast path for filling headers
if not isinstance(text, str):
raise TypeError("text argument must be str (%r)" %
type(text))
if content_type is None:
content_type = 'text/plain'
if charset is None:
charset = 'utf-8'
headers[hdrs.CONTENT_TYPE] = (
content_type + '; charset=' + charset)
body = text.encode(charset)
text = None
else:
if hdrs.CONTENT_TYPE in headers:
if content_type is not None or charset is not None:
raise ValueError("passing both Content-Type header and "
"content_type or charset params "
"is forbidden")
else:
if content_type is not None:
if charset is not None:
content_type += '; charset=' + charset
headers[hdrs.CONTENT_TYPE] = content_type
super().__init__(status=status, reason=reason, headers=headers)
self.set_tcp_cork(True)
if text is not None:
self.text = text
else:
self.body = body
@property
def body(self):
return self._body
@body.setter
def body(self, body):
if body is not None and not isinstance(body, bytes):
raise TypeError("body argument must be bytes (%r)" % type(body))
self._body = body
if body is not None:
self.content_length = len(body)
else:
self.content_length = 0
@property
def text(self):
if self._body is None:
return None
return self._body.decode(self.charset or 'utf-8')
@text.setter
def text(self, text):
if text is not None and not isinstance(text, str):
raise TypeError("text argument must be str (%r)" % type(text))
if self.content_type == 'application/octet-stream':
self.content_type = 'text/plain'
if self.charset is None:
self.charset = 'utf-8'
self.body = text.encode(self.charset)
@asyncio.coroutine
def write_eof(self):
try:
body = self._body
if (body is not None and
self._req.method != hdrs.METH_HEAD and
self._status not in [204, 304]):
self.write(body)
finally:
self.set_tcp_nodelay(True)
yield from super().write_eof()
def json_response(data=sentinel, *, text=None, body=None, status=200,
reason=None, headers=None, content_type='application/json',
dumps=json.dumps):
if data is not sentinel:
if text or body:
raise ValueError(
"only one of data, text, or body should be specified"
)
else:
text = dumps(data)
return Response(text=text, body=body, status=status, reason=reason,
headers=headers, content_type=content_type)

View File

@ -0,0 +1,825 @@
import abc
import asyncio
import collections
import inspect
import keyword
import os
import re
import sys
import warnings
from collections.abc import Container, Iterable, Sized
from pathlib import Path
from types import MappingProxyType
from urllib.parse import unquote, urlencode
from . import hdrs
from .abc import AbstractMatchInfo, AbstractRouter, AbstractView
from .file_sender import FileSender
from .protocol import HttpVersion11
from .web_exceptions import (HTTPExpectationFailed, HTTPForbidden,
HTTPMethodNotAllowed, HTTPNotFound)
from .web_reqrep import Response, StreamResponse
__all__ = ('UrlDispatcher', 'UrlMappingMatchInfo',
'AbstractResource', 'Resource', 'PlainResource', 'DynamicResource',
'ResourceAdapter',
'AbstractRoute', 'ResourceRoute',
'Route', 'PlainRoute', 'DynamicRoute', 'StaticRoute', 'View')
PY_35 = sys.version_info >= (3, 5)
HTTP_METHOD_RE = re.compile(r"^[0-9A-Za-z!#\$%&'\*\+\-\.\^_`\|~]+$")
class AbstractResource(Sized, Iterable):
def __init__(self, *, name=None):
self._name = name
@property
def name(self):
return self._name
@abc.abstractmethod # pragma: no branch
def url(self, **kwargs):
"""Construct url for resource with additional params."""
@asyncio.coroutine
@abc.abstractmethod # pragma: no branch
def resolve(self, method, path):
"""Resolve resource
Return (UrlMappingMatchInfo, allowed_methods) pair."""
@abc.abstractmethod
def get_info(self):
"""Return a dict with additional info useful for introspection"""
@staticmethod
def _append_query(url, query):
if query:
return url + "?" + urlencode(query)
else:
return url
class AbstractRoute(abc.ABC):
def __init__(self, method, handler, *,
expect_handler=None,
resource=None):
if expect_handler is None:
expect_handler = _defaultExpectHandler
assert asyncio.iscoroutinefunction(expect_handler), \
'Coroutine is expected, got {!r}'.format(expect_handler)
method = method.upper()
if not HTTP_METHOD_RE.match(method):
raise ValueError("{} is not allowed HTTP method".format(method))
assert callable(handler), handler
if asyncio.iscoroutinefunction(handler):
pass
elif inspect.isgeneratorfunction(handler):
warnings.warn("Bare generators are deprecated, "
"use @coroutine wrapper", DeprecationWarning)
elif (isinstance(handler, type) and
issubclass(handler, AbstractView)):
pass
else:
@asyncio.coroutine
def handler_wrapper(*args, **kwargs):
result = old_handler(*args, **kwargs)
if asyncio.iscoroutine(result):
result = yield from result
return result
old_handler = handler
handler = handler_wrapper
self._method = method
self._handler = handler
self._expect_handler = expect_handler
self._resource = resource
@property
def method(self):
return self._method
@property
def handler(self):
return self._handler
@property
@abc.abstractmethod
def name(self):
"""Optional route's name, always equals to resource's name."""
@property
def resource(self):
return self._resource
@abc.abstractmethod
def get_info(self):
"""Return a dict with additional info useful for introspection"""
@abc.abstractmethod # pragma: no branch
def url(self, **kwargs):
"""Construct url for route with additional params."""
@asyncio.coroutine
def handle_expect_header(self, request):
return (yield from self._expect_handler(request))
class UrlMappingMatchInfo(dict, AbstractMatchInfo):
def __init__(self, match_dict, route):
super().__init__(match_dict)
self._route = route
@property
def handler(self):
return self._route.handler
@property
def route(self):
return self._route
@property
def expect_handler(self):
return self._route.handle_expect_header
@property
def http_exception(self):
return None
def get_info(self):
return self._route.get_info()
def __repr__(self):
return "<MatchInfo {}: {}>".format(super().__repr__(), self._route)
class MatchInfoError(UrlMappingMatchInfo):
def __init__(self, http_exception):
self._exception = http_exception
super().__init__({}, SystemRoute(self._exception))
@property
def http_exception(self):
return self._exception
def __repr__(self):
return "<MatchInfoError {}: {}>".format(self._exception.status,
self._exception.reason)
@asyncio.coroutine
def _defaultExpectHandler(request):
"""Default handler for Expect header.
Just send "100 Continue" to client.
raise HTTPExpectationFailed if value of header is not "100-continue"
"""
expect = request.headers.get(hdrs.EXPECT)
if request.version == HttpVersion11:
if expect.lower() == "100-continue":
request.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
else:
raise HTTPExpectationFailed(text="Unknown Expect: %s" % expect)
class ResourceAdapter(AbstractResource):
def __init__(self, route):
assert isinstance(route, Route), \
'Instance of Route class is required, got {!r}'.format(route)
super().__init__(name=route.name)
self._route = route
route._resource = self
def url(self, **kwargs):
return self._route.url(**kwargs)
@asyncio.coroutine
def resolve(self, method, path):
route_method = self._route.method
allowed_methods = set()
match_dict = self._route.match(path)
if match_dict is not None:
allowed_methods.add(route_method)
if route_method == hdrs.METH_ANY or route_method == method:
return (UrlMappingMatchInfo(match_dict, self._route),
allowed_methods)
return None, allowed_methods
def get_info(self):
return self._route.get_info()
def __len__(self):
return 1
def __iter__(self):
yield self._route
class Resource(AbstractResource):
def __init__(self, *, name=None):
super().__init__(name=name)
self._routes = []
def add_route(self, method, handler, *,
expect_handler=None):
for route in self._routes:
if route.method == method or route.method == hdrs.METH_ANY:
raise RuntimeError("Added route will never be executed, "
"method {route.method} is "
"already registered".format(route=route))
route = ResourceRoute(method, handler, self,
expect_handler=expect_handler)
self.register_route(route)
return route
def register_route(self, route):
assert isinstance(route, ResourceRoute), \
'Instance of Route class is required, got {!r}'.format(route)
self._routes.append(route)
@asyncio.coroutine
def resolve(self, method, path):
allowed_methods = set()
match_dict = self._match(path)
if match_dict is None:
return None, allowed_methods
for route in self._routes:
route_method = route.method
allowed_methods.add(route_method)
if route_method == method or route_method == hdrs.METH_ANY:
return UrlMappingMatchInfo(match_dict, route), allowed_methods
else:
return None, allowed_methods
def __len__(self):
return len(self._routes)
def __iter__(self):
return iter(self._routes)
class PlainResource(Resource):
def __init__(self, path, *, name=None):
super().__init__(name=name)
self._path = path
def _match(self, path):
# string comparison is about 10 times faster than regexp matching
if self._path == path:
return {}
else:
return None
def get_info(self):
return {'path': self._path}
def url(self, *, query=None):
return self._append_query(self._path, query)
def __repr__(self):
name = "'" + self.name + "' " if self.name is not None else ""
return "<PlainResource {name} {path}".format(name=name,
path=self._path)
class DynamicResource(Resource):
def __init__(self, pattern, formatter, *, name=None):
super().__init__(name=name)
self._pattern = pattern
self._formatter = formatter
def _match(self, path):
match = self._pattern.match(path)
if match is None:
return None
else:
return {key: unquote(value) for key, value in
match.groupdict().items()}
def get_info(self):
return {'formatter': self._formatter,
'pattern': self._pattern}
def url(self, *, parts, query=None):
url = self._formatter.format_map(parts)
return self._append_query(url, query)
def __repr__(self):
name = "'" + self.name + "' " if self.name is not None else ""
return ("<DynamicResource {name} {formatter}"
.format(name=name, formatter=self._formatter))
class ResourceRoute(AbstractRoute):
"""A route with resource"""
def __init__(self, method, handler, resource, *,
expect_handler=None):
super().__init__(method, handler, expect_handler=expect_handler,
resource=resource)
def __repr__(self):
return "<ResourceRoute [{method}] {resource} -> {handler!r}".format(
method=self.method, resource=self._resource,
handler=self.handler)
@property
def name(self):
return self._resource.name
def url(self, **kwargs):
"""Construct url for route with additional params."""
return self._resource.url(**kwargs)
def get_info(self):
return self._resource.get_info()
_append_query = staticmethod(Resource._append_query)
class Route(AbstractRoute):
"""Old fashion route"""
def __init__(self, method, handler, name, *, expect_handler=None):
super().__init__(method, handler, expect_handler=expect_handler)
self._name = name
@property
def name(self):
return self._name
@abc.abstractmethod
def match(self, path):
"""Return dict with info for given path or
None if route cannot process path."""
_append_query = staticmethod(Resource._append_query)
class PlainRoute(Route):
def __init__(self, method, handler, name, path, *, expect_handler=None):
super().__init__(method, handler, name, expect_handler=expect_handler)
self._path = path
def match(self, path):
# string comparison is about 10 times faster than regexp matching
if self._path == path:
return {}
else:
return None
def url(self, *, query=None):
return self._append_query(self._path, query)
def get_info(self):
return {'path': self._path}
def __repr__(self):
name = "'" + self.name + "' " if self.name is not None else ""
return "<PlainRoute {name}[{method}] {path} -> {handler!r}".format(
name=name, method=self.method, path=self._path,
handler=self.handler)
class DynamicRoute(Route):
def __init__(self, method, handler, name, pattern, formatter, *,
expect_handler=None):
super().__init__(method, handler, name, expect_handler=expect_handler)
self._pattern = pattern
self._formatter = formatter
def match(self, path):
match = self._pattern.match(path)
if match is None:
return None
else:
return match.groupdict()
def url(self, *, parts, query=None):
url = self._formatter.format_map(parts)
return self._append_query(url, query)
def get_info(self):
return {'formatter': self._formatter,
'pattern': self._pattern}
def __repr__(self):
name = "'" + self.name + "' " if self.name is not None else ""
return ("<DynamicRoute {name}[{method}] {formatter} -> {handler!r}"
.format(name=name, method=self.method,
formatter=self._formatter, handler=self.handler))
class StaticRoute(Route):
def __init__(self, name, prefix, directory, *,
expect_handler=None, chunk_size=256*1024,
response_factory=StreamResponse,
show_index=False):
assert prefix.startswith('/'), prefix
assert prefix.endswith('/'), prefix
super().__init__(
'GET', self.handle, name, expect_handler=expect_handler)
self._prefix = prefix
self._prefix_len = len(self._prefix)
try:
directory = Path(directory)
if str(directory).startswith('~'):
directory = Path(os.path.expanduser(str(directory)))
directory = directory.resolve()
if not directory.is_dir():
raise ValueError('Not a directory')
except (FileNotFoundError, ValueError) as error:
raise ValueError(
"No directory exists at '{}'".format(directory)) from error
self._directory = directory
self._file_sender = FileSender(resp_factory=response_factory,
chunk_size=chunk_size)
self._show_index = show_index
def match(self, path):
if not path.startswith(self._prefix):
return None
return {'filename': path[self._prefix_len:]}
def url(self, *, filename, query=None):
if isinstance(filename, Path):
filename = str(filename)
while filename.startswith('/'):
filename = filename[1:]
url = self._prefix + filename
return self._append_query(url, query)
def get_info(self):
return {'directory': self._directory,
'prefix': self._prefix}
@asyncio.coroutine
def handle(self, request):
filename = unquote(request.match_info['filename'])
try:
filepath = self._directory.joinpath(filename).resolve()
filepath.relative_to(self._directory)
except (ValueError, FileNotFoundError) as error:
# relatively safe
raise HTTPNotFound() from error
except Exception as error:
# perm error or other kind!
request.app.logger.exception(error)
raise HTTPNotFound() from error
# on opening a dir, load it's contents if allowed
if filepath.is_dir():
if self._show_index:
try:
ret = Response(text=self._directory_as_html(filepath),
content_type="text/html")
except PermissionError:
raise HTTPForbidden()
else:
raise HTTPForbidden()
elif filepath.is_file():
ret = yield from self._file_sender.send(request, filepath)
else:
raise HTTPNotFound
return ret
def _directory_as_html(self, filepath):
"returns directory's index as html"
# sanity check
assert filepath.is_dir()
posix_dir_len = len(self._directory.as_posix())
# remove the beginning of posix path, so it would be relative
# to our added static path
relative_path_to_dir = filepath.as_posix()[posix_dir_len:]
index_of = "Index of /{}".format(relative_path_to_dir)
head = "<head>\n<title>{}</title>\n</head>".format(index_of)
h1 = "<h1>{}</h1>".format(index_of)
index_list = []
dir_index = filepath.iterdir()
for _file in sorted(dir_index):
# show file url as relative to static path
file_url = _file.as_posix()[posix_dir_len:]
# if file is a directory, add '/' to the end of the name
if _file.is_dir():
file_name = "{}/".format(_file.name)
else:
file_name = _file.name
index_list.append(
'<li><a href="{url}">{name}</a></li>'.format(url=file_url,
name=file_name)
)
ul = "<ul>\n{}\n</ul>".format('\n'.join(index_list))
body = "<body>\n{}\n{}\n</body>".format(h1, ul)
html = "<html>\n{}\n{}\n</html>".format(head, body)
return html
def __repr__(self):
name = "'" + self.name + "' " if self.name is not None else ""
return "<StaticRoute {name}[{method}] {path} -> {directory!r}".format(
name=name, method=self.method, path=self._prefix,
directory=self._directory)
class SystemRoute(Route):
def __init__(self, http_exception):
super().__init__(hdrs.METH_ANY, self._handler, None)
self._http_exception = http_exception
def url(self, **kwargs):
raise RuntimeError(".url() is not allowed for SystemRoute")
def match(self, path):
return None
def get_info(self):
return {'http_exception': self._http_exception}
@asyncio.coroutine
def _handler(self, request):
raise self._http_exception
@property
def status(self):
return self._http_exception.status
@property
def reason(self):
return self._http_exception.reason
def __repr__(self):
return "<SystemRoute {self.status}: {self.reason}>".format(self=self)
class View(AbstractView):
@asyncio.coroutine
def __iter__(self):
if self.request.method not in hdrs.METH_ALL:
self._raise_allowed_methods()
method = getattr(self, self.request.method.lower(), None)
if method is None:
self._raise_allowed_methods()
resp = yield from method()
return resp
if PY_35:
def __await__(self):
return (yield from self.__iter__())
def _raise_allowed_methods(self):
allowed_methods = {
m for m in hdrs.METH_ALL if hasattr(self, m.lower())}
raise HTTPMethodNotAllowed(self.request.method, allowed_methods)
class ResourcesView(Sized, Iterable, Container):
def __init__(self, resources):
self._resources = resources
def __len__(self):
return len(self._resources)
def __iter__(self):
yield from self._resources
def __contains__(self, resource):
return resource in self._resources
class RoutesView(Sized, Iterable, Container):
def __init__(self, resources):
self._routes = []
for resource in resources:
for route in resource:
self._routes.append(route)
def __len__(self):
return len(self._routes)
def __iter__(self):
yield from self._routes
def __contains__(self, route):
return route in self._routes
class UrlDispatcher(AbstractRouter, collections.abc.Mapping):
DYN = re.compile(r'^\{(?P<var>[a-zA-Z][_a-zA-Z0-9]*)\}$')
DYN_WITH_RE = re.compile(
r'^\{(?P<var>[a-zA-Z][_a-zA-Z0-9]*):(?P<re>.+)\}$')
GOOD = r'[^{}/]+'
ROUTE_RE = re.compile(r'(\{[_a-zA-Z][^{}]*(?:\{[^{}]*\}[^{}]*)*\})')
NAME_SPLIT_RE = re.compile('[.:-]')
def __init__(self):
super().__init__()
self._resources = []
self._named_resources = {}
@asyncio.coroutine
def resolve(self, request):
path = request.raw_path
method = request.method
allowed_methods = set()
for resource in self._resources:
match_dict, allowed = yield from resource.resolve(method, path)
if match_dict is not None:
return match_dict
else:
allowed_methods |= allowed
else:
if allowed_methods:
return MatchInfoError(HTTPMethodNotAllowed(method,
allowed_methods))
else:
return MatchInfoError(HTTPNotFound())
def __iter__(self):
return iter(self._named_resources)
def __len__(self):
return len(self._named_resources)
def __contains__(self, name):
return name in self._named_resources
def __getitem__(self, name):
return self._named_resources[name]
def resources(self):
return ResourcesView(self._resources)
def routes(self):
return RoutesView(self._resources)
def named_resources(self):
return MappingProxyType(self._named_resources)
def named_routes(self):
# NB: it's ambiguous but it's really resources.
warnings.warn("Use .named_resources instead", DeprecationWarning)
return self.named_resources()
def register_route(self, route):
warnings.warn("Use resource-based interface", DeprecationWarning)
resource = ResourceAdapter(route)
self._reg_resource(resource)
def _reg_resource(self, resource):
assert isinstance(resource, AbstractResource), \
'Instance of AbstractResource class is required, got {!r}'.format(
resource)
name = resource.name
if name is not None:
parts = self.NAME_SPLIT_RE.split(name)
for part in parts:
if not part.isidentifier() or keyword.iskeyword(part):
raise ValueError('Incorrect route name {!r}, '
'the name should be a sequence of '
'python identifiers separated '
'by dash, dot or column'.format(name))
if name in self._named_resources:
raise ValueError('Duplicate {!r}, '
'already handled by {!r}'
.format(name, self._named_resources[name]))
self._named_resources[name] = resource
self._resources.append(resource)
def add_resource(self, path, *, name=None):
if not path.startswith('/'):
raise ValueError("path should be started with /")
if not ('{' in path or '}' in path or self.ROUTE_RE.search(path)):
resource = PlainResource(path, name=name)
self._reg_resource(resource)
return resource
pattern = ''
formatter = ''
for part in self.ROUTE_RE.split(path):
match = self.DYN.match(part)
if match:
pattern += '(?P<{}>{})'.format(match.group('var'), self.GOOD)
formatter += '{' + match.group('var') + '}'
continue
match = self.DYN_WITH_RE.match(part)
if match:
pattern += '(?P<{var}>{re})'.format(**match.groupdict())
formatter += '{' + match.group('var') + '}'
continue
if '{' in part or '}' in part:
raise ValueError("Invalid path '{}'['{}']".format(path, part))
formatter += part
pattern += re.escape(part)
try:
compiled = re.compile('^' + pattern + '$')
except re.error as exc:
raise ValueError(
"Bad pattern '{}': {}".format(pattern, exc)) from None
resource = DynamicResource(compiled, formatter, name=name)
self._reg_resource(resource)
return resource
def add_route(self, method, path, handler,
*, name=None, expect_handler=None):
resource = self.add_resource(path, name=name)
return resource.add_route(method, handler,
expect_handler=expect_handler)
def add_static(self, prefix, path, *, name=None, expect_handler=None,
chunk_size=256*1024, response_factory=StreamResponse,
show_index=False):
"""Add static files view.
prefix - url prefix
path - folder with files
"""
assert prefix.startswith('/')
if not prefix.endswith('/'):
prefix += '/'
route = StaticRoute(name, prefix, path,
expect_handler=expect_handler,
chunk_size=chunk_size,
response_factory=response_factory,
show_index=show_index)
self.register_route(route)
return route
def add_head(self, *args, **kwargs):
"""
Shortcut for add_route with method HEAD
"""
return self.add_route(hdrs.METH_HEAD, *args, **kwargs)
def add_get(self, *args, **kwargs):
"""
Shortcut for add_route with method GET
"""
return self.add_route(hdrs.METH_GET, *args, **kwargs)
def add_post(self, *args, **kwargs):
"""
Shortcut for add_route with method POST
"""
return self.add_route(hdrs.METH_POST, *args, **kwargs)
def add_put(self, *args, **kwargs):
"""
Shortcut for add_route with method PUT
"""
return self.add_route(hdrs.METH_PUT, *args, **kwargs)
def add_patch(self, *args, **kwargs):
"""
Shortcut for add_route with method PATCH
"""
return self.add_route(hdrs.METH_PATCH, *args, **kwargs)
def add_delete(self, *args, **kwargs):
"""
Shortcut for add_route with method DELETE
"""
return self.add_route(hdrs.METH_DELETE, *args, **kwargs)

View File

@ -0,0 +1,320 @@
import asyncio
import json
import sys
import warnings
from collections import namedtuple
from . import Timeout, hdrs
from ._ws_impl import (CLOSED_MESSAGE, WebSocketError, WSMessage, WSMsgType,
do_handshake)
from .errors import ClientDisconnectedError, HttpProcessingError
from .web_exceptions import (HTTPBadRequest, HTTPInternalServerError,
HTTPMethodNotAllowed)
from .web_reqrep import StreamResponse
__all__ = ('WebSocketResponse', 'WebSocketReady', 'MsgType', 'WSMsgType',)
PY_35 = sys.version_info >= (3, 5)
PY_352 = sys.version_info >= (3, 5, 2)
THRESHOLD_CONNLOST_ACCESS = 5
# deprecated since 1.0
MsgType = WSMsgType
class WebSocketReady(namedtuple('WebSocketReady', 'ok protocol')):
def __bool__(self):
return self.ok
class WebSocketResponse(StreamResponse):
def __init__(self, *,
timeout=10.0, autoclose=True, autoping=True, protocols=()):
super().__init__(status=101)
self._protocols = protocols
self._protocol = None
self._writer = None
self._reader = None
self._closed = False
self._closing = False
self._conn_lost = 0
self._close_code = None
self._loop = None
self._waiting = False
self._exception = None
self._timeout = timeout
self._autoclose = autoclose
self._autoping = autoping
@asyncio.coroutine
def prepare(self, request):
# make pre-check to don't hide it by do_handshake() exceptions
resp_impl = self._start_pre_check(request)
if resp_impl is not None:
return resp_impl
parser, protocol, writer = self._pre_start(request)
resp_impl = yield from super().prepare(request)
self._post_start(request, parser, protocol, writer)
return resp_impl
def _pre_start(self, request):
try:
status, headers, parser, writer, protocol = do_handshake(
request.method, request.headers, request.transport,
self._protocols)
except HttpProcessingError as err:
if err.code == 405:
raise HTTPMethodNotAllowed(
request.method, [hdrs.METH_GET], body=b'')
elif err.code == 400:
raise HTTPBadRequest(text=err.message, headers=err.headers)
else: # pragma: no cover
raise HTTPInternalServerError() from err
if self.status != status:
self.set_status(status)
for k, v in headers:
self.headers[k] = v
self.force_close()
return parser, protocol, writer
def _post_start(self, request, parser, protocol, writer):
self._reader = request._reader.set_parser(parser)
self._writer = writer
self._protocol = protocol
self._loop = request.app.loop
def start(self, request):
warnings.warn('use .prepare(request) instead', DeprecationWarning)
# make pre-check to don't hide it by do_handshake() exceptions
resp_impl = self._start_pre_check(request)
if resp_impl is not None:
return resp_impl
parser, protocol, writer = self._pre_start(request)
resp_impl = super().start(request)
self._post_start(request, parser, protocol, writer)
return resp_impl
def can_prepare(self, request):
if self._writer is not None:
raise RuntimeError('Already started')
try:
_, _, _, _, protocol = do_handshake(
request.method, request.headers, request.transport,
self._protocols)
except HttpProcessingError:
return WebSocketReady(False, None)
else:
return WebSocketReady(True, protocol)
def can_start(self, request):
warnings.warn('use .can_prepare(request) instead', DeprecationWarning)
return self.can_prepare(request)
@property
def closed(self):
return self._closed
@property
def close_code(self):
return self._close_code
@property
def protocol(self):
return self._protocol
def exception(self):
return self._exception
def ping(self, message='b'):
if self._writer is None:
raise RuntimeError('Call .prepare() first')
if self._closed:
raise RuntimeError('websocket connection is closing')
self._writer.ping(message)
def pong(self, message='b'):
# unsolicited pong
if self._writer is None:
raise RuntimeError('Call .prepare() first')
if self._closed:
raise RuntimeError('websocket connection is closing')
self._writer.pong(message)
def send_str(self, data):
if self._writer is None:
raise RuntimeError('Call .prepare() first')
if self._closed:
raise RuntimeError('websocket connection is closing')
if not isinstance(data, str):
raise TypeError('data argument must be str (%r)' % type(data))
self._writer.send(data, binary=False)
def send_bytes(self, data):
if self._writer is None:
raise RuntimeError('Call .prepare() first')
if self._closed:
raise RuntimeError('websocket connection is closing')
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError('data argument must be byte-ish (%r)' %
type(data))
self._writer.send(data, binary=True)
def send_json(self, data, *, dumps=json.dumps):
self.send_str(dumps(data))
@asyncio.coroutine
def write_eof(self):
if self._eof_sent:
return
if self._resp_impl is None:
raise RuntimeError("Response has not been started")
yield from self.close()
self._eof_sent = True
@asyncio.coroutine
def close(self, *, code=1000, message=b''):
if self._writer is None:
raise RuntimeError('Call .prepare() first')
if not self._closed:
self._closed = True
try:
self._writer.close(code, message)
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = 1006
raise
except Exception as exc:
self._close_code = 1006
self._exception = exc
return True
if self._closing:
return True
begin = self._loop.time()
while self._loop.time() - begin < self._timeout:
try:
with Timeout(timeout=self._timeout,
loop=self._loop):
msg = yield from self._reader.read()
except asyncio.CancelledError:
self._close_code = 1006
raise
except Exception as exc:
self._close_code = 1006
self._exception = exc
return True
if msg.type == WSMsgType.CLOSE:
self._close_code = msg.data
return True
self._close_code = 1006
self._exception = asyncio.TimeoutError()
return True
else:
return False
@asyncio.coroutine
def receive(self):
if self._reader is None:
raise RuntimeError('Call .prepare() first')
if self._waiting:
raise RuntimeError('Concurrent call to receive() is not allowed')
self._waiting = True
try:
while True:
if self._closed:
self._conn_lost += 1
if self._conn_lost >= THRESHOLD_CONNLOST_ACCESS:
raise RuntimeError('WebSocket connection is closed.')
return CLOSED_MESSAGE
try:
msg = yield from self._reader.read()
except (asyncio.CancelledError, asyncio.TimeoutError):
raise
except WebSocketError as exc:
self._close_code = exc.code
yield from self.close(code=exc.code)
return WSMessage(WSMsgType.ERROR, exc, None)
except ClientDisconnectedError:
self._closed = True
self._close_code = 1006
return WSMessage(WSMsgType.CLOSE, None, None)
except Exception as exc:
self._exception = exc
self._closing = True
self._close_code = 1006
yield from self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
if msg.type == WSMsgType.CLOSE:
self._closing = True
self._close_code = msg.data
if not self._closed and self._autoclose:
yield from self.close()
return msg
if msg.type == WSMsgType.PING and self._autoping:
self.pong(msg.data)
elif msg.type == WSMsgType.PONG and self._autoping:
continue
else:
return msg
finally:
self._waiting = False
@asyncio.coroutine
def receive_msg(self):
warnings.warn(
'receive_msg() coroutine is deprecated. use receive() instead',
DeprecationWarning)
return (yield from self.receive())
@asyncio.coroutine
def receive_str(self):
msg = yield from self.receive()
if msg.type != WSMsgType.TEXT:
raise TypeError(
"Received message {}:{!r} is not str".format(msg.type,
msg.data))
return msg.data
@asyncio.coroutine
def receive_bytes(self):
msg = yield from self.receive()
if msg.type != WSMsgType.BINARY:
raise TypeError(
"Received message {}:{!r} is not bytes".format(msg.type,
msg.data))
return msg.data
@asyncio.coroutine
def receive_json(self, *, loads=json.loads):
data = yield from self.receive_str()
return loads(data)
def write(self, data):
raise RuntimeError("Cannot call .write() for websocket")
if PY_35:
def __aiter__(self):
return self
if not PY_352: # pragma: no cover
__aiter__ = asyncio.coroutine(__aiter__)
@asyncio.coroutine
def __anext__(self):
msg = yield from self.receive()
if msg.type == WSMsgType.CLOSE:
raise StopAsyncIteration # NOQA
return msg

View File

@ -0,0 +1,195 @@
"""Async gunicorn worker for aiohttp.web"""
import asyncio
import os
import re
import signal
import ssl
import sys
import gunicorn.workers.base as base
from gunicorn.config import AccessLogFormat as GunicornAccessLogFormat
from aiohttp.helpers import AccessLogger, ensure_future
__all__ = ('GunicornWebWorker', 'GunicornUVLoopWebWorker')
class GunicornWebWorker(base.Worker):
DEFAULT_AIOHTTP_LOG_FORMAT = AccessLogger.LOG_FORMAT
DEFAULT_GUNICORN_LOG_FORMAT = GunicornAccessLogFormat.default
def __init__(self, *args, **kw): # pragma: no cover
super().__init__(*args, **kw)
self.servers = {}
self.exit_code = 0
def init_process(self):
# create new event_loop after fork
asyncio.get_event_loop().close()
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
super().init_process()
def run(self):
self.loop.run_until_complete(self.wsgi.startup())
self._runner = ensure_future(self._run(), loop=self.loop)
try:
self.loop.run_until_complete(self._runner)
finally:
self.loop.close()
sys.exit(self.exit_code)
def make_handler(self, app):
return app.make_handler(
logger=self.log,
slow_request_timeout=self.cfg.timeout,
keepalive_timeout=self.cfg.keepalive,
access_log=self.log.access_log,
access_log_format=self._get_valid_log_format(
self.cfg.access_log_format))
@asyncio.coroutine
def close(self):
if self.servers:
servers = self.servers
self.servers = None
# stop accepting connections
for server, handler in servers.items():
self.log.info("Stopping server: %s, connections: %s",
self.pid, len(handler.connections))
server.close()
yield from server.wait_closed()
# send on_shutdown event
yield from self.wsgi.shutdown()
# stop alive connections
tasks = [
handler.finish_connections(
timeout=self.cfg.graceful_timeout / 100 * 95)
for handler in servers.values()]
yield from asyncio.gather(*tasks, loop=self.loop)
# cleanup application
yield from self.wsgi.cleanup()
@asyncio.coroutine
def _run(self):
ctx = self._create_ssl_context(self.cfg) if self.cfg.is_ssl else None
for sock in self.sockets:
handler = self.make_handler(self.wsgi)
srv = yield from self.loop.create_server(handler, sock=sock.sock,
ssl=ctx)
self.servers[srv] = handler
# If our parent changed then we shut down.
pid = os.getpid()
try:
while self.alive:
self.notify()
cnt = sum(handler.requests_count
for handler in self.servers.values())
if self.cfg.max_requests and cnt > self.cfg.max_requests:
self.alive = False
self.log.info("Max requests, shutting down: %s", self)
elif pid == os.getpid() and self.ppid != os.getppid():
self.alive = False
self.log.info("Parent changed, shutting down: %s", self)
else:
yield from asyncio.sleep(1.0, loop=self.loop)
except BaseException:
pass
yield from self.close()
def init_signals(self):
# Set up signals through the event loop API.
self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit,
signal.SIGQUIT, None)
self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit,
signal.SIGTERM, None)
self.loop.add_signal_handler(signal.SIGINT, self.handle_quit,
signal.SIGINT, None)
self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch,
signal.SIGWINCH, None)
self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1,
signal.SIGUSR1, None)
self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort,
signal.SIGABRT, None)
# Don't let SIGTERM and SIGUSR1 disturb active requests
# by interrupting system calls
signal.siginterrupt(signal.SIGTERM, False)
signal.siginterrupt(signal.SIGUSR1, False)
def handle_quit(self, sig, frame):
self.alive = False
def handle_abort(self, sig, frame):
self.alive = False
self.exit_code = 1
@staticmethod
def _create_ssl_context(cfg):
""" Creates SSLContext instance for usage in asyncio.create_server.
See ssl.SSLSocket.__init__ for more details.
"""
ctx = ssl.SSLContext(cfg.ssl_version)
ctx.load_cert_chain(cfg.certfile, cfg.keyfile)
ctx.verify_mode = cfg.cert_reqs
if cfg.ca_certs:
ctx.load_verify_locations(cfg.ca_certs)
if cfg.ciphers:
ctx.set_ciphers(cfg.ciphers)
return ctx
def _get_valid_log_format(self, source_format):
if source_format == self.DEFAULT_GUNICORN_LOG_FORMAT:
return self.DEFAULT_AIOHTTP_LOG_FORMAT
elif re.search(r'%\([^\)]+\)', source_format):
raise ValueError(
"Gunicorn's style options in form of `%(name)s` are not "
"supported for the log formatting. Please use aiohttp's "
"format specification to configure access log formatting: "
"http://aiohttp.readthedocs.io/en/stable/logging.html"
"#format-specification"
)
else:
return source_format
class GunicornUVLoopWebWorker(GunicornWebWorker):
def init_process(self):
import uvloop
# Close any existing event loop before setting a
# new policy.
asyncio.get_event_loop().close()
# Setup uvloop policy, so that every
# asyncio.get_event_loop() will create an instance
# of uvloop event loop.
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
super().init_process()

View File

@ -0,0 +1,235 @@
"""wsgi server.
TODO:
* proxy protocol
* x-forward security
* wsgi file support (os.sendfile)
"""
import asyncio
import inspect
import io
import os
import socket
import sys
from urllib.parse import urlsplit
import aiohttp
from aiohttp import hdrs, server
__all__ = ('WSGIServerHttpProtocol',)
class WSGIServerHttpProtocol(server.ServerHttpProtocol):
"""HTTP Server that implements the Python WSGI protocol.
It uses 'wsgi.async' of 'True'. 'wsgi.input' can behave differently
depends on 'readpayload' constructor parameter. If readpayload is set to
True, wsgi server reads all incoming data into BytesIO object and
sends it as 'wsgi.input' environ var. If readpayload is set to false
'wsgi.input' is a StreamReader and application should read incoming
data with "yield from environ['wsgi.input'].read()". It defaults to False.
"""
SCRIPT_NAME = os.environ.get('SCRIPT_NAME', '')
def __init__(self, app, readpayload=False, is_ssl=False, *args, **kw):
super().__init__(*args, **kw)
self.wsgi = app
self.is_ssl = is_ssl
self.readpayload = readpayload
def create_wsgi_response(self, message):
return WsgiResponse(self.writer, message)
def create_wsgi_environ(self, message, payload):
uri_parts = urlsplit(message.path)
environ = {
'wsgi.input': payload,
'wsgi.errors': sys.stderr,
'wsgi.version': (1, 0),
'wsgi.async': True,
'wsgi.multithread': False,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'wsgi.file_wrapper': FileWrapper,
'SERVER_SOFTWARE': aiohttp.HttpMessage.SERVER_SOFTWARE,
'REQUEST_METHOD': message.method,
'QUERY_STRING': uri_parts.query or '',
'RAW_URI': message.path,
'SERVER_PROTOCOL': 'HTTP/%s.%s' % message.version
}
script_name = self.SCRIPT_NAME
for hdr_name, hdr_value in message.headers.items():
hdr_name = hdr_name.upper()
if hdr_name == 'SCRIPT_NAME':
script_name = hdr_value
elif hdr_name == 'CONTENT-TYPE':
environ['CONTENT_TYPE'] = hdr_value
continue
elif hdr_name == 'CONTENT-LENGTH':
environ['CONTENT_LENGTH'] = hdr_value
continue
key = 'HTTP_%s' % hdr_name.replace('-', '_')
if key in environ:
hdr_value = '%s,%s' % (environ[key], hdr_value)
environ[key] = hdr_value
url_scheme = environ.get('HTTP_X_FORWARDED_PROTO')
if url_scheme is None:
url_scheme = 'https' if self.is_ssl else 'http'
environ['wsgi.url_scheme'] = url_scheme
# authors should be aware that REMOTE_HOST and REMOTE_ADDR
# may not qualify the remote addr
# also SERVER_PORT variable MUST be set to the TCP/IP port number on
# which this request is received from the client.
# http://www.ietf.org/rfc/rfc3875
family = self.transport.get_extra_info('socket').family
if family in (socket.AF_INET, socket.AF_INET6):
peername = self.transport.get_extra_info('peername')
environ['REMOTE_ADDR'] = peername[0]
environ['REMOTE_PORT'] = str(peername[1])
http_host = message.headers.get("HOST", None)
if http_host:
hostport = http_host.split(":")
environ['SERVER_NAME'] = hostport[0]
if len(hostport) > 1:
environ['SERVER_PORT'] = str(hostport[1])
else:
environ['SERVER_PORT'] = '80'
else:
# SERVER_NAME should be set to value of Host header, but this
# header is not required. In this case we shoud set it to local
# address of socket
sockname = self.transport.get_extra_info('sockname')
environ['SERVER_NAME'] = sockname[0]
environ['SERVER_PORT'] = str(sockname[1])
else:
# We are behind reverse proxy, so get all vars from headers
for header in ('REMOTE_ADDR', 'REMOTE_PORT',
'SERVER_NAME', 'SERVER_PORT'):
environ[header] = message.headers.get(header, '')
path_info = uri_parts.path
if script_name:
path_info = path_info.split(script_name, 1)[-1]
environ['PATH_INFO'] = path_info
environ['SCRIPT_NAME'] = script_name
environ['async.reader'] = self.reader
environ['async.writer'] = self.writer
return environ
@asyncio.coroutine
def handle_request(self, message, payload):
"""Handle a single HTTP request"""
now = self._loop.time()
if self.readpayload:
wsgiinput = io.BytesIO()
wsgiinput.write((yield from payload.read()))
wsgiinput.seek(0)
payload = wsgiinput
environ = self.create_wsgi_environ(message, payload)
response = self.create_wsgi_response(message)
riter = self.wsgi(environ, response.start_response)
if isinstance(riter, asyncio.Future) or inspect.isgenerator(riter):
riter = yield from riter
resp = response.response
try:
for item in riter:
if isinstance(item, asyncio.Future):
item = yield from item
yield from resp.write(item)
yield from resp.write_eof()
finally:
if hasattr(riter, 'close'):
riter.close()
if resp.keep_alive():
self.keep_alive(True)
self.log_access(
message, environ, response.response, self._loop.time() - now)
class FileWrapper:
"""Custom file wrapper."""
def __init__(self, fobj, chunk_size=8192):
self.fobj = fobj
self.chunk_size = chunk_size
if hasattr(fobj, 'close'):
self.close = fobj.close
def __iter__(self):
return self
def __next__(self):
data = self.fobj.read(self.chunk_size)
if data:
return data
raise StopIteration
class WsgiResponse:
"""Implementation of start_response() callable as specified by PEP 3333"""
status = None
HOP_HEADERS = {
hdrs.CONNECTION,
hdrs.KEEP_ALIVE,
hdrs.PROXY_AUTHENTICATE,
hdrs.PROXY_AUTHORIZATION,
hdrs.TE,
hdrs.TRAILER,
hdrs.TRANSFER_ENCODING,
hdrs.UPGRADE,
}
def __init__(self, writer, message):
self.writer = writer
self.message = message
def start_response(self, status, headers, exc_info=None):
if exc_info:
try:
if self.status:
raise exc_info[1]
finally:
exc_info = None
status_code = int(status.split(' ', 1)[0])
self.status = status
resp = self.response = aiohttp.Response(
self.writer, status_code,
self.message.version, self.message.should_close)
resp.HOP_HEADERS = self.HOP_HEADERS
for name, value in headers:
resp.add_header(name, value)
if resp.has_chunked_hdr:
resp.enable_chunked_encoding()
# send headers immediately for websocket connection
if status_code == 101 and resp.upgrade and resp.websocket:
resp.send_headers()
else:
resp._send_headers = True
return self.response.write

View File

@ -0,0 +1,76 @@
async-timeout
=============
asyncio-compatible timeout context manager.
Usage example
-------------
The context manager is useful in cases when you want to apply timeout
logic around block of code or in cases when ``asyncio.wait_for()`` is
not suitable. Also it's much faster than ``asyncio.wait_for()``
because ``timeout`` doesn't create a new task.
The ``timeout(timeout, *, loop=None)`` call returns a context manager
that cancels a block on *timeout* expiring::
with timeout(1.5):
yield from inner()
1. If ``inner()`` is executed faster than in ``1.5`` seconds nothing
happens.
2. Otherwise ``inner()`` is cancelled internally by sending
``asyncio.CancelledError`` into but ``asyncio.TimeoutError`` is
raised outside of context manager scope.
*timeout* parameter could be ``None`` for skipping timeout functionality.
Installation
------------
::
$ pip install async-timeout
The library is Python 3 only!
Authors and License
-------------------
The module is written by Andrew Svetlov.
It's *Apache 2* licensed and freely available.
CHANGES
=======
1.2.1 (2017-05-02)
------------------
* Support unpublished event loop's "current_task" api.
1.2.0 (2017-03-11)
------------------
* Extra check on context manager exit
* 0 is no-op timeout
1.1.0 (2016-10-20)
------------------
* Rename to `async-timeout`
1.0.0 (2016-09-09)
------------------
* The first release.

View File

@ -0,0 +1 @@
pip

View File

@ -0,0 +1,95 @@
Metadata-Version: 2.0
Name: async-timeout
Version: 1.2.1
Summary: Timeout context manager for asyncio programs
Home-page: https://github.com/aio-libs/async_timeout/
Author: Andrew Svetlov
Author-email: andrew.svetlov@gmail.com
License: Apache 2
Platform: UNKNOWN
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Intended Audience :: Developers
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.4
Classifier: Programming Language :: Python :: 3.5
Classifier: Programming Language :: Python :: 3.6
Classifier: Topic :: Internet :: WWW/HTTP
Classifier: Framework :: AsyncIO
async-timeout
=============
asyncio-compatible timeout context manager.
Usage example
-------------
The context manager is useful in cases when you want to apply timeout
logic around block of code or in cases when ``asyncio.wait_for()`` is
not suitable. Also it's much faster than ``asyncio.wait_for()``
because ``timeout`` doesn't create a new task.
The ``timeout(timeout, *, loop=None)`` call returns a context manager
that cancels a block on *timeout* expiring::
with timeout(1.5):
yield from inner()
1. If ``inner()`` is executed faster than in ``1.5`` seconds nothing
happens.
2. Otherwise ``inner()`` is cancelled internally by sending
``asyncio.CancelledError`` into but ``asyncio.TimeoutError`` is
raised outside of context manager scope.
*timeout* parameter could be ``None`` for skipping timeout functionality.
Installation
------------
::
$ pip install async-timeout
The library is Python 3 only!
Authors and License
-------------------
The module is written by Andrew Svetlov.
It's *Apache 2* licensed and freely available.
CHANGES
=======
1.2.1 (2017-05-02)
------------------
* Support unpublished event loop's "current_task" api.
1.2.0 (2017-03-11)
------------------
* Extra check on context manager exit
* 0 is no-op timeout
1.1.0 (2016-10-20)
------------------
* Rename to `async-timeout`
1.0.0 (2016-09-09)
------------------
* The first release.

View File

@ -0,0 +1,9 @@
async_timeout/__init__.py,sha256=5ONrYCJMKAzWV1qcK-qhUkRKPGEQfEzCWoxVvwAAru4,1913
async_timeout-1.2.1.dist-info/DESCRIPTION.rst,sha256=IQuZGR3YfIcIGhWshP8gce8HXNCvMhxA-ov9oroqnI8,1430
async_timeout-1.2.1.dist-info/METADATA,sha256=cOJx0VKD1jtlzkp0JAlBu4nzZ5cRX35I8PCQW4CG5bE,2117
async_timeout-1.2.1.dist-info/RECORD,,
async_timeout-1.2.1.dist-info/WHEEL,sha256=rNo05PbNqwnXiIHFsYm0m22u4Zm6YJtugFG2THx4w3g,92
async_timeout-1.2.1.dist-info/metadata.json,sha256=FwV6Nc2u0faHG1tFFHTGglKZida7muCQA-9_jH0qq5E,889
async_timeout-1.2.1.dist-info/top_level.txt,sha256=9oM4e7Twq8iD_7_Q3Mz0E6GPIB6vJvRFo-UBwUQtBDU,14
async_timeout-1.2.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
async_timeout/__pycache__/__init__.cpython-36.pyc,,

View File

@ -0,0 +1,5 @@
Wheel-Version: 1.0
Generator: bdist_wheel (0.29.0)
Root-Is-Purelib: true
Tag: py3-none-any

View File

@ -0,0 +1 @@
async_timeout

View File

@ -0,0 +1,62 @@
import asyncio
__version__ = '1.2.1'
class timeout:
"""timeout context manager.
Useful in cases when you want to apply timeout logic around block
of code or in cases when asyncio.wait_for is not suitable. For example:
>>> with timeout(0.001):
... async with aiohttp.get('https://github.com') as r:
... await r.text()
timeout - value in seconds or None to disable timeout logic
loop - asyncio compatible event loop
"""
def __init__(self, timeout, *, loop=None):
if timeout is not None and timeout == 0:
timeout = None
self._timeout = timeout
if loop is None:
loop = asyncio.get_event_loop()
self._loop = loop
self._task = None
self._cancelled = False
self._cancel_handler = None
def __enter__(self):
if self._timeout is not None:
self._task = current_task(self._loop)
if self._task is None:
raise RuntimeError('Timeout context manager should be used '
'inside a task')
self._cancel_handler = self._loop.call_later(
self._timeout, self._cancel_task)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is asyncio.CancelledError and self._cancelled:
self._cancel_handler = None
self._task = None
raise asyncio.TimeoutError from None
if self._timeout is not None and self._cancel_handler is not None:
self._cancel_handler.cancel()
self._cancel_handler = None
self._task = None
def _cancel_task(self):
self._cancelled = self._task.cancel()
def current_task(loop):
task = asyncio.Task.current_task(loop=loop)
if task is None:
if hasattr(loop, 'current_task'):
task = loop.current_task()
return task

View File

@ -0,0 +1,49 @@
Certifi: Python SSL Certificates
================================
`Certifi`_ is a carefully curated collection of Root Certificates for
validating the trustworthiness of SSL certificates while verifying the identity
of TLS hosts. It has been extracted from the `Requests`_ project.
Installation
------------
``certifi`` is available on PyPI. Simply install it with ``pip``::
$ pip install certifi
Usage
-----
To reference the installed certificate authority (CA) bundle, you can use the
built-in function::
>>> import certifi
>>> certifi.where()
'/usr/local/lib/python2.7/site-packages/certifi/cacert.pem'
Enjoy!
1024-bit Root Certificates
~~~~~~~~~~~~~~~~~~~~~~~~~~
Browsers and certificate authorities have concluded that 1024-bit keys are
unacceptably weak for certificates, particularly root certificates. For this
reason, Mozilla has removed any weak (i.e. 1024-bit key) certificate from its
bundle, replacing it with an equivalent strong (i.e. 2048-bit or greater key)
certificate from the same CA. Because Mozilla removed these certificates from
its bundle, ``certifi`` removed them as well.
Unfortunately, old versions of OpenSSL (less than 1.0.2) sometimes fail to
validate certificate chains that use the strong roots. For this reason, if you
fail to validate a certificate using the ``certifi.where()`` mechanism, you can
intentionally re-add the 1024-bit roots back into your bundle by calling
``certifi.old_where()`` instead. This is not recommended in production: if at
all possible you should upgrade to a newer OpenSSL. However, if you have no
other option, this may work for you.
.. _`Certifi`: http://certifi.io/en/latest/
.. _`Requests`: http://docs.python-requests.org/en/latest/

View File

@ -0,0 +1 @@
pip

View File

@ -0,0 +1,68 @@
Metadata-Version: 2.0
Name: certifi
Version: 2017.4.17
Summary: Python package for providing Mozilla's CA Bundle.
Home-page: http://certifi.io/
Author: Kenneth Reitz
Author-email: me@kennethreitz.com
License: ISC
Platform: UNKNOWN
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Developers
Classifier: Natural Language :: English
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 2.6
Classifier: Programming Language :: Python :: 2.7
Classifier: Programming Language :: Python :: 3.3
Classifier: Programming Language :: Python :: 3.4
Classifier: Programming Language :: Python :: 3.5
Certifi: Python SSL Certificates
================================
`Certifi`_ is a carefully curated collection of Root Certificates for
validating the trustworthiness of SSL certificates while verifying the identity
of TLS hosts. It has been extracted from the `Requests`_ project.
Installation
------------
``certifi`` is available on PyPI. Simply install it with ``pip``::
$ pip install certifi
Usage
-----
To reference the installed certificate authority (CA) bundle, you can use the
built-in function::
>>> import certifi
>>> certifi.where()
'/usr/local/lib/python2.7/site-packages/certifi/cacert.pem'
Enjoy!
1024-bit Root Certificates
~~~~~~~~~~~~~~~~~~~~~~~~~~
Browsers and certificate authorities have concluded that 1024-bit keys are
unacceptably weak for certificates, particularly root certificates. For this
reason, Mozilla has removed any weak (i.e. 1024-bit key) certificate from its
bundle, replacing it with an equivalent strong (i.e. 2048-bit or greater key)
certificate from the same CA. Because Mozilla removed these certificates from
its bundle, ``certifi`` removed them as well.
Unfortunately, old versions of OpenSSL (less than 1.0.2) sometimes fail to
validate certificate chains that use the strong roots. For this reason, if you
fail to validate a certificate using the ``certifi.where()`` mechanism, you can
intentionally re-add the 1024-bit roots back into your bundle by calling
``certifi.old_where()`` instead. This is not recommended in production: if at
all possible you should upgrade to a newer OpenSSL. However, if you have no
other option, this may work for you.
.. _`Certifi`: http://certifi.io/en/latest/
.. _`Requests`: http://docs.python-requests.org/en/latest/

View File

@ -0,0 +1,16 @@
certifi/__init__.py,sha256=fygqpMx6KPrCIRMY4qcO5Zo60MDyQCYtNaI5BbgZyqE,63
certifi/__main__.py,sha256=FiOYt1Fltst7wk9DRa6GCoBr8qBUxlNQu_MKJf04E6s,41
certifi/cacert.pem,sha256=UgTuBXP5FC1mKK2skamUrJKyL8RVMmtUTKJZpTSFn_U,321422
certifi/core.py,sha256=DqvIINYNNXsp3Srlk_NRaiizaww8po3l8t8ksz-Xt6Q,716
certifi/old_root.pem,sha256=HT0KIfaM83q0XHFqGEesiGyfmlSWuD2RI0-AVIS2srY,25626
certifi/weak.pem,sha256=LGe1E3ewgvNAs_yRA9ZKBN6C5KV2Cx34iJFMPi8_hyo,347048
certifi-2017.4.17.dist-info/DESCRIPTION.rst,sha256=wVWYoH3eovdWFPZnYU2NT4itGRx3eN5C_s1IuNm4qF4,1731
certifi-2017.4.17.dist-info/METADATA,sha256=ZzDLL1LWFj2SDbZdV_QZ-ZNWd_UDw_NXGbARLwgLYdg,2396
certifi-2017.4.17.dist-info/RECORD,,
certifi-2017.4.17.dist-info/WHEEL,sha256=5wvfB7GvgZAbKBSE9uX9Zbi6LCL-_KgezgHblXhCRnM,113
certifi-2017.4.17.dist-info/metadata.json,sha256=8MYPZlDmqjogcs0bl7CuyQwwrSvqaFjpGMYsjFBQBlw,790
certifi-2017.4.17.dist-info/top_level.txt,sha256=KMu4vUCfsjLrkPbSNdgdekS-pVJzBAJFO__nI8NF6-U,8
certifi-2017.4.17.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
certifi/__pycache__/core.cpython-36.pyc,,
certifi/__pycache__/__init__.cpython-36.pyc,,
certifi/__pycache__/__main__.cpython-36.pyc,,

View File

@ -0,0 +1,6 @@
Wheel-Version: 1.0
Generator: bdist_wheel (0.30.0.a0)
Root-Is-Purelib: true
Tag: py2-none-any
Tag: py3-none-any

View File

@ -0,0 +1 @@
certifi

View File

@ -0,0 +1,3 @@
from .core import where, old_where
__version__ = "2017.04.17"

View File

@ -0,0 +1,2 @@
from certifi import where
print(where())

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,36 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
certifi.py
~~~~~~~~~~
This module returns the installation location of cacert.pem.
"""
import os
import warnings
class DeprecatedBundleWarning(DeprecationWarning):
"""
The weak security bundle is being deprecated. Please bother your service
provider to get them to stop using cross-signed roots.
"""
def where():
f = os.path.split(__file__)[0]
return os.path.join(f, 'cacert.pem')
def old_where():
warnings.warn(
"The weak security bundle is being deprecated.",
DeprecatedBundleWarning
)
f = os.path.split(__file__)[0]
return os.path.join(f, 'weak.pem')
if __name__ == '__main__':
print(where())

View File

@ -0,0 +1,414 @@
# Issuer: CN=Entrust.net Secure Server Certification Authority O=Entrust.net OU=www.entrust.net/CPS incorp. by ref. (limits liab.)/(c) 1999 Entrust.net Limited
# Subject: CN=Entrust.net Secure Server Certification Authority O=Entrust.net OU=www.entrust.net/CPS incorp. by ref. (limits liab.)/(c) 1999 Entrust.net Limited
# Label: "Entrust.net Secure Server CA"
# Serial: 927650371
# MD5 Fingerprint: df:f2:80:73:cc:f1:e6:61:73:fc:f5:42:e9:c5:7c:ee
# SHA1 Fingerprint: 99:a6:9b:e6:1a:fe:88:6b:4d:2b:82:00:7c:b8:54:fc:31:7e:15:39
# SHA256 Fingerprint: 62:f2:40:27:8c:56:4c:4d:d8:bf:7d:9d:4f:6f:36:6e:a8:94:d2:2f:5f:34:d9:89:a9:83:ac:ec:2f:ff:ed:50
-----BEGIN CERTIFICATE-----
MIIE2DCCBEGgAwIBAgIEN0rSQzANBgkqhkiG9w0BAQUFADCBwzELMAkGA1UEBhMC
VVMxFDASBgNVBAoTC0VudHJ1c3QubmV0MTswOQYDVQQLEzJ3d3cuZW50cnVzdC5u
ZXQvQ1BTIGluY29ycC4gYnkgcmVmLiAobGltaXRzIGxpYWIuKTElMCMGA1UECxMc
KGMpIDE5OTkgRW50cnVzdC5uZXQgTGltaXRlZDE6MDgGA1UEAxMxRW50cnVzdC5u
ZXQgU2VjdXJlIFNlcnZlciBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTAeFw05OTA1
MjUxNjA5NDBaFw0xOTA1MjUxNjM5NDBaMIHDMQswCQYDVQQGEwJVUzEUMBIGA1UE
ChMLRW50cnVzdC5uZXQxOzA5BgNVBAsTMnd3dy5lbnRydXN0Lm5ldC9DUFMgaW5j
b3JwLiBieSByZWYuIChsaW1pdHMgbGlhYi4pMSUwIwYDVQQLExwoYykgMTk5OSBF
bnRydXN0Lm5ldCBMaW1pdGVkMTowOAYDVQQDEzFFbnRydXN0Lm5ldCBTZWN1cmUg
U2VydmVyIENlcnRpZmljYXRpb24gQXV0aG9yaXR5MIGdMA0GCSqGSIb3DQEBAQUA
A4GLADCBhwKBgQDNKIM0VBuJ8w+vN5Ex/68xYMmo6LIQaO2f55M28Qpku0f1BBc/
I0dNxScZgSYMVHINiC3ZH5oSn7yzcdOAGT9HZnuMNSjSuQrfJNqc1lB5gXpa0zf3
wkrYKZImZNHkmGw6AIr1NJtl+O3jEP/9uElY3KDegjlrgbEWGWG5VLbmQwIBA6OC
AdcwggHTMBEGCWCGSAGG+EIBAQQEAwIABzCCARkGA1UdHwSCARAwggEMMIHeoIHb
oIHYpIHVMIHSMQswCQYDVQQGEwJVUzEUMBIGA1UEChMLRW50cnVzdC5uZXQxOzA5
BgNVBAsTMnd3dy5lbnRydXN0Lm5ldC9DUFMgaW5jb3JwLiBieSByZWYuIChsaW1p
dHMgbGlhYi4pMSUwIwYDVQQLExwoYykgMTk5OSBFbnRydXN0Lm5ldCBMaW1pdGVk
MTowOAYDVQQDEzFFbnRydXN0Lm5ldCBTZWN1cmUgU2VydmVyIENlcnRpZmljYXRp
b24gQXV0aG9yaXR5MQ0wCwYDVQQDEwRDUkwxMCmgJ6AlhiNodHRwOi8vd3d3LmVu
dHJ1c3QubmV0L0NSTC9uZXQxLmNybDArBgNVHRAEJDAigA8xOTk5MDUyNTE2MDk0
MFqBDzIwMTkwNTI1MTYwOTQwWjALBgNVHQ8EBAMCAQYwHwYDVR0jBBgwFoAU8Bdi
E1U9s/8KAGv7UISX8+1i0BowHQYDVR0OBBYEFPAXYhNVPbP/CgBr+1CEl/PtYtAa
MAwGA1UdEwQFMAMBAf8wGQYJKoZIhvZ9B0EABAwwChsEVjQuMAMCBJAwDQYJKoZI
hvcNAQEFBQADgYEAkNwwAvpkdMKnCqV8IY00F6j7Rw7/JXyNEwr75Ji174z4xRAN
95K+8cPV1ZVqBLssziY2ZcgxxufuP+NXdYR6Ee9GTxj005i7qIcyunL2POI9n9cd
2cNgQ4xYDiKWL2KjLB+6rQXvqzJ4h6BUcxm1XAX5Uj5tLUUL9wqT6u0G+bI=
-----END CERTIFICATE-----
# Issuer: CN=http://www.valicert.com/ O=ValiCert, Inc. OU=ValiCert Class 2 Policy Validation Authority
# Subject: CN=http://www.valicert.com/ O=ValiCert, Inc. OU=ValiCert Class 2 Policy Validation Authority
# Label: "ValiCert Class 2 VA"
# Serial: 1
# MD5 Fingerprint: a9:23:75:9b:ba:49:36:6e:31:c2:db:f2:e7:66:ba:87
# SHA1 Fingerprint: 31:7a:2a:d0:7f:2b:33:5e:f5:a1:c3:4e:4b:57:e8:b7:d8:f1:fc:a6
# SHA256 Fingerprint: 58:d0:17:27:9c:d4:dc:63:ab:dd:b1:96:a6:c9:90:6c:30:c4:e0:87:83:ea:e8:c1:60:99:54:d6:93:55:59:6b
-----BEGIN CERTIFICATE-----
MIIC5zCCAlACAQEwDQYJKoZIhvcNAQEFBQAwgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0
IFZhbGlkYXRpb24gTmV0d29yazEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAz
BgNVBAsTLFZhbGlDZXJ0IENsYXNzIDIgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9y
aXR5MSEwHwYDVQQDExhodHRwOi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG
9w0BCQEWEWluZm9AdmFsaWNlcnQuY29tMB4XDTk5MDYyNjAwMTk1NFoXDTE5MDYy
NjAwMTk1NFowgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0IFZhbGlkYXRpb24gTmV0d29y
azEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAzBgNVBAsTLFZhbGlDZXJ0IENs
YXNzIDIgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9yaXR5MSEwHwYDVQQDExhodHRw
Oi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG9w0BCQEWEWluZm9AdmFsaWNl
cnQuY29tMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDOOnHK5avIWZJV16vY
dA757tn2VUdZZUcOBVXc65g2PFxTXdMwzzjsvUGJ7SVCCSRrCl6zfN1SLUzm1NZ9
WlmpZdRJEy0kTRxQb7XBhVQ7/nHk01xC+YDgkRoKWzk2Z/M/VXwbP7RfZHM047QS
v4dk+NoS/zcnwbNDu+97bi5p9wIDAQABMA0GCSqGSIb3DQEBBQUAA4GBADt/UG9v
UJSZSWI4OB9L+KXIPqeCgfYrx+jFzug6EILLGACOTb2oWH+heQC1u+mNr0HZDzTu
IYEZoDJJKPTEjlbVUjP9UNV+mWwD5MlM/Mtsq2azSiGM5bUMMj4QssxsodyamEwC
W/POuZ6lcg5Ktz885hZo+L7tdEy8W9ViH0Pd
-----END CERTIFICATE-----
# Issuer: CN=NetLock Expressz (Class C) Tanusitvanykiado O=NetLock Halozatbiztonsagi Kft. OU=Tanusitvanykiadok
# Subject: CN=NetLock Expressz (Class C) Tanusitvanykiado O=NetLock Halozatbiztonsagi Kft. OU=Tanusitvanykiadok
# Label: "NetLock Express (Class C) Root"
# Serial: 104
# MD5 Fingerprint: 4f:eb:f1:f0:70:c2:80:63:5d:58:9f:da:12:3c:a9:c4
# SHA1 Fingerprint: e3:92:51:2f:0a:cf:f5:05:df:f6:de:06:7f:75:37:e1:65:ea:57:4b
# SHA256 Fingerprint: 0b:5e:ed:4e:84:64:03:cf:55:e0:65:84:84:40:ed:2a:82:75:8b:f5:b9:aa:1f:25:3d:46:13:cf:a0:80:ff:3f
-----BEGIN CERTIFICATE-----
MIIFTzCCBLigAwIBAgIBaDANBgkqhkiG9w0BAQQFADCBmzELMAkGA1UEBhMCSFUx
ETAPBgNVBAcTCEJ1ZGFwZXN0MScwJQYDVQQKEx5OZXRMb2NrIEhhbG96YXRiaXp0
b25zYWdpIEtmdC4xGjAYBgNVBAsTEVRhbnVzaXR2YW55a2lhZG9rMTQwMgYDVQQD
EytOZXRMb2NrIEV4cHJlc3N6IChDbGFzcyBDKSBUYW51c2l0dmFueWtpYWRvMB4X
DTk5MDIyNTE0MDgxMVoXDTE5MDIyMDE0MDgxMVowgZsxCzAJBgNVBAYTAkhVMREw
DwYDVQQHEwhCdWRhcGVzdDEnMCUGA1UEChMeTmV0TG9jayBIYWxvemF0Yml6dG9u
c2FnaSBLZnQuMRowGAYDVQQLExFUYW51c2l0dmFueWtpYWRvazE0MDIGA1UEAxMr
TmV0TG9jayBFeHByZXNzeiAoQ2xhc3MgQykgVGFudXNpdHZhbnlraWFkbzCBnzAN
BgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA6+ywbGGKIyWvYCDj2Z/8kwvbXY2wobNA
OoLO/XXgeDIDhlqGlZHtU/qdQPzm6N3ZW3oDvV3zOwzDUXmbrVWg6dADEK8KuhRC
2VImESLH0iDMgqSaqf64gXadarfSNnU+sYYJ9m5tfk63euyucYT2BDMIJTLrdKwW
RMbkQJMdf60CAwEAAaOCAp8wggKbMBIGA1UdEwEB/wQIMAYBAf8CAQQwDgYDVR0P
AQH/BAQDAgAGMBEGCWCGSAGG+EIBAQQEAwIABzCCAmAGCWCGSAGG+EIBDQSCAlEW
ggJNRklHWUVMRU0hIEV6ZW4gdGFudXNpdHZhbnkgYSBOZXRMb2NrIEtmdC4gQWx0
YWxhbm9zIFN6b2xnYWx0YXRhc2kgRmVsdGV0ZWxlaWJlbiBsZWlydCBlbGphcmFz
b2sgYWxhcGphbiBrZXN6dWx0LiBBIGhpdGVsZXNpdGVzIGZvbHlhbWF0YXQgYSBO
ZXRMb2NrIEtmdC4gdGVybWVrZmVsZWxvc3NlZy1iaXp0b3NpdGFzYSB2ZWRpLiBB
IGRpZ2l0YWxpcyBhbGFpcmFzIGVsZm9nYWRhc2FuYWsgZmVsdGV0ZWxlIGF6IGVs
b2lydCBlbGxlbm9yemVzaSBlbGphcmFzIG1lZ3RldGVsZS4gQXogZWxqYXJhcyBs
ZWlyYXNhIG1lZ3RhbGFsaGF0byBhIE5ldExvY2sgS2Z0LiBJbnRlcm5ldCBob25s
YXBqYW4gYSBodHRwczovL3d3dy5uZXRsb2NrLm5ldC9kb2NzIGNpbWVuIHZhZ3kg
a2VyaGV0byBheiBlbGxlbm9yemVzQG5ldGxvY2submV0IGUtbWFpbCBjaW1lbi4g
SU1QT1JUQU5UISBUaGUgaXNzdWFuY2UgYW5kIHRoZSB1c2Ugb2YgdGhpcyBjZXJ0
aWZpY2F0ZSBpcyBzdWJqZWN0IHRvIHRoZSBOZXRMb2NrIENQUyBhdmFpbGFibGUg
YXQgaHR0cHM6Ly93d3cubmV0bG9jay5uZXQvZG9jcyBvciBieSBlLW1haWwgYXQg
Y3BzQG5ldGxvY2submV0LjANBgkqhkiG9w0BAQQFAAOBgQAQrX/XDDKACtiG8XmY
ta3UzbM2xJZIwVzNmtkFLp++UOv0JhQQLdRmF/iewSf98e3ke0ugbLWrmldwpu2g
pO0u9f38vf5NNwgMvOOWgyL1SRt/Syu0VMGAfJlOHdCM7tCs5ZL6dVb+ZKATj7i4
Fp1hBWeAyNDYpQcCNJgEjTME1A==
-----END CERTIFICATE-----
# Issuer: CN=NetLock Uzleti (Class B) Tanusitvanykiado O=NetLock Halozatbiztonsagi Kft. OU=Tanusitvanykiadok
# Subject: CN=NetLock Uzleti (Class B) Tanusitvanykiado O=NetLock Halozatbiztonsagi Kft. OU=Tanusitvanykiadok
# Label: "NetLock Business (Class B) Root"
# Serial: 105
# MD5 Fingerprint: 39:16:aa:b9:6a:41:e1:14:69:df:9e:6c:3b:72:dc:b6
# SHA1 Fingerprint: 87:9f:4b:ee:05:df:98:58:3b:e3:60:d6:33:e7:0d:3f:fe:98:71:af
# SHA256 Fingerprint: 39:df:7b:68:2b:7b:93:8f:84:71:54:81:cc:de:8d:60:d8:f2:2e:c5:98:87:7d:0a:aa:c1:2b:59:18:2b:03:12
-----BEGIN CERTIFICATE-----
MIIFSzCCBLSgAwIBAgIBaTANBgkqhkiG9w0BAQQFADCBmTELMAkGA1UEBhMCSFUx
ETAPBgNVBAcTCEJ1ZGFwZXN0MScwJQYDVQQKEx5OZXRMb2NrIEhhbG96YXRiaXp0
b25zYWdpIEtmdC4xGjAYBgNVBAsTEVRhbnVzaXR2YW55a2lhZG9rMTIwMAYDVQQD
EylOZXRMb2NrIFV6bGV0aSAoQ2xhc3MgQikgVGFudXNpdHZhbnlraWFkbzAeFw05
OTAyMjUxNDEwMjJaFw0xOTAyMjAxNDEwMjJaMIGZMQswCQYDVQQGEwJIVTERMA8G
A1UEBxMIQnVkYXBlc3QxJzAlBgNVBAoTHk5ldExvY2sgSGFsb3phdGJpenRvbnNh
Z2kgS2Z0LjEaMBgGA1UECxMRVGFudXNpdHZhbnlraWFkb2sxMjAwBgNVBAMTKU5l
dExvY2sgVXpsZXRpIChDbGFzcyBCKSBUYW51c2l0dmFueWtpYWRvMIGfMA0GCSqG
SIb3DQEBAQUAA4GNADCBiQKBgQCx6gTsIKAjwo84YM/HRrPVG/77uZmeBNwcf4xK
gZjupNTKihe5In+DCnVMm8Bp2GQ5o+2So/1bXHQawEfKOml2mrriRBf8TKPV/riX
iK+IA4kfpPIEPsgHC+b5sy96YhQJRhTKZPWLgLViqNhr1nGTLbO/CVRY7QbrqHvc
Q7GhaQIDAQABo4ICnzCCApswEgYDVR0TAQH/BAgwBgEB/wIBBDAOBgNVHQ8BAf8E
BAMCAAYwEQYJYIZIAYb4QgEBBAQDAgAHMIICYAYJYIZIAYb4QgENBIICURaCAk1G
SUdZRUxFTSEgRXplbiB0YW51c2l0dmFueSBhIE5ldExvY2sgS2Z0LiBBbHRhbGFu
b3MgU3pvbGdhbHRhdGFzaSBGZWx0ZXRlbGVpYmVuIGxlaXJ0IGVsamFyYXNvayBh
bGFwamFuIGtlc3p1bHQuIEEgaGl0ZWxlc2l0ZXMgZm9seWFtYXRhdCBhIE5ldExv
Y2sgS2Z0LiB0ZXJtZWtmZWxlbG9zc2VnLWJpenRvc2l0YXNhIHZlZGkuIEEgZGln
aXRhbGlzIGFsYWlyYXMgZWxmb2dhZGFzYW5hayBmZWx0ZXRlbGUgYXogZWxvaXJ0
IGVsbGVub3J6ZXNpIGVsamFyYXMgbWVndGV0ZWxlLiBBeiBlbGphcmFzIGxlaXJh
c2EgbWVndGFsYWxoYXRvIGEgTmV0TG9jayBLZnQuIEludGVybmV0IGhvbmxhcGph
biBhIGh0dHBzOi8vd3d3Lm5ldGxvY2submV0L2RvY3MgY2ltZW4gdmFneSBrZXJo
ZXRvIGF6IGVsbGVub3J6ZXNAbmV0bG9jay5uZXQgZS1tYWlsIGNpbWVuLiBJTVBP
UlRBTlQhIFRoZSBpc3N1YW5jZSBhbmQgdGhlIHVzZSBvZiB0aGlzIGNlcnRpZmlj
YXRlIGlzIHN1YmplY3QgdG8gdGhlIE5ldExvY2sgQ1BTIGF2YWlsYWJsZSBhdCBo
dHRwczovL3d3dy5uZXRsb2NrLm5ldC9kb2NzIG9yIGJ5IGUtbWFpbCBhdCBjcHNA
bmV0bG9jay5uZXQuMA0GCSqGSIb3DQEBBAUAA4GBAATbrowXr/gOkDFOzT4JwG06
sPgzTEdM43WIEJessDgVkcYplswhwG08pXTP2IKlOcNl40JwuyKQ433bNXbhoLXa
n3BukxowOR0w2y7jfLKRstE3Kfq51hdcR0/jHTjrn9V7lagonhVK0dHQKwCXoOKS
NitjrFgBazMpUIaD8QFI
-----END CERTIFICATE-----
# Issuer: CN=http://www.valicert.com/ O=ValiCert, Inc. OU=ValiCert Class 3 Policy Validation Authority
# Subject: CN=http://www.valicert.com/ O=ValiCert, Inc. OU=ValiCert Class 3 Policy Validation Authority
# Label: "RSA Root Certificate 1"
# Serial: 1
# MD5 Fingerprint: a2:6f:53:b7:ee:40:db:4a:68:e7:fa:18:d9:10:4b:72
# SHA1 Fingerprint: 69:bd:8c:f4:9c:d3:00:fb:59:2e:17:93:ca:55:6a:f3:ec:aa:35:fb
# SHA256 Fingerprint: bc:23:f9:8a:31:3c:b9:2d:e3:bb:fc:3a:5a:9f:44:61:ac:39:49:4c:4a:e1:5a:9e:9d:f1:31:e9:9b:73:01:9a
-----BEGIN CERTIFICATE-----
MIIC5zCCAlACAQEwDQYJKoZIhvcNAQEFBQAwgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0
IFZhbGlkYXRpb24gTmV0d29yazEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAz
BgNVBAsTLFZhbGlDZXJ0IENsYXNzIDMgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9y
aXR5MSEwHwYDVQQDExhodHRwOi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG
9w0BCQEWEWluZm9AdmFsaWNlcnQuY29tMB4XDTk5MDYyNjAwMjIzM1oXDTE5MDYy
NjAwMjIzM1owgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0IFZhbGlkYXRpb24gTmV0d29y
azEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAzBgNVBAsTLFZhbGlDZXJ0IENs
YXNzIDMgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9yaXR5MSEwHwYDVQQDExhodHRw
Oi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG9w0BCQEWEWluZm9AdmFsaWNl
cnQuY29tMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDjmFGWHOjVsQaBalfD
cnWTq8+epvzzFlLWLU2fNUSoLgRNB0mKOCn1dzfnt6td3zZxFJmP3MKS8edgkpfs
2Ejcv8ECIMYkpChMMFp2bbFc893enhBxoYjHW5tBbcqwuI4V7q0zK89HBFx1cQqY
JJgpp0lZpd34t0NiYfPT4tBVPwIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAFa7AliE
Zwgs3x/be0kz9dNnnfS0ChCzycUs4pJqcXgn8nCDQtM+z6lU9PHYkhaM0QTLS6vJ
n0WuPIqpsHEzXcjFV9+vqDWzf4mH6eglkrh/hXqu1rweN1gqZ8mRzyqBPu3GOd/A
PhmcGcwTTYJBtYze4D1gCCAPRX5ron+jjBXu
-----END CERTIFICATE-----
# Issuer: CN=http://www.valicert.com/ O=ValiCert, Inc. OU=ValiCert Class 1 Policy Validation Authority
# Subject: CN=http://www.valicert.com/ O=ValiCert, Inc. OU=ValiCert Class 1 Policy Validation Authority
# Label: "ValiCert Class 1 VA"
# Serial: 1
# MD5 Fingerprint: 65:58:ab:15:ad:57:6c:1e:a8:a7:b5:69:ac:bf:ff:eb
# SHA1 Fingerprint: e5:df:74:3c:b6:01:c4:9b:98:43:dc:ab:8c:e8:6a:81:10:9f:e4:8e
# SHA256 Fingerprint: f4:c1:49:55:1a:30:13:a3:5b:c7:bf:fe:17:a7:f3:44:9b:c1:ab:5b:5a:0a:e7:4b:06:c2:3b:90:00:4c:01:04
-----BEGIN CERTIFICATE-----
MIIC5zCCAlACAQEwDQYJKoZIhvcNAQEFBQAwgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0
IFZhbGlkYXRpb24gTmV0d29yazEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAz
BgNVBAsTLFZhbGlDZXJ0IENsYXNzIDEgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9y
aXR5MSEwHwYDVQQDExhodHRwOi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG
9w0BCQEWEWluZm9AdmFsaWNlcnQuY29tMB4XDTk5MDYyNTIyMjM0OFoXDTE5MDYy
NTIyMjM0OFowgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0IFZhbGlkYXRpb24gTmV0d29y
azEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAzBgNVBAsTLFZhbGlDZXJ0IENs
YXNzIDEgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9yaXR5MSEwHwYDVQQDExhodHRw
Oi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG9w0BCQEWEWluZm9AdmFsaWNl
cnQuY29tMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDYWYJ6ibiWuqYvaG9Y
LqdUHAZu9OqNSLwxlBfw8068srg1knaw0KWlAdcAAxIiGQj4/xEjm84H9b9pGib+
TunRf50sQB1ZaG6m+FiwnRqP0z/x3BkGgagO4DrdyFNFCQbmD3DD+kCmDuJWBQ8Y
TfwggtFzVXSNdnKgHZ0dwN0/cQIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAFBoPUn0
LBwGlN+VYH+Wexf+T3GtZMjdd9LvWVXoP+iOBSoh8gfStadS/pyxtuJbdxdA6nLW
I8sogTLDAHkY7FkXicnGah5xyf23dKUlRWnFSKsZ4UWKJWsZ7uW7EvV/96aNUcPw
nXS3qT6gpf+2SQMT2iLM7XGCK5nPOrf1LXLI
-----END CERTIFICATE-----
# Issuer: CN=Equifax Secure eBusiness CA-1 O=Equifax Secure Inc.
# Subject: CN=Equifax Secure eBusiness CA-1 O=Equifax Secure Inc.
# Label: "Equifax Secure eBusiness CA 1"
# Serial: 4
# MD5 Fingerprint: 64:9c:ef:2e:44:fc:c6:8f:52:07:d0:51:73:8f:cb:3d
# SHA1 Fingerprint: da:40:18:8b:91:89:a3:ed:ee:ae:da:97:fe:2f:9d:f5:b7:d1:8a:41
# SHA256 Fingerprint: cf:56:ff:46:a4:a1:86:10:9d:d9:65:84:b5:ee:b5:8a:51:0c:42:75:b0:e5:f9:4f:40:bb:ae:86:5e:19:f6:73
-----BEGIN CERTIFICATE-----
MIICgjCCAeugAwIBAgIBBDANBgkqhkiG9w0BAQQFADBTMQswCQYDVQQGEwJVUzEc
MBoGA1UEChMTRXF1aWZheCBTZWN1cmUgSW5jLjEmMCQGA1UEAxMdRXF1aWZheCBT
ZWN1cmUgZUJ1c2luZXNzIENBLTEwHhcNOTkwNjIxMDQwMDAwWhcNMjAwNjIxMDQw
MDAwWjBTMQswCQYDVQQGEwJVUzEcMBoGA1UEChMTRXF1aWZheCBTZWN1cmUgSW5j
LjEmMCQGA1UEAxMdRXF1aWZheCBTZWN1cmUgZUJ1c2luZXNzIENBLTEwgZ8wDQYJ
KoZIhvcNAQEBBQADgY0AMIGJAoGBAM4vGbwXt3fek6lfWg0XTzQaDJj0ItlZ1MRo
RvC0NcWFAyDGr0WlIVFFQesWWDYyb+JQYmT5/VGcqiTZ9J2DKocKIdMSODRsjQBu
WqDZQu4aIZX5UkxVWsUPOE9G+m34LjXWHXzr4vCwdYDIqROsvojvOm6rXyo4YgKw
Env+j6YDAgMBAAGjZjBkMBEGCWCGSAGG+EIBAQQEAwIABzAPBgNVHRMBAf8EBTAD
AQH/MB8GA1UdIwQYMBaAFEp4MlIR21kWNl7fwRQ2QGpHfEyhMB0GA1UdDgQWBBRK
eDJSEdtZFjZe38EUNkBqR3xMoTANBgkqhkiG9w0BAQQFAAOBgQB1W6ibAxHm6VZM
zfmpTMANmvPMZWnmJXbMWbfWVMMdzZmsGd20hdXgPfxiIKeES1hl8eL5lSE/9dR+
WB5Hh1Q+WKG1tfgq73HnvMP2sUlG4tega+VWeponmHxGYhTnyfxuAxJ5gDgdSIKN
/Bf+KpYrtWKmpj29f5JZzVoqgrI3eQ==
-----END CERTIFICATE-----
# Issuer: CN=Equifax Secure Global eBusiness CA-1 O=Equifax Secure Inc.
# Subject: CN=Equifax Secure Global eBusiness CA-1 O=Equifax Secure Inc.
# Label: "Equifax Secure Global eBusiness CA"
# Serial: 1
# MD5 Fingerprint: 8f:5d:77:06:27:c4:98:3c:5b:93:78:e7:d7:7d:9b:cc
# SHA1 Fingerprint: 7e:78:4a:10:1c:82:65:cc:2d:e1:f1:6d:47:b4:40:ca:d9:0a:19:45
# SHA256 Fingerprint: 5f:0b:62:ea:b5:e3:53:ea:65:21:65:16:58:fb:b6:53:59:f4:43:28:0a:4a:fb:d1:04:d7:7d:10:f9:f0:4c:07
-----BEGIN CERTIFICATE-----
MIICkDCCAfmgAwIBAgIBATANBgkqhkiG9w0BAQQFADBaMQswCQYDVQQGEwJVUzEc
MBoGA1UEChMTRXF1aWZheCBTZWN1cmUgSW5jLjEtMCsGA1UEAxMkRXF1aWZheCBT
ZWN1cmUgR2xvYmFsIGVCdXNpbmVzcyBDQS0xMB4XDTk5MDYyMTA0MDAwMFoXDTIw
MDYyMTA0MDAwMFowWjELMAkGA1UEBhMCVVMxHDAaBgNVBAoTE0VxdWlmYXggU2Vj
dXJlIEluYy4xLTArBgNVBAMTJEVxdWlmYXggU2VjdXJlIEdsb2JhbCBlQnVzaW5l
c3MgQ0EtMTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAuucXkAJlsTRVPEnC
UdXfp9E3j9HngXNBUmCbnaEXJnitx7HoJpQytd4zjTov2/KaelpzmKNc6fuKcxtc
58O/gGzNqfTWK8D3+ZmqY6KxRwIP1ORROhI8bIpaVIRw28HFkM9yRcuoWcDNM50/
o5brhTMhHD4ePmBudpxnhcXIw2ECAwEAAaNmMGQwEQYJYIZIAYb4QgEBBAQDAgAH
MA8GA1UdEwEB/wQFMAMBAf8wHwYDVR0jBBgwFoAUvqigdHJQa0S3ySPY+6j/s1dr
aGwwHQYDVR0OBBYEFL6ooHRyUGtEt8kj2Puo/7NXa2hsMA0GCSqGSIb3DQEBBAUA
A4GBADDiAVGqx+pf2rnQZQ8w1j7aDRRJbpGTJxQx78T3LUX47Me/okENI7SS+RkA
Z70Br83gcfxaz2TE4JaY0KNA4gGK7ycH8WUBikQtBmV1UsCGECAhX2xrD2yuCRyv
8qIYNMR1pHMc8Y3c7635s3a0kr/clRAevsvIO1qEYBlWlKlV
-----END CERTIFICATE-----
# Issuer: CN=Thawte Premium Server CA O=Thawte Consulting cc OU=Certification Services Division
# Subject: CN=Thawte Premium Server CA O=Thawte Consulting cc OU=Certification Services Division
# Label: "Thawte Premium Server CA"
# Serial: 1
# MD5 Fingerprint: 06:9f:69:79:16:66:90:02:1b:8c:8c:a2:c3:07:6f:3a
# SHA1 Fingerprint: 62:7f:8d:78:27:65:63:99:d2:7d:7f:90:44:c9:fe:b3:f3:3e:fa:9a
# SHA256 Fingerprint: ab:70:36:36:5c:71:54:aa:29:c2:c2:9f:5d:41:91:16:3b:16:2a:22:25:01:13:57:d5:6d:07:ff:a7:bc:1f:72
-----BEGIN CERTIFICATE-----
MIIDJzCCApCgAwIBAgIBATANBgkqhkiG9w0BAQQFADCBzjELMAkGA1UEBhMCWkEx
FTATBgNVBAgTDFdlc3Rlcm4gQ2FwZTESMBAGA1UEBxMJQ2FwZSBUb3duMR0wGwYD
VQQKExRUaGF3dGUgQ29uc3VsdGluZyBjYzEoMCYGA1UECxMfQ2VydGlmaWNhdGlv
biBTZXJ2aWNlcyBEaXZpc2lvbjEhMB8GA1UEAxMYVGhhd3RlIFByZW1pdW0gU2Vy
dmVyIENBMSgwJgYJKoZIhvcNAQkBFhlwcmVtaXVtLXNlcnZlckB0aGF3dGUuY29t
MB4XDTk2MDgwMTAwMDAwMFoXDTIwMTIzMTIzNTk1OVowgc4xCzAJBgNVBAYTAlpB
MRUwEwYDVQQIEwxXZXN0ZXJuIENhcGUxEjAQBgNVBAcTCUNhcGUgVG93bjEdMBsG
A1UEChMUVGhhd3RlIENvbnN1bHRpbmcgY2MxKDAmBgNVBAsTH0NlcnRpZmljYXRp
b24gU2VydmljZXMgRGl2aXNpb24xITAfBgNVBAMTGFRoYXd0ZSBQcmVtaXVtIFNl
cnZlciBDQTEoMCYGCSqGSIb3DQEJARYZcHJlbWl1bS1zZXJ2ZXJAdGhhd3RlLmNv
bTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA0jY2aovXwlue2oFBYo847kkE
VdbQ7xwblRZH7xhINTpS9CtqBo87L+pW46+GjZ4X9560ZXUCTe/LCaIhUdib0GfQ
ug2SBhRz1JPLlyoAnFxODLz6FVL88kRu2hFKbgifLy3j+ao6hnO2RlNYyIkFvYMR
uHM/qgeN9EJN50CdHDcCAwEAAaMTMBEwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG
9w0BAQQFAAOBgQAmSCwWwlj66BZ0DKqqX1Q/8tfJeGBeXm43YyJ3Nn6yF8Q0ufUI
hfzJATj/Tb7yFkJD57taRvvBxhEf8UqwKEbJw8RCfbz6q1lu1bdRiBHjpIUZa4JM
pAwSremkrj/xw0llmozFyD4lt5SZu5IycQfwhl7tUCemDaYj+bvLpgcUQg==
-----END CERTIFICATE-----
# Issuer: CN=Thawte Server CA O=Thawte Consulting cc OU=Certification Services Division
# Subject: CN=Thawte Server CA O=Thawte Consulting cc OU=Certification Services Division
# Label: "Thawte Server CA"
# Serial: 1
# MD5 Fingerprint: c5:70:c4:a2:ed:53:78:0c:c8:10:53:81:64:cb:d0:1d
# SHA1 Fingerprint: 23:e5:94:94:51:95:f2:41:48:03:b4:d5:64:d2:a3:a3:f5:d8:8b:8c
# SHA256 Fingerprint: b4:41:0b:73:e2:e6:ea:ca:47:fb:c4:2f:8f:a4:01:8a:f4:38:1d:c5:4c:fa:a8:44:50:46:1e:ed:09:45:4d:e9
-----BEGIN CERTIFICATE-----
MIIDEzCCAnygAwIBAgIBATANBgkqhkiG9w0BAQQFADCBxDELMAkGA1UEBhMCWkEx
FTATBgNVBAgTDFdlc3Rlcm4gQ2FwZTESMBAGA1UEBxMJQ2FwZSBUb3duMR0wGwYD
VQQKExRUaGF3dGUgQ29uc3VsdGluZyBjYzEoMCYGA1UECxMfQ2VydGlmaWNhdGlv
biBTZXJ2aWNlcyBEaXZpc2lvbjEZMBcGA1UEAxMQVGhhd3RlIFNlcnZlciBDQTEm
MCQGCSqGSIb3DQEJARYXc2VydmVyLWNlcnRzQHRoYXd0ZS5jb20wHhcNOTYwODAx
MDAwMDAwWhcNMjAxMjMxMjM1OTU5WjCBxDELMAkGA1UEBhMCWkExFTATBgNVBAgT
DFdlc3Rlcm4gQ2FwZTESMBAGA1UEBxMJQ2FwZSBUb3duMR0wGwYDVQQKExRUaGF3
dGUgQ29uc3VsdGluZyBjYzEoMCYGA1UECxMfQ2VydGlmaWNhdGlvbiBTZXJ2aWNl
cyBEaXZpc2lvbjEZMBcGA1UEAxMQVGhhd3RlIFNlcnZlciBDQTEmMCQGCSqGSIb3
DQEJARYXc2VydmVyLWNlcnRzQHRoYXd0ZS5jb20wgZ8wDQYJKoZIhvcNAQEBBQAD
gY0AMIGJAoGBANOkUG7I/1Zr5s9dtuoMaHVHoqrC2oQl/Kj0R1HahbUgdJSGHg91
yekIYfUGbTBuFRkC6VLAYttNmZ7iagxEOM3+vuNkCXDF/rFrKbYvScg71CcEJRCX
L+eQbcAoQpnXTEPew/UhbVSfXcNY4cDk2VuwuNy0e982OsK1ZiIS1ocNAgMBAAGj
EzARMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQEEBQADgYEAB/pMaVz7lcxG
7oWDTSEwjsrZqG9JGubaUeNgcGyEYRGhGshIPllDfU+VPaGLtwtimHp1it2ITk6e
QNuozDJ0uW8NxuOzRAvZim+aKZuZGCg70eNAKJpaPNW15yAbi8qkq43pUdniTCxZ
qdq5snUb9kLy78fyGPmJvKP/iiMucEc=
-----END CERTIFICATE-----
# Issuer: O=VeriSign, Inc. OU=Class 3 Public Primary Certification Authority
# Subject: O=VeriSign, Inc. OU=Class 3 Public Primary Certification Authority
# Label: "Verisign Class 3 Public Primary Certification Authority"
# Serial: 149843929435818692848040365716851702463
# MD5 Fingerprint: 10:fc:63:5d:f6:26:3e:0d:f3:25:be:5f:79:cd:67:67
# SHA1 Fingerprint: 74:2c:31:92:e6:07:e4:24:eb:45:49:54:2b:e1:bb:c5:3e:61:74:e2
# SHA256 Fingerprint: e7:68:56:34:ef:ac:f6:9a:ce:93:9a:6b:25:5b:7b:4f:ab:ef:42:93:5b:50:a2:65:ac:b5:cb:60:27:e4:4e:70
-----BEGIN CERTIFICATE-----
MIICPDCCAaUCEHC65B0Q2Sk0tjjKewPMur8wDQYJKoZIhvcNAQECBQAwXzELMAkG
A1UEBhMCVVMxFzAVBgNVBAoTDlZlcmlTaWduLCBJbmMuMTcwNQYDVQQLEy5DbGFz
cyAzIFB1YmxpYyBQcmltYXJ5IENlcnRpZmljYXRpb24gQXV0aG9yaXR5MB4XDTk2
MDEyOTAwMDAwMFoXDTI4MDgwMTIzNTk1OVowXzELMAkGA1UEBhMCVVMxFzAVBgNV
BAoTDlZlcmlTaWduLCBJbmMuMTcwNQYDVQQLEy5DbGFzcyAzIFB1YmxpYyBQcmlt
YXJ5IENlcnRpZmljYXRpb24gQXV0aG9yaXR5MIGfMA0GCSqGSIb3DQEBAQUAA4GN
ADCBiQKBgQDJXFme8huKARS0EN8EQNvjV69qRUCPhAwL0TPZ2RHP7gJYHyX3KqhE
BarsAx94f56TuZoAqiN91qyFomNFx3InzPRMxnVx0jnvT0Lwdd8KkMaOIG+YD/is
I19wKTakyYbnsZogy1Olhec9vn2a/iRFM9x2Fe0PonFkTGUugWhFpwIDAQABMA0G
CSqGSIb3DQEBAgUAA4GBALtMEivPLCYATxQT3ab7/AoRhIzzKBxnki98tsX63/Do
lbwdj2wsqFHMc9ikwFPwTtYmwHYBV4GSXiHx0bH/59AhWM1pF+NEHJwZRDmJXNyc
AA9WjQKZ7aKQRUzkuxCkPfAyAw7xzvjoyVGM5mKf5p/AfbdynMk2OmufTqj/ZA1k
-----END CERTIFICATE-----
# Issuer: O=VeriSign, Inc. OU=Class 3 Public Primary Certification Authority
# Subject: O=VeriSign, Inc. OU=Class 3 Public Primary Certification Authority
# Label: "Verisign Class 3 Public Primary Certification Authority"
# Serial: 80507572722862485515306429940691309246
# MD5 Fingerprint: ef:5a:f1:33:ef:f1:cd:bb:51:02:ee:12:14:4b:96:c4
# SHA1 Fingerprint: a1:db:63:93:91:6f:17:e4:18:55:09:40:04:15:c7:02:40:b0:ae:6b
# SHA256 Fingerprint: a4:b6:b3:99:6f:c2:f3:06:b3:fd:86:81:bd:63:41:3d:8c:50:09:cc:4f:a3:29:c2:cc:f0:e2:fa:1b:14:03:05
-----BEGIN CERTIFICATE-----
MIICPDCCAaUCEDyRMcsf9tAbDpq40ES/Er4wDQYJKoZIhvcNAQEFBQAwXzELMAkG
A1UEBhMCVVMxFzAVBgNVBAoTDlZlcmlTaWduLCBJbmMuMTcwNQYDVQQLEy5DbGFz
cyAzIFB1YmxpYyBQcmltYXJ5IENlcnRpZmljYXRpb24gQXV0aG9yaXR5MB4XDTk2
MDEyOTAwMDAwMFoXDTI4MDgwMjIzNTk1OVowXzELMAkGA1UEBhMCVVMxFzAVBgNV
BAoTDlZlcmlTaWduLCBJbmMuMTcwNQYDVQQLEy5DbGFzcyAzIFB1YmxpYyBQcmlt
YXJ5IENlcnRpZmljYXRpb24gQXV0aG9yaXR5MIGfMA0GCSqGSIb3DQEBAQUAA4GN
ADCBiQKBgQDJXFme8huKARS0EN8EQNvjV69qRUCPhAwL0TPZ2RHP7gJYHyX3KqhE
BarsAx94f56TuZoAqiN91qyFomNFx3InzPRMxnVx0jnvT0Lwdd8KkMaOIG+YD/is
I19wKTakyYbnsZogy1Olhec9vn2a/iRFM9x2Fe0PonFkTGUugWhFpwIDAQABMA0G
CSqGSIb3DQEBBQUAA4GBABByUqkFFBkyCEHwxWsKzH4PIRnN5GfcX6kb5sroc50i
2JhucwNhkcV8sEVAbkSdjbCxlnRhLQ2pRdKkkirWmnWXbj9T/UWZYB2oK0z5XqcJ
2HUw19JlYD1n1khVdWk/kfVIC0dpImmClr7JyDiGSnoscxlIaU5rfGW/D/xwzoiQ
-----END CERTIFICATE-----
# Issuer: O=VeriSign, Inc. OU=Class 3 Public Primary Certification Authority - G2/(c) 1998 VeriSign, Inc. - For authorized use only/VeriSign Trust Network
# Subject: O=VeriSign, Inc. OU=Class 3 Public Primary Certification Authority - G2/(c) 1998 VeriSign, Inc. - For authorized use only/VeriSign Trust Network
# Label: "Verisign Class 3 Public Primary Certification Authority - G2"
# Serial: 167285380242319648451154478808036881606
# MD5 Fingerprint: a2:33:9b:4c:74:78:73:d4:6c:e7:c1:f3:8d:cb:5c:e9
# SHA1 Fingerprint: 85:37:1c:a6:e5:50:14:3d:ce:28:03:47:1b:de:3a:09:e8:f8:77:0f
# SHA256 Fingerprint: 83:ce:3c:12:29:68:8a:59:3d:48:5f:81:97:3c:0f:91:95:43:1e:da:37:cc:5e:36:43:0e:79:c7:a8:88:63:8b
-----BEGIN CERTIFICATE-----
MIIDAjCCAmsCEH3Z/gfPqB63EHln+6eJNMYwDQYJKoZIhvcNAQEFBQAwgcExCzAJ
BgNVBAYTAlVTMRcwFQYDVQQKEw5WZXJpU2lnbiwgSW5jLjE8MDoGA1UECxMzQ2xh
c3MgMyBQdWJsaWMgUHJpbWFyeSBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eSAtIEcy
MTowOAYDVQQLEzEoYykgMTk5OCBWZXJpU2lnbiwgSW5jLiAtIEZvciBhdXRob3Jp
emVkIHVzZSBvbmx5MR8wHQYDVQQLExZWZXJpU2lnbiBUcnVzdCBOZXR3b3JrMB4X
DTk4MDUxODAwMDAwMFoXDTI4MDgwMTIzNTk1OVowgcExCzAJBgNVBAYTAlVTMRcw
FQYDVQQKEw5WZXJpU2lnbiwgSW5jLjE8MDoGA1UECxMzQ2xhc3MgMyBQdWJsaWMg
UHJpbWFyeSBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eSAtIEcyMTowOAYDVQQLEzEo
YykgMTk5OCBWZXJpU2lnbiwgSW5jLiAtIEZvciBhdXRob3JpemVkIHVzZSBvbmx5
MR8wHQYDVQQLExZWZXJpU2lnbiBUcnVzdCBOZXR3b3JrMIGfMA0GCSqGSIb3DQEB
AQUAA4GNADCBiQKBgQDMXtERXVxp0KvTuWpMmR9ZmDCOFoUgRm1HP9SFIIThbbP4
pO0M8RcPO/mn+SXXwc+EY/J8Y8+iR/LGWzOOZEAEaMGAuWQcRXfH2G71lSk8UOg0
13gfqLptQ5GVj0VXXn7F+8qkBOvqlzdUMG+7AUcyM83cV5tkaWH4mx0ciU9cZwID
AQABMA0GCSqGSIb3DQEBBQUAA4GBAFFNzb5cy5gZnBWyATl4Lk0PZ3BwmcYQWpSk
U01UbSuvDV1Ai2TT1+7eVmGSX6bEHRBhNtMsJzzoKQm5EWR0zLVznxxIqbxhAe7i
F6YM40AIOw7n60RzKprxaZLvcRTDOaxxp5EJb+RxBrO6WVcmeQD2+A2iMzAo1KpY
oJ2daZH9
-----END CERTIFICATE-----
# Issuer: CN=GTE CyberTrust Global Root O=GTE Corporation OU=GTE CyberTrust Solutions, Inc.
# Subject: CN=GTE CyberTrust Global Root O=GTE Corporation OU=GTE CyberTrust Solutions, Inc.
# Label: "GTE CyberTrust Global Root"
# Serial: 421
# MD5 Fingerprint: ca:3d:d3:68:f1:03:5c:d0:32:fa:b8:2b:59:e8:5a:db
# SHA1 Fingerprint: 97:81:79:50:d8:1c:96:70:cc:34:d8:09:cf:79:44:31:36:7e:f4:74
# SHA256 Fingerprint: a5:31:25:18:8d:21:10:aa:96:4b:02:c7:b7:c6:da:32:03:17:08:94:e5:fb:71:ff:fb:66:67:d5:e6:81:0a:36
-----BEGIN CERTIFICATE-----
MIICWjCCAcMCAgGlMA0GCSqGSIb3DQEBBAUAMHUxCzAJBgNVBAYTAlVTMRgwFgYD
VQQKEw9HVEUgQ29ycG9yYXRpb24xJzAlBgNVBAsTHkdURSBDeWJlclRydXN0IFNv
bHV0aW9ucywgSW5jLjEjMCEGA1UEAxMaR1RFIEN5YmVyVHJ1c3QgR2xvYmFsIFJv
b3QwHhcNOTgwODEzMDAyOTAwWhcNMTgwODEzMjM1OTAwWjB1MQswCQYDVQQGEwJV
UzEYMBYGA1UEChMPR1RFIENvcnBvcmF0aW9uMScwJQYDVQQLEx5HVEUgQ3liZXJU
cnVzdCBTb2x1dGlvbnMsIEluYy4xIzAhBgNVBAMTGkdURSBDeWJlclRydXN0IEds
b2JhbCBSb290MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCVD6C28FCc6HrH
iM3dFw4usJTQGz0O9pTAipTHBsiQl8i4ZBp6fmw8U+E3KHNgf7KXUwefU/ltWJTS
r41tiGeA5u2ylc9yMcqlHHK6XALnZELn+aks1joNrI1CqiQBOeacPwGFVw1Yh0X4
04Wqk2kmhXBIgD8SFcd5tB8FLztimQIDAQABMA0GCSqGSIb3DQEBBAUAA4GBAG3r
GwnpXtlR22ciYaQqPEh346B8pt5zohQDhT37qw4wxYMWM4ETCJ57NE7fQMh017l9
3PR2VX2bY1QY6fDq81yx2YtCHrnAlU66+tXifPVoYb+O7AWXX1uw16OFNMQkpw0P
lZPvy5TYnh+dXIVtx6quTx8itc2VrbqnzPmrC3p/
-----END CERTIFICATE-----
# Issuer: C=US, O=Equifax, OU=Equifax Secure Certificate Authority
# Subject: C=US, O=Equifax, OU=Equifax Secure Certificate Authority
# Label: "Equifax Secure Certificate Authority"
# Serial: 903804111
# MD5 Fingerprint: 67:cb:9d:c0:13:24:8a:82:9b:b2:17:1e:d1:1b:ec:d4
# SHA1 Fingerprint: d2:32:09:ad:23:d3:14:23:21:74:e4:0d:7f:9d:62:13:97:86:63:3a
# SHA256 Fingerprint: 08:29:7a:40:47:db:a2:36:80:c7:31:db:6e:31:76:53:ca:78:48:e1:be:bd:3a:0b:01:79:a7:07:f9:2c:f1:78
-----BEGIN CERTIFICATE-----
MIIDIDCCAomgAwIBAgIENd70zzANBgkqhkiG9w0BAQUFADBOMQswCQYDVQQGEwJV
UzEQMA4GA1UEChMHRXF1aWZheDEtMCsGA1UECxMkRXF1aWZheCBTZWN1cmUgQ2Vy
dGlmaWNhdGUgQXV0aG9yaXR5MB4XDTk4MDgyMjE2NDE1MVoXDTE4MDgyMjE2NDE1
MVowTjELMAkGA1UEBhMCVVMxEDAOBgNVBAoTB0VxdWlmYXgxLTArBgNVBAsTJEVx
dWlmYXggU2VjdXJlIENlcnRpZmljYXRlIEF1dGhvcml0eTCBnzANBgkqhkiG9w0B
AQEFAAOBjQAwgYkCgYEAwV2xWGcIYu6gmi0fCG2RFGiYCh7+2gRvE4RiIcPRfM6f
BeC4AfBONOziipUEZKzxa1NfBbPLZ4C/QgKO/t0BCezhABRP/PvwDN1Dulsr4R+A
cJkVV5MW8Q+XarfCaCMczE1ZMKxRHjuvK9buY0V7xdlfUNLjUA86iOe/FP3gx7kC
AwEAAaOCAQkwggEFMHAGA1UdHwRpMGcwZaBjoGGkXzBdMQswCQYDVQQGEwJVUzEQ
MA4GA1UEChMHRXF1aWZheDEtMCsGA1UECxMkRXF1aWZheCBTZWN1cmUgQ2VydGlm
aWNhdGUgQXV0aG9yaXR5MQ0wCwYDVQQDEwRDUkwxMBoGA1UdEAQTMBGBDzIwMTgw
ODIyMTY0MTUxWjALBgNVHQ8EBAMCAQYwHwYDVR0jBBgwFoAUSOZo+SvSspXXR9gj
IBBPM5iQn9QwHQYDVR0OBBYEFEjmaPkr0rKV10fYIyAQTzOYkJ/UMAwGA1UdEwQF
MAMBAf8wGgYJKoZIhvZ9B0EABA0wCxsFVjMuMGMDAgbAMA0GCSqGSIb3DQEBBQUA
A4GBAFjOKer89961zgK5F7WF0bnj4JXMJTENAKaSbn+2kmOeUJXRmm/kEd5jhW6Y
7qj/WsjTVbJmcVfewCHrPSqnI0kBBIZCe/zuf6IWUrVnZ9NA2zsmWLIodz2uFHdh
1voqZiegDfqnc1zqcPGUIWVEX/r87yloqaKHee9570+sB3c4
-----END CERTIFICATE-----

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,13 @@
CFFI
====
Foreign Function Interface for Python calling C code.
Please see the `Documentation <http://cffi.readthedocs.org/>`_.
Contact
-------
`Mailing list <https://groups.google.com/forum/#!forum/python-cffi>`_

View File

@ -0,0 +1 @@
pip

View File

@ -0,0 +1,35 @@
Metadata-Version: 2.0
Name: cffi
Version: 1.10.0
Summary: Foreign Function Interface for Python calling C code.
Home-page: http://cffi.readthedocs.org
Author: Armin Rigo, Maciej Fijalkowski
Author-email: python-cffi@googlegroups.com
License: MIT
Platform: UNKNOWN
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 2
Classifier: Programming Language :: Python :: 2.6
Classifier: Programming Language :: Python :: 2.7
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.2
Classifier: Programming Language :: Python :: 3.3
Classifier: Programming Language :: Python :: 3.4
Classifier: Programming Language :: Python :: 3.5
Classifier: Programming Language :: Python :: Implementation :: CPython
Classifier: Programming Language :: Python :: Implementation :: PyPy
Requires-Dist: pycparser
CFFI
====
Foreign Function Interface for Python calling C code.
Please see the `Documentation <http://cffi.readthedocs.org/>`_.
Contact
-------
`Mailing list <https://groups.google.com/forum/#!forum/python-cffi>`_

View File

@ -0,0 +1,42 @@
_cffi_backend.cp36-win32.pyd,sha256=Go5Po3dC2iO2JU1oFvQuisiCbeIH8GyVYoi2_fJGg-Y,128512
cffi/__init__.py,sha256=QcAAIPcVY5LsX041WHzW-GYObh691LT5L5FKo0yVwDI,479
cffi/_cffi_include.h,sha256=8SuGAPe_N8n4Uv2B8aQOqp5JFG0un02SKGvy1e9tihQ,10238
cffi/_embedding.h,sha256=c_xb0Dw0k7gq9OP6x10yrLBy1offSq5KxJphiU4p3dw,17275
cffi/api.py,sha256=FHlxuRrwmZbheKY2HY3l1ScFHNuWUBjGu80YPhleICs,39647
cffi/backend_ctypes.py,sha256=CcGNp1XCa7QHBihowgHu5BojpeQZ32s7tpEjFrtFtME,42078
cffi/cffi_opcode.py,sha256=hbX3E-hmvcwzfqOhqSfH3ObpmRMNOHf_VZ--32flIEo,5459
cffi/commontypes.py,sha256=QS4uxCDI7JhtTyjh1hlnCA-gynmaszWxJaRRLGkJa1A,2689
cffi/cparser.py,sha256=AX4kk4BejnA8erNHzTEyeJWbX_He82MBmpIsZlkmdl8,38507
cffi/error.py,sha256=yNDakwm_HPJ_T53ivbB7hEts2N-oBGjMLw_25aNi2E8,536
cffi/ffiplatform.py,sha256=g-6tBT6R2aXkIDaAmni92fXI4rwiCegHU8AyD_wL3wo,3645
cffi/lock.py,sha256=l9TTdwMIMpi6jDkJGnQgE9cvTIR7CAntIJr8EGHt3pY,747
cffi/model.py,sha256=1daM9AYkCFmSMUrgbbe9KIK1jXsJPYZn_A828S-Qbv8,21103
cffi/parse_c_type.h,sha256=BBca7ODJCzlsK_1c4HT5MFhy1wdUbyHYvKuvCvxaQZ8,5835
cffi/recompiler.py,sha256=43UEvyl2mtigxa0laEWHLDtv2T8OTZEMj2NXComOoJU,61597
cffi/setuptools_ext.py,sha256=07n99TzG6QAsFDhf5-cE10pN2FIzKQ-DVFV5YbnN6eA,7463
cffi/vengine_cpy.py,sha256=Kw_Z38hrBJPUod5R517dRwAvRC7SjQ0qx8fEX1ZaFAM,41325
cffi/vengine_gen.py,sha256=dLmNdH0MmI_Jog3PlYuCG5RIdx8_xP_RuOjRa84q_N8,26597
cffi/verifier.py,sha256=Vk8v9fePaHkMmDc-wftJ4gowPErbacfC3soTw_rpT8U,11519
cffi-1.10.0.dist-info/DESCRIPTION.rst,sha256=9ijQLbcqTWNF-iV0RznFiBeBCNrjArA0P-eutKUPw98,220
cffi-1.10.0.dist-info/METADATA,sha256=c9-fyjmuNh52W-A4SeBTlgr35GZEsWW2Tskw_3nHCWM,1090
cffi-1.10.0.dist-info/RECORD,,
cffi-1.10.0.dist-info/WHEEL,sha256=xiHTm3JxoVljPSD6nSGhq3B4VY9iUqMNXwYQ259n1PI,102
cffi-1.10.0.dist-info/entry_points.txt,sha256=Q9f5C9IpjYxo0d2PK9eUcnkgxHc9pHWwjEMaANPKNCI,76
cffi-1.10.0.dist-info/metadata.json,sha256=fBsfmNhS5_P6IGaL1mdGMDj8o0NZPotHZeGIB3FprRI,1112
cffi-1.10.0.dist-info/top_level.txt,sha256=rE7WR3rZfNKxWI9-jn6hsHCAl7MDkB-FmuQbxWjFehQ,19
cffi-1.10.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
cffi/__pycache__/api.cpython-36.pyc,,
cffi/__pycache__/backend_ctypes.cpython-36.pyc,,
cffi/__pycache__/cffi_opcode.cpython-36.pyc,,
cffi/__pycache__/commontypes.cpython-36.pyc,,
cffi/__pycache__/cparser.cpython-36.pyc,,
cffi/__pycache__/error.cpython-36.pyc,,
cffi/__pycache__/ffiplatform.cpython-36.pyc,,
cffi/__pycache__/lock.cpython-36.pyc,,
cffi/__pycache__/model.cpython-36.pyc,,
cffi/__pycache__/recompiler.cpython-36.pyc,,
cffi/__pycache__/setuptools_ext.cpython-36.pyc,,
cffi/__pycache__/vengine_cpy.cpython-36.pyc,,
cffi/__pycache__/vengine_gen.cpython-36.pyc,,
cffi/__pycache__/verifier.cpython-36.pyc,,
cffi/__pycache__/__init__.cpython-36.pyc,,

View File

@ -0,0 +1,5 @@
Wheel-Version: 1.0
Generator: bdist_wheel (0.29.0)
Root-Is-Purelib: false
Tag: cp36-cp36m-win32

View File

@ -0,0 +1,3 @@
[distutils.setup_keywords]
cffi_modules = cffi.setuptools_ext:cffi_modules

View File

@ -0,0 +1,2 @@
_cffi_backend
cffi

View File

@ -0,0 +1,13 @@
__all__ = ['FFI', 'VerificationError', 'VerificationMissing', 'CDefError',
'FFIError']
from .api import FFI
from .error import CDefError, FFIError, VerificationError, VerificationMissing
__version__ = "1.10.0"
__version_info__ = (1, 10, 0)
# The verifier module file names are based on the CRC32 of a string that
# contains the following version number. It may be older than __version__
# if nothing is clearly incompatible.
__version_verifier_modules__ = "0.8.6"

View File

@ -0,0 +1,253 @@
#define _CFFI_
/* We try to define Py_LIMITED_API before including Python.h.
Mess: we can only define it if Py_DEBUG, Py_TRACE_REFS and
Py_REF_DEBUG are not defined. This is a best-effort approximation:
we can learn about Py_DEBUG from pyconfig.h, but it is unclear if
the same works for the other two macros. Py_DEBUG implies them,
but not the other way around.
*/
#ifndef _CFFI_USE_EMBEDDING
# include <pyconfig.h>
# if !defined(Py_DEBUG) && !defined(Py_TRACE_REFS) && !defined(Py_REF_DEBUG)
# define Py_LIMITED_API
# endif
#endif
#include <Python.h>
#ifdef __cplusplus
extern "C" {
#endif
#include <stddef.h>
#include "parse_c_type.h"
/* this block of #ifs should be kept exactly identical between
c/_cffi_backend.c, cffi/vengine_cpy.py, cffi/vengine_gen.py
and cffi/_cffi_include.h */
#if defined(_MSC_VER)
# include <malloc.h> /* for alloca() */
# if _MSC_VER < 1600 /* MSVC < 2010 */
typedef __int8 int8_t;
typedef __int16 int16_t;
typedef __int32 int32_t;
typedef __int64 int64_t;
typedef unsigned __int8 uint8_t;
typedef unsigned __int16 uint16_t;
typedef unsigned __int32 uint32_t;
typedef unsigned __int64 uint64_t;
typedef __int8 int_least8_t;
typedef __int16 int_least16_t;
typedef __int32 int_least32_t;
typedef __int64 int_least64_t;
typedef unsigned __int8 uint_least8_t;
typedef unsigned __int16 uint_least16_t;
typedef unsigned __int32 uint_least32_t;
typedef unsigned __int64 uint_least64_t;
typedef __int8 int_fast8_t;
typedef __int16 int_fast16_t;
typedef __int32 int_fast32_t;
typedef __int64 int_fast64_t;
typedef unsigned __int8 uint_fast8_t;
typedef unsigned __int16 uint_fast16_t;
typedef unsigned __int32 uint_fast32_t;
typedef unsigned __int64 uint_fast64_t;
typedef __int64 intmax_t;
typedef unsigned __int64 uintmax_t;
# else
# include <stdint.h>
# endif
# if _MSC_VER < 1800 /* MSVC < 2013 */
# ifndef __cplusplus
typedef unsigned char _Bool;
# endif
# endif
#else
# include <stdint.h>
# if (defined (__SVR4) && defined (__sun)) || defined(_AIX) || defined(__hpux)
# include <alloca.h>
# endif
#endif
#ifdef __GNUC__
# define _CFFI_UNUSED_FN __attribute__((unused))
#else
# define _CFFI_UNUSED_FN /* nothing */
#endif
#ifdef __cplusplus
# ifndef _Bool
typedef bool _Bool; /* semi-hackish: C++ has no _Bool; bool is builtin */
# endif
#endif
/********** CPython-specific section **********/
#ifndef PYPY_VERSION
#if PY_MAJOR_VERSION >= 3
# define PyInt_FromLong PyLong_FromLong
#endif
#define _cffi_from_c_double PyFloat_FromDouble
#define _cffi_from_c_float PyFloat_FromDouble
#define _cffi_from_c_long PyInt_FromLong
#define _cffi_from_c_ulong PyLong_FromUnsignedLong
#define _cffi_from_c_longlong PyLong_FromLongLong
#define _cffi_from_c_ulonglong PyLong_FromUnsignedLongLong
#define _cffi_to_c_double PyFloat_AsDouble
#define _cffi_to_c_float PyFloat_AsDouble
#define _cffi_from_c_int(x, type) \
(((type)-1) > 0 ? /* unsigned */ \
(sizeof(type) < sizeof(long) ? \
PyInt_FromLong((long)x) : \
sizeof(type) == sizeof(long) ? \
PyLong_FromUnsignedLong((unsigned long)x) : \
PyLong_FromUnsignedLongLong((unsigned long long)x)) : \
(sizeof(type) <= sizeof(long) ? \
PyInt_FromLong((long)x) : \
PyLong_FromLongLong((long long)x)))
#define _cffi_to_c_int(o, type) \
((type)( \
sizeof(type) == 1 ? (((type)-1) > 0 ? (type)_cffi_to_c_u8(o) \
: (type)_cffi_to_c_i8(o)) : \
sizeof(type) == 2 ? (((type)-1) > 0 ? (type)_cffi_to_c_u16(o) \
: (type)_cffi_to_c_i16(o)) : \
sizeof(type) == 4 ? (((type)-1) > 0 ? (type)_cffi_to_c_u32(o) \
: (type)_cffi_to_c_i32(o)) : \
sizeof(type) == 8 ? (((type)-1) > 0 ? (type)_cffi_to_c_u64(o) \
: (type)_cffi_to_c_i64(o)) : \
(Py_FatalError("unsupported size for type " #type), (type)0)))
#define _cffi_to_c_i8 \
((int(*)(PyObject *))_cffi_exports[1])
#define _cffi_to_c_u8 \
((int(*)(PyObject *))_cffi_exports[2])
#define _cffi_to_c_i16 \
((int(*)(PyObject *))_cffi_exports[3])
#define _cffi_to_c_u16 \
((int(*)(PyObject *))_cffi_exports[4])
#define _cffi_to_c_i32 \
((int(*)(PyObject *))_cffi_exports[5])
#define _cffi_to_c_u32 \
((unsigned int(*)(PyObject *))_cffi_exports[6])
#define _cffi_to_c_i64 \
((long long(*)(PyObject *))_cffi_exports[7])
#define _cffi_to_c_u64 \
((unsigned long long(*)(PyObject *))_cffi_exports[8])
#define _cffi_to_c_char \
((int(*)(PyObject *))_cffi_exports[9])
#define _cffi_from_c_pointer \
((PyObject *(*)(char *, struct _cffi_ctypedescr *))_cffi_exports[10])
#define _cffi_to_c_pointer \
((char *(*)(PyObject *, struct _cffi_ctypedescr *))_cffi_exports[11])
#define _cffi_get_struct_layout \
not used any more
#define _cffi_restore_errno \
((void(*)(void))_cffi_exports[13])
#define _cffi_save_errno \
((void(*)(void))_cffi_exports[14])
#define _cffi_from_c_char \
((PyObject *(*)(char))_cffi_exports[15])
#define _cffi_from_c_deref \
((PyObject *(*)(char *, struct _cffi_ctypedescr *))_cffi_exports[16])
#define _cffi_to_c \
((int(*)(char *, struct _cffi_ctypedescr *, PyObject *))_cffi_exports[17])
#define _cffi_from_c_struct \
((PyObject *(*)(char *, struct _cffi_ctypedescr *))_cffi_exports[18])
#define _cffi_to_c_wchar_t \
((wchar_t(*)(PyObject *))_cffi_exports[19])
#define _cffi_from_c_wchar_t \
((PyObject *(*)(wchar_t))_cffi_exports[20])
#define _cffi_to_c_long_double \
((long double(*)(PyObject *))_cffi_exports[21])
#define _cffi_to_c__Bool \
((_Bool(*)(PyObject *))_cffi_exports[22])
#define _cffi_prepare_pointer_call_argument \
((Py_ssize_t(*)(struct _cffi_ctypedescr *, \
PyObject *, char **))_cffi_exports[23])
#define _cffi_convert_array_from_object \
((int(*)(char *, struct _cffi_ctypedescr *, PyObject *))_cffi_exports[24])
#define _CFFI_CPIDX 25
#define _cffi_call_python \
((void(*)(struct _cffi_externpy_s *, char *))_cffi_exports[_CFFI_CPIDX])
#define _CFFI_NUM_EXPORTS 26
struct _cffi_ctypedescr;
static void *_cffi_exports[_CFFI_NUM_EXPORTS];
#define _cffi_type(index) ( \
assert((((uintptr_t)_cffi_types[index]) & 1) == 0), \
(struct _cffi_ctypedescr *)_cffi_types[index])
static PyObject *_cffi_init(const char *module_name, Py_ssize_t version,
const struct _cffi_type_context_s *ctx)
{
PyObject *module, *o_arg, *new_module;
void *raw[] = {
(void *)module_name,
(void *)version,
(void *)_cffi_exports,
(void *)ctx,
};
module = PyImport_ImportModule("_cffi_backend");
if (module == NULL)
goto failure;
o_arg = PyLong_FromVoidPtr((void *)raw);
if (o_arg == NULL)
goto failure;
new_module = PyObject_CallMethod(
module, (char *)"_init_cffi_1_0_external_module", (char *)"O", o_arg);
Py_DECREF(o_arg);
Py_DECREF(module);
return new_module;
failure:
Py_XDECREF(module);
return NULL;
}
/********** end CPython-specific section **********/
#else
_CFFI_UNUSED_FN
static void (*_cffi_call_python_org)(struct _cffi_externpy_s *, char *);
# define _cffi_call_python _cffi_call_python_org
#endif
#define _cffi_array_len(array) (sizeof(array) / sizeof((array)[0]))
#define _cffi_prim_int(size, sign) \
((size) == 1 ? ((sign) ? _CFFI_PRIM_INT8 : _CFFI_PRIM_UINT8) : \
(size) == 2 ? ((sign) ? _CFFI_PRIM_INT16 : _CFFI_PRIM_UINT16) : \
(size) == 4 ? ((sign) ? _CFFI_PRIM_INT32 : _CFFI_PRIM_UINT32) : \
(size) == 8 ? ((sign) ? _CFFI_PRIM_INT64 : _CFFI_PRIM_UINT64) : \
_CFFI__UNKNOWN_PRIM)
#define _cffi_prim_float(size) \
((size) == sizeof(float) ? _CFFI_PRIM_FLOAT : \
(size) == sizeof(double) ? _CFFI_PRIM_DOUBLE : \
(size) == sizeof(long double) ? _CFFI__UNKNOWN_LONG_DOUBLE : \
_CFFI__UNKNOWN_FLOAT_PRIM)
#define _cffi_check_int(got, got_nonpos, expected) \
((got_nonpos) == (expected <= 0) && \
(got) == (unsigned long long)expected)
#ifdef MS_WIN32
# define _cffi_stdcall __stdcall
#else
# define _cffi_stdcall /* nothing */
#endif
#ifdef __cplusplus
}
#endif

View File

@ -0,0 +1,517 @@
/***** Support code for embedding *****/
#if defined(_MSC_VER)
# define CFFI_DLLEXPORT __declspec(dllexport)
#elif defined(__GNUC__)
# define CFFI_DLLEXPORT __attribute__((visibility("default")))
#else
# define CFFI_DLLEXPORT /* nothing */
#endif
/* There are two global variables of type _cffi_call_python_fnptr:
* _cffi_call_python, which we declare just below, is the one called
by ``extern "Python"`` implementations.
* _cffi_call_python_org, which on CPython is actually part of the
_cffi_exports[] array, is the function pointer copied from
_cffi_backend.
After initialization is complete, both are equal. However, the
first one remains equal to &_cffi_start_and_call_python until the
very end of initialization, when we are (or should be) sure that
concurrent threads also see a completely initialized world, and
only then is it changed.
*/
#undef _cffi_call_python
typedef void (*_cffi_call_python_fnptr)(struct _cffi_externpy_s *, char *);
static void _cffi_start_and_call_python(struct _cffi_externpy_s *, char *);
static _cffi_call_python_fnptr _cffi_call_python = &_cffi_start_and_call_python;
#ifndef _MSC_VER
/* --- Assuming a GCC not infinitely old --- */
# define cffi_compare_and_swap(l,o,n) __sync_bool_compare_and_swap(l,o,n)
# define cffi_write_barrier() __sync_synchronize()
# if !defined(__amd64__) && !defined(__x86_64__) && \
!defined(__i386__) && !defined(__i386)
# define cffi_read_barrier() __sync_synchronize()
# else
# define cffi_read_barrier() (void)0
# endif
#else
/* --- Windows threads version --- */
# include <Windows.h>
# define cffi_compare_and_swap(l,o,n) \
(InterlockedCompareExchangePointer(l,n,o) == (o))
# define cffi_write_barrier() InterlockedCompareExchange(&_cffi_dummy,0,0)
# define cffi_read_barrier() (void)0
static volatile LONG _cffi_dummy;
#endif
#ifdef WITH_THREAD
# ifndef _MSC_VER
# include <pthread.h>
static pthread_mutex_t _cffi_embed_startup_lock;
# else
static CRITICAL_SECTION _cffi_embed_startup_lock;
# endif
static char _cffi_embed_startup_lock_ready = 0;
#endif
static void _cffi_acquire_reentrant_mutex(void)
{
static void *volatile lock = NULL;
while (!cffi_compare_and_swap(&lock, NULL, (void *)1)) {
/* should ideally do a spin loop instruction here, but
hard to do it portably and doesn't really matter I
think: pthread_mutex_init() should be very fast, and
this is only run at start-up anyway. */
}
#ifdef WITH_THREAD
if (!_cffi_embed_startup_lock_ready) {
# ifndef _MSC_VER
pthread_mutexattr_t attr;
pthread_mutexattr_init(&attr);
pthread_mutexattr_settype(&attr, PTHREAD_MUTEX_RECURSIVE);
pthread_mutex_init(&_cffi_embed_startup_lock, &attr);
# else
InitializeCriticalSection(&_cffi_embed_startup_lock);
# endif
_cffi_embed_startup_lock_ready = 1;
}
#endif
while (!cffi_compare_and_swap(&lock, (void *)1, NULL))
;
#ifndef _MSC_VER
pthread_mutex_lock(&_cffi_embed_startup_lock);
#else
EnterCriticalSection(&_cffi_embed_startup_lock);
#endif
}
static void _cffi_release_reentrant_mutex(void)
{
#ifndef _MSC_VER
pthread_mutex_unlock(&_cffi_embed_startup_lock);
#else
LeaveCriticalSection(&_cffi_embed_startup_lock);
#endif
}
/********** CPython-specific section **********/
#ifndef PYPY_VERSION
#define _cffi_call_python_org _cffi_exports[_CFFI_CPIDX]
PyMODINIT_FUNC _CFFI_PYTHON_STARTUP_FUNC(void); /* forward */
static void _cffi_py_initialize(void)
{
/* XXX use initsigs=0, which "skips initialization registration of
signal handlers, which might be useful when Python is
embedded" according to the Python docs. But review and think
if it should be a user-controllable setting.
XXX we should also give a way to write errors to a buffer
instead of to stderr.
XXX if importing 'site' fails, CPython (any version) calls
exit(). Should we try to work around this behavior here?
*/
Py_InitializeEx(0);
}
static int _cffi_initialize_python(void)
{
/* This initializes Python, imports _cffi_backend, and then the
present .dll/.so is set up as a CPython C extension module.
*/
int result;
PyGILState_STATE state;
PyObject *pycode=NULL, *global_dict=NULL, *x;
#if PY_MAJOR_VERSION >= 3
/* see comments in _cffi_carefully_make_gil() about the
Python2/Python3 difference
*/
#else
/* Acquire the GIL. We have no threadstate here. If Python is
already initialized, it is possible that there is already one
existing for this thread, but it is not made current now.
*/
PyEval_AcquireLock();
_cffi_py_initialize();
/* The Py_InitializeEx() sometimes made a threadstate for us, but
not always. Indeed Py_InitializeEx() could be called and do
nothing. So do we have a threadstate, or not? We don't know,
but we can replace it with NULL in all cases.
*/
(void)PyThreadState_Swap(NULL);
/* Now we can release the GIL and re-acquire immediately using the
logic of PyGILState(), which handles making or installing the
correct threadstate.
*/
PyEval_ReleaseLock();
#endif
state = PyGILState_Ensure();
/* Call the initxxx() function from the present module. It will
create and initialize us as a CPython extension module, instead
of letting the startup Python code do it---it might reimport
the same .dll/.so and get maybe confused on some platforms.
It might also have troubles locating the .dll/.so again for all
I know.
*/
(void)_CFFI_PYTHON_STARTUP_FUNC();
if (PyErr_Occurred())
goto error;
/* Now run the Python code provided to ffi.embedding_init_code().
*/
pycode = Py_CompileString(_CFFI_PYTHON_STARTUP_CODE,
"<init code for '" _CFFI_MODULE_NAME "'>",
Py_file_input);
if (pycode == NULL)
goto error;
global_dict = PyDict_New();
if (global_dict == NULL)
goto error;
if (PyDict_SetItemString(global_dict, "__builtins__",
PyThreadState_GET()->interp->builtins) < 0)
goto error;
x = PyEval_EvalCode(
#if PY_MAJOR_VERSION < 3
(PyCodeObject *)
#endif
pycode, global_dict, global_dict);
if (x == NULL)
goto error;
Py_DECREF(x);
/* Done! Now if we've been called from
_cffi_start_and_call_python() in an ``extern "Python"``, we can
only hope that the Python code did correctly set up the
corresponding @ffi.def_extern() function. Otherwise, the
general logic of ``extern "Python"`` functions (inside the
_cffi_backend module) will find that the reference is still
missing and print an error.
*/
result = 0;
done:
Py_XDECREF(pycode);
Py_XDECREF(global_dict);
PyGILState_Release(state);
return result;
error:;
{
/* Print as much information as potentially useful.
Debugging load-time failures with embedding is not fun
*/
PyObject *exception, *v, *tb, *f, *modules, *mod;
PyErr_Fetch(&exception, &v, &tb);
if (exception != NULL) {
PyErr_NormalizeException(&exception, &v, &tb);
PyErr_Display(exception, v, tb);
}
Py_XDECREF(exception);
Py_XDECREF(v);
Py_XDECREF(tb);
f = PySys_GetObject((char *)"stderr");
if (f != NULL && f != Py_None) {
PyFile_WriteString("\nFrom: " _CFFI_MODULE_NAME
"\ncompiled with cffi version: 1.10.0"
"\n_cffi_backend module: ", f);
modules = PyImport_GetModuleDict();
mod = PyDict_GetItemString(modules, "_cffi_backend");
if (mod == NULL) {
PyFile_WriteString("not loaded", f);
}
else {
v = PyObject_GetAttrString(mod, "__file__");
PyFile_WriteObject(v, f, 0);
Py_XDECREF(v);
}
PyFile_WriteString("\nsys.path: ", f);
PyFile_WriteObject(PySys_GetObject((char *)"path"), f, 0);
PyFile_WriteString("\n\n", f);
}
}
result = -1;
goto done;
}
PyAPI_DATA(char *) _PyParser_TokenNames[]; /* from CPython */
static int _cffi_carefully_make_gil(void)
{
/* This does the basic initialization of Python. It can be called
completely concurrently from unrelated threads. It assumes
that we don't hold the GIL before (if it exists), and we don't
hold it afterwards.
What it really does is completely different in Python 2 and
Python 3.
Python 2
========
Initialize the GIL, without initializing the rest of Python,
by calling PyEval_InitThreads().
PyEval_InitThreads() must not be called concurrently at all.
So we use a global variable as a simple spin lock. This global
variable must be from 'libpythonX.Y.so', not from this
cffi-based extension module, because it must be shared from
different cffi-based extension modules. We choose
_PyParser_TokenNames[0] as a completely arbitrary pointer value
that is never written to. The default is to point to the
string "ENDMARKER". We change it temporarily to point to the
next character in that string. (Yes, I know it's REALLY
obscure.)
Python 3
========
In Python 3, PyEval_InitThreads() cannot be called before
Py_InitializeEx() any more. So this function calls
Py_InitializeEx() first. It uses the same obscure logic to
make sure we never call it concurrently.
Arguably, this is less good on the spinlock, because
Py_InitializeEx() takes much longer to run than
PyEval_InitThreads(). But I didn't find a way around it.
*/
#ifdef WITH_THREAD
char *volatile *lock = (char *volatile *)_PyParser_TokenNames;
char *old_value;
while (1) { /* spin loop */
old_value = *lock;
if (old_value[0] == 'E') {
assert(old_value[1] == 'N');
if (cffi_compare_and_swap(lock, old_value, old_value + 1))
break;
}
else {
assert(old_value[0] == 'N');
/* should ideally do a spin loop instruction here, but
hard to do it portably and doesn't really matter I
think: PyEval_InitThreads() should be very fast, and
this is only run at start-up anyway. */
}
}
#endif
#if PY_MAJOR_VERSION >= 3
/* Python 3: call Py_InitializeEx() */
{
PyGILState_STATE state = PyGILState_UNLOCKED;
if (!Py_IsInitialized())
_cffi_py_initialize();
else
state = PyGILState_Ensure();
PyEval_InitThreads();
PyGILState_Release(state);
}
#else
/* Python 2: call PyEval_InitThreads() */
# ifdef WITH_THREAD
if (!PyEval_ThreadsInitialized()) {
PyEval_InitThreads(); /* makes the GIL */
PyEval_ReleaseLock(); /* then release it */
}
/* else: there is already a GIL, but we still needed to do the
spinlock dance to make sure that we see it as fully ready */
# endif
#endif
#ifdef WITH_THREAD
/* release the lock */
while (!cffi_compare_and_swap(lock, old_value + 1, old_value))
;
#endif
return 0;
}
/********** end CPython-specific section **********/
#else
/********** PyPy-specific section **********/
PyMODINIT_FUNC _CFFI_PYTHON_STARTUP_FUNC(const void *[]); /* forward */
static struct _cffi_pypy_init_s {
const char *name;
void (*func)(const void *[]);
const char *code;
} _cffi_pypy_init = {
_CFFI_MODULE_NAME,
(void(*)(const void *[]))_CFFI_PYTHON_STARTUP_FUNC,
_CFFI_PYTHON_STARTUP_CODE,
};
extern int pypy_carefully_make_gil(const char *);
extern int pypy_init_embedded_cffi_module(int, struct _cffi_pypy_init_s *);
static int _cffi_carefully_make_gil(void)
{
return pypy_carefully_make_gil(_CFFI_MODULE_NAME);
}
static int _cffi_initialize_python(void)
{
return pypy_init_embedded_cffi_module(0xB011, &_cffi_pypy_init);
}
/********** end PyPy-specific section **********/
#endif
#ifdef __GNUC__
__attribute__((noinline))
#endif
static _cffi_call_python_fnptr _cffi_start_python(void)
{
/* Delicate logic to initialize Python. This function can be
called multiple times concurrently, e.g. when the process calls
its first ``extern "Python"`` functions in multiple threads at
once. It can also be called recursively, in which case we must
ignore it. We also have to consider what occurs if several
different cffi-based extensions reach this code in parallel
threads---it is a different copy of the code, then, and we
can't have any shared global variable unless it comes from
'libpythonX.Y.so'.
Idea:
* _cffi_carefully_make_gil(): "carefully" call
PyEval_InitThreads() (possibly with Py_InitializeEx() first).
* then we use a (local) custom lock to make sure that a call to this
cffi-based extension will wait if another call to the *same*
extension is running the initialization in another thread.
It is reentrant, so that a recursive call will not block, but
only one from a different thread.
* then we grab the GIL and (Python 2) we call Py_InitializeEx().
At this point, concurrent calls to Py_InitializeEx() are not
possible: we have the GIL.
* do the rest of the specific initialization, which may
temporarily release the GIL but not the custom lock.
Only release the custom lock when we are done.
*/
static char called = 0;
if (_cffi_carefully_make_gil() != 0)
return NULL;
_cffi_acquire_reentrant_mutex();
/* Here the GIL exists, but we don't have it. We're only protected
from concurrency by the reentrant mutex. */
/* This file only initializes the embedded module once, the first
time this is called, even if there are subinterpreters. */
if (!called) {
called = 1; /* invoke _cffi_initialize_python() only once,
but don't set '_cffi_call_python' right now,
otherwise concurrent threads won't call
this function at all (we need them to wait) */
if (_cffi_initialize_python() == 0) {
/* now initialization is finished. Switch to the fast-path. */
/* We would like nobody to see the new value of
'_cffi_call_python' without also seeing the rest of the
data initialized. However, this is not possible. But
the new value of '_cffi_call_python' is the function
'cffi_call_python()' from _cffi_backend. So: */
cffi_write_barrier();
/* ^^^ we put a write barrier here, and a corresponding
read barrier at the start of cffi_call_python(). This
ensures that after that read barrier, we see everything
done here before the write barrier.
*/
assert(_cffi_call_python_org != NULL);
_cffi_call_python = (_cffi_call_python_fnptr)_cffi_call_python_org;
}
else {
/* initialization failed. Reset this to NULL, even if it was
already set to some other value. Future calls to
_cffi_start_python() are still forced to occur, and will
always return NULL from now on. */
_cffi_call_python_org = NULL;
}
}
_cffi_release_reentrant_mutex();
return (_cffi_call_python_fnptr)_cffi_call_python_org;
}
static
void _cffi_start_and_call_python(struct _cffi_externpy_s *externpy, char *args)
{
_cffi_call_python_fnptr fnptr;
int current_err = errno;
#ifdef _MSC_VER
int current_lasterr = GetLastError();
#endif
fnptr = _cffi_start_python();
if (fnptr == NULL) {
fprintf(stderr, "function %s() called, but initialization code "
"failed. Returning 0.\n", externpy->name);
memset(args, 0, externpy->size_of_result);
}
#ifdef _MSC_VER
SetLastError(current_lasterr);
#endif
errno = current_err;
if (fnptr != NULL)
fnptr(externpy, args);
}
/* The cffi_start_python() function makes sure Python is initialized
and our cffi module is set up. It can be called manually from the
user C code. The same effect is obtained automatically from any
dll-exported ``extern "Python"`` function. This function returns
-1 if initialization failed, 0 if all is OK. */
_CFFI_UNUSED_FN
static int cffi_start_python(void)
{
if (_cffi_call_python == &_cffi_start_and_call_python) {
if (_cffi_start_python() == NULL)
return -1;
}
cffi_read_barrier();
return 0;
}
#undef cffi_compare_and_swap
#undef cffi_write_barrier
#undef cffi_read_barrier

View File

@ -0,0 +1,916 @@
import sys, types
from .lock import allocate_lock
from .error import CDefError
from . import model
try:
callable
except NameError:
# Python 3.1
from collections import Callable
callable = lambda x: isinstance(x, Callable)
try:
basestring
except NameError:
# Python 3.x
basestring = str
class FFI(object):
r'''
The main top-level class that you instantiate once, or once per module.
Example usage:
ffi = FFI()
ffi.cdef("""
int printf(const char *, ...);
""")
C = ffi.dlopen(None) # standard library
-or-
C = ffi.verify() # use a C compiler: verify the decl above is right
C.printf("hello, %s!\n", ffi.new("char[]", "world"))
'''
def __init__(self, backend=None):
"""Create an FFI instance. The 'backend' argument is used to
select a non-default backend, mostly for tests.
"""
if backend is None:
# You need PyPy (>= 2.0 beta), or a CPython (>= 2.6) with
# _cffi_backend.so compiled.
import _cffi_backend as backend
from . import __version__
if backend.__version__ != __version__:
# bad version! Try to be as explicit as possible.
if hasattr(backend, '__file__'):
# CPython
raise Exception("Version mismatch: this is the 'cffi' package version %s, located in %r. When we import the top-level '_cffi_backend' extension module, we get version %s, located in %r. The two versions should be equal; check your installation." % (
__version__, __file__,
backend.__version__, backend.__file__))
else:
# PyPy
raise Exception("Version mismatch: this is the 'cffi' package version %s, located in %r. This interpreter comes with a built-in '_cffi_backend' module, which is version %s. The two versions should be equal; check your installation." % (
__version__, __file__, backend.__version__))
# (If you insist you can also try to pass the option
# 'backend=backend_ctypes.CTypesBackend()', but don't
# rely on it! It's probably not going to work well.)
from . import cparser
self._backend = backend
self._lock = allocate_lock()
self._parser = cparser.Parser()
self._cached_btypes = {}
self._parsed_types = types.ModuleType('parsed_types').__dict__
self._new_types = types.ModuleType('new_types').__dict__
self._function_caches = []
self._libraries = []
self._cdefsources = []
self._included_ffis = []
self._windows_unicode = None
self._init_once_cache = {}
self._cdef_version = None
self._embedding = None
if hasattr(backend, 'set_ffi'):
backend.set_ffi(self)
for name in backend.__dict__:
if name.startswith('RTLD_'):
setattr(self, name, getattr(backend, name))
#
with self._lock:
self.BVoidP = self._get_cached_btype(model.voidp_type)
self.BCharA = self._get_cached_btype(model.char_array_type)
if isinstance(backend, types.ModuleType):
# _cffi_backend: attach these constants to the class
if not hasattr(FFI, 'NULL'):
FFI.NULL = self.cast(self.BVoidP, 0)
FFI.CData, FFI.CType = backend._get_types()
else:
# ctypes backend: attach these constants to the instance
self.NULL = self.cast(self.BVoidP, 0)
self.CData, self.CType = backend._get_types()
self.buffer = backend.buffer
def cdef(self, csource, override=False, packed=False):
"""Parse the given C source. This registers all declared functions,
types, and global variables. The functions and global variables can
then be accessed via either 'ffi.dlopen()' or 'ffi.verify()'.
The types can be used in 'ffi.new()' and other functions.
If 'packed' is specified as True, all structs declared inside this
cdef are packed, i.e. laid out without any field alignment at all.
"""
self._cdef(csource, override=override, packed=packed)
def embedding_api(self, csource, packed=False):
self._cdef(csource, packed=packed, dllexport=True)
if self._embedding is None:
self._embedding = ''
def _cdef(self, csource, override=False, **options):
if not isinstance(csource, str): # unicode, on Python 2
if not isinstance(csource, basestring):
raise TypeError("cdef() argument must be a string")
csource = csource.encode('ascii')
with self._lock:
self._cdef_version = object()
self._parser.parse(csource, override=override, **options)
self._cdefsources.append(csource)
if override:
for cache in self._function_caches:
cache.clear()
finishlist = self._parser._recomplete
if finishlist:
self._parser._recomplete = []
for tp in finishlist:
tp.finish_backend_type(self, finishlist)
def dlopen(self, name, flags=0):
"""Load and return a dynamic library identified by 'name'.
The standard C library can be loaded by passing None.
Note that functions and types declared by 'ffi.cdef()' are not
linked to a particular library, just like C headers; in the
library we only look for the actual (untyped) symbols.
"""
assert isinstance(name, basestring) or name is None
with self._lock:
lib, function_cache = _make_ffi_library(self, name, flags)
self._function_caches.append(function_cache)
self._libraries.append(lib)
return lib
def _typeof_locked(self, cdecl):
# call me with the lock!
key = cdecl
if key in self._parsed_types:
return self._parsed_types[key]
#
if not isinstance(cdecl, str): # unicode, on Python 2
cdecl = cdecl.encode('ascii')
#
type = self._parser.parse_type(cdecl)
really_a_function_type = type.is_raw_function
if really_a_function_type:
type = type.as_function_pointer()
btype = self._get_cached_btype(type)
result = btype, really_a_function_type
self._parsed_types[key] = result
return result
def _typeof(self, cdecl, consider_function_as_funcptr=False):
# string -> ctype object
try:
result = self._parsed_types[cdecl]
except KeyError:
with self._lock:
result = self._typeof_locked(cdecl)
#
btype, really_a_function_type = result
if really_a_function_type and not consider_function_as_funcptr:
raise CDefError("the type %r is a function type, not a "
"pointer-to-function type" % (cdecl,))
return btype
def typeof(self, cdecl):
"""Parse the C type given as a string and return the
corresponding <ctype> object.
It can also be used on 'cdata' instance to get its C type.
"""
if isinstance(cdecl, basestring):
return self._typeof(cdecl)
if isinstance(cdecl, self.CData):
return self._backend.typeof(cdecl)
if isinstance(cdecl, types.BuiltinFunctionType):
res = _builtin_function_type(cdecl)
if res is not None:
return res
if (isinstance(cdecl, types.FunctionType)
and hasattr(cdecl, '_cffi_base_type')):
with self._lock:
return self._get_cached_btype(cdecl._cffi_base_type)
raise TypeError(type(cdecl))
def sizeof(self, cdecl):
"""Return the size in bytes of the argument. It can be a
string naming a C type, or a 'cdata' instance.
"""
if isinstance(cdecl, basestring):
BType = self._typeof(cdecl)
return self._backend.sizeof(BType)
else:
return self._backend.sizeof(cdecl)
def alignof(self, cdecl):
"""Return the natural alignment size in bytes of the C type
given as a string.
"""
if isinstance(cdecl, basestring):
cdecl = self._typeof(cdecl)
return self._backend.alignof(cdecl)
def offsetof(self, cdecl, *fields_or_indexes):
"""Return the offset of the named field inside the given
structure or array, which must be given as a C type name.
You can give several field names in case of nested structures.
You can also give numeric values which correspond to array
items, in case of an array type.
"""
if isinstance(cdecl, basestring):
cdecl = self._typeof(cdecl)
return self._typeoffsetof(cdecl, *fields_or_indexes)[1]
def new(self, cdecl, init=None):
"""Allocate an instance according to the specified C type and
return a pointer to it. The specified C type must be either a
pointer or an array: ``new('X *')`` allocates an X and returns
a pointer to it, whereas ``new('X[n]')`` allocates an array of
n X'es and returns an array referencing it (which works
mostly like a pointer, like in C). You can also use
``new('X[]', n)`` to allocate an array of a non-constant
length n.
The memory is initialized following the rules of declaring a
global variable in C: by default it is zero-initialized, but
an explicit initializer can be given which can be used to
fill all or part of the memory.
When the returned <cdata> object goes out of scope, the memory
is freed. In other words the returned <cdata> object has
ownership of the value of type 'cdecl' that it points to. This
means that the raw data can be used as long as this object is
kept alive, but must not be used for a longer time. Be careful
about that when copying the pointer to the memory somewhere
else, e.g. into another structure.
"""
if isinstance(cdecl, basestring):
cdecl = self._typeof(cdecl)
return self._backend.newp(cdecl, init)
def new_allocator(self, alloc=None, free=None,
should_clear_after_alloc=True):
"""Return a new allocator, i.e. a function that behaves like ffi.new()
but uses the provided low-level 'alloc' and 'free' functions.
'alloc' is called with the size as argument. If it returns NULL, a
MemoryError is raised. 'free' is called with the result of 'alloc'
as argument. Both can be either Python function or directly C
functions. If 'free' is None, then no free function is called.
If both 'alloc' and 'free' are None, the default is used.
If 'should_clear_after_alloc' is set to False, then the memory
returned by 'alloc' is assumed to be already cleared (or you are
fine with garbage); otherwise CFFI will clear it.
"""
compiled_ffi = self._backend.FFI()
allocator = compiled_ffi.new_allocator(alloc, free,
should_clear_after_alloc)
def allocate(cdecl, init=None):
if isinstance(cdecl, basestring):
cdecl = self._typeof(cdecl)
return allocator(cdecl, init)
return allocate
def cast(self, cdecl, source):
"""Similar to a C cast: returns an instance of the named C
type initialized with the given 'source'. The source is
casted between integers or pointers of any type.
"""
if isinstance(cdecl, basestring):
cdecl = self._typeof(cdecl)
return self._backend.cast(cdecl, source)
def string(self, cdata, maxlen=-1):
"""Return a Python string (or unicode string) from the 'cdata'.
If 'cdata' is a pointer or array of characters or bytes, returns
the null-terminated string. The returned string extends until
the first null character, or at most 'maxlen' characters. If
'cdata' is an array then 'maxlen' defaults to its length.
If 'cdata' is a pointer or array of wchar_t, returns a unicode
string following the same rules.
If 'cdata' is a single character or byte or a wchar_t, returns
it as a string or unicode string.
If 'cdata' is an enum, returns the value of the enumerator as a
string, or 'NUMBER' if the value is out of range.
"""
return self._backend.string(cdata, maxlen)
def unpack(self, cdata, length):
"""Unpack an array of C data of the given length,
returning a Python string/unicode/list.
If 'cdata' is a pointer to 'char', returns a byte string.
It does not stop at the first null. This is equivalent to:
ffi.buffer(cdata, length)[:]
If 'cdata' is a pointer to 'wchar_t', returns a unicode string.
'length' is measured in wchar_t's; it is not the size in bytes.
If 'cdata' is a pointer to anything else, returns a list of
'length' items. This is a faster equivalent to:
[cdata[i] for i in range(length)]
"""
return self._backend.unpack(cdata, length)
#def buffer(self, cdata, size=-1):
# """Return a read-write buffer object that references the raw C data
# pointed to by the given 'cdata'. The 'cdata' must be a pointer or
# an array. Can be passed to functions expecting a buffer, or directly
# manipulated with:
#
# buf[:] get a copy of it in a regular string, or
# buf[idx] as a single character
# buf[:] = ...
# buf[idx] = ... change the content
# """
# note that 'buffer' is a type, set on this instance by __init__
def from_buffer(self, python_buffer):
"""Return a <cdata 'char[]'> that points to the data of the
given Python object, which must support the buffer interface.
Note that this is not meant to be used on the built-in types
str or unicode (you can build 'char[]' arrays explicitly)
but only on objects containing large quantities of raw data
in some other format, like 'array.array' or numpy arrays.
"""
return self._backend.from_buffer(self.BCharA, python_buffer)
def memmove(self, dest, src, n):
"""ffi.memmove(dest, src, n) copies n bytes of memory from src to dest.
Like the C function memmove(), the memory areas may overlap;
apart from that it behaves like the C function memcpy().
'src' can be any cdata ptr or array, or any Python buffer object.
'dest' can be any cdata ptr or array, or a writable Python buffer
object. The size to copy, 'n', is always measured in bytes.
Unlike other methods, this one supports all Python buffer including
byte strings and bytearrays---but it still does not support
non-contiguous buffers.
"""
return self._backend.memmove(dest, src, n)
def callback(self, cdecl, python_callable=None, error=None, onerror=None):
"""Return a callback object or a decorator making such a
callback object. 'cdecl' must name a C function pointer type.
The callback invokes the specified 'python_callable' (which may
be provided either directly or via a decorator). Important: the
callback object must be manually kept alive for as long as the
callback may be invoked from the C level.
"""
def callback_decorator_wrap(python_callable):
if not callable(python_callable):
raise TypeError("the 'python_callable' argument "
"is not callable")
return self._backend.callback(cdecl, python_callable,
error, onerror)
if isinstance(cdecl, basestring):
cdecl = self._typeof(cdecl, consider_function_as_funcptr=True)
if python_callable is None:
return callback_decorator_wrap # decorator mode
else:
return callback_decorator_wrap(python_callable) # direct mode
def getctype(self, cdecl, replace_with=''):
"""Return a string giving the C type 'cdecl', which may be itself
a string or a <ctype> object. If 'replace_with' is given, it gives
extra text to append (or insert for more complicated C types), like
a variable name, or '*' to get actually the C type 'pointer-to-cdecl'.
"""
if isinstance(cdecl, basestring):
cdecl = self._typeof(cdecl)
replace_with = replace_with.strip()
if (replace_with.startswith('*')
and '&[' in self._backend.getcname(cdecl, '&')):
replace_with = '(%s)' % replace_with
elif replace_with and not replace_with[0] in '[(':
replace_with = ' ' + replace_with
return self._backend.getcname(cdecl, replace_with)
def gc(self, cdata, destructor):
"""Return a new cdata object that points to the same
data. Later, when this new cdata object is garbage-collected,
'destructor(old_cdata_object)' will be called.
"""
return self._backend.gcp(cdata, destructor)
def _get_cached_btype(self, type):
assert self._lock.acquire(False) is False
# call me with the lock!
try:
BType = self._cached_btypes[type]
except KeyError:
finishlist = []
BType = type.get_cached_btype(self, finishlist)
for type in finishlist:
type.finish_backend_type(self, finishlist)
return BType
def verify(self, source='', tmpdir=None, **kwargs):
"""Verify that the current ffi signatures compile on this
machine, and return a dynamic library object. The dynamic
library can be used to call functions and access global
variables declared in this 'ffi'. The library is compiled
by the C compiler: it gives you C-level API compatibility
(including calling macros). This is unlike 'ffi.dlopen()',
which requires binary compatibility in the signatures.
"""
from .verifier import Verifier, _caller_dir_pycache
#
# If set_unicode(True) was called, insert the UNICODE and
# _UNICODE macro declarations
if self._windows_unicode:
self._apply_windows_unicode(kwargs)
#
# Set the tmpdir here, and not in Verifier.__init__: it picks
# up the caller's directory, which we want to be the caller of
# ffi.verify(), as opposed to the caller of Veritier().
tmpdir = tmpdir or _caller_dir_pycache()
#
# Make a Verifier() and use it to load the library.
self.verifier = Verifier(self, source, tmpdir, **kwargs)
lib = self.verifier.load_library()
#
# Save the loaded library for keep-alive purposes, even
# if the caller doesn't keep it alive itself (it should).
self._libraries.append(lib)
return lib
def _get_errno(self):
return self._backend.get_errno()
def _set_errno(self, errno):
self._backend.set_errno(errno)
errno = property(_get_errno, _set_errno, None,
"the value of 'errno' from/to the C calls")
def getwinerror(self, code=-1):
return self._backend.getwinerror(code)
def _pointer_to(self, ctype):
with self._lock:
return model.pointer_cache(self, ctype)
def addressof(self, cdata, *fields_or_indexes):
"""Return the address of a <cdata 'struct-or-union'>.
If 'fields_or_indexes' are given, returns the address of that
field or array item in the structure or array, recursively in
case of nested structures.
"""
try:
ctype = self._backend.typeof(cdata)
except TypeError:
if '__addressof__' in type(cdata).__dict__:
return type(cdata).__addressof__(cdata, *fields_or_indexes)
raise
if fields_or_indexes:
ctype, offset = self._typeoffsetof(ctype, *fields_or_indexes)
else:
if ctype.kind == "pointer":
raise TypeError("addressof(pointer)")
offset = 0
ctypeptr = self._pointer_to(ctype)
return self._backend.rawaddressof(ctypeptr, cdata, offset)
def _typeoffsetof(self, ctype, field_or_index, *fields_or_indexes):
ctype, offset = self._backend.typeoffsetof(ctype, field_or_index)
for field1 in fields_or_indexes:
ctype, offset1 = self._backend.typeoffsetof(ctype, field1, 1)
offset += offset1
return ctype, offset
def include(self, ffi_to_include):
"""Includes the typedefs, structs, unions and enums defined
in another FFI instance. Usage is similar to a #include in C,
where a part of the program might include types defined in
another part for its own usage. Note that the include()
method has no effect on functions, constants and global
variables, which must anyway be accessed directly from the
lib object returned by the original FFI instance.
"""
if not isinstance(ffi_to_include, FFI):
raise TypeError("ffi.include() expects an argument that is also of"
" type cffi.FFI, not %r" % (
type(ffi_to_include).__name__,))
if ffi_to_include is self:
raise ValueError("self.include(self)")
with ffi_to_include._lock:
with self._lock:
self._parser.include(ffi_to_include._parser)
self._cdefsources.append('[')
self._cdefsources.extend(ffi_to_include._cdefsources)
self._cdefsources.append(']')
self._included_ffis.append(ffi_to_include)
def new_handle(self, x):
return self._backend.newp_handle(self.BVoidP, x)
def from_handle(self, x):
return self._backend.from_handle(x)
def set_unicode(self, enabled_flag):
"""Windows: if 'enabled_flag' is True, enable the UNICODE and
_UNICODE defines in C, and declare the types like TCHAR and LPTCSTR
to be (pointers to) wchar_t. If 'enabled_flag' is False,
declare these types to be (pointers to) plain 8-bit characters.
This is mostly for backward compatibility; you usually want True.
"""
if self._windows_unicode is not None:
raise ValueError("set_unicode() can only be called once")
enabled_flag = bool(enabled_flag)
if enabled_flag:
self.cdef("typedef wchar_t TBYTE;"
"typedef wchar_t TCHAR;"
"typedef const wchar_t *LPCTSTR;"
"typedef const wchar_t *PCTSTR;"
"typedef wchar_t *LPTSTR;"
"typedef wchar_t *PTSTR;"
"typedef TBYTE *PTBYTE;"
"typedef TCHAR *PTCHAR;")
else:
self.cdef("typedef char TBYTE;"
"typedef char TCHAR;"
"typedef const char *LPCTSTR;"
"typedef const char *PCTSTR;"
"typedef char *LPTSTR;"
"typedef char *PTSTR;"
"typedef TBYTE *PTBYTE;"
"typedef TCHAR *PTCHAR;")
self._windows_unicode = enabled_flag
def _apply_windows_unicode(self, kwds):
defmacros = kwds.get('define_macros', ())
if not isinstance(defmacros, (list, tuple)):
raise TypeError("'define_macros' must be a list or tuple")
defmacros = list(defmacros) + [('UNICODE', '1'),
('_UNICODE', '1')]
kwds['define_macros'] = defmacros
def _apply_embedding_fix(self, kwds):
# must include an argument like "-lpython2.7" for the compiler
def ensure(key, value):
lst = kwds.setdefault(key, [])
if value not in lst:
lst.append(value)
#
if '__pypy__' in sys.builtin_module_names:
import os
if sys.platform == "win32":
# we need 'libpypy-c.lib'. Current distributions of
# pypy (>= 4.1) contain it as 'libs/python27.lib'.
pythonlib = "python27"
if hasattr(sys, 'prefix'):
ensure('library_dirs', os.path.join(sys.prefix, 'libs'))
else:
# we need 'libpypy-c.{so,dylib}', which should be by
# default located in 'sys.prefix/bin' for installed
# systems.
if sys.version_info < (3,):
pythonlib = "pypy-c"
else:
pythonlib = "pypy3-c"
if hasattr(sys, 'prefix'):
ensure('library_dirs', os.path.join(sys.prefix, 'bin'))
# On uninstalled pypy's, the libpypy-c is typically found in
# .../pypy/goal/.
if hasattr(sys, 'prefix'):
ensure('library_dirs', os.path.join(sys.prefix, 'pypy', 'goal'))
else:
if sys.platform == "win32":
template = "python%d%d"
if hasattr(sys, 'gettotalrefcount'):
template += '_d'
else:
try:
import sysconfig
except ImportError: # 2.6
from distutils import sysconfig
template = "python%d.%d"
if sysconfig.get_config_var('DEBUG_EXT'):
template += sysconfig.get_config_var('DEBUG_EXT')
pythonlib = (template %
(sys.hexversion >> 24, (sys.hexversion >> 16) & 0xff))
if hasattr(sys, 'abiflags'):
pythonlib += sys.abiflags
ensure('libraries', pythonlib)
if sys.platform == "win32":
ensure('extra_link_args', '/MANIFEST')
def set_source(self, module_name, source, source_extension='.c', **kwds):
import os
if hasattr(self, '_assigned_source'):
raise ValueError("set_source() cannot be called several times "
"per ffi object")
if not isinstance(module_name, basestring):
raise TypeError("'module_name' must be a string")
if os.sep in module_name or (os.altsep and os.altsep in module_name):
raise ValueError("'module_name' must not contain '/': use a dotted "
"name to make a 'package.module' location")
self._assigned_source = (str(module_name), source,
source_extension, kwds)
def distutils_extension(self, tmpdir='build', verbose=True):
from distutils.dir_util import mkpath
from .recompiler import recompile
#
if not hasattr(self, '_assigned_source'):
if hasattr(self, 'verifier'): # fallback, 'tmpdir' ignored
return self.verifier.get_extension()
raise ValueError("set_source() must be called before"
" distutils_extension()")
module_name, source, source_extension, kwds = self._assigned_source
if source is None:
raise TypeError("distutils_extension() is only for C extension "
"modules, not for dlopen()-style pure Python "
"modules")
mkpath(tmpdir)
ext, updated = recompile(self, module_name,
source, tmpdir=tmpdir, extradir=tmpdir,
source_extension=source_extension,
call_c_compiler=False, **kwds)
if verbose:
if updated:
sys.stderr.write("regenerated: %r\n" % (ext.sources[0],))
else:
sys.stderr.write("not modified: %r\n" % (ext.sources[0],))
return ext
def emit_c_code(self, filename):
from .recompiler import recompile
#
if not hasattr(self, '_assigned_source'):
raise ValueError("set_source() must be called before emit_c_code()")
module_name, source, source_extension, kwds = self._assigned_source
if source is None:
raise TypeError("emit_c_code() is only for C extension modules, "
"not for dlopen()-style pure Python modules")
recompile(self, module_name, source,
c_file=filename, call_c_compiler=False, **kwds)
def emit_python_code(self, filename):
from .recompiler import recompile
#
if not hasattr(self, '_assigned_source'):
raise ValueError("set_source() must be called before emit_c_code()")
module_name, source, source_extension, kwds = self._assigned_source
if source is not None:
raise TypeError("emit_python_code() is only for dlopen()-style "
"pure Python modules, not for C extension modules")
recompile(self, module_name, source,
c_file=filename, call_c_compiler=False, **kwds)
def compile(self, tmpdir='.', verbose=0, target=None, debug=None):
"""The 'target' argument gives the final file name of the
compiled DLL. Use '*' to force distutils' choice, suitable for
regular CPython C API modules. Use a file name ending in '.*'
to ask for the system's default extension for dynamic libraries
(.so/.dll/.dylib).
The default is '*' when building a non-embedded C API extension,
and (module_name + '.*') when building an embedded library.
"""
from .recompiler import recompile
#
if not hasattr(self, '_assigned_source'):
raise ValueError("set_source() must be called before compile()")
module_name, source, source_extension, kwds = self._assigned_source
return recompile(self, module_name, source, tmpdir=tmpdir,
target=target, source_extension=source_extension,
compiler_verbose=verbose, debug=debug, **kwds)
def init_once(self, func, tag):
# Read _init_once_cache[tag], which is either (False, lock) if
# we're calling the function now in some thread, or (True, result).
# Don't call setdefault() in most cases, to avoid allocating and
# immediately freeing a lock; but still use setdefaut() to avoid
# races.
try:
x = self._init_once_cache[tag]
except KeyError:
x = self._init_once_cache.setdefault(tag, (False, allocate_lock()))
# Common case: we got (True, result), so we return the result.
if x[0]:
return x[1]
# Else, it's a lock. Acquire it to serialize the following tests.
with x[1]:
# Read again from _init_once_cache the current status.
x = self._init_once_cache[tag]
if x[0]:
return x[1]
# Call the function and store the result back.
result = func()
self._init_once_cache[tag] = (True, result)
return result
def embedding_init_code(self, pysource):
if self._embedding:
raise ValueError("embedding_init_code() can only be called once")
# fix 'pysource' before it gets dumped into the C file:
# - remove empty lines at the beginning, so it starts at "line 1"
# - dedent, if all non-empty lines are indented
# - check for SyntaxErrors
import re
match = re.match(r'\s*\n', pysource)
if match:
pysource = pysource[match.end():]
lines = pysource.splitlines() or ['']
prefix = re.match(r'\s*', lines[0]).group()
for i in range(1, len(lines)):
line = lines[i]
if line.rstrip():
while not line.startswith(prefix):
prefix = prefix[:-1]
i = len(prefix)
lines = [line[i:]+'\n' for line in lines]
pysource = ''.join(lines)
#
compile(pysource, "cffi_init", "exec")
#
self._embedding = pysource
def def_extern(self, *args, **kwds):
raise ValueError("ffi.def_extern() is only available on API-mode FFI "
"objects")
def list_types(self):
"""Returns the user type names known to this FFI instance.
This returns a tuple containing three lists of names:
(typedef_names, names_of_structs, names_of_unions)
"""
typedefs = []
structs = []
unions = []
for key in self._parser._declarations:
if key.startswith('typedef '):
typedefs.append(key[8:])
elif key.startswith('struct '):
structs.append(key[7:])
elif key.startswith('union '):
unions.append(key[6:])
typedefs.sort()
structs.sort()
unions.sort()
return (typedefs, structs, unions)
def _load_backend_lib(backend, name, flags):
import os
if name is None:
if sys.platform != "win32":
return backend.load_library(None, flags)
name = "c" # Windows: load_library(None) fails, but this works
# (backward compatibility hack only)
first_error = None
if '.' in name or '/' in name or os.sep in name:
try:
return backend.load_library(name, flags)
except OSError as e:
first_error = e
import ctypes.util
path = ctypes.util.find_library(name)
if path is None:
msg = ("ctypes.util.find_library() did not manage "
"to locate a library called %r" % (name,))
if first_error is not None:
msg = "%s. Additionally, %s" % (first_error, msg)
raise OSError(msg)
return backend.load_library(path, flags)
def _make_ffi_library(ffi, libname, flags):
backend = ffi._backend
backendlib = _load_backend_lib(backend, libname, flags)
#
def accessor_function(name):
key = 'function ' + name
tp, _ = ffi._parser._declarations[key]
BType = ffi._get_cached_btype(tp)
value = backendlib.load_function(BType, name)
library.__dict__[name] = value
#
def accessor_variable(name):
key = 'variable ' + name
tp, _ = ffi._parser._declarations[key]
BType = ffi._get_cached_btype(tp)
read_variable = backendlib.read_variable
write_variable = backendlib.write_variable
setattr(FFILibrary, name, property(
lambda self: read_variable(BType, name),
lambda self, value: write_variable(BType, name, value)))
#
def addressof_var(name):
try:
return addr_variables[name]
except KeyError:
with ffi._lock:
if name not in addr_variables:
key = 'variable ' + name
tp, _ = ffi._parser._declarations[key]
BType = ffi._get_cached_btype(tp)
if BType.kind != 'array':
BType = model.pointer_cache(ffi, BType)
p = backendlib.load_function(BType, name)
addr_variables[name] = p
return addr_variables[name]
#
def accessor_constant(name):
raise NotImplementedError("non-integer constant '%s' cannot be "
"accessed from a dlopen() library" % (name,))
#
def accessor_int_constant(name):
library.__dict__[name] = ffi._parser._int_constants[name]
#
accessors = {}
accessors_version = [False]
addr_variables = {}
#
def update_accessors():
if accessors_version[0] is ffi._cdef_version:
return
#
for key, (tp, _) in ffi._parser._declarations.items():
if not isinstance(tp, model.EnumType):
tag, name = key.split(' ', 1)
if tag == 'function':
accessors[name] = accessor_function
elif tag == 'variable':
accessors[name] = accessor_variable
elif tag == 'constant':
accessors[name] = accessor_constant
else:
for i, enumname in enumerate(tp.enumerators):
def accessor_enum(name, tp=tp, i=i):
tp.check_not_partial()
library.__dict__[name] = tp.enumvalues[i]
accessors[enumname] = accessor_enum
for name in ffi._parser._int_constants:
accessors.setdefault(name, accessor_int_constant)
accessors_version[0] = ffi._cdef_version
#
def make_accessor(name):
with ffi._lock:
if name in library.__dict__ or name in FFILibrary.__dict__:
return # added by another thread while waiting for the lock
if name not in accessors:
update_accessors()
if name not in accessors:
raise AttributeError(name)
accessors[name](name)
#
class FFILibrary(object):
def __getattr__(self, name):
make_accessor(name)
return getattr(self, name)
def __setattr__(self, name, value):
try:
property = getattr(self.__class__, name)
except AttributeError:
make_accessor(name)
setattr(self, name, value)
else:
property.__set__(self, value)
def __dir__(self):
with ffi._lock:
update_accessors()
return accessors.keys()
def __addressof__(self, name):
if name in library.__dict__:
return library.__dict__[name]
if name in FFILibrary.__dict__:
return addressof_var(name)
make_accessor(name)
if name in library.__dict__:
return library.__dict__[name]
if name in FFILibrary.__dict__:
return addressof_var(name)
raise AttributeError("cffi library has no function or "
"global variable named '%s'" % (name,))
#
if libname is not None:
try:
if not isinstance(libname, str): # unicode, on Python 2
libname = libname.encode('utf-8')
FFILibrary.__name__ = 'FFILibrary_%s' % libname
except UnicodeError:
pass
library = FFILibrary()
return library, library.__dict__
def _builtin_function_type(func):
# a hack to make at least ffi.typeof(builtin_function) work,
# if the builtin function was obtained by 'vengine_cpy'.
import sys
try:
module = sys.modules[func.__module__]
ffi = module._cffi_original_ffi
types_of_builtin_funcs = module._cffi_types_of_builtin_funcs
tp = types_of_builtin_funcs[func]
except (KeyError, AttributeError, TypeError):
return None
else:
with ffi._lock:
return ffi._get_cached_btype(tp)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,179 @@
from .error import VerificationError
class CffiOp(object):
def __init__(self, op, arg):
self.op = op
self.arg = arg
def as_c_expr(self):
if self.op is None:
assert isinstance(self.arg, str)
return '(_cffi_opcode_t)(%s)' % (self.arg,)
classname = CLASS_NAME[self.op]
return '_CFFI_OP(_CFFI_OP_%s, %s)' % (classname, self.arg)
def as_python_bytes(self):
if self.op is None and self.arg.isdigit():
value = int(self.arg) # non-negative: '-' not in self.arg
if value >= 2**31:
raise OverflowError("cannot emit %r: limited to 2**31-1"
% (self.arg,))
return format_four_bytes(value)
if isinstance(self.arg, str):
raise VerificationError("cannot emit to Python: %r" % (self.arg,))
return format_four_bytes((self.arg << 8) | self.op)
def __str__(self):
classname = CLASS_NAME.get(self.op, self.op)
return '(%s %s)' % (classname, self.arg)
def format_four_bytes(num):
return '\\x%02X\\x%02X\\x%02X\\x%02X' % (
(num >> 24) & 0xFF,
(num >> 16) & 0xFF,
(num >> 8) & 0xFF,
(num ) & 0xFF)
OP_PRIMITIVE = 1
OP_POINTER = 3
OP_ARRAY = 5
OP_OPEN_ARRAY = 7
OP_STRUCT_UNION = 9
OP_ENUM = 11
OP_FUNCTION = 13
OP_FUNCTION_END = 15
OP_NOOP = 17
OP_BITFIELD = 19
OP_TYPENAME = 21
OP_CPYTHON_BLTN_V = 23 # varargs
OP_CPYTHON_BLTN_N = 25 # noargs
OP_CPYTHON_BLTN_O = 27 # O (i.e. a single arg)
OP_CONSTANT = 29
OP_CONSTANT_INT = 31
OP_GLOBAL_VAR = 33
OP_DLOPEN_FUNC = 35
OP_DLOPEN_CONST = 37
OP_GLOBAL_VAR_F = 39
OP_EXTERN_PYTHON = 41
PRIM_VOID = 0
PRIM_BOOL = 1
PRIM_CHAR = 2
PRIM_SCHAR = 3
PRIM_UCHAR = 4
PRIM_SHORT = 5
PRIM_USHORT = 6
PRIM_INT = 7
PRIM_UINT = 8
PRIM_LONG = 9
PRIM_ULONG = 10
PRIM_LONGLONG = 11
PRIM_ULONGLONG = 12
PRIM_FLOAT = 13
PRIM_DOUBLE = 14
PRIM_LONGDOUBLE = 15
PRIM_WCHAR = 16
PRIM_INT8 = 17
PRIM_UINT8 = 18
PRIM_INT16 = 19
PRIM_UINT16 = 20
PRIM_INT32 = 21
PRIM_UINT32 = 22
PRIM_INT64 = 23
PRIM_UINT64 = 24
PRIM_INTPTR = 25
PRIM_UINTPTR = 26
PRIM_PTRDIFF = 27
PRIM_SIZE = 28
PRIM_SSIZE = 29
PRIM_INT_LEAST8 = 30
PRIM_UINT_LEAST8 = 31
PRIM_INT_LEAST16 = 32
PRIM_UINT_LEAST16 = 33
PRIM_INT_LEAST32 = 34
PRIM_UINT_LEAST32 = 35
PRIM_INT_LEAST64 = 36
PRIM_UINT_LEAST64 = 37
PRIM_INT_FAST8 = 38
PRIM_UINT_FAST8 = 39
PRIM_INT_FAST16 = 40
PRIM_UINT_FAST16 = 41
PRIM_INT_FAST32 = 42
PRIM_UINT_FAST32 = 43
PRIM_INT_FAST64 = 44
PRIM_UINT_FAST64 = 45
PRIM_INTMAX = 46
PRIM_UINTMAX = 47
_NUM_PRIM = 48
_UNKNOWN_PRIM = -1
_UNKNOWN_FLOAT_PRIM = -2
_UNKNOWN_LONG_DOUBLE = -3
_IO_FILE_STRUCT = -1
PRIMITIVE_TO_INDEX = {
'char': PRIM_CHAR,
'short': PRIM_SHORT,
'int': PRIM_INT,
'long': PRIM_LONG,
'long long': PRIM_LONGLONG,
'signed char': PRIM_SCHAR,
'unsigned char': PRIM_UCHAR,
'unsigned short': PRIM_USHORT,
'unsigned int': PRIM_UINT,
'unsigned long': PRIM_ULONG,
'unsigned long long': PRIM_ULONGLONG,
'float': PRIM_FLOAT,
'double': PRIM_DOUBLE,
'long double': PRIM_LONGDOUBLE,
'_Bool': PRIM_BOOL,
'wchar_t': PRIM_WCHAR,
'int8_t': PRIM_INT8,
'uint8_t': PRIM_UINT8,
'int16_t': PRIM_INT16,
'uint16_t': PRIM_UINT16,
'int32_t': PRIM_INT32,
'uint32_t': PRIM_UINT32,
'int64_t': PRIM_INT64,
'uint64_t': PRIM_UINT64,
'intptr_t': PRIM_INTPTR,
'uintptr_t': PRIM_UINTPTR,
'ptrdiff_t': PRIM_PTRDIFF,
'size_t': PRIM_SIZE,
'ssize_t': PRIM_SSIZE,
'int_least8_t': PRIM_INT_LEAST8,
'uint_least8_t': PRIM_UINT_LEAST8,
'int_least16_t': PRIM_INT_LEAST16,
'uint_least16_t': PRIM_UINT_LEAST16,
'int_least32_t': PRIM_INT_LEAST32,
'uint_least32_t': PRIM_UINT_LEAST32,
'int_least64_t': PRIM_INT_LEAST64,
'uint_least64_t': PRIM_UINT_LEAST64,
'int_fast8_t': PRIM_INT_FAST8,
'uint_fast8_t': PRIM_UINT_FAST8,
'int_fast16_t': PRIM_INT_FAST16,
'uint_fast16_t': PRIM_UINT_FAST16,
'int_fast32_t': PRIM_INT_FAST32,
'uint_fast32_t': PRIM_UINT_FAST32,
'int_fast64_t': PRIM_INT_FAST64,
'uint_fast64_t': PRIM_UINT_FAST64,
'intmax_t': PRIM_INTMAX,
'uintmax_t': PRIM_UINTMAX,
}
F_UNION = 0x01
F_CHECK_FIELDS = 0x02
F_PACKED = 0x04
F_EXTERNAL = 0x08
F_OPAQUE = 0x10
G_FLAGS = dict([('_CFFI_' + _key, globals()[_key])
for _key in ['F_UNION', 'F_CHECK_FIELDS', 'F_PACKED',
'F_EXTERNAL', 'F_OPAQUE']])
CLASS_NAME = {}
for _name, _value in list(globals().items()):
if _name.startswith('OP_') and isinstance(_value, int):
CLASS_NAME[_value] = _name[3:]

View File

@ -0,0 +1,80 @@
import sys
from . import model
from .error import FFIError
COMMON_TYPES = {}
try:
# fetch "bool" and all simple Windows types
from _cffi_backend import _get_common_types
_get_common_types(COMMON_TYPES)
except ImportError:
pass
COMMON_TYPES['FILE'] = model.unknown_type('FILE', '_IO_FILE')
COMMON_TYPES['bool'] = '_Bool' # in case we got ImportError above
for _type in model.PrimitiveType.ALL_PRIMITIVE_TYPES:
if _type.endswith('_t'):
COMMON_TYPES[_type] = _type
del _type
_CACHE = {}
def resolve_common_type(parser, commontype):
try:
return _CACHE[commontype]
except KeyError:
cdecl = COMMON_TYPES.get(commontype, commontype)
if not isinstance(cdecl, str):
result, quals = cdecl, 0 # cdecl is already a BaseType
elif cdecl in model.PrimitiveType.ALL_PRIMITIVE_TYPES:
result, quals = model.PrimitiveType(cdecl), 0
elif cdecl == 'set-unicode-needed':
raise FFIError("The Windows type %r is only available after "
"you call ffi.set_unicode()" % (commontype,))
else:
if commontype == cdecl:
raise FFIError(
"Unsupported type: %r. Please look at "
"http://cffi.readthedocs.io/en/latest/cdef.html#ffi-cdef-limitations "
"and file an issue if you think this type should really "
"be supported." % (commontype,))
result, quals = parser.parse_type_and_quals(cdecl) # recursive
assert isinstance(result, model.BaseTypeByIdentity)
_CACHE[commontype] = result, quals
return result, quals
# ____________________________________________________________
# extra types for Windows (most of them are in commontypes.c)
def win_common_types():
return {
"UNICODE_STRING": model.StructType(
"_UNICODE_STRING",
["Length",
"MaximumLength",
"Buffer"],
[model.PrimitiveType("unsigned short"),
model.PrimitiveType("unsigned short"),
model.PointerType(model.PrimitiveType("wchar_t"))],
[-1, -1, -1]),
"PUNICODE_STRING": "UNICODE_STRING *",
"PCUNICODE_STRING": "const UNICODE_STRING *",
"TBYTE": "set-unicode-needed",
"TCHAR": "set-unicode-needed",
"LPCTSTR": "set-unicode-needed",
"PCTSTR": "set-unicode-needed",
"LPTSTR": "set-unicode-needed",
"PTSTR": "set-unicode-needed",
"PTBYTE": "set-unicode-needed",
"PTCHAR": "set-unicode-needed",
}
if sys.platform == 'win32':
COMMON_TYPES.update(win_common_types())

View File

@ -0,0 +1,876 @@
from . import model
from .commontypes import COMMON_TYPES, resolve_common_type
from .error import FFIError, CDefError
try:
from . import _pycparser as pycparser
except ImportError:
import pycparser
import weakref, re, sys
try:
if sys.version_info < (3,):
import thread as _thread
else:
import _thread
lock = _thread.allocate_lock()
except ImportError:
lock = None
_r_comment = re.compile(r"/\*.*?\*/|//([^\n\\]|\\.)*?$",
re.DOTALL | re.MULTILINE)
_r_define = re.compile(r"^\s*#\s*define\s+([A-Za-z_][A-Za-z_0-9]*)"
r"\b((?:[^\n\\]|\\.)*?)$",
re.DOTALL | re.MULTILINE)
_r_partial_enum = re.compile(r"=\s*\.\.\.\s*[,}]|\.\.\.\s*\}")
_r_enum_dotdotdot = re.compile(r"__dotdotdot\d+__$")
_r_partial_array = re.compile(r"\[\s*\.\.\.\s*\]")
_r_words = re.compile(r"\w+|\S")
_parser_cache = None
_r_int_literal = re.compile(r"-?0?x?[0-9a-f]+[lu]*$", re.IGNORECASE)
_r_stdcall1 = re.compile(r"\b(__stdcall|WINAPI)\b")
_r_stdcall2 = re.compile(r"[(]\s*(__stdcall|WINAPI)\b")
_r_cdecl = re.compile(r"\b__cdecl\b")
_r_extern_python = re.compile(r'\bextern\s*"'
r'(Python|Python\s*\+\s*C|C\s*\+\s*Python)"\s*.')
_r_star_const_space = re.compile( # matches "* const "
r"[*]\s*((const|volatile|restrict)\b\s*)+")
_r_int_dotdotdot = re.compile(r"(\b(int|long|short|signed|unsigned|char)\s*)+"
r"\.\.\.")
_r_float_dotdotdot = re.compile(r"\b(double|float)\s*\.\.\.")
def _get_parser():
global _parser_cache
if _parser_cache is None:
_parser_cache = pycparser.CParser()
return _parser_cache
def _workaround_for_old_pycparser(csource):
# Workaround for a pycparser issue (fixed between pycparser 2.10 and
# 2.14): "char*const***" gives us a wrong syntax tree, the same as
# for "char***(*const)". This means we can't tell the difference
# afterwards. But "char(*const(***))" gives us the right syntax
# tree. The issue only occurs if there are several stars in
# sequence with no parenthesis inbetween, just possibly qualifiers.
# Attempt to fix it by adding some parentheses in the source: each
# time we see "* const" or "* const *", we add an opening
# parenthesis before each star---the hard part is figuring out where
# to close them.
parts = []
while True:
match = _r_star_const_space.search(csource)
if not match:
break
#print repr(''.join(parts)+csource), '=>',
parts.append(csource[:match.start()])
parts.append('('); closing = ')'
parts.append(match.group()) # e.g. "* const "
endpos = match.end()
if csource.startswith('*', endpos):
parts.append('('); closing += ')'
level = 0
i = endpos
while i < len(csource):
c = csource[i]
if c == '(':
level += 1
elif c == ')':
if level == 0:
break
level -= 1
elif c in ',;=':
if level == 0:
break
i += 1
csource = csource[endpos:i] + closing + csource[i:]
#print repr(''.join(parts)+csource)
parts.append(csource)
return ''.join(parts)
def _preprocess_extern_python(csource):
# input: `extern "Python" int foo(int);` or
# `extern "Python" { int foo(int); }`
# output:
# void __cffi_extern_python_start;
# int foo(int);
# void __cffi_extern_python_stop;
#
# input: `extern "Python+C" int foo(int);`
# output:
# void __cffi_extern_python_plus_c_start;
# int foo(int);
# void __cffi_extern_python_stop;
parts = []
while True:
match = _r_extern_python.search(csource)
if not match:
break
endpos = match.end() - 1
#print
#print ''.join(parts)+csource
#print '=>'
parts.append(csource[:match.start()])
if 'C' in match.group(1):
parts.append('void __cffi_extern_python_plus_c_start; ')
else:
parts.append('void __cffi_extern_python_start; ')
if csource[endpos] == '{':
# grouping variant
closing = csource.find('}', endpos)
if closing < 0:
raise CDefError("'extern \"Python\" {': no '}' found")
if csource.find('{', endpos + 1, closing) >= 0:
raise NotImplementedError("cannot use { } inside a block "
"'extern \"Python\" { ... }'")
parts.append(csource[endpos+1:closing])
csource = csource[closing+1:]
else:
# non-grouping variant
semicolon = csource.find(';', endpos)
if semicolon < 0:
raise CDefError("'extern \"Python\": no ';' found")
parts.append(csource[endpos:semicolon+1])
csource = csource[semicolon+1:]
parts.append(' void __cffi_extern_python_stop;')
#print ''.join(parts)+csource
#print
parts.append(csource)
return ''.join(parts)
def _preprocess(csource):
# Remove comments. NOTE: this only work because the cdef() section
# should not contain any string literal!
csource = _r_comment.sub(' ', csource)
# Remove the "#define FOO x" lines
macros = {}
for match in _r_define.finditer(csource):
macroname, macrovalue = match.groups()
macrovalue = macrovalue.replace('\\\n', '').strip()
macros[macroname] = macrovalue
csource = _r_define.sub('', csource)
#
if pycparser.__version__ < '2.14':
csource = _workaround_for_old_pycparser(csource)
#
# BIG HACK: replace WINAPI or __stdcall with "volatile const".
# It doesn't make sense for the return type of a function to be
# "volatile volatile const", so we abuse it to detect __stdcall...
# Hack number 2 is that "int(volatile *fptr)();" is not valid C
# syntax, so we place the "volatile" before the opening parenthesis.
csource = _r_stdcall2.sub(' volatile volatile const(', csource)
csource = _r_stdcall1.sub(' volatile volatile const ', csource)
csource = _r_cdecl.sub(' ', csource)
#
# Replace `extern "Python"` with start/end markers
csource = _preprocess_extern_python(csource)
#
# Replace "[...]" with "[__dotdotdotarray__]"
csource = _r_partial_array.sub('[__dotdotdotarray__]', csource)
#
# Replace "...}" with "__dotdotdotNUM__}". This construction should
# occur only at the end of enums; at the end of structs we have "...;}"
# and at the end of vararg functions "...);". Also replace "=...[,}]"
# with ",__dotdotdotNUM__[,}]": this occurs in the enums too, when
# giving an unknown value.
matches = list(_r_partial_enum.finditer(csource))
for number, match in enumerate(reversed(matches)):
p = match.start()
if csource[p] == '=':
p2 = csource.find('...', p, match.end())
assert p2 > p
csource = '%s,__dotdotdot%d__ %s' % (csource[:p], number,
csource[p2+3:])
else:
assert csource[p:p+3] == '...'
csource = '%s __dotdotdot%d__ %s' % (csource[:p], number,
csource[p+3:])
# Replace "int ..." or "unsigned long int..." with "__dotdotdotint__"
csource = _r_int_dotdotdot.sub(' __dotdotdotint__ ', csource)
# Replace "float ..." or "double..." with "__dotdotdotfloat__"
csource = _r_float_dotdotdot.sub(' __dotdotdotfloat__ ', csource)
# Replace all remaining "..." with the same name, "__dotdotdot__",
# which is declared with a typedef for the purpose of C parsing.
return csource.replace('...', ' __dotdotdot__ '), macros
def _common_type_names(csource):
# Look in the source for what looks like usages of types from the
# list of common types. A "usage" is approximated here as the
# appearance of the word, minus a "definition" of the type, which
# is the last word in a "typedef" statement. Approximative only
# but should be fine for all the common types.
look_for_words = set(COMMON_TYPES)
look_for_words.add(';')
look_for_words.add(',')
look_for_words.add('(')
look_for_words.add(')')
look_for_words.add('typedef')
words_used = set()
is_typedef = False
paren = 0
previous_word = ''
for word in _r_words.findall(csource):
if word in look_for_words:
if word == ';':
if is_typedef:
words_used.discard(previous_word)
look_for_words.discard(previous_word)
is_typedef = False
elif word == 'typedef':
is_typedef = True
paren = 0
elif word == '(':
paren += 1
elif word == ')':
paren -= 1
elif word == ',':
if is_typedef and paren == 0:
words_used.discard(previous_word)
look_for_words.discard(previous_word)
else: # word in COMMON_TYPES
words_used.add(word)
previous_word = word
return words_used
class Parser(object):
def __init__(self):
self._declarations = {}
self._included_declarations = set()
self._anonymous_counter = 0
self._structnode2type = weakref.WeakKeyDictionary()
self._options = {}
self._int_constants = {}
self._recomplete = []
self._uses_new_feature = None
def _parse(self, csource):
csource, macros = _preprocess(csource)
# XXX: for more efficiency we would need to poke into the
# internals of CParser... the following registers the
# typedefs, because their presence or absence influences the
# parsing itself (but what they are typedef'ed to plays no role)
ctn = _common_type_names(csource)
typenames = []
for name in sorted(self._declarations):
if name.startswith('typedef '):
name = name[8:]
typenames.append(name)
ctn.discard(name)
typenames += sorted(ctn)
#
csourcelines = ['typedef int %s;' % typename for typename in typenames]
csourcelines.append('typedef int __dotdotdotint__, __dotdotdotfloat__,'
' __dotdotdot__;')
csourcelines.append(csource)
csource = '\n'.join(csourcelines)
if lock is not None:
lock.acquire() # pycparser is not thread-safe...
try:
ast = _get_parser().parse(csource)
except pycparser.c_parser.ParseError as e:
self.convert_pycparser_error(e, csource)
finally:
if lock is not None:
lock.release()
# csource will be used to find buggy source text
return ast, macros, csource
def _convert_pycparser_error(self, e, csource):
# xxx look for ":NUM:" at the start of str(e) and try to interpret
# it as a line number
line = None
msg = str(e)
if msg.startswith(':') and ':' in msg[1:]:
linenum = msg[1:msg.find(':',1)]
if linenum.isdigit():
linenum = int(linenum, 10)
csourcelines = csource.splitlines()
if 1 <= linenum <= len(csourcelines):
line = csourcelines[linenum-1]
return line
def convert_pycparser_error(self, e, csource):
line = self._convert_pycparser_error(e, csource)
msg = str(e)
if line:
msg = 'cannot parse "%s"\n%s' % (line.strip(), msg)
else:
msg = 'parse error\n%s' % (msg,)
raise CDefError(msg)
def parse(self, csource, override=False, packed=False, dllexport=False):
prev_options = self._options
try:
self._options = {'override': override,
'packed': packed,
'dllexport': dllexport}
self._internal_parse(csource)
finally:
self._options = prev_options
def _internal_parse(self, csource):
ast, macros, csource = self._parse(csource)
# add the macros
self._process_macros(macros)
# find the first "__dotdotdot__" and use that as a separator
# between the repeated typedefs and the real csource
iterator = iter(ast.ext)
for decl in iterator:
if decl.name == '__dotdotdot__':
break
else:
assert 0
#
try:
self._inside_extern_python = '__cffi_extern_python_stop'
for decl in iterator:
if isinstance(decl, pycparser.c_ast.Decl):
self._parse_decl(decl)
elif isinstance(decl, pycparser.c_ast.Typedef):
if not decl.name:
raise CDefError("typedef does not declare any name",
decl)
quals = 0
if (isinstance(decl.type.type, pycparser.c_ast.IdentifierType) and
decl.type.type.names[-1].startswith('__dotdotdot')):
realtype = self._get_unknown_type(decl)
elif (isinstance(decl.type, pycparser.c_ast.PtrDecl) and
isinstance(decl.type.type, pycparser.c_ast.TypeDecl) and
isinstance(decl.type.type.type,
pycparser.c_ast.IdentifierType) and
decl.type.type.type.names[-1].startswith('__dotdotdot')):
realtype = self._get_unknown_ptr_type(decl)
else:
realtype, quals = self._get_type_and_quals(
decl.type, name=decl.name, partial_length_ok=True)
self._declare('typedef ' + decl.name, realtype, quals=quals)
elif decl.__class__.__name__ == 'Pragma':
pass # skip pragma, only in pycparser 2.15
else:
raise CDefError("unrecognized construct", decl)
except FFIError as e:
msg = self._convert_pycparser_error(e, csource)
if msg:
e.args = (e.args[0] + "\n *** Err: %s" % msg,)
raise
def _add_constants(self, key, val):
if key in self._int_constants:
if self._int_constants[key] == val:
return # ignore identical double declarations
raise FFIError(
"multiple declarations of constant: %s" % (key,))
self._int_constants[key] = val
def _add_integer_constant(self, name, int_str):
int_str = int_str.lower().rstrip("ul")
neg = int_str.startswith('-')
if neg:
int_str = int_str[1:]
# "010" is not valid oct in py3
if (int_str.startswith("0") and int_str != '0'
and not int_str.startswith("0x")):
int_str = "0o" + int_str[1:]
pyvalue = int(int_str, 0)
if neg:
pyvalue = -pyvalue
self._add_constants(name, pyvalue)
self._declare('macro ' + name, pyvalue)
def _process_macros(self, macros):
for key, value in macros.items():
value = value.strip()
if _r_int_literal.match(value):
self._add_integer_constant(key, value)
elif value == '...':
self._declare('macro ' + key, value)
else:
raise CDefError(
'only supports one of the following syntax:\n'
' #define %s ... (literally dot-dot-dot)\n'
' #define %s NUMBER (with NUMBER an integer'
' constant, decimal/hex/octal)\n'
'got:\n'
' #define %s %s'
% (key, key, key, value))
def _declare_function(self, tp, quals, decl):
tp = self._get_type_pointer(tp, quals)
if self._options.get('dllexport'):
tag = 'dllexport_python '
elif self._inside_extern_python == '__cffi_extern_python_start':
tag = 'extern_python '
elif self._inside_extern_python == '__cffi_extern_python_plus_c_start':
tag = 'extern_python_plus_c '
else:
tag = 'function '
self._declare(tag + decl.name, tp)
def _parse_decl(self, decl):
node = decl.type
if isinstance(node, pycparser.c_ast.FuncDecl):
tp, quals = self._get_type_and_quals(node, name=decl.name)
assert isinstance(tp, model.RawFunctionType)
self._declare_function(tp, quals, decl)
else:
if isinstance(node, pycparser.c_ast.Struct):
self._get_struct_union_enum_type('struct', node)
elif isinstance(node, pycparser.c_ast.Union):
self._get_struct_union_enum_type('union', node)
elif isinstance(node, pycparser.c_ast.Enum):
self._get_struct_union_enum_type('enum', node)
elif not decl.name:
raise CDefError("construct does not declare any variable",
decl)
#
if decl.name:
tp, quals = self._get_type_and_quals(node,
partial_length_ok=True)
if tp.is_raw_function:
self._declare_function(tp, quals, decl)
elif (tp.is_integer_type() and
hasattr(decl, 'init') and
hasattr(decl.init, 'value') and
_r_int_literal.match(decl.init.value)):
self._add_integer_constant(decl.name, decl.init.value)
elif (tp.is_integer_type() and
isinstance(decl.init, pycparser.c_ast.UnaryOp) and
decl.init.op == '-' and
hasattr(decl.init.expr, 'value') and
_r_int_literal.match(decl.init.expr.value)):
self._add_integer_constant(decl.name,
'-' + decl.init.expr.value)
elif (tp is model.void_type and
decl.name.startswith('__cffi_extern_python_')):
# hack: `extern "Python"` in the C source is replaced
# with "void __cffi_extern_python_start;" and
# "void __cffi_extern_python_stop;"
self._inside_extern_python = decl.name
else:
if self._inside_extern_python !='__cffi_extern_python_stop':
raise CDefError(
"cannot declare constants or "
"variables with 'extern \"Python\"'")
if (quals & model.Q_CONST) and not tp.is_array_type:
self._declare('constant ' + decl.name, tp, quals=quals)
else:
self._declare('variable ' + decl.name, tp, quals=quals)
def parse_type(self, cdecl):
return self.parse_type_and_quals(cdecl)[0]
def parse_type_and_quals(self, cdecl):
ast, macros = self._parse('void __dummy(\n%s\n);' % cdecl)[:2]
assert not macros
exprnode = ast.ext[-1].type.args.params[0]
if isinstance(exprnode, pycparser.c_ast.ID):
raise CDefError("unknown identifier '%s'" % (exprnode.name,))
return self._get_type_and_quals(exprnode.type)
def _declare(self, name, obj, included=False, quals=0):
if name in self._declarations:
prevobj, prevquals = self._declarations[name]
if prevobj is obj and prevquals == quals:
return
if not self._options.get('override'):
raise FFIError(
"multiple declarations of %s (for interactive usage, "
"try cdef(xx, override=True))" % (name,))
assert '__dotdotdot__' not in name.split()
self._declarations[name] = (obj, quals)
if included:
self._included_declarations.add(obj)
def _extract_quals(self, type):
quals = 0
if isinstance(type, (pycparser.c_ast.TypeDecl,
pycparser.c_ast.PtrDecl)):
if 'const' in type.quals:
quals |= model.Q_CONST
if 'volatile' in type.quals:
quals |= model.Q_VOLATILE
if 'restrict' in type.quals:
quals |= model.Q_RESTRICT
return quals
def _get_type_pointer(self, type, quals, declname=None):
if isinstance(type, model.RawFunctionType):
return type.as_function_pointer()
if (isinstance(type, model.StructOrUnionOrEnum) and
type.name.startswith('$') and type.name[1:].isdigit() and
type.forcename is None and declname is not None):
return model.NamedPointerType(type, declname, quals)
return model.PointerType(type, quals)
def _get_type_and_quals(self, typenode, name=None, partial_length_ok=False):
# first, dereference typedefs, if we have it already parsed, we're good
if (isinstance(typenode, pycparser.c_ast.TypeDecl) and
isinstance(typenode.type, pycparser.c_ast.IdentifierType) and
len(typenode.type.names) == 1 and
('typedef ' + typenode.type.names[0]) in self._declarations):
tp, quals = self._declarations['typedef ' + typenode.type.names[0]]
quals |= self._extract_quals(typenode)
return tp, quals
#
if isinstance(typenode, pycparser.c_ast.ArrayDecl):
# array type
if typenode.dim is None:
length = None
else:
length = self._parse_constant(
typenode.dim, partial_length_ok=partial_length_ok)
tp, quals = self._get_type_and_quals(typenode.type,
partial_length_ok=partial_length_ok)
return model.ArrayType(tp, length), quals
#
if isinstance(typenode, pycparser.c_ast.PtrDecl):
# pointer type
itemtype, itemquals = self._get_type_and_quals(typenode.type)
tp = self._get_type_pointer(itemtype, itemquals, declname=name)
quals = self._extract_quals(typenode)
return tp, quals
#
if isinstance(typenode, pycparser.c_ast.TypeDecl):
quals = self._extract_quals(typenode)
type = typenode.type
if isinstance(type, pycparser.c_ast.IdentifierType):
# assume a primitive type. get it from .names, but reduce
# synonyms to a single chosen combination
names = list(type.names)
if names != ['signed', 'char']: # keep this unmodified
prefixes = {}
while names:
name = names[0]
if name in ('short', 'long', 'signed', 'unsigned'):
prefixes[name] = prefixes.get(name, 0) + 1
del names[0]
else:
break
# ignore the 'signed' prefix below, and reorder the others
newnames = []
for prefix in ('unsigned', 'short', 'long'):
for i in range(prefixes.get(prefix, 0)):
newnames.append(prefix)
if not names:
names = ['int'] # implicitly
if names == ['int']: # but kill it if 'short' or 'long'
if 'short' in prefixes or 'long' in prefixes:
names = []
names = newnames + names
ident = ' '.join(names)
if ident == 'void':
return model.void_type, quals
if ident == '__dotdotdot__':
raise FFIError(':%d: bad usage of "..."' %
typenode.coord.line)
tp0, quals0 = resolve_common_type(self, ident)
return tp0, (quals | quals0)
#
if isinstance(type, pycparser.c_ast.Struct):
# 'struct foobar'
tp = self._get_struct_union_enum_type('struct', type, name)
return tp, quals
#
if isinstance(type, pycparser.c_ast.Union):
# 'union foobar'
tp = self._get_struct_union_enum_type('union', type, name)
return tp, quals
#
if isinstance(type, pycparser.c_ast.Enum):
# 'enum foobar'
tp = self._get_struct_union_enum_type('enum', type, name)
return tp, quals
#
if isinstance(typenode, pycparser.c_ast.FuncDecl):
# a function type
return self._parse_function_type(typenode, name), 0
#
# nested anonymous structs or unions end up here
if isinstance(typenode, pycparser.c_ast.Struct):
return self._get_struct_union_enum_type('struct', typenode, name,
nested=True), 0
if isinstance(typenode, pycparser.c_ast.Union):
return self._get_struct_union_enum_type('union', typenode, name,
nested=True), 0
#
raise FFIError(":%d: bad or unsupported type declaration" %
typenode.coord.line)
def _parse_function_type(self, typenode, funcname=None):
params = list(getattr(typenode.args, 'params', []))
for i, arg in enumerate(params):
if not hasattr(arg, 'type'):
raise CDefError("%s arg %d: unknown type '%s'"
" (if you meant to use the old C syntax of giving"
" untyped arguments, it is not supported)"
% (funcname or 'in expression', i + 1,
getattr(arg, 'name', '?')))
ellipsis = (
len(params) > 0 and
isinstance(params[-1].type, pycparser.c_ast.TypeDecl) and
isinstance(params[-1].type.type,
pycparser.c_ast.IdentifierType) and
params[-1].type.type.names == ['__dotdotdot__'])
if ellipsis:
params.pop()
if not params:
raise CDefError(
"%s: a function with only '(...)' as argument"
" is not correct C" % (funcname or 'in expression'))
args = [self._as_func_arg(*self._get_type_and_quals(argdeclnode.type))
for argdeclnode in params]
if not ellipsis and args == [model.void_type]:
args = []
result, quals = self._get_type_and_quals(typenode.type)
# the 'quals' on the result type are ignored. HACK: we absure them
# to detect __stdcall functions: we textually replace "__stdcall"
# with "volatile volatile const" above.
abi = None
if hasattr(typenode.type, 'quals'): # else, probable syntax error anyway
if typenode.type.quals[-3:] == ['volatile', 'volatile', 'const']:
abi = '__stdcall'
return model.RawFunctionType(tuple(args), result, ellipsis, abi)
def _as_func_arg(self, type, quals):
if isinstance(type, model.ArrayType):
return model.PointerType(type.item, quals)
elif isinstance(type, model.RawFunctionType):
return type.as_function_pointer()
else:
return type
def _get_struct_union_enum_type(self, kind, type, name=None, nested=False):
# First, a level of caching on the exact 'type' node of the AST.
# This is obscure, but needed because pycparser "unrolls" declarations
# such as "typedef struct { } foo_t, *foo_p" and we end up with
# an AST that is not a tree, but a DAG, with the "type" node of the
# two branches foo_t and foo_p of the trees being the same node.
# It's a bit silly but detecting "DAG-ness" in the AST tree seems
# to be the only way to distinguish this case from two independent
# structs. See test_struct_with_two_usages.
try:
return self._structnode2type[type]
except KeyError:
pass
#
# Note that this must handle parsing "struct foo" any number of
# times and always return the same StructType object. Additionally,
# one of these times (not necessarily the first), the fields of
# the struct can be specified with "struct foo { ...fields... }".
# If no name is given, then we have to create a new anonymous struct
# with no caching; in this case, the fields are either specified
# right now or never.
#
force_name = name
name = type.name
#
# get the type or create it if needed
if name is None:
# 'force_name' is used to guess a more readable name for
# anonymous structs, for the common case "typedef struct { } foo".
if force_name is not None:
explicit_name = '$%s' % force_name
else:
self._anonymous_counter += 1
explicit_name = '$%d' % self._anonymous_counter
tp = None
else:
explicit_name = name
key = '%s %s' % (kind, name)
tp, _ = self._declarations.get(key, (None, None))
#
if tp is None:
if kind == 'struct':
tp = model.StructType(explicit_name, None, None, None)
elif kind == 'union':
tp = model.UnionType(explicit_name, None, None, None)
elif kind == 'enum':
if explicit_name == '__dotdotdot__':
raise CDefError("Enums cannot be declared with ...")
tp = self._build_enum_type(explicit_name, type.values)
else:
raise AssertionError("kind = %r" % (kind,))
if name is not None:
self._declare(key, tp)
else:
if kind == 'enum' and type.values is not None:
raise NotImplementedError(
"enum %s: the '{}' declaration should appear on the first "
"time the enum is mentioned, not later" % explicit_name)
if not tp.forcename:
tp.force_the_name(force_name)
if tp.forcename and '$' in tp.name:
self._declare('anonymous %s' % tp.forcename, tp)
#
self._structnode2type[type] = tp
#
# enums: done here
if kind == 'enum':
return tp
#
# is there a 'type.decls'? If yes, then this is the place in the
# C sources that declare the fields. If no, then just return the
# existing type, possibly still incomplete.
if type.decls is None:
return tp
#
if tp.fldnames is not None:
raise CDefError("duplicate declaration of struct %s" % name)
fldnames = []
fldtypes = []
fldbitsize = []
fldquals = []
for decl in type.decls:
if (isinstance(decl.type, pycparser.c_ast.IdentifierType) and
''.join(decl.type.names) == '__dotdotdot__'):
# XXX pycparser is inconsistent: 'names' should be a list
# of strings, but is sometimes just one string. Use
# str.join() as a way to cope with both.
self._make_partial(tp, nested)
continue
if decl.bitsize is None:
bitsize = -1
else:
bitsize = self._parse_constant(decl.bitsize)
self._partial_length = False
type, fqual = self._get_type_and_quals(decl.type,
partial_length_ok=True)
if self._partial_length:
self._make_partial(tp, nested)
if isinstance(type, model.StructType) and type.partial:
self._make_partial(tp, nested)
fldnames.append(decl.name or '')
fldtypes.append(type)
fldbitsize.append(bitsize)
fldquals.append(fqual)
tp.fldnames = tuple(fldnames)
tp.fldtypes = tuple(fldtypes)
tp.fldbitsize = tuple(fldbitsize)
tp.fldquals = tuple(fldquals)
if fldbitsize != [-1] * len(fldbitsize):
if isinstance(tp, model.StructType) and tp.partial:
raise NotImplementedError("%s: using both bitfields and '...;'"
% (tp,))
tp.packed = self._options.get('packed')
if tp.completed: # must be re-completed: it is not opaque any more
tp.completed = 0
self._recomplete.append(tp)
return tp
def _make_partial(self, tp, nested):
if not isinstance(tp, model.StructOrUnion):
raise CDefError("%s cannot be partial" % (tp,))
if not tp.has_c_name() and not nested:
raise NotImplementedError("%s is partial but has no C name" %(tp,))
tp.partial = True
def _parse_constant(self, exprnode, partial_length_ok=False):
# for now, limited to expressions that are an immediate number
# or positive/negative number
if isinstance(exprnode, pycparser.c_ast.Constant):
s = exprnode.value
if s.startswith('0'):
if s.startswith('0x') or s.startswith('0X'):
return int(s, 16)
return int(s, 8)
elif '1' <= s[0] <= '9':
return int(s, 10)
elif s[0] == "'" and s[-1] == "'" and (
len(s) == 3 or (len(s) == 4 and s[1] == "\\")):
return ord(s[-2])
else:
raise CDefError("invalid constant %r" % (s,))
#
if (isinstance(exprnode, pycparser.c_ast.UnaryOp) and
exprnode.op == '+'):
return self._parse_constant(exprnode.expr)
#
if (isinstance(exprnode, pycparser.c_ast.UnaryOp) and
exprnode.op == '-'):
return -self._parse_constant(exprnode.expr)
# load previously defined int constant
if (isinstance(exprnode, pycparser.c_ast.ID) and
exprnode.name in self._int_constants):
return self._int_constants[exprnode.name]
#
if (isinstance(exprnode, pycparser.c_ast.ID) and
exprnode.name == '__dotdotdotarray__'):
if partial_length_ok:
self._partial_length = True
return '...'
raise FFIError(":%d: unsupported '[...]' here, cannot derive "
"the actual array length in this context"
% exprnode.coord.line)
#
if (isinstance(exprnode, pycparser.c_ast.BinaryOp) and
exprnode.op == '+'):
return (self._parse_constant(exprnode.left) +
self._parse_constant(exprnode.right))
#
if (isinstance(exprnode, pycparser.c_ast.BinaryOp) and
exprnode.op == '-'):
return (self._parse_constant(exprnode.left) -
self._parse_constant(exprnode.right))
#
raise FFIError(":%d: unsupported expression: expected a "
"simple numeric constant" % exprnode.coord.line)
def _build_enum_type(self, explicit_name, decls):
if decls is not None:
partial = False
enumerators = []
enumvalues = []
nextenumvalue = 0
for enum in decls.enumerators:
if _r_enum_dotdotdot.match(enum.name):
partial = True
continue
if enum.value is not None:
nextenumvalue = self._parse_constant(enum.value)
enumerators.append(enum.name)
enumvalues.append(nextenumvalue)
self._add_constants(enum.name, nextenumvalue)
nextenumvalue += 1
enumerators = tuple(enumerators)
enumvalues = tuple(enumvalues)
tp = model.EnumType(explicit_name, enumerators, enumvalues)
tp.partial = partial
else: # opaque enum
tp = model.EnumType(explicit_name, (), ())
return tp
def include(self, other):
for name, (tp, quals) in other._declarations.items():
if name.startswith('anonymous $enum_$'):
continue # fix for test_anonymous_enum_include
kind = name.split(' ', 1)[0]
if kind in ('struct', 'union', 'enum', 'anonymous', 'typedef'):
self._declare(name, tp, included=True, quals=quals)
for k, v in other._int_constants.items():
self._add_constants(k, v)
def _get_unknown_type(self, decl):
typenames = decl.type.type.names
if typenames == ['__dotdotdot__']:
return model.unknown_type(decl.name)
if typenames == ['__dotdotdotint__']:
if self._uses_new_feature is None:
self._uses_new_feature = "'typedef int... %s'" % decl.name
return model.UnknownIntegerType(decl.name)
if typenames == ['__dotdotdotfloat__']:
# note: not for 'long double' so far
if self._uses_new_feature is None:
self._uses_new_feature = "'typedef float... %s'" % decl.name
return model.UnknownFloatType(decl.name)
raise FFIError(':%d: unsupported usage of "..." in typedef'
% decl.coord.line)
def _get_unknown_ptr_type(self, decl):
if decl.type.type.type.names == ['__dotdotdot__']:
return model.unknown_ptr_type(decl.name)
raise FFIError(':%d: unsupported usage of "..." in typedef'
% decl.coord.line)

View File

@ -0,0 +1,20 @@
class FFIError(Exception):
pass
class CDefError(Exception):
def __str__(self):
try:
line = 'line %d: ' % (self.args[1].coord.line,)
except (AttributeError, TypeError, IndexError):
line = ''
return '%s%s' % (line, self.args[0])
class VerificationError(Exception):
""" An error raised when verification fails
"""
class VerificationMissing(Exception):
""" An error raised when incomplete structures are passed into
cdef, but no verification has been done
"""

View File

@ -0,0 +1,115 @@
import sys, os
from .error import VerificationError
LIST_OF_FILE_NAMES = ['sources', 'include_dirs', 'library_dirs',
'extra_objects', 'depends']
def get_extension(srcfilename, modname, sources=(), **kwds):
from distutils.core import Extension
allsources = [srcfilename]
for src in sources:
allsources.append(os.path.normpath(src))
return Extension(name=modname, sources=allsources, **kwds)
def compile(tmpdir, ext, compiler_verbose=0, debug=None):
"""Compile a C extension module using distutils."""
saved_environ = os.environ.copy()
try:
outputfilename = _build(tmpdir, ext, compiler_verbose, debug)
outputfilename = os.path.abspath(outputfilename)
finally:
# workaround for a distutils bugs where some env vars can
# become longer and longer every time it is used
for key, value in saved_environ.items():
if os.environ.get(key) != value:
os.environ[key] = value
return outputfilename
def _build(tmpdir, ext, compiler_verbose=0, debug=None):
# XXX compact but horrible :-(
from distutils.core import Distribution
import distutils.errors, distutils.log
#
dist = Distribution({'ext_modules': [ext]})
dist.parse_config_files()
options = dist.get_option_dict('build_ext')
if debug is None:
debug = sys.flags.debug
options['debug'] = ('ffiplatform', debug)
options['force'] = ('ffiplatform', True)
options['build_lib'] = ('ffiplatform', tmpdir)
options['build_temp'] = ('ffiplatform', tmpdir)
#
try:
old_level = distutils.log.set_threshold(0) or 0
try:
distutils.log.set_verbosity(compiler_verbose)
dist.run_command('build_ext')
cmd_obj = dist.get_command_obj('build_ext')
[soname] = cmd_obj.get_outputs()
finally:
distutils.log.set_threshold(old_level)
except (distutils.errors.CompileError,
distutils.errors.LinkError) as e:
raise VerificationError('%s: %s' % (e.__class__.__name__, e))
#
return soname
try:
from os.path import samefile
except ImportError:
def samefile(f1, f2):
return os.path.abspath(f1) == os.path.abspath(f2)
def maybe_relative_path(path):
if not os.path.isabs(path):
return path # already relative
dir = path
names = []
while True:
prevdir = dir
dir, name = os.path.split(prevdir)
if dir == prevdir or not dir:
return path # failed to make it relative
names.append(name)
try:
if samefile(dir, os.curdir):
names.reverse()
return os.path.join(*names)
except OSError:
pass
# ____________________________________________________________
try:
int_or_long = (int, long)
import cStringIO
except NameError:
int_or_long = int # Python 3
import io as cStringIO
def _flatten(x, f):
if isinstance(x, str):
f.write('%ds%s' % (len(x), x))
elif isinstance(x, dict):
keys = sorted(x.keys())
f.write('%dd' % len(keys))
for key in keys:
_flatten(key, f)
_flatten(x[key], f)
elif isinstance(x, (list, tuple)):
f.write('%dl' % len(x))
for value in x:
_flatten(value, f)
elif isinstance(x, int_or_long):
f.write('%di' % (x,))
else:
raise TypeError(
"the keywords to verify() contains unsupported object %r" % (x,))
def flatten(x):
f = cStringIO.StringIO()
_flatten(x, f)
return f.getvalue()

Some files were not shown because too many files have changed in this diff Show More