#if DEBUG
#define LONGTIMEOUT
#endif
namespace DotNetOpenId {
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Net;
using System.Text.RegularExpressions;
using System.Configuration;
using DotNetOpenId.Configuration;
using System.Reflection;
///
/// A paranoid HTTP get/post request engine. It helps to protect against attacks from remote
/// server leaving dangling connections, sending too much data, causing requests against
/// internal servers, etc.
///
///
/// Protections include:
/// * Conservative maximum time to receive the complete response.
/// * Only HTTP and HTTPS schemes are permitted.
/// * Internal IP address ranges are not permitted: 127.*.*.*, 1::*
/// * Internal host names are not permitted (periods must be found in the host name)
/// If a particular host would be permitted but is in the blacklist, it is not allowed.
/// If a particular host would not be permitted but is in the whitelist, it is allowed.
///
public static class UntrustedWebRequest {
private static string UserAgentValue = Assembly.GetExecutingAssembly().GetName().Name + "/" + Assembly.GetExecutingAssembly().GetName().Version;
static Configuration.UntrustedWebRequestSection Configuration {
get { return UntrustedWebRequestSection.Configuration; }
}
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
static int maximumBytesToRead = Configuration.MaximumBytesToRead;
///
/// The default maximum bytes to read in any given HTTP request.
/// Default is 1MB. Cannot be less than 2KB.
///
public static int MaximumBytesToRead {
get { return maximumBytesToRead; }
set {
if (value < 2048) throw new ArgumentOutOfRangeException("value");
maximumBytesToRead = value;
}
}
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
static int maximumRedirections = Configuration.MaximumRedirections;
///
/// The total number of redirections to allow on any one request.
/// Default is 10.
///
public static int MaximumRedirections {
get { return maximumRedirections; }
set {
if (value < 0) throw new ArgumentOutOfRangeException("value");
maximumRedirections = value;
}
}
///
/// Gets the time allowed to wait for single read or write operation to complete.
/// Default is 500 milliseconds.
///
public static TimeSpan ReadWriteTimeout { get; set; }
///
/// Gets the time allowed for an entire HTTP request.
/// Default is 5 seconds.
///
public static TimeSpan Timeout { get; set; }
internal delegate UntrustedWebResponse MockRequestResponse(Uri uri, byte[] body, string[] acceptTypes);
///
/// Used in unit testing to mock HTTP responses to expected requests.
///
///
/// If null, no mocking will take place. But if non-null, all requests
/// will be channeled through this mock method for processing.
///
internal static MockRequestResponse MockRequests;
[System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Performance", "CA1810:InitializeReferenceTypeStaticFieldsInline")]
static UntrustedWebRequest() {
ReadWriteTimeout = Configuration.ReadWriteTimeout;
Timeout = Configuration.Timeout;
#if LONGTIMEOUT
ReadWriteTimeout = TimeSpan.FromHours(1);
Timeout = TimeSpan.FromHours(1);
#endif
}
static bool isIPv6Loopback(IPAddress ip) {
Debug.Assert(ip != null);
byte[] addressBytes = ip.GetAddressBytes();
for (int i = 0; i < addressBytes.Length - 1; i++)
if (addressBytes[i] != 0) return false;
if (addressBytes[addressBytes.Length - 1] != 1) return false;
return true;
}
static ICollection allowableSchemes = new List { "http", "https" };
static ICollection whitelistHosts = new List(Configuration.WhitelistHosts.KeysAsStrings);
///
/// A collection of host name literals that should be allowed even if they don't
/// pass standard security checks.
///
public static ICollection WhitelistHosts { get { return whitelistHosts; } }
static ICollection whitelistHostsRegex = new List(Configuration.WhitelistHostsRegex.KeysAsRegexs);
///
/// A collection of host name regular expressions that indicate hosts that should
/// be allowed even though they don't pass standard security checks.
///
public static ICollection WhitelistHostsRegex { get { return whitelistHostsRegex; } }
static ICollection blacklistHosts = new List(Configuration.BlacklistHosts.KeysAsStrings);
///
/// A collection of host name literals that should be rejected even if they
/// pass standard security checks.
///
public static ICollection BlacklistHosts { get { return blacklistHosts; } }
static ICollection blacklistHostsRegex = new List(Configuration.BlacklistHostsRegex.KeysAsRegexs);
///
/// A collection of host name regular expressions that indicate hosts that should
/// be rjected even if they pass standard security checks.
///
public static ICollection BlacklistHostsRegex { get { return blacklistHostsRegex; } }
static bool isHostWhitelisted(string host) {
return isHostInList(host, WhitelistHosts, WhitelistHostsRegex);
}
static bool isHostBlacklisted(string host) {
return isHostInList(host, BlacklistHosts, BlacklistHostsRegex);
}
static bool isHostInList(string host, ICollection stringList, ICollection regexList) {
Debug.Assert(!string.IsNullOrEmpty(host));
Debug.Assert(stringList != null);
Debug.Assert(regexList != null);
foreach (string testHost in stringList) {
if (string.Equals(host, testHost, StringComparison.OrdinalIgnoreCase))
return true;
}
foreach (Regex regex in regexList) {
if (regex.IsMatch(host))
return true;
}
return false;
}
static bool isUriAllowable(Uri uri) {
Debug.Assert(uri != null);
if (!allowableSchemes.Contains(uri.Scheme)) {
Logger.WarnFormat("Rejecting URL {0} because it uses a disallowed scheme.", uri);
return false;
}
// Allow for whitelist or blacklist to override our detection.
DotNetOpenId.Util.Func failsUnlessWhitelisted = (string reason) => {
if (isHostWhitelisted(uri.DnsSafeHost)) return true;
Logger.WarnFormat("Rejecting URL {0} because {1}.", uri, reason);
return false;
};
// Try to interpret the hostname as an IP address so we can test for internal
// IP address ranges. Note that IP addresses can appear in many forms
// (e.g. http://127.0.0.1, http://2130706433, http://0x0100007f, http://::1
// So we convert them to a canonical IPAddress instance, and test for all
// non-routable IP ranges: 10.*.*.*, 127.*.*.*, ::1
// Note that Uri.IsLoopback is very unreliable, not catching many of these variants.
IPAddress hostIPAddress;
if (IPAddress.TryParse(uri.DnsSafeHost, out hostIPAddress)) {
byte[] addressBytes = hostIPAddress.GetAddressBytes();
// The host is actually an IP address.
switch (hostIPAddress.AddressFamily) {
case System.Net.Sockets.AddressFamily.InterNetwork:
if (addressBytes[0] == 127 || addressBytes[0] == 10)
return failsUnlessWhitelisted("it is a loopback address.");
break;
case System.Net.Sockets.AddressFamily.InterNetworkV6:
if (isIPv6Loopback(hostIPAddress))
return failsUnlessWhitelisted("it is a loopback address.");
break;
default:
return failsUnlessWhitelisted("it does not use an IPv4 or IPv6 address.");
}
} else {
// The host is given by name. We require names to contain periods to
// help make sure it's not an internal address.
if (!uri.Host.Contains(".")) {
return failsUnlessWhitelisted("it does not contain a period in the host name.");
}
}
if (isHostBlacklisted(uri.DnsSafeHost)) {
Logger.WarnFormat("Rejected URL {0} because it is blacklisted.", uri);
return false;
}
return true;
}
///
/// Reads a maximum number of bytes from a response stream.
///
///
/// The number of bytes actually read.
/// WARNING: This can be fewer than the size of the returned buffer.
///
static void readData(HttpWebResponse resp, out byte[] buffer, out int length) {
int bufferSize = resp.ContentLength >= 0 && resp.ContentLength < int.MaxValue ?
Math.Min(MaximumBytesToRead, (int)resp.ContentLength) : MaximumBytesToRead;
buffer = new byte[bufferSize];
using (Stream stream = resp.GetResponseStream()) {
int dataLength = 0;
int chunkSize;
while (dataLength < bufferSize && (chunkSize = stream.Read(buffer, dataLength, bufferSize - dataLength)) > 0)
dataLength += chunkSize;
length = dataLength;
}
}
static UntrustedWebResponse getResponse(Uri requestUri, Uri finalRequestUri, HttpWebResponse resp) {
byte[] data;
int length;
readData(resp, out data, out length);
return new UntrustedWebResponse(requestUri, finalRequestUri, resp, new MemoryStream(data, 0, length));
}
internal static UntrustedWebResponse Request(Uri uri) {
return Request(uri, null);
}
internal static UntrustedWebResponse Request(Uri uri, byte[] body) {
return Request(uri, body, null);
}
internal static UntrustedWebResponse Request(Uri uri, byte[] body, string[] acceptTypes) {
return Request(uri, body, acceptTypes, false);
}
internal static UntrustedWebResponse Request(Uri uri, byte[] body, string[] acceptTypes, bool requireSsl) {
// Since we may require SSL for every redirect, we handle each redirect manually
// in order to detect and fail if any redirect sends us to an HTTP url.
// We COULD allow automatic redirect in the cases where HTTPS is not required,
// but our mock request infrastructure can't do redirects on its own either.
Uri originalRequestUri = uri;
int i;
for (i = 0; i < MaximumRedirections; i++) {
UntrustedWebResponse response = RequestInternal(uri, body, acceptTypes, requireSsl, false, originalRequestUri);
if (response.StatusCode == HttpStatusCode.MovedPermanently ||
response.StatusCode == HttpStatusCode.Redirect ||
response.StatusCode == HttpStatusCode.RedirectMethod ||
response.StatusCode == HttpStatusCode.RedirectKeepVerb) {
uri = new Uri(response.FinalUri, response.Headers[HttpResponseHeader.Location]);
} else {
return response;
}
}
throw new WebException(string.Format(CultureInfo.CurrentCulture, Strings.TooManyRedirects, originalRequestUri));
}
static UntrustedWebResponse RequestInternal(Uri uri, byte[] body, string[] acceptTypes,
bool requireSsl, bool avoidSendingExpect100Continue, Uri originalRequestUri) {
if (uri == null) throw new ArgumentNullException("uri");
if (originalRequestUri == null) throw new ArgumentNullException("originalRequestUri");
if (!isUriAllowable(uri)) throw new ArgumentException(string.Format(CultureInfo.CurrentCulture,
Strings.UnsafeWebRequestDetected, uri), "uri");
if (requireSsl && !String.Equals(uri.Scheme, Uri.UriSchemeHttps, StringComparison.OrdinalIgnoreCase)) {
throw new OpenIdException(string.Format(CultureInfo.CurrentCulture, Strings.InsecureWebRequestWithSslRequired, uri));
}
// mock the request if a hosting unit test has configured it.
if (MockRequests != null) {
return MockRequests(uri, body, acceptTypes);
}
HttpWebRequest request = (HttpWebRequest)WebRequest.Create(uri);
// If SSL is required throughout, we cannot allow auto redirects because
// it may include a pass through an unprotected HTTP request.
// We have to follow redirects manually, and our caller will be responsible for that.
// It also allows us to ignore HttpWebResponse.FinalUri since that can be affected by
// the Content-Location header and open security holes.
request.AllowAutoRedirect = false;
request.ReadWriteTimeout = (int)ReadWriteTimeout.TotalMilliseconds;
request.Timeout = (int)Timeout.TotalMilliseconds;
request.KeepAlive = false;
request.UserAgent = UserAgentValue;
if (acceptTypes != null)
request.Accept = string.Join(",", acceptTypes);
if (body != null) {
request.ContentType = "application/x-www-form-urlencoded";
request.ContentLength = body.Length;
request.Method = "POST";
if (avoidSendingExpect100Continue) {
// Some OpenID servers doesn't understand Expect header and send 417 error back.
// If this server just failed from that, we're trying again without sending the
// "Expect: 100-Continue" HTTP header. (see Google Code Issue 72)
// We don't just set Expect100Continue = !avoidSendingExpect100Continue
// so that future requests don't reset this and have to try twice as well.
// We don't want to blindly set all ServicePoints to not use the Expect header
// as that would be a security hole allowing any visitor to a web site change
// the web site's global behavior when calling that host.
request.ServicePoint.Expect100Continue = false;
}
}
try {
if (body != null) {
using (Stream outStream = request.GetRequestStream()) {
outStream.Write(body, 0, body.Length);
}
}
using (HttpWebResponse response = (HttpWebResponse)request.GetResponse()) {
return getResponse(originalRequestUri, request.RequestUri, response);
}
} catch (WebException e) {
using (HttpWebResponse response = (HttpWebResponse)e.Response) {
if (response != null) {
if (response.StatusCode == HttpStatusCode.ExpectationFailed) {
if (!avoidSendingExpect100Continue) { // must only try this once more
return RequestInternal(uri, body, acceptTypes, requireSsl, true, originalRequestUri);
}
}
return getResponse(originalRequestUri, request.RequestUri, response);
} else {
throw new OpenIdException(string.Format(CultureInfo.CurrentCulture,
Strings.WebRequestFailed, originalRequestUri), e);
}
}
}
}
}
}