summaryrefslogtreecommitdiffstats
path: root/src/DotNetOpenId/UntrustedWebRequest.cs
diff options
context:
space:
mode:
Diffstat (limited to 'src/DotNetOpenId/UntrustedWebRequest.cs')
-rw-r--r--src/DotNetOpenId/UntrustedWebRequest.cs106
1 files changed, 79 insertions, 27 deletions
diff --git a/src/DotNetOpenId/UntrustedWebRequest.cs b/src/DotNetOpenId/UntrustedWebRequest.cs
index f169a89..a621a65 100644
--- a/src/DotNetOpenId/UntrustedWebRequest.cs
+++ b/src/DotNetOpenId/UntrustedWebRequest.cs
@@ -3,13 +3,15 @@
#endif
namespace DotNetOpenId {
using System;
- using System.Net;
- using System.IO;
+ using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
- using System.Collections.Generic;
+ using System.IO;
+ using System.Net;
using System.Text.RegularExpressions;
-
+ using System.Configuration;
+ using DotNetOpenId.Configuration;
+ using System.Reflection;
/// <summary>
/// A paranoid HTTP get/post request engine. It helps to protect against attacks from remote
/// server leaving dangling connections, sending too much data, causing requests against
@@ -25,8 +27,14 @@ namespace DotNetOpenId {
/// If a particular host would not be permitted but is in the whitelist, it is allowed.
/// </remarks>
public static class UntrustedWebRequest {
+ private static string UserAgentValue = Assembly.GetExecutingAssembly().GetName().Name + "/" + Assembly.GetExecutingAssembly().GetName().Version;
+
+ static Configuration.UntrustedWebRequestSection Configuration {
+ get { return UntrustedWebRequestSection.Configuration; }
+ }
+
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
- static int maximumBytesToRead = 1024 * 1024;
+ static int maximumBytesToRead = Configuration.MaximumBytesToRead;
/// <summary>
/// The default maximum bytes to read in any given HTTP request.
/// Default is 1MB. Cannot be less than 2KB.
@@ -39,7 +47,7 @@ namespace DotNetOpenId {
}
}
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
- static int maximumRedirections = 10;
+ static int maximumRedirections = Configuration.MaximumRedirections;
/// <summary>
/// The total number of redirections to allow on any one request.
/// Default is 10.
@@ -62,10 +70,20 @@ namespace DotNetOpenId {
/// </summary>
public static TimeSpan Timeout { get; set; }
+ internal delegate UntrustedWebResponse MockRequestResponse(Uri uri, byte[] body, string[] acceptTypes);
+ /// <summary>
+ /// Used in unit testing to mock HTTP responses to expected requests.
+ /// </summary>
+ /// <remarks>
+ /// If null, no mocking will take place. But if non-null, all requests
+ /// will be channeled through this mock method for processing.
+ /// </remarks>
+ internal static MockRequestResponse MockRequests;
+
[System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Performance", "CA1810:InitializeReferenceTypeStaticFieldsInline")]
static UntrustedWebRequest() {
- ReadWriteTimeout = TimeSpan.FromMilliseconds(500);
- Timeout = TimeSpan.FromSeconds(10);
+ ReadWriteTimeout = Configuration.ReadWriteTimeout;
+ Timeout = Configuration.Timeout;
#if LONGTIMEOUT
ReadWriteTimeout = TimeSpan.FromHours(1);
Timeout = TimeSpan.FromHours(1);
@@ -81,25 +99,25 @@ namespace DotNetOpenId {
return true;
}
static ICollection<string> allowableSchemes = new List<string> { "http", "https" };
- static ICollection<string> whitelistHosts = new List<string>();
+ static ICollection<string> whitelistHosts = new List<string>(Configuration.WhitelistHosts.KeysAsStrings);
/// <summary>
/// A collection of host name literals that should be allowed even if they don't
/// pass standard security checks.
/// </summary>
public static ICollection<string> WhitelistHosts { get { return whitelistHosts; } }
- static ICollection<Regex> whitelistHostsRegex = new List<Regex>();
+ static ICollection<Regex> whitelistHostsRegex = new List<Regex>(Configuration.WhitelistHostsRegex.KeysAsRegexs);
/// <summary>
/// A collection of host name regular expressions that indicate hosts that should
/// be allowed even though they don't pass standard security checks.
/// </summary>
public static ICollection<Regex> WhitelistHostsRegex { get { return whitelistHostsRegex; } }
- static ICollection<string> blacklistHosts = new List<string>();
+ static ICollection<string> blacklistHosts = new List<string>(Configuration.BlacklistHosts.KeysAsStrings);
/// <summary>
/// A collection of host name literals that should be rejected even if they
/// pass standard security checks.
/// </summary>
public static ICollection<string> BlacklistHosts { get { return blacklistHosts; } }
- static ICollection<Regex> blacklistHostsRegex = new List<Regex>();
+ static ICollection<Regex> blacklistHostsRegex = new List<Regex>(Configuration.BlacklistHostsRegex.KeysAsRegexs);
/// <summary>
/// A collection of host name regular expressions that indicate hosts that should
/// be rjected even if they pass standard security checks.
@@ -128,16 +146,14 @@ namespace DotNetOpenId {
static bool isUriAllowable(Uri uri) {
Debug.Assert(uri != null);
if (!allowableSchemes.Contains(uri.Scheme)) {
- if (TraceUtil.Switch.TraceWarning)
- Trace.TraceWarning("Rejecting URL {0} because it uses a disallowed scheme.", uri);
+ Logger.WarnFormat("Rejecting URL {0} because it uses a disallowed scheme.", uri);
return false;
}
// Allow for whitelist or blacklist to override our detection.
DotNetOpenId.Util.Func<string, bool> failsUnlessWhitelisted = (string reason) => {
if (isHostWhitelisted(uri.DnsSafeHost)) return true;
- if (TraceUtil.Switch.TraceWarning)
- Trace.TraceWarning("Rejecting URL {0} because {1}.", uri, reason);
+ Logger.WarnFormat("Rejecting URL {0} because {1}.", uri, reason);
return false;
};
@@ -171,8 +187,7 @@ namespace DotNetOpenId {
}
}
if (isHostBlacklisted(uri.DnsSafeHost)) {
- if (TraceUtil.Switch.TraceWarning)
- Trace.TraceWarning("Rejected URL {0} because it is blacklisted.", uri);
+ Logger.WarnFormat("Rejected URL {0} because it is blacklisted.", uri);
return false;
}
return true;
@@ -198,11 +213,11 @@ namespace DotNetOpenId {
}
}
- static UntrustedWebResponse getResponse(Uri requestUri, HttpWebResponse resp) {
+ static UntrustedWebResponse getResponse(Uri requestUri, Uri finalRequestUri, HttpWebResponse resp) {
byte[] data;
int length;
readData(resp, out data, out length);
- return new UntrustedWebResponse(requestUri, resp, new MemoryStream(data, 0, length));
+ return new UntrustedWebResponse(requestUri, finalRequestUri, resp, new MemoryStream(data, 0, length));
}
internal static UntrustedWebResponse Request(Uri uri) {
@@ -217,17 +232,53 @@ namespace DotNetOpenId {
return Request(uri, body, acceptTypes, false);
}
- static UntrustedWebResponse Request(Uri uri, byte[] body, string[] acceptTypes,
- bool avoidSendingExpect100Continue) {
+ internal static UntrustedWebResponse Request(Uri uri, byte[] body, string[] acceptTypes, bool requireSsl) {
+ // Since we may require SSL for every redirect, we handle each redirect manually
+ // in order to detect and fail if any redirect sends us to an HTTP url.
+ // We COULD allow automatic redirect in the cases where HTTPS is not required,
+ // but our mock request infrastructure can't do redirects on its own either.
+ Uri originalRequestUri = uri;
+ int i;
+ for (i = 0; i < MaximumRedirections; i++) {
+ UntrustedWebResponse response = RequestInternal(uri, body, acceptTypes, requireSsl, false, originalRequestUri);
+ if (response.StatusCode == HttpStatusCode.MovedPermanently ||
+ response.StatusCode == HttpStatusCode.Redirect ||
+ response.StatusCode == HttpStatusCode.RedirectMethod ||
+ response.StatusCode == HttpStatusCode.RedirectKeepVerb) {
+ uri = new Uri(response.FinalUri, response.Headers[HttpResponseHeader.Location]);
+ } else {
+ return response;
+ }
+ }
+ throw new WebException(string.Format(CultureInfo.CurrentCulture, Strings.TooManyRedirects, originalRequestUri));
+ }
+
+ static UntrustedWebResponse RequestInternal(Uri uri, byte[] body, string[] acceptTypes,
+ bool requireSsl, bool avoidSendingExpect100Continue, Uri originalRequestUri) {
if (uri == null) throw new ArgumentNullException("uri");
+ if (originalRequestUri == null) throw new ArgumentNullException("originalRequestUri");
if (!isUriAllowable(uri)) throw new ArgumentException(string.Format(CultureInfo.CurrentCulture,
Strings.UnsafeWebRequestDetected, uri), "uri");
+ if (requireSsl && !String.Equals(uri.Scheme, Uri.UriSchemeHttps, StringComparison.OrdinalIgnoreCase)) {
+ throw new OpenIdException(string.Format(CultureInfo.CurrentCulture, Strings.InsecureWebRequestWithSslRequired, uri));
+ }
+
+ // mock the request if a hosting unit test has configured it.
+ if (MockRequests != null) {
+ return MockRequests(uri, body, acceptTypes);
+ }
HttpWebRequest request = (HttpWebRequest)WebRequest.Create(uri);
+ // If SSL is required throughout, we cannot allow auto redirects because
+ // it may include a pass through an unprotected HTTP request.
+ // We have to follow redirects manually, and our caller will be responsible for that.
+ // It also allows us to ignore HttpWebResponse.FinalUri since that can be affected by
+ // the Content-Location header and open security holes.
+ request.AllowAutoRedirect = false;
request.ReadWriteTimeout = (int)ReadWriteTimeout.TotalMilliseconds;
request.Timeout = (int)Timeout.TotalMilliseconds;
request.KeepAlive = false;
- request.MaximumAutomaticRedirections = MaximumRedirections;
+ request.UserAgent = UserAgentValue;
if (acceptTypes != null)
request.Accept = string.Join(",", acceptTypes);
if (body != null) {
@@ -255,19 +306,20 @@ namespace DotNetOpenId {
}
using (HttpWebResponse response = (HttpWebResponse)request.GetResponse()) {
- return getResponse(uri, response);
+ return getResponse(originalRequestUri, request.RequestUri, response);
}
} catch (WebException e) {
using (HttpWebResponse response = (HttpWebResponse)e.Response) {
if (response != null) {
if (response.StatusCode == HttpStatusCode.ExpectationFailed) {
if (!avoidSendingExpect100Continue) { // must only try this once more
- return Request(uri, body, acceptTypes, true);
+ return RequestInternal(uri, body, acceptTypes, requireSsl, true, originalRequestUri);
}
}
- return getResponse(uri, response);
+ return getResponse(originalRequestUri, request.RequestUri, response);
} else {
- throw;
+ throw new OpenIdException(string.Format(CultureInfo.CurrentCulture,
+ Strings.WebRequestFailed, originalRequestUri), e);
}
}
}