diff options
author | Andrew Arnott <andrewarnott@gmail.com> | 2012-12-31 22:54:20 -0800 |
---|---|---|
committer | Andrew Arnott <andrewarnott@gmail.com> | 2012-12-31 22:54:20 -0800 |
commit | 30cdda15c5e8b6db0d7260697c0a13c06943afec (patch) | |
tree | 4b3ea104a96a617502bc11193ad2a99c59d917d7 | |
parent | 90b6aa8ba9d15e0254eccf05b73b24f334128654 (diff) | |
download | DotNetOpenAuth-30cdda15c5e8b6db0d7260697c0a13c06943afec.zip DotNetOpenAuth-30cdda15c5e8b6db0d7260697c0a13c06943afec.tar.gz DotNetOpenAuth-30cdda15c5e8b6db0d7260697c0a13c06943afec.tar.bz2 |
DNOA.OpenId.RP now builds.
30 files changed, 505 insertions, 386 deletions
diff --git a/src/DotNetOpenAuth.Core/Messaging/Bindings/StandardExpirationBindingElement.cs b/src/DotNetOpenAuth.Core/Messaging/Bindings/StandardExpirationBindingElement.cs index 7ab78db..ace4cf5 100644 --- a/src/DotNetOpenAuth.Core/Messaging/Bindings/StandardExpirationBindingElement.cs +++ b/src/DotNetOpenAuth.Core/Messaging/Bindings/StandardExpirationBindingElement.cs @@ -6,6 +6,8 @@ namespace DotNetOpenAuth.Messaging.Bindings { using System; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Configuration; /// <summary> @@ -13,6 +15,9 @@ namespace DotNetOpenAuth.Messaging.Bindings { /// implementing the <see cref="IExpiringProtocolMessage"/> interface. /// </summary> internal class StandardExpirationBindingElement : IChannelBindingElement { + private static readonly Task<MessageProtections?> NullTask = Task.FromResult<MessageProtections?>(null); + private static readonly Task<MessageProtections?> CompletedExpirationTask = Task.FromResult<MessageProtections?>(MessageProtections.Expiration); + /// <summary> /// Initializes a new instance of the <see cref="StandardExpirationBindingElement"/> class. /// </summary> @@ -55,14 +60,14 @@ namespace DotNetOpenAuth.Messaging.Bindings { /// The protections (if any) that this binding element applied to the message. /// Null if this binding element did not even apply to this binding element. /// </returns> - public MessageProtections? ProcessOutgoingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { IExpiringProtocolMessage expiringMessage = message as IExpiringProtocolMessage; if (expiringMessage != null) { expiringMessage.UtcCreationDate = DateTime.UtcNow; - return MessageProtections.Expiration; + return CompletedExpirationTask; } - return null; + return NullTask; } /// <summary> @@ -78,7 +83,7 @@ namespace DotNetOpenAuth.Messaging.Bindings { /// Thrown when the binding element rules indicate that this message is invalid and should /// NOT be processed. /// </exception> - public MessageProtections? ProcessIncomingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { IExpiringProtocolMessage expiringMessage = message as IExpiringProtocolMessage; if (expiringMessage != null) { // Yes the UtcCreationDate is supposed to always be in UTC already, @@ -96,7 +101,7 @@ namespace DotNetOpenAuth.Messaging.Bindings { MessagingStrings.MessageTimestampInFuture, creationDate); - return MessageProtections.Expiration; + return CompletedExpirationTask; } return null; diff --git a/src/DotNetOpenAuth.Core/Messaging/Bindings/StandardReplayProtectionBindingElement.cs b/src/DotNetOpenAuth.Core/Messaging/Bindings/StandardReplayProtectionBindingElement.cs index 45bccdf..2502742 100644 --- a/src/DotNetOpenAuth.Core/Messaging/Bindings/StandardReplayProtectionBindingElement.cs +++ b/src/DotNetOpenAuth.Core/Messaging/Bindings/StandardReplayProtectionBindingElement.cs @@ -7,12 +7,17 @@ namespace DotNetOpenAuth.Messaging.Bindings { using System; using System.Diagnostics; + using System.Threading; + using System.Threading.Tasks; using Validation; /// <summary> /// A binding element that checks/verifies a nonce message part. /// </summary> internal class StandardReplayProtectionBindingElement : IChannelBindingElement { + private static readonly Task<MessageProtections?> NullTask = Task.FromResult<MessageProtections?>(null); + private static readonly Task<MessageProtections?> CompletedReplayProtectionTask = Task.FromResult<MessageProtections?>(MessageProtections.ReplayProtection); + /// <summary> /// These are the characters that may be chosen from when forming a random nonce. /// </summary> @@ -100,14 +105,14 @@ namespace DotNetOpenAuth.Messaging.Bindings { /// The protections (if any) that this binding element applied to the message. /// Null if this binding element did not even apply to this binding element. /// </returns> - public MessageProtections? ProcessOutgoingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { IReplayProtectedProtocolMessage nonceMessage = message as IReplayProtectedProtocolMessage; if (nonceMessage != null) { nonceMessage.Nonce = this.GenerateUniqueFragment(); - return MessageProtections.ReplayProtection; + return CompletedReplayProtectionTask; } - return null; + return NullTask; } /// <summary> @@ -119,7 +124,7 @@ namespace DotNetOpenAuth.Messaging.Bindings { /// Null if this binding element did not even apply to this binding element. /// </returns> /// <exception cref="ReplayedMessageException">Thrown when the nonce check revealed a replayed message.</exception> - public MessageProtections? ProcessIncomingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { IReplayProtectedProtocolMessage nonceMessage = message as IReplayProtectedProtocolMessage; if (nonceMessage != null && nonceMessage.Nonce != null) { ErrorUtilities.VerifyProtocol(nonceMessage.Nonce.Length > 0 || this.AllowZeroLengthNonce, MessagingStrings.InvalidNonceReceived); @@ -129,10 +134,10 @@ namespace DotNetOpenAuth.Messaging.Bindings { throw new ReplayedMessageException(message); } - return MessageProtections.ReplayProtection; + return CompletedReplayProtectionTask; } - return null; + return NullTask; } #endregion diff --git a/src/DotNetOpenAuth.Core/Messaging/Channel.cs b/src/DotNetOpenAuth.Core/Messaging/Channel.cs index bff44e7..4d35c90 100644 --- a/src/DotNetOpenAuth.Core/Messaging/Channel.cs +++ b/src/DotNetOpenAuth.Core/Messaging/Channel.cs @@ -279,10 +279,10 @@ namespace DotNetOpenAuth.Messaging { /// </summary> /// <param name="message">The one-way message to send</param> /// <returns>The pending user agent redirect based message to be sent as an HttpResponse.</returns> - public HttpResponseMessage PrepareResponse(IProtocolMessage message) { + public async Task<HttpResponseMessage> PrepareResponseAsync(IProtocolMessage message, CancellationToken cancellationToken) { Requires.NotNull(message, "message"); - this.ProcessOutgoingMessage(message); + await this.ProcessOutgoingMessageAsync(message, cancellationToken); Logger.Channel.DebugFormat("Sending message: {0}", message.GetType().Name); HttpResponseMessage result; @@ -333,24 +333,8 @@ namespace DotNetOpenAuth.Messaging { /// Requires an HttpContext.Current context. /// </remarks> /// <exception cref="InvalidOperationException">Thrown when <see cref="HttpContext.Current"/> is null.</exception> - public IDirectedProtocolMessage ReadFromRequest() { - return this.ReadFromRequest(this.GetRequestFromContext()); - } - - /// <summary> - /// Gets the protocol message embedded in the given HTTP request, if present. - /// </summary> - /// <typeparam name="TRequest">The expected type of the message to be received.</typeparam> - /// <param name="request">The deserialized message, if one is found. Null otherwise.</param> - /// <returns>True if the expected message was recognized and deserialized. False otherwise.</returns> - /// <remarks> - /// Requires an HttpContext.Current context. - /// </remarks> - /// <exception cref="InvalidOperationException">Thrown when <see cref="HttpContext.Current"/> is null.</exception> - /// <exception cref="ProtocolException">Thrown when a request message of an unexpected type is received.</exception> - public bool TryReadFromRequest<TRequest>(out TRequest request) - where TRequest : class, IProtocolMessage { - return TryReadFromRequest<TRequest>(this.GetRequestFromContext(), out request); + public Task<IDirectedProtocolMessage> ReadFromRequestAsync(CancellationToken cancellationToken) { + return this.ReadFromRequestAsync(this.GetRequestFromContext(), cancellationToken); } /// <summary> @@ -362,36 +346,18 @@ namespace DotNetOpenAuth.Messaging { /// <returns>True if the expected message was recognized and deserialized. False otherwise.</returns> /// <exception cref="InvalidOperationException">Thrown when <see cref="HttpContext.Current"/> is null.</exception> /// <exception cref="ProtocolException">Thrown when a request message of an unexpected type is received.</exception> - public bool TryReadFromRequest<TRequest>(HttpRequestBase httpRequest, out TRequest request) + public async Task<TRequest> TryReadFromRequestAsync<TRequest>(CancellationToken cancellationToken, HttpRequestBase httpRequest = null) where TRequest : class, IProtocolMessage { - Requires.NotNull(httpRequest, "httpRequest"); + httpRequest = httpRequest ?? this.GetRequestFromContext(); - IProtocolMessage untypedRequest = this.ReadFromRequest(httpRequest); + IProtocolMessage untypedRequest = await this.ReadFromRequestAsync(httpRequest, cancellationToken); if (untypedRequest == null) { - request = null; - return false; + return null; } - request = untypedRequest as TRequest; + var request = untypedRequest as TRequest; ErrorUtilities.VerifyProtocol(request != null, MessagingStrings.UnexpectedMessageReceived, typeof(TRequest), untypedRequest.GetType()); - - return true; - } - - /// <summary> - /// Gets the protocol message embedded in the current HTTP request. - /// </summary> - /// <typeparam name="TRequest">The expected type of the message to be received.</typeparam> - /// <returns>The deserialized message. Never null.</returns> - /// <remarks> - /// Requires an HttpContext.Current context. - /// </remarks> - /// <exception cref="InvalidOperationException">Thrown when <see cref="HttpContext.Current"/> is null.</exception> - /// <exception cref="ProtocolException">Thrown if the expected message was not recognized in the response.</exception> - [SuppressMessage("Microsoft.Design", "CA1004:GenericMethodsShouldProvideTypeParameter", Justification = "This returns and verifies the appropriate message type.")] - public TRequest ReadFromRequest<TRequest>() - where TRequest : class, IProtocolMessage { - return this.ReadFromRequest<TRequest>(this.GetRequestFromContext()); + return request; } /// <summary> @@ -402,15 +368,12 @@ namespace DotNetOpenAuth.Messaging { /// <returns>The deserialized message. Never null.</returns> /// <exception cref="ProtocolException">Thrown if the expected message was not recognized in the response.</exception> [SuppressMessage("Microsoft.Design", "CA1004:GenericMethodsShouldProvideTypeParameter", Justification = "This returns and verifies the appropriate message type.")] - public TRequest ReadFromRequest<TRequest>(HttpRequestBase httpRequest) + public async Task<TRequest> ReadFromRequestAsync<TRequest>(CancellationToken cancellationToken, HttpRequestBase httpRequest = null) where TRequest : class, IProtocolMessage { - Requires.NotNull(httpRequest, "httpRequest"); - TRequest request; - if (this.TryReadFromRequest<TRequest>(httpRequest, out request)) { - return request; - } else { - throw ErrorUtilities.ThrowProtocol(MessagingStrings.ExpectedMessageNotReceived, typeof(TRequest)); - } + httpRequest = httpRequest ?? this.GetRequestFromContext(); + TRequest request = await this.TryReadFromRequestAsync<TRequest>(cancellationToken, httpRequest); + ErrorUtilities.VerifyProtocol(request != null, MessagingStrings.ExpectedMessageNotReceived, typeof(TRequest)); + return request; } /// <summary> @@ -418,13 +381,13 @@ namespace DotNetOpenAuth.Messaging { /// </summary> /// <param name="httpRequest">The request to search for an embedded message.</param> /// <returns>The deserialized message, if one is found. Null otherwise.</returns> - public IDirectedProtocolMessage ReadFromRequest(HttpRequestBase httpRequest) { + public async Task<IDirectedProtocolMessage> ReadFromRequestAsync(HttpRequestBase httpRequest, CancellationToken cancellationToken) { Requires.NotNull(httpRequest, "httpRequest"); if (Logger.Channel.IsInfoEnabled && httpRequest.GetPublicFacingUrl() != null) { Logger.Channel.InfoFormat("Scanning incoming request for messages: {0}", httpRequest.GetPublicFacingUrl().AbsoluteUri); } - IDirectedProtocolMessage requestMessage = this.ReadFromRequestCore(httpRequest); + IDirectedProtocolMessage requestMessage = this.ReadFromRequestCore(httpRequest, cancellationToken); if (requestMessage != null) { Logger.Channel.DebugFormat("Incoming request received: {0}", requestMessage.GetType().Name); @@ -435,7 +398,7 @@ namespace DotNetOpenAuth.Messaging { } } - this.ProcessIncomingMessage(requestMessage); + await this.ProcessIncomingMessageAsync(requestMessage, cancellationToken); } return requestMessage; @@ -446,17 +409,18 @@ namespace DotNetOpenAuth.Messaging { /// </summary> /// <typeparam name="TResponse">The expected type of the message to be received.</typeparam> /// <param name="requestMessage">The message to send.</param> - /// <returns>The remote party's response.</returns> - /// <exception cref="ProtocolException"> - /// Thrown if no message is recognized in the response - /// or an unexpected type of message is received. - /// </exception> + /// <param name="cancellationToken">The cancellation token.</param> + /// <returns> + /// The remote party's response. + /// </returns> + /// <exception cref="ProtocolException">Thrown if no message is recognized in the response + /// or an unexpected type of message is received.</exception> [SuppressMessage("Microsoft.Design", "CA1004:GenericMethodsShouldProvideTypeParameter", Justification = "This returns and verifies the appropriate message type.")] - public async Task<TResponse> RequestAsync<TResponse>(IDirectedProtocolMessage requestMessage) + public async Task<TResponse> RequestAsync<TResponse>(IDirectedProtocolMessage requestMessage, CancellationToken cancellationToken) where TResponse : class, IProtocolMessage { Requires.NotNull(requestMessage, "requestMessage"); - IProtocolMessage response = await this.RequestAsync(requestMessage); + IProtocolMessage response = await this.RequestAsync(requestMessage, cancellationToken); ErrorUtilities.VerifyProtocol(response != null, MessagingStrings.ExpectedMessageNotReceived, typeof(TResponse)); var expectedResponse = response as TResponse; @@ -469,18 +433,21 @@ namespace DotNetOpenAuth.Messaging { /// Sends a direct message to a remote party and waits for the response. /// </summary> /// <param name="requestMessage">The message to send.</param> - /// <returns>The remote party's response. Guaranteed to never be null.</returns> + /// <param name="cancellationToken">The cancellation token.</param> + /// <returns> + /// The remote party's response. Guaranteed to never be null. + /// </returns> /// <exception cref="ProtocolException">Thrown if the response does not include a protocol message.</exception> - public async Task<IProtocolMessage> RequestAsync(IDirectedProtocolMessage requestMessage) { + public async Task<IProtocolMessage> RequestAsync(IDirectedProtocolMessage requestMessage, CancellationToken cancellationToken) { Requires.NotNull(requestMessage, "requestMessage"); - this.ProcessOutgoingMessage(requestMessage); + await this.ProcessOutgoingMessageAsync(requestMessage, cancellationToken); Logger.Channel.DebugFormat("Sending {0} request.", requestMessage.GetType().Name); - var responseMessage = await this.RequestCoreAsync(requestMessage); + var responseMessage = await this.RequestCoreAsync(requestMessage, cancellationToken); ErrorUtilities.VerifyProtocol(responseMessage != null, MessagingStrings.ExpectedMessageNotReceived, typeof(IProtocolMessage).Name); Logger.Channel.DebugFormat("Received {0} response.", responseMessage.GetType().Name); - this.ProcessIncomingMessage(responseMessage); + await this.ProcessIncomingMessageAsync(responseMessage, cancellationToken); return responseMessage; } @@ -505,8 +472,8 @@ namespace DotNetOpenAuth.Messaging { /// Thrown when the message is somehow invalid. /// This can be due to tampering, replay attack or expiration, among other things. /// </exception> - internal void ProcessIncomingMessageTestHook(IProtocolMessage message) { - this.ProcessIncomingMessage(message); + internal Task ProcessIncomingMessageTestHookAsync(IProtocolMessage message, CancellationToken cancellationToken) { + return this.ProcessIncomingMessageAsync(message, cancellationToken); } /// <summary> @@ -545,16 +512,18 @@ namespace DotNetOpenAuth.Messaging { return this.ReadFromResponseCoreAsync(response); } - /// <remarks> - /// This method should NOT be called by derived types - /// except when sending ONE WAY request messages. - /// </remarks> /// <summary> /// Prepares a message for transmit by applying signatures, nonces, etc. /// </summary> /// <param name="message">The message to prepare for sending.</param> - internal void ProcessOutgoingMessageTestHook(IProtocolMessage message) { - this.ProcessOutgoingMessage(message); + /// <param name="cancellationToken">The cancellation token.</param> + /// <returns></returns> + /// <remarks> + /// This method should NOT be called by derived types + /// except when sending ONE WAY request messages. + /// </remarks> + internal Task ProcessOutgoingMessageTestHookAsync(IProtocolMessage message, CancellationToken cancellationToken) { + return this.ProcessOutgoingMessageAsync(message, cancellationToken); } /// <summary> @@ -658,7 +627,7 @@ namespace DotNetOpenAuth.Messaging { /// behavior. However in non-HTTP frameworks, such as unit test mocks, it may be appropriate to override /// this method to eliminate all use of an HTTP transport. /// </remarks> - protected virtual async Task<IProtocolMessage> RequestCoreAsync(IDirectedProtocolMessage request) { + protected virtual async Task<IProtocolMessage> RequestCoreAsync(IDirectedProtocolMessage request, CancellationToken cancellationToken) { Requires.NotNull(request, "request"); Requires.That(request.Recipient != null, "request", MessagingStrings.DirectedMessageMissingRecipient); @@ -674,7 +643,8 @@ namespace DotNetOpenAuth.Messaging { IDirectResponseProtocolMessage responseMessage; using (var httpClient = this.HostFactories.CreateHttpClient()) { - using (HttpResponseMessage response = await httpClient.SendAsync(webRequest)) { + using (HttpResponseMessage response = await httpClient.SendAsync(webRequest, cancellationToken)) { + response.EnsureSuccessStatusCode(); if (response.Content == null) { return null; } @@ -717,7 +687,7 @@ namespace DotNetOpenAuth.Messaging { /// </summary> /// <param name="request">The request to search for an embedded message.</param> /// <returns>The deserialized message, if one is found. Null otherwise.</returns> - protected virtual IDirectedProtocolMessage ReadFromRequestCore(HttpRequestBase request) { + protected virtual IDirectedProtocolMessage ReadFromRequestCore(HttpRequestBase request, CancellationToken cancellationToken) { Requires.NotNull(request, "request"); Logger.Channel.DebugFormat("Incoming HTTP request: {0} {1}", request.HttpMethod, request.GetPublicFacingUrl().AbsoluteUri); @@ -843,7 +813,7 @@ namespace DotNetOpenAuth.Messaging { }; response.Headers.Location = builder.Uri; - response.Content.Headers.ContentType = new MediaTypeHeaderValue("text/html; charset=utf-8"); + response.Content.Headers.ContentType = new MediaTypeHeaderValue("text/html; charset=utf-8"); return response; } @@ -949,11 +919,14 @@ namespace DotNetOpenAuth.Messaging { /// Prepares a message for transmit by applying signatures, nonces, etc. /// </summary> /// <param name="message">The message to prepare for sending.</param> + /// <param name="cancellationToken">The cancellation token.</param> + /// <returns></returns> + /// <exception cref="UnprotectedMessageException"></exception> /// <remarks> /// This method should NOT be called by derived types /// except when sending ONE WAY request messages. /// </remarks> - protected void ProcessOutgoingMessage(IProtocolMessage message) { + protected async Task ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { Requires.NotNull(message, "message"); Logger.Channel.DebugFormat("Preparing to send {0} ({1}) message.", message.GetType().Name, message.Version); @@ -968,7 +941,7 @@ namespace DotNetOpenAuth.Messaging { MessageProtections appliedProtection = MessageProtections.None; foreach (IChannelBindingElement bindingElement in this.outgoingBindingElements) { Assumes.True(bindingElement.Channel != null); - MessageProtections? elementProtection = bindingElement.ProcessOutgoingMessage(message); + MessageProtections? elementProtection = await bindingElement.ProcessOutgoingMessageAsync(message, cancellationToken); if (elementProtection.HasValue) { Logger.Bindings.DebugFormat("Binding element {0} applied to message.", bindingElement.GetType().FullName); @@ -1125,7 +1098,7 @@ namespace DotNetOpenAuth.Messaging { /// Thrown when the message is somehow invalid. /// This can be due to tampering, replay attack or expiration, among other things. /// </exception> - protected virtual void ProcessIncomingMessage(IProtocolMessage message) { + protected virtual async Task ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { Requires.NotNull(message, "message"); if (Logger.Channel.IsInfoEnabled) { @@ -1141,7 +1114,7 @@ namespace DotNetOpenAuth.Messaging { MessageProtections appliedProtection = MessageProtections.None; foreach (IChannelBindingElement bindingElement in this.IncomingBindingElements) { Assumes.True(bindingElement.Channel != null); // CC bug: this.IncomingBindingElements ensures this... why must we assume it here? - MessageProtections? elementProtection = bindingElement.ProcessIncomingMessage(message); + MessageProtections? elementProtection = await bindingElement.ProcessIncomingMessageAsync(message, cancellationToken); if (elementProtection.HasValue) { Logger.Bindings.DebugFormat("Binding element {0} applied to message.", bindingElement.GetType().FullName); diff --git a/src/DotNetOpenAuth.Core/Messaging/IChannelBindingElement.cs b/src/DotNetOpenAuth.Core/Messaging/IChannelBindingElement.cs index fca46a0..dc026a4 100644 --- a/src/DotNetOpenAuth.Core/Messaging/IChannelBindingElement.cs +++ b/src/DotNetOpenAuth.Core/Messaging/IChannelBindingElement.cs @@ -6,6 +6,8 @@ namespace DotNetOpenAuth.Messaging { using System; + using System.Threading; + using System.Threading.Tasks; using Validation; /// <summary> @@ -41,7 +43,7 @@ namespace DotNetOpenAuth.Messaging { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - MessageProtections? ProcessOutgoingMessage(IProtocolMessage message); + Task<MessageProtections?> ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken); /// <summary> /// Performs any transformation on an incoming message that may be necessary and/or @@ -60,6 +62,6 @@ namespace DotNetOpenAuth.Messaging { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - MessageProtections? ProcessIncomingMessage(IProtocolMessage message); + Task<MessageProtections?> ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken); } } diff --git a/src/DotNetOpenAuth.Core/Messaging/ProtocolFaultResponseException.cs b/src/DotNetOpenAuth.Core/Messaging/ProtocolFaultResponseException.cs index 3c8839e..c77577b 100644 --- a/src/DotNetOpenAuth.Core/Messaging/ProtocolFaultResponseException.cs +++ b/src/DotNetOpenAuth.Core/Messaging/ProtocolFaultResponseException.cs @@ -10,6 +10,8 @@ namespace DotNetOpenAuth.Messaging { using System.Linq; using System.Net.Http; using System.Text; + using System.Threading; + using System.Threading.Tasks; using Validation; /// <summary> @@ -63,9 +65,8 @@ namespace DotNetOpenAuth.Messaging { /// Creates the HTTP response to forward to the client to report the error. /// </summary> /// <returns>The HTTP response.</returns> - public HttpResponseMessage CreateErrorResponse() { - var response = this.channel.PrepareResponse(this.ErrorResponseMessage); - return response; + public Task<HttpResponseMessage> CreateErrorResponse(CancellationToken cancellationToken) { + return this.channel.PrepareResponseAsync(this.ErrorResponseMessage, cancellationToken); } } } diff --git a/src/DotNetOpenAuth.OpenId.Provider/OpenId/ChannelElements/ProviderSigningBindingElement.cs b/src/DotNetOpenAuth.OpenId.Provider/OpenId/ChannelElements/ProviderSigningBindingElement.cs index fbbf37a..f84860b 100644 --- a/src/DotNetOpenAuth.OpenId.Provider/OpenId/ChannelElements/ProviderSigningBindingElement.cs +++ b/src/DotNetOpenAuth.OpenId.Provider/OpenId/ChannelElements/ProviderSigningBindingElement.cs @@ -9,6 +9,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { using System.Collections.Generic; using System.Linq; using System.Text; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; @@ -161,10 +162,12 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// <returns> /// The applied protections. /// </returns> - protected override MessageProtections VerifySignatureByUnrecognizedHandle(IProtocolMessage message, ITamperResistantOpenIdMessage signedMessage, MessageProtections protectionsApplied) { + protected override Task<MessageProtections> VerifySignatureByUnrecognizedHandleAsync(IProtocolMessage message, ITamperResistantOpenIdMessage signedMessage, MessageProtections protectionsApplied) { // If we're on the Provider, then the RP sent us a check_auth with a signature // we don't have an association for. (It may have expired, or it may be a faulty RP). - throw new InvalidSignatureException(message); + var tcs = new TaskCompletionSource<MessageProtections>(); + tcs.SetException(new InvalidSignatureException(message)); + return tcs.Task; } /// <summary> @@ -224,9 +227,9 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { MessageDescription description = this.Channel.MessageDescriptions.Get(signedMessage); var signedParts = from part in description.Mapping.Values - where (part.RequiredProtection & System.Net.Security.ProtectionLevel.Sign) != 0 - && part.GetValue(signedMessage) != null - select part.Name; + where (part.RequiredProtection & System.Net.Security.ProtectionLevel.Sign) != 0 + && part.GetValue(signedMessage) != null + select part.Name; string prefix = Protocol.V20.openid.Prefix; ErrorUtilities.VerifyInternal(signedParts.All(name => name.StartsWith(prefix, StringComparison.Ordinal)), "All signed message parts must start with 'openid.'."); diff --git a/src/DotNetOpenAuth.OpenId.RelyingParty/DotNetOpenAuth.OpenId.RelyingParty.csproj b/src/DotNetOpenAuth.OpenId.RelyingParty/DotNetOpenAuth.OpenId.RelyingParty.csproj index 43fd0ae..5e0eeca 100644 --- a/src/DotNetOpenAuth.OpenId.RelyingParty/DotNetOpenAuth.OpenId.RelyingParty.csproj +++ b/src/DotNetOpenAuth.OpenId.RelyingParty/DotNetOpenAuth.OpenId.RelyingParty.csproj @@ -79,6 +79,8 @@ </ItemGroup> <ItemGroup> <Reference Include="System" /> + <Reference Include="System.Net.Http" /> + <Reference Include="System.Net.Http.WebRequest" /> <Reference Include="Validation"> <HintPath>..\packages\Validation.2.0.1.12362\lib\portable-windows8+net40+sl5+windowsphone8\Validation.dll</HintPath> <Private>True</Private> diff --git a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/ChannelElements/RelyingPartySecurityOptions.cs b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/ChannelElements/RelyingPartySecurityOptions.cs index b9328dd..01aa16e 100644 --- a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/ChannelElements/RelyingPartySecurityOptions.cs +++ b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/ChannelElements/RelyingPartySecurityOptions.cs @@ -5,6 +5,8 @@ //----------------------------------------------------------------------- namespace DotNetOpenAuth.OpenId.ChannelElements { + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId.Messages; using DotNetOpenAuth.OpenId.RelyingParty; @@ -13,6 +15,11 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Helps ensure compliance to some properties in the <see cref="RelyingPartySecuritySettings"/>. /// </summary> internal class RelyingPartySecurityOptions : IChannelBindingElement { + private static readonly Task<MessageProtections?> NullTask = Task.FromResult<MessageProtections?>(null); + + private static readonly Task<MessageProtections?> NoneTask = + Task.FromResult<MessageProtections?>(MessageProtections.None); + /// <summary> /// The security settings that are active on the relying party. /// </summary> @@ -58,8 +65,8 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - public MessageProtections? ProcessOutgoingMessage(IProtocolMessage message) { - return null; + public Task<MessageProtections?> ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { + return NullTask; } /// <summary> @@ -79,7 +86,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - public MessageProtections? ProcessIncomingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { var positiveAssertion = message as PositiveAssertionResponse; if (positiveAssertion != null) { ErrorUtilities.VerifyProtocol( @@ -87,10 +94,10 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { positiveAssertion.LocalIdentifier == positiveAssertion.ClaimedIdentifier, OpenIdStrings.DelegatingIdentifiersNotAllowed); - return MessageProtections.None; + return NoneTask; } - return null; + return NullTask; } #endregion diff --git a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/ChannelElements/RelyingPartySigningBindingElement.cs b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/ChannelElements/RelyingPartySigningBindingElement.cs index 3ec2eee..6a6dee2 100644 --- a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/ChannelElements/RelyingPartySigningBindingElement.cs +++ b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/ChannelElements/RelyingPartySigningBindingElement.cs @@ -9,6 +9,8 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { using System.Collections.Generic; using System.Linq; using System.Text; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OpenId.Messages; @@ -78,12 +80,12 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// <returns> /// The applied protections. /// </returns> - protected override MessageProtections VerifySignatureByUnrecognizedHandle(IProtocolMessage message, ITamperResistantOpenIdMessage signedMessage, MessageProtections protectionsApplied) { + protected override async Task<MessageProtections> VerifySignatureByUnrecognizedHandleAsync(IProtocolMessage message, ITamperResistantOpenIdMessage signedMessage, MessageProtections protectionsApplied, CancellationToken cancellationToken) { // We did not recognize the association the provider used to sign the message. // Ask the provider to check the signature then. var indirectSignedResponse = (IndirectSignedResponse)signedMessage; var checkSignatureRequest = new CheckAuthenticationRequest(indirectSignedResponse, this.Channel); - var checkSignatureResponse = this.Channel.Request<CheckAuthenticationResponse>(checkSignatureRequest); + var checkSignatureResponse = await this.Channel.RequestAsync<CheckAuthenticationResponse>(checkSignatureRequest, cancellationToken); if (!checkSignatureResponse.IsValid) { Logger.Bindings.Error("Provider reports signature verification failed."); throw new InvalidSignatureException(message); diff --git a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/ChannelElements/ReturnToNonceBindingElement.cs b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/ChannelElements/ReturnToNonceBindingElement.cs index c459487..d71a086 100644 --- a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/ChannelElements/ReturnToNonceBindingElement.cs +++ b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/ChannelElements/ReturnToNonceBindingElement.cs @@ -9,6 +9,8 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { using System.Collections.Generic; using System.Linq; using System.Text; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OpenId.Messages; @@ -46,6 +48,11 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// only on the RP side and only on some messages.</para> /// </remarks> internal class ReturnToNonceBindingElement : IChannelBindingElement { + private static readonly Task<MessageProtections?> NullTask = Task.FromResult<MessageProtections?>(null); + + private static readonly Task<MessageProtections?> ReplayProtectionTask = + Task.FromResult<MessageProtections?>(MessageProtections.ReplayProtection); + /// <summary> /// The context within which return_to nonces must be unique -- they all go into the same bucket. /// </summary> @@ -136,17 +143,17 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - public MessageProtections? ProcessOutgoingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { // We only add a nonce to some auth requests. SignedResponseRequest request = message as SignedResponseRequest; if (this.UseRequestNonce(request)) { request.AddReturnToArguments(Protocol.ReturnToNonceParameter, CustomNonce.NewNonce().Serialize()); request.SignReturnTo = true; // a nonce without a signature is completely pointless - return MessageProtections.ReplayProtection; + return ReplayProtectionTask; } - return null; + return NullTask; } /// <summary> @@ -166,7 +173,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - public MessageProtections? ProcessIncomingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { IndirectSignedResponse response = message as IndirectSignedResponse; if (this.UseRequestNonce(response)) { if (!response.ReturnToParametersSignatureValidated) { @@ -190,10 +197,10 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { throw new ReplayedMessageException(message); } - return MessageProtections.ReplayProtection; + return ReplayProtectionTask; } - return null; + return NullTask; } #endregion diff --git a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/HostMetaDiscoveryService.cs b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/HostMetaDiscoveryService.cs index 1871f19..6a517ad 100644 --- a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/HostMetaDiscoveryService.cs +++ b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/HostMetaDiscoveryService.cs @@ -13,12 +13,17 @@ namespace DotNetOpenAuth.OpenId { using System.IO; using System.Linq; using System.Net; + using System.Net.Cache; + using System.Net.Http; + using System.Net.Http.Headers; using System.Security; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using System.Security.Permissions; using System.Text; using System.Text.RegularExpressions; + using System.Threading; + using System.Threading.Tasks; using System.Xml; using System.Xml.XPath; using DotNetOpenAuth.Configuration; @@ -62,10 +67,15 @@ namespace DotNetOpenAuth.OpenId { /// <summary> /// Initializes a new instance of the <see cref="HostMetaDiscoveryService"/> class. /// </summary> - public HostMetaDiscoveryService() { + public HostMetaDiscoveryService(IHostFactories hostFactories) { + Requires.NotNull(hostFactories, "hostFactories"); + this.TrustedHostMetaProxies = new List<HostMetaProxy>(); + this.HostFactories = hostFactories; } + public IHostFactories HostFactories { get; private set; } + /// <summary> /// Gets the set of URI templates to use to contact host-meta hosting proxies /// for domain discovery. @@ -106,27 +116,25 @@ namespace DotNetOpenAuth.OpenId { /// <returns> /// A sequence of service endpoints yielded by discovery. Must not be null, but may be empty. /// </returns> - public IEnumerable<IdentifierDiscoveryResult> Discover(Identifier identifier, IDirectWebRequestHandler requestHandler, out bool abortDiscoveryChain) { - abortDiscoveryChain = false; - + public async Task<IdentifierDiscoveryServiceResult> DiscoverAsync(Identifier identifier, CancellationToken cancellationToken) { // Google Apps are always URIs -- not XRIs. var uriIdentifier = identifier as UriIdentifier; if (uriIdentifier == null) { - return Enumerable.Empty<IdentifierDiscoveryResult>(); + return new IdentifierDiscoveryServiceResult(Enumerable.Empty<IdentifierDiscoveryResult>()); } var results = new List<IdentifierDiscoveryResult>(); - string signingHost; - using (var response = GetXrdsResponse(uriIdentifier, requestHandler, out signingHost)) { - if (response != null) { + using (var response = await this.GetXrdsResponseAsync(uriIdentifier, cancellationToken)) { + if (response.Result != null) { try { var readerSettings = MessagingUtilities.CreateUntrustedXmlReaderSettings(); - var document = new XrdsDocument(XmlReader.Create(response.ResponseStream, readerSettings)); - ValidateXmlDSig(document, uriIdentifier, response, signingHost); + var responseStream = await response.Result.Content.ReadAsStreamAsync(); + var document = new XrdsDocument(XmlReader.Create(responseStream, readerSettings)); + await ValidateXmlDSigAsync(document, uriIdentifier, response.Result, response.SigningHost); var xrds = GetXrdElements(document, uriIdentifier.Uri.Host); // Look for claimed identifier template URIs for an additional XRDS document. - results.AddRange(GetExternalServices(xrds, uriIdentifier, requestHandler)); + results.AddRange(await this.GetExternalServicesAsync(xrds, uriIdentifier, cancellationToken)); // If we couldn't find any claimed identifiers, look for OP identifiers. // Normally this would be the opposite (OP Identifiers take precedence over @@ -136,15 +144,13 @@ namespace DotNetOpenAuth.OpenId { if (results.Count == 0) { results.AddRange(xrds.CreateServiceEndpoints(uriIdentifier, uriIdentifier)); } - - abortDiscoveryChain = true; } catch (XmlException ex) { - Logger.Yadis.ErrorFormat("Error while parsing XRDS document at {0} pointed to by host-meta: {1}", response.FinalUri, ex); + Logger.Yadis.ErrorFormat("Error while parsing XRDS document at {0} pointed to by host-meta: {1}", response.Result.RequestMessage.RequestUri, ex); } } } - return results; + return new IdentifierDiscoveryServiceResult(results, abortDiscoveryChain: true); } #endregion @@ -181,10 +187,9 @@ namespace DotNetOpenAuth.OpenId { /// <param name="identifier">The identifier under discovery.</param> /// <param name="requestHandler">The request handler.</param> /// <returns>The discovered services.</returns> - private static IEnumerable<IdentifierDiscoveryResult> GetExternalServices(IEnumerable<XrdElement> xrds, UriIdentifier identifier, IDirectWebRequestHandler requestHandler) { + private async Task<IEnumerable<IdentifierDiscoveryResult>> GetExternalServicesAsync(IEnumerable<XrdElement> xrds, UriIdentifier identifier, CancellationToken cancellationToken) { Requires.NotNull(xrds, "xrds"); Requires.NotNull(identifier, "identifier"); - Requires.NotNull(requestHandler, "requestHandler"); var results = new List<IdentifierDiscoveryResult>(); foreach (var serviceElement in GetDescribedByServices(xrds)) { @@ -194,10 +199,11 @@ namespace DotNetOpenAuth.OpenId { Uri externalLocation = new Uri(templateNode.Value.Trim().Replace("{%uri}", Uri.EscapeDataString(identifier.Uri.AbsoluteUri))); string nextAuthority = nextAuthorityNode != null ? nextAuthorityNode.Value.Trim() : identifier.Uri.Host; try { - using (var externalXrdsResponse = GetXrdsResponse(identifier, requestHandler, externalLocation)) { + using (var externalXrdsResponse = await this.GetXrdsResponseAsync(identifier, externalLocation, cancellationToken)) { var readerSettings = MessagingUtilities.CreateUntrustedXmlReaderSettings(); - XrdsDocument externalXrds = new XrdsDocument(XmlReader.Create(externalXrdsResponse.ResponseStream, readerSettings)); - ValidateXmlDSig(externalXrds, identifier, externalXrdsResponse, nextAuthority); + var responseStream = await externalXrdsResponse.Content.ReadAsStreamAsync(); + XrdsDocument externalXrds = new XrdsDocument(XmlReader.Create(responseStream, readerSettings)); + await ValidateXmlDSigAsync(externalXrds, identifier, externalXrdsResponse, nextAuthority); results.AddRange(GetXrdElements(externalXrds, identifier).CreateServiceEndpoints(identifier, identifier)); } } catch (ProtocolException ex) { @@ -220,7 +226,7 @@ namespace DotNetOpenAuth.OpenId { /// <param name="signingHost">The host name on the certificate that should be used to verify the signature in the XRDS.</param> /// <exception cref="ProtocolException">Thrown if the XRDS document has an invalid or a missing signature.</exception> [SuppressMessage("Microsoft.Naming", "CA2204:Literals should be spelled correctly", MessageId = "XmlDSig", Justification = "xml")] - private static void ValidateXmlDSig(XrdsDocument document, UriIdentifier identifier, IncomingWebResponse response, string signingHost) { + private static async Task ValidateXmlDSigAsync(XrdsDocument document, UriIdentifier identifier, HttpResponseMessage response, string signingHost) { Requires.NotNull(document, "document"); Requires.NotNull(identifier, "identifier"); Requires.NotNull(response, "response"); @@ -246,11 +252,12 @@ namespace DotNetOpenAuth.OpenId { ErrorUtilities.VerifyProtocol(string.Equals(hostName, signingHost, StringComparison.OrdinalIgnoreCase), OpenIdStrings.MisdirectedSigningCertificate, hostName, signingHost); // Verify the signature itself - byte[] signature = Convert.FromBase64String(response.Headers["Signature"]); + byte[] signature = Convert.FromBase64String(response.Headers.GetValues("Signature").First()); var provider = (RSACryptoServiceProvider)certs.First().PublicKey.Key; - byte[] data = new byte[response.ResponseStream.Length]; - response.ResponseStream.Seek(0, SeekOrigin.Begin); - response.ResponseStream.Read(data, 0, data.Length); + var responseStream = await response.Content.ReadAsStreamAsync(); + byte[] data = new byte[responseStream.Length]; + responseStream.Seek(0, SeekOrigin.Begin); + await responseStream.ReadAsync(data, 0, data.Length); ErrorUtilities.VerifyProtocol(provider.VerifyData(data, "SHA1", signature), OpenIdStrings.InvalidDSig); } @@ -292,21 +299,28 @@ namespace DotNetOpenAuth.OpenId { /// A HTTP response carrying an XRDS document. /// </returns> /// <exception cref="ProtocolException">Thrown if the XRDS document could not be obtained.</exception> - private static IncomingWebResponse GetXrdsResponse(UriIdentifier identifier, IDirectWebRequestHandler requestHandler, Uri xrdsLocation) { + private async Task<HttpResponseMessage> GetXrdsResponseAsync(UriIdentifier identifier, Uri xrdsLocation, CancellationToken cancellationToken) { Requires.NotNull(identifier, "identifier"); - Requires.NotNull(requestHandler, "requestHandler"); Requires.NotNull(xrdsLocation, "xrdsLocation"); - var request = (HttpWebRequest)WebRequest.Create(xrdsLocation); - request.CachePolicy = Yadis.IdentifierDiscoveryCachePolicy; - request.Accept = ContentTypes.Xrds; - var options = identifier.IsDiscoverySecureEndToEnd ? DirectWebRequestOptions.RequireSsl : DirectWebRequestOptions.None; - var response = requestHandler.GetResponse(request, options).GetSnapshot(Yadis.MaximumResultToScan); - if (!string.Equals(response.ContentType.MediaType, ContentTypes.Xrds, StringComparison.Ordinal)) { - Logger.Yadis.WarnFormat("Host-meta pointed to XRDS at {0}, but Content-Type at that URL was unexpected value '{1}'.", xrdsLocation, response.ContentType); - } + using (var httpClient = this.HostFactories.CreateHttpClient(identifier.IsDiscoverySecureEndToEnd, Yadis.IdentifierDiscoveryCachePolicy)) { + var request = new HttpRequestMessage(HttpMethod.Get, xrdsLocation); + request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue(ContentTypes.Xrds)); + var response = await httpClient.SendAsync(request, cancellationToken); + try { + if (!string.Equals(response.Content.Headers.ContentType.MediaType, ContentTypes.Xrds, StringComparison.Ordinal)) { + Logger.Yadis.WarnFormat( + "Host-meta pointed to XRDS at {0}, but Content-Type at that URL was unexpected value '{1}'.", + xrdsLocation, + response.Content.Headers.ContentType); + } - return response; + return response; + } catch { + response.Dispose(); + throw; + } + } } /// <summary> @@ -358,17 +372,16 @@ namespace DotNetOpenAuth.OpenId { /// <param name="signingHost">The host name on the certificate that should be used to verify the signature in the XRDS.</param> /// <returns>A HTTP response carrying an XRDS document, or <c>null</c> if one could not be obtained.</returns> /// <exception cref="ProtocolException">Thrown if the XRDS document could not be obtained.</exception> - private IncomingWebResponse GetXrdsResponse(UriIdentifier identifier, IDirectWebRequestHandler requestHandler, out string signingHost) { + private async Task<ResultWithSigningHost<HttpResponseMessage>> GetXrdsResponseAsync(UriIdentifier identifier, CancellationToken cancellationToken) { Requires.NotNull(identifier, "identifier"); - Requires.NotNull(requestHandler, "requestHandler"); - Uri xrdsLocation = this.GetXrdsLocation(identifier, requestHandler, out signingHost); - if (xrdsLocation == null) { - return null; - } - var response = GetXrdsResponse(identifier, requestHandler, xrdsLocation); + var result = await this.GetXrdsLocationAsync(identifier, cancellationToken); + if (result.Result == null) { + return new ResultWithSigningHost<HttpResponseMessage>(); + } - return response; + var response = await this.GetXrdsResponseAsync(identifier, result.Result, cancellationToken); + return new ResultWithSigningHost<HttpResponseMessage>(response, result.SigningHost); } /// <summary> @@ -378,26 +391,26 @@ namespace DotNetOpenAuth.OpenId { /// <param name="requestHandler">The request handler.</param> /// <param name="signingHost">The host name on the certificate that should be used to verify the signature in the XRDS.</param> /// <returns>An absolute URI, or <c>null</c> if one could not be determined.</returns> - private Uri GetXrdsLocation(UriIdentifier identifier, IDirectWebRequestHandler requestHandler, out string signingHost) { + private async Task<ResultWithSigningHost<Uri>> GetXrdsLocationAsync(UriIdentifier identifier, CancellationToken cancellationToken) { Requires.NotNull(identifier, "identifier"); - Requires.NotNull(requestHandler, "requestHandler"); - using (var hostMetaResponse = this.GetHostMeta(identifier, requestHandler, out signingHost)) { - if (hostMetaResponse == null) { - return null; + + using (var hostMetaResponse = await this.GetHostMetaAsync(identifier, cancellationToken)) { + if (hostMetaResponse.Result == null) { + return new ResultWithSigningHost<Uri>(); } - using (var sr = hostMetaResponse.GetResponseReader()) { - string line = sr.ReadLine(); + using (var sr = new StreamReader(await hostMetaResponse.Result.Content.ReadAsStreamAsync())) { + string line = await sr.ReadLineAsync(); Match m = HostMetaLink.Match(line); if (m.Success) { Uri location = new Uri(m.Groups["location"].Value); - Logger.Yadis.InfoFormat("Found link to XRDS at {0} in host-meta document {1}.", location, hostMetaResponse.FinalUri); - return location; + Logger.Yadis.InfoFormat("Found link to XRDS at {0} in host-meta document {1}.", location, hostMetaResponse.Result.RequestMessage.RequestUri); + return new ResultWithSigningHost<Uri>(location, hostMetaResponse.SigningHost); } } - Logger.Yadis.WarnFormat("Could not find link to XRDS in host-meta document: {0}", hostMetaResponse.FinalUri); - return null; + Logger.Yadis.WarnFormat("Could not find link to XRDS in host-meta document: {0}", hostMetaResponse.Result.RequestMessage.RequestUri); + return new ResultWithSigningHost<Uri>(); } } @@ -410,35 +423,29 @@ namespace DotNetOpenAuth.OpenId { /// <returns> /// The host-meta response, or <c>null</c> if no host-meta document could be obtained. /// </returns> - private IncomingWebResponse GetHostMeta(UriIdentifier identifier, IDirectWebRequestHandler requestHandler, out string signingHost) { + private async Task<ResultWithSigningHost<HttpResponseMessage>> GetHostMetaAsync(UriIdentifier identifier, CancellationToken cancellationToken) { Requires.NotNull(identifier, "identifier"); - Requires.NotNull(requestHandler, "requestHandler"); - foreach (var hostMetaProxy in this.GetHostMetaLocations(identifier)) { - var hostMetaLocation = hostMetaProxy.GetProxy(identifier); - var request = (HttpWebRequest)WebRequest.Create(hostMetaLocation); - request.CachePolicy = Yadis.IdentifierDiscoveryCachePolicy; - var options = DirectWebRequestOptions.AcceptAllHttpResponses; - if (identifier.IsDiscoverySecureEndToEnd) { - options |= DirectWebRequestOptions.RequireSsl; - } - var response = requestHandler.GetResponse(request, options).GetSnapshot(Yadis.MaximumResultToScan); - try { - if (response.Status == HttpStatusCode.OK) { - Logger.Yadis.InfoFormat("Found host-meta for {0} at: {1}", identifier.Uri.Host, hostMetaLocation); - signingHost = hostMetaProxy.GetSigningHost(identifier); - return response; - } else { - Logger.Yadis.InfoFormat("Could not obtain host-meta for {0} from {1}", identifier.Uri.Host, hostMetaLocation); + + using (var httpClient = this.HostFactories.CreateHttpClient(identifier.IsDiscoverySecureEndToEnd, Yadis.IdentifierDiscoveryCachePolicy)) { + foreach (var hostMetaProxy in this.GetHostMetaLocations(identifier)) { + var hostMetaLocation = hostMetaProxy.GetProxy(identifier); + var response = await httpClient.GetAsync(hostMetaLocation, cancellationToken); + try { + if (response.IsSuccessStatusCode) { + Logger.Yadis.InfoFormat("Found host-meta for {0} at: {1}", identifier.Uri.Host, hostMetaLocation); + return new ResultWithSigningHost<HttpResponseMessage>(response, hostMetaProxy.GetSigningHost(identifier)); + } else { + Logger.Yadis.InfoFormat("Could not obtain host-meta for {0} from {1}", identifier.Uri.Host, hostMetaLocation); + response.Dispose(); + } + } catch { response.Dispose(); + throw; } - } catch { - response.Dispose(); - throw; } } - signingHost = null; - return null; + return new ResultWithSigningHost<HttpResponseMessage>(); } /// <summary> @@ -546,5 +553,22 @@ namespace DotNetOpenAuth.OpenId { return this.ProxyFormat.GetHashCode(); } } + + private struct ResultWithSigningHost<T> : IDisposable { + internal ResultWithSigningHost(T result, string signingHost) + : this() { + this.Result = result; + this.SigningHost = signingHost; + } + + public T Result { get; private set; } + + public string SigningHost { get; private set; } + + public void Dispose() { + var disposable = this.Result as IDisposable; + disposable.DisposeIfNotNull(); + } + } } } diff --git a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/Interop/OpenIdRelyingPartyShim.cs b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/Interop/OpenIdRelyingPartyShim.cs index eb37d86..9568c1d 100644 --- a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/Interop/OpenIdRelyingPartyShim.cs +++ b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/Interop/OpenIdRelyingPartyShim.cs @@ -11,6 +11,7 @@ namespace DotNetOpenAuth.OpenId.Interop { using System.IO; using System.Runtime.InteropServices; using System.Text; + using System.Threading; using System.Web; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId.Extensions.SimpleRegistration; @@ -68,8 +69,9 @@ namespace DotNetOpenAuth.OpenId.Interop { /// <exception cref="ProtocolException">Thrown if no OpenID endpoint could be found.</exception> [SuppressMessage("Microsoft.Usage", "CA2234:PassSystemUriObjectsInsteadOfStrings", Justification = "COM requires primitive types")] public string CreateRequest(string userSuppliedIdentifier, string realm, string returnToUrl) { - var request = relyingParty.CreateRequest(userSuppliedIdentifier, realm, new Uri(returnToUrl)); - return request.RedirectingResponse.GetDirectUriRequest(relyingParty.Channel).AbsoluteUri; + var request = relyingParty.CreateRequestAsync(userSuppliedIdentifier, realm, new Uri(returnToUrl)).Result; + var response = request.GetRedirectingResponseAsync(CancellationToken.None).Result; + return response.GetDirectUriRequest().AbsoluteUri; } /// <summary> @@ -91,7 +93,7 @@ namespace DotNetOpenAuth.OpenId.Interop { /// <exception cref="ProtocolException">Thrown if no OpenID endpoint could be found.</exception> [SuppressMessage("Microsoft.Usage", "CA2234:PassSystemUriObjectsInsteadOfStrings", Justification = "COM requires primitive types")] public string CreateRequestWithSimpleRegistration(string userSuppliedIdentifier, string realm, string returnToUrl, string optionalSreg, string requiredSreg) { - var request = relyingParty.CreateRequest(userSuppliedIdentifier, realm, new Uri(returnToUrl)); + var request = relyingParty.CreateRequestAsync(userSuppliedIdentifier, realm, new Uri(returnToUrl)).Result; ClaimsRequest sreg = new ClaimsRequest(); if (!string.IsNullOrEmpty(optionalSreg)) { @@ -101,7 +103,8 @@ namespace DotNetOpenAuth.OpenId.Interop { sreg.SetProfileRequestFromList(requiredSreg.Split(','), DemandLevel.Require); } request.AddExtension(sreg); - return request.RedirectingResponse.GetDirectUriRequest(relyingParty.Channel).AbsoluteUri; + var response = request.GetRedirectingResponseAsync(CancellationToken.None).Result; + return response.GetDirectUriRequest().AbsoluteUri; } /// <summary> @@ -120,7 +123,7 @@ namespace DotNetOpenAuth.OpenId.Interop { } HttpRequestBase requestInfo = new HttpRequestInfo(method, new Uri(url), form: formMap); - var response = relyingParty.GetResponse(requestInfo); + var response = relyingParty.GetResponseAsync(requestInfo, CancellationToken.None).Result; if (response != null) { return new AuthenticationResponseShim(response); } diff --git a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/AssociationManager.cs b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/AssociationManager.cs index dfb307b..14566e1 100644 --- a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/AssociationManager.cs +++ b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/AssociationManager.cs @@ -11,6 +11,8 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { using System.Net; using System.Security; using System.Text; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId.ChannelElements; using DotNetOpenAuth.OpenId.Messages; @@ -131,8 +133,8 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// </summary> /// <param name="provider">The provider to get an association for.</param> /// <returns>The existing or new association; <c>null</c> if none existed and one could not be created.</returns> - internal Association GetOrCreateAssociation(IProviderEndpoint provider) { - return this.GetExistingAssociation(provider) ?? this.CreateNewAssociation(provider); + internal async Task<Association> GetOrCreateAssociationAsync(IProviderEndpoint provider, CancellationToken cancellationToken) { + return this.GetExistingAssociation(provider) ?? await this.CreateNewAssociationAsync(provider, cancellationToken); } /// <summary> @@ -148,7 +150,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// association store. /// Any new association is automatically added to the <see cref="associationStore"/>. /// </remarks> - private Association CreateNewAssociation(IProviderEndpoint provider) { + private async Task<Association> CreateNewAssociationAsync(IProviderEndpoint provider, CancellationToken cancellationToken) { Requires.NotNull(provider, "provider"); // If there is no association store, there is no point in creating an association. @@ -160,7 +162,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { var associateRequest = AssociateRequestRelyingParty.Create(this.securitySettings, provider); const int RenegotiateRetries = 1; - return this.CreateNewAssociation(provider, associateRequest, RenegotiateRetries); + return await this.CreateNewAssociationAsync(provider, associateRequest, RenegotiateRetries, cancellationToken); } catch (VerificationException ex) { // See Trac ticket #163. In partial trust host environments, the // Diffie-Hellman implementation we're using for HTTP OP endpoints @@ -182,7 +184,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// The newly created association, or null if no association can be created with /// the given Provider given the current security settings. /// </returns> - private Association CreateNewAssociation(IProviderEndpoint provider, AssociateRequest associateRequest, int retriesRemaining) { + private async Task<Association> CreateNewAssociationAsync(IProviderEndpoint provider, AssociateRequest associateRequest, int retriesRemaining, CancellationToken cancellationToken) { Requires.NotNull(provider, "provider"); if (associateRequest == null || retriesRemaining < 0) { @@ -191,8 +193,9 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { return null; } + Exception exception = null; try { - var associateResponse = this.channel.Request(associateRequest); + var associateResponse = await this.channel.RequestAsync(associateRequest, cancellationToken); var associateSuccessfulResponse = associateResponse as IAssociateSuccessfulResponseRelyingParty; var associateUnsuccessfulResponse = associateResponse as AssociateUnsuccessfulResponse; if (associateSuccessfulResponse != null) { @@ -224,23 +227,27 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { associateUnsuccessfulResponse.SessionType); associateRequest = AssociateRequestRelyingParty.Create(this.securitySettings, provider, associateUnsuccessfulResponse.AssociationType, associateUnsuccessfulResponse.SessionType); - return this.CreateNewAssociation(provider, associateRequest, retriesRemaining - 1); + return await this.CreateNewAssociationAsync(provider, associateRequest, retriesRemaining - 1, cancellationToken); } else { throw new ProtocolException(MessagingStrings.UnexpectedMessageReceivedOfMany); } } catch (ProtocolException ex) { - // If the association failed because the remote server can't handle Expect: 100 Continue headers, - // then our web request handler should have already accomodated for future calls. Go ahead and - // immediately make one of those future calls now to try to get the association to succeed. - if (StandardWebRequestHandler.IsExceptionFrom417ExpectationFailed(ex)) { - return this.CreateNewAssociation(provider, associateRequest, retriesRemaining - 1); - } + exception = ex; + } - // Since having associations with OPs is not totally critical, we'll log and eat - // the exception so that auth may continue in dumb mode. - Logger.OpenId.ErrorFormat("An error occurred while trying to create an association with {0}. {1}", provider.Uri, ex); - return null; + Assumes.NotNull(exception); + + // If the association failed because the remote server can't handle Expect: 100 Continue headers, + // then our web request handler should have already accomodated for future calls. Go ahead and + // immediately make one of those future calls now to try to get the association to succeed. + if (UntrustedWebRequestHandler.IsExceptionFrom417ExpectationFailed(exception)) { + return await this.CreateNewAssociationAsync(provider, associateRequest, retriesRemaining - 1, cancellationToken); } + + // Since having associations with OPs is not totally critical, we'll log and eat + // the exception so that auth may continue in dumb mode. + Logger.OpenId.ErrorFormat("An error occurred while trying to create an association with {0}. {1}", provider.Uri, exception); + return null; } } } diff --git a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/AuthenticationRequest.cs b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/AuthenticationRequest.cs index 92af297..f9abc37 100644 --- a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/AuthenticationRequest.cs +++ b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/AuthenticationRequest.cs @@ -9,10 +9,12 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { using System.Collections.Generic; using System.Collections.Specialized; using System.Linq; + using System.Net.Http; + using System.ServiceModel.Channels; using System.Text; using System.Threading; + using System.Threading.Tasks; using System.Web; - using DotNetOpenAuth.Configuration; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId.ChannelElements; @@ -95,14 +97,13 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// to redirect it to the OpenID Provider to start the OpenID authentication process. /// </summary> /// <value></value> - public OutgoingWebResponse RedirectingResponse { - get { - foreach (var behavior in this.RelyingParty.Behaviors) { - behavior.OnOutgoingAuthenticationRequest(this); - } - - return this.RelyingParty.Channel.PrepareResponse(this.CreateRequestMessage()); + public async Task<HttpResponseMessage> GetRedirectingResponseAsync(CancellationToken cancellationToken) { + foreach (var behavior in this.RelyingParty.Behaviors) { + behavior.OnOutgoingAuthenticationRequest(this); } + + var request = await this.CreateRequestMessageAsync(cancellationToken); + return await this.RelyingParty.Channel.PrepareResponseAsync(request, cancellationToken); } /// <summary> @@ -293,16 +294,6 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { this.extensions.Add(extension); } - /// <summary> - /// Redirects the user agent to the provider for authentication. - /// </summary> - /// <remarks> - /// This method requires an ASP.NET HttpContext. - /// </remarks> - public void RedirectToProvider() { - this.RedirectingResponse.Send(); - } - #endregion /// <summary> @@ -318,7 +309,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// A sequence of authentication requests, any of which constitutes a valid identity assertion on the Claimed Identifier. /// Never null, but may be empty. /// </returns> - internal static IEnumerable<AuthenticationRequest> Create(Identifier userSuppliedIdentifier, OpenIdRelyingParty relyingParty, Realm realm, Uri returnToUrl, bool createNewAssociationsAsNeeded) { + internal static async Task<IEnumerable<AuthenticationRequest>> CreateAsync(Identifier userSuppliedIdentifier, OpenIdRelyingParty relyingParty, Realm realm, Uri returnToUrl, bool createNewAssociationsAsNeeded, CancellationToken cancellationToken) { Requires.NotNull(userSuppliedIdentifier, "userSuppliedIdentifier"); Requires.NotNull(relyingParty, "relyingParty"); Requires.NotNull(realm, "realm"); @@ -360,7 +351,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { // Perform discovery right now (not deferred). IEnumerable<IdentifierDiscoveryResult> serviceEndpoints; try { - var results = relyingParty.Discover(userSuppliedIdentifier).CacheGeneratedResults(); + var results = (await relyingParty.DiscoverAsync(userSuppliedIdentifier, cancellationToken)).CacheGeneratedResults(); // If any OP Identifier service elements were found, we must not proceed // to use any Claimed Identifier services, per OpenID 2.0 sections 7.3.2.2 and 11.2. @@ -381,7 +372,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { serviceEndpoints = relyingParty.SecuritySettings.FilterEndpoints(serviceEndpoints); // Call another method that defers request generation. - return CreateInternal(userSuppliedIdentifier, relyingParty, realm, returnToUrl, serviceEndpoints, createNewAssociationsAsNeeded); + return await CreateInternalAsync(userSuppliedIdentifier, relyingParty, realm, returnToUrl, serviceEndpoints, createNewAssociationsAsNeeded, cancellationToken); } /// <summary> @@ -401,9 +392,8 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// based on the properties in this instance. /// </summary> /// <returns>The message to send to the Provider.</returns> - internal SignedResponseRequest CreateRequestMessageTestHook() - { - return this.CreateRequestMessage(); + internal Task<SignedResponseRequest> CreateRequestMessageTestHookAsync(CancellationToken cancellationToken) { + return this.CreateRequestMessageAsync(cancellationToken); } /// <summary> @@ -423,18 +413,18 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// All data validation and cleansing steps must have ALREADY taken place /// before calling this method. /// </remarks> - private static IEnumerable<AuthenticationRequest> CreateInternal(Identifier userSuppliedIdentifier, OpenIdRelyingParty relyingParty, Realm realm, Uri returnToUrl, IEnumerable<IdentifierDiscoveryResult> serviceEndpoints, bool createNewAssociationsAsNeeded) { - // DO NOT USE CODE CONTRACTS IN THIS METHOD, since it uses yield return - ErrorUtilities.VerifyArgumentNotNull(userSuppliedIdentifier, "userSuppliedIdentifier"); - ErrorUtilities.VerifyArgumentNotNull(relyingParty, "relyingParty"); - ErrorUtilities.VerifyArgumentNotNull(realm, "realm"); - ErrorUtilities.VerifyArgumentNotNull(serviceEndpoints, "serviceEndpoints"); + private static async Task<IEnumerable<AuthenticationRequest>> CreateInternalAsync(Identifier userSuppliedIdentifier, OpenIdRelyingParty relyingParty, Realm realm, Uri returnToUrl, IEnumerable<IdentifierDiscoveryResult> serviceEndpoints, bool createNewAssociationsAsNeeded, CancellationToken cancellationToken) { + Requires.NotNull(userSuppliedIdentifier, "userSuppliedIdentifier"); + Requires.NotNull(relyingParty, "relyingParty"); + Requires.NotNull(realm, "realm"); + Requires.NotNull(serviceEndpoints, "serviceEndpoints"); //// // If shared associations are required, then we had better have an association store. ErrorUtilities.VerifyOperation(!relyingParty.SecuritySettings.RequireAssociation || relyingParty.AssociationManager.HasAssociationStore, OpenIdStrings.AssociationStoreRequired); Logger.Yadis.InfoFormat("Performing discovery on user-supplied identifier: {0}", userSuppliedIdentifier); IEnumerable<IdentifierDiscoveryResult> endpoints = FilterAndSortEndpoints(serviceEndpoints, relyingParty); + var results = new List<AuthenticationRequest>(); // Maintain a list of endpoints that we could not form an association with. // We'll fallback to generating requests to these if the ones we CAN create @@ -450,7 +440,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { // In some scenarios (like the AJAX control wanting ALL auth requests possible), // we don't want to create associations with every Provider. But we'll use // associations where they are already formed from previous authentications. - association = createNewAssociationsAsNeeded ? relyingParty.AssociationManager.GetOrCreateAssociation(endpoint) : relyingParty.AssociationManager.GetExistingAssociation(endpoint); + association = createNewAssociationsAsNeeded ? await relyingParty.AssociationManager.GetOrCreateAssociationAsync(endpoint, cancellationToken) : relyingParty.AssociationManager.GetExistingAssociation(endpoint); if (association == null && createNewAssociationsAsNeeded) { Logger.OpenId.WarnFormat("Failed to create association with {0}. Skipping to next endpoint.", endpoint.ProviderEndpoint); @@ -461,7 +451,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { } } - yield return new AuthenticationRequest(endpoint, realm, returnToUrl, relyingParty); + results.Add(new AuthenticationRequest(endpoint, realm, returnToUrl, relyingParty)); } // Now that we've run out of endpoints that respond to association requests, @@ -481,10 +471,12 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { // because we've already tried. Let's not have it waste time trying again. var authRequest = new AuthenticationRequest(endpoint, realm, returnToUrl, relyingParty); authRequest.associationPreference = AssociationPreference.IfAlreadyEstablished; - yield return authRequest; + results.Add(authRequest); } } } + + return results; } /// <summary> @@ -535,8 +527,8 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// based on the properties in this instance. /// </summary> /// <returns>The message to send to the Provider.</returns> - private SignedResponseRequest CreateRequestMessage() { - Association association = this.GetAssociation(); + private async Task<SignedResponseRequest> CreateRequestMessageAsync(CancellationToken cancellationToken) { + Association association = await this.GetAssociationAsync(cancellationToken); SignedResponseRequest request; if (!this.IsExtensionOnly) { @@ -566,11 +558,11 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// Gets the association to use for this authentication request. /// </summary> /// <returns>The association to use; <c>null</c> to use 'dumb mode'.</returns> - private Association GetAssociation() { + private async Task<Association> GetAssociationAsync(CancellationToken cancellationToken) { Association association = null; switch (this.associationPreference) { case AssociationPreference.IfPossible: - association = this.RelyingParty.AssociationManager.GetOrCreateAssociation(this.DiscoveryResult); + association = await this.RelyingParty.AssociationManager.GetOrCreateAssociationAsync(this.DiscoveryResult, cancellationToken); if (association == null) { // Avoid trying to create the association again if the redirecting response // is generated again. diff --git a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/Extensions/UIUtilities.cs b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/Extensions/UIUtilities.cs index a5de08b..80bfe65 100644 --- a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/Extensions/UIUtilities.cs +++ b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/Extensions/UIUtilities.cs @@ -6,10 +6,12 @@ namespace DotNetOpenAuth.OpenId.RelyingParty.Extensions.UI { using System; - using System.Globalization; - using DotNetOpenAuth.Messaging; - using DotNetOpenAuth.OpenId.RelyingParty; - using Validation; +using System.Globalization; +using System.Threading; +using System.Threading.Tasks; +using DotNetOpenAuth.Messaging; +using DotNetOpenAuth.OpenId.RelyingParty; +using Validation; /// <summary> /// Constants used in implementing support for the UI extension. @@ -23,12 +25,13 @@ namespace DotNetOpenAuth.OpenId.RelyingParty.Extensions.UI { /// <param name="request">The authentication request to place in the window.</param> /// <param name="windowName">The name to assign to the popup window.</param> /// <returns>A string starting with 'window.open' and forming just that one method call.</returns> - internal static string GetWindowPopupScript(OpenIdRelyingParty relyingParty, IAuthenticationRequest request, string windowName) { + internal static async Task<string> GetWindowPopupScriptAsync(OpenIdRelyingParty relyingParty, IAuthenticationRequest request, string windowName, CancellationToken cancellationToken) { Requires.NotNull(relyingParty, "relyingParty"); Requires.NotNull(request, "request"); Requires.NotNullOrEmpty(windowName, "windowName"); - Uri popupUrl = request.RedirectingResponse.GetDirectUriRequest(relyingParty.Channel); + var response = await request.GetRedirectingResponseAsync(cancellationToken); + Uri popupUrl = response.GetDirectUriRequest(); return string.Format( CultureInfo.InvariantCulture, diff --git a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/OpenIdRelyingParty.cs b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/OpenIdRelyingParty.cs index 2177591..a55e042 100644 --- a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/OpenIdRelyingParty.cs +++ b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/OpenIdRelyingParty.cs @@ -14,8 +14,12 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { using System.Globalization; using System.Linq; using System.Net; + using System.Net.Http; + using System.Net.Http.Headers; using System.Net.Mime; using System.Text; + using System.Threading; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Configuration; using DotNetOpenAuth.Messaging; @@ -272,11 +276,10 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { } /// <summary> - /// Gets the web request handler to use for discovery and the part of - /// authentication where direct messages are sent to an untrusted remote party. + /// Gets the factory for various dependencies. /// </summary> - IDirectWebRequestHandler IOpenIdHost.WebRequestHandler { - get { return this.Channel.WebRequestHandler; } + IHostFactories IOpenIdHost.HostFactories { + get { return this.channel.HostFactories; } } /// <summary> @@ -288,14 +291,6 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { } /// <summary> - /// Gets the web request handler to use for discovery and the part of - /// authentication where direct messages are sent to an untrusted remote party. - /// </summary> - internal IDirectWebRequestHandler WebRequestHandler { - get { return this.Channel.WebRequestHandler; } - } - - /// <summary> /// Gets the association manager. /// </summary> internal AssociationManager AssociationManager { get; private set; } @@ -339,12 +334,12 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// an object to send to the user agent to initiate the authentication. /// </returns> /// <exception cref="ProtocolException">Thrown if no OpenID endpoint could be found.</exception> - public IAuthenticationRequest CreateRequest(Identifier userSuppliedIdentifier, Realm realm, Uri returnToUrl) { + public async Task<IAuthenticationRequest> CreateRequestAsync(Identifier userSuppliedIdentifier, Realm realm, Uri returnToUrl, CancellationToken cancellationToken = default(CancellationToken)) { Requires.NotNull(userSuppliedIdentifier, "userSuppliedIdentifier"); Requires.NotNull(realm, "realm"); Requires.NotNull(returnToUrl, "returnToUrl"); try { - return this.CreateRequests(userSuppliedIdentifier, realm, returnToUrl).First(); + return (await this.CreateRequestsAsync(userSuppliedIdentifier, realm, returnToUrl)).First(); } catch (InvalidOperationException ex) { throw ErrorUtilities.Wrap(ex, OpenIdStrings.OpenIdEndpointNotFound); } @@ -371,11 +366,11 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// </remarks> /// <exception cref="ProtocolException">Thrown if no OpenID endpoint could be found.</exception> /// <exception cref="InvalidOperationException">Thrown if <see cref="HttpContext.Current">HttpContext.Current</see> == <c>null</c>.</exception> - public IAuthenticationRequest CreateRequest(Identifier userSuppliedIdentifier, Realm realm) { + public async Task<IAuthenticationRequest> CreateRequestAsync(Identifier userSuppliedIdentifier, Realm realm) { Requires.NotNull(userSuppliedIdentifier, "userSuppliedIdentifier"); Requires.NotNull(realm, "realm"); try { - var result = this.CreateRequests(userSuppliedIdentifier, realm).First(); + var result = (await this.CreateRequestsAsync(userSuppliedIdentifier, realm)).First(); Assumes.True(result != null); return result; } catch (InvalidOperationException ex) { @@ -399,10 +394,10 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// </remarks> /// <exception cref="ProtocolException">Thrown if no OpenID endpoint could be found.</exception> /// <exception cref="InvalidOperationException">Thrown if <see cref="HttpContext.Current">HttpContext.Current</see> == <c>null</c>.</exception> - public IAuthenticationRequest CreateRequest(Identifier userSuppliedIdentifier) { + public async Task<IAuthenticationRequest> CreateRequestAsync(Identifier userSuppliedIdentifier) { Requires.NotNull(userSuppliedIdentifier, "userSuppliedIdentifier"); try { - return this.CreateRequests(userSuppliedIdentifier).First(); + return (await this.CreateRequestsAsync(userSuppliedIdentifier)).First(); } catch (InvalidOperationException ex) { throw ErrorUtilities.Wrap(ex, OpenIdStrings.OpenIdEndpointNotFound); } @@ -435,12 +430,13 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// <para>No exception is thrown if no OpenID endpoints were discovered. /// An empty enumerable is returned instead.</para> /// </remarks> - public virtual IEnumerable<IAuthenticationRequest> CreateRequests(Identifier userSuppliedIdentifier, Realm realm, Uri returnToUrl) { + public virtual async Task<IEnumerable<IAuthenticationRequest>> CreateRequestsAsync(Identifier userSuppliedIdentifier, Realm realm, Uri returnToUrl, CancellationToken cancellationToken = default(CancellationToken)) { Requires.NotNull(userSuppliedIdentifier, "userSuppliedIdentifier"); Requires.NotNull(realm, "realm"); Requires.NotNull(returnToUrl, "returnToUrl"); - return AuthenticationRequest.Create(userSuppliedIdentifier, this, realm, returnToUrl, true).Cast<IAuthenticationRequest>().CacheGeneratedResults(); + var requests = await AuthenticationRequest.CreateAsync(userSuppliedIdentifier, this, realm, returnToUrl, true, cancellationToken); + return requests.Cast<IAuthenticationRequest>().CacheGeneratedResults(); } /// <summary> @@ -468,7 +464,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// <para>Requires an <see cref="HttpContext.Current">HttpContext.Current</see> context.</para> /// </remarks> /// <exception cref="InvalidOperationException">Thrown if <see cref="HttpContext.Current">HttpContext.Current</see> == <c>null</c>.</exception> - public IEnumerable<IAuthenticationRequest> CreateRequests(Identifier userSuppliedIdentifier, Realm realm) { + public async Task<IEnumerable<IAuthenticationRequest>> CreateRequestsAsync(Identifier userSuppliedIdentifier, Realm realm) { RequiresEx.ValidState(HttpContext.Current != null && HttpContext.Current.Request != null, MessagingStrings.HttpContextRequired); Requires.NotNull(userSuppliedIdentifier, "userSuppliedIdentifier"); Requires.NotNull(realm, "realm"); @@ -491,7 +487,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { } returnTo.AppendQueryArgs(returnToParams); - return this.CreateRequests(userSuppliedIdentifier, realm, returnTo.Uri); + return await this.CreateRequestsAsync(userSuppliedIdentifier, realm, returnTo.Uri); } /// <summary> @@ -514,11 +510,11 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// <para>Requires an <see cref="HttpContext.Current">HttpContext.Current</see> context.</para> /// </remarks> /// <exception cref="InvalidOperationException">Thrown if <see cref="HttpContext.Current">HttpContext.Current</see> == <c>null</c>.</exception> - public IEnumerable<IAuthenticationRequest> CreateRequests(Identifier userSuppliedIdentifier) { + public async Task<IEnumerable<IAuthenticationRequest>> CreateRequestsAsync(Identifier userSuppliedIdentifier) { Requires.NotNull(userSuppliedIdentifier, "userSuppliedIdentifier"); RequiresEx.ValidState(HttpContext.Current != null && HttpContext.Current.Request != null, MessagingStrings.HttpContextRequired); - return this.CreateRequests(userSuppliedIdentifier, Realm.AutoDetect); + return await this.CreateRequestsAsync(userSuppliedIdentifier, Realm.AutoDetect); } /// <summary> @@ -528,9 +524,9 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// <remarks> /// <para>Requires an <see cref="HttpContext.Current">HttpContext.Current</see> context.</para> /// </remarks> - public IAuthenticationResponse GetResponse() { + public Task<IAuthenticationResponse> GetResponseAsync(CancellationToken cancellationToken) { RequiresEx.ValidState(HttpContext.Current != null && HttpContext.Current.Request != null, MessagingStrings.HttpContextRequired); - return this.GetResponse(this.Channel.GetRequestFromContext()); + return this.GetResponseAsync(this.Channel.GetRequestFromContext(), cancellationToken); } /// <summary> @@ -538,10 +534,10 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// </summary> /// <param name="httpRequestInfo">The HTTP request that may be carrying an authentication response from the Provider.</param> /// <returns>The processed authentication response if there is any; <c>null</c> otherwise.</returns> - public IAuthenticationResponse GetResponse(HttpRequestBase httpRequestInfo) { + public async Task<IAuthenticationResponse> GetResponseAsync(HttpRequestBase httpRequestInfo, CancellationToken cancellationToken) { Requires.NotNull(httpRequestInfo, "httpRequestInfo"); try { - var message = this.Channel.ReadFromRequest(httpRequestInfo); + var message = await this.Channel.ReadFromRequestAsync(httpRequestInfo, cancellationToken); PositiveAssertionResponse positiveAssertion; NegativeAssertionResponse negativeAssertion; IndirectSignedResponse positiveExtensionOnly; @@ -554,7 +550,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { OpenIdStrings.PositiveAssertionFromNonQualifiedProvider, providerEndpoint.Uri); - var response = new PositiveAuthenticationResponse(positiveAssertion, this); + var response = await PositiveAuthenticationResponse.CreateAsync(positiveAssertion, this, cancellationToken); foreach (var behavior in this.Behaviors) { behavior.OnIncomingPositiveAssertion(response); } @@ -581,10 +577,10 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// <remarks> /// <para>Requires an <see cref="HttpContext.Current">HttpContext.Current</see> context.</para> /// </remarks> - public OutgoingWebResponse ProcessResponseFromPopup() { + public Task<HttpResponseMessage> ProcessResponseFromPopupAsync(CancellationToken cancellationToken) { RequiresEx.ValidState(HttpContext.Current != null && HttpContext.Current.Request != null, MessagingStrings.HttpContextRequired); - return this.ProcessResponseFromPopup(this.Channel.GetRequestFromContext()); + return this.ProcessResponseFromPopupAsync(this.Channel.GetRequestFromContext(), cancellationToken); } /// <summary> @@ -592,10 +588,10 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// </summary> /// <param name="request">The incoming HTTP request that is expected to carry an OpenID authentication response.</param> /// <returns>The HTTP response to send to this HTTP request.</returns> - public OutgoingWebResponse ProcessResponseFromPopup(HttpRequestBase request) { + public Task<HttpResponseMessage> ProcessResponseFromPopupAsync(HttpRequestBase request, CancellationToken cancellationToken) { Requires.NotNull(request, "request"); - return this.ProcessResponseFromPopup(request, null); + return this.ProcessResponseFromPopupAsync(request, null, cancellationToken); } /// <summary> @@ -678,11 +674,11 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// The HTTP response to send to this HTTP request. /// </returns> [SuppressMessage("Microsoft.Naming", "CA2204:Literals should be spelled correctly", MessageId = "OpenID", Justification = "real word"), SuppressMessage("Microsoft.Naming", "CA2204:Literals should be spelled correctly", MessageId = "iframe", Justification = "Code contracts")] - internal OutgoingWebResponse ProcessResponseFromPopup(HttpRequestBase request, Action<AuthenticationStatus> callback) { + internal async Task<HttpResponseMessage> ProcessResponseFromPopupAsync(HttpRequestBase request, Action<AuthenticationStatus> callback, CancellationToken cancellationToken) { Requires.NotNull(request, "request"); string extensionsJson = null; - var authResponse = this.NonVerifyingRelyingParty.GetResponse(); + var authResponse = await this.NonVerifyingRelyingParty.GetResponseAsync(cancellationToken); ErrorUtilities.VerifyProtocol(authResponse != null, OpenIdStrings.PopupRedirectMissingResponse); // Give the caller a chance to notify the hosting page and fill up the clientScriptExtensions collection. @@ -734,8 +730,8 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// </summary> /// <param name="identifier">The identifier to discover services for.</param> /// <returns>A non-null sequence of services discovered for the identifier.</returns> - internal IEnumerable<IdentifierDiscoveryResult> Discover(Identifier identifier) { - return this.discoveryServices.Discover(identifier); + internal Task<IEnumerable<IdentifierDiscoveryResult>> DiscoverAsync(Identifier identifier, CancellationToken cancellationToken) { + return this.discoveryServices.DiscoverAsync(identifier, cancellationToken); } /// <summary> @@ -795,7 +791,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// <param name="methodCall">The method to call on the parent window, including /// parameters. (i.e. "callback('arg1', 2)"). No escaping is done by this method.</param> /// <returns>The entire HTTP response to send to the popup window or iframe to perform the invocation.</returns> - private static OutgoingWebResponse InvokeParentPageScript(string methodCall) { + private static HttpResponseMessage InvokeParentPageScript(string methodCall) { Requires.NotNullOrEmpty(methodCall, "methodCall"); Logger.OpenId.DebugFormat("Sending Javascript callback: {0}", methodCall); @@ -824,9 +820,9 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { builder.AppendLine("//]]>--></script>"); builder.AppendLine("</body></html>"); - var response = new OutgoingWebResponse(); - response.Body = builder.ToString(); - response.Headers.Add(HttpResponseHeader.ContentType, new ContentType("text/html").ToString()); + var response = new HttpResponseMessage(); + response.Content = new StringContent(builder.ToString()); + response.Content.Headers.ContentType = new MediaTypeHeaderValue("text/html"); return response; } diff --git a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/PositiveAuthenticationResponse.cs b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/PositiveAuthenticationResponse.cs index 509eb60..f05abaa 100644 --- a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/PositiveAuthenticationResponse.cs +++ b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/PositiveAuthenticationResponse.cs @@ -8,6 +8,8 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { using System; using System.Diagnostics; using System.Linq; + using System.Threading; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId.Messages; @@ -24,7 +26,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// </summary> /// <param name="response">The positive assertion response that was just received by the Relying Party.</param> /// <param name="relyingParty">The relying party.</param> - internal PositiveAuthenticationResponse(PositiveAssertionResponse response, OpenIdRelyingParty relyingParty) + private PositiveAuthenticationResponse(PositiveAssertionResponse response, OpenIdRelyingParty relyingParty) : base(response) { Requires.NotNull(relyingParty, "relyingParty"); @@ -36,8 +38,6 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { null, null); - this.VerifyDiscoveryMatchesAssertion(relyingParty); - Logger.OpenId.InfoFormat("Received identity assertion for {0} via {1}.", this.Response.ClaimedIdentifier, this.Provider.Uri); } @@ -123,6 +123,13 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { get { return (PositiveAssertionResponse)base.Response; } } + internal static async Task<PositiveAuthenticationResponse> CreateAsync( + PositiveAssertionResponse response, OpenIdRelyingParty relyingParty, CancellationToken cancellationToken) { + var result = new PositiveAuthenticationResponse(response, relyingParty); + await result.VerifyDiscoveryMatchesAssertionAsync(relyingParty, cancellationToken); + return result; + } + /// <summary> /// Verifies that the positive assertion data matches the results of /// discovery on the Claimed Identifier. @@ -134,7 +141,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// This would be an indication of either a misconfigured Provider or /// an attempt by someone to spoof another user's identity with a rogue Provider. /// </exception> - private void VerifyDiscoveryMatchesAssertion(OpenIdRelyingParty relyingParty) { + private async Task VerifyDiscoveryMatchesAssertionAsync(OpenIdRelyingParty relyingParty, CancellationToken cancellationToken) { Logger.OpenId.Debug("Verifying assertion matches identifier discovery results..."); // Ensure that we abide by the RP's rules regarding RequireSsl for this discovery step. @@ -163,7 +170,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { // is signed by the RP before it's considered reliable. In 1.x stateless mode, this RP // doesn't (and can't) sign its own return_to URL, so its cached discovery information // is merely a hint that must be verified by performing discovery again here. - var discoveryResults = relyingParty.Discover(claimedId); + var discoveryResults = await relyingParty.DiscoverAsync(claimedId, cancellationToken); ErrorUtilities.VerifyProtocol( discoveryResults.Contains(this.Endpoint), OpenIdStrings.IssuedAssertionFailsIdentifierDiscovery, diff --git a/src/DotNetOpenAuth.OpenId.RelyingParty/packages.config b/src/DotNetOpenAuth.OpenId.RelyingParty/packages.config index 58890d8..1d93cf5 100644 --- a/src/DotNetOpenAuth.OpenId.RelyingParty/packages.config +++ b/src/DotNetOpenAuth.OpenId.RelyingParty/packages.config @@ -1,4 +1,5 @@ <?xml version="1.0" encoding="utf-8"?> <packages> + <package id="Microsoft.Net.Http" version="2.0.20710.0" targetFramework="net45" /> <package id="Validation" version="2.0.1.12362" targetFramework="net45" /> </packages>
\ No newline at end of file diff --git a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/BackwardCompatibilityBindingElement.cs b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/BackwardCompatibilityBindingElement.cs index ff8a766..c448e2f 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/BackwardCompatibilityBindingElement.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/BackwardCompatibilityBindingElement.cs @@ -6,6 +6,8 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { using System; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Reflection; using DotNetOpenAuth.OpenId.Messages; @@ -16,6 +18,11 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// are required to send back with positive assertions. /// </summary> internal class BackwardCompatibilityBindingElement : IChannelBindingElement { + private static readonly Task<MessageProtections?> NullTask = Task.FromResult<MessageProtections?>(null); + + private static readonly Task<MessageProtections?> NoneTask = + Task.FromResult<MessageProtections?>(MessageProtections.None); + /// <summary> /// The "dnoa.op_endpoint" callback parameter that stores the Provider Endpoint URL /// to tack onto the return_to URI. @@ -59,7 +66,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - public MessageProtections? ProcessOutgoingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { SignedResponseRequest request = message as SignedResponseRequest; if (request != null && request.Version.Major < 2) { request.AddReturnToArguments(ProviderEndpointParameterName, request.Recipient.AbsoluteUri); @@ -69,10 +76,10 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { request.AddReturnToArguments(ClaimedIdentifierParameterName, authRequest.ClaimedIdentifier); } - return MessageProtections.None; + return NoneTask; } - return null; + return NullTask; } /// <summary> @@ -92,7 +99,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - public MessageProtections? ProcessIncomingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { IndirectSignedResponse response = message as IndirectSignedResponse; if (response != null && response.Version.Major < 2) { // GetReturnToArgument may return parameters that are not signed, @@ -118,10 +125,10 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { } } - return MessageProtections.None; + return NoneTask; } - return null; + return NullTask; } #endregion diff --git a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/ExtensionsBindingElement.cs b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/ExtensionsBindingElement.cs index f24c8b4..2a5946a 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/ExtensionsBindingElement.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/ExtensionsBindingElement.cs @@ -10,6 +10,8 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Text; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Reflection; using DotNetOpenAuth.OpenId.Extensions; @@ -21,6 +23,11 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// their carrying OpenID messages. /// </summary> internal class ExtensionsBindingElement : IChannelBindingElement { + private static readonly Task<MessageProtections?> NullTask = Task.FromResult<MessageProtections?>(null); + + private static readonly Task<MessageProtections?> NoneTask = + Task.FromResult<MessageProtections?>(MessageProtections.None); + /// <summary> /// False if unsigned extensions should be dropped. Must always be true on Providers, since RPs never sign extensions. /// </summary> @@ -77,7 +84,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> [SuppressMessage("Microsoft.Maintainability", "CA1506:AvoidExcessiveClassCoupling", Justification = "It doesn't look too bad to me. :)")] - public MessageProtections? ProcessOutgoingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { var extendableMessage = message as IProtocolMessageWithExtensions; if (extendableMessage != null) { Protocol protocol = Protocol.Lookup(message.Version); @@ -120,10 +127,10 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { // Add the extension parameters to the base message for transmission. baseMessageDictionary.AddExtraParameters(extensionManager.GetArgumentsToSend(includeOpenIdPrefix)); - return MessageProtections.None; + return NoneTask; } - return null; + return NullTask; } /// <summary> @@ -143,7 +150,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - public MessageProtections? ProcessIncomingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { var extendableMessage = message as IProtocolMessageWithExtensions; if (extendableMessage != null) { // First add the extensions that are signed by the Provider. @@ -164,10 +171,10 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { } } - return MessageProtections.None; + return NoneTask; } - return null; + return NullTask; } #endregion diff --git a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/OpenIdChannel.cs b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/OpenIdChannel.cs index eb4ca65..221994a 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/OpenIdChannel.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/OpenIdChannel.cs @@ -16,6 +16,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { using System.Net.Http; using System.Net.Http.Headers; using System.Text; + using System.Threading; using System.Threading.Tasks; using DotNetOpenAuth.Configuration; @@ -51,7 +52,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// <param name="messageTypeProvider">A class prepared to analyze incoming messages and indicate what concrete /// message types can deserialize from it.</param> /// <param name="bindingElements">The binding elements to use in sending and receiving messages.</param> - protected OpenIdChannel(IMessageFactory messageTypeProvider, IChannelBindingElement[] bindingElements, IHostFactories hostFactories) + protected OpenIdChannel(IMessageFactory messageTypeProvider, IChannelBindingElement[] bindingElements, IHostFactories hostFactories = null) : base(messageTypeProvider, bindingElements, hostFactories ?? new DefaultOpenIdHostFactories()) { Requires.NotNull(messageTypeProvider, "messageTypeProvider"); @@ -84,18 +85,18 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Thrown when the message is somehow invalid, except for check_authentication messages. /// This can be due to tampering, replay attack or expiration, among other things. /// </exception> - protected override void ProcessIncomingMessage(IProtocolMessage message) { + protected override async Task ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { var checkAuthRequest = message as CheckAuthenticationRequest; if (checkAuthRequest != null) { IndirectSignedResponse originalResponse = new IndirectSignedResponse(checkAuthRequest, this); try { - base.ProcessIncomingMessage(originalResponse); + await base.ProcessIncomingMessageAsync(originalResponse, cancellationToken); checkAuthRequest.IsValid = true; } catch (ProtocolException) { checkAuthRequest.IsValid = false; } } else { - base.ProcessIncomingMessage(message); + await base.ProcessIncomingMessageAsync(message, cancellationToken); } // Convert an OpenID indirect error message, which we never expect diff --git a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/ReturnToSignatureBindingElement.cs b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/ReturnToSignatureBindingElement.cs index 726c01f..c55704d 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/ReturnToSignatureBindingElement.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/ReturnToSignatureBindingElement.cs @@ -9,6 +9,8 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { using System.Collections.Generic; using System.Collections.Specialized; using System.Security.Cryptography; + using System.Threading; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Configuration; using DotNetOpenAuth.Messaging; @@ -31,6 +33,11 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// anything except a particular message part.</para> /// </remarks> internal class ReturnToSignatureBindingElement : IChannelBindingElement { + private static readonly Task<MessageProtections?> NullTask = Task.FromResult<MessageProtections?>(null); + + private static readonly Task<MessageProtections?> NoneTask = + Task.FromResult<MessageProtections?>(MessageProtections.None); + /// <summary> /// The name of the callback parameter we'll tack onto the return_to value /// to store our signature on the return_to parameter. @@ -98,7 +105,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - public MessageProtections? ProcessOutgoingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { SignedResponseRequest request = message as SignedResponseRequest; if (request != null && request.ReturnTo != null && request.SignReturnTo) { var cryptoKeyPair = this.cryptoKeyStore.GetCurrentKey(SecretUri.AbsoluteUri, OpenIdElement.Configuration.MaxAuthenticationTime); @@ -107,10 +114,10 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { request.AddReturnToArguments(ReturnToSignatureParameterName, signature); // We return none because we are not signing the entire message (only a part). - return MessageProtections.None; + return NoneTask; } - return null; + return NullTask; } /// <summary> @@ -130,7 +137,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - public MessageProtections? ProcessIncomingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { IndirectSignedResponse response = message as IndirectSignedResponse; if (response != null) { @@ -150,11 +157,11 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { Logger.Bindings.WarnFormat("The return_to signature failed verification."); } - return MessageProtections.None; + return NoneTask; } } - return null; + return NullTask; } #endregion diff --git a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/SigningBindingElement.cs b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/SigningBindingElement.cs index 584b0e9..83d45a1 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/SigningBindingElement.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/SigningBindingElement.cs @@ -11,6 +11,8 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { using System.Globalization; using System.Linq; using System.Net.Security; + using System.Threading; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Loggers; using DotNetOpenAuth.Messaging; @@ -23,6 +25,8 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Signs and verifies authentication assertions. /// </summary> internal abstract class SigningBindingElement : IChannelBindingElement { + private static readonly Task<MessageProtections?> NullTask = Task.FromResult<MessageProtections?>(null); + #region IChannelBindingElement Properties /// <summary> @@ -57,8 +61,8 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// The protections (if any) that this binding element applied to the message. /// Null if this binding element did not even apply to this binding element. /// </returns> - public virtual MessageProtections? ProcessOutgoingMessage(IProtocolMessage message) { - return null; + public virtual Task<MessageProtections?> ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { + return NullTask; } /// <summary> @@ -74,7 +78,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Thrown when the binding element rules indicate that this message is invalid and should /// NOT be processed. /// </exception> - public MessageProtections? ProcessIncomingMessage(IProtocolMessage message) { + public async Task<MessageProtections?> ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { var signedMessage = message as ITamperResistantOpenIdMessage; if (signedMessage != null) { Logger.Bindings.DebugFormat("Verifying incoming {0} message signature of: {1}", message.GetType().Name, signedMessage.Signature); @@ -92,7 +96,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { } else { ErrorUtilities.VerifyInternal(this.Channel != null, "Cannot verify private association signature because we don't have a channel."); - protectionsApplied = this.VerifySignatureByUnrecognizedHandle(message, signedMessage, protectionsApplied); + protectionsApplied = await this.VerifySignatureByUnrecognizedHandleAsync(message, signedMessage, protectionsApplied, cancellationToken); } return protectionsApplied; @@ -108,7 +112,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// <param name="signedMessage">The signed message.</param> /// <param name="protectionsApplied">The protections applied.</param> /// <returns>The applied protections.</returns> - protected abstract MessageProtections VerifySignatureByUnrecognizedHandle(IProtocolMessage message, ITamperResistantOpenIdMessage signedMessage, MessageProtections protectionsApplied); + protected abstract Task<MessageProtections> VerifySignatureByUnrecognizedHandleAsync(IProtocolMessage message, ITamperResistantOpenIdMessage signedMessage, MessageProtections protectionsApplied, CancellationToken cancellationToken); #endregion diff --git a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/SkipSecurityBindingElement.cs b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/SkipSecurityBindingElement.cs index d162cf6..900a422 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/SkipSecurityBindingElement.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/ChannelElements/SkipSecurityBindingElement.cs @@ -10,12 +10,16 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { using System.Diagnostics; using System.Linq; using System.Text; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; /// <summary> /// Spoofs security checks on incoming OpenID messages. /// </summary> internal class SkipSecurityBindingElement : IChannelBindingElement { + private static readonly Task<MessageProtections?> NullTask = Task.FromResult<MessageProtections?>(null); + #region IChannelBindingElement Members /// <summary> @@ -50,7 +54,7 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - public MessageProtections? ProcessOutgoingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { Debug.Fail("SkipSecurityBindingElement.ProcessOutgoingMessage should never be called."); return null; } @@ -72,14 +76,14 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Implementations that provide message protection must honor the /// <see cref="MessagePartAttribute.RequiredProtection"/> properties where applicable. /// </remarks> - public MessageProtections? ProcessIncomingMessage(IProtocolMessage message) { + public Task<MessageProtections?> ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { var signedMessage = message as ITamperResistantOpenIdMessage; if (signedMessage != null) { Logger.Bindings.DebugFormat("Skipped security checks of incoming {0} message for preview purposes.", message.GetType().Name); - return this.Protection; + return Task.FromResult<MessageProtections?>(this.Protection); } - return null; + return NullTask; } #endregion diff --git a/src/DotNetOpenAuth.OpenId/OpenId/IOpenIdHost.cs b/src/DotNetOpenAuth.OpenId/OpenId/IOpenIdHost.cs index 0c5bf80..cf52fef 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/IOpenIdHost.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/IOpenIdHost.cs @@ -22,7 +22,7 @@ namespace DotNetOpenAuth.OpenId { SecuritySettings SecuritySettings { get; } /// <summary> - /// Gets the web request handler. + /// Gets the factory for various dependencies. /// </summary> IHostFactories HostFactories { get; } } diff --git a/src/DotNetOpenAuth.OpenId/OpenId/Messages/NegativeAssertionResponse.cs b/src/DotNetOpenAuth.OpenId/OpenId/Messages/NegativeAssertionResponse.cs index d67e9fe..1bb52b9 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/Messages/NegativeAssertionResponse.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/Messages/NegativeAssertionResponse.cs @@ -9,6 +9,7 @@ namespace DotNetOpenAuth.OpenId.Messages { using System.Collections.Generic; using System.Linq; using System.Text; + using System.Threading; using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using Validation; @@ -22,9 +23,19 @@ namespace DotNetOpenAuth.OpenId.Messages { /// <summary> /// Initializes a new instance of the <see cref="NegativeAssertionResponse"/> class. /// </summary> - /// <param name="request">The request that the relying party sent.</param> - internal NegativeAssertionResponse(CheckIdRequest request) - : this(request, null) { + /// <param name="request">The request.</param> + private NegativeAssertionResponse(SignedResponseRequest request) + : base(request, GetMode(request)) { + } + + /// <summary> + /// Initializes a new instance of the <see cref="NegativeAssertionResponse"/> class. + /// </summary> + /// <param name="version">The version.</param> + /// <param name="relyingPartyReturnTo">The relying party return to.</param> + /// <param name="mode">The value of the openid.mode parameter.</param> + internal NegativeAssertionResponse(Version version, Uri relyingPartyReturnTo, string mode) + : base(version, relyingPartyReturnTo, mode) { } /// <summary> @@ -32,24 +43,17 @@ namespace DotNetOpenAuth.OpenId.Messages { /// </summary> /// <param name="request">The request that the relying party sent.</param> /// <param name="channel">The channel to use to simulate construction of the user_setup_url, if applicable. May be null, but the user_setup_url will not be constructed.</param> - internal NegativeAssertionResponse(SignedResponseRequest request, Channel channel) - : base(request, GetMode(request)) { + internal static async Task<NegativeAssertionResponse> CreateAsync(SignedResponseRequest request, CancellationToken cancellationToken, Channel channel = null) { + var result = new NegativeAssertionResponse(request); + // If appropriate, and when we're provided with a channel to do it, // go ahead and construct the user_setup_url - if (this.Version.Major < 2 && request.Immediate && channel != null) { + if (result.Version.Major < 2 && request.Immediate && channel != null) { // All requests are CheckIdRequests in OpenID 1.x, so this cast should be safe. - this.UserSetupUrl = ConstructUserSetupUrl((CheckIdRequest)request, channel); + result.UserSetupUrl = await ConstructUserSetupUrlAsync((CheckIdRequest)request, channel, cancellationToken); } - } - /// <summary> - /// Initializes a new instance of the <see cref="NegativeAssertionResponse"/> class. - /// </summary> - /// <param name="version">The version.</param> - /// <param name="relyingPartyReturnTo">The relying party return to.</param> - /// <param name="mode">The value of the openid.mode parameter.</param> - internal NegativeAssertionResponse(Version version, Uri relyingPartyReturnTo, string mode) - : base(version, relyingPartyReturnTo, mode) { + return result; } /// <summary> @@ -114,7 +118,7 @@ namespace DotNetOpenAuth.OpenId.Messages { /// <param name="immediateRequest">The immediate request.</param> /// <param name="channel">The channel to use to simulate construction of the message.</param> /// <returns>The value to use for the user_setup_url parameter.</returns> - private static Uri ConstructUserSetupUrl(CheckIdRequest immediateRequest, Channel channel) { + private static async Task<Uri> ConstructUserSetupUrlAsync(CheckIdRequest immediateRequest, Channel channel, CancellationToken cancellationToken) { Requires.NotNull(immediateRequest, "immediateRequest"); Requires.NotNull(channel, "channel"); ErrorUtilities.VerifyInternal(immediateRequest.Immediate, "Only immediate requests should be sent here."); @@ -124,7 +128,7 @@ namespace DotNetOpenAuth.OpenId.Messages { setupRequest.ReturnTo = immediateRequest.ReturnTo; setupRequest.Realm = immediateRequest.Realm; setupRequest.AssociationHandle = immediateRequest.AssociationHandle; - var response = channel.PrepareResponse(setupRequest); + var response = await channel.PrepareResponseAsync(setupRequest, cancellationToken); return response.GetDirectUriRequest(); } diff --git a/src/DotNetOpenAuth.OpenId/OpenId/OpenIdUtilities.cs b/src/DotNetOpenAuth.OpenId/OpenId/OpenIdUtilities.cs index f0ed946..f8d542d 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/OpenIdUtilities.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/OpenIdUtilities.cs @@ -187,24 +187,36 @@ namespace DotNetOpenAuth.OpenId { internal static HttpClient CreateHttpClient(this IHostFactories hostFactories, bool requireSsl, RequestCachePolicy cachePolicy = null) { Requires.NotNull(hostFactories, "hostFactories"); - var handler = hostFactories.CreateHttpMessageHandler(); - var webRequestHandler = handler as WebRequestHandler; - var untrustedHandler = handler as UntrustedWebRequestHandler; - if (webRequestHandler != null) { - if (cachePolicy != null) { - webRequestHandler.CachePolicy = cachePolicy; - } - } else if (untrustedHandler != null) { - if (cachePolicy != null) { - untrustedHandler.CachePolicy = cachePolicy; - } + var rootHandler = hostFactories.CreateHttpMessageHandler(); + var handler = rootHandler; + do { + var webRequestHandler = handler as WebRequestHandler; + var untrustedHandler = handler as UntrustedWebRequestHandler; + var delegatingHandler = handler as DelegatingHandler; + if (webRequestHandler != null) { + if (cachePolicy != null) { + webRequestHandler.CachePolicy = cachePolicy; + } - untrustedHandler.IsSslRequired = requireSsl; - } else { - Logger.Http.DebugFormat("Unable to set cache policy on unsupported {0}.", handler.GetType().FullName); + break; + } else if (untrustedHandler != null) { + if (cachePolicy != null) { + untrustedHandler.CachePolicy = cachePolicy; + } + + untrustedHandler.IsSslRequired = requireSsl; + break; + } else if (delegatingHandler != null) { + handler = delegatingHandler.InnerHandler; + } else { + Logger.Http.DebugFormat("Unable to set cache policy on unsupported {0}.", handler.GetType().FullName); + break; + } } + while (true); + - return hostFactories.CreateHttpClient(handler); + return hostFactories.CreateHttpClient(rootHandler); } internal static Uri GetDirectUriRequest(this HttpResponseMessage response) { diff --git a/src/DotNetOpenAuth.OpenId/OpenId/RelyingParty/IAuthenticationRequest.cs b/src/DotNetOpenAuth.OpenId/OpenId/RelyingParty/IAuthenticationRequest.cs index 35a92e1..34a2595 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/RelyingParty/IAuthenticationRequest.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/RelyingParty/IAuthenticationRequest.cs @@ -10,6 +10,8 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { using System.Linq; using System.Net.Http; using System.Text; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId.Messages; @@ -25,12 +27,6 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { AuthenticationRequestMode Mode { get; set; } /// <summary> - /// Gets the HTTP response the relying party should send to the user agent - /// to redirect it to the OpenID Provider to start the OpenID authentication process. - /// </summary> - HttpResponseMessage RedirectingResponse { get; } - - /// <summary> /// Gets the URL that the user agent will return to after authentication /// completes or fails at the Provider. /// </summary> @@ -174,12 +170,9 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { void AddExtension(IOpenIdMessageExtension extension); /// <summary> - /// Redirects the user agent to the provider for authentication. - /// Execution of the current page terminates after this call. + /// Gets the HTTP response the relying party should send to the user agent + /// to redirect it to the OpenID Provider to start the OpenID authentication process. /// </summary> - /// <remarks> - /// This method requires an ASP.NET HttpContext. - /// </remarks> - void RedirectToProvider(); + Task<HttpResponseMessage> GetRedirectingResponseAsync(CancellationToken cancellationToken); } } diff --git a/src/DotNetOpenAuth.OpenId/OpenId/UntrustedWebRequestHandler.cs b/src/DotNetOpenAuth.OpenId/OpenId/UntrustedWebRequestHandler.cs index c61ac7f..25d4bb6 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/UntrustedWebRequestHandler.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/UntrustedWebRequestHandler.cs @@ -207,6 +207,31 @@ namespace DotNetOpenAuth.OpenId { return client; } + /// <summary> + /// Determines whether an exception was thrown because of the remote HTTP server returning HTTP 417 Expectation Failed. + /// </summary> + /// <param name="ex">The caught exception.</param> + /// <returns> + /// <c>true</c> if the failure was originally caused by a 417 Exceptation Failed error; otherwise, <c>false</c>. + /// </returns> + internal static bool IsExceptionFrom417ExpectationFailed(Exception ex) { + while (ex != null) { + WebException webEx = ex as WebException; + if (webEx != null) { + HttpWebResponse response = webEx.Response as HttpWebResponse; + if (response != null) { + if (response.StatusCode == HttpStatusCode.ExpectationFailed) { + return true; + } + } + } + + ex = ex.InnerException; + } + + return false; + } + protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { this.EnsureAllowableRequestUri(request.RequestUri); @@ -226,9 +251,24 @@ namespace DotNetOpenAuth.OpenId { ErrorUtilities.VerifyProtocol(request.Method != HttpMethod.Post, MessagingStrings.UntrustedRedirectsOnPOSTNotSupported); Uri redirectUri = new Uri(request.RequestUri, response.Headers.Location); request = request.Clone(redirectUri); - } else { - return response; + continue; } + + if (response.StatusCode == HttpStatusCode.ExpectationFailed) { + // Some OpenID servers doesn't understand the Expect header and send 417 error back. + // If this server just failed from that, alter the ServicePoint for this server + // so that we don't send that header again next time (whenever that is). + // "Expect: 100-Continue" HTTP header. (see Google Code Issue 72) + // We don't want to blindly set all ServicePoints to not use the Expect header + // as that would be a security hole allowing any visitor to a web site change + // the web site's global behavior when calling that host. + // TODO: verify that this still works in DNOA 5.0 + var servicePoint = ServicePointManager.FindServicePoint(request.RequestUri); + Logger.Http.InfoFormat("HTTP POST to {0} resulted in 417 Expectation Failed. Changing ServicePoint to not use Expect: Continue next time.", request.RequestUri); + servicePoint.Expect100Continue = false; + } + + return response; } throw ErrorUtilities.ThrowProtocol(MessagingStrings.TooManyRedirects, originalRequestUri); diff --git a/src/DotNetOpenAuth.sln b/src/DotNetOpenAuth.sln index c56772f..db126bc 100644 --- a/src/DotNetOpenAuth.sln +++ b/src/DotNetOpenAuth.sln @@ -535,12 +535,12 @@ Global {C7EF1823-3AA7-477E-8476-28929F5C05D2} = {8D4236F7-C49B-49D3-BA71-6B86C9514BDE} {9AF74F53-10F5-49A2-B747-87B97CD559D3} = {8D4236F7-C49B-49D3-BA71-6B86C9514BDE} {529B4262-6B5A-4EF9-BD3B-1D29A2597B67} = {8D4236F7-C49B-49D3-BA71-6B86C9514BDE} - {238B6BA8-AD99-43C9-B8E2-D2BCE6CE04DC} = {8D4236F7-C49B-49D3-BA71-6B86C9514BDE} {57A7DD35-666C-4FA3-9A1B-38961E50CA27} = {8D4236F7-C49B-49D3-BA71-6B86C9514BDE} {60426312-6AE5-4835-8667-37EDEA670222} = {8D4236F7-C49B-49D3-BA71-6B86C9514BDE} {173E7B8D-E751-46E2-A133-F72297C0D2F4} = {8D4236F7-C49B-49D3-BA71-6B86C9514BDE} {51835086-9611-4C53-819B-F2D5C9320873} = {8D4236F7-C49B-49D3-BA71-6B86C9514BDE} {115217C5-22CD-415C-A292-0DD0238CDD89} = {8D4236F7-C49B-49D3-BA71-6B86C9514BDE} + {238B6BA8-AD99-43C9-B8E2-D2BCE6CE04DC} = {8D4236F7-C49B-49D3-BA71-6B86C9514BDE} {3896A32A-E876-4C23-B9B8-78E17D134CD3} = {C7EF1823-3AA7-477E-8476-28929F5C05D2} {F8284738-3B5D-4733-A511-38C23F4A763F} = {C7EF1823-3AA7-477E-8476-28929F5C05D2} {F458AB60-BA1C-43D9-8CEF-EC01B50BE87B} = {C7EF1823-3AA7-477E-8476-28929F5C05D2} |