diff options
author | Andrew Arnott <andrewarnott@gmail.com> | 2008-11-24 22:47:57 -0800 |
---|---|---|
committer | Andrew <andrewarnott@gmail.com> | 2008-11-24 22:47:57 -0800 |
commit | 866385f5426835483eea4d701fe07388dff3f3c3 (patch) | |
tree | 44cde819e064b8dc9ff31ed0d4351b1b6b16ffa8 /src | |
parent | 143d80b2ce76ef6eee4bddda9039a0b0f9673356 (diff) | |
download | DotNetOpenAuth-866385f5426835483eea4d701fe07388dff3f3c3.zip DotNetOpenAuth-866385f5426835483eea4d701fe07388dff3f3c3.tar.gz DotNetOpenAuth-866385f5426835483eea4d701fe07388dff3f3c3.tar.bz2 |
All 249 enabled tests pass.
Diffstat (limited to 'src')
-rw-r--r-- | src/DotNetOpenAuth.Test/Mocks/MockHttpRequest.cs | 26 | ||||
-rw-r--r-- | src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs | 7 | ||||
-rw-r--r-- | src/DotNetOpenAuth.Test/OpenId/UriIdentifierTests.cs | 3 | ||||
-rw-r--r-- | src/DotNetOpenAuth/Messaging/Channel.cs | 11 | ||||
-rw-r--r-- | src/DotNetOpenAuth/Messaging/DirectWebResponse.cs | 149 | ||||
-rw-r--r-- | src/DotNetOpenAuth/Messaging/UntrustedWebRequestHandler.cs | 53 | ||||
-rw-r--r-- | src/DotNetOpenAuth/OAuth/ConsumerBase.cs | 2 | ||||
-rw-r--r-- | src/DotNetOpenAuth/OpenId/XriIdentifier.cs | 6 | ||||
-rw-r--r-- | src/DotNetOpenAuth/Yadis/Yadis.cs | 2 |
9 files changed, 208 insertions, 51 deletions
diff --git a/src/DotNetOpenAuth.Test/Mocks/MockHttpRequest.cs b/src/DotNetOpenAuth.Test/Mocks/MockHttpRequest.cs index a9c6ad4..1473bf0 100644 --- a/src/DotNetOpenAuth.Test/Mocks/MockHttpRequest.cs +++ b/src/DotNetOpenAuth.Test/Mocks/MockHttpRequest.cs @@ -14,14 +14,25 @@ internal class MockHttpRequest { private readonly Dictionary<Uri, DirectWebResponse> registeredMockResponses = new Dictionary<Uri, DirectWebResponse>(); - private readonly TestWebRequestHandler mockHandler; - internal MockHttpRequest(TestWebRequestHandler mockHandler) { + internal static MockHttpRequest CreateUntrustedMockHttpHandler() { + TestWebRequestHandler testHandler = new TestWebRequestHandler(); + UntrustedWebRequestHandler untrustedHandler = new UntrustedWebRequestHandler(testHandler); + if (!untrustedHandler.WhitelistHosts.Contains("localhost")) { + untrustedHandler.WhitelistHosts.Add("localhost"); + } + MockHttpRequest mock = new MockHttpRequest(untrustedHandler); + testHandler.Callback = mock.GetMockResponse; + return mock; + } + + private MockHttpRequest(IDirectSslWebRequestHandler mockHandler) { ErrorUtilities.VerifyArgumentNotNull(mockHandler, "mockHandler"); - this.mockHandler = mockHandler; - this.mockHandler.Callback = this.GetMockResponse; + this.MockWebRequestHandler = mockHandler; } + internal IDirectSslWebRequestHandler MockWebRequestHandler { get; private set; } + private DirectWebResponse GetMockResponse(HttpWebRequest request) { DirectWebResponse response; if (this.registeredMockResponses.TryGetValue(request.RequestUri, out response)) { @@ -36,13 +47,6 @@ } } - /// <summary> - /// Clears all all mock HTTP responses and deactivates HTTP mocking. - /// </summary> - internal void Reset() { - this.registeredMockResponses.Clear(); - } - internal void RegisterMockResponse(DirectWebResponse response) { if (response == null) throw new ArgumentNullException("response"); if (registeredMockResponses.ContainsKey(response.RequestUri)) { diff --git a/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs b/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs index 4eff09f..82cf976 100644 --- a/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs +++ b/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs @@ -10,13 +10,14 @@ namespace DotNetOpenAuth.Test.OpenId { using DotNetOpenAuth.OpenId.RelyingParty; using Microsoft.VisualStudio.TestTools.UnitTesting; using DotNetOpenAuth.Test.Mocks; + using DotNetOpenAuth.Messaging; public class OpenIdTestBase : TestBase { protected RelyingPartySecuritySettings RelyingPartySecuritySettings { get; private set; } protected ProviderSecuritySettings ProviderSecuritySettings { get; private set; } - internal TestWebRequestHandler requestHandler; + internal IDirectSslWebRequestHandler requestHandler; internal MockHttpRequest mockResponder; [TestInitialize] @@ -26,8 +27,8 @@ namespace DotNetOpenAuth.Test.OpenId { this.RelyingPartySecuritySettings = RelyingPartySection.Configuration.SecuritySettings.CreateSecuritySettings(); this.ProviderSecuritySettings = ProviderSection.Configuration.SecuritySettings.CreateSecuritySettings(); - this.requestHandler = new TestWebRequestHandler(); - this.mockResponder = new MockHttpRequest(requestHandler); + this.mockResponder = MockHttpRequest.CreateUntrustedMockHttpHandler(); + this.requestHandler = this.mockResponder.MockWebRequestHandler; } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/UriIdentifierTests.cs b/src/DotNetOpenAuth.Test/OpenId/UriIdentifierTests.cs index da2c6cc..f9b54b4 100644 --- a/src/DotNetOpenAuth.Test/OpenId/UriIdentifierTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/UriIdentifierTests.cs @@ -13,7 +13,7 @@ namespace DotNetOpenAuth.Test.OpenId { using DotNetOpenAuth.Test.Mocks; using Microsoft.VisualStudio.TestTools.UnitTesting; using DotNetOpenAuth.OpenId; -using DotNetOpenAuth.Messaging; + using DotNetOpenAuth.Messaging; [TestClass] public class UriIdentifierTests : OpenIdTestBase { @@ -320,7 +320,6 @@ using DotNetOpenAuth.Messaging; [TestMethod] public void DiscoverRequireSslWithSecureRedirects() { - this.mockResponder.Reset(); Identifier claimedId = TestSupport.GetMockIdentifier(TestSupport.Scenarios.AutoApproval, this.mockResponder, ProtocolVersion.V20, true); // Add a couple of chained redirect pages that lead to the claimedId. diff --git a/src/DotNetOpenAuth/Messaging/Channel.cs b/src/DotNetOpenAuth/Messaging/Channel.cs index 71afb4f..f6d060b 100644 --- a/src/DotNetOpenAuth/Messaging/Channel.cs +++ b/src/DotNetOpenAuth/Messaging/Channel.cs @@ -358,13 +358,16 @@ namespace DotNetOpenAuth.Messaging { /// </remarks> protected virtual IProtocolMessage RequestInternal(IDirectedProtocolMessage request) { HttpWebRequest webRequest = this.CreateHttpRequest(request); + IDictionary<string, string> responseFields; - DirectWebResponse response = this.WebRequestHandler.GetResponse(webRequest); - if (response.ResponseStream == null) { - return null; + using (DirectWebResponse response = this.WebRequestHandler.GetResponse(webRequest)) { + if (response.ResponseStream == null) { + return null; + } + + responseFields = this.ReadFromResponseInternal(response); } - var responseFields = this.ReadFromResponseInternal(response); IDirectResponseProtocolMessage responseMessage = this.MessageFactory.GetNewResponseMessage(request, responseFields); if (responseMessage == null) { return null; diff --git a/src/DotNetOpenAuth/Messaging/DirectWebResponse.cs b/src/DotNetOpenAuth/Messaging/DirectWebResponse.cs index f8dc31e..60e4f7d 100644 --- a/src/DotNetOpenAuth/Messaging/DirectWebResponse.cs +++ b/src/DotNetOpenAuth/Messaging/DirectWebResponse.cs @@ -8,27 +8,55 @@ using System.Globalization; using System.Net.Mime; using System.Net; + using System.Diagnostics.CodeAnalysis; [Serializable] - [DebuggerDisplay("{StatusCode} {ContentType.MediaType}: {ReadResponseString().Substring(4,50)}")] - public class DirectWebResponse : Response { + [DebuggerDisplay("{Status} {ContentType.MediaType}: {ReadResponseString().Substring(4,50)}")] + public class DirectWebResponse : IDisposable { private const string DefaultContentEncoding = "ISO-8859-1"; + private HttpWebResponse httpWebResponse; + private object responseLock = new object(); internal DirectWebResponse() { + this.Status = HttpStatusCode.OK; + this.Headers = new WebHeaderCollection(); } - internal DirectWebResponse(Uri requestUri, HttpWebResponse response) - : this(requestUri, response, int.MaxValue) { - } - - internal DirectWebResponse(Uri requestUri, HttpWebResponse response, int maximumBytesToRead) : base(response, maximumBytesToRead) { + internal DirectWebResponse(Uri requestUri, HttpWebResponse response) { ErrorUtilities.VerifyArgumentNotNull(requestUri, "requestUri"); ErrorUtilities.VerifyArgumentNotNull(response, "response"); + this.RequestUri = requestUri; - if (!string.IsNullOrEmpty(response.ContentType)) - ContentType = new ContentType(response.ContentType); - ContentEncoding = string.IsNullOrEmpty(response.ContentEncoding) ? DefaultContentEncoding : response.ContentEncoding; - FinalUri = response.ResponseUri; + if (!string.IsNullOrEmpty(response.ContentType)) { + this.ContentType = new ContentType(response.ContentType); + } + this.ContentEncoding = string.IsNullOrEmpty(response.ContentEncoding) ? DefaultContentEncoding : response.ContentEncoding; + this.FinalUri = response.ResponseUri; + this.Status = response.StatusCode; + this.Headers = response.Headers; + this.httpWebResponse = response; + this.ResponseStream = response.GetResponseStream(); + } + + internal void CacheNetworkStreamAndClose() { + this.CacheNetworkStreamAndClose(int.MaxValue); + } + + internal void CacheNetworkStreamAndClose(int maximumBytesToRead) { + lock (responseLock) { + if (this.httpWebResponse != null) { + // Now read and cache the network stream + Stream networkStream = this.ResponseStream; + this.ResponseStream = new MemoryStream(this.httpWebResponse.ContentLength < 0 ? 4 * 1024 : Math.Min((int)this.httpWebResponse.ContentLength, maximumBytesToRead)); + // BUGBUG: strictly speaking, is the response were exactly the limit, we'd report it as truncated here. + this.IsResponseTruncated = networkStream.CopyTo(this.ResponseStream, maximumBytesToRead) == maximumBytesToRead; + this.ResponseStream.Seek(0, SeekOrigin.Begin); + + networkStream.Dispose(); + this.httpWebResponse.Close(); + this.httpWebResponse = null; + } + } } /// <summary> @@ -54,6 +82,40 @@ public Uri RequestUri { get; private set; } public Uri FinalUri { get; private set; } + /// <summary> + /// Gets a value indicating whether the response stream is incomplete due + /// to a length limitation imposed by the HttpWebRequest or calling method. + /// </summary> + public bool IsResponseTruncated { get; internal set; } + + /// <summary> + /// Gets the headers that must be included in the response to the user agent. + /// </summary> + /// <remarks> + /// The headers in this collection are not meant to be a comprehensive list + /// of exactly what should be sent, but are meant to augment whatever headers + /// are generally included in a typical response. + /// </remarks> + public WebHeaderCollection Headers { get; internal set; } + + /// <summary> + /// Gets the HTTP status code to use in the HTTP response. + /// </summary> + public HttpStatusCode Status { get; internal set; } + + /// <summary> + /// Gets the body of the HTTP response. + /// </summary> + public Stream ResponseStream { get; internal set; } + + /// <summary> + /// Gets or sets the body of the response as a string. + /// </summary> + public string Body { + get { return this.ResponseStream != null ? this.GetResponseReader().ReadToEnd() : null; } + set { this.SetResponse(value); } + } + public override string ToString() { StringBuilder sb = new StringBuilder(); sb.AppendLine(string.Format(CultureInfo.CurrentCulture, "RequestUri = {0}", this.RequestUri)); @@ -69,5 +131,70 @@ sb.AppendLine(this.Body); return sb.ToString(); } + /// <summary> + /// Creates a text reader for the response stream. + /// </summary> + /// <returns>The text reader, initialized for the proper encoding.</returns> + [SuppressMessage("Microsoft.Design", "CA1024:UsePropertiesWhereAppropriate", Justification = "Costly operation")] + public StreamReader GetResponseReader() { + this.ResponseStream.Seek(0, SeekOrigin.Begin); + string contentEncoding = this.Headers[HttpResponseHeader.ContentEncoding]; + if (string.IsNullOrEmpty(contentEncoding)) { + return new StreamReader(this.ResponseStream); + } else { + return new StreamReader(this.ResponseStream, Encoding.GetEncoding(contentEncoding)); + } + } + + + /// <summary> + /// Sets the response to some string, encoded as UTF-8. + /// </summary> + /// <param name="body">The string to set the response to.</param> + internal void SetResponse(string body) { + if (body == null) { + this.ResponseStream = null; + return; + } + + Encoding encoding = Encoding.UTF8; + this.Headers[HttpResponseHeader.ContentEncoding] = encoding.HeaderName; + this.ResponseStream = new MemoryStream(); + StreamWriter writer = new StreamWriter(this.ResponseStream, encoding); + writer.Write(body); + writer.Flush(); + this.ResponseStream.Seek(0, SeekOrigin.Begin); + } + + #region IDisposable Members + + /// <summary> + /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. + /// </summary> + public void Dispose() { + this.Dispose(true); + GC.SuppressFinalize(true); + } + + /// <summary> + /// Releases unmanaged and - optionally - managed resources + /// </summary> + /// <param name="disposing"><c>true</c> to release both managed and unmanaged resources; <c>false</c> to release only unmanaged resources.</param> + protected void Dispose(bool disposing) { + if (disposing) { + lock (responseLock) { + if (this.ResponseStream != null) { + this.ResponseStream.Dispose(); + this.ResponseStream = null; + } + if (this.httpWebResponse != null) { + this.httpWebResponse.Close(); + this.httpWebResponse = null; + } + } + } + } + + #endregion } } diff --git a/src/DotNetOpenAuth/Messaging/UntrustedWebRequestHandler.cs b/src/DotNetOpenAuth/Messaging/UntrustedWebRequestHandler.cs index 4cc8844..dee0cca 100644 --- a/src/DotNetOpenAuth/Messaging/UntrustedWebRequestHandler.cs +++ b/src/DotNetOpenAuth/Messaging/UntrustedWebRequestHandler.cs @@ -62,10 +62,23 @@ namespace DotNetOpenAuth.Messaging { [DebuggerBrowsable(DebuggerBrowsableState.Never)] private int maximumRedirections = Configuration.MaximumRedirections; + private IDirectWebRequestHandler chainedWebRequestHandler; + + /// <summary> + /// Initializes a new instance of the <see cref="UntrustedWebRequestHandler"/> class. + /// </summary> + public UntrustedWebRequestHandler() + : this(new StandardWebRequestHandler()) { + } + /// <summary> /// Initializes a new instance of the <see cref="UntrustedWebRequestHandler"/> class. /// </summary> - public UntrustedWebRequestHandler() { + /// <param name="chainedWebRequestHandler">The chained web request handler.</param> + public UntrustedWebRequestHandler(IDirectWebRequestHandler chainedWebRequestHandler) { + ErrorUtilities.VerifyArgumentNotNull(chainedWebRequestHandler, "chainedWebRequestHandler"); + + this.chainedWebRequestHandler = chainedWebRequestHandler; this.ReadWriteTimeout = Configuration.ReadWriteTimeout; this.Timeout = Configuration.Timeout; #if LONGTIMEOUT @@ -124,7 +137,7 @@ namespace DotNetOpenAuth.Messaging { /// </summary> public ICollection<Regex> BlacklistHostsRegex { get { return blacklistHostsRegex; } } - #region IDirectUntrustedWebRequestHandler Members + #region IDirectSslWebRequestHandler Members /// <summary> /// Prepares an <see cref="HttpWebRequest"/> that contains an POST entity for sending the entity. @@ -146,11 +159,7 @@ namespace DotNetOpenAuth.Messaging { request.AllowAutoRedirect = false; // Submit the request and get the request stream back. - try { - return new StreamWriter(request.GetRequestStream()); - } catch (WebException ex) { - throw ErrorUtilities.Wrap(ex, MessagingStrings.ErrorInRequestReplyMessage); - } + return this.chainedWebRequestHandler.GetRequestStream(request); } /// <summary> @@ -170,8 +179,7 @@ namespace DotNetOpenAuth.Messaging { // we have no guarantee, so do it just to be safe. this.PrepareRequest(request); - // TODO: Code here - throw new NotImplementedException(); + return this.RequestWithManagedRedirects(request, requireSsl); } #endregion @@ -198,7 +206,7 @@ namespace DotNetOpenAuth.Messaging { DirectWebResponse IDirectWebRequestHandler.GetResponse(HttpWebRequest request) { return this.GetResponse(request, false); } - + #endregion internal DirectWebResponse RequestWithManagedRedirects(HttpWebRequest request, bool requireSsl) { @@ -237,15 +245,15 @@ namespace DotNetOpenAuth.Messaging { newRequest.AutomaticDecompression = request.AutomaticDecompression; newRequest.CachePolicy = request.CachePolicy; newRequest.ClientCertificates = request.ClientCertificates; - newRequest.Connection = request.Connection; newRequest.ConnectionGroupName = request.ConnectionGroupName; - newRequest.ContentLength = request.ContentLength; + if (request.ContentLength >= 0) { + newRequest.ContentLength = request.ContentLength; + } newRequest.ContentType = request.ContentType; newRequest.ContinueDelegate = request.ContinueDelegate; newRequest.CookieContainer = request.CookieContainer; newRequest.Credentials = request.Credentials; newRequest.Expect = request.Expect; - newRequest.Headers = request.Headers; newRequest.IfModifiedSince = request.IfModifiedSince; newRequest.ImpersonationLevel = request.ImpersonationLevel; newRequest.KeepAlive = request.KeepAlive; @@ -266,6 +274,15 @@ namespace DotNetOpenAuth.Messaging { newRequest.UseDefaultCredentials = request.UseDefaultCredentials; newRequest.UserAgent = request.UserAgent; + // We copy headers last, and only those that do not yet exist as a result + // of setting these properties, so as to avoid exceptions thrown because + // there are properties .NET wants us to use rather than direct headers. + foreach (string header in request.Headers) { + if (string.IsNullOrEmpty(newRequest.Headers[header])) { + newRequest.Headers.Add(header, request.Headers[header]); + } + } + return newRequest; } @@ -388,9 +405,9 @@ namespace DotNetOpenAuth.Messaging { } } - using (HttpWebResponse response = (HttpWebResponse)request.GetResponse()) { - return new DirectWebResponse(originalRequestUri, response, MaximumBytesToRead); - } + DirectWebResponse response = this.chainedWebRequestHandler.GetResponse(request); + response.CacheNetworkStreamAndClose(MaximumBytesToRead); + return response; } catch (WebException e) { using (HttpWebResponse response = (HttpWebResponse)e.Response) { if (response != null) { @@ -410,7 +427,9 @@ namespace DotNetOpenAuth.Messaging { return RequestCore(request, postEntity, originalRequestUri, requireSsl); } } - return new DirectWebResponse(originalRequestUri, response, MaximumBytesToRead); + var directResponse = new DirectWebResponse(originalRequestUri, response); + directResponse.CacheNetworkStreamAndClose(MaximumBytesToRead); + return directResponse; } else { throw ErrorUtilities.Wrap(e, MessagingStrings.WebRequestFailed, originalRequestUri); } diff --git a/src/DotNetOpenAuth/OAuth/ConsumerBase.cs b/src/DotNetOpenAuth/OAuth/ConsumerBase.cs index 7c634b7..90683e8 100644 --- a/src/DotNetOpenAuth/OAuth/ConsumerBase.cs +++ b/src/DotNetOpenAuth/OAuth/ConsumerBase.cs @@ -87,7 +87,7 @@ namespace DotNetOpenAuth.OAuth { /// <param name="accessToken">The access token that permits access to the protected resource.</param> /// <returns>The initialized WebRequest object.</returns> /// <exception cref="WebException">Thrown if the request fails for any reason after it is sent to the Service Provider.</exception> - public Response PrepareAuthorizedRequestAndSend(MessageReceivingEndpoint endpoint, string accessToken) { + public DirectWebResponse PrepareAuthorizedRequestAndSend(MessageReceivingEndpoint endpoint, string accessToken) { IDirectedProtocolMessage message = this.CreateAuthorizingMessage(endpoint, accessToken); HttpWebRequest wr = this.OAuthChannel.InitializeRequest(message); return this.Channel.WebRequestHandler.GetResponse(wr); diff --git a/src/DotNetOpenAuth/OpenId/XriIdentifier.cs b/src/DotNetOpenAuth/OpenId/XriIdentifier.cs index a6a7fe8..7bbfb9d 100644 --- a/src/DotNetOpenAuth/OpenId/XriIdentifier.cs +++ b/src/DotNetOpenAuth/OpenId/XriIdentifier.cs @@ -150,8 +150,10 @@ namespace DotNetOpenAuth.OpenId { } private XrdsDocument downloadXrds(IDirectSslWebRequestHandler requestHandler) { - var xrdsResponse = Yadis.Request(requestHandler, this.XrdsUrl, this.IsDiscoverySecureEndToEnd); - XrdsDocument doc = new XrdsDocument(XmlReader.Create(xrdsResponse.ResponseStream)); + XrdsDocument doc; + using (var xrdsResponse = Yadis.Request(requestHandler, this.XrdsUrl, this.IsDiscoverySecureEndToEnd)) { + doc = new XrdsDocument(XmlReader.Create(xrdsResponse.ResponseStream)); + } ErrorUtilities.VerifyProtocol(doc.IsXrdResolutionSuccessful, OpenIdStrings.XriResolutionFailed); return doc; } diff --git a/src/DotNetOpenAuth/Yadis/Yadis.cs b/src/DotNetOpenAuth/Yadis/Yadis.cs index 1bbbc73..9339376 100644 --- a/src/DotNetOpenAuth/Yadis/Yadis.cs +++ b/src/DotNetOpenAuth/Yadis/Yadis.cs @@ -55,6 +55,7 @@ namespace DotNetOpenAuth.Yadis { return null; } response = Request(requestHandler, uri, requireSsl, ContentTypes.Html, ContentTypes.XHtml, ContentTypes.Xrds); + response.CacheNetworkStreamAndClose(); if (response.Status != System.Net.HttpStatusCode.OK) { return null; } @@ -84,6 +85,7 @@ namespace DotNetOpenAuth.Yadis { if (url != null) { if (!requireSsl || string.Equals(url.Scheme, Uri.UriSchemeHttps, StringComparison.OrdinalIgnoreCase)) { response2 = Request(requestHandler, url, requireSsl); + response2.CacheNetworkStreamAndClose(); if (response2.Status != System.Net.HttpStatusCode.OK) { return null; } |