Add in bot, personify for RBXLegacy l8r
Oh also it's pre installed with all requirements 😃
This commit is contained in:
parent
4344ad3582
commit
8fda0bc62e
|
|
@ -0,0 +1,8 @@
|
|||
*.json
|
||||
*.pyc
|
||||
__pycache__
|
||||
data
|
||||
!data/trivia/*
|
||||
!data/audio/playlists/*
|
||||
*.exe
|
||||
*.dll
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
language: python
|
||||
python:
|
||||
- "3.5.2"
|
||||
install:
|
||||
- pip install -r requirements.txt
|
||||
script:
|
||||
- python -m compileall ./red.py
|
||||
- python -m compileall ./cogs
|
||||
- python ./red.py --no-prompt --no-cogs --dry-run
|
||||
cache: pip
|
||||
notifications:
|
||||
email: false
|
||||
|
|
@ -0,0 +1,191 @@
|
|||
from discord.ext import commands
|
||||
from .utils.chat_formatting import box
|
||||
from .utils.dataIO import dataIO
|
||||
from .utils import checks
|
||||
from __main__ import user_allowed, send_cmd_help
|
||||
from copy import deepcopy
|
||||
import os
|
||||
import discord
|
||||
|
||||
|
||||
class Alias:
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.file_path = "data/alias/aliases.json"
|
||||
self.aliases = dataIO.load_json(self.file_path)
|
||||
self.remove_old()
|
||||
|
||||
@commands.group(pass_context=True, no_pm=True)
|
||||
async def alias(self, ctx):
|
||||
"""Manage per-server aliases for commands"""
|
||||
if ctx.invoked_subcommand is None:
|
||||
await send_cmd_help(ctx)
|
||||
|
||||
@alias.command(name="add", pass_context=True, no_pm=True)
|
||||
@checks.mod_or_permissions(manage_server=True)
|
||||
async def _add_alias(self, ctx, command, *, to_execute):
|
||||
"""Add an alias for a command
|
||||
|
||||
Example: !alias add test flip @Twentysix"""
|
||||
server = ctx.message.server
|
||||
command = command.lower()
|
||||
if len(command.split(" ")) != 1:
|
||||
await self.bot.say("I can't safely do multi-word aliases because"
|
||||
" of the fact that I allow arguments to"
|
||||
" aliases. It sucks, I know, deal with it.")
|
||||
return
|
||||
if self.part_of_existing_command(command, server.id):
|
||||
await self.bot.say('I can\'t safely add an alias that starts with '
|
||||
'an existing command or alias. Sry <3')
|
||||
return
|
||||
prefix = self.get_prefix(server, to_execute)
|
||||
if prefix is not None:
|
||||
to_execute = to_execute[len(prefix):]
|
||||
if server.id not in self.aliases:
|
||||
self.aliases[server.id] = {}
|
||||
if command not in self.bot.commands:
|
||||
self.aliases[server.id][command] = to_execute
|
||||
dataIO.save_json(self.file_path, self.aliases)
|
||||
await self.bot.say("Alias '{}' added.".format(command))
|
||||
else:
|
||||
await self.bot.say("Cannot add '{}' because it's a real bot "
|
||||
"command.".format(command))
|
||||
|
||||
@alias.command(name="help", pass_context=True, no_pm=True)
|
||||
async def _help_alias(self, ctx, command):
|
||||
"""Tries to execute help for the base command of the alias"""
|
||||
server = ctx.message.server
|
||||
if server.id in self.aliases:
|
||||
server_aliases = self.aliases[server.id]
|
||||
if command in server_aliases:
|
||||
help_cmd = server_aliases[command].split(" ")[0]
|
||||
new_content = self.bot.settings.get_prefixes(server)[0]
|
||||
new_content += "help "
|
||||
new_content += help_cmd[len(self.get_prefix(server,
|
||||
help_cmd)):]
|
||||
message = ctx.message
|
||||
message.content = new_content
|
||||
await self.bot.process_commands(message)
|
||||
else:
|
||||
await self.bot.say("That alias doesn't exist.")
|
||||
|
||||
@alias.command(name="show", pass_context=True, no_pm=True)
|
||||
async def _show_alias(self, ctx, command):
|
||||
"""Shows what command the alias executes."""
|
||||
server = ctx.message.server
|
||||
if server.id in self.aliases:
|
||||
server_aliases = self.aliases[server.id]
|
||||
if command in server_aliases:
|
||||
await self.bot.say(box(server_aliases[command]))
|
||||
else:
|
||||
await self.bot.say("That alias doesn't exist.")
|
||||
|
||||
@alias.command(name="del", pass_context=True, no_pm=True)
|
||||
@checks.mod_or_permissions(manage_server=True)
|
||||
async def _del_alias(self, ctx, command):
|
||||
"""Deletes an alias"""
|
||||
command = command.lower()
|
||||
server = ctx.message.server
|
||||
if server.id in self.aliases:
|
||||
self.aliases[server.id].pop(command, None)
|
||||
dataIO.save_json(self.file_path, self.aliases)
|
||||
await self.bot.say("Alias '{}' deleted.".format(command))
|
||||
|
||||
@alias.command(name="list", pass_context=True, no_pm=True)
|
||||
async def _alias_list(self, ctx):
|
||||
"""Lists aliases available on this server
|
||||
|
||||
Responds in DM"""
|
||||
server = ctx.message.server
|
||||
if server.id in self.aliases:
|
||||
message = "```Alias list:\n"
|
||||
for alias in sorted(self.aliases[server.id]):
|
||||
if len(message) + len(alias) + 3 > 2000:
|
||||
await self.bot.whisper(message)
|
||||
message = "```\n"
|
||||
message += "\t{}\n".format(alias)
|
||||
if message != "```Alias list:\n":
|
||||
message += "```"
|
||||
await self.bot.whisper(message)
|
||||
else:
|
||||
await self.bot.say("There are no aliases on this server.")
|
||||
|
||||
async def on_message(self, message):
|
||||
if len(message.content) < 2 or message.channel.is_private:
|
||||
return
|
||||
|
||||
msg = message.content
|
||||
server = message.server
|
||||
prefix = self.get_prefix(server, msg)
|
||||
|
||||
if not prefix:
|
||||
return
|
||||
|
||||
if server.id in self.aliases and user_allowed(message):
|
||||
alias = self.first_word(msg[len(prefix):]).lower()
|
||||
if alias in self.aliases[server.id]:
|
||||
new_command = self.aliases[server.id][alias]
|
||||
args = message.content[len(prefix + alias):]
|
||||
new_message = deepcopy(message)
|
||||
new_message.content = prefix + new_command + args
|
||||
await self.bot.process_commands(new_message)
|
||||
|
||||
def part_of_existing_command(self, alias, server):
|
||||
'''Command or alias'''
|
||||
for command in self.bot.commands:
|
||||
if alias.lower() == command.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
def remove_old(self):
|
||||
for sid in self.aliases:
|
||||
to_delete = []
|
||||
to_add = []
|
||||
for aliasname, alias in self.aliases[sid].items():
|
||||
lower = aliasname.lower()
|
||||
if aliasname != lower:
|
||||
to_delete.append(aliasname)
|
||||
to_add.append((lower, alias))
|
||||
if aliasname != self.first_word(aliasname):
|
||||
to_delete.append(aliasname)
|
||||
continue
|
||||
server = discord.Object(id=sid)
|
||||
prefix = self.get_prefix(server, alias)
|
||||
if prefix is not None:
|
||||
self.aliases[sid][aliasname] = alias[len(prefix):]
|
||||
for alias in to_delete: # Fixes caps and bad prefixes
|
||||
del self.aliases[sid][alias]
|
||||
for alias, command in to_add: # For fixing caps
|
||||
self.aliases[sid][alias] = command
|
||||
dataIO.save_json(self.file_path, self.aliases)
|
||||
|
||||
def first_word(self, msg):
|
||||
return msg.split(" ")[0]
|
||||
|
||||
def get_prefix(self, server, msg):
|
||||
prefixes = self.bot.settings.get_prefixes(server)
|
||||
for p in prefixes:
|
||||
if msg.startswith(p):
|
||||
return p
|
||||
return None
|
||||
|
||||
|
||||
def check_folder():
|
||||
if not os.path.exists("data/alias"):
|
||||
print("Creating data/alias folder...")
|
||||
os.makedirs("data/alias")
|
||||
|
||||
|
||||
def check_file():
|
||||
aliases = {}
|
||||
|
||||
f = "data/alias/aliases.json"
|
||||
if not dataIO.is_valid_json(f):
|
||||
print("Creating default alias's aliases.json...")
|
||||
dataIO.save_json(f, aliases)
|
||||
|
||||
|
||||
def setup(bot):
|
||||
check_folder()
|
||||
check_file()
|
||||
bot.add_cog(Alias(bot))
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,200 @@
|
|||
from discord.ext import commands
|
||||
from .utils.dataIO import dataIO
|
||||
from .utils import checks
|
||||
from .utils.chat_formatting import pagify, box
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
class CustomCommands:
|
||||
"""Custom commands
|
||||
|
||||
Creates commands used to display text"""
|
||||
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.file_path = "data/customcom/commands.json"
|
||||
self.c_commands = dataIO.load_json(self.file_path)
|
||||
|
||||
@commands.group(aliases=["cc"], pass_context=True, no_pm=True)
|
||||
async def customcom(self, ctx):
|
||||
"""Custom commands management"""
|
||||
if ctx.invoked_subcommand is None:
|
||||
await self.bot.send_cmd_help(ctx)
|
||||
|
||||
@customcom.command(name="add", pass_context=True)
|
||||
@checks.mod_or_permissions(administrator=True)
|
||||
async def cc_add(self, ctx, command : str, *, text):
|
||||
"""Adds a custom command
|
||||
|
||||
Example:
|
||||
[p]customcom add yourcommand Text you want
|
||||
|
||||
CCs can be enhanced with arguments:
|
||||
https://twentysix26.github.io/Red-Docs/red_guide_command_args/
|
||||
"""
|
||||
server = ctx.message.server
|
||||
command = command.lower()
|
||||
if command in self.bot.commands:
|
||||
await self.bot.say("That command is already a standard command.")
|
||||
return
|
||||
if server.id not in self.c_commands:
|
||||
self.c_commands[server.id] = {}
|
||||
cmdlist = self.c_commands[server.id]
|
||||
if command not in cmdlist:
|
||||
cmdlist[command] = text
|
||||
self.c_commands[server.id] = cmdlist
|
||||
dataIO.save_json(self.file_path, self.c_commands)
|
||||
await self.bot.say("Custom command successfully added.")
|
||||
else:
|
||||
await self.bot.say("This command already exists. Use "
|
||||
"`{}customcom edit` to edit it."
|
||||
"".format(ctx.prefix))
|
||||
|
||||
@customcom.command(name="edit", pass_context=True)
|
||||
@checks.mod_or_permissions(administrator=True)
|
||||
async def cc_edit(self, ctx, command : str, *, text):
|
||||
"""Edits a custom command
|
||||
|
||||
Example:
|
||||
[p]customcom edit yourcommand Text you want
|
||||
"""
|
||||
server = ctx.message.server
|
||||
command = command.lower()
|
||||
if server.id in self.c_commands:
|
||||
cmdlist = self.c_commands[server.id]
|
||||
if command in cmdlist:
|
||||
cmdlist[command] = text
|
||||
self.c_commands[server.id] = cmdlist
|
||||
dataIO.save_json(self.file_path, self.c_commands)
|
||||
await self.bot.say("Custom command successfully edited.")
|
||||
else:
|
||||
await self.bot.say("That command doesn't exist. Use "
|
||||
"`{}customcom add` to add it."
|
||||
"".format(ctx.prefix))
|
||||
else:
|
||||
await self.bot.say("There are no custom commands in this server."
|
||||
" Use `{}customcom add` to start adding some."
|
||||
"".format(ctx.prefix))
|
||||
|
||||
@customcom.command(name="delete", pass_context=True)
|
||||
@checks.mod_or_permissions(administrator=True)
|
||||
async def cc_delete(self, ctx, command : str):
|
||||
"""Deletes a custom command
|
||||
|
||||
Example:
|
||||
[p]customcom delete yourcommand"""
|
||||
server = ctx.message.server
|
||||
command = command.lower()
|
||||
if server.id in self.c_commands:
|
||||
cmdlist = self.c_commands[server.id]
|
||||
if command in cmdlist:
|
||||
cmdlist.pop(command, None)
|
||||
self.c_commands[server.id] = cmdlist
|
||||
dataIO.save_json(self.file_path, self.c_commands)
|
||||
await self.bot.say("Custom command successfully deleted.")
|
||||
else:
|
||||
await self.bot.say("That command doesn't exist.")
|
||||
else:
|
||||
await self.bot.say("There are no custom commands in this server."
|
||||
" Use `{}customcom add` to start adding some."
|
||||
"".format(ctx.prefix))
|
||||
|
||||
@customcom.command(name="list", pass_context=True)
|
||||
async def cc_list(self, ctx):
|
||||
"""Shows custom commands list"""
|
||||
server = ctx.message.server
|
||||
commands = self.c_commands.get(server.id, {})
|
||||
|
||||
if not commands:
|
||||
await self.bot.say("There are no custom commands in this server."
|
||||
" Use `{}customcom add` to start adding some."
|
||||
"".format(ctx.prefix))
|
||||
return
|
||||
|
||||
commands = ", ".join([ctx.prefix + c for c in sorted(commands)])
|
||||
commands = "Custom commands:\n\n" + commands
|
||||
|
||||
if len(commands) < 1500:
|
||||
await self.bot.say(box(commands))
|
||||
else:
|
||||
for page in pagify(commands, delims=[" ", "\n"]):
|
||||
await self.bot.whisper(box(page))
|
||||
|
||||
async def on_message(self, message):
|
||||
if len(message.content) < 2 or message.channel.is_private:
|
||||
return
|
||||
|
||||
server = message.server
|
||||
prefix = self.get_prefix(message)
|
||||
|
||||
if not prefix:
|
||||
return
|
||||
|
||||
if server.id in self.c_commands and self.bot.user_allowed(message):
|
||||
cmdlist = self.c_commands[server.id]
|
||||
cmd = message.content[len(prefix):]
|
||||
if cmd in cmdlist:
|
||||
cmd = cmdlist[cmd]
|
||||
cmd = self.format_cc(cmd, message)
|
||||
await self.bot.send_message(message.channel, cmd)
|
||||
elif cmd.lower() in cmdlist:
|
||||
cmd = cmdlist[cmd.lower()]
|
||||
cmd = self.format_cc(cmd, message)
|
||||
await self.bot.send_message(message.channel, cmd)
|
||||
|
||||
def get_prefix(self, message):
|
||||
for p in self.bot.settings.get_prefixes(message.server):
|
||||
if message.content.startswith(p):
|
||||
return p
|
||||
return False
|
||||
|
||||
def format_cc(self, command, message):
|
||||
results = re.findall("\{([^}]+)\}", command)
|
||||
for result in results:
|
||||
param = self.transform_parameter(result, message)
|
||||
command = command.replace("{" + result + "}", param)
|
||||
return command
|
||||
|
||||
def transform_parameter(self, result, message):
|
||||
"""
|
||||
For security reasons only specific objects are allowed
|
||||
Internals are ignored
|
||||
"""
|
||||
raw_result = "{" + result + "}"
|
||||
objects = {
|
||||
"message" : message,
|
||||
"author" : message.author,
|
||||
"channel" : message.channel,
|
||||
"server" : message.server
|
||||
}
|
||||
if result in objects:
|
||||
return str(objects[result])
|
||||
try:
|
||||
first, second = result.split(".")
|
||||
except ValueError:
|
||||
return raw_result
|
||||
if first in objects and not second.startswith("_"):
|
||||
first = objects[first]
|
||||
else:
|
||||
return raw_result
|
||||
return str(getattr(first, second, raw_result))
|
||||
|
||||
|
||||
def check_folders():
|
||||
if not os.path.exists("data/customcom"):
|
||||
print("Creating data/customcom folder...")
|
||||
os.makedirs("data/customcom")
|
||||
|
||||
|
||||
def check_files():
|
||||
f = "data/customcom/commands.json"
|
||||
if not dataIO.is_valid_json(f):
|
||||
print("Creating empty commands.json...")
|
||||
dataIO.save_json(f, {})
|
||||
|
||||
|
||||
def setup(bot):
|
||||
check_folders()
|
||||
check_files()
|
||||
bot.add_cog(CustomCommands(bot))
|
||||
|
|
@ -0,0 +1,693 @@
|
|||
from discord.ext import commands
|
||||
from cogs.utils.dataIO import dataIO
|
||||
from cogs.utils import checks
|
||||
from cogs.utils.chat_formatting import pagify, box
|
||||
from __main__ import send_cmd_help, set_cog
|
||||
import os
|
||||
from subprocess import run as sp_run, PIPE
|
||||
import shutil
|
||||
from asyncio import as_completed
|
||||
from setuptools import distutils
|
||||
import discord
|
||||
from functools import partial
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from time import time
|
||||
from importlib.util import find_spec
|
||||
from copy import deepcopy
|
||||
|
||||
NUM_THREADS = 4
|
||||
REPO_NONEX = 0x1
|
||||
REPO_CLONE = 0x2
|
||||
REPO_SAME = 0x4
|
||||
REPOS_LIST = "https://twentysix26.github.io/Red-Docs/red_cog_approved_repos/"
|
||||
|
||||
DISCLAIMER = ("You're about to add a 3rd party repository. The creator of Red"
|
||||
" and its community have no responsibility for any potential "
|
||||
"damage that the content of 3rd party repositories might cause."
|
||||
"\nBy typing 'I agree' you declare to have read and understand "
|
||||
"the above message. This message won't be shown again until the"
|
||||
" next reboot.")
|
||||
|
||||
|
||||
class UpdateError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CloningError(UpdateError):
|
||||
pass
|
||||
|
||||
|
||||
class RequirementFail(UpdateError):
|
||||
pass
|
||||
|
||||
|
||||
class Downloader:
|
||||
"""Cog downloader/installer."""
|
||||
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.disclaimer_accepted = False
|
||||
self.path = os.path.join("data", "downloader")
|
||||
self.file_path = os.path.join(self.path, "repos.json")
|
||||
# {name:{url,cog1:{installed},cog1:{installed}}}
|
||||
self.repos = dataIO.load_json(self.file_path)
|
||||
self.executor = ThreadPoolExecutor(NUM_THREADS)
|
||||
self._do_first_run()
|
||||
|
||||
def save_repos(self):
|
||||
dataIO.save_json(self.file_path, self.repos)
|
||||
|
||||
@commands.group(pass_context=True)
|
||||
@checks.is_owner()
|
||||
async def cog(self, ctx):
|
||||
"""Additional cogs management"""
|
||||
if ctx.invoked_subcommand is None:
|
||||
await send_cmd_help(ctx)
|
||||
|
||||
@cog.group(pass_context=True)
|
||||
async def repo(self, ctx):
|
||||
"""Repo management commands"""
|
||||
if ctx.invoked_subcommand is None or \
|
||||
isinstance(ctx.invoked_subcommand, commands.Group):
|
||||
await send_cmd_help(ctx)
|
||||
return
|
||||
|
||||
@repo.command(name="add", pass_context=True)
|
||||
async def _repo_add(self, ctx, repo_name: str, repo_url: str):
|
||||
"""Adds repo to available repo lists
|
||||
|
||||
Warning: Adding 3RD Party Repositories is at your own
|
||||
Risk."""
|
||||
if not self.disclaimer_accepted:
|
||||
await self.bot.say(DISCLAIMER)
|
||||
answer = await self.bot.wait_for_message(timeout=30,
|
||||
author=ctx.message.author)
|
||||
if answer is None:
|
||||
await self.bot.say('Not adding repo.')
|
||||
return
|
||||
elif "i agree" not in answer.content.lower():
|
||||
await self.bot.say('Not adding repo.')
|
||||
return
|
||||
else:
|
||||
self.disclaimer_accepted = True
|
||||
self.repos[repo_name] = {}
|
||||
self.repos[repo_name]['url'] = repo_url
|
||||
try:
|
||||
self.update_repo(repo_name)
|
||||
except CloningError:
|
||||
await self.bot.say("That repository link doesn't seem to be "
|
||||
"valid.")
|
||||
del self.repos[repo_name]
|
||||
return
|
||||
self.populate_list(repo_name)
|
||||
self.save_repos()
|
||||
data = self.get_info_data(repo_name)
|
||||
if data:
|
||||
msg = data.get("INSTALL_MSG")
|
||||
if msg:
|
||||
await self.bot.say(msg[:2000])
|
||||
await self.bot.say("Repo '{}' added.".format(repo_name))
|
||||
|
||||
@repo.command(name="remove")
|
||||
async def _repo_del(self, repo_name: str):
|
||||
"""Removes repo from repo list. COGS ARE NOT REMOVED."""
|
||||
def remove_readonly(func, path, excinfo):
|
||||
os.chmod(path, 0o755)
|
||||
func(path)
|
||||
|
||||
if repo_name not in self.repos:
|
||||
await self.bot.say("That repo doesn't exist.")
|
||||
return
|
||||
del self.repos[repo_name]
|
||||
try:
|
||||
shutil.rmtree(os.path.join(self.path, repo_name), onerror=remove_readonly)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
self.save_repos()
|
||||
await self.bot.say("Repo '{}' removed.".format(repo_name))
|
||||
|
||||
@cog.command(name="list")
|
||||
async def _send_list(self, repo_name=None):
|
||||
"""Lists installable cogs
|
||||
|
||||
Repositories list:
|
||||
https://twentysix26.github.io/Red-Docs/red_cog_approved_repos/"""
|
||||
retlist = []
|
||||
if repo_name and repo_name in self.repos:
|
||||
msg = "Available cogs:\n"
|
||||
for cog in sorted(self.repos[repo_name].keys()):
|
||||
if 'url' == cog:
|
||||
continue
|
||||
data = self.get_info_data(repo_name, cog)
|
||||
if data and data.get("HIDDEN") is True:
|
||||
continue
|
||||
if data:
|
||||
retlist.append([cog, data.get("SHORT", "")])
|
||||
else:
|
||||
retlist.append([cog, ''])
|
||||
else:
|
||||
if self.repos:
|
||||
msg = "Available repos:\n"
|
||||
for repo_name in sorted(self.repos.keys()):
|
||||
data = self.get_info_data(repo_name)
|
||||
if data:
|
||||
retlist.append([repo_name, data.get("SHORT", "")])
|
||||
else:
|
||||
retlist.append([repo_name, ""])
|
||||
else:
|
||||
await self.bot.say("You haven't added a repository yet.\n"
|
||||
"Start now! {}".format(REPOS_LIST))
|
||||
return
|
||||
|
||||
col_width = max(len(row[0]) for row in retlist) + 2
|
||||
for row in retlist:
|
||||
msg += "\t" + "".join(word.ljust(col_width) for word in row) + "\n"
|
||||
msg += "\nRepositories list: {}".format(REPOS_LIST)
|
||||
for page in pagify(msg, delims=['\n'], shorten_by=8):
|
||||
await self.bot.say(box(page))
|
||||
|
||||
@cog.command()
|
||||
async def info(self, repo_name: str, cog: str=None):
|
||||
"""Shows info about the specified cog"""
|
||||
if cog is not None:
|
||||
cogs = self.list_cogs(repo_name)
|
||||
if cog in cogs:
|
||||
data = self.get_info_data(repo_name, cog)
|
||||
if data:
|
||||
msg = "{} by {}\n\n".format(cog, data["AUTHOR"])
|
||||
msg += data["NAME"] + "\n\n" + data["DESCRIPTION"]
|
||||
await self.bot.say(box(msg))
|
||||
else:
|
||||
await self.bot.say("The specified cog has no info file.")
|
||||
else:
|
||||
await self.bot.say("That cog doesn't exist."
|
||||
" Use cog list to see the full list.")
|
||||
else:
|
||||
data = self.get_info_data(repo_name)
|
||||
if data is None:
|
||||
await self.bot.say("That repo does not exist or the"
|
||||
" information file is missing for that repo"
|
||||
".")
|
||||
return
|
||||
name = data.get("NAME", None)
|
||||
name = repo_name if name is None else name
|
||||
author = data.get("AUTHOR", "Unknown")
|
||||
desc = data.get("DESCRIPTION", "")
|
||||
msg = ("```{} by {}```\n\n{}".format(name, author, desc))
|
||||
await self.bot.say(msg)
|
||||
|
||||
@cog.command(hidden=True)
|
||||
async def search(self, *terms: str):
|
||||
"""Search installable cogs"""
|
||||
pass # TO DO
|
||||
|
||||
@cog.command(pass_context=True)
|
||||
async def update(self, ctx):
|
||||
"""Updates cogs"""
|
||||
|
||||
tasknum = 0
|
||||
num_repos = len(self.repos)
|
||||
|
||||
min_dt = 0.5
|
||||
burst_inc = 0.1/(NUM_THREADS)
|
||||
touch_n = tasknum
|
||||
touch_t = time()
|
||||
|
||||
def regulate(touch_t, touch_n):
|
||||
dt = time() - touch_t
|
||||
if dt + burst_inc*(touch_n) > min_dt:
|
||||
touch_n = 0
|
||||
touch_t = time()
|
||||
return True, touch_t, touch_n
|
||||
return False, touch_t, touch_n + 1
|
||||
|
||||
tasks = []
|
||||
for r in self.repos:
|
||||
task = partial(self.update_repo, r)
|
||||
task = self.bot.loop.run_in_executor(self.executor, task)
|
||||
tasks.append(task)
|
||||
|
||||
base_msg = "Downloading updated cogs, please wait... "
|
||||
status = ' %d/%d repos updated' % (tasknum, num_repos)
|
||||
msg = await self.bot.say(base_msg + status)
|
||||
|
||||
updated_cogs = []
|
||||
new_cogs = []
|
||||
deleted_cogs = []
|
||||
failed_cogs = []
|
||||
error_repos = {}
|
||||
installed_updated_cogs = []
|
||||
|
||||
for f in as_completed(tasks):
|
||||
tasknum += 1
|
||||
try:
|
||||
name, updates, oldhash = await f
|
||||
if updates:
|
||||
if type(updates) is dict:
|
||||
for k, l in updates.items():
|
||||
tl = [(name, c, oldhash) for c in l]
|
||||
if k == 'A':
|
||||
new_cogs.extend(tl)
|
||||
elif k == 'D':
|
||||
deleted_cogs.extend(tl)
|
||||
elif k == 'M':
|
||||
updated_cogs.extend(tl)
|
||||
except UpdateError as e:
|
||||
name, what = e.args
|
||||
error_repos[name] = what
|
||||
edit, touch_t, touch_n = regulate(touch_t, touch_n)
|
||||
if edit:
|
||||
status = ' %d/%d repos updated' % (tasknum, num_repos)
|
||||
msg = await self._robust_edit(msg, base_msg + status)
|
||||
status = 'done. '
|
||||
|
||||
for t in updated_cogs:
|
||||
repo, cog, _ = t
|
||||
if self.repos[repo][cog]['INSTALLED']:
|
||||
try:
|
||||
await self.install(repo, cog,
|
||||
no_install_on_reqs_fail=False)
|
||||
except RequirementFail:
|
||||
failed_cogs.append(t)
|
||||
else:
|
||||
installed_updated_cogs.append(t)
|
||||
|
||||
for t in updated_cogs.copy():
|
||||
if t in failed_cogs:
|
||||
updated_cogs.remove(t)
|
||||
|
||||
if not any(self.repos[repo][cog]['INSTALLED'] for
|
||||
repo, cog, _ in updated_cogs):
|
||||
status += ' No updates to apply. '
|
||||
|
||||
if new_cogs:
|
||||
status += '\nNew cogs: ' \
|
||||
+ ', '.join('%s/%s' % c[:2] for c in new_cogs) + '.'
|
||||
if deleted_cogs:
|
||||
status += '\nDeleted cogs: ' \
|
||||
+ ', '.join('%s/%s' % c[:2] for c in deleted_cogs) + '.'
|
||||
if updated_cogs:
|
||||
status += '\nUpdated cogs: ' \
|
||||
+ ', '.join('%s/%s' % c[:2] for c in updated_cogs) + '.'
|
||||
if failed_cogs:
|
||||
status += '\nCogs that got new requirements which have ' + \
|
||||
'failed to install: ' + \
|
||||
', '.join('%s/%s' % c[:2] for c in failed_cogs) + '.'
|
||||
if error_repos:
|
||||
status += '\nThe following repos failed to update: '
|
||||
for n, what in error_repos.items():
|
||||
status += '\n%s: %s' % (n, what)
|
||||
|
||||
msg = await self._robust_edit(msg, base_msg + status)
|
||||
|
||||
if not installed_updated_cogs:
|
||||
return
|
||||
|
||||
patchnote_lang = 'Prolog'
|
||||
shorten_by = 8 + len(patchnote_lang)
|
||||
for note in self.patch_notes_handler(installed_updated_cogs):
|
||||
if note is None:
|
||||
continue
|
||||
for page in pagify(note, delims=['\n'], shorten_by=shorten_by):
|
||||
await self.bot.say(box(page, patchnote_lang))
|
||||
|
||||
await self.bot.say("Cogs updated. Reload updated cogs? (yes/no)")
|
||||
answer = await self.bot.wait_for_message(timeout=15,
|
||||
author=ctx.message.author)
|
||||
if answer is None:
|
||||
await self.bot.say("Ok then, you can reload cogs with"
|
||||
" `{}reload <cog_name>`".format(ctx.prefix))
|
||||
elif answer.content.lower().strip() == "yes":
|
||||
registry = dataIO.load_json(os.path.join("data", "red", "cogs.json"))
|
||||
update_list = []
|
||||
fail_list = []
|
||||
for repo, cog, _ in installed_updated_cogs:
|
||||
if not registry.get('cogs.' + cog, False):
|
||||
continue
|
||||
try:
|
||||
self.bot.unload_extension("cogs." + cog)
|
||||
self.bot.load_extension("cogs." + cog)
|
||||
update_list.append(cog)
|
||||
except:
|
||||
fail_list.append(cog)
|
||||
msg = 'Done.'
|
||||
if update_list:
|
||||
msg += " The following cogs were reloaded: "\
|
||||
+ ', '.join(update_list) + "\n"
|
||||
if fail_list:
|
||||
msg += " The following cogs failed to reload: "\
|
||||
+ ', '.join(fail_list)
|
||||
await self.bot.say(msg)
|
||||
|
||||
else:
|
||||
await self.bot.say("Ok then, you can reload cogs with"
|
||||
" `{}reload <cog_name>`".format(ctx.prefix))
|
||||
|
||||
def patch_notes_handler(self, repo_cog_hash_pairs):
|
||||
for repo, cog, oldhash in repo_cog_hash_pairs:
|
||||
repo_path = os.path.join('data', 'downloader', repo)
|
||||
cogfile = os.path.join(cog, cog + ".py")
|
||||
cmd = ["git", "-C", repo_path, "log", "--relative-date",
|
||||
"--reverse", oldhash + '..', cogfile
|
||||
]
|
||||
try:
|
||||
log = sp_run(cmd, stdout=PIPE).stdout.decode().strip()
|
||||
yield self.format_patch(repo, cog, log)
|
||||
except:
|
||||
pass
|
||||
|
||||
@cog.command(pass_context=True)
|
||||
async def uninstall(self, ctx, repo_name, cog):
|
||||
"""Uninstalls a cog"""
|
||||
if repo_name not in self.repos:
|
||||
await self.bot.say("That repo doesn't exist.")
|
||||
return
|
||||
if cog not in self.repos[repo_name]:
|
||||
await self.bot.say("That cog isn't available from that repo.")
|
||||
return
|
||||
set_cog("cogs." + cog, False)
|
||||
self.repos[repo_name][cog]['INSTALLED'] = False
|
||||
self.save_repos()
|
||||
os.remove(os.path.join("cogs", cog + ".py"))
|
||||
owner = self.bot.get_cog('Owner')
|
||||
await owner.unload.callback(owner, cog_name=cog)
|
||||
await self.bot.say("Cog successfully uninstalled.")
|
||||
|
||||
@cog.command(name="install", pass_context=True)
|
||||
async def _install(self, ctx, repo_name: str, cog: str):
|
||||
"""Installs specified cog"""
|
||||
if repo_name not in self.repos:
|
||||
await self.bot.say("That repo doesn't exist.")
|
||||
return
|
||||
if cog not in self.repos[repo_name]:
|
||||
await self.bot.say("That cog isn't available from that repo.")
|
||||
return
|
||||
data = self.get_info_data(repo_name, cog)
|
||||
try:
|
||||
install_cog = await self.install(repo_name, cog, notify_reqs=True)
|
||||
except RequirementFail:
|
||||
await self.bot.say("That cog has requirements that I could not "
|
||||
"install. Check the console for more "
|
||||
"informations.")
|
||||
return
|
||||
if data is not None:
|
||||
install_msg = data.get("INSTALL_MSG", None)
|
||||
if install_msg:
|
||||
await self.bot.say(install_msg[:2000])
|
||||
if install_cog:
|
||||
await self.bot.say("Installation completed. Load it now? (yes/no)")
|
||||
answer = await self.bot.wait_for_message(timeout=15,
|
||||
author=ctx.message.author)
|
||||
if answer is None:
|
||||
await self.bot.say("Ok then, you can load it with"
|
||||
" `{}load {}`".format(ctx.prefix, cog))
|
||||
elif answer.content.lower().strip() == "yes":
|
||||
set_cog("cogs." + cog, True)
|
||||
owner = self.bot.get_cog('Owner')
|
||||
await owner.load.callback(owner, cog_name=cog)
|
||||
else:
|
||||
await self.bot.say("Ok then, you can load it with"
|
||||
" `{}load {}`".format(ctx.prefix, cog))
|
||||
elif install_cog is False:
|
||||
await self.bot.say("Invalid cog. Installation aborted.")
|
||||
else:
|
||||
await self.bot.say("That cog doesn't exist. Use cog list to see"
|
||||
" the full list.")
|
||||
|
||||
async def install(self, repo_name, cog, *, notify_reqs=False,
|
||||
no_install_on_reqs_fail=True):
|
||||
# 'no_install_on_reqs_fail' will make the cog get installed anyway
|
||||
# on requirements installation fail. This is necessary because due to
|
||||
# how 'cog update' works right now, the user would have no way to
|
||||
# reupdate the cog if the update fails, since 'cog update' only
|
||||
# updates the cogs that get a new commit.
|
||||
# This is not a great way to deal with the problem and a cog update
|
||||
# rework would probably be the best course of action.
|
||||
reqs_failed = False
|
||||
if cog.endswith('.py'):
|
||||
cog = cog[:-3]
|
||||
|
||||
path = self.repos[repo_name][cog]['file']
|
||||
cog_folder_path = self.repos[repo_name][cog]['folder']
|
||||
cog_data_path = os.path.join(cog_folder_path, 'data')
|
||||
data = self.get_info_data(repo_name, cog)
|
||||
if data is not None:
|
||||
requirements = data.get("REQUIREMENTS", [])
|
||||
|
||||
requirements = [r for r in requirements
|
||||
if not self.is_lib_installed(r)]
|
||||
|
||||
if requirements and notify_reqs:
|
||||
await self.bot.say("Installing cog's requirements...")
|
||||
|
||||
for requirement in requirements:
|
||||
if not self.is_lib_installed(requirement):
|
||||
success = await self.bot.pip_install(requirement)
|
||||
if not success:
|
||||
if no_install_on_reqs_fail:
|
||||
raise RequirementFail()
|
||||
else:
|
||||
reqs_failed = True
|
||||
|
||||
to_path = os.path.join("cogs", cog + ".py")
|
||||
|
||||
print("Copying {}...".format(cog))
|
||||
shutil.copy(path, to_path)
|
||||
|
||||
if os.path.exists(cog_data_path):
|
||||
print("Copying {}'s data folder...".format(cog))
|
||||
distutils.dir_util.copy_tree(cog_data_path,
|
||||
os.path.join('data', cog))
|
||||
self.repos[repo_name][cog]['INSTALLED'] = True
|
||||
self.save_repos()
|
||||
if not reqs_failed:
|
||||
return True
|
||||
else:
|
||||
raise RequirementFail()
|
||||
|
||||
def get_info_data(self, repo_name, cog=None):
|
||||
if cog is not None:
|
||||
cogs = self.list_cogs(repo_name)
|
||||
if cog in cogs:
|
||||
info_file = os.path.join(cogs[cog].get('folder'), "info.json")
|
||||
if os.path.isfile(info_file):
|
||||
try:
|
||||
data = dataIO.load_json(info_file)
|
||||
except:
|
||||
return None
|
||||
return data
|
||||
else:
|
||||
repo_info = os.path.join(self.path, repo_name, 'info.json')
|
||||
if os.path.isfile(repo_info):
|
||||
try:
|
||||
data = dataIO.load_json(repo_info)
|
||||
return data
|
||||
except:
|
||||
return None
|
||||
return None
|
||||
|
||||
def list_cogs(self, repo_name):
|
||||
valid_cogs = {}
|
||||
|
||||
repo_path = os.path.join(self.path, repo_name)
|
||||
folders = [f for f in os.listdir(repo_path)
|
||||
if os.path.isdir(os.path.join(repo_path, f))]
|
||||
legacy_path = os.path.join(repo_path, "cogs")
|
||||
legacy_folders = []
|
||||
if os.path.exists(legacy_path):
|
||||
for f in os.listdir(legacy_path):
|
||||
if os.path.isdir(os.path.join(legacy_path, f)):
|
||||
legacy_folders.append(os.path.join("cogs", f))
|
||||
|
||||
folders = folders + legacy_folders
|
||||
|
||||
for f in folders:
|
||||
cog_folder_path = os.path.join(self.path, repo_name, f)
|
||||
cog_folder = os.path.basename(cog_folder_path)
|
||||
for cog in os.listdir(cog_folder_path):
|
||||
cog_path = os.path.join(cog_folder_path, cog)
|
||||
if os.path.isfile(cog_path) and cog_folder == cog[:-3]:
|
||||
valid_cogs[cog[:-3]] = {'folder': cog_folder_path,
|
||||
'file': cog_path}
|
||||
return valid_cogs
|
||||
|
||||
def get_dir_name(self, url):
|
||||
splitted = url.split("/")
|
||||
git_name = splitted[-1]
|
||||
return git_name[:-4]
|
||||
|
||||
def is_lib_installed(self, name):
|
||||
return bool(find_spec(name))
|
||||
|
||||
def _do_first_run(self):
|
||||
save = False
|
||||
repos_copy = deepcopy(self.repos)
|
||||
|
||||
# Issue 725
|
||||
for repo in repos_copy:
|
||||
for cog in repos_copy[repo]:
|
||||
cog_data = repos_copy[repo][cog]
|
||||
if isinstance(cog_data, str): # ... url field
|
||||
continue
|
||||
for k, v in cog_data.items():
|
||||
if k in ("file", "folder"):
|
||||
repos_copy[repo][cog][k] = os.path.normpath(cog_data[k])
|
||||
|
||||
if self.repos != repos_copy:
|
||||
self.repos = repos_copy
|
||||
save = True
|
||||
|
||||
invalid = []
|
||||
|
||||
for repo in self.repos:
|
||||
broken = 'url' in self.repos[repo] and len(self.repos[repo]) == 1
|
||||
if broken:
|
||||
save = True
|
||||
try:
|
||||
self.update_repo(repo)
|
||||
self.populate_list(repo)
|
||||
except CloningError:
|
||||
invalid.append(repo)
|
||||
continue
|
||||
except Exception as e:
|
||||
print(e) # TODO: Proper logging
|
||||
continue
|
||||
|
||||
for repo in invalid:
|
||||
del self.repos[repo]
|
||||
|
||||
if save:
|
||||
self.save_repos()
|
||||
|
||||
def populate_list(self, name):
|
||||
valid_cogs = self.list_cogs(name)
|
||||
new = set(valid_cogs.keys())
|
||||
old = set(self.repos[name].keys())
|
||||
for cog in new - old:
|
||||
self.repos[name][cog] = valid_cogs.get(cog, {})
|
||||
self.repos[name][cog]['INSTALLED'] = False
|
||||
for cog in new & old:
|
||||
self.repos[name][cog].update(valid_cogs[cog])
|
||||
for cog in old - new:
|
||||
if cog != 'url':
|
||||
del self.repos[name][cog]
|
||||
|
||||
def update_repo(self, name):
|
||||
|
||||
def run(*args, **kwargs):
|
||||
env = os.environ.copy()
|
||||
env['GIT_TERMINAL_PROMPT'] = '0'
|
||||
kwargs['env'] = env
|
||||
return sp_run(*args, **kwargs)
|
||||
|
||||
try:
|
||||
dd = self.path
|
||||
if name not in self.repos:
|
||||
raise UpdateError("Repo does not exist in data, wtf")
|
||||
folder = os.path.join(dd, name)
|
||||
# Make sure we don't git reset the Red folder on accident
|
||||
if not os.path.exists(os.path.join(folder, '.git')):
|
||||
#if os.path.exists(folder):
|
||||
#shutil.rmtree(folder)
|
||||
url = self.repos[name].get('url')
|
||||
if not url:
|
||||
raise UpdateError("Need to clone but no URL set")
|
||||
branch = None
|
||||
if "@" in url: # Specific branch
|
||||
url, branch = url.rsplit("@", maxsplit=1)
|
||||
if branch is None:
|
||||
p = run(["git", "clone", url, folder])
|
||||
else:
|
||||
p = run(["git", "clone", "-b", branch, url, folder])
|
||||
if p.returncode != 0:
|
||||
raise CloningError()
|
||||
self.populate_list(name)
|
||||
return name, REPO_CLONE, None
|
||||
else:
|
||||
rpbcmd = ["git", "-C", folder, "rev-parse", "--abbrev-ref", "HEAD"]
|
||||
p = run(rpbcmd, stdout=PIPE)
|
||||
branch = p.stdout.decode().strip()
|
||||
|
||||
rpcmd = ["git", "-C", folder, "rev-parse", branch]
|
||||
p = run(["git", "-C", folder, "reset", "--hard",
|
||||
"origin/%s" % branch, "-q"])
|
||||
if p.returncode != 0:
|
||||
raise UpdateError("Error resetting to origin/%s" % branch)
|
||||
p = run(rpcmd, stdout=PIPE)
|
||||
if p.returncode != 0:
|
||||
raise UpdateError("Unable to determine old commit hash")
|
||||
oldhash = p.stdout.decode().strip()
|
||||
p = run(["git", "-C", folder, "pull", "-q", "--ff-only"])
|
||||
if p.returncode != 0:
|
||||
raise UpdateError("Error pulling updates")
|
||||
p = run(rpcmd, stdout=PIPE)
|
||||
if p.returncode != 0:
|
||||
raise UpdateError("Unable to determine new commit hash")
|
||||
newhash = p.stdout.decode().strip()
|
||||
if oldhash == newhash:
|
||||
return name, REPO_SAME, None
|
||||
else:
|
||||
self.populate_list(name)
|
||||
self.save_repos()
|
||||
ret = {}
|
||||
cmd = ['git', '-C', folder, 'diff', '--no-commit-id',
|
||||
'--name-status', oldhash, newhash]
|
||||
p = run(cmd, stdout=PIPE)
|
||||
|
||||
if p.returncode != 0:
|
||||
raise UpdateError("Error in git diff")
|
||||
|
||||
changed = p.stdout.strip().decode().split('\n')
|
||||
|
||||
for f in changed:
|
||||
if not f.endswith('.py'):
|
||||
continue
|
||||
|
||||
status, _, cogpath = f.partition('\t')
|
||||
cogname = os.path.split(cogpath)[-1][:-3] # strip .py
|
||||
if status not in ret:
|
||||
ret[status] = []
|
||||
ret[status].append(cogname)
|
||||
|
||||
return name, ret, oldhash
|
||||
|
||||
except CloningError as e:
|
||||
raise CloningError(name, *e.args) from None
|
||||
except UpdateError as e:
|
||||
raise UpdateError(name, *e.args) from None
|
||||
|
||||
async def _robust_edit(self, msg, text):
|
||||
try:
|
||||
msg = await self.bot.edit_message(msg, text)
|
||||
except discord.errors.NotFound:
|
||||
msg = await self.bot.send_message(msg.channel, text)
|
||||
except:
|
||||
raise
|
||||
return msg
|
||||
|
||||
@staticmethod
|
||||
def format_patch(repo, cog, log):
|
||||
header = "Patch Notes for %s/%s" % (repo, cog)
|
||||
line = "=" * len(header)
|
||||
if log:
|
||||
return '\n'.join((header, line, log))
|
||||
|
||||
|
||||
def check_folders():
|
||||
if not os.path.exists(os.path.join("data", "downloader")):
|
||||
print('Making repo downloads folder...')
|
||||
os.mkdir(os.path.join("data", "downloader"))
|
||||
|
||||
|
||||
def check_files():
|
||||
f = os.path.join("data", "downloader", "repos.json")
|
||||
if not dataIO.is_valid_json(f):
|
||||
print("Creating default data/downloader/repos.json")
|
||||
dataIO.save_json(f, {})
|
||||
|
||||
|
||||
def setup(bot):
|
||||
check_folders()
|
||||
check_files()
|
||||
n = Downloader(bot)
|
||||
bot.add_cog(n)
|
||||
|
|
@ -0,0 +1,736 @@
|
|||
import discord
|
||||
from discord.ext import commands
|
||||
from cogs.utils.dataIO import dataIO
|
||||
from collections import namedtuple, defaultdict, deque
|
||||
from datetime import datetime
|
||||
from copy import deepcopy
|
||||
from .utils import checks
|
||||
from cogs.utils.chat_formatting import pagify, box
|
||||
from enum import Enum
|
||||
from __main__ import send_cmd_help
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import random
|
||||
|
||||
default_settings = {"PAYDAY_TIME": 300, "PAYDAY_CREDITS": 120,
|
||||
"SLOT_MIN": 5, "SLOT_MAX": 100, "SLOT_TIME": 0,
|
||||
"REGISTER_CREDITS": 0}
|
||||
|
||||
|
||||
class EconomyError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class OnCooldown(EconomyError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidBid(EconomyError):
|
||||
pass
|
||||
|
||||
|
||||
class BankError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AccountAlreadyExists(BankError):
|
||||
pass
|
||||
|
||||
|
||||
class NoAccount(BankError):
|
||||
pass
|
||||
|
||||
|
||||
class InsufficientBalance(BankError):
|
||||
pass
|
||||
|
||||
|
||||
class NegativeValue(BankError):
|
||||
pass
|
||||
|
||||
|
||||
class SameSenderAndReceiver(BankError):
|
||||
pass
|
||||
|
||||
|
||||
NUM_ENC = "\N{COMBINING ENCLOSING KEYCAP}"
|
||||
|
||||
|
||||
class SMReel(Enum):
|
||||
cherries = "\N{CHERRIES}"
|
||||
cookie = "\N{COOKIE}"
|
||||
two = "\N{DIGIT TWO}" + NUM_ENC
|
||||
flc = "\N{FOUR LEAF CLOVER}"
|
||||
cyclone = "\N{CYCLONE}"
|
||||
sunflower = "\N{SUNFLOWER}"
|
||||
six = "\N{DIGIT SIX}" + NUM_ENC
|
||||
mushroom = "\N{MUSHROOM}"
|
||||
heart = "\N{HEAVY BLACK HEART}"
|
||||
snowflake = "\N{SNOWFLAKE}"
|
||||
|
||||
PAYOUTS = {
|
||||
(SMReel.two, SMReel.two, SMReel.six) : {
|
||||
"payout" : lambda x: x * 2500 + x,
|
||||
"phrase" : "JACKPOT! 226! Your bid has been multiplied * 2500!"
|
||||
},
|
||||
(SMReel.flc, SMReel.flc, SMReel.flc) : {
|
||||
"payout" : lambda x: x + 1000,
|
||||
"phrase" : "4LC! +1000!"
|
||||
},
|
||||
(SMReel.cherries, SMReel.cherries, SMReel.cherries) : {
|
||||
"payout" : lambda x: x + 800,
|
||||
"phrase" : "Three cherries! +800!"
|
||||
},
|
||||
(SMReel.two, SMReel.six) : {
|
||||
"payout" : lambda x: x * 4 + x,
|
||||
"phrase" : "2 6! Your bid has been multiplied * 4!"
|
||||
},
|
||||
(SMReel.cherries, SMReel.cherries) : {
|
||||
"payout" : lambda x: x * 3 + x,
|
||||
"phrase" : "Two cherries! Your bid has been multiplied * 3!"
|
||||
},
|
||||
"3 symbols" : {
|
||||
"payout" : lambda x: x + 500,
|
||||
"phrase" : "Three symbols! +500!"
|
||||
},
|
||||
"2 symbols" : {
|
||||
"payout" : lambda x: x * 2 + x,
|
||||
"phrase" : "Two consecutive symbols! Your bid has been multiplied * 2!"
|
||||
},
|
||||
}
|
||||
|
||||
SLOT_PAYOUTS_MSG = ("Slot machine payouts:\n"
|
||||
"{two.value} {two.value} {six.value} Bet * 2500\n"
|
||||
"{flc.value} {flc.value} {flc.value} +1000\n"
|
||||
"{cherries.value} {cherries.value} {cherries.value} +800\n"
|
||||
"{two.value} {six.value} Bet * 4\n"
|
||||
"{cherries.value} {cherries.value} Bet * 3\n\n"
|
||||
"Three symbols: +500\n"
|
||||
"Two symbols: Bet * 2".format(**SMReel.__dict__))
|
||||
|
||||
|
||||
class Bank:
|
||||
|
||||
def __init__(self, bot, file_path):
|
||||
self.accounts = dataIO.load_json(file_path)
|
||||
self.bot = bot
|
||||
|
||||
def create_account(self, user, *, initial_balance=0):
|
||||
server = user.server
|
||||
if not self.account_exists(user):
|
||||
if server.id not in self.accounts:
|
||||
self.accounts[server.id] = {}
|
||||
if user.id in self.accounts: # Legacy account
|
||||
balance = self.accounts[user.id]["balance"]
|
||||
else:
|
||||
balance = initial_balance
|
||||
timestamp = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
|
||||
account = {"name": user.name,
|
||||
"balance": balance,
|
||||
"created_at": timestamp
|
||||
}
|
||||
self.accounts[server.id][user.id] = account
|
||||
self._save_bank()
|
||||
return self.get_account(user)
|
||||
else:
|
||||
raise AccountAlreadyExists()
|
||||
|
||||
def account_exists(self, user):
|
||||
try:
|
||||
self._get_account(user)
|
||||
except NoAccount:
|
||||
return False
|
||||
return True
|
||||
|
||||
def withdraw_credits(self, user, amount):
|
||||
server = user.server
|
||||
|
||||
if amount < 0:
|
||||
raise NegativeValue()
|
||||
|
||||
account = self._get_account(user)
|
||||
if account["balance"] >= amount:
|
||||
account["balance"] -= amount
|
||||
self.accounts[server.id][user.id] = account
|
||||
self._save_bank()
|
||||
else:
|
||||
raise InsufficientBalance()
|
||||
|
||||
def deposit_credits(self, user, amount):
|
||||
server = user.server
|
||||
if amount < 0:
|
||||
raise NegativeValue()
|
||||
account = self._get_account(user)
|
||||
account["balance"] += amount
|
||||
self.accounts[server.id][user.id] = account
|
||||
self._save_bank()
|
||||
|
||||
def set_credits(self, user, amount):
|
||||
server = user.server
|
||||
if amount < 0:
|
||||
raise NegativeValue()
|
||||
account = self._get_account(user)
|
||||
account["balance"] = amount
|
||||
self.accounts[server.id][user.id] = account
|
||||
self._save_bank()
|
||||
|
||||
def transfer_credits(self, sender, receiver, amount):
|
||||
if amount < 0:
|
||||
raise NegativeValue()
|
||||
if sender is receiver:
|
||||
raise SameSenderAndReceiver()
|
||||
if self.account_exists(sender) and self.account_exists(receiver):
|
||||
sender_acc = self._get_account(sender)
|
||||
if sender_acc["balance"] < amount:
|
||||
raise InsufficientBalance()
|
||||
self.withdraw_credits(sender, amount)
|
||||
self.deposit_credits(receiver, amount)
|
||||
else:
|
||||
raise NoAccount()
|
||||
|
||||
def can_spend(self, user, amount):
|
||||
account = self._get_account(user)
|
||||
if account["balance"] >= amount:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def wipe_bank(self, server):
|
||||
self.accounts[server.id] = {}
|
||||
self._save_bank()
|
||||
|
||||
def get_server_accounts(self, server):
|
||||
if server.id in self.accounts:
|
||||
raw_server_accounts = deepcopy(self.accounts[server.id])
|
||||
accounts = []
|
||||
for k, v in raw_server_accounts.items():
|
||||
v["id"] = k
|
||||
v["server"] = server
|
||||
acc = self._create_account_obj(v)
|
||||
accounts.append(acc)
|
||||
return accounts
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_all_accounts(self):
|
||||
accounts = []
|
||||
for server_id, v in self.accounts.items():
|
||||
server = self.bot.get_server(server_id)
|
||||
if server is None:
|
||||
# Servers that have since been left will be ignored
|
||||
# Same for users_id from the old bank format
|
||||
continue
|
||||
raw_server_accounts = deepcopy(self.accounts[server.id])
|
||||
for k, v in raw_server_accounts.items():
|
||||
v["id"] = k
|
||||
v["server"] = server
|
||||
acc = self._create_account_obj(v)
|
||||
accounts.append(acc)
|
||||
return accounts
|
||||
|
||||
def get_balance(self, user):
|
||||
account = self._get_account(user)
|
||||
return account["balance"]
|
||||
|
||||
def get_account(self, user):
|
||||
acc = self._get_account(user)
|
||||
acc["id"] = user.id
|
||||
acc["server"] = user.server
|
||||
return self._create_account_obj(acc)
|
||||
|
||||
def _create_account_obj(self, account):
|
||||
account["member"] = account["server"].get_member(account["id"])
|
||||
account["created_at"] = datetime.strptime(account["created_at"],
|
||||
"%Y-%m-%d %H:%M:%S")
|
||||
Account = namedtuple("Account", "id name balance "
|
||||
"created_at server member")
|
||||
return Account(**account)
|
||||
|
||||
def _save_bank(self):
|
||||
dataIO.save_json("data/economy/bank.json", self.accounts)
|
||||
|
||||
def _get_account(self, user):
|
||||
server = user.server
|
||||
try:
|
||||
return deepcopy(self.accounts[server.id][user.id])
|
||||
except KeyError:
|
||||
raise NoAccount()
|
||||
|
||||
|
||||
class SetParser:
|
||||
def __init__(self, argument):
|
||||
allowed = ("+", "-")
|
||||
if argument and argument[0] in allowed:
|
||||
try:
|
||||
self.sum = int(argument)
|
||||
except:
|
||||
raise
|
||||
if self.sum < 0:
|
||||
self.operation = "withdraw"
|
||||
elif self.sum > 0:
|
||||
self.operation = "deposit"
|
||||
else:
|
||||
raise
|
||||
self.sum = abs(self.sum)
|
||||
elif argument.isdigit():
|
||||
self.sum = int(argument)
|
||||
self.operation = "set"
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
class Economy:
|
||||
"""Economy
|
||||
|
||||
Get rich and have fun with imaginary currency!"""
|
||||
|
||||
def __init__(self, bot):
|
||||
global default_settings
|
||||
self.bot = bot
|
||||
self.bank = Bank(bot, "data/economy/bank.json")
|
||||
self.file_path = "data/economy/settings.json"
|
||||
self.settings = dataIO.load_json(self.file_path)
|
||||
if "PAYDAY_TIME" in self.settings: # old format
|
||||
default_settings = self.settings
|
||||
self.settings = {}
|
||||
self.settings = defaultdict(lambda: default_settings, self.settings)
|
||||
self.payday_register = defaultdict(dict)
|
||||
self.slot_register = defaultdict(dict)
|
||||
|
||||
@commands.group(name="bank", pass_context=True)
|
||||
async def _bank(self, ctx):
|
||||
"""Bank operations"""
|
||||
if ctx.invoked_subcommand is None:
|
||||
await send_cmd_help(ctx)
|
||||
|
||||
@_bank.command(pass_context=True, no_pm=True)
|
||||
async def register(self, ctx):
|
||||
"""Registers an account at the Twentysix bank"""
|
||||
settings = self.settings[ctx.message.server.id]
|
||||
author = ctx.message.author
|
||||
credits = 0
|
||||
if ctx.message.server.id in self.settings:
|
||||
credits = settings.get("REGISTER_CREDITS", 0)
|
||||
try:
|
||||
account = self.bank.create_account(author, initial_balance=credits)
|
||||
await self.bot.say("{} Account opened. Current balance: {}"
|
||||
"".format(author.mention, account.balance))
|
||||
except AccountAlreadyExists:
|
||||
await self.bot.say("{} You already have an account at the"
|
||||
" Twentysix bank.".format(author.mention))
|
||||
|
||||
@_bank.command(pass_context=True)
|
||||
async def balance(self, ctx, user: discord.Member=None):
|
||||
"""Shows balance of user.
|
||||
|
||||
Defaults to yours."""
|
||||
if not user:
|
||||
user = ctx.message.author
|
||||
try:
|
||||
await self.bot.say("{} Your balance is: {}".format(
|
||||
user.mention, self.bank.get_balance(user)))
|
||||
except NoAccount:
|
||||
await self.bot.say("{} You don't have an account at the"
|
||||
" Twentysix bank. Type `{}bank register`"
|
||||
" to open one.".format(user.mention,
|
||||
ctx.prefix))
|
||||
else:
|
||||
try:
|
||||
await self.bot.say("{}'s balance is {}".format(
|
||||
user.name, self.bank.get_balance(user)))
|
||||
except NoAccount:
|
||||
await self.bot.say("That user has no bank account.")
|
||||
|
||||
@_bank.command(pass_context=True)
|
||||
async def transfer(self, ctx, user: discord.Member, sum: int):
|
||||
"""Transfer credits to other users"""
|
||||
author = ctx.message.author
|
||||
try:
|
||||
self.bank.transfer_credits(author, user, sum)
|
||||
logger.info("{}({}) transferred {} credits to {}({})".format(
|
||||
author.name, author.id, sum, user.name, user.id))
|
||||
await self.bot.say("{} credits have been transferred to {}'s"
|
||||
" account.".format(sum, user.name))
|
||||
except NegativeValue:
|
||||
await self.bot.say("You need to transfer at least 1 credit.")
|
||||
except SameSenderAndReceiver:
|
||||
await self.bot.say("You can't transfer credits to yourself.")
|
||||
except InsufficientBalance:
|
||||
await self.bot.say("You don't have that sum in your bank account.")
|
||||
except NoAccount:
|
||||
await self.bot.say("That user has no bank account.")
|
||||
|
||||
@_bank.command(name="set", pass_context=True)
|
||||
@checks.admin_or_permissions(manage_server=True)
|
||||
async def _set(self, ctx, user: discord.Member, credits: SetParser):
|
||||
"""Sets credits of user's bank account. See help for more operations
|
||||
|
||||
Passing positive and negative values will add/remove credits instead
|
||||
|
||||
Examples:
|
||||
bank set @Twentysix 26 - Sets 26 credits
|
||||
bank set @Twentysix +2 - Adds 2 credits
|
||||
bank set @Twentysix -6 - Removes 6 credits"""
|
||||
author = ctx.message.author
|
||||
try:
|
||||
if credits.operation == "deposit":
|
||||
self.bank.deposit_credits(user, credits.sum)
|
||||
logger.info("{}({}) added {} credits to {} ({})".format(
|
||||
author.name, author.id, credits.sum, user.name, user.id))
|
||||
await self.bot.say("{} credits have been added to {}"
|
||||
"".format(credits.sum, user.name))
|
||||
elif credits.operation == "withdraw":
|
||||
self.bank.withdraw_credits(user, credits.sum)
|
||||
logger.info("{}({}) removed {} credits to {} ({})".format(
|
||||
author.name, author.id, credits.sum, user.name, user.id))
|
||||
await self.bot.say("{} credits have been withdrawn from {}"
|
||||
"".format(credits.sum, user.name))
|
||||
elif credits.operation == "set":
|
||||
self.bank.set_credits(user, credits.sum)
|
||||
logger.info("{}({}) set {} credits to {} ({})"
|
||||
"".format(author.name, author.id, credits.sum,
|
||||
user.name, user.id))
|
||||
await self.bot.say("{}'s credits have been set to {}".format(
|
||||
user.name, credits.sum))
|
||||
except InsufficientBalance:
|
||||
await self.bot.say("User doesn't have enough credits.")
|
||||
except NoAccount:
|
||||
await self.bot.say("User has no bank account.")
|
||||
|
||||
@_bank.command(pass_context=True, no_pm=True)
|
||||
@checks.serverowner_or_permissions(administrator=True)
|
||||
async def reset(self, ctx, confirmation: bool=False):
|
||||
"""Deletes all server's bank accounts"""
|
||||
if confirmation is False:
|
||||
await self.bot.say("This will delete all bank accounts on "
|
||||
"this server.\nIf you're sure, type "
|
||||
"{}bank reset yes".format(ctx.prefix))
|
||||
else:
|
||||
self.bank.wipe_bank(ctx.message.server)
|
||||
await self.bot.say("All bank accounts of this server have been "
|
||||
"deleted.")
|
||||
|
||||
@commands.command(pass_context=True, no_pm=True)
|
||||
async def payday(self, ctx): # TODO
|
||||
"""Get some free credits"""
|
||||
author = ctx.message.author
|
||||
server = author.server
|
||||
id = author.id
|
||||
if self.bank.account_exists(author):
|
||||
if id in self.payday_register[server.id]:
|
||||
seconds = abs(self.payday_register[server.id][
|
||||
id] - int(time.perf_counter()))
|
||||
if seconds >= self.settings[server.id]["PAYDAY_TIME"]:
|
||||
self.bank.deposit_credits(author, self.settings[
|
||||
server.id]["PAYDAY_CREDITS"])
|
||||
self.payday_register[server.id][
|
||||
id] = int(time.perf_counter())
|
||||
await self.bot.say(
|
||||
"{} Here, take some credits. Enjoy! (+{}"
|
||||
" credits!)".format(
|
||||
author.mention,
|
||||
str(self.settings[server.id]["PAYDAY_CREDITS"])))
|
||||
else:
|
||||
dtime = self.display_time(
|
||||
self.settings[server.id]["PAYDAY_TIME"] - seconds)
|
||||
await self.bot.say(
|
||||
"{} Too soon. For your next payday you have to"
|
||||
" wait {}.".format(author.mention, dtime))
|
||||
else:
|
||||
self.payday_register[server.id][id] = int(time.perf_counter())
|
||||
self.bank.deposit_credits(author, self.settings[
|
||||
server.id]["PAYDAY_CREDITS"])
|
||||
await self.bot.say(
|
||||
"{} Here, take some credits. Enjoy! (+{} credits!)".format(
|
||||
author.mention,
|
||||
str(self.settings[server.id]["PAYDAY_CREDITS"])))
|
||||
else:
|
||||
await self.bot.say("{} You need an account to receive credits."
|
||||
" Type `{}bank register` to open one.".format(
|
||||
author.mention, ctx.prefix))
|
||||
|
||||
@commands.group(pass_context=True)
|
||||
async def leaderboard(self, ctx):
|
||||
"""Server / global leaderboard
|
||||
|
||||
Defaults to server"""
|
||||
if ctx.invoked_subcommand is None:
|
||||
await ctx.invoke(self._server_leaderboard)
|
||||
|
||||
@leaderboard.command(name="server", pass_context=True)
|
||||
async def _server_leaderboard(self, ctx, top: int=10):
|
||||
"""Prints out the server's leaderboard
|
||||
|
||||
Defaults to top 10"""
|
||||
# Originally coded by Airenkun - edited by irdumb
|
||||
server = ctx.message.server
|
||||
if top < 1:
|
||||
top = 10
|
||||
bank_sorted = sorted(self.bank.get_server_accounts(server),
|
||||
key=lambda x: x.balance, reverse=True)
|
||||
bank_sorted = [a for a in bank_sorted if a.member] # exclude users who left
|
||||
if len(bank_sorted) < top:
|
||||
top = len(bank_sorted)
|
||||
topten = bank_sorted[:top]
|
||||
highscore = ""
|
||||
place = 1
|
||||
for acc in topten:
|
||||
highscore += str(place).ljust(len(str(top)) + 1)
|
||||
highscore += (str(acc.member.display_name) + " ").ljust(23 - len(str(acc.balance)))
|
||||
highscore += str(acc.balance) + "\n"
|
||||
place += 1
|
||||
if highscore != "":
|
||||
for page in pagify(highscore, shorten_by=12):
|
||||
await self.bot.say(box(page, lang="py"))
|
||||
else:
|
||||
await self.bot.say("There are no accounts in the bank.")
|
||||
|
||||
@leaderboard.command(name="global")
|
||||
async def _global_leaderboard(self, top: int=10):
|
||||
"""Prints out the global leaderboard
|
||||
|
||||
Defaults to top 10"""
|
||||
if top < 1:
|
||||
top = 10
|
||||
bank_sorted = sorted(self.bank.get_all_accounts(),
|
||||
key=lambda x: x.balance, reverse=True)
|
||||
bank_sorted = [a for a in bank_sorted if a.member] # exclude users who left
|
||||
unique_accounts = []
|
||||
for acc in bank_sorted:
|
||||
if not self.already_in_list(unique_accounts, acc):
|
||||
unique_accounts.append(acc)
|
||||
if len(unique_accounts) < top:
|
||||
top = len(unique_accounts)
|
||||
topten = unique_accounts[:top]
|
||||
highscore = ""
|
||||
place = 1
|
||||
for acc in topten:
|
||||
highscore += str(place).ljust(len(str(top)) + 1)
|
||||
highscore += ("{} |{}| ".format(acc.member, acc.server)
|
||||
).ljust(23 - len(str(acc.balance)))
|
||||
highscore += str(acc.balance) + "\n"
|
||||
place += 1
|
||||
if highscore != "":
|
||||
for page in pagify(highscore, shorten_by=12):
|
||||
await self.bot.say(box(page, lang="py"))
|
||||
else:
|
||||
await self.bot.say("There are no accounts in the bank.")
|
||||
|
||||
def already_in_list(self, accounts, user):
|
||||
for acc in accounts:
|
||||
if user.id == acc.id:
|
||||
return True
|
||||
return False
|
||||
|
||||
@commands.command()
|
||||
async def payouts(self):
|
||||
"""Shows slot machine payouts"""
|
||||
await self.bot.whisper(SLOT_PAYOUTS_MSG)
|
||||
|
||||
@commands.command(pass_context=True, no_pm=True)
|
||||
async def slot(self, ctx, bid: int):
|
||||
"""Play the slot machine"""
|
||||
author = ctx.message.author
|
||||
server = author.server
|
||||
settings = self.settings[server.id]
|
||||
valid_bid = settings["SLOT_MIN"] <= bid and bid <= settings["SLOT_MAX"]
|
||||
slot_time = settings["SLOT_TIME"]
|
||||
last_slot = self.slot_register.get(author.id)
|
||||
now = datetime.utcnow()
|
||||
try:
|
||||
if last_slot:
|
||||
if (now - last_slot).seconds < slot_time:
|
||||
raise OnCooldown()
|
||||
if not valid_bid:
|
||||
raise InvalidBid()
|
||||
if not self.bank.can_spend(author, bid):
|
||||
raise InsufficientBalance
|
||||
await self.slot_machine(author, bid)
|
||||
except NoAccount:
|
||||
await self.bot.say("{} You need an account to use the slot "
|
||||
"machine. Type `{}bank register` to open one."
|
||||
"".format(author.mention, ctx.prefix))
|
||||
except InsufficientBalance:
|
||||
await self.bot.say("{} You need an account with enough funds to "
|
||||
"play the slot machine.".format(author.mention))
|
||||
except OnCooldown:
|
||||
await self.bot.say("Slot machine is still cooling off! Wait {} "
|
||||
"seconds between each pull".format(slot_time))
|
||||
except InvalidBid:
|
||||
await self.bot.say("Bid must be between {} and {}."
|
||||
"".format(settings["SLOT_MIN"],
|
||||
settings["SLOT_MAX"]))
|
||||
|
||||
async def slot_machine(self, author, bid):
|
||||
default_reel = deque(SMReel)
|
||||
reels = []
|
||||
self.slot_register[author.id] = datetime.utcnow()
|
||||
for i in range(3):
|
||||
default_reel.rotate(random.randint(-999, 999)) # weeeeee
|
||||
new_reel = deque(default_reel, maxlen=3) # we need only 3 symbols
|
||||
reels.append(new_reel) # for each reel
|
||||
rows = ((reels[0][0], reels[1][0], reels[2][0]),
|
||||
(reels[0][1], reels[1][1], reels[2][1]),
|
||||
(reels[0][2], reels[1][2], reels[2][2]))
|
||||
|
||||
slot = "~~\n~~" # Mobile friendly
|
||||
for i, row in enumerate(rows): # Let's build the slot to show
|
||||
sign = " "
|
||||
if i == 1:
|
||||
sign = ">"
|
||||
slot += "{}{} {} {}\n".format(sign, *[c.value for c in row])
|
||||
|
||||
payout = PAYOUTS.get(rows[1])
|
||||
if not payout:
|
||||
# Checks for two-consecutive-symbols special rewards
|
||||
payout = PAYOUTS.get((rows[1][0], rows[1][1]),
|
||||
PAYOUTS.get((rows[1][1], rows[1][2]))
|
||||
)
|
||||
if not payout:
|
||||
# Still nothing. Let's check for 3 generic same symbols
|
||||
# or 2 consecutive symbols
|
||||
has_three = rows[1][0] == rows[1][1] == rows[1][2]
|
||||
has_two = (rows[1][0] == rows[1][1]) or (rows[1][1] == rows[1][2])
|
||||
if has_three:
|
||||
payout = PAYOUTS["3 symbols"]
|
||||
elif has_two:
|
||||
payout = PAYOUTS["2 symbols"]
|
||||
|
||||
if payout:
|
||||
then = self.bank.get_balance(author)
|
||||
pay = payout["payout"](bid)
|
||||
now = then - bid + pay
|
||||
self.bank.set_credits(author, now)
|
||||
await self.bot.say("{}\n{} {}\n\nYour bid: {}\n{} → {}!"
|
||||
"".format(slot, author.mention,
|
||||
payout["phrase"], bid, then, now))
|
||||
else:
|
||||
then = self.bank.get_balance(author)
|
||||
self.bank.withdraw_credits(author, bid)
|
||||
now = then - bid
|
||||
await self.bot.say("{}\n{} Nothing!\nYour bid: {}\n{} → {}!"
|
||||
"".format(slot, author.mention, bid, then, now))
|
||||
|
||||
@commands.group(pass_context=True, no_pm=True)
|
||||
@checks.admin_or_permissions(manage_server=True)
|
||||
async def economyset(self, ctx):
|
||||
"""Changes economy module settings"""
|
||||
server = ctx.message.server
|
||||
settings = self.settings[server.id]
|
||||
if ctx.invoked_subcommand is None:
|
||||
msg = "```"
|
||||
for k, v in settings.items():
|
||||
msg += "{}: {}\n".format(k, v)
|
||||
msg += "```"
|
||||
await send_cmd_help(ctx)
|
||||
await self.bot.say(msg)
|
||||
|
||||
@economyset.command(pass_context=True)
|
||||
async def slotmin(self, ctx, bid: int):
|
||||
"""Minimum slot machine bid"""
|
||||
server = ctx.message.server
|
||||
self.settings[server.id]["SLOT_MIN"] = bid
|
||||
await self.bot.say("Minimum bid is now {} credits.".format(bid))
|
||||
dataIO.save_json(self.file_path, self.settings)
|
||||
|
||||
@economyset.command(pass_context=True)
|
||||
async def slotmax(self, ctx, bid: int):
|
||||
"""Maximum slot machine bid"""
|
||||
server = ctx.message.server
|
||||
self.settings[server.id]["SLOT_MAX"] = bid
|
||||
await self.bot.say("Maximum bid is now {} credits.".format(bid))
|
||||
dataIO.save_json(self.file_path, self.settings)
|
||||
|
||||
@economyset.command(pass_context=True)
|
||||
async def slottime(self, ctx, seconds: int):
|
||||
"""Seconds between each slots use"""
|
||||
server = ctx.message.server
|
||||
self.settings[server.id]["SLOT_TIME"] = seconds
|
||||
await self.bot.say("Cooldown is now {} seconds.".format(seconds))
|
||||
dataIO.save_json(self.file_path, self.settings)
|
||||
|
||||
@economyset.command(pass_context=True)
|
||||
async def paydaytime(self, ctx, seconds: int):
|
||||
"""Seconds between each payday"""
|
||||
server = ctx.message.server
|
||||
self.settings[server.id]["PAYDAY_TIME"] = seconds
|
||||
await self.bot.say("Value modified. At least {} seconds must pass "
|
||||
"between each payday.".format(seconds))
|
||||
dataIO.save_json(self.file_path, self.settings)
|
||||
|
||||
@economyset.command(pass_context=True)
|
||||
async def paydaycredits(self, ctx, credits: int):
|
||||
"""Credits earned each payday"""
|
||||
server = ctx.message.server
|
||||
self.settings[server.id]["PAYDAY_CREDITS"] = credits
|
||||
await self.bot.say("Every payday will now give {} credits."
|
||||
"".format(credits))
|
||||
dataIO.save_json(self.file_path, self.settings)
|
||||
|
||||
@economyset.command(pass_context=True)
|
||||
async def registercredits(self, ctx, credits: int):
|
||||
"""Credits given on registering an account"""
|
||||
server = ctx.message.server
|
||||
if credits < 0:
|
||||
credits = 0
|
||||
self.settings[server.id]["REGISTER_CREDITS"] = credits
|
||||
await self.bot.say("Registering an account will now give {} credits."
|
||||
"".format(credits))
|
||||
dataIO.save_json(self.file_path, self.settings)
|
||||
|
||||
# What would I ever do without stackoverflow?
|
||||
def display_time(self, seconds, granularity=2):
|
||||
intervals = ( # Source: http://stackoverflow.com/a/24542445
|
||||
('weeks', 604800), # 60 * 60 * 24 * 7
|
||||
('days', 86400), # 60 * 60 * 24
|
||||
('hours', 3600), # 60 * 60
|
||||
('minutes', 60),
|
||||
('seconds', 1),
|
||||
)
|
||||
|
||||
result = []
|
||||
|
||||
for name, count in intervals:
|
||||
value = seconds // count
|
||||
if value:
|
||||
seconds -= value * count
|
||||
if value == 1:
|
||||
name = name.rstrip('s')
|
||||
result.append("{} {}".format(value, name))
|
||||
return ', '.join(result[:granularity])
|
||||
|
||||
|
||||
def check_folders():
|
||||
if not os.path.exists("data/economy"):
|
||||
print("Creating data/economy folder...")
|
||||
os.makedirs("data/economy")
|
||||
|
||||
|
||||
def check_files():
|
||||
|
||||
f = "data/economy/settings.json"
|
||||
if not dataIO.is_valid_json(f):
|
||||
print("Creating default economy's settings.json...")
|
||||
dataIO.save_json(f, {})
|
||||
|
||||
f = "data/economy/bank.json"
|
||||
if not dataIO.is_valid_json(f):
|
||||
print("Creating empty bank.json...")
|
||||
dataIO.save_json(f, {})
|
||||
|
||||
|
||||
def setup(bot):
|
||||
global logger
|
||||
check_folders()
|
||||
check_files()
|
||||
logger = logging.getLogger("red.economy")
|
||||
if logger.level == 0:
|
||||
# Prevents the logger from being loaded again in case of module reload
|
||||
logger.setLevel(logging.INFO)
|
||||
handler = logging.FileHandler(
|
||||
filename='data/economy/economy.log', encoding='utf-8', mode='a')
|
||||
handler.setFormatter(logging.Formatter(
|
||||
'%(asctime)s %(message)s', datefmt="[%d/%m/%Y %H:%M]"))
|
||||
logger.addHandler(handler)
|
||||
bot.add_cog(Economy(bot))
|
||||
|
|
@ -0,0 +1,433 @@
|
|||
import discord
|
||||
from discord.ext import commands
|
||||
from .utils.chat_formatting import escape_mass_mentions, italics, pagify
|
||||
from random import randint
|
||||
from random import choice
|
||||
from enum import Enum
|
||||
from urllib.parse import quote_plus
|
||||
import datetime
|
||||
import time
|
||||
import aiohttp
|
||||
import asyncio
|
||||
|
||||
settings = {"POLL_DURATION" : 60}
|
||||
|
||||
|
||||
class RPS(Enum):
|
||||
rock = "\N{MOYAI}"
|
||||
paper = "\N{PAGE FACING UP}"
|
||||
scissors = "\N{BLACK SCISSORS}"
|
||||
|
||||
|
||||
class RPSParser:
|
||||
def __init__(self, argument):
|
||||
argument = argument.lower()
|
||||
if argument == "rock":
|
||||
self.choice = RPS.rock
|
||||
elif argument == "paper":
|
||||
self.choice = RPS.paper
|
||||
elif argument == "scissors":
|
||||
self.choice = RPS.scissors
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
class General:
|
||||
"""General commands."""
|
||||
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.stopwatches = {}
|
||||
self.ball = ["As I see it, yes", "It is certain", "It is decidedly so", "Most likely", "Outlook good",
|
||||
"Signs point to yes", "Without a doubt", "Yes", "Yes – definitely", "You may rely on it", "Reply hazy, try again",
|
||||
"Ask again later", "Better not tell you now", "Cannot predict now", "Concentrate and ask again",
|
||||
"Don't count on it", "My reply is no", "My sources say no", "Outlook not so good", "Very doubtful"]
|
||||
self.poll_sessions = []
|
||||
|
||||
@commands.command(hidden=True)
|
||||
async def ping(self):
|
||||
"""Pong."""
|
||||
await self.bot.say("Pong.")
|
||||
|
||||
@commands.command()
|
||||
async def choose(self, *choices):
|
||||
"""Chooses between multiple choices.
|
||||
|
||||
To denote multiple choices, you should use double quotes.
|
||||
"""
|
||||
choices = [escape_mass_mentions(c) for c in choices]
|
||||
if len(choices) < 2:
|
||||
await self.bot.say('Not enough choices to pick from.')
|
||||
else:
|
||||
await self.bot.say(choice(choices))
|
||||
|
||||
@commands.command(pass_context=True)
|
||||
async def roll(self, ctx, number : int = 100):
|
||||
"""Rolls random number (between 1 and user choice)
|
||||
|
||||
Defaults to 100.
|
||||
"""
|
||||
author = ctx.message.author
|
||||
if number > 1:
|
||||
n = randint(1, number)
|
||||
await self.bot.say("{} :game_die: {} :game_die:".format(author.mention, n))
|
||||
else:
|
||||
await self.bot.say("{} Maybe higher than 1? ;P".format(author.mention))
|
||||
|
||||
@commands.command(pass_context=True)
|
||||
async def flip(self, ctx, user : discord.Member=None):
|
||||
"""Flips a coin... or a user.
|
||||
|
||||
Defaults to coin.
|
||||
"""
|
||||
if user != None:
|
||||
msg = ""
|
||||
if user.id == self.bot.user.id:
|
||||
user = ctx.message.author
|
||||
msg = "Nice try. You think this is funny? How about *this* instead:\n\n"
|
||||
char = "abcdefghijklmnopqrstuvwxyz"
|
||||
tran = "ɐqɔpǝɟƃɥᴉɾʞlɯuodbɹsʇnʌʍxʎz"
|
||||
table = str.maketrans(char, tran)
|
||||
name = user.display_name.translate(table)
|
||||
char = char.upper()
|
||||
tran = "∀qƆpƎℲפHIſʞ˥WNOԀQᴚS┴∩ΛMX⅄Z"
|
||||
table = str.maketrans(char, tran)
|
||||
name = name.translate(table)
|
||||
await self.bot.say(msg + "(╯°□°)╯︵ " + name[::-1])
|
||||
else:
|
||||
await self.bot.say("*flips a coin and... " + choice(["HEADS!*", "TAILS!*"]))
|
||||
|
||||
@commands.command(pass_context=True)
|
||||
async def rps(self, ctx, your_choice : RPSParser):
|
||||
"""Play rock paper scissors"""
|
||||
author = ctx.message.author
|
||||
player_choice = your_choice.choice
|
||||
red_choice = choice((RPS.rock, RPS.paper, RPS.scissors))
|
||||
cond = {
|
||||
(RPS.rock, RPS.paper) : False,
|
||||
(RPS.rock, RPS.scissors) : True,
|
||||
(RPS.paper, RPS.rock) : True,
|
||||
(RPS.paper, RPS.scissors) : False,
|
||||
(RPS.scissors, RPS.rock) : False,
|
||||
(RPS.scissors, RPS.paper) : True
|
||||
}
|
||||
|
||||
if red_choice == player_choice:
|
||||
outcome = None # Tie
|
||||
else:
|
||||
outcome = cond[(player_choice, red_choice)]
|
||||
|
||||
if outcome is True:
|
||||
await self.bot.say("{} You win {}!"
|
||||
"".format(red_choice.value, author.mention))
|
||||
elif outcome is False:
|
||||
await self.bot.say("{} You lose {}!"
|
||||
"".format(red_choice.value, author.mention))
|
||||
else:
|
||||
await self.bot.say("{} We're square {}!"
|
||||
"".format(red_choice.value, author.mention))
|
||||
|
||||
@commands.command(name="8", aliases=["8ball"])
|
||||
async def _8ball(self, *, question : str):
|
||||
"""Ask 8 ball a question
|
||||
|
||||
Question must end with a question mark.
|
||||
"""
|
||||
if question.endswith("?") and question != "?":
|
||||
await self.bot.say("`" + choice(self.ball) + "`")
|
||||
else:
|
||||
await self.bot.say("That doesn't look like a question.")
|
||||
|
||||
@commands.command(aliases=["sw"], pass_context=True)
|
||||
async def stopwatch(self, ctx):
|
||||
"""Starts/stops stopwatch"""
|
||||
author = ctx.message.author
|
||||
if not author.id in self.stopwatches:
|
||||
self.stopwatches[author.id] = int(time.perf_counter())
|
||||
await self.bot.say(author.mention + " Stopwatch started!")
|
||||
else:
|
||||
tmp = abs(self.stopwatches[author.id] - int(time.perf_counter()))
|
||||
tmp = str(datetime.timedelta(seconds=tmp))
|
||||
await self.bot.say(author.mention + " Stopwatch stopped! Time: **" + tmp + "**")
|
||||
self.stopwatches.pop(author.id, None)
|
||||
|
||||
@commands.command()
|
||||
async def lmgtfy(self, *, search_terms : str):
|
||||
"""Creates a lmgtfy link"""
|
||||
search_terms = escape_mass_mentions(search_terms.replace(" ", "+"))
|
||||
await self.bot.say("https://lmgtfy.com/?q={}".format(search_terms))
|
||||
|
||||
@commands.command(no_pm=True, hidden=True)
|
||||
async def hug(self, user : discord.Member, intensity : int=1):
|
||||
"""Because everyone likes hugs
|
||||
|
||||
Up to 10 intensity levels."""
|
||||
name = italics(user.display_name)
|
||||
if intensity <= 0:
|
||||
msg = "(っ˘̩╭╮˘̩)っ" + name
|
||||
elif intensity <= 3:
|
||||
msg = "(っ´▽`)っ" + name
|
||||
elif intensity <= 6:
|
||||
msg = "╰(*´︶`*)╯" + name
|
||||
elif intensity <= 9:
|
||||
msg = "(つ≧▽≦)つ" + name
|
||||
elif intensity >= 10:
|
||||
msg = "(づ ̄ ³ ̄)づ{} ⊂(´・ω・`⊂)".format(name)
|
||||
await self.bot.say(msg)
|
||||
|
||||
@commands.command(pass_context=True, no_pm=True)
|
||||
async def userinfo(self, ctx, *, user: discord.Member=None):
|
||||
"""Shows users's informations"""
|
||||
author = ctx.message.author
|
||||
server = ctx.message.server
|
||||
|
||||
if not user:
|
||||
user = author
|
||||
|
||||
roles = [x.name for x in user.roles if x.name != "@everyone"]
|
||||
|
||||
joined_at = self.fetch_joined_at(user, server)
|
||||
since_created = (ctx.message.timestamp - user.created_at).days
|
||||
since_joined = (ctx.message.timestamp - joined_at).days
|
||||
user_joined = joined_at.strftime("%d %b %Y %H:%M")
|
||||
user_created = user.created_at.strftime("%d %b %Y %H:%M")
|
||||
member_number = sorted(server.members,
|
||||
key=lambda m: m.joined_at).index(user) + 1
|
||||
|
||||
created_on = "{}\n({} days ago)".format(user_created, since_created)
|
||||
joined_on = "{}\n({} days ago)".format(user_joined, since_joined)
|
||||
|
||||
game = "Chilling in {} status".format(user.status)
|
||||
|
||||
if user.game is None:
|
||||
pass
|
||||
elif user.game.url is None:
|
||||
game = "Playing {}".format(user.game)
|
||||
else:
|
||||
game = "Streaming: [{}]({})".format(user.game, user.game.url)
|
||||
|
||||
if roles:
|
||||
roles = sorted(roles, key=[x.name for x in server.role_hierarchy
|
||||
if x.name != "@everyone"].index)
|
||||
roles = ", ".join(roles)
|
||||
else:
|
||||
roles = "None"
|
||||
|
||||
data = discord.Embed(description=game, colour=user.colour)
|
||||
data.add_field(name="Joined Discord on", value=created_on)
|
||||
data.add_field(name="Joined this server on", value=joined_on)
|
||||
data.add_field(name="Roles", value=roles, inline=False)
|
||||
data.set_footer(text="Member #{} | User ID:{}"
|
||||
"".format(member_number, user.id))
|
||||
|
||||
name = str(user)
|
||||
name = " ~ ".join((name, user.nick)) if user.nick else name
|
||||
|
||||
if user.avatar_url:
|
||||
data.set_author(name=name, url=user.avatar_url)
|
||||
data.set_thumbnail(url=user.avatar_url)
|
||||
else:
|
||||
data.set_author(name=name)
|
||||
|
||||
try:
|
||||
await self.bot.say(embed=data)
|
||||
except discord.HTTPException:
|
||||
await self.bot.say("I need the `Embed links` permission "
|
||||
"to send this")
|
||||
|
||||
@commands.command(pass_context=True, no_pm=True)
|
||||
async def serverinfo(self, ctx):
|
||||
"""Shows server's informations"""
|
||||
server = ctx.message.server
|
||||
online = len([m.status for m in server.members
|
||||
if m.status == discord.Status.online or
|
||||
m.status == discord.Status.idle])
|
||||
total_users = len(server.members)
|
||||
text_channels = len([x for x in server.channels
|
||||
if x.type == discord.ChannelType.text])
|
||||
voice_channels = len(server.channels) - text_channels
|
||||
passed = (ctx.message.timestamp - server.created_at).days
|
||||
created_at = ("Since {}. That's over {} days ago!"
|
||||
"".format(server.created_at.strftime("%d %b %Y %H:%M"),
|
||||
passed))
|
||||
|
||||
colour = ''.join([choice('0123456789ABCDEF') for x in range(6)])
|
||||
colour = int(colour, 16)
|
||||
|
||||
data = discord.Embed(
|
||||
description=created_at,
|
||||
colour=discord.Colour(value=colour))
|
||||
data.add_field(name="Region", value=str(server.region))
|
||||
data.add_field(name="Users", value="{}/{}".format(online, total_users))
|
||||
data.add_field(name="Text Channels", value=text_channels)
|
||||
data.add_field(name="Voice Channels", value=voice_channels)
|
||||
data.add_field(name="Roles", value=len(server.roles))
|
||||
data.add_field(name="Owner", value=str(server.owner))
|
||||
data.set_footer(text="Server ID: " + server.id)
|
||||
|
||||
if server.icon_url:
|
||||
data.set_author(name=server.name, url=server.icon_url)
|
||||
data.set_thumbnail(url=server.icon_url)
|
||||
else:
|
||||
data.set_author(name=server.name)
|
||||
|
||||
try:
|
||||
await self.bot.say(embed=data)
|
||||
except discord.HTTPException:
|
||||
await self.bot.say("I need the `Embed links` permission "
|
||||
"to send this")
|
||||
|
||||
@commands.command()
|
||||
async def urban(self, *, search_terms : str, definition_number : int=1):
|
||||
"""Urban Dictionary search
|
||||
|
||||
Definition number must be between 1 and 10"""
|
||||
def encode(s):
|
||||
return quote_plus(s, encoding='utf-8', errors='replace')
|
||||
|
||||
# definition_number is just there to show up in the help
|
||||
# all this mess is to avoid forcing double quotes on the user
|
||||
|
||||
search_terms = search_terms.split(" ")
|
||||
try:
|
||||
if len(search_terms) > 1:
|
||||
pos = int(search_terms[-1]) - 1
|
||||
search_terms = search_terms[:-1]
|
||||
else:
|
||||
pos = 0
|
||||
if pos not in range(0, 11): # API only provides the
|
||||
pos = 0 # top 10 definitions
|
||||
except ValueError:
|
||||
pos = 0
|
||||
|
||||
search_terms = "+".join([encode(s) for s in search_terms])
|
||||
url = "http://api.urbandictionary.com/v0/define?term=" + search_terms
|
||||
try:
|
||||
async with aiohttp.get(url) as r:
|
||||
result = await r.json()
|
||||
if result["list"]:
|
||||
definition = result['list'][pos]['definition']
|
||||
example = result['list'][pos]['example']
|
||||
defs = len(result['list'])
|
||||
msg = ("**Definition #{} out of {}:\n**{}\n\n"
|
||||
"**Example:\n**{}".format(pos+1, defs, definition,
|
||||
example))
|
||||
msg = pagify(msg, ["\n"])
|
||||
for page in msg:
|
||||
await self.bot.say(page)
|
||||
else:
|
||||
await self.bot.say("Your search terms gave no results.")
|
||||
except IndexError:
|
||||
await self.bot.say("There is no definition #{}".format(pos+1))
|
||||
except:
|
||||
await self.bot.say("Error.")
|
||||
|
||||
@commands.command(pass_context=True, no_pm=True)
|
||||
async def poll(self, ctx, *text):
|
||||
"""Starts/stops a poll
|
||||
|
||||
Usage example:
|
||||
poll Is this a poll?;Yes;No;Maybe
|
||||
poll stop"""
|
||||
message = ctx.message
|
||||
if len(text) == 1:
|
||||
if text[0].lower() == "stop":
|
||||
await self.endpoll(message)
|
||||
return
|
||||
if not self.getPollByChannel(message):
|
||||
check = " ".join(text).lower()
|
||||
if "@everyone" in check or "@here" in check:
|
||||
await self.bot.say("Nice try.")
|
||||
return
|
||||
p = NewPoll(message, " ".join(text), self)
|
||||
if p.valid:
|
||||
self.poll_sessions.append(p)
|
||||
await p.start()
|
||||
else:
|
||||
await self.bot.say("poll question;option1;option2 (...)")
|
||||
else:
|
||||
await self.bot.say("A poll is already ongoing in this channel.")
|
||||
|
||||
async def endpoll(self, message):
|
||||
if self.getPollByChannel(message):
|
||||
p = self.getPollByChannel(message)
|
||||
if p.author == message.author.id: # or isMemberAdmin(message)
|
||||
await self.getPollByChannel(message).endPoll()
|
||||
else:
|
||||
await self.bot.say("Only admins and the author can stop the poll.")
|
||||
else:
|
||||
await self.bot.say("There's no poll ongoing in this channel.")
|
||||
|
||||
def getPollByChannel(self, message):
|
||||
for poll in self.poll_sessions:
|
||||
if poll.channel == message.channel:
|
||||
return poll
|
||||
return False
|
||||
|
||||
async def check_poll_votes(self, message):
|
||||
if message.author.id != self.bot.user.id:
|
||||
if self.getPollByChannel(message):
|
||||
self.getPollByChannel(message).checkAnswer(message)
|
||||
|
||||
def fetch_joined_at(self, user, server):
|
||||
"""Just a special case for someone special :^)"""
|
||||
if user.id == "96130341705637888" and server.id == "133049272517001216":
|
||||
return datetime.datetime(2016, 1, 10, 6, 8, 4, 443000)
|
||||
else:
|
||||
return user.joined_at
|
||||
|
||||
class NewPoll():
|
||||
def __init__(self, message, text, main):
|
||||
self.channel = message.channel
|
||||
self.author = message.author.id
|
||||
self.client = main.bot
|
||||
self.poll_sessions = main.poll_sessions
|
||||
msg = [ans.strip() for ans in text.split(";")]
|
||||
if len(msg) < 2: # Needs at least one question and 2 choices
|
||||
self.valid = False
|
||||
return None
|
||||
else:
|
||||
self.valid = True
|
||||
self.already_voted = []
|
||||
self.question = msg[0]
|
||||
msg.remove(self.question)
|
||||
self.answers = {}
|
||||
i = 1
|
||||
for answer in msg: # {id : {answer, votes}}
|
||||
self.answers[i] = {"ANSWER" : answer, "VOTES" : 0}
|
||||
i += 1
|
||||
|
||||
async def start(self):
|
||||
msg = "**POLL STARTED!**\n\n{}\n\n".format(self.question)
|
||||
for id, data in self.answers.items():
|
||||
msg += "{}. *{}*\n".format(id, data["ANSWER"])
|
||||
msg += "\nType the number to vote!"
|
||||
await self.client.send_message(self.channel, msg)
|
||||
await asyncio.sleep(settings["POLL_DURATION"])
|
||||
if self.valid:
|
||||
await self.endPoll()
|
||||
|
||||
async def endPoll(self):
|
||||
self.valid = False
|
||||
msg = "**POLL ENDED!**\n\n{}\n\n".format(self.question)
|
||||
for data in self.answers.values():
|
||||
msg += "*{}* - {} votes\n".format(data["ANSWER"], str(data["VOTES"]))
|
||||
await self.client.send_message(self.channel, msg)
|
||||
self.poll_sessions.remove(self)
|
||||
|
||||
def checkAnswer(self, message):
|
||||
try:
|
||||
i = int(message.content)
|
||||
if i in self.answers.keys():
|
||||
if message.author.id not in self.already_voted:
|
||||
data = self.answers[i]
|
||||
data["VOTES"] += 1
|
||||
self.answers[i] = data
|
||||
self.already_voted.append(message.author.id)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def setup(bot):
|
||||
n = General(bot)
|
||||
bot.add_listener(n.check_poll_votes, "on_message")
|
||||
bot.add_cog(n)
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
from discord.ext import commands
|
||||
from random import choice, shuffle
|
||||
import aiohttp
|
||||
import functools
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
from imgurpython import ImgurClient
|
||||
except:
|
||||
ImgurClient = False
|
||||
|
||||
CLIENT_ID = "1fd3ef04daf8cab"
|
||||
CLIENT_SECRET = "f963e574e8e3c17993c933af4f0522e1dc01e230"
|
||||
GIPHY_API_KEY = "dc6zaTOxFJmzC"
|
||||
|
||||
|
||||
class Image:
|
||||
"""Image related commands."""
|
||||
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.imgur = ImgurClient(CLIENT_ID, CLIENT_SECRET)
|
||||
|
||||
@commands.group(name="imgur", no_pm=True, pass_context=True)
|
||||
async def _imgur(self, ctx):
|
||||
"""Retrieves pictures from imgur"""
|
||||
if ctx.invoked_subcommand is None:
|
||||
await self.bot.send_cmd_help(ctx)
|
||||
|
||||
@_imgur.command(pass_context=True, name="random")
|
||||
async def imgur_random(self, ctx, *, term: str=None):
|
||||
"""Retrieves a random image from Imgur
|
||||
|
||||
Search terms can be specified"""
|
||||
if term is None:
|
||||
task = functools.partial(self.imgur.gallery_random, page=0)
|
||||
else:
|
||||
task = functools.partial(self.imgur.gallery_search, term,
|
||||
advanced=None, sort='time',
|
||||
window='all', page=0)
|
||||
task = self.bot.loop.run_in_executor(None, task)
|
||||
|
||||
try:
|
||||
results = await asyncio.wait_for(task, timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
await self.bot.say("Error: request timed out")
|
||||
else:
|
||||
if results:
|
||||
item = choice(results)
|
||||
link = item.gifv if hasattr(item, "gifv") else item.link
|
||||
await self.bot.say(link)
|
||||
else:
|
||||
await self.bot.say("Your search terms gave no results.")
|
||||
|
||||
@_imgur.command(pass_context=True, name="search")
|
||||
async def imgur_search(self, ctx, *, term: str):
|
||||
"""Searches Imgur for the specified term and returns up to 3 results"""
|
||||
task = functools.partial(self.imgur.gallery_search, term,
|
||||
advanced=None, sort='time',
|
||||
window='all', page=0)
|
||||
task = self.bot.loop.run_in_executor(None, task)
|
||||
|
||||
try:
|
||||
results = await asyncio.wait_for(task, timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
await self.bot.say("Error: request timed out")
|
||||
else:
|
||||
if results:
|
||||
shuffle(results)
|
||||
msg = "Search results...\n"
|
||||
for r in results[:3]:
|
||||
msg += r.gifv if hasattr(r, "gifv") else r.link
|
||||
msg += "\n"
|
||||
await self.bot.say(msg)
|
||||
else:
|
||||
await self.bot.say("Your search terms gave no results.")
|
||||
|
||||
@_imgur.command(pass_context=True, name="subreddit")
|
||||
async def imgur_subreddit(self, ctx, subreddit: str, sort_type: str="top", window: str="day"):
|
||||
"""Gets images from the specified subreddit section
|
||||
|
||||
Sort types: new, top
|
||||
Time windows: day, week, month, year, all"""
|
||||
sort_type = sort_type.lower()
|
||||
|
||||
if sort_type not in ("new", "top"):
|
||||
await self.bot.say("Only 'new' and 'top' are a valid sort type.")
|
||||
return
|
||||
elif window not in ("day", "week", "month", "year", "all"):
|
||||
await self.bot.send_cmd_help(ctx)
|
||||
return
|
||||
|
||||
if sort_type == "new":
|
||||
sort = "time"
|
||||
elif sort_type == "top":
|
||||
sort = "top"
|
||||
|
||||
links = []
|
||||
|
||||
task = functools.partial(self.imgur.subreddit_gallery, subreddit,
|
||||
sort=sort, window=window, page=0)
|
||||
task = self.bot.loop.run_in_executor(None, task)
|
||||
try:
|
||||
items = await asyncio.wait_for(task, timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
await self.bot.say("Error: request timed out")
|
||||
return
|
||||
|
||||
for item in items[:3]:
|
||||
link = item.gifv if hasattr(item, "gifv") else item.link
|
||||
links.append("{}\n{}".format(item.title, link))
|
||||
|
||||
if links:
|
||||
await self.bot.say("\n".join(links))
|
||||
else:
|
||||
await self.bot.say("No results found.")
|
||||
|
||||
@commands.command(pass_context=True, no_pm=True)
|
||||
async def gif(self, ctx, *keywords):
|
||||
"""Retrieves first search result from giphy"""
|
||||
if keywords:
|
||||
keywords = "+".join(keywords)
|
||||
else:
|
||||
await self.bot.send_cmd_help(ctx)
|
||||
return
|
||||
|
||||
url = ("http://api.giphy.com/v1/gifs/search?&api_key={}&q={}"
|
||||
"".format(GIPHY_API_KEY, keywords))
|
||||
|
||||
async with aiohttp.get(url) as r:
|
||||
result = await r.json()
|
||||
if r.status == 200:
|
||||
if result["data"]:
|
||||
await self.bot.say(result["data"][0]["url"])
|
||||
else:
|
||||
await self.bot.say("No results found.")
|
||||
else:
|
||||
await self.bot.say("Error contacting the API")
|
||||
|
||||
@commands.command(pass_context=True, no_pm=True)
|
||||
async def gifr(self, ctx, *keywords):
|
||||
"""Retrieves a random gif from a giphy search"""
|
||||
if keywords:
|
||||
keywords = "+".join(keywords)
|
||||
else:
|
||||
await self.bot.send_cmd_help(ctx)
|
||||
return
|
||||
|
||||
url = ("http://api.giphy.com/v1/gifs/random?&api_key={}&tag={}"
|
||||
"".format(GIPHY_API_KEY, keywords))
|
||||
|
||||
async with aiohttp.get(url) as r:
|
||||
result = await r.json()
|
||||
if r.status == 200:
|
||||
if result["data"]:
|
||||
await self.bot.say(result["data"]["url"])
|
||||
else:
|
||||
await self.bot.say("No results found.")
|
||||
else:
|
||||
await self.bot.say("Error contacting the API")
|
||||
|
||||
|
||||
def setup(bot):
|
||||
if ImgurClient is False:
|
||||
raise RuntimeError("You need the imgurpython module to use this.\n"
|
||||
"pip3 install imgurpython")
|
||||
|
||||
bot.add_cog(Image(bot))
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,689 @@
|
|||
from discord.ext import commands
|
||||
from .utils.dataIO import dataIO
|
||||
from .utils.chat_formatting import escape_mass_mentions
|
||||
from .utils import checks
|
||||
from collections import defaultdict
|
||||
from string import ascii_letters
|
||||
from random import choice
|
||||
import discord
|
||||
import os
|
||||
import re
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
|
||||
|
||||
class StreamsError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class StreamNotFound(StreamsError):
|
||||
pass
|
||||
|
||||
|
||||
class APIError(StreamsError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidCredentials(StreamsError):
|
||||
pass
|
||||
|
||||
|
||||
class OfflineStream(StreamsError):
|
||||
pass
|
||||
|
||||
|
||||
class Streams:
|
||||
"""Streams
|
||||
|
||||
Alerts for a variety of streaming services"""
|
||||
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.twitch_streams = dataIO.load_json("data/streams/twitch.json")
|
||||
self.hitbox_streams = dataIO.load_json("data/streams/hitbox.json")
|
||||
self.mixer_streams = dataIO.load_json("data/streams/beam.json")
|
||||
self.picarto_streams = dataIO.load_json("data/streams/picarto.json")
|
||||
settings = dataIO.load_json("data/streams/settings.json")
|
||||
self.settings = defaultdict(dict, settings)
|
||||
self.messages_cache = defaultdict(list)
|
||||
|
||||
@commands.command()
|
||||
async def hitbox(self, stream: str):
|
||||
"""Checks if hitbox stream is online"""
|
||||
stream = escape_mass_mentions(stream)
|
||||
regex = r'^(https?\:\/\/)?(www\.)?(hitbox\.tv\/)'
|
||||
stream = re.sub(regex, '', stream)
|
||||
try:
|
||||
embed = await self.hitbox_online(stream)
|
||||
except OfflineStream:
|
||||
await self.bot.say(stream + " is offline.")
|
||||
except StreamNotFound:
|
||||
await self.bot.say("That stream doesn't exist.")
|
||||
except APIError:
|
||||
await self.bot.say("Error contacting the API.")
|
||||
else:
|
||||
await self.bot.say(embed=embed)
|
||||
|
||||
@commands.command(pass_context=True)
|
||||
async def twitch(self, ctx, stream: str):
|
||||
"""Checks if twitch stream is online"""
|
||||
stream = escape_mass_mentions(stream)
|
||||
regex = r'^(https?\:\/\/)?(www\.)?(twitch\.tv\/)'
|
||||
stream = re.sub(regex, '', stream)
|
||||
try:
|
||||
data = await self.fetch_twitch_ids(stream, raise_if_none=True)
|
||||
embed = await self.twitch_online(data[0]["_id"])
|
||||
except OfflineStream:
|
||||
await self.bot.say(stream + " is offline.")
|
||||
except StreamNotFound:
|
||||
await self.bot.say("That stream doesn't exist.")
|
||||
except APIError:
|
||||
await self.bot.say("Error contacting the API.")
|
||||
except InvalidCredentials:
|
||||
await self.bot.say("Owner: Client-ID is invalid or not set. "
|
||||
"See `{}streamset twitchtoken`"
|
||||
"".format(ctx.prefix))
|
||||
else:
|
||||
await self.bot.say(embed=embed)
|
||||
|
||||
@commands.command()
|
||||
async def mixer(self, stream: str):
|
||||
"""Checks if mixer stream is online"""
|
||||
stream = escape_mass_mentions(stream)
|
||||
regex = r'^(https?\:\/\/)?(www\.)?(mixer\.com\/)'
|
||||
stream = re.sub(regex, '', stream)
|
||||
try:
|
||||
embed = await self.mixer_online(stream)
|
||||
except OfflineStream:
|
||||
await self.bot.say(stream + " is offline.")
|
||||
except StreamNotFound:
|
||||
await self.bot.say("That stream doesn't exist.")
|
||||
except APIError:
|
||||
await self.bot.say("Error contacting the API.")
|
||||
else:
|
||||
await self.bot.say(embed=embed)
|
||||
|
||||
@commands.command()
|
||||
async def picarto(self, stream: str):
|
||||
"""Checks if picarto stream is online"""
|
||||
stream = escape_mass_mentions(stream)
|
||||
regex = r'^(https?\:\/\/)?(www\.)?(picarto\.tv\/)'
|
||||
stream = re.sub(regex, '', stream)
|
||||
try:
|
||||
embed = await self.picarto_online(stream)
|
||||
except OfflineStream:
|
||||
await self.bot.say(stream + " is offline.")
|
||||
except StreamNotFound:
|
||||
await self.bot.say("That stream doesn't exist.")
|
||||
except APIError:
|
||||
await self.bot.say("Error contacting the API.")
|
||||
else:
|
||||
await self.bot.say(embed=embed)
|
||||
|
||||
@commands.group(pass_context=True, no_pm=True)
|
||||
@checks.mod_or_permissions(manage_server=True)
|
||||
async def streamalert(self, ctx):
|
||||
"""Adds/removes stream alerts from the current channel"""
|
||||
if ctx.invoked_subcommand is None:
|
||||
await self.bot.send_cmd_help(ctx)
|
||||
|
||||
@streamalert.command(name="twitch", pass_context=True)
|
||||
async def twitch_alert(self, ctx, stream: str):
|
||||
"""Adds/removes twitch alerts from the current channel"""
|
||||
stream = escape_mass_mentions(stream)
|
||||
regex = r'^(https?\:\/\/)?(www\.)?(twitch\.tv\/)'
|
||||
stream = re.sub(regex, '', stream)
|
||||
channel = ctx.message.channel
|
||||
try:
|
||||
data = await self.fetch_twitch_ids(stream, raise_if_none=True)
|
||||
except StreamNotFound:
|
||||
await self.bot.say("That stream doesn't exist.")
|
||||
return
|
||||
except APIError:
|
||||
await self.bot.say("Error contacting the API.")
|
||||
return
|
||||
except InvalidCredentials:
|
||||
await self.bot.say("Owner: Client-ID is invalid or not set. "
|
||||
"See `{}streamset twitchtoken`"
|
||||
"".format(ctx.prefix))
|
||||
return
|
||||
|
||||
enabled = self.enable_or_disable_if_active(self.twitch_streams,
|
||||
stream,
|
||||
channel,
|
||||
_id=data[0]["_id"])
|
||||
|
||||
if enabled:
|
||||
await self.bot.say("Alert activated. I will notify this channel "
|
||||
"when {} is live.".format(stream))
|
||||
else:
|
||||
await self.bot.say("Alert has been removed from this channel.")
|
||||
|
||||
dataIO.save_json("data/streams/twitch.json", self.twitch_streams)
|
||||
|
||||
@streamalert.command(name="hitbox", pass_context=True)
|
||||
async def hitbox_alert(self, ctx, stream: str):
|
||||
"""Adds/removes hitbox alerts from the current channel"""
|
||||
stream = escape_mass_mentions(stream)
|
||||
regex = r'^(https?\:\/\/)?(www\.)?(hitbox\.tv\/)'
|
||||
stream = re.sub(regex, '', stream)
|
||||
channel = ctx.message.channel
|
||||
try:
|
||||
await self.hitbox_online(stream)
|
||||
except StreamNotFound:
|
||||
await self.bot.say("That stream doesn't exist.")
|
||||
return
|
||||
except APIError:
|
||||
await self.bot.say("Error contacting the API.")
|
||||
return
|
||||
except OfflineStream:
|
||||
pass
|
||||
|
||||
enabled = self.enable_or_disable_if_active(self.hitbox_streams,
|
||||
stream,
|
||||
channel)
|
||||
|
||||
if enabled:
|
||||
await self.bot.say("Alert activated. I will notify this channel "
|
||||
"when {} is live.".format(stream))
|
||||
else:
|
||||
await self.bot.say("Alert has been removed from this channel.")
|
||||
|
||||
dataIO.save_json("data/streams/hitbox.json", self.hitbox_streams)
|
||||
|
||||
@streamalert.command(name="mixer", pass_context=True)
|
||||
async def mixer_alert(self, ctx, stream: str):
|
||||
"""Adds/removes mixer alerts from the current channel"""
|
||||
stream = escape_mass_mentions(stream)
|
||||
regex = r'^(https?\:\/\/)?(www\.)?(mixer\.com\/)'
|
||||
stream = re.sub(regex, '', stream)
|
||||
channel = ctx.message.channel
|
||||
try:
|
||||
await self.mixer_online(stream)
|
||||
except StreamNotFound:
|
||||
await self.bot.say("That stream doesn't exist.")
|
||||
return
|
||||
except APIError:
|
||||
await self.bot.say("Error contacting the API.")
|
||||
return
|
||||
except OfflineStream:
|
||||
pass
|
||||
|
||||
enabled = self.enable_or_disable_if_active(self.mixer_streams,
|
||||
stream,
|
||||
channel)
|
||||
|
||||
if enabled:
|
||||
await self.bot.say("Alert activated. I will notify this channel "
|
||||
"when {} is live.".format(stream))
|
||||
else:
|
||||
await self.bot.say("Alert has been removed from this channel.")
|
||||
|
||||
dataIO.save_json("data/streams/beam.json", self.mixer_streams)
|
||||
|
||||
@streamalert.command(name="picarto", pass_context=True)
|
||||
async def picarto_alert(self, ctx, stream: str):
|
||||
"""Adds/removes picarto alerts from the current channel"""
|
||||
stream = escape_mass_mentions(stream)
|
||||
regex = r'^(https?\:\/\/)?(www\.)?(picarto\.tv\/)'
|
||||
stream = re.sub(regex, '', stream)
|
||||
channel = ctx.message.channel
|
||||
try:
|
||||
await self.picarto_online(stream)
|
||||
except StreamNotFound:
|
||||
await self.bot.say("That stream doesn't exist.")
|
||||
return
|
||||
except APIError:
|
||||
await self.bot.say("Error contacting the API.")
|
||||
return
|
||||
except OfflineStream:
|
||||
pass
|
||||
|
||||
enabled = self.enable_or_disable_if_active(self.picarto_streams,
|
||||
stream,
|
||||
channel)
|
||||
|
||||
if enabled:
|
||||
await self.bot.say("Alert activated. I will notify this channel "
|
||||
"when {} is live.".format(stream))
|
||||
else:
|
||||
await self.bot.say("Alert has been removed from this channel.")
|
||||
|
||||
dataIO.save_json("data/streams/picarto.json", self.picarto_streams)
|
||||
|
||||
@streamalert.command(name="stop", pass_context=True)
|
||||
async def stop_alert(self, ctx):
|
||||
"""Stops all streams alerts in the current channel"""
|
||||
channel = ctx.message.channel
|
||||
|
||||
streams = (
|
||||
self.hitbox_streams,
|
||||
self.twitch_streams,
|
||||
self.mixer_streams,
|
||||
self.picarto_streams
|
||||
)
|
||||
|
||||
for stream_type in streams:
|
||||
to_delete = []
|
||||
|
||||
for s in stream_type:
|
||||
if channel.id in s["CHANNELS"]:
|
||||
s["CHANNELS"].remove(channel.id)
|
||||
if not s["CHANNELS"]:
|
||||
to_delete.append(s)
|
||||
|
||||
for s in to_delete:
|
||||
stream_type.remove(s)
|
||||
|
||||
dataIO.save_json("data/streams/twitch.json", self.twitch_streams)
|
||||
dataIO.save_json("data/streams/hitbox.json", self.hitbox_streams)
|
||||
dataIO.save_json("data/streams/beam.json", self.mixer_streams)
|
||||
dataIO.save_json("data/streams/picarto.json", self.picarto_streams)
|
||||
|
||||
await self.bot.say("There will be no more stream alerts in this "
|
||||
"channel.")
|
||||
|
||||
@commands.group(pass_context=True)
|
||||
async def streamset(self, ctx):
|
||||
"""Stream settings"""
|
||||
if ctx.invoked_subcommand is None:
|
||||
await self.bot.send_cmd_help(ctx)
|
||||
|
||||
@streamset.command()
|
||||
@checks.is_owner()
|
||||
async def twitchtoken(self, token : str):
|
||||
"""Sets the Client-ID for Twitch
|
||||
|
||||
https://blog.twitch.tv/client-id-required-for-kraken-api-calls-afbb8e95f843"""
|
||||
self.settings["TWITCH_TOKEN"] = token
|
||||
dataIO.save_json("data/streams/settings.json", self.settings)
|
||||
await self.bot.say('Twitch Client-ID set.')
|
||||
|
||||
@streamset.command(pass_context=True, no_pm=True)
|
||||
@checks.admin()
|
||||
async def mention(self, ctx, *, mention_type : str):
|
||||
"""Sets mentions for stream alerts
|
||||
|
||||
Types: everyone, here, none"""
|
||||
server = ctx.message.server
|
||||
mention_type = mention_type.lower()
|
||||
|
||||
if mention_type in ("everyone", "here"):
|
||||
self.settings[server.id]["MENTION"] = "@" + mention_type
|
||||
await self.bot.say("When a stream is online @\u200b{} will be "
|
||||
"mentioned.".format(mention_type))
|
||||
elif mention_type == "none":
|
||||
self.settings[server.id]["MENTION"] = ""
|
||||
await self.bot.say("Mentions disabled.")
|
||||
else:
|
||||
await self.bot.send_cmd_help(ctx)
|
||||
|
||||
dataIO.save_json("data/streams/settings.json", self.settings)
|
||||
|
||||
@streamset.command(pass_context=True, no_pm=True)
|
||||
@checks.admin()
|
||||
async def autodelete(self, ctx):
|
||||
"""Toggles automatic notification deletion for streams that go offline"""
|
||||
server = ctx.message.server
|
||||
settings = self.settings[server.id]
|
||||
current = settings.get("AUTODELETE", True)
|
||||
settings["AUTODELETE"] = not current
|
||||
if settings["AUTODELETE"]:
|
||||
await self.bot.say("Notifications will be automatically deleted "
|
||||
"once the stream goes offline.")
|
||||
else:
|
||||
await self.bot.say("Notifications won't be deleted anymore.")
|
||||
|
||||
dataIO.save_json("data/streams/settings.json", self.settings)
|
||||
|
||||
async def hitbox_online(self, stream):
|
||||
url = "https://api.hitbox.tv/media/live/" + stream
|
||||
|
||||
async with aiohttp.get(url) as r:
|
||||
data = await r.json(encoding='utf-8')
|
||||
|
||||
if "livestream" not in data:
|
||||
raise StreamNotFound()
|
||||
elif data["livestream"][0]["media_is_live"] == "0":
|
||||
raise OfflineStream()
|
||||
elif data["livestream"][0]["media_is_live"] == "1":
|
||||
return self.hitbox_embed(data)
|
||||
|
||||
raise APIError()
|
||||
|
||||
async def twitch_online(self, stream):
|
||||
session = aiohttp.ClientSession()
|
||||
url = "https://api.twitch.tv/kraken/streams/" + stream
|
||||
header = {
|
||||
'Client-ID': self.settings.get("TWITCH_TOKEN", ""),
|
||||
'Accept': 'application/vnd.twitchtv.v5+json'
|
||||
}
|
||||
|
||||
async with session.get(url, headers=header) as r:
|
||||
data = await r.json(encoding='utf-8')
|
||||
await session.close()
|
||||
if r.status == 200:
|
||||
if data["stream"] is None:
|
||||
raise OfflineStream()
|
||||
return self.twitch_embed(data)
|
||||
elif r.status == 400:
|
||||
raise InvalidCredentials()
|
||||
elif r.status == 404:
|
||||
raise StreamNotFound()
|
||||
else:
|
||||
raise APIError()
|
||||
|
||||
async def mixer_online(self, stream):
|
||||
url = "https://mixer.com/api/v1/channels/" + stream
|
||||
|
||||
async with aiohttp.get(url) as r:
|
||||
data = await r.json(encoding='utf-8')
|
||||
if r.status == 200:
|
||||
if data["online"] is True:
|
||||
return self.mixer_embed(data)
|
||||
else:
|
||||
raise OfflineStream()
|
||||
elif r.status == 404:
|
||||
raise StreamNotFound()
|
||||
else:
|
||||
raise APIError()
|
||||
|
||||
async def picarto_online(self, stream):
|
||||
url = "https://api.picarto.tv/v1/channel/name/" + stream
|
||||
|
||||
async with aiohttp.get(url) as r:
|
||||
data = await r.text(encoding='utf-8')
|
||||
if r.status == 200:
|
||||
data = json.loads(data)
|
||||
if data["online"] is True:
|
||||
return self.picarto_embed(data)
|
||||
else:
|
||||
raise OfflineStream()
|
||||
elif r.status == 404:
|
||||
raise StreamNotFound()
|
||||
else:
|
||||
raise APIError()
|
||||
|
||||
async def fetch_twitch_ids(self, *streams, raise_if_none=False):
|
||||
def chunks(l):
|
||||
for i in range(0, len(l), 100):
|
||||
yield l[i:i + 100]
|
||||
|
||||
base_url = "https://api.twitch.tv/kraken/users?login="
|
||||
header = {
|
||||
'Client-ID': self.settings.get("TWITCH_TOKEN", ""),
|
||||
'Accept': 'application/vnd.twitchtv.v5+json'
|
||||
}
|
||||
results = []
|
||||
|
||||
for streams_list in chunks(streams):
|
||||
session = aiohttp.ClientSession()
|
||||
url = base_url + ",".join(streams_list)
|
||||
async with session.get(url, headers=header) as r:
|
||||
data = await r.json()
|
||||
if r.status == 200:
|
||||
results.extend(data["users"])
|
||||
elif r.status == 400:
|
||||
raise InvalidCredentials()
|
||||
else:
|
||||
raise APIError()
|
||||
await session.close()
|
||||
|
||||
if not results and raise_if_none:
|
||||
raise StreamNotFound()
|
||||
|
||||
return results
|
||||
|
||||
def twitch_embed(self, data):
|
||||
channel = data["stream"]["channel"]
|
||||
url = channel["url"]
|
||||
logo = channel["logo"]
|
||||
if logo is None:
|
||||
logo = "https://static-cdn.jtvnw.net/jtv_user_pictures/xarth/404_user_70x70.png"
|
||||
status = channel["status"]
|
||||
if not status:
|
||||
status = "Untitled broadcast"
|
||||
embed = discord.Embed(title=status, url=url)
|
||||
embed.set_author(name=channel["display_name"])
|
||||
embed.add_field(name="Followers", value=channel["followers"])
|
||||
embed.add_field(name="Total views", value=channel["views"])
|
||||
embed.set_thumbnail(url=logo)
|
||||
if data["stream"]["preview"]["medium"]:
|
||||
embed.set_image(url=data["stream"]["preview"]["medium"] + self.rnd_attr())
|
||||
if channel["game"]:
|
||||
embed.set_footer(text="Playing: " + channel["game"])
|
||||
embed.color = 0x6441A4
|
||||
return embed
|
||||
|
||||
def hitbox_embed(self, data):
|
||||
base_url = "https://edge.sf.hitbox.tv"
|
||||
livestream = data["livestream"][0]
|
||||
channel = livestream["channel"]
|
||||
url = channel["channel_link"]
|
||||
embed = discord.Embed(title=livestream["media_status"], url=url)
|
||||
embed.set_author(name=livestream["media_name"])
|
||||
embed.add_field(name="Followers", value=channel["followers"])
|
||||
#embed.add_field(name="Views", value=channel["views"])
|
||||
embed.set_thumbnail(url=base_url + channel["user_logo"])
|
||||
if livestream["media_thumbnail"]:
|
||||
embed.set_image(url=base_url + livestream["media_thumbnail"] + self.rnd_attr())
|
||||
embed.set_footer(text="Playing: " + livestream["category_name"])
|
||||
embed.color = 0x98CB00
|
||||
return embed
|
||||
|
||||
def mixer_embed(self, data):
|
||||
default_avatar = ("https://mixer.com/_latest/assets/images/main/"
|
||||
"avatars/default.jpg")
|
||||
user = data["user"]
|
||||
url = "https://mixer.com/" + data["token"]
|
||||
embed = discord.Embed(title=data["name"], url=url)
|
||||
embed.set_author(name=user["username"])
|
||||
embed.add_field(name="Followers", value=data["numFollowers"])
|
||||
embed.add_field(name="Total views", value=data["viewersTotal"])
|
||||
if user["avatarUrl"]:
|
||||
embed.set_thumbnail(url=user["avatarUrl"])
|
||||
else:
|
||||
embed.set_thumbnail(url=default_avatar)
|
||||
if data["thumbnail"]:
|
||||
embed.set_image(url=data["thumbnail"]["url"] + self.rnd_attr())
|
||||
embed.color = 0x4C90F3
|
||||
if data["type"] is not None:
|
||||
embed.set_footer(text="Playing: " + data["type"]["name"])
|
||||
return embed
|
||||
|
||||
def picarto_embed(self, data):
|
||||
avatar = ("https://picarto.tv/user_data/usrimg/{}/dsdefault.jpg{}"
|
||||
"".format(data["name"].lower(), self.rnd_attr()))
|
||||
url = "https://picarto.tv/" + data["name"]
|
||||
thumbnail = ("https://thumb.picarto.tv/thumbnail/{}.jpg"
|
||||
"".format(data["name"]))
|
||||
embed = discord.Embed(title=data["title"], url=url)
|
||||
embed.set_author(name=data["name"])
|
||||
embed.set_image(url=thumbnail + self.rnd_attr())
|
||||
embed.add_field(name="Followers", value=data["followers"])
|
||||
embed.add_field(name="Total views", value=data["viewers_total"])
|
||||
embed.set_thumbnail(url=avatar)
|
||||
embed.color = 0x132332
|
||||
data["tags"] = ", ".join(data["tags"])
|
||||
|
||||
if not data["tags"]:
|
||||
data["tags"] = "None"
|
||||
|
||||
if data["adult"]:
|
||||
data["adult"] = "NSFW | "
|
||||
else:
|
||||
data["adult"] = ""
|
||||
|
||||
embed.color = 0x4C90F3
|
||||
embed.set_footer(text="{adult}Category: {category} | Tags: {tags}"
|
||||
"".format(**data))
|
||||
return embed
|
||||
|
||||
def enable_or_disable_if_active(self, streams, stream, channel, _id=None):
|
||||
"""Returns True if enabled or False if disabled"""
|
||||
for i, s in enumerate(streams):
|
||||
if s["NAME"] != stream:
|
||||
continue
|
||||
|
||||
if channel.id in s["CHANNELS"]:
|
||||
streams[i]["CHANNELS"].remove(channel.id)
|
||||
if not s["CHANNELS"]:
|
||||
streams.remove(s)
|
||||
return False
|
||||
else:
|
||||
streams[i]["CHANNELS"].append(channel.id)
|
||||
return True
|
||||
|
||||
data = {"CHANNELS": [channel.id],
|
||||
"NAME": stream,
|
||||
"ALREADY_ONLINE": False}
|
||||
|
||||
if _id:
|
||||
data["ID"] = _id
|
||||
|
||||
streams.append(data)
|
||||
|
||||
return True
|
||||
|
||||
async def stream_checker(self):
|
||||
CHECK_DELAY = 60
|
||||
|
||||
try:
|
||||
await self._migration_twitch_v5()
|
||||
except InvalidCredentials:
|
||||
print("Error during convertion of twitch usernames to IDs: "
|
||||
"invalid token")
|
||||
except Exception as e:
|
||||
print("Error during convertion of twitch usernames to IDs: "
|
||||
"{}".format(e))
|
||||
|
||||
while self == self.bot.get_cog("Streams"):
|
||||
save = False
|
||||
|
||||
streams = ((self.twitch_streams, self.twitch_online),
|
||||
(self.hitbox_streams, self.hitbox_online),
|
||||
(self.mixer_streams, self.mixer_online),
|
||||
(self.picarto_streams, self.picarto_online))
|
||||
|
||||
for streams_list, parser in streams:
|
||||
if parser == self.twitch_online:
|
||||
_type = "ID"
|
||||
else:
|
||||
_type = "NAME"
|
||||
for stream in streams_list:
|
||||
if _type not in stream:
|
||||
continue
|
||||
key = (parser, stream[_type])
|
||||
try:
|
||||
embed = await parser(stream[_type])
|
||||
except OfflineStream:
|
||||
if stream["ALREADY_ONLINE"]:
|
||||
stream["ALREADY_ONLINE"] = False
|
||||
save = True
|
||||
await self.delete_old_notifications(key)
|
||||
except: # We don't want our task to die
|
||||
continue
|
||||
else:
|
||||
if stream["ALREADY_ONLINE"]:
|
||||
continue
|
||||
save = True
|
||||
stream["ALREADY_ONLINE"] = True
|
||||
messages_sent = []
|
||||
for channel_id in stream["CHANNELS"]:
|
||||
channel = self.bot.get_channel(channel_id)
|
||||
if channel is None:
|
||||
continue
|
||||
mention = self.settings.get(channel.server.id, {}).get("MENTION", "")
|
||||
can_speak = channel.permissions_for(channel.server.me).send_messages
|
||||
message = mention + " {} is live!".format(stream["NAME"])
|
||||
if channel and can_speak:
|
||||
m = await self.bot.send_message(channel, message, embed=embed)
|
||||
messages_sent.append(m)
|
||||
self.messages_cache[key] = messages_sent
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if save:
|
||||
dataIO.save_json("data/streams/twitch.json", self.twitch_streams)
|
||||
dataIO.save_json("data/streams/hitbox.json", self.hitbox_streams)
|
||||
dataIO.save_json("data/streams/beam.json", self.mixer_streams)
|
||||
dataIO.save_json("data/streams/picarto.json", self.picarto_streams)
|
||||
|
||||
await asyncio.sleep(CHECK_DELAY)
|
||||
|
||||
async def delete_old_notifications(self, key):
|
||||
for message in self.messages_cache[key]:
|
||||
server = message.server
|
||||
settings = self.settings.get(server.id, {})
|
||||
is_enabled = settings.get("AUTODELETE", True)
|
||||
try:
|
||||
if is_enabled:
|
||||
await self.bot.delete_message(message)
|
||||
except:
|
||||
pass
|
||||
|
||||
del self.messages_cache[key]
|
||||
|
||||
def rnd_attr(self):
|
||||
"""Avoids Discord's caching"""
|
||||
return "?rnd=" + "".join([choice(ascii_letters) for i in range(6)])
|
||||
|
||||
async def _migration_twitch_v5(self):
|
||||
# Migration of old twitch streams to API v5
|
||||
to_convert = []
|
||||
for stream in self.twitch_streams:
|
||||
if "ID" not in stream:
|
||||
to_convert.append(stream["NAME"])
|
||||
|
||||
if not to_convert:
|
||||
return
|
||||
|
||||
results = await self.fetch_twitch_ids(*to_convert)
|
||||
|
||||
for stream in self.twitch_streams:
|
||||
for result in results:
|
||||
if stream["NAME"].lower() == result["name"].lower():
|
||||
stream["ID"] = result["_id"]
|
||||
|
||||
# We might as well delete the invalid / renamed ones
|
||||
self.twitch_streams = [s for s in self.twitch_streams if "ID" in s]
|
||||
|
||||
dataIO.save_json("data/streams/twitch.json", self.twitch_streams)
|
||||
|
||||
|
||||
def check_folders():
|
||||
if not os.path.exists("data/streams"):
|
||||
print("Creating data/streams folder...")
|
||||
os.makedirs("data/streams")
|
||||
|
||||
|
||||
def check_files():
|
||||
stream_files = (
|
||||
"twitch.json",
|
||||
"hitbox.json",
|
||||
"beam.json",
|
||||
"picarto.json"
|
||||
)
|
||||
|
||||
for filename in stream_files:
|
||||
if not dataIO.is_valid_json("data/streams/" + filename):
|
||||
print("Creating empty {}...".format(filename))
|
||||
dataIO.save_json("data/streams/" + filename, [])
|
||||
|
||||
f = "data/streams/settings.json"
|
||||
if not dataIO.is_valid_json(f):
|
||||
print("Creating empty settings.json...")
|
||||
dataIO.save_json(f, {})
|
||||
|
||||
|
||||
def setup(bot):
|
||||
logger = logging.getLogger('aiohttp.client')
|
||||
logger.setLevel(50) # Stops warning spam
|
||||
check_folders()
|
||||
check_files()
|
||||
n = Streams(bot)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.create_task(n.stream_checker())
|
||||
bot.add_cog(n)
|
||||
|
|
@ -0,0 +1,332 @@
|
|||
from discord.ext import commands
|
||||
from random import choice
|
||||
from .utils.dataIO import dataIO
|
||||
from .utils import checks
|
||||
from .utils.chat_formatting import box
|
||||
from collections import Counter, defaultdict, namedtuple
|
||||
import discord
|
||||
import time
|
||||
import os
|
||||
import asyncio
|
||||
import chardet
|
||||
|
||||
DEFAULTS = {"MAX_SCORE" : 10,
|
||||
"TIMEOUT" : 120,
|
||||
"DELAY" : 15,
|
||||
"BOT_PLAYS" : False,
|
||||
"REVEAL_ANSWER": True}
|
||||
|
||||
TriviaLine = namedtuple("TriviaLine", "question answers")
|
||||
|
||||
|
||||
class Trivia:
|
||||
"""General commands."""
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.trivia_sessions = []
|
||||
self.file_path = "data/trivia/settings.json"
|
||||
settings = dataIO.load_json(self.file_path)
|
||||
self.settings = defaultdict(lambda: DEFAULTS.copy(), settings)
|
||||
|
||||
@commands.group(pass_context=True, no_pm=True)
|
||||
@checks.mod_or_permissions(administrator=True)
|
||||
async def triviaset(self, ctx):
|
||||
"""Change trivia settings"""
|
||||
server = ctx.message.server
|
||||
if ctx.invoked_subcommand is None:
|
||||
settings = self.settings[server.id]
|
||||
msg = box("Red gains points: {BOT_PLAYS}\n"
|
||||
"Seconds to answer: {DELAY}\n"
|
||||
"Points to win: {MAX_SCORE}\n"
|
||||
"Reveal answer on timeout: {REVEAL_ANSWER}\n"
|
||||
"".format(**settings))
|
||||
msg += "\nSee {}help triviaset to edit the settings".format(ctx.prefix)
|
||||
await self.bot.say(msg)
|
||||
|
||||
@triviaset.command(pass_context=True)
|
||||
async def maxscore(self, ctx, score : int):
|
||||
"""Points required to win"""
|
||||
server = ctx.message.server
|
||||
if score > 0:
|
||||
self.settings[server.id]["MAX_SCORE"] = score
|
||||
self.save_settings()
|
||||
await self.bot.say("Points required to win set to {}".format(score))
|
||||
else:
|
||||
await self.bot.say("Score must be superior to 0.")
|
||||
|
||||
@triviaset.command(pass_context=True)
|
||||
async def timelimit(self, ctx, seconds : int):
|
||||
"""Maximum seconds to answer"""
|
||||
server = ctx.message.server
|
||||
if seconds > 4:
|
||||
self.settings[server.id]["DELAY"] = seconds
|
||||
self.save_settings()
|
||||
await self.bot.say("Maximum seconds to answer set to {}".format(seconds))
|
||||
else:
|
||||
await self.bot.say("Seconds must be at least 5.")
|
||||
|
||||
@triviaset.command(pass_context=True)
|
||||
async def botplays(self, ctx):
|
||||
"""Red gains points"""
|
||||
server = ctx.message.server
|
||||
if self.settings[server.id]["BOT_PLAYS"]:
|
||||
self.settings[server.id]["BOT_PLAYS"] = False
|
||||
await self.bot.say("Alright, I won't embarass you at trivia anymore.")
|
||||
else:
|
||||
self.settings[server.id]["BOT_PLAYS"] = True
|
||||
await self.bot.say("I'll gain a point everytime you don't answer in time.")
|
||||
self.save_settings()
|
||||
|
||||
@triviaset.command(pass_context=True)
|
||||
async def revealanswer(self, ctx):
|
||||
"""Reveals answer to the question on timeout"""
|
||||
server = ctx.message.server
|
||||
if self.settings[server.id]["REVEAL_ANSWER"]:
|
||||
self.settings[server.id]["REVEAL_ANSWER"] = False
|
||||
await self.bot.say("I won't reveal the answer to the questions anymore.")
|
||||
else:
|
||||
self.settings[server.id]["REVEAL_ANSWER"] = True
|
||||
await self.bot.say("I'll reveal the answer if no one knows it.")
|
||||
self.save_settings()
|
||||
|
||||
@commands.group(pass_context=True, invoke_without_command=True, no_pm=True)
|
||||
async def trivia(self, ctx, list_name: str):
|
||||
"""Start a trivia session with the specified list"""
|
||||
message = ctx.message
|
||||
server = message.server
|
||||
session = self.get_trivia_by_channel(message.channel)
|
||||
if not session:
|
||||
try:
|
||||
trivia_list = self.parse_trivia_list(list_name)
|
||||
except FileNotFoundError:
|
||||
await self.bot.say("That trivia list doesn't exist.")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
await self.bot.say("Error loading the trivia list.")
|
||||
else:
|
||||
settings = self.settings[server.id]
|
||||
t = TriviaSession(self.bot, trivia_list, message, settings)
|
||||
self.trivia_sessions.append(t)
|
||||
await t.new_question()
|
||||
else:
|
||||
await self.bot.say("A trivia session is already ongoing in this channel.")
|
||||
|
||||
@trivia.group(name="stop", pass_context=True, no_pm=True)
|
||||
async def trivia_stop(self, ctx):
|
||||
"""Stops an ongoing trivia session"""
|
||||
author = ctx.message.author
|
||||
server = author.server
|
||||
admin_role = self.bot.settings.get_server_admin(server)
|
||||
mod_role = self.bot.settings.get_server_mod(server)
|
||||
is_admin = discord.utils.get(author.roles, name=admin_role)
|
||||
is_mod = discord.utils.get(author.roles, name=mod_role)
|
||||
is_owner = author.id == self.bot.settings.owner
|
||||
is_server_owner = author == server.owner
|
||||
is_authorized = is_admin or is_mod or is_owner or is_server_owner
|
||||
|
||||
session = self.get_trivia_by_channel(ctx.message.channel)
|
||||
if session:
|
||||
if author == session.starter or is_authorized:
|
||||
await session.end_game()
|
||||
await self.bot.say("Trivia stopped.")
|
||||
else:
|
||||
await self.bot.say("You are not allowed to do that.")
|
||||
else:
|
||||
await self.bot.say("There's no trivia session ongoing in this channel.")
|
||||
|
||||
@trivia.group(name="list")
|
||||
async def trivia_list(self):
|
||||
"""Shows available trivia lists"""
|
||||
lists = os.listdir("data/trivia/")
|
||||
lists = [l for l in lists if l.endswith(".txt") and " " not in l]
|
||||
lists = [l.replace(".txt", "") for l in lists]
|
||||
|
||||
if lists:
|
||||
msg = "+ Available trivia lists\n\n" + ", ".join(sorted(lists))
|
||||
msg = box(msg, lang="diff")
|
||||
if len(lists) < 100:
|
||||
await self.bot.say(msg)
|
||||
else:
|
||||
await self.bot.whisper(msg)
|
||||
else:
|
||||
await self.bot.say("There are no trivia lists available.")
|
||||
|
||||
def parse_trivia_list(self, filename):
|
||||
path = "data/trivia/{}.txt".format(filename)
|
||||
parsed_list = []
|
||||
|
||||
with open(path, "rb") as f:
|
||||
try:
|
||||
encoding = chardet.detect(f.read())["encoding"]
|
||||
except:
|
||||
encoding = "ISO-8859-1"
|
||||
|
||||
with open(path, "r", encoding=encoding) as f:
|
||||
trivia_list = f.readlines()
|
||||
|
||||
for line in trivia_list:
|
||||
if "`" not in line:
|
||||
continue
|
||||
line = line.replace("\n", "")
|
||||
line = line.split("`")
|
||||
question = line[0]
|
||||
answers = []
|
||||
for l in line[1:]:
|
||||
answers.append(l.strip())
|
||||
if len(line) >= 2 and question and answers:
|
||||
line = TriviaLine(question=question, answers=answers)
|
||||
parsed_list.append(line)
|
||||
|
||||
if not parsed_list:
|
||||
raise ValueError("Empty trivia list")
|
||||
|
||||
return parsed_list
|
||||
|
||||
def get_trivia_by_channel(self, channel):
|
||||
for t in self.trivia_sessions:
|
||||
if t.channel == channel:
|
||||
return t
|
||||
return None
|
||||
|
||||
async def on_message(self, message):
|
||||
if message.author != self.bot.user:
|
||||
session = self.get_trivia_by_channel(message.channel)
|
||||
if session:
|
||||
await session.check_answer(message)
|
||||
|
||||
async def on_trivia_end(self, instance):
|
||||
if instance in self.trivia_sessions:
|
||||
self.trivia_sessions.remove(instance)
|
||||
|
||||
def save_settings(self):
|
||||
dataIO.save_json(self.file_path, self.settings)
|
||||
|
||||
|
||||
class TriviaSession():
|
||||
def __init__(self, bot, trivia_list, message, settings):
|
||||
self.bot = bot
|
||||
self.reveal_messages = ("I know this one! {}!",
|
||||
"Easy: {}.",
|
||||
"Oh really? It's {} of course.")
|
||||
self.fail_messages = ("To the next one I guess...",
|
||||
"Moving on...",
|
||||
"I'm sure you'll know the answer of the next one.",
|
||||
"\N{PENSIVE FACE} Next one.")
|
||||
self.current_line = None # {"QUESTION" : "String", "ANSWERS" : []}
|
||||
self.question_list = trivia_list
|
||||
self.channel = message.channel
|
||||
self.starter = message.author
|
||||
self.scores = Counter()
|
||||
self.status = "new question"
|
||||
self.timer = None
|
||||
self.timeout = time.perf_counter()
|
||||
self.count = 0
|
||||
self.settings = settings
|
||||
|
||||
async def stop_trivia(self):
|
||||
self.status = "stop"
|
||||
self.bot.dispatch("trivia_end", self)
|
||||
|
||||
async def end_game(self):
|
||||
self.status = "stop"
|
||||
if self.scores:
|
||||
await self.send_table()
|
||||
self.bot.dispatch("trivia_end", self)
|
||||
|
||||
async def new_question(self):
|
||||
for score in self.scores.values():
|
||||
if score == self.settings["MAX_SCORE"]:
|
||||
await self.end_game()
|
||||
return True
|
||||
if self.question_list == []:
|
||||
await self.end_game()
|
||||
return True
|
||||
self.current_line = choice(self.question_list)
|
||||
self.question_list.remove(self.current_line)
|
||||
self.status = "waiting for answer"
|
||||
self.count += 1
|
||||
self.timer = int(time.perf_counter())
|
||||
msg = "**Question number {}!**\n\n{}".format(self.count, self.current_line.question)
|
||||
await self.bot.say(msg)
|
||||
|
||||
while self.status != "correct answer" and abs(self.timer - int(time.perf_counter())) <= self.settings["DELAY"]:
|
||||
if abs(self.timeout - int(time.perf_counter())) >= self.settings["TIMEOUT"]:
|
||||
await self.bot.say("Guys...? Well, I guess I'll stop then.")
|
||||
await self.stop_trivia()
|
||||
return True
|
||||
await asyncio.sleep(1) #Waiting for an answer or for the time limit
|
||||
if self.status == "correct answer":
|
||||
self.status = "new question"
|
||||
await asyncio.sleep(3)
|
||||
if not self.status == "stop":
|
||||
await self.new_question()
|
||||
elif self.status == "stop":
|
||||
return True
|
||||
else:
|
||||
if self.settings["REVEAL_ANSWER"]:
|
||||
msg = choice(self.reveal_messages).format(self.current_line.answers[0])
|
||||
else:
|
||||
msg = choice(self.fail_messages)
|
||||
if self.settings["BOT_PLAYS"]:
|
||||
msg += " **+1** for me!"
|
||||
self.scores[self.bot.user] += 1
|
||||
self.current_line = None
|
||||
await self.bot.say(msg)
|
||||
await self.bot.type()
|
||||
await asyncio.sleep(3)
|
||||
if not self.status == "stop":
|
||||
await self.new_question()
|
||||
|
||||
async def send_table(self):
|
||||
t = "+ Results: \n\n"
|
||||
for user, score in self.scores.most_common():
|
||||
t += "+ {}\t{}\n".format(user, score)
|
||||
await self.bot.say(box(t, lang="diff"))
|
||||
|
||||
async def check_answer(self, message):
|
||||
if message.author == self.bot.user:
|
||||
return
|
||||
elif self.current_line is None:
|
||||
return
|
||||
|
||||
self.timeout = time.perf_counter()
|
||||
has_guessed = False
|
||||
|
||||
for answer in self.current_line.answers:
|
||||
answer = answer.lower()
|
||||
guess = message.content.lower()
|
||||
if " " not in answer: # Exact matching, issue #331
|
||||
guess = guess.split(" ")
|
||||
for word in guess:
|
||||
if word == answer:
|
||||
has_guessed = True
|
||||
else: # The answer has spaces, we can't be as strict
|
||||
if answer in guess:
|
||||
has_guessed = True
|
||||
|
||||
if has_guessed:
|
||||
self.current_line = None
|
||||
self.status = "correct answer"
|
||||
self.scores[message.author] += 1
|
||||
msg = "You got it {}! **+1** to you!".format(message.author.name)
|
||||
await self.bot.send_message(message.channel, msg)
|
||||
|
||||
|
||||
def check_folders():
|
||||
folders = ("data", "data/trivia/")
|
||||
for folder in folders:
|
||||
if not os.path.exists(folder):
|
||||
print("Creating " + folder + " folder...")
|
||||
os.makedirs(folder)
|
||||
|
||||
|
||||
def check_files():
|
||||
if not os.path.isfile("data/trivia/settings.json"):
|
||||
print("Creating empty settings.json...")
|
||||
dataIO.save_json("data/trivia/settings.json", {})
|
||||
|
||||
|
||||
def setup(bot):
|
||||
check_folders()
|
||||
check_files()
|
||||
bot.add_cog(Trivia(bot))
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
def error(text):
|
||||
return "\N{NO ENTRY SIGN} {}".format(text)
|
||||
|
||||
|
||||
def warning(text):
|
||||
return "\N{WARNING SIGN} {}".format(text)
|
||||
|
||||
|
||||
def info(text):
|
||||
return "\N{INFORMATION SOURCE} {}".format(text)
|
||||
|
||||
|
||||
def question(text):
|
||||
return "\N{BLACK QUESTION MARK ORNAMENT} {}".format(text)
|
||||
|
||||
|
||||
def bold(text):
|
||||
return "**{}**".format(text)
|
||||
|
||||
|
||||
def box(text, lang=""):
|
||||
ret = "```{}\n{}\n```".format(lang, text)
|
||||
return ret
|
||||
|
||||
|
||||
def inline(text):
|
||||
return "`{}`".format(text)
|
||||
|
||||
|
||||
def italics(text):
|
||||
return "*{}*".format(text)
|
||||
|
||||
|
||||
def pagify(text, delims=["\n"], *, escape=True, shorten_by=8,
|
||||
page_length=2000):
|
||||
"""DOES NOT RESPECT MARKDOWN BOXES OR INLINE CODE"""
|
||||
in_text = text
|
||||
if escape:
|
||||
num_mentions = text.count("@here") + text.count("@everyone")
|
||||
shorten_by += num_mentions
|
||||
page_length -= shorten_by
|
||||
while len(in_text) > page_length:
|
||||
closest_delim = max([in_text.rfind(d, 0, page_length)
|
||||
for d in delims])
|
||||
closest_delim = closest_delim if closest_delim != -1 else page_length
|
||||
if escape:
|
||||
to_send = escape_mass_mentions(in_text[:closest_delim])
|
||||
else:
|
||||
to_send = in_text[:closest_delim]
|
||||
yield to_send
|
||||
in_text = in_text[closest_delim:]
|
||||
|
||||
if escape:
|
||||
yield escape_mass_mentions(in_text)
|
||||
else:
|
||||
yield in_text
|
||||
|
||||
|
||||
def strikethrough(text):
|
||||
return "~~{}~~".format(text)
|
||||
|
||||
|
||||
def underline(text):
|
||||
return "__{}__".format(text)
|
||||
|
||||
|
||||
def escape(text, *, mass_mentions=False, formatting=False):
|
||||
if mass_mentions:
|
||||
text = text.replace("@everyone", "@\u200beveryone")
|
||||
text = text.replace("@here", "@\u200bhere")
|
||||
if formatting:
|
||||
text = (text.replace("`", "\\`")
|
||||
.replace("*", "\\*")
|
||||
.replace("_", "\\_")
|
||||
.replace("~", "\\~"))
|
||||
return text
|
||||
|
||||
|
||||
def escape_mass_mentions(text):
|
||||
return escape(text, mass_mentions=True)
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
from discord.ext import commands
|
||||
import discord.utils
|
||||
from __main__ import settings
|
||||
|
||||
#
|
||||
# This is a modified version of checks.py, originally made by Rapptz
|
||||
#
|
||||
# https://github.com/Rapptz
|
||||
# https://github.com/Rapptz/RoboDanny/tree/async
|
||||
#
|
||||
|
||||
def is_owner_check(ctx):
|
||||
return ctx.message.author.id == settings.owner
|
||||
|
||||
def is_owner():
|
||||
return commands.check(is_owner_check)
|
||||
|
||||
# The permission system of the bot is based on a "just works" basis
|
||||
# You have permissions and the bot has permissions. If you meet the permissions
|
||||
# required to execute the command (and the bot does as well) then it goes through
|
||||
# and you can execute the command.
|
||||
# If these checks fail, then there are two fallbacks.
|
||||
# A role with the name of Bot Mod and a role with the name of Bot Admin.
|
||||
# Having these roles provides you access to certain commands without actually having
|
||||
# the permissions required for them.
|
||||
# Of course, the owner will always be able to execute commands.
|
||||
|
||||
def check_permissions(ctx, perms):
|
||||
if is_owner_check(ctx):
|
||||
return True
|
||||
elif not perms:
|
||||
return False
|
||||
|
||||
ch = ctx.message.channel
|
||||
author = ctx.message.author
|
||||
resolved = ch.permissions_for(author)
|
||||
return all(getattr(resolved, name, None) == value for name, value in perms.items())
|
||||
|
||||
def role_or_permissions(ctx, check, **perms):
|
||||
if check_permissions(ctx, perms):
|
||||
return True
|
||||
|
||||
ch = ctx.message.channel
|
||||
author = ctx.message.author
|
||||
if ch.is_private:
|
||||
return False # can't have roles in PMs
|
||||
|
||||
role = discord.utils.find(check, author.roles)
|
||||
return role is not None
|
||||
|
||||
def mod_or_permissions(**perms):
|
||||
def predicate(ctx):
|
||||
server = ctx.message.server
|
||||
mod_role = settings.get_server_mod(server).lower()
|
||||
admin_role = settings.get_server_admin(server).lower()
|
||||
return role_or_permissions(ctx, lambda r: r.name.lower() in (mod_role,admin_role), **perms)
|
||||
|
||||
return commands.check(predicate)
|
||||
|
||||
def admin_or_permissions(**perms):
|
||||
def predicate(ctx):
|
||||
server = ctx.message.server
|
||||
admin_role = settings.get_server_admin(server)
|
||||
return role_or_permissions(ctx, lambda r: r.name.lower() == admin_role.lower(), **perms)
|
||||
|
||||
return commands.check(predicate)
|
||||
|
||||
def serverowner_or_permissions(**perms):
|
||||
def predicate(ctx):
|
||||
if ctx.message.server is None:
|
||||
return False
|
||||
server = ctx.message.server
|
||||
owner = server.owner
|
||||
|
||||
if ctx.message.author.id == owner.id:
|
||||
return True
|
||||
|
||||
return check_permissions(ctx,perms)
|
||||
return commands.check(predicate)
|
||||
|
||||
def serverowner():
|
||||
return serverowner_or_permissions()
|
||||
|
||||
def admin():
|
||||
return admin_or_permissions()
|
||||
|
||||
def mod():
|
||||
return mod_or_permissions()
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
import json
|
||||
import os
|
||||
import logging
|
||||
from random import randint
|
||||
|
||||
class InvalidFileIO(Exception):
|
||||
pass
|
||||
|
||||
class DataIO():
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger("red")
|
||||
|
||||
def save_json(self, filename, data):
|
||||
"""Atomically saves json file"""
|
||||
rnd = randint(1000, 9999)
|
||||
path, ext = os.path.splitext(filename)
|
||||
tmp_file = "{}-{}.tmp".format(path, rnd)
|
||||
self._save_json(tmp_file, data)
|
||||
try:
|
||||
self._read_json(tmp_file)
|
||||
except json.decoder.JSONDecodeError:
|
||||
self.logger.exception("Attempted to write file {} but JSON "
|
||||
"integrity check on tmp file has failed. "
|
||||
"The original file is unaltered."
|
||||
"".format(filename))
|
||||
return False
|
||||
os.replace(tmp_file, filename)
|
||||
return True
|
||||
|
||||
def load_json(self, filename):
|
||||
"""Loads json file"""
|
||||
return self._read_json(filename)
|
||||
|
||||
def is_valid_json(self, filename):
|
||||
"""Verifies if json file exists / is readable"""
|
||||
try:
|
||||
self._read_json(filename)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except json.decoder.JSONDecodeError:
|
||||
return False
|
||||
|
||||
def _read_json(self, filename):
|
||||
with open(filename, encoding='utf-8', mode="r") as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
def _save_json(self, filename, data):
|
||||
with open(filename, encoding='utf-8', mode="w") as f:
|
||||
json.dump(data, f, indent=4,sort_keys=True,
|
||||
separators=(',',' : '))
|
||||
return data
|
||||
|
||||
def _legacy_fileio(self, filename, IO, data=None):
|
||||
"""Old fileIO provided for backwards compatibility"""
|
||||
if IO == "save" and data != None:
|
||||
return self.save_json(filename, data)
|
||||
elif IO == "load" and data == None:
|
||||
return self.load_json(filename)
|
||||
elif IO == "check" and data == None:
|
||||
return self.is_valid_json(filename)
|
||||
else:
|
||||
raise InvalidFileIO("FileIO was called with invalid"
|
||||
" parameters")
|
||||
|
||||
def get_value(filename, key):
|
||||
with open(filename, encoding='utf-8', mode="r") as f:
|
||||
data = json.load(f)
|
||||
return data[key]
|
||||
|
||||
def set_value(filename, key, value):
|
||||
data = fileIO(filename, "load")
|
||||
data[key] = value
|
||||
fileIO(filename, "save", data)
|
||||
return True
|
||||
|
||||
dataIO = DataIO()
|
||||
fileIO = dataIO._legacy_fileio # backwards compatibility
|
||||
|
|
@ -0,0 +1,291 @@
|
|||
from .dataIO import dataIO
|
||||
from copy import deepcopy
|
||||
import discord
|
||||
import os
|
||||
import argparse
|
||||
|
||||
|
||||
default_path = "data/red/settings.json"
|
||||
|
||||
|
||||
class Settings:
|
||||
|
||||
def __init__(self, path=default_path, parse_args=True):
|
||||
self.path = path
|
||||
self.check_folders()
|
||||
self.default_settings = {
|
||||
"TOKEN": None,
|
||||
"EMAIL": None,
|
||||
"PASSWORD": None,
|
||||
"OWNER": None,
|
||||
"PREFIXES": [],
|
||||
"default": {"ADMIN_ROLE": "Transistor",
|
||||
"MOD_ROLE": "Process",
|
||||
"PREFIXES": []}
|
||||
}
|
||||
self._memory_only = False
|
||||
|
||||
if not dataIO.is_valid_json(self.path):
|
||||
self.bot_settings = deepcopy(self.default_settings)
|
||||
self.save_settings()
|
||||
else:
|
||||
current = dataIO.load_json(self.path)
|
||||
if current.keys() != self.default_settings.keys():
|
||||
for key in self.default_settings.keys():
|
||||
if key not in current.keys():
|
||||
current[key] = self.default_settings[key]
|
||||
print("Adding " + str(key) +
|
||||
" field to red settings.json")
|
||||
dataIO.save_json(self.path, current)
|
||||
self.bot_settings = dataIO.load_json(self.path)
|
||||
|
||||
if "default" not in self.bot_settings:
|
||||
self.update_old_settings_v1()
|
||||
|
||||
if "LOGIN_TYPE" in self.bot_settings:
|
||||
self.update_old_settings_v2()
|
||||
if parse_args:
|
||||
self.parse_cmd_arguments()
|
||||
|
||||
def parse_cmd_arguments(self):
|
||||
parser = argparse.ArgumentParser(description="Red - Discord Bot")
|
||||
parser.add_argument("--owner", help="ID of the owner. Only who hosts "
|
||||
"Red should be owner, this has "
|
||||
"security implications")
|
||||
parser.add_argument("--prefix", "-p", action="append",
|
||||
help="Global prefix. Can be multiple")
|
||||
parser.add_argument("--admin-role", help="Role seen as admin role by "
|
||||
"Red")
|
||||
parser.add_argument("--mod-role", help="Role seen as mod role by Red")
|
||||
parser.add_argument("--no-prompt",
|
||||
action="store_true",
|
||||
help="Disables console inputs. Features requiring "
|
||||
"console interaction could be disabled as a "
|
||||
"result")
|
||||
parser.add_argument("--no-cogs",
|
||||
action="store_true",
|
||||
help="Starts Red with no cogs loaded, only core")
|
||||
parser.add_argument("--self-bot",
|
||||
action='store_true',
|
||||
help="Specifies if Red should log in as selfbot")
|
||||
parser.add_argument("--memory-only",
|
||||
action="store_true",
|
||||
help="Arguments passed and future edits to the "
|
||||
"settings will not be saved to disk")
|
||||
parser.add_argument("--dry-run",
|
||||
action="store_true",
|
||||
help="Makes Red quit with code 0 just before the "
|
||||
"login. This is useful for testing the boot "
|
||||
"process.")
|
||||
parser.add_argument("--debug",
|
||||
action="store_true",
|
||||
help="Enables debug mode")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.owner:
|
||||
self.owner = args.owner
|
||||
if args.prefix:
|
||||
self.prefixes = sorted(args.prefix, reverse=True)
|
||||
if args.admin_role:
|
||||
self.default_admin = args.admin_role
|
||||
if args.mod_role:
|
||||
self.default_mod = args.mod_role
|
||||
|
||||
self.no_prompt = args.no_prompt
|
||||
self.self_bot = args.self_bot
|
||||
self._memory_only = args.memory_only
|
||||
self._no_cogs = args.no_cogs
|
||||
self.debug = args.debug
|
||||
self._dry_run = args.dry_run
|
||||
|
||||
self.save_settings()
|
||||
|
||||
def check_folders(self):
|
||||
folders = ("data", os.path.dirname(self.path), "cogs", "cogs/utils")
|
||||
for folder in folders:
|
||||
if not os.path.exists(folder):
|
||||
print("Creating " + folder + " folder...")
|
||||
os.makedirs(folder)
|
||||
|
||||
def save_settings(self):
|
||||
if not self._memory_only:
|
||||
dataIO.save_json(self.path, self.bot_settings)
|
||||
|
||||
def update_old_settings_v1(self):
|
||||
# This converts the old settings format
|
||||
mod = self.bot_settings["MOD_ROLE"]
|
||||
admin = self.bot_settings["ADMIN_ROLE"]
|
||||
del self.bot_settings["MOD_ROLE"]
|
||||
del self.bot_settings["ADMIN_ROLE"]
|
||||
self.bot_settings["default"] = {"MOD_ROLE": mod,
|
||||
"ADMIN_ROLE": admin,
|
||||
"PREFIXES": []
|
||||
}
|
||||
self.save_settings()
|
||||
|
||||
def update_old_settings_v2(self):
|
||||
# The joys of backwards compatibility
|
||||
settings = self.bot_settings
|
||||
if settings["EMAIL"] == "EmailHere":
|
||||
settings["EMAIL"] = None
|
||||
if settings["PASSWORD"] == "":
|
||||
settings["PASSWORD"] = None
|
||||
if settings["LOGIN_TYPE"] == "token":
|
||||
settings["TOKEN"] = settings["EMAIL"]
|
||||
settings["EMAIL"] = None
|
||||
settings["PASSWORD"] = None
|
||||
else:
|
||||
settings["TOKEN"] = None
|
||||
del settings["LOGIN_TYPE"]
|
||||
self.save_settings()
|
||||
|
||||
@property
|
||||
def owner(self):
|
||||
return self.bot_settings["OWNER"]
|
||||
|
||||
@owner.setter
|
||||
def owner(self, value):
|
||||
self.bot_settings["OWNER"] = value
|
||||
|
||||
@property
|
||||
def token(self):
|
||||
return os.environ.get("RED_TOKEN", self.bot_settings["TOKEN"])
|
||||
|
||||
@token.setter
|
||||
def token(self, value):
|
||||
self.bot_settings["TOKEN"] = value
|
||||
self.bot_settings["EMAIL"] = None
|
||||
self.bot_settings["PASSWORD"] = None
|
||||
|
||||
@property
|
||||
def email(self):
|
||||
return os.environ.get("RED_EMAIL", self.bot_settings["EMAIL"])
|
||||
|
||||
@email.setter
|
||||
def email(self, value):
|
||||
self.bot_settings["EMAIL"] = value
|
||||
self.bot_settings["TOKEN"] = None
|
||||
|
||||
@property
|
||||
def password(self):
|
||||
return os.environ.get("RED_PASSWORD", self.bot_settings["PASSWORD"])
|
||||
|
||||
@password.setter
|
||||
def password(self, value):
|
||||
self.bot_settings["PASSWORD"] = value
|
||||
|
||||
@property
|
||||
def login_credentials(self):
|
||||
if self.token:
|
||||
return (self.token,)
|
||||
elif self.email and self.password:
|
||||
return (self.email, self.password)
|
||||
else:
|
||||
return tuple()
|
||||
|
||||
@property
|
||||
def prefixes(self):
|
||||
return self.bot_settings["PREFIXES"]
|
||||
|
||||
@prefixes.setter
|
||||
def prefixes(self, value):
|
||||
assert isinstance(value, list)
|
||||
self.bot_settings["PREFIXES"] = value
|
||||
|
||||
@property
|
||||
def default_admin(self):
|
||||
if "default" not in self.bot_settings:
|
||||
self.update_old_settings()
|
||||
return self.bot_settings["default"].get("ADMIN_ROLE", "")
|
||||
|
||||
@default_admin.setter
|
||||
def default_admin(self, value):
|
||||
if "default" not in self.bot_settings:
|
||||
self.update_old_settings()
|
||||
self.bot_settings["default"]["ADMIN_ROLE"] = value
|
||||
|
||||
@property
|
||||
def default_mod(self):
|
||||
if "default" not in self.bot_settings:
|
||||
self.update_old_settings_v1()
|
||||
return self.bot_settings["default"].get("MOD_ROLE", "")
|
||||
|
||||
@default_mod.setter
|
||||
def default_mod(self, value):
|
||||
if "default" not in self.bot_settings:
|
||||
self.update_old_settings_v1()
|
||||
self.bot_settings["default"]["MOD_ROLE"] = value
|
||||
|
||||
@property
|
||||
def servers(self):
|
||||
ret = {}
|
||||
server_ids = list(
|
||||
filter(lambda x: str(x).isdigit(), self.bot_settings))
|
||||
for server in server_ids:
|
||||
ret.update({server: self.bot_settings[server]})
|
||||
return ret
|
||||
|
||||
def get_server(self, server):
|
||||
if server is None:
|
||||
return self.bot_settings["default"].copy()
|
||||
assert isinstance(server, discord.Server)
|
||||
return self.bot_settings.get(server.id,
|
||||
self.bot_settings["default"]).copy()
|
||||
|
||||
def get_server_admin(self, server):
|
||||
if server is None:
|
||||
return self.default_admin
|
||||
assert isinstance(server, discord.Server)
|
||||
if server.id not in self.bot_settings:
|
||||
return self.default_admin
|
||||
return self.bot_settings[server.id].get("ADMIN_ROLE", "")
|
||||
|
||||
def set_server_admin(self, server, value):
|
||||
if server is None:
|
||||
return
|
||||
assert isinstance(server, discord.Server)
|
||||
if server.id not in self.bot_settings:
|
||||
self.add_server(server.id)
|
||||
self.bot_settings[server.id]["ADMIN_ROLE"] = value
|
||||
self.save_settings()
|
||||
|
||||
def get_server_mod(self, server):
|
||||
if server is None:
|
||||
return self.default_mod
|
||||
assert isinstance(server, discord.Server)
|
||||
if server.id not in self.bot_settings:
|
||||
return self.default_mod
|
||||
return self.bot_settings[server.id].get("MOD_ROLE", "")
|
||||
|
||||
def set_server_mod(self, server, value):
|
||||
if server is None:
|
||||
return
|
||||
assert isinstance(server, discord.Server)
|
||||
if server.id not in self.bot_settings:
|
||||
self.add_server(server.id)
|
||||
self.bot_settings[server.id]["MOD_ROLE"] = value
|
||||
self.save_settings()
|
||||
|
||||
def get_server_prefixes(self, server):
|
||||
if server is None or server.id not in self.bot_settings:
|
||||
return self.prefixes
|
||||
return self.bot_settings[server.id].get("PREFIXES", [])
|
||||
|
||||
def set_server_prefixes(self, server, prefixes):
|
||||
if server is None:
|
||||
return
|
||||
assert isinstance(server, discord.Server)
|
||||
if server.id not in self.bot_settings:
|
||||
self.add_server(server.id)
|
||||
self.bot_settings[server.id]["PREFIXES"] = prefixes
|
||||
self.save_settings()
|
||||
|
||||
def get_prefixes(self, server):
|
||||
"""Returns server's prefixes if set, otherwise global ones"""
|
||||
p = self.get_server_prefixes(server)
|
||||
return p if p else self.prefixes
|
||||
|
||||
def add_server(self, sid):
|
||||
self.bot_settings[sid] = self.bot_settings["default"].copy()
|
||||
self.save_settings()
|
||||
|
|
@ -0,0 +1,573 @@
|
|||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
try: # Older Pythons lack this
|
||||
import urllib.request # We'll let them reach the Python
|
||||
from importlib.util import find_spec # check anyway
|
||||
except ImportError:
|
||||
pass
|
||||
import platform
|
||||
import webbrowser
|
||||
import hashlib
|
||||
import argparse
|
||||
import shutil
|
||||
import stat
|
||||
import time
|
||||
try:
|
||||
import pip
|
||||
except ImportError:
|
||||
pip = None
|
||||
|
||||
REQS_DIR = "lib"
|
||||
sys.path.insert(0, REQS_DIR)
|
||||
REQS_TXT = "requirements.txt"
|
||||
REQS_NO_AUDIO_TXT = "requirements_no_audio.txt"
|
||||
FFMPEG_BUILDS_URL = "https://ffmpeg.zeranoe.com/builds/"
|
||||
|
||||
INTRO = ("==========================\n"
|
||||
"Red Discord Bot - Launcher\n"
|
||||
"==========================\n")
|
||||
|
||||
IS_WINDOWS = os.name == "nt"
|
||||
IS_MAC = sys.platform == "darwin"
|
||||
IS_64BIT = platform.machine().endswith("64")
|
||||
INTERACTIVE_MODE = not len(sys.argv) > 1 # CLI flags = non-interactive
|
||||
PYTHON_OK = sys.version_info >= (3, 5)
|
||||
|
||||
FFMPEG_FILES = {
|
||||
"ffmpeg.exe" : "e0d60f7c0d27ad9d7472ddf13e78dc89",
|
||||
"ffplay.exe" : "d100abe8281cbcc3e6aebe550c675e09",
|
||||
"ffprobe.exe" : "0e84b782c0346a98434ed476e937764f"
|
||||
}
|
||||
|
||||
|
||||
def parse_cli_arguments():
|
||||
parser = argparse.ArgumentParser(description="Red - Discord Bot's launcher")
|
||||
parser.add_argument("--start", "-s",
|
||||
help="Starts Red",
|
||||
action="store_true")
|
||||
parser.add_argument("--auto-restart",
|
||||
help="Autorestarts Red in case of issues",
|
||||
action="store_true")
|
||||
parser.add_argument("--update-red",
|
||||
help="Updates Red (git)",
|
||||
action="store_true")
|
||||
parser.add_argument("--update-reqs",
|
||||
help="Updates requirements (w/ audio)",
|
||||
action="store_true")
|
||||
parser.add_argument("--update-reqs-no-audio",
|
||||
help="Updates requirements (w/o audio)",
|
||||
action="store_true")
|
||||
parser.add_argument("--repair",
|
||||
help="Issues a git reset --hard",
|
||||
action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def install_reqs(audio):
|
||||
remove_reqs_readonly()
|
||||
interpreter = sys.executable
|
||||
|
||||
if interpreter is None:
|
||||
print("Python interpreter not found.")
|
||||
return
|
||||
|
||||
txt = REQS_TXT if audio else REQS_NO_AUDIO_TXT
|
||||
|
||||
args = [
|
||||
interpreter, "-m",
|
||||
"pip", "install",
|
||||
"--upgrade",
|
||||
"--target", REQS_DIR,
|
||||
"-r", txt
|
||||
]
|
||||
|
||||
if IS_MAC: # --target is a problem on Homebrew. See PR #552
|
||||
args.remove("--target")
|
||||
args.remove(REQS_DIR)
|
||||
|
||||
code = subprocess.call(args)
|
||||
|
||||
if code == 0:
|
||||
print("\nRequirements setup completed.")
|
||||
else:
|
||||
print("\nAn error occurred and the requirements setup might "
|
||||
"not be completed. Consult the docs.\n")
|
||||
|
||||
|
||||
def update_pip():
|
||||
interpreter = sys.executable
|
||||
|
||||
if interpreter is None:
|
||||
print("Python interpreter not found.")
|
||||
return
|
||||
|
||||
args = [
|
||||
interpreter, "-m",
|
||||
"pip", "install",
|
||||
"--upgrade", "pip"
|
||||
]
|
||||
|
||||
code = subprocess.call(args)
|
||||
|
||||
if code == 0:
|
||||
print("\nPip has been updated.")
|
||||
else:
|
||||
print("\nAn error occurred and pip might not have been updated.")
|
||||
|
||||
|
||||
def update_red():
|
||||
try:
|
||||
code = subprocess.call(("git", "pull", "--ff-only"))
|
||||
except FileNotFoundError:
|
||||
print("\nError: Git not found. It's either not installed or not in "
|
||||
"the PATH environment variable like requested in the guide.")
|
||||
return
|
||||
if code == 0:
|
||||
print("\nRed has been updated")
|
||||
else:
|
||||
print("\nRed could not update properly. If this is caused by edits "
|
||||
"you have made to the code you can try the repair option from "
|
||||
"the Maintenance submenu")
|
||||
|
||||
|
||||
def reset_red(reqs=False, data=False, cogs=False, git_reset=False):
|
||||
if reqs:
|
||||
try:
|
||||
shutil.rmtree(REQS_DIR, onerror=remove_readonly)
|
||||
print("Installed local packages have been wiped.")
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print("An error occurred when trying to remove installed "
|
||||
"requirements: {}".format(e))
|
||||
if data:
|
||||
try:
|
||||
shutil.rmtree("data", onerror=remove_readonly)
|
||||
print("'data' folder has been wiped.")
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print("An error occurred when trying to remove the 'data' folder: "
|
||||
"{}".format(e))
|
||||
|
||||
if cogs:
|
||||
try:
|
||||
shutil.rmtree("cogs", onerror=remove_readonly)
|
||||
print("'cogs' folder has been wiped.")
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print("An error occurred when trying to remove the 'cogs' folder: "
|
||||
"{}".format(e))
|
||||
|
||||
if git_reset:
|
||||
code = subprocess.call(("git", "reset", "--hard"))
|
||||
if code == 0:
|
||||
print("Red has been restored to the last local commit.")
|
||||
else:
|
||||
print("The repair has failed.")
|
||||
|
||||
|
||||
def download_ffmpeg(bitness):
|
||||
clear_screen()
|
||||
repo = "https://github.com/Twentysix26/Red-DiscordBot/raw/master/"
|
||||
verified = []
|
||||
|
||||
if bitness == "32bit":
|
||||
print("Please download 'ffmpeg 32bit static' from the page that "
|
||||
"is about to open.\nOnce done, open the 'bin' folder located "
|
||||
"inside the zip.\nThere should be 3 files: ffmpeg.exe, "
|
||||
"ffplay.exe, ffprobe.exe.\nPut all three of them into the "
|
||||
"bot's main folder.")
|
||||
time.sleep(4)
|
||||
webbrowser.open(FFMPEG_BUILDS_URL)
|
||||
return
|
||||
|
||||
for filename in FFMPEG_FILES:
|
||||
if os.path.isfile(filename):
|
||||
print("{} already present. Verifying integrity... "
|
||||
"".format(filename), end="")
|
||||
_hash = calculate_md5(filename)
|
||||
if _hash == FFMPEG_FILES[filename]:
|
||||
verified.append(filename)
|
||||
print("Ok")
|
||||
continue
|
||||
else:
|
||||
print("Hash mismatch. Redownloading.")
|
||||
print("Downloading {}... Please wait.".format(filename))
|
||||
with urllib.request.urlopen(repo + filename) as data:
|
||||
with open(filename, "wb") as f:
|
||||
f.write(data.read())
|
||||
print("Download completed.")
|
||||
|
||||
for filename, _hash in FFMPEG_FILES.items():
|
||||
if filename in verified:
|
||||
continue
|
||||
print("Verifying {}... ".format(filename), end="")
|
||||
if not calculate_md5(filename) != _hash:
|
||||
print("Passed.")
|
||||
else:
|
||||
print("Hash mismatch. Please redownload.")
|
||||
|
||||
print("\nAll files have been downloaded.")
|
||||
|
||||
|
||||
def verify_requirements():
|
||||
sys.path_importer_cache = {} # I don't know if the cache reset has any
|
||||
basic = find_spec("discord") # side effect. Without it, the lib folder
|
||||
audio = find_spec("nacl") # wouldn't be seen if it didn't exist
|
||||
if not basic: # when the launcher was started
|
||||
return None
|
||||
elif not audio:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def is_git_installed():
|
||||
try:
|
||||
subprocess.call(["git", "--version"], stdout=subprocess.DEVNULL,
|
||||
stdin =subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL)
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def requirements_menu():
|
||||
clear_screen()
|
||||
while True:
|
||||
print(INTRO)
|
||||
print("Main requirements:\n")
|
||||
print("1. Install basic + audio requirements (recommended)")
|
||||
print("2. Install basic requirements")
|
||||
if IS_WINDOWS:
|
||||
print("\nffmpeg (required for audio):")
|
||||
print("3. Install ffmpeg 32bit")
|
||||
if IS_64BIT:
|
||||
print("4. Install ffmpeg 64bit (recommended on Windows 64bit)")
|
||||
print("\n0. Go back")
|
||||
choice = user_choice()
|
||||
if choice == "1":
|
||||
install_reqs(audio=True)
|
||||
wait()
|
||||
elif choice == "2":
|
||||
install_reqs(audio=False)
|
||||
wait()
|
||||
elif choice == "3" and IS_WINDOWS:
|
||||
download_ffmpeg(bitness="32bit")
|
||||
wait()
|
||||
elif choice == "4" and (IS_WINDOWS and IS_64BIT):
|
||||
download_ffmpeg(bitness="64bit")
|
||||
wait()
|
||||
elif choice == "0":
|
||||
break
|
||||
clear_screen()
|
||||
|
||||
|
||||
def update_menu():
|
||||
clear_screen()
|
||||
while True:
|
||||
print(INTRO)
|
||||
reqs = verify_requirements()
|
||||
if reqs is None:
|
||||
status = "No requirements installed"
|
||||
elif reqs is False:
|
||||
status = "Basic requirements installed (no audio)"
|
||||
else:
|
||||
status = "Basic + audio requirements installed"
|
||||
print("Status: " + status + "\n")
|
||||
print("Update:\n")
|
||||
print("Red:")
|
||||
print("1. Update Red + requirements (recommended)")
|
||||
print("2. Update Red")
|
||||
print("3. Update requirements")
|
||||
print("\nOthers:")
|
||||
print("4. Update pip (might require admin privileges)")
|
||||
print("\n0. Go back")
|
||||
choice = user_choice()
|
||||
if choice == "1":
|
||||
update_red()
|
||||
print("Updating requirements...")
|
||||
reqs = verify_requirements()
|
||||
if reqs is not None:
|
||||
install_reqs(audio=reqs)
|
||||
else:
|
||||
print("The requirements haven't been installed yet.")
|
||||
wait()
|
||||
elif choice == "2":
|
||||
update_red()
|
||||
wait()
|
||||
elif choice == "3":
|
||||
reqs = verify_requirements()
|
||||
if reqs is not None:
|
||||
install_reqs(audio=reqs)
|
||||
else:
|
||||
print("The requirements haven't been installed yet.")
|
||||
wait()
|
||||
elif choice == "4":
|
||||
update_pip()
|
||||
wait()
|
||||
elif choice == "0":
|
||||
break
|
||||
clear_screen()
|
||||
|
||||
|
||||
def maintenance_menu():
|
||||
clear_screen()
|
||||
while True:
|
||||
print(INTRO)
|
||||
print("Maintenance:\n")
|
||||
print("1. Repair Red (discards code changes, keeps data intact)")
|
||||
print("2. Wipe 'data' folder (all settings, cogs' data...)")
|
||||
print("3. Wipe 'lib' folder (all local requirements / local installed"
|
||||
" python packages)")
|
||||
print("4. Factory reset")
|
||||
print("\n0. Go back")
|
||||
choice = user_choice()
|
||||
if choice == "1":
|
||||
print("Any code modification you have made will be lost. Data/"
|
||||
"non-default cogs will be left intact. Are you sure?")
|
||||
if user_pick_yes_no():
|
||||
reset_red(git_reset=True)
|
||||
wait()
|
||||
elif choice == "2":
|
||||
print("Are you sure? This will wipe the 'data' folder, which "
|
||||
"contains all your settings and cogs' data.\nThe 'cogs' "
|
||||
"folder, however, will be left intact.")
|
||||
if user_pick_yes_no():
|
||||
reset_red(data=True)
|
||||
wait()
|
||||
elif choice == "3":
|
||||
reset_red(reqs=True)
|
||||
wait()
|
||||
elif choice == "4":
|
||||
print("Are you sure? This will wipe ALL your Red's installation "
|
||||
"data.\nYou'll lose all your settings, cogs and any "
|
||||
"modification you have made.\nThere is no going back.")
|
||||
if user_pick_yes_no():
|
||||
reset_red(reqs=True, data=True, cogs=True, git_reset=True)
|
||||
wait()
|
||||
elif choice == "0":
|
||||
break
|
||||
clear_screen()
|
||||
|
||||
|
||||
def run_red(autorestart):
|
||||
interpreter = sys.executable
|
||||
|
||||
if interpreter is None: # This should never happen
|
||||
raise RuntimeError("Couldn't find Python's interpreter")
|
||||
|
||||
if verify_requirements() is None:
|
||||
print("You don't have the requirements to start Red. "
|
||||
"Install them from the launcher.")
|
||||
if not INTERACTIVE_MODE:
|
||||
exit(1)
|
||||
|
||||
cmd = (interpreter, "red.py")
|
||||
|
||||
while True:
|
||||
try:
|
||||
code = subprocess.call(cmd)
|
||||
except KeyboardInterrupt:
|
||||
code = 0
|
||||
break
|
||||
else:
|
||||
if code == 0:
|
||||
break
|
||||
elif code == 26:
|
||||
print("Restarting Red...")
|
||||
continue
|
||||
else:
|
||||
if not autorestart:
|
||||
break
|
||||
|
||||
print("Red has been terminated. Exit code: %d" % code)
|
||||
|
||||
if INTERACTIVE_MODE:
|
||||
wait()
|
||||
|
||||
|
||||
def clear_screen():
|
||||
if IS_WINDOWS:
|
||||
os.system("cls")
|
||||
else:
|
||||
os.system("clear")
|
||||
|
||||
|
||||
def wait():
|
||||
if INTERACTIVE_MODE:
|
||||
input("Press enter to continue.")
|
||||
|
||||
|
||||
def user_choice():
|
||||
return input("> ").lower().strip()
|
||||
|
||||
|
||||
def user_pick_yes_no():
|
||||
choice = None
|
||||
yes = ("yes", "y")
|
||||
no = ("no", "n")
|
||||
while choice not in yes and choice not in no:
|
||||
choice = input("Yes/No > ").lower().strip()
|
||||
return choice in yes
|
||||
|
||||
|
||||
def remove_readonly(func, path, excinfo):
|
||||
os.chmod(path, 0o755)
|
||||
func(path)
|
||||
|
||||
|
||||
def remove_reqs_readonly():
|
||||
"""Workaround for issue #569"""
|
||||
if not os.path.isdir(REQS_DIR):
|
||||
return
|
||||
os.chmod(REQS_DIR, 0o755)
|
||||
for root, dirs, files in os.walk(REQS_DIR):
|
||||
for d in dirs:
|
||||
os.chmod(os.path.join(root, d), 0o755)
|
||||
for f in files:
|
||||
os.chmod(os.path.join(root, f), 0o755)
|
||||
|
||||
|
||||
def calculate_md5(filename):
|
||||
hash_md5 = hashlib.md5()
|
||||
with open(filename, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
hash_md5.update(chunk)
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def create_fast_start_scripts():
|
||||
"""Creates scripts for fast boot of Red without going
|
||||
through the launcher"""
|
||||
interpreter = sys.executable
|
||||
if not interpreter:
|
||||
return
|
||||
|
||||
call = "\"{}\" launcher.py".format(interpreter)
|
||||
start_red = "{} --start".format(call)
|
||||
start_red_autorestart = "{} --start --auto-restart".format(call)
|
||||
modified = False
|
||||
|
||||
if IS_WINDOWS:
|
||||
ccd = "pushd %~dp0\n"
|
||||
pause = "\npause"
|
||||
ext = ".bat"
|
||||
else:
|
||||
ccd = 'cd "$(dirname "$0")"\n'
|
||||
pause = "\nread -rsp $'Press enter to continue...\\n'"
|
||||
if not IS_MAC:
|
||||
ext = ".sh"
|
||||
else:
|
||||
ext = ".command"
|
||||
|
||||
start_red = ccd + start_red + pause
|
||||
start_red_autorestart = ccd + start_red_autorestart + pause
|
||||
|
||||
files = {
|
||||
"start_red" + ext : start_red,
|
||||
"start_red_autorestart" + ext : start_red_autorestart
|
||||
}
|
||||
|
||||
if not IS_WINDOWS:
|
||||
files["start_launcher" + ext] = ccd + call
|
||||
|
||||
for filename, content in files.items():
|
||||
if not os.path.isfile(filename):
|
||||
print("Creating {}... (fast start scripts)".format(filename))
|
||||
modified = True
|
||||
with open(filename, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
if not IS_WINDOWS and modified: # Let's make them executable on Unix
|
||||
for script in files:
|
||||
st = os.stat(script)
|
||||
os.chmod(script, st.st_mode | stat.S_IEXEC)
|
||||
|
||||
|
||||
def main():
|
||||
print("Verifying git installation...")
|
||||
has_git = is_git_installed()
|
||||
is_git_installation = os.path.isdir(".git")
|
||||
if IS_WINDOWS:
|
||||
os.system("TITLE Red Discord Bot - Launcher")
|
||||
clear_screen()
|
||||
|
||||
try:
|
||||
create_fast_start_scripts()
|
||||
except Exception as e:
|
||||
print("Failed making fast start scripts: {}\n".format(e))
|
||||
|
||||
while True:
|
||||
print(INTRO)
|
||||
|
||||
if not is_git_installation:
|
||||
print("WARNING: It doesn't look like Red has been "
|
||||
"installed with git.\nThis means that you won't "
|
||||
"be able to update and some features won't be working.\n"
|
||||
"A reinstallation is recommended. Follow the guide "
|
||||
"properly this time:\n"
|
||||
"https://twentysix26.github.io/Red-Docs/\n")
|
||||
|
||||
if not has_git:
|
||||
print("WARNING: Git not found. This means that it's either not "
|
||||
"installed or not in the PATH environment variable like "
|
||||
"requested in the guide.\n")
|
||||
|
||||
print("1. Run Red /w autorestart in case of issues")
|
||||
print("2. Run Red")
|
||||
print("3. Update")
|
||||
print("4. Install requirements")
|
||||
print("5. Maintenance (repair, reset...)")
|
||||
print("\n0. Quit")
|
||||
choice = user_choice()
|
||||
if choice == "1":
|
||||
run_red(autorestart=True)
|
||||
elif choice == "2":
|
||||
run_red(autorestart=False)
|
||||
elif choice == "3":
|
||||
update_menu()
|
||||
elif choice == "4":
|
||||
requirements_menu()
|
||||
elif choice == "5":
|
||||
maintenance_menu()
|
||||
elif choice == "0":
|
||||
break
|
||||
clear_screen()
|
||||
|
||||
args = parse_cli_arguments()
|
||||
|
||||
if __name__ == '__main__':
|
||||
abspath = os.path.abspath(__file__)
|
||||
dirname = os.path.dirname(abspath)
|
||||
# Sets current directory to the script's
|
||||
os.chdir(dirname)
|
||||
if not PYTHON_OK:
|
||||
print("Red needs Python 3.5 or superior. Install the required "
|
||||
"version.\nPress enter to continue.")
|
||||
if INTERACTIVE_MODE:
|
||||
wait()
|
||||
exit(1)
|
||||
if pip is None:
|
||||
print("Red cannot work without the pip module. Please make sure to "
|
||||
"install Python without unchecking any option during the setup")
|
||||
wait()
|
||||
exit(1)
|
||||
if args.repair:
|
||||
reset_red(git_reset=True)
|
||||
if args.update_red:
|
||||
update_red()
|
||||
if args.update_reqs:
|
||||
install_reqs(audio=True)
|
||||
elif args.update_reqs_no_audio:
|
||||
install_reqs(audio=False)
|
||||
if INTERACTIVE_MODE:
|
||||
main()
|
||||
elif args.start:
|
||||
print("Starting Red...")
|
||||
run_red(autorestart=args.auto_restart)
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
PyNaCl
|
||||
======
|
||||
|
||||
.. image:: https://pypip.in/version/PyNaCl/badge.svg?style=flat
|
||||
:target: https://pypi.python.org/pypi/PyNaCl/
|
||||
:alt: Latest Version
|
||||
|
||||
.. image:: https://travis-ci.org/pyca/pynacl.svg?branch=master
|
||||
:target: https://travis-ci.org/pyca/pynacl
|
||||
|
||||
.. image:: https://coveralls.io/repos/pyca/pynacl/badge.svg?branch=master
|
||||
:target: https://coveralls.io/r/pyca/pynacl?branch=master
|
||||
|
||||
PyNaCl is a Python binding to the `Networking and Cryptography library`_,
|
||||
a crypto library with the stated goal of improving usability, security and
|
||||
speed.
|
||||
|
||||
.. _Networking and Cryptography library: https://nacl.cr.yp.to/
|
||||
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
|
||||
Linux
|
||||
~~~~~
|
||||
|
||||
PyNaCl relies on libsodium_, a portable C library. A copy is bundled
|
||||
with PyNaCl so to install you can run:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install pynacl
|
||||
|
||||
If you'd prefer to use one provided by your distribution you can disable
|
||||
the bundled copy during install by running:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ SODIUM_INSTALL=system pip install pynacl
|
||||
|
||||
|
||||
.. _libsodium: https://github.com/jedisct1/libsodium
|
||||
|
||||
Mac OS X & Windows
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
PyNaCl ships as a binary wheel on OS X and Windows so all dependencies
|
||||
are included. Make sure you have an up-to-date pip and run:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install pynacl
|
||||
|
||||
|
||||
Features
|
||||
--------
|
||||
|
||||
* Digital signatures
|
||||
* Secret-key encryption
|
||||
* Public-key encryption
|
||||
|
||||
|
||||
Changes
|
||||
-------
|
||||
|
||||
* 1.0.1:
|
||||
|
||||
* Fix an issue with absolute paths that prevented the creation of wheels.
|
||||
|
||||
* 1.0:
|
||||
|
||||
* PyNaCl has been ported to use the new APIs available in cffi 1.0+.
|
||||
Due to this change we no longer support PyPy releases older than 2.6.
|
||||
|
||||
* Python 3.2 support has been dropped.
|
||||
|
||||
* Functions to convert between Ed25519 and Curve25519 keys have been added.
|
||||
|
||||
* 0.3.0:
|
||||
|
||||
* The low-level API (`nacl.c.*`) has been changed to match the
|
||||
upstream NaCl C/C++ conventions (as well as those of other NaCl bindings).
|
||||
The order of arguments and return values has changed significantly. To
|
||||
avoid silent failures, `nacl.c` has been removed, and replaced with
|
||||
`nacl.bindings` (with the new argument ordering). If you have code which
|
||||
calls these functions (e.g. `nacl.c.crypto_box_keypair()`), you must review
|
||||
the new docstrings and update your code/imports to match the new
|
||||
conventions.
|
||||
|
||||
|
||||
|
|
@ -0,0 +1 @@
|
|||
pip
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
Metadata-Version: 2.0
|
||||
Name: PyNaCl
|
||||
Version: 1.0.1
|
||||
Summary: Python binding to the Networking and Cryptography (NaCl) library
|
||||
Home-page: https://github.com/pyca/pynacl/
|
||||
Author: The PyNaCl developers
|
||||
Author-email: cryptography-dev@python.org
|
||||
License: Apache License 2.0
|
||||
Platform: UNKNOWN
|
||||
Classifier: Programming Language :: Python :: Implementation :: CPython
|
||||
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
||||
Classifier: Programming Language :: Python :: 2
|
||||
Classifier: Programming Language :: Python :: 2.6
|
||||
Classifier: Programming Language :: Python :: 2.7
|
||||
Classifier: Programming Language :: Python :: 3
|
||||
Classifier: Programming Language :: Python :: 3.3
|
||||
Classifier: Programming Language :: Python :: 3.4
|
||||
Classifier: Programming Language :: Python :: 3.5
|
||||
Requires-Dist: cffi (>=1.1.0)
|
||||
Requires-Dist: six
|
||||
Provides-Extra: tests
|
||||
Requires-Dist: pytest; extra == 'tests'
|
||||
|
||||
PyNaCl
|
||||
======
|
||||
|
||||
.. image:: https://pypip.in/version/PyNaCl/badge.svg?style=flat
|
||||
:target: https://pypi.python.org/pypi/PyNaCl/
|
||||
:alt: Latest Version
|
||||
|
||||
.. image:: https://travis-ci.org/pyca/pynacl.svg?branch=master
|
||||
:target: https://travis-ci.org/pyca/pynacl
|
||||
|
||||
.. image:: https://coveralls.io/repos/pyca/pynacl/badge.svg?branch=master
|
||||
:target: https://coveralls.io/r/pyca/pynacl?branch=master
|
||||
|
||||
PyNaCl is a Python binding to the `Networking and Cryptography library`_,
|
||||
a crypto library with the stated goal of improving usability, security and
|
||||
speed.
|
||||
|
||||
.. _Networking and Cryptography library: https://nacl.cr.yp.to/
|
||||
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
|
||||
Linux
|
||||
~~~~~
|
||||
|
||||
PyNaCl relies on libsodium_, a portable C library. A copy is bundled
|
||||
with PyNaCl so to install you can run:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install pynacl
|
||||
|
||||
If you'd prefer to use one provided by your distribution you can disable
|
||||
the bundled copy during install by running:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ SODIUM_INSTALL=system pip install pynacl
|
||||
|
||||
|
||||
.. _libsodium: https://github.com/jedisct1/libsodium
|
||||
|
||||
Mac OS X & Windows
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
PyNaCl ships as a binary wheel on OS X and Windows so all dependencies
|
||||
are included. Make sure you have an up-to-date pip and run:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install pynacl
|
||||
|
||||
|
||||
Features
|
||||
--------
|
||||
|
||||
* Digital signatures
|
||||
* Secret-key encryption
|
||||
* Public-key encryption
|
||||
|
||||
|
||||
Changes
|
||||
-------
|
||||
|
||||
* 1.0.1:
|
||||
|
||||
* Fix an issue with absolute paths that prevented the creation of wheels.
|
||||
|
||||
* 1.0:
|
||||
|
||||
* PyNaCl has been ported to use the new APIs available in cffi 1.0+.
|
||||
Due to this change we no longer support PyPy releases older than 2.6.
|
||||
|
||||
* Python 3.2 support has been dropped.
|
||||
|
||||
* Functions to convert between Ed25519 and Curve25519 keys have been added.
|
||||
|
||||
* 0.3.0:
|
||||
|
||||
* The low-level API (`nacl.c.*`) has been changed to match the
|
||||
upstream NaCl C/C++ conventions (as well as those of other NaCl bindings).
|
||||
The order of arguments and return values has changed significantly. To
|
||||
avoid silent failures, `nacl.c` has been removed, and replaced with
|
||||
`nacl.bindings` (with the new argument ordering). If you have code which
|
||||
calls these functions (e.g. `nacl.c.crypto_box_keypair()`), you must review
|
||||
the new docstrings and update your code/imports to match the new
|
||||
conventions.
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
PyNaCl-1.0.1.dist-info/DESCRIPTION.rst,sha256=3qnL8XwzGV1BtOJoieda8zEbuLvAY0wWIroa08MdjXY,2310
|
||||
PyNaCl-1.0.1.dist-info/METADATA,sha256=05N1teI88uA_YyplqlutrzCcpWNADo_66_DDqH483Eo,3194
|
||||
PyNaCl-1.0.1.dist-info/RECORD,,
|
||||
PyNaCl-1.0.1.dist-info/WHEEL,sha256=xiHTm3JxoVljPSD6nSGhq3B4VY9iUqMNXwYQ259n1PI,102
|
||||
PyNaCl-1.0.1.dist-info/metadata.json,sha256=a5yJUJb5gjQCsCaQGG4I-ZbQFD9z8LMYAstuJ8z8XZ0,1064
|
||||
PyNaCl-1.0.1.dist-info/top_level.txt,sha256=wfdEOI_G2RIzmzsMyhpqP17HUh6Jcqi99to9aHLEslo,13
|
||||
nacl/__init__.py,sha256=avK2K7KSLic4oYjD2vCcAYqKG5Ed277wMfZuwNwr2oc,1170
|
||||
nacl/_sodium.cp36-win32.pyd,sha256=XdBkQdVqquHCCaULVcrij_XAtlGq3faEgYwe96huYT4,183296
|
||||
nacl/encoding.py,sha256=tOiyIQVVpGU6A4Lzr0tMuqomhc_Aj0V_c1t56a-ZtPw,1928
|
||||
nacl/exceptions.py,sha256=cY0MvWUHpa443Qi9ZjikX2bg2zWC4ko0vChO90JEaf4,881
|
||||
nacl/hash.py,sha256=SP9wJIcs5bOg2l52JCpqe_p1BjwOA_NYdWsuHuJ9cFs,962
|
||||
nacl/public.py,sha256=BVD0UMu26mYhMrjxuKBC2xzP7YwILO0eLAKN5ZMTMK0,7519
|
||||
nacl/secret.py,sha256=704VLB1VR0FO8vuAILG2O3Idh8KYf2S3jAu_fidf5cU,4777
|
||||
nacl/signing.py,sha256=X8I0AUhA5jZH0pZi9c4uRqg9LQvCXZ1uBmGHJ1QTdmY,6661
|
||||
nacl/utils.py,sha256=E8TKyHN6g_xCOn7eB9KrmO7JIbayuLXbcpv4kWBu_rQ,1601
|
||||
nacl/bindings/__init__.py,sha256=GklZvnvt_q9Mlo9XOIG490c-y-G8CVD6gaWqyGncXtQ,3164
|
||||
nacl/bindings/crypto_box.py,sha256=k2hnwFH5nSyUBqaIYYWXpGJOlYvssC46XmIqxqH811k,5603
|
||||
nacl/bindings/crypto_hash.py,sha256=kBA-JVoRt9WkkFcF--prKz5BnWyHzLmif-PLt7FM4MU,1942
|
||||
nacl/bindings/crypto_scalarmult.py,sha256=VHTiWhkhbXxUiok9SyFxHzgEhBBYMZ7kuPpNZtW9nGs,1579
|
||||
nacl/bindings/crypto_secretbox.py,sha256=Q_E3fpCfhyvaKkb7ndwRfaHrMZk0aTDWUGWpguEUXeA,2641
|
||||
nacl/bindings/crypto_sign.py,sha256=fD_346rtF3CCtlXkjy30A-HGkXY1Z8GOSRv_0kUTsFg,4857
|
||||
nacl/bindings/randombytes.py,sha256=eThts6s-9xBXOl3GNzT57fV1dZUhzPjjAmAVIUHfcrc,988
|
||||
nacl/bindings/sodium_core.py,sha256=8B6CPFXlkmzPCRJ_Asvc-KFS5yaqe6OCYqsL5xOsvgE,950
|
||||
PyNaCl-1.0.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
nacl/bindings/__pycache__/crypto_box.cpython-36.pyc,,
|
||||
nacl/bindings/__pycache__/crypto_hash.cpython-36.pyc,,
|
||||
nacl/bindings/__pycache__/crypto_scalarmult.cpython-36.pyc,,
|
||||
nacl/bindings/__pycache__/crypto_secretbox.cpython-36.pyc,,
|
||||
nacl/bindings/__pycache__/crypto_sign.cpython-36.pyc,,
|
||||
nacl/bindings/__pycache__/randombytes.cpython-36.pyc,,
|
||||
nacl/bindings/__pycache__/sodium_core.cpython-36.pyc,,
|
||||
nacl/bindings/__pycache__/__init__.cpython-36.pyc,,
|
||||
nacl/__pycache__/encoding.cpython-36.pyc,,
|
||||
nacl/__pycache__/exceptions.cpython-36.pyc,,
|
||||
nacl/__pycache__/hash.cpython-36.pyc,,
|
||||
nacl/__pycache__/public.cpython-36.pyc,,
|
||||
nacl/__pycache__/secret.cpython-36.pyc,,
|
||||
nacl/__pycache__/signing.cpython-36.pyc,,
|
||||
nacl/__pycache__/utils.cpython-36.pyc,,
|
||||
nacl/__pycache__/__init__.cpython-36.pyc,,
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
Wheel-Version: 1.0
|
||||
Generator: bdist_wheel (0.29.0)
|
||||
Root-Is-Purelib: false
|
||||
Tag: cp36-cp36m-win32
|
||||
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
_sodium
|
||||
nacl
|
||||
Binary file not shown.
|
|
@ -0,0 +1,352 @@
|
|||
Metadata-Version: 1.1
|
||||
Name: aiohttp
|
||||
Version: 1.0.5
|
||||
Summary: http client/server for asyncio
|
||||
Home-page: https://github.com/KeepSafe/aiohttp/
|
||||
Author: Andrew Svetlov
|
||||
Author-email: andrew.svetlov@gmail.com
|
||||
License: Apache 2
|
||||
Description: http client/server for asyncio
|
||||
==============================
|
||||
|
||||
.. image:: https://raw.github.com/KeepSafe/aiohttp/master/docs/_static/aiohttp-icon-128x128.png
|
||||
:height: 64px
|
||||
:width: 64px
|
||||
:alt: aiohttp logo
|
||||
|
||||
.. image:: https://travis-ci.org/KeepSafe/aiohttp.svg?branch=master
|
||||
:target: https://travis-ci.org/KeepSafe/aiohttp
|
||||
:align: right
|
||||
|
||||
.. image:: https://codecov.io/gh/KeepSafe/aiohttp/branch/master/graph/badge.svg
|
||||
:target: https://codecov.io/gh/KeepSafe/aiohttp
|
||||
|
||||
.. image:: https://badge.fury.io/py/aiohttp.svg
|
||||
:target: https://badge.fury.io/py/aiohttp
|
||||
|
||||
Features
|
||||
--------
|
||||
|
||||
- Supports both client and server side of HTTP protocol.
|
||||
- Supports both client and server Web-Sockets out-of-the-box.
|
||||
- Web-server has middlewares and pluggable routing.
|
||||
|
||||
|
||||
Getting started
|
||||
---------------
|
||||
|
||||
Client
|
||||
^^^^^^
|
||||
|
||||
To retrieve something from the web:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
|
||||
async def fetch(session, url):
|
||||
with aiohttp.Timeout(10, loop=session.loop):
|
||||
async with session.get(url) as response:
|
||||
return await response.text()
|
||||
|
||||
async def main(loop):
|
||||
async with aiohttp.ClientSession(loop=loop) as session:
|
||||
html = await fetch(session, 'http://python.org')
|
||||
print(html)
|
||||
|
||||
if __name__ == '__main__':
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(main(loop))
|
||||
|
||||
|
||||
Server
|
||||
^^^^^^
|
||||
|
||||
This is simple usage example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
async def handle(request):
|
||||
name = request.match_info.get('name', "Anonymous")
|
||||
text = "Hello, " + name
|
||||
return web.Response(text=text)
|
||||
|
||||
async def wshandler(request):
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == web.MsgType.text:
|
||||
ws.send_str("Hello, {}".format(msg.data))
|
||||
elif msg.type == web.MsgType.binary:
|
||||
ws.send_bytes(msg.data)
|
||||
elif msg.type == web.MsgType.close:
|
||||
break
|
||||
|
||||
return ws
|
||||
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get('/echo', wshandler)
|
||||
app.router.add_get('/', handle)
|
||||
app.router.add_get('/{name}', handle)
|
||||
|
||||
web.run_app(app)
|
||||
|
||||
|
||||
Note: examples are written for Python 3.5+ and utilize PEP-492 aka
|
||||
async/await. If you are using Python 3.4 please replace ``await`` with
|
||||
``yield from`` and ``async def`` with ``@coroutine`` e.g.::
|
||||
|
||||
async def coro(...):
|
||||
ret = await f()
|
||||
|
||||
should be replaced by::
|
||||
|
||||
@asyncio.coroutine
|
||||
def coro(...):
|
||||
ret = yield from f()
|
||||
|
||||
Documentation
|
||||
-------------
|
||||
|
||||
https://aiohttp.readthedocs.io/
|
||||
|
||||
Discussion list
|
||||
---------------
|
||||
|
||||
*aio-libs* google group: https://groups.google.com/forum/#!forum/aio-libs
|
||||
|
||||
Requirements
|
||||
------------
|
||||
|
||||
- Python >= 3.4.2
|
||||
- chardet_
|
||||
- multidict_
|
||||
|
||||
Optionally you may install the cChardet_ and aiodns_ libraries (highly
|
||||
recommended for sake of speed).
|
||||
|
||||
.. _chardet: https://pypi.python.org/pypi/chardet
|
||||
.. _aiodns: https://pypi.python.org/pypi/aiodns
|
||||
.. _multidict: https://pypi.python.org/pypi/multidict
|
||||
.. _cChardet: https://pypi.python.org/pypi/cchardet
|
||||
|
||||
License
|
||||
-------
|
||||
|
||||
``aiohttp`` is offered under the Apache 2 license.
|
||||
|
||||
|
||||
Source code
|
||||
------------
|
||||
|
||||
The latest developer version is available in a github repository:
|
||||
https://github.com/KeepSafe/aiohttp
|
||||
|
||||
Benchmarks
|
||||
----------
|
||||
|
||||
If you are interested in by efficiency, AsyncIO community maintains a
|
||||
list of benchmarks on the official wiki:
|
||||
https://github.com/python/asyncio/wiki/Benchmarks
|
||||
|
||||
CHANGES
|
||||
=======
|
||||
|
||||
1.0.5 (2016-10-11)
|
||||
------------------
|
||||
|
||||
- Fix StreamReader._read_nowait to return all available
|
||||
data up to the requested amount #1297
|
||||
|
||||
|
||||
1.0.4 (2016-09-22)
|
||||
------------------
|
||||
|
||||
- Fix FlowControlStreamReader.read_nowait so that it checks
|
||||
whether the transport is paused #1206
|
||||
|
||||
|
||||
1.0.2 (2016-09-22)
|
||||
------------------
|
||||
|
||||
- Make CookieJar compatible with 32-bit systems #1188
|
||||
|
||||
- Add missing `WSMsgType` to `web_ws.__all__`, see #1200
|
||||
|
||||
- Fix `CookieJar` ctor when called with `loop=None` #1203
|
||||
|
||||
- Fix broken upper-casing in wsgi support #1197
|
||||
|
||||
|
||||
1.0.1 (2016-09-16)
|
||||
------------------
|
||||
|
||||
- Restore `aiohttp.web.MsgType` alias for `aiohttp.WSMsgType` for sake
|
||||
of backward compatibility #1178
|
||||
|
||||
- Tune alabaster schema.
|
||||
|
||||
- Use `text/html` content type for displaying index pages by static
|
||||
file handler.
|
||||
|
||||
- Fix `AssertionError` in static file handling #1177
|
||||
|
||||
- Fix access log formats `%O` and `%b` for static file handling
|
||||
|
||||
- Remove `debug` setting of GunicornWorker, use `app.debug`
|
||||
to control its debug-mode instead
|
||||
|
||||
|
||||
1.0.0 (2016-09-16)
|
||||
-------------------
|
||||
|
||||
- Change default size for client session's connection pool from
|
||||
unlimited to 20 #977
|
||||
|
||||
- Add IE support for cookie deletion. #994
|
||||
|
||||
- Remove deprecated `WebSocketResponse.wait_closed` method (BACKWARD
|
||||
INCOMPATIBLE)
|
||||
|
||||
- Remove deprecated `force` parameter for `ClientResponse.close`
|
||||
method (BACKWARD INCOMPATIBLE)
|
||||
|
||||
- Avoid using of mutable CIMultiDict kw param in make_mocked_request
|
||||
#997
|
||||
|
||||
- Make WebSocketResponse.close a little bit faster by avoiding new
|
||||
task creating just for timeout measurement
|
||||
|
||||
- Add `proxy` and `proxy_auth` params to `client.get()` and family,
|
||||
deprecate `ProxyConnector` #998
|
||||
|
||||
- Add support for websocket send_json and receive_json, synchronize
|
||||
server and client API for websockets #984
|
||||
|
||||
- Implement router shourtcuts for most useful HTTP methods, use
|
||||
`app.router.add_get()`, `app.router.add_post()` etc. instead of
|
||||
`app.router.add_route()` #986
|
||||
|
||||
- Support SSL connections for gunicorn worker #1003
|
||||
|
||||
- Move obsolete examples to legacy folder
|
||||
|
||||
- Switch to multidict 2.0 and title-cased strings #1015
|
||||
|
||||
- `{FOO}e` logger format is case-sensitive now
|
||||
|
||||
- Fix logger report for unix socket 8e8469b
|
||||
|
||||
- Rename aiohttp.websocket to aiohttp._ws_impl
|
||||
|
||||
- Rename aiohttp.MsgType tp aiohttp.WSMsgType
|
||||
|
||||
- Introduce aiohttp.WSMessage officially
|
||||
|
||||
- Rename Message -> WSMessage
|
||||
|
||||
- Remove deprecated decode param from resp.read(decode=True)
|
||||
|
||||
- Use 5min default client timeout #1028
|
||||
|
||||
- Relax HTTP method validation in UrlDispatcher #1037
|
||||
|
||||
- Pin minimal supported asyncio version to 3.4.2+ (`loop.is_close()`
|
||||
should be present)
|
||||
|
||||
- Remove aiohttp.websocket module (BACKWARD INCOMPATIBLE)
|
||||
Please use high-level client and server approaches
|
||||
|
||||
- Link header for 451 status code is mandatory
|
||||
|
||||
- Fix test_client fixture to allow multiple clients per test #1072
|
||||
|
||||
- make_mocked_request now accepts dict as headers #1073
|
||||
|
||||
- Add Python 3.5.2/3.6+ compatibility patch for async generator
|
||||
protocol change #1082
|
||||
|
||||
- Improvement test_client can accept instance object #1083
|
||||
|
||||
- Simplify ServerHttpProtocol implementation #1060
|
||||
|
||||
- Add a flag for optional showing directory index for static file
|
||||
handling #921
|
||||
|
||||
- Define `web.Application.on_startup()` signal handler #1103
|
||||
|
||||
- Drop ChunkedParser and LinesParser #1111
|
||||
|
||||
- Call `Application.startup` in GunicornWebWorker #1105
|
||||
|
||||
- Fix client handling hostnames with 63 bytes when a port is given in
|
||||
the url #1044
|
||||
|
||||
- Implement proxy support for ClientSession.ws_connect #1025
|
||||
|
||||
- Return named tuple from WebSocketResponse.can_prepare #1016
|
||||
|
||||
- Fix access_log_format in `GunicornWebWorker` #1117
|
||||
|
||||
- Setup Content-Type to application/octet-stream by default #1124
|
||||
|
||||
- Deprecate debug parameter from app.make_handler(), use
|
||||
`Application(debug=True)` instead #1121
|
||||
|
||||
- Remove fragment string in request path #846
|
||||
|
||||
- Use aiodns.DNSResolver.gethostbyname() if available #1136
|
||||
|
||||
- Fix static file sending on uvloop when sendfile is available #1093
|
||||
|
||||
- Make prettier urls if query is empty dict #1143
|
||||
|
||||
- Fix redirects for HEAD requests #1147
|
||||
|
||||
- Default value for `StreamReader.read_nowait` is -1 from now #1150
|
||||
|
||||
- `aiohttp.StreamReader` is not inherited from `asyncio.StreamReader` from now
|
||||
(BACKWARD INCOMPATIBLE) #1150
|
||||
|
||||
- Streams documentation added #1150
|
||||
|
||||
- Add `multipart` coroutine method for web Request object #1067
|
||||
|
||||
- Publish ClientSession.loop property #1149
|
||||
|
||||
- Fix static file with spaces #1140
|
||||
|
||||
- Fix piling up asyncio loop by cookie expiration callbacks #1061
|
||||
|
||||
- Drop `Timeout` class for sake of `async_timeout` external library.
|
||||
`aiohttp.Timeout` is an alias for `async_timeout.timeout`
|
||||
|
||||
- `use_dns_cache` parameter of `aiohttp.TCPConnector` is `True` by
|
||||
default (BACKWARD INCOMPATIBLE) #1152
|
||||
|
||||
- `aiohttp.TCPConnector` uses asynchronous DNS resolver if available by
|
||||
default (BACKWARD INCOMPATIBLE) #1152
|
||||
|
||||
- Conform to RFC3986 - do not include url fragments in client requests #1174
|
||||
|
||||
- Drop `ClientSession.cookies` (BACKWARD INCOMPATIBLE) #1173
|
||||
|
||||
- Refactor `AbstractCookieJar` public API (BACKWARD INCOMPATIBLE) #1173
|
||||
|
||||
- Fix clashing cookies with have the same name but belong to different
|
||||
domains (BACKWARD INCOMPATIBLE) #1125
|
||||
|
||||
- Support binary Content-Transfer-Encoding #1169
|
||||
Platform: UNKNOWN
|
||||
Classifier: License :: OSI Approved :: Apache Software License
|
||||
Classifier: Intended Audience :: Developers
|
||||
Classifier: Programming Language :: Python
|
||||
Classifier: Programming Language :: Python :: 3
|
||||
Classifier: Programming Language :: Python :: 3.4
|
||||
Classifier: Programming Language :: Python :: 3.5
|
||||
Classifier: Topic :: Internet :: WWW/HTTP
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
CHANGES.rst
|
||||
CONTRIBUTORS.txt
|
||||
LICENSE.txt
|
||||
MANIFEST.in
|
||||
Makefile
|
||||
README.rst
|
||||
setup.cfg
|
||||
setup.py
|
||||
aiohttp/__init__.py
|
||||
aiohttp/_websocket.c
|
||||
aiohttp/_websocket.pyx
|
||||
aiohttp/_ws_impl.py
|
||||
aiohttp/abc.py
|
||||
aiohttp/client.py
|
||||
aiohttp/client_reqrep.py
|
||||
aiohttp/client_ws.py
|
||||
aiohttp/connector.py
|
||||
aiohttp/cookiejar.py
|
||||
aiohttp/errors.py
|
||||
aiohttp/file_sender.py
|
||||
aiohttp/hdrs.py
|
||||
aiohttp/helpers.py
|
||||
aiohttp/log.py
|
||||
aiohttp/multipart.py
|
||||
aiohttp/parsers.py
|
||||
aiohttp/protocol.py
|
||||
aiohttp/pytest_plugin.py
|
||||
aiohttp/resolver.py
|
||||
aiohttp/server.py
|
||||
aiohttp/signals.py
|
||||
aiohttp/streams.py
|
||||
aiohttp/test_utils.py
|
||||
aiohttp/web.py
|
||||
aiohttp/web_exceptions.py
|
||||
aiohttp/web_reqrep.py
|
||||
aiohttp/web_urldispatcher.py
|
||||
aiohttp/web_ws.py
|
||||
aiohttp/worker.py
|
||||
aiohttp/wsgi.py
|
||||
aiohttp.egg-info/PKG-INFO
|
||||
aiohttp.egg-info/SOURCES.txt
|
||||
aiohttp.egg-info/dependency_links.txt
|
||||
aiohttp.egg-info/requires.txt
|
||||
aiohttp.egg-info/top_level.txt
|
||||
docs/Makefile
|
||||
docs/abc.rst
|
||||
docs/aiohttp-icon.ico
|
||||
docs/aiohttp-icon.svg
|
||||
docs/api.rst
|
||||
docs/changes.rst
|
||||
docs/client.rst
|
||||
docs/client_reference.rst
|
||||
docs/conf.py
|
||||
docs/contributing.rst
|
||||
docs/faq.rst
|
||||
docs/glossary.rst
|
||||
docs/gunicorn.rst
|
||||
docs/index.rst
|
||||
docs/logging.rst
|
||||
docs/make.bat
|
||||
docs/multipart.rst
|
||||
docs/new_router.rst
|
||||
docs/server.rst
|
||||
docs/spelling_wordlist.txt
|
||||
docs/streams.rst
|
||||
docs/testing.rst
|
||||
docs/tutorial.rst
|
||||
docs/web.rst
|
||||
docs/web_reference.rst
|
||||
docs/_static/aiohttp-icon-128x128.png
|
||||
docs/_static/aiohttp-icon-32x32.png
|
||||
docs/_static/aiohttp-icon-64x64.png
|
||||
docs/_static/aiohttp-icon-96x96.png
|
||||
examples/background_tasks.py
|
||||
examples/basic_srv.py
|
||||
examples/cli_app.py
|
||||
examples/client_auth.py
|
||||
examples/client_json.py
|
||||
examples/client_ws.py
|
||||
examples/curl.py
|
||||
examples/fake_server.py
|
||||
examples/server.crt
|
||||
examples/server.csr
|
||||
examples/server.key
|
||||
examples/static_files.py
|
||||
examples/web_classview1.py
|
||||
examples/web_cookies.py
|
||||
examples/web_rewrite_headers_middleware.py
|
||||
examples/web_srv.py
|
||||
examples/web_ws.py
|
||||
examples/websocket.html
|
||||
examples/legacy/crawl.py
|
||||
examples/legacy/srv.py
|
||||
examples/legacy/tcp_protocol_parser.py
|
||||
tests/conftest.py
|
||||
tests/data.unknown_mime_type
|
||||
tests/hello.txt.gz
|
||||
tests/sample.crt
|
||||
tests/sample.crt.der
|
||||
tests/sample.key
|
||||
tests/software_development_in_picture.jpg
|
||||
tests/test_classbasedview.py
|
||||
tests/test_client_connection.py
|
||||
tests/test_client_functional.py
|
||||
tests/test_client_functional_oldstyle.py
|
||||
tests/test_client_request.py
|
||||
tests/test_client_response.py
|
||||
tests/test_client_session.py
|
||||
tests/test_client_ws.py
|
||||
tests/test_client_ws_functional.py
|
||||
tests/test_connector.py
|
||||
tests/test_cookiejar.py
|
||||
tests/test_errors.py
|
||||
tests/test_flowcontrol_streams.py
|
||||
tests/test_helpers.py
|
||||
tests/test_http_parser.py
|
||||
tests/test_multipart.py
|
||||
tests/test_parser_buffer.py
|
||||
tests/test_protocol.py
|
||||
tests/test_proxy.py
|
||||
tests/test_pytest_plugin.py
|
||||
tests/test_resolver.py
|
||||
tests/test_run_app.py
|
||||
tests/test_server.py
|
||||
tests/test_signals.py
|
||||
tests/test_stream_parser.py
|
||||
tests/test_stream_protocol.py
|
||||
tests/test_stream_writer.py
|
||||
tests/test_streams.py
|
||||
tests/test_test_utils.py
|
||||
tests/test_urldispatch.py
|
||||
tests/test_web_application.py
|
||||
tests/test_web_cli.py
|
||||
tests/test_web_exceptions.py
|
||||
tests/test_web_functional.py
|
||||
tests/test_web_middleware.py
|
||||
tests/test_web_request.py
|
||||
tests/test_web_request_handler.py
|
||||
tests/test_web_response.py
|
||||
tests/test_web_sendfile.py
|
||||
tests/test_web_sendfile_functional.py
|
||||
tests/test_web_urldispatcher.py
|
||||
tests/test_web_websocket.py
|
||||
tests/test_web_websocket_functional.py
|
||||
tests/test_web_websocket_functional_oldstyle.py
|
||||
tests/test_websocket_handshake.py
|
||||
tests/test_websocket_parser.py
|
||||
tests/test_websocket_writer.py
|
||||
tests/test_worker.py
|
||||
tests/test_wsgi.py
|
||||
tests/autobahn/client.py
|
||||
tests/autobahn/fuzzingclient.json
|
||||
tests/autobahn/fuzzingserver.json
|
||||
tests/autobahn/server.py
|
||||
tests/test_py35/test_cbv35.py
|
||||
tests/test_py35/test_client.py
|
||||
tests/test_py35/test_client_websocket_35.py
|
||||
tests/test_py35/test_multipart_35.py
|
||||
tests/test_py35/test_resp.py
|
||||
tests/test_py35/test_streams_35.py
|
||||
tests/test_py35/test_test_utils_35.py
|
||||
tests/test_py35/test_web_websocket_35.py
|
||||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
..\aiohttp\abc.py
|
||||
..\aiohttp\client.py
|
||||
..\aiohttp\client_reqrep.py
|
||||
..\aiohttp\client_ws.py
|
||||
..\aiohttp\connector.py
|
||||
..\aiohttp\cookiejar.py
|
||||
..\aiohttp\errors.py
|
||||
..\aiohttp\file_sender.py
|
||||
..\aiohttp\hdrs.py
|
||||
..\aiohttp\helpers.py
|
||||
..\aiohttp\log.py
|
||||
..\aiohttp\multipart.py
|
||||
..\aiohttp\parsers.py
|
||||
..\aiohttp\protocol.py
|
||||
..\aiohttp\pytest_plugin.py
|
||||
..\aiohttp\resolver.py
|
||||
..\aiohttp\server.py
|
||||
..\aiohttp\signals.py
|
||||
..\aiohttp\streams.py
|
||||
..\aiohttp\test_utils.py
|
||||
..\aiohttp\web.py
|
||||
..\aiohttp\web_exceptions.py
|
||||
..\aiohttp\web_reqrep.py
|
||||
..\aiohttp\web_urldispatcher.py
|
||||
..\aiohttp\web_ws.py
|
||||
..\aiohttp\worker.py
|
||||
..\aiohttp\wsgi.py
|
||||
..\aiohttp\_ws_impl.py
|
||||
..\aiohttp\__init__.py
|
||||
..\aiohttp\_websocket.c
|
||||
..\aiohttp\_websocket.pyx
|
||||
..\aiohttp\__pycache__\abc.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\client.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\client_reqrep.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\client_ws.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\connector.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\cookiejar.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\errors.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\file_sender.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\hdrs.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\helpers.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\log.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\multipart.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\parsers.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\protocol.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\pytest_plugin.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\resolver.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\server.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\signals.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\streams.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\test_utils.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\web.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\web_exceptions.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\web_reqrep.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\web_urldispatcher.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\web_ws.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\worker.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\wsgi.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\_ws_impl.cpython-36.pyc
|
||||
..\aiohttp\__pycache__\__init__.cpython-36.pyc
|
||||
..\aiohttp\_websocket.cp36-win32.pyd
|
||||
dependency_links.txt
|
||||
PKG-INFO
|
||||
requires.txt
|
||||
SOURCES.txt
|
||||
top_level.txt
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
chardet
|
||||
multidict>=2.0
|
||||
async_timeout
|
||||
|
|
@ -0,0 +1 @@
|
|||
aiohttp
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
__version__ = '1.0.5'
|
||||
|
||||
# Deprecated, keep it here for a while for backward compatibility.
|
||||
import multidict # noqa
|
||||
|
||||
# This relies on each of the submodules having an __all__ variable.
|
||||
|
||||
from multidict import * # noqa
|
||||
from . import hdrs # noqa
|
||||
from .protocol import * # noqa
|
||||
from .connector import * # noqa
|
||||
from .client import * # noqa
|
||||
from .client_reqrep import * # noqa
|
||||
from .errors import * # noqa
|
||||
from .helpers import * # noqa
|
||||
from .parsers import * # noqa
|
||||
from .streams import * # noqa
|
||||
from .multipart import * # noqa
|
||||
from .client_ws import ClientWebSocketResponse # noqa
|
||||
from ._ws_impl import WSMsgType, WSCloseCode, WSMessage, WebSocketError # noqa
|
||||
from .file_sender import FileSender # noqa
|
||||
from .cookiejar import CookieJar # noqa
|
||||
from .resolver import * # noqa
|
||||
|
||||
|
||||
MsgType = WSMsgType # backward compatibility
|
||||
|
||||
|
||||
__all__ = (client.__all__ + # noqa
|
||||
client_reqrep.__all__ + # noqa
|
||||
errors.__all__ + # noqa
|
||||
helpers.__all__ + # noqa
|
||||
parsers.__all__ + # noqa
|
||||
protocol.__all__ + # noqa
|
||||
connector.__all__ + # noqa
|
||||
streams.__all__ + # noqa
|
||||
multidict.__all__ + # noqa
|
||||
multipart.__all__ + # noqa
|
||||
('hdrs', 'FileSender', 'WSMsgType', 'MsgType', 'WSCloseCode',
|
||||
'WebSocketError', 'WSMessage',
|
||||
'ClientWebSocketResponse', 'CookieJar'))
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
|
@ -0,0 +1,48 @@
|
|||
from cpython cimport PyBytes_AsString
|
||||
|
||||
#from cpython cimport PyByteArray_AsString # cython still not exports that
|
||||
cdef extern from "Python.h":
|
||||
char* PyByteArray_AsString(bytearray ba) except NULL
|
||||
|
||||
from libc.stdint cimport uint32_t, uint64_t, uintmax_t
|
||||
|
||||
def _websocket_mask_cython(bytes mask, bytearray data):
|
||||
"""Note, this function mutates it's `data` argument
|
||||
"""
|
||||
cdef:
|
||||
Py_ssize_t data_len, i
|
||||
# bit operations on signed integers are implementation-specific
|
||||
unsigned char * in_buf
|
||||
const unsigned char * mask_buf
|
||||
uint32_t uint32_msk
|
||||
uint64_t uint64_msk
|
||||
|
||||
assert len(mask) == 4
|
||||
|
||||
data_len = len(data)
|
||||
in_buf = <unsigned char*>PyByteArray_AsString(data)
|
||||
mask_buf = <const unsigned char*>PyBytes_AsString(mask)
|
||||
uint32_msk = (<uint32_t*>mask_buf)[0]
|
||||
|
||||
# TODO: align in_data ptr to achieve even faster speeds
|
||||
# does it need in python ?! malloc() always aligns to sizeof(long) bytes
|
||||
|
||||
if sizeof(size_t) >= 8:
|
||||
uint64_msk = uint32_msk
|
||||
uint64_msk = (uint64_msk << 32) | uint32_msk
|
||||
|
||||
while data_len >= 8:
|
||||
(<uint64_t*>in_buf)[0] ^= uint64_msk
|
||||
in_buf += 8
|
||||
data_len -= 8
|
||||
|
||||
|
||||
while data_len >= 4:
|
||||
(<uint32_t*>in_buf)[0] ^= uint32_msk
|
||||
in_buf += 4
|
||||
data_len -= 4
|
||||
|
||||
for i in range(0, data_len):
|
||||
in_buf[i] ^= mask_buf[i]
|
||||
|
||||
return data
|
||||
|
|
@ -0,0 +1,438 @@
|
|||
"""WebSocket protocol versions 13 and 8."""
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import collections
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from enum import IntEnum
|
||||
from struct import Struct
|
||||
|
||||
from aiohttp import errors, hdrs
|
||||
from aiohttp.log import ws_logger
|
||||
|
||||
__all__ = ('WebSocketParser', 'WebSocketWriter', 'do_handshake',
|
||||
'WSMessage', 'WebSocketError', 'WSMsgType', 'WSCloseCode')
|
||||
|
||||
|
||||
class WSCloseCode(IntEnum):
|
||||
OK = 1000
|
||||
GOING_AWAY = 1001
|
||||
PROTOCOL_ERROR = 1002
|
||||
UNSUPPORTED_DATA = 1003
|
||||
INVALID_TEXT = 1007
|
||||
POLICY_VIOLATION = 1008
|
||||
MESSAGE_TOO_BIG = 1009
|
||||
MANDATORY_EXTENSION = 1010
|
||||
INTERNAL_ERROR = 1011
|
||||
SERVICE_RESTART = 1012
|
||||
TRY_AGAIN_LATER = 1013
|
||||
|
||||
|
||||
ALLOWED_CLOSE_CODES = {int(i) for i in WSCloseCode}
|
||||
|
||||
|
||||
class WSMsgType(IntEnum):
|
||||
CONTINUATION = 0x0
|
||||
TEXT = 0x1
|
||||
BINARY = 0x2
|
||||
PING = 0x9
|
||||
PONG = 0xa
|
||||
CLOSE = 0x8
|
||||
CLOSED = 0x101
|
||||
ERROR = 0x102
|
||||
|
||||
text = TEXT
|
||||
binary = BINARY
|
||||
ping = PING
|
||||
pong = PONG
|
||||
close = CLOSE
|
||||
closed = CLOSED
|
||||
error = ERROR
|
||||
|
||||
|
||||
WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
|
||||
|
||||
|
||||
UNPACK_LEN2 = Struct('!H').unpack_from
|
||||
UNPACK_LEN3 = Struct('!Q').unpack_from
|
||||
UNPACK_CLOSE_CODE = Struct('!H').unpack
|
||||
PACK_LEN1 = Struct('!BB').pack
|
||||
PACK_LEN2 = Struct('!BBH').pack
|
||||
PACK_LEN3 = Struct('!BBQ').pack
|
||||
PACK_CLOSE_CODE = Struct('!H').pack
|
||||
MSG_SIZE = 2 ** 14
|
||||
|
||||
|
||||
_WSMessageBase = collections.namedtuple('_WSMessageBase',
|
||||
['type', 'data', 'extra'])
|
||||
|
||||
|
||||
class WSMessage(_WSMessageBase):
|
||||
def json(self, *, loads=json.loads):
|
||||
"""Return parsed JSON data.
|
||||
|
||||
.. versionadded:: 0.22
|
||||
"""
|
||||
return loads(self.data)
|
||||
|
||||
@property
|
||||
def tp(self):
|
||||
return self.type
|
||||
|
||||
|
||||
CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None)
|
||||
|
||||
|
||||
class WebSocketError(Exception):
|
||||
"""WebSocket protocol parser error."""
|
||||
|
||||
def __init__(self, code, message):
|
||||
self.code = code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def WebSocketParser(out, buf):
|
||||
while True:
|
||||
fin, opcode, payload = yield from parse_frame(buf)
|
||||
|
||||
if opcode == WSMsgType.CLOSE:
|
||||
if len(payload) >= 2:
|
||||
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
|
||||
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
'Invalid close code: {}'.format(close_code))
|
||||
try:
|
||||
close_message = payload[2:].decode('utf-8')
|
||||
except UnicodeDecodeError as exc:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.INVALID_TEXT,
|
||||
'Invalid UTF-8 text message') from exc
|
||||
msg = WSMessage(WSMsgType.CLOSE, close_code, close_message)
|
||||
elif payload:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
'Invalid close frame: {} {} {!r}'.format(
|
||||
fin, opcode, payload))
|
||||
else:
|
||||
msg = WSMessage(WSMsgType.CLOSE, 0, '')
|
||||
|
||||
out.feed_data(msg, 0)
|
||||
|
||||
elif opcode == WSMsgType.PING:
|
||||
out.feed_data(WSMessage(WSMsgType.PING, payload, ''), len(payload))
|
||||
|
||||
elif opcode == WSMsgType.PONG:
|
||||
out.feed_data(WSMessage(WSMsgType.PONG, payload, ''), len(payload))
|
||||
|
||||
elif opcode not in (WSMsgType.TEXT, WSMsgType.BINARY):
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Unexpected opcode={!r}".format(opcode))
|
||||
else:
|
||||
# load text/binary
|
||||
data = [payload]
|
||||
|
||||
while not fin:
|
||||
fin, _opcode, payload = yield from parse_frame(buf, True)
|
||||
|
||||
# We can receive ping/close in the middle of
|
||||
# text message, Case 5.*
|
||||
if _opcode == WSMsgType.PING:
|
||||
out.feed_data(
|
||||
WSMessage(WSMsgType.PING, payload, ''), len(payload))
|
||||
fin, _opcode, payload = yield from parse_frame(buf, True)
|
||||
elif _opcode == WSMsgType.CLOSE:
|
||||
if len(payload) >= 2:
|
||||
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
|
||||
if (close_code not in ALLOWED_CLOSE_CODES and
|
||||
close_code < 3000):
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
'Invalid close code: {}'.format(close_code))
|
||||
try:
|
||||
close_message = payload[2:].decode('utf-8')
|
||||
except UnicodeDecodeError as exc:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.INVALID_TEXT,
|
||||
'Invalid UTF-8 text message') from exc
|
||||
msg = WSMessage(WSMsgType.CLOSE, close_code,
|
||||
close_message)
|
||||
elif payload:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
'Invalid close frame: {} {} {!r}'.format(
|
||||
fin, opcode, payload))
|
||||
else:
|
||||
msg = WSMessage(WSMsgType.CLOSE, 0, '')
|
||||
|
||||
out.feed_data(msg, 0)
|
||||
fin, _opcode, payload = yield from parse_frame(buf, True)
|
||||
|
||||
if _opcode != WSMsgType.CONTINUATION:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
'The opcode in non-fin frame is expected '
|
||||
'to be zero, got {!r}'.format(_opcode))
|
||||
else:
|
||||
data.append(payload)
|
||||
|
||||
if opcode == WSMsgType.TEXT:
|
||||
try:
|
||||
text = b''.join(data).decode('utf-8')
|
||||
out.feed_data(WSMessage(WSMsgType.TEXT, text, ''),
|
||||
len(text))
|
||||
except UnicodeDecodeError as exc:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.INVALID_TEXT,
|
||||
'Invalid UTF-8 text message') from exc
|
||||
else:
|
||||
data = b''.join(data)
|
||||
out.feed_data(
|
||||
WSMessage(WSMsgType.BINARY, data, ''), len(data))
|
||||
|
||||
|
||||
native_byteorder = sys.byteorder
|
||||
|
||||
|
||||
def _websocket_mask_python(mask, data):
|
||||
"""Websocket masking function.
|
||||
|
||||
`mask` is a `bytes` object of length 4; `data` is a `bytes` object
|
||||
of any length. Returns a `bytes` object of the same length as
|
||||
`data` with the mask applied as specified in section 5.3 of RFC
|
||||
6455.
|
||||
|
||||
This pure-python implementation may be replaced by an optimized
|
||||
version when available.
|
||||
|
||||
"""
|
||||
assert isinstance(data, bytearray), data
|
||||
assert len(mask) == 4, mask
|
||||
datalen = len(data)
|
||||
if datalen == 0:
|
||||
# everything work without this, but may be changed later in Python.
|
||||
return bytearray()
|
||||
data = int.from_bytes(data, native_byteorder)
|
||||
mask = int.from_bytes(mask * (datalen // 4) + mask[: datalen % 4],
|
||||
native_byteorder)
|
||||
return (data ^ mask).to_bytes(datalen, native_byteorder)
|
||||
|
||||
|
||||
if bool(os.environ.get('AIOHTTP_NO_EXTENSIONS')):
|
||||
_websocket_mask = _websocket_mask_python
|
||||
else:
|
||||
try:
|
||||
from ._websocket import _websocket_mask_cython
|
||||
_websocket_mask = _websocket_mask_cython
|
||||
except ImportError: # pragma: no cover
|
||||
_websocket_mask = _websocket_mask_python
|
||||
|
||||
|
||||
def parse_frame(buf, continuation=False):
|
||||
"""Return the next frame from the socket."""
|
||||
# read header
|
||||
data = yield from buf.read(2)
|
||||
first_byte, second_byte = data
|
||||
|
||||
fin = (first_byte >> 7) & 1
|
||||
rsv1 = (first_byte >> 6) & 1
|
||||
rsv2 = (first_byte >> 5) & 1
|
||||
rsv3 = (first_byte >> 4) & 1
|
||||
opcode = first_byte & 0xf
|
||||
|
||||
# frame-fin = %x0 ; more frames of this message follow
|
||||
# / %x1 ; final frame of this message
|
||||
# frame-rsv1 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
|
||||
# frame-rsv2 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
|
||||
# frame-rsv3 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
|
||||
if rsv1 or rsv2 or rsv3:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
'Received frame with non-zero reserved bits')
|
||||
|
||||
if opcode > 0x7 and fin == 0:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
'Received fragmented control frame')
|
||||
|
||||
if fin == 0 and opcode == WSMsgType.CONTINUATION and not continuation:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
'Received new fragment frame with non-zero '
|
||||
'opcode {!r}'.format(opcode))
|
||||
|
||||
has_mask = (second_byte >> 7) & 1
|
||||
length = (second_byte) & 0x7f
|
||||
|
||||
# Control frames MUST have a payload length of 125 bytes or less
|
||||
if opcode > 0x7 and length > 125:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Control frame payload cannot be larger than 125 bytes")
|
||||
|
||||
# read payload
|
||||
if length == 126:
|
||||
data = yield from buf.read(2)
|
||||
length = UNPACK_LEN2(data)[0]
|
||||
elif length > 126:
|
||||
data = yield from buf.read(8)
|
||||
length = UNPACK_LEN3(data)[0]
|
||||
|
||||
if has_mask:
|
||||
mask = yield from buf.read(4)
|
||||
|
||||
if length:
|
||||
payload = yield from buf.read(length)
|
||||
else:
|
||||
payload = bytearray()
|
||||
|
||||
if has_mask:
|
||||
payload = _websocket_mask(bytes(mask), payload)
|
||||
|
||||
return fin, opcode, payload
|
||||
|
||||
|
||||
class WebSocketWriter:
|
||||
|
||||
def __init__(self, writer, *, use_mask=False, random=random.Random()):
|
||||
self.writer = writer
|
||||
self.use_mask = use_mask
|
||||
self.randrange = random.randrange
|
||||
|
||||
def _send_frame(self, message, opcode):
|
||||
"""Send a frame over the websocket with message as its payload."""
|
||||
msg_length = len(message)
|
||||
|
||||
use_mask = self.use_mask
|
||||
if use_mask:
|
||||
mask_bit = 0x80
|
||||
else:
|
||||
mask_bit = 0
|
||||
|
||||
if msg_length < 126:
|
||||
header = PACK_LEN1(0x80 | opcode, msg_length | mask_bit)
|
||||
elif msg_length < (1 << 16):
|
||||
header = PACK_LEN2(0x80 | opcode, 126 | mask_bit, msg_length)
|
||||
else:
|
||||
header = PACK_LEN3(0x80 | opcode, 127 | mask_bit, msg_length)
|
||||
if use_mask:
|
||||
mask = self.randrange(0, 0xffffffff)
|
||||
mask = mask.to_bytes(4, 'big')
|
||||
message = _websocket_mask(mask, bytearray(message))
|
||||
self.writer.write(header + mask + message)
|
||||
else:
|
||||
if len(message) > MSG_SIZE:
|
||||
self.writer.write(header)
|
||||
self.writer.write(message)
|
||||
else:
|
||||
self.writer.write(header + message)
|
||||
|
||||
def pong(self, message=b''):
|
||||
"""Send pong message."""
|
||||
if isinstance(message, str):
|
||||
message = message.encode('utf-8')
|
||||
self._send_frame(message, WSMsgType.PONG)
|
||||
|
||||
def ping(self, message=b''):
|
||||
"""Send ping message."""
|
||||
if isinstance(message, str):
|
||||
message = message.encode('utf-8')
|
||||
self._send_frame(message, WSMsgType.PING)
|
||||
|
||||
def send(self, message, binary=False):
|
||||
"""Send a frame over the websocket with message as its payload."""
|
||||
if isinstance(message, str):
|
||||
message = message.encode('utf-8')
|
||||
if binary:
|
||||
self._send_frame(message, WSMsgType.BINARY)
|
||||
else:
|
||||
self._send_frame(message, WSMsgType.TEXT)
|
||||
|
||||
def close(self, code=1000, message=b''):
|
||||
"""Close the websocket, sending the specified code and message."""
|
||||
if isinstance(message, str):
|
||||
message = message.encode('utf-8')
|
||||
self._send_frame(
|
||||
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE)
|
||||
|
||||
|
||||
def do_handshake(method, headers, transport, protocols=()):
|
||||
"""Prepare WebSocket handshake.
|
||||
|
||||
It return HTTP response code, response headers, websocket parser,
|
||||
websocket writer. It does not perform any IO.
|
||||
|
||||
`protocols` is a sequence of known protocols. On successful handshake,
|
||||
the returned response headers contain the first protocol in this list
|
||||
which the server also knows.
|
||||
|
||||
"""
|
||||
# WebSocket accepts only GET
|
||||
if method.upper() != hdrs.METH_GET:
|
||||
raise errors.HttpProcessingError(
|
||||
code=405, headers=((hdrs.ALLOW, hdrs.METH_GET),))
|
||||
|
||||
if 'websocket' != headers.get(hdrs.UPGRADE, '').lower().strip():
|
||||
raise errors.HttpBadRequest(
|
||||
message='No WebSocket UPGRADE hdr: {}\n Can '
|
||||
'"Upgrade" only to "WebSocket".'.format(headers.get(hdrs.UPGRADE)))
|
||||
|
||||
if 'upgrade' not in headers.get(hdrs.CONNECTION, '').lower():
|
||||
raise errors.HttpBadRequest(
|
||||
message='No CONNECTION upgrade hdr: {}'.format(
|
||||
headers.get(hdrs.CONNECTION)))
|
||||
|
||||
# find common sub-protocol between client and server
|
||||
protocol = None
|
||||
if hdrs.SEC_WEBSOCKET_PROTOCOL in headers:
|
||||
req_protocols = [str(proto.strip()) for proto in
|
||||
headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
|
||||
|
||||
for proto in req_protocols:
|
||||
if proto in protocols:
|
||||
protocol = proto
|
||||
break
|
||||
else:
|
||||
# No overlap found: Return no protocol as per spec
|
||||
ws_logger.warning(
|
||||
'Client protocols %r don’t overlap server-known ones %r',
|
||||
req_protocols, protocols)
|
||||
|
||||
# check supported version
|
||||
version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, '')
|
||||
if version not in ('13', '8', '7'):
|
||||
raise errors.HttpBadRequest(
|
||||
message='Unsupported version: {}'.format(version),
|
||||
headers=((hdrs.SEC_WEBSOCKET_VERSION, '13'),))
|
||||
|
||||
# check client handshake for validity
|
||||
key = headers.get(hdrs.SEC_WEBSOCKET_KEY)
|
||||
try:
|
||||
if not key or len(base64.b64decode(key)) != 16:
|
||||
raise errors.HttpBadRequest(
|
||||
message='Handshake error: {!r}'.format(key))
|
||||
except binascii.Error:
|
||||
raise errors.HttpBadRequest(
|
||||
message='Handshake error: {!r}'.format(key)) from None
|
||||
|
||||
response_headers = [
|
||||
(hdrs.UPGRADE, 'websocket'),
|
||||
(hdrs.CONNECTION, 'upgrade'),
|
||||
(hdrs.TRANSFER_ENCODING, 'chunked'),
|
||||
(hdrs.SEC_WEBSOCKET_ACCEPT, base64.b64encode(
|
||||
hashlib.sha1(key.encode() + WS_KEY).digest()).decode())]
|
||||
|
||||
if protocol:
|
||||
response_headers.append((hdrs.SEC_WEBSOCKET_PROTOCOL, protocol))
|
||||
|
||||
# response code, headers, parser, writer, protocol
|
||||
return (101,
|
||||
response_headers,
|
||||
WebSocketParser,
|
||||
WebSocketWriter(transport),
|
||||
protocol)
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
import asyncio
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable, Sized
|
||||
|
||||
PY_35 = sys.version_info >= (3, 5)
|
||||
|
||||
|
||||
class AbstractRouter(ABC):
|
||||
|
||||
@asyncio.coroutine # pragma: no branch
|
||||
@abstractmethod
|
||||
def resolve(self, request):
|
||||
"""Return MATCH_INFO for given request"""
|
||||
|
||||
|
||||
class AbstractMatchInfo(ABC):
|
||||
|
||||
@asyncio.coroutine # pragma: no branch
|
||||
@abstractmethod
|
||||
def handler(self, request):
|
||||
"""Execute matched request handler"""
|
||||
|
||||
@asyncio.coroutine # pragma: no branch
|
||||
@abstractmethod
|
||||
def expect_handler(self, request):
|
||||
"""Expect handler for 100-continue processing"""
|
||||
|
||||
@property # pragma: no branch
|
||||
@abstractmethod
|
||||
def http_exception(self):
|
||||
"""HTTPException instance raised on router's resolving, or None"""
|
||||
|
||||
@abstractmethod # pragma: no branch
|
||||
def get_info(self):
|
||||
"""Return a dict with additional info useful for introspection"""
|
||||
|
||||
|
||||
class AbstractView(ABC):
|
||||
|
||||
def __init__(self, request):
|
||||
self._request = request
|
||||
|
||||
@property
|
||||
def request(self):
|
||||
return self._request
|
||||
|
||||
@asyncio.coroutine # pragma: no branch
|
||||
@abstractmethod
|
||||
def __iter__(self):
|
||||
while False: # pragma: no cover
|
||||
yield None
|
||||
|
||||
if PY_35: # pragma: no branch
|
||||
@abstractmethod
|
||||
def __await__(self):
|
||||
return # pragma: no cover
|
||||
|
||||
|
||||
class AbstractResolver(ABC):
|
||||
|
||||
@asyncio.coroutine # pragma: no branch
|
||||
@abstractmethod
|
||||
def resolve(self, hostname):
|
||||
"""Return IP address for given hostname"""
|
||||
|
||||
@asyncio.coroutine # pragma: no branch
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
"""Release resolver"""
|
||||
|
||||
|
||||
class AbstractCookieJar(Sized, Iterable):
|
||||
|
||||
def __init__(self, *, loop=None):
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
|
||||
@abstractmethod
|
||||
def clear(self):
|
||||
"""Clear all cookies."""
|
||||
|
||||
@abstractmethod
|
||||
def update_cookies(self, cookies, response_url=None):
|
||||
"""Update cookies."""
|
||||
|
||||
@abstractmethod
|
||||
def filter_cookies(self, request_url):
|
||||
"""Return the jar's cookies filtered by their attributes."""
|
||||
|
|
@ -0,0 +1,786 @@
|
|||
"""HTTP Client for asyncio."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import urllib.parse
|
||||
import warnings
|
||||
|
||||
from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr
|
||||
|
||||
import aiohttp
|
||||
|
||||
from . import hdrs, helpers
|
||||
from ._ws_impl import WS_KEY, WebSocketParser, WebSocketWriter
|
||||
from .client_reqrep import ClientRequest, ClientResponse
|
||||
from .client_ws import ClientWebSocketResponse
|
||||
from .cookiejar import CookieJar
|
||||
from .errors import WSServerHandshakeError
|
||||
from .helpers import Timeout
|
||||
|
||||
__all__ = ('ClientSession', 'request', 'get', 'options', 'head',
|
||||
'delete', 'post', 'put', 'patch', 'ws_connect')
|
||||
|
||||
PY_35 = sys.version_info >= (3, 5)
|
||||
|
||||
|
||||
class ClientSession:
|
||||
"""First-class interface for making HTTP requests."""
|
||||
|
||||
_source_traceback = None
|
||||
_connector = None
|
||||
|
||||
def __init__(self, *, connector=None, loop=None, cookies=None,
|
||||
headers=None, skip_auto_headers=None,
|
||||
auth=None, request_class=ClientRequest,
|
||||
response_class=ClientResponse,
|
||||
ws_response_class=ClientWebSocketResponse,
|
||||
version=aiohttp.HttpVersion11,
|
||||
cookie_jar=None):
|
||||
|
||||
if connector is None:
|
||||
connector = aiohttp.TCPConnector(loop=loop)
|
||||
loop = connector._loop # never None
|
||||
else:
|
||||
if loop is None:
|
||||
loop = connector._loop # never None
|
||||
elif connector._loop is not loop:
|
||||
raise ValueError("loop argument must agree with connector")
|
||||
|
||||
self._loop = loop
|
||||
if loop.get_debug():
|
||||
self._source_traceback = traceback.extract_stack(sys._getframe(1))
|
||||
|
||||
if cookie_jar is None:
|
||||
cookie_jar = CookieJar(loop=loop)
|
||||
self._cookie_jar = cookie_jar
|
||||
|
||||
if cookies is not None:
|
||||
self._cookie_jar.update_cookies(cookies)
|
||||
self._connector = connector
|
||||
self._default_auth = auth
|
||||
self._version = version
|
||||
|
||||
# Convert to list of tuples
|
||||
if headers:
|
||||
headers = CIMultiDict(headers)
|
||||
else:
|
||||
headers = CIMultiDict()
|
||||
self._default_headers = headers
|
||||
if skip_auto_headers is not None:
|
||||
self._skip_auto_headers = frozenset([istr(i)
|
||||
for i in skip_auto_headers])
|
||||
else:
|
||||
self._skip_auto_headers = frozenset()
|
||||
|
||||
self._request_class = request_class
|
||||
self._response_class = response_class
|
||||
self._ws_response_class = ws_response_class
|
||||
|
||||
def __del__(self, _warnings=warnings):
|
||||
if not self.closed:
|
||||
self.close()
|
||||
|
||||
_warnings.warn("Unclosed client session {!r}".format(self),
|
||||
ResourceWarning)
|
||||
context = {'client_session': self,
|
||||
'message': 'Unclosed client session'}
|
||||
if self._source_traceback is not None:
|
||||
context['source_traceback'] = self._source_traceback
|
||||
self._loop.call_exception_handler(context)
|
||||
|
||||
def request(self, method, url, *,
|
||||
params=None,
|
||||
data=None,
|
||||
headers=None,
|
||||
skip_auto_headers=None,
|
||||
auth=None,
|
||||
allow_redirects=True,
|
||||
max_redirects=10,
|
||||
encoding='utf-8',
|
||||
version=None,
|
||||
compress=None,
|
||||
chunked=None,
|
||||
expect100=False,
|
||||
read_until_eof=True,
|
||||
proxy=None,
|
||||
proxy_auth=None,
|
||||
timeout=5*60):
|
||||
"""Perform HTTP request."""
|
||||
|
||||
return _RequestContextManager(
|
||||
self._request(
|
||||
method,
|
||||
url,
|
||||
params=params,
|
||||
data=data,
|
||||
headers=headers,
|
||||
skip_auto_headers=skip_auto_headers,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
max_redirects=max_redirects,
|
||||
encoding=encoding,
|
||||
version=version,
|
||||
compress=compress,
|
||||
chunked=chunked,
|
||||
expect100=expect100,
|
||||
read_until_eof=read_until_eof,
|
||||
proxy=proxy,
|
||||
proxy_auth=proxy_auth,
|
||||
timeout=timeout))
|
||||
|
||||
@asyncio.coroutine
|
||||
def _request(self, method, url, *,
|
||||
params=None,
|
||||
data=None,
|
||||
headers=None,
|
||||
skip_auto_headers=None,
|
||||
auth=None,
|
||||
allow_redirects=True,
|
||||
max_redirects=10,
|
||||
encoding='utf-8',
|
||||
version=None,
|
||||
compress=None,
|
||||
chunked=None,
|
||||
expect100=False,
|
||||
read_until_eof=True,
|
||||
proxy=None,
|
||||
proxy_auth=None,
|
||||
timeout=5*60):
|
||||
|
||||
if version is not None:
|
||||
warnings.warn("HTTP version should be specified "
|
||||
"by ClientSession constructor", DeprecationWarning)
|
||||
else:
|
||||
version = self._version
|
||||
|
||||
if self.closed:
|
||||
raise RuntimeError('Session is closed')
|
||||
|
||||
redirects = 0
|
||||
history = []
|
||||
|
||||
# Merge with default headers and transform to CIMultiDict
|
||||
headers = self._prepare_headers(headers)
|
||||
if auth is None:
|
||||
auth = self._default_auth
|
||||
# It would be confusing if we support explicit Authorization header
|
||||
# with `auth` argument
|
||||
if (headers is not None and
|
||||
auth is not None and
|
||||
hdrs.AUTHORIZATION in headers):
|
||||
raise ValueError("Can't combine `Authorization` header with "
|
||||
"`auth` argument")
|
||||
|
||||
skip_headers = set(self._skip_auto_headers)
|
||||
if skip_auto_headers is not None:
|
||||
for i in skip_auto_headers:
|
||||
skip_headers.add(istr(i))
|
||||
|
||||
while True:
|
||||
url, _ = urllib.parse.urldefrag(url)
|
||||
|
||||
cookies = self._cookie_jar.filter_cookies(url)
|
||||
|
||||
req = self._request_class(
|
||||
method, url, params=params, headers=headers,
|
||||
skip_auto_headers=skip_headers, data=data,
|
||||
cookies=cookies, encoding=encoding,
|
||||
auth=auth, version=version, compress=compress, chunked=chunked,
|
||||
expect100=expect100,
|
||||
loop=self._loop, response_class=self._response_class,
|
||||
proxy=proxy, proxy_auth=proxy_auth, timeout=timeout)
|
||||
|
||||
with Timeout(timeout, loop=self._loop):
|
||||
conn = yield from self._connector.connect(req)
|
||||
try:
|
||||
resp = req.send(conn.writer, conn.reader)
|
||||
try:
|
||||
yield from resp.start(conn, read_until_eof)
|
||||
except:
|
||||
resp.close()
|
||||
conn.close()
|
||||
raise
|
||||
except (aiohttp.HttpProcessingError,
|
||||
aiohttp.ServerDisconnectedError) as exc:
|
||||
raise aiohttp.ClientResponseError() from exc
|
||||
except OSError as exc:
|
||||
raise aiohttp.ClientOSError(*exc.args) from exc
|
||||
|
||||
self._cookie_jar.update_cookies(resp.cookies, resp.url)
|
||||
|
||||
# redirects
|
||||
if resp.status in (301, 302, 303, 307) and allow_redirects:
|
||||
redirects += 1
|
||||
history.append(resp)
|
||||
if max_redirects and redirects >= max_redirects:
|
||||
resp.close()
|
||||
break
|
||||
else:
|
||||
# TODO: close the connection if BODY is large enough
|
||||
# Redirect with big BODY is forbidden by HTTP protocol
|
||||
# but malformed server may send illegal response.
|
||||
# Small BODIES with text like "Not Found" are still
|
||||
# perfectly fine and should be accepted.
|
||||
yield from resp.release()
|
||||
|
||||
# For 301 and 302, mimic IE behaviour, now changed in RFC.
|
||||
# Details: https://github.com/kennethreitz/requests/pull/269
|
||||
if (resp.status == 303 and resp.method != hdrs.METH_HEAD) \
|
||||
or (resp.status in (301, 302) and
|
||||
resp.method == hdrs.METH_POST):
|
||||
method = hdrs.METH_GET
|
||||
data = None
|
||||
if headers.get(hdrs.CONTENT_LENGTH):
|
||||
headers.pop(hdrs.CONTENT_LENGTH)
|
||||
|
||||
r_url = (resp.headers.get(hdrs.LOCATION) or
|
||||
resp.headers.get(hdrs.URI))
|
||||
|
||||
scheme = urllib.parse.urlsplit(r_url)[0]
|
||||
if scheme not in ('http', 'https', ''):
|
||||
resp.close()
|
||||
raise ValueError('Can redirect only to http or https')
|
||||
elif not scheme:
|
||||
r_url = urllib.parse.urljoin(url, r_url)
|
||||
|
||||
url = r_url
|
||||
params = None
|
||||
yield from resp.release()
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
resp._history = tuple(history)
|
||||
return resp
|
||||
|
||||
def ws_connect(self, url, *,
|
||||
protocols=(),
|
||||
timeout=10.0,
|
||||
autoclose=True,
|
||||
autoping=True,
|
||||
auth=None,
|
||||
origin=None,
|
||||
headers=None,
|
||||
proxy=None,
|
||||
proxy_auth=None):
|
||||
"""Initiate websocket connection."""
|
||||
return _WSRequestContextManager(
|
||||
self._ws_connect(url,
|
||||
protocols=protocols,
|
||||
timeout=timeout,
|
||||
autoclose=autoclose,
|
||||
autoping=autoping,
|
||||
auth=auth,
|
||||
origin=origin,
|
||||
headers=headers,
|
||||
proxy=proxy,
|
||||
proxy_auth=proxy_auth))
|
||||
|
||||
@asyncio.coroutine
|
||||
def _ws_connect(self, url, *,
|
||||
protocols=(),
|
||||
timeout=10.0,
|
||||
autoclose=True,
|
||||
autoping=True,
|
||||
auth=None,
|
||||
origin=None,
|
||||
headers=None,
|
||||
proxy=None,
|
||||
proxy_auth=None):
|
||||
|
||||
sec_key = base64.b64encode(os.urandom(16))
|
||||
|
||||
if headers is None:
|
||||
headers = CIMultiDict()
|
||||
|
||||
default_headers = {
|
||||
hdrs.UPGRADE: hdrs.WEBSOCKET,
|
||||
hdrs.CONNECTION: hdrs.UPGRADE,
|
||||
hdrs.SEC_WEBSOCKET_VERSION: '13',
|
||||
hdrs.SEC_WEBSOCKET_KEY: sec_key.decode(),
|
||||
}
|
||||
|
||||
for key, value in default_headers.items():
|
||||
if key not in headers:
|
||||
headers[key] = value
|
||||
|
||||
if protocols:
|
||||
headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols)
|
||||
if origin is not None:
|
||||
headers[hdrs.ORIGIN] = origin
|
||||
|
||||
# send request
|
||||
resp = yield from self.get(url, headers=headers,
|
||||
read_until_eof=False,
|
||||
auth=auth,
|
||||
proxy=proxy,
|
||||
proxy_auth=proxy_auth)
|
||||
|
||||
try:
|
||||
# check handshake
|
||||
if resp.status != 101:
|
||||
raise WSServerHandshakeError(
|
||||
message='Invalid response status',
|
||||
code=resp.status,
|
||||
headers=resp.headers)
|
||||
|
||||
if resp.headers.get(hdrs.UPGRADE, '').lower() != 'websocket':
|
||||
raise WSServerHandshakeError(
|
||||
message='Invalid upgrade header',
|
||||
code=resp.status,
|
||||
headers=resp.headers)
|
||||
|
||||
if resp.headers.get(hdrs.CONNECTION, '').lower() != 'upgrade':
|
||||
raise WSServerHandshakeError(
|
||||
message='Invalid connection header',
|
||||
code=resp.status,
|
||||
headers=resp.headers)
|
||||
|
||||
# key calculation
|
||||
key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, '')
|
||||
match = base64.b64encode(
|
||||
hashlib.sha1(sec_key + WS_KEY).digest()).decode()
|
||||
if key != match:
|
||||
raise WSServerHandshakeError(
|
||||
message='Invalid challenge response',
|
||||
code=resp.status,
|
||||
headers=resp.headers)
|
||||
|
||||
# websocket protocol
|
||||
protocol = None
|
||||
if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers:
|
||||
resp_protocols = [
|
||||
proto.strip() for proto in
|
||||
resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
|
||||
|
||||
for proto in resp_protocols:
|
||||
if proto in protocols:
|
||||
protocol = proto
|
||||
break
|
||||
|
||||
reader = resp.connection.reader.set_parser(WebSocketParser)
|
||||
resp.connection.writer.set_tcp_nodelay(True)
|
||||
writer = WebSocketWriter(resp.connection.writer, use_mask=True)
|
||||
except Exception:
|
||||
resp.close()
|
||||
raise
|
||||
else:
|
||||
return self._ws_response_class(reader,
|
||||
writer,
|
||||
protocol,
|
||||
resp,
|
||||
timeout,
|
||||
autoclose,
|
||||
autoping,
|
||||
self._loop)
|
||||
|
||||
def _prepare_headers(self, headers):
|
||||
""" Add default headers and transform it to CIMultiDict
|
||||
"""
|
||||
# Convert headers to MultiDict
|
||||
result = CIMultiDict(self._default_headers)
|
||||
if headers:
|
||||
if not isinstance(headers, (MultiDictProxy, MultiDict)):
|
||||
headers = CIMultiDict(headers)
|
||||
added_names = set()
|
||||
for key, value in headers.items():
|
||||
if key in added_names:
|
||||
result.add(key, value)
|
||||
else:
|
||||
result[key] = value
|
||||
added_names.add(key)
|
||||
return result
|
||||
|
||||
def get(self, url, *, allow_redirects=True, **kwargs):
|
||||
"""Perform HTTP GET request."""
|
||||
return _RequestContextManager(
|
||||
self._request(hdrs.METH_GET, url,
|
||||
allow_redirects=allow_redirects,
|
||||
**kwargs))
|
||||
|
||||
def options(self, url, *, allow_redirects=True, **kwargs):
|
||||
"""Perform HTTP OPTIONS request."""
|
||||
return _RequestContextManager(
|
||||
self._request(hdrs.METH_OPTIONS, url,
|
||||
allow_redirects=allow_redirects,
|
||||
**kwargs))
|
||||
|
||||
def head(self, url, *, allow_redirects=False, **kwargs):
|
||||
"""Perform HTTP HEAD request."""
|
||||
return _RequestContextManager(
|
||||
self._request(hdrs.METH_HEAD, url,
|
||||
allow_redirects=allow_redirects,
|
||||
**kwargs))
|
||||
|
||||
def post(self, url, *, data=None, **kwargs):
|
||||
"""Perform HTTP POST request."""
|
||||
return _RequestContextManager(
|
||||
self._request(hdrs.METH_POST, url,
|
||||
data=data,
|
||||
**kwargs))
|
||||
|
||||
def put(self, url, *, data=None, **kwargs):
|
||||
"""Perform HTTP PUT request."""
|
||||
return _RequestContextManager(
|
||||
self._request(hdrs.METH_PUT, url,
|
||||
data=data,
|
||||
**kwargs))
|
||||
|
||||
def patch(self, url, *, data=None, **kwargs):
|
||||
"""Perform HTTP PATCH request."""
|
||||
return _RequestContextManager(
|
||||
self._request(hdrs.METH_PATCH, url,
|
||||
data=data,
|
||||
**kwargs))
|
||||
|
||||
def delete(self, url, **kwargs):
|
||||
"""Perform HTTP DELETE request."""
|
||||
return _RequestContextManager(
|
||||
self._request(hdrs.METH_DELETE, url,
|
||||
**kwargs))
|
||||
|
||||
def close(self):
|
||||
"""Close underlying connector.
|
||||
|
||||
Release all acquired resources.
|
||||
"""
|
||||
if not self.closed:
|
||||
self._connector.close()
|
||||
self._connector = None
|
||||
ret = helpers.create_future(self._loop)
|
||||
ret.set_result(None)
|
||||
return ret
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
"""Is client session closed.
|
||||
|
||||
A readonly property.
|
||||
"""
|
||||
return self._connector is None or self._connector.closed
|
||||
|
||||
@property
|
||||
def connector(self):
|
||||
"""Connector instance used for the session."""
|
||||
return self._connector
|
||||
|
||||
@property
|
||||
def cookie_jar(self):
|
||||
"""The session cookies."""
|
||||
return self._cookie_jar
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
"""The session HTTP protocol version."""
|
||||
return self._version
|
||||
|
||||
@property
|
||||
def loop(self):
|
||||
"""Session's loop."""
|
||||
return self._loop
|
||||
|
||||
def detach(self):
|
||||
"""Detach connector from session without closing the former.
|
||||
|
||||
Session is switched to closed state anyway.
|
||||
"""
|
||||
self._connector = None
|
||||
|
||||
def __enter__(self):
|
||||
warnings.warn("Use async with instead", DeprecationWarning)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
if PY_35:
|
||||
@asyncio.coroutine
|
||||
def __aenter__(self):
|
||||
return self
|
||||
|
||||
@asyncio.coroutine
|
||||
def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
yield from self.close()
|
||||
|
||||
if PY_35:
|
||||
from collections.abc import Coroutine
|
||||
base = Coroutine
|
||||
else:
|
||||
base = object
|
||||
|
||||
|
||||
class _BaseRequestContextManager(base):
|
||||
|
||||
__slots__ = ('_coro', '_resp')
|
||||
|
||||
def __init__(self, coro):
|
||||
self._coro = coro
|
||||
self._resp = None
|
||||
|
||||
def send(self, value):
|
||||
return self._coro.send(value)
|
||||
|
||||
def throw(self, typ, val=None, tb=None):
|
||||
if val is None:
|
||||
return self._coro.throw(typ)
|
||||
elif tb is None:
|
||||
return self._coro.throw(typ, val)
|
||||
else:
|
||||
return self._coro.throw(typ, val, tb)
|
||||
|
||||
def close(self):
|
||||
return self._coro.close()
|
||||
|
||||
@property
|
||||
def gi_frame(self):
|
||||
return self._coro.gi_frame
|
||||
|
||||
@property
|
||||
def gi_running(self):
|
||||
return self._coro.gi_running
|
||||
|
||||
@property
|
||||
def gi_code(self):
|
||||
return self._coro.gi_code
|
||||
|
||||
def __next__(self):
|
||||
return self.send(None)
|
||||
|
||||
@asyncio.coroutine
|
||||
def __iter__(self):
|
||||
resp = yield from self._coro
|
||||
return resp
|
||||
|
||||
if PY_35:
|
||||
def __await__(self):
|
||||
resp = yield from self._coro
|
||||
return resp
|
||||
|
||||
@asyncio.coroutine
|
||||
def __aenter__(self):
|
||||
self._resp = yield from self._coro
|
||||
return self._resp
|
||||
|
||||
|
||||
if not PY_35:
|
||||
try:
|
||||
from asyncio import coroutines
|
||||
coroutines._COROUTINE_TYPES += (_BaseRequestContextManager,)
|
||||
except: # pragma: no cover
|
||||
pass # Python 3.4.2 and 3.4.3 has no coroutines._COROUTINE_TYPES
|
||||
|
||||
|
||||
class _RequestContextManager(_BaseRequestContextManager):
|
||||
if PY_35:
|
||||
@asyncio.coroutine
|
||||
def __aexit__(self, exc_type, exc, tb):
|
||||
if exc_type is not None:
|
||||
self._resp.close()
|
||||
else:
|
||||
yield from self._resp.release()
|
||||
|
||||
|
||||
class _WSRequestContextManager(_BaseRequestContextManager):
|
||||
if PY_35:
|
||||
@asyncio.coroutine
|
||||
def __aexit__(self, exc_type, exc, tb):
|
||||
yield from self._resp.close()
|
||||
|
||||
|
||||
class _DetachedRequestContextManager(_RequestContextManager):
|
||||
|
||||
__slots__ = _RequestContextManager.__slots__ + ('_session', )
|
||||
|
||||
def __init__(self, coro, session):
|
||||
super().__init__(coro)
|
||||
self._session = session
|
||||
|
||||
@asyncio.coroutine
|
||||
def __iter__(self):
|
||||
try:
|
||||
return (yield from self._coro)
|
||||
except:
|
||||
yield from self._session.close()
|
||||
raise
|
||||
|
||||
if PY_35:
|
||||
def __await__(self):
|
||||
try:
|
||||
return (yield from self._coro)
|
||||
except:
|
||||
yield from self._session.close()
|
||||
raise
|
||||
|
||||
def __del__(self):
|
||||
self._session.detach()
|
||||
|
||||
|
||||
class _DetachedWSRequestContextManager(_WSRequestContextManager):
|
||||
|
||||
__slots__ = _WSRequestContextManager.__slots__ + ('_session', )
|
||||
|
||||
def __init__(self, coro, session):
|
||||
super().__init__(coro)
|
||||
self._session = session
|
||||
|
||||
def __del__(self):
|
||||
self._session.detach()
|
||||
|
||||
|
||||
def request(method, url, *,
|
||||
params=None,
|
||||
data=None,
|
||||
headers=None,
|
||||
skip_auto_headers=None,
|
||||
cookies=None,
|
||||
auth=None,
|
||||
allow_redirects=True,
|
||||
max_redirects=10,
|
||||
encoding='utf-8',
|
||||
version=None,
|
||||
compress=None,
|
||||
chunked=None,
|
||||
expect100=False,
|
||||
connector=None,
|
||||
loop=None,
|
||||
read_until_eof=True,
|
||||
request_class=None,
|
||||
response_class=None,
|
||||
proxy=None,
|
||||
proxy_auth=None):
|
||||
"""Constructs and sends a request. Returns response object.
|
||||
|
||||
method - HTTP method
|
||||
url - request url
|
||||
params - (optional) Dictionary or bytes to be sent in the query
|
||||
string of the new request
|
||||
data - (optional) Dictionary, bytes, or file-like object to
|
||||
send in the body of the request
|
||||
headers - (optional) Dictionary of HTTP Headers to send with
|
||||
the request
|
||||
cookies - (optional) Dict object to send with the request
|
||||
auth - (optional) BasicAuth named tuple represent HTTP Basic Auth
|
||||
auth - aiohttp.helpers.BasicAuth
|
||||
allow_redirects - (optional) If set to False, do not follow
|
||||
redirects
|
||||
version - Request HTTP version.
|
||||
compress - Set to True if request has to be compressed
|
||||
with deflate encoding.
|
||||
chunked - Set to chunk size for chunked transfer encoding.
|
||||
expect100 - Expect 100-continue response from server.
|
||||
connector - BaseConnector sub-class instance to support
|
||||
connection pooling.
|
||||
read_until_eof - Read response until eof if response
|
||||
does not have Content-Length header.
|
||||
request_class - (optional) Custom Request class implementation.
|
||||
response_class - (optional) Custom Response class implementation.
|
||||
loop - Optional event loop.
|
||||
|
||||
Usage::
|
||||
|
||||
>>> import aiohttp
|
||||
>>> resp = yield from aiohttp.request('GET', 'http://python.org/')
|
||||
>>> resp
|
||||
<ClientResponse(python.org/) [200]>
|
||||
>>> data = yield from resp.read()
|
||||
|
||||
"""
|
||||
warnings.warn("Use ClientSession().request() instead", DeprecationWarning)
|
||||
if connector is None:
|
||||
connector = aiohttp.TCPConnector(loop=loop, force_close=True)
|
||||
|
||||
kwargs = {}
|
||||
|
||||
if request_class is not None:
|
||||
kwargs['request_class'] = request_class
|
||||
|
||||
if response_class is not None:
|
||||
kwargs['response_class'] = response_class
|
||||
|
||||
session = ClientSession(loop=loop,
|
||||
cookies=cookies,
|
||||
connector=connector,
|
||||
**kwargs)
|
||||
return _DetachedRequestContextManager(
|
||||
session._request(method, url,
|
||||
params=params,
|
||||
data=data,
|
||||
headers=headers,
|
||||
skip_auto_headers=skip_auto_headers,
|
||||
auth=auth,
|
||||
allow_redirects=allow_redirects,
|
||||
max_redirects=max_redirects,
|
||||
encoding=encoding,
|
||||
version=version,
|
||||
compress=compress,
|
||||
chunked=chunked,
|
||||
expect100=expect100,
|
||||
read_until_eof=read_until_eof,
|
||||
proxy=proxy,
|
||||
proxy_auth=proxy_auth,),
|
||||
session=session)
|
||||
|
||||
|
||||
def get(url, **kwargs):
|
||||
warnings.warn("Use ClientSession().get() instead", DeprecationWarning)
|
||||
return request(hdrs.METH_GET, url, **kwargs)
|
||||
|
||||
|
||||
def options(url, **kwargs):
|
||||
warnings.warn("Use ClientSession().options() instead", DeprecationWarning)
|
||||
return request(hdrs.METH_OPTIONS, url, **kwargs)
|
||||
|
||||
|
||||
def head(url, **kwargs):
|
||||
warnings.warn("Use ClientSession().head() instead", DeprecationWarning)
|
||||
return request(hdrs.METH_HEAD, url, **kwargs)
|
||||
|
||||
|
||||
def post(url, **kwargs):
|
||||
warnings.warn("Use ClientSession().post() instead", DeprecationWarning)
|
||||
return request(hdrs.METH_POST, url, **kwargs)
|
||||
|
||||
|
||||
def put(url, **kwargs):
|
||||
warnings.warn("Use ClientSession().put() instead", DeprecationWarning)
|
||||
return request(hdrs.METH_PUT, url, **kwargs)
|
||||
|
||||
|
||||
def patch(url, **kwargs):
|
||||
warnings.warn("Use ClientSession().patch() instead", DeprecationWarning)
|
||||
return request(hdrs.METH_PATCH, url, **kwargs)
|
||||
|
||||
|
||||
def delete(url, **kwargs):
|
||||
warnings.warn("Use ClientSession().delete() instead", DeprecationWarning)
|
||||
return request(hdrs.METH_DELETE, url, **kwargs)
|
||||
|
||||
|
||||
def ws_connect(url, *, protocols=(), timeout=10.0, connector=None, auth=None,
|
||||
ws_response_class=ClientWebSocketResponse, autoclose=True,
|
||||
autoping=True, loop=None, origin=None, headers=None):
|
||||
|
||||
warnings.warn("Use ClientSession().ws_connect() instead",
|
||||
DeprecationWarning)
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if connector is None:
|
||||
connector = aiohttp.TCPConnector(loop=loop, force_close=True)
|
||||
|
||||
session = aiohttp.ClientSession(loop=loop, connector=connector, auth=auth,
|
||||
ws_response_class=ws_response_class,
|
||||
headers=headers)
|
||||
|
||||
return _DetachedWSRequestContextManager(
|
||||
session._ws_connect(url,
|
||||
protocols=protocols,
|
||||
timeout=timeout,
|
||||
autoclose=autoclose,
|
||||
autoping=autoping,
|
||||
origin=origin),
|
||||
session=session)
|
||||
|
|
@ -0,0 +1,801 @@
|
|||
import asyncio
|
||||
import collections
|
||||
import http.cookies
|
||||
import io
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import urllib.parse
|
||||
import warnings
|
||||
|
||||
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
|
||||
|
||||
import aiohttp
|
||||
|
||||
from . import hdrs, helpers, streams
|
||||
from .helpers import Timeout
|
||||
from .log import client_logger
|
||||
from .multipart import MultipartWriter
|
||||
from .protocol import HttpMessage
|
||||
from .streams import EOF_MARKER, FlowControlStreamReader
|
||||
|
||||
try:
|
||||
import cchardet as chardet
|
||||
except ImportError:
|
||||
import chardet
|
||||
|
||||
|
||||
__all__ = ('ClientRequest', 'ClientResponse')
|
||||
|
||||
PY_35 = sys.version_info >= (3, 5)
|
||||
|
||||
HTTP_PORT = 80
|
||||
HTTPS_PORT = 443
|
||||
|
||||
|
||||
class ClientRequest:
|
||||
|
||||
GET_METHODS = {hdrs.METH_GET, hdrs.METH_HEAD, hdrs.METH_OPTIONS}
|
||||
POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT}
|
||||
ALL_METHODS = GET_METHODS.union(POST_METHODS).union(
|
||||
{hdrs.METH_DELETE, hdrs.METH_TRACE})
|
||||
|
||||
DEFAULT_HEADERS = {
|
||||
hdrs.ACCEPT: '*/*',
|
||||
hdrs.ACCEPT_ENCODING: 'gzip, deflate',
|
||||
}
|
||||
|
||||
SERVER_SOFTWARE = HttpMessage.SERVER_SOFTWARE
|
||||
|
||||
body = b''
|
||||
auth = None
|
||||
response = None
|
||||
response_class = None
|
||||
|
||||
_writer = None # async task for streaming data
|
||||
_continue = None # waiter future for '100 Continue' response
|
||||
|
||||
# N.B.
|
||||
# Adding __del__ method with self._writer closing doesn't make sense
|
||||
# because _writer is instance method, thus it keeps a reference to self.
|
||||
# Until writer has finished finalizer will not be called.
|
||||
|
||||
def __init__(self, method, url, *,
|
||||
params=None, headers=None, skip_auto_headers=frozenset(),
|
||||
data=None, cookies=None,
|
||||
auth=None, encoding='utf-8',
|
||||
version=aiohttp.HttpVersion11, compress=None,
|
||||
chunked=None, expect100=False,
|
||||
loop=None, response_class=None,
|
||||
proxy=None, proxy_auth=None,
|
||||
timeout=5*60):
|
||||
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
self.url = url
|
||||
self.method = method.upper()
|
||||
self.encoding = encoding
|
||||
self.chunked = chunked
|
||||
self.compress = compress
|
||||
self.loop = loop
|
||||
self.response_class = response_class or ClientResponse
|
||||
self._timeout = timeout
|
||||
|
||||
if loop.get_debug():
|
||||
self._source_traceback = traceback.extract_stack(sys._getframe(1))
|
||||
|
||||
self.update_version(version)
|
||||
self.update_host(url)
|
||||
self.update_path(params)
|
||||
self.update_headers(headers)
|
||||
self.update_auto_headers(skip_auto_headers)
|
||||
self.update_cookies(cookies)
|
||||
self.update_content_encoding(data)
|
||||
self.update_auth(auth)
|
||||
self.update_proxy(proxy, proxy_auth)
|
||||
|
||||
self.update_body_from_data(data, skip_auto_headers)
|
||||
self.update_transfer_encoding()
|
||||
self.update_expect_continue(expect100)
|
||||
|
||||
def update_host(self, url):
|
||||
"""Update destination host, port and connection type (ssl)."""
|
||||
url_parsed = urllib.parse.urlsplit(url)
|
||||
|
||||
# check for network location part
|
||||
netloc = url_parsed.netloc
|
||||
if not netloc:
|
||||
raise ValueError('Host could not be detected.')
|
||||
|
||||
# get host/port
|
||||
host = url_parsed.hostname
|
||||
if not host:
|
||||
raise ValueError('Host could not be detected.')
|
||||
|
||||
try:
|
||||
port = url_parsed.port
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
'Port number could not be converted.') from None
|
||||
|
||||
# check domain idna encoding
|
||||
try:
|
||||
host = host.encode('idna').decode('utf-8')
|
||||
netloc = self.make_netloc(host, url_parsed.port)
|
||||
except UnicodeError:
|
||||
raise ValueError('URL has an invalid label.')
|
||||
|
||||
# basic auth info
|
||||
username, password = url_parsed.username, url_parsed.password
|
||||
if username:
|
||||
self.auth = helpers.BasicAuth(username, password or '')
|
||||
|
||||
# Record entire netloc for usage in host header
|
||||
self.netloc = netloc
|
||||
|
||||
scheme = url_parsed.scheme
|
||||
self.ssl = scheme in ('https', 'wss')
|
||||
|
||||
# set port number if it isn't already set
|
||||
if not port:
|
||||
if self.ssl:
|
||||
port = HTTPS_PORT
|
||||
else:
|
||||
port = HTTP_PORT
|
||||
|
||||
self.host, self.port, self.scheme = host, port, scheme
|
||||
|
||||
def make_netloc(self, host, port):
|
||||
ret = host
|
||||
if port:
|
||||
ret = ret + ':' + str(port)
|
||||
return ret
|
||||
|
||||
def update_version(self, version):
|
||||
"""Convert request version to two elements tuple.
|
||||
|
||||
parser HTTP version '1.1' => (1, 1)
|
||||
"""
|
||||
if isinstance(version, str):
|
||||
v = [l.strip() for l in version.split('.', 1)]
|
||||
try:
|
||||
version = int(v[0]), int(v[1])
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
'Can not parse http version number: {}'
|
||||
.format(version)) from None
|
||||
self.version = version
|
||||
|
||||
def update_path(self, params):
|
||||
"""Build path."""
|
||||
# extract path
|
||||
scheme, netloc, path, query, fragment = urllib.parse.urlsplit(self.url)
|
||||
if not path:
|
||||
path = '/'
|
||||
|
||||
if isinstance(params, collections.Mapping):
|
||||
params = list(params.items())
|
||||
|
||||
if params:
|
||||
if not isinstance(params, str):
|
||||
params = urllib.parse.urlencode(params)
|
||||
if query:
|
||||
query = '%s&%s' % (query, params)
|
||||
else:
|
||||
query = params
|
||||
|
||||
self.path = urllib.parse.urlunsplit(('', '', helpers.requote_uri(path),
|
||||
query, ''))
|
||||
self.url = urllib.parse.urlunsplit(
|
||||
(scheme, netloc, self.path, '', fragment))
|
||||
|
||||
def update_headers(self, headers):
|
||||
"""Update request headers."""
|
||||
self.headers = CIMultiDict()
|
||||
if headers:
|
||||
if isinstance(headers, dict):
|
||||
headers = headers.items()
|
||||
elif isinstance(headers, (MultiDictProxy, MultiDict)):
|
||||
headers = headers.items()
|
||||
|
||||
for key, value in headers:
|
||||
self.headers.add(key, value)
|
||||
|
||||
def update_auto_headers(self, skip_auto_headers):
|
||||
self.skip_auto_headers = skip_auto_headers
|
||||
used_headers = set(self.headers) | skip_auto_headers
|
||||
|
||||
for hdr, val in self.DEFAULT_HEADERS.items():
|
||||
if hdr not in used_headers:
|
||||
self.headers.add(hdr, val)
|
||||
|
||||
# add host
|
||||
if hdrs.HOST not in used_headers:
|
||||
self.headers[hdrs.HOST] = self.netloc
|
||||
|
||||
if hdrs.USER_AGENT not in used_headers:
|
||||
self.headers[hdrs.USER_AGENT] = self.SERVER_SOFTWARE
|
||||
|
||||
def update_cookies(self, cookies):
|
||||
"""Update request cookies header."""
|
||||
if not cookies:
|
||||
return
|
||||
|
||||
c = http.cookies.SimpleCookie()
|
||||
if hdrs.COOKIE in self.headers:
|
||||
c.load(self.headers.get(hdrs.COOKIE, ''))
|
||||
del self.headers[hdrs.COOKIE]
|
||||
|
||||
if isinstance(cookies, dict):
|
||||
cookies = cookies.items()
|
||||
|
||||
for name, value in cookies:
|
||||
if isinstance(value, http.cookies.Morsel):
|
||||
c[value.key] = value.value
|
||||
else:
|
||||
c[name] = value
|
||||
|
||||
self.headers[hdrs.COOKIE] = c.output(header='', sep=';').strip()
|
||||
|
||||
def update_content_encoding(self, data):
|
||||
"""Set request content encoding."""
|
||||
if not data:
|
||||
return
|
||||
|
||||
enc = self.headers.get(hdrs.CONTENT_ENCODING, '').lower()
|
||||
if enc:
|
||||
if self.compress is not False:
|
||||
self.compress = enc
|
||||
# enable chunked, no need to deal with length
|
||||
self.chunked = True
|
||||
elif self.compress:
|
||||
if not isinstance(self.compress, str):
|
||||
self.compress = 'deflate'
|
||||
self.headers[hdrs.CONTENT_ENCODING] = self.compress
|
||||
self.chunked = True # enable chunked, no need to deal with length
|
||||
|
||||
def update_auth(self, auth):
|
||||
"""Set basic auth."""
|
||||
if auth is None:
|
||||
auth = self.auth
|
||||
if auth is None:
|
||||
return
|
||||
|
||||
if not isinstance(auth, helpers.BasicAuth):
|
||||
raise TypeError('BasicAuth() tuple is required instead')
|
||||
|
||||
self.headers[hdrs.AUTHORIZATION] = auth.encode()
|
||||
|
||||
def update_body_from_data(self, data, skip_auto_headers):
|
||||
if not data:
|
||||
return
|
||||
|
||||
if isinstance(data, str):
|
||||
data = data.encode(self.encoding)
|
||||
|
||||
if isinstance(data, (bytes, bytearray)):
|
||||
self.body = data
|
||||
if (hdrs.CONTENT_TYPE not in self.headers and
|
||||
hdrs.CONTENT_TYPE not in skip_auto_headers):
|
||||
self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'
|
||||
if hdrs.CONTENT_LENGTH not in self.headers and not self.chunked:
|
||||
self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))
|
||||
|
||||
elif isinstance(data, (asyncio.StreamReader, streams.StreamReader,
|
||||
streams.DataQueue)):
|
||||
self.body = data
|
||||
|
||||
elif asyncio.iscoroutine(data):
|
||||
self.body = data
|
||||
if (hdrs.CONTENT_LENGTH not in self.headers and
|
||||
self.chunked is None):
|
||||
self.chunked = True
|
||||
|
||||
elif isinstance(data, io.IOBase):
|
||||
assert not isinstance(data, io.StringIO), \
|
||||
'attempt to send text data instead of binary'
|
||||
self.body = data
|
||||
if not self.chunked and isinstance(data, io.BytesIO):
|
||||
# Not chunking if content-length can be determined
|
||||
size = len(data.getbuffer())
|
||||
self.headers[hdrs.CONTENT_LENGTH] = str(size)
|
||||
self.chunked = False
|
||||
elif not self.chunked and isinstance(data, io.BufferedReader):
|
||||
# Not chunking if content-length can be determined
|
||||
try:
|
||||
size = os.fstat(data.fileno()).st_size - data.tell()
|
||||
self.headers[hdrs.CONTENT_LENGTH] = str(size)
|
||||
self.chunked = False
|
||||
except OSError:
|
||||
# data.fileno() is not supported, e.g.
|
||||
# io.BufferedReader(io.BytesIO(b'data'))
|
||||
self.chunked = True
|
||||
else:
|
||||
self.chunked = True
|
||||
|
||||
if hasattr(data, 'mode'):
|
||||
if data.mode == 'r':
|
||||
raise ValueError('file {!r} should be open in binary mode'
|
||||
''.format(data))
|
||||
if (hdrs.CONTENT_TYPE not in self.headers and
|
||||
hdrs.CONTENT_TYPE not in skip_auto_headers and
|
||||
hasattr(data, 'name')):
|
||||
mime = mimetypes.guess_type(data.name)[0]
|
||||
mime = 'application/octet-stream' if mime is None else mime
|
||||
self.headers[hdrs.CONTENT_TYPE] = mime
|
||||
|
||||
elif isinstance(data, MultipartWriter):
|
||||
self.body = data.serialize()
|
||||
self.headers.update(data.headers)
|
||||
self.chunked = self.chunked or 8192
|
||||
|
||||
else:
|
||||
if not isinstance(data, helpers.FormData):
|
||||
data = helpers.FormData(data)
|
||||
|
||||
self.body = data(self.encoding)
|
||||
|
||||
if (hdrs.CONTENT_TYPE not in self.headers and
|
||||
hdrs.CONTENT_TYPE not in skip_auto_headers):
|
||||
self.headers[hdrs.CONTENT_TYPE] = data.content_type
|
||||
|
||||
if data.is_multipart:
|
||||
self.chunked = self.chunked or 8192
|
||||
else:
|
||||
if (hdrs.CONTENT_LENGTH not in self.headers and
|
||||
not self.chunked):
|
||||
self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))
|
||||
|
||||
def update_transfer_encoding(self):
|
||||
"""Analyze transfer-encoding header."""
|
||||
te = self.headers.get(hdrs.TRANSFER_ENCODING, '').lower()
|
||||
|
||||
if self.chunked:
|
||||
if hdrs.CONTENT_LENGTH in self.headers:
|
||||
del self.headers[hdrs.CONTENT_LENGTH]
|
||||
if 'chunked' not in te:
|
||||
self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'
|
||||
|
||||
self.chunked = self.chunked if type(self.chunked) is int else 8192
|
||||
else:
|
||||
if 'chunked' in te:
|
||||
self.chunked = 8192
|
||||
else:
|
||||
self.chunked = None
|
||||
if hdrs.CONTENT_LENGTH not in self.headers:
|
||||
self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body))
|
||||
|
||||
def update_expect_continue(self, expect=False):
|
||||
if expect:
|
||||
self.headers[hdrs.EXPECT] = '100-continue'
|
||||
elif self.headers.get(hdrs.EXPECT, '').lower() == '100-continue':
|
||||
expect = True
|
||||
|
||||
if expect:
|
||||
self._continue = helpers.create_future(self.loop)
|
||||
|
||||
def update_proxy(self, proxy, proxy_auth):
|
||||
if proxy and not proxy.startswith('http://'):
|
||||
raise ValueError("Only http proxies are supported")
|
||||
if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
|
||||
raise ValueError("proxy_auth must be None or BasicAuth() tuple")
|
||||
self.proxy = proxy
|
||||
self.proxy_auth = proxy_auth
|
||||
|
||||
@asyncio.coroutine
|
||||
def write_bytes(self, request, reader):
|
||||
"""Support coroutines that yields bytes objects."""
|
||||
# 100 response
|
||||
if self._continue is not None:
|
||||
yield from self._continue
|
||||
|
||||
try:
|
||||
if asyncio.iscoroutine(self.body):
|
||||
request.transport.set_tcp_nodelay(True)
|
||||
exc = None
|
||||
value = None
|
||||
stream = self.body
|
||||
|
||||
while True:
|
||||
try:
|
||||
if exc is not None:
|
||||
result = stream.throw(exc)
|
||||
else:
|
||||
result = stream.send(value)
|
||||
except StopIteration as exc:
|
||||
if isinstance(exc.value, bytes):
|
||||
yield from request.write(exc.value, drain=True)
|
||||
break
|
||||
except:
|
||||
self.response.close()
|
||||
raise
|
||||
|
||||
if isinstance(result, asyncio.Future):
|
||||
exc = None
|
||||
value = None
|
||||
try:
|
||||
value = yield result
|
||||
except Exception as err:
|
||||
exc = err
|
||||
elif isinstance(result, (bytes, bytearray)):
|
||||
yield from request.write(result, drain=True)
|
||||
value = None
|
||||
else:
|
||||
raise ValueError(
|
||||
'Bytes object is expected, got: %s.' %
|
||||
type(result))
|
||||
|
||||
elif isinstance(self.body, (asyncio.StreamReader,
|
||||
streams.StreamReader)):
|
||||
request.transport.set_tcp_nodelay(True)
|
||||
chunk = yield from self.body.read(streams.DEFAULT_LIMIT)
|
||||
while chunk:
|
||||
yield from request.write(chunk, drain=True)
|
||||
chunk = yield from self.body.read(streams.DEFAULT_LIMIT)
|
||||
|
||||
elif isinstance(self.body, streams.DataQueue):
|
||||
request.transport.set_tcp_nodelay(True)
|
||||
while True:
|
||||
try:
|
||||
chunk = yield from self.body.read()
|
||||
if chunk is EOF_MARKER:
|
||||
break
|
||||
yield from request.write(chunk, drain=True)
|
||||
except streams.EofStream:
|
||||
break
|
||||
|
||||
elif isinstance(self.body, io.IOBase):
|
||||
chunk = self.body.read(self.chunked)
|
||||
while chunk:
|
||||
request.write(chunk)
|
||||
chunk = self.body.read(self.chunked)
|
||||
request.transport.set_tcp_nodelay(True)
|
||||
|
||||
else:
|
||||
if isinstance(self.body, (bytes, bytearray)):
|
||||
self.body = (self.body,)
|
||||
|
||||
for chunk in self.body:
|
||||
request.write(chunk)
|
||||
request.transport.set_tcp_nodelay(True)
|
||||
|
||||
except Exception as exc:
|
||||
new_exc = aiohttp.ClientRequestError(
|
||||
'Can not write request body for %s' % self.url)
|
||||
new_exc.__context__ = exc
|
||||
new_exc.__cause__ = exc
|
||||
reader.set_exception(new_exc)
|
||||
else:
|
||||
assert request.transport.tcp_nodelay
|
||||
try:
|
||||
ret = request.write_eof()
|
||||
# NB: in asyncio 3.4.1+ StreamWriter.drain() is coroutine
|
||||
# see bug #170
|
||||
if (asyncio.iscoroutine(ret) or
|
||||
isinstance(ret, asyncio.Future)):
|
||||
yield from ret
|
||||
except Exception as exc:
|
||||
new_exc = aiohttp.ClientRequestError(
|
||||
'Can not write request body for %s' % self.url)
|
||||
new_exc.__context__ = exc
|
||||
new_exc.__cause__ = exc
|
||||
reader.set_exception(new_exc)
|
||||
|
||||
self._writer = None
|
||||
|
||||
def send(self, writer, reader):
|
||||
writer.set_tcp_cork(True)
|
||||
request = aiohttp.Request(writer, self.method, self.path, self.version)
|
||||
|
||||
if self.compress:
|
||||
request.add_compression_filter(self.compress)
|
||||
|
||||
if self.chunked is not None:
|
||||
request.enable_chunked_encoding()
|
||||
request.add_chunking_filter(self.chunked)
|
||||
|
||||
# set default content-type
|
||||
if (self.method in self.POST_METHODS and
|
||||
hdrs.CONTENT_TYPE not in self.skip_auto_headers and
|
||||
hdrs.CONTENT_TYPE not in self.headers):
|
||||
self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'
|
||||
|
||||
for k, value in self.headers.items():
|
||||
request.add_header(k, value)
|
||||
request.send_headers()
|
||||
|
||||
self._writer = helpers.ensure_future(
|
||||
self.write_bytes(request, reader), loop=self.loop)
|
||||
|
||||
self.response = self.response_class(
|
||||
self.method, self.url, self.host,
|
||||
writer=self._writer, continue100=self._continue,
|
||||
timeout=self._timeout)
|
||||
self.response._post_init(self.loop)
|
||||
return self.response
|
||||
|
||||
@asyncio.coroutine
|
||||
def close(self):
|
||||
if self._writer is not None:
|
||||
try:
|
||||
yield from self._writer
|
||||
finally:
|
||||
self._writer = None
|
||||
|
||||
def terminate(self):
|
||||
if self._writer is not None:
|
||||
if not self.loop.is_closed():
|
||||
self._writer.cancel()
|
||||
self._writer = None
|
||||
|
||||
|
||||
class ClientResponse:
|
||||
|
||||
# from the Status-Line of the response
|
||||
version = None # HTTP-Version
|
||||
status = None # Status-Code
|
||||
reason = None # Reason-Phrase
|
||||
|
||||
cookies = None # Response cookies (Set-Cookie)
|
||||
content = None # Payload stream
|
||||
headers = None # Response headers, CIMultiDictProxy
|
||||
raw_headers = None # Response raw headers, a sequence of pairs
|
||||
|
||||
_connection = None # current connection
|
||||
flow_control_class = FlowControlStreamReader # reader flow control
|
||||
_reader = None # input stream
|
||||
_response_parser = aiohttp.HttpResponseParser()
|
||||
_source_traceback = None
|
||||
# setted up by ClientRequest after ClientResponse object creation
|
||||
# post-init stage allows to not change ctor signature
|
||||
_loop = None
|
||||
_closed = True # to allow __del__ for non-initialized properly response
|
||||
|
||||
def __init__(self, method, url, host='', *, writer=None, continue100=None,
|
||||
timeout=5*60):
|
||||
super().__init__()
|
||||
|
||||
self.method = method
|
||||
self.url = url
|
||||
self.host = host
|
||||
self._content = None
|
||||
self._writer = writer
|
||||
self._continue = continue100
|
||||
self._closed = False
|
||||
self._should_close = True # override by message.should_close later
|
||||
self._history = ()
|
||||
self._timeout = timeout
|
||||
|
||||
def _post_init(self, loop):
|
||||
self._loop = loop
|
||||
if loop.get_debug():
|
||||
self._source_traceback = traceback.extract_stack(sys._getframe(1))
|
||||
|
||||
def __del__(self, _warnings=warnings):
|
||||
if self._loop is None:
|
||||
return # not started
|
||||
if self._closed:
|
||||
return
|
||||
self.close()
|
||||
|
||||
_warnings.warn("Unclosed response {!r}".format(self),
|
||||
ResourceWarning)
|
||||
context = {'client_response': self,
|
||||
'message': 'Unclosed response'}
|
||||
if self._source_traceback:
|
||||
context['source_traceback'] = self._source_traceback
|
||||
self._loop.call_exception_handler(context)
|
||||
|
||||
def __repr__(self):
|
||||
out = io.StringIO()
|
||||
ascii_encodable_url = self.url.encode('ascii', 'backslashreplace') \
|
||||
.decode('ascii')
|
||||
if self.reason:
|
||||
ascii_encodable_reason = self.reason.encode('ascii',
|
||||
'backslashreplace') \
|
||||
.decode('ascii')
|
||||
else:
|
||||
ascii_encodable_reason = self.reason
|
||||
print('<ClientResponse({}) [{} {}]>'.format(
|
||||
ascii_encodable_url, self.status, ascii_encodable_reason),
|
||||
file=out)
|
||||
print(self.headers, file=out)
|
||||
return out.getvalue()
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
return self._connection
|
||||
|
||||
@property
|
||||
def history(self):
|
||||
"""A sequence of of responses, if redirects occured."""
|
||||
return self._history
|
||||
|
||||
def waiting_for_continue(self):
|
||||
return self._continue is not None
|
||||
|
||||
def _setup_connection(self, connection):
|
||||
self._reader = connection.reader
|
||||
self._connection = connection
|
||||
self.content = self.flow_control_class(
|
||||
connection.reader, loop=connection.loop, timeout=self._timeout)
|
||||
|
||||
def _need_parse_response_body(self):
|
||||
return (self.method.lower() != 'head' and
|
||||
self.status not in [204, 304])
|
||||
|
||||
@asyncio.coroutine
|
||||
def start(self, connection, read_until_eof=False):
|
||||
"""Start response processing."""
|
||||
self._setup_connection(connection)
|
||||
|
||||
while True:
|
||||
httpstream = self._reader.set_parser(self._response_parser)
|
||||
|
||||
# read response
|
||||
with Timeout(self._timeout, loop=self._loop):
|
||||
message = yield from httpstream.read()
|
||||
if message.code != 100:
|
||||
break
|
||||
|
||||
if self._continue is not None and not self._continue.done():
|
||||
self._continue.set_result(True)
|
||||
self._continue = None
|
||||
|
||||
# response status
|
||||
self.version = message.version
|
||||
self.status = message.code
|
||||
self.reason = message.reason
|
||||
self._should_close = message.should_close
|
||||
|
||||
# headers
|
||||
self.headers = CIMultiDictProxy(message.headers)
|
||||
self.raw_headers = tuple(message.raw_headers)
|
||||
|
||||
# payload
|
||||
rwb = self._need_parse_response_body()
|
||||
self._reader.set_parser(
|
||||
aiohttp.HttpPayloadParser(message,
|
||||
readall=read_until_eof,
|
||||
response_with_body=rwb),
|
||||
self.content)
|
||||
|
||||
# cookies
|
||||
self.cookies = http.cookies.SimpleCookie()
|
||||
if hdrs.SET_COOKIE in self.headers:
|
||||
for hdr in self.headers.getall(hdrs.SET_COOKIE):
|
||||
try:
|
||||
self.cookies.load(hdr)
|
||||
except http.cookies.CookieError as exc:
|
||||
client_logger.warning(
|
||||
'Can not load response cookies: %s', exc)
|
||||
return self
|
||||
|
||||
def close(self):
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
if self._loop is None or self._loop.is_closed():
|
||||
return
|
||||
|
||||
if self._connection is not None:
|
||||
self._connection.close()
|
||||
self._connection = None
|
||||
self._cleanup_writer()
|
||||
|
||||
@asyncio.coroutine
|
||||
def release(self):
|
||||
if self._closed:
|
||||
return
|
||||
try:
|
||||
content = self.content
|
||||
if content is not None and not content.at_eof():
|
||||
chunk = yield from content.readany()
|
||||
while chunk is not EOF_MARKER or chunk:
|
||||
chunk = yield from content.readany()
|
||||
except Exception:
|
||||
self._connection.close()
|
||||
self._connection = None
|
||||
raise
|
||||
finally:
|
||||
self._closed = True
|
||||
if self._connection is not None:
|
||||
self._connection.release()
|
||||
if self._reader is not None:
|
||||
self._reader.unset_parser()
|
||||
self._connection = None
|
||||
self._cleanup_writer()
|
||||
|
||||
def raise_for_status(self):
|
||||
if 400 <= self.status:
|
||||
raise aiohttp.HttpProcessingError(
|
||||
code=self.status,
|
||||
message=self.reason)
|
||||
|
||||
def _cleanup_writer(self):
|
||||
if self._writer is not None and not self._writer.done():
|
||||
self._writer.cancel()
|
||||
self._writer = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def wait_for_close(self):
|
||||
if self._writer is not None:
|
||||
try:
|
||||
yield from self._writer
|
||||
finally:
|
||||
self._writer = None
|
||||
yield from self.release()
|
||||
|
||||
@asyncio.coroutine
|
||||
def read(self):
|
||||
"""Read response payload."""
|
||||
if self._content is None:
|
||||
try:
|
||||
self._content = yield from self.content.read()
|
||||
except:
|
||||
self.close()
|
||||
raise
|
||||
else:
|
||||
yield from self.release()
|
||||
|
||||
return self._content
|
||||
|
||||
def _get_encoding(self):
|
||||
ctype = self.headers.get(hdrs.CONTENT_TYPE, '').lower()
|
||||
mtype, stype, _, params = helpers.parse_mimetype(ctype)
|
||||
|
||||
encoding = params.get('charset')
|
||||
if not encoding:
|
||||
encoding = chardet.detect(self._content)['encoding']
|
||||
if not encoding:
|
||||
encoding = 'utf-8'
|
||||
|
||||
return encoding
|
||||
|
||||
@asyncio.coroutine
|
||||
def text(self, encoding=None):
|
||||
"""Read response payload and decode."""
|
||||
if self._content is None:
|
||||
yield from self.read()
|
||||
|
||||
if encoding is None:
|
||||
encoding = self._get_encoding()
|
||||
|
||||
return self._content.decode(encoding)
|
||||
|
||||
@asyncio.coroutine
|
||||
def json(self, *, encoding=None, loads=json.loads):
|
||||
"""Read and decodes JSON response."""
|
||||
if self._content is None:
|
||||
yield from self.read()
|
||||
|
||||
ctype = self.headers.get(hdrs.CONTENT_TYPE, '').lower()
|
||||
if 'json' not in ctype:
|
||||
client_logger.warning(
|
||||
'Attempt to decode JSON with unexpected mimetype: %s', ctype)
|
||||
|
||||
stripped = self._content.strip()
|
||||
if not stripped:
|
||||
return None
|
||||
|
||||
if encoding is None:
|
||||
encoding = self._get_encoding()
|
||||
|
||||
return loads(stripped.decode(encoding))
|
||||
|
||||
if PY_35:
|
||||
@asyncio.coroutine
|
||||
def __aenter__(self):
|
||||
return self
|
||||
|
||||
@asyncio.coroutine
|
||||
def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is None:
|
||||
yield from self.release()
|
||||
else:
|
||||
self.close()
|
||||
|
|
@ -0,0 +1,193 @@
|
|||
"""WebSocket client for asyncio."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
|
||||
from ._ws_impl import CLOSED_MESSAGE, WebSocketError, WSMessage, WSMsgType
|
||||
|
||||
PY_35 = sys.version_info >= (3, 5)
|
||||
PY_352 = sys.version_info >= (3, 5, 2)
|
||||
|
||||
|
||||
class ClientWebSocketResponse:
|
||||
|
||||
def __init__(self, reader, writer, protocol,
|
||||
response, timeout, autoclose, autoping, loop):
|
||||
self._response = response
|
||||
self._conn = response.connection
|
||||
|
||||
self._writer = writer
|
||||
self._reader = reader
|
||||
self._protocol = protocol
|
||||
self._closed = False
|
||||
self._closing = False
|
||||
self._close_code = None
|
||||
self._timeout = timeout
|
||||
self._autoclose = autoclose
|
||||
self._autoping = autoping
|
||||
self._loop = loop
|
||||
self._waiting = False
|
||||
self._exception = None
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
return self._closed
|
||||
|
||||
@property
|
||||
def close_code(self):
|
||||
return self._close_code
|
||||
|
||||
@property
|
||||
def protocol(self):
|
||||
return self._protocol
|
||||
|
||||
def exception(self):
|
||||
return self._exception
|
||||
|
||||
def ping(self, message='b'):
|
||||
if self._closed:
|
||||
raise RuntimeError('websocket connection is closed')
|
||||
self._writer.ping(message)
|
||||
|
||||
def pong(self, message='b'):
|
||||
if self._closed:
|
||||
raise RuntimeError('websocket connection is closed')
|
||||
self._writer.pong(message)
|
||||
|
||||
def send_str(self, data):
|
||||
if self._closed:
|
||||
raise RuntimeError('websocket connection is closed')
|
||||
if not isinstance(data, str):
|
||||
raise TypeError('data argument must be str (%r)' % type(data))
|
||||
self._writer.send(data, binary=False)
|
||||
|
||||
def send_bytes(self, data):
|
||||
if self._closed:
|
||||
raise RuntimeError('websocket connection is closed')
|
||||
if not isinstance(data, (bytes, bytearray, memoryview)):
|
||||
raise TypeError('data argument must be byte-ish (%r)' %
|
||||
type(data))
|
||||
self._writer.send(data, binary=True)
|
||||
|
||||
def send_json(self, data, *, dumps=json.dumps):
|
||||
self.send_str(dumps(data))
|
||||
|
||||
@asyncio.coroutine
|
||||
def close(self, *, code=1000, message=b''):
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
try:
|
||||
self._writer.close(code, message)
|
||||
except asyncio.CancelledError:
|
||||
self._close_code = 1006
|
||||
self._response.close()
|
||||
raise
|
||||
except Exception as exc:
|
||||
self._close_code = 1006
|
||||
self._exception = exc
|
||||
self._response.close()
|
||||
return True
|
||||
|
||||
if self._closing:
|
||||
self._response.close()
|
||||
return True
|
||||
|
||||
while True:
|
||||
try:
|
||||
msg = yield from asyncio.wait_for(
|
||||
self._reader.read(), self._timeout, loop=self._loop)
|
||||
except asyncio.CancelledError:
|
||||
self._close_code = 1006
|
||||
self._response.close()
|
||||
raise
|
||||
except Exception as exc:
|
||||
self._close_code = 1006
|
||||
self._exception = exc
|
||||
self._response.close()
|
||||
return True
|
||||
|
||||
if msg.type == WSMsgType.CLOSE:
|
||||
self._close_code = msg.data
|
||||
self._response.close()
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@asyncio.coroutine
|
||||
def receive(self):
|
||||
if self._waiting:
|
||||
raise RuntimeError('Concurrent call to receive() is not allowed')
|
||||
|
||||
self._waiting = True
|
||||
try:
|
||||
while True:
|
||||
if self._closed:
|
||||
return CLOSED_MESSAGE
|
||||
|
||||
try:
|
||||
msg = yield from self._reader.read()
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
raise
|
||||
except WebSocketError as exc:
|
||||
self._close_code = exc.code
|
||||
yield from self.close(code=exc.code)
|
||||
return WSMessage(WSMsgType.ERROR, exc, None)
|
||||
except Exception as exc:
|
||||
self._exception = exc
|
||||
self._closing = True
|
||||
self._close_code = 1006
|
||||
yield from self.close()
|
||||
return WSMessage(WSMsgType.ERROR, exc, None)
|
||||
|
||||
if msg.type == WSMsgType.CLOSE:
|
||||
self._closing = True
|
||||
self._close_code = msg.data
|
||||
if not self._closed and self._autoclose:
|
||||
yield from self.close()
|
||||
return msg
|
||||
if msg.type == WSMsgType.PING and self._autoping:
|
||||
self.pong(msg.data)
|
||||
elif msg.type == WSMsgType.PONG and self._autoping:
|
||||
continue
|
||||
else:
|
||||
return msg
|
||||
finally:
|
||||
self._waiting = False
|
||||
|
||||
@asyncio.coroutine
|
||||
def receive_str(self):
|
||||
msg = yield from self.receive()
|
||||
if msg.type != WSMsgType.TEXT:
|
||||
raise TypeError(
|
||||
"Received message {}:{!r} is not str".format(msg.type,
|
||||
msg.data))
|
||||
return msg.data
|
||||
|
||||
@asyncio.coroutine
|
||||
def receive_bytes(self):
|
||||
msg = yield from self.receive()
|
||||
if msg.type != WSMsgType.BINARY:
|
||||
raise TypeError(
|
||||
"Received message {}:{!r} is not bytes".format(msg.type,
|
||||
msg.data))
|
||||
return msg.data
|
||||
|
||||
@asyncio.coroutine
|
||||
def receive_json(self, *, loads=json.loads):
|
||||
data = yield from self.receive_str()
|
||||
return loads(data)
|
||||
|
||||
if PY_35:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
if not PY_352: # pragma: no cover
|
||||
__aiter__ = asyncio.coroutine(__aiter__)
|
||||
|
||||
@asyncio.coroutine
|
||||
def __anext__(self):
|
||||
msg = yield from self.receive()
|
||||
if msg.type == WSMsgType.CLOSE:
|
||||
raise StopAsyncIteration # NOQA
|
||||
return msg
|
||||
|
|
@ -0,0 +1,783 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import http.cookies
|
||||
import ssl
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from hashlib import md5, sha1, sha256
|
||||
from itertools import chain
|
||||
from math import ceil
|
||||
from types import MappingProxyType
|
||||
|
||||
import aiohttp
|
||||
|
||||
from . import hdrs, helpers
|
||||
from .client import ClientRequest
|
||||
from .errors import (ClientOSError, ClientTimeoutError, FingerprintMismatch,
|
||||
HttpProxyError, ProxyConnectionError,
|
||||
ServerDisconnectedError)
|
||||
from .helpers import is_ip_address, sentinel
|
||||
from .resolver import DefaultResolver
|
||||
|
||||
__all__ = ('BaseConnector', 'TCPConnector', 'ProxyConnector', 'UnixConnector')
|
||||
|
||||
PY_343 = sys.version_info >= (3, 4, 3)
|
||||
|
||||
HASHFUNC_BY_DIGESTLEN = {
|
||||
16: md5,
|
||||
20: sha1,
|
||||
32: sha256,
|
||||
}
|
||||
|
||||
|
||||
class Connection:
|
||||
|
||||
_source_traceback = None
|
||||
_transport = None
|
||||
|
||||
def __init__(self, connector, key, request, transport, protocol, loop):
|
||||
self._key = key
|
||||
self._connector = connector
|
||||
self._request = request
|
||||
self._transport = transport
|
||||
self._protocol = protocol
|
||||
self._loop = loop
|
||||
self.reader = protocol.reader
|
||||
self.writer = protocol.writer
|
||||
|
||||
if loop.get_debug():
|
||||
self._source_traceback = traceback.extract_stack(sys._getframe(1))
|
||||
|
||||
def __repr__(self):
|
||||
return 'Connection<{}>'.format(self._key)
|
||||
|
||||
def __del__(self, _warnings=warnings):
|
||||
if self._transport is not None:
|
||||
_warnings.warn('Unclosed connection {!r}'.format(self),
|
||||
ResourceWarning)
|
||||
if self._loop.is_closed():
|
||||
return
|
||||
|
||||
self._connector._release(
|
||||
self._key, self._request, self._transport, self._protocol,
|
||||
should_close=True)
|
||||
|
||||
context = {'client_connection': self,
|
||||
'message': 'Unclosed connection'}
|
||||
if self._source_traceback is not None:
|
||||
context['source_traceback'] = self._source_traceback
|
||||
self._loop.call_exception_handler(context)
|
||||
|
||||
@property
|
||||
def loop(self):
|
||||
return self._loop
|
||||
|
||||
def close(self):
|
||||
if self._transport is not None:
|
||||
self._connector._release(
|
||||
self._key, self._request, self._transport, self._protocol,
|
||||
should_close=True)
|
||||
self._transport = None
|
||||
|
||||
def release(self):
|
||||
if self._transport is not None:
|
||||
self._connector._release(
|
||||
self._key, self._request, self._transport, self._protocol,
|
||||
should_close=False)
|
||||
self._transport = None
|
||||
|
||||
def detach(self):
|
||||
self._transport = None
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
return self._transport is None
|
||||
|
||||
|
||||
class BaseConnector(object):
|
||||
"""Base connector class.
|
||||
|
||||
conn_timeout - (optional) Connect timeout.
|
||||
keepalive_timeout - (optional) Keep-alive timeout.
|
||||
force_close - Set to True to force close and do reconnect
|
||||
after each request (and between redirects).
|
||||
limit - The limit of simultaneous connections to the same endpoint.
|
||||
loop - Optional event loop.
|
||||
"""
|
||||
|
||||
_closed = True # prevent AttributeError in __del__ if ctor was failed
|
||||
_source_traceback = None
|
||||
|
||||
def __init__(self, *, conn_timeout=None, keepalive_timeout=sentinel,
|
||||
force_close=False, limit=20,
|
||||
loop=None):
|
||||
|
||||
if force_close:
|
||||
if keepalive_timeout is not None and \
|
||||
keepalive_timeout is not sentinel:
|
||||
raise ValueError('keepalive_timeout cannot '
|
||||
'be set if force_close is True')
|
||||
else:
|
||||
if keepalive_timeout is sentinel:
|
||||
keepalive_timeout = 30
|
||||
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
self._closed = False
|
||||
if loop.get_debug():
|
||||
self._source_traceback = traceback.extract_stack(sys._getframe(1))
|
||||
|
||||
self._conns = {}
|
||||
self._acquired = defaultdict(set)
|
||||
self._conn_timeout = conn_timeout
|
||||
self._keepalive_timeout = keepalive_timeout
|
||||
self._cleanup_handle = None
|
||||
self._force_close = force_close
|
||||
self._limit = limit
|
||||
self._waiters = defaultdict(list)
|
||||
|
||||
self._loop = loop
|
||||
self._factory = functools.partial(
|
||||
aiohttp.StreamProtocol, loop=loop,
|
||||
disconnect_error=ServerDisconnectedError)
|
||||
|
||||
self.cookies = http.cookies.SimpleCookie()
|
||||
|
||||
def __del__(self, _warnings=warnings):
|
||||
if self._closed:
|
||||
return
|
||||
if not self._conns:
|
||||
return
|
||||
|
||||
conns = [repr(c) for c in self._conns.values()]
|
||||
|
||||
self.close()
|
||||
|
||||
_warnings.warn("Unclosed connector {!r}".format(self),
|
||||
ResourceWarning)
|
||||
context = {'connector': self,
|
||||
'connections': conns,
|
||||
'message': 'Unclosed connector'}
|
||||
if self._source_traceback is not None:
|
||||
context['source_traceback'] = self._source_traceback
|
||||
self._loop.call_exception_handler(context)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self.close()
|
||||
|
||||
@property
|
||||
def force_close(self):
|
||||
"""Ultimately close connection on releasing if True."""
|
||||
return self._force_close
|
||||
|
||||
@property
|
||||
def limit(self):
|
||||
"""The limit for simultaneous connections to the same endpoint.
|
||||
|
||||
Endpoints are the same if they are have equal
|
||||
(host, port, is_ssl) triple.
|
||||
|
||||
If limit is None the connector has no limit.
|
||||
The default limit size is 20.
|
||||
"""
|
||||
return self._limit
|
||||
|
||||
def _cleanup(self):
|
||||
"""Cleanup unused transports."""
|
||||
if self._cleanup_handle:
|
||||
self._cleanup_handle.cancel()
|
||||
self._cleanup_handle = None
|
||||
|
||||
now = self._loop.time()
|
||||
|
||||
connections = {}
|
||||
timeout = self._keepalive_timeout
|
||||
|
||||
for key, conns in self._conns.items():
|
||||
alive = []
|
||||
for transport, proto, t0 in conns:
|
||||
if transport is not None:
|
||||
if proto and not proto.is_connected():
|
||||
transport = None
|
||||
else:
|
||||
delta = t0 + self._keepalive_timeout - now
|
||||
if delta < 0:
|
||||
transport.close()
|
||||
transport = None
|
||||
elif delta < timeout:
|
||||
timeout = delta
|
||||
|
||||
if transport is not None:
|
||||
alive.append((transport, proto, t0))
|
||||
if alive:
|
||||
connections[key] = alive
|
||||
|
||||
if connections:
|
||||
self._cleanup_handle = self._loop.call_at(
|
||||
ceil(now + timeout), self._cleanup)
|
||||
|
||||
self._conns = connections
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
if self._cleanup_handle is None:
|
||||
now = self._loop.time()
|
||||
self._cleanup_handle = self._loop.call_at(
|
||||
ceil(now + self._keepalive_timeout), self._cleanup)
|
||||
|
||||
def close(self):
|
||||
"""Close all opened transports."""
|
||||
ret = helpers.create_future(self._loop)
|
||||
ret.set_result(None)
|
||||
if self._closed:
|
||||
return ret
|
||||
self._closed = True
|
||||
|
||||
try:
|
||||
if self._loop.is_closed():
|
||||
return ret
|
||||
|
||||
for key, data in self._conns.items():
|
||||
for transport, proto, t0 in data:
|
||||
transport.close()
|
||||
|
||||
for transport in chain(*self._acquired.values()):
|
||||
transport.close()
|
||||
|
||||
if self._cleanup_handle:
|
||||
self._cleanup_handle.cancel()
|
||||
|
||||
finally:
|
||||
self._conns.clear()
|
||||
self._acquired.clear()
|
||||
self._cleanup_handle = None
|
||||
return ret
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
"""Is connector closed.
|
||||
|
||||
A readonly property.
|
||||
"""
|
||||
return self._closed
|
||||
|
||||
@asyncio.coroutine
|
||||
def connect(self, req):
|
||||
"""Get from pool or create new connection."""
|
||||
key = (req.host, req.port, req.ssl)
|
||||
|
||||
limit = self._limit
|
||||
if limit is not None:
|
||||
fut = helpers.create_future(self._loop)
|
||||
waiters = self._waiters[key]
|
||||
|
||||
# The limit defines the maximum number of concurrent connections
|
||||
# for a key. Waiters must be counted against the limit, even before
|
||||
# the underlying connection is created.
|
||||
available = limit - len(waiters) - len(self._acquired[key])
|
||||
|
||||
# Don't wait if there are connections available.
|
||||
if available > 0:
|
||||
fut.set_result(None)
|
||||
|
||||
# This connection will now count towards the limit.
|
||||
waiters.append(fut)
|
||||
|
||||
try:
|
||||
if limit is not None:
|
||||
yield from fut
|
||||
|
||||
transport, proto = self._get(key)
|
||||
if transport is None:
|
||||
try:
|
||||
if self._conn_timeout:
|
||||
transport, proto = yield from asyncio.wait_for(
|
||||
self._create_connection(req),
|
||||
self._conn_timeout, loop=self._loop)
|
||||
else:
|
||||
transport, proto = \
|
||||
yield from self._create_connection(req)
|
||||
|
||||
except asyncio.TimeoutError as exc:
|
||||
raise ClientTimeoutError(
|
||||
'Connection timeout to host {0[0]}:{0[1]} ssl:{0[2]}'
|
||||
.format(key)) from exc
|
||||
except OSError as exc:
|
||||
raise ClientOSError(
|
||||
exc.errno,
|
||||
'Cannot connect to host {0[0]}:{0[1]} ssl:{0[2]} [{1}]'
|
||||
.format(key, exc.strerror)) from exc
|
||||
except:
|
||||
self._release_waiter(key)
|
||||
raise
|
||||
|
||||
self._acquired[key].add(transport)
|
||||
conn = Connection(self, key, req, transport, proto, self._loop)
|
||||
return conn
|
||||
|
||||
def _get(self, key):
|
||||
try:
|
||||
conns = self._conns[key]
|
||||
except KeyError:
|
||||
return None, None
|
||||
t1 = self._loop.time()
|
||||
while conns:
|
||||
transport, proto, t0 = conns.pop()
|
||||
if transport is not None and proto.is_connected():
|
||||
if t1 - t0 > self._keepalive_timeout:
|
||||
transport.close()
|
||||
transport = None
|
||||
else:
|
||||
if not conns:
|
||||
# The very last connection was reclaimed: drop the key
|
||||
del self._conns[key]
|
||||
return transport, proto
|
||||
# No more connections: drop the key
|
||||
del self._conns[key]
|
||||
return None, None
|
||||
|
||||
def _release_waiter(self, key):
|
||||
waiters = self._waiters[key]
|
||||
while waiters:
|
||||
waiter = waiters.pop(0)
|
||||
if not waiter.done():
|
||||
waiter.set_result(None)
|
||||
break
|
||||
|
||||
def _release(self, key, req, transport, protocol, *, should_close=False):
|
||||
if self._closed:
|
||||
# acquired connection is already released on connector closing
|
||||
return
|
||||
|
||||
acquired = self._acquired[key]
|
||||
try:
|
||||
acquired.remove(transport)
|
||||
except KeyError: # pragma: no cover
|
||||
# this may be result of undetermenistic order of objects
|
||||
# finalization due garbage collection.
|
||||
pass
|
||||
else:
|
||||
if self._limit is not None and len(acquired) < self._limit:
|
||||
self._release_waiter(key)
|
||||
|
||||
resp = req.response
|
||||
|
||||
if not should_close:
|
||||
if self._force_close:
|
||||
should_close = True
|
||||
elif resp is not None:
|
||||
should_close = resp._should_close
|
||||
|
||||
reader = protocol.reader
|
||||
if should_close or (reader.output and not reader.output.at_eof()):
|
||||
transport.close()
|
||||
else:
|
||||
conns = self._conns.get(key)
|
||||
if conns is None:
|
||||
conns = self._conns[key] = []
|
||||
conns.append((transport, protocol, self._loop.time()))
|
||||
reader.unset_parser()
|
||||
|
||||
self._start_cleanup_task()
|
||||
|
||||
@asyncio.coroutine
|
||||
def _create_connection(self, req):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
_SSL_OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0)
|
||||
|
||||
|
||||
class TCPConnector(BaseConnector):
|
||||
"""TCP connector.
|
||||
|
||||
verify_ssl - Set to True to check ssl certifications.
|
||||
fingerprint - Pass the binary md5, sha1, or sha256
|
||||
digest of the expected certificate in DER format to verify
|
||||
that the certificate the server presents matches. See also
|
||||
https://en.wikipedia.org/wiki/Transport_Layer_Security#Certificate_pinning
|
||||
resolve - (Deprecated) Set to True to do DNS lookup for
|
||||
host name.
|
||||
resolver - Enable DNS lookups and use this
|
||||
resolver
|
||||
use_dns_cache - Use memory cache for DNS lookups.
|
||||
family - socket address family
|
||||
local_addr - local tuple of (host, port) to bind socket to
|
||||
|
||||
conn_timeout - (optional) Connect timeout.
|
||||
keepalive_timeout - (optional) Keep-alive timeout.
|
||||
force_close - Set to True to force close and do reconnect
|
||||
after each request (and between redirects).
|
||||
limit - The limit of simultaneous connections to the same endpoint.
|
||||
loop - Optional event loop.
|
||||
"""
|
||||
|
||||
def __init__(self, *, verify_ssl=True, fingerprint=None,
|
||||
resolve=sentinel, use_dns_cache=sentinel,
|
||||
family=0, ssl_context=None, local_addr=None, resolver=None,
|
||||
conn_timeout=None, keepalive_timeout=sentinel,
|
||||
force_close=False, limit=20,
|
||||
loop=None):
|
||||
super().__init__(conn_timeout=conn_timeout,
|
||||
keepalive_timeout=keepalive_timeout,
|
||||
force_close=force_close, limit=limit, loop=loop)
|
||||
|
||||
if not verify_ssl and ssl_context is not None:
|
||||
raise ValueError(
|
||||
"Either disable ssl certificate validation by "
|
||||
"verify_ssl=False or specify ssl_context, not both.")
|
||||
|
||||
self._verify_ssl = verify_ssl
|
||||
|
||||
if fingerprint:
|
||||
digestlen = len(fingerprint)
|
||||
hashfunc = HASHFUNC_BY_DIGESTLEN.get(digestlen)
|
||||
if not hashfunc:
|
||||
raise ValueError('fingerprint has invalid length')
|
||||
self._hashfunc = hashfunc
|
||||
self._fingerprint = fingerprint
|
||||
|
||||
if resolve is not sentinel:
|
||||
warnings.warn(("resolve parameter is deprecated, "
|
||||
"use use_dns_cache instead"),
|
||||
DeprecationWarning, stacklevel=2)
|
||||
|
||||
if use_dns_cache is not sentinel and resolve is not sentinel:
|
||||
if use_dns_cache != resolve:
|
||||
raise ValueError("use_dns_cache must agree with resolve")
|
||||
_use_dns_cache = use_dns_cache
|
||||
elif use_dns_cache is not sentinel:
|
||||
_use_dns_cache = use_dns_cache
|
||||
elif resolve is not sentinel:
|
||||
_use_dns_cache = resolve
|
||||
else:
|
||||
_use_dns_cache = True
|
||||
|
||||
if resolver is None:
|
||||
resolver = DefaultResolver(loop=self._loop)
|
||||
self._resolver = resolver
|
||||
|
||||
self._use_dns_cache = _use_dns_cache
|
||||
self._cached_hosts = {}
|
||||
self._ssl_context = ssl_context
|
||||
self._family = family
|
||||
self._local_addr = local_addr
|
||||
|
||||
@property
|
||||
def verify_ssl(self):
|
||||
"""Do check for ssl certifications?"""
|
||||
return self._verify_ssl
|
||||
|
||||
@property
|
||||
def fingerprint(self):
|
||||
"""Expected ssl certificate fingerprint."""
|
||||
return self._fingerprint
|
||||
|
||||
@property
|
||||
def ssl_context(self):
|
||||
"""SSLContext instance for https requests.
|
||||
|
||||
Lazy property, creates context on demand.
|
||||
"""
|
||||
if self._ssl_context is None:
|
||||
if not self._verify_ssl:
|
||||
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
||||
sslcontext.options |= ssl.OP_NO_SSLv2
|
||||
sslcontext.options |= ssl.OP_NO_SSLv3
|
||||
sslcontext.options |= _SSL_OP_NO_COMPRESSION
|
||||
sslcontext.set_default_verify_paths()
|
||||
else:
|
||||
sslcontext = ssl.create_default_context()
|
||||
self._ssl_context = sslcontext
|
||||
return self._ssl_context
|
||||
|
||||
@property
|
||||
def family(self):
|
||||
"""Socket family like AF_INET."""
|
||||
return self._family
|
||||
|
||||
@property
|
||||
def use_dns_cache(self):
|
||||
"""True if local DNS caching is enabled."""
|
||||
return self._use_dns_cache
|
||||
|
||||
@property
|
||||
def cached_hosts(self):
|
||||
"""Read-only dict of cached DNS record."""
|
||||
return MappingProxyType(self._cached_hosts)
|
||||
|
||||
def clear_dns_cache(self, host=None, port=None):
|
||||
"""Remove specified host/port or clear all dns local cache."""
|
||||
if host is not None and port is not None:
|
||||
self._cached_hosts.pop((host, port), None)
|
||||
elif host is not None or port is not None:
|
||||
raise ValueError("either both host and port "
|
||||
"or none of them are allowed")
|
||||
else:
|
||||
self._cached_hosts.clear()
|
||||
|
||||
@property
|
||||
def resolve(self):
|
||||
"""Do DNS lookup for host name?"""
|
||||
warnings.warn((".resolve property is deprecated, "
|
||||
"use .dns_cache instead"),
|
||||
DeprecationWarning, stacklevel=2)
|
||||
return self.use_dns_cache
|
||||
|
||||
@property
|
||||
def resolved_hosts(self):
|
||||
"""The dict of (host, port) -> (ipaddr, port) pairs."""
|
||||
warnings.warn((".resolved_hosts property is deprecated, "
|
||||
"use .cached_hosts instead"),
|
||||
DeprecationWarning, stacklevel=2)
|
||||
return self.cached_hosts
|
||||
|
||||
def clear_resolved_hosts(self, host=None, port=None):
|
||||
"""Remove specified host/port or clear all resolve cache."""
|
||||
warnings.warn((".clear_resolved_hosts() is deprecated, "
|
||||
"use .clear_dns_cache() instead"),
|
||||
DeprecationWarning, stacklevel=2)
|
||||
if host is not None and port is not None:
|
||||
self.clear_dns_cache(host, port)
|
||||
else:
|
||||
self.clear_dns_cache()
|
||||
|
||||
@asyncio.coroutine
|
||||
def _resolve_host(self, host, port):
|
||||
if is_ip_address(host):
|
||||
return [{'hostname': host, 'host': host, 'port': port,
|
||||
'family': self._family, 'proto': 0, 'flags': 0}]
|
||||
|
||||
if self._use_dns_cache:
|
||||
key = (host, port)
|
||||
|
||||
if key not in self._cached_hosts:
|
||||
self._cached_hosts[key] = yield from \
|
||||
self._resolver.resolve(host, port, family=self._family)
|
||||
|
||||
return self._cached_hosts[key]
|
||||
else:
|
||||
res = yield from self._resolver.resolve(
|
||||
host, port, family=self._family)
|
||||
return res
|
||||
|
||||
@asyncio.coroutine
|
||||
def _create_connection(self, req):
|
||||
"""Create connection.
|
||||
|
||||
Has same keyword arguments as BaseEventLoop.create_connection.
|
||||
"""
|
||||
if req.proxy:
|
||||
transport, proto = yield from self._create_proxy_connection(req)
|
||||
else:
|
||||
transport, proto = yield from self._create_direct_connection(req)
|
||||
|
||||
return transport, proto
|
||||
|
||||
@asyncio.coroutine
|
||||
def _create_direct_connection(self, req):
|
||||
if req.ssl:
|
||||
sslcontext = self.ssl_context
|
||||
else:
|
||||
sslcontext = None
|
||||
|
||||
hosts = yield from self._resolve_host(req.host, req.port)
|
||||
exc = None
|
||||
|
||||
for hinfo in hosts:
|
||||
try:
|
||||
host = hinfo['host']
|
||||
port = hinfo['port']
|
||||
transp, proto = yield from self._loop.create_connection(
|
||||
self._factory, host, port,
|
||||
ssl=sslcontext, family=hinfo['family'],
|
||||
proto=hinfo['proto'], flags=hinfo['flags'],
|
||||
server_hostname=hinfo['hostname'] if sslcontext else None,
|
||||
local_addr=self._local_addr)
|
||||
has_cert = transp.get_extra_info('sslcontext')
|
||||
if has_cert and self._fingerprint:
|
||||
sock = transp.get_extra_info('socket')
|
||||
if not hasattr(sock, 'getpeercert'):
|
||||
# Workaround for asyncio 3.5.0
|
||||
# Starting from 3.5.1 version
|
||||
# there is 'ssl_object' extra info in transport
|
||||
sock = transp._ssl_protocol._sslpipe.ssl_object
|
||||
# gives DER-encoded cert as a sequence of bytes (or None)
|
||||
cert = sock.getpeercert(binary_form=True)
|
||||
assert cert
|
||||
got = self._hashfunc(cert).digest()
|
||||
expected = self._fingerprint
|
||||
if got != expected:
|
||||
transp.close()
|
||||
raise FingerprintMismatch(expected, got, host, port)
|
||||
return transp, proto
|
||||
except OSError as e:
|
||||
exc = e
|
||||
else:
|
||||
raise ClientOSError(exc.errno,
|
||||
'Can not connect to %s:%s [%s]' %
|
||||
(req.host, req.port, exc.strerror)) from exc
|
||||
|
||||
@asyncio.coroutine
|
||||
def _create_proxy_connection(self, req):
|
||||
proxy_req = ClientRequest(
|
||||
hdrs.METH_GET, req.proxy,
|
||||
headers={hdrs.HOST: req.host},
|
||||
auth=req.proxy_auth,
|
||||
loop=self._loop)
|
||||
try:
|
||||
# create connection to proxy server
|
||||
transport, proto = yield from self._create_direct_connection(
|
||||
proxy_req)
|
||||
except OSError as exc:
|
||||
raise ProxyConnectionError(*exc.args) from exc
|
||||
|
||||
if not req.ssl:
|
||||
req.path = '{scheme}://{host}{path}'.format(scheme=req.scheme,
|
||||
host=req.netloc,
|
||||
path=req.path)
|
||||
if hdrs.AUTHORIZATION in proxy_req.headers:
|
||||
auth = proxy_req.headers[hdrs.AUTHORIZATION]
|
||||
del proxy_req.headers[hdrs.AUTHORIZATION]
|
||||
if not req.ssl:
|
||||
req.headers[hdrs.PROXY_AUTHORIZATION] = auth
|
||||
else:
|
||||
proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth
|
||||
|
||||
if req.ssl:
|
||||
# For HTTPS requests over HTTP proxy
|
||||
# we must notify proxy to tunnel connection
|
||||
# so we send CONNECT command:
|
||||
# CONNECT www.python.org:443 HTTP/1.1
|
||||
# Host: www.python.org
|
||||
#
|
||||
# next we must do TLS handshake and so on
|
||||
# to do this we must wrap raw socket into secure one
|
||||
# asyncio handles this perfectly
|
||||
proxy_req.method = hdrs.METH_CONNECT
|
||||
proxy_req.path = '{}:{}'.format(req.host, req.port)
|
||||
key = (req.host, req.port, req.ssl)
|
||||
conn = Connection(self, key, proxy_req,
|
||||
transport, proto, self._loop)
|
||||
self._acquired[key].add(conn._transport)
|
||||
proxy_resp = proxy_req.send(conn.writer, conn.reader)
|
||||
try:
|
||||
resp = yield from proxy_resp.start(conn, True)
|
||||
except:
|
||||
proxy_resp.close()
|
||||
conn.close()
|
||||
raise
|
||||
else:
|
||||
conn.detach()
|
||||
if resp.status != 200:
|
||||
raise HttpProxyError(code=resp.status, message=resp.reason)
|
||||
rawsock = transport.get_extra_info('socket', default=None)
|
||||
if rawsock is None:
|
||||
raise RuntimeError(
|
||||
"Transport does not expose socket instance")
|
||||
transport.pause_reading()
|
||||
transport, proto = yield from self._loop.create_connection(
|
||||
self._factory, ssl=self.ssl_context, sock=rawsock,
|
||||
server_hostname=req.host)
|
||||
finally:
|
||||
proxy_resp.close()
|
||||
|
||||
return transport, proto
|
||||
|
||||
|
||||
class ProxyConnector(TCPConnector):
|
||||
"""Http Proxy connector.
|
||||
Deprecated, use ClientSession.request with proxy parameters.
|
||||
Is still here for backward compatibility.
|
||||
|
||||
proxy - Proxy URL address. Only HTTP proxy supported.
|
||||
proxy_auth - (optional) Proxy HTTP Basic Auth
|
||||
proxy_auth - aiohttp.helpers.BasicAuth
|
||||
conn_timeout - (optional) Connect timeout.
|
||||
keepalive_timeout - (optional) Keep-alive timeout.
|
||||
force_close - Set to True to force close and do reconnect
|
||||
after each request (and between redirects).
|
||||
limit - The limit of simultaneous connections to the same endpoint.
|
||||
loop - Optional event loop.
|
||||
|
||||
Usage:
|
||||
|
||||
>>> conn = ProxyConnector(proxy="http://some.proxy.com")
|
||||
>>> session = ClientSession(connector=conn)
|
||||
>>> resp = yield from session.get('http://python.org')
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, proxy, *, proxy_auth=None, force_close=True,
|
||||
conn_timeout=None, keepalive_timeout=sentinel,
|
||||
limit=20, loop=None):
|
||||
warnings.warn("ProxyConnector is deprecated, use "
|
||||
"client.get(url, proxy=proxy_url) instead",
|
||||
DeprecationWarning)
|
||||
super().__init__(force_close=force_close,
|
||||
conn_timeout=conn_timeout,
|
||||
keepalive_timeout=keepalive_timeout,
|
||||
limit=limit, loop=loop)
|
||||
self._proxy = proxy
|
||||
self._proxy_auth = proxy_auth
|
||||
|
||||
@property
|
||||
def proxy(self):
|
||||
return self._proxy
|
||||
|
||||
@property
|
||||
def proxy_auth(self):
|
||||
return self._proxy_auth
|
||||
|
||||
@asyncio.coroutine
|
||||
def _create_connection(self, req):
|
||||
"""
|
||||
Use TCPConnector _create_connection, to emulate old ProxyConnector.
|
||||
"""
|
||||
req.update_proxy(self._proxy, self._proxy_auth)
|
||||
transport, proto = yield from super()._create_connection(req)
|
||||
|
||||
return transport, proto
|
||||
|
||||
|
||||
class UnixConnector(BaseConnector):
|
||||
"""Unix socket connector.
|
||||
|
||||
path - Unix socket path.
|
||||
conn_timeout - (optional) Connect timeout.
|
||||
keepalive_timeout - (optional) Keep-alive timeout.
|
||||
force_close - Set to True to force close and do reconnect
|
||||
after each request (and between redirects).
|
||||
limit - The limit of simultaneous connections to the same endpoint.
|
||||
loop - Optional event loop.
|
||||
|
||||
Usage:
|
||||
|
||||
>>> conn = UnixConnector(path='/path/to/socket')
|
||||
>>> session = ClientSession(connector=conn)
|
||||
>>> resp = yield from session.get('http://python.org')
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, path, force_close=False, conn_timeout=None,
|
||||
keepalive_timeout=sentinel, limit=20, loop=None):
|
||||
super().__init__(force_close=force_close,
|
||||
conn_timeout=conn_timeout,
|
||||
keepalive_timeout=keepalive_timeout,
|
||||
limit=limit, loop=loop)
|
||||
self._path = path
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
"""Path to unix socket."""
|
||||
return self._path
|
||||
|
||||
@asyncio.coroutine
|
||||
def _create_connection(self, req):
|
||||
return (yield from self._loop.create_unix_connection(
|
||||
self._factory, self._path))
|
||||
|
|
@ -0,0 +1,290 @@
|
|||
import datetime
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from http.cookies import Morsel, SimpleCookie
|
||||
from math import ceil
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from .abc import AbstractCookieJar
|
||||
from .helpers import is_ip_address
|
||||
|
||||
|
||||
class CookieJar(AbstractCookieJar):
|
||||
"""Implements cookie storage adhering to RFC 6265."""
|
||||
|
||||
DATE_TOKENS_RE = re.compile(
|
||||
"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
|
||||
"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)")
|
||||
|
||||
DATE_HMS_TIME_RE = re.compile("(\d{1,2}):(\d{1,2}):(\d{1,2})")
|
||||
|
||||
DATE_DAY_OF_MONTH_RE = re.compile("(\d{1,2})")
|
||||
|
||||
DATE_MONTH_RE = re.compile("(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|"
|
||||
"(aug)|(sep)|(oct)|(nov)|(dec)", re.I)
|
||||
|
||||
DATE_YEAR_RE = re.compile("(\d{2,4})")
|
||||
|
||||
MAX_TIME = 2051215261.0 # so far in future (2035-01-01)
|
||||
|
||||
def __init__(self, *, unsafe=False, loop=None):
|
||||
super().__init__(loop=loop)
|
||||
self._cookies = defaultdict(SimpleCookie)
|
||||
self._host_only_cookies = set()
|
||||
self._unsafe = unsafe
|
||||
self._next_expiration = ceil(self._loop.time())
|
||||
self._expirations = {}
|
||||
|
||||
def clear(self):
|
||||
self._cookies.clear()
|
||||
self._host_only_cookies.clear()
|
||||
self._next_expiration = ceil(self._loop.time())
|
||||
self._expirations.clear()
|
||||
|
||||
def __iter__(self):
|
||||
self._do_expiration()
|
||||
for val in self._cookies.values():
|
||||
yield from val.values()
|
||||
|
||||
def __len__(self):
|
||||
return sum(1 for i in self)
|
||||
|
||||
def _do_expiration(self):
|
||||
now = self._loop.time()
|
||||
if self._next_expiration > now:
|
||||
return
|
||||
if not self._expirations:
|
||||
return
|
||||
next_expiration = self.MAX_TIME
|
||||
to_del = []
|
||||
cookies = self._cookies
|
||||
expirations = self._expirations
|
||||
for (domain, name), when in expirations.items():
|
||||
if when < now:
|
||||
cookies[domain].pop(name, None)
|
||||
to_del.append((domain, name))
|
||||
self._host_only_cookies.discard((domain, name))
|
||||
else:
|
||||
next_expiration = min(next_expiration, when)
|
||||
for key in to_del:
|
||||
del expirations[key]
|
||||
|
||||
self._next_expiration = ceil(next_expiration)
|
||||
|
||||
def _expire_cookie(self, when, domain, name):
|
||||
self._next_expiration = min(self._next_expiration, when)
|
||||
self._expirations[(domain, name)] = when
|
||||
|
||||
def update_cookies(self, cookies, response_url=None):
|
||||
"""Update cookies."""
|
||||
url_parsed = urlsplit(response_url or "")
|
||||
hostname = url_parsed.hostname
|
||||
|
||||
if not self._unsafe and is_ip_address(hostname):
|
||||
# Don't accept cookies from IPs
|
||||
return
|
||||
|
||||
if isinstance(cookies, Mapping):
|
||||
cookies = cookies.items()
|
||||
|
||||
for name, cookie in cookies:
|
||||
if not isinstance(cookie, Morsel):
|
||||
tmp = SimpleCookie()
|
||||
tmp[name] = cookie
|
||||
cookie = tmp[name]
|
||||
|
||||
domain = cookie["domain"]
|
||||
|
||||
# ignore domains with trailing dots
|
||||
if domain.endswith('.'):
|
||||
domain = ""
|
||||
del cookie["domain"]
|
||||
|
||||
if not domain and hostname is not None:
|
||||
# Set the cookie's domain to the response hostname
|
||||
# and set its host-only-flag
|
||||
self._host_only_cookies.add((hostname, name))
|
||||
domain = cookie["domain"] = hostname
|
||||
|
||||
if domain.startswith("."):
|
||||
# Remove leading dot
|
||||
domain = domain[1:]
|
||||
cookie["domain"] = domain
|
||||
|
||||
if hostname and not self._is_domain_match(domain, hostname):
|
||||
# Setting cookies for different domains is not allowed
|
||||
continue
|
||||
|
||||
path = cookie["path"]
|
||||
if not path or not path.startswith("/"):
|
||||
# Set the cookie's path to the response path
|
||||
path = url_parsed.path
|
||||
if not path.startswith("/"):
|
||||
path = "/"
|
||||
else:
|
||||
# Cut everything from the last slash to the end
|
||||
path = "/" + path[1:path.rfind("/")]
|
||||
cookie["path"] = path
|
||||
|
||||
max_age = cookie["max-age"]
|
||||
if max_age:
|
||||
try:
|
||||
delta_seconds = int(max_age)
|
||||
self._expire_cookie(self._loop.time() + delta_seconds,
|
||||
domain, name)
|
||||
except ValueError:
|
||||
cookie["max-age"] = ""
|
||||
|
||||
else:
|
||||
expires = cookie["expires"]
|
||||
if expires:
|
||||
expire_time = self._parse_date(expires)
|
||||
if expire_time:
|
||||
self._expire_cookie(expire_time.timestamp(),
|
||||
domain, name)
|
||||
else:
|
||||
cookie["expires"] = ""
|
||||
|
||||
# use dict method because SimpleCookie class modifies value
|
||||
# before Python 3.4.3
|
||||
dict.__setitem__(self._cookies[domain], name, cookie)
|
||||
|
||||
self._do_expiration()
|
||||
|
||||
def filter_cookies(self, request_url):
|
||||
"""Returns this jar's cookies filtered by their attributes."""
|
||||
self._do_expiration()
|
||||
url_parsed = urlsplit(request_url)
|
||||
filtered = SimpleCookie()
|
||||
hostname = url_parsed.hostname or ""
|
||||
is_not_secure = url_parsed.scheme not in ("https", "wss")
|
||||
|
||||
for cookie in self:
|
||||
name = cookie.key
|
||||
domain = cookie["domain"]
|
||||
|
||||
# Send shared cookies
|
||||
if not domain:
|
||||
filtered[name] = cookie.value
|
||||
continue
|
||||
|
||||
if not self._unsafe and is_ip_address(hostname):
|
||||
continue
|
||||
|
||||
if (domain, name) in self._host_only_cookies:
|
||||
if domain != hostname:
|
||||
continue
|
||||
elif not self._is_domain_match(domain, hostname):
|
||||
continue
|
||||
|
||||
if not self._is_path_match(url_parsed.path, cookie["path"]):
|
||||
continue
|
||||
|
||||
if is_not_secure and cookie["secure"]:
|
||||
continue
|
||||
|
||||
filtered[name] = cookie.value
|
||||
|
||||
return filtered
|
||||
|
||||
@staticmethod
|
||||
def _is_domain_match(domain, hostname):
|
||||
"""Implements domain matching adhering to RFC 6265."""
|
||||
if hostname == domain:
|
||||
return True
|
||||
|
||||
if not hostname.endswith(domain):
|
||||
return False
|
||||
|
||||
non_matching = hostname[:-len(domain)]
|
||||
|
||||
if not non_matching.endswith("."):
|
||||
return False
|
||||
|
||||
return not is_ip_address(hostname)
|
||||
|
||||
@staticmethod
|
||||
def _is_path_match(req_path, cookie_path):
|
||||
"""Implements path matching adhering to RFC 6265."""
|
||||
if not req_path.startswith("/"):
|
||||
req_path = "/"
|
||||
|
||||
if req_path == cookie_path:
|
||||
return True
|
||||
|
||||
if not req_path.startswith(cookie_path):
|
||||
return False
|
||||
|
||||
if cookie_path.endswith("/"):
|
||||
return True
|
||||
|
||||
non_matching = req_path[len(cookie_path):]
|
||||
|
||||
return non_matching.startswith("/")
|
||||
|
||||
@classmethod
|
||||
def _parse_date(cls, date_str):
|
||||
"""Implements date string parsing adhering to RFC 6265."""
|
||||
if not date_str:
|
||||
return
|
||||
|
||||
found_time = False
|
||||
found_day = False
|
||||
found_month = False
|
||||
found_year = False
|
||||
|
||||
hour = minute = second = 0
|
||||
day = 0
|
||||
month = 0
|
||||
year = 0
|
||||
|
||||
for token_match in cls.DATE_TOKENS_RE.finditer(date_str):
|
||||
|
||||
token = token_match.group("token")
|
||||
|
||||
if not found_time:
|
||||
time_match = cls.DATE_HMS_TIME_RE.match(token)
|
||||
if time_match:
|
||||
found_time = True
|
||||
hour, minute, second = [
|
||||
int(s) for s in time_match.groups()]
|
||||
continue
|
||||
|
||||
if not found_day:
|
||||
day_match = cls.DATE_DAY_OF_MONTH_RE.match(token)
|
||||
if day_match:
|
||||
found_day = True
|
||||
day = int(day_match.group())
|
||||
continue
|
||||
|
||||
if not found_month:
|
||||
month_match = cls.DATE_MONTH_RE.match(token)
|
||||
if month_match:
|
||||
found_month = True
|
||||
month = month_match.lastindex
|
||||
continue
|
||||
|
||||
if not found_year:
|
||||
year_match = cls.DATE_YEAR_RE.match(token)
|
||||
if year_match:
|
||||
found_year = True
|
||||
year = int(year_match.group())
|
||||
|
||||
if 70 <= year <= 99:
|
||||
year += 1900
|
||||
elif 0 <= year <= 69:
|
||||
year += 2000
|
||||
|
||||
if False in (found_day, found_month, found_year, found_time):
|
||||
return
|
||||
|
||||
if not 1 <= day <= 31:
|
||||
return
|
||||
|
||||
if year < 1601 or hour > 23 or minute > 59 or second > 59:
|
||||
return
|
||||
|
||||
return datetime.datetime(year, month, day,
|
||||
hour, minute, second,
|
||||
tzinfo=datetime.timezone.utc)
|
||||
|
|
@ -0,0 +1,186 @@
|
|||
"""HTTP related errors."""
|
||||
|
||||
from asyncio import TimeoutError
|
||||
|
||||
__all__ = (
|
||||
'DisconnectedError', 'ClientDisconnectedError', 'ServerDisconnectedError',
|
||||
|
||||
'HttpProcessingError', 'BadHttpMessage',
|
||||
'HttpMethodNotAllowed', 'HttpBadRequest', 'HttpProxyError',
|
||||
'BadStatusLine', 'LineTooLong', 'InvalidHeader',
|
||||
|
||||
'ClientError', 'ClientHttpProcessingError', 'ClientConnectionError',
|
||||
'ClientOSError', 'ClientTimeoutError', 'ProxyConnectionError',
|
||||
'ClientRequestError', 'ClientResponseError',
|
||||
'FingerprintMismatch',
|
||||
|
||||
'WSServerHandshakeError', 'WSClientDisconnectedError')
|
||||
|
||||
|
||||
class DisconnectedError(Exception):
|
||||
"""Disconnected."""
|
||||
|
||||
|
||||
class ClientDisconnectedError(DisconnectedError):
|
||||
"""Client disconnected."""
|
||||
|
||||
|
||||
class ServerDisconnectedError(DisconnectedError):
|
||||
"""Server disconnected."""
|
||||
|
||||
|
||||
class WSClientDisconnectedError(ClientDisconnectedError):
|
||||
"""Deprecated."""
|
||||
|
||||
|
||||
class ClientError(Exception):
|
||||
"""Base class for client connection errors."""
|
||||
|
||||
|
||||
class ClientHttpProcessingError(ClientError):
|
||||
"""Base class for client HTTP processing errors."""
|
||||
|
||||
|
||||
class ClientRequestError(ClientHttpProcessingError):
|
||||
"""Connection error during sending request."""
|
||||
|
||||
|
||||
class ClientResponseError(ClientHttpProcessingError):
|
||||
"""Connection error during reading response."""
|
||||
|
||||
|
||||
class ClientConnectionError(ClientError):
|
||||
"""Base class for client socket errors."""
|
||||
|
||||
|
||||
class ClientOSError(ClientConnectionError, OSError):
|
||||
"""OSError error."""
|
||||
|
||||
|
||||
class ClientTimeoutError(ClientConnectionError, TimeoutError):
|
||||
"""Client connection timeout error."""
|
||||
|
||||
|
||||
class ProxyConnectionError(ClientConnectionError):
|
||||
"""Proxy connection error.
|
||||
|
||||
Raised in :class:`aiohttp.connector.ProxyConnector` if
|
||||
connection to proxy can not be established.
|
||||
"""
|
||||
|
||||
|
||||
class HttpProcessingError(Exception):
|
||||
"""HTTP error.
|
||||
|
||||
Shortcut for raising HTTP errors with custom code, message and headers.
|
||||
|
||||
:param int code: HTTP Error code.
|
||||
:param str message: (optional) Error message.
|
||||
:param list of [tuple] headers: (optional) Headers to be sent in response.
|
||||
"""
|
||||
|
||||
code = 0
|
||||
message = ''
|
||||
headers = None
|
||||
|
||||
def __init__(self, *, code=None, message='', headers=None):
|
||||
if code is not None:
|
||||
self.code = code
|
||||
self.headers = headers
|
||||
self.message = message
|
||||
|
||||
super().__init__("%s, message='%s'" % (self.code, message))
|
||||
|
||||
|
||||
class WSServerHandshakeError(HttpProcessingError):
|
||||
"""websocket server handshake error."""
|
||||
|
||||
|
||||
class HttpProxyError(HttpProcessingError):
|
||||
"""HTTP proxy error.
|
||||
|
||||
Raised in :class:`aiohttp.connector.ProxyConnector` if
|
||||
proxy responds with status other than ``200 OK``
|
||||
on ``CONNECT`` request.
|
||||
"""
|
||||
|
||||
|
||||
class BadHttpMessage(HttpProcessingError):
|
||||
|
||||
code = 400
|
||||
message = 'Bad Request'
|
||||
|
||||
def __init__(self, message, *, headers=None):
|
||||
super().__init__(message=message, headers=headers)
|
||||
|
||||
|
||||
class HttpMethodNotAllowed(HttpProcessingError):
|
||||
|
||||
code = 405
|
||||
message = 'Method Not Allowed'
|
||||
|
||||
|
||||
class HttpBadRequest(BadHttpMessage):
|
||||
|
||||
code = 400
|
||||
message = 'Bad Request'
|
||||
|
||||
|
||||
class ContentEncodingError(BadHttpMessage):
|
||||
"""Content encoding error."""
|
||||
|
||||
|
||||
class TransferEncodingError(BadHttpMessage):
|
||||
"""transfer encoding error."""
|
||||
|
||||
|
||||
class LineTooLong(BadHttpMessage):
|
||||
|
||||
def __init__(self, line, limit='Unknown'):
|
||||
super().__init__(
|
||||
"got more than %s bytes when reading %s" % (limit, line))
|
||||
|
||||
|
||||
class InvalidHeader(BadHttpMessage):
|
||||
|
||||
def __init__(self, hdr):
|
||||
if isinstance(hdr, bytes):
|
||||
hdr = hdr.decode('utf-8', 'surrogateescape')
|
||||
super().__init__('Invalid HTTP Header: {}'.format(hdr))
|
||||
self.hdr = hdr
|
||||
|
||||
|
||||
class BadStatusLine(BadHttpMessage):
|
||||
|
||||
def __init__(self, line=''):
|
||||
if not line:
|
||||
line = repr(line)
|
||||
self.args = line,
|
||||
self.line = line
|
||||
|
||||
|
||||
class LineLimitExceededParserError(HttpBadRequest):
|
||||
"""Line is too long."""
|
||||
|
||||
def __init__(self, msg, limit):
|
||||
super().__init__(msg)
|
||||
self.limit = limit
|
||||
|
||||
|
||||
class FingerprintMismatch(ClientConnectionError):
|
||||
"""SSL certificate does not match expected fingerprint."""
|
||||
|
||||
def __init__(self, expected, got, host, port):
|
||||
self.expected = expected
|
||||
self.got = got
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
def __repr__(self):
|
||||
return '<{} expected={} got={} host={} port={}>'.format(
|
||||
self.__class__.__name__, self.expected, self.got,
|
||||
self.host, self.port)
|
||||
|
||||
|
||||
class InvalidURL(Exception):
|
||||
"""Invalid URL."""
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
import asyncio
|
||||
import mimetypes
|
||||
import os
|
||||
|
||||
from . import hdrs
|
||||
from .helpers import create_future
|
||||
from .web_reqrep import StreamResponse
|
||||
|
||||
|
||||
class FileSender:
|
||||
""""A helper that can be used to send files.
|
||||
"""
|
||||
|
||||
def __init__(self, *, resp_factory=StreamResponse, chunk_size=256*1024):
|
||||
self._response_factory = resp_factory
|
||||
self._chunk_size = chunk_size
|
||||
if bool(os.environ.get("AIOHTTP_NOSENDFILE")):
|
||||
self._sendfile = self._sendfile_fallback
|
||||
|
||||
def _sendfile_cb(self, fut, out_fd, in_fd, offset,
|
||||
count, loop, registered):
|
||||
if registered:
|
||||
loop.remove_writer(out_fd)
|
||||
if fut.cancelled():
|
||||
return
|
||||
try:
|
||||
n = os.sendfile(out_fd, in_fd, offset, count)
|
||||
if n == 0: # EOF reached
|
||||
n = count
|
||||
except (BlockingIOError, InterruptedError):
|
||||
n = 0
|
||||
except Exception as exc:
|
||||
fut.set_exception(exc)
|
||||
return
|
||||
|
||||
if n < count:
|
||||
loop.add_writer(out_fd, self._sendfile_cb, fut, out_fd, in_fd,
|
||||
offset + n, count - n, loop, True)
|
||||
else:
|
||||
fut.set_result(None)
|
||||
|
||||
@asyncio.coroutine
|
||||
def _sendfile_system(self, request, resp, fobj, count):
|
||||
# Write count bytes of fobj to resp using
|
||||
# the os.sendfile system call.
|
||||
#
|
||||
# request should be a aiohttp.web.Request instance.
|
||||
#
|
||||
# resp should be a aiohttp.web.StreamResponse instance.
|
||||
#
|
||||
# fobj should be an open file object.
|
||||
#
|
||||
# count should be an integer > 0.
|
||||
|
||||
transport = request.transport
|
||||
|
||||
if transport.get_extra_info("sslcontext"):
|
||||
yield from self._sendfile_fallback(request, resp, fobj, count)
|
||||
return
|
||||
|
||||
def _send_headers(resp_impl):
|
||||
# Durty hack required for
|
||||
# https://github.com/KeepSafe/aiohttp/issues/1093
|
||||
# don't send headers in sendfile mode
|
||||
pass
|
||||
|
||||
resp._send_headers = _send_headers
|
||||
|
||||
@asyncio.coroutine
|
||||
def write_eof():
|
||||
# Durty hack required for
|
||||
# https://github.com/KeepSafe/aiohttp/issues/1177
|
||||
# do nothing in write_eof
|
||||
pass
|
||||
|
||||
resp.write_eof = write_eof
|
||||
|
||||
resp_impl = yield from resp.prepare(request)
|
||||
|
||||
loop = request.app.loop
|
||||
# See https://github.com/KeepSafe/aiohttp/issues/958 for details
|
||||
|
||||
# send headers
|
||||
headers = ['HTTP/{0.major}.{0.minor} 200 OK\r\n'.format(
|
||||
request.version)]
|
||||
for hdr, val in resp.headers.items():
|
||||
headers.append('{}: {}\r\n'.format(hdr, val))
|
||||
headers.append('\r\n')
|
||||
|
||||
out_socket = transport.get_extra_info("socket").dup()
|
||||
out_socket.setblocking(False)
|
||||
out_fd = out_socket.fileno()
|
||||
in_fd = fobj.fileno()
|
||||
|
||||
bheaders = ''.join(headers).encode('utf-8')
|
||||
headers_length = len(bheaders)
|
||||
resp_impl.headers_length = headers_length
|
||||
resp_impl.output_length = headers_length + count
|
||||
|
||||
try:
|
||||
yield from loop.sock_sendall(out_socket, bheaders)
|
||||
fut = create_future(loop)
|
||||
self._sendfile_cb(fut, out_fd, in_fd, 0, count, loop, False)
|
||||
|
||||
yield from fut
|
||||
finally:
|
||||
out_socket.close()
|
||||
|
||||
@asyncio.coroutine
|
||||
def _sendfile_fallback(self, request, resp, fobj, count):
|
||||
# Mimic the _sendfile_system() method, but without using the
|
||||
# os.sendfile() system call. This should be used on systems
|
||||
# that don't support the os.sendfile().
|
||||
|
||||
# To avoid blocking the event loop & to keep memory usage low,
|
||||
# fobj is transferred in chunks controlled by the
|
||||
# constructor's chunk_size argument.
|
||||
|
||||
yield from resp.prepare(request)
|
||||
|
||||
chunk_size = self._chunk_size
|
||||
|
||||
chunk = fobj.read(chunk_size)
|
||||
while True:
|
||||
resp.write(chunk)
|
||||
yield from resp.drain()
|
||||
count = count - chunk_size
|
||||
if count <= 0:
|
||||
break
|
||||
chunk = fobj.read(count)
|
||||
|
||||
if hasattr(os, "sendfile"): # pragma: no cover
|
||||
_sendfile = _sendfile_system
|
||||
else: # pragma: no cover
|
||||
_sendfile = _sendfile_fallback
|
||||
|
||||
@asyncio.coroutine
|
||||
def send(self, request, filepath):
|
||||
"""Send filepath to client using request."""
|
||||
st = filepath.stat()
|
||||
|
||||
modsince = request.if_modified_since
|
||||
if modsince is not None and st.st_mtime <= modsince.timestamp():
|
||||
from .web_exceptions import HTTPNotModified
|
||||
raise HTTPNotModified()
|
||||
|
||||
ct, encoding = mimetypes.guess_type(str(filepath))
|
||||
if not ct:
|
||||
ct = 'application/octet-stream'
|
||||
|
||||
resp = self._response_factory()
|
||||
resp.content_type = ct
|
||||
if encoding:
|
||||
resp.headers[hdrs.CONTENT_ENCODING] = encoding
|
||||
resp.last_modified = st.st_mtime
|
||||
|
||||
file_size = st.st_size
|
||||
|
||||
resp.content_length = file_size
|
||||
resp.set_tcp_cork(True)
|
||||
try:
|
||||
with filepath.open('rb') as f:
|
||||
yield from self._sendfile(request, resp, f, file_size)
|
||||
|
||||
finally:
|
||||
resp.set_tcp_nodelay(True)
|
||||
|
||||
return resp
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
"""HTTP Headers constants."""
|
||||
from multidict import istr
|
||||
|
||||
METH_ANY = '*'
|
||||
METH_CONNECT = 'CONNECT'
|
||||
METH_HEAD = 'HEAD'
|
||||
METH_GET = 'GET'
|
||||
METH_DELETE = 'DELETE'
|
||||
METH_OPTIONS = 'OPTIONS'
|
||||
METH_PATCH = 'PATCH'
|
||||
METH_POST = 'POST'
|
||||
METH_PUT = 'PUT'
|
||||
METH_TRACE = 'TRACE'
|
||||
|
||||
METH_ALL = {METH_CONNECT, METH_HEAD, METH_GET, METH_DELETE,
|
||||
METH_OPTIONS, METH_PATCH, METH_POST, METH_PUT, METH_TRACE}
|
||||
|
||||
|
||||
ACCEPT = istr('ACCEPT')
|
||||
ACCEPT_CHARSET = istr('ACCEPT-CHARSET')
|
||||
ACCEPT_ENCODING = istr('ACCEPT-ENCODING')
|
||||
ACCEPT_LANGUAGE = istr('ACCEPT-LANGUAGE')
|
||||
ACCEPT_RANGES = istr('ACCEPT-RANGES')
|
||||
ACCESS_CONTROL_MAX_AGE = istr('ACCESS-CONTROL-MAX-AGE')
|
||||
ACCESS_CONTROL_ALLOW_CREDENTIALS = istr('ACCESS-CONTROL-ALLOW-CREDENTIALS')
|
||||
ACCESS_CONTROL_ALLOW_HEADERS = istr('ACCESS-CONTROL-ALLOW-HEADERS')
|
||||
ACCESS_CONTROL_ALLOW_METHODS = istr('ACCESS-CONTROL-ALLOW-METHODS')
|
||||
ACCESS_CONTROL_ALLOW_ORIGIN = istr('ACCESS-CONTROL-ALLOW-ORIGIN')
|
||||
ACCESS_CONTROL_EXPOSE_HEADERS = istr('ACCESS-CONTROL-EXPOSE-HEADERS')
|
||||
ACCESS_CONTROL_REQUEST_HEADERS = istr('ACCESS-CONTROL-REQUEST-HEADERS')
|
||||
ACCESS_CONTROL_REQUEST_METHOD = istr('ACCESS-CONTROL-REQUEST-METHOD')
|
||||
AGE = istr('AGE')
|
||||
ALLOW = istr('ALLOW')
|
||||
AUTHORIZATION = istr('AUTHORIZATION')
|
||||
CACHE_CONTROL = istr('CACHE-CONTROL')
|
||||
CONNECTION = istr('CONNECTION')
|
||||
CONTENT_DISPOSITION = istr('CONTENT-DISPOSITION')
|
||||
CONTENT_ENCODING = istr('CONTENT-ENCODING')
|
||||
CONTENT_LANGUAGE = istr('CONTENT-LANGUAGE')
|
||||
CONTENT_LENGTH = istr('CONTENT-LENGTH')
|
||||
CONTENT_LOCATION = istr('CONTENT-LOCATION')
|
||||
CONTENT_MD5 = istr('CONTENT-MD5')
|
||||
CONTENT_RANGE = istr('CONTENT-RANGE')
|
||||
CONTENT_TRANSFER_ENCODING = istr('CONTENT-TRANSFER-ENCODING')
|
||||
CONTENT_TYPE = istr('CONTENT-TYPE')
|
||||
COOKIE = istr('COOKIE')
|
||||
DATE = istr('DATE')
|
||||
DESTINATION = istr('DESTINATION')
|
||||
DIGEST = istr('DIGEST')
|
||||
ETAG = istr('ETAG')
|
||||
EXPECT = istr('EXPECT')
|
||||
EXPIRES = istr('EXPIRES')
|
||||
FROM = istr('FROM')
|
||||
HOST = istr('HOST')
|
||||
IF_MATCH = istr('IF-MATCH')
|
||||
IF_MODIFIED_SINCE = istr('IF-MODIFIED-SINCE')
|
||||
IF_NONE_MATCH = istr('IF-NONE-MATCH')
|
||||
IF_RANGE = istr('IF-RANGE')
|
||||
IF_UNMODIFIED_SINCE = istr('IF-UNMODIFIED-SINCE')
|
||||
KEEP_ALIVE = istr('KEEP-ALIVE')
|
||||
LAST_EVENT_ID = istr('LAST-EVENT-ID')
|
||||
LAST_MODIFIED = istr('LAST-MODIFIED')
|
||||
LINK = istr('LINK')
|
||||
LOCATION = istr('LOCATION')
|
||||
MAX_FORWARDS = istr('MAX-FORWARDS')
|
||||
ORIGIN = istr('ORIGIN')
|
||||
PRAGMA = istr('PRAGMA')
|
||||
PROXY_AUTHENTICATE = istr('PROXY_AUTHENTICATE')
|
||||
PROXY_AUTHORIZATION = istr('PROXY-AUTHORIZATION')
|
||||
RANGE = istr('RANGE')
|
||||
REFERER = istr('REFERER')
|
||||
RETRY_AFTER = istr('RETRY-AFTER')
|
||||
SEC_WEBSOCKET_ACCEPT = istr('SEC-WEBSOCKET-ACCEPT')
|
||||
SEC_WEBSOCKET_VERSION = istr('SEC-WEBSOCKET-VERSION')
|
||||
SEC_WEBSOCKET_PROTOCOL = istr('SEC-WEBSOCKET-PROTOCOL')
|
||||
SEC_WEBSOCKET_KEY = istr('SEC-WEBSOCKET-KEY')
|
||||
SEC_WEBSOCKET_KEY1 = istr('SEC-WEBSOCKET-KEY1')
|
||||
SERVER = istr('SERVER')
|
||||
SET_COOKIE = istr('SET-COOKIE')
|
||||
TE = istr('TE')
|
||||
TRAILER = istr('TRAILER')
|
||||
TRANSFER_ENCODING = istr('TRANSFER-ENCODING')
|
||||
UPGRADE = istr('UPGRADE')
|
||||
WEBSOCKET = istr('WEBSOCKET')
|
||||
URI = istr('URI')
|
||||
USER_AGENT = istr('USER-AGENT')
|
||||
VARY = istr('VARY')
|
||||
VIA = istr('VIA')
|
||||
WANT_DIGEST = istr('WANT-DIGEST')
|
||||
WARNING = istr('WARNING')
|
||||
WWW_AUTHENTICATE = istr('WWW-AUTHENTICATE')
|
||||
|
|
@ -0,0 +1,534 @@
|
|||
"""Various helper functions"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
import datetime
|
||||
import functools
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote, urlencode
|
||||
|
||||
from async_timeout import timeout
|
||||
from multidict import MultiDict, MultiDictProxy
|
||||
|
||||
from . import hdrs
|
||||
from .errors import InvalidURL
|
||||
|
||||
try:
|
||||
from asyncio import ensure_future
|
||||
except ImportError:
|
||||
ensure_future = asyncio.async
|
||||
|
||||
|
||||
__all__ = ('BasicAuth', 'create_future', 'FormData', 'parse_mimetype',
|
||||
'Timeout', 'ensure_future')
|
||||
|
||||
|
||||
sentinel = object()
|
||||
Timeout = timeout
|
||||
|
||||
|
||||
class BasicAuth(namedtuple('BasicAuth', ['login', 'password', 'encoding'])):
|
||||
"""Http basic authentication helper.
|
||||
|
||||
:param str login: Login
|
||||
:param str password: Password
|
||||
:param str encoding: (optional) encoding ('latin1' by default)
|
||||
"""
|
||||
|
||||
def __new__(cls, login, password='', encoding='latin1'):
|
||||
if login is None:
|
||||
raise ValueError('None is not allowed as login value')
|
||||
|
||||
if password is None:
|
||||
raise ValueError('None is not allowed as password value')
|
||||
|
||||
return super().__new__(cls, login, password, encoding)
|
||||
|
||||
@classmethod
|
||||
def decode(cls, auth_header, encoding='latin1'):
|
||||
"""Create a :class:`BasicAuth` object from an ``Authorization`` HTTP
|
||||
header."""
|
||||
split = auth_header.strip().split(' ')
|
||||
if len(split) == 2:
|
||||
if split[0].strip().lower() != 'basic':
|
||||
raise ValueError('Unknown authorization method %s' % split[0])
|
||||
to_decode = split[1]
|
||||
else:
|
||||
raise ValueError('Could not parse authorization header.')
|
||||
|
||||
try:
|
||||
username, _, password = base64.b64decode(
|
||||
to_decode.encode('ascii')
|
||||
).decode(encoding).partition(':')
|
||||
except binascii.Error:
|
||||
raise ValueError('Invalid base64 encoding.')
|
||||
|
||||
return cls(username, password, encoding=encoding)
|
||||
|
||||
def encode(self):
|
||||
"""Encode credentials."""
|
||||
creds = ('%s:%s' % (self.login, self.password)).encode(self.encoding)
|
||||
return 'Basic %s' % base64.b64encode(creds).decode(self.encoding)
|
||||
|
||||
|
||||
def create_future(loop):
|
||||
"""Compatibility wrapper for the loop.create_future() call introduced in
|
||||
3.5.2."""
|
||||
if hasattr(loop, 'create_future'):
|
||||
return loop.create_future()
|
||||
else:
|
||||
return asyncio.Future(loop=loop)
|
||||
|
||||
|
||||
class FormData:
|
||||
"""Helper class for multipart/form-data and
|
||||
application/x-www-form-urlencoded body generation."""
|
||||
|
||||
def __init__(self, fields=()):
|
||||
from . import multipart
|
||||
self._writer = multipart.MultipartWriter('form-data')
|
||||
self._fields = []
|
||||
self._is_multipart = False
|
||||
|
||||
if isinstance(fields, dict):
|
||||
fields = list(fields.items())
|
||||
elif not isinstance(fields, (list, tuple)):
|
||||
fields = (fields,)
|
||||
self.add_fields(*fields)
|
||||
|
||||
@property
|
||||
def is_multipart(self):
|
||||
return self._is_multipart
|
||||
|
||||
@property
|
||||
def content_type(self):
|
||||
if self._is_multipart:
|
||||
return self._writer.headers[hdrs.CONTENT_TYPE]
|
||||
else:
|
||||
return 'application/x-www-form-urlencoded'
|
||||
|
||||
def add_field(self, name, value, *, content_type=None, filename=None,
|
||||
content_transfer_encoding=None):
|
||||
|
||||
if isinstance(value, io.IOBase):
|
||||
self._is_multipart = True
|
||||
elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
if filename is None and content_transfer_encoding is None:
|
||||
filename = name
|
||||
|
||||
type_options = MultiDict({'name': name})
|
||||
if filename is not None and not isinstance(filename, str):
|
||||
raise TypeError('filename must be an instance of str. '
|
||||
'Got: %s' % filename)
|
||||
if filename is None and isinstance(value, io.IOBase):
|
||||
filename = guess_filename(value, name)
|
||||
if filename is not None:
|
||||
type_options['filename'] = filename
|
||||
self._is_multipart = True
|
||||
|
||||
headers = {}
|
||||
if content_type is not None:
|
||||
if not isinstance(content_type, str):
|
||||
raise TypeError('content_type must be an instance of str. '
|
||||
'Got: %s' % content_type)
|
||||
headers[hdrs.CONTENT_TYPE] = content_type
|
||||
self._is_multipart = True
|
||||
if content_transfer_encoding is not None:
|
||||
if not isinstance(content_transfer_encoding, str):
|
||||
raise TypeError('content_transfer_encoding must be an instance'
|
||||
' of str. Got: %s' % content_transfer_encoding)
|
||||
headers[hdrs.CONTENT_TRANSFER_ENCODING] = content_transfer_encoding
|
||||
self._is_multipart = True
|
||||
|
||||
self._fields.append((type_options, headers, value))
|
||||
|
||||
def add_fields(self, *fields):
|
||||
to_add = list(fields)
|
||||
|
||||
while to_add:
|
||||
rec = to_add.pop(0)
|
||||
|
||||
if isinstance(rec, io.IOBase):
|
||||
k = guess_filename(rec, 'unknown')
|
||||
self.add_field(k, rec)
|
||||
|
||||
elif isinstance(rec, (MultiDictProxy, MultiDict)):
|
||||
to_add.extend(rec.items())
|
||||
|
||||
elif isinstance(rec, (list, tuple)) and len(rec) == 2:
|
||||
k, fp = rec
|
||||
self.add_field(k, fp)
|
||||
|
||||
else:
|
||||
raise TypeError('Only io.IOBase, multidict and (name, file) '
|
||||
'pairs allowed, use .add_field() for passing '
|
||||
'more complex parameters')
|
||||
|
||||
def _gen_form_urlencoded(self, encoding):
|
||||
# form data (x-www-form-urlencoded)
|
||||
data = []
|
||||
for type_options, _, value in self._fields:
|
||||
data.append((type_options['name'], value))
|
||||
|
||||
data = urlencode(data, doseq=True)
|
||||
return data.encode(encoding)
|
||||
|
||||
def _gen_form_data(self, *args, **kwargs):
|
||||
"""Encode a list of fields using the multipart/form-data MIME format"""
|
||||
for dispparams, headers, value in self._fields:
|
||||
part = self._writer.append(value, headers)
|
||||
if dispparams:
|
||||
part.set_content_disposition('form-data', **dispparams)
|
||||
# FIXME cgi.FieldStorage doesn't likes body parts with
|
||||
# Content-Length which were sent via chunked transfer encoding
|
||||
part.headers.pop(hdrs.CONTENT_LENGTH, None)
|
||||
yield from self._writer.serialize()
|
||||
|
||||
def __call__(self, encoding):
|
||||
if self._is_multipart:
|
||||
return self._gen_form_data(encoding)
|
||||
else:
|
||||
return self._gen_form_urlencoded(encoding)
|
||||
|
||||
|
||||
def parse_mimetype(mimetype):
|
||||
"""Parses a MIME type into its components.
|
||||
|
||||
:param str mimetype: MIME type
|
||||
|
||||
:returns: 4 element tuple for MIME type, subtype, suffix and parameters
|
||||
:rtype: tuple
|
||||
|
||||
Example:
|
||||
|
||||
>>> parse_mimetype('text/html; charset=utf-8')
|
||||
('text', 'html', '', {'charset': 'utf-8'})
|
||||
|
||||
"""
|
||||
if not mimetype:
|
||||
return '', '', '', {}
|
||||
|
||||
parts = mimetype.split(';')
|
||||
params = []
|
||||
for item in parts[1:]:
|
||||
if not item:
|
||||
continue
|
||||
key, value = item.split('=', 1) if '=' in item else (item, '')
|
||||
params.append((key.lower().strip(), value.strip(' "')))
|
||||
params = dict(params)
|
||||
|
||||
fulltype = parts[0].strip().lower()
|
||||
if fulltype == '*':
|
||||
fulltype = '*/*'
|
||||
|
||||
mtype, stype = fulltype.split('/', 1) \
|
||||
if '/' in fulltype else (fulltype, '')
|
||||
stype, suffix = stype.split('+', 1) if '+' in stype else (stype, '')
|
||||
|
||||
return mtype, stype, suffix, params
|
||||
|
||||
|
||||
def guess_filename(obj, default=None):
|
||||
name = getattr(obj, 'name', None)
|
||||
if name and name[0] != '<' and name[-1] != '>':
|
||||
return Path(name).name
|
||||
return default
|
||||
|
||||
|
||||
class AccessLogger:
|
||||
"""Helper object to log access.
|
||||
|
||||
Usage:
|
||||
log = logging.getLogger("spam")
|
||||
log_format = "%a %{User-Agent}i"
|
||||
access_logger = AccessLogger(log, log_format)
|
||||
access_logger.log(message, environ, response, transport, time)
|
||||
|
||||
Format:
|
||||
%% The percent sign
|
||||
%a Remote IP-address (IP-address of proxy if using reverse proxy)
|
||||
%t Time when the request was started to process
|
||||
%P The process ID of the child that serviced the request
|
||||
%r First line of request
|
||||
%s Response status code
|
||||
%b Size of response in bytes, excluding HTTP headers
|
||||
%O Bytes sent, including headers
|
||||
%T Time taken to serve the request, in seconds
|
||||
%Tf Time taken to serve the request, in seconds with floating fraction
|
||||
in .06f format
|
||||
%D Time taken to serve the request, in microseconds
|
||||
%{FOO}i request.headers['FOO']
|
||||
%{FOO}o response.headers['FOO']
|
||||
%{FOO}e os.environ['FOO']
|
||||
|
||||
"""
|
||||
|
||||
LOG_FORMAT = '%a %l %u %t "%r" %s %b "%{Referrer}i" "%{User-Agent}i"'
|
||||
FORMAT_RE = re.compile(r'%(\{([A-Za-z\-]+)\}([ioe])|[atPrsbOD]|Tf?)')
|
||||
CLEANUP_RE = re.compile(r'(%[^s])')
|
||||
_FORMAT_CACHE = {}
|
||||
|
||||
def __init__(self, logger, log_format=LOG_FORMAT):
|
||||
"""Initialise the logger.
|
||||
|
||||
:param logger: logger object to be used for logging
|
||||
:param log_format: apache compatible log format
|
||||
|
||||
"""
|
||||
self.logger = logger
|
||||
_compiled_format = AccessLogger._FORMAT_CACHE.get(log_format)
|
||||
if not _compiled_format:
|
||||
_compiled_format = self.compile_format(log_format)
|
||||
AccessLogger._FORMAT_CACHE[log_format] = _compiled_format
|
||||
self._log_format, self._methods = _compiled_format
|
||||
|
||||
def compile_format(self, log_format):
|
||||
"""Translate log_format into form usable by modulo formatting
|
||||
|
||||
All known atoms will be replaced with %s
|
||||
Also methods for formatting of those atoms will be added to
|
||||
_methods in apropriate order
|
||||
|
||||
For example we have log_format = "%a %t"
|
||||
This format will be translated to "%s %s"
|
||||
Also contents of _methods will be
|
||||
[self._format_a, self._format_t]
|
||||
These method will be called and results will be passed
|
||||
to translated string format.
|
||||
|
||||
Each _format_* method receive 'args' which is list of arguments
|
||||
given to self.log
|
||||
|
||||
Exceptions are _format_e, _format_i and _format_o methods which
|
||||
also receive key name (by functools.partial)
|
||||
|
||||
"""
|
||||
|
||||
log_format = log_format.replace("%l", "-")
|
||||
log_format = log_format.replace("%u", "-")
|
||||
methods = []
|
||||
|
||||
for atom in self.FORMAT_RE.findall(log_format):
|
||||
if atom[1] == '':
|
||||
methods.append(getattr(AccessLogger, '_format_%s' % atom[0]))
|
||||
else:
|
||||
m = getattr(AccessLogger, '_format_%s' % atom[2])
|
||||
methods.append(functools.partial(m, atom[1]))
|
||||
log_format = self.FORMAT_RE.sub(r'%s', log_format)
|
||||
log_format = self.CLEANUP_RE.sub(r'%\1', log_format)
|
||||
return log_format, methods
|
||||
|
||||
@staticmethod
|
||||
def _format_e(key, args):
|
||||
return (args[1] or {}).get(key, '-')
|
||||
|
||||
@staticmethod
|
||||
def _format_i(key, args):
|
||||
if not args[0]:
|
||||
return '(no headers)'
|
||||
# suboptimal, make istr(key) once
|
||||
return args[0].headers.get(key, '-')
|
||||
|
||||
@staticmethod
|
||||
def _format_o(key, args):
|
||||
# suboptimal, make istr(key) once
|
||||
return args[2].headers.get(key, '-')
|
||||
|
||||
@staticmethod
|
||||
def _format_a(args):
|
||||
if args[3] is None:
|
||||
return '-'
|
||||
peername = args[3].get_extra_info('peername')
|
||||
if isinstance(peername, (list, tuple)):
|
||||
return peername[0]
|
||||
else:
|
||||
return peername
|
||||
|
||||
@staticmethod
|
||||
def _format_t(args):
|
||||
return datetime.datetime.utcnow().strftime('[%d/%b/%Y:%H:%M:%S +0000]')
|
||||
|
||||
@staticmethod
|
||||
def _format_P(args):
|
||||
return "<%s>" % os.getpid()
|
||||
|
||||
@staticmethod
|
||||
def _format_r(args):
|
||||
msg = args[0]
|
||||
if not msg:
|
||||
return '-'
|
||||
return '%s %s HTTP/%s.%s' % tuple((msg.method,
|
||||
msg.path) + msg.version)
|
||||
|
||||
@staticmethod
|
||||
def _format_s(args):
|
||||
return args[2].status
|
||||
|
||||
@staticmethod
|
||||
def _format_b(args):
|
||||
return args[2].body_length
|
||||
|
||||
@staticmethod
|
||||
def _format_O(args):
|
||||
return args[2].output_length
|
||||
|
||||
@staticmethod
|
||||
def _format_T(args):
|
||||
return round(args[4])
|
||||
|
||||
@staticmethod
|
||||
def _format_Tf(args):
|
||||
return '%06f' % args[4]
|
||||
|
||||
@staticmethod
|
||||
def _format_D(args):
|
||||
return round(args[4] * 1000000)
|
||||
|
||||
def _format_line(self, args):
|
||||
return tuple(m(args) for m in self._methods)
|
||||
|
||||
def log(self, message, environ, response, transport, time):
|
||||
"""Log access.
|
||||
|
||||
:param message: Request object. May be None.
|
||||
:param environ: Environment dict. May be None.
|
||||
:param response: Response object.
|
||||
:param transport: Tansport object. May be None
|
||||
:param float time: Time taken to serve the request.
|
||||
"""
|
||||
try:
|
||||
self.logger.info(self._log_format % self._format_line(
|
||||
[message, environ, response, transport, time]))
|
||||
except Exception:
|
||||
self.logger.exception("Error in logging")
|
||||
|
||||
|
||||
class reify:
|
||||
"""Use as a class method decorator. It operates almost exactly like
|
||||
the Python `@property` decorator, but it puts the result of the
|
||||
method it decorates into the instance dict after the first call,
|
||||
effectively replacing the function it decorates with an instance
|
||||
variable. It is, in Python parlance, a data descriptor.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, wrapped):
|
||||
self.wrapped = wrapped
|
||||
try:
|
||||
self.__doc__ = wrapped.__doc__
|
||||
except: # pragma: no cover
|
||||
self.__doc__ = ""
|
||||
self.name = wrapped.__name__
|
||||
|
||||
def __get__(self, inst, owner, _sentinel=sentinel):
|
||||
if inst is None:
|
||||
return self
|
||||
val = inst.__dict__.get(self.name, _sentinel)
|
||||
if val is not _sentinel:
|
||||
return val
|
||||
val = self.wrapped(inst)
|
||||
inst.__dict__[self.name] = val
|
||||
return val
|
||||
|
||||
def __set__(self, inst, value):
|
||||
raise AttributeError("reified property is read-only")
|
||||
|
||||
|
||||
# The unreserved URI characters (RFC 3986)
|
||||
UNRESERVED_SET = frozenset(
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +
|
||||
"0123456789-._~")
|
||||
|
||||
|
||||
def unquote_unreserved(uri):
|
||||
"""Un-escape any percent-escape sequences in a URI that are unreserved
|
||||
characters. This leaves all reserved, illegal and non-ASCII bytes encoded.
|
||||
"""
|
||||
parts = uri.split('%')
|
||||
for i in range(1, len(parts)):
|
||||
h = parts[i][0:2]
|
||||
if len(h) == 2 and h.isalnum():
|
||||
try:
|
||||
c = chr(int(h, 16))
|
||||
except ValueError:
|
||||
raise InvalidURL("Invalid percent-escape sequence: '%s'" % h)
|
||||
|
||||
if c in UNRESERVED_SET:
|
||||
parts[i] = c + parts[i][2:]
|
||||
else:
|
||||
parts[i] = '%' + parts[i]
|
||||
else:
|
||||
parts[i] = '%' + parts[i]
|
||||
return ''.join(parts)
|
||||
|
||||
|
||||
def requote_uri(uri):
|
||||
"""Re-quote the given URI.
|
||||
|
||||
This function passes the given URI through an unquote/quote cycle to
|
||||
ensure that it is fully and consistently quoted.
|
||||
"""
|
||||
safe_with_percent = "!#$%&'()*+,/:;=?@[]~"
|
||||
safe_without_percent = "!#$&'()*+,/:;=?@[]~"
|
||||
try:
|
||||
# Unquote only the unreserved characters
|
||||
# Then quote only illegal characters (do not quote reserved,
|
||||
# unreserved, or '%')
|
||||
return quote(unquote_unreserved(uri), safe=safe_with_percent)
|
||||
except InvalidURL:
|
||||
# We couldn't unquote the given URI, so let's try quoting it, but
|
||||
# there may be unquoted '%'s in the URI. We need to make sure they're
|
||||
# properly quoted so they do not cause issues elsewhere.
|
||||
return quote(uri, safe=safe_without_percent)
|
||||
|
||||
|
||||
_ipv4_pattern = ('^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}'
|
||||
'(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$')
|
||||
_ipv6_pattern = (
|
||||
'^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}'
|
||||
'(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)'
|
||||
'((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})'
|
||||
'(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}'
|
||||
'(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}'
|
||||
'[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)'
|
||||
'(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}'
|
||||
':|:(:[A-F0-9]{1,4}){7})$')
|
||||
_ipv4_regex = re.compile(_ipv4_pattern)
|
||||
_ipv6_regex = re.compile(_ipv6_pattern, flags=re.IGNORECASE)
|
||||
_ipv4_regexb = re.compile(_ipv4_pattern.encode('ascii'))
|
||||
_ipv6_regexb = re.compile(_ipv6_pattern.encode('ascii'), flags=re.IGNORECASE)
|
||||
|
||||
|
||||
def is_ip_address(host):
|
||||
if host is None:
|
||||
return False
|
||||
if isinstance(host, str):
|
||||
if _ipv4_regex.match(host) or _ipv6_regex.match(host):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
elif isinstance(host, (bytes, bytearray, memoryview)):
|
||||
if _ipv4_regexb.match(host) or _ipv6_regexb.match(host):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
raise TypeError("{} [{}] is not a str or bytes"
|
||||
.format(host, type(host)))
|
||||
|
||||
|
||||
def _get_kwarg(kwargs, old, new, value):
|
||||
val = kwargs.pop(old, sentinel)
|
||||
if val is not sentinel:
|
||||
warnings.warn("{} is deprecated, use {} instead".format(old, new),
|
||||
DeprecationWarning,
|
||||
stacklevel=3)
|
||||
return val
|
||||
else:
|
||||
return value
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
import logging
|
||||
|
||||
access_logger = logging.getLogger('aiohttp.access')
|
||||
client_logger = logging.getLogger('aiohttp.client')
|
||||
internal_logger = logging.getLogger('aiohttp.internal')
|
||||
server_logger = logging.getLogger('aiohttp.server')
|
||||
web_logger = logging.getLogger('aiohttp.web')
|
||||
ws_logger = logging.getLogger('aiohttp.websocket')
|
||||
|
|
@ -0,0 +1,973 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
import io
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
import warnings
|
||||
import zlib
|
||||
from collections import Mapping, Sequence, deque
|
||||
from pathlib import Path
|
||||
from urllib.parse import parse_qsl, quote, unquote, urlencode
|
||||
|
||||
from multidict import CIMultiDict
|
||||
|
||||
from .hdrs import (CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_LENGTH,
|
||||
CONTENT_TRANSFER_ENCODING, CONTENT_TYPE)
|
||||
from .helpers import parse_mimetype
|
||||
from .protocol import HttpParser
|
||||
|
||||
__all__ = ('MultipartReader', 'MultipartWriter',
|
||||
'BodyPartReader', 'BodyPartWriter',
|
||||
'BadContentDispositionHeader', 'BadContentDispositionParam',
|
||||
'parse_content_disposition', 'content_disposition_filename')
|
||||
|
||||
|
||||
CHAR = set(chr(i) for i in range(0, 128))
|
||||
CTL = set(chr(i) for i in range(0, 32)) | {chr(127), }
|
||||
SEPARATORS = {'(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']',
|
||||
'?', '=', '{', '}', ' ', chr(9)}
|
||||
TOKEN = CHAR ^ CTL ^ SEPARATORS
|
||||
|
||||
PY_35 = sys.version_info >= (3, 5)
|
||||
PY_352 = sys.version_info >= (3, 5, 2)
|
||||
|
||||
|
||||
class BadContentDispositionHeader(RuntimeWarning):
|
||||
pass
|
||||
|
||||
|
||||
class BadContentDispositionParam(RuntimeWarning):
|
||||
pass
|
||||
|
||||
|
||||
def parse_content_disposition(header):
|
||||
def is_token(string):
|
||||
return string and TOKEN >= set(string)
|
||||
|
||||
def is_quoted(string):
|
||||
return string[0] == string[-1] == '"'
|
||||
|
||||
def is_rfc5987(string):
|
||||
return is_token(string) and string.count("'") == 2
|
||||
|
||||
def is_extended_param(string):
|
||||
return string.endswith('*')
|
||||
|
||||
def is_continuous_param(string):
|
||||
pos = string.find('*') + 1
|
||||
if not pos:
|
||||
return False
|
||||
substring = string[pos:-1] if string.endswith('*') else string[pos:]
|
||||
return substring.isdigit()
|
||||
|
||||
def unescape(text, *, chars=''.join(map(re.escape, CHAR))):
|
||||
return re.sub('\\\\([{}])'.format(chars), '\\1', text)
|
||||
|
||||
if not header:
|
||||
return None, {}
|
||||
|
||||
disptype, *parts = header.split(';')
|
||||
if not is_token(disptype):
|
||||
warnings.warn(BadContentDispositionHeader(header))
|
||||
return None, {}
|
||||
|
||||
params = {}
|
||||
for item in parts:
|
||||
if '=' not in item:
|
||||
warnings.warn(BadContentDispositionHeader(header))
|
||||
return None, {}
|
||||
|
||||
key, value = item.split('=', 1)
|
||||
key = key.lower().strip()
|
||||
value = value.lstrip()
|
||||
|
||||
if key in params:
|
||||
warnings.warn(BadContentDispositionHeader(header))
|
||||
return None, {}
|
||||
|
||||
if not is_token(key):
|
||||
warnings.warn(BadContentDispositionParam(item))
|
||||
continue
|
||||
|
||||
elif is_continuous_param(key):
|
||||
if is_quoted(value):
|
||||
value = unescape(value[1:-1])
|
||||
elif not is_token(value):
|
||||
warnings.warn(BadContentDispositionParam(item))
|
||||
continue
|
||||
|
||||
elif is_extended_param(key):
|
||||
if is_rfc5987(value):
|
||||
encoding, _, value = value.split("'", 2)
|
||||
encoding = encoding or 'utf-8'
|
||||
else:
|
||||
warnings.warn(BadContentDispositionParam(item))
|
||||
continue
|
||||
|
||||
try:
|
||||
value = unquote(value, encoding, 'strict')
|
||||
except UnicodeDecodeError: # pragma: nocover
|
||||
warnings.warn(BadContentDispositionParam(item))
|
||||
continue
|
||||
|
||||
else:
|
||||
if is_quoted(value):
|
||||
value = unescape(value[1:-1].lstrip('\\/'))
|
||||
elif not is_token(value):
|
||||
warnings.warn(BadContentDispositionHeader(header))
|
||||
return None, {}
|
||||
|
||||
params[key] = value
|
||||
|
||||
return disptype.lower(), params
|
||||
|
||||
|
||||
def content_disposition_filename(params):
|
||||
if not params:
|
||||
return None
|
||||
elif 'filename*' in params:
|
||||
return params['filename*']
|
||||
elif 'filename' in params:
|
||||
return params['filename']
|
||||
else:
|
||||
parts = []
|
||||
fnparams = sorted((key, value)
|
||||
for key, value in params.items()
|
||||
if key.startswith('filename*'))
|
||||
for num, (key, value) in enumerate(fnparams):
|
||||
_, tail = key.split('*', 1)
|
||||
if tail.endswith('*'):
|
||||
tail = tail[:-1]
|
||||
if tail == str(num):
|
||||
parts.append(value)
|
||||
else:
|
||||
break
|
||||
if not parts:
|
||||
return None
|
||||
value = ''.join(parts)
|
||||
if "'" in value:
|
||||
encoding, _, value = value.split("'", 2)
|
||||
encoding = encoding or 'utf-8'
|
||||
return unquote(value, encoding, 'strict')
|
||||
return value
|
||||
|
||||
|
||||
class MultipartResponseWrapper(object):
|
||||
"""Wrapper around the :class:`MultipartBodyReader` to take care about
|
||||
underlying connection and close it when it needs in."""
|
||||
|
||||
def __init__(self, resp, stream):
|
||||
self.resp = resp
|
||||
self.stream = stream
|
||||
|
||||
if PY_35:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
if not PY_352: # pragma: no cover
|
||||
__aiter__ = asyncio.coroutine(__aiter__)
|
||||
|
||||
@asyncio.coroutine
|
||||
def __anext__(self):
|
||||
part = yield from self.next()
|
||||
if part is None:
|
||||
raise StopAsyncIteration # NOQA
|
||||
return part
|
||||
|
||||
def at_eof(self):
|
||||
"""Returns ``True`` when all response data had been read.
|
||||
|
||||
:rtype: bool
|
||||
"""
|
||||
return self.resp.content.at_eof()
|
||||
|
||||
@asyncio.coroutine
|
||||
def next(self):
|
||||
"""Emits next multipart reader object."""
|
||||
item = yield from self.stream.next()
|
||||
if self.stream.at_eof():
|
||||
yield from self.release()
|
||||
return item
|
||||
|
||||
@asyncio.coroutine
|
||||
def release(self):
|
||||
"""Releases the connection gracefully, reading all the content
|
||||
to the void."""
|
||||
yield from self.resp.release()
|
||||
|
||||
|
||||
class BodyPartReader(object):
|
||||
"""Multipart reader for single body part."""
|
||||
|
||||
chunk_size = 8192
|
||||
|
||||
def __init__(self, boundary, headers, content):
|
||||
self.headers = headers
|
||||
self._boundary = boundary
|
||||
self._content = content
|
||||
self._at_eof = False
|
||||
length = self.headers.get(CONTENT_LENGTH, None)
|
||||
self._length = int(length) if length is not None else None
|
||||
self._read_bytes = 0
|
||||
self._unread = deque()
|
||||
self._prev_chunk = None
|
||||
self._content_eof = 0
|
||||
|
||||
if PY_35:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
if not PY_352: # pragma: no cover
|
||||
__aiter__ = asyncio.coroutine(__aiter__)
|
||||
|
||||
@asyncio.coroutine
|
||||
def __anext__(self):
|
||||
part = yield from self.next()
|
||||
if part is None:
|
||||
raise StopAsyncIteration # NOQA
|
||||
return part
|
||||
|
||||
@asyncio.coroutine
|
||||
def next(self):
|
||||
item = yield from self.read()
|
||||
if not item:
|
||||
return None
|
||||
return item
|
||||
|
||||
@asyncio.coroutine
|
||||
def read(self, *, decode=False):
|
||||
"""Reads body part data.
|
||||
|
||||
:param bool decode: Decodes data following by encoding
|
||||
method from `Content-Encoding` header. If it missed
|
||||
data remains untouched
|
||||
|
||||
:rtype: bytearray
|
||||
"""
|
||||
if self._at_eof:
|
||||
return b''
|
||||
data = bytearray()
|
||||
if self._length is None:
|
||||
while not self._at_eof:
|
||||
data.extend((yield from self.readline()))
|
||||
else:
|
||||
while not self._at_eof:
|
||||
data.extend((yield from self.read_chunk(self.chunk_size)))
|
||||
if decode:
|
||||
return self.decode(data)
|
||||
return data
|
||||
|
||||
@asyncio.coroutine
|
||||
def read_chunk(self, size=chunk_size):
|
||||
"""Reads body part content chunk of the specified size.
|
||||
|
||||
:param int size: chunk size
|
||||
|
||||
:rtype: bytearray
|
||||
"""
|
||||
if self._at_eof:
|
||||
return b''
|
||||
if self._length:
|
||||
chunk = yield from self._read_chunk_from_length(size)
|
||||
else:
|
||||
chunk = yield from self._read_chunk_from_stream(size)
|
||||
|
||||
self._read_bytes += len(chunk)
|
||||
if self._read_bytes == self._length:
|
||||
self._at_eof = True
|
||||
if self._at_eof:
|
||||
assert b'\r\n' == (yield from self._content.readline()), \
|
||||
'reader did not read all the data or it is malformed'
|
||||
return chunk
|
||||
|
||||
@asyncio.coroutine
|
||||
def _read_chunk_from_length(self, size):
|
||||
"""Reads body part content chunk of the specified size.
|
||||
The body part must has `Content-Length` header with proper value.
|
||||
|
||||
:param int size: chunk size
|
||||
|
||||
:rtype: bytearray
|
||||
"""
|
||||
assert self._length is not None, \
|
||||
'Content-Length required for chunked read'
|
||||
chunk_size = min(size, self._length - self._read_bytes)
|
||||
chunk = yield from self._content.read(chunk_size)
|
||||
return chunk
|
||||
|
||||
@asyncio.coroutine
|
||||
def _read_chunk_from_stream(self, size):
|
||||
"""Reads content chunk of body part with unknown length.
|
||||
The `Content-Length` header for body part is not necessary.
|
||||
|
||||
:param int size: chunk size
|
||||
|
||||
:rtype: bytearray
|
||||
"""
|
||||
assert size >= len(self._boundary) + 2, \
|
||||
'Chunk size must be greater or equal than boundary length + 2'
|
||||
first_chunk = self._prev_chunk is None
|
||||
if first_chunk:
|
||||
self._prev_chunk = yield from self._content.read(size)
|
||||
|
||||
chunk = yield from self._content.read(size)
|
||||
self._content_eof += int(self._content.at_eof())
|
||||
assert self._content_eof < 3, "Reading after EOF"
|
||||
window = self._prev_chunk + chunk
|
||||
sub = b'\r\n' + self._boundary
|
||||
if first_chunk:
|
||||
idx = window.find(sub)
|
||||
else:
|
||||
idx = window.find(sub, max(0, len(self._prev_chunk) - len(sub)))
|
||||
if idx >= 0:
|
||||
# pushing boundary back to content
|
||||
self._content.unread_data(window[idx:])
|
||||
if size > idx:
|
||||
self._prev_chunk = self._prev_chunk[:idx]
|
||||
chunk = window[len(self._prev_chunk):idx]
|
||||
if not chunk:
|
||||
self._at_eof = True
|
||||
if 0 < len(chunk) < len(sub) and not self._content_eof:
|
||||
self._prev_chunk += chunk
|
||||
self._at_eof = False
|
||||
return b''
|
||||
result = self._prev_chunk
|
||||
self._prev_chunk = chunk
|
||||
return result
|
||||
|
||||
@asyncio.coroutine
|
||||
def readline(self):
|
||||
"""Reads body part by line by line.
|
||||
|
||||
:rtype: bytearray
|
||||
"""
|
||||
if self._at_eof:
|
||||
return b''
|
||||
|
||||
if self._unread:
|
||||
line = self._unread.popleft()
|
||||
else:
|
||||
line = yield from self._content.readline()
|
||||
|
||||
if line.startswith(self._boundary):
|
||||
# the very last boundary may not come with \r\n,
|
||||
# so set single rules for everyone
|
||||
sline = line.rstrip(b'\r\n')
|
||||
boundary = self._boundary
|
||||
last_boundary = self._boundary + b'--'
|
||||
# ensure that we read exactly the boundary, not something alike
|
||||
if sline == boundary or sline == last_boundary:
|
||||
self._at_eof = True
|
||||
self._unread.append(line)
|
||||
return b''
|
||||
else:
|
||||
next_line = yield from self._content.readline()
|
||||
if next_line.startswith(self._boundary):
|
||||
line = line[:-2] # strip CRLF but only once
|
||||
self._unread.append(next_line)
|
||||
|
||||
return line
|
||||
|
||||
@asyncio.coroutine
|
||||
def release(self):
|
||||
"""Like :meth:`read`, but reads all the data to the void.
|
||||
|
||||
:rtype: None
|
||||
"""
|
||||
if self._at_eof:
|
||||
return
|
||||
if self._length is None:
|
||||
while not self._at_eof:
|
||||
yield from self.readline()
|
||||
else:
|
||||
while not self._at_eof:
|
||||
yield from self.read_chunk(self.chunk_size)
|
||||
|
||||
@asyncio.coroutine
|
||||
def text(self, *, encoding=None):
|
||||
"""Like :meth:`read`, but assumes that body part contains text data.
|
||||
|
||||
:param str encoding: Custom text encoding. Overrides specified
|
||||
in charset param of `Content-Type` header
|
||||
|
||||
:rtype: str
|
||||
"""
|
||||
data = yield from self.read(decode=True)
|
||||
encoding = encoding or self.get_charset(default='latin1')
|
||||
return data.decode(encoding)
|
||||
|
||||
@asyncio.coroutine
|
||||
def json(self, *, encoding=None):
|
||||
"""Like :meth:`read`, but assumes that body parts contains JSON data.
|
||||
|
||||
:param str encoding: Custom JSON encoding. Overrides specified
|
||||
in charset param of `Content-Type` header
|
||||
"""
|
||||
data = yield from self.read(decode=True)
|
||||
if not data:
|
||||
return None
|
||||
encoding = encoding or self.get_charset(default='utf-8')
|
||||
return json.loads(data.decode(encoding))
|
||||
|
||||
@asyncio.coroutine
|
||||
def form(self, *, encoding=None):
|
||||
"""Like :meth:`read`, but assumes that body parts contains form
|
||||
urlencoded data.
|
||||
|
||||
:param str encoding: Custom form encoding. Overrides specified
|
||||
in charset param of `Content-Type` header
|
||||
"""
|
||||
data = yield from self.read(decode=True)
|
||||
if not data:
|
||||
return None
|
||||
encoding = encoding or self.get_charset(default='utf-8')
|
||||
return parse_qsl(data.rstrip().decode(encoding), encoding=encoding)
|
||||
|
||||
def at_eof(self):
|
||||
"""Returns ``True`` if the boundary was reached or
|
||||
``False`` otherwise.
|
||||
|
||||
:rtype: bool
|
||||
"""
|
||||
return self._at_eof
|
||||
|
||||
def decode(self, data):
|
||||
"""Decodes data according the specified `Content-Encoding`
|
||||
or `Content-Transfer-Encoding` headers value.
|
||||
|
||||
Supports ``gzip``, ``deflate`` and ``identity`` encodings for
|
||||
`Content-Encoding` header.
|
||||
|
||||
Supports ``base64``, ``quoted-printable``, ``binary`` encodings for
|
||||
`Content-Transfer-Encoding` header.
|
||||
|
||||
:param bytearray data: Data to decode.
|
||||
|
||||
:raises: :exc:`RuntimeError` - if encoding is unknown.
|
||||
|
||||
:rtype: bytes
|
||||
"""
|
||||
if CONTENT_TRANSFER_ENCODING in self.headers:
|
||||
data = self._decode_content_transfer(data)
|
||||
if CONTENT_ENCODING in self.headers:
|
||||
return self._decode_content(data)
|
||||
return data
|
||||
|
||||
def _decode_content(self, data):
|
||||
encoding = self.headers[CONTENT_ENCODING].lower()
|
||||
|
||||
if encoding == 'deflate':
|
||||
return zlib.decompress(data, -zlib.MAX_WBITS)
|
||||
elif encoding == 'gzip':
|
||||
return zlib.decompress(data, 16 + zlib.MAX_WBITS)
|
||||
elif encoding == 'identity':
|
||||
return data
|
||||
else:
|
||||
raise RuntimeError('unknown content encoding: {}'.format(encoding))
|
||||
|
||||
def _decode_content_transfer(self, data):
|
||||
encoding = self.headers[CONTENT_TRANSFER_ENCODING].lower()
|
||||
|
||||
if encoding == 'base64':
|
||||
return base64.b64decode(data)
|
||||
elif encoding == 'quoted-printable':
|
||||
return binascii.a2b_qp(data)
|
||||
elif encoding == 'binary':
|
||||
return data
|
||||
else:
|
||||
raise RuntimeError('unknown content transfer encoding: {}'
|
||||
''.format(encoding))
|
||||
|
||||
def get_charset(self, default=None):
|
||||
"""Returns charset parameter from ``Content-Type`` header or default.
|
||||
"""
|
||||
ctype = self.headers.get(CONTENT_TYPE, '')
|
||||
*_, params = parse_mimetype(ctype)
|
||||
return params.get('charset', default)
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
"""Returns filename specified in Content-Disposition header or ``None``
|
||||
if missed or header is malformed."""
|
||||
_, params = parse_content_disposition(
|
||||
self.headers.get(CONTENT_DISPOSITION))
|
||||
return content_disposition_filename(params)
|
||||
|
||||
|
||||
class MultipartReader(object):
|
||||
"""Multipart body reader."""
|
||||
|
||||
#: Response wrapper, used when multipart readers constructs from response.
|
||||
response_wrapper_cls = MultipartResponseWrapper
|
||||
#: Multipart reader class, used to handle multipart/* body parts.
|
||||
#: None points to type(self)
|
||||
multipart_reader_cls = None
|
||||
#: Body part reader class for non multipart/* content types.
|
||||
part_reader_cls = BodyPartReader
|
||||
|
||||
def __init__(self, headers, content):
|
||||
self.headers = headers
|
||||
self._boundary = ('--' + self._get_boundary()).encode()
|
||||
self._content = content
|
||||
self._last_part = None
|
||||
self._at_eof = False
|
||||
self._at_bof = True
|
||||
self._unread = []
|
||||
|
||||
if PY_35:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
if not PY_352: # pragma: no cover
|
||||
__aiter__ = asyncio.coroutine(__aiter__)
|
||||
|
||||
@asyncio.coroutine
|
||||
def __anext__(self):
|
||||
part = yield from self.next()
|
||||
if part is None:
|
||||
raise StopAsyncIteration # NOQA
|
||||
return part
|
||||
|
||||
@classmethod
|
||||
def from_response(cls, response):
|
||||
"""Constructs reader instance from HTTP response.
|
||||
|
||||
:param response: :class:`~aiohttp.client.ClientResponse` instance
|
||||
"""
|
||||
obj = cls.response_wrapper_cls(response, cls(response.headers,
|
||||
response.content))
|
||||
return obj
|
||||
|
||||
def at_eof(self):
|
||||
"""Returns ``True`` if the final boundary was reached or
|
||||
``False`` otherwise.
|
||||
|
||||
:rtype: bool
|
||||
"""
|
||||
return self._at_eof
|
||||
|
||||
@asyncio.coroutine
|
||||
def next(self):
|
||||
"""Emits the next multipart body part."""
|
||||
# So, if we're at BOF, we need to skip till the boundary.
|
||||
if self._at_eof:
|
||||
return
|
||||
yield from self._maybe_release_last_part()
|
||||
if self._at_bof:
|
||||
yield from self._read_until_first_boundary()
|
||||
self._at_bof = False
|
||||
else:
|
||||
yield from self._read_boundary()
|
||||
if self._at_eof: # we just read the last boundary, nothing to do there
|
||||
return
|
||||
self._last_part = yield from self.fetch_next_part()
|
||||
return self._last_part
|
||||
|
||||
@asyncio.coroutine
|
||||
def release(self):
|
||||
"""Reads all the body parts to the void till the final boundary."""
|
||||
while not self._at_eof:
|
||||
item = yield from self.next()
|
||||
if item is None:
|
||||
break
|
||||
yield from item.release()
|
||||
|
||||
@asyncio.coroutine
|
||||
def fetch_next_part(self):
|
||||
"""Returns the next body part reader."""
|
||||
headers = yield from self._read_headers()
|
||||
return self._get_part_reader(headers)
|
||||
|
||||
def _get_part_reader(self, headers):
|
||||
"""Dispatches the response by the `Content-Type` header, returning
|
||||
suitable reader instance.
|
||||
|
||||
:param dict headers: Response headers
|
||||
"""
|
||||
ctype = headers.get(CONTENT_TYPE, '')
|
||||
mtype, *_ = parse_mimetype(ctype)
|
||||
if mtype == 'multipart':
|
||||
if self.multipart_reader_cls is None:
|
||||
return type(self)(headers, self._content)
|
||||
return self.multipart_reader_cls(headers, self._content)
|
||||
else:
|
||||
return self.part_reader_cls(self._boundary, headers, self._content)
|
||||
|
||||
def _get_boundary(self):
|
||||
mtype, *_, params = parse_mimetype(self.headers[CONTENT_TYPE])
|
||||
|
||||
assert mtype == 'multipart', 'multipart/* content type expected'
|
||||
|
||||
if 'boundary' not in params:
|
||||
raise ValueError('boundary missed for Content-Type: %s'
|
||||
% self.headers[CONTENT_TYPE])
|
||||
|
||||
boundary = params['boundary']
|
||||
if len(boundary) > 70:
|
||||
raise ValueError('boundary %r is too long (70 chars max)'
|
||||
% boundary)
|
||||
|
||||
return boundary
|
||||
|
||||
@asyncio.coroutine
|
||||
def _readline(self):
|
||||
if self._unread:
|
||||
return self._unread.pop()
|
||||
return (yield from self._content.readline())
|
||||
|
||||
@asyncio.coroutine
|
||||
def _read_until_first_boundary(self):
|
||||
while True:
|
||||
chunk = yield from self._readline()
|
||||
if chunk == b'':
|
||||
raise ValueError("Could not find starting boundary %r"
|
||||
% (self._boundary))
|
||||
chunk = chunk.rstrip()
|
||||
if chunk == self._boundary:
|
||||
return
|
||||
elif chunk == self._boundary + b'--':
|
||||
self._at_eof = True
|
||||
return
|
||||
|
||||
@asyncio.coroutine
|
||||
def _read_boundary(self):
|
||||
chunk = (yield from self._readline()).rstrip()
|
||||
if chunk == self._boundary:
|
||||
pass
|
||||
elif chunk == self._boundary + b'--':
|
||||
self._at_eof = True
|
||||
else:
|
||||
raise ValueError('Invalid boundary %r, expected %r'
|
||||
% (chunk, self._boundary))
|
||||
|
||||
@asyncio.coroutine
|
||||
def _read_headers(self):
|
||||
lines = [b'']
|
||||
while True:
|
||||
chunk = yield from self._content.readline()
|
||||
chunk = chunk.strip()
|
||||
lines.append(chunk)
|
||||
if not chunk:
|
||||
break
|
||||
parser = HttpParser()
|
||||
headers, *_ = parser.parse_headers(lines)
|
||||
return headers
|
||||
|
||||
@asyncio.coroutine
|
||||
def _maybe_release_last_part(self):
|
||||
"""Ensures that the last read body part is read completely."""
|
||||
if self._last_part is not None:
|
||||
if not self._last_part.at_eof():
|
||||
yield from self._last_part.release()
|
||||
self._unread.extend(self._last_part._unread)
|
||||
self._last_part = None
|
||||
|
||||
|
||||
class BodyPartWriter(object):
|
||||
"""Multipart writer for single body part."""
|
||||
|
||||
def __init__(self, obj, headers=None, *, chunk_size=8192):
|
||||
if headers is None:
|
||||
headers = CIMultiDict()
|
||||
elif not isinstance(headers, CIMultiDict):
|
||||
headers = CIMultiDict(headers)
|
||||
|
||||
self.obj = obj
|
||||
self.headers = headers
|
||||
self._chunk_size = chunk_size
|
||||
self._fill_headers_with_defaults()
|
||||
|
||||
self._serialize_map = {
|
||||
bytes: self._serialize_bytes,
|
||||
str: self._serialize_str,
|
||||
io.IOBase: self._serialize_io,
|
||||
MultipartWriter: self._serialize_multipart,
|
||||
('application', 'json'): self._serialize_json,
|
||||
('application', 'x-www-form-urlencoded'): self._serialize_form
|
||||
}
|
||||
|
||||
def _fill_headers_with_defaults(self):
|
||||
if CONTENT_TYPE not in self.headers:
|
||||
content_type = self._guess_content_type(self.obj)
|
||||
if content_type is not None:
|
||||
self.headers[CONTENT_TYPE] = content_type
|
||||
|
||||
if CONTENT_LENGTH not in self.headers:
|
||||
content_length = self._guess_content_length(self.obj)
|
||||
if content_length is not None:
|
||||
self.headers[CONTENT_LENGTH] = str(content_length)
|
||||
|
||||
if CONTENT_DISPOSITION not in self.headers:
|
||||
filename = self._guess_filename(self.obj)
|
||||
if filename is not None:
|
||||
self.set_content_disposition('attachment', filename=filename)
|
||||
|
||||
def _guess_content_length(self, obj):
|
||||
if isinstance(obj, bytes):
|
||||
return len(obj)
|
||||
elif isinstance(obj, str):
|
||||
*_, params = parse_mimetype(self.headers.get(CONTENT_TYPE))
|
||||
charset = params.get('charset', 'us-ascii')
|
||||
return len(obj.encode(charset))
|
||||
elif isinstance(obj, io.StringIO):
|
||||
*_, params = parse_mimetype(self.headers.get(CONTENT_TYPE))
|
||||
charset = params.get('charset', 'us-ascii')
|
||||
return len(obj.getvalue().encode(charset)) - obj.tell()
|
||||
elif isinstance(obj, io.BytesIO):
|
||||
return len(obj.getvalue()) - obj.tell()
|
||||
elif isinstance(obj, io.IOBase):
|
||||
try:
|
||||
return os.fstat(obj.fileno()).st_size - obj.tell()
|
||||
except (AttributeError, OSError):
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
def _guess_content_type(self, obj, default='application/octet-stream'):
|
||||
if hasattr(obj, 'name'):
|
||||
name = getattr(obj, 'name')
|
||||
return mimetypes.guess_type(name)[0]
|
||||
elif isinstance(obj, (str, io.StringIO)):
|
||||
return 'text/plain; charset=utf-8'
|
||||
else:
|
||||
return default
|
||||
|
||||
def _guess_filename(self, obj):
|
||||
if isinstance(obj, io.IOBase):
|
||||
name = getattr(obj, 'name', None)
|
||||
if name is not None:
|
||||
return Path(name).name
|
||||
|
||||
def serialize(self):
|
||||
"""Yields byte chunks for body part."""
|
||||
|
||||
has_encoding = (
|
||||
CONTENT_ENCODING in self.headers and
|
||||
self.headers[CONTENT_ENCODING] != 'identity' or
|
||||
CONTENT_TRANSFER_ENCODING in self.headers
|
||||
)
|
||||
if has_encoding:
|
||||
# since we're following streaming approach which doesn't assumes
|
||||
# any intermediate buffers, we cannot calculate real content length
|
||||
# with the specified content encoding scheme. So, instead of lying
|
||||
# about content length and cause reading issues, we have to strip
|
||||
# this information.
|
||||
self.headers.pop(CONTENT_LENGTH, None)
|
||||
|
||||
if self.headers:
|
||||
yield b'\r\n'.join(
|
||||
b': '.join(map(lambda i: i.encode('latin1'), item))
|
||||
for item in self.headers.items()
|
||||
)
|
||||
yield b'\r\n\r\n'
|
||||
yield from self._maybe_encode_stream(self._serialize_obj())
|
||||
yield b'\r\n'
|
||||
|
||||
def _serialize_obj(self):
|
||||
obj = self.obj
|
||||
mtype, stype, *_ = parse_mimetype(self.headers.get(CONTENT_TYPE))
|
||||
serializer = self._serialize_map.get((mtype, stype))
|
||||
if serializer is not None:
|
||||
return serializer(obj)
|
||||
|
||||
for key in self._serialize_map:
|
||||
if not isinstance(key, tuple) and isinstance(obj, key):
|
||||
return self._serialize_map[key](obj)
|
||||
return self._serialize_default(obj)
|
||||
|
||||
def _serialize_bytes(self, obj):
|
||||
yield obj
|
||||
|
||||
def _serialize_str(self, obj):
|
||||
*_, params = parse_mimetype(self.headers.get(CONTENT_TYPE))
|
||||
yield obj.encode(params.get('charset', 'us-ascii'))
|
||||
|
||||
def _serialize_io(self, obj):
|
||||
while True:
|
||||
chunk = obj.read(self._chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
if isinstance(chunk, str):
|
||||
yield from self._serialize_str(chunk)
|
||||
else:
|
||||
yield from self._serialize_bytes(chunk)
|
||||
|
||||
def _serialize_multipart(self, obj):
|
||||
yield from obj.serialize()
|
||||
|
||||
def _serialize_json(self, obj):
|
||||
*_, params = parse_mimetype(self.headers.get(CONTENT_TYPE))
|
||||
yield json.dumps(obj).encode(params.get('charset', 'utf-8'))
|
||||
|
||||
def _serialize_form(self, obj):
|
||||
if isinstance(obj, Mapping):
|
||||
obj = list(obj.items())
|
||||
return self._serialize_str(urlencode(obj, doseq=True))
|
||||
|
||||
def _serialize_default(self, obj):
|
||||
raise TypeError('unknown body part type %r' % type(obj))
|
||||
|
||||
def _maybe_encode_stream(self, stream):
|
||||
if CONTENT_ENCODING in self.headers:
|
||||
stream = self._apply_content_encoding(stream)
|
||||
if CONTENT_TRANSFER_ENCODING in self.headers:
|
||||
stream = self._apply_content_transfer_encoding(stream)
|
||||
yield from stream
|
||||
|
||||
def _apply_content_encoding(self, stream):
|
||||
encoding = self.headers[CONTENT_ENCODING].lower()
|
||||
if encoding == 'identity':
|
||||
yield from stream
|
||||
elif encoding in ('deflate', 'gzip'):
|
||||
if encoding == 'gzip':
|
||||
zlib_mode = 16 + zlib.MAX_WBITS
|
||||
else:
|
||||
zlib_mode = -zlib.MAX_WBITS
|
||||
zcomp = zlib.compressobj(wbits=zlib_mode)
|
||||
for chunk in stream:
|
||||
yield zcomp.compress(chunk)
|
||||
else:
|
||||
yield zcomp.flush()
|
||||
else:
|
||||
raise RuntimeError('unknown content encoding: {}'
|
||||
''.format(encoding))
|
||||
|
||||
def _apply_content_transfer_encoding(self, stream):
|
||||
encoding = self.headers[CONTENT_TRANSFER_ENCODING].lower()
|
||||
if encoding == 'base64':
|
||||
buffer = bytearray()
|
||||
while True:
|
||||
if buffer:
|
||||
div, mod = divmod(len(buffer), 3)
|
||||
chunk, buffer = buffer[:div * 3], buffer[div * 3:]
|
||||
if chunk:
|
||||
yield base64.b64encode(chunk)
|
||||
chunk = next(stream, None)
|
||||
if not chunk:
|
||||
if buffer:
|
||||
yield base64.b64encode(buffer[:])
|
||||
return
|
||||
buffer.extend(chunk)
|
||||
elif encoding == 'quoted-printable':
|
||||
for chunk in stream:
|
||||
yield binascii.b2a_qp(chunk)
|
||||
elif encoding == 'binary':
|
||||
yield from stream
|
||||
else:
|
||||
raise RuntimeError('unknown content transfer encoding: {}'
|
||||
''.format(encoding))
|
||||
|
||||
def set_content_disposition(self, disptype, **params):
|
||||
"""Sets ``Content-Disposition`` header.
|
||||
|
||||
:param str disptype: Disposition type: inline, attachment, form-data.
|
||||
Should be valid extension token (see RFC 2183)
|
||||
:param dict params: Disposition params
|
||||
"""
|
||||
if not disptype or not (TOKEN > set(disptype)):
|
||||
raise ValueError('bad content disposition type {!r}'
|
||||
''.format(disptype))
|
||||
value = disptype
|
||||
if params:
|
||||
lparams = []
|
||||
for key, val in params.items():
|
||||
if not key or not (TOKEN > set(key)):
|
||||
raise ValueError('bad content disposition parameter'
|
||||
' {!r}={!r}'.format(key, val))
|
||||
qval = quote(val, '')
|
||||
lparams.append((key, '"%s"' % qval))
|
||||
if key == 'filename':
|
||||
lparams.append(('filename*', "utf-8''" + qval))
|
||||
sparams = '; '.join('='.join(pair) for pair in lparams)
|
||||
value = '; '.join((value, sparams))
|
||||
self.headers[CONTENT_DISPOSITION] = value
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
"""Returns filename specified in Content-Disposition header or ``None``
|
||||
if missed."""
|
||||
_, params = parse_content_disposition(
|
||||
self.headers.get(CONTENT_DISPOSITION))
|
||||
return content_disposition_filename(params)
|
||||
|
||||
|
||||
class MultipartWriter(object):
|
||||
"""Multipart body writer."""
|
||||
|
||||
#: Body part reader class for non multipart/* content types.
|
||||
part_writer_cls = BodyPartWriter
|
||||
|
||||
def __init__(self, subtype='mixed', boundary=None):
|
||||
boundary = boundary if boundary is not None else uuid.uuid4().hex
|
||||
try:
|
||||
boundary.encode('us-ascii')
|
||||
except UnicodeEncodeError:
|
||||
raise ValueError('boundary should contains ASCII only chars')
|
||||
self.headers = CIMultiDict()
|
||||
self.headers[CONTENT_TYPE] = 'multipart/{}; boundary="{}"'.format(
|
||||
subtype, boundary
|
||||
)
|
||||
self.parts = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.parts)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.parts)
|
||||
|
||||
@property
|
||||
def boundary(self):
|
||||
*_, params = parse_mimetype(self.headers.get(CONTENT_TYPE))
|
||||
return params['boundary'].encode('us-ascii')
|
||||
|
||||
def append(self, obj, headers=None):
|
||||
"""Adds a new body part to multipart writer."""
|
||||
if isinstance(obj, self.part_writer_cls):
|
||||
if headers:
|
||||
obj.headers.update(headers)
|
||||
self.parts.append(obj)
|
||||
else:
|
||||
if not headers:
|
||||
headers = CIMultiDict()
|
||||
self.parts.append(self.part_writer_cls(obj, headers))
|
||||
return self.parts[-1]
|
||||
|
||||
def append_json(self, obj, headers=None):
|
||||
"""Helper to append JSON part."""
|
||||
if not headers:
|
||||
headers = CIMultiDict()
|
||||
headers[CONTENT_TYPE] = 'application/json'
|
||||
return self.append(obj, headers)
|
||||
|
||||
def append_form(self, obj, headers=None):
|
||||
"""Helper to append form urlencoded part."""
|
||||
if not headers:
|
||||
headers = CIMultiDict()
|
||||
headers[CONTENT_TYPE] = 'application/x-www-form-urlencoded'
|
||||
assert isinstance(obj, (Sequence, Mapping))
|
||||
return self.append(obj, headers)
|
||||
|
||||
def serialize(self):
|
||||
"""Yields multipart byte chunks."""
|
||||
if not self.parts:
|
||||
yield b''
|
||||
return
|
||||
|
||||
for part in self.parts:
|
||||
yield b'--' + self.boundary + b'\r\n'
|
||||
yield from part.serialize()
|
||||
else:
|
||||
yield b'--' + self.boundary + b'--\r\n'
|
||||
|
||||
yield b''
|
||||
|
|
@ -0,0 +1,495 @@
|
|||
"""Parser is a generator function (NOT coroutine).
|
||||
|
||||
Parser receives data with generator's send() method and sends data to
|
||||
destination DataQueue. Parser receives ParserBuffer and DataQueue objects
|
||||
as a parameters of the parser call, all subsequent send() calls should
|
||||
send bytes objects. Parser sends parsed `term` to destination buffer with
|
||||
DataQueue.feed_data() method. DataQueue object should implement two methods.
|
||||
feed_data() - parser uses this method to send parsed protocol data.
|
||||
feed_eof() - parser uses this method for indication of end of parsing stream.
|
||||
To indicate end of incoming data stream EofStream exception should be sent
|
||||
into parser. Parser could throw exceptions.
|
||||
|
||||
There are three stages:
|
||||
|
||||
* Data flow chain:
|
||||
|
||||
1. Application creates StreamParser object for storing incoming data.
|
||||
2. StreamParser creates ParserBuffer as internal data buffer.
|
||||
3. Application create parser and set it into stream buffer:
|
||||
|
||||
parser = HttpRequestParser()
|
||||
data_queue = stream.set_parser(parser)
|
||||
|
||||
3. At this stage StreamParser creates DataQueue object and passes it
|
||||
and internal buffer into parser as an arguments.
|
||||
|
||||
def set_parser(self, parser):
|
||||
output = DataQueue()
|
||||
self.p = parser(output, self._input)
|
||||
return output
|
||||
|
||||
4. Application waits data on output.read()
|
||||
|
||||
while True:
|
||||
msg = yield from output.read()
|
||||
...
|
||||
|
||||
* Data flow:
|
||||
|
||||
1. asyncio's transport reads data from socket and sends data to protocol
|
||||
with data_received() call.
|
||||
2. Protocol sends data to StreamParser with feed_data() call.
|
||||
3. StreamParser sends data into parser with generator's send() method.
|
||||
4. Parser processes incoming data and sends parsed data
|
||||
to DataQueue with feed_data()
|
||||
5. Application received parsed data from DataQueue.read()
|
||||
|
||||
* Eof:
|
||||
|
||||
1. StreamParser receives eof with feed_eof() call.
|
||||
2. StreamParser throws EofStream exception into parser.
|
||||
3. Then it unsets parser.
|
||||
|
||||
_SocketSocketTransport ->
|
||||
-> "protocol" -> StreamParser -> "parser" -> DataQueue <- "application"
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import asyncio.streams
|
||||
import inspect
|
||||
import socket
|
||||
|
||||
from . import errors
|
||||
from .streams import EofStream, FlowControlDataQueue
|
||||
|
||||
__all__ = ('EofStream', 'StreamParser', 'StreamProtocol',
|
||||
'ParserBuffer', 'StreamWriter')
|
||||
|
||||
DEFAULT_LIMIT = 2 ** 16
|
||||
|
||||
if hasattr(socket, 'TCP_CORK'): # pragma: no cover
|
||||
CORK = socket.TCP_CORK
|
||||
elif hasattr(socket, 'TCP_NOPUSH'): # pragma: no cover
|
||||
CORK = socket.TCP_NOPUSH
|
||||
else: # pragma: no cover
|
||||
CORK = None
|
||||
|
||||
|
||||
class StreamParser:
|
||||
"""StreamParser manages incoming bytes stream and protocol parsers.
|
||||
|
||||
StreamParser uses ParserBuffer as internal buffer.
|
||||
|
||||
set_parser() sets current parser, it creates DataQueue object
|
||||
and sends ParserBuffer and DataQueue into parser generator.
|
||||
|
||||
unset_parser() sends EofStream into parser and then removes it.
|
||||
"""
|
||||
|
||||
def __init__(self, *, loop=None, buf=None,
|
||||
limit=DEFAULT_LIMIT, eof_exc_class=RuntimeError, **kwargs):
|
||||
self._loop = loop
|
||||
self._eof = False
|
||||
self._exception = None
|
||||
self._parser = None
|
||||
self._output = None
|
||||
self._limit = limit
|
||||
self._eof_exc_class = eof_exc_class
|
||||
self._buffer = buf if buf is not None else ParserBuffer()
|
||||
|
||||
self.paused = False
|
||||
self.transport = None
|
||||
|
||||
@property
|
||||
def output(self):
|
||||
return self._output
|
||||
|
||||
def set_transport(self, transport):
|
||||
assert transport is None or self.transport is None, \
|
||||
'Transport already set'
|
||||
self.transport = transport
|
||||
|
||||
def at_eof(self):
|
||||
return self._eof
|
||||
|
||||
def exception(self):
|
||||
return self._exception
|
||||
|
||||
def set_exception(self, exc):
|
||||
if isinstance(exc, ConnectionError):
|
||||
exc, old_exc = self._eof_exc_class(), exc
|
||||
exc.__cause__ = old_exc
|
||||
exc.__context__ = old_exc
|
||||
|
||||
self._exception = exc
|
||||
|
||||
if self._output is not None:
|
||||
self._output.set_exception(exc)
|
||||
self._output = None
|
||||
self._parser = None
|
||||
|
||||
def feed_data(self, data):
|
||||
"""send data to current parser or store in buffer."""
|
||||
if data is None:
|
||||
return
|
||||
|
||||
if self._parser:
|
||||
try:
|
||||
self._parser.send(data)
|
||||
except StopIteration:
|
||||
self._output.feed_eof()
|
||||
self._output = None
|
||||
self._parser = None
|
||||
except Exception as exc:
|
||||
self._output.set_exception(exc)
|
||||
self._output = None
|
||||
self._parser = None
|
||||
else:
|
||||
self._buffer.feed_data(data)
|
||||
|
||||
def feed_eof(self):
|
||||
"""send eof to all parsers, recursively."""
|
||||
if self._parser:
|
||||
try:
|
||||
if self._buffer:
|
||||
self._parser.send(b'')
|
||||
self._parser.throw(EofStream())
|
||||
except StopIteration:
|
||||
self._output.feed_eof()
|
||||
except EofStream:
|
||||
self._output.set_exception(self._eof_exc_class())
|
||||
except Exception as exc:
|
||||
self._output.set_exception(exc)
|
||||
|
||||
self._parser = None
|
||||
self._output = None
|
||||
|
||||
self._eof = True
|
||||
|
||||
def set_parser(self, parser, output=None):
|
||||
"""set parser to stream. return parser's DataQueue."""
|
||||
if self._parser:
|
||||
self.unset_parser()
|
||||
|
||||
if output is None:
|
||||
output = FlowControlDataQueue(
|
||||
self, limit=self._limit, loop=self._loop)
|
||||
|
||||
if self._exception:
|
||||
output.set_exception(self._exception)
|
||||
return output
|
||||
|
||||
# init parser
|
||||
p = parser(output, self._buffer)
|
||||
assert inspect.isgenerator(p), 'Generator is required'
|
||||
|
||||
try:
|
||||
# initialize parser with data and parser buffers
|
||||
next(p)
|
||||
except StopIteration:
|
||||
pass
|
||||
except Exception as exc:
|
||||
output.set_exception(exc)
|
||||
else:
|
||||
# parser still require more data
|
||||
self._parser = p
|
||||
self._output = output
|
||||
|
||||
if self._eof:
|
||||
self.unset_parser()
|
||||
|
||||
return output
|
||||
|
||||
def unset_parser(self):
|
||||
"""unset parser, send eof to the parser and then remove it."""
|
||||
if self._parser is None:
|
||||
return
|
||||
|
||||
# TODO: write test
|
||||
if self._loop.is_closed():
|
||||
# TODO: log something
|
||||
return
|
||||
|
||||
try:
|
||||
self._parser.throw(EofStream())
|
||||
except StopIteration:
|
||||
self._output.feed_eof()
|
||||
except EofStream:
|
||||
self._output.set_exception(self._eof_exc_class())
|
||||
except Exception as exc:
|
||||
self._output.set_exception(exc)
|
||||
finally:
|
||||
self._output = None
|
||||
self._parser = None
|
||||
|
||||
|
||||
class StreamWriter(asyncio.streams.StreamWriter):
|
||||
|
||||
def __init__(self, transport, protocol, reader, loop):
|
||||
self._transport = transport
|
||||
self._protocol = protocol
|
||||
self._reader = reader
|
||||
self._loop = loop
|
||||
self._tcp_nodelay = False
|
||||
self._tcp_cork = False
|
||||
self._socket = transport.get_extra_info('socket')
|
||||
|
||||
@property
|
||||
def tcp_nodelay(self):
|
||||
return self._tcp_nodelay
|
||||
|
||||
def set_tcp_nodelay(self, value):
|
||||
value = bool(value)
|
||||
if self._tcp_nodelay == value:
|
||||
return
|
||||
self._tcp_nodelay = value
|
||||
if self._socket is None:
|
||||
return
|
||||
if self._socket.family not in (socket.AF_INET, socket.AF_INET6):
|
||||
return
|
||||
if self._tcp_cork:
|
||||
self._tcp_cork = False
|
||||
if CORK is not None: # pragma: no branch
|
||||
self._socket.setsockopt(socket.IPPROTO_TCP, CORK, False)
|
||||
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, value)
|
||||
|
||||
@property
|
||||
def tcp_cork(self):
|
||||
return self._tcp_cork
|
||||
|
||||
def set_tcp_cork(self, value):
|
||||
value = bool(value)
|
||||
if self._tcp_cork == value:
|
||||
return
|
||||
self._tcp_cork = value
|
||||
if self._socket is None:
|
||||
return
|
||||
if self._socket.family not in (socket.AF_INET, socket.AF_INET6):
|
||||
return
|
||||
if self._tcp_nodelay:
|
||||
self._socket.setsockopt(socket.IPPROTO_TCP,
|
||||
socket.TCP_NODELAY,
|
||||
False)
|
||||
self._tcp_nodelay = False
|
||||
if CORK is not None: # pragma: no branch
|
||||
self._socket.setsockopt(socket.IPPROTO_TCP, CORK, value)
|
||||
|
||||
|
||||
class StreamProtocol(asyncio.streams.FlowControlMixin, asyncio.Protocol):
|
||||
"""Helper class to adapt between Protocol and StreamReader."""
|
||||
|
||||
def __init__(self, *, loop=None, disconnect_error=RuntimeError, **kwargs):
|
||||
super().__init__(loop=loop)
|
||||
|
||||
self.transport = None
|
||||
self.writer = None
|
||||
self.reader = StreamParser(
|
||||
loop=loop, eof_exc_class=disconnect_error, **kwargs)
|
||||
|
||||
def is_connected(self):
|
||||
return self.transport is not None
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
self.reader.set_transport(transport)
|
||||
self.writer = StreamWriter(transport, self, self.reader, self._loop)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
self.transport = self.writer = None
|
||||
self.reader.set_transport(None)
|
||||
|
||||
if exc is None:
|
||||
self.reader.feed_eof()
|
||||
else:
|
||||
self.reader.set_exception(exc)
|
||||
|
||||
super().connection_lost(exc)
|
||||
|
||||
def data_received(self, data):
|
||||
self.reader.feed_data(data)
|
||||
|
||||
def eof_received(self):
|
||||
self.reader.feed_eof()
|
||||
|
||||
|
||||
class _ParserBufferHelper:
|
||||
|
||||
__slots__ = ('exception', 'data')
|
||||
|
||||
def __init__(self, exception, data):
|
||||
self.exception = exception
|
||||
self.data = data
|
||||
|
||||
|
||||
class ParserBuffer:
|
||||
"""ParserBuffer is NOT a bytearray extension anymore.
|
||||
|
||||
ParserBuffer provides helper methods for parsers.
|
||||
"""
|
||||
__slots__ = ('_helper', '_writer', '_data')
|
||||
|
||||
def __init__(self, *args):
|
||||
self._data = bytearray(*args)
|
||||
self._helper = _ParserBufferHelper(None, self._data)
|
||||
self._writer = self._feed_data(self._helper)
|
||||
next(self._writer)
|
||||
|
||||
def exception(self):
|
||||
return self._helper.exception
|
||||
|
||||
def set_exception(self, exc):
|
||||
self._helper.exception = exc
|
||||
|
||||
@staticmethod
|
||||
def _feed_data(helper):
|
||||
while True:
|
||||
chunk = yield
|
||||
if chunk:
|
||||
helper.data.extend(chunk)
|
||||
|
||||
if helper.exception:
|
||||
raise helper.exception
|
||||
|
||||
def feed_data(self, data):
|
||||
if not self._helper.exception:
|
||||
self._writer.send(data)
|
||||
|
||||
def read(self, size):
|
||||
"""read() reads specified amount of bytes."""
|
||||
|
||||
while True:
|
||||
if self._helper.exception:
|
||||
raise self._helper.exception
|
||||
|
||||
if len(self._data) >= size:
|
||||
data = self._data[:size]
|
||||
del self._data[:size]
|
||||
return data
|
||||
|
||||
self._writer.send((yield))
|
||||
|
||||
def readsome(self, size=None):
|
||||
"""reads size of less amount of bytes."""
|
||||
|
||||
while True:
|
||||
if self._helper.exception:
|
||||
raise self._helper.exception
|
||||
|
||||
length = len(self._data)
|
||||
if length > 0:
|
||||
if size is None or length < size:
|
||||
size = length
|
||||
|
||||
data = self._data[:size]
|
||||
del self._data[:size]
|
||||
return data
|
||||
|
||||
self._writer.send((yield))
|
||||
|
||||
def readuntil(self, stop, limit=None):
|
||||
assert isinstance(stop, bytes) and stop, \
|
||||
'bytes is required: {!r}'.format(stop)
|
||||
|
||||
stop_len = len(stop)
|
||||
|
||||
while True:
|
||||
if self._helper.exception:
|
||||
raise self._helper.exception
|
||||
|
||||
pos = self._data.find(stop)
|
||||
if pos >= 0:
|
||||
end = pos + stop_len
|
||||
size = end
|
||||
if limit is not None and size > limit:
|
||||
raise errors.LineLimitExceededParserError(
|
||||
'Line is too long.', limit)
|
||||
|
||||
data = self._data[:size]
|
||||
del self._data[:size]
|
||||
return data
|
||||
else:
|
||||
if limit is not None and len(self._data) > limit:
|
||||
raise errors.LineLimitExceededParserError(
|
||||
'Line is too long.', limit)
|
||||
|
||||
self._writer.send((yield))
|
||||
|
||||
def wait(self, size):
|
||||
"""wait() waits for specified amount of bytes
|
||||
then returns data without changing internal buffer."""
|
||||
|
||||
while True:
|
||||
if self._helper.exception:
|
||||
raise self._helper.exception
|
||||
|
||||
if len(self._data) >= size:
|
||||
return self._data[:size]
|
||||
|
||||
self._writer.send((yield))
|
||||
|
||||
def waituntil(self, stop, limit=None):
|
||||
"""waituntil() reads until `stop` bytes sequence."""
|
||||
assert isinstance(stop, bytes) and stop, \
|
||||
'bytes is required: {!r}'.format(stop)
|
||||
|
||||
stop_len = len(stop)
|
||||
|
||||
while True:
|
||||
if self._helper.exception:
|
||||
raise self._helper.exception
|
||||
|
||||
pos = self._data.find(stop)
|
||||
if pos >= 0:
|
||||
size = pos + stop_len
|
||||
if limit is not None and size > limit:
|
||||
raise errors.LineLimitExceededParserError(
|
||||
'Line is too long. %s' % bytes(self._data), limit)
|
||||
|
||||
return self._data[:size]
|
||||
else:
|
||||
if limit is not None and len(self._data) > limit:
|
||||
raise errors.LineLimitExceededParserError(
|
||||
'Line is too long. %s' % bytes(self._data), limit)
|
||||
|
||||
self._writer.send((yield))
|
||||
|
||||
def skip(self, size):
|
||||
"""skip() skips specified amount of bytes."""
|
||||
|
||||
while len(self._data) < size:
|
||||
if self._helper.exception:
|
||||
raise self._helper.exception
|
||||
|
||||
self._writer.send((yield))
|
||||
|
||||
del self._data[:size]
|
||||
|
||||
def skipuntil(self, stop):
|
||||
"""skipuntil() reads until `stop` bytes sequence."""
|
||||
assert isinstance(stop, bytes) and stop, \
|
||||
'bytes is required: {!r}'.format(stop)
|
||||
|
||||
stop_len = len(stop)
|
||||
|
||||
while True:
|
||||
if self._helper.exception:
|
||||
raise self._helper.exception
|
||||
|
||||
stop_line = self._data.find(stop)
|
||||
if stop_line >= 0:
|
||||
size = stop_line + stop_len
|
||||
del self._data[:size]
|
||||
return
|
||||
|
||||
self._writer.send((yield))
|
||||
|
||||
def extend(self, data):
|
||||
self._data.extend(data)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
def __bytes__(self):
|
||||
return bytes(self._data)
|
||||
|
|
@ -0,0 +1,916 @@
|
|||
"""Http related parsers and protocol."""
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import http.server
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
import zlib
|
||||
from abc import ABC, abstractmethod
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
from multidict import CIMultiDict, istr
|
||||
|
||||
import aiohttp
|
||||
|
||||
from . import errors, hdrs
|
||||
from .helpers import reify
|
||||
from .log import internal_logger
|
||||
|
||||
__all__ = ('HttpMessage', 'Request', 'Response',
|
||||
'HttpVersion', 'HttpVersion10', 'HttpVersion11',
|
||||
'RawRequestMessage', 'RawResponseMessage',
|
||||
'HttpPrefixParser', 'HttpRequestParser', 'HttpResponseParser',
|
||||
'HttpPayloadParser')
|
||||
|
||||
ASCIISET = set(string.printable)
|
||||
METHRE = re.compile('[A-Z0-9$-_.]+')
|
||||
VERSRE = re.compile('HTTP/(\d+).(\d+)')
|
||||
HDRRE = re.compile(b'[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]')
|
||||
EOF_MARKER = object()
|
||||
EOL_MARKER = object()
|
||||
STATUS_LINE_READY = object()
|
||||
|
||||
RESPONSES = http.server.BaseHTTPRequestHandler.responses
|
||||
|
||||
HttpVersion = collections.namedtuple(
|
||||
'HttpVersion', ['major', 'minor'])
|
||||
HttpVersion10 = HttpVersion(1, 0)
|
||||
HttpVersion11 = HttpVersion(1, 1)
|
||||
|
||||
RawStatusLineMessage = collections.namedtuple(
|
||||
'RawStatusLineMessage', ['method', 'path', 'version'])
|
||||
|
||||
RawRequestMessage = collections.namedtuple(
|
||||
'RawRequestMessage',
|
||||
['method', 'path', 'version', 'headers', 'raw_headers',
|
||||
'should_close', 'compression'])
|
||||
|
||||
|
||||
RawResponseMessage = collections.namedtuple(
|
||||
'RawResponseMessage',
|
||||
['version', 'code', 'reason', 'headers', 'raw_headers',
|
||||
'should_close', 'compression'])
|
||||
|
||||
|
||||
class HttpParser:
|
||||
|
||||
def __init__(self, max_line_size=8190, max_headers=32768,
|
||||
max_field_size=8190):
|
||||
self.max_line_size = max_line_size
|
||||
self.max_headers = max_headers
|
||||
self.max_field_size = max_field_size
|
||||
|
||||
def parse_headers(self, lines):
|
||||
"""Parses RFC 5322 headers from a stream.
|
||||
|
||||
Line continuations are supported. Returns list of header name
|
||||
and value pairs. Header name is in upper case.
|
||||
"""
|
||||
close_conn = None
|
||||
encoding = None
|
||||
headers = CIMultiDict()
|
||||
raw_headers = []
|
||||
|
||||
lines_idx = 1
|
||||
line = lines[1]
|
||||
|
||||
while line:
|
||||
header_length = len(line)
|
||||
|
||||
# Parse initial header name : value pair.
|
||||
try:
|
||||
bname, bvalue = line.split(b':', 1)
|
||||
except ValueError:
|
||||
raise errors.InvalidHeader(line) from None
|
||||
|
||||
bname = bname.strip(b' \t').upper()
|
||||
if HDRRE.search(bname):
|
||||
raise errors.InvalidHeader(bname)
|
||||
|
||||
# next line
|
||||
lines_idx += 1
|
||||
line = lines[lines_idx]
|
||||
|
||||
# consume continuation lines
|
||||
continuation = line and line[0] in (32, 9) # (' ', '\t')
|
||||
|
||||
if continuation:
|
||||
bvalue = [bvalue]
|
||||
while continuation:
|
||||
header_length += len(line)
|
||||
if header_length > self.max_field_size:
|
||||
raise errors.LineTooLong(
|
||||
'limit request headers fields size')
|
||||
bvalue.append(line)
|
||||
|
||||
# next line
|
||||
lines_idx += 1
|
||||
line = lines[lines_idx]
|
||||
continuation = line[0] in (32, 9) # (' ', '\t')
|
||||
bvalue = b'\r\n'.join(bvalue)
|
||||
else:
|
||||
if header_length > self.max_field_size:
|
||||
raise errors.LineTooLong(
|
||||
'limit request headers fields size')
|
||||
|
||||
bvalue = bvalue.strip()
|
||||
|
||||
name = istr(bname.decode('utf-8', 'surrogateescape'))
|
||||
value = bvalue.decode('utf-8', 'surrogateescape')
|
||||
|
||||
# keep-alive and encoding
|
||||
if name == hdrs.CONNECTION:
|
||||
v = value.lower()
|
||||
if v == 'close':
|
||||
close_conn = True
|
||||
elif v == 'keep-alive':
|
||||
close_conn = False
|
||||
elif name == hdrs.CONTENT_ENCODING:
|
||||
enc = value.lower()
|
||||
if enc in ('gzip', 'deflate'):
|
||||
encoding = enc
|
||||
|
||||
headers.add(name, value)
|
||||
raw_headers.append((bname, bvalue))
|
||||
|
||||
return headers, raw_headers, close_conn, encoding
|
||||
|
||||
|
||||
class HttpPrefixParser:
|
||||
"""Waits for 'HTTP' prefix (non destructive)"""
|
||||
|
||||
def __init__(self, allowed_methods=()):
|
||||
self.allowed_methods = [m.upper() for m in allowed_methods]
|
||||
|
||||
def __call__(self, out, buf):
|
||||
raw_data = yield from buf.waituntil(b' ', 12)
|
||||
method = raw_data.decode('ascii', 'surrogateescape').strip()
|
||||
|
||||
# method
|
||||
method = method.upper()
|
||||
if not METHRE.match(method):
|
||||
raise errors.BadStatusLine(method)
|
||||
|
||||
# allowed method
|
||||
if self.allowed_methods and method not in self.allowed_methods:
|
||||
raise errors.HttpMethodNotAllowed(message=method)
|
||||
|
||||
out.feed_data(method, len(method))
|
||||
out.feed_eof()
|
||||
|
||||
|
||||
class HttpRequestParser(HttpParser):
|
||||
"""Read request status line. Exception errors.BadStatusLine
|
||||
could be raised in case of any errors in status line.
|
||||
Returns RawRequestMessage.
|
||||
"""
|
||||
|
||||
def __call__(self, out, buf):
|
||||
# read HTTP message (request line + headers)
|
||||
try:
|
||||
raw_data = yield from buf.readuntil(
|
||||
b'\r\n\r\n', self.max_headers)
|
||||
except errors.LineLimitExceededParserError as exc:
|
||||
raise errors.LineTooLong(exc.limit) from None
|
||||
|
||||
lines = raw_data.split(b'\r\n')
|
||||
|
||||
# request line
|
||||
line = lines[0].decode('utf-8', 'surrogateescape')
|
||||
try:
|
||||
method, path, version = line.split(None, 2)
|
||||
except ValueError:
|
||||
raise errors.BadStatusLine(line) from None
|
||||
|
||||
# method
|
||||
method = method.upper()
|
||||
if not METHRE.match(method):
|
||||
raise errors.BadStatusLine(method)
|
||||
|
||||
# version
|
||||
try:
|
||||
if version.startswith('HTTP/'):
|
||||
n1, n2 = version[5:].split('.', 1)
|
||||
version = HttpVersion(int(n1), int(n2))
|
||||
else:
|
||||
raise errors.BadStatusLine(version)
|
||||
except:
|
||||
raise errors.BadStatusLine(version)
|
||||
|
||||
# read headers
|
||||
headers, raw_headers, close, compression = self.parse_headers(lines)
|
||||
if close is None: # then the headers weren't set in the request
|
||||
if version <= HttpVersion10: # HTTP 1.0 must asks to not close
|
||||
close = True
|
||||
else: # HTTP 1.1 must ask to close.
|
||||
close = False
|
||||
|
||||
out.feed_data(
|
||||
RawRequestMessage(
|
||||
method, path, version, headers, raw_headers,
|
||||
close, compression),
|
||||
len(raw_data))
|
||||
out.feed_eof()
|
||||
|
||||
|
||||
class HttpResponseParser(HttpParser):
|
||||
"""Read response status line and headers.
|
||||
|
||||
BadStatusLine could be raised in case of any errors in status line.
|
||||
Returns RawResponseMessage"""
|
||||
|
||||
def __call__(self, out, buf):
|
||||
# read HTTP message (response line + headers)
|
||||
try:
|
||||
raw_data = yield from buf.readuntil(
|
||||
b'\r\n\r\n', self.max_line_size + self.max_headers)
|
||||
except errors.LineLimitExceededParserError as exc:
|
||||
raise errors.LineTooLong(exc.limit) from None
|
||||
|
||||
lines = raw_data.split(b'\r\n')
|
||||
|
||||
line = lines[0].decode('utf-8', 'surrogateescape')
|
||||
try:
|
||||
version, status = line.split(None, 1)
|
||||
except ValueError:
|
||||
raise errors.BadStatusLine(line) from None
|
||||
else:
|
||||
try:
|
||||
status, reason = status.split(None, 1)
|
||||
except ValueError:
|
||||
reason = ''
|
||||
|
||||
# version
|
||||
match = VERSRE.match(version)
|
||||
if match is None:
|
||||
raise errors.BadStatusLine(line)
|
||||
version = HttpVersion(int(match.group(1)), int(match.group(2)))
|
||||
|
||||
# The status code is a three-digit number
|
||||
try:
|
||||
status = int(status)
|
||||
except ValueError:
|
||||
raise errors.BadStatusLine(line) from None
|
||||
|
||||
if status < 100 or status > 999:
|
||||
raise errors.BadStatusLine(line)
|
||||
|
||||
# read headers
|
||||
headers, raw_headers, close, compression = self.parse_headers(lines)
|
||||
|
||||
if close is None:
|
||||
close = version <= HttpVersion10
|
||||
|
||||
out.feed_data(
|
||||
RawResponseMessage(
|
||||
version, status, reason.strip(),
|
||||
headers, raw_headers, close, compression),
|
||||
len(raw_data))
|
||||
out.feed_eof()
|
||||
|
||||
|
||||
class HttpPayloadParser:
|
||||
|
||||
def __init__(self, message, length=None, compression=True,
|
||||
readall=False, response_with_body=True):
|
||||
self.message = message
|
||||
self.length = length
|
||||
self.compression = compression
|
||||
self.readall = readall
|
||||
self.response_with_body = response_with_body
|
||||
|
||||
def __call__(self, out, buf):
|
||||
# payload params
|
||||
length = self.message.headers.get(hdrs.CONTENT_LENGTH, self.length)
|
||||
if hdrs.SEC_WEBSOCKET_KEY1 in self.message.headers:
|
||||
length = 8
|
||||
|
||||
# payload decompression wrapper
|
||||
if (self.response_with_body and
|
||||
self.compression and self.message.compression):
|
||||
out = DeflateBuffer(out, self.message.compression)
|
||||
|
||||
# payload parser
|
||||
if not self.response_with_body:
|
||||
# don't parse payload if it's not expected to be received
|
||||
pass
|
||||
|
||||
elif 'chunked' in self.message.headers.get(
|
||||
hdrs.TRANSFER_ENCODING, ''):
|
||||
yield from self.parse_chunked_payload(out, buf)
|
||||
|
||||
elif length is not None:
|
||||
try:
|
||||
length = int(length)
|
||||
except ValueError:
|
||||
raise errors.InvalidHeader(hdrs.CONTENT_LENGTH) from None
|
||||
|
||||
if length < 0:
|
||||
raise errors.InvalidHeader(hdrs.CONTENT_LENGTH)
|
||||
elif length > 0:
|
||||
yield from self.parse_length_payload(out, buf, length)
|
||||
else:
|
||||
if self.readall and getattr(self.message, 'code', 0) != 204:
|
||||
yield from self.parse_eof_payload(out, buf)
|
||||
elif getattr(self.message, 'method', None) in ('PUT', 'POST'):
|
||||
internal_logger.warning( # pragma: no cover
|
||||
'Content-Length or Transfer-Encoding header is required')
|
||||
|
||||
out.feed_eof()
|
||||
|
||||
def parse_chunked_payload(self, out, buf):
|
||||
"""Chunked transfer encoding parser."""
|
||||
while True:
|
||||
# read next chunk size
|
||||
line = yield from buf.readuntil(b'\r\n', 8192)
|
||||
|
||||
i = line.find(b';')
|
||||
if i >= 0:
|
||||
line = line[:i] # strip chunk-extensions
|
||||
else:
|
||||
line = line.strip()
|
||||
try:
|
||||
size = int(line, 16)
|
||||
except ValueError:
|
||||
raise errors.TransferEncodingError(line) from None
|
||||
|
||||
if size == 0: # eof marker
|
||||
break
|
||||
|
||||
# read chunk and feed buffer
|
||||
while size:
|
||||
chunk = yield from buf.readsome(size)
|
||||
out.feed_data(chunk, len(chunk))
|
||||
size = size - len(chunk)
|
||||
|
||||
# toss the CRLF at the end of the chunk
|
||||
yield from buf.skip(2)
|
||||
|
||||
# read and discard trailer up to the CRLF terminator
|
||||
yield from buf.skipuntil(b'\r\n')
|
||||
|
||||
def parse_length_payload(self, out, buf, length=0):
|
||||
"""Read specified amount of bytes."""
|
||||
required = length
|
||||
while required:
|
||||
chunk = yield from buf.readsome(required)
|
||||
out.feed_data(chunk, len(chunk))
|
||||
required -= len(chunk)
|
||||
|
||||
def parse_eof_payload(self, out, buf):
|
||||
"""Read all bytes until eof."""
|
||||
try:
|
||||
while True:
|
||||
chunk = yield from buf.readsome()
|
||||
out.feed_data(chunk, len(chunk))
|
||||
except aiohttp.EofStream:
|
||||
pass
|
||||
|
||||
|
||||
class DeflateBuffer:
|
||||
"""DeflateStream decompress stream and feed data into specified stream."""
|
||||
|
||||
def __init__(self, out, encoding):
|
||||
self.out = out
|
||||
zlib_mode = (16 + zlib.MAX_WBITS
|
||||
if encoding == 'gzip' else -zlib.MAX_WBITS)
|
||||
|
||||
self.zlib = zlib.decompressobj(wbits=zlib_mode)
|
||||
|
||||
def feed_data(self, chunk, size):
|
||||
try:
|
||||
chunk = self.zlib.decompress(chunk)
|
||||
except Exception:
|
||||
raise errors.ContentEncodingError('deflate')
|
||||
|
||||
if chunk:
|
||||
self.out.feed_data(chunk, len(chunk))
|
||||
|
||||
def feed_eof(self):
|
||||
chunk = self.zlib.flush()
|
||||
self.out.feed_data(chunk, len(chunk))
|
||||
if not self.zlib.eof:
|
||||
raise errors.ContentEncodingError('deflate')
|
||||
|
||||
self.out.feed_eof()
|
||||
|
||||
|
||||
def wrap_payload_filter(func):
|
||||
"""Wraps payload filter and piped filters.
|
||||
|
||||
Filter is a generator that accepts arbitrary chunks of data,
|
||||
modify data and emit new stream of data.
|
||||
|
||||
For example we have stream of chunks: ['1', '2', '3', '4', '5'],
|
||||
we can apply chunking filter to this stream:
|
||||
|
||||
['1', '2', '3', '4', '5']
|
||||
|
|
||||
response.add_chunking_filter(2)
|
||||
|
|
||||
['12', '34', '5']
|
||||
|
||||
It is possible to use different filters at the same time.
|
||||
|
||||
For a example to compress incoming stream with 'deflate' encoding
|
||||
and then split data and emit chunks of 8192 bytes size chunks:
|
||||
|
||||
>>> response.add_compression_filter('deflate')
|
||||
>>> response.add_chunking_filter(8192)
|
||||
|
||||
Filters do not alter transfer encoding.
|
||||
|
||||
Filter can receive types types of data, bytes object or EOF_MARKER.
|
||||
|
||||
1. If filter receives bytes object, it should process data
|
||||
and yield processed data then yield EOL_MARKER object.
|
||||
2. If Filter received EOF_MARKER, it should yield remaining
|
||||
data (buffered) and then yield EOF_MARKER.
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kw):
|
||||
new_filter = func(self, *args, **kw)
|
||||
|
||||
filter = self.filter
|
||||
if filter is not None:
|
||||
next(new_filter)
|
||||
self.filter = filter_pipe(filter, new_filter)
|
||||
else:
|
||||
self.filter = new_filter
|
||||
|
||||
next(self.filter)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def filter_pipe(filter, filter2, *,
|
||||
EOF_MARKER=EOF_MARKER, EOL_MARKER=EOL_MARKER):
|
||||
"""Creates pipe between two filters.
|
||||
|
||||
filter_pipe() feeds first filter with incoming data and then
|
||||
send yielded from first filter data into filter2, results of
|
||||
filter2 are being emitted.
|
||||
|
||||
1. If filter_pipe receives bytes object, it sends it to the first filter.
|
||||
2. Reads yielded values from the first filter until it receives
|
||||
EOF_MARKER or EOL_MARKER.
|
||||
3. Each of this values is being send to second filter.
|
||||
4. Reads yielded values from second filter until it receives EOF_MARKER
|
||||
or EOL_MARKER. Each of this values yields to writer.
|
||||
"""
|
||||
chunk = yield
|
||||
|
||||
while True:
|
||||
eof = chunk is EOF_MARKER
|
||||
chunk = filter.send(chunk)
|
||||
|
||||
while chunk is not EOL_MARKER:
|
||||
chunk = filter2.send(chunk)
|
||||
|
||||
while chunk not in (EOF_MARKER, EOL_MARKER):
|
||||
yield chunk
|
||||
chunk = next(filter2)
|
||||
|
||||
if chunk is not EOF_MARKER:
|
||||
if eof:
|
||||
chunk = EOF_MARKER
|
||||
else:
|
||||
chunk = next(filter)
|
||||
else:
|
||||
break
|
||||
|
||||
chunk = yield EOL_MARKER
|
||||
|
||||
|
||||
class HttpMessage(ABC):
|
||||
"""HttpMessage allows to write headers and payload to a stream.
|
||||
|
||||
For example, lets say we want to read file then compress it with deflate
|
||||
compression and then send it with chunked transfer encoding, code may look
|
||||
like this:
|
||||
|
||||
>>> response = aiohttp.Response(transport, 200)
|
||||
|
||||
We have to use deflate compression first:
|
||||
|
||||
>>> response.add_compression_filter('deflate')
|
||||
|
||||
Then we want to split output stream into chunks of 1024 bytes size:
|
||||
|
||||
>>> response.add_chunking_filter(1024)
|
||||
|
||||
We can add headers to response with add_headers() method. add_headers()
|
||||
does not send data to transport, send_headers() sends request/response
|
||||
line and then sends headers:
|
||||
|
||||
>>> response.add_headers(
|
||||
... ('Content-Disposition', 'attachment; filename="..."'))
|
||||
>>> response.send_headers()
|
||||
|
||||
Now we can use chunked writer to write stream to a network stream.
|
||||
First call to write() method sends response status line and headers,
|
||||
add_header() and add_headers() method unavailable at this stage:
|
||||
|
||||
>>> with open('...', 'rb') as f:
|
||||
... chunk = fp.read(8192)
|
||||
... while chunk:
|
||||
... response.write(chunk)
|
||||
... chunk = fp.read(8192)
|
||||
|
||||
>>> response.write_eof()
|
||||
|
||||
"""
|
||||
|
||||
writer = None
|
||||
|
||||
# 'filter' is being used for altering write() behaviour,
|
||||
# add_chunking_filter adds deflate/gzip compression and
|
||||
# add_compression_filter splits incoming data into a chunks.
|
||||
filter = None
|
||||
|
||||
HOP_HEADERS = None # Must be set by subclass.
|
||||
|
||||
SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} aiohttp/{1}'.format(
|
||||
sys.version_info, aiohttp.__version__)
|
||||
|
||||
upgrade = False # Connection: UPGRADE
|
||||
websocket = False # Upgrade: WEBSOCKET
|
||||
has_chunked_hdr = False # Transfer-encoding: chunked
|
||||
|
||||
# subclass can enable auto sending headers with write() call,
|
||||
# this is useful for wsgi's start_response implementation.
|
||||
_send_headers = False
|
||||
|
||||
def __init__(self, transport, version, close):
|
||||
self.transport = transport
|
||||
self._version = version
|
||||
self.closing = close
|
||||
self.keepalive = None
|
||||
self.chunked = False
|
||||
self.length = None
|
||||
self.headers = CIMultiDict()
|
||||
self.headers_sent = False
|
||||
self.output_length = 0
|
||||
self.headers_length = 0
|
||||
self._output_size = 0
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def status_line(self):
|
||||
return b''
|
||||
|
||||
@abstractmethod
|
||||
def autochunked(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
return self._version
|
||||
|
||||
@property
|
||||
def body_length(self):
|
||||
return self.output_length - self.headers_length
|
||||
|
||||
def force_close(self):
|
||||
self.closing = True
|
||||
self.keepalive = False
|
||||
|
||||
def enable_chunked_encoding(self):
|
||||
self.chunked = True
|
||||
|
||||
def keep_alive(self):
|
||||
if self.keepalive is None:
|
||||
if self.version < HttpVersion10:
|
||||
# keep alive not supported at all
|
||||
return False
|
||||
if self.version == HttpVersion10:
|
||||
if self.headers.get(hdrs.CONNECTION) == 'keep-alive':
|
||||
return True
|
||||
else: # no headers means we close for Http 1.0
|
||||
return False
|
||||
else:
|
||||
return not self.closing
|
||||
else:
|
||||
return self.keepalive
|
||||
|
||||
def is_headers_sent(self):
|
||||
return self.headers_sent
|
||||
|
||||
def add_header(self, name, value):
|
||||
"""Analyze headers. Calculate content length,
|
||||
removes hop headers, etc."""
|
||||
assert not self.headers_sent, 'headers have been sent already'
|
||||
assert isinstance(name, str), \
|
||||
'Header name should be a string, got {!r}'.format(name)
|
||||
assert set(name).issubset(ASCIISET), \
|
||||
'Header name should contain ASCII chars, got {!r}'.format(name)
|
||||
assert isinstance(value, str), \
|
||||
'Header {!r} should have string value, got {!r}'.format(
|
||||
name, value)
|
||||
|
||||
name = istr(name)
|
||||
value = value.strip()
|
||||
|
||||
if name == hdrs.CONTENT_LENGTH:
|
||||
self.length = int(value)
|
||||
|
||||
if name == hdrs.TRANSFER_ENCODING:
|
||||
self.has_chunked_hdr = value.lower().strip() == 'chunked'
|
||||
|
||||
if name == hdrs.CONNECTION:
|
||||
val = value.lower()
|
||||
# handle websocket
|
||||
if 'upgrade' in val:
|
||||
self.upgrade = True
|
||||
# connection keep-alive
|
||||
elif 'close' in val:
|
||||
self.keepalive = False
|
||||
elif 'keep-alive' in val:
|
||||
self.keepalive = True
|
||||
|
||||
elif name == hdrs.UPGRADE:
|
||||
if 'websocket' in value.lower():
|
||||
self.websocket = True
|
||||
self.headers[name] = value
|
||||
|
||||
elif name not in self.HOP_HEADERS:
|
||||
# ignore hop-by-hop headers
|
||||
self.headers.add(name, value)
|
||||
|
||||
def add_headers(self, *headers):
|
||||
"""Adds headers to a HTTP message."""
|
||||
for name, value in headers:
|
||||
self.add_header(name, value)
|
||||
|
||||
def send_headers(self, _sep=': ', _end='\r\n'):
|
||||
"""Writes headers to a stream. Constructs payload writer."""
|
||||
# Chunked response is only for HTTP/1.1 clients or newer
|
||||
# and there is no Content-Length header is set.
|
||||
# Do not use chunked responses when the response is guaranteed to
|
||||
# not have a response body (304, 204).
|
||||
assert not self.headers_sent, 'headers have been sent already'
|
||||
self.headers_sent = True
|
||||
|
||||
if self.chunked or self.autochunked():
|
||||
self.writer = self._write_chunked_payload()
|
||||
self.headers[hdrs.TRANSFER_ENCODING] = 'chunked'
|
||||
|
||||
elif self.length is not None:
|
||||
self.writer = self._write_length_payload(self.length)
|
||||
|
||||
else:
|
||||
self.writer = self._write_eof_payload()
|
||||
|
||||
next(self.writer)
|
||||
|
||||
self._add_default_headers()
|
||||
|
||||
# status + headers
|
||||
headers = self.status_line + ''.join(
|
||||
[k + _sep + v + _end for k, v in self.headers.items()])
|
||||
headers = headers.encode('utf-8') + b'\r\n'
|
||||
|
||||
self.output_length += len(headers)
|
||||
self.headers_length = len(headers)
|
||||
self.transport.write(headers)
|
||||
|
||||
def _add_default_headers(self):
|
||||
# set the connection header
|
||||
connection = None
|
||||
if self.upgrade:
|
||||
connection = 'upgrade'
|
||||
elif not self.closing if self.keepalive is None else self.keepalive:
|
||||
if self.version == HttpVersion10:
|
||||
connection = 'keep-alive'
|
||||
else:
|
||||
if self.version == HttpVersion11:
|
||||
connection = 'close'
|
||||
|
||||
if connection is not None:
|
||||
self.headers[hdrs.CONNECTION] = connection
|
||||
|
||||
def write(self, chunk, *,
|
||||
drain=False, EOF_MARKER=EOF_MARKER, EOL_MARKER=EOL_MARKER):
|
||||
"""Writes chunk of data to a stream by using different writers.
|
||||
|
||||
writer uses filter to modify chunk of data.
|
||||
write_eof() indicates end of stream.
|
||||
writer can't be used after write_eof() method being called.
|
||||
write() return drain future.
|
||||
"""
|
||||
assert (isinstance(chunk, (bytes, bytearray)) or
|
||||
chunk is EOF_MARKER), chunk
|
||||
|
||||
size = self.output_length
|
||||
|
||||
if self._send_headers and not self.headers_sent:
|
||||
self.send_headers()
|
||||
|
||||
assert self.writer is not None, 'send_headers() is not called.'
|
||||
|
||||
if self.filter:
|
||||
chunk = self.filter.send(chunk)
|
||||
while chunk not in (EOF_MARKER, EOL_MARKER):
|
||||
if chunk:
|
||||
self.writer.send(chunk)
|
||||
chunk = next(self.filter)
|
||||
else:
|
||||
if chunk is not EOF_MARKER:
|
||||
self.writer.send(chunk)
|
||||
|
||||
self._output_size += self.output_length - size
|
||||
|
||||
if self._output_size > 64 * 1024:
|
||||
if drain:
|
||||
self._output_size = 0
|
||||
return self.transport.drain()
|
||||
|
||||
return ()
|
||||
|
||||
def write_eof(self):
|
||||
self.write(EOF_MARKER)
|
||||
try:
|
||||
self.writer.throw(aiohttp.EofStream())
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
return self.transport.drain()
|
||||
|
||||
def _write_chunked_payload(self):
|
||||
"""Write data in chunked transfer encoding."""
|
||||
while True:
|
||||
try:
|
||||
chunk = yield
|
||||
except aiohttp.EofStream:
|
||||
self.transport.write(b'0\r\n\r\n')
|
||||
self.output_length += 5
|
||||
break
|
||||
|
||||
chunk = bytes(chunk)
|
||||
chunk_len = '{:x}\r\n'.format(len(chunk)).encode('ascii')
|
||||
self.transport.write(chunk_len + chunk + b'\r\n')
|
||||
self.output_length += len(chunk_len) + len(chunk) + 2
|
||||
|
||||
def _write_length_payload(self, length):
|
||||
"""Write specified number of bytes to a stream."""
|
||||
while True:
|
||||
try:
|
||||
chunk = yield
|
||||
except aiohttp.EofStream:
|
||||
break
|
||||
|
||||
if length:
|
||||
l = len(chunk)
|
||||
if length >= l:
|
||||
self.transport.write(chunk)
|
||||
self.output_length += l
|
||||
length = length-l
|
||||
else:
|
||||
self.transport.write(chunk[:length])
|
||||
self.output_length += length
|
||||
length = 0
|
||||
|
||||
def _write_eof_payload(self):
|
||||
while True:
|
||||
try:
|
||||
chunk = yield
|
||||
except aiohttp.EofStream:
|
||||
break
|
||||
|
||||
self.transport.write(chunk)
|
||||
self.output_length += len(chunk)
|
||||
|
||||
@wrap_payload_filter
|
||||
def add_chunking_filter(self, chunk_size=16*1024, *,
|
||||
EOF_MARKER=EOF_MARKER, EOL_MARKER=EOL_MARKER):
|
||||
"""Split incoming stream into chunks."""
|
||||
buf = bytearray()
|
||||
chunk = yield
|
||||
|
||||
while True:
|
||||
if chunk is EOF_MARKER:
|
||||
if buf:
|
||||
yield buf
|
||||
|
||||
yield EOF_MARKER
|
||||
|
||||
else:
|
||||
buf.extend(chunk)
|
||||
|
||||
while len(buf) >= chunk_size:
|
||||
chunk = bytes(buf[:chunk_size])
|
||||
del buf[:chunk_size]
|
||||
yield chunk
|
||||
|
||||
chunk = yield EOL_MARKER
|
||||
|
||||
@wrap_payload_filter
|
||||
def add_compression_filter(self, encoding='deflate', *,
|
||||
EOF_MARKER=EOF_MARKER, EOL_MARKER=EOL_MARKER):
|
||||
"""Compress incoming stream with deflate or gzip encoding."""
|
||||
zlib_mode = (16 + zlib.MAX_WBITS
|
||||
if encoding == 'gzip' else -zlib.MAX_WBITS)
|
||||
zcomp = zlib.compressobj(wbits=zlib_mode)
|
||||
|
||||
chunk = yield
|
||||
while True:
|
||||
if chunk is EOF_MARKER:
|
||||
yield zcomp.flush()
|
||||
chunk = yield EOF_MARKER
|
||||
|
||||
else:
|
||||
yield zcomp.compress(chunk)
|
||||
chunk = yield EOL_MARKER
|
||||
|
||||
|
||||
class Response(HttpMessage):
|
||||
"""Create HTTP response message.
|
||||
|
||||
Transport is a socket stream transport. status is a response status code,
|
||||
status has to be integer value. http_version is a tuple that represents
|
||||
HTTP version, (1, 0) stands for HTTP/1.0 and (1, 1) is for HTTP/1.1
|
||||
"""
|
||||
|
||||
HOP_HEADERS = ()
|
||||
|
||||
@staticmethod
|
||||
def calc_reason(status, *, _RESPONSES=RESPONSES):
|
||||
record = _RESPONSES.get(status)
|
||||
if record is not None:
|
||||
reason = record[0]
|
||||
else:
|
||||
reason = str(status)
|
||||
return reason
|
||||
|
||||
def __init__(self, transport, status,
|
||||
http_version=HttpVersion11, close=False, reason=None):
|
||||
super().__init__(transport, http_version, close)
|
||||
|
||||
self._status = status
|
||||
if reason is None:
|
||||
reason = self.calc_reason(status)
|
||||
|
||||
self._reason = reason
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
return self._status
|
||||
|
||||
@property
|
||||
def reason(self):
|
||||
return self._reason
|
||||
|
||||
@reify
|
||||
def status_line(self):
|
||||
version = self.version
|
||||
return 'HTTP/{}.{} {} {}\r\n'.format(
|
||||
version[0], version[1], self.status, self.reason)
|
||||
|
||||
def autochunked(self):
|
||||
return (self.length is None and
|
||||
self.version >= HttpVersion11)
|
||||
|
||||
def _add_default_headers(self):
|
||||
super()._add_default_headers()
|
||||
|
||||
if hdrs.DATE not in self.headers:
|
||||
# format_date_time(None) is quite expensive
|
||||
self.headers.setdefault(hdrs.DATE, format_date_time(None))
|
||||
self.headers.setdefault(hdrs.SERVER, self.SERVER_SOFTWARE)
|
||||
|
||||
|
||||
class Request(HttpMessage):
|
||||
|
||||
HOP_HEADERS = ()
|
||||
|
||||
def __init__(self, transport, method, path,
|
||||
http_version=HttpVersion11, close=False):
|
||||
# set the default for HTTP 0.9 to be different
|
||||
# will only be overwritten with keep-alive header
|
||||
if http_version < HttpVersion10:
|
||||
close = True
|
||||
|
||||
super().__init__(transport, http_version, close)
|
||||
|
||||
self._method = method
|
||||
self._path = path
|
||||
|
||||
@property
|
||||
def method(self):
|
||||
return self._method
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return self._path
|
||||
|
||||
@reify
|
||||
def status_line(self):
|
||||
return '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format(
|
||||
self.method, self.path, self.version)
|
||||
|
||||
def autochunked(self):
|
||||
return (self.length is None and
|
||||
self.version >= HttpVersion11 and
|
||||
self.status not in (304, 204))
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
|
||||
import pytest
|
||||
|
||||
from aiohttp.web import Application
|
||||
|
||||
from .test_utils import unused_port as _unused_port
|
||||
from .test_utils import (TestClient, TestServer, loop_context, setup_test_loop,
|
||||
teardown_test_loop)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _passthrough_loop_context(loop):
|
||||
if loop:
|
||||
# loop already exists, pass it straight through
|
||||
yield loop
|
||||
else:
|
||||
# this shadows loop_context's standard behavior
|
||||
loop = setup_test_loop()
|
||||
yield loop
|
||||
teardown_test_loop(loop)
|
||||
|
||||
|
||||
def pytest_pycollect_makeitem(collector, name, obj):
|
||||
"""
|
||||
Fix pytest collecting for coroutines.
|
||||
"""
|
||||
if collector.funcnamefilter(name) and asyncio.iscoroutinefunction(obj):
|
||||
return list(collector._genfunctions(name, obj))
|
||||
|
||||
|
||||
def pytest_pyfunc_call(pyfuncitem):
|
||||
"""
|
||||
Run coroutines in an event loop instead of a normal function call.
|
||||
"""
|
||||
if asyncio.iscoroutinefunction(pyfuncitem.function):
|
||||
existing_loop = pyfuncitem.funcargs.get('loop', None)
|
||||
with _passthrough_loop_context(existing_loop) as _loop:
|
||||
testargs = {arg: pyfuncitem.funcargs[arg]
|
||||
for arg in pyfuncitem._fixtureinfo.argnames}
|
||||
|
||||
task = _loop.create_task(pyfuncitem.obj(**testargs))
|
||||
_loop.run_until_complete(task)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@pytest.yield_fixture
|
||||
def loop():
|
||||
with loop_context() as _loop:
|
||||
yield _loop
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unused_port():
|
||||
return _unused_port
|
||||
|
||||
|
||||
@pytest.yield_fixture
|
||||
def test_server(loop):
|
||||
servers = []
|
||||
|
||||
@asyncio.coroutine
|
||||
def go(app, **kwargs):
|
||||
assert app.loop is loop, \
|
||||
"Application is attached to other event loop"
|
||||
|
||||
server = TestServer(app)
|
||||
yield from server.start_server(**kwargs)
|
||||
servers.append(server)
|
||||
return server
|
||||
|
||||
yield go
|
||||
|
||||
@asyncio.coroutine
|
||||
def finalize():
|
||||
while servers:
|
||||
yield from servers.pop().close()
|
||||
|
||||
loop.run_until_complete(finalize())
|
||||
|
||||
|
||||
@pytest.yield_fixture
|
||||
def test_client(loop):
|
||||
clients = []
|
||||
|
||||
@asyncio.coroutine
|
||||
def go(__param, *args, **kwargs):
|
||||
if isinstance(__param, Application):
|
||||
assert not args, "args should be empty"
|
||||
assert not kwargs, "kwargs should be empty"
|
||||
assert __param.loop is loop, \
|
||||
"Application is attached to other event loop"
|
||||
elif isinstance(__param, TestServer):
|
||||
assert __param.app.loop is loop, \
|
||||
"TestServer is attached to other event loop"
|
||||
else:
|
||||
__param = __param(loop, *args, **kwargs)
|
||||
|
||||
client = TestClient(__param)
|
||||
yield from client.start_server()
|
||||
clients.append(client)
|
||||
return client
|
||||
|
||||
yield go
|
||||
|
||||
@asyncio.coroutine
|
||||
def finalize():
|
||||
while clients:
|
||||
yield from clients.pop().close()
|
||||
|
||||
loop.run_until_complete(finalize())
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
import asyncio
|
||||
import socket
|
||||
|
||||
from .abc import AbstractResolver
|
||||
|
||||
__all__ = ('ThreadedResolver', 'AsyncResolver', 'DefaultResolver')
|
||||
|
||||
try:
|
||||
import aiodns
|
||||
aiodns_default = hasattr(aiodns.DNSResolver, 'gethostbyname')
|
||||
except ImportError: # pragma: no cover
|
||||
aiodns = None
|
||||
aiodns_default = False
|
||||
|
||||
|
||||
class ThreadedResolver(AbstractResolver):
|
||||
"""Use Executor for synchronous getaddrinfo() calls, which defaults to
|
||||
concurrent.futures.ThreadPoolExecutor.
|
||||
"""
|
||||
|
||||
def __init__(self, loop=None):
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
self._loop = loop
|
||||
|
||||
@asyncio.coroutine
|
||||
def resolve(self, host, port=0, family=socket.AF_INET):
|
||||
infos = yield from self._loop.getaddrinfo(
|
||||
host, port, type=socket.SOCK_STREAM, family=family)
|
||||
|
||||
hosts = []
|
||||
for family, _, proto, _, address in infos:
|
||||
hosts.append(
|
||||
{'hostname': host,
|
||||
'host': address[0], 'port': address[1],
|
||||
'family': family, 'proto': proto,
|
||||
'flags': socket.AI_NUMERICHOST})
|
||||
|
||||
return hosts
|
||||
|
||||
@asyncio.coroutine
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class AsyncResolver(AbstractResolver):
|
||||
"""Use the `aiodns` package to make asynchronous DNS lookups"""
|
||||
|
||||
def __init__(self, loop=None, *args, **kwargs):
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if aiodns is None:
|
||||
raise RuntimeError("Resolver requires aiodns library")
|
||||
|
||||
self._loop = loop
|
||||
self._resolver = aiodns.DNSResolver(*args, loop=loop, **kwargs)
|
||||
|
||||
if not hasattr(self._resolver, 'gethostbyname'):
|
||||
# aiodns 1.1 is not available, fallback to DNSResolver.query
|
||||
self.resolve = self.resolve_with_query
|
||||
|
||||
@asyncio.coroutine
|
||||
def resolve(self, host, port=0, family=socket.AF_INET):
|
||||
hosts = []
|
||||
resp = yield from self._resolver.gethostbyname(host, family)
|
||||
|
||||
for address in resp.addresses:
|
||||
hosts.append(
|
||||
{'hostname': host,
|
||||
'host': address, 'port': port,
|
||||
'family': family, 'proto': 0,
|
||||
'flags': socket.AI_NUMERICHOST})
|
||||
return hosts
|
||||
|
||||
@asyncio.coroutine
|
||||
def resolve_with_query(self, host, port=0, family=socket.AF_INET):
|
||||
if family == socket.AF_INET6:
|
||||
qtype = 'AAAA'
|
||||
else:
|
||||
qtype = 'A'
|
||||
|
||||
hosts = []
|
||||
resp = yield from self._resolver.query(host, qtype)
|
||||
|
||||
for rr in resp:
|
||||
hosts.append(
|
||||
{'hostname': host,
|
||||
'host': rr.host, 'port': port,
|
||||
'family': family, 'proto': 0,
|
||||
'flags': socket.AI_NUMERICHOST})
|
||||
|
||||
return hosts
|
||||
|
||||
@asyncio.coroutine
|
||||
def close(self):
|
||||
return self._resolver.cancel()
|
||||
|
||||
|
||||
DefaultResolver = AsyncResolver if aiodns_default else ThreadedResolver
|
||||
|
|
@ -0,0 +1,376 @@
|
|||
"""simple HTTP server."""
|
||||
|
||||
import asyncio
|
||||
import http.server
|
||||
import socket
|
||||
import traceback
|
||||
import warnings
|
||||
from contextlib import suppress
|
||||
from html import escape as html_escape
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import errors, hdrs, helpers, streams
|
||||
from aiohttp.helpers import Timeout, _get_kwarg, ensure_future
|
||||
from aiohttp.log import access_logger, server_logger
|
||||
|
||||
__all__ = ('ServerHttpProtocol',)
|
||||
|
||||
|
||||
RESPONSES = http.server.BaseHTTPRequestHandler.responses
|
||||
DEFAULT_ERROR_MESSAGE = """
|
||||
<html>
|
||||
<head>
|
||||
<title>{status} {reason}</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>{status} {reason}</h1>
|
||||
{message}
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
|
||||
if hasattr(socket, 'SO_KEEPALIVE'):
|
||||
def tcp_keepalive(server, transport):
|
||||
sock = transport.get_extra_info('socket')
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
||||
else:
|
||||
def tcp_keepalive(server, transport): # pragma: no cover
|
||||
pass
|
||||
|
||||
EMPTY_PAYLOAD = streams.EmptyStreamReader()
|
||||
|
||||
|
||||
class ServerHttpProtocol(aiohttp.StreamProtocol):
|
||||
"""Simple HTTP protocol implementation.
|
||||
|
||||
ServerHttpProtocol handles incoming HTTP request. It reads request line,
|
||||
request headers and request payload and calls handle_request() method.
|
||||
By default it always returns with 404 response.
|
||||
|
||||
ServerHttpProtocol handles errors in incoming request, like bad
|
||||
status line, bad headers or incomplete payload. If any error occurs,
|
||||
connection gets closed.
|
||||
|
||||
:param keepalive_timeout: number of seconds before closing
|
||||
keep-alive connection
|
||||
:type keepalive_timeout: int or None
|
||||
|
||||
:param bool tcp_keepalive: TCP keep-alive is on, default is on
|
||||
|
||||
:param int slow_request_timeout: slow request timeout
|
||||
|
||||
:param bool debug: enable debug mode
|
||||
|
||||
:param logger: custom logger object
|
||||
:type logger: aiohttp.log.server_logger
|
||||
|
||||
:param access_log: custom logging object
|
||||
:type access_log: aiohttp.log.server_logger
|
||||
|
||||
:param str access_log_format: access log format string
|
||||
|
||||
:param loop: Optional event loop
|
||||
|
||||
:param int max_line_size: Optional maximum header line size
|
||||
|
||||
:param int max_field_size: Optional maximum header field size
|
||||
|
||||
:param int max_headers: Optional maximum header size
|
||||
|
||||
"""
|
||||
_request_count = 0
|
||||
_request_handler = None
|
||||
_reading_request = False
|
||||
_keepalive = False # keep transport open
|
||||
|
||||
def __init__(self, *, loop=None,
|
||||
keepalive_timeout=75, # NGINX default value is 75 secs
|
||||
tcp_keepalive=True,
|
||||
slow_request_timeout=0,
|
||||
logger=server_logger,
|
||||
access_log=access_logger,
|
||||
access_log_format=helpers.AccessLogger.LOG_FORMAT,
|
||||
debug=False,
|
||||
max_line_size=8190,
|
||||
max_headers=32768,
|
||||
max_field_size=8190,
|
||||
**kwargs):
|
||||
|
||||
# process deprecated params
|
||||
logger = _get_kwarg(kwargs, 'log', 'logger', logger)
|
||||
|
||||
tcp_keepalive = _get_kwarg(kwargs, 'keep_alive_on',
|
||||
'tcp_keepalive', tcp_keepalive)
|
||||
|
||||
keepalive_timeout = _get_kwarg(kwargs, 'keep_alive',
|
||||
'keepalive_timeout', keepalive_timeout)
|
||||
|
||||
slow_request_timeout = _get_kwarg(kwargs, 'timeout',
|
||||
'slow_request_timeout',
|
||||
slow_request_timeout)
|
||||
|
||||
super().__init__(
|
||||
loop=loop,
|
||||
disconnect_error=errors.ClientDisconnectedError, **kwargs)
|
||||
|
||||
self._tcp_keepalive = tcp_keepalive
|
||||
self._keepalive_timeout = keepalive_timeout
|
||||
self._slow_request_timeout = slow_request_timeout
|
||||
self._loop = loop if loop is not None else asyncio.get_event_loop()
|
||||
|
||||
self._request_prefix = aiohttp.HttpPrefixParser()
|
||||
self._request_parser = aiohttp.HttpRequestParser(
|
||||
max_line_size=max_line_size,
|
||||
max_field_size=max_field_size,
|
||||
max_headers=max_headers)
|
||||
|
||||
self.logger = logger
|
||||
self.debug = debug
|
||||
self.access_log = access_log
|
||||
if access_log:
|
||||
self.access_logger = helpers.AccessLogger(access_log,
|
||||
access_log_format)
|
||||
else:
|
||||
self.access_logger = None
|
||||
self._closing = False
|
||||
|
||||
@property
|
||||
def keep_alive_timeout(self):
|
||||
warnings.warn("Use keepalive_timeout property instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
return self._keepalive_timeout
|
||||
|
||||
@property
|
||||
def keepalive_timeout(self):
|
||||
return self._keepalive_timeout
|
||||
|
||||
@asyncio.coroutine
|
||||
def shutdown(self, timeout=15.0):
|
||||
"""Worker process is about to exit, we need cleanup everything and
|
||||
stop accepting requests. It is especially important for keep-alive
|
||||
connections."""
|
||||
if self._request_handler is None:
|
||||
return
|
||||
self._closing = True
|
||||
|
||||
if timeout:
|
||||
canceller = self._loop.call_later(timeout,
|
||||
self._request_handler.cancel)
|
||||
with suppress(asyncio.CancelledError):
|
||||
yield from self._request_handler
|
||||
canceller.cancel()
|
||||
else:
|
||||
self._request_handler.cancel()
|
||||
|
||||
def connection_made(self, transport):
|
||||
super().connection_made(transport)
|
||||
|
||||
self._request_handler = ensure_future(self.start(), loop=self._loop)
|
||||
|
||||
if self._tcp_keepalive:
|
||||
tcp_keepalive(self, transport)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
super().connection_lost(exc)
|
||||
|
||||
self._closing = True
|
||||
if self._request_handler is not None:
|
||||
self._request_handler.cancel()
|
||||
|
||||
def data_received(self, data):
|
||||
super().data_received(data)
|
||||
|
||||
# reading request
|
||||
if not self._reading_request:
|
||||
self._reading_request = True
|
||||
|
||||
def keep_alive(self, val):
|
||||
"""Set keep-alive connection mode.
|
||||
|
||||
:param bool val: new state.
|
||||
"""
|
||||
self._keepalive = val
|
||||
|
||||
def log_access(self, message, environ, response, time):
|
||||
if self.access_logger:
|
||||
self.access_logger.log(message, environ, response,
|
||||
self.transport, time)
|
||||
|
||||
def log_debug(self, *args, **kw):
|
||||
if self.debug:
|
||||
self.logger.debug(*args, **kw)
|
||||
|
||||
def log_exception(self, *args, **kw):
|
||||
self.logger.exception(*args, **kw)
|
||||
|
||||
@asyncio.coroutine
|
||||
def start(self):
|
||||
"""Start processing of incoming requests.
|
||||
|
||||
It reads request line, request headers and request payload, then
|
||||
calls handle_request() method. Subclass has to override
|
||||
handle_request(). start() handles various exceptions in request
|
||||
or response handling. Connection is being closed always unless
|
||||
keep_alive(True) specified.
|
||||
"""
|
||||
reader = self.reader
|
||||
|
||||
try:
|
||||
while not self._closing:
|
||||
message = None
|
||||
self._keepalive = False
|
||||
self._request_count += 1
|
||||
self._reading_request = False
|
||||
|
||||
payload = None
|
||||
with Timeout(max(self._slow_request_timeout,
|
||||
self._keepalive_timeout),
|
||||
loop=self._loop):
|
||||
# read HTTP request method
|
||||
prefix = reader.set_parser(self._request_prefix)
|
||||
yield from prefix.read()
|
||||
|
||||
# start reading request
|
||||
self._reading_request = True
|
||||
|
||||
# start slow request timer
|
||||
# read request headers
|
||||
httpstream = reader.set_parser(self._request_parser)
|
||||
message = yield from httpstream.read()
|
||||
|
||||
# request may not have payload
|
||||
try:
|
||||
content_length = int(
|
||||
message.headers.get(hdrs.CONTENT_LENGTH, 0))
|
||||
except ValueError:
|
||||
raise errors.InvalidHeader(hdrs.CONTENT_LENGTH) from None
|
||||
|
||||
if (content_length > 0 or
|
||||
message.method == 'CONNECT' or
|
||||
hdrs.SEC_WEBSOCKET_KEY1 in message.headers or
|
||||
'chunked' in message.headers.get(
|
||||
hdrs.TRANSFER_ENCODING, '')):
|
||||
payload = streams.FlowControlStreamReader(
|
||||
reader, loop=self._loop)
|
||||
reader.set_parser(
|
||||
aiohttp.HttpPayloadParser(message), payload)
|
||||
else:
|
||||
payload = EMPTY_PAYLOAD
|
||||
|
||||
yield from self.handle_request(message, payload)
|
||||
|
||||
if payload and not payload.is_eof():
|
||||
self.log_debug('Uncompleted request.')
|
||||
self._closing = True
|
||||
else:
|
||||
reader.unset_parser()
|
||||
if not self._keepalive or not self._keepalive_timeout:
|
||||
self._closing = True
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.log_debug(
|
||||
'Request handler cancelled.')
|
||||
return
|
||||
except asyncio.TimeoutError:
|
||||
self.log_debug(
|
||||
'Request handler timed out.')
|
||||
return
|
||||
except errors.ClientDisconnectedError:
|
||||
self.log_debug(
|
||||
'Ignored premature client disconnection #1.')
|
||||
return
|
||||
except errors.HttpProcessingError as exc:
|
||||
yield from self.handle_error(exc.code, message,
|
||||
None, exc, exc.headers,
|
||||
exc.message)
|
||||
except Exception as exc:
|
||||
yield from self.handle_error(500, message, None, exc)
|
||||
finally:
|
||||
self._request_handler = None
|
||||
if self.transport is None:
|
||||
self.log_debug(
|
||||
'Ignored premature client disconnection #2.')
|
||||
else:
|
||||
self.transport.close()
|
||||
|
||||
def handle_error(self, status=500, message=None,
|
||||
payload=None, exc=None, headers=None, reason=None):
|
||||
"""Handle errors.
|
||||
|
||||
Returns HTTP response with specific status code. Logs additional
|
||||
information. It always closes current connection."""
|
||||
now = self._loop.time()
|
||||
try:
|
||||
if self.transport is None:
|
||||
# client has been disconnected during writing.
|
||||
return ()
|
||||
|
||||
if status == 500:
|
||||
self.log_exception("Error handling request")
|
||||
|
||||
try:
|
||||
if reason is None or reason == '':
|
||||
reason, msg = RESPONSES[status]
|
||||
else:
|
||||
msg = reason
|
||||
except KeyError:
|
||||
status = 500
|
||||
reason, msg = '???', ''
|
||||
|
||||
if self.debug and exc is not None:
|
||||
try:
|
||||
tb = traceback.format_exc()
|
||||
tb = html_escape(tb)
|
||||
msg += '<br><h2>Traceback:</h2>\n<pre>{}</pre>'.format(tb)
|
||||
except:
|
||||
pass
|
||||
|
||||
html = DEFAULT_ERROR_MESSAGE.format(
|
||||
status=status, reason=reason, message=msg).encode('utf-8')
|
||||
|
||||
response = aiohttp.Response(self.writer, status, close=True)
|
||||
response.add_header(hdrs.CONTENT_TYPE, 'text/html; charset=utf-8')
|
||||
response.add_header(hdrs.CONTENT_LENGTH, str(len(html)))
|
||||
if headers is not None:
|
||||
for name, value in headers:
|
||||
response.add_header(name, value)
|
||||
response.send_headers()
|
||||
|
||||
response.write(html)
|
||||
# disable CORK, enable NODELAY if needed
|
||||
self.writer.set_tcp_nodelay(True)
|
||||
drain = response.write_eof()
|
||||
|
||||
self.log_access(message, None, response, self._loop.time() - now)
|
||||
return drain
|
||||
finally:
|
||||
self.keep_alive(False)
|
||||
|
||||
def handle_request(self, message, payload):
|
||||
"""Handle a single HTTP request.
|
||||
|
||||
Subclass should override this method. By default it always
|
||||
returns 404 response.
|
||||
|
||||
:param message: Request headers
|
||||
:type message: aiohttp.protocol.HttpRequestParser
|
||||
:param payload: Request payload
|
||||
:type payload: aiohttp.streams.FlowControlStreamReader
|
||||
"""
|
||||
now = self._loop.time()
|
||||
response = aiohttp.Response(
|
||||
self.writer, 404, http_version=message.version, close=True)
|
||||
|
||||
body = b'Page Not Found!'
|
||||
|
||||
response.add_header(hdrs.CONTENT_TYPE, 'text/plain')
|
||||
response.add_header(hdrs.CONTENT_LENGTH, str(len(body)))
|
||||
response.send_headers()
|
||||
response.write(body)
|
||||
drain = response.write_eof()
|
||||
|
||||
self.keep_alive(False)
|
||||
self.log_access(message, None, response, self._loop.time() - now)
|
||||
|
||||
return drain
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
import asyncio
|
||||
from itertools import count
|
||||
|
||||
|
||||
class BaseSignal(list):
|
||||
|
||||
@asyncio.coroutine
|
||||
def _send(self, *args, **kwargs):
|
||||
for receiver in self:
|
||||
res = receiver(*args, **kwargs)
|
||||
if asyncio.iscoroutine(res) or isinstance(res, asyncio.Future):
|
||||
yield from res
|
||||
|
||||
def copy(self):
|
||||
raise NotImplementedError("copy() is forbidden")
|
||||
|
||||
def sort(self):
|
||||
raise NotImplementedError("sort() is forbidden")
|
||||
|
||||
|
||||
class Signal(BaseSignal):
|
||||
"""Coroutine-based signal implementation.
|
||||
|
||||
To connect a callback to a signal, use any list method.
|
||||
|
||||
Signals are fired using the :meth:`send` coroutine, which takes named
|
||||
arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__()
|
||||
self._app = app
|
||||
klass = self.__class__
|
||||
self._name = klass.__module__ + ':' + klass.__qualname__
|
||||
self._pre = app.on_pre_signal
|
||||
self._post = app.on_post_signal
|
||||
|
||||
@asyncio.coroutine
|
||||
def send(self, *args, **kwargs):
|
||||
"""
|
||||
Sends data to all registered receivers.
|
||||
"""
|
||||
ordinal = None
|
||||
debug = self._app._debug
|
||||
if debug:
|
||||
ordinal = self._pre.ordinal()
|
||||
yield from self._pre.send(ordinal, self._name, *args, **kwargs)
|
||||
yield from self._send(*args, **kwargs)
|
||||
if debug:
|
||||
yield from self._post.send(ordinal, self._name, *args, **kwargs)
|
||||
|
||||
|
||||
class DebugSignal(BaseSignal):
|
||||
|
||||
@asyncio.coroutine
|
||||
def send(self, ordinal, name, *args, **kwargs):
|
||||
yield from self._send(ordinal, name, *args, **kwargs)
|
||||
|
||||
|
||||
class PreSignal(DebugSignal):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._counter = count(1)
|
||||
|
||||
def ordinal(self):
|
||||
return next(self._counter)
|
||||
|
||||
|
||||
class PostSignal(DebugSignal):
|
||||
pass
|
||||
|
|
@ -0,0 +1,672 @@
|
|||
import asyncio
|
||||
import collections
|
||||
import functools
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from . import helpers
|
||||
from .log import internal_logger
|
||||
|
||||
__all__ = (
|
||||
'EofStream', 'StreamReader', 'DataQueue', 'ChunksQueue',
|
||||
'FlowControlStreamReader',
|
||||
'FlowControlDataQueue', 'FlowControlChunksQueue')
|
||||
|
||||
PY_35 = sys.version_info >= (3, 5)
|
||||
PY_352 = sys.version_info >= (3, 5, 2)
|
||||
|
||||
EOF_MARKER = b''
|
||||
DEFAULT_LIMIT = 2 ** 16
|
||||
|
||||
|
||||
class EofStream(Exception):
|
||||
"""eof stream indication."""
|
||||
|
||||
|
||||
if PY_35:
|
||||
class AsyncStreamIterator:
|
||||
|
||||
def __init__(self, read_func):
|
||||
self.read_func = read_func
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
if not PY_352: # pragma: no cover
|
||||
__aiter__ = asyncio.coroutine(__aiter__)
|
||||
|
||||
@asyncio.coroutine
|
||||
def __anext__(self):
|
||||
try:
|
||||
rv = yield from self.read_func()
|
||||
except EofStream:
|
||||
raise StopAsyncIteration # NOQA
|
||||
if rv == EOF_MARKER:
|
||||
raise StopAsyncIteration # NOQA
|
||||
return rv
|
||||
|
||||
|
||||
class AsyncStreamReaderMixin:
|
||||
|
||||
if PY_35:
|
||||
def __aiter__(self):
|
||||
return AsyncStreamIterator(self.readline)
|
||||
|
||||
if not PY_352: # pragma: no cover
|
||||
__aiter__ = asyncio.coroutine(__aiter__)
|
||||
|
||||
def iter_chunked(self, n):
|
||||
"""Returns an asynchronous iterator that yields chunks of size n.
|
||||
|
||||
Python-3.5 available for Python 3.5+ only
|
||||
"""
|
||||
return AsyncStreamIterator(lambda: self.read(n))
|
||||
|
||||
def iter_any(self):
|
||||
"""Returns an asynchronous iterator that yields slices of data
|
||||
as they come.
|
||||
|
||||
Python-3.5 available for Python 3.5+ only
|
||||
"""
|
||||
return AsyncStreamIterator(self.readany)
|
||||
|
||||
|
||||
class StreamReader(AsyncStreamReaderMixin):
|
||||
"""An enhancement of asyncio.StreamReader.
|
||||
|
||||
Supports asynchronous iteration by line, chunk or as available::
|
||||
|
||||
async for line in reader:
|
||||
...
|
||||
async for chunk in reader.iter_chunked(1024):
|
||||
...
|
||||
async for slice in reader.iter_any():
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
total_bytes = 0
|
||||
|
||||
def __init__(self, limit=DEFAULT_LIMIT, timeout=None, loop=None):
|
||||
self._limit = limit
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
self._loop = loop
|
||||
self._buffer = collections.deque()
|
||||
self._buffer_size = 0
|
||||
self._buffer_offset = 0
|
||||
self._eof = False
|
||||
self._waiter = None
|
||||
self._canceller = None
|
||||
self._eof_waiter = None
|
||||
self._exception = None
|
||||
self._timeout = timeout
|
||||
|
||||
def __repr__(self):
|
||||
info = ['StreamReader']
|
||||
if self._buffer_size:
|
||||
info.append('%d bytes' % self._buffer_size)
|
||||
if self._eof:
|
||||
info.append('eof')
|
||||
if self._limit != DEFAULT_LIMIT:
|
||||
info.append('l=%d' % self._limit)
|
||||
if self._waiter:
|
||||
info.append('w=%r' % self._waiter)
|
||||
if self._exception:
|
||||
info.append('e=%r' % self._exception)
|
||||
return '<%s>' % ' '.join(info)
|
||||
|
||||
def exception(self):
|
||||
return self._exception
|
||||
|
||||
def set_exception(self, exc):
|
||||
self._exception = exc
|
||||
|
||||
waiter = self._waiter
|
||||
if waiter is not None:
|
||||
self._waiter = None
|
||||
if not waiter.cancelled():
|
||||
waiter.set_exception(exc)
|
||||
|
||||
canceller = self._canceller
|
||||
if canceller is not None:
|
||||
self._canceller = None
|
||||
canceller.cancel()
|
||||
|
||||
def feed_eof(self):
|
||||
self._eof = True
|
||||
|
||||
waiter = self._waiter
|
||||
if waiter is not None:
|
||||
self._waiter = None
|
||||
if not waiter.cancelled():
|
||||
waiter.set_result(True)
|
||||
|
||||
canceller = self._canceller
|
||||
if canceller is not None:
|
||||
self._canceller = None
|
||||
canceller.cancel()
|
||||
|
||||
waiter = self._eof_waiter
|
||||
if waiter is not None:
|
||||
self._eof_waiter = None
|
||||
if not waiter.cancelled():
|
||||
waiter.set_result(True)
|
||||
|
||||
def is_eof(self):
|
||||
"""Return True if 'feed_eof' was called."""
|
||||
return self._eof
|
||||
|
||||
def at_eof(self):
|
||||
"""Return True if the buffer is empty and 'feed_eof' was called."""
|
||||
return self._eof and not self._buffer
|
||||
|
||||
@asyncio.coroutine
|
||||
def wait_eof(self):
|
||||
if self._eof:
|
||||
return
|
||||
|
||||
assert self._eof_waiter is None
|
||||
self._eof_waiter = helpers.create_future(self._loop)
|
||||
try:
|
||||
yield from self._eof_waiter
|
||||
finally:
|
||||
self._eof_waiter = None
|
||||
|
||||
def unread_data(self, data):
|
||||
""" rollback reading some data from stream, inserting it to buffer head.
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
|
||||
if self._buffer_offset:
|
||||
self._buffer[0] = self._buffer[0][self._buffer_offset:]
|
||||
self._buffer_offset = 0
|
||||
self._buffer.appendleft(data)
|
||||
self._buffer_size += len(data)
|
||||
|
||||
def feed_data(self, data):
|
||||
assert not self._eof, 'feed_data after feed_eof'
|
||||
|
||||
if not data:
|
||||
return
|
||||
|
||||
self._buffer.append(data)
|
||||
self._buffer_size += len(data)
|
||||
self.total_bytes += len(data)
|
||||
|
||||
waiter = self._waiter
|
||||
if waiter is not None:
|
||||
self._waiter = None
|
||||
if not waiter.cancelled():
|
||||
waiter.set_result(False)
|
||||
|
||||
canceller = self._canceller
|
||||
if canceller is not None:
|
||||
self._canceller = None
|
||||
canceller.cancel()
|
||||
|
||||
@asyncio.coroutine
|
||||
def _wait(self, func_name):
|
||||
# StreamReader uses a future to link the protocol feed_data() method
|
||||
# to a read coroutine. Running two read coroutines at the same time
|
||||
# would have an unexpected behaviour. It would not possible to know
|
||||
# which coroutine would get the next data.
|
||||
if self._waiter is not None:
|
||||
raise RuntimeError('%s() called while another coroutine is '
|
||||
'already waiting for incoming data' % func_name)
|
||||
waiter = self._waiter = helpers.create_future(self._loop)
|
||||
if self._timeout:
|
||||
self._canceller = self._loop.call_later(self._timeout,
|
||||
self.set_exception,
|
||||
asyncio.TimeoutError())
|
||||
try:
|
||||
yield from waiter
|
||||
finally:
|
||||
self._waiter = None
|
||||
if self._canceller is not None:
|
||||
self._canceller.cancel()
|
||||
self._canceller = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def readline(self):
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
|
||||
line = []
|
||||
line_size = 0
|
||||
not_enough = True
|
||||
|
||||
while not_enough:
|
||||
while self._buffer and not_enough:
|
||||
offset = self._buffer_offset
|
||||
ichar = self._buffer[0].find(b'\n', offset) + 1
|
||||
# Read from current offset to found b'\n' or to the end.
|
||||
data = self._read_nowait_chunk(ichar - offset if ichar else -1)
|
||||
line.append(data)
|
||||
line_size += len(data)
|
||||
if ichar:
|
||||
not_enough = False
|
||||
|
||||
if line_size > self._limit:
|
||||
raise ValueError('Line is too long')
|
||||
|
||||
if self._eof:
|
||||
break
|
||||
|
||||
if not_enough:
|
||||
yield from self._wait('readline')
|
||||
|
||||
return b''.join(line)
|
||||
|
||||
@asyncio.coroutine
|
||||
def read(self, n=-1):
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
|
||||
# migration problem; with DataQueue you have to catch
|
||||
# EofStream exception, so common way is to run payload.read() inside
|
||||
# infinite loop. what can cause real infinite loop with StreamReader
|
||||
# lets keep this code one major release.
|
||||
if __debug__:
|
||||
if self._eof and not self._buffer:
|
||||
self._eof_counter = getattr(self, '_eof_counter', 0) + 1
|
||||
if self._eof_counter > 5:
|
||||
stack = traceback.format_stack()
|
||||
internal_logger.warning(
|
||||
'Multiple access to StreamReader in eof state, '
|
||||
'might be infinite loop: \n%s', stack)
|
||||
|
||||
if not n:
|
||||
return EOF_MARKER
|
||||
|
||||
if n < 0:
|
||||
# This used to just loop creating a new waiter hoping to
|
||||
# collect everything in self._buffer, but that would
|
||||
# deadlock if the subprocess sends more than self.limit
|
||||
# bytes. So just call self.readany() until EOF.
|
||||
blocks = []
|
||||
while True:
|
||||
block = yield from self.readany()
|
||||
if not block:
|
||||
break
|
||||
blocks.append(block)
|
||||
return b''.join(blocks)
|
||||
|
||||
if not self._buffer and not self._eof:
|
||||
yield from self._wait('read')
|
||||
|
||||
return self._read_nowait(n)
|
||||
|
||||
@asyncio.coroutine
|
||||
def readany(self):
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
|
||||
if not self._buffer and not self._eof:
|
||||
yield from self._wait('readany')
|
||||
|
||||
return self._read_nowait(-1)
|
||||
|
||||
@asyncio.coroutine
|
||||
def readexactly(self, n):
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
|
||||
blocks = []
|
||||
while n > 0:
|
||||
block = yield from self.read(n)
|
||||
if not block:
|
||||
partial = b''.join(blocks)
|
||||
raise asyncio.streams.IncompleteReadError(
|
||||
partial, len(partial) + n)
|
||||
blocks.append(block)
|
||||
n -= len(block)
|
||||
|
||||
return b''.join(blocks)
|
||||
|
||||
def read_nowait(self, n=-1):
|
||||
# default was changed to be consistent with .read(-1)
|
||||
#
|
||||
# I believe the most users don't know about the method and
|
||||
# they are not affected.
|
||||
assert n is not None, "n should be -1"
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
|
||||
if self._waiter and not self._waiter.done():
|
||||
raise RuntimeError(
|
||||
'Called while some coroutine is waiting for incoming data.')
|
||||
|
||||
return self._read_nowait(n)
|
||||
|
||||
def _read_nowait_chunk(self, n):
|
||||
first_buffer = self._buffer[0]
|
||||
offset = self._buffer_offset
|
||||
if n != -1 and len(first_buffer) - offset > n:
|
||||
data = first_buffer[offset:offset + n]
|
||||
self._buffer_offset += n
|
||||
|
||||
elif offset:
|
||||
self._buffer.popleft()
|
||||
data = first_buffer[offset:]
|
||||
self._buffer_offset = 0
|
||||
|
||||
else:
|
||||
data = self._buffer.popleft()
|
||||
|
||||
self._buffer_size -= len(data)
|
||||
return data
|
||||
|
||||
def _read_nowait(self, n):
|
||||
chunks = []
|
||||
|
||||
while self._buffer:
|
||||
chunk = self._read_nowait_chunk(n)
|
||||
chunks.append(chunk)
|
||||
if n != -1:
|
||||
n -= len(chunk)
|
||||
if n == 0:
|
||||
break
|
||||
|
||||
return b''.join(chunks) if chunks else EOF_MARKER
|
||||
|
||||
|
||||
class EmptyStreamReader(AsyncStreamReaderMixin):
|
||||
|
||||
def exception(self):
|
||||
return None
|
||||
|
||||
def set_exception(self, exc):
|
||||
pass
|
||||
|
||||
def feed_eof(self):
|
||||
pass
|
||||
|
||||
def is_eof(self):
|
||||
return True
|
||||
|
||||
def at_eof(self):
|
||||
return True
|
||||
|
||||
@asyncio.coroutine
|
||||
def wait_eof(self):
|
||||
return
|
||||
|
||||
def feed_data(self, data):
|
||||
pass
|
||||
|
||||
@asyncio.coroutine
|
||||
def readline(self):
|
||||
return EOF_MARKER
|
||||
|
||||
@asyncio.coroutine
|
||||
def read(self, n=-1):
|
||||
return EOF_MARKER
|
||||
|
||||
@asyncio.coroutine
|
||||
def readany(self):
|
||||
return EOF_MARKER
|
||||
|
||||
@asyncio.coroutine
|
||||
def readexactly(self, n):
|
||||
raise asyncio.streams.IncompleteReadError(b'', n)
|
||||
|
||||
def read_nowait(self):
|
||||
return EOF_MARKER
|
||||
|
||||
|
||||
class DataQueue:
|
||||
"""DataQueue is a general-purpose blocking queue with one reader."""
|
||||
|
||||
def __init__(self, *, loop=None):
|
||||
self._loop = loop
|
||||
self._eof = False
|
||||
self._waiter = None
|
||||
self._exception = None
|
||||
self._size = 0
|
||||
self._buffer = collections.deque()
|
||||
|
||||
def is_eof(self):
|
||||
return self._eof
|
||||
|
||||
def at_eof(self):
|
||||
return self._eof and not self._buffer
|
||||
|
||||
def exception(self):
|
||||
return self._exception
|
||||
|
||||
def set_exception(self, exc):
|
||||
self._exception = exc
|
||||
|
||||
waiter = self._waiter
|
||||
if waiter is not None:
|
||||
self._waiter = None
|
||||
if not waiter.done():
|
||||
waiter.set_exception(exc)
|
||||
|
||||
def feed_data(self, data, size=0):
|
||||
self._size += size
|
||||
self._buffer.append((data, size))
|
||||
|
||||
waiter = self._waiter
|
||||
if waiter is not None:
|
||||
self._waiter = None
|
||||
if not waiter.cancelled():
|
||||
waiter.set_result(True)
|
||||
|
||||
def feed_eof(self):
|
||||
self._eof = True
|
||||
|
||||
waiter = self._waiter
|
||||
if waiter is not None:
|
||||
self._waiter = None
|
||||
if not waiter.cancelled():
|
||||
waiter.set_result(False)
|
||||
|
||||
@asyncio.coroutine
|
||||
def read(self):
|
||||
if not self._buffer and not self._eof:
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
|
||||
assert not self._waiter
|
||||
self._waiter = helpers.create_future(self._loop)
|
||||
try:
|
||||
yield from self._waiter
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
self._waiter = None
|
||||
raise
|
||||
|
||||
if self._buffer:
|
||||
data, size = self._buffer.popleft()
|
||||
self._size -= size
|
||||
return data
|
||||
else:
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
else:
|
||||
raise EofStream
|
||||
|
||||
if PY_35:
|
||||
def __aiter__(self):
|
||||
return AsyncStreamIterator(self.read)
|
||||
|
||||
if not PY_352: # pragma: no cover
|
||||
__aiter__ = asyncio.coroutine(__aiter__)
|
||||
|
||||
|
||||
class ChunksQueue(DataQueue):
|
||||
"""Like a :class:`DataQueue`, but for binary chunked data transfer."""
|
||||
|
||||
@asyncio.coroutine
|
||||
def read(self):
|
||||
try:
|
||||
return (yield from super().read())
|
||||
except EofStream:
|
||||
return EOF_MARKER
|
||||
|
||||
readany = read
|
||||
|
||||
|
||||
def maybe_resume(func):
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
@asyncio.coroutine
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kw):
|
||||
result = yield from func(self, *args, **kw)
|
||||
self._check_buffer_size()
|
||||
return result
|
||||
else:
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kw):
|
||||
result = func(self, *args, **kw)
|
||||
self._check_buffer_size()
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class FlowControlStreamReader(StreamReader):
|
||||
|
||||
def __init__(self, stream, limit=DEFAULT_LIMIT, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._stream = stream
|
||||
self._b_limit = limit * 2
|
||||
|
||||
# resume transport reading
|
||||
if stream.paused:
|
||||
try:
|
||||
self._stream.transport.resume_reading()
|
||||
except (AttributeError, NotImplementedError):
|
||||
pass
|
||||
else:
|
||||
self._stream.paused = False
|
||||
|
||||
def _check_buffer_size(self):
|
||||
if self._stream.paused:
|
||||
if self._buffer_size < self._b_limit:
|
||||
try:
|
||||
self._stream.transport.resume_reading()
|
||||
except (AttributeError, NotImplementedError):
|
||||
pass
|
||||
else:
|
||||
self._stream.paused = False
|
||||
else:
|
||||
if self._buffer_size > self._b_limit:
|
||||
try:
|
||||
self._stream.transport.pause_reading()
|
||||
except (AttributeError, NotImplementedError):
|
||||
pass
|
||||
else:
|
||||
self._stream.paused = True
|
||||
|
||||
def feed_data(self, data, size=0):
|
||||
has_waiter = self._waiter is not None and not self._waiter.cancelled()
|
||||
|
||||
super().feed_data(data)
|
||||
|
||||
if (not self._stream.paused and
|
||||
not has_waiter and self._buffer_size > self._b_limit):
|
||||
try:
|
||||
self._stream.transport.pause_reading()
|
||||
except (AttributeError, NotImplementedError):
|
||||
pass
|
||||
else:
|
||||
self._stream.paused = True
|
||||
|
||||
@maybe_resume
|
||||
@asyncio.coroutine
|
||||
def read(self, n=-1):
|
||||
return (yield from super().read(n))
|
||||
|
||||
@maybe_resume
|
||||
@asyncio.coroutine
|
||||
def readline(self):
|
||||
return (yield from super().readline())
|
||||
|
||||
@maybe_resume
|
||||
@asyncio.coroutine
|
||||
def readany(self):
|
||||
return (yield from super().readany())
|
||||
|
||||
@maybe_resume
|
||||
@asyncio.coroutine
|
||||
def readexactly(self, n):
|
||||
return (yield from super().readexactly(n))
|
||||
|
||||
@maybe_resume
|
||||
def read_nowait(self, n=-1):
|
||||
return super().read_nowait(n)
|
||||
|
||||
|
||||
class FlowControlDataQueue(DataQueue):
|
||||
"""FlowControlDataQueue resumes and pauses an underlying stream.
|
||||
|
||||
It is a destination for parsed data."""
|
||||
|
||||
def __init__(self, stream, *, limit=DEFAULT_LIMIT, loop=None):
|
||||
super().__init__(loop=loop)
|
||||
|
||||
self._stream = stream
|
||||
self._limit = limit * 2
|
||||
|
||||
# resume transport reading
|
||||
if stream.paused:
|
||||
try:
|
||||
self._stream.transport.resume_reading()
|
||||
except (AttributeError, NotImplementedError):
|
||||
pass
|
||||
else:
|
||||
self._stream.paused = False
|
||||
|
||||
def feed_data(self, data, size):
|
||||
has_waiter = self._waiter is not None and not self._waiter.cancelled()
|
||||
|
||||
super().feed_data(data, size)
|
||||
|
||||
if (not self._stream.paused and
|
||||
not has_waiter and self._size > self._limit):
|
||||
try:
|
||||
self._stream.transport.pause_reading()
|
||||
except (AttributeError, NotImplementedError):
|
||||
pass
|
||||
else:
|
||||
self._stream.paused = True
|
||||
|
||||
@asyncio.coroutine
|
||||
def read(self):
|
||||
result = yield from super().read()
|
||||
|
||||
if self._stream.paused:
|
||||
if self._size < self._limit:
|
||||
try:
|
||||
self._stream.transport.resume_reading()
|
||||
except (AttributeError, NotImplementedError):
|
||||
pass
|
||||
else:
|
||||
self._stream.paused = False
|
||||
else:
|
||||
if self._size > self._limit:
|
||||
try:
|
||||
self._stream.transport.pause_reading()
|
||||
except (AttributeError, NotImplementedError):
|
||||
pass
|
||||
else:
|
||||
self._stream.paused = True
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class FlowControlChunksQueue(FlowControlDataQueue):
|
||||
|
||||
@asyncio.coroutine
|
||||
def read(self):
|
||||
try:
|
||||
return (yield from super().read())
|
||||
except EofStream:
|
||||
return EOF_MARKER
|
||||
|
||||
readany = read
|
||||
|
|
@ -0,0 +1,485 @@
|
|||
"""Utilities shared by tests."""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import functools
|
||||
import gc
|
||||
import socket
|
||||
import sys
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
from multidict import CIMultiDict
|
||||
|
||||
import aiohttp
|
||||
|
||||
from . import ClientSession, hdrs
|
||||
from .helpers import sentinel
|
||||
from .protocol import HttpVersion, RawRequestMessage
|
||||
from .signals import Signal
|
||||
from .web import Application, Request
|
||||
|
||||
PY_35 = sys.version_info >= (3, 5)
|
||||
|
||||
|
||||
def run_briefly(loop):
|
||||
@asyncio.coroutine
|
||||
def once():
|
||||
pass
|
||||
t = asyncio.Task(once(), loop=loop)
|
||||
loop.run_until_complete(t)
|
||||
|
||||
|
||||
def unused_port():
|
||||
"""Return a port that is unused on the current host."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('127.0.0.1', 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
class TestServer:
|
||||
def __init__(self, app, *, scheme="http", host='127.0.0.1'):
|
||||
self.app = app
|
||||
self._loop = app.loop
|
||||
self.port = None
|
||||
self.server = None
|
||||
self.handler = None
|
||||
self._root = None
|
||||
self.host = host
|
||||
self.scheme = scheme
|
||||
self._closed = False
|
||||
|
||||
@asyncio.coroutine
|
||||
def start_server(self, **kwargs):
|
||||
if self.server:
|
||||
return
|
||||
self.port = unused_port()
|
||||
self._root = '{}://{}:{}'.format(self.scheme, self.host, self.port)
|
||||
self.handler = self.app.make_handler(**kwargs)
|
||||
self.server = yield from self._loop.create_server(self.handler,
|
||||
self.host,
|
||||
self.port)
|
||||
|
||||
def make_url(self, path):
|
||||
return self._root + path
|
||||
|
||||
@asyncio.coroutine
|
||||
def close(self):
|
||||
"""Close all fixtures created by the test client.
|
||||
|
||||
After that point, the TestClient is no longer usable.
|
||||
|
||||
This is an idempotent function: running close multiple times
|
||||
will not have any additional effects.
|
||||
|
||||
close is also run when the object is garbage collected, and on
|
||||
exit when used as a context manager.
|
||||
|
||||
"""
|
||||
if self.server is not None and not self._closed:
|
||||
self.server.close()
|
||||
yield from self.server.wait_closed()
|
||||
yield from self.app.shutdown()
|
||||
yield from self.handler.finish_connections()
|
||||
yield from self.app.cleanup()
|
||||
self._root = None
|
||||
self.port = None
|
||||
self._closed = True
|
||||
|
||||
def __enter__(self):
|
||||
self._loop.run_until_complete(self.start_server())
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self._loop.run_until_complete(self.close())
|
||||
|
||||
if PY_35:
|
||||
@asyncio.coroutine
|
||||
def __aenter__(self):
|
||||
yield from self.start_server()
|
||||
return self
|
||||
|
||||
@asyncio.coroutine
|
||||
def __aexit__(self, exc_type, exc_value, traceback):
|
||||
yield from self.close()
|
||||
|
||||
|
||||
class TestClient:
|
||||
"""
|
||||
A test client implementation, for a aiohttp.web.Application.
|
||||
|
||||
:param app: the aiohttp.web application passed to create_test_server
|
||||
|
||||
:type app: aiohttp.web.Application
|
||||
|
||||
:param protocol: http or https
|
||||
|
||||
:type protocol: str
|
||||
|
||||
TestClient can also be used as a contextmanager, returning
|
||||
the instance of itself instantiated.
|
||||
"""
|
||||
|
||||
def __init__(self, app_or_server, *, scheme=sentinel, host=sentinel):
|
||||
if isinstance(app_or_server, TestServer):
|
||||
if scheme is not sentinel or host is not sentinel:
|
||||
raise ValueError("scheme and host are mutable exclusive "
|
||||
"with TestServer parameter")
|
||||
self._server = app_or_server
|
||||
elif isinstance(app_or_server, Application):
|
||||
scheme = "http" if scheme is sentinel else scheme
|
||||
host = '127.0.0.1' if host is sentinel else host
|
||||
self._server = TestServer(app_or_server,
|
||||
scheme=scheme, host=host)
|
||||
else:
|
||||
raise TypeError("app_or_server should be either web.Application "
|
||||
"or TestServer instance")
|
||||
self._loop = self._server.app.loop
|
||||
self._session = ClientSession(
|
||||
loop=self._loop,
|
||||
cookie_jar=aiohttp.CookieJar(unsafe=True,
|
||||
loop=self._loop))
|
||||
self._closed = False
|
||||
self._responses = []
|
||||
|
||||
@asyncio.coroutine
|
||||
def start_server(self):
|
||||
yield from self._server.start_server()
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
return self._server.app
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
return self._server.host
|
||||
|
||||
@property
|
||||
def port(self):
|
||||
return self._server.port
|
||||
|
||||
@property
|
||||
def handler(self):
|
||||
return self._server.handler
|
||||
|
||||
@property
|
||||
def server(self):
|
||||
return self._server.server
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
"""A raw handler to the aiohttp.ClientSession.
|
||||
|
||||
Unlike the methods on the TestClient, client session requests
|
||||
do not automatically include the host in the url queried, and
|
||||
will require an absolute path to the resource.
|
||||
|
||||
"""
|
||||
return self._session
|
||||
|
||||
def make_url(self, path):
|
||||
return self._server.make_url(path)
|
||||
|
||||
@asyncio.coroutine
|
||||
def request(self, method, path, *args, **kwargs):
|
||||
"""Routes a request to the http server.
|
||||
|
||||
The interface is identical to asyncio.ClientSession.request,
|
||||
except the loop kwarg is overridden by the instance used by the
|
||||
application.
|
||||
|
||||
"""
|
||||
resp = yield from self._session.request(
|
||||
method, self.make_url(path), *args, **kwargs
|
||||
)
|
||||
# save it to close later
|
||||
self._responses.append(resp)
|
||||
return resp
|
||||
|
||||
def get(self, path, *args, **kwargs):
|
||||
"""Perform an HTTP GET request."""
|
||||
return self.request(hdrs.METH_GET, path, *args, **kwargs)
|
||||
|
||||
def post(self, path, *args, **kwargs):
|
||||
"""Perform an HTTP POST request."""
|
||||
return self.request(hdrs.METH_POST, path, *args, **kwargs)
|
||||
|
||||
def options(self, path, *args, **kwargs):
|
||||
"""Perform an HTTP OPTIONS request."""
|
||||
return self.request(hdrs.METH_OPTIONS, path, *args, **kwargs)
|
||||
|
||||
def head(self, path, *args, **kwargs):
|
||||
"""Perform an HTTP HEAD request."""
|
||||
return self.request(hdrs.METH_HEAD, path, *args, **kwargs)
|
||||
|
||||
def put(self, path, *args, **kwargs):
|
||||
"""Perform an HTTP PUT request."""
|
||||
return self.request(hdrs.METH_PUT, path, *args, **kwargs)
|
||||
|
||||
def patch(self, path, *args, **kwargs):
|
||||
"""Perform an HTTP PATCH request."""
|
||||
return self.request(hdrs.METH_PATCH, path, *args, **kwargs)
|
||||
|
||||
def delete(self, path, *args, **kwargs):
|
||||
"""Perform an HTTP PATCH request."""
|
||||
return self.request(hdrs.METH_DELETE, path, *args, **kwargs)
|
||||
|
||||
def ws_connect(self, path, *args, **kwargs):
|
||||
"""Initiate websocket connection.
|
||||
|
||||
The api is identical to aiohttp.ClientSession.ws_connect.
|
||||
|
||||
"""
|
||||
return self._session.ws_connect(
|
||||
self.make_url(path), *args, **kwargs
|
||||
)
|
||||
|
||||
@asyncio.coroutine
|
||||
def close(self):
|
||||
"""Close all fixtures created by the test client.
|
||||
|
||||
After that point, the TestClient is no longer usable.
|
||||
|
||||
This is an idempotent function: running close multiple times
|
||||
will not have any additional effects.
|
||||
|
||||
close is also run on exit when used as a(n) (asynchronous)
|
||||
context manager.
|
||||
|
||||
"""
|
||||
if not self._closed:
|
||||
for resp in self._responses:
|
||||
resp.close()
|
||||
yield from self._session.close()
|
||||
yield from self._server.close()
|
||||
self._closed = True
|
||||
|
||||
def __enter__(self):
|
||||
self._loop.run_until_complete(self.start_server())
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self._loop.run_until_complete(self.close())
|
||||
|
||||
if PY_35:
|
||||
@asyncio.coroutine
|
||||
def __aenter__(self):
|
||||
yield from self.start_server()
|
||||
return self
|
||||
|
||||
@asyncio.coroutine
|
||||
def __aexit__(self, exc_type, exc_value, traceback):
|
||||
yield from self.close()
|
||||
|
||||
|
||||
class AioHTTPTestCase(unittest.TestCase):
|
||||
"""A base class to allow for unittest web applications using
|
||||
aiohttp.
|
||||
|
||||
Provides the following:
|
||||
|
||||
* self.client (aiohttp.test_utils.TestClient): an aiohttp test client.
|
||||
* self.loop (asyncio.BaseEventLoop): the event loop in which the
|
||||
application and server are running.
|
||||
* self.app (aiohttp.web.Application): the application returned by
|
||||
self.get_app()
|
||||
|
||||
Note that the TestClient's methods are asynchronous: you have to
|
||||
execute function on the test client using asynchronous methods.
|
||||
"""
|
||||
|
||||
def get_app(self, loop):
|
||||
"""
|
||||
This method should be overridden
|
||||
to return the aiohttp.web.Application
|
||||
object to test.
|
||||
|
||||
:param loop: the event_loop to use
|
||||
:type loop: asyncio.BaseEventLoop
|
||||
"""
|
||||
pass # pragma: no cover
|
||||
|
||||
def setUp(self):
|
||||
self.loop = setup_test_loop()
|
||||
self.app = self.get_app(self.loop)
|
||||
self.client = TestClient(self.app)
|
||||
self.loop.run_until_complete(self.client.start_server())
|
||||
|
||||
def tearDown(self):
|
||||
self.loop.run_until_complete(self.client.close())
|
||||
teardown_test_loop(self.loop)
|
||||
|
||||
|
||||
def unittest_run_loop(func):
|
||||
"""A decorator dedicated to use with asynchronous methods of an
|
||||
AioHTTPTestCase.
|
||||
|
||||
Handles executing an asynchronous function, using
|
||||
the self.loop of the AioHTTPTestCase.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def new_func(self):
|
||||
return self.loop.run_until_complete(func(self))
|
||||
|
||||
return new_func
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def loop_context(loop_factory=asyncio.new_event_loop):
|
||||
"""A contextmanager that creates an event_loop, for test purposes.
|
||||
|
||||
Handles the creation and cleanup of a test loop.
|
||||
"""
|
||||
loop = setup_test_loop(loop_factory)
|
||||
yield loop
|
||||
teardown_test_loop(loop)
|
||||
|
||||
|
||||
def setup_test_loop(loop_factory=asyncio.new_event_loop):
|
||||
"""Create and return an asyncio.BaseEventLoop
|
||||
instance.
|
||||
|
||||
The caller should also call teardown_test_loop,
|
||||
once they are done with the loop.
|
||||
"""
|
||||
loop = loop_factory()
|
||||
asyncio.set_event_loop(None)
|
||||
return loop
|
||||
|
||||
|
||||
def teardown_test_loop(loop):
|
||||
"""Teardown and cleanup an event_loop created
|
||||
by setup_test_loop.
|
||||
|
||||
:param loop: the loop to teardown
|
||||
:type loop: asyncio.BaseEventLoop
|
||||
"""
|
||||
closed = loop.is_closed()
|
||||
if not closed:
|
||||
loop.call_soon(loop.stop)
|
||||
loop.run_forever()
|
||||
loop.close()
|
||||
gc.collect()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
|
||||
def _create_app_mock():
|
||||
app = mock.Mock()
|
||||
app._debug = False
|
||||
app.on_response_prepare = Signal(app)
|
||||
return app
|
||||
|
||||
|
||||
def _create_transport(sslcontext=None):
|
||||
transport = mock.Mock()
|
||||
|
||||
def get_extra_info(key):
|
||||
if key == 'sslcontext':
|
||||
return sslcontext
|
||||
else:
|
||||
return None
|
||||
|
||||
transport.get_extra_info.side_effect = get_extra_info
|
||||
return transport
|
||||
|
||||
|
||||
def make_mocked_request(method, path, headers=None, *,
|
||||
version=HttpVersion(1, 1), closing=False,
|
||||
app=None,
|
||||
reader=sentinel,
|
||||
writer=sentinel,
|
||||
transport=sentinel,
|
||||
payload=sentinel,
|
||||
sslcontext=None,
|
||||
secure_proxy_ssl_header=None):
|
||||
"""Creates mocked web.Request testing purposes.
|
||||
|
||||
Useful in unit tests, when spinning full web server is overkill or
|
||||
specific conditions and errors are hard to trigger.
|
||||
|
||||
:param method: str, that represents HTTP method, like; GET, POST.
|
||||
:type method: str
|
||||
|
||||
:param path: str, The URL including *PATH INFO* without the host or scheme
|
||||
:type path: str
|
||||
|
||||
:param headers: mapping containing the headers. Can be anything accepted
|
||||
by the multidict.CIMultiDict constructor.
|
||||
:type headers: dict, multidict.CIMultiDict, list of pairs
|
||||
|
||||
:param version: namedtuple with encoded HTTP version
|
||||
:type version: aiohttp.protocol.HttpVersion
|
||||
|
||||
:param closing: flag indicates that connection should be closed after
|
||||
response.
|
||||
:type closing: bool
|
||||
|
||||
:param app: the aiohttp.web application attached for fake request
|
||||
:type app: aiohttp.web.Application
|
||||
|
||||
:param reader: object for storing and managing incoming data
|
||||
:type reader: aiohttp.parsers.StreamParser
|
||||
|
||||
:param writer: object for managing outcoming data
|
||||
:type wirter: aiohttp.parsers.StreamWriter
|
||||
|
||||
:param transport: asyncio transport instance
|
||||
:type transport: asyncio.transports.Transport
|
||||
|
||||
:param payload: raw payload reader object
|
||||
:type payload: aiohttp.streams.FlowControlStreamReader
|
||||
|
||||
:param sslcontext: ssl.SSLContext object, for HTTPS connection
|
||||
:type sslcontext: ssl.SSLContext
|
||||
|
||||
:param secure_proxy_ssl_header: A tuple representing a HTTP header/value
|
||||
combination that signifies a request is secure.
|
||||
:type secure_proxy_ssl_header: tuple
|
||||
|
||||
"""
|
||||
|
||||
if version < HttpVersion(1, 1):
|
||||
closing = True
|
||||
|
||||
if headers:
|
||||
hdrs = CIMultiDict(headers)
|
||||
raw_hdrs = [
|
||||
(k.encode('utf-8'), v.encode('utf-8')) for k, v in headers.items()]
|
||||
else:
|
||||
hdrs = CIMultiDict()
|
||||
raw_hdrs = []
|
||||
|
||||
message = RawRequestMessage(method, path, version, hdrs,
|
||||
raw_hdrs, closing, False)
|
||||
if app is None:
|
||||
app = _create_app_mock()
|
||||
|
||||
if reader is sentinel:
|
||||
reader = mock.Mock()
|
||||
|
||||
if writer is sentinel:
|
||||
writer = mock.Mock()
|
||||
|
||||
if transport is sentinel:
|
||||
transport = _create_transport(sslcontext)
|
||||
|
||||
if payload is sentinel:
|
||||
payload = mock.Mock()
|
||||
|
||||
req = Request(app, message, payload,
|
||||
transport, reader, writer,
|
||||
secure_proxy_ssl_header=secure_proxy_ssl_header)
|
||||
|
||||
return req
|
||||
|
||||
|
||||
def make_mocked_coro(return_value=sentinel, raise_exception=sentinel):
|
||||
"""Creates a coroutine mock."""
|
||||
@asyncio.coroutine
|
||||
def mock_coro(*args, **kwargs):
|
||||
if raise_exception is not sentinel:
|
||||
raise raise_exception
|
||||
return return_value
|
||||
|
||||
return mock.Mock(wraps=mock_coro)
|
||||
|
|
@ -0,0 +1,376 @@
|
|||
import asyncio
|
||||
import sys
|
||||
import warnings
|
||||
from argparse import ArgumentParser
|
||||
from importlib import import_module
|
||||
|
||||
from . import hdrs, web_exceptions, web_reqrep, web_urldispatcher, web_ws
|
||||
from .abc import AbstractMatchInfo, AbstractRouter
|
||||
from .helpers import sentinel
|
||||
from .log import web_logger
|
||||
from .protocol import HttpVersion # noqa
|
||||
from .server import ServerHttpProtocol
|
||||
from .signals import PostSignal, PreSignal, Signal
|
||||
from .web_exceptions import * # noqa
|
||||
from .web_reqrep import * # noqa
|
||||
from .web_urldispatcher import * # noqa
|
||||
from .web_ws import * # noqa
|
||||
|
||||
|
||||
__all__ = (web_reqrep.__all__ +
|
||||
web_exceptions.__all__ +
|
||||
web_urldispatcher.__all__ +
|
||||
web_ws.__all__ +
|
||||
('Application', 'RequestHandler',
|
||||
'RequestHandlerFactory', 'HttpVersion',
|
||||
'MsgType'))
|
||||
|
||||
|
||||
class RequestHandler(ServerHttpProtocol):
|
||||
|
||||
_meth = 'none'
|
||||
_path = 'none'
|
||||
|
||||
def __init__(self, manager, app, router, *,
|
||||
secure_proxy_ssl_header=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._manager = manager
|
||||
self._app = app
|
||||
self._router = router
|
||||
self._middlewares = app.middlewares
|
||||
self._secure_proxy_ssl_header = secure_proxy_ssl_header
|
||||
|
||||
def __repr__(self):
|
||||
return "<{} {}:{} {}>".format(
|
||||
self.__class__.__name__, self._meth, self._path,
|
||||
'connected' if self.transport is not None else 'disconnected')
|
||||
|
||||
def connection_made(self, transport):
|
||||
super().connection_made(transport)
|
||||
|
||||
self._manager.connection_made(self, transport)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
self._manager.connection_lost(self, exc)
|
||||
|
||||
super().connection_lost(exc)
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_request(self, message, payload):
|
||||
self._manager._requests_count += 1
|
||||
if self.access_log:
|
||||
now = self._loop.time()
|
||||
|
||||
app = self._app
|
||||
request = web_reqrep.Request(
|
||||
app, message, payload,
|
||||
self.transport, self.reader, self.writer,
|
||||
secure_proxy_ssl_header=self._secure_proxy_ssl_header)
|
||||
self._meth = request.method
|
||||
self._path = request.path
|
||||
try:
|
||||
match_info = yield from self._router.resolve(request)
|
||||
|
||||
assert isinstance(match_info, AbstractMatchInfo), match_info
|
||||
|
||||
resp = None
|
||||
request._match_info = match_info
|
||||
expect = request.headers.get(hdrs.EXPECT)
|
||||
if expect:
|
||||
resp = (
|
||||
yield from match_info.expect_handler(request))
|
||||
|
||||
if resp is None:
|
||||
handler = match_info.handler
|
||||
for factory in reversed(self._middlewares):
|
||||
handler = yield from factory(app, handler)
|
||||
resp = yield from handler(request)
|
||||
|
||||
assert isinstance(resp, web_reqrep.StreamResponse), \
|
||||
("Handler {!r} should return response instance, "
|
||||
"got {!r} [middlewares {!r}]").format(
|
||||
match_info.handler, type(resp), self._middlewares)
|
||||
except web_exceptions.HTTPException as exc:
|
||||
resp = exc
|
||||
|
||||
resp_msg = yield from resp.prepare(request)
|
||||
yield from resp.write_eof()
|
||||
|
||||
# notify server about keep-alive
|
||||
self.keep_alive(resp.keep_alive)
|
||||
|
||||
# log access
|
||||
if self.access_log:
|
||||
self.log_access(message, None, resp_msg, self._loop.time() - now)
|
||||
|
||||
# for repr
|
||||
self._meth = 'none'
|
||||
self._path = 'none'
|
||||
|
||||
|
||||
class RequestHandlerFactory:
|
||||
|
||||
def __init__(self, app, router, *,
|
||||
handler=RequestHandler, loop=None,
|
||||
secure_proxy_ssl_header=None, **kwargs):
|
||||
self._app = app
|
||||
self._router = router
|
||||
self._handler = handler
|
||||
self._loop = loop
|
||||
self._connections = {}
|
||||
self._secure_proxy_ssl_header = secure_proxy_ssl_header
|
||||
self._kwargs = kwargs
|
||||
self._kwargs.setdefault('logger', app.logger)
|
||||
self._requests_count = 0
|
||||
|
||||
@property
|
||||
def requests_count(self):
|
||||
"""Number of processed requests."""
|
||||
return self._requests_count
|
||||
|
||||
@property
|
||||
def secure_proxy_ssl_header(self):
|
||||
return self._secure_proxy_ssl_header
|
||||
|
||||
@property
|
||||
def connections(self):
|
||||
return list(self._connections.keys())
|
||||
|
||||
def connection_made(self, handler, transport):
|
||||
self._connections[handler] = transport
|
||||
|
||||
def connection_lost(self, handler, exc=None):
|
||||
if handler in self._connections:
|
||||
del self._connections[handler]
|
||||
|
||||
@asyncio.coroutine
|
||||
def finish_connections(self, timeout=None):
|
||||
coros = [conn.shutdown(timeout) for conn in self._connections]
|
||||
yield from asyncio.gather(*coros, loop=self._loop)
|
||||
self._connections.clear()
|
||||
|
||||
def __call__(self):
|
||||
return self._handler(
|
||||
self, self._app, self._router, loop=self._loop,
|
||||
secure_proxy_ssl_header=self._secure_proxy_ssl_header,
|
||||
**self._kwargs)
|
||||
|
||||
|
||||
class Application(dict):
|
||||
|
||||
def __init__(self, *, logger=web_logger, loop=None,
|
||||
router=None, handler_factory=RequestHandlerFactory,
|
||||
middlewares=(), debug=False):
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
if router is None:
|
||||
router = web_urldispatcher.UrlDispatcher()
|
||||
assert isinstance(router, AbstractRouter), router
|
||||
|
||||
self._debug = debug
|
||||
self._router = router
|
||||
self._handler_factory = handler_factory
|
||||
self._loop = loop
|
||||
self.logger = logger
|
||||
|
||||
self._middlewares = list(middlewares)
|
||||
|
||||
self._on_pre_signal = PreSignal()
|
||||
self._on_post_signal = PostSignal()
|
||||
self._on_response_prepare = Signal(self)
|
||||
self._on_startup = Signal(self)
|
||||
self._on_shutdown = Signal(self)
|
||||
self._on_cleanup = Signal(self)
|
||||
|
||||
@property
|
||||
def debug(self):
|
||||
return self._debug
|
||||
|
||||
@property
|
||||
def on_response_prepare(self):
|
||||
return self._on_response_prepare
|
||||
|
||||
@property
|
||||
def on_pre_signal(self):
|
||||
return self._on_pre_signal
|
||||
|
||||
@property
|
||||
def on_post_signal(self):
|
||||
return self._on_post_signal
|
||||
|
||||
@property
|
||||
def on_startup(self):
|
||||
return self._on_startup
|
||||
|
||||
@property
|
||||
def on_shutdown(self):
|
||||
return self._on_shutdown
|
||||
|
||||
@property
|
||||
def on_cleanup(self):
|
||||
return self._on_cleanup
|
||||
|
||||
@property
|
||||
def router(self):
|
||||
return self._router
|
||||
|
||||
@property
|
||||
def loop(self):
|
||||
return self._loop
|
||||
|
||||
@property
|
||||
def middlewares(self):
|
||||
return self._middlewares
|
||||
|
||||
def make_handler(self, **kwargs):
|
||||
debug = kwargs.pop('debug', sentinel)
|
||||
if debug is not sentinel:
|
||||
warnings.warn(
|
||||
"`debug` parameter is deprecated. "
|
||||
"Use Application's debug mode instead", DeprecationWarning)
|
||||
if debug != self.debug:
|
||||
raise ValueError(
|
||||
"The value of `debug` parameter conflicts with the debug "
|
||||
"settings of the `Application` instance. The "
|
||||
"application's debug mode setting should be used instead "
|
||||
"as a single point to setup a debug mode. For more "
|
||||
"information please check "
|
||||
"http://aiohttp.readthedocs.io/en/stable/"
|
||||
"web_reference.html#aiohttp.web.Application"
|
||||
)
|
||||
return self._handler_factory(self, self.router, debug=self.debug,
|
||||
loop=self.loop, **kwargs)
|
||||
|
||||
@asyncio.coroutine
|
||||
def startup(self):
|
||||
"""Causes on_startup signal
|
||||
|
||||
Should be called in the event loop along with the request handler.
|
||||
"""
|
||||
yield from self.on_startup.send(self)
|
||||
|
||||
@asyncio.coroutine
|
||||
def shutdown(self):
|
||||
"""Causes on_shutdown signal
|
||||
|
||||
Should be called before cleanup()
|
||||
"""
|
||||
yield from self.on_shutdown.send(self)
|
||||
|
||||
@asyncio.coroutine
|
||||
def cleanup(self):
|
||||
"""Causes on_cleanup signal
|
||||
|
||||
Should be called after shutdown()
|
||||
"""
|
||||
yield from self.on_cleanup.send(self)
|
||||
|
||||
@asyncio.coroutine
|
||||
def finish(self):
|
||||
"""Finalize an application.
|
||||
|
||||
Deprecated alias for .cleanup()
|
||||
"""
|
||||
warnings.warn("Use .cleanup() instead", DeprecationWarning)
|
||||
yield from self.cleanup()
|
||||
|
||||
def register_on_finish(self, func, *args, **kwargs):
|
||||
warnings.warn("Use .on_cleanup.append() instead", DeprecationWarning)
|
||||
self.on_cleanup.append(lambda app: func(app, *args, **kwargs))
|
||||
|
||||
def copy(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self):
|
||||
"""gunicorn compatibility"""
|
||||
return self
|
||||
|
||||
def __repr__(self):
|
||||
return "<Application>"
|
||||
|
||||
|
||||
def run_app(app, *, host='0.0.0.0', port=None,
|
||||
shutdown_timeout=60.0, ssl_context=None,
|
||||
print=print, backlog=128):
|
||||
"""Run an app locally"""
|
||||
if port is None:
|
||||
if not ssl_context:
|
||||
port = 8080
|
||||
else:
|
||||
port = 8443
|
||||
|
||||
loop = app.loop
|
||||
|
||||
handler = app.make_handler()
|
||||
server = loop.create_server(handler, host, port, ssl=ssl_context,
|
||||
backlog=backlog)
|
||||
srv, startup_res = loop.run_until_complete(asyncio.gather(server,
|
||||
app.startup(),
|
||||
loop=loop))
|
||||
|
||||
scheme = 'https' if ssl_context else 'http'
|
||||
print("======== Running on {scheme}://{host}:{port}/ ========\n"
|
||||
"(Press CTRL+C to quit)".format(
|
||||
scheme=scheme, host=host, port=port))
|
||||
|
||||
try:
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt: # pragma: no cover
|
||||
pass
|
||||
finally:
|
||||
srv.close()
|
||||
loop.run_until_complete(srv.wait_closed())
|
||||
loop.run_until_complete(app.shutdown())
|
||||
loop.run_until_complete(handler.finish_connections(shutdown_timeout))
|
||||
loop.run_until_complete(app.cleanup())
|
||||
loop.close()
|
||||
|
||||
|
||||
def main(argv):
|
||||
arg_parser = ArgumentParser(
|
||||
description="aiohttp.web Application server",
|
||||
prog="aiohttp.web"
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"entry_func",
|
||||
help=("Callable returning the `aiohttp.web.Application` instance to "
|
||||
"run. Should be specified in the 'module:function' syntax."),
|
||||
metavar="entry-func"
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"-H", "--hostname",
|
||||
help="TCP/IP hostname to serve on (default: %(default)r)",
|
||||
default="localhost"
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"-P", "--port",
|
||||
help="TCP/IP port to serve on (default: %(default)r)",
|
||||
type=int,
|
||||
default="8080"
|
||||
)
|
||||
args, extra_argv = arg_parser.parse_known_args(argv)
|
||||
|
||||
# Import logic
|
||||
mod_str, _, func_str = args.entry_func.partition(":")
|
||||
if not func_str or not mod_str:
|
||||
arg_parser.error(
|
||||
"'entry-func' not in 'module:function' syntax"
|
||||
)
|
||||
if mod_str.startswith("."):
|
||||
arg_parser.error("relative module names not supported")
|
||||
try:
|
||||
module = import_module(mod_str)
|
||||
except ImportError:
|
||||
arg_parser.error("module %r not found" % mod_str)
|
||||
try:
|
||||
func = getattr(module, func_str)
|
||||
except AttributeError:
|
||||
arg_parser.error("module %r has no attribute %r" % (mod_str, func_str))
|
||||
|
||||
app = func(extra_argv)
|
||||
run_app(app, host=args.hostname, port=args.port)
|
||||
arg_parser.exit(message="Stopped\n")
|
||||
|
||||
if __name__ == "__main__": # pragma: no branch
|
||||
main(sys.argv[1:]) # pragma: no cover
|
||||
|
|
@ -0,0 +1,349 @@
|
|||
from .web_reqrep import Response
|
||||
|
||||
__all__ = (
|
||||
'HTTPException',
|
||||
'HTTPError',
|
||||
'HTTPRedirection',
|
||||
'HTTPSuccessful',
|
||||
'HTTPOk',
|
||||
'HTTPCreated',
|
||||
'HTTPAccepted',
|
||||
'HTTPNonAuthoritativeInformation',
|
||||
'HTTPNoContent',
|
||||
'HTTPResetContent',
|
||||
'HTTPPartialContent',
|
||||
'HTTPMultipleChoices',
|
||||
'HTTPMovedPermanently',
|
||||
'HTTPFound',
|
||||
'HTTPSeeOther',
|
||||
'HTTPNotModified',
|
||||
'HTTPUseProxy',
|
||||
'HTTPTemporaryRedirect',
|
||||
'HTTPPermanentRedirect',
|
||||
'HTTPClientError',
|
||||
'HTTPBadRequest',
|
||||
'HTTPUnauthorized',
|
||||
'HTTPPaymentRequired',
|
||||
'HTTPForbidden',
|
||||
'HTTPNotFound',
|
||||
'HTTPMethodNotAllowed',
|
||||
'HTTPNotAcceptable',
|
||||
'HTTPProxyAuthenticationRequired',
|
||||
'HTTPRequestTimeout',
|
||||
'HTTPConflict',
|
||||
'HTTPGone',
|
||||
'HTTPLengthRequired',
|
||||
'HTTPPreconditionFailed',
|
||||
'HTTPRequestEntityTooLarge',
|
||||
'HTTPRequestURITooLong',
|
||||
'HTTPUnsupportedMediaType',
|
||||
'HTTPRequestRangeNotSatisfiable',
|
||||
'HTTPExpectationFailed',
|
||||
'HTTPMisdirectedRequest',
|
||||
'HTTPUpgradeRequired',
|
||||
'HTTPPreconditionRequired',
|
||||
'HTTPTooManyRequests',
|
||||
'HTTPRequestHeaderFieldsTooLarge',
|
||||
'HTTPUnavailableForLegalReasons',
|
||||
'HTTPServerError',
|
||||
'HTTPInternalServerError',
|
||||
'HTTPNotImplemented',
|
||||
'HTTPBadGateway',
|
||||
'HTTPServiceUnavailable',
|
||||
'HTTPGatewayTimeout',
|
||||
'HTTPVersionNotSupported',
|
||||
'HTTPVariantAlsoNegotiates',
|
||||
'HTTPNotExtended',
|
||||
'HTTPNetworkAuthenticationRequired',
|
||||
)
|
||||
|
||||
|
||||
############################################################
|
||||
# HTTP Exceptions
|
||||
############################################################
|
||||
|
||||
class HTTPException(Response, Exception):
|
||||
|
||||
# You should set in subclasses:
|
||||
# status = 200
|
||||
|
||||
status_code = None
|
||||
empty_body = False
|
||||
|
||||
def __init__(self, *, headers=None, reason=None,
|
||||
body=None, text=None, content_type=None):
|
||||
Response.__init__(self, status=self.status_code,
|
||||
headers=headers, reason=reason,
|
||||
body=body, text=text, content_type=content_type)
|
||||
Exception.__init__(self, self.reason)
|
||||
if self.body is None and not self.empty_body:
|
||||
self.text = "{}: {}".format(self.status, self.reason)
|
||||
|
||||
|
||||
class HTTPError(HTTPException):
|
||||
"""Base class for exceptions with status codes in the 400s and 500s."""
|
||||
|
||||
|
||||
class HTTPRedirection(HTTPException):
|
||||
"""Base class for exceptions with status codes in the 300s."""
|
||||
|
||||
|
||||
class HTTPSuccessful(HTTPException):
|
||||
"""Base class for exceptions with status codes in the 200s."""
|
||||
|
||||
|
||||
class HTTPOk(HTTPSuccessful):
|
||||
status_code = 200
|
||||
|
||||
|
||||
class HTTPCreated(HTTPSuccessful):
|
||||
status_code = 201
|
||||
|
||||
|
||||
class HTTPAccepted(HTTPSuccessful):
|
||||
status_code = 202
|
||||
|
||||
|
||||
class HTTPNonAuthoritativeInformation(HTTPSuccessful):
|
||||
status_code = 203
|
||||
|
||||
|
||||
class HTTPNoContent(HTTPSuccessful):
|
||||
status_code = 204
|
||||
empty_body = True
|
||||
|
||||
|
||||
class HTTPResetContent(HTTPSuccessful):
|
||||
status_code = 205
|
||||
empty_body = True
|
||||
|
||||
|
||||
class HTTPPartialContent(HTTPSuccessful):
|
||||
status_code = 206
|
||||
|
||||
|
||||
############################################################
|
||||
# 3xx redirection
|
||||
############################################################
|
||||
|
||||
|
||||
class _HTTPMove(HTTPRedirection):
|
||||
|
||||
def __init__(self, location, *, headers=None, reason=None,
|
||||
body=None, text=None, content_type=None):
|
||||
if not location:
|
||||
raise ValueError("HTTP redirects need a location to redirect to.")
|
||||
super().__init__(headers=headers, reason=reason,
|
||||
body=body, text=text, content_type=content_type)
|
||||
self.headers['Location'] = location
|
||||
self.location = location
|
||||
|
||||
|
||||
class HTTPMultipleChoices(_HTTPMove):
|
||||
status_code = 300
|
||||
|
||||
|
||||
class HTTPMovedPermanently(_HTTPMove):
|
||||
status_code = 301
|
||||
|
||||
|
||||
class HTTPFound(_HTTPMove):
|
||||
status_code = 302
|
||||
|
||||
|
||||
# This one is safe after a POST (the redirected location will be
|
||||
# retrieved with GET):
|
||||
class HTTPSeeOther(_HTTPMove):
|
||||
status_code = 303
|
||||
|
||||
|
||||
class HTTPNotModified(HTTPRedirection):
|
||||
# FIXME: this should include a date or etag header
|
||||
status_code = 304
|
||||
empty_body = True
|
||||
|
||||
|
||||
class HTTPUseProxy(_HTTPMove):
|
||||
# Not a move, but looks a little like one
|
||||
status_code = 305
|
||||
|
||||
|
||||
class HTTPTemporaryRedirect(_HTTPMove):
|
||||
status_code = 307
|
||||
|
||||
|
||||
class HTTPPermanentRedirect(_HTTPMove):
|
||||
status_code = 308
|
||||
|
||||
|
||||
############################################################
|
||||
# 4xx client error
|
||||
############################################################
|
||||
|
||||
|
||||
class HTTPClientError(HTTPError):
|
||||
pass
|
||||
|
||||
|
||||
class HTTPBadRequest(HTTPClientError):
|
||||
status_code = 400
|
||||
|
||||
|
||||
class HTTPUnauthorized(HTTPClientError):
|
||||
status_code = 401
|
||||
|
||||
|
||||
class HTTPPaymentRequired(HTTPClientError):
|
||||
status_code = 402
|
||||
|
||||
|
||||
class HTTPForbidden(HTTPClientError):
|
||||
status_code = 403
|
||||
|
||||
|
||||
class HTTPNotFound(HTTPClientError):
|
||||
status_code = 404
|
||||
|
||||
|
||||
class HTTPMethodNotAllowed(HTTPClientError):
|
||||
status_code = 405
|
||||
|
||||
def __init__(self, method, allowed_methods, *, headers=None, reason=None,
|
||||
body=None, text=None, content_type=None):
|
||||
allow = ','.join(sorted(allowed_methods))
|
||||
super().__init__(headers=headers, reason=reason,
|
||||
body=body, text=text, content_type=content_type)
|
||||
self.headers['Allow'] = allow
|
||||
self.allowed_methods = allowed_methods
|
||||
self.method = method.upper()
|
||||
|
||||
|
||||
class HTTPNotAcceptable(HTTPClientError):
|
||||
status_code = 406
|
||||
|
||||
|
||||
class HTTPProxyAuthenticationRequired(HTTPClientError):
|
||||
status_code = 407
|
||||
|
||||
|
||||
class HTTPRequestTimeout(HTTPClientError):
|
||||
status_code = 408
|
||||
|
||||
|
||||
class HTTPConflict(HTTPClientError):
|
||||
status_code = 409
|
||||
|
||||
|
||||
class HTTPGone(HTTPClientError):
|
||||
status_code = 410
|
||||
|
||||
|
||||
class HTTPLengthRequired(HTTPClientError):
|
||||
status_code = 411
|
||||
|
||||
|
||||
class HTTPPreconditionFailed(HTTPClientError):
|
||||
status_code = 412
|
||||
|
||||
|
||||
class HTTPRequestEntityTooLarge(HTTPClientError):
|
||||
status_code = 413
|
||||
|
||||
|
||||
class HTTPRequestURITooLong(HTTPClientError):
|
||||
status_code = 414
|
||||
|
||||
|
||||
class HTTPUnsupportedMediaType(HTTPClientError):
|
||||
status_code = 415
|
||||
|
||||
|
||||
class HTTPRequestRangeNotSatisfiable(HTTPClientError):
|
||||
status_code = 416
|
||||
|
||||
|
||||
class HTTPExpectationFailed(HTTPClientError):
|
||||
status_code = 417
|
||||
|
||||
|
||||
class HTTPMisdirectedRequest(HTTPClientError):
|
||||
status_code = 421
|
||||
|
||||
|
||||
class HTTPUpgradeRequired(HTTPClientError):
|
||||
status_code = 426
|
||||
|
||||
|
||||
class HTTPPreconditionRequired(HTTPClientError):
|
||||
status_code = 428
|
||||
|
||||
|
||||
class HTTPTooManyRequests(HTTPClientError):
|
||||
status_code = 429
|
||||
|
||||
|
||||
class HTTPRequestHeaderFieldsTooLarge(HTTPClientError):
|
||||
status_code = 431
|
||||
|
||||
|
||||
class HTTPUnavailableForLegalReasons(HTTPClientError):
|
||||
status_code = 451
|
||||
|
||||
def __init__(self, link, *, headers=None, reason=None,
|
||||
body=None, text=None, content_type=None):
|
||||
super().__init__(headers=headers, reason=reason,
|
||||
body=body, text=text, content_type=content_type)
|
||||
self.headers['Link'] = '<%s>; rel="blocked-by"' % link
|
||||
self.link = link
|
||||
|
||||
|
||||
############################################################
|
||||
# 5xx Server Error
|
||||
############################################################
|
||||
# Response status codes beginning with the digit "5" indicate cases in
|
||||
# which the server is aware that it has erred or is incapable of
|
||||
# performing the request. Except when responding to a HEAD request, the
|
||||
# server SHOULD include an entity containing an explanation of the error
|
||||
# situation, and whether it is a temporary or permanent condition. User
|
||||
# agents SHOULD display any included entity to the user. These response
|
||||
# codes are applicable to any request method.
|
||||
|
||||
|
||||
class HTTPServerError(HTTPError):
|
||||
pass
|
||||
|
||||
|
||||
class HTTPInternalServerError(HTTPServerError):
|
||||
status_code = 500
|
||||
|
||||
|
||||
class HTTPNotImplemented(HTTPServerError):
|
||||
status_code = 501
|
||||
|
||||
|
||||
class HTTPBadGateway(HTTPServerError):
|
||||
status_code = 502
|
||||
|
||||
|
||||
class HTTPServiceUnavailable(HTTPServerError):
|
||||
status_code = 503
|
||||
|
||||
|
||||
class HTTPGatewayTimeout(HTTPServerError):
|
||||
status_code = 504
|
||||
|
||||
|
||||
class HTTPVersionNotSupported(HTTPServerError):
|
||||
status_code = 505
|
||||
|
||||
|
||||
class HTTPVariantAlsoNegotiates(HTTPServerError):
|
||||
status_code = 506
|
||||
|
||||
|
||||
class HTTPNotExtended(HTTPServerError):
|
||||
status_code = 510
|
||||
|
||||
|
||||
class HTTPNetworkAuthenticationRequired(HTTPServerError):
|
||||
status_code = 511
|
||||
|
|
@ -0,0 +1,895 @@
|
|||
import asyncio
|
||||
import binascii
|
||||
import cgi
|
||||
import collections
|
||||
import datetime
|
||||
import enum
|
||||
import http.cookies
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import time
|
||||
import warnings
|
||||
from email.utils import parsedate
|
||||
from types import MappingProxyType
|
||||
from urllib.parse import parse_qsl, unquote, urlsplit
|
||||
|
||||
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
|
||||
|
||||
from . import hdrs, multipart
|
||||
from .helpers import reify, sentinel
|
||||
from .protocol import Response as ResponseImpl
|
||||
from .protocol import HttpVersion10, HttpVersion11
|
||||
from .streams import EOF_MARKER
|
||||
|
||||
__all__ = (
|
||||
'ContentCoding', 'Request', 'StreamResponse', 'Response',
|
||||
'json_response'
|
||||
)
|
||||
|
||||
|
||||
class HeadersMixin:
|
||||
|
||||
_content_type = None
|
||||
_content_dict = None
|
||||
_stored_content_type = sentinel
|
||||
|
||||
def _parse_content_type(self, raw):
|
||||
self._stored_content_type = raw
|
||||
if raw is None:
|
||||
# default value according to RFC 2616
|
||||
self._content_type = 'application/octet-stream'
|
||||
self._content_dict = {}
|
||||
else:
|
||||
self._content_type, self._content_dict = cgi.parse_header(raw)
|
||||
|
||||
@property
|
||||
def content_type(self, _CONTENT_TYPE=hdrs.CONTENT_TYPE):
|
||||
"""The value of content part for Content-Type HTTP header."""
|
||||
raw = self.headers.get(_CONTENT_TYPE)
|
||||
if self._stored_content_type != raw:
|
||||
self._parse_content_type(raw)
|
||||
return self._content_type
|
||||
|
||||
@property
|
||||
def charset(self, _CONTENT_TYPE=hdrs.CONTENT_TYPE):
|
||||
"""The value of charset part for Content-Type HTTP header."""
|
||||
raw = self.headers.get(_CONTENT_TYPE)
|
||||
if self._stored_content_type != raw:
|
||||
self._parse_content_type(raw)
|
||||
return self._content_dict.get('charset')
|
||||
|
||||
@property
|
||||
def content_length(self, _CONTENT_LENGTH=hdrs.CONTENT_LENGTH):
|
||||
"""The value of Content-Length HTTP header."""
|
||||
l = self.headers.get(_CONTENT_LENGTH)
|
||||
if l is None:
|
||||
return None
|
||||
else:
|
||||
return int(l)
|
||||
|
||||
FileField = collections.namedtuple('Field', 'name filename file content_type')
|
||||
|
||||
|
||||
class ContentCoding(enum.Enum):
|
||||
# The content codings that we have support for.
|
||||
#
|
||||
# Additional registered codings are listed at:
|
||||
# https://www.iana.org/assignments/http-parameters/http-parameters.xhtml#content-coding
|
||||
deflate = 'deflate'
|
||||
gzip = 'gzip'
|
||||
identity = 'identity'
|
||||
|
||||
|
||||
############################################################
|
||||
# HTTP Request
|
||||
############################################################
|
||||
|
||||
|
||||
class Request(dict, HeadersMixin):
|
||||
|
||||
POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT,
|
||||
hdrs.METH_TRACE, hdrs.METH_DELETE}
|
||||
|
||||
def __init__(self, app, message, payload, transport, reader, writer, *,
|
||||
secure_proxy_ssl_header=None):
|
||||
self._app = app
|
||||
self._message = message
|
||||
self._transport = transport
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
self._post = None
|
||||
self._post_files_cache = None
|
||||
|
||||
# matchdict, route_name, handler
|
||||
# or information about traversal lookup
|
||||
self._match_info = None # initialized after route resolving
|
||||
|
||||
self._payload = payload
|
||||
|
||||
self._read_bytes = None
|
||||
self._has_body = not payload.at_eof()
|
||||
|
||||
self._secure_proxy_ssl_header = secure_proxy_ssl_header
|
||||
|
||||
@reify
|
||||
def scheme(self):
|
||||
"""A string representing the scheme of the request.
|
||||
|
||||
'http' or 'https'.
|
||||
"""
|
||||
if self._transport.get_extra_info('sslcontext'):
|
||||
return 'https'
|
||||
secure_proxy_ssl_header = self._secure_proxy_ssl_header
|
||||
if secure_proxy_ssl_header is not None:
|
||||
header, value = secure_proxy_ssl_header
|
||||
if self.headers.get(header) == value:
|
||||
return 'https'
|
||||
return 'http'
|
||||
|
||||
@reify
|
||||
def method(self):
|
||||
"""Read only property for getting HTTP method.
|
||||
|
||||
The value is upper-cased str like 'GET', 'POST', 'PUT' etc.
|
||||
"""
|
||||
return self._message.method
|
||||
|
||||
@reify
|
||||
def version(self):
|
||||
"""Read only property for getting HTTP version of request.
|
||||
|
||||
Returns aiohttp.protocol.HttpVersion instance.
|
||||
"""
|
||||
return self._message.version
|
||||
|
||||
@reify
|
||||
def host(self):
|
||||
"""Read only property for getting *HOST* header of request.
|
||||
|
||||
Returns str or None if HTTP request has no HOST header.
|
||||
"""
|
||||
return self._message.headers.get(hdrs.HOST)
|
||||
|
||||
@reify
|
||||
def path_qs(self):
|
||||
"""The URL including PATH_INFO and the query string.
|
||||
|
||||
E.g, /app/blog?id=10
|
||||
"""
|
||||
return self._message.path
|
||||
|
||||
@reify
|
||||
def _splitted_path(self):
|
||||
url = '{}://{}{}'.format(self.scheme, self.host, self.path_qs)
|
||||
return urlsplit(url)
|
||||
|
||||
@reify
|
||||
def raw_path(self):
|
||||
""" The URL including raw *PATH INFO* without the host or scheme.
|
||||
Warning, the path is unquoted and may contains non valid URL characters
|
||||
|
||||
E.g., ``/my%2Fpath%7Cwith%21some%25strange%24characters``
|
||||
"""
|
||||
return self._splitted_path.path
|
||||
|
||||
@reify
|
||||
def path(self):
|
||||
"""The URL including *PATH INFO* without the host or scheme.
|
||||
|
||||
E.g., ``/app/blog``
|
||||
"""
|
||||
return unquote(self.raw_path)
|
||||
|
||||
@reify
|
||||
def query_string(self):
|
||||
"""The query string in the URL.
|
||||
|
||||
E.g., id=10
|
||||
"""
|
||||
return self._splitted_path.query
|
||||
|
||||
@reify
|
||||
def GET(self):
|
||||
"""A multidict with all the variables in the query string.
|
||||
|
||||
Lazy property.
|
||||
"""
|
||||
return MultiDictProxy(MultiDict(parse_qsl(self.query_string,
|
||||
keep_blank_values=True)))
|
||||
|
||||
@reify
|
||||
def POST(self):
|
||||
"""A multidict with all the variables in the POST parameters.
|
||||
|
||||
post() methods has to be called before using this attribute.
|
||||
"""
|
||||
if self._post is None:
|
||||
raise RuntimeError("POST is not available before post()")
|
||||
return self._post
|
||||
|
||||
@reify
|
||||
def headers(self):
|
||||
"""A case-insensitive multidict proxy with all headers."""
|
||||
return CIMultiDictProxy(self._message.headers)
|
||||
|
||||
@reify
|
||||
def raw_headers(self):
|
||||
"""A sequence of pars for all headers."""
|
||||
return tuple(self._message.raw_headers)
|
||||
|
||||
@reify
|
||||
def if_modified_since(self, _IF_MODIFIED_SINCE=hdrs.IF_MODIFIED_SINCE):
|
||||
"""The value of If-Modified-Since HTTP header, or None.
|
||||
|
||||
This header is represented as a `datetime` object.
|
||||
"""
|
||||
httpdate = self.headers.get(_IF_MODIFIED_SINCE)
|
||||
if httpdate is not None:
|
||||
timetuple = parsedate(httpdate)
|
||||
if timetuple is not None:
|
||||
return datetime.datetime(*timetuple[:6],
|
||||
tzinfo=datetime.timezone.utc)
|
||||
return None
|
||||
|
||||
@reify
|
||||
def keep_alive(self):
|
||||
"""Is keepalive enabled by client?"""
|
||||
if self.version < HttpVersion10:
|
||||
return False
|
||||
else:
|
||||
return not self._message.should_close
|
||||
|
||||
@property
|
||||
def match_info(self):
|
||||
"""Result of route resolving."""
|
||||
return self._match_info
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
"""Application instance."""
|
||||
return self._app
|
||||
|
||||
@property
|
||||
def transport(self):
|
||||
"""Transport used for request processing."""
|
||||
return self._transport
|
||||
|
||||
@reify
|
||||
def cookies(self):
|
||||
"""Return request cookies.
|
||||
|
||||
A read-only dictionary-like object.
|
||||
"""
|
||||
raw = self.headers.get(hdrs.COOKIE, '')
|
||||
parsed = http.cookies.SimpleCookie(raw)
|
||||
return MappingProxyType(
|
||||
{key: val.value for key, val in parsed.items()})
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
"""Return raw payload stream."""
|
||||
return self._payload
|
||||
|
||||
@property
|
||||
def has_body(self):
|
||||
"""Return True if request has HTTP BODY, False otherwise."""
|
||||
return self._has_body
|
||||
|
||||
@asyncio.coroutine
|
||||
def release(self):
|
||||
"""Release request.
|
||||
|
||||
Eat unread part of HTTP BODY if present.
|
||||
"""
|
||||
chunk = yield from self._payload.readany()
|
||||
while chunk is not EOF_MARKER or chunk:
|
||||
chunk = yield from self._payload.readany()
|
||||
|
||||
@asyncio.coroutine
|
||||
def read(self):
|
||||
"""Read request body if present.
|
||||
|
||||
Returns bytes object with full request content.
|
||||
"""
|
||||
if self._read_bytes is None:
|
||||
body = bytearray()
|
||||
while True:
|
||||
chunk = yield from self._payload.readany()
|
||||
body.extend(chunk)
|
||||
if chunk is EOF_MARKER:
|
||||
break
|
||||
self._read_bytes = bytes(body)
|
||||
return self._read_bytes
|
||||
|
||||
@asyncio.coroutine
|
||||
def text(self):
|
||||
"""Return BODY as text using encoding from .charset."""
|
||||
bytes_body = yield from self.read()
|
||||
encoding = self.charset or 'utf-8'
|
||||
return bytes_body.decode(encoding)
|
||||
|
||||
@asyncio.coroutine
|
||||
def json(self, *, loads=json.loads, loader=None):
|
||||
"""Return BODY as JSON."""
|
||||
if loader is not None:
|
||||
warnings.warn(
|
||||
"Using loader argument is deprecated, use loads instead",
|
||||
DeprecationWarning)
|
||||
loads = loader
|
||||
body = yield from self.text()
|
||||
return loads(body)
|
||||
|
||||
@asyncio.coroutine
|
||||
def multipart(self, *, reader=multipart.MultipartReader):
|
||||
"""Return async iterator to process BODY as multipart."""
|
||||
return reader(self.headers, self.content)
|
||||
|
||||
@asyncio.coroutine
|
||||
def post(self):
|
||||
"""Return POST parameters."""
|
||||
if self._post is not None:
|
||||
return self._post
|
||||
if self.method not in self.POST_METHODS:
|
||||
self._post = MultiDictProxy(MultiDict())
|
||||
return self._post
|
||||
|
||||
content_type = self.content_type
|
||||
if (content_type not in ('',
|
||||
'application/x-www-form-urlencoded',
|
||||
'multipart/form-data')):
|
||||
self._post = MultiDictProxy(MultiDict())
|
||||
return self._post
|
||||
|
||||
if self.content_type.startswith('multipart/'):
|
||||
warnings.warn('To process multipart requests use .multipart'
|
||||
' coroutine instead.', DeprecationWarning)
|
||||
|
||||
body = yield from self.read()
|
||||
content_charset = self.charset or 'utf-8'
|
||||
|
||||
environ = {'REQUEST_METHOD': self.method,
|
||||
'CONTENT_LENGTH': str(len(body)),
|
||||
'QUERY_STRING': '',
|
||||
'CONTENT_TYPE': self.headers.get(hdrs.CONTENT_TYPE)}
|
||||
|
||||
fs = cgi.FieldStorage(fp=io.BytesIO(body),
|
||||
environ=environ,
|
||||
keep_blank_values=True,
|
||||
encoding=content_charset)
|
||||
|
||||
supported_transfer_encoding = {
|
||||
'base64': binascii.a2b_base64,
|
||||
'quoted-printable': binascii.a2b_qp
|
||||
}
|
||||
|
||||
out = MultiDict()
|
||||
_count = 1
|
||||
for field in fs.list or ():
|
||||
transfer_encoding = field.headers.get(
|
||||
hdrs.CONTENT_TRANSFER_ENCODING, None)
|
||||
if field.filename:
|
||||
ff = FileField(field.name,
|
||||
field.filename,
|
||||
field.file, # N.B. file closed error
|
||||
field.type)
|
||||
if self._post_files_cache is None:
|
||||
self._post_files_cache = {}
|
||||
self._post_files_cache[field.name+str(_count)] = field
|
||||
_count += 1
|
||||
out.add(field.name, ff)
|
||||
else:
|
||||
value = field.value
|
||||
if transfer_encoding in supported_transfer_encoding:
|
||||
# binascii accepts bytes
|
||||
value = value.encode('utf-8')
|
||||
value = supported_transfer_encoding[
|
||||
transfer_encoding](value)
|
||||
out.add(field.name, value)
|
||||
|
||||
self._post = MultiDictProxy(out)
|
||||
return self._post
|
||||
|
||||
def copy(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
ascii_encodable_path = self.path.encode('ascii', 'backslashreplace') \
|
||||
.decode('ascii')
|
||||
return "<{} {} {} >".format(self.__class__.__name__,
|
||||
self.method, ascii_encodable_path)
|
||||
|
||||
|
||||
############################################################
|
||||
# HTTP Response classes
|
||||
############################################################
|
||||
|
||||
|
||||
class StreamResponse(HeadersMixin):
|
||||
|
||||
def __init__(self, *, status=200, reason=None, headers=None):
|
||||
self._body = None
|
||||
self._keep_alive = None
|
||||
self._chunked = False
|
||||
self._chunk_size = None
|
||||
self._compression = False
|
||||
self._compression_force = False
|
||||
self._headers = CIMultiDict()
|
||||
self._cookies = http.cookies.SimpleCookie()
|
||||
self.set_status(status, reason)
|
||||
|
||||
self._req = None
|
||||
self._resp_impl = None
|
||||
self._eof_sent = False
|
||||
self._tcp_nodelay = True
|
||||
self._tcp_cork = False
|
||||
|
||||
if headers is not None:
|
||||
self._headers.extend(headers)
|
||||
self._parse_content_type(self._headers.get(hdrs.CONTENT_TYPE))
|
||||
self._generate_content_type_header()
|
||||
|
||||
def _copy_cookies(self):
|
||||
for cookie in self._cookies.values():
|
||||
value = cookie.output(header='')[1:]
|
||||
self.headers.add(hdrs.SET_COOKIE, value)
|
||||
|
||||
@property
|
||||
def prepared(self):
|
||||
return self._resp_impl is not None
|
||||
|
||||
@property
|
||||
def started(self):
|
||||
warnings.warn('use Response.prepared instead', DeprecationWarning)
|
||||
return self.prepared
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
return self._status
|
||||
|
||||
@property
|
||||
def chunked(self):
|
||||
return self._chunked
|
||||
|
||||
@property
|
||||
def compression(self):
|
||||
return self._compression
|
||||
|
||||
@property
|
||||
def reason(self):
|
||||
return self._reason
|
||||
|
||||
def set_status(self, status, reason=None):
|
||||
self._status = int(status)
|
||||
if reason is None:
|
||||
reason = ResponseImpl.calc_reason(status)
|
||||
self._reason = reason
|
||||
|
||||
@property
|
||||
def keep_alive(self):
|
||||
return self._keep_alive
|
||||
|
||||
def force_close(self):
|
||||
self._keep_alive = False
|
||||
|
||||
def enable_chunked_encoding(self, chunk_size=None):
|
||||
"""Enables automatic chunked transfer encoding."""
|
||||
self._chunked = True
|
||||
self._chunk_size = chunk_size
|
||||
|
||||
def enable_compression(self, force=None):
|
||||
"""Enables response compression encoding."""
|
||||
# Backwards compatibility for when force was a bool <0.17.
|
||||
if type(force) == bool:
|
||||
force = ContentCoding.deflate if force else ContentCoding.identity
|
||||
elif force is not None:
|
||||
assert isinstance(force, ContentCoding), ("force should one of "
|
||||
"None, bool or "
|
||||
"ContentEncoding")
|
||||
|
||||
self._compression = True
|
||||
self._compression_force = force
|
||||
|
||||
@property
|
||||
def headers(self):
|
||||
return self._headers
|
||||
|
||||
@property
|
||||
def cookies(self):
|
||||
return self._cookies
|
||||
|
||||
def set_cookie(self, name, value, *, expires=None,
|
||||
domain=None, max_age=None, path='/',
|
||||
secure=None, httponly=None, version=None):
|
||||
"""Set or update response cookie.
|
||||
|
||||
Sets new cookie or updates existent with new value.
|
||||
Also updates only those params which are not None.
|
||||
"""
|
||||
|
||||
old = self._cookies.get(name)
|
||||
if old is not None and old.coded_value == '':
|
||||
# deleted cookie
|
||||
self._cookies.pop(name, None)
|
||||
|
||||
self._cookies[name] = value
|
||||
c = self._cookies[name]
|
||||
|
||||
if expires is not None:
|
||||
c['expires'] = expires
|
||||
elif c.get('expires') == 'Thu, 01 Jan 1970 00:00:00 GMT':
|
||||
del c['expires']
|
||||
|
||||
if domain is not None:
|
||||
c['domain'] = domain
|
||||
|
||||
if max_age is not None:
|
||||
c['max-age'] = max_age
|
||||
elif 'max-age' in c:
|
||||
del c['max-age']
|
||||
|
||||
c['path'] = path
|
||||
|
||||
if secure is not None:
|
||||
c['secure'] = secure
|
||||
if httponly is not None:
|
||||
c['httponly'] = httponly
|
||||
if version is not None:
|
||||
c['version'] = version
|
||||
|
||||
def del_cookie(self, name, *, domain=None, path='/'):
|
||||
"""Delete cookie.
|
||||
|
||||
Creates new empty expired cookie.
|
||||
"""
|
||||
# TODO: do we need domain/path here?
|
||||
self._cookies.pop(name, None)
|
||||
self.set_cookie(name, '', max_age=0,
|
||||
expires="Thu, 01 Jan 1970 00:00:00 GMT",
|
||||
domain=domain, path=path)
|
||||
|
||||
@property
|
||||
def content_length(self):
|
||||
# Just a placeholder for adding setter
|
||||
return super().content_length
|
||||
|
||||
@content_length.setter
|
||||
def content_length(self, value):
|
||||
if value is not None:
|
||||
value = int(value)
|
||||
# TODO: raise error if chunked enabled
|
||||
self.headers[hdrs.CONTENT_LENGTH] = str(value)
|
||||
else:
|
||||
self.headers.pop(hdrs.CONTENT_LENGTH, None)
|
||||
|
||||
@property
|
||||
def content_type(self):
|
||||
# Just a placeholder for adding setter
|
||||
return super().content_type
|
||||
|
||||
@content_type.setter
|
||||
def content_type(self, value):
|
||||
self.content_type # read header values if needed
|
||||
self._content_type = str(value)
|
||||
self._generate_content_type_header()
|
||||
|
||||
@property
|
||||
def charset(self):
|
||||
# Just a placeholder for adding setter
|
||||
return super().charset
|
||||
|
||||
@charset.setter
|
||||
def charset(self, value):
|
||||
ctype = self.content_type # read header values if needed
|
||||
if ctype == 'application/octet-stream':
|
||||
raise RuntimeError("Setting charset for application/octet-stream "
|
||||
"doesn't make sense, setup content_type first")
|
||||
if value is None:
|
||||
self._content_dict.pop('charset', None)
|
||||
else:
|
||||
self._content_dict['charset'] = str(value).lower()
|
||||
self._generate_content_type_header()
|
||||
|
||||
@property
|
||||
def last_modified(self, _LAST_MODIFIED=hdrs.LAST_MODIFIED):
|
||||
"""The value of Last-Modified HTTP header, or None.
|
||||
|
||||
This header is represented as a `datetime` object.
|
||||
"""
|
||||
httpdate = self.headers.get(_LAST_MODIFIED)
|
||||
if httpdate is not None:
|
||||
timetuple = parsedate(httpdate)
|
||||
if timetuple is not None:
|
||||
return datetime.datetime(*timetuple[:6],
|
||||
tzinfo=datetime.timezone.utc)
|
||||
return None
|
||||
|
||||
@last_modified.setter
|
||||
def last_modified(self, value):
|
||||
if value is None:
|
||||
self.headers.pop(hdrs.LAST_MODIFIED, None)
|
||||
elif isinstance(value, (int, float)):
|
||||
self.headers[hdrs.LAST_MODIFIED] = time.strftime(
|
||||
"%a, %d %b %Y %H:%M:%S GMT", time.gmtime(math.ceil(value)))
|
||||
elif isinstance(value, datetime.datetime):
|
||||
self.headers[hdrs.LAST_MODIFIED] = time.strftime(
|
||||
"%a, %d %b %Y %H:%M:%S GMT", value.utctimetuple())
|
||||
elif isinstance(value, str):
|
||||
self.headers[hdrs.LAST_MODIFIED] = value
|
||||
|
||||
@property
|
||||
def tcp_nodelay(self):
|
||||
return self._tcp_nodelay
|
||||
|
||||
def set_tcp_nodelay(self, value):
|
||||
value = bool(value)
|
||||
self._tcp_nodelay = value
|
||||
if value:
|
||||
self._tcp_cork = False
|
||||
if self._resp_impl is None:
|
||||
return
|
||||
if value:
|
||||
self._resp_impl.transport.set_tcp_cork(False)
|
||||
self._resp_impl.transport.set_tcp_nodelay(value)
|
||||
|
||||
@property
|
||||
def tcp_cork(self):
|
||||
return self._tcp_cork
|
||||
|
||||
def set_tcp_cork(self, value):
|
||||
value = bool(value)
|
||||
self._tcp_cork = value
|
||||
if value:
|
||||
self._tcp_nodelay = False
|
||||
if self._resp_impl is None:
|
||||
return
|
||||
if value:
|
||||
self._resp_impl.transport.set_tcp_nodelay(False)
|
||||
self._resp_impl.transport.set_tcp_cork(value)
|
||||
|
||||
def _generate_content_type_header(self, CONTENT_TYPE=hdrs.CONTENT_TYPE):
|
||||
params = '; '.join("%s=%s" % i for i in self._content_dict.items())
|
||||
if params:
|
||||
ctype = self._content_type + '; ' + params
|
||||
else:
|
||||
ctype = self._content_type
|
||||
self.headers[CONTENT_TYPE] = ctype
|
||||
|
||||
def _start_pre_check(self, request):
|
||||
if self._resp_impl is not None:
|
||||
if self._req is not request:
|
||||
raise RuntimeError(
|
||||
"Response has been started with different request.")
|
||||
else:
|
||||
return self._resp_impl
|
||||
else:
|
||||
return None
|
||||
|
||||
def _do_start_compression(self, coding):
|
||||
if coding != ContentCoding.identity:
|
||||
self.headers[hdrs.CONTENT_ENCODING] = coding.value
|
||||
self._resp_impl.add_compression_filter(coding.value)
|
||||
self.content_length = None
|
||||
|
||||
def _start_compression(self, request):
|
||||
if self._compression_force:
|
||||
self._do_start_compression(self._compression_force)
|
||||
else:
|
||||
accept_encoding = request.headers.get(
|
||||
hdrs.ACCEPT_ENCODING, '').lower()
|
||||
for coding in ContentCoding:
|
||||
if coding.value in accept_encoding:
|
||||
self._do_start_compression(coding)
|
||||
return
|
||||
|
||||
def start(self, request):
|
||||
warnings.warn('use .prepare(request) instead', DeprecationWarning)
|
||||
resp_impl = self._start_pre_check(request)
|
||||
if resp_impl is not None:
|
||||
return resp_impl
|
||||
|
||||
return self._start(request)
|
||||
|
||||
@asyncio.coroutine
|
||||
def prepare(self, request):
|
||||
resp_impl = self._start_pre_check(request)
|
||||
if resp_impl is not None:
|
||||
return resp_impl
|
||||
yield from request.app.on_response_prepare.send(request, self)
|
||||
|
||||
return self._start(request)
|
||||
|
||||
def _start(self, request):
|
||||
self._req = request
|
||||
keep_alive = self._keep_alive
|
||||
if keep_alive is None:
|
||||
keep_alive = request.keep_alive
|
||||
self._keep_alive = keep_alive
|
||||
|
||||
resp_impl = self._resp_impl = ResponseImpl(
|
||||
request._writer,
|
||||
self._status,
|
||||
request.version,
|
||||
not keep_alive,
|
||||
self._reason)
|
||||
|
||||
self._copy_cookies()
|
||||
|
||||
if self._compression:
|
||||
self._start_compression(request)
|
||||
|
||||
if self._chunked:
|
||||
if request.version != HttpVersion11:
|
||||
raise RuntimeError("Using chunked encoding is forbidden "
|
||||
"for HTTP/{0.major}.{0.minor}".format(
|
||||
request.version))
|
||||
resp_impl.enable_chunked_encoding()
|
||||
if self._chunk_size:
|
||||
resp_impl.add_chunking_filter(self._chunk_size)
|
||||
|
||||
headers = self.headers.items()
|
||||
for key, val in headers:
|
||||
resp_impl.add_header(key, val)
|
||||
|
||||
resp_impl.transport.set_tcp_nodelay(self._tcp_nodelay)
|
||||
resp_impl.transport.set_tcp_cork(self._tcp_cork)
|
||||
self._send_headers(resp_impl)
|
||||
return resp_impl
|
||||
|
||||
def _send_headers(self, resp_impl):
|
||||
# Durty hack required for
|
||||
# https://github.com/KeepSafe/aiohttp/issues/1093
|
||||
# File sender may override it
|
||||
resp_impl.send_headers()
|
||||
|
||||
def write(self, data):
|
||||
assert isinstance(data, (bytes, bytearray, memoryview)), \
|
||||
"data argument must be byte-ish (%r)" % type(data)
|
||||
|
||||
if self._eof_sent:
|
||||
raise RuntimeError("Cannot call write() after write_eof()")
|
||||
if self._resp_impl is None:
|
||||
raise RuntimeError("Cannot call write() before start()")
|
||||
|
||||
if data:
|
||||
return self._resp_impl.write(data)
|
||||
else:
|
||||
return ()
|
||||
|
||||
@asyncio.coroutine
|
||||
def drain(self):
|
||||
if self._resp_impl is None:
|
||||
raise RuntimeError("Response has not been started")
|
||||
yield from self._resp_impl.transport.drain()
|
||||
|
||||
@asyncio.coroutine
|
||||
def write_eof(self):
|
||||
if self._eof_sent:
|
||||
return
|
||||
if self._resp_impl is None:
|
||||
raise RuntimeError("Response has not been started")
|
||||
|
||||
yield from self._resp_impl.write_eof()
|
||||
self._eof_sent = True
|
||||
|
||||
def __repr__(self):
|
||||
if self.started:
|
||||
info = "{} {} ".format(self._req.method, self._req.path)
|
||||
else:
|
||||
info = "not started"
|
||||
return "<{} {} {}>".format(self.__class__.__name__,
|
||||
self.reason, info)
|
||||
|
||||
|
||||
class Response(StreamResponse):
|
||||
|
||||
def __init__(self, *, body=None, status=200,
|
||||
reason=None, text=None, headers=None, content_type=None,
|
||||
charset=None):
|
||||
if body is not None and text is not None:
|
||||
raise ValueError("body and text are not allowed together")
|
||||
|
||||
if headers is None:
|
||||
headers = CIMultiDict()
|
||||
elif not isinstance(headers, (CIMultiDict, CIMultiDictProxy)):
|
||||
headers = CIMultiDict(headers)
|
||||
|
||||
if content_type is not None and ";" in content_type:
|
||||
raise ValueError("charset must not be in content_type "
|
||||
"argument")
|
||||
|
||||
if text is not None:
|
||||
if hdrs.CONTENT_TYPE in headers:
|
||||
if content_type or charset:
|
||||
raise ValueError("passing both Content-Type header and "
|
||||
"content_type or charset params "
|
||||
"is forbidden")
|
||||
else:
|
||||
# fast path for filling headers
|
||||
if not isinstance(text, str):
|
||||
raise TypeError("text argument must be str (%r)" %
|
||||
type(text))
|
||||
if content_type is None:
|
||||
content_type = 'text/plain'
|
||||
if charset is None:
|
||||
charset = 'utf-8'
|
||||
headers[hdrs.CONTENT_TYPE] = (
|
||||
content_type + '; charset=' + charset)
|
||||
body = text.encode(charset)
|
||||
text = None
|
||||
else:
|
||||
if hdrs.CONTENT_TYPE in headers:
|
||||
if content_type is not None or charset is not None:
|
||||
raise ValueError("passing both Content-Type header and "
|
||||
"content_type or charset params "
|
||||
"is forbidden")
|
||||
else:
|
||||
if content_type is not None:
|
||||
if charset is not None:
|
||||
content_type += '; charset=' + charset
|
||||
headers[hdrs.CONTENT_TYPE] = content_type
|
||||
|
||||
super().__init__(status=status, reason=reason, headers=headers)
|
||||
self.set_tcp_cork(True)
|
||||
if text is not None:
|
||||
self.text = text
|
||||
else:
|
||||
self.body = body
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self._body
|
||||
|
||||
@body.setter
|
||||
def body(self, body):
|
||||
if body is not None and not isinstance(body, bytes):
|
||||
raise TypeError("body argument must be bytes (%r)" % type(body))
|
||||
self._body = body
|
||||
if body is not None:
|
||||
self.content_length = len(body)
|
||||
else:
|
||||
self.content_length = 0
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
if self._body is None:
|
||||
return None
|
||||
return self._body.decode(self.charset or 'utf-8')
|
||||
|
||||
@text.setter
|
||||
def text(self, text):
|
||||
if text is not None and not isinstance(text, str):
|
||||
raise TypeError("text argument must be str (%r)" % type(text))
|
||||
|
||||
if self.content_type == 'application/octet-stream':
|
||||
self.content_type = 'text/plain'
|
||||
if self.charset is None:
|
||||
self.charset = 'utf-8'
|
||||
|
||||
self.body = text.encode(self.charset)
|
||||
|
||||
@asyncio.coroutine
|
||||
def write_eof(self):
|
||||
try:
|
||||
body = self._body
|
||||
if (body is not None and
|
||||
self._req.method != hdrs.METH_HEAD and
|
||||
self._status not in [204, 304]):
|
||||
self.write(body)
|
||||
finally:
|
||||
self.set_tcp_nodelay(True)
|
||||
yield from super().write_eof()
|
||||
|
||||
|
||||
def json_response(data=sentinel, *, text=None, body=None, status=200,
|
||||
reason=None, headers=None, content_type='application/json',
|
||||
dumps=json.dumps):
|
||||
if data is not sentinel:
|
||||
if text or body:
|
||||
raise ValueError(
|
||||
"only one of data, text, or body should be specified"
|
||||
)
|
||||
else:
|
||||
text = dumps(data)
|
||||
return Response(text=text, body=body, status=status, reason=reason,
|
||||
headers=headers, content_type=content_type)
|
||||
|
|
@ -0,0 +1,825 @@
|
|||
import abc
|
||||
import asyncio
|
||||
import collections
|
||||
import inspect
|
||||
import keyword
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Container, Iterable, Sized
|
||||
from pathlib import Path
|
||||
from types import MappingProxyType
|
||||
from urllib.parse import unquote, urlencode
|
||||
|
||||
from . import hdrs
|
||||
from .abc import AbstractMatchInfo, AbstractRouter, AbstractView
|
||||
from .file_sender import FileSender
|
||||
from .protocol import HttpVersion11
|
||||
from .web_exceptions import (HTTPExpectationFailed, HTTPForbidden,
|
||||
HTTPMethodNotAllowed, HTTPNotFound)
|
||||
from .web_reqrep import Response, StreamResponse
|
||||
|
||||
__all__ = ('UrlDispatcher', 'UrlMappingMatchInfo',
|
||||
'AbstractResource', 'Resource', 'PlainResource', 'DynamicResource',
|
||||
'ResourceAdapter',
|
||||
'AbstractRoute', 'ResourceRoute',
|
||||
'Route', 'PlainRoute', 'DynamicRoute', 'StaticRoute', 'View')
|
||||
|
||||
|
||||
PY_35 = sys.version_info >= (3, 5)
|
||||
|
||||
|
||||
HTTP_METHOD_RE = re.compile(r"^[0-9A-Za-z!#\$%&'\*\+\-\.\^_`\|~]+$")
|
||||
|
||||
|
||||
class AbstractResource(Sized, Iterable):
|
||||
|
||||
def __init__(self, *, name=None):
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@abc.abstractmethod # pragma: no branch
|
||||
def url(self, **kwargs):
|
||||
"""Construct url for resource with additional params."""
|
||||
|
||||
@asyncio.coroutine
|
||||
@abc.abstractmethod # pragma: no branch
|
||||
def resolve(self, method, path):
|
||||
"""Resolve resource
|
||||
|
||||
Return (UrlMappingMatchInfo, allowed_methods) pair."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_info(self):
|
||||
"""Return a dict with additional info useful for introspection"""
|
||||
|
||||
@staticmethod
|
||||
def _append_query(url, query):
|
||||
if query:
|
||||
return url + "?" + urlencode(query)
|
||||
else:
|
||||
return url
|
||||
|
||||
|
||||
class AbstractRoute(abc.ABC):
|
||||
|
||||
def __init__(self, method, handler, *,
|
||||
expect_handler=None,
|
||||
resource=None):
|
||||
|
||||
if expect_handler is None:
|
||||
expect_handler = _defaultExpectHandler
|
||||
|
||||
assert asyncio.iscoroutinefunction(expect_handler), \
|
||||
'Coroutine is expected, got {!r}'.format(expect_handler)
|
||||
|
||||
method = method.upper()
|
||||
if not HTTP_METHOD_RE.match(method):
|
||||
raise ValueError("{} is not allowed HTTP method".format(method))
|
||||
|
||||
assert callable(handler), handler
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
pass
|
||||
elif inspect.isgeneratorfunction(handler):
|
||||
warnings.warn("Bare generators are deprecated, "
|
||||
"use @coroutine wrapper", DeprecationWarning)
|
||||
elif (isinstance(handler, type) and
|
||||
issubclass(handler, AbstractView)):
|
||||
pass
|
||||
else:
|
||||
@asyncio.coroutine
|
||||
def handler_wrapper(*args, **kwargs):
|
||||
result = old_handler(*args, **kwargs)
|
||||
if asyncio.iscoroutine(result):
|
||||
result = yield from result
|
||||
return result
|
||||
old_handler = handler
|
||||
handler = handler_wrapper
|
||||
|
||||
self._method = method
|
||||
self._handler = handler
|
||||
self._expect_handler = expect_handler
|
||||
self._resource = resource
|
||||
|
||||
@property
|
||||
def method(self):
|
||||
return self._method
|
||||
|
||||
@property
|
||||
def handler(self):
|
||||
return self._handler
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self):
|
||||
"""Optional route's name, always equals to resource's name."""
|
||||
|
||||
@property
|
||||
def resource(self):
|
||||
return self._resource
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_info(self):
|
||||
"""Return a dict with additional info useful for introspection"""
|
||||
|
||||
@abc.abstractmethod # pragma: no branch
|
||||
def url(self, **kwargs):
|
||||
"""Construct url for route with additional params."""
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_expect_header(self, request):
|
||||
return (yield from self._expect_handler(request))
|
||||
|
||||
|
||||
class UrlMappingMatchInfo(dict, AbstractMatchInfo):
|
||||
|
||||
def __init__(self, match_dict, route):
|
||||
super().__init__(match_dict)
|
||||
self._route = route
|
||||
|
||||
@property
|
||||
def handler(self):
|
||||
return self._route.handler
|
||||
|
||||
@property
|
||||
def route(self):
|
||||
return self._route
|
||||
|
||||
@property
|
||||
def expect_handler(self):
|
||||
return self._route.handle_expect_header
|
||||
|
||||
@property
|
||||
def http_exception(self):
|
||||
return None
|
||||
|
||||
def get_info(self):
|
||||
return self._route.get_info()
|
||||
|
||||
def __repr__(self):
|
||||
return "<MatchInfo {}: {}>".format(super().__repr__(), self._route)
|
||||
|
||||
|
||||
class MatchInfoError(UrlMappingMatchInfo):
|
||||
|
||||
def __init__(self, http_exception):
|
||||
self._exception = http_exception
|
||||
super().__init__({}, SystemRoute(self._exception))
|
||||
|
||||
@property
|
||||
def http_exception(self):
|
||||
return self._exception
|
||||
|
||||
def __repr__(self):
|
||||
return "<MatchInfoError {}: {}>".format(self._exception.status,
|
||||
self._exception.reason)
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def _defaultExpectHandler(request):
|
||||
"""Default handler for Expect header.
|
||||
|
||||
Just send "100 Continue" to client.
|
||||
raise HTTPExpectationFailed if value of header is not "100-continue"
|
||||
"""
|
||||
expect = request.headers.get(hdrs.EXPECT)
|
||||
if request.version == HttpVersion11:
|
||||
if expect.lower() == "100-continue":
|
||||
request.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
|
||||
else:
|
||||
raise HTTPExpectationFailed(text="Unknown Expect: %s" % expect)
|
||||
|
||||
|
||||
class ResourceAdapter(AbstractResource):
|
||||
|
||||
def __init__(self, route):
|
||||
assert isinstance(route, Route), \
|
||||
'Instance of Route class is required, got {!r}'.format(route)
|
||||
super().__init__(name=route.name)
|
||||
self._route = route
|
||||
route._resource = self
|
||||
|
||||
def url(self, **kwargs):
|
||||
return self._route.url(**kwargs)
|
||||
|
||||
@asyncio.coroutine
|
||||
def resolve(self, method, path):
|
||||
route_method = self._route.method
|
||||
allowed_methods = set()
|
||||
match_dict = self._route.match(path)
|
||||
if match_dict is not None:
|
||||
allowed_methods.add(route_method)
|
||||
if route_method == hdrs.METH_ANY or route_method == method:
|
||||
return (UrlMappingMatchInfo(match_dict, self._route),
|
||||
allowed_methods)
|
||||
return None, allowed_methods
|
||||
|
||||
def get_info(self):
|
||||
return self._route.get_info()
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
def __iter__(self):
|
||||
yield self._route
|
||||
|
||||
|
||||
class Resource(AbstractResource):
|
||||
|
||||
def __init__(self, *, name=None):
|
||||
super().__init__(name=name)
|
||||
self._routes = []
|
||||
|
||||
def add_route(self, method, handler, *,
|
||||
expect_handler=None):
|
||||
|
||||
for route in self._routes:
|
||||
if route.method == method or route.method == hdrs.METH_ANY:
|
||||
raise RuntimeError("Added route will never be executed, "
|
||||
"method {route.method} is "
|
||||
"already registered".format(route=route))
|
||||
|
||||
route = ResourceRoute(method, handler, self,
|
||||
expect_handler=expect_handler)
|
||||
self.register_route(route)
|
||||
return route
|
||||
|
||||
def register_route(self, route):
|
||||
assert isinstance(route, ResourceRoute), \
|
||||
'Instance of Route class is required, got {!r}'.format(route)
|
||||
self._routes.append(route)
|
||||
|
||||
@asyncio.coroutine
|
||||
def resolve(self, method, path):
|
||||
allowed_methods = set()
|
||||
|
||||
match_dict = self._match(path)
|
||||
if match_dict is None:
|
||||
return None, allowed_methods
|
||||
|
||||
for route in self._routes:
|
||||
route_method = route.method
|
||||
allowed_methods.add(route_method)
|
||||
|
||||
if route_method == method or route_method == hdrs.METH_ANY:
|
||||
return UrlMappingMatchInfo(match_dict, route), allowed_methods
|
||||
else:
|
||||
return None, allowed_methods
|
||||
|
||||
def __len__(self):
|
||||
return len(self._routes)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._routes)
|
||||
|
||||
|
||||
class PlainResource(Resource):
|
||||
|
||||
def __init__(self, path, *, name=None):
|
||||
super().__init__(name=name)
|
||||
self._path = path
|
||||
|
||||
def _match(self, path):
|
||||
# string comparison is about 10 times faster than regexp matching
|
||||
if self._path == path:
|
||||
return {}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_info(self):
|
||||
return {'path': self._path}
|
||||
|
||||
def url(self, *, query=None):
|
||||
return self._append_query(self._path, query)
|
||||
|
||||
def __repr__(self):
|
||||
name = "'" + self.name + "' " if self.name is not None else ""
|
||||
return "<PlainResource {name} {path}".format(name=name,
|
||||
path=self._path)
|
||||
|
||||
|
||||
class DynamicResource(Resource):
|
||||
|
||||
def __init__(self, pattern, formatter, *, name=None):
|
||||
super().__init__(name=name)
|
||||
self._pattern = pattern
|
||||
self._formatter = formatter
|
||||
|
||||
def _match(self, path):
|
||||
match = self._pattern.match(path)
|
||||
if match is None:
|
||||
return None
|
||||
else:
|
||||
return {key: unquote(value) for key, value in
|
||||
match.groupdict().items()}
|
||||
|
||||
def get_info(self):
|
||||
return {'formatter': self._formatter,
|
||||
'pattern': self._pattern}
|
||||
|
||||
def url(self, *, parts, query=None):
|
||||
url = self._formatter.format_map(parts)
|
||||
return self._append_query(url, query)
|
||||
|
||||
def __repr__(self):
|
||||
name = "'" + self.name + "' " if self.name is not None else ""
|
||||
return ("<DynamicResource {name} {formatter}"
|
||||
.format(name=name, formatter=self._formatter))
|
||||
|
||||
|
||||
class ResourceRoute(AbstractRoute):
|
||||
"""A route with resource"""
|
||||
|
||||
def __init__(self, method, handler, resource, *,
|
||||
expect_handler=None):
|
||||
super().__init__(method, handler, expect_handler=expect_handler,
|
||||
resource=resource)
|
||||
|
||||
def __repr__(self):
|
||||
return "<ResourceRoute [{method}] {resource} -> {handler!r}".format(
|
||||
method=self.method, resource=self._resource,
|
||||
handler=self.handler)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._resource.name
|
||||
|
||||
def url(self, **kwargs):
|
||||
"""Construct url for route with additional params."""
|
||||
return self._resource.url(**kwargs)
|
||||
|
||||
def get_info(self):
|
||||
return self._resource.get_info()
|
||||
|
||||
_append_query = staticmethod(Resource._append_query)
|
||||
|
||||
|
||||
class Route(AbstractRoute):
|
||||
"""Old fashion route"""
|
||||
|
||||
def __init__(self, method, handler, name, *, expect_handler=None):
|
||||
super().__init__(method, handler, expect_handler=expect_handler)
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@abc.abstractmethod
|
||||
def match(self, path):
|
||||
"""Return dict with info for given path or
|
||||
None if route cannot process path."""
|
||||
|
||||
_append_query = staticmethod(Resource._append_query)
|
||||
|
||||
|
||||
class PlainRoute(Route):
|
||||
|
||||
def __init__(self, method, handler, name, path, *, expect_handler=None):
|
||||
super().__init__(method, handler, name, expect_handler=expect_handler)
|
||||
self._path = path
|
||||
|
||||
def match(self, path):
|
||||
# string comparison is about 10 times faster than regexp matching
|
||||
if self._path == path:
|
||||
return {}
|
||||
else:
|
||||
return None
|
||||
|
||||
def url(self, *, query=None):
|
||||
return self._append_query(self._path, query)
|
||||
|
||||
def get_info(self):
|
||||
return {'path': self._path}
|
||||
|
||||
def __repr__(self):
|
||||
name = "'" + self.name + "' " if self.name is not None else ""
|
||||
return "<PlainRoute {name}[{method}] {path} -> {handler!r}".format(
|
||||
name=name, method=self.method, path=self._path,
|
||||
handler=self.handler)
|
||||
|
||||
|
||||
class DynamicRoute(Route):
|
||||
|
||||
def __init__(self, method, handler, name, pattern, formatter, *,
|
||||
expect_handler=None):
|
||||
super().__init__(method, handler, name, expect_handler=expect_handler)
|
||||
self._pattern = pattern
|
||||
self._formatter = formatter
|
||||
|
||||
def match(self, path):
|
||||
match = self._pattern.match(path)
|
||||
if match is None:
|
||||
return None
|
||||
else:
|
||||
return match.groupdict()
|
||||
|
||||
def url(self, *, parts, query=None):
|
||||
url = self._formatter.format_map(parts)
|
||||
return self._append_query(url, query)
|
||||
|
||||
def get_info(self):
|
||||
return {'formatter': self._formatter,
|
||||
'pattern': self._pattern}
|
||||
|
||||
def __repr__(self):
|
||||
name = "'" + self.name + "' " if self.name is not None else ""
|
||||
return ("<DynamicRoute {name}[{method}] {formatter} -> {handler!r}"
|
||||
.format(name=name, method=self.method,
|
||||
formatter=self._formatter, handler=self.handler))
|
||||
|
||||
|
||||
class StaticRoute(Route):
|
||||
|
||||
def __init__(self, name, prefix, directory, *,
|
||||
expect_handler=None, chunk_size=256*1024,
|
||||
response_factory=StreamResponse,
|
||||
show_index=False):
|
||||
assert prefix.startswith('/'), prefix
|
||||
assert prefix.endswith('/'), prefix
|
||||
super().__init__(
|
||||
'GET', self.handle, name, expect_handler=expect_handler)
|
||||
self._prefix = prefix
|
||||
self._prefix_len = len(self._prefix)
|
||||
try:
|
||||
directory = Path(directory)
|
||||
if str(directory).startswith('~'):
|
||||
directory = Path(os.path.expanduser(str(directory)))
|
||||
directory = directory.resolve()
|
||||
if not directory.is_dir():
|
||||
raise ValueError('Not a directory')
|
||||
except (FileNotFoundError, ValueError) as error:
|
||||
raise ValueError(
|
||||
"No directory exists at '{}'".format(directory)) from error
|
||||
self._directory = directory
|
||||
self._file_sender = FileSender(resp_factory=response_factory,
|
||||
chunk_size=chunk_size)
|
||||
self._show_index = show_index
|
||||
|
||||
def match(self, path):
|
||||
if not path.startswith(self._prefix):
|
||||
return None
|
||||
return {'filename': path[self._prefix_len:]}
|
||||
|
||||
def url(self, *, filename, query=None):
|
||||
if isinstance(filename, Path):
|
||||
filename = str(filename)
|
||||
while filename.startswith('/'):
|
||||
filename = filename[1:]
|
||||
url = self._prefix + filename
|
||||
return self._append_query(url, query)
|
||||
|
||||
def get_info(self):
|
||||
return {'directory': self._directory,
|
||||
'prefix': self._prefix}
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle(self, request):
|
||||
filename = unquote(request.match_info['filename'])
|
||||
try:
|
||||
filepath = self._directory.joinpath(filename).resolve()
|
||||
filepath.relative_to(self._directory)
|
||||
except (ValueError, FileNotFoundError) as error:
|
||||
# relatively safe
|
||||
raise HTTPNotFound() from error
|
||||
except Exception as error:
|
||||
# perm error or other kind!
|
||||
request.app.logger.exception(error)
|
||||
raise HTTPNotFound() from error
|
||||
|
||||
# on opening a dir, load it's contents if allowed
|
||||
if filepath.is_dir():
|
||||
if self._show_index:
|
||||
try:
|
||||
ret = Response(text=self._directory_as_html(filepath),
|
||||
content_type="text/html")
|
||||
except PermissionError:
|
||||
raise HTTPForbidden()
|
||||
else:
|
||||
raise HTTPForbidden()
|
||||
elif filepath.is_file():
|
||||
ret = yield from self._file_sender.send(request, filepath)
|
||||
else:
|
||||
raise HTTPNotFound
|
||||
|
||||
return ret
|
||||
|
||||
def _directory_as_html(self, filepath):
|
||||
"returns directory's index as html"
|
||||
# sanity check
|
||||
assert filepath.is_dir()
|
||||
|
||||
posix_dir_len = len(self._directory.as_posix())
|
||||
|
||||
# remove the beginning of posix path, so it would be relative
|
||||
# to our added static path
|
||||
relative_path_to_dir = filepath.as_posix()[posix_dir_len:]
|
||||
index_of = "Index of /{}".format(relative_path_to_dir)
|
||||
head = "<head>\n<title>{}</title>\n</head>".format(index_of)
|
||||
h1 = "<h1>{}</h1>".format(index_of)
|
||||
|
||||
index_list = []
|
||||
dir_index = filepath.iterdir()
|
||||
for _file in sorted(dir_index):
|
||||
# show file url as relative to static path
|
||||
file_url = _file.as_posix()[posix_dir_len:]
|
||||
|
||||
# if file is a directory, add '/' to the end of the name
|
||||
if _file.is_dir():
|
||||
file_name = "{}/".format(_file.name)
|
||||
else:
|
||||
file_name = _file.name
|
||||
|
||||
index_list.append(
|
||||
'<li><a href="{url}">{name}</a></li>'.format(url=file_url,
|
||||
name=file_name)
|
||||
)
|
||||
ul = "<ul>\n{}\n</ul>".format('\n'.join(index_list))
|
||||
body = "<body>\n{}\n{}\n</body>".format(h1, ul)
|
||||
|
||||
html = "<html>\n{}\n{}\n</html>".format(head, body)
|
||||
|
||||
return html
|
||||
|
||||
def __repr__(self):
|
||||
name = "'" + self.name + "' " if self.name is not None else ""
|
||||
return "<StaticRoute {name}[{method}] {path} -> {directory!r}".format(
|
||||
name=name, method=self.method, path=self._prefix,
|
||||
directory=self._directory)
|
||||
|
||||
|
||||
class SystemRoute(Route):
|
||||
|
||||
def __init__(self, http_exception):
|
||||
super().__init__(hdrs.METH_ANY, self._handler, None)
|
||||
self._http_exception = http_exception
|
||||
|
||||
def url(self, **kwargs):
|
||||
raise RuntimeError(".url() is not allowed for SystemRoute")
|
||||
|
||||
def match(self, path):
|
||||
return None
|
||||
|
||||
def get_info(self):
|
||||
return {'http_exception': self._http_exception}
|
||||
|
||||
@asyncio.coroutine
|
||||
def _handler(self, request):
|
||||
raise self._http_exception
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
return self._http_exception.status
|
||||
|
||||
@property
|
||||
def reason(self):
|
||||
return self._http_exception.reason
|
||||
|
||||
def __repr__(self):
|
||||
return "<SystemRoute {self.status}: {self.reason}>".format(self=self)
|
||||
|
||||
|
||||
class View(AbstractView):
|
||||
|
||||
@asyncio.coroutine
|
||||
def __iter__(self):
|
||||
if self.request.method not in hdrs.METH_ALL:
|
||||
self._raise_allowed_methods()
|
||||
method = getattr(self, self.request.method.lower(), None)
|
||||
if method is None:
|
||||
self._raise_allowed_methods()
|
||||
resp = yield from method()
|
||||
return resp
|
||||
|
||||
if PY_35:
|
||||
def __await__(self):
|
||||
return (yield from self.__iter__())
|
||||
|
||||
def _raise_allowed_methods(self):
|
||||
allowed_methods = {
|
||||
m for m in hdrs.METH_ALL if hasattr(self, m.lower())}
|
||||
raise HTTPMethodNotAllowed(self.request.method, allowed_methods)
|
||||
|
||||
|
||||
class ResourcesView(Sized, Iterable, Container):
|
||||
|
||||
def __init__(self, resources):
|
||||
self._resources = resources
|
||||
|
||||
def __len__(self):
|
||||
return len(self._resources)
|
||||
|
||||
def __iter__(self):
|
||||
yield from self._resources
|
||||
|
||||
def __contains__(self, resource):
|
||||
return resource in self._resources
|
||||
|
||||
|
||||
class RoutesView(Sized, Iterable, Container):
|
||||
|
||||
def __init__(self, resources):
|
||||
self._routes = []
|
||||
for resource in resources:
|
||||
for route in resource:
|
||||
self._routes.append(route)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._routes)
|
||||
|
||||
def __iter__(self):
|
||||
yield from self._routes
|
||||
|
||||
def __contains__(self, route):
|
||||
return route in self._routes
|
||||
|
||||
|
||||
class UrlDispatcher(AbstractRouter, collections.abc.Mapping):
|
||||
|
||||
DYN = re.compile(r'^\{(?P<var>[a-zA-Z][_a-zA-Z0-9]*)\}$')
|
||||
DYN_WITH_RE = re.compile(
|
||||
r'^\{(?P<var>[a-zA-Z][_a-zA-Z0-9]*):(?P<re>.+)\}$')
|
||||
GOOD = r'[^{}/]+'
|
||||
ROUTE_RE = re.compile(r'(\{[_a-zA-Z][^{}]*(?:\{[^{}]*\}[^{}]*)*\})')
|
||||
NAME_SPLIT_RE = re.compile('[.:-]')
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._resources = []
|
||||
self._named_resources = {}
|
||||
|
||||
@asyncio.coroutine
|
||||
def resolve(self, request):
|
||||
path = request.raw_path
|
||||
method = request.method
|
||||
allowed_methods = set()
|
||||
|
||||
for resource in self._resources:
|
||||
match_dict, allowed = yield from resource.resolve(method, path)
|
||||
if match_dict is not None:
|
||||
return match_dict
|
||||
else:
|
||||
allowed_methods |= allowed
|
||||
else:
|
||||
if allowed_methods:
|
||||
return MatchInfoError(HTTPMethodNotAllowed(method,
|
||||
allowed_methods))
|
||||
else:
|
||||
return MatchInfoError(HTTPNotFound())
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._named_resources)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._named_resources)
|
||||
|
||||
def __contains__(self, name):
|
||||
return name in self._named_resources
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self._named_resources[name]
|
||||
|
||||
def resources(self):
|
||||
return ResourcesView(self._resources)
|
||||
|
||||
def routes(self):
|
||||
return RoutesView(self._resources)
|
||||
|
||||
def named_resources(self):
|
||||
return MappingProxyType(self._named_resources)
|
||||
|
||||
def named_routes(self):
|
||||
# NB: it's ambiguous but it's really resources.
|
||||
warnings.warn("Use .named_resources instead", DeprecationWarning)
|
||||
return self.named_resources()
|
||||
|
||||
def register_route(self, route):
|
||||
warnings.warn("Use resource-based interface", DeprecationWarning)
|
||||
resource = ResourceAdapter(route)
|
||||
self._reg_resource(resource)
|
||||
|
||||
def _reg_resource(self, resource):
|
||||
assert isinstance(resource, AbstractResource), \
|
||||
'Instance of AbstractResource class is required, got {!r}'.format(
|
||||
resource)
|
||||
|
||||
name = resource.name
|
||||
|
||||
if name is not None:
|
||||
parts = self.NAME_SPLIT_RE.split(name)
|
||||
for part in parts:
|
||||
if not part.isidentifier() or keyword.iskeyword(part):
|
||||
raise ValueError('Incorrect route name {!r}, '
|
||||
'the name should be a sequence of '
|
||||
'python identifiers separated '
|
||||
'by dash, dot or column'.format(name))
|
||||
if name in self._named_resources:
|
||||
raise ValueError('Duplicate {!r}, '
|
||||
'already handled by {!r}'
|
||||
.format(name, self._named_resources[name]))
|
||||
self._named_resources[name] = resource
|
||||
self._resources.append(resource)
|
||||
|
||||
def add_resource(self, path, *, name=None):
|
||||
if not path.startswith('/'):
|
||||
raise ValueError("path should be started with /")
|
||||
if not ('{' in path or '}' in path or self.ROUTE_RE.search(path)):
|
||||
resource = PlainResource(path, name=name)
|
||||
self._reg_resource(resource)
|
||||
return resource
|
||||
|
||||
pattern = ''
|
||||
formatter = ''
|
||||
for part in self.ROUTE_RE.split(path):
|
||||
match = self.DYN.match(part)
|
||||
if match:
|
||||
pattern += '(?P<{}>{})'.format(match.group('var'), self.GOOD)
|
||||
formatter += '{' + match.group('var') + '}'
|
||||
continue
|
||||
|
||||
match = self.DYN_WITH_RE.match(part)
|
||||
if match:
|
||||
pattern += '(?P<{var}>{re})'.format(**match.groupdict())
|
||||
formatter += '{' + match.group('var') + '}'
|
||||
continue
|
||||
|
||||
if '{' in part or '}' in part:
|
||||
raise ValueError("Invalid path '{}'['{}']".format(path, part))
|
||||
|
||||
formatter += part
|
||||
pattern += re.escape(part)
|
||||
|
||||
try:
|
||||
compiled = re.compile('^' + pattern + '$')
|
||||
except re.error as exc:
|
||||
raise ValueError(
|
||||
"Bad pattern '{}': {}".format(pattern, exc)) from None
|
||||
resource = DynamicResource(compiled, formatter, name=name)
|
||||
self._reg_resource(resource)
|
||||
return resource
|
||||
|
||||
def add_route(self, method, path, handler,
|
||||
*, name=None, expect_handler=None):
|
||||
resource = self.add_resource(path, name=name)
|
||||
return resource.add_route(method, handler,
|
||||
expect_handler=expect_handler)
|
||||
|
||||
def add_static(self, prefix, path, *, name=None, expect_handler=None,
|
||||
chunk_size=256*1024, response_factory=StreamResponse,
|
||||
show_index=False):
|
||||
"""Add static files view.
|
||||
|
||||
prefix - url prefix
|
||||
path - folder with files
|
||||
|
||||
"""
|
||||
assert prefix.startswith('/')
|
||||
if not prefix.endswith('/'):
|
||||
prefix += '/'
|
||||
route = StaticRoute(name, prefix, path,
|
||||
expect_handler=expect_handler,
|
||||
chunk_size=chunk_size,
|
||||
response_factory=response_factory,
|
||||
show_index=show_index)
|
||||
self.register_route(route)
|
||||
return route
|
||||
|
||||
def add_head(self, *args, **kwargs):
|
||||
"""
|
||||
Shortcut for add_route with method HEAD
|
||||
"""
|
||||
return self.add_route(hdrs.METH_HEAD, *args, **kwargs)
|
||||
|
||||
def add_get(self, *args, **kwargs):
|
||||
"""
|
||||
Shortcut for add_route with method GET
|
||||
"""
|
||||
return self.add_route(hdrs.METH_GET, *args, **kwargs)
|
||||
|
||||
def add_post(self, *args, **kwargs):
|
||||
"""
|
||||
Shortcut for add_route with method POST
|
||||
"""
|
||||
return self.add_route(hdrs.METH_POST, *args, **kwargs)
|
||||
|
||||
def add_put(self, *args, **kwargs):
|
||||
"""
|
||||
Shortcut for add_route with method PUT
|
||||
"""
|
||||
return self.add_route(hdrs.METH_PUT, *args, **kwargs)
|
||||
|
||||
def add_patch(self, *args, **kwargs):
|
||||
"""
|
||||
Shortcut for add_route with method PATCH
|
||||
"""
|
||||
return self.add_route(hdrs.METH_PATCH, *args, **kwargs)
|
||||
|
||||
def add_delete(self, *args, **kwargs):
|
||||
"""
|
||||
Shortcut for add_route with method DELETE
|
||||
"""
|
||||
return self.add_route(hdrs.METH_DELETE, *args, **kwargs)
|
||||
|
|
@ -0,0 +1,320 @@
|
|||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
|
||||
from . import Timeout, hdrs
|
||||
from ._ws_impl import (CLOSED_MESSAGE, WebSocketError, WSMessage, WSMsgType,
|
||||
do_handshake)
|
||||
from .errors import ClientDisconnectedError, HttpProcessingError
|
||||
from .web_exceptions import (HTTPBadRequest, HTTPInternalServerError,
|
||||
HTTPMethodNotAllowed)
|
||||
from .web_reqrep import StreamResponse
|
||||
|
||||
__all__ = ('WebSocketResponse', 'WebSocketReady', 'MsgType', 'WSMsgType',)
|
||||
|
||||
PY_35 = sys.version_info >= (3, 5)
|
||||
PY_352 = sys.version_info >= (3, 5, 2)
|
||||
|
||||
THRESHOLD_CONNLOST_ACCESS = 5
|
||||
|
||||
|
||||
# deprecated since 1.0
|
||||
MsgType = WSMsgType
|
||||
|
||||
|
||||
class WebSocketReady(namedtuple('WebSocketReady', 'ok protocol')):
|
||||
def __bool__(self):
|
||||
return self.ok
|
||||
|
||||
|
||||
class WebSocketResponse(StreamResponse):
|
||||
|
||||
def __init__(self, *,
|
||||
timeout=10.0, autoclose=True, autoping=True, protocols=()):
|
||||
super().__init__(status=101)
|
||||
self._protocols = protocols
|
||||
self._protocol = None
|
||||
self._writer = None
|
||||
self._reader = None
|
||||
self._closed = False
|
||||
self._closing = False
|
||||
self._conn_lost = 0
|
||||
self._close_code = None
|
||||
self._loop = None
|
||||
self._waiting = False
|
||||
self._exception = None
|
||||
self._timeout = timeout
|
||||
self._autoclose = autoclose
|
||||
self._autoping = autoping
|
||||
|
||||
@asyncio.coroutine
|
||||
def prepare(self, request):
|
||||
# make pre-check to don't hide it by do_handshake() exceptions
|
||||
resp_impl = self._start_pre_check(request)
|
||||
if resp_impl is not None:
|
||||
return resp_impl
|
||||
|
||||
parser, protocol, writer = self._pre_start(request)
|
||||
resp_impl = yield from super().prepare(request)
|
||||
self._post_start(request, parser, protocol, writer)
|
||||
return resp_impl
|
||||
|
||||
def _pre_start(self, request):
|
||||
try:
|
||||
status, headers, parser, writer, protocol = do_handshake(
|
||||
request.method, request.headers, request.transport,
|
||||
self._protocols)
|
||||
except HttpProcessingError as err:
|
||||
if err.code == 405:
|
||||
raise HTTPMethodNotAllowed(
|
||||
request.method, [hdrs.METH_GET], body=b'')
|
||||
elif err.code == 400:
|
||||
raise HTTPBadRequest(text=err.message, headers=err.headers)
|
||||
else: # pragma: no cover
|
||||
raise HTTPInternalServerError() from err
|
||||
|
||||
if self.status != status:
|
||||
self.set_status(status)
|
||||
for k, v in headers:
|
||||
self.headers[k] = v
|
||||
self.force_close()
|
||||
return parser, protocol, writer
|
||||
|
||||
def _post_start(self, request, parser, protocol, writer):
|
||||
self._reader = request._reader.set_parser(parser)
|
||||
self._writer = writer
|
||||
self._protocol = protocol
|
||||
self._loop = request.app.loop
|
||||
|
||||
def start(self, request):
|
||||
warnings.warn('use .prepare(request) instead', DeprecationWarning)
|
||||
# make pre-check to don't hide it by do_handshake() exceptions
|
||||
resp_impl = self._start_pre_check(request)
|
||||
if resp_impl is not None:
|
||||
return resp_impl
|
||||
|
||||
parser, protocol, writer = self._pre_start(request)
|
||||
resp_impl = super().start(request)
|
||||
self._post_start(request, parser, protocol, writer)
|
||||
return resp_impl
|
||||
|
||||
def can_prepare(self, request):
|
||||
if self._writer is not None:
|
||||
raise RuntimeError('Already started')
|
||||
try:
|
||||
_, _, _, _, protocol = do_handshake(
|
||||
request.method, request.headers, request.transport,
|
||||
self._protocols)
|
||||
except HttpProcessingError:
|
||||
return WebSocketReady(False, None)
|
||||
else:
|
||||
return WebSocketReady(True, protocol)
|
||||
|
||||
def can_start(self, request):
|
||||
warnings.warn('use .can_prepare(request) instead', DeprecationWarning)
|
||||
return self.can_prepare(request)
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
return self._closed
|
||||
|
||||
@property
|
||||
def close_code(self):
|
||||
return self._close_code
|
||||
|
||||
@property
|
||||
def protocol(self):
|
||||
return self._protocol
|
||||
|
||||
def exception(self):
|
||||
return self._exception
|
||||
|
||||
def ping(self, message='b'):
|
||||
if self._writer is None:
|
||||
raise RuntimeError('Call .prepare() first')
|
||||
if self._closed:
|
||||
raise RuntimeError('websocket connection is closing')
|
||||
self._writer.ping(message)
|
||||
|
||||
def pong(self, message='b'):
|
||||
# unsolicited pong
|
||||
if self._writer is None:
|
||||
raise RuntimeError('Call .prepare() first')
|
||||
if self._closed:
|
||||
raise RuntimeError('websocket connection is closing')
|
||||
self._writer.pong(message)
|
||||
|
||||
def send_str(self, data):
|
||||
if self._writer is None:
|
||||
raise RuntimeError('Call .prepare() first')
|
||||
if self._closed:
|
||||
raise RuntimeError('websocket connection is closing')
|
||||
if not isinstance(data, str):
|
||||
raise TypeError('data argument must be str (%r)' % type(data))
|
||||
self._writer.send(data, binary=False)
|
||||
|
||||
def send_bytes(self, data):
|
||||
if self._writer is None:
|
||||
raise RuntimeError('Call .prepare() first')
|
||||
if self._closed:
|
||||
raise RuntimeError('websocket connection is closing')
|
||||
if not isinstance(data, (bytes, bytearray, memoryview)):
|
||||
raise TypeError('data argument must be byte-ish (%r)' %
|
||||
type(data))
|
||||
self._writer.send(data, binary=True)
|
||||
|
||||
def send_json(self, data, *, dumps=json.dumps):
|
||||
self.send_str(dumps(data))
|
||||
|
||||
@asyncio.coroutine
|
||||
def write_eof(self):
|
||||
if self._eof_sent:
|
||||
return
|
||||
if self._resp_impl is None:
|
||||
raise RuntimeError("Response has not been started")
|
||||
|
||||
yield from self.close()
|
||||
self._eof_sent = True
|
||||
|
||||
@asyncio.coroutine
|
||||
def close(self, *, code=1000, message=b''):
|
||||
if self._writer is None:
|
||||
raise RuntimeError('Call .prepare() first')
|
||||
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
try:
|
||||
self._writer.close(code, message)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
self._close_code = 1006
|
||||
raise
|
||||
except Exception as exc:
|
||||
self._close_code = 1006
|
||||
self._exception = exc
|
||||
return True
|
||||
|
||||
if self._closing:
|
||||
return True
|
||||
|
||||
begin = self._loop.time()
|
||||
while self._loop.time() - begin < self._timeout:
|
||||
try:
|
||||
with Timeout(timeout=self._timeout,
|
||||
loop=self._loop):
|
||||
msg = yield from self._reader.read()
|
||||
except asyncio.CancelledError:
|
||||
self._close_code = 1006
|
||||
raise
|
||||
except Exception as exc:
|
||||
self._close_code = 1006
|
||||
self._exception = exc
|
||||
return True
|
||||
|
||||
if msg.type == WSMsgType.CLOSE:
|
||||
self._close_code = msg.data
|
||||
return True
|
||||
|
||||
self._close_code = 1006
|
||||
self._exception = asyncio.TimeoutError()
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@asyncio.coroutine
|
||||
def receive(self):
|
||||
if self._reader is None:
|
||||
raise RuntimeError('Call .prepare() first')
|
||||
if self._waiting:
|
||||
raise RuntimeError('Concurrent call to receive() is not allowed')
|
||||
|
||||
self._waiting = True
|
||||
try:
|
||||
while True:
|
||||
if self._closed:
|
||||
self._conn_lost += 1
|
||||
if self._conn_lost >= THRESHOLD_CONNLOST_ACCESS:
|
||||
raise RuntimeError('WebSocket connection is closed.')
|
||||
return CLOSED_MESSAGE
|
||||
|
||||
try:
|
||||
msg = yield from self._reader.read()
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
raise
|
||||
except WebSocketError as exc:
|
||||
self._close_code = exc.code
|
||||
yield from self.close(code=exc.code)
|
||||
return WSMessage(WSMsgType.ERROR, exc, None)
|
||||
except ClientDisconnectedError:
|
||||
self._closed = True
|
||||
self._close_code = 1006
|
||||
return WSMessage(WSMsgType.CLOSE, None, None)
|
||||
except Exception as exc:
|
||||
self._exception = exc
|
||||
self._closing = True
|
||||
self._close_code = 1006
|
||||
yield from self.close()
|
||||
return WSMessage(WSMsgType.ERROR, exc, None)
|
||||
|
||||
if msg.type == WSMsgType.CLOSE:
|
||||
self._closing = True
|
||||
self._close_code = msg.data
|
||||
if not self._closed and self._autoclose:
|
||||
yield from self.close()
|
||||
return msg
|
||||
if msg.type == WSMsgType.PING and self._autoping:
|
||||
self.pong(msg.data)
|
||||
elif msg.type == WSMsgType.PONG and self._autoping:
|
||||
continue
|
||||
else:
|
||||
return msg
|
||||
finally:
|
||||
self._waiting = False
|
||||
|
||||
@asyncio.coroutine
|
||||
def receive_msg(self):
|
||||
warnings.warn(
|
||||
'receive_msg() coroutine is deprecated. use receive() instead',
|
||||
DeprecationWarning)
|
||||
return (yield from self.receive())
|
||||
|
||||
@asyncio.coroutine
|
||||
def receive_str(self):
|
||||
msg = yield from self.receive()
|
||||
if msg.type != WSMsgType.TEXT:
|
||||
raise TypeError(
|
||||
"Received message {}:{!r} is not str".format(msg.type,
|
||||
msg.data))
|
||||
return msg.data
|
||||
|
||||
@asyncio.coroutine
|
||||
def receive_bytes(self):
|
||||
msg = yield from self.receive()
|
||||
if msg.type != WSMsgType.BINARY:
|
||||
raise TypeError(
|
||||
"Received message {}:{!r} is not bytes".format(msg.type,
|
||||
msg.data))
|
||||
return msg.data
|
||||
|
||||
@asyncio.coroutine
|
||||
def receive_json(self, *, loads=json.loads):
|
||||
data = yield from self.receive_str()
|
||||
return loads(data)
|
||||
|
||||
def write(self, data):
|
||||
raise RuntimeError("Cannot call .write() for websocket")
|
||||
|
||||
if PY_35:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
if not PY_352: # pragma: no cover
|
||||
__aiter__ = asyncio.coroutine(__aiter__)
|
||||
|
||||
@asyncio.coroutine
|
||||
def __anext__(self):
|
||||
msg = yield from self.receive()
|
||||
if msg.type == WSMsgType.CLOSE:
|
||||
raise StopAsyncIteration # NOQA
|
||||
return msg
|
||||
|
|
@ -0,0 +1,195 @@
|
|||
"""Async gunicorn worker for aiohttp.web"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import ssl
|
||||
import sys
|
||||
|
||||
import gunicorn.workers.base as base
|
||||
from gunicorn.config import AccessLogFormat as GunicornAccessLogFormat
|
||||
|
||||
from aiohttp.helpers import AccessLogger, ensure_future
|
||||
|
||||
__all__ = ('GunicornWebWorker', 'GunicornUVLoopWebWorker')
|
||||
|
||||
|
||||
class GunicornWebWorker(base.Worker):
|
||||
|
||||
DEFAULT_AIOHTTP_LOG_FORMAT = AccessLogger.LOG_FORMAT
|
||||
DEFAULT_GUNICORN_LOG_FORMAT = GunicornAccessLogFormat.default
|
||||
|
||||
def __init__(self, *args, **kw): # pragma: no cover
|
||||
super().__init__(*args, **kw)
|
||||
|
||||
self.servers = {}
|
||||
self.exit_code = 0
|
||||
|
||||
def init_process(self):
|
||||
# create new event_loop after fork
|
||||
asyncio.get_event_loop().close()
|
||||
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
super().init_process()
|
||||
|
||||
def run(self):
|
||||
self.loop.run_until_complete(self.wsgi.startup())
|
||||
self._runner = ensure_future(self._run(), loop=self.loop)
|
||||
|
||||
try:
|
||||
self.loop.run_until_complete(self._runner)
|
||||
finally:
|
||||
self.loop.close()
|
||||
|
||||
sys.exit(self.exit_code)
|
||||
|
||||
def make_handler(self, app):
|
||||
return app.make_handler(
|
||||
logger=self.log,
|
||||
slow_request_timeout=self.cfg.timeout,
|
||||
keepalive_timeout=self.cfg.keepalive,
|
||||
access_log=self.log.access_log,
|
||||
access_log_format=self._get_valid_log_format(
|
||||
self.cfg.access_log_format))
|
||||
|
||||
@asyncio.coroutine
|
||||
def close(self):
|
||||
if self.servers:
|
||||
servers = self.servers
|
||||
self.servers = None
|
||||
|
||||
# stop accepting connections
|
||||
for server, handler in servers.items():
|
||||
self.log.info("Stopping server: %s, connections: %s",
|
||||
self.pid, len(handler.connections))
|
||||
server.close()
|
||||
yield from server.wait_closed()
|
||||
|
||||
# send on_shutdown event
|
||||
yield from self.wsgi.shutdown()
|
||||
|
||||
# stop alive connections
|
||||
tasks = [
|
||||
handler.finish_connections(
|
||||
timeout=self.cfg.graceful_timeout / 100 * 95)
|
||||
for handler in servers.values()]
|
||||
yield from asyncio.gather(*tasks, loop=self.loop)
|
||||
|
||||
# cleanup application
|
||||
yield from self.wsgi.cleanup()
|
||||
|
||||
@asyncio.coroutine
|
||||
def _run(self):
|
||||
|
||||
ctx = self._create_ssl_context(self.cfg) if self.cfg.is_ssl else None
|
||||
|
||||
for sock in self.sockets:
|
||||
handler = self.make_handler(self.wsgi)
|
||||
srv = yield from self.loop.create_server(handler, sock=sock.sock,
|
||||
ssl=ctx)
|
||||
self.servers[srv] = handler
|
||||
|
||||
# If our parent changed then we shut down.
|
||||
pid = os.getpid()
|
||||
try:
|
||||
while self.alive:
|
||||
self.notify()
|
||||
|
||||
cnt = sum(handler.requests_count
|
||||
for handler in self.servers.values())
|
||||
if self.cfg.max_requests and cnt > self.cfg.max_requests:
|
||||
self.alive = False
|
||||
self.log.info("Max requests, shutting down: %s", self)
|
||||
|
||||
elif pid == os.getpid() and self.ppid != os.getppid():
|
||||
self.alive = False
|
||||
self.log.info("Parent changed, shutting down: %s", self)
|
||||
else:
|
||||
yield from asyncio.sleep(1.0, loop=self.loop)
|
||||
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
yield from self.close()
|
||||
|
||||
def init_signals(self):
|
||||
# Set up signals through the event loop API.
|
||||
|
||||
self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit,
|
||||
signal.SIGQUIT, None)
|
||||
|
||||
self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit,
|
||||
signal.SIGTERM, None)
|
||||
|
||||
self.loop.add_signal_handler(signal.SIGINT, self.handle_quit,
|
||||
signal.SIGINT, None)
|
||||
|
||||
self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch,
|
||||
signal.SIGWINCH, None)
|
||||
|
||||
self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1,
|
||||
signal.SIGUSR1, None)
|
||||
|
||||
self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort,
|
||||
signal.SIGABRT, None)
|
||||
|
||||
# Don't let SIGTERM and SIGUSR1 disturb active requests
|
||||
# by interrupting system calls
|
||||
signal.siginterrupt(signal.SIGTERM, False)
|
||||
signal.siginterrupt(signal.SIGUSR1, False)
|
||||
|
||||
def handle_quit(self, sig, frame):
|
||||
self.alive = False
|
||||
|
||||
def handle_abort(self, sig, frame):
|
||||
self.alive = False
|
||||
self.exit_code = 1
|
||||
|
||||
@staticmethod
|
||||
def _create_ssl_context(cfg):
|
||||
""" Creates SSLContext instance for usage in asyncio.create_server.
|
||||
|
||||
See ssl.SSLSocket.__init__ for more details.
|
||||
"""
|
||||
ctx = ssl.SSLContext(cfg.ssl_version)
|
||||
ctx.load_cert_chain(cfg.certfile, cfg.keyfile)
|
||||
ctx.verify_mode = cfg.cert_reqs
|
||||
if cfg.ca_certs:
|
||||
ctx.load_verify_locations(cfg.ca_certs)
|
||||
if cfg.ciphers:
|
||||
ctx.set_ciphers(cfg.ciphers)
|
||||
return ctx
|
||||
|
||||
def _get_valid_log_format(self, source_format):
|
||||
if source_format == self.DEFAULT_GUNICORN_LOG_FORMAT:
|
||||
return self.DEFAULT_AIOHTTP_LOG_FORMAT
|
||||
elif re.search(r'%\([^\)]+\)', source_format):
|
||||
raise ValueError(
|
||||
"Gunicorn's style options in form of `%(name)s` are not "
|
||||
"supported for the log formatting. Please use aiohttp's "
|
||||
"format specification to configure access log formatting: "
|
||||
"http://aiohttp.readthedocs.io/en/stable/logging.html"
|
||||
"#format-specification"
|
||||
)
|
||||
else:
|
||||
return source_format
|
||||
|
||||
|
||||
class GunicornUVLoopWebWorker(GunicornWebWorker):
|
||||
|
||||
def init_process(self):
|
||||
import uvloop
|
||||
|
||||
# Close any existing event loop before setting a
|
||||
# new policy.
|
||||
asyncio.get_event_loop().close()
|
||||
|
||||
# Setup uvloop policy, so that every
|
||||
# asyncio.get_event_loop() will create an instance
|
||||
# of uvloop event loop.
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
super().init_process()
|
||||
|
|
@ -0,0 +1,235 @@
|
|||
"""wsgi server.
|
||||
|
||||
TODO:
|
||||
* proxy protocol
|
||||
* x-forward security
|
||||
* wsgi file support (os.sendfile)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import io
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import hdrs, server
|
||||
|
||||
__all__ = ('WSGIServerHttpProtocol',)
|
||||
|
||||
|
||||
class WSGIServerHttpProtocol(server.ServerHttpProtocol):
|
||||
"""HTTP Server that implements the Python WSGI protocol.
|
||||
|
||||
It uses 'wsgi.async' of 'True'. 'wsgi.input' can behave differently
|
||||
depends on 'readpayload' constructor parameter. If readpayload is set to
|
||||
True, wsgi server reads all incoming data into BytesIO object and
|
||||
sends it as 'wsgi.input' environ var. If readpayload is set to false
|
||||
'wsgi.input' is a StreamReader and application should read incoming
|
||||
data with "yield from environ['wsgi.input'].read()". It defaults to False.
|
||||
"""
|
||||
|
||||
SCRIPT_NAME = os.environ.get('SCRIPT_NAME', '')
|
||||
|
||||
def __init__(self, app, readpayload=False, is_ssl=False, *args, **kw):
|
||||
super().__init__(*args, **kw)
|
||||
|
||||
self.wsgi = app
|
||||
self.is_ssl = is_ssl
|
||||
self.readpayload = readpayload
|
||||
|
||||
def create_wsgi_response(self, message):
|
||||
return WsgiResponse(self.writer, message)
|
||||
|
||||
def create_wsgi_environ(self, message, payload):
|
||||
uri_parts = urlsplit(message.path)
|
||||
|
||||
environ = {
|
||||
'wsgi.input': payload,
|
||||
'wsgi.errors': sys.stderr,
|
||||
'wsgi.version': (1, 0),
|
||||
'wsgi.async': True,
|
||||
'wsgi.multithread': False,
|
||||
'wsgi.multiprocess': False,
|
||||
'wsgi.run_once': False,
|
||||
'wsgi.file_wrapper': FileWrapper,
|
||||
'SERVER_SOFTWARE': aiohttp.HttpMessage.SERVER_SOFTWARE,
|
||||
'REQUEST_METHOD': message.method,
|
||||
'QUERY_STRING': uri_parts.query or '',
|
||||
'RAW_URI': message.path,
|
||||
'SERVER_PROTOCOL': 'HTTP/%s.%s' % message.version
|
||||
}
|
||||
|
||||
script_name = self.SCRIPT_NAME
|
||||
|
||||
for hdr_name, hdr_value in message.headers.items():
|
||||
hdr_name = hdr_name.upper()
|
||||
if hdr_name == 'SCRIPT_NAME':
|
||||
script_name = hdr_value
|
||||
elif hdr_name == 'CONTENT-TYPE':
|
||||
environ['CONTENT_TYPE'] = hdr_value
|
||||
continue
|
||||
elif hdr_name == 'CONTENT-LENGTH':
|
||||
environ['CONTENT_LENGTH'] = hdr_value
|
||||
continue
|
||||
|
||||
key = 'HTTP_%s' % hdr_name.replace('-', '_')
|
||||
if key in environ:
|
||||
hdr_value = '%s,%s' % (environ[key], hdr_value)
|
||||
|
||||
environ[key] = hdr_value
|
||||
|
||||
url_scheme = environ.get('HTTP_X_FORWARDED_PROTO')
|
||||
if url_scheme is None:
|
||||
url_scheme = 'https' if self.is_ssl else 'http'
|
||||
environ['wsgi.url_scheme'] = url_scheme
|
||||
|
||||
# authors should be aware that REMOTE_HOST and REMOTE_ADDR
|
||||
# may not qualify the remote addr
|
||||
# also SERVER_PORT variable MUST be set to the TCP/IP port number on
|
||||
# which this request is received from the client.
|
||||
# http://www.ietf.org/rfc/rfc3875
|
||||
|
||||
family = self.transport.get_extra_info('socket').family
|
||||
if family in (socket.AF_INET, socket.AF_INET6):
|
||||
peername = self.transport.get_extra_info('peername')
|
||||
environ['REMOTE_ADDR'] = peername[0]
|
||||
environ['REMOTE_PORT'] = str(peername[1])
|
||||
http_host = message.headers.get("HOST", None)
|
||||
if http_host:
|
||||
hostport = http_host.split(":")
|
||||
environ['SERVER_NAME'] = hostport[0]
|
||||
if len(hostport) > 1:
|
||||
environ['SERVER_PORT'] = str(hostport[1])
|
||||
else:
|
||||
environ['SERVER_PORT'] = '80'
|
||||
else:
|
||||
# SERVER_NAME should be set to value of Host header, but this
|
||||
# header is not required. In this case we shoud set it to local
|
||||
# address of socket
|
||||
sockname = self.transport.get_extra_info('sockname')
|
||||
environ['SERVER_NAME'] = sockname[0]
|
||||
environ['SERVER_PORT'] = str(sockname[1])
|
||||
else:
|
||||
# We are behind reverse proxy, so get all vars from headers
|
||||
for header in ('REMOTE_ADDR', 'REMOTE_PORT',
|
||||
'SERVER_NAME', 'SERVER_PORT'):
|
||||
environ[header] = message.headers.get(header, '')
|
||||
|
||||
path_info = uri_parts.path
|
||||
if script_name:
|
||||
path_info = path_info.split(script_name, 1)[-1]
|
||||
|
||||
environ['PATH_INFO'] = path_info
|
||||
environ['SCRIPT_NAME'] = script_name
|
||||
|
||||
environ['async.reader'] = self.reader
|
||||
environ['async.writer'] = self.writer
|
||||
|
||||
return environ
|
||||
|
||||
@asyncio.coroutine
|
||||
def handle_request(self, message, payload):
|
||||
"""Handle a single HTTP request"""
|
||||
now = self._loop.time()
|
||||
|
||||
if self.readpayload:
|
||||
wsgiinput = io.BytesIO()
|
||||
wsgiinput.write((yield from payload.read()))
|
||||
wsgiinput.seek(0)
|
||||
payload = wsgiinput
|
||||
|
||||
environ = self.create_wsgi_environ(message, payload)
|
||||
response = self.create_wsgi_response(message)
|
||||
|
||||
riter = self.wsgi(environ, response.start_response)
|
||||
if isinstance(riter, asyncio.Future) or inspect.isgenerator(riter):
|
||||
riter = yield from riter
|
||||
|
||||
resp = response.response
|
||||
try:
|
||||
for item in riter:
|
||||
if isinstance(item, asyncio.Future):
|
||||
item = yield from item
|
||||
yield from resp.write(item)
|
||||
|
||||
yield from resp.write_eof()
|
||||
finally:
|
||||
if hasattr(riter, 'close'):
|
||||
riter.close()
|
||||
|
||||
if resp.keep_alive():
|
||||
self.keep_alive(True)
|
||||
|
||||
self.log_access(
|
||||
message, environ, response.response, self._loop.time() - now)
|
||||
|
||||
|
||||
class FileWrapper:
|
||||
"""Custom file wrapper."""
|
||||
|
||||
def __init__(self, fobj, chunk_size=8192):
|
||||
self.fobj = fobj
|
||||
self.chunk_size = chunk_size
|
||||
if hasattr(fobj, 'close'):
|
||||
self.close = fobj.close
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
data = self.fobj.read(self.chunk_size)
|
||||
if data:
|
||||
return data
|
||||
raise StopIteration
|
||||
|
||||
|
||||
class WsgiResponse:
|
||||
"""Implementation of start_response() callable as specified by PEP 3333"""
|
||||
|
||||
status = None
|
||||
|
||||
HOP_HEADERS = {
|
||||
hdrs.CONNECTION,
|
||||
hdrs.KEEP_ALIVE,
|
||||
hdrs.PROXY_AUTHENTICATE,
|
||||
hdrs.PROXY_AUTHORIZATION,
|
||||
hdrs.TE,
|
||||
hdrs.TRAILER,
|
||||
hdrs.TRANSFER_ENCODING,
|
||||
hdrs.UPGRADE,
|
||||
}
|
||||
|
||||
def __init__(self, writer, message):
|
||||
self.writer = writer
|
||||
self.message = message
|
||||
|
||||
def start_response(self, status, headers, exc_info=None):
|
||||
if exc_info:
|
||||
try:
|
||||
if self.status:
|
||||
raise exc_info[1]
|
||||
finally:
|
||||
exc_info = None
|
||||
|
||||
status_code = int(status.split(' ', 1)[0])
|
||||
|
||||
self.status = status
|
||||
resp = self.response = aiohttp.Response(
|
||||
self.writer, status_code,
|
||||
self.message.version, self.message.should_close)
|
||||
resp.HOP_HEADERS = self.HOP_HEADERS
|
||||
for name, value in headers:
|
||||
resp.add_header(name, value)
|
||||
|
||||
if resp.has_chunked_hdr:
|
||||
resp.enable_chunked_encoding()
|
||||
|
||||
# send headers immediately for websocket connection
|
||||
if status_code == 101 and resp.upgrade and resp.websocket:
|
||||
resp.send_headers()
|
||||
else:
|
||||
resp._send_headers = True
|
||||
return self.response.write
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
async-timeout
|
||||
=============
|
||||
|
||||
asyncio-compatible timeout context manager.
|
||||
|
||||
|
||||
Usage example
|
||||
-------------
|
||||
|
||||
|
||||
The context manager is useful in cases when you want to apply timeout
|
||||
logic around block of code or in cases when ``asyncio.wait_for()`` is
|
||||
not suitable. Also it's much faster than ``asyncio.wait_for()``
|
||||
because ``timeout`` doesn't create a new task.
|
||||
|
||||
The ``timeout(timeout, *, loop=None)`` call returns a context manager
|
||||
that cancels a block on *timeout* expiring::
|
||||
|
||||
with timeout(1.5):
|
||||
yield from inner()
|
||||
|
||||
1. If ``inner()`` is executed faster than in ``1.5`` seconds nothing
|
||||
happens.
|
||||
2. Otherwise ``inner()`` is cancelled internally by sending
|
||||
``asyncio.CancelledError`` into but ``asyncio.TimeoutError`` is
|
||||
raised outside of context manager scope.
|
||||
|
||||
*timeout* parameter could be ``None`` for skipping timeout functionality.
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
::
|
||||
|
||||
$ pip install async-timeout
|
||||
|
||||
The library is Python 3 only!
|
||||
|
||||
|
||||
|
||||
Authors and License
|
||||
-------------------
|
||||
|
||||
The module is written by Andrew Svetlov.
|
||||
|
||||
It's *Apache 2* licensed and freely available.
|
||||
|
||||
|
||||
CHANGES
|
||||
=======
|
||||
|
||||
1.2.1 (2017-05-02)
|
||||
------------------
|
||||
|
||||
* Support unpublished event loop's "current_task" api.
|
||||
|
||||
|
||||
1.2.0 (2017-03-11)
|
||||
------------------
|
||||
|
||||
* Extra check on context manager exit
|
||||
|
||||
* 0 is no-op timeout
|
||||
|
||||
|
||||
1.1.0 (2016-10-20)
|
||||
------------------
|
||||
|
||||
* Rename to `async-timeout`
|
||||
|
||||
1.0.0 (2016-09-09)
|
||||
------------------
|
||||
|
||||
* The first release.
|
||||
|
||||
|
||||
|
|
@ -0,0 +1 @@
|
|||
pip
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
Metadata-Version: 2.0
|
||||
Name: async-timeout
|
||||
Version: 1.2.1
|
||||
Summary: Timeout context manager for asyncio programs
|
||||
Home-page: https://github.com/aio-libs/async_timeout/
|
||||
Author: Andrew Svetlov
|
||||
Author-email: andrew.svetlov@gmail.com
|
||||
License: Apache 2
|
||||
Platform: UNKNOWN
|
||||
Classifier: License :: OSI Approved :: Apache Software License
|
||||
Classifier: Intended Audience :: Developers
|
||||
Classifier: Programming Language :: Python
|
||||
Classifier: Programming Language :: Python :: 3
|
||||
Classifier: Programming Language :: Python :: 3.4
|
||||
Classifier: Programming Language :: Python :: 3.5
|
||||
Classifier: Programming Language :: Python :: 3.6
|
||||
Classifier: Topic :: Internet :: WWW/HTTP
|
||||
Classifier: Framework :: AsyncIO
|
||||
|
||||
async-timeout
|
||||
=============
|
||||
|
||||
asyncio-compatible timeout context manager.
|
||||
|
||||
|
||||
Usage example
|
||||
-------------
|
||||
|
||||
|
||||
The context manager is useful in cases when you want to apply timeout
|
||||
logic around block of code or in cases when ``asyncio.wait_for()`` is
|
||||
not suitable. Also it's much faster than ``asyncio.wait_for()``
|
||||
because ``timeout`` doesn't create a new task.
|
||||
|
||||
The ``timeout(timeout, *, loop=None)`` call returns a context manager
|
||||
that cancels a block on *timeout* expiring::
|
||||
|
||||
with timeout(1.5):
|
||||
yield from inner()
|
||||
|
||||
1. If ``inner()`` is executed faster than in ``1.5`` seconds nothing
|
||||
happens.
|
||||
2. Otherwise ``inner()`` is cancelled internally by sending
|
||||
``asyncio.CancelledError`` into but ``asyncio.TimeoutError`` is
|
||||
raised outside of context manager scope.
|
||||
|
||||
*timeout* parameter could be ``None`` for skipping timeout functionality.
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
::
|
||||
|
||||
$ pip install async-timeout
|
||||
|
||||
The library is Python 3 only!
|
||||
|
||||
|
||||
|
||||
Authors and License
|
||||
-------------------
|
||||
|
||||
The module is written by Andrew Svetlov.
|
||||
|
||||
It's *Apache 2* licensed and freely available.
|
||||
|
||||
|
||||
CHANGES
|
||||
=======
|
||||
|
||||
1.2.1 (2017-05-02)
|
||||
------------------
|
||||
|
||||
* Support unpublished event loop's "current_task" api.
|
||||
|
||||
|
||||
1.2.0 (2017-03-11)
|
||||
------------------
|
||||
|
||||
* Extra check on context manager exit
|
||||
|
||||
* 0 is no-op timeout
|
||||
|
||||
|
||||
1.1.0 (2016-10-20)
|
||||
------------------
|
||||
|
||||
* Rename to `async-timeout`
|
||||
|
||||
1.0.0 (2016-09-09)
|
||||
------------------
|
||||
|
||||
* The first release.
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
async_timeout/__init__.py,sha256=5ONrYCJMKAzWV1qcK-qhUkRKPGEQfEzCWoxVvwAAru4,1913
|
||||
async_timeout-1.2.1.dist-info/DESCRIPTION.rst,sha256=IQuZGR3YfIcIGhWshP8gce8HXNCvMhxA-ov9oroqnI8,1430
|
||||
async_timeout-1.2.1.dist-info/METADATA,sha256=cOJx0VKD1jtlzkp0JAlBu4nzZ5cRX35I8PCQW4CG5bE,2117
|
||||
async_timeout-1.2.1.dist-info/RECORD,,
|
||||
async_timeout-1.2.1.dist-info/WHEEL,sha256=rNo05PbNqwnXiIHFsYm0m22u4Zm6YJtugFG2THx4w3g,92
|
||||
async_timeout-1.2.1.dist-info/metadata.json,sha256=FwV6Nc2u0faHG1tFFHTGglKZida7muCQA-9_jH0qq5E,889
|
||||
async_timeout-1.2.1.dist-info/top_level.txt,sha256=9oM4e7Twq8iD_7_Q3Mz0E6GPIB6vJvRFo-UBwUQtBDU,14
|
||||
async_timeout-1.2.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
async_timeout/__pycache__/__init__.cpython-36.pyc,,
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
Wheel-Version: 1.0
|
||||
Generator: bdist_wheel (0.29.0)
|
||||
Root-Is-Purelib: true
|
||||
Tag: py3-none-any
|
||||
|
||||
|
|
@ -0,0 +1 @@
|
|||
async_timeout
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
import asyncio
|
||||
|
||||
|
||||
__version__ = '1.2.1'
|
||||
|
||||
|
||||
class timeout:
|
||||
"""timeout context manager.
|
||||
|
||||
Useful in cases when you want to apply timeout logic around block
|
||||
of code or in cases when asyncio.wait_for is not suitable. For example:
|
||||
|
||||
>>> with timeout(0.001):
|
||||
... async with aiohttp.get('https://github.com') as r:
|
||||
... await r.text()
|
||||
|
||||
|
||||
timeout - value in seconds or None to disable timeout logic
|
||||
loop - asyncio compatible event loop
|
||||
"""
|
||||
def __init__(self, timeout, *, loop=None):
|
||||
if timeout is not None and timeout == 0:
|
||||
timeout = None
|
||||
self._timeout = timeout
|
||||
if loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
self._loop = loop
|
||||
self._task = None
|
||||
self._cancelled = False
|
||||
self._cancel_handler = None
|
||||
|
||||
def __enter__(self):
|
||||
if self._timeout is not None:
|
||||
self._task = current_task(self._loop)
|
||||
if self._task is None:
|
||||
raise RuntimeError('Timeout context manager should be used '
|
||||
'inside a task')
|
||||
self._cancel_handler = self._loop.call_later(
|
||||
self._timeout, self._cancel_task)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is asyncio.CancelledError and self._cancelled:
|
||||
self._cancel_handler = None
|
||||
self._task = None
|
||||
raise asyncio.TimeoutError from None
|
||||
if self._timeout is not None and self._cancel_handler is not None:
|
||||
self._cancel_handler.cancel()
|
||||
self._cancel_handler = None
|
||||
self._task = None
|
||||
|
||||
def _cancel_task(self):
|
||||
self._cancelled = self._task.cancel()
|
||||
|
||||
|
||||
def current_task(loop):
|
||||
task = asyncio.Task.current_task(loop=loop)
|
||||
if task is None:
|
||||
if hasattr(loop, 'current_task'):
|
||||
task = loop.current_task()
|
||||
|
||||
return task
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
Certifi: Python SSL Certificates
|
||||
================================
|
||||
|
||||
`Certifi`_ is a carefully curated collection of Root Certificates for
|
||||
validating the trustworthiness of SSL certificates while verifying the identity
|
||||
of TLS hosts. It has been extracted from the `Requests`_ project.
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
``certifi`` is available on PyPI. Simply install it with ``pip``::
|
||||
|
||||
$ pip install certifi
|
||||
|
||||
Usage
|
||||
-----
|
||||
|
||||
To reference the installed certificate authority (CA) bundle, you can use the
|
||||
built-in function::
|
||||
|
||||
>>> import certifi
|
||||
|
||||
>>> certifi.where()
|
||||
'/usr/local/lib/python2.7/site-packages/certifi/cacert.pem'
|
||||
|
||||
Enjoy!
|
||||
|
||||
1024-bit Root Certificates
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Browsers and certificate authorities have concluded that 1024-bit keys are
|
||||
unacceptably weak for certificates, particularly root certificates. For this
|
||||
reason, Mozilla has removed any weak (i.e. 1024-bit key) certificate from its
|
||||
bundle, replacing it with an equivalent strong (i.e. 2048-bit or greater key)
|
||||
certificate from the same CA. Because Mozilla removed these certificates from
|
||||
its bundle, ``certifi`` removed them as well.
|
||||
|
||||
Unfortunately, old versions of OpenSSL (less than 1.0.2) sometimes fail to
|
||||
validate certificate chains that use the strong roots. For this reason, if you
|
||||
fail to validate a certificate using the ``certifi.where()`` mechanism, you can
|
||||
intentionally re-add the 1024-bit roots back into your bundle by calling
|
||||
``certifi.old_where()`` instead. This is not recommended in production: if at
|
||||
all possible you should upgrade to a newer OpenSSL. However, if you have no
|
||||
other option, this may work for you.
|
||||
|
||||
.. _`Certifi`: http://certifi.io/en/latest/
|
||||
.. _`Requests`: http://docs.python-requests.org/en/latest/
|
||||
|
||||
|
||||
|
|
@ -0,0 +1 @@
|
|||
pip
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
Metadata-Version: 2.0
|
||||
Name: certifi
|
||||
Version: 2017.4.17
|
||||
Summary: Python package for providing Mozilla's CA Bundle.
|
||||
Home-page: http://certifi.io/
|
||||
Author: Kenneth Reitz
|
||||
Author-email: me@kennethreitz.com
|
||||
License: ISC
|
||||
Platform: UNKNOWN
|
||||
Classifier: Development Status :: 5 - Production/Stable
|
||||
Classifier: Intended Audience :: Developers
|
||||
Classifier: Natural Language :: English
|
||||
Classifier: Programming Language :: Python
|
||||
Classifier: Programming Language :: Python :: 2.6
|
||||
Classifier: Programming Language :: Python :: 2.7
|
||||
Classifier: Programming Language :: Python :: 3.3
|
||||
Classifier: Programming Language :: Python :: 3.4
|
||||
Classifier: Programming Language :: Python :: 3.5
|
||||
|
||||
Certifi: Python SSL Certificates
|
||||
================================
|
||||
|
||||
`Certifi`_ is a carefully curated collection of Root Certificates for
|
||||
validating the trustworthiness of SSL certificates while verifying the identity
|
||||
of TLS hosts. It has been extracted from the `Requests`_ project.
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
``certifi`` is available on PyPI. Simply install it with ``pip``::
|
||||
|
||||
$ pip install certifi
|
||||
|
||||
Usage
|
||||
-----
|
||||
|
||||
To reference the installed certificate authority (CA) bundle, you can use the
|
||||
built-in function::
|
||||
|
||||
>>> import certifi
|
||||
|
||||
>>> certifi.where()
|
||||
'/usr/local/lib/python2.7/site-packages/certifi/cacert.pem'
|
||||
|
||||
Enjoy!
|
||||
|
||||
1024-bit Root Certificates
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Browsers and certificate authorities have concluded that 1024-bit keys are
|
||||
unacceptably weak for certificates, particularly root certificates. For this
|
||||
reason, Mozilla has removed any weak (i.e. 1024-bit key) certificate from its
|
||||
bundle, replacing it with an equivalent strong (i.e. 2048-bit or greater key)
|
||||
certificate from the same CA. Because Mozilla removed these certificates from
|
||||
its bundle, ``certifi`` removed them as well.
|
||||
|
||||
Unfortunately, old versions of OpenSSL (less than 1.0.2) sometimes fail to
|
||||
validate certificate chains that use the strong roots. For this reason, if you
|
||||
fail to validate a certificate using the ``certifi.where()`` mechanism, you can
|
||||
intentionally re-add the 1024-bit roots back into your bundle by calling
|
||||
``certifi.old_where()`` instead. This is not recommended in production: if at
|
||||
all possible you should upgrade to a newer OpenSSL. However, if you have no
|
||||
other option, this may work for you.
|
||||
|
||||
.. _`Certifi`: http://certifi.io/en/latest/
|
||||
.. _`Requests`: http://docs.python-requests.org/en/latest/
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
certifi/__init__.py,sha256=fygqpMx6KPrCIRMY4qcO5Zo60MDyQCYtNaI5BbgZyqE,63
|
||||
certifi/__main__.py,sha256=FiOYt1Fltst7wk9DRa6GCoBr8qBUxlNQu_MKJf04E6s,41
|
||||
certifi/cacert.pem,sha256=UgTuBXP5FC1mKK2skamUrJKyL8RVMmtUTKJZpTSFn_U,321422
|
||||
certifi/core.py,sha256=DqvIINYNNXsp3Srlk_NRaiizaww8po3l8t8ksz-Xt6Q,716
|
||||
certifi/old_root.pem,sha256=HT0KIfaM83q0XHFqGEesiGyfmlSWuD2RI0-AVIS2srY,25626
|
||||
certifi/weak.pem,sha256=LGe1E3ewgvNAs_yRA9ZKBN6C5KV2Cx34iJFMPi8_hyo,347048
|
||||
certifi-2017.4.17.dist-info/DESCRIPTION.rst,sha256=wVWYoH3eovdWFPZnYU2NT4itGRx3eN5C_s1IuNm4qF4,1731
|
||||
certifi-2017.4.17.dist-info/METADATA,sha256=ZzDLL1LWFj2SDbZdV_QZ-ZNWd_UDw_NXGbARLwgLYdg,2396
|
||||
certifi-2017.4.17.dist-info/RECORD,,
|
||||
certifi-2017.4.17.dist-info/WHEEL,sha256=5wvfB7GvgZAbKBSE9uX9Zbi6LCL-_KgezgHblXhCRnM,113
|
||||
certifi-2017.4.17.dist-info/metadata.json,sha256=8MYPZlDmqjogcs0bl7CuyQwwrSvqaFjpGMYsjFBQBlw,790
|
||||
certifi-2017.4.17.dist-info/top_level.txt,sha256=KMu4vUCfsjLrkPbSNdgdekS-pVJzBAJFO__nI8NF6-U,8
|
||||
certifi-2017.4.17.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
certifi/__pycache__/core.cpython-36.pyc,,
|
||||
certifi/__pycache__/__init__.cpython-36.pyc,,
|
||||
certifi/__pycache__/__main__.cpython-36.pyc,,
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
Wheel-Version: 1.0
|
||||
Generator: bdist_wheel (0.30.0.a0)
|
||||
Root-Is-Purelib: true
|
||||
Tag: py2-none-any
|
||||
Tag: py3-none-any
|
||||
|
||||
|
|
@ -0,0 +1 @@
|
|||
certifi
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
from .core import where, old_where
|
||||
|
||||
__version__ = "2017.04.17"
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
from certifi import where
|
||||
print(where())
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,36 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
certifi.py
|
||||
~~~~~~~~~~
|
||||
|
||||
This module returns the installation location of cacert.pem.
|
||||
"""
|
||||
import os
|
||||
import warnings
|
||||
|
||||
|
||||
class DeprecatedBundleWarning(DeprecationWarning):
|
||||
"""
|
||||
The weak security bundle is being deprecated. Please bother your service
|
||||
provider to get them to stop using cross-signed roots.
|
||||
"""
|
||||
|
||||
|
||||
def where():
|
||||
f = os.path.split(__file__)[0]
|
||||
|
||||
return os.path.join(f, 'cacert.pem')
|
||||
|
||||
|
||||
def old_where():
|
||||
warnings.warn(
|
||||
"The weak security bundle is being deprecated.",
|
||||
DeprecatedBundleWarning
|
||||
)
|
||||
f = os.path.split(__file__)[0]
|
||||
return os.path.join(f, 'weak.pem')
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(where())
|
||||
|
|
@ -0,0 +1,414 @@
|
|||
# Issuer: CN=Entrust.net Secure Server Certification Authority O=Entrust.net OU=www.entrust.net/CPS incorp. by ref. (limits liab.)/(c) 1999 Entrust.net Limited
|
||||
# Subject: CN=Entrust.net Secure Server Certification Authority O=Entrust.net OU=www.entrust.net/CPS incorp. by ref. (limits liab.)/(c) 1999 Entrust.net Limited
|
||||
# Label: "Entrust.net Secure Server CA"
|
||||
# Serial: 927650371
|
||||
# MD5 Fingerprint: df:f2:80:73:cc:f1:e6:61:73:fc:f5:42:e9:c5:7c:ee
|
||||
# SHA1 Fingerprint: 99:a6:9b:e6:1a:fe:88:6b:4d:2b:82:00:7c:b8:54:fc:31:7e:15:39
|
||||
# SHA256 Fingerprint: 62:f2:40:27:8c:56:4c:4d:d8:bf:7d:9d:4f:6f:36:6e:a8:94:d2:2f:5f:34:d9:89:a9:83:ac:ec:2f:ff:ed:50
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIE2DCCBEGgAwIBAgIEN0rSQzANBgkqhkiG9w0BAQUFADCBwzELMAkGA1UEBhMC
|
||||
VVMxFDASBgNVBAoTC0VudHJ1c3QubmV0MTswOQYDVQQLEzJ3d3cuZW50cnVzdC5u
|
||||
ZXQvQ1BTIGluY29ycC4gYnkgcmVmLiAobGltaXRzIGxpYWIuKTElMCMGA1UECxMc
|
||||
KGMpIDE5OTkgRW50cnVzdC5uZXQgTGltaXRlZDE6MDgGA1UEAxMxRW50cnVzdC5u
|
||||
ZXQgU2VjdXJlIFNlcnZlciBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTAeFw05OTA1
|
||||
MjUxNjA5NDBaFw0xOTA1MjUxNjM5NDBaMIHDMQswCQYDVQQGEwJVUzEUMBIGA1UE
|
||||
ChMLRW50cnVzdC5uZXQxOzA5BgNVBAsTMnd3dy5lbnRydXN0Lm5ldC9DUFMgaW5j
|
||||
b3JwLiBieSByZWYuIChsaW1pdHMgbGlhYi4pMSUwIwYDVQQLExwoYykgMTk5OSBF
|
||||
bnRydXN0Lm5ldCBMaW1pdGVkMTowOAYDVQQDEzFFbnRydXN0Lm5ldCBTZWN1cmUg
|
||||
U2VydmVyIENlcnRpZmljYXRpb24gQXV0aG9yaXR5MIGdMA0GCSqGSIb3DQEBAQUA
|
||||
A4GLADCBhwKBgQDNKIM0VBuJ8w+vN5Ex/68xYMmo6LIQaO2f55M28Qpku0f1BBc/
|
||||
I0dNxScZgSYMVHINiC3ZH5oSn7yzcdOAGT9HZnuMNSjSuQrfJNqc1lB5gXpa0zf3
|
||||
wkrYKZImZNHkmGw6AIr1NJtl+O3jEP/9uElY3KDegjlrgbEWGWG5VLbmQwIBA6OC
|
||||
AdcwggHTMBEGCWCGSAGG+EIBAQQEAwIABzCCARkGA1UdHwSCARAwggEMMIHeoIHb
|
||||
oIHYpIHVMIHSMQswCQYDVQQGEwJVUzEUMBIGA1UEChMLRW50cnVzdC5uZXQxOzA5
|
||||
BgNVBAsTMnd3dy5lbnRydXN0Lm5ldC9DUFMgaW5jb3JwLiBieSByZWYuIChsaW1p
|
||||
dHMgbGlhYi4pMSUwIwYDVQQLExwoYykgMTk5OSBFbnRydXN0Lm5ldCBMaW1pdGVk
|
||||
MTowOAYDVQQDEzFFbnRydXN0Lm5ldCBTZWN1cmUgU2VydmVyIENlcnRpZmljYXRp
|
||||
b24gQXV0aG9yaXR5MQ0wCwYDVQQDEwRDUkwxMCmgJ6AlhiNodHRwOi8vd3d3LmVu
|
||||
dHJ1c3QubmV0L0NSTC9uZXQxLmNybDArBgNVHRAEJDAigA8xOTk5MDUyNTE2MDk0
|
||||
MFqBDzIwMTkwNTI1MTYwOTQwWjALBgNVHQ8EBAMCAQYwHwYDVR0jBBgwFoAU8Bdi
|
||||
E1U9s/8KAGv7UISX8+1i0BowHQYDVR0OBBYEFPAXYhNVPbP/CgBr+1CEl/PtYtAa
|
||||
MAwGA1UdEwQFMAMBAf8wGQYJKoZIhvZ9B0EABAwwChsEVjQuMAMCBJAwDQYJKoZI
|
||||
hvcNAQEFBQADgYEAkNwwAvpkdMKnCqV8IY00F6j7Rw7/JXyNEwr75Ji174z4xRAN
|
||||
95K+8cPV1ZVqBLssziY2ZcgxxufuP+NXdYR6Ee9GTxj005i7qIcyunL2POI9n9cd
|
||||
2cNgQ4xYDiKWL2KjLB+6rQXvqzJ4h6BUcxm1XAX5Uj5tLUUL9wqT6u0G+bI=
|
||||
-----END CERTIFICATE-----
|
||||
|
||||
# Issuer: CN=http://www.valicert.com/ O=ValiCert, Inc. OU=ValiCert Class 2 Policy Validation Authority
|
||||
# Subject: CN=http://www.valicert.com/ O=ValiCert, Inc. OU=ValiCert Class 2 Policy Validation Authority
|
||||
# Label: "ValiCert Class 2 VA"
|
||||
# Serial: 1
|
||||
# MD5 Fingerprint: a9:23:75:9b:ba:49:36:6e:31:c2:db:f2:e7:66:ba:87
|
||||
# SHA1 Fingerprint: 31:7a:2a:d0:7f:2b:33:5e:f5:a1:c3:4e:4b:57:e8:b7:d8:f1:fc:a6
|
||||
# SHA256 Fingerprint: 58:d0:17:27:9c:d4:dc:63:ab:dd:b1:96:a6:c9:90:6c:30:c4:e0:87:83:ea:e8:c1:60:99:54:d6:93:55:59:6b
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIC5zCCAlACAQEwDQYJKoZIhvcNAQEFBQAwgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0
|
||||
IFZhbGlkYXRpb24gTmV0d29yazEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAz
|
||||
BgNVBAsTLFZhbGlDZXJ0IENsYXNzIDIgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9y
|
||||
aXR5MSEwHwYDVQQDExhodHRwOi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG
|
||||
9w0BCQEWEWluZm9AdmFsaWNlcnQuY29tMB4XDTk5MDYyNjAwMTk1NFoXDTE5MDYy
|
||||
NjAwMTk1NFowgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0IFZhbGlkYXRpb24gTmV0d29y
|
||||
azEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAzBgNVBAsTLFZhbGlDZXJ0IENs
|
||||
YXNzIDIgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9yaXR5MSEwHwYDVQQDExhodHRw
|
||||
Oi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG9w0BCQEWEWluZm9AdmFsaWNl
|
||||
cnQuY29tMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDOOnHK5avIWZJV16vY
|
||||
dA757tn2VUdZZUcOBVXc65g2PFxTXdMwzzjsvUGJ7SVCCSRrCl6zfN1SLUzm1NZ9
|
||||
WlmpZdRJEy0kTRxQb7XBhVQ7/nHk01xC+YDgkRoKWzk2Z/M/VXwbP7RfZHM047QS
|
||||
v4dk+NoS/zcnwbNDu+97bi5p9wIDAQABMA0GCSqGSIb3DQEBBQUAA4GBADt/UG9v
|
||||
UJSZSWI4OB9L+KXIPqeCgfYrx+jFzug6EILLGACOTb2oWH+heQC1u+mNr0HZDzTu
|
||||
IYEZoDJJKPTEjlbVUjP9UNV+mWwD5MlM/Mtsq2azSiGM5bUMMj4QssxsodyamEwC
|
||||
W/POuZ6lcg5Ktz885hZo+L7tdEy8W9ViH0Pd
|
||||
-----END CERTIFICATE-----
|
||||
|
||||
# Issuer: CN=NetLock Expressz (Class C) Tanusitvanykiado O=NetLock Halozatbiztonsagi Kft. OU=Tanusitvanykiadok
|
||||
# Subject: CN=NetLock Expressz (Class C) Tanusitvanykiado O=NetLock Halozatbiztonsagi Kft. OU=Tanusitvanykiadok
|
||||
# Label: "NetLock Express (Class C) Root"
|
||||
# Serial: 104
|
||||
# MD5 Fingerprint: 4f:eb:f1:f0:70:c2:80:63:5d:58:9f:da:12:3c:a9:c4
|
||||
# SHA1 Fingerprint: e3:92:51:2f:0a:cf:f5:05:df:f6:de:06:7f:75:37:e1:65:ea:57:4b
|
||||
# SHA256 Fingerprint: 0b:5e:ed:4e:84:64:03:cf:55:e0:65:84:84:40:ed:2a:82:75:8b:f5:b9:aa:1f:25:3d:46:13:cf:a0:80:ff:3f
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIFTzCCBLigAwIBAgIBaDANBgkqhkiG9w0BAQQFADCBmzELMAkGA1UEBhMCSFUx
|
||||
ETAPBgNVBAcTCEJ1ZGFwZXN0MScwJQYDVQQKEx5OZXRMb2NrIEhhbG96YXRiaXp0
|
||||
b25zYWdpIEtmdC4xGjAYBgNVBAsTEVRhbnVzaXR2YW55a2lhZG9rMTQwMgYDVQQD
|
||||
EytOZXRMb2NrIEV4cHJlc3N6IChDbGFzcyBDKSBUYW51c2l0dmFueWtpYWRvMB4X
|
||||
DTk5MDIyNTE0MDgxMVoXDTE5MDIyMDE0MDgxMVowgZsxCzAJBgNVBAYTAkhVMREw
|
||||
DwYDVQQHEwhCdWRhcGVzdDEnMCUGA1UEChMeTmV0TG9jayBIYWxvemF0Yml6dG9u
|
||||
c2FnaSBLZnQuMRowGAYDVQQLExFUYW51c2l0dmFueWtpYWRvazE0MDIGA1UEAxMr
|
||||
TmV0TG9jayBFeHByZXNzeiAoQ2xhc3MgQykgVGFudXNpdHZhbnlraWFkbzCBnzAN
|
||||
BgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA6+ywbGGKIyWvYCDj2Z/8kwvbXY2wobNA
|
||||
OoLO/XXgeDIDhlqGlZHtU/qdQPzm6N3ZW3oDvV3zOwzDUXmbrVWg6dADEK8KuhRC
|
||||
2VImESLH0iDMgqSaqf64gXadarfSNnU+sYYJ9m5tfk63euyucYT2BDMIJTLrdKwW
|
||||
RMbkQJMdf60CAwEAAaOCAp8wggKbMBIGA1UdEwEB/wQIMAYBAf8CAQQwDgYDVR0P
|
||||
AQH/BAQDAgAGMBEGCWCGSAGG+EIBAQQEAwIABzCCAmAGCWCGSAGG+EIBDQSCAlEW
|
||||
ggJNRklHWUVMRU0hIEV6ZW4gdGFudXNpdHZhbnkgYSBOZXRMb2NrIEtmdC4gQWx0
|
||||
YWxhbm9zIFN6b2xnYWx0YXRhc2kgRmVsdGV0ZWxlaWJlbiBsZWlydCBlbGphcmFz
|
||||
b2sgYWxhcGphbiBrZXN6dWx0LiBBIGhpdGVsZXNpdGVzIGZvbHlhbWF0YXQgYSBO
|
||||
ZXRMb2NrIEtmdC4gdGVybWVrZmVsZWxvc3NlZy1iaXp0b3NpdGFzYSB2ZWRpLiBB
|
||||
IGRpZ2l0YWxpcyBhbGFpcmFzIGVsZm9nYWRhc2FuYWsgZmVsdGV0ZWxlIGF6IGVs
|
||||
b2lydCBlbGxlbm9yemVzaSBlbGphcmFzIG1lZ3RldGVsZS4gQXogZWxqYXJhcyBs
|
||||
ZWlyYXNhIG1lZ3RhbGFsaGF0byBhIE5ldExvY2sgS2Z0LiBJbnRlcm5ldCBob25s
|
||||
YXBqYW4gYSBodHRwczovL3d3dy5uZXRsb2NrLm5ldC9kb2NzIGNpbWVuIHZhZ3kg
|
||||
a2VyaGV0byBheiBlbGxlbm9yemVzQG5ldGxvY2submV0IGUtbWFpbCBjaW1lbi4g
|
||||
SU1QT1JUQU5UISBUaGUgaXNzdWFuY2UgYW5kIHRoZSB1c2Ugb2YgdGhpcyBjZXJ0
|
||||
aWZpY2F0ZSBpcyBzdWJqZWN0IHRvIHRoZSBOZXRMb2NrIENQUyBhdmFpbGFibGUg
|
||||
YXQgaHR0cHM6Ly93d3cubmV0bG9jay5uZXQvZG9jcyBvciBieSBlLW1haWwgYXQg
|
||||
Y3BzQG5ldGxvY2submV0LjANBgkqhkiG9w0BAQQFAAOBgQAQrX/XDDKACtiG8XmY
|
||||
ta3UzbM2xJZIwVzNmtkFLp++UOv0JhQQLdRmF/iewSf98e3ke0ugbLWrmldwpu2g
|
||||
pO0u9f38vf5NNwgMvOOWgyL1SRt/Syu0VMGAfJlOHdCM7tCs5ZL6dVb+ZKATj7i4
|
||||
Fp1hBWeAyNDYpQcCNJgEjTME1A==
|
||||
-----END CERTIFICATE-----
|
||||
|
||||
# Issuer: CN=NetLock Uzleti (Class B) Tanusitvanykiado O=NetLock Halozatbiztonsagi Kft. OU=Tanusitvanykiadok
|
||||
# Subject: CN=NetLock Uzleti (Class B) Tanusitvanykiado O=NetLock Halozatbiztonsagi Kft. OU=Tanusitvanykiadok
|
||||
# Label: "NetLock Business (Class B) Root"
|
||||
# Serial: 105
|
||||
# MD5 Fingerprint: 39:16:aa:b9:6a:41:e1:14:69:df:9e:6c:3b:72:dc:b6
|
||||
# SHA1 Fingerprint: 87:9f:4b:ee:05:df:98:58:3b:e3:60:d6:33:e7:0d:3f:fe:98:71:af
|
||||
# SHA256 Fingerprint: 39:df:7b:68:2b:7b:93:8f:84:71:54:81:cc:de:8d:60:d8:f2:2e:c5:98:87:7d:0a:aa:c1:2b:59:18:2b:03:12
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIFSzCCBLSgAwIBAgIBaTANBgkqhkiG9w0BAQQFADCBmTELMAkGA1UEBhMCSFUx
|
||||
ETAPBgNVBAcTCEJ1ZGFwZXN0MScwJQYDVQQKEx5OZXRMb2NrIEhhbG96YXRiaXp0
|
||||
b25zYWdpIEtmdC4xGjAYBgNVBAsTEVRhbnVzaXR2YW55a2lhZG9rMTIwMAYDVQQD
|
||||
EylOZXRMb2NrIFV6bGV0aSAoQ2xhc3MgQikgVGFudXNpdHZhbnlraWFkbzAeFw05
|
||||
OTAyMjUxNDEwMjJaFw0xOTAyMjAxNDEwMjJaMIGZMQswCQYDVQQGEwJIVTERMA8G
|
||||
A1UEBxMIQnVkYXBlc3QxJzAlBgNVBAoTHk5ldExvY2sgSGFsb3phdGJpenRvbnNh
|
||||
Z2kgS2Z0LjEaMBgGA1UECxMRVGFudXNpdHZhbnlraWFkb2sxMjAwBgNVBAMTKU5l
|
||||
dExvY2sgVXpsZXRpIChDbGFzcyBCKSBUYW51c2l0dmFueWtpYWRvMIGfMA0GCSqG
|
||||
SIb3DQEBAQUAA4GNADCBiQKBgQCx6gTsIKAjwo84YM/HRrPVG/77uZmeBNwcf4xK
|
||||
gZjupNTKihe5In+DCnVMm8Bp2GQ5o+2So/1bXHQawEfKOml2mrriRBf8TKPV/riX
|
||||
iK+IA4kfpPIEPsgHC+b5sy96YhQJRhTKZPWLgLViqNhr1nGTLbO/CVRY7QbrqHvc
|
||||
Q7GhaQIDAQABo4ICnzCCApswEgYDVR0TAQH/BAgwBgEB/wIBBDAOBgNVHQ8BAf8E
|
||||
BAMCAAYwEQYJYIZIAYb4QgEBBAQDAgAHMIICYAYJYIZIAYb4QgENBIICURaCAk1G
|
||||
SUdZRUxFTSEgRXplbiB0YW51c2l0dmFueSBhIE5ldExvY2sgS2Z0LiBBbHRhbGFu
|
||||
b3MgU3pvbGdhbHRhdGFzaSBGZWx0ZXRlbGVpYmVuIGxlaXJ0IGVsamFyYXNvayBh
|
||||
bGFwamFuIGtlc3p1bHQuIEEgaGl0ZWxlc2l0ZXMgZm9seWFtYXRhdCBhIE5ldExv
|
||||
Y2sgS2Z0LiB0ZXJtZWtmZWxlbG9zc2VnLWJpenRvc2l0YXNhIHZlZGkuIEEgZGln
|
||||
aXRhbGlzIGFsYWlyYXMgZWxmb2dhZGFzYW5hayBmZWx0ZXRlbGUgYXogZWxvaXJ0
|
||||
IGVsbGVub3J6ZXNpIGVsamFyYXMgbWVndGV0ZWxlLiBBeiBlbGphcmFzIGxlaXJh
|
||||
c2EgbWVndGFsYWxoYXRvIGEgTmV0TG9jayBLZnQuIEludGVybmV0IGhvbmxhcGph
|
||||
biBhIGh0dHBzOi8vd3d3Lm5ldGxvY2submV0L2RvY3MgY2ltZW4gdmFneSBrZXJo
|
||||
ZXRvIGF6IGVsbGVub3J6ZXNAbmV0bG9jay5uZXQgZS1tYWlsIGNpbWVuLiBJTVBP
|
||||
UlRBTlQhIFRoZSBpc3N1YW5jZSBhbmQgdGhlIHVzZSBvZiB0aGlzIGNlcnRpZmlj
|
||||
YXRlIGlzIHN1YmplY3QgdG8gdGhlIE5ldExvY2sgQ1BTIGF2YWlsYWJsZSBhdCBo
|
||||
dHRwczovL3d3dy5uZXRsb2NrLm5ldC9kb2NzIG9yIGJ5IGUtbWFpbCBhdCBjcHNA
|
||||
bmV0bG9jay5uZXQuMA0GCSqGSIb3DQEBBAUAA4GBAATbrowXr/gOkDFOzT4JwG06
|
||||
sPgzTEdM43WIEJessDgVkcYplswhwG08pXTP2IKlOcNl40JwuyKQ433bNXbhoLXa
|
||||
n3BukxowOR0w2y7jfLKRstE3Kfq51hdcR0/jHTjrn9V7lagonhVK0dHQKwCXoOKS
|
||||
NitjrFgBazMpUIaD8QFI
|
||||
-----END CERTIFICATE-----
|
||||
|
||||
# Issuer: CN=http://www.valicert.com/ O=ValiCert, Inc. OU=ValiCert Class 3 Policy Validation Authority
|
||||
# Subject: CN=http://www.valicert.com/ O=ValiCert, Inc. OU=ValiCert Class 3 Policy Validation Authority
|
||||
# Label: "RSA Root Certificate 1"
|
||||
# Serial: 1
|
||||
# MD5 Fingerprint: a2:6f:53:b7:ee:40:db:4a:68:e7:fa:18:d9:10:4b:72
|
||||
# SHA1 Fingerprint: 69:bd:8c:f4:9c:d3:00:fb:59:2e:17:93:ca:55:6a:f3:ec:aa:35:fb
|
||||
# SHA256 Fingerprint: bc:23:f9:8a:31:3c:b9:2d:e3:bb:fc:3a:5a:9f:44:61:ac:39:49:4c:4a:e1:5a:9e:9d:f1:31:e9:9b:73:01:9a
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIC5zCCAlACAQEwDQYJKoZIhvcNAQEFBQAwgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0
|
||||
IFZhbGlkYXRpb24gTmV0d29yazEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAz
|
||||
BgNVBAsTLFZhbGlDZXJ0IENsYXNzIDMgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9y
|
||||
aXR5MSEwHwYDVQQDExhodHRwOi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG
|
||||
9w0BCQEWEWluZm9AdmFsaWNlcnQuY29tMB4XDTk5MDYyNjAwMjIzM1oXDTE5MDYy
|
||||
NjAwMjIzM1owgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0IFZhbGlkYXRpb24gTmV0d29y
|
||||
azEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAzBgNVBAsTLFZhbGlDZXJ0IENs
|
||||
YXNzIDMgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9yaXR5MSEwHwYDVQQDExhodHRw
|
||||
Oi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG9w0BCQEWEWluZm9AdmFsaWNl
|
||||
cnQuY29tMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDjmFGWHOjVsQaBalfD
|
||||
cnWTq8+epvzzFlLWLU2fNUSoLgRNB0mKOCn1dzfnt6td3zZxFJmP3MKS8edgkpfs
|
||||
2Ejcv8ECIMYkpChMMFp2bbFc893enhBxoYjHW5tBbcqwuI4V7q0zK89HBFx1cQqY
|
||||
JJgpp0lZpd34t0NiYfPT4tBVPwIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAFa7AliE
|
||||
Zwgs3x/be0kz9dNnnfS0ChCzycUs4pJqcXgn8nCDQtM+z6lU9PHYkhaM0QTLS6vJ
|
||||
n0WuPIqpsHEzXcjFV9+vqDWzf4mH6eglkrh/hXqu1rweN1gqZ8mRzyqBPu3GOd/A
|
||||
PhmcGcwTTYJBtYze4D1gCCAPRX5ron+jjBXu
|
||||
-----END CERTIFICATE-----
|
||||
|
||||
# Issuer: CN=http://www.valicert.com/ O=ValiCert, Inc. OU=ValiCert Class 1 Policy Validation Authority
|
||||
# Subject: CN=http://www.valicert.com/ O=ValiCert, Inc. OU=ValiCert Class 1 Policy Validation Authority
|
||||
# Label: "ValiCert Class 1 VA"
|
||||
# Serial: 1
|
||||
# MD5 Fingerprint: 65:58:ab:15:ad:57:6c:1e:a8:a7:b5:69:ac:bf:ff:eb
|
||||
# SHA1 Fingerprint: e5:df:74:3c:b6:01:c4:9b:98:43:dc:ab:8c:e8:6a:81:10:9f:e4:8e
|
||||
# SHA256 Fingerprint: f4:c1:49:55:1a:30:13:a3:5b:c7:bf:fe:17:a7:f3:44:9b:c1:ab:5b:5a:0a:e7:4b:06:c2:3b:90:00:4c:01:04
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIC5zCCAlACAQEwDQYJKoZIhvcNAQEFBQAwgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0
|
||||
IFZhbGlkYXRpb24gTmV0d29yazEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAz
|
||||
BgNVBAsTLFZhbGlDZXJ0IENsYXNzIDEgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9y
|
||||
aXR5MSEwHwYDVQQDExhodHRwOi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG
|
||||
9w0BCQEWEWluZm9AdmFsaWNlcnQuY29tMB4XDTk5MDYyNTIyMjM0OFoXDTE5MDYy
|
||||
NTIyMjM0OFowgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0IFZhbGlkYXRpb24gTmV0d29y
|
||||
azEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAzBgNVBAsTLFZhbGlDZXJ0IENs
|
||||
YXNzIDEgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9yaXR5MSEwHwYDVQQDExhodHRw
|
||||
Oi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG9w0BCQEWEWluZm9AdmFsaWNl
|
||||
cnQuY29tMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDYWYJ6ibiWuqYvaG9Y
|
||||
LqdUHAZu9OqNSLwxlBfw8068srg1knaw0KWlAdcAAxIiGQj4/xEjm84H9b9pGib+
|
||||
TunRf50sQB1ZaG6m+FiwnRqP0z/x3BkGgagO4DrdyFNFCQbmD3DD+kCmDuJWBQ8Y
|
||||
TfwggtFzVXSNdnKgHZ0dwN0/cQIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAFBoPUn0
|
||||
LBwGlN+VYH+Wexf+T3GtZMjdd9LvWVXoP+iOBSoh8gfStadS/pyxtuJbdxdA6nLW
|
||||
I8sogTLDAHkY7FkXicnGah5xyf23dKUlRWnFSKsZ4UWKJWsZ7uW7EvV/96aNUcPw
|
||||
nXS3qT6gpf+2SQMT2iLM7XGCK5nPOrf1LXLI
|
||||
-----END CERTIFICATE-----
|
||||
|
||||
# Issuer: CN=Equifax Secure eBusiness CA-1 O=Equifax Secure Inc.
|
||||
# Subject: CN=Equifax Secure eBusiness CA-1 O=Equifax Secure Inc.
|
||||
# Label: "Equifax Secure eBusiness CA 1"
|
||||
# Serial: 4
|
||||
# MD5 Fingerprint: 64:9c:ef:2e:44:fc:c6:8f:52:07:d0:51:73:8f:cb:3d
|
||||
# SHA1 Fingerprint: da:40:18:8b:91:89:a3:ed:ee:ae:da:97:fe:2f:9d:f5:b7:d1:8a:41
|
||||
# SHA256 Fingerprint: cf:56:ff:46:a4:a1:86:10:9d:d9:65:84:b5:ee:b5:8a:51:0c:42:75:b0:e5:f9:4f:40:bb:ae:86:5e:19:f6:73
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIICgjCCAeugAwIBAgIBBDANBgkqhkiG9w0BAQQFADBTMQswCQYDVQQGEwJVUzEc
|
||||
MBoGA1UEChMTRXF1aWZheCBTZWN1cmUgSW5jLjEmMCQGA1UEAxMdRXF1aWZheCBT
|
||||
ZWN1cmUgZUJ1c2luZXNzIENBLTEwHhcNOTkwNjIxMDQwMDAwWhcNMjAwNjIxMDQw
|
||||
MDAwWjBTMQswCQYDVQQGEwJVUzEcMBoGA1UEChMTRXF1aWZheCBTZWN1cmUgSW5j
|
||||
LjEmMCQGA1UEAxMdRXF1aWZheCBTZWN1cmUgZUJ1c2luZXNzIENBLTEwgZ8wDQYJ
|
||||
KoZIhvcNAQEBBQADgY0AMIGJAoGBAM4vGbwXt3fek6lfWg0XTzQaDJj0ItlZ1MRo
|
||||
RvC0NcWFAyDGr0WlIVFFQesWWDYyb+JQYmT5/VGcqiTZ9J2DKocKIdMSODRsjQBu
|
||||
WqDZQu4aIZX5UkxVWsUPOE9G+m34LjXWHXzr4vCwdYDIqROsvojvOm6rXyo4YgKw
|
||||
Env+j6YDAgMBAAGjZjBkMBEGCWCGSAGG+EIBAQQEAwIABzAPBgNVHRMBAf8EBTAD
|
||||
AQH/MB8GA1UdIwQYMBaAFEp4MlIR21kWNl7fwRQ2QGpHfEyhMB0GA1UdDgQWBBRK
|
||||
eDJSEdtZFjZe38EUNkBqR3xMoTANBgkqhkiG9w0BAQQFAAOBgQB1W6ibAxHm6VZM
|
||||
zfmpTMANmvPMZWnmJXbMWbfWVMMdzZmsGd20hdXgPfxiIKeES1hl8eL5lSE/9dR+
|
||||
WB5Hh1Q+WKG1tfgq73HnvMP2sUlG4tega+VWeponmHxGYhTnyfxuAxJ5gDgdSIKN
|
||||
/Bf+KpYrtWKmpj29f5JZzVoqgrI3eQ==
|
||||
-----END CERTIFICATE-----
|
||||
|
||||
# Issuer: CN=Equifax Secure Global eBusiness CA-1 O=Equifax Secure Inc.
|
||||
# Subject: CN=Equifax Secure Global eBusiness CA-1 O=Equifax Secure Inc.
|
||||
# Label: "Equifax Secure Global eBusiness CA"
|
||||
# Serial: 1
|
||||
# MD5 Fingerprint: 8f:5d:77:06:27:c4:98:3c:5b:93:78:e7:d7:7d:9b:cc
|
||||
# SHA1 Fingerprint: 7e:78:4a:10:1c:82:65:cc:2d:e1:f1:6d:47:b4:40:ca:d9:0a:19:45
|
||||
# SHA256 Fingerprint: 5f:0b:62:ea:b5:e3:53:ea:65:21:65:16:58:fb:b6:53:59:f4:43:28:0a:4a:fb:d1:04:d7:7d:10:f9:f0:4c:07
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIICkDCCAfmgAwIBAgIBATANBgkqhkiG9w0BAQQFADBaMQswCQYDVQQGEwJVUzEc
|
||||
MBoGA1UEChMTRXF1aWZheCBTZWN1cmUgSW5jLjEtMCsGA1UEAxMkRXF1aWZheCBT
|
||||
ZWN1cmUgR2xvYmFsIGVCdXNpbmVzcyBDQS0xMB4XDTk5MDYyMTA0MDAwMFoXDTIw
|
||||
MDYyMTA0MDAwMFowWjELMAkGA1UEBhMCVVMxHDAaBgNVBAoTE0VxdWlmYXggU2Vj
|
||||
dXJlIEluYy4xLTArBgNVBAMTJEVxdWlmYXggU2VjdXJlIEdsb2JhbCBlQnVzaW5l
|
||||
c3MgQ0EtMTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAuucXkAJlsTRVPEnC
|
||||
UdXfp9E3j9HngXNBUmCbnaEXJnitx7HoJpQytd4zjTov2/KaelpzmKNc6fuKcxtc
|
||||
58O/gGzNqfTWK8D3+ZmqY6KxRwIP1ORROhI8bIpaVIRw28HFkM9yRcuoWcDNM50/
|
||||
o5brhTMhHD4ePmBudpxnhcXIw2ECAwEAAaNmMGQwEQYJYIZIAYb4QgEBBAQDAgAH
|
||||
MA8GA1UdEwEB/wQFMAMBAf8wHwYDVR0jBBgwFoAUvqigdHJQa0S3ySPY+6j/s1dr
|
||||
aGwwHQYDVR0OBBYEFL6ooHRyUGtEt8kj2Puo/7NXa2hsMA0GCSqGSIb3DQEBBAUA
|
||||
A4GBADDiAVGqx+pf2rnQZQ8w1j7aDRRJbpGTJxQx78T3LUX47Me/okENI7SS+RkA
|
||||
Z70Br83gcfxaz2TE4JaY0KNA4gGK7ycH8WUBikQtBmV1UsCGECAhX2xrD2yuCRyv
|
||||
8qIYNMR1pHMc8Y3c7635s3a0kr/clRAevsvIO1qEYBlWlKlV
|
||||
-----END CERTIFICATE-----
|
||||
|
||||
# Issuer: CN=Thawte Premium Server CA O=Thawte Consulting cc OU=Certification Services Division
|
||||
# Subject: CN=Thawte Premium Server CA O=Thawte Consulting cc OU=Certification Services Division
|
||||
# Label: "Thawte Premium Server CA"
|
||||
# Serial: 1
|
||||
# MD5 Fingerprint: 06:9f:69:79:16:66:90:02:1b:8c:8c:a2:c3:07:6f:3a
|
||||
# SHA1 Fingerprint: 62:7f:8d:78:27:65:63:99:d2:7d:7f:90:44:c9:fe:b3:f3:3e:fa:9a
|
||||
# SHA256 Fingerprint: ab:70:36:36:5c:71:54:aa:29:c2:c2:9f:5d:41:91:16:3b:16:2a:22:25:01:13:57:d5:6d:07:ff:a7:bc:1f:72
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDJzCCApCgAwIBAgIBATANBgkqhkiG9w0BAQQFADCBzjELMAkGA1UEBhMCWkEx
|
||||
FTATBgNVBAgTDFdlc3Rlcm4gQ2FwZTESMBAGA1UEBxMJQ2FwZSBUb3duMR0wGwYD
|
||||
VQQKExRUaGF3dGUgQ29uc3VsdGluZyBjYzEoMCYGA1UECxMfQ2VydGlmaWNhdGlv
|
||||
biBTZXJ2aWNlcyBEaXZpc2lvbjEhMB8GA1UEAxMYVGhhd3RlIFByZW1pdW0gU2Vy
|
||||
dmVyIENBMSgwJgYJKoZIhvcNAQkBFhlwcmVtaXVtLXNlcnZlckB0aGF3dGUuY29t
|
||||
MB4XDTk2MDgwMTAwMDAwMFoXDTIwMTIzMTIzNTk1OVowgc4xCzAJBgNVBAYTAlpB
|
||||
MRUwEwYDVQQIEwxXZXN0ZXJuIENhcGUxEjAQBgNVBAcTCUNhcGUgVG93bjEdMBsG
|
||||
A1UEChMUVGhhd3RlIENvbnN1bHRpbmcgY2MxKDAmBgNVBAsTH0NlcnRpZmljYXRp
|
||||
b24gU2VydmljZXMgRGl2aXNpb24xITAfBgNVBAMTGFRoYXd0ZSBQcmVtaXVtIFNl
|
||||
cnZlciBDQTEoMCYGCSqGSIb3DQEJARYZcHJlbWl1bS1zZXJ2ZXJAdGhhd3RlLmNv
|
||||
bTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA0jY2aovXwlue2oFBYo847kkE
|
||||
VdbQ7xwblRZH7xhINTpS9CtqBo87L+pW46+GjZ4X9560ZXUCTe/LCaIhUdib0GfQ
|
||||
ug2SBhRz1JPLlyoAnFxODLz6FVL88kRu2hFKbgifLy3j+ao6hnO2RlNYyIkFvYMR
|
||||
uHM/qgeN9EJN50CdHDcCAwEAAaMTMBEwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG
|
||||
9w0BAQQFAAOBgQAmSCwWwlj66BZ0DKqqX1Q/8tfJeGBeXm43YyJ3Nn6yF8Q0ufUI
|
||||
hfzJATj/Tb7yFkJD57taRvvBxhEf8UqwKEbJw8RCfbz6q1lu1bdRiBHjpIUZa4JM
|
||||
pAwSremkrj/xw0llmozFyD4lt5SZu5IycQfwhl7tUCemDaYj+bvLpgcUQg==
|
||||
-----END CERTIFICATE-----
|
||||
|
||||
# Issuer: CN=Thawte Server CA O=Thawte Consulting cc OU=Certification Services Division
|
||||
# Subject: CN=Thawte Server CA O=Thawte Consulting cc OU=Certification Services Division
|
||||
# Label: "Thawte Server CA"
|
||||
# Serial: 1
|
||||
# MD5 Fingerprint: c5:70:c4:a2:ed:53:78:0c:c8:10:53:81:64:cb:d0:1d
|
||||
# SHA1 Fingerprint: 23:e5:94:94:51:95:f2:41:48:03:b4:d5:64:d2:a3:a3:f5:d8:8b:8c
|
||||
# SHA256 Fingerprint: b4:41:0b:73:e2:e6:ea:ca:47:fb:c4:2f:8f:a4:01:8a:f4:38:1d:c5:4c:fa:a8:44:50:46:1e:ed:09:45:4d:e9
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDEzCCAnygAwIBAgIBATANBgkqhkiG9w0BAQQFADCBxDELMAkGA1UEBhMCWkEx
|
||||
FTATBgNVBAgTDFdlc3Rlcm4gQ2FwZTESMBAGA1UEBxMJQ2FwZSBUb3duMR0wGwYD
|
||||
VQQKExRUaGF3dGUgQ29uc3VsdGluZyBjYzEoMCYGA1UECxMfQ2VydGlmaWNhdGlv
|
||||
biBTZXJ2aWNlcyBEaXZpc2lvbjEZMBcGA1UEAxMQVGhhd3RlIFNlcnZlciBDQTEm
|
||||
MCQGCSqGSIb3DQEJARYXc2VydmVyLWNlcnRzQHRoYXd0ZS5jb20wHhcNOTYwODAx
|
||||
MDAwMDAwWhcNMjAxMjMxMjM1OTU5WjCBxDELMAkGA1UEBhMCWkExFTATBgNVBAgT
|
||||
DFdlc3Rlcm4gQ2FwZTESMBAGA1UEBxMJQ2FwZSBUb3duMR0wGwYDVQQKExRUaGF3
|
||||
dGUgQ29uc3VsdGluZyBjYzEoMCYGA1UECxMfQ2VydGlmaWNhdGlvbiBTZXJ2aWNl
|
||||
cyBEaXZpc2lvbjEZMBcGA1UEAxMQVGhhd3RlIFNlcnZlciBDQTEmMCQGCSqGSIb3
|
||||
DQEJARYXc2VydmVyLWNlcnRzQHRoYXd0ZS5jb20wgZ8wDQYJKoZIhvcNAQEBBQAD
|
||||
gY0AMIGJAoGBANOkUG7I/1Zr5s9dtuoMaHVHoqrC2oQl/Kj0R1HahbUgdJSGHg91
|
||||
yekIYfUGbTBuFRkC6VLAYttNmZ7iagxEOM3+vuNkCXDF/rFrKbYvScg71CcEJRCX
|
||||
L+eQbcAoQpnXTEPew/UhbVSfXcNY4cDk2VuwuNy0e982OsK1ZiIS1ocNAgMBAAGj
|
||||
EzARMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQEEBQADgYEAB/pMaVz7lcxG
|
||||
7oWDTSEwjsrZqG9JGubaUeNgcGyEYRGhGshIPllDfU+VPaGLtwtimHp1it2ITk6e
|
||||
QNuozDJ0uW8NxuOzRAvZim+aKZuZGCg70eNAKJpaPNW15yAbi8qkq43pUdniTCxZ
|
||||
qdq5snUb9kLy78fyGPmJvKP/iiMucEc=
|
||||
-----END CERTIFICATE-----
|
||||
|
||||
# Issuer: O=VeriSign, Inc. OU=Class 3 Public Primary Certification Authority
|
||||
# Subject: O=VeriSign, Inc. OU=Class 3 Public Primary Certification Authority
|
||||
# Label: "Verisign Class 3 Public Primary Certification Authority"
|
||||
# Serial: 149843929435818692848040365716851702463
|
||||
# MD5 Fingerprint: 10:fc:63:5d:f6:26:3e:0d:f3:25:be:5f:79:cd:67:67
|
||||
# SHA1 Fingerprint: 74:2c:31:92:e6:07:e4:24:eb:45:49:54:2b:e1:bb:c5:3e:61:74:e2
|
||||
# SHA256 Fingerprint: e7:68:56:34:ef:ac:f6:9a:ce:93:9a:6b:25:5b:7b:4f:ab:ef:42:93:5b:50:a2:65:ac:b5:cb:60:27:e4:4e:70
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIICPDCCAaUCEHC65B0Q2Sk0tjjKewPMur8wDQYJKoZIhvcNAQECBQAwXzELMAkG
|
||||
A1UEBhMCVVMxFzAVBgNVBAoTDlZlcmlTaWduLCBJbmMuMTcwNQYDVQQLEy5DbGFz
|
||||
cyAzIFB1YmxpYyBQcmltYXJ5IENlcnRpZmljYXRpb24gQXV0aG9yaXR5MB4XDTk2
|
||||
MDEyOTAwMDAwMFoXDTI4MDgwMTIzNTk1OVowXzELMAkGA1UEBhMCVVMxFzAVBgNV
|
||||
BAoTDlZlcmlTaWduLCBJbmMuMTcwNQYDVQQLEy5DbGFzcyAzIFB1YmxpYyBQcmlt
|
||||
YXJ5IENlcnRpZmljYXRpb24gQXV0aG9yaXR5MIGfMA0GCSqGSIb3DQEBAQUAA4GN
|
||||
ADCBiQKBgQDJXFme8huKARS0EN8EQNvjV69qRUCPhAwL0TPZ2RHP7gJYHyX3KqhE
|
||||
BarsAx94f56TuZoAqiN91qyFomNFx3InzPRMxnVx0jnvT0Lwdd8KkMaOIG+YD/is
|
||||
I19wKTakyYbnsZogy1Olhec9vn2a/iRFM9x2Fe0PonFkTGUugWhFpwIDAQABMA0G
|
||||
CSqGSIb3DQEBAgUAA4GBALtMEivPLCYATxQT3ab7/AoRhIzzKBxnki98tsX63/Do
|
||||
lbwdj2wsqFHMc9ikwFPwTtYmwHYBV4GSXiHx0bH/59AhWM1pF+NEHJwZRDmJXNyc
|
||||
AA9WjQKZ7aKQRUzkuxCkPfAyAw7xzvjoyVGM5mKf5p/AfbdynMk2OmufTqj/ZA1k
|
||||
-----END CERTIFICATE-----
|
||||
|
||||
# Issuer: O=VeriSign, Inc. OU=Class 3 Public Primary Certification Authority
|
||||
# Subject: O=VeriSign, Inc. OU=Class 3 Public Primary Certification Authority
|
||||
# Label: "Verisign Class 3 Public Primary Certification Authority"
|
||||
# Serial: 80507572722862485515306429940691309246
|
||||
# MD5 Fingerprint: ef:5a:f1:33:ef:f1:cd:bb:51:02:ee:12:14:4b:96:c4
|
||||
# SHA1 Fingerprint: a1:db:63:93:91:6f:17:e4:18:55:09:40:04:15:c7:02:40:b0:ae:6b
|
||||
# SHA256 Fingerprint: a4:b6:b3:99:6f:c2:f3:06:b3:fd:86:81:bd:63:41:3d:8c:50:09:cc:4f:a3:29:c2:cc:f0:e2:fa:1b:14:03:05
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIICPDCCAaUCEDyRMcsf9tAbDpq40ES/Er4wDQYJKoZIhvcNAQEFBQAwXzELMAkG
|
||||
A1UEBhMCVVMxFzAVBgNVBAoTDlZlcmlTaWduLCBJbmMuMTcwNQYDVQQLEy5DbGFz
|
||||
cyAzIFB1YmxpYyBQcmltYXJ5IENlcnRpZmljYXRpb24gQXV0aG9yaXR5MB4XDTk2
|
||||
MDEyOTAwMDAwMFoXDTI4MDgwMjIzNTk1OVowXzELMAkGA1UEBhMCVVMxFzAVBgNV
|
||||
BAoTDlZlcmlTaWduLCBJbmMuMTcwNQYDVQQLEy5DbGFzcyAzIFB1YmxpYyBQcmlt
|
||||
YXJ5IENlcnRpZmljYXRpb24gQXV0aG9yaXR5MIGfMA0GCSqGSIb3DQEBAQUAA4GN
|
||||
ADCBiQKBgQDJXFme8huKARS0EN8EQNvjV69qRUCPhAwL0TPZ2RHP7gJYHyX3KqhE
|
||||
BarsAx94f56TuZoAqiN91qyFomNFx3InzPRMxnVx0jnvT0Lwdd8KkMaOIG+YD/is
|
||||
I19wKTakyYbnsZogy1Olhec9vn2a/iRFM9x2Fe0PonFkTGUugWhFpwIDAQABMA0G
|
||||
CSqGSIb3DQEBBQUAA4GBABByUqkFFBkyCEHwxWsKzH4PIRnN5GfcX6kb5sroc50i
|
||||
2JhucwNhkcV8sEVAbkSdjbCxlnRhLQ2pRdKkkirWmnWXbj9T/UWZYB2oK0z5XqcJ
|
||||
2HUw19JlYD1n1khVdWk/kfVIC0dpImmClr7JyDiGSnoscxlIaU5rfGW/D/xwzoiQ
|
||||
-----END CERTIFICATE-----
|
||||
|
||||
# Issuer: O=VeriSign, Inc. OU=Class 3 Public Primary Certification Authority - G2/(c) 1998 VeriSign, Inc. - For authorized use only/VeriSign Trust Network
|
||||
# Subject: O=VeriSign, Inc. OU=Class 3 Public Primary Certification Authority - G2/(c) 1998 VeriSign, Inc. - For authorized use only/VeriSign Trust Network
|
||||
# Label: "Verisign Class 3 Public Primary Certification Authority - G2"
|
||||
# Serial: 167285380242319648451154478808036881606
|
||||
# MD5 Fingerprint: a2:33:9b:4c:74:78:73:d4:6c:e7:c1:f3:8d:cb:5c:e9
|
||||
# SHA1 Fingerprint: 85:37:1c:a6:e5:50:14:3d:ce:28:03:47:1b:de:3a:09:e8:f8:77:0f
|
||||
# SHA256 Fingerprint: 83:ce:3c:12:29:68:8a:59:3d:48:5f:81:97:3c:0f:91:95:43:1e:da:37:cc:5e:36:43:0e:79:c7:a8:88:63:8b
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDAjCCAmsCEH3Z/gfPqB63EHln+6eJNMYwDQYJKoZIhvcNAQEFBQAwgcExCzAJ
|
||||
BgNVBAYTAlVTMRcwFQYDVQQKEw5WZXJpU2lnbiwgSW5jLjE8MDoGA1UECxMzQ2xh
|
||||
c3MgMyBQdWJsaWMgUHJpbWFyeSBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eSAtIEcy
|
||||
MTowOAYDVQQLEzEoYykgMTk5OCBWZXJpU2lnbiwgSW5jLiAtIEZvciBhdXRob3Jp
|
||||
emVkIHVzZSBvbmx5MR8wHQYDVQQLExZWZXJpU2lnbiBUcnVzdCBOZXR3b3JrMB4X
|
||||
DTk4MDUxODAwMDAwMFoXDTI4MDgwMTIzNTk1OVowgcExCzAJBgNVBAYTAlVTMRcw
|
||||
FQYDVQQKEw5WZXJpU2lnbiwgSW5jLjE8MDoGA1UECxMzQ2xhc3MgMyBQdWJsaWMg
|
||||
UHJpbWFyeSBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eSAtIEcyMTowOAYDVQQLEzEo
|
||||
YykgMTk5OCBWZXJpU2lnbiwgSW5jLiAtIEZvciBhdXRob3JpemVkIHVzZSBvbmx5
|
||||
MR8wHQYDVQQLExZWZXJpU2lnbiBUcnVzdCBOZXR3b3JrMIGfMA0GCSqGSIb3DQEB
|
||||
AQUAA4GNADCBiQKBgQDMXtERXVxp0KvTuWpMmR9ZmDCOFoUgRm1HP9SFIIThbbP4
|
||||
pO0M8RcPO/mn+SXXwc+EY/J8Y8+iR/LGWzOOZEAEaMGAuWQcRXfH2G71lSk8UOg0
|
||||
13gfqLptQ5GVj0VXXn7F+8qkBOvqlzdUMG+7AUcyM83cV5tkaWH4mx0ciU9cZwID
|
||||
AQABMA0GCSqGSIb3DQEBBQUAA4GBAFFNzb5cy5gZnBWyATl4Lk0PZ3BwmcYQWpSk
|
||||
U01UbSuvDV1Ai2TT1+7eVmGSX6bEHRBhNtMsJzzoKQm5EWR0zLVznxxIqbxhAe7i
|
||||
F6YM40AIOw7n60RzKprxaZLvcRTDOaxxp5EJb+RxBrO6WVcmeQD2+A2iMzAo1KpY
|
||||
oJ2daZH9
|
||||
-----END CERTIFICATE-----
|
||||
|
||||
# Issuer: CN=GTE CyberTrust Global Root O=GTE Corporation OU=GTE CyberTrust Solutions, Inc.
|
||||
# Subject: CN=GTE CyberTrust Global Root O=GTE Corporation OU=GTE CyberTrust Solutions, Inc.
|
||||
# Label: "GTE CyberTrust Global Root"
|
||||
# Serial: 421
|
||||
# MD5 Fingerprint: ca:3d:d3:68:f1:03:5c:d0:32:fa:b8:2b:59:e8:5a:db
|
||||
# SHA1 Fingerprint: 97:81:79:50:d8:1c:96:70:cc:34:d8:09:cf:79:44:31:36:7e:f4:74
|
||||
# SHA256 Fingerprint: a5:31:25:18:8d:21:10:aa:96:4b:02:c7:b7:c6:da:32:03:17:08:94:e5:fb:71:ff:fb:66:67:d5:e6:81:0a:36
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIICWjCCAcMCAgGlMA0GCSqGSIb3DQEBBAUAMHUxCzAJBgNVBAYTAlVTMRgwFgYD
|
||||
VQQKEw9HVEUgQ29ycG9yYXRpb24xJzAlBgNVBAsTHkdURSBDeWJlclRydXN0IFNv
|
||||
bHV0aW9ucywgSW5jLjEjMCEGA1UEAxMaR1RFIEN5YmVyVHJ1c3QgR2xvYmFsIFJv
|
||||
b3QwHhcNOTgwODEzMDAyOTAwWhcNMTgwODEzMjM1OTAwWjB1MQswCQYDVQQGEwJV
|
||||
UzEYMBYGA1UEChMPR1RFIENvcnBvcmF0aW9uMScwJQYDVQQLEx5HVEUgQ3liZXJU
|
||||
cnVzdCBTb2x1dGlvbnMsIEluYy4xIzAhBgNVBAMTGkdURSBDeWJlclRydXN0IEds
|
||||
b2JhbCBSb290MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCVD6C28FCc6HrH
|
||||
iM3dFw4usJTQGz0O9pTAipTHBsiQl8i4ZBp6fmw8U+E3KHNgf7KXUwefU/ltWJTS
|
||||
r41tiGeA5u2ylc9yMcqlHHK6XALnZELn+aks1joNrI1CqiQBOeacPwGFVw1Yh0X4
|
||||
04Wqk2kmhXBIgD8SFcd5tB8FLztimQIDAQABMA0GCSqGSIb3DQEBBAUAA4GBAG3r
|
||||
GwnpXtlR22ciYaQqPEh346B8pt5zohQDhT37qw4wxYMWM4ETCJ57NE7fQMh017l9
|
||||
3PR2VX2bY1QY6fDq81yx2YtCHrnAlU66+tXifPVoYb+O7AWXX1uw16OFNMQkpw0P
|
||||
lZPvy5TYnh+dXIVtx6quTx8itc2VrbqnzPmrC3p/
|
||||
-----END CERTIFICATE-----
|
||||
|
||||
# Issuer: C=US, O=Equifax, OU=Equifax Secure Certificate Authority
|
||||
# Subject: C=US, O=Equifax, OU=Equifax Secure Certificate Authority
|
||||
# Label: "Equifax Secure Certificate Authority"
|
||||
# Serial: 903804111
|
||||
# MD5 Fingerprint: 67:cb:9d:c0:13:24:8a:82:9b:b2:17:1e:d1:1b:ec:d4
|
||||
# SHA1 Fingerprint: d2:32:09:ad:23:d3:14:23:21:74:e4:0d:7f:9d:62:13:97:86:63:3a
|
||||
# SHA256 Fingerprint: 08:29:7a:40:47:db:a2:36:80:c7:31:db:6e:31:76:53:ca:78:48:e1:be:bd:3a:0b:01:79:a7:07:f9:2c:f1:78
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDIDCCAomgAwIBAgIENd70zzANBgkqhkiG9w0BAQUFADBOMQswCQYDVQQGEwJV
|
||||
UzEQMA4GA1UEChMHRXF1aWZheDEtMCsGA1UECxMkRXF1aWZheCBTZWN1cmUgQ2Vy
|
||||
dGlmaWNhdGUgQXV0aG9yaXR5MB4XDTk4MDgyMjE2NDE1MVoXDTE4MDgyMjE2NDE1
|
||||
MVowTjELMAkGA1UEBhMCVVMxEDAOBgNVBAoTB0VxdWlmYXgxLTArBgNVBAsTJEVx
|
||||
dWlmYXggU2VjdXJlIENlcnRpZmljYXRlIEF1dGhvcml0eTCBnzANBgkqhkiG9w0B
|
||||
AQEFAAOBjQAwgYkCgYEAwV2xWGcIYu6gmi0fCG2RFGiYCh7+2gRvE4RiIcPRfM6f
|
||||
BeC4AfBONOziipUEZKzxa1NfBbPLZ4C/QgKO/t0BCezhABRP/PvwDN1Dulsr4R+A
|
||||
cJkVV5MW8Q+XarfCaCMczE1ZMKxRHjuvK9buY0V7xdlfUNLjUA86iOe/FP3gx7kC
|
||||
AwEAAaOCAQkwggEFMHAGA1UdHwRpMGcwZaBjoGGkXzBdMQswCQYDVQQGEwJVUzEQ
|
||||
MA4GA1UEChMHRXF1aWZheDEtMCsGA1UECxMkRXF1aWZheCBTZWN1cmUgQ2VydGlm
|
||||
aWNhdGUgQXV0aG9yaXR5MQ0wCwYDVQQDEwRDUkwxMBoGA1UdEAQTMBGBDzIwMTgw
|
||||
ODIyMTY0MTUxWjALBgNVHQ8EBAMCAQYwHwYDVR0jBBgwFoAUSOZo+SvSspXXR9gj
|
||||
IBBPM5iQn9QwHQYDVR0OBBYEFEjmaPkr0rKV10fYIyAQTzOYkJ/UMAwGA1UdEwQF
|
||||
MAMBAf8wGgYJKoZIhvZ9B0EABA0wCxsFVjMuMGMDAgbAMA0GCSqGSIb3DQEBBQUA
|
||||
A4GBAFjOKer89961zgK5F7WF0bnj4JXMJTENAKaSbn+2kmOeUJXRmm/kEd5jhW6Y
|
||||
7qj/WsjTVbJmcVfewCHrPSqnI0kBBIZCe/zuf6IWUrVnZ9NA2zsmWLIodz2uFHdh
|
||||
1voqZiegDfqnc1zqcPGUIWVEX/r87yloqaKHee9570+sB3c4
|
||||
-----END CERTIFICATE-----
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,13 @@
|
|||
|
||||
CFFI
|
||||
====
|
||||
|
||||
Foreign Function Interface for Python calling C code.
|
||||
Please see the `Documentation <http://cffi.readthedocs.org/>`_.
|
||||
|
||||
Contact
|
||||
-------
|
||||
|
||||
`Mailing list <https://groups.google.com/forum/#!forum/python-cffi>`_
|
||||
|
||||
|
||||
|
|
@ -0,0 +1 @@
|
|||
pip
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
Metadata-Version: 2.0
|
||||
Name: cffi
|
||||
Version: 1.10.0
|
||||
Summary: Foreign Function Interface for Python calling C code.
|
||||
Home-page: http://cffi.readthedocs.org
|
||||
Author: Armin Rigo, Maciej Fijalkowski
|
||||
Author-email: python-cffi@googlegroups.com
|
||||
License: MIT
|
||||
Platform: UNKNOWN
|
||||
Classifier: Programming Language :: Python
|
||||
Classifier: Programming Language :: Python :: 2
|
||||
Classifier: Programming Language :: Python :: 2.6
|
||||
Classifier: Programming Language :: Python :: 2.7
|
||||
Classifier: Programming Language :: Python :: 3
|
||||
Classifier: Programming Language :: Python :: 3.2
|
||||
Classifier: Programming Language :: Python :: 3.3
|
||||
Classifier: Programming Language :: Python :: 3.4
|
||||
Classifier: Programming Language :: Python :: 3.5
|
||||
Classifier: Programming Language :: Python :: Implementation :: CPython
|
||||
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
||||
Requires-Dist: pycparser
|
||||
|
||||
|
||||
CFFI
|
||||
====
|
||||
|
||||
Foreign Function Interface for Python calling C code.
|
||||
Please see the `Documentation <http://cffi.readthedocs.org/>`_.
|
||||
|
||||
Contact
|
||||
-------
|
||||
|
||||
`Mailing list <https://groups.google.com/forum/#!forum/python-cffi>`_
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
_cffi_backend.cp36-win32.pyd,sha256=Go5Po3dC2iO2JU1oFvQuisiCbeIH8GyVYoi2_fJGg-Y,128512
|
||||
cffi/__init__.py,sha256=QcAAIPcVY5LsX041WHzW-GYObh691LT5L5FKo0yVwDI,479
|
||||
cffi/_cffi_include.h,sha256=8SuGAPe_N8n4Uv2B8aQOqp5JFG0un02SKGvy1e9tihQ,10238
|
||||
cffi/_embedding.h,sha256=c_xb0Dw0k7gq9OP6x10yrLBy1offSq5KxJphiU4p3dw,17275
|
||||
cffi/api.py,sha256=FHlxuRrwmZbheKY2HY3l1ScFHNuWUBjGu80YPhleICs,39647
|
||||
cffi/backend_ctypes.py,sha256=CcGNp1XCa7QHBihowgHu5BojpeQZ32s7tpEjFrtFtME,42078
|
||||
cffi/cffi_opcode.py,sha256=hbX3E-hmvcwzfqOhqSfH3ObpmRMNOHf_VZ--32flIEo,5459
|
||||
cffi/commontypes.py,sha256=QS4uxCDI7JhtTyjh1hlnCA-gynmaszWxJaRRLGkJa1A,2689
|
||||
cffi/cparser.py,sha256=AX4kk4BejnA8erNHzTEyeJWbX_He82MBmpIsZlkmdl8,38507
|
||||
cffi/error.py,sha256=yNDakwm_HPJ_T53ivbB7hEts2N-oBGjMLw_25aNi2E8,536
|
||||
cffi/ffiplatform.py,sha256=g-6tBT6R2aXkIDaAmni92fXI4rwiCegHU8AyD_wL3wo,3645
|
||||
cffi/lock.py,sha256=l9TTdwMIMpi6jDkJGnQgE9cvTIR7CAntIJr8EGHt3pY,747
|
||||
cffi/model.py,sha256=1daM9AYkCFmSMUrgbbe9KIK1jXsJPYZn_A828S-Qbv8,21103
|
||||
cffi/parse_c_type.h,sha256=BBca7ODJCzlsK_1c4HT5MFhy1wdUbyHYvKuvCvxaQZ8,5835
|
||||
cffi/recompiler.py,sha256=43UEvyl2mtigxa0laEWHLDtv2T8OTZEMj2NXComOoJU,61597
|
||||
cffi/setuptools_ext.py,sha256=07n99TzG6QAsFDhf5-cE10pN2FIzKQ-DVFV5YbnN6eA,7463
|
||||
cffi/vengine_cpy.py,sha256=Kw_Z38hrBJPUod5R517dRwAvRC7SjQ0qx8fEX1ZaFAM,41325
|
||||
cffi/vengine_gen.py,sha256=dLmNdH0MmI_Jog3PlYuCG5RIdx8_xP_RuOjRa84q_N8,26597
|
||||
cffi/verifier.py,sha256=Vk8v9fePaHkMmDc-wftJ4gowPErbacfC3soTw_rpT8U,11519
|
||||
cffi-1.10.0.dist-info/DESCRIPTION.rst,sha256=9ijQLbcqTWNF-iV0RznFiBeBCNrjArA0P-eutKUPw98,220
|
||||
cffi-1.10.0.dist-info/METADATA,sha256=c9-fyjmuNh52W-A4SeBTlgr35GZEsWW2Tskw_3nHCWM,1090
|
||||
cffi-1.10.0.dist-info/RECORD,,
|
||||
cffi-1.10.0.dist-info/WHEEL,sha256=xiHTm3JxoVljPSD6nSGhq3B4VY9iUqMNXwYQ259n1PI,102
|
||||
cffi-1.10.0.dist-info/entry_points.txt,sha256=Q9f5C9IpjYxo0d2PK9eUcnkgxHc9pHWwjEMaANPKNCI,76
|
||||
cffi-1.10.0.dist-info/metadata.json,sha256=fBsfmNhS5_P6IGaL1mdGMDj8o0NZPotHZeGIB3FprRI,1112
|
||||
cffi-1.10.0.dist-info/top_level.txt,sha256=rE7WR3rZfNKxWI9-jn6hsHCAl7MDkB-FmuQbxWjFehQ,19
|
||||
cffi-1.10.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
cffi/__pycache__/api.cpython-36.pyc,,
|
||||
cffi/__pycache__/backend_ctypes.cpython-36.pyc,,
|
||||
cffi/__pycache__/cffi_opcode.cpython-36.pyc,,
|
||||
cffi/__pycache__/commontypes.cpython-36.pyc,,
|
||||
cffi/__pycache__/cparser.cpython-36.pyc,,
|
||||
cffi/__pycache__/error.cpython-36.pyc,,
|
||||
cffi/__pycache__/ffiplatform.cpython-36.pyc,,
|
||||
cffi/__pycache__/lock.cpython-36.pyc,,
|
||||
cffi/__pycache__/model.cpython-36.pyc,,
|
||||
cffi/__pycache__/recompiler.cpython-36.pyc,,
|
||||
cffi/__pycache__/setuptools_ext.cpython-36.pyc,,
|
||||
cffi/__pycache__/vengine_cpy.cpython-36.pyc,,
|
||||
cffi/__pycache__/vengine_gen.cpython-36.pyc,,
|
||||
cffi/__pycache__/verifier.cpython-36.pyc,,
|
||||
cffi/__pycache__/__init__.cpython-36.pyc,,
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
Wheel-Version: 1.0
|
||||
Generator: bdist_wheel (0.29.0)
|
||||
Root-Is-Purelib: false
|
||||
Tag: cp36-cp36m-win32
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
[distutils.setup_keywords]
|
||||
cffi_modules = cffi.setuptools_ext:cffi_modules
|
||||
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
_cffi_backend
|
||||
cffi
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
__all__ = ['FFI', 'VerificationError', 'VerificationMissing', 'CDefError',
|
||||
'FFIError']
|
||||
|
||||
from .api import FFI
|
||||
from .error import CDefError, FFIError, VerificationError, VerificationMissing
|
||||
|
||||
__version__ = "1.10.0"
|
||||
__version_info__ = (1, 10, 0)
|
||||
|
||||
# The verifier module file names are based on the CRC32 of a string that
|
||||
# contains the following version number. It may be older than __version__
|
||||
# if nothing is clearly incompatible.
|
||||
__version_verifier_modules__ = "0.8.6"
|
||||
|
|
@ -0,0 +1,253 @@
|
|||
#define _CFFI_
|
||||
|
||||
/* We try to define Py_LIMITED_API before including Python.h.
|
||||
|
||||
Mess: we can only define it if Py_DEBUG, Py_TRACE_REFS and
|
||||
Py_REF_DEBUG are not defined. This is a best-effort approximation:
|
||||
we can learn about Py_DEBUG from pyconfig.h, but it is unclear if
|
||||
the same works for the other two macros. Py_DEBUG implies them,
|
||||
but not the other way around.
|
||||
*/
|
||||
#ifndef _CFFI_USE_EMBEDDING
|
||||
# include <pyconfig.h>
|
||||
# if !defined(Py_DEBUG) && !defined(Py_TRACE_REFS) && !defined(Py_REF_DEBUG)
|
||||
# define Py_LIMITED_API
|
||||
# endif
|
||||
#endif
|
||||
|
||||
#include <Python.h>
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
#include <stddef.h>
|
||||
#include "parse_c_type.h"
|
||||
|
||||
/* this block of #ifs should be kept exactly identical between
|
||||
c/_cffi_backend.c, cffi/vengine_cpy.py, cffi/vengine_gen.py
|
||||
and cffi/_cffi_include.h */
|
||||
#if defined(_MSC_VER)
|
||||
# include <malloc.h> /* for alloca() */
|
||||
# if _MSC_VER < 1600 /* MSVC < 2010 */
|
||||
typedef __int8 int8_t;
|
||||
typedef __int16 int16_t;
|
||||
typedef __int32 int32_t;
|
||||
typedef __int64 int64_t;
|
||||
typedef unsigned __int8 uint8_t;
|
||||
typedef unsigned __int16 uint16_t;
|
||||
typedef unsigned __int32 uint32_t;
|
||||
typedef unsigned __int64 uint64_t;
|
||||
typedef __int8 int_least8_t;
|
||||
typedef __int16 int_least16_t;
|
||||
typedef __int32 int_least32_t;
|
||||
typedef __int64 int_least64_t;
|
||||
typedef unsigned __int8 uint_least8_t;
|
||||
typedef unsigned __int16 uint_least16_t;
|
||||
typedef unsigned __int32 uint_least32_t;
|
||||
typedef unsigned __int64 uint_least64_t;
|
||||
typedef __int8 int_fast8_t;
|
||||
typedef __int16 int_fast16_t;
|
||||
typedef __int32 int_fast32_t;
|
||||
typedef __int64 int_fast64_t;
|
||||
typedef unsigned __int8 uint_fast8_t;
|
||||
typedef unsigned __int16 uint_fast16_t;
|
||||
typedef unsigned __int32 uint_fast32_t;
|
||||
typedef unsigned __int64 uint_fast64_t;
|
||||
typedef __int64 intmax_t;
|
||||
typedef unsigned __int64 uintmax_t;
|
||||
# else
|
||||
# include <stdint.h>
|
||||
# endif
|
||||
# if _MSC_VER < 1800 /* MSVC < 2013 */
|
||||
# ifndef __cplusplus
|
||||
typedef unsigned char _Bool;
|
||||
# endif
|
||||
# endif
|
||||
#else
|
||||
# include <stdint.h>
|
||||
# if (defined (__SVR4) && defined (__sun)) || defined(_AIX) || defined(__hpux)
|
||||
# include <alloca.h>
|
||||
# endif
|
||||
#endif
|
||||
|
||||
#ifdef __GNUC__
|
||||
# define _CFFI_UNUSED_FN __attribute__((unused))
|
||||
#else
|
||||
# define _CFFI_UNUSED_FN /* nothing */
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
# ifndef _Bool
|
||||
typedef bool _Bool; /* semi-hackish: C++ has no _Bool; bool is builtin */
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/********** CPython-specific section **********/
|
||||
#ifndef PYPY_VERSION
|
||||
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
# define PyInt_FromLong PyLong_FromLong
|
||||
#endif
|
||||
|
||||
#define _cffi_from_c_double PyFloat_FromDouble
|
||||
#define _cffi_from_c_float PyFloat_FromDouble
|
||||
#define _cffi_from_c_long PyInt_FromLong
|
||||
#define _cffi_from_c_ulong PyLong_FromUnsignedLong
|
||||
#define _cffi_from_c_longlong PyLong_FromLongLong
|
||||
#define _cffi_from_c_ulonglong PyLong_FromUnsignedLongLong
|
||||
|
||||
#define _cffi_to_c_double PyFloat_AsDouble
|
||||
#define _cffi_to_c_float PyFloat_AsDouble
|
||||
|
||||
#define _cffi_from_c_int(x, type) \
|
||||
(((type)-1) > 0 ? /* unsigned */ \
|
||||
(sizeof(type) < sizeof(long) ? \
|
||||
PyInt_FromLong((long)x) : \
|
||||
sizeof(type) == sizeof(long) ? \
|
||||
PyLong_FromUnsignedLong((unsigned long)x) : \
|
||||
PyLong_FromUnsignedLongLong((unsigned long long)x)) : \
|
||||
(sizeof(type) <= sizeof(long) ? \
|
||||
PyInt_FromLong((long)x) : \
|
||||
PyLong_FromLongLong((long long)x)))
|
||||
|
||||
#define _cffi_to_c_int(o, type) \
|
||||
((type)( \
|
||||
sizeof(type) == 1 ? (((type)-1) > 0 ? (type)_cffi_to_c_u8(o) \
|
||||
: (type)_cffi_to_c_i8(o)) : \
|
||||
sizeof(type) == 2 ? (((type)-1) > 0 ? (type)_cffi_to_c_u16(o) \
|
||||
: (type)_cffi_to_c_i16(o)) : \
|
||||
sizeof(type) == 4 ? (((type)-1) > 0 ? (type)_cffi_to_c_u32(o) \
|
||||
: (type)_cffi_to_c_i32(o)) : \
|
||||
sizeof(type) == 8 ? (((type)-1) > 0 ? (type)_cffi_to_c_u64(o) \
|
||||
: (type)_cffi_to_c_i64(o)) : \
|
||||
(Py_FatalError("unsupported size for type " #type), (type)0)))
|
||||
|
||||
#define _cffi_to_c_i8 \
|
||||
((int(*)(PyObject *))_cffi_exports[1])
|
||||
#define _cffi_to_c_u8 \
|
||||
((int(*)(PyObject *))_cffi_exports[2])
|
||||
#define _cffi_to_c_i16 \
|
||||
((int(*)(PyObject *))_cffi_exports[3])
|
||||
#define _cffi_to_c_u16 \
|
||||
((int(*)(PyObject *))_cffi_exports[4])
|
||||
#define _cffi_to_c_i32 \
|
||||
((int(*)(PyObject *))_cffi_exports[5])
|
||||
#define _cffi_to_c_u32 \
|
||||
((unsigned int(*)(PyObject *))_cffi_exports[6])
|
||||
#define _cffi_to_c_i64 \
|
||||
((long long(*)(PyObject *))_cffi_exports[7])
|
||||
#define _cffi_to_c_u64 \
|
||||
((unsigned long long(*)(PyObject *))_cffi_exports[8])
|
||||
#define _cffi_to_c_char \
|
||||
((int(*)(PyObject *))_cffi_exports[9])
|
||||
#define _cffi_from_c_pointer \
|
||||
((PyObject *(*)(char *, struct _cffi_ctypedescr *))_cffi_exports[10])
|
||||
#define _cffi_to_c_pointer \
|
||||
((char *(*)(PyObject *, struct _cffi_ctypedescr *))_cffi_exports[11])
|
||||
#define _cffi_get_struct_layout \
|
||||
not used any more
|
||||
#define _cffi_restore_errno \
|
||||
((void(*)(void))_cffi_exports[13])
|
||||
#define _cffi_save_errno \
|
||||
((void(*)(void))_cffi_exports[14])
|
||||
#define _cffi_from_c_char \
|
||||
((PyObject *(*)(char))_cffi_exports[15])
|
||||
#define _cffi_from_c_deref \
|
||||
((PyObject *(*)(char *, struct _cffi_ctypedescr *))_cffi_exports[16])
|
||||
#define _cffi_to_c \
|
||||
((int(*)(char *, struct _cffi_ctypedescr *, PyObject *))_cffi_exports[17])
|
||||
#define _cffi_from_c_struct \
|
||||
((PyObject *(*)(char *, struct _cffi_ctypedescr *))_cffi_exports[18])
|
||||
#define _cffi_to_c_wchar_t \
|
||||
((wchar_t(*)(PyObject *))_cffi_exports[19])
|
||||
#define _cffi_from_c_wchar_t \
|
||||
((PyObject *(*)(wchar_t))_cffi_exports[20])
|
||||
#define _cffi_to_c_long_double \
|
||||
((long double(*)(PyObject *))_cffi_exports[21])
|
||||
#define _cffi_to_c__Bool \
|
||||
((_Bool(*)(PyObject *))_cffi_exports[22])
|
||||
#define _cffi_prepare_pointer_call_argument \
|
||||
((Py_ssize_t(*)(struct _cffi_ctypedescr *, \
|
||||
PyObject *, char **))_cffi_exports[23])
|
||||
#define _cffi_convert_array_from_object \
|
||||
((int(*)(char *, struct _cffi_ctypedescr *, PyObject *))_cffi_exports[24])
|
||||
#define _CFFI_CPIDX 25
|
||||
#define _cffi_call_python \
|
||||
((void(*)(struct _cffi_externpy_s *, char *))_cffi_exports[_CFFI_CPIDX])
|
||||
#define _CFFI_NUM_EXPORTS 26
|
||||
|
||||
struct _cffi_ctypedescr;
|
||||
|
||||
static void *_cffi_exports[_CFFI_NUM_EXPORTS];
|
||||
|
||||
#define _cffi_type(index) ( \
|
||||
assert((((uintptr_t)_cffi_types[index]) & 1) == 0), \
|
||||
(struct _cffi_ctypedescr *)_cffi_types[index])
|
||||
|
||||
static PyObject *_cffi_init(const char *module_name, Py_ssize_t version,
|
||||
const struct _cffi_type_context_s *ctx)
|
||||
{
|
||||
PyObject *module, *o_arg, *new_module;
|
||||
void *raw[] = {
|
||||
(void *)module_name,
|
||||
(void *)version,
|
||||
(void *)_cffi_exports,
|
||||
(void *)ctx,
|
||||
};
|
||||
|
||||
module = PyImport_ImportModule("_cffi_backend");
|
||||
if (module == NULL)
|
||||
goto failure;
|
||||
|
||||
o_arg = PyLong_FromVoidPtr((void *)raw);
|
||||
if (o_arg == NULL)
|
||||
goto failure;
|
||||
|
||||
new_module = PyObject_CallMethod(
|
||||
module, (char *)"_init_cffi_1_0_external_module", (char *)"O", o_arg);
|
||||
|
||||
Py_DECREF(o_arg);
|
||||
Py_DECREF(module);
|
||||
return new_module;
|
||||
|
||||
failure:
|
||||
Py_XDECREF(module);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/********** end CPython-specific section **********/
|
||||
#else
|
||||
_CFFI_UNUSED_FN
|
||||
static void (*_cffi_call_python_org)(struct _cffi_externpy_s *, char *);
|
||||
# define _cffi_call_python _cffi_call_python_org
|
||||
#endif
|
||||
|
||||
|
||||
#define _cffi_array_len(array) (sizeof(array) / sizeof((array)[0]))
|
||||
|
||||
#define _cffi_prim_int(size, sign) \
|
||||
((size) == 1 ? ((sign) ? _CFFI_PRIM_INT8 : _CFFI_PRIM_UINT8) : \
|
||||
(size) == 2 ? ((sign) ? _CFFI_PRIM_INT16 : _CFFI_PRIM_UINT16) : \
|
||||
(size) == 4 ? ((sign) ? _CFFI_PRIM_INT32 : _CFFI_PRIM_UINT32) : \
|
||||
(size) == 8 ? ((sign) ? _CFFI_PRIM_INT64 : _CFFI_PRIM_UINT64) : \
|
||||
_CFFI__UNKNOWN_PRIM)
|
||||
|
||||
#define _cffi_prim_float(size) \
|
||||
((size) == sizeof(float) ? _CFFI_PRIM_FLOAT : \
|
||||
(size) == sizeof(double) ? _CFFI_PRIM_DOUBLE : \
|
||||
(size) == sizeof(long double) ? _CFFI__UNKNOWN_LONG_DOUBLE : \
|
||||
_CFFI__UNKNOWN_FLOAT_PRIM)
|
||||
|
||||
#define _cffi_check_int(got, got_nonpos, expected) \
|
||||
((got_nonpos) == (expected <= 0) && \
|
||||
(got) == (unsigned long long)expected)
|
||||
|
||||
#ifdef MS_WIN32
|
||||
# define _cffi_stdcall __stdcall
|
||||
#else
|
||||
# define _cffi_stdcall /* nothing */
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
@ -0,0 +1,517 @@
|
|||
|
||||
/***** Support code for embedding *****/
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
# define CFFI_DLLEXPORT __declspec(dllexport)
|
||||
#elif defined(__GNUC__)
|
||||
# define CFFI_DLLEXPORT __attribute__((visibility("default")))
|
||||
#else
|
||||
# define CFFI_DLLEXPORT /* nothing */
|
||||
#endif
|
||||
|
||||
|
||||
/* There are two global variables of type _cffi_call_python_fnptr:
|
||||
|
||||
* _cffi_call_python, which we declare just below, is the one called
|
||||
by ``extern "Python"`` implementations.
|
||||
|
||||
* _cffi_call_python_org, which on CPython is actually part of the
|
||||
_cffi_exports[] array, is the function pointer copied from
|
||||
_cffi_backend.
|
||||
|
||||
After initialization is complete, both are equal. However, the
|
||||
first one remains equal to &_cffi_start_and_call_python until the
|
||||
very end of initialization, when we are (or should be) sure that
|
||||
concurrent threads also see a completely initialized world, and
|
||||
only then is it changed.
|
||||
*/
|
||||
#undef _cffi_call_python
|
||||
typedef void (*_cffi_call_python_fnptr)(struct _cffi_externpy_s *, char *);
|
||||
static void _cffi_start_and_call_python(struct _cffi_externpy_s *, char *);
|
||||
static _cffi_call_python_fnptr _cffi_call_python = &_cffi_start_and_call_python;
|
||||
|
||||
|
||||
#ifndef _MSC_VER
|
||||
/* --- Assuming a GCC not infinitely old --- */
|
||||
# define cffi_compare_and_swap(l,o,n) __sync_bool_compare_and_swap(l,o,n)
|
||||
# define cffi_write_barrier() __sync_synchronize()
|
||||
# if !defined(__amd64__) && !defined(__x86_64__) && \
|
||||
!defined(__i386__) && !defined(__i386)
|
||||
# define cffi_read_barrier() __sync_synchronize()
|
||||
# else
|
||||
# define cffi_read_barrier() (void)0
|
||||
# endif
|
||||
#else
|
||||
/* --- Windows threads version --- */
|
||||
# include <Windows.h>
|
||||
# define cffi_compare_and_swap(l,o,n) \
|
||||
(InterlockedCompareExchangePointer(l,n,o) == (o))
|
||||
# define cffi_write_barrier() InterlockedCompareExchange(&_cffi_dummy,0,0)
|
||||
# define cffi_read_barrier() (void)0
|
||||
static volatile LONG _cffi_dummy;
|
||||
#endif
|
||||
|
||||
#ifdef WITH_THREAD
|
||||
# ifndef _MSC_VER
|
||||
# include <pthread.h>
|
||||
static pthread_mutex_t _cffi_embed_startup_lock;
|
||||
# else
|
||||
static CRITICAL_SECTION _cffi_embed_startup_lock;
|
||||
# endif
|
||||
static char _cffi_embed_startup_lock_ready = 0;
|
||||
#endif
|
||||
|
||||
static void _cffi_acquire_reentrant_mutex(void)
|
||||
{
|
||||
static void *volatile lock = NULL;
|
||||
|
||||
while (!cffi_compare_and_swap(&lock, NULL, (void *)1)) {
|
||||
/* should ideally do a spin loop instruction here, but
|
||||
hard to do it portably and doesn't really matter I
|
||||
think: pthread_mutex_init() should be very fast, and
|
||||
this is only run at start-up anyway. */
|
||||
}
|
||||
|
||||
#ifdef WITH_THREAD
|
||||
if (!_cffi_embed_startup_lock_ready) {
|
||||
# ifndef _MSC_VER
|
||||
pthread_mutexattr_t attr;
|
||||
pthread_mutexattr_init(&attr);
|
||||
pthread_mutexattr_settype(&attr, PTHREAD_MUTEX_RECURSIVE);
|
||||
pthread_mutex_init(&_cffi_embed_startup_lock, &attr);
|
||||
# else
|
||||
InitializeCriticalSection(&_cffi_embed_startup_lock);
|
||||
# endif
|
||||
_cffi_embed_startup_lock_ready = 1;
|
||||
}
|
||||
#endif
|
||||
|
||||
while (!cffi_compare_and_swap(&lock, (void *)1, NULL))
|
||||
;
|
||||
|
||||
#ifndef _MSC_VER
|
||||
pthread_mutex_lock(&_cffi_embed_startup_lock);
|
||||
#else
|
||||
EnterCriticalSection(&_cffi_embed_startup_lock);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void _cffi_release_reentrant_mutex(void)
|
||||
{
|
||||
#ifndef _MSC_VER
|
||||
pthread_mutex_unlock(&_cffi_embed_startup_lock);
|
||||
#else
|
||||
LeaveCriticalSection(&_cffi_embed_startup_lock);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
/********** CPython-specific section **********/
|
||||
#ifndef PYPY_VERSION
|
||||
|
||||
|
||||
#define _cffi_call_python_org _cffi_exports[_CFFI_CPIDX]
|
||||
|
||||
PyMODINIT_FUNC _CFFI_PYTHON_STARTUP_FUNC(void); /* forward */
|
||||
|
||||
static void _cffi_py_initialize(void)
|
||||
{
|
||||
/* XXX use initsigs=0, which "skips initialization registration of
|
||||
signal handlers, which might be useful when Python is
|
||||
embedded" according to the Python docs. But review and think
|
||||
if it should be a user-controllable setting.
|
||||
|
||||
XXX we should also give a way to write errors to a buffer
|
||||
instead of to stderr.
|
||||
|
||||
XXX if importing 'site' fails, CPython (any version) calls
|
||||
exit(). Should we try to work around this behavior here?
|
||||
*/
|
||||
Py_InitializeEx(0);
|
||||
}
|
||||
|
||||
static int _cffi_initialize_python(void)
|
||||
{
|
||||
/* This initializes Python, imports _cffi_backend, and then the
|
||||
present .dll/.so is set up as a CPython C extension module.
|
||||
*/
|
||||
int result;
|
||||
PyGILState_STATE state;
|
||||
PyObject *pycode=NULL, *global_dict=NULL, *x;
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
/* see comments in _cffi_carefully_make_gil() about the
|
||||
Python2/Python3 difference
|
||||
*/
|
||||
#else
|
||||
/* Acquire the GIL. We have no threadstate here. If Python is
|
||||
already initialized, it is possible that there is already one
|
||||
existing for this thread, but it is not made current now.
|
||||
*/
|
||||
PyEval_AcquireLock();
|
||||
|
||||
_cffi_py_initialize();
|
||||
|
||||
/* The Py_InitializeEx() sometimes made a threadstate for us, but
|
||||
not always. Indeed Py_InitializeEx() could be called and do
|
||||
nothing. So do we have a threadstate, or not? We don't know,
|
||||
but we can replace it with NULL in all cases.
|
||||
*/
|
||||
(void)PyThreadState_Swap(NULL);
|
||||
|
||||
/* Now we can release the GIL and re-acquire immediately using the
|
||||
logic of PyGILState(), which handles making or installing the
|
||||
correct threadstate.
|
||||
*/
|
||||
PyEval_ReleaseLock();
|
||||
#endif
|
||||
state = PyGILState_Ensure();
|
||||
|
||||
/* Call the initxxx() function from the present module. It will
|
||||
create and initialize us as a CPython extension module, instead
|
||||
of letting the startup Python code do it---it might reimport
|
||||
the same .dll/.so and get maybe confused on some platforms.
|
||||
It might also have troubles locating the .dll/.so again for all
|
||||
I know.
|
||||
*/
|
||||
(void)_CFFI_PYTHON_STARTUP_FUNC();
|
||||
if (PyErr_Occurred())
|
||||
goto error;
|
||||
|
||||
/* Now run the Python code provided to ffi.embedding_init_code().
|
||||
*/
|
||||
pycode = Py_CompileString(_CFFI_PYTHON_STARTUP_CODE,
|
||||
"<init code for '" _CFFI_MODULE_NAME "'>",
|
||||
Py_file_input);
|
||||
if (pycode == NULL)
|
||||
goto error;
|
||||
global_dict = PyDict_New();
|
||||
if (global_dict == NULL)
|
||||
goto error;
|
||||
if (PyDict_SetItemString(global_dict, "__builtins__",
|
||||
PyThreadState_GET()->interp->builtins) < 0)
|
||||
goto error;
|
||||
x = PyEval_EvalCode(
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
(PyCodeObject *)
|
||||
#endif
|
||||
pycode, global_dict, global_dict);
|
||||
if (x == NULL)
|
||||
goto error;
|
||||
Py_DECREF(x);
|
||||
|
||||
/* Done! Now if we've been called from
|
||||
_cffi_start_and_call_python() in an ``extern "Python"``, we can
|
||||
only hope that the Python code did correctly set up the
|
||||
corresponding @ffi.def_extern() function. Otherwise, the
|
||||
general logic of ``extern "Python"`` functions (inside the
|
||||
_cffi_backend module) will find that the reference is still
|
||||
missing and print an error.
|
||||
*/
|
||||
result = 0;
|
||||
done:
|
||||
Py_XDECREF(pycode);
|
||||
Py_XDECREF(global_dict);
|
||||
PyGILState_Release(state);
|
||||
return result;
|
||||
|
||||
error:;
|
||||
{
|
||||
/* Print as much information as potentially useful.
|
||||
Debugging load-time failures with embedding is not fun
|
||||
*/
|
||||
PyObject *exception, *v, *tb, *f, *modules, *mod;
|
||||
PyErr_Fetch(&exception, &v, &tb);
|
||||
if (exception != NULL) {
|
||||
PyErr_NormalizeException(&exception, &v, &tb);
|
||||
PyErr_Display(exception, v, tb);
|
||||
}
|
||||
Py_XDECREF(exception);
|
||||
Py_XDECREF(v);
|
||||
Py_XDECREF(tb);
|
||||
|
||||
f = PySys_GetObject((char *)"stderr");
|
||||
if (f != NULL && f != Py_None) {
|
||||
PyFile_WriteString("\nFrom: " _CFFI_MODULE_NAME
|
||||
"\ncompiled with cffi version: 1.10.0"
|
||||
"\n_cffi_backend module: ", f);
|
||||
modules = PyImport_GetModuleDict();
|
||||
mod = PyDict_GetItemString(modules, "_cffi_backend");
|
||||
if (mod == NULL) {
|
||||
PyFile_WriteString("not loaded", f);
|
||||
}
|
||||
else {
|
||||
v = PyObject_GetAttrString(mod, "__file__");
|
||||
PyFile_WriteObject(v, f, 0);
|
||||
Py_XDECREF(v);
|
||||
}
|
||||
PyFile_WriteString("\nsys.path: ", f);
|
||||
PyFile_WriteObject(PySys_GetObject((char *)"path"), f, 0);
|
||||
PyFile_WriteString("\n\n", f);
|
||||
}
|
||||
}
|
||||
result = -1;
|
||||
goto done;
|
||||
}
|
||||
|
||||
PyAPI_DATA(char *) _PyParser_TokenNames[]; /* from CPython */
|
||||
|
||||
static int _cffi_carefully_make_gil(void)
|
||||
{
|
||||
/* This does the basic initialization of Python. It can be called
|
||||
completely concurrently from unrelated threads. It assumes
|
||||
that we don't hold the GIL before (if it exists), and we don't
|
||||
hold it afterwards.
|
||||
|
||||
What it really does is completely different in Python 2 and
|
||||
Python 3.
|
||||
|
||||
Python 2
|
||||
========
|
||||
|
||||
Initialize the GIL, without initializing the rest of Python,
|
||||
by calling PyEval_InitThreads().
|
||||
|
||||
PyEval_InitThreads() must not be called concurrently at all.
|
||||
So we use a global variable as a simple spin lock. This global
|
||||
variable must be from 'libpythonX.Y.so', not from this
|
||||
cffi-based extension module, because it must be shared from
|
||||
different cffi-based extension modules. We choose
|
||||
_PyParser_TokenNames[0] as a completely arbitrary pointer value
|
||||
that is never written to. The default is to point to the
|
||||
string "ENDMARKER". We change it temporarily to point to the
|
||||
next character in that string. (Yes, I know it's REALLY
|
||||
obscure.)
|
||||
|
||||
Python 3
|
||||
========
|
||||
|
||||
In Python 3, PyEval_InitThreads() cannot be called before
|
||||
Py_InitializeEx() any more. So this function calls
|
||||
Py_InitializeEx() first. It uses the same obscure logic to
|
||||
make sure we never call it concurrently.
|
||||
|
||||
Arguably, this is less good on the spinlock, because
|
||||
Py_InitializeEx() takes much longer to run than
|
||||
PyEval_InitThreads(). But I didn't find a way around it.
|
||||
*/
|
||||
|
||||
#ifdef WITH_THREAD
|
||||
char *volatile *lock = (char *volatile *)_PyParser_TokenNames;
|
||||
char *old_value;
|
||||
|
||||
while (1) { /* spin loop */
|
||||
old_value = *lock;
|
||||
if (old_value[0] == 'E') {
|
||||
assert(old_value[1] == 'N');
|
||||
if (cffi_compare_and_swap(lock, old_value, old_value + 1))
|
||||
break;
|
||||
}
|
||||
else {
|
||||
assert(old_value[0] == 'N');
|
||||
/* should ideally do a spin loop instruction here, but
|
||||
hard to do it portably and doesn't really matter I
|
||||
think: PyEval_InitThreads() should be very fast, and
|
||||
this is only run at start-up anyway. */
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
/* Python 3: call Py_InitializeEx() */
|
||||
{
|
||||
PyGILState_STATE state = PyGILState_UNLOCKED;
|
||||
if (!Py_IsInitialized())
|
||||
_cffi_py_initialize();
|
||||
else
|
||||
state = PyGILState_Ensure();
|
||||
|
||||
PyEval_InitThreads();
|
||||
PyGILState_Release(state);
|
||||
}
|
||||
#else
|
||||
/* Python 2: call PyEval_InitThreads() */
|
||||
# ifdef WITH_THREAD
|
||||
if (!PyEval_ThreadsInitialized()) {
|
||||
PyEval_InitThreads(); /* makes the GIL */
|
||||
PyEval_ReleaseLock(); /* then release it */
|
||||
}
|
||||
/* else: there is already a GIL, but we still needed to do the
|
||||
spinlock dance to make sure that we see it as fully ready */
|
||||
# endif
|
||||
#endif
|
||||
|
||||
#ifdef WITH_THREAD
|
||||
/* release the lock */
|
||||
while (!cffi_compare_and_swap(lock, old_value + 1, old_value))
|
||||
;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/********** end CPython-specific section **********/
|
||||
|
||||
|
||||
#else
|
||||
|
||||
|
||||
/********** PyPy-specific section **********/
|
||||
|
||||
PyMODINIT_FUNC _CFFI_PYTHON_STARTUP_FUNC(const void *[]); /* forward */
|
||||
|
||||
static struct _cffi_pypy_init_s {
|
||||
const char *name;
|
||||
void (*func)(const void *[]);
|
||||
const char *code;
|
||||
} _cffi_pypy_init = {
|
||||
_CFFI_MODULE_NAME,
|
||||
(void(*)(const void *[]))_CFFI_PYTHON_STARTUP_FUNC,
|
||||
_CFFI_PYTHON_STARTUP_CODE,
|
||||
};
|
||||
|
||||
extern int pypy_carefully_make_gil(const char *);
|
||||
extern int pypy_init_embedded_cffi_module(int, struct _cffi_pypy_init_s *);
|
||||
|
||||
static int _cffi_carefully_make_gil(void)
|
||||
{
|
||||
return pypy_carefully_make_gil(_CFFI_MODULE_NAME);
|
||||
}
|
||||
|
||||
static int _cffi_initialize_python(void)
|
||||
{
|
||||
return pypy_init_embedded_cffi_module(0xB011, &_cffi_pypy_init);
|
||||
}
|
||||
|
||||
/********** end PyPy-specific section **********/
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef __GNUC__
|
||||
__attribute__((noinline))
|
||||
#endif
|
||||
static _cffi_call_python_fnptr _cffi_start_python(void)
|
||||
{
|
||||
/* Delicate logic to initialize Python. This function can be
|
||||
called multiple times concurrently, e.g. when the process calls
|
||||
its first ``extern "Python"`` functions in multiple threads at
|
||||
once. It can also be called recursively, in which case we must
|
||||
ignore it. We also have to consider what occurs if several
|
||||
different cffi-based extensions reach this code in parallel
|
||||
threads---it is a different copy of the code, then, and we
|
||||
can't have any shared global variable unless it comes from
|
||||
'libpythonX.Y.so'.
|
||||
|
||||
Idea:
|
||||
|
||||
* _cffi_carefully_make_gil(): "carefully" call
|
||||
PyEval_InitThreads() (possibly with Py_InitializeEx() first).
|
||||
|
||||
* then we use a (local) custom lock to make sure that a call to this
|
||||
cffi-based extension will wait if another call to the *same*
|
||||
extension is running the initialization in another thread.
|
||||
It is reentrant, so that a recursive call will not block, but
|
||||
only one from a different thread.
|
||||
|
||||
* then we grab the GIL and (Python 2) we call Py_InitializeEx().
|
||||
At this point, concurrent calls to Py_InitializeEx() are not
|
||||
possible: we have the GIL.
|
||||
|
||||
* do the rest of the specific initialization, which may
|
||||
temporarily release the GIL but not the custom lock.
|
||||
Only release the custom lock when we are done.
|
||||
*/
|
||||
static char called = 0;
|
||||
|
||||
if (_cffi_carefully_make_gil() != 0)
|
||||
return NULL;
|
||||
|
||||
_cffi_acquire_reentrant_mutex();
|
||||
|
||||
/* Here the GIL exists, but we don't have it. We're only protected
|
||||
from concurrency by the reentrant mutex. */
|
||||
|
||||
/* This file only initializes the embedded module once, the first
|
||||
time this is called, even if there are subinterpreters. */
|
||||
if (!called) {
|
||||
called = 1; /* invoke _cffi_initialize_python() only once,
|
||||
but don't set '_cffi_call_python' right now,
|
||||
otherwise concurrent threads won't call
|
||||
this function at all (we need them to wait) */
|
||||
if (_cffi_initialize_python() == 0) {
|
||||
/* now initialization is finished. Switch to the fast-path. */
|
||||
|
||||
/* We would like nobody to see the new value of
|
||||
'_cffi_call_python' without also seeing the rest of the
|
||||
data initialized. However, this is not possible. But
|
||||
the new value of '_cffi_call_python' is the function
|
||||
'cffi_call_python()' from _cffi_backend. So: */
|
||||
cffi_write_barrier();
|
||||
/* ^^^ we put a write barrier here, and a corresponding
|
||||
read barrier at the start of cffi_call_python(). This
|
||||
ensures that after that read barrier, we see everything
|
||||
done here before the write barrier.
|
||||
*/
|
||||
|
||||
assert(_cffi_call_python_org != NULL);
|
||||
_cffi_call_python = (_cffi_call_python_fnptr)_cffi_call_python_org;
|
||||
}
|
||||
else {
|
||||
/* initialization failed. Reset this to NULL, even if it was
|
||||
already set to some other value. Future calls to
|
||||
_cffi_start_python() are still forced to occur, and will
|
||||
always return NULL from now on. */
|
||||
_cffi_call_python_org = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
_cffi_release_reentrant_mutex();
|
||||
|
||||
return (_cffi_call_python_fnptr)_cffi_call_python_org;
|
||||
}
|
||||
|
||||
static
|
||||
void _cffi_start_and_call_python(struct _cffi_externpy_s *externpy, char *args)
|
||||
{
|
||||
_cffi_call_python_fnptr fnptr;
|
||||
int current_err = errno;
|
||||
#ifdef _MSC_VER
|
||||
int current_lasterr = GetLastError();
|
||||
#endif
|
||||
fnptr = _cffi_start_python();
|
||||
if (fnptr == NULL) {
|
||||
fprintf(stderr, "function %s() called, but initialization code "
|
||||
"failed. Returning 0.\n", externpy->name);
|
||||
memset(args, 0, externpy->size_of_result);
|
||||
}
|
||||
#ifdef _MSC_VER
|
||||
SetLastError(current_lasterr);
|
||||
#endif
|
||||
errno = current_err;
|
||||
|
||||
if (fnptr != NULL)
|
||||
fnptr(externpy, args);
|
||||
}
|
||||
|
||||
|
||||
/* The cffi_start_python() function makes sure Python is initialized
|
||||
and our cffi module is set up. It can be called manually from the
|
||||
user C code. The same effect is obtained automatically from any
|
||||
dll-exported ``extern "Python"`` function. This function returns
|
||||
-1 if initialization failed, 0 if all is OK. */
|
||||
_CFFI_UNUSED_FN
|
||||
static int cffi_start_python(void)
|
||||
{
|
||||
if (_cffi_call_python == &_cffi_start_and_call_python) {
|
||||
if (_cffi_start_python() == NULL)
|
||||
return -1;
|
||||
}
|
||||
cffi_read_barrier();
|
||||
return 0;
|
||||
}
|
||||
|
||||
#undef cffi_compare_and_swap
|
||||
#undef cffi_write_barrier
|
||||
#undef cffi_read_barrier
|
||||
|
|
@ -0,0 +1,916 @@
|
|||
import sys, types
|
||||
from .lock import allocate_lock
|
||||
from .error import CDefError
|
||||
from . import model
|
||||
|
||||
try:
|
||||
callable
|
||||
except NameError:
|
||||
# Python 3.1
|
||||
from collections import Callable
|
||||
callable = lambda x: isinstance(x, Callable)
|
||||
|
||||
try:
|
||||
basestring
|
||||
except NameError:
|
||||
# Python 3.x
|
||||
basestring = str
|
||||
|
||||
|
||||
|
||||
class FFI(object):
|
||||
r'''
|
||||
The main top-level class that you instantiate once, or once per module.
|
||||
|
||||
Example usage:
|
||||
|
||||
ffi = FFI()
|
||||
ffi.cdef("""
|
||||
int printf(const char *, ...);
|
||||
""")
|
||||
|
||||
C = ffi.dlopen(None) # standard library
|
||||
-or-
|
||||
C = ffi.verify() # use a C compiler: verify the decl above is right
|
||||
|
||||
C.printf("hello, %s!\n", ffi.new("char[]", "world"))
|
||||
'''
|
||||
|
||||
def __init__(self, backend=None):
|
||||
"""Create an FFI instance. The 'backend' argument is used to
|
||||
select a non-default backend, mostly for tests.
|
||||
"""
|
||||
if backend is None:
|
||||
# You need PyPy (>= 2.0 beta), or a CPython (>= 2.6) with
|
||||
# _cffi_backend.so compiled.
|
||||
import _cffi_backend as backend
|
||||
from . import __version__
|
||||
if backend.__version__ != __version__:
|
||||
# bad version! Try to be as explicit as possible.
|
||||
if hasattr(backend, '__file__'):
|
||||
# CPython
|
||||
raise Exception("Version mismatch: this is the 'cffi' package version %s, located in %r. When we import the top-level '_cffi_backend' extension module, we get version %s, located in %r. The two versions should be equal; check your installation." % (
|
||||
__version__, __file__,
|
||||
backend.__version__, backend.__file__))
|
||||
else:
|
||||
# PyPy
|
||||
raise Exception("Version mismatch: this is the 'cffi' package version %s, located in %r. This interpreter comes with a built-in '_cffi_backend' module, which is version %s. The two versions should be equal; check your installation." % (
|
||||
__version__, __file__, backend.__version__))
|
||||
# (If you insist you can also try to pass the option
|
||||
# 'backend=backend_ctypes.CTypesBackend()', but don't
|
||||
# rely on it! It's probably not going to work well.)
|
||||
|
||||
from . import cparser
|
||||
self._backend = backend
|
||||
self._lock = allocate_lock()
|
||||
self._parser = cparser.Parser()
|
||||
self._cached_btypes = {}
|
||||
self._parsed_types = types.ModuleType('parsed_types').__dict__
|
||||
self._new_types = types.ModuleType('new_types').__dict__
|
||||
self._function_caches = []
|
||||
self._libraries = []
|
||||
self._cdefsources = []
|
||||
self._included_ffis = []
|
||||
self._windows_unicode = None
|
||||
self._init_once_cache = {}
|
||||
self._cdef_version = None
|
||||
self._embedding = None
|
||||
if hasattr(backend, 'set_ffi'):
|
||||
backend.set_ffi(self)
|
||||
for name in backend.__dict__:
|
||||
if name.startswith('RTLD_'):
|
||||
setattr(self, name, getattr(backend, name))
|
||||
#
|
||||
with self._lock:
|
||||
self.BVoidP = self._get_cached_btype(model.voidp_type)
|
||||
self.BCharA = self._get_cached_btype(model.char_array_type)
|
||||
if isinstance(backend, types.ModuleType):
|
||||
# _cffi_backend: attach these constants to the class
|
||||
if not hasattr(FFI, 'NULL'):
|
||||
FFI.NULL = self.cast(self.BVoidP, 0)
|
||||
FFI.CData, FFI.CType = backend._get_types()
|
||||
else:
|
||||
# ctypes backend: attach these constants to the instance
|
||||
self.NULL = self.cast(self.BVoidP, 0)
|
||||
self.CData, self.CType = backend._get_types()
|
||||
self.buffer = backend.buffer
|
||||
|
||||
def cdef(self, csource, override=False, packed=False):
|
||||
"""Parse the given C source. This registers all declared functions,
|
||||
types, and global variables. The functions and global variables can
|
||||
then be accessed via either 'ffi.dlopen()' or 'ffi.verify()'.
|
||||
The types can be used in 'ffi.new()' and other functions.
|
||||
If 'packed' is specified as True, all structs declared inside this
|
||||
cdef are packed, i.e. laid out without any field alignment at all.
|
||||
"""
|
||||
self._cdef(csource, override=override, packed=packed)
|
||||
|
||||
def embedding_api(self, csource, packed=False):
|
||||
self._cdef(csource, packed=packed, dllexport=True)
|
||||
if self._embedding is None:
|
||||
self._embedding = ''
|
||||
|
||||
def _cdef(self, csource, override=False, **options):
|
||||
if not isinstance(csource, str): # unicode, on Python 2
|
||||
if not isinstance(csource, basestring):
|
||||
raise TypeError("cdef() argument must be a string")
|
||||
csource = csource.encode('ascii')
|
||||
with self._lock:
|
||||
self._cdef_version = object()
|
||||
self._parser.parse(csource, override=override, **options)
|
||||
self._cdefsources.append(csource)
|
||||
if override:
|
||||
for cache in self._function_caches:
|
||||
cache.clear()
|
||||
finishlist = self._parser._recomplete
|
||||
if finishlist:
|
||||
self._parser._recomplete = []
|
||||
for tp in finishlist:
|
||||
tp.finish_backend_type(self, finishlist)
|
||||
|
||||
def dlopen(self, name, flags=0):
|
||||
"""Load and return a dynamic library identified by 'name'.
|
||||
The standard C library can be loaded by passing None.
|
||||
Note that functions and types declared by 'ffi.cdef()' are not
|
||||
linked to a particular library, just like C headers; in the
|
||||
library we only look for the actual (untyped) symbols.
|
||||
"""
|
||||
assert isinstance(name, basestring) or name is None
|
||||
with self._lock:
|
||||
lib, function_cache = _make_ffi_library(self, name, flags)
|
||||
self._function_caches.append(function_cache)
|
||||
self._libraries.append(lib)
|
||||
return lib
|
||||
|
||||
def _typeof_locked(self, cdecl):
|
||||
# call me with the lock!
|
||||
key = cdecl
|
||||
if key in self._parsed_types:
|
||||
return self._parsed_types[key]
|
||||
#
|
||||
if not isinstance(cdecl, str): # unicode, on Python 2
|
||||
cdecl = cdecl.encode('ascii')
|
||||
#
|
||||
type = self._parser.parse_type(cdecl)
|
||||
really_a_function_type = type.is_raw_function
|
||||
if really_a_function_type:
|
||||
type = type.as_function_pointer()
|
||||
btype = self._get_cached_btype(type)
|
||||
result = btype, really_a_function_type
|
||||
self._parsed_types[key] = result
|
||||
return result
|
||||
|
||||
def _typeof(self, cdecl, consider_function_as_funcptr=False):
|
||||
# string -> ctype object
|
||||
try:
|
||||
result = self._parsed_types[cdecl]
|
||||
except KeyError:
|
||||
with self._lock:
|
||||
result = self._typeof_locked(cdecl)
|
||||
#
|
||||
btype, really_a_function_type = result
|
||||
if really_a_function_type and not consider_function_as_funcptr:
|
||||
raise CDefError("the type %r is a function type, not a "
|
||||
"pointer-to-function type" % (cdecl,))
|
||||
return btype
|
||||
|
||||
def typeof(self, cdecl):
|
||||
"""Parse the C type given as a string and return the
|
||||
corresponding <ctype> object.
|
||||
It can also be used on 'cdata' instance to get its C type.
|
||||
"""
|
||||
if isinstance(cdecl, basestring):
|
||||
return self._typeof(cdecl)
|
||||
if isinstance(cdecl, self.CData):
|
||||
return self._backend.typeof(cdecl)
|
||||
if isinstance(cdecl, types.BuiltinFunctionType):
|
||||
res = _builtin_function_type(cdecl)
|
||||
if res is not None:
|
||||
return res
|
||||
if (isinstance(cdecl, types.FunctionType)
|
||||
and hasattr(cdecl, '_cffi_base_type')):
|
||||
with self._lock:
|
||||
return self._get_cached_btype(cdecl._cffi_base_type)
|
||||
raise TypeError(type(cdecl))
|
||||
|
||||
def sizeof(self, cdecl):
|
||||
"""Return the size in bytes of the argument. It can be a
|
||||
string naming a C type, or a 'cdata' instance.
|
||||
"""
|
||||
if isinstance(cdecl, basestring):
|
||||
BType = self._typeof(cdecl)
|
||||
return self._backend.sizeof(BType)
|
||||
else:
|
||||
return self._backend.sizeof(cdecl)
|
||||
|
||||
def alignof(self, cdecl):
|
||||
"""Return the natural alignment size in bytes of the C type
|
||||
given as a string.
|
||||
"""
|
||||
if isinstance(cdecl, basestring):
|
||||
cdecl = self._typeof(cdecl)
|
||||
return self._backend.alignof(cdecl)
|
||||
|
||||
def offsetof(self, cdecl, *fields_or_indexes):
|
||||
"""Return the offset of the named field inside the given
|
||||
structure or array, which must be given as a C type name.
|
||||
You can give several field names in case of nested structures.
|
||||
You can also give numeric values which correspond to array
|
||||
items, in case of an array type.
|
||||
"""
|
||||
if isinstance(cdecl, basestring):
|
||||
cdecl = self._typeof(cdecl)
|
||||
return self._typeoffsetof(cdecl, *fields_or_indexes)[1]
|
||||
|
||||
def new(self, cdecl, init=None):
|
||||
"""Allocate an instance according to the specified C type and
|
||||
return a pointer to it. The specified C type must be either a
|
||||
pointer or an array: ``new('X *')`` allocates an X and returns
|
||||
a pointer to it, whereas ``new('X[n]')`` allocates an array of
|
||||
n X'es and returns an array referencing it (which works
|
||||
mostly like a pointer, like in C). You can also use
|
||||
``new('X[]', n)`` to allocate an array of a non-constant
|
||||
length n.
|
||||
|
||||
The memory is initialized following the rules of declaring a
|
||||
global variable in C: by default it is zero-initialized, but
|
||||
an explicit initializer can be given which can be used to
|
||||
fill all or part of the memory.
|
||||
|
||||
When the returned <cdata> object goes out of scope, the memory
|
||||
is freed. In other words the returned <cdata> object has
|
||||
ownership of the value of type 'cdecl' that it points to. This
|
||||
means that the raw data can be used as long as this object is
|
||||
kept alive, but must not be used for a longer time. Be careful
|
||||
about that when copying the pointer to the memory somewhere
|
||||
else, e.g. into another structure.
|
||||
"""
|
||||
if isinstance(cdecl, basestring):
|
||||
cdecl = self._typeof(cdecl)
|
||||
return self._backend.newp(cdecl, init)
|
||||
|
||||
def new_allocator(self, alloc=None, free=None,
|
||||
should_clear_after_alloc=True):
|
||||
"""Return a new allocator, i.e. a function that behaves like ffi.new()
|
||||
but uses the provided low-level 'alloc' and 'free' functions.
|
||||
|
||||
'alloc' is called with the size as argument. If it returns NULL, a
|
||||
MemoryError is raised. 'free' is called with the result of 'alloc'
|
||||
as argument. Both can be either Python function or directly C
|
||||
functions. If 'free' is None, then no free function is called.
|
||||
If both 'alloc' and 'free' are None, the default is used.
|
||||
|
||||
If 'should_clear_after_alloc' is set to False, then the memory
|
||||
returned by 'alloc' is assumed to be already cleared (or you are
|
||||
fine with garbage); otherwise CFFI will clear it.
|
||||
"""
|
||||
compiled_ffi = self._backend.FFI()
|
||||
allocator = compiled_ffi.new_allocator(alloc, free,
|
||||
should_clear_after_alloc)
|
||||
def allocate(cdecl, init=None):
|
||||
if isinstance(cdecl, basestring):
|
||||
cdecl = self._typeof(cdecl)
|
||||
return allocator(cdecl, init)
|
||||
return allocate
|
||||
|
||||
def cast(self, cdecl, source):
|
||||
"""Similar to a C cast: returns an instance of the named C
|
||||
type initialized with the given 'source'. The source is
|
||||
casted between integers or pointers of any type.
|
||||
"""
|
||||
if isinstance(cdecl, basestring):
|
||||
cdecl = self._typeof(cdecl)
|
||||
return self._backend.cast(cdecl, source)
|
||||
|
||||
def string(self, cdata, maxlen=-1):
|
||||
"""Return a Python string (or unicode string) from the 'cdata'.
|
||||
If 'cdata' is a pointer or array of characters or bytes, returns
|
||||
the null-terminated string. The returned string extends until
|
||||
the first null character, or at most 'maxlen' characters. If
|
||||
'cdata' is an array then 'maxlen' defaults to its length.
|
||||
|
||||
If 'cdata' is a pointer or array of wchar_t, returns a unicode
|
||||
string following the same rules.
|
||||
|
||||
If 'cdata' is a single character or byte or a wchar_t, returns
|
||||
it as a string or unicode string.
|
||||
|
||||
If 'cdata' is an enum, returns the value of the enumerator as a
|
||||
string, or 'NUMBER' if the value is out of range.
|
||||
"""
|
||||
return self._backend.string(cdata, maxlen)
|
||||
|
||||
def unpack(self, cdata, length):
|
||||
"""Unpack an array of C data of the given length,
|
||||
returning a Python string/unicode/list.
|
||||
|
||||
If 'cdata' is a pointer to 'char', returns a byte string.
|
||||
It does not stop at the first null. This is equivalent to:
|
||||
ffi.buffer(cdata, length)[:]
|
||||
|
||||
If 'cdata' is a pointer to 'wchar_t', returns a unicode string.
|
||||
'length' is measured in wchar_t's; it is not the size in bytes.
|
||||
|
||||
If 'cdata' is a pointer to anything else, returns a list of
|
||||
'length' items. This is a faster equivalent to:
|
||||
[cdata[i] for i in range(length)]
|
||||
"""
|
||||
return self._backend.unpack(cdata, length)
|
||||
|
||||
#def buffer(self, cdata, size=-1):
|
||||
# """Return a read-write buffer object that references the raw C data
|
||||
# pointed to by the given 'cdata'. The 'cdata' must be a pointer or
|
||||
# an array. Can be passed to functions expecting a buffer, or directly
|
||||
# manipulated with:
|
||||
#
|
||||
# buf[:] get a copy of it in a regular string, or
|
||||
# buf[idx] as a single character
|
||||
# buf[:] = ...
|
||||
# buf[idx] = ... change the content
|
||||
# """
|
||||
# note that 'buffer' is a type, set on this instance by __init__
|
||||
|
||||
def from_buffer(self, python_buffer):
|
||||
"""Return a <cdata 'char[]'> that points to the data of the
|
||||
given Python object, which must support the buffer interface.
|
||||
Note that this is not meant to be used on the built-in types
|
||||
str or unicode (you can build 'char[]' arrays explicitly)
|
||||
but only on objects containing large quantities of raw data
|
||||
in some other format, like 'array.array' or numpy arrays.
|
||||
"""
|
||||
return self._backend.from_buffer(self.BCharA, python_buffer)
|
||||
|
||||
def memmove(self, dest, src, n):
|
||||
"""ffi.memmove(dest, src, n) copies n bytes of memory from src to dest.
|
||||
|
||||
Like the C function memmove(), the memory areas may overlap;
|
||||
apart from that it behaves like the C function memcpy().
|
||||
|
||||
'src' can be any cdata ptr or array, or any Python buffer object.
|
||||
'dest' can be any cdata ptr or array, or a writable Python buffer
|
||||
object. The size to copy, 'n', is always measured in bytes.
|
||||
|
||||
Unlike other methods, this one supports all Python buffer including
|
||||
byte strings and bytearrays---but it still does not support
|
||||
non-contiguous buffers.
|
||||
"""
|
||||
return self._backend.memmove(dest, src, n)
|
||||
|
||||
def callback(self, cdecl, python_callable=None, error=None, onerror=None):
|
||||
"""Return a callback object or a decorator making such a
|
||||
callback object. 'cdecl' must name a C function pointer type.
|
||||
The callback invokes the specified 'python_callable' (which may
|
||||
be provided either directly or via a decorator). Important: the
|
||||
callback object must be manually kept alive for as long as the
|
||||
callback may be invoked from the C level.
|
||||
"""
|
||||
def callback_decorator_wrap(python_callable):
|
||||
if not callable(python_callable):
|
||||
raise TypeError("the 'python_callable' argument "
|
||||
"is not callable")
|
||||
return self._backend.callback(cdecl, python_callable,
|
||||
error, onerror)
|
||||
if isinstance(cdecl, basestring):
|
||||
cdecl = self._typeof(cdecl, consider_function_as_funcptr=True)
|
||||
if python_callable is None:
|
||||
return callback_decorator_wrap # decorator mode
|
||||
else:
|
||||
return callback_decorator_wrap(python_callable) # direct mode
|
||||
|
||||
def getctype(self, cdecl, replace_with=''):
|
||||
"""Return a string giving the C type 'cdecl', which may be itself
|
||||
a string or a <ctype> object. If 'replace_with' is given, it gives
|
||||
extra text to append (or insert for more complicated C types), like
|
||||
a variable name, or '*' to get actually the C type 'pointer-to-cdecl'.
|
||||
"""
|
||||
if isinstance(cdecl, basestring):
|
||||
cdecl = self._typeof(cdecl)
|
||||
replace_with = replace_with.strip()
|
||||
if (replace_with.startswith('*')
|
||||
and '&[' in self._backend.getcname(cdecl, '&')):
|
||||
replace_with = '(%s)' % replace_with
|
||||
elif replace_with and not replace_with[0] in '[(':
|
||||
replace_with = ' ' + replace_with
|
||||
return self._backend.getcname(cdecl, replace_with)
|
||||
|
||||
def gc(self, cdata, destructor):
|
||||
"""Return a new cdata object that points to the same
|
||||
data. Later, when this new cdata object is garbage-collected,
|
||||
'destructor(old_cdata_object)' will be called.
|
||||
"""
|
||||
return self._backend.gcp(cdata, destructor)
|
||||
|
||||
def _get_cached_btype(self, type):
|
||||
assert self._lock.acquire(False) is False
|
||||
# call me with the lock!
|
||||
try:
|
||||
BType = self._cached_btypes[type]
|
||||
except KeyError:
|
||||
finishlist = []
|
||||
BType = type.get_cached_btype(self, finishlist)
|
||||
for type in finishlist:
|
||||
type.finish_backend_type(self, finishlist)
|
||||
return BType
|
||||
|
||||
def verify(self, source='', tmpdir=None, **kwargs):
|
||||
"""Verify that the current ffi signatures compile on this
|
||||
machine, and return a dynamic library object. The dynamic
|
||||
library can be used to call functions and access global
|
||||
variables declared in this 'ffi'. The library is compiled
|
||||
by the C compiler: it gives you C-level API compatibility
|
||||
(including calling macros). This is unlike 'ffi.dlopen()',
|
||||
which requires binary compatibility in the signatures.
|
||||
"""
|
||||
from .verifier import Verifier, _caller_dir_pycache
|
||||
#
|
||||
# If set_unicode(True) was called, insert the UNICODE and
|
||||
# _UNICODE macro declarations
|
||||
if self._windows_unicode:
|
||||
self._apply_windows_unicode(kwargs)
|
||||
#
|
||||
# Set the tmpdir here, and not in Verifier.__init__: it picks
|
||||
# up the caller's directory, which we want to be the caller of
|
||||
# ffi.verify(), as opposed to the caller of Veritier().
|
||||
tmpdir = tmpdir or _caller_dir_pycache()
|
||||
#
|
||||
# Make a Verifier() and use it to load the library.
|
||||
self.verifier = Verifier(self, source, tmpdir, **kwargs)
|
||||
lib = self.verifier.load_library()
|
||||
#
|
||||
# Save the loaded library for keep-alive purposes, even
|
||||
# if the caller doesn't keep it alive itself (it should).
|
||||
self._libraries.append(lib)
|
||||
return lib
|
||||
|
||||
def _get_errno(self):
|
||||
return self._backend.get_errno()
|
||||
def _set_errno(self, errno):
|
||||
self._backend.set_errno(errno)
|
||||
errno = property(_get_errno, _set_errno, None,
|
||||
"the value of 'errno' from/to the C calls")
|
||||
|
||||
def getwinerror(self, code=-1):
|
||||
return self._backend.getwinerror(code)
|
||||
|
||||
def _pointer_to(self, ctype):
|
||||
with self._lock:
|
||||
return model.pointer_cache(self, ctype)
|
||||
|
||||
def addressof(self, cdata, *fields_or_indexes):
|
||||
"""Return the address of a <cdata 'struct-or-union'>.
|
||||
If 'fields_or_indexes' are given, returns the address of that
|
||||
field or array item in the structure or array, recursively in
|
||||
case of nested structures.
|
||||
"""
|
||||
try:
|
||||
ctype = self._backend.typeof(cdata)
|
||||
except TypeError:
|
||||
if '__addressof__' in type(cdata).__dict__:
|
||||
return type(cdata).__addressof__(cdata, *fields_or_indexes)
|
||||
raise
|
||||
if fields_or_indexes:
|
||||
ctype, offset = self._typeoffsetof(ctype, *fields_or_indexes)
|
||||
else:
|
||||
if ctype.kind == "pointer":
|
||||
raise TypeError("addressof(pointer)")
|
||||
offset = 0
|
||||
ctypeptr = self._pointer_to(ctype)
|
||||
return self._backend.rawaddressof(ctypeptr, cdata, offset)
|
||||
|
||||
def _typeoffsetof(self, ctype, field_or_index, *fields_or_indexes):
|
||||
ctype, offset = self._backend.typeoffsetof(ctype, field_or_index)
|
||||
for field1 in fields_or_indexes:
|
||||
ctype, offset1 = self._backend.typeoffsetof(ctype, field1, 1)
|
||||
offset += offset1
|
||||
return ctype, offset
|
||||
|
||||
def include(self, ffi_to_include):
|
||||
"""Includes the typedefs, structs, unions and enums defined
|
||||
in another FFI instance. Usage is similar to a #include in C,
|
||||
where a part of the program might include types defined in
|
||||
another part for its own usage. Note that the include()
|
||||
method has no effect on functions, constants and global
|
||||
variables, which must anyway be accessed directly from the
|
||||
lib object returned by the original FFI instance.
|
||||
"""
|
||||
if not isinstance(ffi_to_include, FFI):
|
||||
raise TypeError("ffi.include() expects an argument that is also of"
|
||||
" type cffi.FFI, not %r" % (
|
||||
type(ffi_to_include).__name__,))
|
||||
if ffi_to_include is self:
|
||||
raise ValueError("self.include(self)")
|
||||
with ffi_to_include._lock:
|
||||
with self._lock:
|
||||
self._parser.include(ffi_to_include._parser)
|
||||
self._cdefsources.append('[')
|
||||
self._cdefsources.extend(ffi_to_include._cdefsources)
|
||||
self._cdefsources.append(']')
|
||||
self._included_ffis.append(ffi_to_include)
|
||||
|
||||
def new_handle(self, x):
|
||||
return self._backend.newp_handle(self.BVoidP, x)
|
||||
|
||||
def from_handle(self, x):
|
||||
return self._backend.from_handle(x)
|
||||
|
||||
def set_unicode(self, enabled_flag):
|
||||
"""Windows: if 'enabled_flag' is True, enable the UNICODE and
|
||||
_UNICODE defines in C, and declare the types like TCHAR and LPTCSTR
|
||||
to be (pointers to) wchar_t. If 'enabled_flag' is False,
|
||||
declare these types to be (pointers to) plain 8-bit characters.
|
||||
This is mostly for backward compatibility; you usually want True.
|
||||
"""
|
||||
if self._windows_unicode is not None:
|
||||
raise ValueError("set_unicode() can only be called once")
|
||||
enabled_flag = bool(enabled_flag)
|
||||
if enabled_flag:
|
||||
self.cdef("typedef wchar_t TBYTE;"
|
||||
"typedef wchar_t TCHAR;"
|
||||
"typedef const wchar_t *LPCTSTR;"
|
||||
"typedef const wchar_t *PCTSTR;"
|
||||
"typedef wchar_t *LPTSTR;"
|
||||
"typedef wchar_t *PTSTR;"
|
||||
"typedef TBYTE *PTBYTE;"
|
||||
"typedef TCHAR *PTCHAR;")
|
||||
else:
|
||||
self.cdef("typedef char TBYTE;"
|
||||
"typedef char TCHAR;"
|
||||
"typedef const char *LPCTSTR;"
|
||||
"typedef const char *PCTSTR;"
|
||||
"typedef char *LPTSTR;"
|
||||
"typedef char *PTSTR;"
|
||||
"typedef TBYTE *PTBYTE;"
|
||||
"typedef TCHAR *PTCHAR;")
|
||||
self._windows_unicode = enabled_flag
|
||||
|
||||
def _apply_windows_unicode(self, kwds):
|
||||
defmacros = kwds.get('define_macros', ())
|
||||
if not isinstance(defmacros, (list, tuple)):
|
||||
raise TypeError("'define_macros' must be a list or tuple")
|
||||
defmacros = list(defmacros) + [('UNICODE', '1'),
|
||||
('_UNICODE', '1')]
|
||||
kwds['define_macros'] = defmacros
|
||||
|
||||
def _apply_embedding_fix(self, kwds):
|
||||
# must include an argument like "-lpython2.7" for the compiler
|
||||
def ensure(key, value):
|
||||
lst = kwds.setdefault(key, [])
|
||||
if value not in lst:
|
||||
lst.append(value)
|
||||
#
|
||||
if '__pypy__' in sys.builtin_module_names:
|
||||
import os
|
||||
if sys.platform == "win32":
|
||||
# we need 'libpypy-c.lib'. Current distributions of
|
||||
# pypy (>= 4.1) contain it as 'libs/python27.lib'.
|
||||
pythonlib = "python27"
|
||||
if hasattr(sys, 'prefix'):
|
||||
ensure('library_dirs', os.path.join(sys.prefix, 'libs'))
|
||||
else:
|
||||
# we need 'libpypy-c.{so,dylib}', which should be by
|
||||
# default located in 'sys.prefix/bin' for installed
|
||||
# systems.
|
||||
if sys.version_info < (3,):
|
||||
pythonlib = "pypy-c"
|
||||
else:
|
||||
pythonlib = "pypy3-c"
|
||||
if hasattr(sys, 'prefix'):
|
||||
ensure('library_dirs', os.path.join(sys.prefix, 'bin'))
|
||||
# On uninstalled pypy's, the libpypy-c is typically found in
|
||||
# .../pypy/goal/.
|
||||
if hasattr(sys, 'prefix'):
|
||||
ensure('library_dirs', os.path.join(sys.prefix, 'pypy', 'goal'))
|
||||
else:
|
||||
if sys.platform == "win32":
|
||||
template = "python%d%d"
|
||||
if hasattr(sys, 'gettotalrefcount'):
|
||||
template += '_d'
|
||||
else:
|
||||
try:
|
||||
import sysconfig
|
||||
except ImportError: # 2.6
|
||||
from distutils import sysconfig
|
||||
template = "python%d.%d"
|
||||
if sysconfig.get_config_var('DEBUG_EXT'):
|
||||
template += sysconfig.get_config_var('DEBUG_EXT')
|
||||
pythonlib = (template %
|
||||
(sys.hexversion >> 24, (sys.hexversion >> 16) & 0xff))
|
||||
if hasattr(sys, 'abiflags'):
|
||||
pythonlib += sys.abiflags
|
||||
ensure('libraries', pythonlib)
|
||||
if sys.platform == "win32":
|
||||
ensure('extra_link_args', '/MANIFEST')
|
||||
|
||||
def set_source(self, module_name, source, source_extension='.c', **kwds):
|
||||
import os
|
||||
if hasattr(self, '_assigned_source'):
|
||||
raise ValueError("set_source() cannot be called several times "
|
||||
"per ffi object")
|
||||
if not isinstance(module_name, basestring):
|
||||
raise TypeError("'module_name' must be a string")
|
||||
if os.sep in module_name or (os.altsep and os.altsep in module_name):
|
||||
raise ValueError("'module_name' must not contain '/': use a dotted "
|
||||
"name to make a 'package.module' location")
|
||||
self._assigned_source = (str(module_name), source,
|
||||
source_extension, kwds)
|
||||
|
||||
def distutils_extension(self, tmpdir='build', verbose=True):
|
||||
from distutils.dir_util import mkpath
|
||||
from .recompiler import recompile
|
||||
#
|
||||
if not hasattr(self, '_assigned_source'):
|
||||
if hasattr(self, 'verifier'): # fallback, 'tmpdir' ignored
|
||||
return self.verifier.get_extension()
|
||||
raise ValueError("set_source() must be called before"
|
||||
" distutils_extension()")
|
||||
module_name, source, source_extension, kwds = self._assigned_source
|
||||
if source is None:
|
||||
raise TypeError("distutils_extension() is only for C extension "
|
||||
"modules, not for dlopen()-style pure Python "
|
||||
"modules")
|
||||
mkpath(tmpdir)
|
||||
ext, updated = recompile(self, module_name,
|
||||
source, tmpdir=tmpdir, extradir=tmpdir,
|
||||
source_extension=source_extension,
|
||||
call_c_compiler=False, **kwds)
|
||||
if verbose:
|
||||
if updated:
|
||||
sys.stderr.write("regenerated: %r\n" % (ext.sources[0],))
|
||||
else:
|
||||
sys.stderr.write("not modified: %r\n" % (ext.sources[0],))
|
||||
return ext
|
||||
|
||||
def emit_c_code(self, filename):
|
||||
from .recompiler import recompile
|
||||
#
|
||||
if not hasattr(self, '_assigned_source'):
|
||||
raise ValueError("set_source() must be called before emit_c_code()")
|
||||
module_name, source, source_extension, kwds = self._assigned_source
|
||||
if source is None:
|
||||
raise TypeError("emit_c_code() is only for C extension modules, "
|
||||
"not for dlopen()-style pure Python modules")
|
||||
recompile(self, module_name, source,
|
||||
c_file=filename, call_c_compiler=False, **kwds)
|
||||
|
||||
def emit_python_code(self, filename):
|
||||
from .recompiler import recompile
|
||||
#
|
||||
if not hasattr(self, '_assigned_source'):
|
||||
raise ValueError("set_source() must be called before emit_c_code()")
|
||||
module_name, source, source_extension, kwds = self._assigned_source
|
||||
if source is not None:
|
||||
raise TypeError("emit_python_code() is only for dlopen()-style "
|
||||
"pure Python modules, not for C extension modules")
|
||||
recompile(self, module_name, source,
|
||||
c_file=filename, call_c_compiler=False, **kwds)
|
||||
|
||||
def compile(self, tmpdir='.', verbose=0, target=None, debug=None):
|
||||
"""The 'target' argument gives the final file name of the
|
||||
compiled DLL. Use '*' to force distutils' choice, suitable for
|
||||
regular CPython C API modules. Use a file name ending in '.*'
|
||||
to ask for the system's default extension for dynamic libraries
|
||||
(.so/.dll/.dylib).
|
||||
|
||||
The default is '*' when building a non-embedded C API extension,
|
||||
and (module_name + '.*') when building an embedded library.
|
||||
"""
|
||||
from .recompiler import recompile
|
||||
#
|
||||
if not hasattr(self, '_assigned_source'):
|
||||
raise ValueError("set_source() must be called before compile()")
|
||||
module_name, source, source_extension, kwds = self._assigned_source
|
||||
return recompile(self, module_name, source, tmpdir=tmpdir,
|
||||
target=target, source_extension=source_extension,
|
||||
compiler_verbose=verbose, debug=debug, **kwds)
|
||||
|
||||
def init_once(self, func, tag):
|
||||
# Read _init_once_cache[tag], which is either (False, lock) if
|
||||
# we're calling the function now in some thread, or (True, result).
|
||||
# Don't call setdefault() in most cases, to avoid allocating and
|
||||
# immediately freeing a lock; but still use setdefaut() to avoid
|
||||
# races.
|
||||
try:
|
||||
x = self._init_once_cache[tag]
|
||||
except KeyError:
|
||||
x = self._init_once_cache.setdefault(tag, (False, allocate_lock()))
|
||||
# Common case: we got (True, result), so we return the result.
|
||||
if x[0]:
|
||||
return x[1]
|
||||
# Else, it's a lock. Acquire it to serialize the following tests.
|
||||
with x[1]:
|
||||
# Read again from _init_once_cache the current status.
|
||||
x = self._init_once_cache[tag]
|
||||
if x[0]:
|
||||
return x[1]
|
||||
# Call the function and store the result back.
|
||||
result = func()
|
||||
self._init_once_cache[tag] = (True, result)
|
||||
return result
|
||||
|
||||
def embedding_init_code(self, pysource):
|
||||
if self._embedding:
|
||||
raise ValueError("embedding_init_code() can only be called once")
|
||||
# fix 'pysource' before it gets dumped into the C file:
|
||||
# - remove empty lines at the beginning, so it starts at "line 1"
|
||||
# - dedent, if all non-empty lines are indented
|
||||
# - check for SyntaxErrors
|
||||
import re
|
||||
match = re.match(r'\s*\n', pysource)
|
||||
if match:
|
||||
pysource = pysource[match.end():]
|
||||
lines = pysource.splitlines() or ['']
|
||||
prefix = re.match(r'\s*', lines[0]).group()
|
||||
for i in range(1, len(lines)):
|
||||
line = lines[i]
|
||||
if line.rstrip():
|
||||
while not line.startswith(prefix):
|
||||
prefix = prefix[:-1]
|
||||
i = len(prefix)
|
||||
lines = [line[i:]+'\n' for line in lines]
|
||||
pysource = ''.join(lines)
|
||||
#
|
||||
compile(pysource, "cffi_init", "exec")
|
||||
#
|
||||
self._embedding = pysource
|
||||
|
||||
def def_extern(self, *args, **kwds):
|
||||
raise ValueError("ffi.def_extern() is only available on API-mode FFI "
|
||||
"objects")
|
||||
|
||||
def list_types(self):
|
||||
"""Returns the user type names known to this FFI instance.
|
||||
This returns a tuple containing three lists of names:
|
||||
(typedef_names, names_of_structs, names_of_unions)
|
||||
"""
|
||||
typedefs = []
|
||||
structs = []
|
||||
unions = []
|
||||
for key in self._parser._declarations:
|
||||
if key.startswith('typedef '):
|
||||
typedefs.append(key[8:])
|
||||
elif key.startswith('struct '):
|
||||
structs.append(key[7:])
|
||||
elif key.startswith('union '):
|
||||
unions.append(key[6:])
|
||||
typedefs.sort()
|
||||
structs.sort()
|
||||
unions.sort()
|
||||
return (typedefs, structs, unions)
|
||||
|
||||
|
||||
def _load_backend_lib(backend, name, flags):
|
||||
import os
|
||||
if name is None:
|
||||
if sys.platform != "win32":
|
||||
return backend.load_library(None, flags)
|
||||
name = "c" # Windows: load_library(None) fails, but this works
|
||||
# (backward compatibility hack only)
|
||||
first_error = None
|
||||
if '.' in name or '/' in name or os.sep in name:
|
||||
try:
|
||||
return backend.load_library(name, flags)
|
||||
except OSError as e:
|
||||
first_error = e
|
||||
import ctypes.util
|
||||
path = ctypes.util.find_library(name)
|
||||
if path is None:
|
||||
msg = ("ctypes.util.find_library() did not manage "
|
||||
"to locate a library called %r" % (name,))
|
||||
if first_error is not None:
|
||||
msg = "%s. Additionally, %s" % (first_error, msg)
|
||||
raise OSError(msg)
|
||||
return backend.load_library(path, flags)
|
||||
|
||||
def _make_ffi_library(ffi, libname, flags):
|
||||
backend = ffi._backend
|
||||
backendlib = _load_backend_lib(backend, libname, flags)
|
||||
#
|
||||
def accessor_function(name):
|
||||
key = 'function ' + name
|
||||
tp, _ = ffi._parser._declarations[key]
|
||||
BType = ffi._get_cached_btype(tp)
|
||||
value = backendlib.load_function(BType, name)
|
||||
library.__dict__[name] = value
|
||||
#
|
||||
def accessor_variable(name):
|
||||
key = 'variable ' + name
|
||||
tp, _ = ffi._parser._declarations[key]
|
||||
BType = ffi._get_cached_btype(tp)
|
||||
read_variable = backendlib.read_variable
|
||||
write_variable = backendlib.write_variable
|
||||
setattr(FFILibrary, name, property(
|
||||
lambda self: read_variable(BType, name),
|
||||
lambda self, value: write_variable(BType, name, value)))
|
||||
#
|
||||
def addressof_var(name):
|
||||
try:
|
||||
return addr_variables[name]
|
||||
except KeyError:
|
||||
with ffi._lock:
|
||||
if name not in addr_variables:
|
||||
key = 'variable ' + name
|
||||
tp, _ = ffi._parser._declarations[key]
|
||||
BType = ffi._get_cached_btype(tp)
|
||||
if BType.kind != 'array':
|
||||
BType = model.pointer_cache(ffi, BType)
|
||||
p = backendlib.load_function(BType, name)
|
||||
addr_variables[name] = p
|
||||
return addr_variables[name]
|
||||
#
|
||||
def accessor_constant(name):
|
||||
raise NotImplementedError("non-integer constant '%s' cannot be "
|
||||
"accessed from a dlopen() library" % (name,))
|
||||
#
|
||||
def accessor_int_constant(name):
|
||||
library.__dict__[name] = ffi._parser._int_constants[name]
|
||||
#
|
||||
accessors = {}
|
||||
accessors_version = [False]
|
||||
addr_variables = {}
|
||||
#
|
||||
def update_accessors():
|
||||
if accessors_version[0] is ffi._cdef_version:
|
||||
return
|
||||
#
|
||||
for key, (tp, _) in ffi._parser._declarations.items():
|
||||
if not isinstance(tp, model.EnumType):
|
||||
tag, name = key.split(' ', 1)
|
||||
if tag == 'function':
|
||||
accessors[name] = accessor_function
|
||||
elif tag == 'variable':
|
||||
accessors[name] = accessor_variable
|
||||
elif tag == 'constant':
|
||||
accessors[name] = accessor_constant
|
||||
else:
|
||||
for i, enumname in enumerate(tp.enumerators):
|
||||
def accessor_enum(name, tp=tp, i=i):
|
||||
tp.check_not_partial()
|
||||
library.__dict__[name] = tp.enumvalues[i]
|
||||
accessors[enumname] = accessor_enum
|
||||
for name in ffi._parser._int_constants:
|
||||
accessors.setdefault(name, accessor_int_constant)
|
||||
accessors_version[0] = ffi._cdef_version
|
||||
#
|
||||
def make_accessor(name):
|
||||
with ffi._lock:
|
||||
if name in library.__dict__ or name in FFILibrary.__dict__:
|
||||
return # added by another thread while waiting for the lock
|
||||
if name not in accessors:
|
||||
update_accessors()
|
||||
if name not in accessors:
|
||||
raise AttributeError(name)
|
||||
accessors[name](name)
|
||||
#
|
||||
class FFILibrary(object):
|
||||
def __getattr__(self, name):
|
||||
make_accessor(name)
|
||||
return getattr(self, name)
|
||||
def __setattr__(self, name, value):
|
||||
try:
|
||||
property = getattr(self.__class__, name)
|
||||
except AttributeError:
|
||||
make_accessor(name)
|
||||
setattr(self, name, value)
|
||||
else:
|
||||
property.__set__(self, value)
|
||||
def __dir__(self):
|
||||
with ffi._lock:
|
||||
update_accessors()
|
||||
return accessors.keys()
|
||||
def __addressof__(self, name):
|
||||
if name in library.__dict__:
|
||||
return library.__dict__[name]
|
||||
if name in FFILibrary.__dict__:
|
||||
return addressof_var(name)
|
||||
make_accessor(name)
|
||||
if name in library.__dict__:
|
||||
return library.__dict__[name]
|
||||
if name in FFILibrary.__dict__:
|
||||
return addressof_var(name)
|
||||
raise AttributeError("cffi library has no function or "
|
||||
"global variable named '%s'" % (name,))
|
||||
#
|
||||
if libname is not None:
|
||||
try:
|
||||
if not isinstance(libname, str): # unicode, on Python 2
|
||||
libname = libname.encode('utf-8')
|
||||
FFILibrary.__name__ = 'FFILibrary_%s' % libname
|
||||
except UnicodeError:
|
||||
pass
|
||||
library = FFILibrary()
|
||||
return library, library.__dict__
|
||||
|
||||
def _builtin_function_type(func):
|
||||
# a hack to make at least ffi.typeof(builtin_function) work,
|
||||
# if the builtin function was obtained by 'vengine_cpy'.
|
||||
import sys
|
||||
try:
|
||||
module = sys.modules[func.__module__]
|
||||
ffi = module._cffi_original_ffi
|
||||
types_of_builtin_funcs = module._cffi_types_of_builtin_funcs
|
||||
tp = types_of_builtin_funcs[func]
|
||||
except (KeyError, AttributeError, TypeError):
|
||||
return None
|
||||
else:
|
||||
with ffi._lock:
|
||||
return ffi._get_cached_btype(tp)
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,179 @@
|
|||
from .error import VerificationError
|
||||
|
||||
class CffiOp(object):
|
||||
def __init__(self, op, arg):
|
||||
self.op = op
|
||||
self.arg = arg
|
||||
|
||||
def as_c_expr(self):
|
||||
if self.op is None:
|
||||
assert isinstance(self.arg, str)
|
||||
return '(_cffi_opcode_t)(%s)' % (self.arg,)
|
||||
classname = CLASS_NAME[self.op]
|
||||
return '_CFFI_OP(_CFFI_OP_%s, %s)' % (classname, self.arg)
|
||||
|
||||
def as_python_bytes(self):
|
||||
if self.op is None and self.arg.isdigit():
|
||||
value = int(self.arg) # non-negative: '-' not in self.arg
|
||||
if value >= 2**31:
|
||||
raise OverflowError("cannot emit %r: limited to 2**31-1"
|
||||
% (self.arg,))
|
||||
return format_four_bytes(value)
|
||||
if isinstance(self.arg, str):
|
||||
raise VerificationError("cannot emit to Python: %r" % (self.arg,))
|
||||
return format_four_bytes((self.arg << 8) | self.op)
|
||||
|
||||
def __str__(self):
|
||||
classname = CLASS_NAME.get(self.op, self.op)
|
||||
return '(%s %s)' % (classname, self.arg)
|
||||
|
||||
def format_four_bytes(num):
|
||||
return '\\x%02X\\x%02X\\x%02X\\x%02X' % (
|
||||
(num >> 24) & 0xFF,
|
||||
(num >> 16) & 0xFF,
|
||||
(num >> 8) & 0xFF,
|
||||
(num ) & 0xFF)
|
||||
|
||||
OP_PRIMITIVE = 1
|
||||
OP_POINTER = 3
|
||||
OP_ARRAY = 5
|
||||
OP_OPEN_ARRAY = 7
|
||||
OP_STRUCT_UNION = 9
|
||||
OP_ENUM = 11
|
||||
OP_FUNCTION = 13
|
||||
OP_FUNCTION_END = 15
|
||||
OP_NOOP = 17
|
||||
OP_BITFIELD = 19
|
||||
OP_TYPENAME = 21
|
||||
OP_CPYTHON_BLTN_V = 23 # varargs
|
||||
OP_CPYTHON_BLTN_N = 25 # noargs
|
||||
OP_CPYTHON_BLTN_O = 27 # O (i.e. a single arg)
|
||||
OP_CONSTANT = 29
|
||||
OP_CONSTANT_INT = 31
|
||||
OP_GLOBAL_VAR = 33
|
||||
OP_DLOPEN_FUNC = 35
|
||||
OP_DLOPEN_CONST = 37
|
||||
OP_GLOBAL_VAR_F = 39
|
||||
OP_EXTERN_PYTHON = 41
|
||||
|
||||
PRIM_VOID = 0
|
||||
PRIM_BOOL = 1
|
||||
PRIM_CHAR = 2
|
||||
PRIM_SCHAR = 3
|
||||
PRIM_UCHAR = 4
|
||||
PRIM_SHORT = 5
|
||||
PRIM_USHORT = 6
|
||||
PRIM_INT = 7
|
||||
PRIM_UINT = 8
|
||||
PRIM_LONG = 9
|
||||
PRIM_ULONG = 10
|
||||
PRIM_LONGLONG = 11
|
||||
PRIM_ULONGLONG = 12
|
||||
PRIM_FLOAT = 13
|
||||
PRIM_DOUBLE = 14
|
||||
PRIM_LONGDOUBLE = 15
|
||||
|
||||
PRIM_WCHAR = 16
|
||||
PRIM_INT8 = 17
|
||||
PRIM_UINT8 = 18
|
||||
PRIM_INT16 = 19
|
||||
PRIM_UINT16 = 20
|
||||
PRIM_INT32 = 21
|
||||
PRIM_UINT32 = 22
|
||||
PRIM_INT64 = 23
|
||||
PRIM_UINT64 = 24
|
||||
PRIM_INTPTR = 25
|
||||
PRIM_UINTPTR = 26
|
||||
PRIM_PTRDIFF = 27
|
||||
PRIM_SIZE = 28
|
||||
PRIM_SSIZE = 29
|
||||
PRIM_INT_LEAST8 = 30
|
||||
PRIM_UINT_LEAST8 = 31
|
||||
PRIM_INT_LEAST16 = 32
|
||||
PRIM_UINT_LEAST16 = 33
|
||||
PRIM_INT_LEAST32 = 34
|
||||
PRIM_UINT_LEAST32 = 35
|
||||
PRIM_INT_LEAST64 = 36
|
||||
PRIM_UINT_LEAST64 = 37
|
||||
PRIM_INT_FAST8 = 38
|
||||
PRIM_UINT_FAST8 = 39
|
||||
PRIM_INT_FAST16 = 40
|
||||
PRIM_UINT_FAST16 = 41
|
||||
PRIM_INT_FAST32 = 42
|
||||
PRIM_UINT_FAST32 = 43
|
||||
PRIM_INT_FAST64 = 44
|
||||
PRIM_UINT_FAST64 = 45
|
||||
PRIM_INTMAX = 46
|
||||
PRIM_UINTMAX = 47
|
||||
|
||||
_NUM_PRIM = 48
|
||||
_UNKNOWN_PRIM = -1
|
||||
_UNKNOWN_FLOAT_PRIM = -2
|
||||
_UNKNOWN_LONG_DOUBLE = -3
|
||||
|
||||
_IO_FILE_STRUCT = -1
|
||||
|
||||
PRIMITIVE_TO_INDEX = {
|
||||
'char': PRIM_CHAR,
|
||||
'short': PRIM_SHORT,
|
||||
'int': PRIM_INT,
|
||||
'long': PRIM_LONG,
|
||||
'long long': PRIM_LONGLONG,
|
||||
'signed char': PRIM_SCHAR,
|
||||
'unsigned char': PRIM_UCHAR,
|
||||
'unsigned short': PRIM_USHORT,
|
||||
'unsigned int': PRIM_UINT,
|
||||
'unsigned long': PRIM_ULONG,
|
||||
'unsigned long long': PRIM_ULONGLONG,
|
||||
'float': PRIM_FLOAT,
|
||||
'double': PRIM_DOUBLE,
|
||||
'long double': PRIM_LONGDOUBLE,
|
||||
'_Bool': PRIM_BOOL,
|
||||
'wchar_t': PRIM_WCHAR,
|
||||
'int8_t': PRIM_INT8,
|
||||
'uint8_t': PRIM_UINT8,
|
||||
'int16_t': PRIM_INT16,
|
||||
'uint16_t': PRIM_UINT16,
|
||||
'int32_t': PRIM_INT32,
|
||||
'uint32_t': PRIM_UINT32,
|
||||
'int64_t': PRIM_INT64,
|
||||
'uint64_t': PRIM_UINT64,
|
||||
'intptr_t': PRIM_INTPTR,
|
||||
'uintptr_t': PRIM_UINTPTR,
|
||||
'ptrdiff_t': PRIM_PTRDIFF,
|
||||
'size_t': PRIM_SIZE,
|
||||
'ssize_t': PRIM_SSIZE,
|
||||
'int_least8_t': PRIM_INT_LEAST8,
|
||||
'uint_least8_t': PRIM_UINT_LEAST8,
|
||||
'int_least16_t': PRIM_INT_LEAST16,
|
||||
'uint_least16_t': PRIM_UINT_LEAST16,
|
||||
'int_least32_t': PRIM_INT_LEAST32,
|
||||
'uint_least32_t': PRIM_UINT_LEAST32,
|
||||
'int_least64_t': PRIM_INT_LEAST64,
|
||||
'uint_least64_t': PRIM_UINT_LEAST64,
|
||||
'int_fast8_t': PRIM_INT_FAST8,
|
||||
'uint_fast8_t': PRIM_UINT_FAST8,
|
||||
'int_fast16_t': PRIM_INT_FAST16,
|
||||
'uint_fast16_t': PRIM_UINT_FAST16,
|
||||
'int_fast32_t': PRIM_INT_FAST32,
|
||||
'uint_fast32_t': PRIM_UINT_FAST32,
|
||||
'int_fast64_t': PRIM_INT_FAST64,
|
||||
'uint_fast64_t': PRIM_UINT_FAST64,
|
||||
'intmax_t': PRIM_INTMAX,
|
||||
'uintmax_t': PRIM_UINTMAX,
|
||||
}
|
||||
|
||||
F_UNION = 0x01
|
||||
F_CHECK_FIELDS = 0x02
|
||||
F_PACKED = 0x04
|
||||
F_EXTERNAL = 0x08
|
||||
F_OPAQUE = 0x10
|
||||
|
||||
G_FLAGS = dict([('_CFFI_' + _key, globals()[_key])
|
||||
for _key in ['F_UNION', 'F_CHECK_FIELDS', 'F_PACKED',
|
||||
'F_EXTERNAL', 'F_OPAQUE']])
|
||||
|
||||
CLASS_NAME = {}
|
||||
for _name, _value in list(globals().items()):
|
||||
if _name.startswith('OP_') and isinstance(_value, int):
|
||||
CLASS_NAME[_value] = _name[3:]
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
import sys
|
||||
from . import model
|
||||
from .error import FFIError
|
||||
|
||||
|
||||
COMMON_TYPES = {}
|
||||
|
||||
try:
|
||||
# fetch "bool" and all simple Windows types
|
||||
from _cffi_backend import _get_common_types
|
||||
_get_common_types(COMMON_TYPES)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
COMMON_TYPES['FILE'] = model.unknown_type('FILE', '_IO_FILE')
|
||||
COMMON_TYPES['bool'] = '_Bool' # in case we got ImportError above
|
||||
|
||||
for _type in model.PrimitiveType.ALL_PRIMITIVE_TYPES:
|
||||
if _type.endswith('_t'):
|
||||
COMMON_TYPES[_type] = _type
|
||||
del _type
|
||||
|
||||
_CACHE = {}
|
||||
|
||||
def resolve_common_type(parser, commontype):
|
||||
try:
|
||||
return _CACHE[commontype]
|
||||
except KeyError:
|
||||
cdecl = COMMON_TYPES.get(commontype, commontype)
|
||||
if not isinstance(cdecl, str):
|
||||
result, quals = cdecl, 0 # cdecl is already a BaseType
|
||||
elif cdecl in model.PrimitiveType.ALL_PRIMITIVE_TYPES:
|
||||
result, quals = model.PrimitiveType(cdecl), 0
|
||||
elif cdecl == 'set-unicode-needed':
|
||||
raise FFIError("The Windows type %r is only available after "
|
||||
"you call ffi.set_unicode()" % (commontype,))
|
||||
else:
|
||||
if commontype == cdecl:
|
||||
raise FFIError(
|
||||
"Unsupported type: %r. Please look at "
|
||||
"http://cffi.readthedocs.io/en/latest/cdef.html#ffi-cdef-limitations "
|
||||
"and file an issue if you think this type should really "
|
||||
"be supported." % (commontype,))
|
||||
result, quals = parser.parse_type_and_quals(cdecl) # recursive
|
||||
|
||||
assert isinstance(result, model.BaseTypeByIdentity)
|
||||
_CACHE[commontype] = result, quals
|
||||
return result, quals
|
||||
|
||||
|
||||
# ____________________________________________________________
|
||||
# extra types for Windows (most of them are in commontypes.c)
|
||||
|
||||
|
||||
def win_common_types():
|
||||
return {
|
||||
"UNICODE_STRING": model.StructType(
|
||||
"_UNICODE_STRING",
|
||||
["Length",
|
||||
"MaximumLength",
|
||||
"Buffer"],
|
||||
[model.PrimitiveType("unsigned short"),
|
||||
model.PrimitiveType("unsigned short"),
|
||||
model.PointerType(model.PrimitiveType("wchar_t"))],
|
||||
[-1, -1, -1]),
|
||||
"PUNICODE_STRING": "UNICODE_STRING *",
|
||||
"PCUNICODE_STRING": "const UNICODE_STRING *",
|
||||
|
||||
"TBYTE": "set-unicode-needed",
|
||||
"TCHAR": "set-unicode-needed",
|
||||
"LPCTSTR": "set-unicode-needed",
|
||||
"PCTSTR": "set-unicode-needed",
|
||||
"LPTSTR": "set-unicode-needed",
|
||||
"PTSTR": "set-unicode-needed",
|
||||
"PTBYTE": "set-unicode-needed",
|
||||
"PTCHAR": "set-unicode-needed",
|
||||
}
|
||||
|
||||
if sys.platform == 'win32':
|
||||
COMMON_TYPES.update(win_common_types())
|
||||
|
|
@ -0,0 +1,876 @@
|
|||
from . import model
|
||||
from .commontypes import COMMON_TYPES, resolve_common_type
|
||||
from .error import FFIError, CDefError
|
||||
try:
|
||||
from . import _pycparser as pycparser
|
||||
except ImportError:
|
||||
import pycparser
|
||||
import weakref, re, sys
|
||||
|
||||
try:
|
||||
if sys.version_info < (3,):
|
||||
import thread as _thread
|
||||
else:
|
||||
import _thread
|
||||
lock = _thread.allocate_lock()
|
||||
except ImportError:
|
||||
lock = None
|
||||
|
||||
_r_comment = re.compile(r"/\*.*?\*/|//([^\n\\]|\\.)*?$",
|
||||
re.DOTALL | re.MULTILINE)
|
||||
_r_define = re.compile(r"^\s*#\s*define\s+([A-Za-z_][A-Za-z_0-9]*)"
|
||||
r"\b((?:[^\n\\]|\\.)*?)$",
|
||||
re.DOTALL | re.MULTILINE)
|
||||
_r_partial_enum = re.compile(r"=\s*\.\.\.\s*[,}]|\.\.\.\s*\}")
|
||||
_r_enum_dotdotdot = re.compile(r"__dotdotdot\d+__$")
|
||||
_r_partial_array = re.compile(r"\[\s*\.\.\.\s*\]")
|
||||
_r_words = re.compile(r"\w+|\S")
|
||||
_parser_cache = None
|
||||
_r_int_literal = re.compile(r"-?0?x?[0-9a-f]+[lu]*$", re.IGNORECASE)
|
||||
_r_stdcall1 = re.compile(r"\b(__stdcall|WINAPI)\b")
|
||||
_r_stdcall2 = re.compile(r"[(]\s*(__stdcall|WINAPI)\b")
|
||||
_r_cdecl = re.compile(r"\b__cdecl\b")
|
||||
_r_extern_python = re.compile(r'\bextern\s*"'
|
||||
r'(Python|Python\s*\+\s*C|C\s*\+\s*Python)"\s*.')
|
||||
_r_star_const_space = re.compile( # matches "* const "
|
||||
r"[*]\s*((const|volatile|restrict)\b\s*)+")
|
||||
_r_int_dotdotdot = re.compile(r"(\b(int|long|short|signed|unsigned|char)\s*)+"
|
||||
r"\.\.\.")
|
||||
_r_float_dotdotdot = re.compile(r"\b(double|float)\s*\.\.\.")
|
||||
|
||||
def _get_parser():
|
||||
global _parser_cache
|
||||
if _parser_cache is None:
|
||||
_parser_cache = pycparser.CParser()
|
||||
return _parser_cache
|
||||
|
||||
def _workaround_for_old_pycparser(csource):
|
||||
# Workaround for a pycparser issue (fixed between pycparser 2.10 and
|
||||
# 2.14): "char*const***" gives us a wrong syntax tree, the same as
|
||||
# for "char***(*const)". This means we can't tell the difference
|
||||
# afterwards. But "char(*const(***))" gives us the right syntax
|
||||
# tree. The issue only occurs if there are several stars in
|
||||
# sequence with no parenthesis inbetween, just possibly qualifiers.
|
||||
# Attempt to fix it by adding some parentheses in the source: each
|
||||
# time we see "* const" or "* const *", we add an opening
|
||||
# parenthesis before each star---the hard part is figuring out where
|
||||
# to close them.
|
||||
parts = []
|
||||
while True:
|
||||
match = _r_star_const_space.search(csource)
|
||||
if not match:
|
||||
break
|
||||
#print repr(''.join(parts)+csource), '=>',
|
||||
parts.append(csource[:match.start()])
|
||||
parts.append('('); closing = ')'
|
||||
parts.append(match.group()) # e.g. "* const "
|
||||
endpos = match.end()
|
||||
if csource.startswith('*', endpos):
|
||||
parts.append('('); closing += ')'
|
||||
level = 0
|
||||
i = endpos
|
||||
while i < len(csource):
|
||||
c = csource[i]
|
||||
if c == '(':
|
||||
level += 1
|
||||
elif c == ')':
|
||||
if level == 0:
|
||||
break
|
||||
level -= 1
|
||||
elif c in ',;=':
|
||||
if level == 0:
|
||||
break
|
||||
i += 1
|
||||
csource = csource[endpos:i] + closing + csource[i:]
|
||||
#print repr(''.join(parts)+csource)
|
||||
parts.append(csource)
|
||||
return ''.join(parts)
|
||||
|
||||
def _preprocess_extern_python(csource):
|
||||
# input: `extern "Python" int foo(int);` or
|
||||
# `extern "Python" { int foo(int); }`
|
||||
# output:
|
||||
# void __cffi_extern_python_start;
|
||||
# int foo(int);
|
||||
# void __cffi_extern_python_stop;
|
||||
#
|
||||
# input: `extern "Python+C" int foo(int);`
|
||||
# output:
|
||||
# void __cffi_extern_python_plus_c_start;
|
||||
# int foo(int);
|
||||
# void __cffi_extern_python_stop;
|
||||
parts = []
|
||||
while True:
|
||||
match = _r_extern_python.search(csource)
|
||||
if not match:
|
||||
break
|
||||
endpos = match.end() - 1
|
||||
#print
|
||||
#print ''.join(parts)+csource
|
||||
#print '=>'
|
||||
parts.append(csource[:match.start()])
|
||||
if 'C' in match.group(1):
|
||||
parts.append('void __cffi_extern_python_plus_c_start; ')
|
||||
else:
|
||||
parts.append('void __cffi_extern_python_start; ')
|
||||
if csource[endpos] == '{':
|
||||
# grouping variant
|
||||
closing = csource.find('}', endpos)
|
||||
if closing < 0:
|
||||
raise CDefError("'extern \"Python\" {': no '}' found")
|
||||
if csource.find('{', endpos + 1, closing) >= 0:
|
||||
raise NotImplementedError("cannot use { } inside a block "
|
||||
"'extern \"Python\" { ... }'")
|
||||
parts.append(csource[endpos+1:closing])
|
||||
csource = csource[closing+1:]
|
||||
else:
|
||||
# non-grouping variant
|
||||
semicolon = csource.find(';', endpos)
|
||||
if semicolon < 0:
|
||||
raise CDefError("'extern \"Python\": no ';' found")
|
||||
parts.append(csource[endpos:semicolon+1])
|
||||
csource = csource[semicolon+1:]
|
||||
parts.append(' void __cffi_extern_python_stop;')
|
||||
#print ''.join(parts)+csource
|
||||
#print
|
||||
parts.append(csource)
|
||||
return ''.join(parts)
|
||||
|
||||
def _preprocess(csource):
|
||||
# Remove comments. NOTE: this only work because the cdef() section
|
||||
# should not contain any string literal!
|
||||
csource = _r_comment.sub(' ', csource)
|
||||
# Remove the "#define FOO x" lines
|
||||
macros = {}
|
||||
for match in _r_define.finditer(csource):
|
||||
macroname, macrovalue = match.groups()
|
||||
macrovalue = macrovalue.replace('\\\n', '').strip()
|
||||
macros[macroname] = macrovalue
|
||||
csource = _r_define.sub('', csource)
|
||||
#
|
||||
if pycparser.__version__ < '2.14':
|
||||
csource = _workaround_for_old_pycparser(csource)
|
||||
#
|
||||
# BIG HACK: replace WINAPI or __stdcall with "volatile const".
|
||||
# It doesn't make sense for the return type of a function to be
|
||||
# "volatile volatile const", so we abuse it to detect __stdcall...
|
||||
# Hack number 2 is that "int(volatile *fptr)();" is not valid C
|
||||
# syntax, so we place the "volatile" before the opening parenthesis.
|
||||
csource = _r_stdcall2.sub(' volatile volatile const(', csource)
|
||||
csource = _r_stdcall1.sub(' volatile volatile const ', csource)
|
||||
csource = _r_cdecl.sub(' ', csource)
|
||||
#
|
||||
# Replace `extern "Python"` with start/end markers
|
||||
csource = _preprocess_extern_python(csource)
|
||||
#
|
||||
# Replace "[...]" with "[__dotdotdotarray__]"
|
||||
csource = _r_partial_array.sub('[__dotdotdotarray__]', csource)
|
||||
#
|
||||
# Replace "...}" with "__dotdotdotNUM__}". This construction should
|
||||
# occur only at the end of enums; at the end of structs we have "...;}"
|
||||
# and at the end of vararg functions "...);". Also replace "=...[,}]"
|
||||
# with ",__dotdotdotNUM__[,}]": this occurs in the enums too, when
|
||||
# giving an unknown value.
|
||||
matches = list(_r_partial_enum.finditer(csource))
|
||||
for number, match in enumerate(reversed(matches)):
|
||||
p = match.start()
|
||||
if csource[p] == '=':
|
||||
p2 = csource.find('...', p, match.end())
|
||||
assert p2 > p
|
||||
csource = '%s,__dotdotdot%d__ %s' % (csource[:p], number,
|
||||
csource[p2+3:])
|
||||
else:
|
||||
assert csource[p:p+3] == '...'
|
||||
csource = '%s __dotdotdot%d__ %s' % (csource[:p], number,
|
||||
csource[p+3:])
|
||||
# Replace "int ..." or "unsigned long int..." with "__dotdotdotint__"
|
||||
csource = _r_int_dotdotdot.sub(' __dotdotdotint__ ', csource)
|
||||
# Replace "float ..." or "double..." with "__dotdotdotfloat__"
|
||||
csource = _r_float_dotdotdot.sub(' __dotdotdotfloat__ ', csource)
|
||||
# Replace all remaining "..." with the same name, "__dotdotdot__",
|
||||
# which is declared with a typedef for the purpose of C parsing.
|
||||
return csource.replace('...', ' __dotdotdot__ '), macros
|
||||
|
||||
def _common_type_names(csource):
|
||||
# Look in the source for what looks like usages of types from the
|
||||
# list of common types. A "usage" is approximated here as the
|
||||
# appearance of the word, minus a "definition" of the type, which
|
||||
# is the last word in a "typedef" statement. Approximative only
|
||||
# but should be fine for all the common types.
|
||||
look_for_words = set(COMMON_TYPES)
|
||||
look_for_words.add(';')
|
||||
look_for_words.add(',')
|
||||
look_for_words.add('(')
|
||||
look_for_words.add(')')
|
||||
look_for_words.add('typedef')
|
||||
words_used = set()
|
||||
is_typedef = False
|
||||
paren = 0
|
||||
previous_word = ''
|
||||
for word in _r_words.findall(csource):
|
||||
if word in look_for_words:
|
||||
if word == ';':
|
||||
if is_typedef:
|
||||
words_used.discard(previous_word)
|
||||
look_for_words.discard(previous_word)
|
||||
is_typedef = False
|
||||
elif word == 'typedef':
|
||||
is_typedef = True
|
||||
paren = 0
|
||||
elif word == '(':
|
||||
paren += 1
|
||||
elif word == ')':
|
||||
paren -= 1
|
||||
elif word == ',':
|
||||
if is_typedef and paren == 0:
|
||||
words_used.discard(previous_word)
|
||||
look_for_words.discard(previous_word)
|
||||
else: # word in COMMON_TYPES
|
||||
words_used.add(word)
|
||||
previous_word = word
|
||||
return words_used
|
||||
|
||||
|
||||
class Parser(object):
|
||||
|
||||
def __init__(self):
|
||||
self._declarations = {}
|
||||
self._included_declarations = set()
|
||||
self._anonymous_counter = 0
|
||||
self._structnode2type = weakref.WeakKeyDictionary()
|
||||
self._options = {}
|
||||
self._int_constants = {}
|
||||
self._recomplete = []
|
||||
self._uses_new_feature = None
|
||||
|
||||
def _parse(self, csource):
|
||||
csource, macros = _preprocess(csource)
|
||||
# XXX: for more efficiency we would need to poke into the
|
||||
# internals of CParser... the following registers the
|
||||
# typedefs, because their presence or absence influences the
|
||||
# parsing itself (but what they are typedef'ed to plays no role)
|
||||
ctn = _common_type_names(csource)
|
||||
typenames = []
|
||||
for name in sorted(self._declarations):
|
||||
if name.startswith('typedef '):
|
||||
name = name[8:]
|
||||
typenames.append(name)
|
||||
ctn.discard(name)
|
||||
typenames += sorted(ctn)
|
||||
#
|
||||
csourcelines = ['typedef int %s;' % typename for typename in typenames]
|
||||
csourcelines.append('typedef int __dotdotdotint__, __dotdotdotfloat__,'
|
||||
' __dotdotdot__;')
|
||||
csourcelines.append(csource)
|
||||
csource = '\n'.join(csourcelines)
|
||||
if lock is not None:
|
||||
lock.acquire() # pycparser is not thread-safe...
|
||||
try:
|
||||
ast = _get_parser().parse(csource)
|
||||
except pycparser.c_parser.ParseError as e:
|
||||
self.convert_pycparser_error(e, csource)
|
||||
finally:
|
||||
if lock is not None:
|
||||
lock.release()
|
||||
# csource will be used to find buggy source text
|
||||
return ast, macros, csource
|
||||
|
||||
def _convert_pycparser_error(self, e, csource):
|
||||
# xxx look for ":NUM:" at the start of str(e) and try to interpret
|
||||
# it as a line number
|
||||
line = None
|
||||
msg = str(e)
|
||||
if msg.startswith(':') and ':' in msg[1:]:
|
||||
linenum = msg[1:msg.find(':',1)]
|
||||
if linenum.isdigit():
|
||||
linenum = int(linenum, 10)
|
||||
csourcelines = csource.splitlines()
|
||||
if 1 <= linenum <= len(csourcelines):
|
||||
line = csourcelines[linenum-1]
|
||||
return line
|
||||
|
||||
def convert_pycparser_error(self, e, csource):
|
||||
line = self._convert_pycparser_error(e, csource)
|
||||
|
||||
msg = str(e)
|
||||
if line:
|
||||
msg = 'cannot parse "%s"\n%s' % (line.strip(), msg)
|
||||
else:
|
||||
msg = 'parse error\n%s' % (msg,)
|
||||
raise CDefError(msg)
|
||||
|
||||
def parse(self, csource, override=False, packed=False, dllexport=False):
|
||||
prev_options = self._options
|
||||
try:
|
||||
self._options = {'override': override,
|
||||
'packed': packed,
|
||||
'dllexport': dllexport}
|
||||
self._internal_parse(csource)
|
||||
finally:
|
||||
self._options = prev_options
|
||||
|
||||
def _internal_parse(self, csource):
|
||||
ast, macros, csource = self._parse(csource)
|
||||
# add the macros
|
||||
self._process_macros(macros)
|
||||
# find the first "__dotdotdot__" and use that as a separator
|
||||
# between the repeated typedefs and the real csource
|
||||
iterator = iter(ast.ext)
|
||||
for decl in iterator:
|
||||
if decl.name == '__dotdotdot__':
|
||||
break
|
||||
else:
|
||||
assert 0
|
||||
#
|
||||
try:
|
||||
self._inside_extern_python = '__cffi_extern_python_stop'
|
||||
for decl in iterator:
|
||||
if isinstance(decl, pycparser.c_ast.Decl):
|
||||
self._parse_decl(decl)
|
||||
elif isinstance(decl, pycparser.c_ast.Typedef):
|
||||
if not decl.name:
|
||||
raise CDefError("typedef does not declare any name",
|
||||
decl)
|
||||
quals = 0
|
||||
if (isinstance(decl.type.type, pycparser.c_ast.IdentifierType) and
|
||||
decl.type.type.names[-1].startswith('__dotdotdot')):
|
||||
realtype = self._get_unknown_type(decl)
|
||||
elif (isinstance(decl.type, pycparser.c_ast.PtrDecl) and
|
||||
isinstance(decl.type.type, pycparser.c_ast.TypeDecl) and
|
||||
isinstance(decl.type.type.type,
|
||||
pycparser.c_ast.IdentifierType) and
|
||||
decl.type.type.type.names[-1].startswith('__dotdotdot')):
|
||||
realtype = self._get_unknown_ptr_type(decl)
|
||||
else:
|
||||
realtype, quals = self._get_type_and_quals(
|
||||
decl.type, name=decl.name, partial_length_ok=True)
|
||||
self._declare('typedef ' + decl.name, realtype, quals=quals)
|
||||
elif decl.__class__.__name__ == 'Pragma':
|
||||
pass # skip pragma, only in pycparser 2.15
|
||||
else:
|
||||
raise CDefError("unrecognized construct", decl)
|
||||
except FFIError as e:
|
||||
msg = self._convert_pycparser_error(e, csource)
|
||||
if msg:
|
||||
e.args = (e.args[0] + "\n *** Err: %s" % msg,)
|
||||
raise
|
||||
|
||||
def _add_constants(self, key, val):
|
||||
if key in self._int_constants:
|
||||
if self._int_constants[key] == val:
|
||||
return # ignore identical double declarations
|
||||
raise FFIError(
|
||||
"multiple declarations of constant: %s" % (key,))
|
||||
self._int_constants[key] = val
|
||||
|
||||
def _add_integer_constant(self, name, int_str):
|
||||
int_str = int_str.lower().rstrip("ul")
|
||||
neg = int_str.startswith('-')
|
||||
if neg:
|
||||
int_str = int_str[1:]
|
||||
# "010" is not valid oct in py3
|
||||
if (int_str.startswith("0") and int_str != '0'
|
||||
and not int_str.startswith("0x")):
|
||||
int_str = "0o" + int_str[1:]
|
||||
pyvalue = int(int_str, 0)
|
||||
if neg:
|
||||
pyvalue = -pyvalue
|
||||
self._add_constants(name, pyvalue)
|
||||
self._declare('macro ' + name, pyvalue)
|
||||
|
||||
def _process_macros(self, macros):
|
||||
for key, value in macros.items():
|
||||
value = value.strip()
|
||||
if _r_int_literal.match(value):
|
||||
self._add_integer_constant(key, value)
|
||||
elif value == '...':
|
||||
self._declare('macro ' + key, value)
|
||||
else:
|
||||
raise CDefError(
|
||||
'only supports one of the following syntax:\n'
|
||||
' #define %s ... (literally dot-dot-dot)\n'
|
||||
' #define %s NUMBER (with NUMBER an integer'
|
||||
' constant, decimal/hex/octal)\n'
|
||||
'got:\n'
|
||||
' #define %s %s'
|
||||
% (key, key, key, value))
|
||||
|
||||
def _declare_function(self, tp, quals, decl):
|
||||
tp = self._get_type_pointer(tp, quals)
|
||||
if self._options.get('dllexport'):
|
||||
tag = 'dllexport_python '
|
||||
elif self._inside_extern_python == '__cffi_extern_python_start':
|
||||
tag = 'extern_python '
|
||||
elif self._inside_extern_python == '__cffi_extern_python_plus_c_start':
|
||||
tag = 'extern_python_plus_c '
|
||||
else:
|
||||
tag = 'function '
|
||||
self._declare(tag + decl.name, tp)
|
||||
|
||||
def _parse_decl(self, decl):
|
||||
node = decl.type
|
||||
if isinstance(node, pycparser.c_ast.FuncDecl):
|
||||
tp, quals = self._get_type_and_quals(node, name=decl.name)
|
||||
assert isinstance(tp, model.RawFunctionType)
|
||||
self._declare_function(tp, quals, decl)
|
||||
else:
|
||||
if isinstance(node, pycparser.c_ast.Struct):
|
||||
self._get_struct_union_enum_type('struct', node)
|
||||
elif isinstance(node, pycparser.c_ast.Union):
|
||||
self._get_struct_union_enum_type('union', node)
|
||||
elif isinstance(node, pycparser.c_ast.Enum):
|
||||
self._get_struct_union_enum_type('enum', node)
|
||||
elif not decl.name:
|
||||
raise CDefError("construct does not declare any variable",
|
||||
decl)
|
||||
#
|
||||
if decl.name:
|
||||
tp, quals = self._get_type_and_quals(node,
|
||||
partial_length_ok=True)
|
||||
if tp.is_raw_function:
|
||||
self._declare_function(tp, quals, decl)
|
||||
elif (tp.is_integer_type() and
|
||||
hasattr(decl, 'init') and
|
||||
hasattr(decl.init, 'value') and
|
||||
_r_int_literal.match(decl.init.value)):
|
||||
self._add_integer_constant(decl.name, decl.init.value)
|
||||
elif (tp.is_integer_type() and
|
||||
isinstance(decl.init, pycparser.c_ast.UnaryOp) and
|
||||
decl.init.op == '-' and
|
||||
hasattr(decl.init.expr, 'value') and
|
||||
_r_int_literal.match(decl.init.expr.value)):
|
||||
self._add_integer_constant(decl.name,
|
||||
'-' + decl.init.expr.value)
|
||||
elif (tp is model.void_type and
|
||||
decl.name.startswith('__cffi_extern_python_')):
|
||||
# hack: `extern "Python"` in the C source is replaced
|
||||
# with "void __cffi_extern_python_start;" and
|
||||
# "void __cffi_extern_python_stop;"
|
||||
self._inside_extern_python = decl.name
|
||||
else:
|
||||
if self._inside_extern_python !='__cffi_extern_python_stop':
|
||||
raise CDefError(
|
||||
"cannot declare constants or "
|
||||
"variables with 'extern \"Python\"'")
|
||||
if (quals & model.Q_CONST) and not tp.is_array_type:
|
||||
self._declare('constant ' + decl.name, tp, quals=quals)
|
||||
else:
|
||||
self._declare('variable ' + decl.name, tp, quals=quals)
|
||||
|
||||
def parse_type(self, cdecl):
|
||||
return self.parse_type_and_quals(cdecl)[0]
|
||||
|
||||
def parse_type_and_quals(self, cdecl):
|
||||
ast, macros = self._parse('void __dummy(\n%s\n);' % cdecl)[:2]
|
||||
assert not macros
|
||||
exprnode = ast.ext[-1].type.args.params[0]
|
||||
if isinstance(exprnode, pycparser.c_ast.ID):
|
||||
raise CDefError("unknown identifier '%s'" % (exprnode.name,))
|
||||
return self._get_type_and_quals(exprnode.type)
|
||||
|
||||
def _declare(self, name, obj, included=False, quals=0):
|
||||
if name in self._declarations:
|
||||
prevobj, prevquals = self._declarations[name]
|
||||
if prevobj is obj and prevquals == quals:
|
||||
return
|
||||
if not self._options.get('override'):
|
||||
raise FFIError(
|
||||
"multiple declarations of %s (for interactive usage, "
|
||||
"try cdef(xx, override=True))" % (name,))
|
||||
assert '__dotdotdot__' not in name.split()
|
||||
self._declarations[name] = (obj, quals)
|
||||
if included:
|
||||
self._included_declarations.add(obj)
|
||||
|
||||
def _extract_quals(self, type):
|
||||
quals = 0
|
||||
if isinstance(type, (pycparser.c_ast.TypeDecl,
|
||||
pycparser.c_ast.PtrDecl)):
|
||||
if 'const' in type.quals:
|
||||
quals |= model.Q_CONST
|
||||
if 'volatile' in type.quals:
|
||||
quals |= model.Q_VOLATILE
|
||||
if 'restrict' in type.quals:
|
||||
quals |= model.Q_RESTRICT
|
||||
return quals
|
||||
|
||||
def _get_type_pointer(self, type, quals, declname=None):
|
||||
if isinstance(type, model.RawFunctionType):
|
||||
return type.as_function_pointer()
|
||||
if (isinstance(type, model.StructOrUnionOrEnum) and
|
||||
type.name.startswith('$') and type.name[1:].isdigit() and
|
||||
type.forcename is None and declname is not None):
|
||||
return model.NamedPointerType(type, declname, quals)
|
||||
return model.PointerType(type, quals)
|
||||
|
||||
def _get_type_and_quals(self, typenode, name=None, partial_length_ok=False):
|
||||
# first, dereference typedefs, if we have it already parsed, we're good
|
||||
if (isinstance(typenode, pycparser.c_ast.TypeDecl) and
|
||||
isinstance(typenode.type, pycparser.c_ast.IdentifierType) and
|
||||
len(typenode.type.names) == 1 and
|
||||
('typedef ' + typenode.type.names[0]) in self._declarations):
|
||||
tp, quals = self._declarations['typedef ' + typenode.type.names[0]]
|
||||
quals |= self._extract_quals(typenode)
|
||||
return tp, quals
|
||||
#
|
||||
if isinstance(typenode, pycparser.c_ast.ArrayDecl):
|
||||
# array type
|
||||
if typenode.dim is None:
|
||||
length = None
|
||||
else:
|
||||
length = self._parse_constant(
|
||||
typenode.dim, partial_length_ok=partial_length_ok)
|
||||
tp, quals = self._get_type_and_quals(typenode.type,
|
||||
partial_length_ok=partial_length_ok)
|
||||
return model.ArrayType(tp, length), quals
|
||||
#
|
||||
if isinstance(typenode, pycparser.c_ast.PtrDecl):
|
||||
# pointer type
|
||||
itemtype, itemquals = self._get_type_and_quals(typenode.type)
|
||||
tp = self._get_type_pointer(itemtype, itemquals, declname=name)
|
||||
quals = self._extract_quals(typenode)
|
||||
return tp, quals
|
||||
#
|
||||
if isinstance(typenode, pycparser.c_ast.TypeDecl):
|
||||
quals = self._extract_quals(typenode)
|
||||
type = typenode.type
|
||||
if isinstance(type, pycparser.c_ast.IdentifierType):
|
||||
# assume a primitive type. get it from .names, but reduce
|
||||
# synonyms to a single chosen combination
|
||||
names = list(type.names)
|
||||
if names != ['signed', 'char']: # keep this unmodified
|
||||
prefixes = {}
|
||||
while names:
|
||||
name = names[0]
|
||||
if name in ('short', 'long', 'signed', 'unsigned'):
|
||||
prefixes[name] = prefixes.get(name, 0) + 1
|
||||
del names[0]
|
||||
else:
|
||||
break
|
||||
# ignore the 'signed' prefix below, and reorder the others
|
||||
newnames = []
|
||||
for prefix in ('unsigned', 'short', 'long'):
|
||||
for i in range(prefixes.get(prefix, 0)):
|
||||
newnames.append(prefix)
|
||||
if not names:
|
||||
names = ['int'] # implicitly
|
||||
if names == ['int']: # but kill it if 'short' or 'long'
|
||||
if 'short' in prefixes or 'long' in prefixes:
|
||||
names = []
|
||||
names = newnames + names
|
||||
ident = ' '.join(names)
|
||||
if ident == 'void':
|
||||
return model.void_type, quals
|
||||
if ident == '__dotdotdot__':
|
||||
raise FFIError(':%d: bad usage of "..."' %
|
||||
typenode.coord.line)
|
||||
tp0, quals0 = resolve_common_type(self, ident)
|
||||
return tp0, (quals | quals0)
|
||||
#
|
||||
if isinstance(type, pycparser.c_ast.Struct):
|
||||
# 'struct foobar'
|
||||
tp = self._get_struct_union_enum_type('struct', type, name)
|
||||
return tp, quals
|
||||
#
|
||||
if isinstance(type, pycparser.c_ast.Union):
|
||||
# 'union foobar'
|
||||
tp = self._get_struct_union_enum_type('union', type, name)
|
||||
return tp, quals
|
||||
#
|
||||
if isinstance(type, pycparser.c_ast.Enum):
|
||||
# 'enum foobar'
|
||||
tp = self._get_struct_union_enum_type('enum', type, name)
|
||||
return tp, quals
|
||||
#
|
||||
if isinstance(typenode, pycparser.c_ast.FuncDecl):
|
||||
# a function type
|
||||
return self._parse_function_type(typenode, name), 0
|
||||
#
|
||||
# nested anonymous structs or unions end up here
|
||||
if isinstance(typenode, pycparser.c_ast.Struct):
|
||||
return self._get_struct_union_enum_type('struct', typenode, name,
|
||||
nested=True), 0
|
||||
if isinstance(typenode, pycparser.c_ast.Union):
|
||||
return self._get_struct_union_enum_type('union', typenode, name,
|
||||
nested=True), 0
|
||||
#
|
||||
raise FFIError(":%d: bad or unsupported type declaration" %
|
||||
typenode.coord.line)
|
||||
|
||||
def _parse_function_type(self, typenode, funcname=None):
|
||||
params = list(getattr(typenode.args, 'params', []))
|
||||
for i, arg in enumerate(params):
|
||||
if not hasattr(arg, 'type'):
|
||||
raise CDefError("%s arg %d: unknown type '%s'"
|
||||
" (if you meant to use the old C syntax of giving"
|
||||
" untyped arguments, it is not supported)"
|
||||
% (funcname or 'in expression', i + 1,
|
||||
getattr(arg, 'name', '?')))
|
||||
ellipsis = (
|
||||
len(params) > 0 and
|
||||
isinstance(params[-1].type, pycparser.c_ast.TypeDecl) and
|
||||
isinstance(params[-1].type.type,
|
||||
pycparser.c_ast.IdentifierType) and
|
||||
params[-1].type.type.names == ['__dotdotdot__'])
|
||||
if ellipsis:
|
||||
params.pop()
|
||||
if not params:
|
||||
raise CDefError(
|
||||
"%s: a function with only '(...)' as argument"
|
||||
" is not correct C" % (funcname or 'in expression'))
|
||||
args = [self._as_func_arg(*self._get_type_and_quals(argdeclnode.type))
|
||||
for argdeclnode in params]
|
||||
if not ellipsis and args == [model.void_type]:
|
||||
args = []
|
||||
result, quals = self._get_type_and_quals(typenode.type)
|
||||
# the 'quals' on the result type are ignored. HACK: we absure them
|
||||
# to detect __stdcall functions: we textually replace "__stdcall"
|
||||
# with "volatile volatile const" above.
|
||||
abi = None
|
||||
if hasattr(typenode.type, 'quals'): # else, probable syntax error anyway
|
||||
if typenode.type.quals[-3:] == ['volatile', 'volatile', 'const']:
|
||||
abi = '__stdcall'
|
||||
return model.RawFunctionType(tuple(args), result, ellipsis, abi)
|
||||
|
||||
def _as_func_arg(self, type, quals):
|
||||
if isinstance(type, model.ArrayType):
|
||||
return model.PointerType(type.item, quals)
|
||||
elif isinstance(type, model.RawFunctionType):
|
||||
return type.as_function_pointer()
|
||||
else:
|
||||
return type
|
||||
|
||||
def _get_struct_union_enum_type(self, kind, type, name=None, nested=False):
|
||||
# First, a level of caching on the exact 'type' node of the AST.
|
||||
# This is obscure, but needed because pycparser "unrolls" declarations
|
||||
# such as "typedef struct { } foo_t, *foo_p" and we end up with
|
||||
# an AST that is not a tree, but a DAG, with the "type" node of the
|
||||
# two branches foo_t and foo_p of the trees being the same node.
|
||||
# It's a bit silly but detecting "DAG-ness" in the AST tree seems
|
||||
# to be the only way to distinguish this case from two independent
|
||||
# structs. See test_struct_with_two_usages.
|
||||
try:
|
||||
return self._structnode2type[type]
|
||||
except KeyError:
|
||||
pass
|
||||
#
|
||||
# Note that this must handle parsing "struct foo" any number of
|
||||
# times and always return the same StructType object. Additionally,
|
||||
# one of these times (not necessarily the first), the fields of
|
||||
# the struct can be specified with "struct foo { ...fields... }".
|
||||
# If no name is given, then we have to create a new anonymous struct
|
||||
# with no caching; in this case, the fields are either specified
|
||||
# right now or never.
|
||||
#
|
||||
force_name = name
|
||||
name = type.name
|
||||
#
|
||||
# get the type or create it if needed
|
||||
if name is None:
|
||||
# 'force_name' is used to guess a more readable name for
|
||||
# anonymous structs, for the common case "typedef struct { } foo".
|
||||
if force_name is not None:
|
||||
explicit_name = '$%s' % force_name
|
||||
else:
|
||||
self._anonymous_counter += 1
|
||||
explicit_name = '$%d' % self._anonymous_counter
|
||||
tp = None
|
||||
else:
|
||||
explicit_name = name
|
||||
key = '%s %s' % (kind, name)
|
||||
tp, _ = self._declarations.get(key, (None, None))
|
||||
#
|
||||
if tp is None:
|
||||
if kind == 'struct':
|
||||
tp = model.StructType(explicit_name, None, None, None)
|
||||
elif kind == 'union':
|
||||
tp = model.UnionType(explicit_name, None, None, None)
|
||||
elif kind == 'enum':
|
||||
if explicit_name == '__dotdotdot__':
|
||||
raise CDefError("Enums cannot be declared with ...")
|
||||
tp = self._build_enum_type(explicit_name, type.values)
|
||||
else:
|
||||
raise AssertionError("kind = %r" % (kind,))
|
||||
if name is not None:
|
||||
self._declare(key, tp)
|
||||
else:
|
||||
if kind == 'enum' and type.values is not None:
|
||||
raise NotImplementedError(
|
||||
"enum %s: the '{}' declaration should appear on the first "
|
||||
"time the enum is mentioned, not later" % explicit_name)
|
||||
if not tp.forcename:
|
||||
tp.force_the_name(force_name)
|
||||
if tp.forcename and '$' in tp.name:
|
||||
self._declare('anonymous %s' % tp.forcename, tp)
|
||||
#
|
||||
self._structnode2type[type] = tp
|
||||
#
|
||||
# enums: done here
|
||||
if kind == 'enum':
|
||||
return tp
|
||||
#
|
||||
# is there a 'type.decls'? If yes, then this is the place in the
|
||||
# C sources that declare the fields. If no, then just return the
|
||||
# existing type, possibly still incomplete.
|
||||
if type.decls is None:
|
||||
return tp
|
||||
#
|
||||
if tp.fldnames is not None:
|
||||
raise CDefError("duplicate declaration of struct %s" % name)
|
||||
fldnames = []
|
||||
fldtypes = []
|
||||
fldbitsize = []
|
||||
fldquals = []
|
||||
for decl in type.decls:
|
||||
if (isinstance(decl.type, pycparser.c_ast.IdentifierType) and
|
||||
''.join(decl.type.names) == '__dotdotdot__'):
|
||||
# XXX pycparser is inconsistent: 'names' should be a list
|
||||
# of strings, but is sometimes just one string. Use
|
||||
# str.join() as a way to cope with both.
|
||||
self._make_partial(tp, nested)
|
||||
continue
|
||||
if decl.bitsize is None:
|
||||
bitsize = -1
|
||||
else:
|
||||
bitsize = self._parse_constant(decl.bitsize)
|
||||
self._partial_length = False
|
||||
type, fqual = self._get_type_and_quals(decl.type,
|
||||
partial_length_ok=True)
|
||||
if self._partial_length:
|
||||
self._make_partial(tp, nested)
|
||||
if isinstance(type, model.StructType) and type.partial:
|
||||
self._make_partial(tp, nested)
|
||||
fldnames.append(decl.name or '')
|
||||
fldtypes.append(type)
|
||||
fldbitsize.append(bitsize)
|
||||
fldquals.append(fqual)
|
||||
tp.fldnames = tuple(fldnames)
|
||||
tp.fldtypes = tuple(fldtypes)
|
||||
tp.fldbitsize = tuple(fldbitsize)
|
||||
tp.fldquals = tuple(fldquals)
|
||||
if fldbitsize != [-1] * len(fldbitsize):
|
||||
if isinstance(tp, model.StructType) and tp.partial:
|
||||
raise NotImplementedError("%s: using both bitfields and '...;'"
|
||||
% (tp,))
|
||||
tp.packed = self._options.get('packed')
|
||||
if tp.completed: # must be re-completed: it is not opaque any more
|
||||
tp.completed = 0
|
||||
self._recomplete.append(tp)
|
||||
return tp
|
||||
|
||||
def _make_partial(self, tp, nested):
|
||||
if not isinstance(tp, model.StructOrUnion):
|
||||
raise CDefError("%s cannot be partial" % (tp,))
|
||||
if not tp.has_c_name() and not nested:
|
||||
raise NotImplementedError("%s is partial but has no C name" %(tp,))
|
||||
tp.partial = True
|
||||
|
||||
def _parse_constant(self, exprnode, partial_length_ok=False):
|
||||
# for now, limited to expressions that are an immediate number
|
||||
# or positive/negative number
|
||||
if isinstance(exprnode, pycparser.c_ast.Constant):
|
||||
s = exprnode.value
|
||||
if s.startswith('0'):
|
||||
if s.startswith('0x') or s.startswith('0X'):
|
||||
return int(s, 16)
|
||||
return int(s, 8)
|
||||
elif '1' <= s[0] <= '9':
|
||||
return int(s, 10)
|
||||
elif s[0] == "'" and s[-1] == "'" and (
|
||||
len(s) == 3 or (len(s) == 4 and s[1] == "\\")):
|
||||
return ord(s[-2])
|
||||
else:
|
||||
raise CDefError("invalid constant %r" % (s,))
|
||||
#
|
||||
if (isinstance(exprnode, pycparser.c_ast.UnaryOp) and
|
||||
exprnode.op == '+'):
|
||||
return self._parse_constant(exprnode.expr)
|
||||
#
|
||||
if (isinstance(exprnode, pycparser.c_ast.UnaryOp) and
|
||||
exprnode.op == '-'):
|
||||
return -self._parse_constant(exprnode.expr)
|
||||
# load previously defined int constant
|
||||
if (isinstance(exprnode, pycparser.c_ast.ID) and
|
||||
exprnode.name in self._int_constants):
|
||||
return self._int_constants[exprnode.name]
|
||||
#
|
||||
if (isinstance(exprnode, pycparser.c_ast.ID) and
|
||||
exprnode.name == '__dotdotdotarray__'):
|
||||
if partial_length_ok:
|
||||
self._partial_length = True
|
||||
return '...'
|
||||
raise FFIError(":%d: unsupported '[...]' here, cannot derive "
|
||||
"the actual array length in this context"
|
||||
% exprnode.coord.line)
|
||||
#
|
||||
if (isinstance(exprnode, pycparser.c_ast.BinaryOp) and
|
||||
exprnode.op == '+'):
|
||||
return (self._parse_constant(exprnode.left) +
|
||||
self._parse_constant(exprnode.right))
|
||||
#
|
||||
if (isinstance(exprnode, pycparser.c_ast.BinaryOp) and
|
||||
exprnode.op == '-'):
|
||||
return (self._parse_constant(exprnode.left) -
|
||||
self._parse_constant(exprnode.right))
|
||||
#
|
||||
raise FFIError(":%d: unsupported expression: expected a "
|
||||
"simple numeric constant" % exprnode.coord.line)
|
||||
|
||||
def _build_enum_type(self, explicit_name, decls):
|
||||
if decls is not None:
|
||||
partial = False
|
||||
enumerators = []
|
||||
enumvalues = []
|
||||
nextenumvalue = 0
|
||||
for enum in decls.enumerators:
|
||||
if _r_enum_dotdotdot.match(enum.name):
|
||||
partial = True
|
||||
continue
|
||||
if enum.value is not None:
|
||||
nextenumvalue = self._parse_constant(enum.value)
|
||||
enumerators.append(enum.name)
|
||||
enumvalues.append(nextenumvalue)
|
||||
self._add_constants(enum.name, nextenumvalue)
|
||||
nextenumvalue += 1
|
||||
enumerators = tuple(enumerators)
|
||||
enumvalues = tuple(enumvalues)
|
||||
tp = model.EnumType(explicit_name, enumerators, enumvalues)
|
||||
tp.partial = partial
|
||||
else: # opaque enum
|
||||
tp = model.EnumType(explicit_name, (), ())
|
||||
return tp
|
||||
|
||||
def include(self, other):
|
||||
for name, (tp, quals) in other._declarations.items():
|
||||
if name.startswith('anonymous $enum_$'):
|
||||
continue # fix for test_anonymous_enum_include
|
||||
kind = name.split(' ', 1)[0]
|
||||
if kind in ('struct', 'union', 'enum', 'anonymous', 'typedef'):
|
||||
self._declare(name, tp, included=True, quals=quals)
|
||||
for k, v in other._int_constants.items():
|
||||
self._add_constants(k, v)
|
||||
|
||||
def _get_unknown_type(self, decl):
|
||||
typenames = decl.type.type.names
|
||||
if typenames == ['__dotdotdot__']:
|
||||
return model.unknown_type(decl.name)
|
||||
|
||||
if typenames == ['__dotdotdotint__']:
|
||||
if self._uses_new_feature is None:
|
||||
self._uses_new_feature = "'typedef int... %s'" % decl.name
|
||||
return model.UnknownIntegerType(decl.name)
|
||||
|
||||
if typenames == ['__dotdotdotfloat__']:
|
||||
# note: not for 'long double' so far
|
||||
if self._uses_new_feature is None:
|
||||
self._uses_new_feature = "'typedef float... %s'" % decl.name
|
||||
return model.UnknownFloatType(decl.name)
|
||||
|
||||
raise FFIError(':%d: unsupported usage of "..." in typedef'
|
||||
% decl.coord.line)
|
||||
|
||||
def _get_unknown_ptr_type(self, decl):
|
||||
if decl.type.type.type.names == ['__dotdotdot__']:
|
||||
return model.unknown_ptr_type(decl.name)
|
||||
raise FFIError(':%d: unsupported usage of "..." in typedef'
|
||||
% decl.coord.line)
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
|
||||
class FFIError(Exception):
|
||||
pass
|
||||
|
||||
class CDefError(Exception):
|
||||
def __str__(self):
|
||||
try:
|
||||
line = 'line %d: ' % (self.args[1].coord.line,)
|
||||
except (AttributeError, TypeError, IndexError):
|
||||
line = ''
|
||||
return '%s%s' % (line, self.args[0])
|
||||
|
||||
class VerificationError(Exception):
|
||||
""" An error raised when verification fails
|
||||
"""
|
||||
|
||||
class VerificationMissing(Exception):
|
||||
""" An error raised when incomplete structures are passed into
|
||||
cdef, but no verification has been done
|
||||
"""
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
import sys, os
|
||||
from .error import VerificationError
|
||||
|
||||
|
||||
LIST_OF_FILE_NAMES = ['sources', 'include_dirs', 'library_dirs',
|
||||
'extra_objects', 'depends']
|
||||
|
||||
def get_extension(srcfilename, modname, sources=(), **kwds):
|
||||
from distutils.core import Extension
|
||||
allsources = [srcfilename]
|
||||
for src in sources:
|
||||
allsources.append(os.path.normpath(src))
|
||||
return Extension(name=modname, sources=allsources, **kwds)
|
||||
|
||||
def compile(tmpdir, ext, compiler_verbose=0, debug=None):
|
||||
"""Compile a C extension module using distutils."""
|
||||
|
||||
saved_environ = os.environ.copy()
|
||||
try:
|
||||
outputfilename = _build(tmpdir, ext, compiler_verbose, debug)
|
||||
outputfilename = os.path.abspath(outputfilename)
|
||||
finally:
|
||||
# workaround for a distutils bugs where some env vars can
|
||||
# become longer and longer every time it is used
|
||||
for key, value in saved_environ.items():
|
||||
if os.environ.get(key) != value:
|
||||
os.environ[key] = value
|
||||
return outputfilename
|
||||
|
||||
def _build(tmpdir, ext, compiler_verbose=0, debug=None):
|
||||
# XXX compact but horrible :-(
|
||||
from distutils.core import Distribution
|
||||
import distutils.errors, distutils.log
|
||||
#
|
||||
dist = Distribution({'ext_modules': [ext]})
|
||||
dist.parse_config_files()
|
||||
options = dist.get_option_dict('build_ext')
|
||||
if debug is None:
|
||||
debug = sys.flags.debug
|
||||
options['debug'] = ('ffiplatform', debug)
|
||||
options['force'] = ('ffiplatform', True)
|
||||
options['build_lib'] = ('ffiplatform', tmpdir)
|
||||
options['build_temp'] = ('ffiplatform', tmpdir)
|
||||
#
|
||||
try:
|
||||
old_level = distutils.log.set_threshold(0) or 0
|
||||
try:
|
||||
distutils.log.set_verbosity(compiler_verbose)
|
||||
dist.run_command('build_ext')
|
||||
cmd_obj = dist.get_command_obj('build_ext')
|
||||
[soname] = cmd_obj.get_outputs()
|
||||
finally:
|
||||
distutils.log.set_threshold(old_level)
|
||||
except (distutils.errors.CompileError,
|
||||
distutils.errors.LinkError) as e:
|
||||
raise VerificationError('%s: %s' % (e.__class__.__name__, e))
|
||||
#
|
||||
return soname
|
||||
|
||||
try:
|
||||
from os.path import samefile
|
||||
except ImportError:
|
||||
def samefile(f1, f2):
|
||||
return os.path.abspath(f1) == os.path.abspath(f2)
|
||||
|
||||
def maybe_relative_path(path):
|
||||
if not os.path.isabs(path):
|
||||
return path # already relative
|
||||
dir = path
|
||||
names = []
|
||||
while True:
|
||||
prevdir = dir
|
||||
dir, name = os.path.split(prevdir)
|
||||
if dir == prevdir or not dir:
|
||||
return path # failed to make it relative
|
||||
names.append(name)
|
||||
try:
|
||||
if samefile(dir, os.curdir):
|
||||
names.reverse()
|
||||
return os.path.join(*names)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# ____________________________________________________________
|
||||
|
||||
try:
|
||||
int_or_long = (int, long)
|
||||
import cStringIO
|
||||
except NameError:
|
||||
int_or_long = int # Python 3
|
||||
import io as cStringIO
|
||||
|
||||
def _flatten(x, f):
|
||||
if isinstance(x, str):
|
||||
f.write('%ds%s' % (len(x), x))
|
||||
elif isinstance(x, dict):
|
||||
keys = sorted(x.keys())
|
||||
f.write('%dd' % len(keys))
|
||||
for key in keys:
|
||||
_flatten(key, f)
|
||||
_flatten(x[key], f)
|
||||
elif isinstance(x, (list, tuple)):
|
||||
f.write('%dl' % len(x))
|
||||
for value in x:
|
||||
_flatten(value, f)
|
||||
elif isinstance(x, int_or_long):
|
||||
f.write('%di' % (x,))
|
||||
else:
|
||||
raise TypeError(
|
||||
"the keywords to verify() contains unsupported object %r" % (x,))
|
||||
|
||||
def flatten(x):
|
||||
f = cStringIO.StringIO()
|
||||
_flatten(x, f)
|
||||
return f.getvalue()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue