diff options
Diffstat (limited to 'src/DotNetOpenAuth.OpenId')
25 files changed, 1091 insertions, 245 deletions
diff --git a/src/DotNetOpenAuth.OpenId/DefaultOpenIdHostFactories.cs b/src/DotNetOpenAuth.OpenId/DefaultOpenIdHostFactories.cs new file mode 100644 index 0000000..cd41c72 --- /dev/null +++ b/src/DotNetOpenAuth.OpenId/DefaultOpenIdHostFactories.cs @@ -0,0 +1,58 @@ +//----------------------------------------------------------------------- +// <copyright file="DefaultOpenIdHostFactories.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.OpenId { + using System; + using System.Collections.Generic; + using System.Linq; + using System.Net.Cache; + using System.Net.Http; + using System.Text; + using System.Threading.Tasks; + + /// <summary> + /// Creates default instances of required dependencies. + /// </summary> + public class DefaultOpenIdHostFactories : IHostFactories { + /// <summary> + /// Initializes a new instance of a concrete derivation of <see cref="HttpMessageHandler" /> + /// to be used for outbound HTTP traffic. + /// </summary> + /// <returns>An instance of <see cref="HttpMessageHandler"/>.</returns> + /// <remarks> + /// An instance of <see cref="WebRequestHandler" /> is recommended where available; + /// otherwise an instance of <see cref="HttpClientHandler" /> is recommended. + /// </remarks> + public virtual HttpMessageHandler CreateHttpMessageHandler() { + var handler = new UntrustedWebRequestHandler(); + ((WebRequestHandler)handler.InnerHandler).CachePolicy = new RequestCachePolicy(RequestCacheLevel.NoCacheNoStore); + return handler; + } + + /// <summary> + /// Initializes a new instance of the <see cref="HttpClient" /> class + /// to be used for outbound HTTP traffic. + /// </summary> + /// <param name="handler">The handler to pass to the <see cref="HttpClient" /> constructor. + /// May be null to use the default that would be provided by <see cref="CreateHttpMessageHandler" />.</param> + /// <returns> + /// An instance of <see cref="HttpClient" />. + /// </returns> + public HttpClient CreateHttpClient(HttpMessageHandler handler) { + handler = handler ?? this.CreateHttpMessageHandler(); + var untrustedHandler = handler as UntrustedWebRequestHandler; + HttpClient client; + if (untrustedHandler != null) { + client = untrustedHandler.CreateClient(); + } else { + client = new HttpClient(handler); + } + + client.DefaultRequestHeaders.UserAgent.Add(Util.LibraryVersionHeader); + return client; + } + } +} diff --git a/src/DotNetOpenAuth.OpenId/DotNetOpenAuth.OpenId.csproj b/src/DotNetOpenAuth.OpenId/DotNetOpenAuth.OpenId.csproj index e238d58..ab4b6a7 100644 --- a/src/DotNetOpenAuth.OpenId/DotNetOpenAuth.OpenId.csproj +++ b/src/DotNetOpenAuth.OpenId/DotNetOpenAuth.OpenId.csproj @@ -30,6 +30,7 @@ <Compile Include="Configuration\OpenIdRelyingPartyElement.cs" /> <Compile Include="Configuration\OpenIdRelyingPartySecuritySettingsElement.cs" /> <Compile Include="Configuration\XriResolverElement.cs" /> + <Compile Include="DefaultOpenIdHostFactories.cs" /> <Compile Include="OpenIdXrdsHelperRelyingParty.cs" /> <Compile Include="OpenId\Association.cs" /> <Compile Include="OpenId\AuthenticationRequestMode.cs" /> @@ -141,6 +142,7 @@ <Compile Include="OpenId\Protocol.cs" /> <Compile Include="OpenId\IOpenIdApplicationStore.cs" /> <Compile Include="OpenId\RelyingParty\RelyingPartySecuritySettings.cs" /> + <Compile Include="OpenId\UntrustedWebRequestHandler.cs" /> <Compile Include="OpenId\UriDiscoveryService.cs" /> <Compile Include="OpenId\XriDiscoveryProxyService.cs" /> <Compile Include="OpenId\SecuritySettings.cs" /> @@ -183,9 +185,11 @@ </ProjectReference> </ItemGroup> <ItemGroup> - <Reference Include="Validation"> - <HintPath>..\packages\Validation.2.0.1.12362\lib\portable-windows8+net40+sl5+windowsphone8\Validation.dll</HintPath> - <Private>True</Private> + <Reference Include="System.Net.Http" /> + <Reference Include="System.Net.Http.WebRequest" /> + <Reference Include="Validation, Version=2.0.0.0, Culture=neutral, PublicKeyToken=2fc06f0d701809a7, processorArchitecture=MSIL"> + <SpecificVersion>False</SpecificVersion> + <HintPath>..\packages\Validation.2.0.2.13022\lib\portable-windows8+net40+sl5+windowsphone8\Validation.dll</HintPath> </Reference> </ItemGroup> <ItemGroup> diff --git a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/BackwardCompatibilityBindingElement.cs b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/BackwardCompatibilityBindingElement.cs index ff8a766..4c55360 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/BackwardCompatibilityBindingElement.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/BackwardCompatibilityBindingElement.cs @@ -6,6 +6,8 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { using System; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Reflection; using DotNetOpenAuth.OpenId.Messages; @@ -17,6 +19,17 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// </summary> internal class BackwardCompatibilityBindingElement : IChannelBindingElement { /// <summary> + /// A reusable pre-completed task that may be returned multiple times to reduce GC pressure. + /// </summary> + private static readonly Task<MessageProtections?> NullTask = Task.FromResult<MessageProtections?>(null); + + /// <summary> + /// A reusable pre-completed task that may be returned multiple times to reduce GC pressure. + /// </summary> + private static readonly Task<MessageProtections?> NoneTask = + Task.FromResult<MessageProtections?>(MessageProtections.None); + + /// <summary> /// The "dnoa.op_endpoint" callback parameter that stores the Provider Endpoint URL /// to tack onto the return_to URI. /// </summary> @@ -51,6 +64,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Prepares a message for sending based on the rules of this channel binding element. /// </summary> /// <param name="message">The message to prepare for sending.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// The protections (if any) that this binding element applied to the message. /// Null if this binding element did not even apply to this binding element. @@ -59,7 +73,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - public MessageProtections? ProcessOutgoingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { SignedResponseRequest request = message as SignedResponseRequest; if (request != null && request.Version.Major < 2) { request.AddReturnToArguments(ProviderEndpointParameterName, request.Recipient.AbsoluteUri); @@ -69,10 +83,10 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { request.AddReturnToArguments(ClaimedIdentifierParameterName, authRequest.ClaimedIdentifier); } - return MessageProtections.None; + return NoneTask; } - return null; + return NullTask; } /// <summary> @@ -80,6 +94,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// validates an incoming message based on the rules of this channel binding element. /// </summary> /// <param name="message">The incoming message to process.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// The protections (if any) that this binding element applied to the message. /// Null if this binding element did not even apply to this binding element. @@ -92,7 +107,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - public MessageProtections? ProcessIncomingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { IndirectSignedResponse response = message as IndirectSignedResponse; if (response != null && response.Version.Major < 2) { // GetReturnToArgument may return parameters that are not signed, @@ -118,10 +133,10 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { } } - return MessageProtections.None; + return NoneTask; } - return null; + return NullTask; } #endregion diff --git a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/ExtensionsBindingElement.cs b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/ExtensionsBindingElement.cs index f24c8b4..727dad7 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/ExtensionsBindingElement.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/ExtensionsBindingElement.cs @@ -10,6 +10,8 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Text; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Reflection; using DotNetOpenAuth.OpenId.Extensions; @@ -22,6 +24,17 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// </summary> internal class ExtensionsBindingElement : IChannelBindingElement { /// <summary> + /// A reusable pre-completed task that may be returned multiple times to reduce GC pressure. + /// </summary> + private static readonly Task<MessageProtections?> NullTask = Task.FromResult<MessageProtections?>(null); + + /// <summary> + /// A reusable pre-completed task that may be returned multiple times to reduce GC pressure. + /// </summary> + private static readonly Task<MessageProtections?> NoneTask = + Task.FromResult<MessageProtections?>(MessageProtections.None); + + /// <summary> /// False if unsigned extensions should be dropped. Must always be true on Providers, since RPs never sign extensions. /// </summary> private readonly bool receiveUnsignedExtensions; @@ -68,6 +81,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Prepares a message for sending based on the rules of this channel binding element. /// </summary> /// <param name="message">The message to prepare for sending.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// The protections (if any) that this binding element applied to the message. /// Null if this binding element did not even apply to this binding element. @@ -77,7 +91,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> [SuppressMessage("Microsoft.Maintainability", "CA1506:AvoidExcessiveClassCoupling", Justification = "It doesn't look too bad to me. :)")] - public MessageProtections? ProcessOutgoingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { var extendableMessage = message as IProtocolMessageWithExtensions; if (extendableMessage != null) { Protocol protocol = Protocol.Lookup(message.Version); @@ -120,10 +134,10 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { // Add the extension parameters to the base message for transmission. baseMessageDictionary.AddExtraParameters(extensionManager.GetArgumentsToSend(includeOpenIdPrefix)); - return MessageProtections.None; + return NoneTask; } - return null; + return NullTask; } /// <summary> @@ -131,6 +145,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// validates an incoming message based on the rules of this channel binding element. /// </summary> /// <param name="message">The incoming message to process.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// The protections (if any) that this binding element applied to the message. /// Null if this binding element did not even apply to this binding element. @@ -143,7 +158,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - public MessageProtections? ProcessIncomingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { var extendableMessage = message as IProtocolMessageWithExtensions; if (extendableMessage != null) { // First add the extensions that are signed by the Provider. @@ -164,10 +179,10 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { } } - return MessageProtections.None; + return NoneTask; } - return null; + return NullTask; } #endregion diff --git a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/KeyValueFormEncoding.cs b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/KeyValueFormEncoding.cs index 6ad66c0..9c06e6b 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/KeyValueFormEncoding.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/KeyValueFormEncoding.cs @@ -11,6 +11,8 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { using System.Globalization; using System.IO; using System.Text; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using Validation; @@ -131,12 +133,13 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// <param name="data">The stream of Key-Value Form encoded bytes.</param> /// <returns>The deserialized dictionary.</returns> /// <exception cref="FormatException">Thrown when the data is not in the expected format.</exception> - public IDictionary<string, string> GetDictionary(Stream data) { + public async Task<IDictionary<string, string>> GetDictionaryAsync(Stream data, CancellationToken cancellationToken) { using (StreamReader reader = new StreamReader(data, textEncoding)) { var dict = new Dictionary<string, string>(); int line_num = 0; string line; - while ((line = reader.ReadLine()) != null) { + while ((line = await reader.ReadLineAsync()) != null) { + cancellationToken.ThrowIfCancellationRequested(); line_num++; if (this.ConformanceLevel == KeyValueFormConformanceLevel.Loose) { line = line.Trim(); diff --git a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/OpenIdChannel.cs b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/OpenIdChannel.cs index 5a6b8bb..89bcd77 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/OpenIdChannel.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/OpenIdChannel.cs @@ -7,12 +7,19 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { using System; using System.Collections.Generic; + using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.IO; using System.Linq; using System.Net; + using System.Net.Http; + using System.Net.Http.Headers; using System.Text; + using System.Threading; + using System.Threading.Tasks; + + using DotNetOpenAuth.Configuration; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OpenId.Extensions; @@ -40,13 +47,14 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { private KeyValueFormEncoding keyValueForm = new KeyValueFormEncoding(); /// <summary> - /// Initializes a new instance of the <see cref="OpenIdChannel"/> class. + /// Initializes a new instance of the <see cref="OpenIdChannel" /> class. /// </summary> /// <param name="messageTypeProvider">A class prepared to analyze incoming messages and indicate what concrete /// message types can deserialize from it.</param> /// <param name="bindingElements">The binding elements to use in sending and receiving messages.</param> - protected OpenIdChannel(IMessageFactory messageTypeProvider, IChannelBindingElement[] bindingElements) - : base(messageTypeProvider, bindingElements) { + /// <param name="hostFactories">The host factories.</param> + protected OpenIdChannel(IMessageFactory messageTypeProvider, IChannelBindingElement[] bindingElements, IHostFactories hostFactories) + : base(messageTypeProvider, bindingElements, hostFactories ?? new DefaultOpenIdHostFactories()) { Requires.NotNull(messageTypeProvider, "messageTypeProvider"); // Customize the binding element order, since we play some tricks for higher @@ -68,33 +76,30 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { } this.CustomizeBindingElementOrder(outgoingBindingElements, incomingBindingElements); - - // Change out the standard web request handler to reflect the standard - // OpenID pattern that outgoing web requests are to unknown and untrusted - // servers on the Internet. - this.WebRequestHandler = new UntrustedWebRequestHandler(); } /// <summary> /// Verifies the integrity and applicability of an incoming message. /// </summary> /// <param name="message">The message just received.</param> - /// <exception cref="ProtocolException"> - /// Thrown when the message is somehow invalid, except for check_authentication messages. - /// This can be due to tampering, replay attack or expiration, among other things. - /// </exception> - protected override void ProcessIncomingMessage(IProtocolMessage message) { + /// <param name="cancellationToken">The cancellation token.</param> + /// <returns> + /// A task that completes with the asynchronous operation. + /// </returns> + /// <exception cref="ProtocolException">Thrown when the message is somehow invalid, except for check_authentication messages. + /// This can be due to tampering, replay attack or expiration, among other things.</exception> + protected override async Task ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { var checkAuthRequest = message as CheckAuthenticationRequest; if (checkAuthRequest != null) { IndirectSignedResponse originalResponse = new IndirectSignedResponse(checkAuthRequest, this); try { - base.ProcessIncomingMessage(originalResponse); + await base.ProcessIncomingMessageAsync(originalResponse, cancellationToken); checkAuthRequest.IsValid = true; } catch (ProtocolException) { checkAuthRequest.IsValid = false; } } else { - base.ProcessIncomingMessage(message); + await base.ProcessIncomingMessageAsync(message, cancellationToken); } // Convert an OpenID indirect error message, which we never expect @@ -120,7 +125,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// <returns> /// The <see cref="HttpWebRequest"/> prepared to send the request. /// </returns> - protected override HttpWebRequest CreateHttpRequest(IDirectedProtocolMessage request) { + protected override HttpRequestMessage CreateHttpRequest(IDirectedProtocolMessage request) { return this.InitializeRequestAsPost(request); } @@ -128,13 +133,16 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Gets the protocol message that may be in the given HTTP response. /// </summary> /// <param name="response">The response that is anticipated to contain an protocol message.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// The deserialized message parts, if found. Null otherwise. /// </returns> /// <exception cref="ProtocolException">Thrown when the response is not valid.</exception> - protected override IDictionary<string, string> ReadFromResponseCore(IncomingWebResponse response) { + protected override async Task<IDictionary<string, string>> ReadFromResponseCoreAsync(HttpResponseMessage response, CancellationToken cancellationToken) { try { - return this.keyValueForm.GetDictionary(response.ResponseStream); + using (var responseStream = await response.Content.ReadAsStreamAsync()) { + return await this.keyValueForm.GetDictionaryAsync(responseStream, cancellationToken); + } } catch (FormatException ex) { throw ErrorUtilities.Wrap(ex, ex.Message); } @@ -145,7 +153,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// </summary> /// <param name="response">The HTTP direct response.</param> /// <param name="message">The newly instantiated message, prior to deserialization.</param> - protected override void OnReceivingDirectResponse(IncomingWebResponse response, IDirectResponseProtocolMessage message) { + protected override void OnReceivingDirectResponse(HttpResponseMessage response, IDirectResponseProtocolMessage message) { base.OnReceivingDirectResponse(response, message); // Verify that the expected HTTP status code was used for the message, @@ -155,10 +163,10 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { var httpDirectResponse = message as IHttpDirectResponse; if (httpDirectResponse != null) { ErrorUtilities.VerifyProtocol( - httpDirectResponse.HttpStatusCode == response.Status, + httpDirectResponse.HttpStatusCode == response.StatusCode, MessagingStrings.UnexpectedHttpStatusCode, (int)httpDirectResponse.HttpStatusCode, - (int)response.Status); + (int)response.StatusCode); } } } @@ -174,55 +182,81 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// <remarks> /// This method implements spec V1.0 section 5.3. /// </remarks> - protected override OutgoingWebResponse PrepareDirectResponse(IProtocolMessage response) { + protected override HttpResponseMessage PrepareDirectResponse(IProtocolMessage response) { var messageAccessor = this.MessageDescriptions.GetAccessor(response); var fields = messageAccessor.Serialize(); byte[] keyValueEncoding = KeyValueFormEncoding.GetBytes(fields); - OutgoingWebResponse preparedResponse = new OutgoingWebResponse(); + var preparedResponse = new HttpResponseMessage(); ApplyMessageTemplate(response, preparedResponse); - preparedResponse.Headers.Add(HttpResponseHeader.ContentType, KeyValueFormContentType); - preparedResponse.OriginalMessage = response; - preparedResponse.ResponseStream = new MemoryStream(keyValueEncoding); + var content = new StreamContent(new MemoryStream(keyValueEncoding)); + content.Headers.ContentType = new MediaTypeHeaderValue(KeyValueFormContentType); + preparedResponse.Content = content; IHttpDirectResponse httpMessage = response as IHttpDirectResponse; if (httpMessage != null) { - preparedResponse.Status = httpMessage.HttpStatusCode; + preparedResponse.StatusCode = httpMessage.HttpStatusCode; } return preparedResponse; } /// <summary> - /// Gets the direct response of a direct HTTP request. + /// Provides derived-types the opportunity to wrap an <see cref="HttpMessageHandler" /> with another one. + /// </summary> + /// <param name="innerHandler">The inner handler received from <see cref="IHostFactories" /></param> + /// <returns> + /// The handler to use in <see cref="HttpClient" /> instances. + /// </returns> + protected override HttpMessageHandler WrapMessageHandler(HttpMessageHandler innerHandler) { + return new ErrorFilteringMessageHandler(base.WrapMessageHandler(innerHandler)); + } + + /// <summary> + /// An HTTP handler that throws an exception if the response message's HTTP status code doesn't fall + /// within those allowed by the OpenID spec. /// </summary> - /// <param name="webRequest">The web request.</param> - /// <returns>The response to the web request.</returns> - /// <exception cref="ProtocolException">Thrown on network or protocol errors.</exception> - protected override IncomingWebResponse GetDirectResponse(HttpWebRequest webRequest) { - IncomingWebResponse response = this.WebRequestHandler.GetResponse(webRequest, DirectWebRequestOptions.AcceptAllHttpResponses); - - // Filter the responses to the allowable set of HTTP status codes. - if (response.Status != HttpStatusCode.OK && response.Status != HttpStatusCode.BadRequest) { - if (Logger.Channel.IsErrorEnabled) { - using (var reader = new StreamReader(response.ResponseStream)) { + private class ErrorFilteringMessageHandler : DelegatingHandler { + /// <summary> + /// Initializes a new instance of the <see cref="ErrorFilteringMessageHandler" /> class. + /// </summary> + /// <param name="innerHandler">The inner handler which is responsible for processing the HTTP response messages.</param> + internal ErrorFilteringMessageHandler(HttpMessageHandler innerHandler) + : base(innerHandler) { + } + + /// <summary> + /// Sends an HTTP request to the inner handler to send to the server as an asynchronous operation. + /// </summary> + /// <param name="request">The HTTP request message to send to the server.</param> + /// <param name="cancellationToken">A cancellation token to cancel operation.</param> + /// <returns> + /// Returns <see cref="T:System.Threading.Tasks.Task`1" />. The task object representing the asynchronous operation. + /// </returns> + protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, System.Threading.CancellationToken cancellationToken) { + var response = await base.SendAsync(request, cancellationToken); + + // Filter the responses to the allowable set of HTTP status codes. + if (response.StatusCode != HttpStatusCode.OK && response.StatusCode != HttpStatusCode.BadRequest) { + if (Logger.Channel.IsErrorEnabled) { + var content = await response.Content.ReadAsStringAsync(); Logger.Channel.ErrorFormat( "Unexpected HTTP status code {0} {1} received in direct response:{2}{3}", - (int)response.Status, - response.Status, + (int)response.StatusCode, + response.StatusCode, Environment.NewLine, - reader.ReadToEnd()); + content); } - } - // Call dispose before throwing since we're not including the response in the - // exception we're throwing. - response.Dispose(); + // Call dispose before throwing since we're not including the response in the + // exception we're throwing. + response.Dispose(); - ErrorUtilities.ThrowProtocol(OpenIdStrings.UnexpectedHttpStatusCode, (int)response.Status, response.Status); - } + ErrorUtilities.ThrowProtocol(OpenIdStrings.UnexpectedHttpStatusCode, (int)response.StatusCode, response.StatusCode); + } - return response; + return response; + } } } } diff --git a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/ReturnToSignatureBindingElement.cs b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/ReturnToSignatureBindingElement.cs index 726c01f..2aad922 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/ReturnToSignatureBindingElement.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/ReturnToSignatureBindingElement.cs @@ -9,6 +9,8 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { using System.Collections.Generic; using System.Collections.Specialized; using System.Security.Cryptography; + using System.Threading; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Configuration; using DotNetOpenAuth.Messaging; @@ -32,6 +34,17 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// </remarks> internal class ReturnToSignatureBindingElement : IChannelBindingElement { /// <summary> + /// A reusable pre-completed task that may be returned multiple times to reduce GC pressure. + /// </summary> + private static readonly Task<MessageProtections?> NullTask = Task.FromResult<MessageProtections?>(null); + + /// <summary> + /// A reusable pre-completed task that may be returned multiple times to reduce GC pressure. + /// </summary> + private static readonly Task<MessageProtections?> NoneTask = + Task.FromResult<MessageProtections?>(MessageProtections.None); + + /// <summary> /// The name of the callback parameter we'll tack onto the return_to value /// to store our signature on the return_to parameter. /// </summary> @@ -90,6 +103,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Prepares a message for sending based on the rules of this channel binding element. /// </summary> /// <param name="message">The message to prepare for sending.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// The protections (if any) that this binding element applied to the message. /// Null if this binding element did not even apply to this binding element. @@ -98,7 +112,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - public MessageProtections? ProcessOutgoingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { SignedResponseRequest request = message as SignedResponseRequest; if (request != null && request.ReturnTo != null && request.SignReturnTo) { var cryptoKeyPair = this.cryptoKeyStore.GetCurrentKey(SecretUri.AbsoluteUri, OpenIdElement.Configuration.MaxAuthenticationTime); @@ -107,10 +121,10 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { request.AddReturnToArguments(ReturnToSignatureParameterName, signature); // We return none because we are not signing the entire message (only a part). - return MessageProtections.None; + return NoneTask; } - return null; + return NullTask; } /// <summary> @@ -118,6 +132,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// validates an incoming message based on the rules of this channel binding element. /// </summary> /// <param name="message">The incoming message to process.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// The protections (if any) that this binding element applied to the message. /// Null if this binding element did not even apply to this binding element. @@ -130,7 +145,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - public MessageProtections? ProcessIncomingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { IndirectSignedResponse response = message as IndirectSignedResponse; if (response != null) { @@ -150,11 +165,11 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { Logger.Bindings.WarnFormat("The return_to signature failed verification."); } - return MessageProtections.None; + return NoneTask; } } - return null; + return NullTask; } #endregion diff --git a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/SigningBindingElement.cs b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/SigningBindingElement.cs index 584b0e9..8f602cf 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/SigningBindingElement.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/SigningBindingElement.cs @@ -11,6 +11,8 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { using System.Globalization; using System.Linq; using System.Net.Security; + using System.Threading; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Loggers; using DotNetOpenAuth.Messaging; @@ -23,6 +25,11 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Signs and verifies authentication assertions. /// </summary> internal abstract class SigningBindingElement : IChannelBindingElement { + /// <summary> + /// A reusable pre-completed task that may be returned multiple times to reduce GC pressure. + /// </summary> + private static readonly Task<MessageProtections?> NullTask = Task.FromResult<MessageProtections?>(null); + #region IChannelBindingElement Properties /// <summary> @@ -53,12 +60,13 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Prepares a message for sending based on the rules of this channel binding element. /// </summary> /// <param name="message">The message to prepare for sending.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// The protections (if any) that this binding element applied to the message. /// Null if this binding element did not even apply to this binding element. /// </returns> - public virtual MessageProtections? ProcessOutgoingMessage(IProtocolMessage message) { - return null; + public virtual Task<MessageProtections?> ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { + return NullTask; } /// <summary> @@ -66,6 +74,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// validates an incoming message based on the rules of this channel binding element. /// </summary> /// <param name="message">The incoming message to process.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// The protections (if any) that this binding element applied to the message. /// Null if this binding element did not even apply to this binding element. @@ -74,7 +83,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Thrown when the binding element rules indicate that this message is invalid and should /// NOT be processed. /// </exception> - public MessageProtections? ProcessIncomingMessage(IProtocolMessage message) { + public async Task<MessageProtections?> ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { var signedMessage = message as ITamperResistantOpenIdMessage; if (signedMessage != null) { Logger.Bindings.DebugFormat("Verifying incoming {0} message signature of: {1}", message.GetType().Name, signedMessage.Signature); @@ -92,7 +101,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { } else { ErrorUtilities.VerifyInternal(this.Channel != null, "Cannot verify private association signature because we don't have a channel."); - protectionsApplied = this.VerifySignatureByUnrecognizedHandle(message, signedMessage, protectionsApplied); + protectionsApplied = await this.VerifySignatureByUnrecognizedHandleAsync(message, signedMessage, protectionsApplied, cancellationToken); } return protectionsApplied; @@ -107,8 +116,9 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// <param name="message">The message.</param> /// <param name="signedMessage">The signed message.</param> /// <param name="protectionsApplied">The protections applied.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns>The applied protections.</returns> - protected abstract MessageProtections VerifySignatureByUnrecognizedHandle(IProtocolMessage message, ITamperResistantOpenIdMessage signedMessage, MessageProtections protectionsApplied); + protected abstract Task<MessageProtections> VerifySignatureByUnrecognizedHandleAsync(IProtocolMessage message, ITamperResistantOpenIdMessage signedMessage, MessageProtections protectionsApplied, CancellationToken cancellationToken); #endregion diff --git a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/SkipSecurityBindingElement.cs b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/SkipSecurityBindingElement.cs index d162cf6..ad8d59e 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/SkipSecurityBindingElement.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/SkipSecurityBindingElement.cs @@ -10,12 +10,19 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { using System.Diagnostics; using System.Linq; using System.Text; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; /// <summary> /// Spoofs security checks on incoming OpenID messages. /// </summary> internal class SkipSecurityBindingElement : IChannelBindingElement { + /// <summary> + /// A reusable pre-completed task that may be returned multiple times to reduce GC pressure. + /// </summary> + private static readonly Task<MessageProtections?> NullTask = Task.FromResult<MessageProtections?>(null); + #region IChannelBindingElement Members /// <summary> @@ -42,6 +49,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Prepares a message for sending based on the rules of this channel binding element. /// </summary> /// <param name="message">The message to prepare for sending.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// The protections (if any) that this binding element applied to the message. /// Null if this binding element did not even apply to this binding element. @@ -50,7 +58,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - public MessageProtections? ProcessOutgoingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { Debug.Fail("SkipSecurityBindingElement.ProcessOutgoingMessage should never be called."); return null; } @@ -60,6 +68,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// validates an incoming message based on the rules of this channel binding element. /// </summary> /// <param name="message">The incoming message to process.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// The protections (if any) that this binding element applied to the message. /// Null if this binding element did not even apply to this binding element. @@ -72,14 +81,14 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - public MessageProtections? ProcessIncomingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { var signedMessage = message as ITamperResistantOpenIdMessage; if (signedMessage != null) { Logger.Bindings.DebugFormat("Skipped security checks of incoming {0} message for preview purposes.", message.GetType().Name); - return this.Protection; + return Task.FromResult<MessageProtections?>(this.Protection); } - return null; + return NullTask; } #endregion diff --git a/src/DotNetOpenAuth.OpenId/OpenId/Extensions/OpenIdExtensionFactoryAggregator.cs b/src/DotNetOpenAuth.OpenId/OpenId/Extensions/OpenIdExtensionFactoryAggregator.cs index ddd60f3..3f88d41 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/Extensions/OpenIdExtensionFactoryAggregator.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/Extensions/OpenIdExtensionFactoryAggregator.cs @@ -72,7 +72,7 @@ namespace DotNetOpenAuth.OpenId.Extensions { var factoriesElement = DotNetOpenAuth.Configuration.OpenIdElement.Configuration.ExtensionFactories; var aggregator = new OpenIdExtensionFactoryAggregator(); aggregator.Factories.Add(new StandardOpenIdExtensionFactory()); - aggregator.factories.AddRange(factoriesElement.CreateInstances(false)); + aggregator.factories.AddRange(factoriesElement.CreateInstances(false, null)); return aggregator; } } diff --git a/src/DotNetOpenAuth.OpenId/OpenId/IIdentifierDiscoveryService.cs b/src/DotNetOpenAuth.OpenId/OpenId/IIdentifierDiscoveryService.cs index 20b8f1c..1055d15 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/IIdentifierDiscoveryService.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/IIdentifierDiscoveryService.cs @@ -10,7 +10,11 @@ namespace DotNetOpenAuth.OpenId { using System.Diagnostics.CodeAnalysis; using System.Diagnostics.Contracts; using System.Linq; + using System.Net.Http; + using System.Runtime.CompilerServices; using System.Text; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId.RelyingParty; using Validation; @@ -23,13 +27,38 @@ namespace DotNetOpenAuth.OpenId { /// Performs discovery on the specified identifier. /// </summary> /// <param name="identifier">The identifier to perform discovery on.</param> - /// <param name="requestHandler">The means to place outgoing HTTP requests.</param> - /// <param name="abortDiscoveryChain">if set to <c>true</c>, no further discovery services will be called for this identifier.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// A sequence of service endpoints yielded by discovery. Must not be null, but may be empty. /// </returns> [SuppressMessage("Microsoft.Design", "CA1021:AvoidOutParameters", MessageId = "2#", Justification = "By design")] - [Pure] - IEnumerable<IdentifierDiscoveryResult> Discover(Identifier identifier, IDirectWebRequestHandler requestHandler, out bool abortDiscoveryChain); + Task<IdentifierDiscoveryServiceResult> DiscoverAsync(Identifier identifier, CancellationToken cancellationToken); + } + + /// <summary> + /// Describes the result of <see cref="IIdentifierDiscoveryService.DiscoverAsync"/>. + /// </summary> + public class IdentifierDiscoveryServiceResult { + /// <summary> + /// Initializes a new instance of the <see cref="IdentifierDiscoveryServiceResult" /> class. + /// </summary> + /// <param name="results">The results.</param> + /// <param name="abortDiscoveryChain">if set to <c>true</c>, no further discovery services will be called for this identifier.</param> + public IdentifierDiscoveryServiceResult(IEnumerable<IdentifierDiscoveryResult> results, bool abortDiscoveryChain = false) { + Requires.NotNull(results, "results"); + + this.Results = results; + this.AbortDiscoveryChain = abortDiscoveryChain; + } + + /// <summary> + /// Gets the results from this individual discovery service. + /// </summary> + public IEnumerable<IdentifierDiscoveryResult> Results { get; private set; } + + /// <summary> + /// Gets a value indicating whether no further discovery services should be called for this identifier. + /// </summary> + public bool AbortDiscoveryChain { get; private set; } } } diff --git a/src/DotNetOpenAuth.OpenId/OpenId/IOpenIdHost.cs b/src/DotNetOpenAuth.OpenId/OpenId/IOpenIdHost.cs index 419cc84..cf52fef 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/IOpenIdHost.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/IOpenIdHost.cs @@ -8,6 +8,7 @@ namespace DotNetOpenAuth.OpenId { using System; using System.Collections.Generic; using System.Linq; + using System.Net.Http; using System.Text; using DotNetOpenAuth.Messaging; @@ -21,8 +22,8 @@ namespace DotNetOpenAuth.OpenId { SecuritySettings SecuritySettings { get; } /// <summary> - /// Gets the web request handler. + /// Gets the factory for various dependencies. /// </summary> - IDirectWebRequestHandler WebRequestHandler { get; } + IHostFactories HostFactories { get; } } } diff --git a/src/DotNetOpenAuth.OpenId/OpenId/IdentifierDiscoveryServices.cs b/src/DotNetOpenAuth.OpenId/OpenId/IdentifierDiscoveryServices.cs index 1b20d4e..a515033 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/IdentifierDiscoveryServices.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/IdentifierDiscoveryServices.cs @@ -7,6 +7,9 @@ namespace DotNetOpenAuth.OpenId { using System.Collections.Generic; using System.Linq; + using System.Runtime.CompilerServices; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Configuration; using DotNetOpenAuth.Messaging; using Validation; @@ -33,7 +36,7 @@ namespace DotNetOpenAuth.OpenId { Requires.NotNull(host, "host"); this.host = host; - this.discoveryServices.AddRange(OpenIdElement.Configuration.RelyingParty.DiscoveryServices.CreateInstances(true)); + this.discoveryServices.AddRange(OpenIdElement.Configuration.RelyingParty.DiscoveryServices.CreateInstances(true, host.HostFactories)); } /// <summary> @@ -47,16 +50,16 @@ namespace DotNetOpenAuth.OpenId { /// Performs discovery on the specified identifier. /// </summary> /// <param name="identifier">The identifier to discover services for.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns>A non-null sequence of services discovered for the identifier.</returns> - public IEnumerable<IdentifierDiscoveryResult> Discover(Identifier identifier) { + public async Task<IEnumerable<IdentifierDiscoveryResult>> DiscoverAsync(Identifier identifier, CancellationToken cancellationToken) { Requires.NotNull(identifier, "identifier"); IEnumerable<IdentifierDiscoveryResult> results = Enumerable.Empty<IdentifierDiscoveryResult>(); foreach (var discoverer in this.DiscoveryServices) { - bool abortDiscoveryChain; - var discoveryResults = discoverer.Discover(identifier, this.host.WebRequestHandler, out abortDiscoveryChain).CacheGeneratedResults(); - results = results.Concat(discoveryResults); - if (abortDiscoveryChain) { + var discoveryResults = await discoverer.DiscoverAsync(identifier, cancellationToken); + results = results.Concat(discoveryResults.Results.CacheGeneratedResults()); + if (discoveryResults.AbortDiscoveryChain) { Logger.OpenId.InfoFormat("Further discovery on '{0}' was stopped by the {1} discovery service.", identifier, discoverer.GetType().Name); break; } diff --git a/src/DotNetOpenAuth.OpenId/OpenId/Messages/NegativeAssertionResponse.cs b/src/DotNetOpenAuth.OpenId/OpenId/Messages/NegativeAssertionResponse.cs index 9aac107..cb5b856 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/Messages/NegativeAssertionResponse.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/Messages/NegativeAssertionResponse.cs @@ -9,6 +9,8 @@ namespace DotNetOpenAuth.OpenId.Messages { using System.Collections.Generic; using System.Linq; using System.Text; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using Validation; @@ -21,34 +23,19 @@ namespace DotNetOpenAuth.OpenId.Messages { /// <summary> /// Initializes a new instance of the <see cref="NegativeAssertionResponse"/> class. /// </summary> - /// <param name="request">The request that the relying party sent.</param> - internal NegativeAssertionResponse(CheckIdRequest request) - : this(request, null) { + /// <param name="version">The version.</param> + /// <param name="relyingPartyReturnTo">The relying party return to.</param> + /// <param name="mode">The value of the openid.mode parameter.</param> + internal NegativeAssertionResponse(Version version, Uri relyingPartyReturnTo, string mode) + : base(version, relyingPartyReturnTo, mode) { } /// <summary> /// Initializes a new instance of the <see cref="NegativeAssertionResponse"/> class. /// </summary> - /// <param name="request">The request that the relying party sent.</param> - /// <param name="channel">The channel to use to simulate construction of the user_setup_url, if applicable. May be null, but the user_setup_url will not be constructed.</param> - internal NegativeAssertionResponse(SignedResponseRequest request, Channel channel) + /// <param name="request">The request.</param> + internal NegativeAssertionResponse(SignedResponseRequest request) : base(request, GetMode(request)) { - // If appropriate, and when we're provided with a channel to do it, - // go ahead and construct the user_setup_url - if (this.Version.Major < 2 && request.Immediate && channel != null) { - // All requests are CheckIdRequests in OpenID 1.x, so this cast should be safe. - this.UserSetupUrl = ConstructUserSetupUrl((CheckIdRequest)request, channel); - } - } - - /// <summary> - /// Initializes a new instance of the <see cref="NegativeAssertionResponse"/> class. - /// </summary> - /// <param name="version">The version.</param> - /// <param name="relyingPartyReturnTo">The relying party return to.</param> - /// <param name="mode">The value of the openid.mode parameter.</param> - internal NegativeAssertionResponse(Version version, Uri relyingPartyReturnTo, string mode) - : base(version, relyingPartyReturnTo, mode) { } /// <summary> @@ -107,13 +94,34 @@ namespace DotNetOpenAuth.OpenId.Messages { } /// <summary> + /// Initializes a new instance of the <see cref="NegativeAssertionResponse" /> class. + /// </summary> + /// <param name="request">The request that the relying party sent.</param> + /// <param name="cancellationToken">The cancellation token.</param> + /// <param name="channel">The channel to use to simulate construction of the user_setup_url, if applicable. May be null, but the user_setup_url will not be constructed.</param> + /// <returns>The negative assertion message that will indicate failure for the user to authenticate or an unwillingness to log into the relying party.</returns> + internal static async Task<NegativeAssertionResponse> CreateAsync(SignedResponseRequest request, CancellationToken cancellationToken, Channel channel = null) { + var result = new NegativeAssertionResponse(request); + + // If appropriate, and when we're provided with a channel to do it, + // go ahead and construct the user_setup_url + if (result.Version.Major < 2 && request.Immediate && channel != null) { + // All requests are CheckIdRequests in OpenID 1.x, so this cast should be safe. + result.UserSetupUrl = await ConstructUserSetupUrlAsync((CheckIdRequest)request, channel, cancellationToken); + } + + return result; + } + + /// <summary> /// Constructs the value for the user_setup_url parameter to be sent back /// in negative assertions in response to OpenID 1.x RP's checkid_immediate requests. /// </summary> /// <param name="immediateRequest">The immediate request.</param> /// <param name="channel">The channel to use to simulate construction of the message.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns>The value to use for the user_setup_url parameter.</returns> - private static Uri ConstructUserSetupUrl(CheckIdRequest immediateRequest, Channel channel) { + private static async Task<Uri> ConstructUserSetupUrlAsync(CheckIdRequest immediateRequest, Channel channel, CancellationToken cancellationToken) { Requires.NotNull(immediateRequest, "immediateRequest"); Requires.NotNull(channel, "channel"); ErrorUtilities.VerifyInternal(immediateRequest.Immediate, "Only immediate requests should be sent here."); @@ -123,7 +131,8 @@ namespace DotNetOpenAuth.OpenId.Messages { setupRequest.ReturnTo = immediateRequest.ReturnTo; setupRequest.Realm = immediateRequest.Realm; setupRequest.AssociationHandle = immediateRequest.AssociationHandle; - return channel.PrepareResponse(setupRequest).GetDirectUriRequest(channel); + var response = await channel.PrepareResponseAsync(setupRequest, cancellationToken); + return response.GetDirectUriRequest(); } /// <summary> diff --git a/src/DotNetOpenAuth.OpenId/OpenId/OpenIdUtilities.cs b/src/DotNetOpenAuth.OpenId/OpenId/OpenIdUtilities.cs index e04a633..b797f3a 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/OpenIdUtilities.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/OpenIdUtilities.cs @@ -11,12 +11,19 @@ namespace DotNetOpenAuth.OpenId { using System.Globalization; using System.IO; using System.Linq; + using System.Net; + using System.Net.Cache; + using System.Net.Http; using System.Text.RegularExpressions; + using System.Threading; + using System.Threading.Tasks; using System.Web; using System.Web.UI; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId.ChannelElements; using DotNetOpenAuth.OpenId.Extensions; + using DotNetOpenAuth.OpenId.RelyingParty; + using Org.Mentalis.Security.Cryptography; using Validation; @@ -76,6 +83,24 @@ namespace DotNetOpenAuth.OpenId { } /// <summary> + /// Immediately sends a redirect response to the browser to initiate an authentication request. + /// </summary> + /// <param name="authenticationRequest">The authentication request to send via redirect.</param> + /// <param name="context">The context.</param> + /// <param name="cancellationToken">The cancellation token.</param> + /// <returns> + /// A task that completes with the asynchronous operation. + /// </returns> + public static async Task RedirectToProviderAsync(this IAuthenticationRequest authenticationRequest, HttpContextBase context = null, CancellationToken cancellationToken = default(CancellationToken)) { + Requires.NotNull(authenticationRequest, "authenticationRequest"); + Verify.Operation(context != null || HttpContext.Current != null, MessagingStrings.HttpContextRequired); + + context = context ?? new HttpContextWrapper(HttpContext.Current); + var response = await authenticationRequest.GetRedirectingResponseAsync(cancellationToken); + await response.SendAsync(context, cancellationToken); + } + + /// <summary> /// Gets the OpenID protocol instance for the version in a message. /// </summary> /// <param name="message">The message.</param> @@ -181,6 +206,51 @@ namespace DotNetOpenAuth.OpenId { } /// <summary> + /// Creates a new HTTP client for use by OpenID relying parties and providers. + /// </summary> + /// <param name="hostFactories">The host factories.</param> + /// <param name="requireSsl">if set to <c>true</c> [require SSL].</param> + /// <param name="cachePolicy">The cache policy.</param> + /// <returns>An HttpClient instance with appropriate caching policies set for OpenID operations.</returns> + internal static HttpClient CreateHttpClient(this IHostFactories hostFactories, bool requireSsl, RequestCachePolicy cachePolicy = null) { + Requires.NotNull(hostFactories, "hostFactories"); + + var rootHandler = hostFactories.CreateHttpMessageHandler(); + var handler = rootHandler; + bool sslRequiredSet = false, cachePolicySet = false; + do { + var webRequestHandler = handler as WebRequestHandler; + var untrustedHandler = handler as UntrustedWebRequestHandler; + var delegatingHandler = handler as DelegatingHandler; + if (webRequestHandler != null) { + if (cachePolicy != null) { + webRequestHandler.CachePolicy = cachePolicy; + cachePolicySet = true; + } + } else if (untrustedHandler != null) { + untrustedHandler.IsSslRequired = requireSsl; + sslRequiredSet = true; + } + + if (delegatingHandler != null) { + handler = delegatingHandler.InnerHandler; + } else { + break; + } + } + while (true); + + if (cachePolicy != null && !cachePolicySet) { + Logger.OpenId.Warn( + "Unable to set cache policy due to HttpMessageHandler instances not being of type WebRequestHandler."); + } + + ErrorUtilities.VerifyProtocol(!requireSsl || sslRequiredSet, "Unable to set RequireSsl on message handler because no HttpMessageHandler was of type {0}.", typeof(UntrustedWebRequestHandler).FullName); + + return hostFactories.CreateHttpClient(rootHandler); + } + + /// <summary> /// Gets the extension factories from the extension aggregator on an OpenID channel. /// </summary> /// <param name="channel">The channel.</param> diff --git a/src/DotNetOpenAuth.OpenId/OpenId/Provider/IHostProcessedRequest.cs b/src/DotNetOpenAuth.OpenId/OpenId/Provider/IHostProcessedRequest.cs index 4a464b9..0615144 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/Provider/IHostProcessedRequest.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/Provider/IHostProcessedRequest.cs @@ -6,6 +6,9 @@ namespace DotNetOpenAuth.OpenId.Provider { using System; + using System.Net.Http; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId.Messages; using Validation; @@ -45,7 +48,8 @@ namespace DotNetOpenAuth.OpenId.Provider { /// <summary> /// Attempts to perform relying party discovery of the return URL claimed by the Relying Party. /// </summary> - /// <param name="webRequestHandler">The web request handler.</param> + /// <param name="hostFactories">The host factories.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// The details of how successful the relying party discovery was. /// </returns> @@ -53,6 +57,6 @@ namespace DotNetOpenAuth.OpenId.Provider { /// <para>Return URL verification is only attempted if this method is called.</para> /// <para>See OpenID Authentication 2.0 spec section 9.2.1.</para> /// </remarks> - RelyingPartyDiscoveryResult IsReturnUrlDiscoverable(IDirectWebRequestHandler webRequestHandler); + Task<RelyingPartyDiscoveryResult> IsReturnUrlDiscoverableAsync(IHostFactories hostFactories = null, CancellationToken cancellationToken = default(CancellationToken)); } } diff --git a/src/DotNetOpenAuth.OpenId/OpenId/Provider/IProviderBehavior.cs b/src/DotNetOpenAuth.OpenId/OpenId/Provider/IProviderBehavior.cs index 57fe66b..d7ec647 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/Provider/IProviderBehavior.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/Provider/IProviderBehavior.cs @@ -6,6 +6,8 @@ namespace DotNetOpenAuth.OpenId.Provider { using System; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.OpenId.ChannelElements; using Validation; @@ -28,6 +30,7 @@ namespace DotNetOpenAuth.OpenId.Provider { /// Called when a request is received by the Provider. /// </summary> /// <param name="request">The incoming request.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// <c>true</c> if this behavior owns this request and wants to stop other behaviors /// from handling it; <c>false</c> to allow other behaviors to process this request. @@ -37,16 +40,17 @@ namespace DotNetOpenAuth.OpenId.Provider { /// should not change the properties on the instance of <see cref="ProviderSecuritySettings"/> /// itself as that instance may be shared across many requests. /// </remarks> - bool OnIncomingRequest(IRequest request); + Task<bool> OnIncomingRequestAsync(IRequest request, CancellationToken cancellationToken); /// <summary> /// Called when the Provider is preparing to send a response to an authentication request. /// </summary> /// <param name="request">The request that is configured to generate the outgoing response.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// <c>true</c> if this behavior owns this request and wants to stop other behaviors /// from handling it; <c>false</c> to allow other behaviors to process this request. /// </returns> - bool OnOutgoingResponse(IAuthenticationRequest request); + Task<bool> OnOutgoingResponseAsync(IAuthenticationRequest request, CancellationToken cancellationToken); } } diff --git a/src/DotNetOpenAuth.OpenId/OpenId/Realm.cs b/src/DotNetOpenAuth.OpenId/OpenId/Realm.cs index c1a959e..f6b4129 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/Realm.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/Realm.cs @@ -12,7 +12,10 @@ namespace DotNetOpenAuth.OpenId { using System.Diagnostics.Contracts; using System.Globalization; using System.Linq; + using System.Net.Http; using System.Text.RegularExpressions; + using System.Threading; + using System.Threading.Tasks; using System.Web; using System.Xml; using DotNetOpenAuth.Messaging; @@ -415,15 +418,16 @@ namespace DotNetOpenAuth.OpenId { /// Searches for an XRDS document at the realm URL, and if found, searches /// for a description of a relying party endpoints (OpenId login pages). /// </summary> - /// <param name="requestHandler">The mechanism to use for sending HTTP requests.</param> + /// <param name="hostFactories">The host factories.</param> /// <param name="allowRedirects">Whether redirects may be followed when discovering the Realm. /// This may be true when creating an unsolicited assertion, but must be /// false when performing return URL verification per 2.0 spec section 9.2.1.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// The details of the endpoints if found; or <c>null</c> if no service document was discovered. /// </returns> - internal virtual IEnumerable<RelyingPartyEndpointDescription> DiscoverReturnToEndpoints(IDirectWebRequestHandler requestHandler, bool allowRedirects) { - XrdsDocument xrds = this.Discover(requestHandler, allowRedirects); + internal virtual async Task<IEnumerable<RelyingPartyEndpointDescription>> DiscoverReturnToEndpointsAsync(IHostFactories hostFactories, bool allowRedirects, CancellationToken cancellationToken) { + XrdsDocument xrds = await this.DiscoverAsync(hostFactories, allowRedirects, cancellationToken); if (xrds != null) { return xrds.FindRelyingPartyReceivingEndpoints(); } @@ -434,16 +438,17 @@ namespace DotNetOpenAuth.OpenId { /// <summary> /// Searches for an XRDS document at the realm URL. /// </summary> - /// <param name="requestHandler">The mechanism to use for sending HTTP requests.</param> + /// <param name="hostFactories">The host factories.</param> /// <param name="allowRedirects">Whether redirects may be followed when discovering the Realm. /// This may be true when creating an unsolicited assertion, but must be /// false when performing return URL verification per 2.0 spec section 9.2.1.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// The XRDS document if found; or <c>null</c> if no service document was discovered. /// </returns> - internal virtual XrdsDocument Discover(IDirectWebRequestHandler requestHandler, bool allowRedirects) { + internal virtual async Task<XrdsDocument> DiscoverAsync(IHostFactories hostFactories, bool allowRedirects, CancellationToken cancellationToken) { // Attempt YADIS discovery - DiscoveryResult yadisResult = Yadis.Discover(requestHandler, this.UriWithWildcardChangedToWww, false); + DiscoveryResult yadisResult = await Yadis.DiscoverAsync(hostFactories, this.UriWithWildcardChangedToWww, false, cancellationToken); if (yadisResult != null) { // Detect disallowed redirects, since realm discovery never allows them for security. ErrorUtilities.VerifyProtocol(allowRedirects || yadisResult.NormalizedUri == yadisResult.RequestUri, OpenIdStrings.RealmCausedRedirectUponDiscovery, yadisResult.RequestUri); diff --git a/src/DotNetOpenAuth.OpenId/OpenId/RelyingParty/IAuthenticationRequest.cs b/src/DotNetOpenAuth.OpenId/OpenId/RelyingParty/IAuthenticationRequest.cs index 886029c..3e922d4 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/RelyingParty/IAuthenticationRequest.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/RelyingParty/IAuthenticationRequest.cs @@ -8,7 +8,10 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { using System; using System.Collections.Generic; using System.Linq; + using System.Net.Http; using System.Text; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId.Messages; @@ -24,12 +27,6 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { AuthenticationRequestMode Mode { get; set; } /// <summary> - /// Gets the HTTP response the relying party should send to the user agent - /// to redirect it to the OpenID Provider to start the OpenID authentication process. - /// </summary> - OutgoingWebResponse RedirectingResponse { get; } - - /// <summary> /// Gets the URL that the user agent will return to after authentication /// completes or fails at the Provider. /// </summary> @@ -173,12 +170,11 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { void AddExtension(IOpenIdMessageExtension extension); /// <summary> - /// Redirects the user agent to the provider for authentication. - /// Execution of the current page terminates after this call. + /// Gets the HTTP response the relying party should send to the user agent + /// to redirect it to the OpenID Provider to start the OpenID authentication process. /// </summary> - /// <remarks> - /// This method requires an ASP.NET HttpContext. - /// </remarks> - void RedirectToProvider(); + /// <param name="cancellationToken">The cancellation token.</param> + /// <returns>The response message that will cause the client to redirect to the Provider.</returns> + Task<HttpResponseMessage> GetRedirectingResponseAsync(CancellationToken cancellationToken = default(CancellationToken)); } } diff --git a/src/DotNetOpenAuth.OpenId/OpenId/UntrustedWebRequestHandler.cs b/src/DotNetOpenAuth.OpenId/OpenId/UntrustedWebRequestHandler.cs new file mode 100644 index 0000000..94d92e5 --- /dev/null +++ b/src/DotNetOpenAuth.OpenId/OpenId/UntrustedWebRequestHandler.cs @@ -0,0 +1,475 @@ +//----------------------------------------------------------------------- +// <copyright file="UntrustedWebRequestHandler.cs" company="Outercurve Foundation"> +// Copyright (c) Outercurve Foundation. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.OpenId { + using System; + using System.Collections.Generic; + using System.Diagnostics; + using System.Diagnostics.CodeAnalysis; + using System.Diagnostics.Contracts; + using System.Globalization; + using System.IO; + using System.Net; + using System.Net.Cache; + using System.Net.Http; + using System.Text.RegularExpressions; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Configuration; + using DotNetOpenAuth.Messaging; + using Validation; + + /// <summary> + /// A paranoid HTTP get/post request engine. It helps to protect against attacks from remote + /// server leaving dangling connections, sending too much data, causing requests against + /// internal servers, etc. + /// </summary> + /// <remarks> + /// Protections include: + /// * Conservative maximum time to receive the complete response. + /// * Only HTTP and HTTPS schemes are permitted. + /// * Internal IP address ranges are not permitted: 127.*.*.*, 1::* + /// * Internal host names are not permitted (periods must be found in the host name) + /// If a particular host would be permitted but is in the blacklist, it is not allowed. + /// If a particular host would not be permitted but is in the whitelist, it is allowed. + /// </remarks> + public class UntrustedWebRequestHandler : DelegatingHandler { + /// <summary> + /// The set of URI schemes allowed in untrusted web requests. + /// </summary> + private ICollection<string> allowableSchemes = new List<string> { "http", "https" }; + + /// <summary> + /// The collection of blacklisted hosts. + /// </summary> + private ICollection<string> blacklistHosts = new List<string>(Configuration.BlacklistHosts.KeysAsStrings); + + /// <summary> + /// The collection of regular expressions used to identify additional blacklisted hosts. + /// </summary> + private ICollection<Regex> blacklistHostsRegex = new List<Regex>(Configuration.BlacklistHostsRegex.KeysAsRegexs); + + /// <summary> + /// The collection of whitelisted hosts. + /// </summary> + private ICollection<string> whitelistHosts = new List<string>(Configuration.WhitelistHosts.KeysAsStrings); + + /// <summary> + /// The collection of regular expressions used to identify additional whitelisted hosts. + /// </summary> + private ICollection<Regex> whitelistHostsRegex = new List<Regex>(Configuration.WhitelistHostsRegex.KeysAsRegexs); + + /// <summary> + /// The maximum redirections to follow in the course of a single request. + /// </summary> + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private int maxAutomaticRedirections = Configuration.MaximumRedirections; + + /// <summary> + /// A value indicating whether to automatically follow redirects. + /// </summary> + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private bool allowAutoRedirect = true; + + /// <summary> + /// Initializes a new instance of the <see cref="UntrustedWebRequestHandler" /> class. + /// </summary> + /// <param name="innerHandler"> + /// The inner handler. This handler will be modified to suit the purposes of this wrapping handler, + /// and should not be used independently of this wrapper after construction of this object. + /// </param> + public UntrustedWebRequestHandler(WebRequestHandler innerHandler = null) + : base(innerHandler ?? new WebRequestHandler()) { + // If SSL is required throughout, we cannot allow auto redirects because + // it may include a pass through an unprotected HTTP request. + // We have to follow redirects manually. + // It also allows us to ignore HttpWebResponse.FinalUri since that can be affected by + // the Content-Location header and open security holes. + this.MaxAutomaticRedirections = Configuration.MaximumRedirections; + this.InnerWebRequestHandler.AllowAutoRedirect = false; + + if (Debugger.IsAttached) { + // Since a debugger is attached, requests may be MUCH slower, + // so give ourselves huge timeouts. + this.InnerWebRequestHandler.ReadWriteTimeout = (int)TimeSpan.FromHours(1).TotalMilliseconds; + } else { + this.InnerWebRequestHandler.ReadWriteTimeout = (int)Configuration.ReadWriteTimeout.TotalMilliseconds; + } + } + + /// <summary> + /// Initializes a new instance of the <see cref="UntrustedWebRequestHandler"/> class + /// for use in unit testing. + /// </summary> + /// <param name="innerHandler"> + /// The inner handler which is responsible for processing the HTTP response messages. + /// This handler should NOT automatically follow redirects. + /// </param> + internal UntrustedWebRequestHandler(HttpMessageHandler innerHandler) + : base(innerHandler) { + } + + /// <summary> + /// Gets or sets a value indicating whether all requests must use SSL. + /// </summary> + /// <value> + /// <c>true</c> if SSL is required; otherwise, <c>false</c>. + /// </value> + public bool IsSslRequired { get; set; } + + /// <summary> + /// Gets or sets the total number of redirections to allow on any one request. + /// Default is 10. + /// </summary> + public int MaxAutomaticRedirections { + get { + return base.InnerHandler is WebRequestHandler ? this.InnerWebRequestHandler.MaxAutomaticRedirections : this.maxAutomaticRedirections; + } + + set { + Requires.Range(value >= 0, "value"); + this.maxAutomaticRedirections = value; + if (base.InnerHandler is WebRequestHandler) { + this.InnerWebRequestHandler.MaxAutomaticRedirections = value; + } + } + } + + /// <summary> + /// Gets or sets a value indicating whether to automatically follow redirects. + /// </summary> + public bool AllowAutoRedirect { + get { + return base.InnerHandler is WebRequestHandler ? this.InnerWebRequestHandler.AllowAutoRedirect : this.allowAutoRedirect; + } + + set { + this.allowAutoRedirect = value; + if (base.InnerHandler is WebRequestHandler) { + this.InnerWebRequestHandler.AllowAutoRedirect = value; + } + } + } + + /// <summary> + /// Gets or sets the time (in milliseconds) allowed to wait for single read or write operation to complete. + /// Default is 500 milliseconds. + /// </summary> + public int ReadWriteTimeout { + get { return this.InnerWebRequestHandler.ReadWriteTimeout; } + set { this.InnerWebRequestHandler.ReadWriteTimeout = value; } + } + + /// <summary> + /// Gets a collection of host name literals that should be allowed even if they don't + /// pass standard security checks. + /// </summary> + [SuppressMessage("Microsoft.Naming", "CA1704:IdentifiersShouldBeSpelledCorrectly", MessageId = "Whitelist", + Justification = "Spelling as intended.")] + public ICollection<string> WhitelistHosts { + get { + return this.whitelistHosts; + } + } + + /// <summary> + /// Gets a collection of host name regular expressions that indicate hosts that should + /// be allowed even though they don't pass standard security checks. + /// </summary> + [SuppressMessage("Microsoft.Naming", "CA1704:IdentifiersShouldBeSpelledCorrectly", MessageId = "Whitelist", + Justification = "Spelling as intended.")] + public ICollection<Regex> WhitelistHostsRegex { + get { + return this.whitelistHostsRegex; + } + } + + /// <summary> + /// Gets a collection of host name literals that should be rejected even if they + /// pass standard security checks. + /// </summary> + public ICollection<string> BlacklistHosts { + get { + return this.blacklistHosts; + } + } + + /// <summary> + /// Gets a collection of host name regular expressions that indicate hosts that should + /// be rejected even if they pass standard security checks. + /// </summary> + public ICollection<Regex> BlacklistHostsRegex { + get { + return this.blacklistHostsRegex; + } + } + + /// <summary> + /// Gets the inner web request handler. + /// </summary> + /// <value> + /// The inner web request handler. + /// </value> + public WebRequestHandler InnerWebRequestHandler { + get { return (WebRequestHandler)this.InnerHandler; } + } + + /// <summary> + /// Gets the configuration for this class that is specified in the host's .config file. + /// </summary> + private static UntrustedWebRequestElement Configuration { + get { return DotNetOpenAuthSection.Messaging.UntrustedWebRequest; } + } + + /// <summary> + /// Creates an HTTP client that uses this instance as an HTTP handler. + /// </summary> + /// <returns>The initialized instance.</returns> + public HttpClient CreateClient() { + var client = new HttpClient(this); + client.MaxResponseContentBufferSize = Configuration.MaximumBytesToRead; + + if (Debugger.IsAttached) { + // Since a debugger is attached, requests may be MUCH slower, + // so give ourselves huge timeouts. + client.Timeout = TimeSpan.FromHours(1); + } else { + client.Timeout = Configuration.Timeout; + } + + return client; + } + + /// <summary> + /// Determines whether an exception was thrown because of the remote HTTP server returning HTTP 417 Expectation Failed. + /// </summary> + /// <param name="ex">The caught exception.</param> + /// <returns> + /// <c>true</c> if the failure was originally caused by a 417 Exceptation Failed error; otherwise, <c>false</c>. + /// </returns> + internal static bool IsExceptionFrom417ExpectationFailed(Exception ex) { + while (ex != null) { + WebException webEx = ex as WebException; + if (webEx != null) { + HttpWebResponse response = webEx.Response as HttpWebResponse; + if (response != null) { + if (response.StatusCode == HttpStatusCode.ExpectationFailed) { + return true; + } + } + } + + ex = ex.InnerException; + } + + return false; + } + + /// <summary> + /// Send an HTTP request as an asynchronous operation. + /// </summary> + /// <param name="request">The HTTP request message to send.</param> + /// <param name="cancellationToken">The cancellation token to cancel operation.</param> + /// <returns> + /// Returns <see cref="T:System.Threading.Tasks.Task`1" />.The task object representing the asynchronous operation. + /// </returns> + protected override async Task<HttpResponseMessage> SendAsync( + HttpRequestMessage request, CancellationToken cancellationToken) { + this.EnsureAllowableRequestUri(request.RequestUri); + + // Since we may require SSL for every redirect, we handle each redirect manually + // in order to detect and fail if any redirect sends us to an HTTP url. + // We COULD allow automatic redirect in the cases where HTTPS is not required, + // but our mock request infrastructure can't do redirects on its own either. + Uri originalRequestUri = request.RequestUri; + int i; + for (i = 0; i < this.MaxAutomaticRedirections; i++) { + this.EnsureAllowableRequestUri(request.RequestUri); + var response = await base.SendAsync(request, cancellationToken); + if (this.AllowAutoRedirect) { + if (response.StatusCode == HttpStatusCode.MovedPermanently || response.StatusCode == HttpStatusCode.Redirect + || response.StatusCode == HttpStatusCode.RedirectMethod + || response.StatusCode == HttpStatusCode.RedirectKeepVerb) { + // We have no copy of the post entity stream to repeat on our manually + // cloned HttpWebRequest, so we have to bail. + ErrorUtilities.VerifyProtocol( + request.Method != HttpMethod.Post, MessagingStrings.UntrustedRedirectsOnPOSTNotSupported); + Uri redirectUri = new Uri(request.RequestUri, response.Headers.Location); + request = request.Clone(); + request.RequestUri = redirectUri; + continue; + } + } + + if (response.StatusCode == HttpStatusCode.ExpectationFailed) { + // Some OpenID servers doesn't understand the Expect header and send 417 error back. + // If this server just failed from that, alter the ServicePoint for this server + // so that we don't send that header again next time (whenever that is). + // "Expect: 100-Continue" HTTP header. (see Google Code Issue 72) + // We don't want to blindly set all ServicePoints to not use the Expect header + // as that would be a security hole allowing any visitor to a web site change + // the web site's global behavior when calling that host. + // TODO: verify that this still works in DNOA 5.0 + var servicePoint = ServicePointManager.FindServicePoint(request.RequestUri); + Logger.Http.InfoFormat( + "HTTP POST to {0} resulted in 417 Expectation Failed. Changing ServicePoint to not use Expect: Continue next time.", + request.RequestUri); + servicePoint.Expect100Continue = false; + } + + return response; + } + + throw ErrorUtilities.ThrowProtocol(MessagingStrings.TooManyRedirects, originalRequestUri); + } + + /// <summary> + /// Determines whether an IP address is the IPv6 equivalent of "localhost/127.0.0.1". + /// </summary> + /// <param name="ip">The ip address to check.</param> + /// <returns> + /// <c>true</c> if this is a loopback IP address; <c>false</c> otherwise. + /// </returns> + private static bool IsIPv6Loopback(IPAddress ip) { + Requires.NotNull(ip, "ip"); + byte[] addressBytes = ip.GetAddressBytes(); + for (int i = 0; i < addressBytes.Length - 1; i++) { + if (addressBytes[i] != 0) { + return false; + } + } + if (addressBytes[addressBytes.Length - 1] != 1) { + return false; + } + return true; + } + + /// <summary> + /// Determines whether the given host name is in a host list or host name regex list. + /// </summary> + /// <param name="host">The host name.</param> + /// <param name="stringList">The list of host names.</param> + /// <param name="regexList">The list of regex patterns of host names.</param> + /// <returns> + /// <c>true</c> if the specified host falls within at least one of the given lists; otherwise, <c>false</c>. + /// </returns> + private static bool IsHostInList(string host, ICollection<string> stringList, ICollection<Regex> regexList) { + Requires.NotNullOrEmpty(host, "host"); + Requires.NotNull(stringList, "stringList"); + Requires.NotNull(regexList, "regexList"); + foreach (string testHost in stringList) { + if (string.Equals(host, testHost, StringComparison.OrdinalIgnoreCase)) { + return true; + } + } + foreach (Regex regex in regexList) { + if (regex.IsMatch(host)) { + return true; + } + } + return false; + } + + /// <summary> + /// Determines whether a given host is whitelisted. + /// </summary> + /// <param name="host">The host name to test.</param> + /// <returns> + /// <c>true</c> if the host is whitelisted; otherwise, <c>false</c>. + /// </returns> + private bool IsHostWhitelisted(string host) { + return IsHostInList(host, this.WhitelistHosts, this.WhitelistHostsRegex); + } + + /// <summary> + /// Determines whether a given host is blacklisted. + /// </summary> + /// <param name="host">The host name to test.</param> + /// <returns> + /// <c>true</c> if the host is blacklisted; otherwise, <c>false</c>. + /// </returns> + private bool IsHostBlacklisted(string host) { + return IsHostInList(host, this.BlacklistHosts, this.BlacklistHostsRegex); + } + + /// <summary> + /// Verify that the request qualifies under our security policies + /// </summary> + /// <param name="requestUri">The request URI.</param> + /// <exception cref="ProtocolException">Thrown when the URI is disallowed for security reasons.</exception> + private void EnsureAllowableRequestUri(Uri requestUri) { + ErrorUtilities.VerifyProtocol( + this.IsUriAllowable(requestUri), MessagingStrings.UnsafeWebRequestDetected, requestUri); + ErrorUtilities.VerifyProtocol( + !this.IsSslRequired || string.Equals(requestUri.Scheme, Uri.UriSchemeHttps, StringComparison.OrdinalIgnoreCase), + MessagingStrings.InsecureWebRequestWithSslRequired, + requestUri); + } + + /// <summary> + /// Determines whether a URI is allowed based on scheme and host name. + /// No requireSSL check is done here + /// </summary> + /// <param name="uri">The URI to test for whether it should be allowed.</param> + /// <returns> + /// <c>true</c> if [is URI allowable] [the specified URI]; otherwise, <c>false</c>. + /// </returns> + private bool IsUriAllowable(Uri uri) { + Requires.NotNull(uri, "uri"); + if (!this.allowableSchemes.Contains(uri.Scheme)) { + Logger.Http.WarnFormat("Rejecting URL {0} because it uses a disallowed scheme.", uri); + return false; + } + + // Allow for whitelist or blacklist to override our detection. + Func<string, bool> failsUnlessWhitelisted = (string reason) => { + if (IsHostWhitelisted(uri.DnsSafeHost)) { + return true; + } + Logger.Http.WarnFormat("Rejecting URL {0} because {1}.", uri, reason); + return false; + }; + + // Try to interpret the hostname as an IP address so we can test for internal + // IP address ranges. Note that IP addresses can appear in many forms + // (e.g. http://127.0.0.1, http://2130706433, http://0x0100007f, http://::1 + // So we convert them to a canonical IPAddress instance, and test for all + // non-routable IP ranges: 10.*.*.*, 127.*.*.*, ::1 + // Note that Uri.IsLoopback is very unreliable, not catching many of these variants. + IPAddress hostIPAddress; + if (IPAddress.TryParse(uri.DnsSafeHost, out hostIPAddress)) { + byte[] addressBytes = hostIPAddress.GetAddressBytes(); + + // The host is actually an IP address. + switch (hostIPAddress.AddressFamily) { + case System.Net.Sockets.AddressFamily.InterNetwork: + if (addressBytes[0] == 127 || addressBytes[0] == 10) { + return failsUnlessWhitelisted("it is a loopback address."); + } + break; + case System.Net.Sockets.AddressFamily.InterNetworkV6: + if (IsIPv6Loopback(hostIPAddress)) { + return failsUnlessWhitelisted("it is a loopback address."); + } + break; + default: + return failsUnlessWhitelisted("it does not use an IPv4 or IPv6 address."); + } + } else { + // The host is given by name. We require names to contain periods to + // help make sure it's not an internal address. + if (!uri.Host.Contains(".")) { + return failsUnlessWhitelisted("it does not contain a period in the host name."); + } + } + if (this.IsHostBlacklisted(uri.DnsSafeHost)) { + Logger.Http.WarnFormat("Rejected URL {0} because it is blacklisted.", uri); + return false; + } + return true; + } + } +} diff --git a/src/DotNetOpenAuth.OpenId/OpenId/UriDiscoveryService.cs b/src/DotNetOpenAuth.OpenId/OpenId/UriDiscoveryService.cs index c262ac9..7cb2a9a 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/UriDiscoveryService.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/UriDiscoveryService.cs @@ -8,47 +8,63 @@ namespace DotNetOpenAuth.OpenId { using System; using System.Collections.Generic; using System.Linq; + using System.Net.Http; + using System.Runtime.CompilerServices; using System.Text; using System.Text.RegularExpressions; + using System.Threading; + using System.Threading.Tasks; using System.Web.UI.HtmlControls; using System.Xml; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId.RelyingParty; using DotNetOpenAuth.Xrds; using DotNetOpenAuth.Yadis; + using Validation; /// <summary> /// The discovery service for URI identifiers. /// </summary> - public class UriDiscoveryService : IIdentifierDiscoveryService { + public class UriDiscoveryService : IIdentifierDiscoveryService, IRequireHostFactories { /// <summary> /// Initializes a new instance of the <see cref="UriDiscoveryService"/> class. /// </summary> public UriDiscoveryService() { } - #region IDiscoveryService Members + /// <summary> + /// Gets or sets the host factories used by this instance. + /// </summary> + /// <value> + /// The host factories. + /// </value> + public IHostFactories HostFactories { get; set; } + + #region IIdentifierDiscoveryService Members /// <summary> /// Performs discovery on the specified identifier. /// </summary> /// <param name="identifier">The identifier to perform discovery on.</param> - /// <param name="requestHandler">The means to place outgoing HTTP requests.</param> - /// <param name="abortDiscoveryChain">if set to <c>true</c>, no further discovery services will be called for this identifier.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// A sequence of service endpoints yielded by discovery. Must not be null, but may be empty. /// </returns> - public IEnumerable<IdentifierDiscoveryResult> Discover(Identifier identifier, IDirectWebRequestHandler requestHandler, out bool abortDiscoveryChain) { - abortDiscoveryChain = false; + public async Task<IdentifierDiscoveryServiceResult> DiscoverAsync(Identifier identifier, CancellationToken cancellationToken) { + Requires.NotNull(identifier, "identifier"); + Verify.Operation(this.HostFactories != null, Strings.HostFactoriesRequired); + cancellationToken.ThrowIfCancellationRequested(); + var uriIdentifier = identifier as UriIdentifier; if (uriIdentifier == null) { - return Enumerable.Empty<IdentifierDiscoveryResult>(); + return new IdentifierDiscoveryServiceResult(Enumerable.Empty<IdentifierDiscoveryResult>()); } var endpoints = new List<IdentifierDiscoveryResult>(); // Attempt YADIS discovery - DiscoveryResult yadisResult = Yadis.Discover(requestHandler, uriIdentifier, identifier.IsDiscoverySecureEndToEnd); + DiscoveryResult yadisResult = await Yadis.DiscoverAsync(this.HostFactories, uriIdentifier, identifier.IsDiscoverySecureEndToEnd, cancellationToken); + cancellationToken.ThrowIfCancellationRequested(); if (yadisResult != null) { if (yadisResult.IsXrds) { try { @@ -67,7 +83,7 @@ namespace DotNetOpenAuth.OpenId { // Failing YADIS discovery of an XRDS document, we try HTML discovery. if (endpoints.Count == 0) { - yadisResult.TryRevertToHtmlResponse(); + await yadisResult.TryRevertToHtmlResponseAsync(); var htmlEndpoints = new List<IdentifierDiscoveryResult>(DiscoverFromHtml(yadisResult.NormalizedUri, uriIdentifier, yadisResult.ResponseText)); if (htmlEndpoints.Any()) { Logger.Yadis.DebugFormat("Total services discovered in HTML: {0}", htmlEndpoints.Count); @@ -83,7 +99,8 @@ namespace DotNetOpenAuth.OpenId { Logger.Yadis.Debug("Skipping HTML discovery because XRDS contained service endpoints."); } } - return endpoints; + + return new IdentifierDiscoveryServiceResult(endpoints); } #endregion diff --git a/src/DotNetOpenAuth.OpenId/OpenId/XriDiscoveryProxyService.cs b/src/DotNetOpenAuth.OpenId/OpenId/XriDiscoveryProxyService.cs index a3e8345..bea8752 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/XriDiscoveryProxyService.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/XriDiscoveryProxyService.cs @@ -10,7 +10,11 @@ namespace DotNetOpenAuth.OpenId { using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Linq; + using System.Net.Http; + using System.Runtime.CompilerServices; using System.Text; + using System.Threading; + using System.Threading.Tasks; using System.Xml; using DotNetOpenAuth.Configuration; using DotNetOpenAuth.Messaging; @@ -23,7 +27,7 @@ namespace DotNetOpenAuth.OpenId { /// The discovery service for XRI identifiers that uses an XRI proxy resolver for discovery. /// </summary> [SuppressMessage("Microsoft.Naming", "CA1704:IdentifiersShouldBeSpelledCorrectly", MessageId = "Xri", Justification = "Acronym")] - public class XriDiscoveryProxyService : IIdentifierDiscoveryService { + public class XriDiscoveryProxyService : IIdentifierDiscoveryService, IRequireHostFactories { /// <summary> /// The magic URL that will provide us an XRDS document for a given XRI identifier. /// </summary> @@ -42,25 +46,33 @@ namespace DotNetOpenAuth.OpenId { public XriDiscoveryProxyService() { } + /// <summary> + /// Gets or sets the host factories used by this instance. + /// </summary> + public IHostFactories HostFactories { get; set; } + #region IDiscoveryService Members /// <summary> /// Performs discovery on the specified identifier. /// </summary> /// <param name="identifier">The identifier to perform discovery on.</param> - /// <param name="requestHandler">The means to place outgoing HTTP requests.</param> - /// <param name="abortDiscoveryChain">if set to <c>true</c>, no further discovery services will be called for this identifier.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// A sequence of service endpoints yielded by discovery. Must not be null, but may be empty. /// </returns> - public IEnumerable<IdentifierDiscoveryResult> Discover(Identifier identifier, IDirectWebRequestHandler requestHandler, out bool abortDiscoveryChain) { - abortDiscoveryChain = false; + public async Task<IdentifierDiscoveryServiceResult> DiscoverAsync(Identifier identifier, CancellationToken cancellationToken) { + Requires.NotNull(identifier, "identifier"); + Verify.Operation(this.HostFactories != null, Strings.HostFactoriesRequired); + var xriIdentifier = identifier as XriIdentifier; if (xriIdentifier == null) { - return Enumerable.Empty<IdentifierDiscoveryResult>(); + return new IdentifierDiscoveryServiceResult(Enumerable.Empty<IdentifierDiscoveryResult>()); } - return DownloadXrds(xriIdentifier, requestHandler).XrdElements.CreateServiceEndpoints(xriIdentifier); + var xrds = await DownloadXrdsAsync(xriIdentifier, this.HostFactories, cancellationToken); + var endpoints = xrds.XrdElements.CreateServiceEndpoints(xriIdentifier); + return new IdentifierDiscoveryServiceResult(endpoints); } #endregion @@ -69,16 +81,26 @@ namespace DotNetOpenAuth.OpenId { /// Downloads the XRDS document for this XRI. /// </summary> /// <param name="identifier">The identifier.</param> - /// <param name="requestHandler">The request handler.</param> - /// <returns>The XRDS document.</returns> - private static XrdsDocument DownloadXrds(XriIdentifier identifier, IDirectWebRequestHandler requestHandler) { + /// <param name="hostFactories">The host factories.</param> + /// <param name="cancellationToken">The cancellation token.</param> + /// <returns> + /// The XRDS document. + /// </returns> + private static async Task<XrdsDocument> DownloadXrdsAsync(XriIdentifier identifier, IHostFactories hostFactories, CancellationToken cancellationToken) { Requires.NotNull(identifier, "identifier"); - Requires.NotNull(requestHandler, "requestHandler"); + Requires.NotNull(hostFactories, "hostFactories"); + XrdsDocument doc; - using (var xrdsResponse = Yadis.Request(requestHandler, GetXrdsUrl(identifier), identifier.IsDiscoverySecureEndToEnd)) { + using (var xrdsResponse = await Yadis.RequestAsync(GetXrdsUrl(identifier), identifier.IsDiscoverySecureEndToEnd, hostFactories, cancellationToken)) { + xrdsResponse.EnsureSuccessStatusCode(); var readerSettings = MessagingUtilities.CreateUntrustedXmlReaderSettings(); - doc = new XrdsDocument(XmlReader.Create(xrdsResponse.ResponseStream, readerSettings)); + ErrorUtilities.VerifyProtocol(xrdsResponse.Content != null, "XRDS request \"{0}\" returned no response.", GetXrdsUrl(identifier)); + await xrdsResponse.Content.LoadIntoBufferAsync(); + using (var xrdsStream = await xrdsResponse.Content.ReadAsStreamAsync()) { + doc = new XrdsDocument(XmlReader.Create(xrdsStream, readerSettings)); + } } + ErrorUtilities.VerifyProtocol(doc.IsXrdResolutionSuccessful, OpenIdStrings.XriResolutionFailed); return doc; } diff --git a/src/DotNetOpenAuth.OpenId/Yadis/DiscoveryResult.cs b/src/DotNetOpenAuth.OpenId/Yadis/DiscoveryResult.cs index 06c6fc7..8266ff0 100644 --- a/src/DotNetOpenAuth.OpenId/Yadis/DiscoveryResult.cs +++ b/src/DotNetOpenAuth.OpenId/Yadis/DiscoveryResult.cs @@ -7,10 +7,15 @@ namespace DotNetOpenAuth.Yadis { using System; using System.IO; + using System.Net; + using System.Net.Http; + using System.Net.Http.Headers; using System.Net.Mime; + using System.Threading.Tasks; using System.Web.UI.HtmlControls; using System.Xml; using DotNetOpenAuth.Messaging; + using Validation; /// <summary> /// Contains the result of YADIS discovery. @@ -20,30 +25,12 @@ namespace DotNetOpenAuth.Yadis { /// The original web response, backed up here if the final web response is the preferred response to use /// in case it turns out to not work out. /// </summary> - private CachedDirectWebResponse htmlFallback; + private HttpResponseMessage htmlFallback; /// <summary> - /// Initializes a new instance of the <see cref="DiscoveryResult"/> class. + /// Prevents a default instance of the <see cref="DiscoveryResult" /> class from being created. /// </summary> - /// <param name="requestUri">The user-supplied identifier.</param> - /// <param name="initialResponse">The initial response.</param> - /// <param name="finalResponse">The final response.</param> - public DiscoveryResult(Uri requestUri, CachedDirectWebResponse initialResponse, CachedDirectWebResponse finalResponse) { - this.RequestUri = requestUri; - this.NormalizedUri = initialResponse.FinalUri; - if (finalResponse == null || finalResponse.Status != System.Net.HttpStatusCode.OK) { - this.ApplyHtmlResponse(initialResponse); - } else { - this.ContentType = finalResponse.ContentType; - this.ResponseText = finalResponse.GetResponseString(); - this.IsXrds = true; - if (initialResponse != finalResponse) { - this.YadisLocation = finalResponse.RequestUri; - } - - // Back up the initial HTML response in case the XRDS is not useful. - this.htmlFallback = initialResponse; - } + private DiscoveryResult() { } /// <summary> @@ -68,7 +55,7 @@ namespace DotNetOpenAuth.Yadis { /// <summary> /// Gets the Content-Type associated with the <see cref="ResponseText"/>. /// </summary> - public ContentType ContentType { get; private set; } + public MediaTypeHeaderValue ContentType { get; private set; } /// <summary> /// Gets the text in the final response. @@ -84,11 +71,42 @@ namespace DotNetOpenAuth.Yadis { public bool IsXrds { get; private set; } /// <summary> + /// Initializes a new instance of the <see cref="DiscoveryResult"/> class. + /// </summary> + /// <param name="requestUri">The request URI.</param> + /// <param name="initialResponse">The initial response.</param> + /// <param name="finalResponse">The final response.</param> + /// <returns>The newly initialized instance.</returns> + internal static async Task<DiscoveryResult> CreateAsync(Uri requestUri, HttpResponseMessage initialResponse, HttpResponseMessage finalResponse) { + var result = new DiscoveryResult(); + result.RequestUri = requestUri; + result.NormalizedUri = initialResponse.RequestMessage.RequestUri; + if (finalResponse == null || finalResponse.StatusCode != HttpStatusCode.OK) { + await result.ApplyHtmlResponseAsync(initialResponse); + } else { + result.ContentType = finalResponse.Content.Headers.ContentType; + result.ResponseText = await finalResponse.Content.ReadAsStringAsync(); + result.IsXrds = true; + if (initialResponse != finalResponse) { + result.YadisLocation = finalResponse.RequestMessage.RequestUri; + } + + // Back up the initial HTML response in case the XRDS is not useful. + result.htmlFallback = initialResponse; + } + + return result; + } + + /// <summary> /// Reverts to the HTML response after the XRDS response didn't work out. /// </summary> - internal void TryRevertToHtmlResponse() { + /// <returns> + /// A task that completes with the asynchronous operation. + /// </returns> + internal async Task TryRevertToHtmlResponseAsync() { if (this.htmlFallback != null) { - this.ApplyHtmlResponse(this.htmlFallback); + await this.ApplyHtmlResponseAsync(this.htmlFallback); this.htmlFallback = null; } } @@ -96,10 +114,15 @@ namespace DotNetOpenAuth.Yadis { /// <summary> /// Applies the HTML response to the object. /// </summary> - /// <param name="initialResponse">The initial response.</param> - private void ApplyHtmlResponse(CachedDirectWebResponse initialResponse) { - this.ContentType = initialResponse.ContentType; - this.ResponseText = initialResponse.GetResponseString(); + /// <param name="response">The initial response.</param> + /// <returns> + /// A task that completes with the asynchronous operation. + /// </returns> + private async Task ApplyHtmlResponseAsync(HttpResponseMessage response) { + Requires.NotNull(response, "response"); + + this.ContentType = response.Content.Headers.ContentType; + this.ResponseText = await response.Content.ReadAsStringAsync(); this.IsXrds = this.ContentType != null && this.ContentType.MediaType == ContentTypes.Xrds; } } diff --git a/src/DotNetOpenAuth.OpenId/Yadis/Yadis.cs b/src/DotNetOpenAuth.OpenId/Yadis/Yadis.cs index 4a06ea7..77d4926 100644 --- a/src/DotNetOpenAuth.OpenId/Yadis/Yadis.cs +++ b/src/DotNetOpenAuth.OpenId/Yadis/Yadis.cs @@ -6,9 +6,15 @@ namespace DotNetOpenAuth.Yadis { using System; + using System.Collections.Generic; using System.IO; + using System.Linq; using System.Net; using System.Net.Cache; + using System.Net.Http; + using System.Net.Http.Headers; + using System.Threading; + using System.Threading.Tasks; using System.Web.UI.HtmlControls; using System.Xml; using DotNetOpenAuth.Configuration; @@ -44,39 +50,51 @@ namespace DotNetOpenAuth.Yadis { /// <summary> /// Performs YADIS discovery on some identifier. /// </summary> - /// <param name="requestHandler">The mechanism to use for sending HTTP requests.</param> + /// <param name="hostFactories">The host factories.</param> /// <param name="uri">The URI to perform discovery on.</param> /// <param name="requireSsl">Whether discovery should fail if any step of it is not encrypted.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// The result of discovery on the given URL. /// Null may be returned if an error occurs, - /// or if <paramref name="requireSsl"/> is true but part of discovery + /// or if <paramref name="requireSsl" /> is true but part of discovery /// is not protected by SSL. /// </returns> - public static DiscoveryResult Discover(IDirectWebRequestHandler requestHandler, UriIdentifier uri, bool requireSsl) { - CachedDirectWebResponse response; + public static async Task<DiscoveryResult> DiscoverAsync(IHostFactories hostFactories, UriIdentifier uri, bool requireSsl, CancellationToken cancellationToken) { + Requires.NotNull(hostFactories, "hostFactories"); + Requires.NotNull(uri, "uri"); + + HttpResponseMessage response; try { if (requireSsl && !string.Equals(uri.Uri.Scheme, Uri.UriSchemeHttps, StringComparison.OrdinalIgnoreCase)) { Logger.Yadis.WarnFormat("Discovery on insecure identifier '{0}' aborted.", uri); return null; } - response = Request(requestHandler, uri, requireSsl, ContentTypes.Html, ContentTypes.XHtml, ContentTypes.Xrds).GetSnapshot(MaximumResultToScan); - if (response.Status != System.Net.HttpStatusCode.OK) { - Logger.Yadis.ErrorFormat("HTTP error {0} {1} while performing discovery on {2}.", (int)response.Status, response.Status, uri); + + response = await RequestAsync(uri, requireSsl, hostFactories, cancellationToken, ContentTypes.Html, ContentTypes.XHtml, ContentTypes.Xrds); + if (response.StatusCode != System.Net.HttpStatusCode.OK) { + Logger.Yadis.ErrorFormat("HTTP error {0} {1} while performing discovery on {2}.", (int)response.StatusCode, response.StatusCode, uri); return null; } + + await response.Content.LoadIntoBufferAsync(); } catch (ArgumentException ex) { // Unsafe URLs generate this Logger.Yadis.WarnFormat("Unsafe OpenId URL detected ({0}). Request aborted. {1}", uri, ex); return null; } - CachedDirectWebResponse response2 = null; - if (IsXrdsDocument(response)) { + HttpResponseMessage response2 = null; + if (await IsXrdsDocumentAsync(response)) { Logger.Yadis.Debug("An XRDS response was received from GET at user-supplied identifier."); Reporting.RecordEventOccurrence("Yadis", "XRDS in initial response"); response2 = response; } else { - string uriString = response.Headers.Get(HeaderName); + IEnumerable<string> uriStrings; + string uriString = null; + if (response.Headers.TryGetValues(HeaderName, out uriStrings)) { + uriString = uriStrings.FirstOrDefault(); + } + Uri url = null; if (uriString != null) { if (Uri.TryCreate(uriString, UriKind.Absolute, out url)) { @@ -84,8 +102,10 @@ namespace DotNetOpenAuth.Yadis { Reporting.RecordEventOccurrence("Yadis", "XRDS referenced in HTTP header"); } } - if (url == null && response.ContentType != null && (response.ContentType.MediaType == ContentTypes.Html || response.ContentType.MediaType == ContentTypes.XHtml)) { - url = FindYadisDocumentLocationInHtmlMetaTags(response.GetResponseString()); + + var contentType = response.Content.Headers.ContentType; + if (url == null && contentType != null && (contentType.MediaType == ContentTypes.Html || contentType.MediaType == ContentTypes.XHtml)) { + url = FindYadisDocumentLocationInHtmlMetaTags(await response.Content.ReadAsStringAsync()); if (url != null) { Logger.Yadis.DebugFormat("{0} found in HTML Http-Equiv tag. Preparing to pull XRDS from {1}", HeaderName, url); Reporting.RecordEventOccurrence("Yadis", "XRDS referenced in HTML"); @@ -93,16 +113,17 @@ namespace DotNetOpenAuth.Yadis { } if (url != null) { if (!requireSsl || string.Equals(url.Scheme, Uri.UriSchemeHttps, StringComparison.OrdinalIgnoreCase)) { - response2 = Request(requestHandler, url, requireSsl, ContentTypes.Xrds).GetSnapshot(MaximumResultToScan); - if (response2.Status != HttpStatusCode.OK) { - Logger.Yadis.ErrorFormat("HTTP error {0} {1} while performing discovery on {2}.", (int)response2.Status, response2.Status, uri); + response2 = await RequestAsync(url, requireSsl, hostFactories, cancellationToken, ContentTypes.Xrds); + if (response2.StatusCode != HttpStatusCode.OK) { + Logger.Yadis.ErrorFormat("HTTP error {0} {1} while performing discovery on {2}.", (int)response2.StatusCode, response2.StatusCode, uri); } } else { Logger.Yadis.WarnFormat("XRDS document at insecure location '{0}'. Aborting YADIS discovery.", url); } } } - return new DiscoveryResult(uri, response, response2); + + return await DiscoveryResult.CreateAsync(uri, response, response2); } /// <summary> @@ -129,43 +150,45 @@ namespace DotNetOpenAuth.Yadis { /// <summary> /// Sends a YADIS HTTP request as part of identifier discovery. /// </summary> - /// <param name="requestHandler">The request handler to use to actually submit the request.</param> /// <param name="uri">The URI to GET.</param> /// <param name="requireSsl">Whether only HTTPS URLs should ever be retrieved.</param> + /// <param name="hostFactories">The host factories.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <param name="acceptTypes">The value of the Accept HTTP header to include in the request.</param> - /// <returns>The HTTP response retrieved from the request.</returns> - internal static IncomingWebResponse Request(IDirectWebRequestHandler requestHandler, Uri uri, bool requireSsl, params string[] acceptTypes) { - Requires.NotNull(requestHandler, "requestHandler"); + /// <returns> + /// The HTTP response retrieved from the request. + /// </returns> + internal static async Task<HttpResponseMessage> RequestAsync(Uri uri, bool requireSsl, IHostFactories hostFactories, CancellationToken cancellationToken, params string[] acceptTypes) { Requires.NotNull(uri, "uri"); + Requires.NotNull(hostFactories, "hostFactories"); - HttpWebRequest request = (HttpWebRequest)WebRequest.Create(uri); - request.CachePolicy = IdentifierDiscoveryCachePolicy; - if (acceptTypes != null) { - request.Accept = string.Join(",", acceptTypes); - } - - DirectWebRequestOptions options = DirectWebRequestOptions.None; - if (requireSsl) { - options |= DirectWebRequestOptions.RequireSsl; - } + using (var httpClient = hostFactories.CreateHttpClient(requireSsl, IdentifierDiscoveryCachePolicy)) { + var request = new HttpRequestMessage(HttpMethod.Get, uri); + if (acceptTypes != null) { + request.Headers.Accept.AddRange(acceptTypes.Select(at => new MediaTypeWithQualityHeaderValue(at))); + } - try { - return requestHandler.GetResponse(request, options); - } catch (ProtocolException ex) { - var webException = ex.InnerException as WebException; - if (webException != null) { - var response = webException.Response as HttpWebResponse; - if (response != null && response.IsFromCache) { + HttpResponseMessage response = null; + try { + response = await httpClient.SendAsync(request, cancellationToken); + // http://stackoverflow.com/questions/14103154/how-to-determine-if-an-httpresponsemessage-was-fulfilled-from-cache-using-httpcl + if (!response.IsSuccessStatusCode && response.Headers.Age.HasValue && response.Headers.Age.Value > TimeSpan.Zero) { // We don't want to report error responses from the cache, since the server may have fixed // whatever was causing the problem. So try again with cache disabled. - Logger.Messaging.Error("An HTTP error response was obtained from the cache. Retrying with cache disabled.", ex); + Logger.Messaging.ErrorFormat("An HTTP {0} response was obtained from the cache. Retrying with cache disabled.", response.StatusCode); + response.Dispose(); // discard the old one + var nonCachingRequest = request.Clone(); - nonCachingRequest.CachePolicy = new HttpRequestCachePolicy(HttpRequestCacheLevel.Reload); - return requestHandler.GetResponse(nonCachingRequest, options); + using (var nonCachingHttpClient = hostFactories.CreateHttpClient(requireSsl, new RequestCachePolicy(RequestCacheLevel.Reload))) { + response = await nonCachingHttpClient.SendAsync(nonCachingRequest, cancellationToken); + } } - } - throw; + return response; + } catch { + response.DisposeIfNotNull(); + throw; + } } } @@ -176,25 +199,26 @@ namespace DotNetOpenAuth.Yadis { /// <returns> /// <c>true</c> if the response constains an XRDS document; otherwise, <c>false</c>. /// </returns> - private static bool IsXrdsDocument(CachedDirectWebResponse response) { - if (response.ContentType == null) { + private static async Task<bool> IsXrdsDocumentAsync(HttpResponseMessage response) { + if (response.Content.Headers.ContentType == null) { return false; } - if (response.ContentType.MediaType == ContentTypes.Xrds) { + if (response.Content.Headers.ContentType.MediaType == ContentTypes.Xrds) { return true; } - if (response.ContentType.MediaType == ContentTypes.Xml) { + if (response.Content.Headers.ContentType.MediaType == ContentTypes.Xml) { // This COULD be an XRDS document with an imprecise content-type. - response.ResponseStream.Seek(0, SeekOrigin.Begin); - var readerSettings = MessagingUtilities.CreateUntrustedXmlReaderSettings(); - XmlReader reader = XmlReader.Create(response.ResponseStream, readerSettings); - while (reader.Read() && reader.NodeType != XmlNodeType.Element) { - // intentionally blank - } - if (reader.NamespaceURI == XrdsNode.XrdsNamespace && reader.Name == "XRDS") { - return true; + using (var responseStream = await response.Content.ReadAsStreamAsync()) { + var readerSettings = MessagingUtilities.CreateUntrustedXmlReaderSettings(); + XmlReader reader = XmlReader.Create(responseStream, readerSettings); + while (await reader.ReadAsync() && reader.NodeType != XmlNodeType.Element) { + // intentionally blank + } + if (reader.NamespaceURI == XrdsNode.XrdsNamespace && reader.Name == "XRDS") { + return true; + } } } diff --git a/src/DotNetOpenAuth.OpenId/packages.config b/src/DotNetOpenAuth.OpenId/packages.config index 58890d8..d32d62f 100644 --- a/src/DotNetOpenAuth.OpenId/packages.config +++ b/src/DotNetOpenAuth.OpenId/packages.config @@ -1,4 +1,5 @@ <?xml version="1.0" encoding="utf-8"?> <packages> - <package id="Validation" version="2.0.1.12362" targetFramework="net45" /> + <package id="Microsoft.Net.Http" version="2.0.20710.0" targetFramework="net45" /> + <package id="Validation" version="2.0.2.13022" targetFramework="net45" /> </packages>
\ No newline at end of file |