diff options
Diffstat (limited to 'src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/OpenIdChannel.cs')
-rw-r--r-- | src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/OpenIdChannel.cs | 130 |
1 files changed, 82 insertions, 48 deletions
diff --git a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/OpenIdChannel.cs b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/OpenIdChannel.cs index 5a6b8bb..89bcd77 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/OpenIdChannel.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/OpenIdChannel.cs @@ -7,12 +7,19 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { using System; using System.Collections.Generic; + using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.IO; using System.Linq; using System.Net; + using System.Net.Http; + using System.Net.Http.Headers; using System.Text; + using System.Threading; + using System.Threading.Tasks; + + using DotNetOpenAuth.Configuration; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OpenId.Extensions; @@ -40,13 +47,14 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { private KeyValueFormEncoding keyValueForm = new KeyValueFormEncoding(); /// <summary> - /// Initializes a new instance of the <see cref="OpenIdChannel"/> class. + /// Initializes a new instance of the <see cref="OpenIdChannel" /> class. /// </summary> /// <param name="messageTypeProvider">A class prepared to analyze incoming messages and indicate what concrete /// message types can deserialize from it.</param> /// <param name="bindingElements">The binding elements to use in sending and receiving messages.</param> - protected OpenIdChannel(IMessageFactory messageTypeProvider, IChannelBindingElement[] bindingElements) - : base(messageTypeProvider, bindingElements) { + /// <param name="hostFactories">The host factories.</param> + protected OpenIdChannel(IMessageFactory messageTypeProvider, IChannelBindingElement[] bindingElements, IHostFactories hostFactories) + : base(messageTypeProvider, bindingElements, hostFactories ?? new DefaultOpenIdHostFactories()) { Requires.NotNull(messageTypeProvider, "messageTypeProvider"); // Customize the binding element order, since we play some tricks for higher @@ -68,33 +76,30 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { } this.CustomizeBindingElementOrder(outgoingBindingElements, incomingBindingElements); - - // Change out the standard web request handler to reflect the standard - // OpenID pattern that outgoing web requests are to unknown and untrusted - // servers on the Internet. - this.WebRequestHandler = new UntrustedWebRequestHandler(); } /// <summary> /// Verifies the integrity and applicability of an incoming message. /// </summary> /// <param name="message">The message just received.</param> - /// <exception cref="ProtocolException"> - /// Thrown when the message is somehow invalid, except for check_authentication messages. - /// This can be due to tampering, replay attack or expiration, among other things. - /// </exception> - protected override void ProcessIncomingMessage(IProtocolMessage message) { + /// <param name="cancellationToken">The cancellation token.</param> + /// <returns> + /// A task that completes with the asynchronous operation. + /// </returns> + /// <exception cref="ProtocolException">Thrown when the message is somehow invalid, except for check_authentication messages. + /// This can be due to tampering, replay attack or expiration, among other things.</exception> + protected override async Task ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { var checkAuthRequest = message as CheckAuthenticationRequest; if (checkAuthRequest != null) { IndirectSignedResponse originalResponse = new IndirectSignedResponse(checkAuthRequest, this); try { - base.ProcessIncomingMessage(originalResponse); + await base.ProcessIncomingMessageAsync(originalResponse, cancellationToken); checkAuthRequest.IsValid = true; } catch (ProtocolException) { checkAuthRequest.IsValid = false; } } else { - base.ProcessIncomingMessage(message); + await base.ProcessIncomingMessageAsync(message, cancellationToken); } // Convert an OpenID indirect error message, which we never expect @@ -120,7 +125,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// <returns> /// The <see cref="HttpWebRequest"/> prepared to send the request. /// </returns> - protected override HttpWebRequest CreateHttpRequest(IDirectedProtocolMessage request) { + protected override HttpRequestMessage CreateHttpRequest(IDirectedProtocolMessage request) { return this.InitializeRequestAsPost(request); } @@ -128,13 +133,16 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Gets the protocol message that may be in the given HTTP response. /// </summary> /// <param name="response">The response that is anticipated to contain an protocol message.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// The deserialized message parts, if found. Null otherwise. /// </returns> /// <exception cref="ProtocolException">Thrown when the response is not valid.</exception> - protected override IDictionary<string, string> ReadFromResponseCore(IncomingWebResponse response) { + protected override async Task<IDictionary<string, string>> ReadFromResponseCoreAsync(HttpResponseMessage response, CancellationToken cancellationToken) { try { - return this.keyValueForm.GetDictionary(response.ResponseStream); + using (var responseStream = await response.Content.ReadAsStreamAsync()) { + return await this.keyValueForm.GetDictionaryAsync(responseStream, cancellationToken); + } } catch (FormatException ex) { throw ErrorUtilities.Wrap(ex, ex.Message); } @@ -145,7 +153,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// </summary> /// <param name="response">The HTTP direct response.</param> /// <param name="message">The newly instantiated message, prior to deserialization.</param> - protected override void OnReceivingDirectResponse(IncomingWebResponse response, IDirectResponseProtocolMessage message) { + protected override void OnReceivingDirectResponse(HttpResponseMessage response, IDirectResponseProtocolMessage message) { base.OnReceivingDirectResponse(response, message); // Verify that the expected HTTP status code was used for the message, @@ -155,10 +163,10 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { var httpDirectResponse = message as IHttpDirectResponse; if (httpDirectResponse != null) { ErrorUtilities.VerifyProtocol( - httpDirectResponse.HttpStatusCode == response.Status, + httpDirectResponse.HttpStatusCode == response.StatusCode, MessagingStrings.UnexpectedHttpStatusCode, (int)httpDirectResponse.HttpStatusCode, - (int)response.Status); + (int)response.StatusCode); } } } @@ -174,55 +182,81 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// <remarks> /// This method implements spec V1.0 section 5.3. /// </remarks> - protected override OutgoingWebResponse PrepareDirectResponse(IProtocolMessage response) { + protected override HttpResponseMessage PrepareDirectResponse(IProtocolMessage response) { var messageAccessor = this.MessageDescriptions.GetAccessor(response); var fields = messageAccessor.Serialize(); byte[] keyValueEncoding = KeyValueFormEncoding.GetBytes(fields); - OutgoingWebResponse preparedResponse = new OutgoingWebResponse(); + var preparedResponse = new HttpResponseMessage(); ApplyMessageTemplate(response, preparedResponse); - preparedResponse.Headers.Add(HttpResponseHeader.ContentType, KeyValueFormContentType); - preparedResponse.OriginalMessage = response; - preparedResponse.ResponseStream = new MemoryStream(keyValueEncoding); + var content = new StreamContent(new MemoryStream(keyValueEncoding)); + content.Headers.ContentType = new MediaTypeHeaderValue(KeyValueFormContentType); + preparedResponse.Content = content; IHttpDirectResponse httpMessage = response as IHttpDirectResponse; if (httpMessage != null) { - preparedResponse.Status = httpMessage.HttpStatusCode; + preparedResponse.StatusCode = httpMessage.HttpStatusCode; } return preparedResponse; } /// <summary> - /// Gets the direct response of a direct HTTP request. + /// Provides derived-types the opportunity to wrap an <see cref="HttpMessageHandler" /> with another one. + /// </summary> + /// <param name="innerHandler">The inner handler received from <see cref="IHostFactories" /></param> + /// <returns> + /// The handler to use in <see cref="HttpClient" /> instances. + /// </returns> + protected override HttpMessageHandler WrapMessageHandler(HttpMessageHandler innerHandler) { + return new ErrorFilteringMessageHandler(base.WrapMessageHandler(innerHandler)); + } + + /// <summary> + /// An HTTP handler that throws an exception if the response message's HTTP status code doesn't fall + /// within those allowed by the OpenID spec. /// </summary> - /// <param name="webRequest">The web request.</param> - /// <returns>The response to the web request.</returns> - /// <exception cref="ProtocolException">Thrown on network or protocol errors.</exception> - protected override IncomingWebResponse GetDirectResponse(HttpWebRequest webRequest) { - IncomingWebResponse response = this.WebRequestHandler.GetResponse(webRequest, DirectWebRequestOptions.AcceptAllHttpResponses); - - // Filter the responses to the allowable set of HTTP status codes. - if (response.Status != HttpStatusCode.OK && response.Status != HttpStatusCode.BadRequest) { - if (Logger.Channel.IsErrorEnabled) { - using (var reader = new StreamReader(response.ResponseStream)) { + private class ErrorFilteringMessageHandler : DelegatingHandler { + /// <summary> + /// Initializes a new instance of the <see cref="ErrorFilteringMessageHandler" /> class. + /// </summary> + /// <param name="innerHandler">The inner handler which is responsible for processing the HTTP response messages.</param> + internal ErrorFilteringMessageHandler(HttpMessageHandler innerHandler) + : base(innerHandler) { + } + + /// <summary> + /// Sends an HTTP request to the inner handler to send to the server as an asynchronous operation. + /// </summary> + /// <param name="request">The HTTP request message to send to the server.</param> + /// <param name="cancellationToken">A cancellation token to cancel operation.</param> + /// <returns> + /// Returns <see cref="T:System.Threading.Tasks.Task`1" />. The task object representing the asynchronous operation. + /// </returns> + protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, System.Threading.CancellationToken cancellationToken) { + var response = await base.SendAsync(request, cancellationToken); + + // Filter the responses to the allowable set of HTTP status codes. + if (response.StatusCode != HttpStatusCode.OK && response.StatusCode != HttpStatusCode.BadRequest) { + if (Logger.Channel.IsErrorEnabled) { + var content = await response.Content.ReadAsStringAsync(); Logger.Channel.ErrorFormat( "Unexpected HTTP status code {0} {1} received in direct response:{2}{3}", - (int)response.Status, - response.Status, + (int)response.StatusCode, + response.StatusCode, Environment.NewLine, - reader.ReadToEnd()); + content); } - } - // Call dispose before throwing since we're not including the response in the - // exception we're throwing. - response.Dispose(); + // Call dispose before throwing since we're not including the response in the + // exception we're throwing. + response.Dispose(); - ErrorUtilities.ThrowProtocol(OpenIdStrings.UnexpectedHttpStatusCode, (int)response.Status, response.Status); - } + ErrorUtilities.ThrowProtocol(OpenIdStrings.UnexpectedHttpStatusCode, (int)response.StatusCode, response.StatusCode); + } - return response; + return response; + } } } } |