# -*- coding: utf-8 -
#
# This file is part of restkit released under the MIT license.
# See the NOTICE for more information.
from webob import exc
from webob.compat import PY3, url_quote
import logging
import socket
import six
import re
try:
import urlparse
except ImportError: # pragma: nocover
import urllib.parse as urlparse # NOQA
try:
import httplib
except ImportError: # pragma: nocover
import http.client as httplib # NOQA
LOW_CHAR_SAFE = ''.join(chr(n) for n in range(128))
ABSOLUTE_URL_RE = re.compile(r"^https?://", re.I)
ALLOWED_METHODS = ['GET', 'HEAD', 'POST', 'PUT', 'DELETE']
WEBOB_ERROR = (
"Content-Length is set to -1. This usually mean that WebOb has "
"already parsed the content body. You should set the Content-Length "
"header to the correct value before forwarding your request to the "
"proxy: ``req.content_length = str(len(req.body));`` "
"req.get_response(proxy)")
def rewrite_location(host_uri, location, prefix_path=None):
prefix_path = prefix_path or ''
url = urlparse.urlparse(location)
host_url = urlparse.urlparse(host_uri)
if not ABSOLUTE_URL_RE.match(location):
# remote server doesn't follow rfc2616
location = urlparse.urljoin(host_uri, location.lstrip('/'))
prefix_path = prefix_path.strip('/')
if prefix_path:
location = location.replace(host_uri,
host_uri + '/' + prefix_path)
return location
elif url.scheme == host_url.scheme and url.netloc == host_url.netloc:
return urlparse.urlunparse((host_url.scheme, host_url.netloc,
prefix_path + url.path, url.params,
url.query, url.fragment))
return location
[docs]class HttpClient(object):
"""A HTTP client using stdlib's httplib (Default client)"""
HTTPConnection = httplib.HTTPConnection
HTTPSConnection = httplib.HTTPSConnection
def __init__(self, **connection_options):
self.options = connection_options
def __call__(self, uri, method, body, headers):
ssl = uri.startswith('https://')
ConnClass = ssl and self.HTTPSConnection or self.HTTPConnection
uri = ssl and uri[8:] or uri[7:]
port = ssl and 443 or 80
host, path = uri.split('/', 1)
path = '/' + path
if ':' in host:
host, port = host.split(':')
conn = ConnClass('%s:%s' % (host, port))
if 'Transfer-Encoding' in headers:
del headers['Transfer-Encoding']
if headers.get('Content-Length'):
body = body.read(int(headers['Content-Length']))
else:
body = None
conn.request(method, path, body, headers, **self.options)
response = conn.getresponse()
status = '%s %s' % (response.status, response.reason)
length = response.getheader('content-length')
body = response.read(int(length)) if length else response.read()
return (status, response.getheader('location', None),
response.getheaders(), [body])
[docs]class Proxy(object):
"""A proxy which redirect the request to SERVER_NAME:SERVER_PORT
and send HTTP_HOST header"""
header_map = {
'HTTP_HOST': 'X_FORWARDED_SERVER',
'SCRIPT_NAME': 'X_FORWARDED_SCRIPT_NAME',
'wsgi.url_scheme': 'X_FORWARDED_SCHEME',
'REMOTE_ADDR': 'X_FORWARDED_FOR',
}
def __init__(self, client=None, allowed_methods=ALLOWED_METHODS,
strip_script_name=True, **client_options):
self.allowed_methods = allowed_methods
self.strip_script_name = strip_script_name
if client is None or client == 'httplib':
self.http = HttpClient(**client_options)
elif hasattr(client, '__call__'):
self.http = client
else:
mod = __import__('wsgiproxy.%s_client' % client,
globals(), locals(), [''])
self.http = mod.HttpClient(**client_options)
self.logger = logging.getLogger(__name__)
def extract_uri(self, environ):
port = None
scheme = environ['wsgi.url_scheme']
if 'SERVER_NAME' in environ:
host = environ['SERVER_NAME']
else:
host = environ['HTTP_HOST']
if ':' in host:
host, port = host.split(':')
if not port:
if 'SERVER_PORT' in environ:
port = environ['SERVER_PORT']
else:
port = scheme == 'https' and '443' or '80'
uri = '%s://%s:%s' % (scheme, host, port)
return uri
def process_request(self, uri, method, headers, environ):
return self.http(uri, method, environ['wsgi.input'], headers)
def __call__(self, environ, start_response):
method = environ['REQUEST_METHOD']
if (self.allowed_methods is not None and
method not in self.allowed_methods):
return exc.HTTPMethodNotAllowed()(environ, start_response)
if 'RAW_URI' in environ:
path_info = environ['RAW_URI']
elif 'REQUEST_URI' in environ:
path_info = environ['REQUEST_URI']
else:
if self.strip_script_name:
path_info = ''
else:
path_info = environ['SCRIPT_NAME']
path_info += environ['PATH_INFO']
if PY3:
path_info = url_quote(path_info.encode('latin-1'),
LOW_CHAR_SAFE)
query_string = environ['QUERY_STRING']
if query_string:
path_info += '?' + query_string
for key, dest in self.header_map.items():
value = environ.get(key)
if value:
environ['HTTP_%s' % dest] = value
host_uri = self.extract_uri(environ)
uri = host_uri + path_info
new_headers = {}
for k, v in environ.items():
if k.startswith('HTTP_'):
k = k[5:].replace('_', '-').title()
new_headers[k] = v
content_type = environ.get("CONTENT_TYPE")
if content_type and content_type is not None:
new_headers['Content-Type'] = content_type
content_length = environ.get('CONTENT_LENGTH')
transfer_encoding = environ.get('Transfer-Encoding', '').lower()
if not content_length and transfer_encoding != 'chunked':
new_headers['Transfer-Encoding'] = 'chunked'
elif content_length:
new_headers['Content-Length'] = content_length
if new_headers.get('Content-Length', '0') == '-1':
resp = exc.HTTPInternalServerError(detail=WEBOB_ERROR)
return resp(environ, start_response)
try:
response = self.process_request(uri, method, new_headers, environ)
except socket.timeout:
return exc.HTTPGatewayTimeout()(environ, start_response)
except (socket.error, socket.gaierror):
return exc.HTTPBadGateway()(environ, start_response)
except Exception as e:
self.logger.exception(e)
return exc.HTTPInternalServerError()(environ, start_response)
status, location, headerslist, app_iter = response
if location:
if self.strip_script_name:
prefix_path = environ['SCRIPT_NAME']
else:
prefix_path = None
new_location = rewrite_location(host_uri, location,
prefix_path=prefix_path)
headers = []
for k, v in headerslist:
if k.lower() == 'location':
v = new_location
headers.append((k, v))
else:
headers = headerslist
start_response(status, headers)
if method == "HEAD":
return [six.b('')]
return app_iter
[docs]class TransparentProxy(Proxy):
"""A proxy based on HTTP_HOST environ variable"""
def extract_uri(self, environ):
port = None
scheme = environ['wsgi.url_scheme']
host = environ['HTTP_HOST']
if ':' in host:
host, port = host.split(':')
if not port:
port = scheme == 'https' and '443' or '80'
uri = '%s://%s:%s' % (scheme, host, port)
return uri
[docs]class HostProxy(Proxy):
"""A proxy to redirect all request to a specific uri"""
def __init__(self, uri, client=None, allowed_methods=ALLOWED_METHODS,
strip_script_name=True, **client_options):
super(HostProxy, self).__init__(
client=client, allowed_methods=allowed_methods,
strip_script_name=strip_script_name, **client_options)
self.uri = str(uri.rstrip('/'))
self.scheme, self.net_loc = urlparse.urlparse(self.uri)[0:2]
def extract_uri(self, environ):
environ['HTTP_HOST'] = self.net_loc
return self.uri