diff options
Diffstat (limited to 'src/DotNetOpenAuth.Test')
75 files changed, 2110 insertions, 3237 deletions
diff --git a/src/DotNetOpenAuth.Test/AutoRedirectHandler.cs b/src/DotNetOpenAuth.Test/AutoRedirectHandler.cs new file mode 100644 index 0000000..3f67259 --- /dev/null +++ b/src/DotNetOpenAuth.Test/AutoRedirectHandler.cs @@ -0,0 +1,33 @@ +namespace DotNetOpenAuth.Test { + using System; + using System.Collections.Generic; + using System.Linq; + using System.Net; + using System.Net.Http; + using System.Text; + using System.Threading.Tasks; + + using DotNetOpenAuth.Messaging; + + internal class AutoRedirectHandler : DelegatingHandler { + internal AutoRedirectHandler(HttpMessageHandler innerHandler) + : base(innerHandler) { + } + + protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, System.Threading.CancellationToken cancellationToken) { + HttpResponseMessage response = null; + do { + if (response != null) { + var modifiedRequest = MessagingUtilities.Clone(request); + modifiedRequest.RequestUri = new Uri(request.RequestUri, response.Headers.Location); + request = modifiedRequest; + } + + response = await base.SendAsync(request, cancellationToken); + } + while (response.StatusCode == HttpStatusCode.Redirect); + + return response; + } + } +} diff --git a/src/DotNetOpenAuth.Test/CookieContainerExtensions.cs b/src/DotNetOpenAuth.Test/CookieContainerExtensions.cs new file mode 100644 index 0000000..5a09d13 --- /dev/null +++ b/src/DotNetOpenAuth.Test/CookieContainerExtensions.cs @@ -0,0 +1,38 @@ +//----------------------------------------------------------------------- +// <copyright file="CookieContainerExtensions.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Test { + using System; + using System.Collections.Generic; + using System.Net; + using System.Net.Http; + + using Validation; + + internal static class CookieContainerExtensions { + internal static void SetCookies(this CookieContainer container, HttpResponseMessage response, Uri requestUri = null) { + Requires.NotNull(container, "container"); + Requires.NotNull(response, "response"); + + IEnumerable<string> cookieHeaders; + if (response.Headers.TryGetValues("Set-Cookie", out cookieHeaders)) { + foreach (string cookie in cookieHeaders) { + container.SetCookies(requestUri ?? response.RequestMessage.RequestUri, cookie); + } + } + } + + internal static void ApplyCookies(this CookieContainer container, HttpRequestMessage request) { + Requires.NotNull(container, "container"); + Requires.NotNull(request, "request"); + + string cookieHeader = container.GetCookieHeader(request.RequestUri); + if (!string.IsNullOrEmpty(cookieHeader)) { + request.Headers.TryAddWithoutValidation("Cookie", cookieHeader); + } + } + } +}
\ No newline at end of file diff --git a/src/DotNetOpenAuth.Test/CookieDelegatingHandler.cs b/src/DotNetOpenAuth.Test/CookieDelegatingHandler.cs new file mode 100644 index 0000000..1b25dc0 --- /dev/null +++ b/src/DotNetOpenAuth.Test/CookieDelegatingHandler.cs @@ -0,0 +1,31 @@ +//----------------------------------------------------------------------- +// <copyright file="CookieDelegatingHandler.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Test { + using System.Linq; + using System.Net; + using System.Net.Http; + using System.Net.Http.Headers; + using System.Text; + using System.Threading; + using System.Threading.Tasks; + + internal class CookieDelegatingHandler : DelegatingHandler { + internal CookieDelegatingHandler(HttpMessageHandler innerHandler, CookieContainer cookieContainer = null) + : base(innerHandler) { + this.Container = cookieContainer ?? new CookieContainer(); + } + + public CookieContainer Container { get; set; } + + protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { + this.Container.ApplyCookies(request); + var response = await base.SendAsync(request, cancellationToken); + this.Container.SetCookies(response); + return response; + } + } +} diff --git a/src/DotNetOpenAuth.Test/CoordinatorBase.cs b/src/DotNetOpenAuth.Test/CoordinatorBase.cs deleted file mode 100644 index d1c6f85..0000000 --- a/src/DotNetOpenAuth.Test/CoordinatorBase.cs +++ /dev/null @@ -1,91 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="CoordinatorBase.cs" company="Outercurve Foundation"> -// Copyright (c) Outercurve Foundation. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.Test { - using System; - using System.Threading; - using DotNetOpenAuth.Messaging; - using DotNetOpenAuth.OpenId.RelyingParty; - using DotNetOpenAuth.Test.Mocks; - using NUnit.Framework; - using Validation; - - internal abstract class CoordinatorBase<T1, T2> { - private Action<T1> party1Action; - private Action<T2> party2Action; - - protected CoordinatorBase(Action<T1> party1Action, Action<T2> party2Action) { - Requires.NotNull(party1Action, "party1Action"); - Requires.NotNull(party2Action, "party2Action"); - - this.party1Action = party1Action; - this.party2Action = party2Action; - } - - protected internal Action<IProtocolMessage> IncomingMessageFilter { get; set; } - - protected internal Action<IProtocolMessage> OutgoingMessageFilter { get; set; } - - internal abstract void Run(); - - protected void RunCore(T1 party1Object, T2 party2Object) { - Thread party1Thread = null, party2Thread = null; - Exception failingException = null; - - // Each thread we create needs a surrounding exception catcher so that we can - // terminate the other thread and inform the test host that the test failed. - Action<Action> safeWrapper = (action) => { - try { - TestBase.SetMockHttpContext(); - action(); - } catch (Exception ex) { - // We may be the second thread in an ThreadAbortException, so check the "flag" - lock (this) { - if (failingException == null || (failingException is ThreadAbortException && !(ex is ThreadAbortException))) { - failingException = ex; - if (Thread.CurrentThread == party1Thread) { - party2Thread.Abort(); - } else { - party1Thread.Abort(); - } - } - } - } - }; - - // Run the threads, and wait for them to complete. - // If this main thread is aborted (test run aborted), go ahead and abort the other two threads. - party1Thread = new Thread(() => { safeWrapper(() => { this.party1Action(party1Object); }); }); - party2Thread = new Thread(() => { safeWrapper(() => { this.party2Action(party2Object); }); }); - party1Thread.Name = "P1"; - party2Thread.Name = "P2"; - try { - party1Thread.Start(); - party2Thread.Start(); - party1Thread.Join(); - party2Thread.Join(); - } catch (ThreadAbortException) { - party1Thread.Abort(); - party2Thread.Abort(); - throw; - } catch (ThreadStartException ex) { - if (ex.InnerException is ThreadAbortException) { - // if party1Thread threw an exception - // (which may even have been intentional for the test) - // before party2Thread even started, then this exception - // can be thrown, and should be ignored. - } else { - throw; - } - } - - // Use the failing reason of a failing sub-thread as our reason, if anything failed. - if (failingException != null) { - throw new AssertionException("Coordinator thread threw unhandled exception: " + failingException, failingException); - } - } - } -} diff --git a/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj b/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj index 0e13449..ab90935 100644 --- a/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj +++ b/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj @@ -84,6 +84,7 @@ <RequiredTargetFramework>3.5</RequiredTargetFramework> </Reference> <Reference Include="System.Net.Http" /> + <Reference Include="System.Net.Http.Formatting, Version=4.0.0.0, Culture=neutral, PublicKeyToken=31bf3856ad364e35, processorArchitecture=MSIL" /> <Reference Include="System.Net.Http.WebRequest" /> <Reference Include="System.Runtime.Serialization"> <RequiredTargetFramework>3.0</RequiredTargetFramework> @@ -97,14 +98,14 @@ <Reference Include="System.Xml.Linq"> <RequiredTargetFramework>3.5</RequiredTargetFramework> </Reference> - <Reference Include="Validation"> - <HintPath>..\packages\Validation.2.0.1.12362\lib\portable-windows8+net40+sl5+windowsphone8\Validation.dll</HintPath> - <Private>True</Private> + <Reference Include="Validation, Version=2.0.0.0, Culture=neutral, PublicKeyToken=2fc06f0d701809a7, processorArchitecture=MSIL"> + <SpecificVersion>False</SpecificVersion> + <HintPath>..\packages\Validation.2.0.2.13022\lib\portable-windows8+net40+sl5+windowsphone8\Validation.dll</HintPath> </Reference> </ItemGroup> <ItemGroup> + <Compile Include="AutoRedirectHandler.cs" /> <Compile Include="Configuration\SectionTests.cs" /> - <Compile Include="CoordinatorBase.cs" /> <Compile Include="Hosting\AspNetHost.cs" /> <Compile Include="Hosting\HostingTests.cs" /> <Compile Include="Hosting\HttpHost.cs" /> @@ -114,8 +115,6 @@ <Compile Include="Messaging\EnumerableCacheTests.cs" /> <Compile Include="Messaging\ErrorUtilitiesTests.cs" /> <Compile Include="Messaging\MessageSerializerTests.cs" /> - <Compile Include="Messaging\MultipartPostPartTests.cs" /> - <Compile Include="Messaging\OutgoingWebResponseTests.cs" /> <Compile Include="Messaging\Reflection\MessageDescriptionTests.cs" /> <Compile Include="Messaging\Reflection\MessageDictionaryTests.cs" /> <Compile Include="Messaging\MessagingTestBase.cs" /> @@ -127,19 +126,14 @@ <Compile Include="Messaging\Reflection\MessagePartTests.cs" /> <Compile Include="Messaging\Reflection\ValueMappingTests.cs" /> <Compile Include="Messaging\StandardMessageFactoryTests.cs" /> + <Compile Include="MockingHostFactories.cs" /> <Compile Include="Mocks\AssociateUnencryptedRequestNoSslCheck.cs" /> - <Compile Include="Mocks\CoordinatingChannel.cs" /> - <Compile Include="Mocks\CoordinatingHttpRequestInfo.cs" /> - <Compile Include="Mocks\CoordinatingOAuth2AuthServerChannel.cs" /> - <Compile Include="Mocks\CoordinatingOAuth2ClientChannel.cs" /> - <Compile Include="Mocks\CoordinatingOutgoingWebResponse.cs" /> - <Compile Include="Mocks\CoordinatingOAuthConsumerChannel.cs" /> + <Compile Include="CookieContainerExtensions.cs" /> + <Compile Include="CookieDelegatingHandler.cs" /> <Compile Include="Mocks\IBaseMessageExplicitMembers.cs" /> <Compile Include="Mocks\InMemoryTokenManager.cs" /> <Compile Include="Mocks\MockHttpMessageHandler.cs" /> <Compile Include="Mocks\MockHttpRequest.cs" /> - <Compile Include="Mocks\MockIdentifier.cs" /> - <Compile Include="Mocks\MockIdentifierDiscoveryService.cs" /> <Compile Include="Mocks\MockOpenIdExtension.cs" /> <Compile Include="Mocks\MockRealm.cs" /> <Compile Include="Mocks\MockTransformationBindingElement.cs" /> @@ -154,7 +148,6 @@ <Compile Include="Mocks\TestExpiringMessage.cs" /> <Compile Include="Mocks\TestSignedDirectedMessage.cs" /> <Compile Include="Mocks\MockSigningBindingElement.cs" /> - <Compile Include="Mocks\TestWebRequestHandler.cs" /> <Compile Include="Mocks\TestChannel.cs" /> <Compile Include="Mocks\TestMessage.cs" /> <Compile Include="Mocks\TestMessageFactory.cs" /> @@ -162,7 +155,6 @@ <Compile Include="OAuth2\MessageFactoryTests.cs" /> <Compile Include="OAuth2\ResourceServerTests.cs" /> <Compile Include="OAuth2\UserAgentClientAuthorizeTests.cs" /> - <Compile Include="OAuth2\OAuth2Coordinator.cs" /> <Compile Include="OAuth2\OAuth2TestBase.cs" /> <Compile Include="OAuth2\WebServerClientAuthorizeTests.cs" /> <Compile Include="OAuth\ChannelElements\HmacSha1SigningBindingElementTests.cs" /> @@ -217,7 +209,6 @@ <Compile Include="OpenId\Messages\PositiveAssertionResponseTests.cs" /> <Compile Include="OpenId\Messages\SignedResponseRequestTests.cs" /> <Compile Include="OpenId\NonIdentityTests.cs" /> - <Compile Include="OpenId\OpenIdCoordinator.cs" /> <Compile Include="OpenId\AssociationHandshakeTests.cs" /> <Compile Include="OpenId\OpenIdTestBase.cs" /> <Compile Include="OpenId\OpenIdUtilitiesTests.cs" /> @@ -244,10 +235,7 @@ <Compile Include="Performance\HighPerformance.cs" /> <Compile Include="Performance\PerformanceTestUtilities.cs" /> <Compile Include="Properties\AssemblyInfo.cs" /> - <Compile Include="Messaging\ResponseTests.cs" /> <Compile Include="OAuth\AppendixScenarios.cs" /> - <Compile Include="Mocks\CoordinatingOAuthServiceProviderChannel.cs" /> - <Compile Include="OAuth\OAuthCoordinator.cs" /> <Compile Include="TestBase.cs" /> <Compile Include="TestUtilities.cs" /> <Compile Include="UriUtilTests.cs" /> @@ -338,6 +326,10 @@ <Project>{60426312-6AE5-4835-8667-37EDEA670222}</Project> <Name>DotNetOpenAuth.Core</Name> </ProjectReference> + <ProjectReference Include="..\DotNetOpenAuth.OAuth.Common\DotNetOpenAuth.OAuth.Common.csproj"> + <Project>{115217c5-22cd-415c-a292-0dd0238cdd89}</Project> + <Name>DotNetOpenAuth.OAuth.Common</Name> + </ProjectReference> <ProjectReference Include="..\DotNetOpenAuth.OAuth.Consumer\DotNetOpenAuth.OAuth.Consumer.csproj"> <Project>{B202E40D-4663-4A2B-ACDA-865F88FF7CAA}</Project> <Name>DotNetOpenAuth.OAuth.Consumer</Name> @@ -394,6 +386,14 @@ <Project>{75E13AAE-7D51-4421-ABFD-3F3DC91F576E}</Project> <Name>DotNetOpenAuth.OpenId.UI</Name> </ProjectReference> + <ProjectReference Include="..\DotNetOpenAuth.OpenIdInfoCard.UI\DotNetOpenAuth.OpenIdInfoCard.UI.csproj"> + <Project>{3a8347e8-59a5-4092-8842-95c75d7d2f36}</Project> + <Name>DotNetOpenAuth.OpenIdInfoCard.UI</Name> + </ProjectReference> + <ProjectReference Include="..\DotNetOpenAuth.OpenIdOAuth\DotNetOpenAuth.OpenIdOAuth.csproj"> + <Project>{4bfaa336-5df3-4f27-82d3-06d13240e8ab}</Project> + <Name>DotNetOpenAuth.OpenIdOAuth</Name> + </ProjectReference> <ProjectReference Include="..\DotNetOpenAuth.OpenId\DotNetOpenAuth.OpenId.csproj"> <Project>{3896A32A-E876-4C23-B9B8-78E17D134CD3}</Project> <Name>DotNetOpenAuth.OpenId</Name> diff --git a/src/DotNetOpenAuth.Test/Messaging/Bindings/StandardExpirationBindingElementTests.cs b/src/DotNetOpenAuth.Test/Messaging/Bindings/StandardExpirationBindingElementTests.cs index 6aa9461..a8eb7c1 100644 --- a/src/DotNetOpenAuth.Test/Messaging/Bindings/StandardExpirationBindingElementTests.cs +++ b/src/DotNetOpenAuth.Test/Messaging/Bindings/StandardExpirationBindingElementTests.cs @@ -6,7 +6,7 @@ namespace DotNetOpenAuth.Test.Messaging.Bindings { using System; - + using System.Threading.Tasks; using DotNetOpenAuth.Configuration; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; @@ -16,38 +16,38 @@ namespace DotNetOpenAuth.Test.Messaging.Bindings { [TestFixture] public class StandardExpirationBindingElementTests : MessagingTestBase { [Test] - public void SendSetsTimestamp() { + public async Task SendSetsTimestamp() { TestExpiringMessage message = new TestExpiringMessage(MessageTransport.Indirect); message.Recipient = new Uri("http://localtest"); ((IExpiringProtocolMessage)message).UtcCreationDate = DateTime.Parse("1/1/1990"); Channel channel = CreateChannel(MessageProtections.Expiration); - channel.PrepareResponse(message); + await channel.PrepareResponseAsync(message); Assert.IsTrue(DateTime.UtcNow - ((IExpiringProtocolMessage)message).UtcCreationDate < TimeSpan.FromSeconds(3), "The timestamp on the message was not set on send."); } [Test] - public void VerifyGoodTimestampIsAccepted() { + public async Task VerifyGoodTimestampIsAccepted() { this.Channel = CreateChannel(MessageProtections.Expiration); - this.ParameterizedReceiveProtectedTest(DateTime.UtcNow, false); + await this.ParameterizedReceiveProtectedTestAsync(DateTime.UtcNow, false); } [Test] - public void VerifyFutureTimestampWithinClockSkewIsAccepted() { + public async Task VerifyFutureTimestampWithinClockSkewIsAccepted() { this.Channel = CreateChannel(MessageProtections.Expiration); - this.ParameterizedReceiveProtectedTest(DateTime.UtcNow + DotNetOpenAuthSection.Messaging.MaximumClockSkew - TimeSpan.FromSeconds(1), false); + await this.ParameterizedReceiveProtectedTestAsync(DateTime.UtcNow + DotNetOpenAuthSection.Messaging.MaximumClockSkew - TimeSpan.FromSeconds(1), false); } [Test, ExpectedException(typeof(ExpiredMessageException))] - public void VerifyOldTimestampIsRejected() { + public async Task VerifyOldTimestampIsRejected() { this.Channel = CreateChannel(MessageProtections.Expiration); - this.ParameterizedReceiveProtectedTest(DateTime.UtcNow - StandardExpirationBindingElement.MaximumMessageAge - TimeSpan.FromSeconds(1), false); + await this.ParameterizedReceiveProtectedTestAsync(DateTime.UtcNow - StandardExpirationBindingElement.MaximumMessageAge - TimeSpan.FromSeconds(1), false); } [Test, ExpectedException(typeof(ProtocolException))] - public void VerifyFutureTimestampIsRejected() { + public async Task VerifyFutureTimestampIsRejected() { this.Channel = CreateChannel(MessageProtections.Expiration); - this.ParameterizedReceiveProtectedTest(DateTime.UtcNow + DotNetOpenAuthSection.Messaging.MaximumClockSkew + TimeSpan.FromSeconds(2), false); + await this.ParameterizedReceiveProtectedTestAsync(DateTime.UtcNow + DotNetOpenAuthSection.Messaging.MaximumClockSkew + TimeSpan.FromSeconds(2), false); } } } diff --git a/src/DotNetOpenAuth.Test/Messaging/Bindings/StandardReplayProtectionBindingElementTests.cs b/src/DotNetOpenAuth.Test/Messaging/Bindings/StandardReplayProtectionBindingElementTests.cs index 9a46e42..04c63ef 100644 --- a/src/DotNetOpenAuth.Test/Messaging/Bindings/StandardReplayProtectionBindingElementTests.cs +++ b/src/DotNetOpenAuth.Test/Messaging/Bindings/StandardReplayProtectionBindingElementTests.cs @@ -9,6 +9,8 @@ namespace DotNetOpenAuth.Test.Messaging.Bindings { using System.Collections.Generic; using System.Linq; using System.Text; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OpenId; @@ -40,14 +42,14 @@ namespace DotNetOpenAuth.Test.Messaging.Bindings { /// Verifies that the generated nonce includes random characters. /// </summary> [Test] - public void RandomCharactersTest() { - Assert.IsNotNull(this.nonceElement.ProcessOutgoingMessage(this.message)); + public async Task RandomCharactersTest() { + Assert.IsNotNull(await this.nonceElement.ProcessOutgoingMessageAsync(this.message, CancellationToken.None)); Assert.IsNotNull(this.message.Nonce, "No nonce was set on the message."); Assert.AreNotEqual(0, this.message.Nonce.Length, "The generated nonce was empty."); string firstNonce = this.message.Nonce; // Apply another nonce and verify that they are different than the first ones. - Assert.IsNotNull(this.nonceElement.ProcessOutgoingMessage(this.message)); + Assert.IsNotNull(await this.nonceElement.ProcessOutgoingMessageAsync(this.message, CancellationToken.None)); Assert.IsNotNull(this.message.Nonce, "No nonce was set on the message."); Assert.AreNotEqual(0, this.message.Nonce.Length, "The generated nonce was empty."); Assert.AreNotEqual(firstNonce, this.message.Nonce, "The two generated nonces are identical."); @@ -57,41 +59,41 @@ namespace DotNetOpenAuth.Test.Messaging.Bindings { /// Verifies that a message is received correctly. /// </summary> [Test] - public void ValidMessageReceivedTest() { + public async Task ValidMessageReceivedTest() { this.message.Nonce = "a"; - Assert.IsNotNull(this.nonceElement.ProcessIncomingMessage(this.message)); + Assert.IsNotNull(await this.nonceElement.ProcessIncomingMessageAsync(this.message, CancellationToken.None)); } /// <summary> /// Verifies that a message that doesn't have a string of random characters is received correctly. /// </summary> [Test] - public void ValidMessageNoNonceReceivedTest() { + public async Task ValidMessageNoNonceReceivedTest() { this.message.Nonce = string.Empty; this.nonceElement.AllowZeroLengthNonce = true; - Assert.IsNotNull(this.nonceElement.ProcessIncomingMessage(this.message)); + Assert.IsNotNull(await this.nonceElement.ProcessIncomingMessageAsync(this.message, CancellationToken.None)); } /// <summary> /// Verifies that a message that doesn't have a string of random characters is received correctly. /// </summary> [Test, ExpectedException(typeof(ProtocolException))] - public void InvalidMessageNoNonceReceivedTest() { + public async Task InvalidMessageNoNonceReceivedTest() { this.message.Nonce = string.Empty; this.nonceElement.AllowZeroLengthNonce = false; - Assert.IsNotNull(this.nonceElement.ProcessIncomingMessage(this.message)); + Assert.IsNotNull(await this.nonceElement.ProcessIncomingMessageAsync(this.message, CancellationToken.None)); } /// <summary> /// Verifies that a replayed message is rejected. /// </summary> [Test, ExpectedException(typeof(ReplayedMessageException))] - public void ReplayDetectionTest() { + public async Task ReplayDetectionTest() { this.message.Nonce = "a"; - Assert.IsNotNull(this.nonceElement.ProcessIncomingMessage(this.message)); + Assert.IsNotNull(await this.nonceElement.ProcessIncomingMessageAsync(this.message, CancellationToken.None)); // Now receive the same message again. This should throw because it's a message replay. - this.nonceElement.ProcessIncomingMessage(this.message); + await this.nonceElement.ProcessIncomingMessageAsync(this.message, CancellationToken.None); } } } diff --git a/src/DotNetOpenAuth.Test/Messaging/ChannelTests.cs b/src/DotNetOpenAuth.Test/Messaging/ChannelTests.cs index 5646a7e..9050fad 100644 --- a/src/DotNetOpenAuth.Test/Messaging/ChannelTests.cs +++ b/src/DotNetOpenAuth.Test/Messaging/ChannelTests.cs @@ -9,29 +9,41 @@ namespace DotNetOpenAuth.Test.Messaging { using System.Collections.Generic; using System.IO; using System.Net; + using System.Net.Http; + using System.Threading; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; + using DotNetOpenAuth.OpenId; using DotNetOpenAuth.Test.Mocks; using NUnit.Framework; [TestFixture] public class ChannelTests : MessagingTestBase { [Test, ExpectedException(typeof(ArgumentNullException))] - public void CtorNull() { - // This bad channel is deliberately constructed to pass null to - // its protected base class' constructor. - new TestBadChannel(true); + public void CtorNullFirstParameter() { + new TestBadChannel(null, new IChannelBindingElement[0], new DefaultOpenIdHostFactories()); + } + + [Test, ExpectedException(typeof(ArgumentNullException))] + public void CtorNullSecondParameter() { + new TestBadChannel(new TestMessageFactory(), null, new DefaultOpenIdHostFactories()); + } + + [Test, ExpectedException(typeof(ArgumentNullException))] + public void CtorNullThirdParameter() { + new TestBadChannel(new TestMessageFactory(), new IChannelBindingElement[0], null); } [Test] - public void ReadFromRequestQueryString() { - this.ParameterizedReceiveTest("GET"); + public async Task ReadFromRequestQueryString() { + await this.ParameterizedReceiveTestAsync(HttpMethod.Get); } [Test] - public void ReadFromRequestForm() { - this.ParameterizedReceiveTest("POST"); + public async Task ReadFromRequestForm() { + await this.ParameterizedReceiveTestAsync(HttpMethod.Post); } /// <summary> @@ -39,78 +51,78 @@ namespace DotNetOpenAuth.Test.Messaging { /// will reject messages that come with an unexpected HTTP verb. /// </summary> [Test, ExpectedException(typeof(ProtocolException))] - public void ReadFromRequestDisallowedHttpMethod() { + public async Task ReadFromRequestDisallowedHttpMethod() { var fields = GetStandardTestFields(FieldFill.CompleteBeforeBindings); fields["GetOnly"] = "true"; - this.Channel.ReadFromRequest(CreateHttpRequestInfo("POST", fields)); + await this.Channel.ReadFromRequestAsync(CreateHttpRequestInfo(HttpMethod.Post, fields), CancellationToken.None); } [Test, ExpectedException(typeof(ArgumentNullException))] - public void SendNull() { - this.Channel.PrepareResponse(null); + public async Task SendNull() { + await this.Channel.PrepareResponseAsync(null); } [Test, ExpectedException(typeof(ArgumentException))] - public void SendIndirectedUndirectedMessage() { + public async Task SendIndirectedUndirectedMessage() { IProtocolMessage message = new TestDirectedMessage(MessageTransport.Indirect); - this.Channel.PrepareResponse(message); + await this.Channel.PrepareResponseAsync(message); } [Test, ExpectedException(typeof(ArgumentException))] - public void SendDirectedNoRecipientMessage() { + public async Task SendDirectedNoRecipientMessage() { IProtocolMessage message = new TestDirectedMessage(MessageTransport.Indirect); - this.Channel.PrepareResponse(message); + await this.Channel.PrepareResponseAsync(message); } [Test, ExpectedException(typeof(ArgumentException))] - public void SendInvalidMessageTransport() { + public async Task SendInvalidMessageTransport() { IProtocolMessage message = new TestDirectedMessage((MessageTransport)100); - this.Channel.PrepareResponse(message); + await this.Channel.PrepareResponseAsync(message); } [Test] - public void SendIndirectMessage301Get() { + public async Task SendIndirectMessage301Get() { TestDirectedMessage message = new TestDirectedMessage(MessageTransport.Indirect); GetStandardTestMessage(FieldFill.CompleteBeforeBindings, message); message.Recipient = new Uri("http://provider/path"); var expected = GetStandardTestFields(FieldFill.CompleteBeforeBindings); - OutgoingWebResponse response = this.Channel.PrepareResponse(message); - Assert.AreEqual(HttpStatusCode.Redirect, response.Status); - Assert.AreEqual("text/html; charset=utf-8", response.Headers[HttpResponseHeader.ContentType]); - Assert.IsTrue(response.Body != null && response.Body.Length > 0); // a non-empty body helps get passed filters like WebSense - StringAssert.StartsWith("http://provider/path", response.Headers[HttpResponseHeader.Location]); + var response = await this.Channel.PrepareResponseAsync(message); + Assert.AreEqual(HttpStatusCode.Redirect, response.StatusCode); + Assert.AreEqual("text/html; charset=utf-8", response.Content.Headers.ContentType.ToString()); + Assert.IsTrue(response.Content != null && response.Content.Headers.ContentLength > 0); // a non-empty body helps get passed filters like WebSense + StringAssert.StartsWith("http://provider/path", response.Headers.Location.AbsoluteUri); foreach (var pair in expected) { string key = MessagingUtilities.EscapeUriDataStringRfc3986(pair.Key); string value = MessagingUtilities.EscapeUriDataStringRfc3986(pair.Value); string substring = string.Format("{0}={1}", key, value); - StringAssert.Contains(substring, response.Headers[HttpResponseHeader.Location]); + StringAssert.Contains(substring, response.Headers.Location.AbsoluteUri); } } [Test, ExpectedException(typeof(ArgumentNullException))] public void SendIndirectMessage301GetNullMessage() { - TestBadChannel badChannel = new TestBadChannel(false); + TestBadChannel badChannel = new TestBadChannel(); badChannel.Create301RedirectResponse(null, new Dictionary<string, string>()); } [Test, ExpectedException(typeof(ArgumentException))] public void SendIndirectMessage301GetEmptyRecipient() { - TestBadChannel badChannel = new TestBadChannel(false); + TestBadChannel badChannel = new TestBadChannel(); var message = new TestDirectedMessage(MessageTransport.Indirect); badChannel.Create301RedirectResponse(message, new Dictionary<string, string>()); } [Test, ExpectedException(typeof(ArgumentNullException))] public void SendIndirectMessage301GetNullFields() { - TestBadChannel badChannel = new TestBadChannel(false); + TestBadChannel badChannel = new TestBadChannel(); var message = new TestDirectedMessage(MessageTransport.Indirect); message.Recipient = new Uri("http://someserver"); badChannel.Create301RedirectResponse(message, null); } [Test] - public void SendIndirectMessageFormPost() { + public async Task SendIndirectMessageFormPost() { // We craft a very large message to force fallback to form POST. // We'll also stick some HTML reserved characters in the string value // to test proper character escaping. @@ -120,10 +132,10 @@ namespace DotNetOpenAuth.Test.Messaging { Location = new Uri("http://host/path"), Recipient = new Uri("http://provider/path"), }; - OutgoingWebResponse response = this.Channel.PrepareResponse(message); - Assert.AreEqual(HttpStatusCode.OK, response.Status, "A form redirect should be an HTTP successful response."); - Assert.IsNull(response.Headers[HttpResponseHeader.Location], "There should not be a redirection header in the response."); - string body = response.Body; + var response = await this.Channel.PrepareResponseAsync(message); + Assert.AreEqual(HttpStatusCode.OK, response.StatusCode, "A form redirect should be an HTTP successful response."); + Assert.IsNull(response.Headers.Location, "There should not be a redirection header in the response."); + string body = await response.Content.ReadAsStringAsync(); StringAssert.Contains("<form ", body); StringAssert.Contains("action=\"http://provider/path\"", body); StringAssert.Contains("method=\"post\"", body); @@ -135,20 +147,20 @@ namespace DotNetOpenAuth.Test.Messaging { [Test, ExpectedException(typeof(ArgumentNullException))] public void SendIndirectMessageFormPostNullMessage() { - TestBadChannel badChannel = new TestBadChannel(false); + TestBadChannel badChannel = new TestBadChannel(); badChannel.CreateFormPostResponse(null, new Dictionary<string, string>()); } [Test, ExpectedException(typeof(ArgumentException))] public void SendIndirectMessageFormPostEmptyRecipient() { - TestBadChannel badChannel = new TestBadChannel(false); + TestBadChannel badChannel = new TestBadChannel(); var message = new TestDirectedMessage(MessageTransport.Indirect); badChannel.CreateFormPostResponse(message, new Dictionary<string, string>()); } [Test, ExpectedException(typeof(ArgumentNullException))] public void SendIndirectMessageFormPostNullFields() { - TestBadChannel badChannel = new TestBadChannel(false); + TestBadChannel badChannel = new TestBadChannel(); var message = new TestDirectedMessage(MessageTransport.Indirect); message.Recipient = new Uri("http://someserver"); badChannel.CreateFormPostResponse(message, null); @@ -162,101 +174,103 @@ namespace DotNetOpenAuth.Test.Messaging { /// we just check that the right method was called. /// </remarks> [Test, ExpectedException(typeof(NotImplementedException))] - public void SendDirectMessageResponse() { + public async Task SendDirectMessageResponse() { IProtocolMessage message = new TestDirectedMessage { Age = 15, Name = "Andrew", Location = new Uri("http://host/path"), }; - this.Channel.PrepareResponse(message); + await this.Channel.PrepareResponseAsync(message); } [Test, ExpectedException(typeof(ArgumentNullException))] public void SendIndirectMessageNull() { - TestBadChannel badChannel = new TestBadChannel(false); + TestBadChannel badChannel = new TestBadChannel(); badChannel.PrepareIndirectResponse(null); } [Test, ExpectedException(typeof(ArgumentNullException))] public void ReceiveNull() { - TestBadChannel badChannel = new TestBadChannel(false); + TestBadChannel badChannel = new TestBadChannel(); badChannel.Receive(null, null); } [Test] public void ReceiveUnrecognizedMessage() { - TestBadChannel badChannel = new TestBadChannel(false); + TestBadChannel badChannel = new TestBadChannel(); Assert.IsNull(badChannel.Receive(new Dictionary<string, string>(), null)); } [Test] - public void ReadFromRequestWithContext() { + public async Task ReadFromRequestWithContext() { var fields = GetStandardTestFields(FieldFill.AllRequired); TestMessage expectedMessage = GetStandardTestMessage(FieldFill.AllRequired); HttpRequest request = new HttpRequest("somefile", "http://someurl", MessagingUtilities.CreateQueryString(fields)); HttpContext.Current = new HttpContext(request, new HttpResponse(new StringWriter())); - IProtocolMessage message = this.Channel.ReadFromRequest(); + var requestBase = this.Channel.GetRequestFromContext(); + IProtocolMessage message = await this.Channel.ReadFromRequestAsync(requestBase.AsHttpRequestMessage(), CancellationToken.None); Assert.IsNotNull(message); Assert.IsInstanceOf<TestMessage>(message); Assert.AreEqual(expectedMessage.Age, ((TestMessage)message).Age); } [Test, ExpectedException(typeof(InvalidOperationException))] - public void ReadFromRequestNoContext() { + public void GetRequestFromContextNoContext() { HttpContext.Current = null; - TestBadChannel badChannel = new TestBadChannel(false); - badChannel.ReadFromRequest(); + TestBadChannel badChannel = new TestBadChannel(); + badChannel.GetRequestFromContext(); } [Test, ExpectedException(typeof(ArgumentNullException))] - public void ReadFromRequestNull() { - TestBadChannel badChannel = new TestBadChannel(false); - badChannel.ReadFromRequest(null); + public async Task ReadFromRequestNull() { + TestBadChannel badChannel = new TestBadChannel(); + await badChannel.ReadFromRequestAsync(null, CancellationToken.None); } [Test] - public void SendReplayProtectedMessageSetsNonce() { + public async Task SendReplayProtectedMessageSetsNonce() { TestReplayProtectedMessage message = new TestReplayProtectedMessage(MessageTransport.Indirect); message.Recipient = new Uri("http://localtest"); this.Channel = CreateChannel(MessageProtections.ReplayProtection); - this.Channel.PrepareResponse(message); + await this.Channel.PrepareResponseAsync(message); Assert.IsNotNull(((IReplayProtectedProtocolMessage)message).Nonce); } [Test, ExpectedException(typeof(InvalidSignatureException))] - public void ReceivedInvalidSignature() { + public async Task ReceivedInvalidSignature() { this.Channel = CreateChannel(MessageProtections.TamperProtection); - this.ParameterizedReceiveProtectedTest(DateTime.UtcNow, true); + await this.ParameterizedReceiveProtectedTestAsync(DateTime.UtcNow, true); } [Test] - public void ReceivedReplayProtectedMessageJustOnce() { + public async Task ReceivedReplayProtectedMessageJustOnce() { this.Channel = CreateChannel(MessageProtections.ReplayProtection); - this.ParameterizedReceiveProtectedTest(DateTime.UtcNow, false); + await this.ParameterizedReceiveProtectedTestAsync(DateTime.UtcNow, false); } [Test, ExpectedException(typeof(ReplayedMessageException))] - public void ReceivedReplayProtectedMessageTwice() { + public async Task ReceivedReplayProtectedMessageTwice() { this.Channel = CreateChannel(MessageProtections.ReplayProtection); - this.ParameterizedReceiveProtectedTest(DateTime.UtcNow, false); - this.ParameterizedReceiveProtectedTest(DateTime.UtcNow, false); + await this.ParameterizedReceiveProtectedTestAsync(DateTime.UtcNow, false); + await this.ParameterizedReceiveProtectedTestAsync(DateTime.UtcNow, false); } [Test, ExpectedException(typeof(ProtocolException))] public void MessageExpirationWithoutTamperResistance() { new TestChannel( new TestMessageFactory(), - new StandardExpirationBindingElement()); + new IChannelBindingElement[] { new StandardExpirationBindingElement() }, + new DefaultOpenIdHostFactories()); } [Test, ExpectedException(typeof(ProtocolException))] - public void TooManyBindingElementsProvidingSameProtection() { + public async Task TooManyBindingElementsProvidingSameProtection() { Channel channel = new TestChannel( new TestMessageFactory(), - new MockSigningBindingElement(), - new MockSigningBindingElement()); - channel.ProcessOutgoingMessageTestHook(new TestSignedDirectedMessage()); + new IChannelBindingElement[] { new MockSigningBindingElement(), new MockSigningBindingElement() }, + new DefaultOpenIdHostFactories()); + await channel.ProcessOutgoingMessageTestHookAsync(new TestSignedDirectedMessage()); } [Test] @@ -269,11 +283,8 @@ namespace DotNetOpenAuth.Test.Messaging { Channel channel = new TestChannel( new TestMessageFactory(), - sign, - replay, - expire, - transformB, - transformA); + new[] { sign, replay, expire, transformB, transformA }, + new DefaultOpenIdHostFactories()); Assert.AreEqual(5, channel.BindingElements.Count); Assert.AreSame(transformB, channel.BindingElements[0]); @@ -284,22 +295,22 @@ namespace DotNetOpenAuth.Test.Messaging { } [Test, ExpectedException(typeof(UnprotectedMessageException))] - public void InsufficientlyProtectedMessageSent() { + public async Task InsufficientlyProtectedMessageSent() { var message = new TestSignedDirectedMessage(MessageTransport.Direct); message.Recipient = new Uri("http://localtest"); - this.Channel.PrepareResponse(message); + await this.Channel.PrepareResponseAsync(message); } [Test, ExpectedException(typeof(UnprotectedMessageException))] - public void InsufficientlyProtectedMessageReceived() { + public async Task InsufficientlyProtectedMessageReceived() { this.Channel = CreateChannel(MessageProtections.None, MessageProtections.TamperProtection); - this.ParameterizedReceiveProtectedTest(DateTime.Now, false); + await this.ParameterizedReceiveProtectedTestAsync(DateTime.Now, false); } [Test, ExpectedException(typeof(ProtocolException))] - public void IncomingMessageMissingRequiredParameters() { + public async Task IncomingMessageMissingRequiredParameters() { var fields = GetStandardTestFields(FieldFill.IdentifiableButNotAllRequired); - this.Channel.ReadFromRequest(CreateHttpRequestInfo("GET", fields)); + await this.Channel.ReadFromRequestAsync(CreateHttpRequestInfo(HttpMethod.Get, fields), CancellationToken.None); } } } diff --git a/src/DotNetOpenAuth.Test/Messaging/MessagingTestBase.cs b/src/DotNetOpenAuth.Test/Messaging/MessagingTestBase.cs index b7c0980..7903e89 100644 --- a/src/DotNetOpenAuth.Test/Messaging/MessagingTestBase.cs +++ b/src/DotNetOpenAuth.Test/Messaging/MessagingTestBase.cs @@ -10,6 +10,9 @@ namespace DotNetOpenAuth.Test { using System.Collections.Specialized; using System.IO; using System.Net; + using System.Net.Http; + using System.Threading; + using System.Threading.Tasks; using System.Xml; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; @@ -52,30 +55,29 @@ namespace DotNetOpenAuth.Test { public override void SetUp() { base.SetUp(); - this.Channel = new TestChannel(); + this.Channel = new TestChannel(this.HostFactories); } - internal static HttpRequestInfo CreateHttpRequestInfo(string method, IDictionary<string, string> fields) { + internal static HttpRequestMessage CreateHttpRequestInfo(HttpMethod method, IDictionary<string, string> fields) { + var result = new HttpRequestMessage() { Method = method }; var requestUri = new UriBuilder(DefaultUrlForHttpRequestInfo); - var headers = new NameValueCollection(); - NameValueCollection form = null; - if (method == "POST") { - form = fields.ToNameValueCollection(); - headers.Add(HttpRequestHeaders.ContentType, Channel.HttpFormUrlEncoded); - } else if (method == "GET") { - requestUri.Query = MessagingUtilities.CreateQueryString(fields); + if (method == HttpMethod.Post) { + result.Content = new FormUrlEncodedContent(fields); + } else if (method == HttpMethod.Get) { + requestUri.AppendQueryArgs(fields); } else { throw new ArgumentOutOfRangeException("method", method, "Expected POST or GET"); } - return new HttpRequestInfo(method, requestUri.Uri, form: form, headers: headers); + result.RequestUri = requestUri.Uri; + return result; } - internal static Channel CreateChannel(MessageProtections capabilityAndRecognition) { - return CreateChannel(capabilityAndRecognition, capabilityAndRecognition); + internal Channel CreateChannel(MessageProtections capabilityAndRecognition) { + return this.CreateChannel(capabilityAndRecognition, capabilityAndRecognition); } - internal static Channel CreateChannel(MessageProtections capability, MessageProtections recognition) { + internal Channel CreateChannel(MessageProtections capability, MessageProtections recognition) { var bindingElements = new List<IChannelBindingElement>(); if (capability >= MessageProtections.TamperProtection) { bindingElements.Add(new MockSigningBindingElement()); @@ -99,7 +101,7 @@ namespace DotNetOpenAuth.Test { } var typeProvider = new TestMessageFactory(signing, expiration, replay); - return new TestChannel(typeProvider, bindingElements.ToArray()); + return new TestChannel(typeProvider, bindingElements.ToArray(), this.HostFactories); } internal static IDictionary<string, string> GetStandardTestFields(FieldFill fill) { @@ -143,11 +145,11 @@ namespace DotNetOpenAuth.Test { } } - internal void ParameterizedReceiveTest(string method) { + internal async Task ParameterizedReceiveTestAsync(HttpMethod method) { var fields = GetStandardTestFields(FieldFill.CompleteBeforeBindings); TestMessage expectedMessage = GetStandardTestMessage(FieldFill.CompleteBeforeBindings); - IDirectedProtocolMessage requestMessage = this.Channel.ReadFromRequest(CreateHttpRequestInfo(method, fields)); + IDirectedProtocolMessage requestMessage = await this.Channel.ReadFromRequestAsync(CreateHttpRequestInfo(method, fields), CancellationToken.None); Assert.IsNotNull(requestMessage); Assert.IsInstanceOf<TestMessage>(requestMessage); TestMessage actualMessage = (TestMessage)requestMessage; @@ -156,7 +158,7 @@ namespace DotNetOpenAuth.Test { Assert.AreEqual(expectedMessage.Location, actualMessage.Location); } - internal void ParameterizedReceiveProtectedTest(DateTime? utcCreatedDate, bool invalidSignature) { + internal async Task ParameterizedReceiveProtectedTestAsync(DateTime? utcCreatedDate, bool invalidSignature) { TestMessage expectedMessage = GetStandardTestMessage(FieldFill.CompleteBeforeBindings); var fields = GetStandardTestFields(FieldFill.CompleteBeforeBindings); fields.Add("Signature", invalidSignature ? "badsig" : MockSigningBindingElement.MessageSignature); @@ -165,7 +167,7 @@ namespace DotNetOpenAuth.Test { utcCreatedDate = DateTime.Parse(utcCreatedDate.Value.ToUniversalTime().ToString()); // round off the milliseconds so comparisons work later fields.Add("created_on", XmlConvert.ToString(utcCreatedDate.Value, XmlDateTimeSerializationMode.Utc)); } - IProtocolMessage requestMessage = this.Channel.ReadFromRequest(CreateHttpRequestInfo("GET", fields)); + IProtocolMessage requestMessage = await this.Channel.ReadFromRequestAsync(CreateHttpRequestInfo(HttpMethod.Get, fields), CancellationToken.None); Assert.IsNotNull(requestMessage); Assert.IsInstanceOf<TestSignedDirectedMessage>(requestMessage); TestSignedDirectedMessage actualMessage = (TestSignedDirectedMessage)requestMessage; diff --git a/src/DotNetOpenAuth.Test/Messaging/MessagingUtilitiesTests.cs b/src/DotNetOpenAuth.Test/Messaging/MessagingUtilitiesTests.cs index cf0f9ca..0abf60f 100644 --- a/src/DotNetOpenAuth.Test/Messaging/MessagingUtilitiesTests.cs +++ b/src/DotNetOpenAuth.Test/Messaging/MessagingUtilitiesTests.cs @@ -67,41 +67,6 @@ namespace DotNetOpenAuth.Test.Messaging { } [Test] - public void AsHttpResponseMessage() { - var responseContent = new byte[10]; - (new Random()).NextBytes(responseContent); - var responseStream = new MemoryStream(responseContent); - var outgoingResponse = new OutgoingWebResponse(); - outgoingResponse.Headers.Add("X-SOME-HEADER", "value"); - outgoingResponse.Headers.Add("Content-Length", responseContent.Length.ToString(CultureInfo.InvariantCulture)); - outgoingResponse.ResponseStream = responseStream; - - var httpResponseMessage = outgoingResponse.AsHttpResponseMessage(); - Assert.That(httpResponseMessage, Is.Not.Null); - Assert.That(httpResponseMessage.Headers.GetValues("X-SOME-HEADER").ToList(), Is.EqualTo(new[] { "value" })); - Assert.That( - httpResponseMessage.Content.Headers.GetValues("Content-Length").ToList(), - Is.EqualTo(new[] { responseContent.Length.ToString(CultureInfo.InvariantCulture) })); - var actualContent = new byte[responseContent.Length + 1]; // give the opportunity to provide a bit more data than we expect. - var bytesRead = httpResponseMessage.Content.ReadAsStreamAsync().Result.Read(actualContent, 0, actualContent.Length); - Assert.That(bytesRead, Is.EqualTo(responseContent.Length)); // verify that only the data we expected came back. - var trimmedActualContent = new byte[bytesRead]; - Array.Copy(actualContent, trimmedActualContent, bytesRead); - Assert.That(trimmedActualContent, Is.EqualTo(responseContent)); - } - - [Test] - public void AsHttpResponseMessageNoContent() { - var outgoingResponse = new OutgoingWebResponse(); - outgoingResponse.Headers.Add("X-SOME-HEADER", "value"); - - var httpResponseMessage = outgoingResponse.AsHttpResponseMessage(); - Assert.That(httpResponseMessage, Is.Not.Null); - Assert.That(httpResponseMessage.Headers.GetValues("X-SOME-HEADER").ToList(), Is.EqualTo(new[] { "value" })); - Assert.That(httpResponseMessage.Content, Is.Null); - } - - [Test] public void ToDictionary() { NameValueCollection nvc = new NameValueCollection(); nvc["a"] = "b"; @@ -142,11 +107,6 @@ namespace DotNetOpenAuth.Test.Messaging { } [Test, ExpectedException(typeof(ArgumentNullException))] - public void ApplyHeadersToResponseNullListenerResponse() { - MessagingUtilities.ApplyHeadersToResponse(new WebHeaderCollection(), (HttpListenerResponse)null); - } - - [Test, ExpectedException(typeof(ArgumentNullException))] public void ApplyHeadersToResponseNullHeaders() { MessagingUtilities.ApplyHeadersToResponse(null, new HttpResponseWrapper(new HttpResponse(new StringWriter()))); } @@ -183,54 +143,25 @@ namespace DotNetOpenAuth.Test.Messaging { } /// <summary> - /// Verifies the overall format of the multipart POST is correct. - /// </summary> - [Test] - public void PostMultipart() { - var httpHandler = new TestWebRequestHandler(); - bool callbackTriggered = false; - httpHandler.Callback = req => { - var m = Regex.Match(req.ContentType, "multipart/form-data; boundary=(.+)"); - Assert.IsTrue(m.Success, "Content-Type HTTP header not set correctly."); - string boundary = m.Groups[1].Value; - boundary = boundary.Substring(0, boundary.IndexOf(';')); // trim off charset - string expectedEntity = "--{0}\r\nContent-Disposition: form-data; name=\"a\"\r\n\r\nb\r\n--{0}--\r\n"; - expectedEntity = string.Format(expectedEntity, boundary); - string actualEntity = httpHandler.RequestEntityAsString; - Assert.AreEqual(expectedEntity, actualEntity); - callbackTriggered = true; - Assert.AreEqual(req.ContentLength, actualEntity.Length); - IncomingWebResponse resp = new CachedDirectWebResponse(); - return resp; - }; - var request = (HttpWebRequest)WebRequest.Create("http://someserver"); - var parts = new[] { - MultipartPostPart.CreateFormPart("a", "b"), - }; - request.PostMultipart(httpHandler, parts); - Assert.IsTrue(callbackTriggered); - } - - /// <summary> /// Verifies proper behavior of GetHttpVerb /// </summary> [Test] public void GetHttpVerbTest() { - Assert.AreEqual("GET", MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.GetRequest)); - Assert.AreEqual("POST", MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.PostRequest)); - Assert.AreEqual("HEAD", MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.HeadRequest)); - Assert.AreEqual("DELETE", MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.DeleteRequest)); - Assert.AreEqual("PUT", MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.PutRequest));
- Assert.AreEqual("PATCH", MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.PatchRequest));
- Assert.AreEqual("OPTIONS", MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.OptionsRequest)); - - Assert.AreEqual("GET", MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.GetRequest | HttpDeliveryMethods.AuthorizationHeaderRequest)); - Assert.AreEqual("POST", MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.PostRequest | HttpDeliveryMethods.AuthorizationHeaderRequest)); - Assert.AreEqual("HEAD", MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.HeadRequest | HttpDeliveryMethods.AuthorizationHeaderRequest)); - Assert.AreEqual("DELETE", MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.DeleteRequest | HttpDeliveryMethods.AuthorizationHeaderRequest)); - Assert.AreEqual("PUT", MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.PutRequest | HttpDeliveryMethods.AuthorizationHeaderRequest));
- Assert.AreEqual("PATCH", MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.PatchRequest | HttpDeliveryMethods.AuthorizationHeaderRequest));
- Assert.AreEqual("OPTIONS", MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.OptionsRequest | HttpDeliveryMethods.AuthorizationHeaderRequest)); + Assert.AreEqual(HttpMethod.Get, MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.GetRequest)); + Assert.AreEqual(HttpMethod.Post, MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.PostRequest)); + Assert.AreEqual(HttpMethod.Head, MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.HeadRequest)); + Assert.AreEqual(HttpMethod.Delete, MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.DeleteRequest)); + Assert.AreEqual(HttpMethod.Put, MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.PutRequest)); + Assert.AreEqual(new HttpMethod("PATCH"), MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.PatchRequest)); + Assert.AreEqual(HttpMethod.Options, MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.OptionsRequest)); + + Assert.AreEqual(HttpMethod.Get, MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.GetRequest | HttpDeliveryMethods.AuthorizationHeaderRequest)); + Assert.AreEqual(HttpMethod.Post, MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.PostRequest | HttpDeliveryMethods.AuthorizationHeaderRequest)); + Assert.AreEqual(HttpMethod.Head, MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.HeadRequest | HttpDeliveryMethods.AuthorizationHeaderRequest)); + Assert.AreEqual(HttpMethod.Delete, MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.DeleteRequest | HttpDeliveryMethods.AuthorizationHeaderRequest)); + Assert.AreEqual(HttpMethod.Put, MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.PutRequest | HttpDeliveryMethods.AuthorizationHeaderRequest)); + Assert.AreEqual(new HttpMethod("PATCH"), MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.PatchRequest | HttpDeliveryMethods.AuthorizationHeaderRequest)); + Assert.AreEqual(HttpMethod.Options, MessagingUtilities.GetHttpVerb(HttpDeliveryMethods.OptionsRequest | HttpDeliveryMethods.AuthorizationHeaderRequest)); } /// <summary> @@ -250,8 +181,8 @@ namespace DotNetOpenAuth.Test.Messaging { Assert.AreEqual(HttpDeliveryMethods.PostRequest, MessagingUtilities.GetHttpDeliveryMethod("POST")); Assert.AreEqual(HttpDeliveryMethods.HeadRequest, MessagingUtilities.GetHttpDeliveryMethod("HEAD")); Assert.AreEqual(HttpDeliveryMethods.PutRequest, MessagingUtilities.GetHttpDeliveryMethod("PUT")); - Assert.AreEqual(HttpDeliveryMethods.DeleteRequest, MessagingUtilities.GetHttpDeliveryMethod("DELETE"));
- Assert.AreEqual(HttpDeliveryMethods.PatchRequest, MessagingUtilities.GetHttpDeliveryMethod("PATCH"));
+ Assert.AreEqual(HttpDeliveryMethods.DeleteRequest, MessagingUtilities.GetHttpDeliveryMethod("DELETE")); + Assert.AreEqual(HttpDeliveryMethods.PatchRequest, MessagingUtilities.GetHttpDeliveryMethod("PATCH")); Assert.AreEqual(HttpDeliveryMethods.OptionsRequest, MessagingUtilities.GetHttpDeliveryMethod("OPTIONS")); } diff --git a/src/DotNetOpenAuth.Test/Messaging/MultipartPostPartTests.cs b/src/DotNetOpenAuth.Test/Messaging/MultipartPostPartTests.cs deleted file mode 100644 index e9ac5aa..0000000 --- a/src/DotNetOpenAuth.Test/Messaging/MultipartPostPartTests.cs +++ /dev/null @@ -1,108 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="MultipartPostPartTests.cs" company="Outercurve Foundation"> -// Copyright (c) Outercurve Foundation. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.Test.Messaging { - using System.CodeDom.Compiler; - using System.Collections.Generic; - using System.IO; - using System.Net; - using DotNetOpenAuth.Messaging; - using NUnit.Framework; - using Validation; - - [TestFixture] - public class MultipartPostPartTests : TestBase { - /// <summary> - /// Verifies that the Length property matches the length actually serialized. - /// </summary> - [Test] - public void FormDataSerializeMatchesLength() { - var part = MultipartPostPart.CreateFormPart("a", "b"); - VerifyLength(part); - } - - /// <summary> - /// Verifies that the length property matches the length actually serialized. - /// </summary> - [Test] - public void FileSerializeMatchesLength() { - using (TempFileCollection tfc = new TempFileCollection()) { - string file = tfc.AddExtension(".txt"); - File.WriteAllText(file, "sometext"); - var part = MultipartPostPart.CreateFormFilePart("someformname", file, "text/plain"); - VerifyLength(part); - } - } - - /// <summary> - /// Verifies file multiparts identify themselves as files and not merely form-data. - /// </summary> - [Test] - public void FilePartAsFile() { - var part = MultipartPostPart.CreateFormFilePart("somename", "somefile", "plain/text", new MemoryStream()); - Assert.AreEqual("file", part.ContentDisposition); - } - - /// <summary> - /// Verifies MultiPartPost sends the right number of bytes. - /// </summary> - [Test] - public void MultiPartPostAscii() { - using (TempFileCollection tfc = new TempFileCollection()) { - string file = tfc.AddExtension("txt"); - File.WriteAllText(file, "sometext"); - this.VerifyFullPost(new List<MultipartPostPart> { - MultipartPostPart.CreateFormPart("a", "b"), - MultipartPostPart.CreateFormFilePart("SomeFormField", file, "text/plain"), - }); - } - } - - /// <summary> - /// Verifies MultiPartPost sends the right number of bytes. - /// </summary> - [Test] - public void MultiPartPostMultiByteCharacters() { - using (TempFileCollection tfc = new TempFileCollection()) { - string file = tfc.AddExtension("txt"); - File.WriteAllText(file, "\x1020\x818"); - this.VerifyFullPost(new List<MultipartPostPart> { - MultipartPostPart.CreateFormPart("a", "\x987"), - MultipartPostPart.CreateFormFilePart("SomeFormField", file, "text/plain"), - }); - } - } - - private static void VerifyLength(MultipartPostPart part) { - Requires.NotNull(part, "part"); - - var expectedLength = part.Length; - var ms = new MemoryStream(); - var sw = new StreamWriter(ms); - part.Serialize(sw); - sw.Flush(); - var actualLength = ms.Length; - Assert.AreEqual(expectedLength, actualLength); - } - - private void VerifyFullPost(List<MultipartPostPart> parts) { - var request = (HttpWebRequest)WebRequest.Create("http://localhost"); - var handler = new Mocks.TestWebRequestHandler(); - bool posted = false; - handler.Callback = req => { - foreach (string header in req.Headers) { - TestUtilities.TestLogger.InfoFormat("{0}: {1}", header, req.Headers[header]); - } - TestUtilities.TestLogger.InfoFormat(handler.RequestEntityAsString); - Assert.AreEqual(req.ContentLength, handler.RequestEntityStream.Length); - posted = true; - return null; - }; - request.PostMultipart(handler, parts); - Assert.IsTrue(posted, "HTTP POST never sent."); - } - } -} diff --git a/src/DotNetOpenAuth.Test/Messaging/OutgoingWebResponseTests.cs b/src/DotNetOpenAuth.Test/Messaging/OutgoingWebResponseTests.cs deleted file mode 100644 index 3efc471..0000000 --- a/src/DotNetOpenAuth.Test/Messaging/OutgoingWebResponseTests.cs +++ /dev/null @@ -1,38 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="OutgoingWebResponseTests.cs" company="Outercurve Foundation"> -// Copyright (c) Outercurve Foundation. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.Test.Messaging { - using System.Net; - using System.Net.Mime; - using System.Text; - using DotNetOpenAuth.Messaging; - using NUnit.Framework; - - [TestFixture] - public class OutgoingWebResponseTests { - /// <summary> - /// Verifies that setting the Body property correctly converts to a byte stream. - /// </summary> - [Test] - public void SetBodyToByteStream() { - var response = new OutgoingWebResponse(); - string stringValue = "abc"; - response.Body = stringValue; - Assert.AreEqual(stringValue.Length, response.ResponseStream.Length); - - // Verify that the actual bytes are correct. - Encoding encoding = new UTF8Encoding(false); // avoid emitting a byte-order mark - var expectedBuffer = encoding.GetBytes(stringValue); - var actualBuffer = new byte[stringValue.Length]; - Assert.AreEqual(stringValue.Length, response.ResponseStream.Read(actualBuffer, 0, stringValue.Length)); - CollectionAssert.AreEqual(expectedBuffer, actualBuffer); - - // Verify that the header was set correctly. - Assert.IsNull(response.Headers[HttpResponseHeader.ContentEncoding]); - Assert.AreEqual(encoding.HeaderName, new ContentType(response.Headers[HttpResponseHeader.ContentType]).CharSet); - } - } -} diff --git a/src/DotNetOpenAuth.Test/Messaging/ResponseTests.cs b/src/DotNetOpenAuth.Test/Messaging/ResponseTests.cs deleted file mode 100644 index 52f031e..0000000 --- a/src/DotNetOpenAuth.Test/Messaging/ResponseTests.cs +++ /dev/null @@ -1,40 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="ResponseTests.cs" company="Outercurve Foundation"> -// Copyright (c) Outercurve Foundation. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.Test.Messaging { - using System; - using System.IO; - using System.Web; - using DotNetOpenAuth.Messaging; - using NUnit.Framework; - - [TestFixture] - public class ResponseTests : TestBase { - [Test, ExpectedException(typeof(InvalidOperationException))] - public void RespondWithoutAspNetContext() { - HttpContext.Current = null; - new OutgoingWebResponse().Respond(); - } - - [Test] - public void Respond() { - StringWriter writer = new StringWriter(); - HttpRequest httpRequest = new HttpRequest("file", "http://server", string.Empty); - HttpResponse httpResponse = new HttpResponse(writer); - HttpContext context = new HttpContext(httpRequest, httpResponse); - HttpContext.Current = context; - - OutgoingWebResponse response = new OutgoingWebResponse(); - response.Status = System.Net.HttpStatusCode.OK; - response.Headers["someHeaderName"] = "someHeaderValue"; - response.Body = "some body"; - response.Respond(); - string results = writer.ToString(); - // For some reason the only output in test is the body... the headers require a web host - Assert.AreEqual(response.Body, results); - } - } -} diff --git a/src/DotNetOpenAuth.Test/MockingHostFactories.cs b/src/DotNetOpenAuth.Test/MockingHostFactories.cs new file mode 100644 index 0000000..b8cbeb0 --- /dev/null +++ b/src/DotNetOpenAuth.Test/MockingHostFactories.cs @@ -0,0 +1,83 @@ +//----------------------------------------------------------------------- +// <copyright file="MockingHostFactories.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Test { + using System.Collections.Generic; + using System.Net; + using System.Net.Http; + using System.Threading; + using System.Threading.Tasks; + using System.Linq; + + using DotNetOpenAuth.OpenId; + + using Validation; + using System; + + internal class MockingHostFactories : IHostFactories { + public MockingHostFactories(Dictionary<Uri, Func<HttpRequestMessage, Task<HttpResponseMessage>>> handlers = null) { + this.Handlers = handlers ?? new Dictionary<Uri, Func<HttpRequestMessage, Task<HttpResponseMessage>>>(); + this.CookieContainer = new CookieContainer(); + this.AllowAutoRedirects = true; + } + + public Dictionary<Uri, Func<HttpRequestMessage, Task<HttpResponseMessage>>> Handlers { get; private set; } + + public CookieContainer CookieContainer { get; set; } + + public bool AllowAutoRedirects { get; set; } + + public bool InstallUntrustedWebReqestHandler { get; set; } + + public HttpMessageHandler CreateHttpMessageHandler() { + var forwardingMessageHandler = new ForwardingMessageHandler(this.Handlers, this); + var cookieDelegatingHandler = new CookieDelegatingHandler(forwardingMessageHandler, this.CookieContainer); + if (this.InstallUntrustedWebReqestHandler) { + var untrustedHandler = new UntrustedWebRequestHandler(cookieDelegatingHandler); + untrustedHandler.AllowAutoRedirect = this.AllowAutoRedirects; + return untrustedHandler; + } else if (this.AllowAutoRedirects) { + return new AutoRedirectHandler(cookieDelegatingHandler); + } else { + return cookieDelegatingHandler; + } + } + + public HttpClient CreateHttpClient(HttpMessageHandler handler = null) { + return new HttpClient(handler ?? this.CreateHttpMessageHandler()); + } + + private class ForwardingMessageHandler : HttpMessageHandler { + private readonly Dictionary<Uri, Func<HttpRequestMessage, Task<HttpResponseMessage>>> handlers; + + private readonly IHostFactories hostFactories; + + public ForwardingMessageHandler(Dictionary<Uri, Func<HttpRequestMessage, Task<HttpResponseMessage>>> handlers, IHostFactories hostFactories) { + Requires.NotNull(handlers, "handlers"); + + this.handlers = handlers; + this.hostFactories = hostFactories; + } + + protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { + foreach (var pair in this.handlers) { + if (pair.Key.IsBaseOf(request.RequestUri) && pair.Key.AbsolutePath == request.RequestUri.AbsolutePath) { + var response = await pair.Value(request); + if (response != null) { + if (response.RequestMessage == null) { + response.RequestMessage = request; + } + + return response; + } + } + } + + return new HttpResponseMessage(HttpStatusCode.NotFound); + } + } + } +}
\ No newline at end of file diff --git a/src/DotNetOpenAuth.Test/Mocks/CoordinatingChannel.cs b/src/DotNetOpenAuth.Test/Mocks/CoordinatingChannel.cs deleted file mode 100644 index 475f4b5..0000000 --- a/src/DotNetOpenAuth.Test/Mocks/CoordinatingChannel.cs +++ /dev/null @@ -1,348 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="CoordinatingChannel.cs" company="Outercurve Foundation"> -// Copyright (c) Outercurve Foundation. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.Test.Mocks { - using System; - using System.Collections.Generic; - using System.Linq; - using System.Net; - using System.Text; - using System.Threading; - using System.Web; - using DotNetOpenAuth.Messaging; - using DotNetOpenAuth.Messaging.Reflection; - using DotNetOpenAuth.Test.OpenId; - using NUnit.Framework; - using Validation; - - internal class CoordinatingChannel : Channel { - /// <summary> - /// A lock to use when checking and setting the <see cref="waitingForMessage"/> - /// or the <see cref="simulationCompleted"/> fields. - /// </summary> - /// <remarks> - /// This is a static member so that all coordinating channels share a lock - /// since they peak at each others fields. - /// </remarks> - private static readonly object waitingForMessageCoordinationLock = new object(); - - /// <summary> - /// The original product channel whose behavior is being modified to work - /// better in automated testing. - /// </summary> - private Channel wrappedChannel; - - /// <summary> - /// A flag set to true when this party in a two-party test has completed - /// its part of the testing. - /// </summary> - private bool simulationCompleted; - - /// <summary> - /// A thread-coordinating signal that is set when another thread has a - /// message ready for this channel to receive. - /// </summary> - private EventWaitHandle incomingMessageSignal = new AutoResetEvent(false); - - /// <summary> - /// A thread-coordinating signal that is set briefly by this thread whenever - /// a message is picked up. - /// </summary> - private EventWaitHandle messageReceivedSignal = new AutoResetEvent(false); - - /// <summary> - /// A flag used to indicate when this channel is waiting for a message - /// to arrive. - /// </summary> - private bool waitingForMessage; - - /// <summary> - /// An incoming message that has been posted by a remote channel and - /// is waiting for receipt by this channel. - /// </summary> - private IDictionary<string, string> incomingMessage; - - /// <summary> - /// The recipient URL of the <see cref="incomingMessage"/>, where applicable. - /// </summary> - private MessageReceivingEndpoint incomingMessageRecipient; - - /// <summary> - /// The headers of the <see cref="incomingMessage"/>, where applicable. - /// </summary> - private WebHeaderCollection incomingMessageHttpHeaders; - - /// <summary> - /// A delegate that gets a chance to peak at and fiddle with all - /// incoming messages. - /// </summary> - private Action<IProtocolMessage> incomingMessageFilter; - - /// <summary> - /// A delegate that gets a chance to peak at and fiddle with all - /// outgoing messages. - /// </summary> - private Action<IProtocolMessage> outgoingMessageFilter; - - /// <summary> - /// The simulated clients cookies. - /// </summary> - private HttpCookieCollection cookies = new HttpCookieCollection(); - - /// <summary> - /// Initializes a new instance of the <see cref="CoordinatingChannel"/> class. - /// </summary> - /// <param name="wrappedChannel">The wrapped channel. Must not be null.</param> - /// <param name="incomingMessageFilter">The incoming message filter. May be null.</param> - /// <param name="outgoingMessageFilter">The outgoing message filter. May be null.</param> - internal CoordinatingChannel(Channel wrappedChannel, Action<IProtocolMessage> incomingMessageFilter, Action<IProtocolMessage> outgoingMessageFilter) - : base(GetMessageFactory(wrappedChannel), wrappedChannel.BindingElements.ToArray()) { - Requires.NotNull(wrappedChannel, "wrappedChannel"); - - this.wrappedChannel = wrappedChannel; - this.incomingMessageFilter = incomingMessageFilter; - this.outgoingMessageFilter = outgoingMessageFilter; - - // Preserve any customized binding element ordering. - this.CustomizeBindingElementOrder(this.wrappedChannel.OutgoingBindingElements, this.wrappedChannel.IncomingBindingElements); - } - - /// <summary> - /// Gets or sets the coordinating channel used by the other party. - /// </summary> - internal CoordinatingChannel RemoteChannel { get; set; } - - /// <summary> - /// Indicates that the simulation that uses this channel has completed work. - /// </summary> - /// <remarks> - /// Calling this method is not strictly necessary, but it gives the channel - /// coordination a chance to recognize when another channel is left dangling - /// waiting for a message from another channel that may never come. - /// </remarks> - internal void Close() { - lock (waitingForMessageCoordinationLock) { - this.simulationCompleted = true; - if (this.RemoteChannel.waitingForMessage && this.RemoteChannel.incomingMessage == null) { - TestUtilities.TestLogger.Debug("CoordinatingChannel is closing while remote channel is waiting for an incoming message. Signaling channel to unblock it to receive a null message."); - this.RemoteChannel.incomingMessageSignal.Set(); - } - - this.Dispose(); - } - } - - /// <summary> - /// Replays the specified message as if it were received again. - /// </summary> - /// <param name="message">The message to replay.</param> - internal void Replay(IProtocolMessage message) { - this.ProcessIncomingMessage(this.CloneSerializedParts(message)); - } - - /// <summary> - /// Called from a remote party's thread to post a message to this channel for processing. - /// </summary> - /// <param name="message">The message that this channel should receive. This message will be cloned.</param> - internal void PostMessage(IProtocolMessage message) { - if (this.incomingMessage != null) { - // The remote party hasn't picked up the last message we sent them. - // Wait for a short period for them to pick it up before failing. - TestBase.TestLogger.Warn("We're blocked waiting to send a message to the remote party and they haven't processed the last message we sent them."); - this.RemoteChannel.messageReceivedSignal.WaitOne(500); - } - ErrorUtilities.VerifyInternal(this.incomingMessage == null, "Oops, a message is already waiting for the remote party!"); - this.incomingMessage = this.MessageDescriptions.GetAccessor(message).Serialize(); - var directedMessage = message as IDirectedProtocolMessage; - this.incomingMessageRecipient = (directedMessage != null && directedMessage.Recipient != null) ? new MessageReceivingEndpoint(directedMessage.Recipient, directedMessage.HttpMethods) : null; - var httpMessage = message as IHttpDirectRequest; - this.incomingMessageHttpHeaders = (httpMessage != null) ? httpMessage.Headers.Clone() : null; - this.incomingMessageSignal.Set(); - } - - internal void SaveCookies(HttpCookieCollection cookies) { - Requires.NotNull(cookies, "cookies"); - foreach (string cookieName in cookies) { - var cookie = cookies[cookieName]; - this.cookies.Set(cookie); - } - } - - protected internal override HttpRequestBase GetRequestFromContext() { - MessageReceivingEndpoint recipient; - WebHeaderCollection headers; - var messageData = this.AwaitIncomingMessage(out recipient, out headers); - CoordinatingHttpRequestInfo result; - if (messageData != null) { - result = new CoordinatingHttpRequestInfo(this, this.MessageFactory, messageData, recipient, this.cookies); - } else { - result = new CoordinatingHttpRequestInfo(recipient, this.cookies); - } - - if (headers != null) { - headers.ApplyTo(result.Headers); - } - - return result; - } - - protected override IProtocolMessage RequestCore(IDirectedProtocolMessage request) { - this.ProcessMessageFilter(request, true); - - // Drop the outgoing message in the other channel's in-slot and let them know it's there. - this.RemoteChannel.PostMessage(request); - - // Now wait for a response... - MessageReceivingEndpoint recipient; - WebHeaderCollection headers; - IDictionary<string, string> responseData = this.AwaitIncomingMessage(out recipient, out headers); - ErrorUtilities.VerifyInternal(recipient == null, "The recipient is expected to be null for direct responses."); - - // And deserialize it. - IDirectResponseProtocolMessage responseMessage = this.MessageFactory.GetNewResponseMessage(request, responseData); - if (responseMessage == null) { - return null; - } - - var responseAccessor = this.MessageDescriptions.GetAccessor(responseMessage); - responseAccessor.Deserialize(responseData); - var responseMessageHttpRequest = responseMessage as IHttpDirectRequest; - if (headers != null && responseMessageHttpRequest != null) { - headers.ApplyTo(responseMessageHttpRequest.Headers); - } - - this.ProcessMessageFilter(responseMessage, false); - return responseMessage; - } - - protected override OutgoingWebResponse PrepareDirectResponse(IProtocolMessage response) { - this.ProcessMessageFilter(response, true); - return new CoordinatingOutgoingWebResponse(response, this.RemoteChannel, this); - } - - protected override OutgoingWebResponse PrepareIndirectResponse(IDirectedProtocolMessage message) { - this.ProcessMessageFilter(message, true); - // In this mock transport, direct and indirect messages are the same. - return this.PrepareDirectResponse(message); - } - - protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestBase request) { - var mockRequest = (CoordinatingHttpRequestInfo)request; - if (mockRequest.Message != null) { - this.ProcessMessageFilter(mockRequest.Message, false); - } - - return mockRequest.Message; - } - - protected override IDictionary<string, string> ReadFromResponseCore(IncomingWebResponse response) { - return this.wrappedChannel.ReadFromResponseCoreTestHook(response); - } - - protected override void ProcessIncomingMessage(IProtocolMessage message) { - this.wrappedChannel.ProcessIncomingMessageTestHook(message); - } - - /// <summary> - /// Clones a message, instantiating the new instance using <i>this</i> channel's - /// message factory. - /// </summary> - /// <typeparam name="T">The type of message to clone.</typeparam> - /// <param name="message">The message to clone.</param> - /// <returns>The new instance of the message.</returns> - /// <remarks> - /// This Clone method should <i>not</i> be used to send message clones to the remote - /// channel since their message factory is not used. - /// </remarks> - protected virtual T CloneSerializedParts<T>(T message) where T : class, IProtocolMessage { - Requires.NotNull(message, "message"); - - IProtocolMessage clonedMessage; - var messageAccessor = this.MessageDescriptions.GetAccessor(message); - var fields = messageAccessor.Serialize(); - - MessageReceivingEndpoint recipient = null; - var directedMessage = message as IDirectedProtocolMessage; - var directResponse = message as IDirectResponseProtocolMessage; - if (directedMessage != null && directedMessage.IsRequest()) { - if (directedMessage.Recipient != null) { - recipient = new MessageReceivingEndpoint(directedMessage.Recipient, directedMessage.HttpMethods); - } - - clonedMessage = this.MessageFactory.GetNewRequestMessage(recipient, fields); - } else if (directResponse != null && directResponse.IsDirectResponse()) { - clonedMessage = this.MessageFactory.GetNewResponseMessage(directResponse.OriginatingRequest, fields); - } else { - throw new InvalidOperationException("Totally expected a message to implement one of the two derived interface types."); - } - - ErrorUtilities.VerifyInternal(clonedMessage != null, "Message factory did not generate a message instance for " + message.GetType().Name); - - // Fill the cloned message with data. - var clonedMessageAccessor = this.MessageDescriptions.GetAccessor(clonedMessage); - clonedMessageAccessor.Deserialize(fields); - - return (T)clonedMessage; - } - - private static IMessageFactory GetMessageFactory(Channel channel) { - Requires.NotNull(channel, "channel"); - - return channel.MessageFactoryTestHook; - } - - private IDictionary<string, string> AwaitIncomingMessage(out MessageReceivingEndpoint recipient, out WebHeaderCollection headers) { - // Special care should be taken so that we don't indefinitely - // wait for a message that may never come due to a bug in the product - // or the test. - // There are two scenarios that we need to watch out for: - // 1. Two channels are waiting to receive messages from each other. - // 2. One channel is waiting for a message that will never come because - // the remote party has already finished executing. - lock (waitingForMessageCoordinationLock) { - // It's possible that a message was just barely transmitted either to this - // or the remote channel. So it's ok for the remote channel to be waiting - // if either it or we are already about to receive a message. - ErrorUtilities.VerifyInternal(!this.RemoteChannel.waitingForMessage || this.RemoteChannel.incomingMessage != null || this.incomingMessage != null, "This channel is expecting an incoming message from another channel that is also blocked waiting for an incoming message from us!"); - - // It's permissible that the remote channel has already closed if it left a message - // for us already. - ErrorUtilities.VerifyInternal(!this.RemoteChannel.simulationCompleted || this.incomingMessage != null, "This channel is expecting an incoming message from another channel that has already been closed."); - this.waitingForMessage = true; - } - - this.incomingMessageSignal.WaitOne(); - - lock (waitingForMessageCoordinationLock) { - this.waitingForMessage = false; - var response = this.incomingMessage; - recipient = this.incomingMessageRecipient; - headers = this.incomingMessageHttpHeaders; - this.incomingMessage = null; - this.incomingMessageRecipient = null; - this.incomingMessageHttpHeaders = null; - - // Briefly signal to another thread that might be waiting for our inbox to be empty - this.messageReceivedSignal.Set(); - this.messageReceivedSignal.Reset(); - - return response; - } - } - - private void ProcessMessageFilter(IProtocolMessage message, bool outgoing) { - if (outgoing) { - if (this.outgoingMessageFilter != null) { - this.outgoingMessageFilter(message); - } - } else { - if (this.incomingMessageFilter != null) { - this.incomingMessageFilter(message); - } - } - } - } -} diff --git a/src/DotNetOpenAuth.Test/Mocks/CoordinatingHttpRequestInfo.cs b/src/DotNetOpenAuth.Test/Mocks/CoordinatingHttpRequestInfo.cs deleted file mode 100644 index 497503c..0000000 --- a/src/DotNetOpenAuth.Test/Mocks/CoordinatingHttpRequestInfo.cs +++ /dev/null @@ -1,109 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="CoordinatingHttpRequestInfo.cs" company="Outercurve Foundation"> -// Copyright (c) Outercurve Foundation. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.Test.Mocks { - using System; - using System.Collections.Generic; - using System.Net; - using System.Web; - - using DotNetOpenAuth.Messaging; - using Validation; - - internal class CoordinatingHttpRequestInfo : HttpRequestInfo { - private readonly Channel channel; - - private readonly IDictionary<string, string> messageData; - - private readonly IMessageFactory messageFactory; - - private readonly MessageReceivingEndpoint recipient; - - private IDirectedProtocolMessage message; - - /// <summary> - /// Initializes a new instance of the <see cref="CoordinatingHttpRequestInfo"/> class - /// that will generate a message when the <see cref="Message"/> property getter is called. - /// </summary> - /// <param name="channel">The channel.</param> - /// <param name="messageFactory">The message factory.</param> - /// <param name="messageData">The message data.</param> - /// <param name="recipient">The recipient.</param> - /// <param name="cookies">Cookies included in the incoming request.</param> - internal CoordinatingHttpRequestInfo( - Channel channel, - IMessageFactory messageFactory, - IDictionary<string, string> messageData, - MessageReceivingEndpoint recipient, - HttpCookieCollection cookies) - : this(recipient, cookies) { - Requires.NotNull(channel, "channel"); - Requires.NotNull(messageFactory, "messageFactory"); - Requires.NotNull(messageData, "messageData"); - this.channel = channel; - this.messageData = messageData; - this.messageFactory = messageFactory; - } - - /// <summary> - /// Initializes a new instance of the <see cref="CoordinatingHttpRequestInfo"/> class - /// that will not generate any message. - /// </summary> - /// <param name="recipient">The recipient.</param> - /// <param name="cookies">Cookies included in the incoming request.</param> - internal CoordinatingHttpRequestInfo(MessageReceivingEndpoint recipient, HttpCookieCollection cookies) - : base(GetHttpVerb(recipient), recipient != null ? recipient.Location : new Uri("http://host/path"), cookies: cookies) { - this.recipient = recipient; - } - - /// <summary> - /// Initializes a new instance of the <see cref="CoordinatingHttpRequestInfo"/> class. - /// </summary> - /// <param name="message">The message being passed in through a mock transport. May be null.</param> - /// <param name="httpMethod">The HTTP method that the incoming request came in on, whether or not <paramref name="message"/> is null.</param> - internal CoordinatingHttpRequestInfo(IDirectedProtocolMessage message, HttpDeliveryMethods httpMethod) - : base(GetHttpVerb(httpMethod), message.Recipient) { - this.message = message; - } - - /// <summary> - /// Gets the message deserialized from the remote channel. - /// </summary> - internal IDirectedProtocolMessage Message { - get { - if (this.message == null && this.messageData != null) { - var message = this.messageFactory.GetNewRequestMessage(this.recipient, this.messageData); - if (message != null) { - this.channel.MessageDescriptions.GetAccessor(message).Deserialize(this.messageData); - this.message = message; - } - } - - return this.message; - } - } - - private static string GetHttpVerb(MessageReceivingEndpoint recipient) { - if (recipient == null) { - return "GET"; - } - - return GetHttpVerb(recipient.AllowedMethods); - } - - private static string GetHttpVerb(HttpDeliveryMethods httpMethod) { - if ((httpMethod & HttpDeliveryMethods.GetRequest) != 0) { - return "GET"; - } - - if ((httpMethod & HttpDeliveryMethods.PostRequest) != 0) { - return "POST"; - } - - throw new ArgumentOutOfRangeException(); - } - } -} diff --git a/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuth2AuthServerChannel.cs b/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuth2AuthServerChannel.cs deleted file mode 100644 index 463b149..0000000 --- a/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuth2AuthServerChannel.cs +++ /dev/null @@ -1,33 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="CoordinatingOAuth2AuthServerChannel.cs" company="Outercurve Foundation"> -// Copyright (c) Outercurve Foundation. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.Test.Mocks { - using System; - using System.Collections.Generic; - using System.Linq; - using System.Text; - using DotNetOpenAuth.Messaging; - using DotNetOpenAuth.OAuth2; - using DotNetOpenAuth.OAuth2.ChannelElements; - - internal class CoordinatingOAuth2AuthServerChannel : CoordinatingChannel, IOAuth2ChannelWithAuthorizationServer { - private OAuth2AuthorizationServerChannel wrappedChannel; - - internal CoordinatingOAuth2AuthServerChannel(Channel wrappedChannel, Action<IProtocolMessage> incomingMessageFilter, Action<IProtocolMessage> outgoingMessageFilter) - : base(wrappedChannel, incomingMessageFilter, outgoingMessageFilter) { - this.wrappedChannel = (OAuth2AuthorizationServerChannel)wrappedChannel; - } - - public IAuthorizationServerHost AuthorizationServer { - get { return this.wrappedChannel.AuthorizationServer; } - } - - public IScopeSatisfiedCheck ScopeSatisfiedCheck { - get { return this.wrappedChannel.ScopeSatisfiedCheck; } - set { this.wrappedChannel.ScopeSatisfiedCheck = value; } - } - } -} diff --git a/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuth2ClientChannel.cs b/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuth2ClientChannel.cs deleted file mode 100644 index 96091ac..0000000 --- a/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuth2ClientChannel.cs +++ /dev/null @@ -1,37 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="CoordinatingOAuth2ClientChannel.cs" company="Andrew Arnott"> -// Copyright (c) Andrew Arnott. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.Test.Mocks { - using System; - using System.Collections.Generic; - using System.Linq; - using System.Text; - using DotNetOpenAuth.Messaging; - using DotNetOpenAuth.OAuth2.ChannelElements; - - internal class CoordinatingOAuth2ClientChannel : CoordinatingChannel, IOAuth2ChannelWithClient { - private OAuth2ClientChannel wrappedChannel; - - internal CoordinatingOAuth2ClientChannel(Channel wrappedChannel, Action<IProtocolMessage> incomingMessageFilter, Action<IProtocolMessage> outgoingMessageFilter) - : base(wrappedChannel, incomingMessageFilter, outgoingMessageFilter) { - this.wrappedChannel = (OAuth2ClientChannel)wrappedChannel; - } - - public string ClientIdentifier { - get { return this.wrappedChannel.ClientIdentifier; } - set { this.wrappedChannel.ClientIdentifier = value; } - } - - public DotNetOpenAuth.OAuth2.ClientCredentialApplicator ClientCredentialApplicator { - get { return this.wrappedChannel.ClientCredentialApplicator; } - set { this.wrappedChannel.ClientCredentialApplicator = value; } - } - - public System.Xml.XmlDictionaryReaderQuotas JsonReaderQuotas { - get { return this.XmlDictionaryReaderQuotas; } - } - } -}
\ No newline at end of file diff --git a/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthConsumerChannel.cs b/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthConsumerChannel.cs deleted file mode 100644 index 9b552d3..0000000 --- a/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthConsumerChannel.cs +++ /dev/null @@ -1,155 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="CoordinatingOAuthConsumerChannel.cs" company="Outercurve Foundation"> -// Copyright (c) Outercurve Foundation. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.Test.Mocks { - using System; - using System.Threading; - using System.Web; - - using DotNetOpenAuth.Messaging; - using DotNetOpenAuth.Messaging.Bindings; - using DotNetOpenAuth.OAuth.ChannelElements; - using DotNetOpenAuth.OAuth.Messages; - using Validation; - - /// <summary> - /// A special channel used in test simulations to pass messages directly between two parties. - /// </summary> - internal class CoordinatingOAuthConsumerChannel : OAuthConsumerChannel { - private EventWaitHandle incomingMessageSignal = new AutoResetEvent(false); - - /// <summary> - /// Initializes a new instance of the <see cref="CoordinatingOAuthConsumerChannel"/> class. - /// </summary> - /// <param name="signingBindingElement">The signing element for the Consumer to use. Null for the Service Provider.</param> - /// <param name="tokenManager">The token manager to use.</param> - /// <param name="securitySettings">The security settings.</param> - internal CoordinatingOAuthConsumerChannel(ITamperProtectionChannelBindingElement signingBindingElement, IConsumerTokenManager tokenManager, DotNetOpenAuth.OAuth.ConsumerSecuritySettings securitySettings) - : base( - signingBindingElement, - new NonceMemoryStore(StandardExpirationBindingElement.MaximumMessageAge), - tokenManager, - securitySettings) { - } - - internal EventWaitHandle IncomingMessageSignal { - get { return this.incomingMessageSignal; } - } - - internal IProtocolMessage IncomingMessage { get; set; } - - internal OutgoingWebResponse IncomingRawResponse { get; set; } - - /// <summary> - /// Gets or sets the coordinating channel used by the other party. - /// </summary> - internal CoordinatingOAuthServiceProviderChannel RemoteChannel { get; set; } - - internal OutgoingWebResponse RequestProtectedResource(AccessProtectedResourceRequest request) { - ((ITamperResistantOAuthMessage)request).HttpMethod = this.GetHttpMethod(((ITamperResistantOAuthMessage)request).HttpMethods); - this.ProcessOutgoingMessage(request); - var requestInfo = this.SpoofHttpMethod(request); - TestBase.TestLogger.InfoFormat("Sending protected resource request: {0}", requestInfo.Message); - // Drop the outgoing message in the other channel's in-slot and let them know it's there. - this.RemoteChannel.IncomingMessage = requestInfo.Message; - this.RemoteChannel.IncomingMessageSignal.Set(); - return this.AwaitIncomingRawResponse(); - } - - protected internal override HttpRequestBase GetRequestFromContext() { - var directedMessage = (IDirectedProtocolMessage)this.AwaitIncomingMessage(); - return new CoordinatingHttpRequestInfo(directedMessage, directedMessage.HttpMethods); - } - - protected override IProtocolMessage RequestCore(IDirectedProtocolMessage request) { - var requestInfo = this.SpoofHttpMethod(request); - // Drop the outgoing message in the other channel's in-slot and let them know it's there. - this.RemoteChannel.IncomingMessage = requestInfo.Message; - this.RemoteChannel.IncomingMessageSignal.Set(); - // Now wait for a response... - return this.AwaitIncomingMessage(); - } - - protected override OutgoingWebResponse PrepareDirectResponse(IProtocolMessage response) { - this.RemoteChannel.IncomingMessage = this.CloneSerializedParts(response); - this.RemoteChannel.IncomingMessageSignal.Set(); - return new OutgoingWebResponse(); // not used, but returning null is not allowed - } - - protected override OutgoingWebResponse PrepareIndirectResponse(IDirectedProtocolMessage message) { - // In this mock transport, direct and indirect messages are the same. - return this.PrepareDirectResponse(message); - } - - protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestBase request) { - var mockRequest = (CoordinatingHttpRequestInfo)request; - return mockRequest.Message; - } - - /// <summary> - /// Spoof HTTP request information for signing/verification purposes. - /// </summary> - /// <param name="message">The message to add a pretend HTTP method to.</param> - /// <returns>A spoofed HttpRequestInfo that wraps the new message.</returns> - private CoordinatingHttpRequestInfo SpoofHttpMethod(IDirectedProtocolMessage message) { - var signedMessage = message as ITamperResistantOAuthMessage; - if (signedMessage != null) { - string httpMethod = this.GetHttpMethod(signedMessage.HttpMethods); - signedMessage.HttpMethod = httpMethod; - } - - var requestInfo = new CoordinatingHttpRequestInfo(this.CloneSerializedParts(message), message.HttpMethods); - return requestInfo; - } - - private IProtocolMessage AwaitIncomingMessage() { - this.incomingMessageSignal.WaitOne(); - IProtocolMessage response = this.IncomingMessage; - this.IncomingMessage = null; - return response; - } - - private OutgoingWebResponse AwaitIncomingRawResponse() { - this.incomingMessageSignal.WaitOne(); - OutgoingWebResponse response = this.IncomingRawResponse; - this.IncomingRawResponse = null; - return response; - } - - private T CloneSerializedParts<T>(T message) where T : class, IProtocolMessage { - Requires.NotNull(message, "message"); - - IProtocolMessage clonedMessage; - var messageAccessor = this.MessageDescriptions.GetAccessor(message); - var fields = messageAccessor.Serialize(); - - MessageReceivingEndpoint recipient = null; - var directedMessage = message as IDirectedProtocolMessage; - var directResponse = message as IDirectResponseProtocolMessage; - if (directedMessage != null && directedMessage.IsRequest()) { - if (directedMessage.Recipient != null) { - recipient = new MessageReceivingEndpoint(directedMessage.Recipient, directedMessage.HttpMethods); - } - - clonedMessage = this.RemoteChannel.MessageFactoryTestHook.GetNewRequestMessage(recipient, fields); - } else if (directResponse != null && directResponse.IsDirectResponse()) { - clonedMessage = this.RemoteChannel.MessageFactoryTestHook.GetNewResponseMessage(directResponse.OriginatingRequest, fields); - } else { - throw new InvalidOperationException("Totally expected a message to implement one of the two derived interface types."); - } - - // Fill the cloned message with data. - var clonedMessageAccessor = this.MessageDescriptions.GetAccessor(clonedMessage); - clonedMessageAccessor.Deserialize(fields); - - return (T)clonedMessage; - } - - private string GetHttpMethod(HttpDeliveryMethods methods) { - return (methods & HttpDeliveryMethods.PostRequest) != 0 ? "POST" : "GET"; - } - } -} diff --git a/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthServiceProviderChannel.cs b/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthServiceProviderChannel.cs deleted file mode 100644 index a6f2a7f..0000000 --- a/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthServiceProviderChannel.cs +++ /dev/null @@ -1,163 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="CoordinatingOAuthServiceProviderChannel.cs" company="Outercurve Foundation"> -// Copyright (c) Outercurve Foundation. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.Test.Mocks { - using System; - using System.Threading; - using System.Web; - - using DotNetOpenAuth.Messaging; - using DotNetOpenAuth.Messaging.Bindings; - using DotNetOpenAuth.OAuth.ChannelElements; - using DotNetOpenAuth.OAuth.Messages; - using NUnit.Framework; - using Validation; - - /// <summary> - /// A special channel used in test simulations to pass messages directly between two parties. - /// </summary> - internal class CoordinatingOAuthServiceProviderChannel : OAuthServiceProviderChannel { - private EventWaitHandle incomingMessageSignal = new AutoResetEvent(false); - - /// <summary> - /// Initializes a new instance of the <see cref="CoordinatingOAuthServiceProviderChannel"/> class. - /// </summary> - /// <param name="signingBindingElement">The signing element for the Consumer to use. Null for the Service Provider.</param> - /// <param name="tokenManager">The token manager to use.</param> - /// <param name="securitySettings">The security settings.</param> - internal CoordinatingOAuthServiceProviderChannel(ITamperProtectionChannelBindingElement signingBindingElement, IServiceProviderTokenManager tokenManager, DotNetOpenAuth.OAuth.ServiceProviderSecuritySettings securitySettings) - : base( - signingBindingElement, - new NonceMemoryStore(StandardExpirationBindingElement.MaximumMessageAge), - tokenManager, - securitySettings, - new OAuthServiceProviderMessageFactory(tokenManager)) { - } - - internal EventWaitHandle IncomingMessageSignal { - get { return this.incomingMessageSignal; } - } - - internal IProtocolMessage IncomingMessage { get; set; } - - internal OutgoingWebResponse IncomingRawResponse { get; set; } - - /// <summary> - /// Gets or sets the coordinating channel used by the other party. - /// </summary> - internal CoordinatingOAuthConsumerChannel RemoteChannel { get; set; } - - internal OutgoingWebResponse RequestProtectedResource(AccessProtectedResourceRequest request) { - ((ITamperResistantOAuthMessage)request).HttpMethod = GetHttpMethod(((ITamperResistantOAuthMessage)request).HttpMethods); - this.ProcessOutgoingMessage(request); - var requestInfo = this.SpoofHttpMethod(request); - TestBase.TestLogger.InfoFormat("Sending protected resource request: {0}", requestInfo.Message); - // Drop the outgoing message in the other channel's in-slot and let them know it's there. - this.RemoteChannel.IncomingMessage = requestInfo.Message; - this.RemoteChannel.IncomingMessageSignal.Set(); - return this.AwaitIncomingRawResponse(); - } - - internal void SendDirectRawResponse(OutgoingWebResponse response) { - this.RemoteChannel.IncomingRawResponse = response; - this.RemoteChannel.IncomingMessageSignal.Set(); - } - - protected internal override HttpRequestBase GetRequestFromContext() { - var directedMessage = (IDirectedProtocolMessage)this.AwaitIncomingMessage(); - return new CoordinatingHttpRequestInfo(directedMessage, directedMessage.HttpMethods); - } - - protected override IProtocolMessage RequestCore(IDirectedProtocolMessage request) { - var requestInfo = this.SpoofHttpMethod(request); - // Drop the outgoing message in the other channel's in-slot and let them know it's there. - this.RemoteChannel.IncomingMessage = requestInfo.Message; - this.RemoteChannel.IncomingMessageSignal.Set(); - // Now wait for a response... - return this.AwaitIncomingMessage(); - } - - protected override OutgoingWebResponse PrepareDirectResponse(IProtocolMessage response) { - this.RemoteChannel.IncomingMessage = this.CloneSerializedParts(response); - this.RemoteChannel.IncomingMessageSignal.Set(); - return new OutgoingWebResponse(); // not used, but returning null is not allowed - } - - protected override OutgoingWebResponse PrepareIndirectResponse(IDirectedProtocolMessage message) { - // In this mock transport, direct and indirect messages are the same. - return this.PrepareDirectResponse(message); - } - - protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestBase request) { - var mockRequest = (CoordinatingHttpRequestInfo)request; - return mockRequest.Message; - } - - private static string GetHttpMethod(HttpDeliveryMethods methods) { - return (methods & HttpDeliveryMethods.PostRequest) != 0 ? "POST" : "GET"; - } - - /// <summary> - /// Spoof HTTP request information for signing/verification purposes. - /// </summary> - /// <param name="message">The message to add a pretend HTTP method to.</param> - /// <returns>A spoofed HttpRequestInfo that wraps the new message.</returns> - private CoordinatingHttpRequestInfo SpoofHttpMethod(IDirectedProtocolMessage message) { - var signedMessage = message as ITamperResistantOAuthMessage; - if (signedMessage != null) { - string httpMethod = GetHttpMethod(signedMessage.HttpMethods); - signedMessage.HttpMethod = httpMethod; - } - - var requestInfo = new CoordinatingHttpRequestInfo(this.CloneSerializedParts(message), message.HttpMethods); - return requestInfo; - } - - private IProtocolMessage AwaitIncomingMessage() { - this.IncomingMessageSignal.WaitOne(); - Assert.That(this.IncomingMessage, Is.Not.Null, "Incoming message signaled, but none supplied."); - IProtocolMessage response = this.IncomingMessage; - this.IncomingMessage = null; - return response; - } - - private OutgoingWebResponse AwaitIncomingRawResponse() { - this.IncomingMessageSignal.WaitOne(); - OutgoingWebResponse response = this.IncomingRawResponse; - this.IncomingRawResponse = null; - return response; - } - - private T CloneSerializedParts<T>(T message) where T : class, IProtocolMessage { - Requires.NotNull(message, "message"); - - IProtocolMessage clonedMessage; - var messageAccessor = this.MessageDescriptions.GetAccessor(message); - var fields = messageAccessor.Serialize(); - - MessageReceivingEndpoint recipient = null; - var directedMessage = message as IDirectedProtocolMessage; - var directResponse = message as IDirectResponseProtocolMessage; - if (directedMessage != null && directedMessage.IsRequest()) { - if (directedMessage.Recipient != null) { - recipient = new MessageReceivingEndpoint(directedMessage.Recipient, directedMessage.HttpMethods); - } - - clonedMessage = this.RemoteChannel.MessageFactoryTestHook.GetNewRequestMessage(recipient, fields); - } else if (directResponse != null && directResponse.IsDirectResponse()) { - clonedMessage = this.RemoteChannel.MessageFactoryTestHook.GetNewResponseMessage(directResponse.OriginatingRequest, fields); - } else { - throw new InvalidOperationException("Totally expected a message to implement one of the two derived interface types."); - } - - // Fill the cloned message with data. - var clonedMessageAccessor = this.MessageDescriptions.GetAccessor(clonedMessage); - clonedMessageAccessor.Deserialize(fields); - - return (T)clonedMessage; - } - } -} diff --git a/src/DotNetOpenAuth.Test/Mocks/CoordinatingOutgoingWebResponse.cs b/src/DotNetOpenAuth.Test/Mocks/CoordinatingOutgoingWebResponse.cs deleted file mode 100644 index 9df791c..0000000 --- a/src/DotNetOpenAuth.Test/Mocks/CoordinatingOutgoingWebResponse.cs +++ /dev/null @@ -1,47 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="CoordinatingOutgoingWebResponse.cs" company="Outercurve Foundation"> -// Copyright (c) Outercurve Foundation. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.Test.Mocks { - using System; - using System.Collections.Generic; - using System.ComponentModel; - using System.Linq; - using System.Text; - using DotNetOpenAuth.Messaging; - using Validation; - - internal class CoordinatingOutgoingWebResponse : OutgoingWebResponse { - private CoordinatingChannel receivingChannel; - - private CoordinatingChannel sendingChannel; - - /// <summary> - /// Initializes a new instance of the <see cref="CoordinatingOutgoingWebResponse"/> class. - /// </summary> - /// <param name="message">The direct response message to send to the remote channel. This message will be cloned.</param> - /// <param name="receivingChannel">The receiving channel.</param> - /// <param name="sendingChannel">The sending channel.</param> - internal CoordinatingOutgoingWebResponse(IProtocolMessage message, CoordinatingChannel receivingChannel, CoordinatingChannel sendingChannel) { - Requires.NotNull(message, "message"); - Requires.NotNull(receivingChannel, "receivingChannel"); - Requires.NotNull(sendingChannel, "sendingChannel"); - - this.receivingChannel = receivingChannel; - this.sendingChannel = sendingChannel; - this.OriginalMessage = message; - } - - [EditorBrowsable(EditorBrowsableState.Never)] - public override void Send() { - this.Respond(); - } - - public override void Respond() { - this.sendingChannel.SaveCookies(this.Cookies); - this.receivingChannel.PostMessage(this.OriginalMessage); - } - } -} diff --git a/src/DotNetOpenAuth.Test/Mocks/InMemoryTokenManager.cs b/src/DotNetOpenAuth.Test/Mocks/InMemoryTokenManager.cs index 494a1c1..7fac125 100644 --- a/src/DotNetOpenAuth.Test/Mocks/InMemoryTokenManager.cs +++ b/src/DotNetOpenAuth.Test/Mocks/InMemoryTokenManager.cs @@ -14,7 +14,7 @@ namespace DotNetOpenAuth.Test.Mocks { using DotNetOpenAuth.OAuth.Messages; using DotNetOpenAuth.Test.OAuth; - internal class InMemoryTokenManager : IConsumerTokenManager, IServiceProviderTokenManager { + internal class InMemoryTokenManager : IServiceProviderTokenManager { private KeyedCollectionDelegate<string, ConsumerInfo> consumers = new KeyedCollectionDelegate<string, ConsumerInfo>(c => c.Key); private KeyedCollectionDelegate<string, TokenInfo> tokens = new KeyedCollectionDelegate<string, TokenInfo>(t => t.Token); @@ -28,18 +28,6 @@ namespace DotNetOpenAuth.Test.Mocks { /// </summary> private List<string> accessTokens = new List<string>(); - #region IConsumerTokenManager Members - - public string ConsumerKey { - get { return this.consumers.Single().Key; } - } - - public string ConsumerSecret { - get { return this.consumers.Single().Secret; } - } - - #endregion - #region ITokenManager Members public string GetTokenSecret(string token) { diff --git a/src/DotNetOpenAuth.Test/Mocks/MockHttpRequest.cs b/src/DotNetOpenAuth.Test/Mocks/MockHttpRequest.cs index d20671e..349be56 100644 --- a/src/DotNetOpenAuth.Test/Mocks/MockHttpRequest.cs +++ b/src/DotNetOpenAuth.Test/Mocks/MockHttpRequest.cs @@ -10,7 +10,9 @@ namespace DotNetOpenAuth.Test.Mocks { using System.Globalization; using System.IO; using System.Net; + using System.Net.Http; using System.Text; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; @@ -19,71 +21,8 @@ namespace DotNetOpenAuth.Test.Mocks { using DotNetOpenAuth.Yadis; using Validation; - internal class MockHttpRequest { - private readonly Dictionary<Uri, IncomingWebResponse> registeredMockResponses = new Dictionary<Uri, IncomingWebResponse>(); - - private MockHttpRequest(IDirectWebRequestHandler mockHandler) { - Requires.NotNull(mockHandler, "mockHandler"); - this.MockWebRequestHandler = mockHandler; - } - - internal IDirectWebRequestHandler MockWebRequestHandler { get; private set; } - - internal static MockHttpRequest CreateUntrustedMockHttpHandler() { - TestWebRequestHandler testHandler = new TestWebRequestHandler(); - UntrustedWebRequestHandler untrustedHandler = new UntrustedWebRequestHandler(testHandler); - if (!untrustedHandler.WhitelistHosts.Contains("localhost")) { - untrustedHandler.WhitelistHosts.Add("localhost"); - } - untrustedHandler.WhitelistHosts.Add(OpenIdTestBase.OPUri.Host); - MockHttpRequest mock = new MockHttpRequest(untrustedHandler); - testHandler.Callback = mock.GetMockResponse; - return mock; - } - - internal void RegisterMockResponse(IncomingWebResponse response) { - Requires.NotNull(response, "response"); - if (this.registeredMockResponses.ContainsKey(response.RequestUri)) { - Logger.Http.WarnFormat("Mock HTTP response already registered for {0}.", response.RequestUri); - } else { - this.registeredMockResponses.Add(response.RequestUri, response); - } - } - - internal void RegisterMockResponse(Uri requestUri, string contentType, string responseBody) { - this.RegisterMockResponse(requestUri, requestUri, contentType, responseBody); - } - - internal void RegisterMockResponse(Uri requestUri, Uri responseUri, string contentType, string responseBody) { - this.RegisterMockResponse(requestUri, responseUri, contentType, new WebHeaderCollection(), responseBody); - } - - internal void RegisterMockResponse(Uri requestUri, Uri responseUri, string contentType, WebHeaderCollection headers, string responseBody) { - Requires.NotNull(requestUri, "requestUri"); - Requires.NotNull(responseUri, "responseUri"); - Requires.NotNullOrEmpty(contentType, "contentType"); - - // Set up the redirect if appropriate - if (requestUri != responseUri) { - this.RegisterMockRedirect(requestUri, responseUri); - } - - string contentEncoding = null; - MemoryStream stream = new MemoryStream(); - StreamWriter sw = new StreamWriter(stream); - sw.Write(responseBody); - sw.Flush(); - stream.Seek(0, SeekOrigin.Begin); - this.RegisterMockResponse(new CachedDirectWebResponse(responseUri, responseUri, headers ?? new WebHeaderCollection(), HttpStatusCode.OK, contentType, contentEncoding, stream)); - } - - internal void RegisterMockXrdsResponses(IDictionary<string, string> requestUriAndResponseBody) { - foreach (var pair in requestUriAndResponseBody) { - this.RegisterMockResponse(new Uri(pair.Key), "text/xml; saml=false; https=false; charset=UTF-8", pair.Value); - } - } - - internal void RegisterMockXrdsResponse(IdentifierDiscoveryResult endpoint) { + internal static class MockHttpRequest { + internal static void RegisterMockXrdsResponse(this TestBase test, IdentifierDiscoveryResult endpoint) { Requires.NotNull(endpoint, "endpoint"); string identityUri; @@ -92,13 +31,14 @@ namespace DotNetOpenAuth.Test.Mocks { } else { identityUri = endpoint.UserSuppliedIdentifier ?? endpoint.ClaimedIdentifier; } - this.RegisterMockXrdsResponse(new Uri(identityUri), new IdentifierDiscoveryResult[] { endpoint }); + + RegisterMockXrdsResponse(test, new Uri(identityUri), new IdentifierDiscoveryResult[] { endpoint }); } - internal void RegisterMockXrdsResponse(Uri respondingUri, IEnumerable<IdentifierDiscoveryResult> endpoints) { + internal static void RegisterMockXrdsResponse(this TestBase test, Uri respondingUri, IEnumerable<IdentifierDiscoveryResult> endpoints) { Requires.NotNull(endpoints, "endpoints"); - StringBuilder xrds = new StringBuilder(); + var xrds = new StringBuilder(); xrds.AppendLine(@"<xrds:XRDS xmlns:xrds='xri://$xrds' xmlns:openid='http://openid.net/xmlns/1.0' xmlns='xri://$xrd*($v*2.0)'> <XRD>"); foreach (var endpoint in endpoints) { @@ -127,10 +67,10 @@ namespace DotNetOpenAuth.Test.Mocks { </XRD> </xrds:XRDS>"); - this.RegisterMockResponse(respondingUri, ContentTypes.Xrds, xrds.ToString()); + test.Handle(respondingUri).By(xrds.ToString(), ContentTypes.Xrds); } - internal void RegisterMockXrdsResponse(UriIdentifier directedIdentityAssignedIdentifier, IdentifierDiscoveryResult providerEndpoint) { + internal static void RegisterMockXrdsResponse(this TestBase test, UriIdentifier directedIdentityAssignedIdentifier, IdentifierDiscoveryResult providerEndpoint) { IdentifierDiscoveryResult identityEndpoint = IdentifierDiscoveryResult.CreateForClaimedIdentifier( directedIdentityAssignedIdentifier, directedIdentityAssignedIdentifier, @@ -138,16 +78,16 @@ namespace DotNetOpenAuth.Test.Mocks { new ProviderEndpointDescription(providerEndpoint.ProviderEndpoint, providerEndpoint.Capabilities), 10, 10); - this.RegisterMockXrdsResponse(identityEndpoint); + RegisterMockXrdsResponse(test, identityEndpoint); } - internal Identifier RegisterMockXrdsResponse(string embeddedResourcePath) { - UriIdentifier id = new Uri(new Uri("http://localhost/"), embeddedResourcePath); - this.RegisterMockResponse(id, "application/xrds+xml", OpenIdTestBase.LoadEmbeddedFile(embeddedResourcePath)); - return id; + internal static void RegisterMockXrdsResponse(this TestBase test, string embeddedResourcePath, out Identifier id) { + id = new Uri(new Uri("http://localhost/"), embeddedResourcePath); + test.Handle(new Uri(id)) + .By(OpenIdTestBase.LoadEmbeddedFile(embeddedResourcePath), "application/xrds+xml"); } - internal void RegisterMockRPDiscovery() { + internal static void RegisterMockRPDiscovery(this TestBase test, bool ssl) { string template = @"<xrds:XRDS xmlns:xrds='xri://$xrds' xmlns:openid='http://openid.net/xmlns/1.0' xmlns='xri://$xrd*($v*2.0)'> <XRD> <Service priority='10'> @@ -164,44 +104,62 @@ namespace DotNetOpenAuth.Test.Mocks { HttpUtility.HtmlEncode(OpenIdTestBase.RPRealmUri.AbsoluteUri), HttpUtility.HtmlEncode(OpenIdTestBase.RPRealmUriSsl.AbsoluteUri)); - this.RegisterMockResponse(OpenIdTestBase.RPRealmUri, ContentTypes.Xrds, xrds); - this.RegisterMockResponse(OpenIdTestBase.RPRealmUriSsl, ContentTypes.Xrds, xrds); + test.Handle(ssl ? OpenIdTestBase.RPRealmUriSsl : OpenIdTestBase.RPRealmUri) + .By(xrds, ContentTypes.Xrds); } - internal void DeleteResponse(Uri requestUri) { - this.registeredMockResponses.Remove(requestUri); + internal static void RegisterMockRedirect(this TestBase test, Uri origin, Uri redirectLocation) { + var response = new HttpResponseMessage(HttpStatusCode.Redirect); + response.Headers.Location = redirectLocation; + test.Handle(origin).By(req => response); } - internal void RegisterMockRedirect(Uri origin, Uri redirectLocation) { - var redirectionHeaders = new WebHeaderCollection { - { HttpResponseHeader.Location, redirectLocation.AbsoluteUri }, - }; - IncomingWebResponse response = new CachedDirectWebResponse(origin, origin, redirectionHeaders, HttpStatusCode.Redirect, null, null, new MemoryStream()); - this.RegisterMockResponse(response); + internal static void RegisterMockXrdsResponses(this TestBase test, IEnumerable<KeyValuePair<string, string>> urlXrdsPairs) { + Requires.NotNull(urlXrdsPairs, "urlXrdsPairs"); + + foreach (var keyValuePair in urlXrdsPairs) { + test.Handle(new Uri(keyValuePair.Key)).By(keyValuePair.Value, ContentTypes.Xrds); + } } - internal void RegisterMockNotFound(Uri requestUri) { - CachedDirectWebResponse errorResponse = new CachedDirectWebResponse( - requestUri, - requestUri, - new WebHeaderCollection(), - HttpStatusCode.NotFound, - "text/plain", - Encoding.UTF8.WebName, - new MemoryStream(Encoding.UTF8.GetBytes("Not found."))); - this.RegisterMockResponse(errorResponse); + internal static void RegisterMockResponse(this TestBase test, Uri url, string contentType, string content) { + test.Handle(url).By(content, contentType); } - private IncomingWebResponse GetMockResponse(HttpWebRequest request) { - IncomingWebResponse response; - if (this.registeredMockResponses.TryGetValue(request.RequestUri, out response)) { - // reset response stream position so this response can be reused on a subsequent request. - response.ResponseStream.Seek(0, SeekOrigin.Begin); + internal static void RegisterMockResponse(this TestBase test, Uri requestUri, Uri responseUri, string contentType, string content) { + RegisterMockResponse(test, requestUri, responseUri, contentType, null, content); + } + + internal static void RegisterMockResponse(this TestBase test, Uri requestUri, Uri responseUri, string contentType, WebHeaderCollection headers, string content) { + Requires.NotNull(requestUri, "requestUri"); + Requires.NotNull(responseUri, "responseUri"); + Requires.NotNullOrEmpty(contentType, "contentType"); + + test.Handle(requestUri).By(req => { + var response = new HttpResponseMessage(); + response.RequestMessage = req; + + if (requestUri != responseUri) { + // Simulate having followed redirects to get the final response. + var clonedRequest = MessagingUtilities.Clone(req); + clonedRequest.RequestUri = responseUri; + response.RequestMessage = clonedRequest; + } + + response.CopyHeadersFrom(headers); + response.Content = new StringContent(content, Encoding.Default, contentType); return response; - } else { - ////Assert.Fail("Unexpected HTTP request: {0}", uri); - Logger.Http.WarnFormat("Unexpected HTTP request: {0}", request.RequestUri); - return new CachedDirectWebResponse(request.RequestUri, request.RequestUri, new WebHeaderCollection(), HttpStatusCode.NotFound, "text/html", null, new MemoryStream()); + }); + } + + private static void CopyHeadersFrom(this HttpResponseMessage message, WebHeaderCollection headers) { + if (headers != null) { + foreach (string headerName in headers) { + string[] headerValues = headers.GetValues(headerName); + if (!message.Headers.TryAddWithoutValidation(headerName, headerValues)) { + message.Content.Headers.TryAddWithoutValidation(headerName, headerValues); + } + } } } } diff --git a/src/DotNetOpenAuth.Test/Mocks/MockIdentifier.cs b/src/DotNetOpenAuth.Test/Mocks/MockIdentifier.cs deleted file mode 100644 index f020923..0000000 --- a/src/DotNetOpenAuth.Test/Mocks/MockIdentifier.cs +++ /dev/null @@ -1,71 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="MockIdentifier.cs" company="Outercurve Foundation"> -// Copyright (c) Outercurve Foundation. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.Test.Mocks { - using System; - using System.Collections.Generic; - using DotNetOpenAuth.Messaging; - using DotNetOpenAuth.OpenId; - using DotNetOpenAuth.OpenId.RelyingParty; - using Validation; - - /// <summary> - /// Performs similar to an ordinary <see cref="Identifier"/>, but when called upon - /// to perform discovery, it returns a preset list of sevice endpoints to avoid - /// having a dependency on a hosted web site to actually perform discovery on. - /// </summary> - internal class MockIdentifier : Identifier { - private IEnumerable<IdentifierDiscoveryResult> endpoints; - - private MockHttpRequest mockHttpRequest; - - private Identifier wrappedIdentifier; - - public MockIdentifier(Identifier wrappedIdentifier, MockHttpRequest mockHttpRequest, IEnumerable<IdentifierDiscoveryResult> endpoints) - : base(wrappedIdentifier.OriginalString, false) { - Requires.NotNull(wrappedIdentifier, "wrappedIdentifier"); - Requires.NotNull(mockHttpRequest, "mockHttpRequest"); - Requires.NotNull(endpoints, "endpoints"); - - this.wrappedIdentifier = wrappedIdentifier; - this.endpoints = endpoints; - this.mockHttpRequest = mockHttpRequest; - - // Register a mock HTTP response to enable discovery of this identifier within the RP - // without having to host an ASP.NET site within the test. - mockHttpRequest.RegisterMockXrdsResponse(new Uri(wrappedIdentifier.ToString()), endpoints); - } - - internal IEnumerable<IdentifierDiscoveryResult> DiscoveryEndpoints { - get { return this.endpoints; } - } - - public override string ToString() { - return this.wrappedIdentifier.ToString(); - } - - public override bool Equals(object obj) { - return this.wrappedIdentifier.Equals(obj); - } - - public override int GetHashCode() { - return this.wrappedIdentifier.GetHashCode(); - } - - internal override Identifier TrimFragment() { - return this; - } - - internal override bool TryRequireSsl(out Identifier secureIdentifier) { - // We take special care to make our wrapped identifier secure, but still - // return a mocked (secure) identifier. - Identifier secureWrappedIdentifier; - bool result = this.wrappedIdentifier.TryRequireSsl(out secureWrappedIdentifier); - secureIdentifier = new MockIdentifier(secureWrappedIdentifier, this.mockHttpRequest, this.endpoints); - return result; - } - } -} diff --git a/src/DotNetOpenAuth.Test/Mocks/MockIdentifierDiscoveryService.cs b/src/DotNetOpenAuth.Test/Mocks/MockIdentifierDiscoveryService.cs deleted file mode 100644 index 0118851..0000000 --- a/src/DotNetOpenAuth.Test/Mocks/MockIdentifierDiscoveryService.cs +++ /dev/null @@ -1,47 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="MockIdentifierDiscoveryService.cs" company="Outercurve Foundation"> -// Copyright (c) Outercurve Foundation. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.Test.Mocks { - using System; - using System.Collections.Generic; - using System.Linq; - using System.Text; - using DotNetOpenAuth.Messaging; - using DotNetOpenAuth.OpenId; - using DotNetOpenAuth.OpenId.RelyingParty; - - internal class MockIdentifierDiscoveryService : IIdentifierDiscoveryService { - /// <summary> - /// Initializes a new instance of the <see cref="MockIdentifierDiscoveryService"/> class. - /// </summary> - public MockIdentifierDiscoveryService() { - } - - #region IIdentifierDiscoveryService Members - - /// <summary> - /// Performs discovery on the specified identifier. - /// </summary> - /// <param name="identifier">The identifier to perform discovery on.</param> - /// <param name="requestHandler">The means to place outgoing HTTP requests.</param> - /// <param name="abortDiscoveryChain">if set to <c>true</c>, no further discovery services will be called for this identifier.</param> - /// <returns> - /// A sequence of service endpoints yielded by discovery. Must not be null, but may be empty. - /// </returns> - public IEnumerable<IdentifierDiscoveryResult> Discover(Identifier identifier, IDirectWebRequestHandler requestHandler, out bool abortDiscoveryChain) { - var mockIdentifier = identifier as MockIdentifier; - if (mockIdentifier == null) { - abortDiscoveryChain = false; - return Enumerable.Empty<IdentifierDiscoveryResult>(); - } - - abortDiscoveryChain = true; - return mockIdentifier.DiscoveryEndpoints; - } - - #endregion - } -} diff --git a/src/DotNetOpenAuth.Test/Mocks/MockRealm.cs b/src/DotNetOpenAuth.Test/Mocks/MockRealm.cs index 8509c03..38f9daf 100644 --- a/src/DotNetOpenAuth.Test/Mocks/MockRealm.cs +++ b/src/DotNetOpenAuth.Test/Mocks/MockRealm.cs @@ -7,6 +7,9 @@ namespace DotNetOpenAuth.Test.Mocks { using System; using System.Collections.Generic; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; using Validation; @@ -30,15 +33,16 @@ namespace DotNetOpenAuth.Test.Mocks { /// Searches for an XRDS document at the realm URL, and if found, searches /// for a description of a relying party endpoints (OpenId login pages). /// </summary> - /// <param name="requestHandler">The mechanism to use for sending HTTP requests.</param> + /// <param name="hostFactories">The host factories.</param> /// <param name="allowRedirects">Whether redirects may be followed when discovering the Realm. /// This may be true when creating an unsolicited assertion, but must be /// false when performing return URL verification per 2.0 spec section 9.2.1.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns> /// The details of the endpoints if found, otherwise null. /// </returns> - internal override IEnumerable<RelyingPartyEndpointDescription> DiscoverReturnToEndpoints(IDirectWebRequestHandler requestHandler, bool allowRedirects) { - return this.relyingPartyDescriptions; + internal override Task<IEnumerable<RelyingPartyEndpointDescription>> DiscoverReturnToEndpointsAsync(IHostFactories hostFactories, bool allowRedirects, CancellationToken cancellationToken) { + return Task.FromResult<IEnumerable<RelyingPartyEndpointDescription>>(this.relyingPartyDescriptions); } } } diff --git a/src/DotNetOpenAuth.Test/Mocks/MockReplayProtectionBindingElement.cs b/src/DotNetOpenAuth.Test/Mocks/MockReplayProtectionBindingElement.cs index 1733f17..58a2367 100644 --- a/src/DotNetOpenAuth.Test/Mocks/MockReplayProtectionBindingElement.cs +++ b/src/DotNetOpenAuth.Test/Mocks/MockReplayProtectionBindingElement.cs @@ -5,6 +5,9 @@ //----------------------------------------------------------------------- namespace DotNetOpenAuth.Test.Mocks { + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using NUnit.Framework; @@ -23,17 +26,17 @@ namespace DotNetOpenAuth.Test.Mocks { /// </summary> public Channel Channel { get; set; } - MessageProtections? IChannelBindingElement.ProcessOutgoingMessage(IProtocolMessage message) { + Task<MessageProtections?> IChannelBindingElement.ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { var replayMessage = message as IReplayProtectedProtocolMessage; if (replayMessage != null) { replayMessage.Nonce = "someNonce"; - return MessageProtections.ReplayProtection; + return MessageProtectionTasks.ReplayProtection; } - return null; + return MessageProtectionTasks.Null; } - MessageProtections? IChannelBindingElement.ProcessIncomingMessage(IProtocolMessage message) { + Task<MessageProtections?> IChannelBindingElement.ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { var replayMessage = message as IReplayProtectedProtocolMessage; if (replayMessage != null) { Assert.AreEqual("someNonce", replayMessage.Nonce, "The nonce didn't serialize correctly, or something"); @@ -42,10 +45,10 @@ namespace DotNetOpenAuth.Test.Mocks { throw new ReplayedMessageException(message); } this.messageReceived = true; - return MessageProtections.ReplayProtection; + return MessageProtectionTasks.ReplayProtection; } - return null; + return MessageProtectionTasks.Null; } #endregion diff --git a/src/DotNetOpenAuth.Test/Mocks/MockSigningBindingElement.cs b/src/DotNetOpenAuth.Test/Mocks/MockSigningBindingElement.cs index aa68b0b..f0b2bfc 100644 --- a/src/DotNetOpenAuth.Test/Mocks/MockSigningBindingElement.cs +++ b/src/DotNetOpenAuth.Test/Mocks/MockSigningBindingElement.cs @@ -9,6 +9,9 @@ namespace DotNetOpenAuth.Test.Mocks { using System.Collections.Generic; using System.Linq; using System.Text; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; @@ -26,26 +29,26 @@ namespace DotNetOpenAuth.Test.Mocks { /// </summary> public Channel Channel { get; set; } - MessageProtections? IChannelBindingElement.ProcessOutgoingMessage(IProtocolMessage message) { - ITamperResistantProtocolMessage signedMessage = message as ITamperResistantProtocolMessage; + Task<MessageProtections?> IChannelBindingElement.ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { + var signedMessage = message as ITamperResistantProtocolMessage; if (signedMessage != null) { signedMessage.Signature = MessageSignature; - return MessageProtections.TamperProtection; + return MessageProtectionTasks.TamperProtection; } - return null; + return MessageProtectionTasks.Null; } - MessageProtections? IChannelBindingElement.ProcessIncomingMessage(IProtocolMessage message) { - ITamperResistantProtocolMessage signedMessage = message as ITamperResistantProtocolMessage; + Task<MessageProtections?> IChannelBindingElement.ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { + var signedMessage = message as ITamperResistantProtocolMessage; if (signedMessage != null) { if (signedMessage.Signature != MessageSignature) { throw new InvalidSignatureException(message); } - return MessageProtections.TamperProtection; + return MessageProtectionTasks.TamperProtection; } - return null; + return MessageProtectionTasks.Null; } #endregion diff --git a/src/DotNetOpenAuth.Test/Mocks/MockTransformationBindingElement.cs b/src/DotNetOpenAuth.Test/Mocks/MockTransformationBindingElement.cs index 2b3249f..35d7f1b 100644 --- a/src/DotNetOpenAuth.Test/Mocks/MockTransformationBindingElement.cs +++ b/src/DotNetOpenAuth.Test/Mocks/MockTransformationBindingElement.cs @@ -9,11 +9,14 @@ namespace DotNetOpenAuth.Test.Mocks { using System.Collections.Generic; using System.Linq; using System.Text; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using NUnit.Framework; internal class MockTransformationBindingElement : IChannelBindingElement { - private string transform; + private readonly string transform; internal MockTransformationBindingElement(string transform) { if (transform == null) { @@ -34,25 +37,25 @@ namespace DotNetOpenAuth.Test.Mocks { /// </summary> public Channel Channel { get; set; } - MessageProtections? IChannelBindingElement.ProcessOutgoingMessage(IProtocolMessage message) { + Task<MessageProtections?> IChannelBindingElement.ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { var testMessage = message as TestMessage; if (testMessage != null) { testMessage.Name = this.transform + testMessage.Name; - return MessageProtections.None; + return MessageProtectionTasks.None; } - return null; + return MessageProtectionTasks.Null; } - MessageProtections? IChannelBindingElement.ProcessIncomingMessage(IProtocolMessage message) { + Task<MessageProtections?> IChannelBindingElement.ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) { var testMessage = message as TestMessage; if (testMessage != null) { StringAssert.StartsWith(this.transform, testMessage.Name); testMessage.Name = testMessage.Name.Substring(this.transform.Length); - return MessageProtections.None; + return MessageProtectionTasks.None; } - return null; + return MessageProtectionTasks.Null; } #endregion diff --git a/src/DotNetOpenAuth.Test/Mocks/TestBadChannel.cs b/src/DotNetOpenAuth.Test/Mocks/TestBadChannel.cs index 263f0fd..344598f 100644 --- a/src/DotNetOpenAuth.Test/Mocks/TestBadChannel.cs +++ b/src/DotNetOpenAuth.Test/Mocks/TestBadChannel.cs @@ -7,15 +7,32 @@ namespace DotNetOpenAuth.Test.Mocks { using System; using System.Collections.Generic; + using System.Net.Http; + using System.Threading; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Messaging; + using DotNetOpenAuth.OpenId; /// <summary> /// A Channel derived type that passes null to the protected constructor. /// </summary> internal class TestBadChannel : Channel { - internal TestBadChannel(bool badConstructorParam) - : base(badConstructorParam ? null : new TestMessageFactory()) { + /// <summary> + /// Initializes a new instance of the <see cref="TestBadChannel" /> class. + /// </summary> + /// <param name="messageFactory">The message factory. Could be <see cref="TestMessageFactory"/></param> + /// <param name="bindingElements">The binding elements.</param> + /// <param name="hostFactories">The host factories.</param> + internal TestBadChannel(IMessageFactory messageFactory, IChannelBindingElement[] bindingElements, IHostFactories hostFactories) + : base(messageFactory, bindingElements, hostFactories) { + } + + /// <summary> + /// Initializes a new instance of the <see cref="TestBadChannel"/> class. + /// </summary> + internal TestBadChannel() + : this(new TestMessageFactory(), new IChannelBindingElement[0], new DefaultOpenIdHostFactories()) { } internal new void Create301RedirectResponse(IDirectedProtocolMessage message, IDictionary<string, string> fields, bool payloadInFragment = false) { @@ -34,15 +51,15 @@ namespace DotNetOpenAuth.Test.Mocks { return base.Receive(fields, recipient); } - internal new IProtocolMessage ReadFromRequest(HttpRequestBase request) { - return base.ReadFromRequest(request); + internal new Task<IDirectedProtocolMessage> ReadFromRequestAsync(HttpRequestMessage request, CancellationToken cancellationToken) { + return base.ReadFromRequestAsync(request, cancellationToken); } - protected override IDictionary<string, string> ReadFromResponseCore(IncomingWebResponse response) { + protected override Task<IDictionary<string, string>> ReadFromResponseCoreAsync(HttpResponseMessage response, CancellationToken cancellationToken) { throw new NotImplementedException(); } - protected override OutgoingWebResponse PrepareDirectResponse(IProtocolMessage response) { + protected override HttpResponseMessage PrepareDirectResponse(IProtocolMessage response) { throw new NotImplementedException(); } } diff --git a/src/DotNetOpenAuth.Test/Mocks/TestChannel.cs b/src/DotNetOpenAuth.Test/Mocks/TestChannel.cs index 1472231..5b318d5 100644 --- a/src/DotNetOpenAuth.Test/Mocks/TestChannel.cs +++ b/src/DotNetOpenAuth.Test/Mocks/TestChannel.cs @@ -8,21 +8,26 @@ namespace DotNetOpenAuth.Test.Mocks { using System; using System.Collections.Generic; using System.Net; + using System.Net.Http; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Reflection; + using DotNetOpenAuth.OpenId; internal class TestChannel : Channel { - internal TestChannel() - : this(new TestMessageFactory()) { + internal TestChannel(IHostFactories hostFactories = null) + : this(new TestMessageFactory(), new IChannelBindingElement[0], hostFactories ?? new DefaultOpenIdHostFactories()) { } - internal TestChannel(MessageDescriptionCollection messageDescriptions) - : this() { + internal TestChannel(MessageDescriptionCollection messageDescriptions, IHostFactories hostFactories = null) + : this(hostFactories) { this.MessageDescriptions = messageDescriptions; } - internal TestChannel(IMessageFactory messageTypeProvider, params IChannelBindingElement[] bindingElements) - : base(messageTypeProvider, bindingElements) { + internal TestChannel(IMessageFactory messageTypeProvider, IChannelBindingElement[] bindingElements, IHostFactories hostFactories) + : base(messageTypeProvider, bindingElements, hostFactories) { } /// <summary> @@ -40,15 +45,15 @@ namespace DotNetOpenAuth.Test.Mocks { return base.Receive(fields, recipient); } - protected override IDictionary<string, string> ReadFromResponseCore(IncomingWebResponse response) { + protected override Task<IDictionary<string, string>> ReadFromResponseCoreAsync(HttpResponseMessage response, CancellationToken cancellationToken) { throw new NotImplementedException("ReadFromResponseInternal"); } - protected override HttpWebRequest CreateHttpRequest(IDirectedProtocolMessage request) { + protected override HttpRequestMessage CreateHttpRequest(IDirectedProtocolMessage request) { throw new NotImplementedException("CreateHttpRequest"); } - protected override OutgoingWebResponse PrepareDirectResponse(IProtocolMessage response) { + protected override HttpResponseMessage PrepareDirectResponse(IProtocolMessage response) { throw new NotImplementedException("SendDirectMessageResponse"); } } diff --git a/src/DotNetOpenAuth.Test/Mocks/TestWebRequestHandler.cs b/src/DotNetOpenAuth.Test/Mocks/TestWebRequestHandler.cs deleted file mode 100644 index b38a3d8..0000000 --- a/src/DotNetOpenAuth.Test/Mocks/TestWebRequestHandler.cs +++ /dev/null @@ -1,116 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="TestWebRequestHandler.cs" company="Outercurve Foundation"> -// Copyright (c) Outercurve Foundation. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.Test.Mocks { - using System; - using System.IO; - using System.Net; - using System.Text; - using DotNetOpenAuth.Messaging; - using DotNetOpenAuth.OAuth.ChannelElements; - - internal class TestWebRequestHandler : IDirectWebRequestHandler { - private Stream postEntity; - - /// <summary> - /// Gets or sets the callback used to provide the mock response for the mock request. - /// </summary> - internal Func<HttpWebRequest, IncomingWebResponse> Callback { get; set; } - - /// <summary> - /// Gets the stream that was written out as if on an HTTP request. - /// </summary> - internal Stream RequestEntityStream { - get { - if (this.postEntity == null) { - return null; - } - - Stream result = new MemoryStream(); - long originalPosition = this.postEntity.Position; - this.postEntity.Position = 0; - this.postEntity.CopyTo(result); - this.postEntity.Position = originalPosition; - result.Position = 0; - return result; - } - } - - /// <summary> - /// Gets the stream that was written out as if on an HTTP request as an ordinary string. - /// </summary> - internal string RequestEntityAsString { - get { - if (this.postEntity == null) { - return null; - } - - StreamReader reader = new StreamReader(this.RequestEntityStream); - return reader.ReadToEnd(); - } - } - - #region IWebRequestHandler Members - - public bool CanSupport(DirectWebRequestOptions options) { - return true; - } - - /// <summary> - /// Prepares an <see cref="HttpWebRequest"/> that contains an POST entity for sending the entity. - /// </summary> - /// <param name="request">The <see cref="HttpWebRequest"/> that should contain the entity.</param> - /// <returns> - /// The writer the caller should write out the entity data to. - /// </returns> - public Stream GetRequestStream(HttpWebRequest request) { - return this.GetRequestStream(request, DirectWebRequestOptions.None); - } - - public Stream GetRequestStream(HttpWebRequest request, DirectWebRequestOptions options) { - this.postEntity = new MemoryStream(); - return this.postEntity; - } - - /// <summary> - /// Processes an <see cref="HttpWebRequest"/> and converts the - /// <see cref="HttpWebResponse"/> to a <see cref="Response"/> instance. - /// </summary> - /// <param name="request">The <see cref="HttpWebRequest"/> to handle.</param> - /// <returns> - /// An instance of <see cref="Response"/> describing the response. - /// </returns> - public IncomingWebResponse GetResponse(HttpWebRequest request) { - return this.GetResponse(request, DirectWebRequestOptions.None); - } - - public IncomingWebResponse GetResponse(HttpWebRequest request, DirectWebRequestOptions options) { - if (this.Callback == null) { - throw new InvalidOperationException("Set the Callback property first."); - } - - return this.Callback(request); - } - - #endregion - - #region IDirectSslWebRequestHandler Members - - public Stream GetRequestStream(HttpWebRequest request, bool requireSsl) { - ErrorUtilities.VerifyProtocol(!requireSsl || request.RequestUri.Scheme == Uri.UriSchemeHttps, "disallowed request"); - return this.GetRequestStream(request); - } - - public IncomingWebResponse GetResponse(HttpWebRequest request, bool requireSsl) { - ErrorUtilities.VerifyProtocol(!requireSsl || request.RequestUri.Scheme == Uri.UriSchemeHttps, "disallowed request"); - var result = this.GetResponse(request); - ErrorUtilities.VerifyProtocol(!requireSsl || result.FinalUri.Scheme == Uri.UriSchemeHttps, "disallowed request"); - return result; - } - - #endregion - } -} diff --git a/src/DotNetOpenAuth.Test/OAuth/AppendixScenarios.cs b/src/DotNetOpenAuth.Test/OAuth/AppendixScenarios.cs index a295732..c7b2bfa 100644 --- a/src/DotNetOpenAuth.Test/OAuth/AppendixScenarios.cs +++ b/src/DotNetOpenAuth.Test/OAuth/AppendixScenarios.cs @@ -8,6 +8,9 @@ namespace DotNetOpenAuth.Test.OAuth { using System; using System.IO; using System.Net; + using System.Net.Http; + using System.Net.Http.Headers; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OAuth; using DotNetOpenAuth.OAuth.ChannelElements; @@ -17,50 +20,76 @@ namespace DotNetOpenAuth.Test.OAuth { [TestFixture] public class AppendixScenarios : TestBase { [Test] - public void SpecAppendixAExample() { - ServiceProviderDescription serviceDescription = new ServiceProviderDescription() { - RequestTokenEndpoint = new MessageReceivingEndpoint("https://photos.example.net/request_token", HttpDeliveryMethods.PostRequest), - UserAuthorizationEndpoint = new MessageReceivingEndpoint("http://photos.example.net/authorize", HttpDeliveryMethods.GetRequest), - AccessTokenEndpoint = new MessageReceivingEndpoint("https://photos.example.net/access_token", HttpDeliveryMethods.PostRequest), - TamperProtectionElements = new ITamperProtectionChannelBindingElement[] { - new PlaintextSigningBindingElement(), - new HmacSha1SigningBindingElement(), - }, + public async Task SpecAppendixAExample() { + var serviceDescription = new ServiceProviderDescription( + "https://photos.example.net/request_token", + "http://photos.example.net/authorize", + "https://photos.example.net/access_token"); + var serviceHostDescription = new ServiceProviderHostDescription { + RequestTokenEndpoint = new MessageReceivingEndpoint(serviceDescription.TemporaryCredentialsRequestEndpoint, HttpDeliveryMethods.PostRequest | HttpDeliveryMethods.AuthorizationHeaderRequest), + UserAuthorizationEndpoint = new MessageReceivingEndpoint(serviceDescription.ResourceOwnerAuthorizationEndpoint, HttpDeliveryMethods.GetRequest), + AccessTokenEndpoint = new MessageReceivingEndpoint(serviceDescription.TokenRequestEndpoint, HttpDeliveryMethods.PostRequest | HttpDeliveryMethods.AuthorizationHeaderRequest), + TamperProtectionElements = new ITamperProtectionChannelBindingElement[] { new HmacSha1SigningBindingElement(), }, }; - MessageReceivingEndpoint accessPhotoEndpoint = new MessageReceivingEndpoint("http://photos.example.net/photos?file=vacation.jpg&size=original", HttpDeliveryMethods.AuthorizationHeaderRequest | HttpDeliveryMethods.GetRequest); - ConsumerDescription consumerDescription = new ConsumerDescription("dpf43f3p2l4k3l03", "kd94hf93k423kf44"); + var accessPhotoEndpoint = new Uri("http://photos.example.net/photos?file=vacation.jpg&size=original"); + var consumerDescription = new ConsumerDescription("dpf43f3p2l4k3l03", "kd94hf93k423kf44"); - OAuthCoordinator coordinator = new OAuthCoordinator( - consumerDescription, - serviceDescription, - consumer => { - consumer.Channel.PrepareResponse(consumer.PrepareRequestUserAuthorization(new Uri("http://printer.example.com/request_token_ready"), null, null)); // .Send() dropped because this is just a simulation - string accessToken = consumer.ProcessUserAuthorization().AccessToken; - var photoRequest = consumer.CreateAuthorizingMessage(accessPhotoEndpoint, accessToken); - OutgoingWebResponse protectedPhoto = ((CoordinatingOAuthConsumerChannel)consumer.Channel).RequestProtectedResource(photoRequest); - Assert.IsNotNull(protectedPhoto); - Assert.AreEqual(HttpStatusCode.OK, protectedPhoto.Status); - Assert.AreEqual("image/jpeg", protectedPhoto.Headers[HttpResponseHeader.ContentType]); - Assert.AreNotEqual(0, protectedPhoto.ResponseStream.Length); - }, - sp => { - var requestTokenMessage = sp.ReadTokenRequest(); - sp.Channel.PrepareResponse(sp.PrepareUnauthorizedTokenMessage(requestTokenMessage)); // .Send() dropped because this is just a simulation - var authRequest = sp.ReadAuthorizationRequest(); + var tokenManager = new InMemoryTokenManager(); + tokenManager.AddConsumer(consumerDescription); + var sp = new ServiceProvider(serviceHostDescription, tokenManager); + + Handle(serviceDescription.TemporaryCredentialsRequestEndpoint).By( + async (request, ct) => { + var requestTokenMessage = await sp.ReadTokenRequestAsync(request, ct); + return await sp.Channel.PrepareResponseAsync(sp.PrepareUnauthorizedTokenMessage(requestTokenMessage)); + }); + Handle(serviceDescription.ResourceOwnerAuthorizationEndpoint).By( + async (request, ct) => { + var authRequest = await sp.ReadAuthorizationRequestAsync(request, ct); ((InMemoryTokenManager)sp.TokenManager).AuthorizeRequestToken(authRequest.RequestToken); - sp.Channel.PrepareResponse(sp.PrepareAuthorizationResponse(authRequest)); // .Send() dropped because this is just a simulation - var accessRequest = sp.ReadAccessTokenRequest(); - sp.Channel.PrepareResponse(sp.PrepareAccessTokenMessage(accessRequest)); // .Send() dropped because this is just a simulation - string accessToken = sp.ReadProtectedResourceAuthorization().AccessToken; - ((CoordinatingOAuthServiceProviderChannel)sp.Channel).SendDirectRawResponse(new OutgoingWebResponse { - ResponseStream = new MemoryStream(new byte[] { 0x33, 0x66 }), - Headers = new WebHeaderCollection { - { HttpResponseHeader.ContentType, "image/jpeg" }, - }, - }); + return await sp.Channel.PrepareResponseAsync(sp.PrepareAuthorizationResponse(authRequest)); + }); + Handle(serviceDescription.TokenRequestEndpoint).By( + async (request, ct) => { + var accessRequest = await sp.ReadAccessTokenRequestAsync(request, ct); + return await sp.Channel.PrepareResponseAsync(sp.PrepareAccessTokenMessage(accessRequest), ct); }); + Handle(accessPhotoEndpoint).By( + async (request, ct) => { + string accessToken = (await sp.ReadProtectedResourceAuthorizationAsync(request)).AccessToken; + Assert.That(accessToken, Is.Not.Null.And.Not.Empty); + var responseMessage = new HttpResponseMessage { Content = new ByteArrayContent(new byte[] { 0x33, 0x66 }), }; + responseMessage.Content.Headers.ContentType = new MediaTypeHeaderValue("image/jpeg"); + return responseMessage; + }); + + var consumer = new Consumer( + consumerDescription.ConsumerKey, + consumerDescription.ConsumerSecret, + serviceDescription, + new MemoryTemporaryCredentialStorage()); + consumer.HostFactories = this.HostFactories; + var authorizeUrl = await consumer.RequestUserAuthorizationAsync(new Uri("http://printer.example.com/request_token_ready")); + Uri authorizeResponseUri; + this.HostFactories.AllowAutoRedirects = false; + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(authorizeUrl)) { + Assert.That(response.StatusCode, Is.EqualTo(HttpStatusCode.Redirect)); + authorizeResponseUri = response.Headers.Location; + } + } + + var accessTokenResponse = await consumer.ProcessUserAuthorizationAsync(authorizeResponseUri); + Assert.That(accessTokenResponse, Is.Not.Null); - coordinator.Run(); + using (var authorizingClient = consumer.CreateHttpClient(accessTokenResponse.AccessToken)) { + using (var protectedPhoto = await authorizingClient.GetAsync(accessPhotoEndpoint)) { + Assert.That(protectedPhoto, Is.Not.Null); + protectedPhoto.EnsureSuccessStatusCode(); + Assert.That("image/jpeg", Is.EqualTo(protectedPhoto.Content.Headers.ContentType.MediaType)); + Assert.That(protectedPhoto.Content.Headers.ContentLength, Is.Not.EqualTo(0)); + } + } } } } diff --git a/src/DotNetOpenAuth.Test/OAuth/ChannelElements/HmacSha1SigningBindingElementTests.cs b/src/DotNetOpenAuth.Test/OAuth/ChannelElements/HmacSha1SigningBindingElementTests.cs index 49260eb..e436143 100644 --- a/src/DotNetOpenAuth.Test/OAuth/ChannelElements/HmacSha1SigningBindingElementTests.cs +++ b/src/DotNetOpenAuth.Test/OAuth/ChannelElements/HmacSha1SigningBindingElementTests.cs @@ -5,6 +5,8 @@ //----------------------------------------------------------------------- namespace DotNetOpenAuth.Test.OAuth.ChannelElements { + using System.Net.Http; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Reflection; using DotNetOpenAuth.OAuth; @@ -33,7 +35,7 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { ((ITamperResistantOAuthMessage)message).ConsumerSecret = "ExJXsYl7Or8OfK98"; ((ITamperResistantOAuthMessage)message).TokenSecret = "b197333b-470a-43b3-bcd7-49d6d2229c4c"; var signedMessage = (ITamperResistantOAuthMessage)message; - signedMessage.HttpMethod = "GET"; + signedMessage.HttpMethod = HttpMethod.Get; signedMessage.SignatureMethod = "HMAC-SHA1"; MessageDictionary dictionary = this.MessageDescriptions.GetAccessor(message); dictionary["oauth_timestamp"] = "1353545248"; diff --git a/src/DotNetOpenAuth.Test/OAuth/ChannelElements/OAuthChannelTests.cs b/src/DotNetOpenAuth.Test/OAuth/ChannelElements/OAuthChannelTests.cs index b081038..fdf652c 100644 --- a/src/DotNetOpenAuth.Test/OAuth/ChannelElements/OAuthChannelTests.cs +++ b/src/DotNetOpenAuth.Test/OAuth/ChannelElements/OAuthChannelTests.cs @@ -10,7 +10,10 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { using System.Collections.Specialized; using System.IO; using System.Net; + using System.Net.Http; using System.Text; + using System.Threading; + using System.Threading.Tasks; using System.Web; using System.Xml; using DotNetOpenAuth.Messaging; @@ -24,41 +27,32 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { [TestFixture] public class OAuthChannelTests : TestBase { private OAuthChannel channel; - private TestWebRequestHandler webRequestHandler; private SigningBindingElementBase signingElement; private INonceStore nonceStore; private DotNetOpenAuth.OAuth.ServiceProviderSecuritySettings serviceProviderSecuritySettings = DotNetOpenAuth.Configuration.OAuthElement.Configuration.ServiceProvider.SecuritySettings.CreateSecuritySettings(); - private DotNetOpenAuth.OAuth.ConsumerSecuritySettings consumerSecuritySettings = DotNetOpenAuth.Configuration.OAuthElement.Configuration.Consumer.SecuritySettings.CreateSecuritySettings(); [SetUp] public override void SetUp() { base.SetUp(); - this.webRequestHandler = new TestWebRequestHandler(); this.signingElement = new RsaSha1ServiceProviderSigningBindingElement(new InMemoryTokenManager()); this.nonceStore = new NonceMemoryStore(StandardExpirationBindingElement.MaximumMessageAge); - this.channel = new OAuthServiceProviderChannel(this.signingElement, this.nonceStore, new InMemoryTokenManager(), this.serviceProviderSecuritySettings, new TestMessageFactory()); - this.channel.WebRequestHandler = this.webRequestHandler; + this.channel = new OAuthServiceProviderChannel(this.signingElement, this.nonceStore, new InMemoryTokenManager(), this.serviceProviderSecuritySettings, new TestMessageFactory(), this.HostFactories); } [Test, ExpectedException(typeof(ArgumentException))] public void CtorNullSigner() { - new OAuthConsumerChannel(null, this.nonceStore, new InMemoryTokenManager(), this.consumerSecuritySettings, new TestMessageFactory()); + new OAuthServiceProviderChannel(null, this.nonceStore, new InMemoryTokenManager(), this.serviceProviderSecuritySettings, new TestMessageFactory()); } [Test, ExpectedException(typeof(ArgumentNullException))] public void CtorNullStore() { - new OAuthConsumerChannel(new RsaSha1ServiceProviderSigningBindingElement(new InMemoryTokenManager()), null, new InMemoryTokenManager(), this.consumerSecuritySettings, new TestMessageFactory()); + new OAuthServiceProviderChannel(new RsaSha1ServiceProviderSigningBindingElement(new InMemoryTokenManager()), null, new InMemoryTokenManager(), this.serviceProviderSecuritySettings, new TestMessageFactory()); } [Test, ExpectedException(typeof(ArgumentNullException))] public void CtorNullTokenManager() { - new OAuthConsumerChannel(new RsaSha1ServiceProviderSigningBindingElement(new InMemoryTokenManager()), this.nonceStore, null, this.consumerSecuritySettings, new TestMessageFactory()); - } - - [Test] - public void CtorSimpleConsumer() { - new OAuthConsumerChannel(new RsaSha1ServiceProviderSigningBindingElement(new InMemoryTokenManager()), this.nonceStore, (IConsumerTokenManager)new InMemoryTokenManager(), this.consumerSecuritySettings); + new OAuthServiceProviderChannel(new RsaSha1ServiceProviderSigningBindingElement(new InMemoryTokenManager()), this.nonceStore, null, this.serviceProviderSecuritySettings, new TestMessageFactory()); } [Test] @@ -67,8 +61,8 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { } [Test] - public void ReadFromRequestAuthorization() { - this.ParameterizedReceiveTest(HttpDeliveryMethods.AuthorizationHeaderRequest); + public async Task ReadFromRequestAuthorization() { + await this.ParameterizedReceiveTestAsync(HttpDeliveryMethods.AuthorizationHeaderRequest); } /// <summary> @@ -76,7 +70,7 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { /// from the Authorization header, the query string and the entity form data. /// </summary> [Test] - public void ReadFromRequestAuthorizationScattered() { + public async Task ReadFromRequestAuthorizationScattered() { // Start by creating a standard POST HTTP request. var postedFields = new Dictionary<string, string> { { "age", "15" }, @@ -97,7 +91,7 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { var requestInfo = new HttpRequestInfo("POST", builder.Uri, form: postedFields.ToNameValueCollection(), headers: headers); - IDirectedProtocolMessage requestMessage = this.channel.ReadFromRequest(requestInfo); + IDirectedProtocolMessage requestMessage = await this.channel.ReadFromRequestAsync(requestInfo.AsHttpRequestMessage(), CancellationToken.None); Assert.IsNotNull(requestMessage); Assert.IsInstanceOf<TestMessage>(requestMessage); @@ -108,36 +102,35 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { } [Test] - public void ReadFromRequestForm() { - this.ParameterizedReceiveTest(HttpDeliveryMethods.PostRequest); + public async Task ReadFromRequestForm() { + await this.ParameterizedReceiveTestAsync(HttpDeliveryMethods.PostRequest); } [Test] - public void ReadFromRequestQueryString() { - this.ParameterizedReceiveTest(HttpDeliveryMethods.GetRequest); + public async Task ReadFromRequestQueryString() { + await this.ParameterizedReceiveTestAsync(HttpDeliveryMethods.GetRequest); } [Test] - public void SendDirectMessageResponse() { + public async Task SendDirectMessageResponse() { IProtocolMessage message = new TestDirectedMessage { Age = 15, Name = "Andrew", Location = new Uri("http://hostb/pathB"), }; - OutgoingWebResponse response = this.channel.PrepareResponse(message); - Assert.AreSame(message, response.OriginalMessage); - Assert.AreEqual(HttpStatusCode.OK, response.Status); - Assert.AreEqual(2, response.Headers.Count); + var response = await this.channel.PrepareResponseAsync(message); + Assert.AreEqual(HttpStatusCode.OK, response.StatusCode); + Assert.AreEqual(Channel.HttpFormUrlEncodedContentType.MediaType, response.Content.Headers.ContentType.MediaType); - NameValueCollection body = HttpUtility.ParseQueryString(response.Body); + NameValueCollection body = HttpUtility.ParseQueryString(await response.Content.ReadAsStringAsync()); Assert.AreEqual("15", body["age"]); Assert.AreEqual("Andrew", body["Name"]); Assert.AreEqual("http://hostb/pathB", body["Location"]); } [Test] - public void ReadFromResponse() { + public async Task ReadFromResponse() { var fields = new Dictionary<string, string> { { "age", "15" }, { "Name", "Andrew" }, @@ -150,7 +143,9 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { writer.Write(MessagingUtilities.CreateQueryString(fields)); writer.Flush(); ms.Seek(0, SeekOrigin.Begin); - IDictionary<string, string> deserializedFields = this.channel.ReadFromResponseCoreTestHook(new CachedDirectWebResponse { CachedResponseStream = ms }); + IDictionary<string, string> deserializedFields = await this.channel.ReadFromResponseCoreAsyncTestHook( + new HttpResponseMessage { Content = new StreamContent(ms) }, + CancellationToken.None); Assert.AreEqual(fields.Count, deserializedFields.Count); foreach (string key in fields.Keys) { Assert.AreEqual(fields[key], deserializedFields[key]); @@ -158,34 +153,34 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { } [Test, ExpectedException(typeof(ArgumentNullException))] - public void RequestNull() { - this.channel.Request(null); + public async Task RequestNull() { + await this.channel.RequestAsync(null, CancellationToken.None); } [Test, ExpectedException(typeof(ArgumentException))] - public void RequestNullRecipient() { + public async Task RequestNullRecipient() { IDirectedProtocolMessage message = new TestDirectedMessage(MessageTransport.Direct); - this.channel.Request(message); + await this.channel.RequestAsync(message, CancellationToken.None); } [Test, ExpectedException(typeof(NotSupportedException))] - public void RequestBadPreferredScheme() { + public async Task RequestBadPreferredScheme() { TestDirectedMessage message = new TestDirectedMessage(MessageTransport.Direct); message.Recipient = new Uri("http://localtest"); message.HttpMethods = HttpDeliveryMethods.None; - this.channel.Request(message); + await this.channel.RequestAsync(message, CancellationToken.None); } [Test] - public void RequestUsingAuthorizationHeader() { - this.ParameterizedRequestTest(HttpDeliveryMethods.AuthorizationHeaderRequest); + public async Task RequestUsingAuthorizationHeader() { + await this.ParameterizedRequestTestAsync(HttpDeliveryMethods.AuthorizationHeaderRequest); } /// <summary> /// Verifies that message parts can be distributed to the query, form, and Authorization header. /// </summary> [Test] - public void RequestUsingAuthorizationHeaderScattered() { + public async Task RequestUsingAuthorizationHeaderScattered() { TestDirectedMessage request = new TestDirectedMessage(MessageTransport.Direct) { Age = 15, Name = "Andrew", @@ -201,9 +196,9 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { request.Recipient = new Uri("http://localhost/?appearinquery=queryish"); request.HttpMethods = HttpDeliveryMethods.AuthorizationHeaderRequest | HttpDeliveryMethods.PostRequest; - HttpWebRequest webRequest = this.channel.InitializeRequest(request); + var webRequest = await this.channel.InitializeRequestAsync(request, CancellationToken.None); Assert.IsNotNull(webRequest); - Assert.AreEqual("POST", webRequest.Method); + Assert.AreEqual(HttpMethod.Post, webRequest.Method); Assert.AreEqual(request.Recipient, webRequest.RequestUri); var declaredParts = new Dictionary<string, string> { @@ -213,23 +208,23 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { { "Timestamp", XmlConvert.ToString(request.Timestamp, XmlDateTimeSerializationMode.Utc) }, }; - Assert.AreEqual(CreateAuthorizationHeader(declaredParts), webRequest.Headers[HttpRequestHeader.Authorization]); - Assert.AreEqual("appearinform=formish", this.webRequestHandler.RequestEntityAsString); + Assert.AreEqual(CreateAuthorizationHeader(declaredParts), webRequest.Headers.Authorization.ToString()); + Assert.AreEqual("appearinform=formish", await webRequest.Content.ReadAsStringAsync()); } [Test] - public void RequestUsingGet() { - this.ParameterizedRequestTest(HttpDeliveryMethods.GetRequest); + public async Task RequestUsingGet() { + await this.ParameterizedRequestTestAsync(HttpDeliveryMethods.GetRequest); } [Test] - public void RequestUsingPost() { - this.ParameterizedRequestTest(HttpDeliveryMethods.PostRequest); + public async Task RequestUsingPost() { + await this.ParameterizedRequestTestAsync(HttpDeliveryMethods.PostRequest); } [Test] - public void RequestUsingHead() { - this.ParameterizedRequestTest(HttpDeliveryMethods.HeadRequest); + public async Task RequestUsingHead() { + await this.ParameterizedRequestTestAsync(HttpDeliveryMethods.HeadRequest); } /// <summary> @@ -238,14 +233,14 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { [Test] public void SendDirectMessageResponseHonorsHttpStatusCodes() { IProtocolMessage message = MessagingTestBase.GetStandardTestMessage(MessagingTestBase.FieldFill.AllRequired); - OutgoingWebResponse directResponse = this.channel.PrepareDirectResponseTestHook(message); - Assert.AreEqual(HttpStatusCode.OK, directResponse.Status); + var directResponse = this.channel.PrepareDirectResponseTestHook(message); + Assert.AreEqual(HttpStatusCode.OK, directResponse.StatusCode); var httpMessage = new TestDirectResponseMessageWithHttpStatus(); MessagingTestBase.GetStandardTestMessage(MessagingTestBase.FieldFill.AllRequired, httpMessage); httpMessage.HttpStatusCode = HttpStatusCode.NotAcceptable; directResponse = this.channel.PrepareDirectResponseTestHook(httpMessage); - Assert.AreEqual(HttpStatusCode.NotAcceptable, directResponse.Status); + Assert.AreEqual(HttpStatusCode.NotAcceptable, directResponse.StatusCode); } private static string CreateAuthorizationHeader(IDictionary<string, string> fields) { @@ -296,8 +291,8 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { return new HttpRequestInfo(request.Method, request.RequestUri, request.Headers, postEntity); } - private void ParameterizedRequestTest(HttpDeliveryMethods scheme) { - TestDirectedMessage request = new TestDirectedMessage(MessageTransport.Direct) { + private async Task ParameterizedRequestTestAsync(HttpDeliveryMethods scheme) { + var request = new TestDirectedMessage(MessageTransport.Direct) { Age = 15, Name = "Andrew", Location = new Uri("http://hostb/pathB"), @@ -306,30 +301,29 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { HttpMethods = scheme, }; - CachedDirectWebResponse rawResponse = null; - this.webRequestHandler.Callback = (req) => { - Assert.IsNotNull(req); - HttpRequestInfo reqInfo = ConvertToRequestInfo(req, this.webRequestHandler.RequestEntityStream); - Assert.AreEqual(MessagingUtilities.GetHttpVerb(scheme), reqInfo.HttpMethod); - var incomingMessage = this.channel.ReadFromRequest(reqInfo) as TestMessage; - Assert.IsNotNull(incomingMessage); - Assert.AreEqual(request.Age, incomingMessage.Age); - Assert.AreEqual(request.Name, incomingMessage.Name); - Assert.AreEqual(request.Location, incomingMessage.Location); - Assert.AreEqual(request.Timestamp, incomingMessage.Timestamp); - - var responseFields = new Dictionary<string, string> { - { "age", request.Age.ToString() }, - { "Name", request.Name }, - { "Location", request.Location.AbsoluteUri }, - { "Timestamp", XmlConvert.ToString(request.Timestamp, XmlDateTimeSerializationMode.Utc) }, - }; - rawResponse = new CachedDirectWebResponse(); - rawResponse.SetResponse(MessagingUtilities.CreateQueryString(responseFields)); - return rawResponse; - }; - - IProtocolMessage response = this.channel.Request(request); + Handle(request.Recipient).By( + async (req, ct) => { + Assert.IsNotNull(req); + Assert.AreEqual(MessagingUtilities.GetHttpVerb(scheme), req.Method); + var incomingMessage = (await this.channel.ReadFromRequestAsync(req, CancellationToken.None)) as TestMessage; + Assert.IsNotNull(incomingMessage); + Assert.AreEqual(request.Age, incomingMessage.Age); + Assert.AreEqual(request.Name, incomingMessage.Name); + Assert.AreEqual(request.Location, incomingMessage.Location); + Assert.AreEqual(request.Timestamp, incomingMessage.Timestamp); + + var responseFields = new Dictionary<string, string> { + { "age", request.Age.ToString() }, + { "Name", request.Name }, + { "Location", request.Location.AbsoluteUri }, + { "Timestamp", XmlConvert.ToString(request.Timestamp, XmlDateTimeSerializationMode.Utc) }, + }; + var rawResponse = new HttpResponseMessage(); + rawResponse.Content = new StringContent(MessagingUtilities.CreateQueryString(responseFields)); + return rawResponse; + }); + + IProtocolMessage response = await this.channel.RequestAsync(request, CancellationToken.None); Assert.IsNotNull(response); Assert.IsInstanceOf<TestMessage>(response); TestMessage responseMessage = (TestMessage)response; @@ -338,7 +332,7 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { Assert.AreEqual(request.Location, responseMessage.Location); } - private void ParameterizedReceiveTest(HttpDeliveryMethods scheme) { + private async Task ParameterizedReceiveTestAsync(HttpDeliveryMethods scheme) { var fields = new Dictionary<string, string> { { "age", "15" }, { "Name", "Andrew" }, @@ -346,7 +340,7 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { { "Timestamp", XmlConvert.ToString(DateTime.UtcNow, XmlDateTimeSerializationMode.Utc) }, { "realm", "someValue" }, }; - IProtocolMessage requestMessage = this.channel.ReadFromRequest(CreateHttpRequestInfo(scheme, fields)); + IProtocolMessage requestMessage = await this.channel.ReadFromRequestAsync(CreateHttpRequestInfo(scheme, fields).AsHttpRequestMessage(), CancellationToken.None); Assert.IsNotNull(requestMessage); Assert.IsInstanceOf<TestMessage>(requestMessage); TestMessage testMessage = (TestMessage)requestMessage; diff --git a/src/DotNetOpenAuth.Test/OAuth/ChannelElements/PlaintextSigningBindingElementTest.cs b/src/DotNetOpenAuth.Test/OAuth/ChannelElements/PlaintextSigningBindingElementTest.cs index b3869e7..b8d4f2b 100644 --- a/src/DotNetOpenAuth.Test/OAuth/ChannelElements/PlaintextSigningBindingElementTest.cs +++ b/src/DotNetOpenAuth.Test/OAuth/ChannelElements/PlaintextSigningBindingElementTest.cs @@ -5,6 +5,9 @@ //----------------------------------------------------------------------- namespace DotNetOpenAuth.Test.OAuth.ChannelElements { + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OAuth; using DotNetOpenAuth.OAuth.ChannelElements; @@ -15,20 +18,20 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { [TestFixture] public class PlaintextSigningBindingElementTest { [Test] - public void HttpsSignatureGeneration() { + public async Task HttpsSignatureGeneration() { SigningBindingElementBase target = new PlaintextSigningBindingElement(); target.Channel = new TestChannel(); MessageReceivingEndpoint endpoint = new MessageReceivingEndpoint("https://localtest", HttpDeliveryMethods.GetRequest); ITamperResistantOAuthMessage message = new UnauthorizedTokenRequest(endpoint, Protocol.Default.Version); message.ConsumerSecret = "cs"; message.TokenSecret = "ts"; - Assert.IsNotNull(target.ProcessOutgoingMessage(message)); + Assert.IsNotNull(await target.ProcessOutgoingMessageAsync(message, CancellationToken.None)); Assert.AreEqual("PLAINTEXT", message.SignatureMethod); Assert.AreEqual("cs&ts", message.Signature); } [Test] - public void HttpsSignatureVerification() { + public async Task HttpsSignatureVerification() { MessageReceivingEndpoint endpoint = new MessageReceivingEndpoint("https://localtest", HttpDeliveryMethods.GetRequest); ITamperProtectionChannelBindingElement target = new PlaintextSigningBindingElement(); target.Channel = new TestChannel(); @@ -37,11 +40,11 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { message.TokenSecret = "ts"; message.SignatureMethod = "PLAINTEXT"; message.Signature = "cs&ts"; - Assert.IsNotNull(target.ProcessIncomingMessage(message)); + Assert.IsNotNull(target.ProcessIncomingMessageAsync(message, CancellationToken.None)); } [Test] - public void HttpsSignatureVerificationNotApplicable() { + public async Task HttpsSignatureVerificationNotApplicable() { SigningBindingElementBase target = new PlaintextSigningBindingElement(); target.Channel = new TestChannel(); MessageReceivingEndpoint endpoint = new MessageReceivingEndpoint("https://localtest", HttpDeliveryMethods.GetRequest); @@ -50,11 +53,11 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { message.TokenSecret = "ts"; message.SignatureMethod = "ANOTHERALGORITHM"; message.Signature = "somethingelse"; - Assert.AreEqual(MessageProtections.None, target.ProcessIncomingMessage(message), "PLAINTEXT binding element should opt-out where it doesn't understand."); + Assert.AreEqual(MessageProtections.None, await target.ProcessIncomingMessageAsync(message, CancellationToken.None), "PLAINTEXT binding element should opt-out where it doesn't understand."); } [Test] - public void HttpSignatureGeneration() { + public async Task HttpSignatureGeneration() { SigningBindingElementBase target = new PlaintextSigningBindingElement(); target.Channel = new TestChannel(); MessageReceivingEndpoint endpoint = new MessageReceivingEndpoint("http://localtest", HttpDeliveryMethods.GetRequest); @@ -63,13 +66,13 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { message.TokenSecret = "ts"; // Since this is (non-encrypted) HTTP, so the plain text signer should not be used - Assert.IsNull(target.ProcessOutgoingMessage(message)); + Assert.IsNull(await target.ProcessOutgoingMessageAsync(message, CancellationToken.None)); Assert.IsNull(message.SignatureMethod); Assert.IsNull(message.Signature); } [Test] - public void HttpSignatureVerification() { + public async Task HttpSignatureVerification() { SigningBindingElementBase target = new PlaintextSigningBindingElement(); target.Channel = new TestChannel(); MessageReceivingEndpoint endpoint = new MessageReceivingEndpoint("http://localtest", HttpDeliveryMethods.GetRequest); @@ -78,7 +81,7 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { message.TokenSecret = "ts"; message.SignatureMethod = "PLAINTEXT"; message.Signature = "cs%26ts"; - Assert.IsNull(target.ProcessIncomingMessage(message), "PLAINTEXT signature binding element should refuse to participate in non-encrypted messages."); + Assert.IsNull(await target.ProcessIncomingMessageAsync(message, CancellationToken.None), "PLAINTEXT signature binding element should refuse to participate in non-encrypted messages."); } } } diff --git a/src/DotNetOpenAuth.Test/OAuth/ChannelElements/SigningBindingElementBaseTests.cs b/src/DotNetOpenAuth.Test/OAuth/ChannelElements/SigningBindingElementBaseTests.cs index 490399c..e356c64 100644 --- a/src/DotNetOpenAuth.Test/OAuth/ChannelElements/SigningBindingElementBaseTests.cs +++ b/src/DotNetOpenAuth.Test/OAuth/ChannelElements/SigningBindingElementBaseTests.cs @@ -6,6 +6,8 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { using System.Collections.Generic; + using System.Net.Http; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Reflection; using DotNetOpenAuth.OAuth; @@ -80,7 +82,7 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { message.AccessToken = "tokenpublic"; var signedMessage = (ITamperResistantOAuthMessage)message; - signedMessage.HttpMethod = "GET"; + signedMessage.HttpMethod = HttpMethod.Get; signedMessage.SignatureMethod = "HMAC-SHA1"; MessageDictionary dictionary = this.MessageDescriptions.GetAccessor(message); @@ -115,7 +117,7 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { message.ConsumerKey = "nerdbank.org"; ((ITamperResistantOAuthMessage)message).ConsumerSecret = "nerdbanksecret"; var signedMessage = (ITamperResistantOAuthMessage)message; - signedMessage.HttpMethod = "GET"; + signedMessage.HttpMethod = HttpMethod.Get; signedMessage.SignatureMethod = "HMAC-SHA1"; MessageDictionary dictionary = messageDescriptions.GetAccessor(message); dictionary["oauth_timestamp"] = "1222665749"; diff --git a/src/DotNetOpenAuth.Test/OAuth/ConsumerDescription.cs b/src/DotNetOpenAuth.Test/OAuth/ConsumerDescription.cs index 74752f8..5c25d30 100644 --- a/src/DotNetOpenAuth.Test/OAuth/ConsumerDescription.cs +++ b/src/DotNetOpenAuth.Test/OAuth/ConsumerDescription.cs @@ -5,6 +5,8 @@ //----------------------------------------------------------------------- namespace DotNetOpenAuth.Test.OAuth { + using DotNetOpenAuth.OAuth; + /// <summary> /// Information necessary to initialize a <see cref="Consumer"/>, /// and to tell a <see cref="ServiceProvider"/> about it. diff --git a/src/DotNetOpenAuth.Test/OAuth/OAuthCoordinator.cs b/src/DotNetOpenAuth.Test/OAuth/OAuthCoordinator.cs deleted file mode 100644 index 21c1775..0000000 --- a/src/DotNetOpenAuth.Test/OAuth/OAuthCoordinator.cs +++ /dev/null @@ -1,71 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="OAuthCoordinator.cs" company="Outercurve Foundation"> -// Copyright (c) Outercurve Foundation. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.Test.OAuth { - using System; - using DotNetOpenAuth.Messaging; - using DotNetOpenAuth.Messaging.Bindings; - using DotNetOpenAuth.OAuth; - using DotNetOpenAuth.OAuth.ChannelElements; - using DotNetOpenAuth.Test.Mocks; - using Validation; - - /// <summary> - /// Runs a Consumer and Service Provider simultaneously so they can interact in a full simulation. - /// </summary> - internal class OAuthCoordinator : CoordinatorBase<WebConsumer, ServiceProvider> { - private ConsumerDescription consumerDescription; - private ServiceProviderDescription serviceDescription; - private DotNetOpenAuth.OAuth.ServiceProviderSecuritySettings serviceProviderSecuritySettings = DotNetOpenAuth.Configuration.OAuthElement.Configuration.ServiceProvider.SecuritySettings.CreateSecuritySettings(); - private DotNetOpenAuth.OAuth.ConsumerSecuritySettings consumerSecuritySettings = DotNetOpenAuth.Configuration.OAuthElement.Configuration.Consumer.SecuritySettings.CreateSecuritySettings(); - - /// <summary>Initializes a new instance of the <see cref="OAuthCoordinator"/> class.</summary> - /// <param name="consumerDescription">The description of the consumer.</param> - /// <param name="serviceDescription">The service description that will be used to construct the Consumer and ServiceProvider objects.</param> - /// <param name="consumerAction">The code path of the Consumer.</param> - /// <param name="serviceProviderAction">The code path of the Service Provider.</param> - internal OAuthCoordinator(ConsumerDescription consumerDescription, ServiceProviderDescription serviceDescription, Action<WebConsumer> consumerAction, Action<ServiceProvider> serviceProviderAction) - : base(consumerAction, serviceProviderAction) { - Requires.NotNull(consumerDescription, "consumerDescription"); - Requires.NotNull(serviceDescription, "serviceDescription"); - - this.consumerDescription = consumerDescription; - this.serviceDescription = serviceDescription; - } - - /// <summary> - /// Starts the simulation. - /// </summary> - internal override void Run() { - // Clone the template signing binding element. - var signingElement = this.serviceDescription.CreateTamperProtectionElement(); - var consumerSigningElement = signingElement.Clone(); - var spSigningElement = signingElement.Clone(); - - // Prepare token managers - InMemoryTokenManager consumerTokenManager = new InMemoryTokenManager(); - InMemoryTokenManager serviceTokenManager = new InMemoryTokenManager(); - consumerTokenManager.AddConsumer(this.consumerDescription); - serviceTokenManager.AddConsumer(this.consumerDescription); - - // Prepare channels that will pass messages directly back and forth. - var consumerChannel = new CoordinatingOAuthConsumerChannel(consumerSigningElement, (IConsumerTokenManager)consumerTokenManager, this.consumerSecuritySettings); - var serviceProviderChannel = new CoordinatingOAuthServiceProviderChannel(spSigningElement, (IServiceProviderTokenManager)serviceTokenManager, this.serviceProviderSecuritySettings); - consumerChannel.RemoteChannel = serviceProviderChannel; - serviceProviderChannel.RemoteChannel = consumerChannel; - - // Prepare the Consumer and Service Provider objects - WebConsumer consumer = new WebConsumer(this.serviceDescription, consumerTokenManager) { - OAuthChannel = consumerChannel, - }; - ServiceProvider serviceProvider = new ServiceProvider(this.serviceDescription, serviceTokenManager, new NonceMemoryStore()) { - OAuthChannel = serviceProviderChannel, - }; - - this.RunCore(consumer, serviceProvider); - } - } -} diff --git a/src/DotNetOpenAuth.Test/OAuth/ServiceProviderDescriptionTests.cs b/src/DotNetOpenAuth.Test/OAuth/ServiceProviderDescriptionTests.cs index cdc8de5..3da3112 100644 --- a/src/DotNetOpenAuth.Test/OAuth/ServiceProviderDescriptionTests.cs +++ b/src/DotNetOpenAuth.Test/OAuth/ServiceProviderDescriptionTests.cs @@ -11,7 +11,7 @@ namespace DotNetOpenAuth.Test.OAuth { using NUnit.Framework; /// <summary> - /// Tests for the <see cref="ServiceProviderEndpoints"/> class. + /// Tests for the <see cref="ServiceProviderHostDescription"/> class. /// </summary> [TestFixture] public class ServiceProviderDescriptionTests : TestBase { @@ -20,8 +20,8 @@ namespace DotNetOpenAuth.Test.OAuth { /// </summary> [Test] public void UserAuthorizationUriTest() { - ServiceProviderDescription target = new ServiceProviderDescription(); - MessageReceivingEndpoint expected = new MessageReceivingEndpoint("http://localhost/authorization", HttpDeliveryMethods.GetRequest); + var target = new ServiceProviderHostDescription(); + var expected = new MessageReceivingEndpoint("http://localhost/authorization", HttpDeliveryMethods.GetRequest); MessageReceivingEndpoint actual; target.UserAuthorizationEndpoint = expected; actual = target.UserAuthorizationEndpoint; @@ -36,8 +36,8 @@ namespace DotNetOpenAuth.Test.OAuth { /// </summary> [Test] public void RequestTokenUriTest() { - var target = new ServiceProviderDescription(); - MessageReceivingEndpoint expected = new MessageReceivingEndpoint("http://localhost/requesttoken", HttpDeliveryMethods.GetRequest); + var target = new ServiceProviderHostDescription(); + var expected = new MessageReceivingEndpoint("http://localhost/requesttoken", HttpDeliveryMethods.GetRequest); MessageReceivingEndpoint actual; target.RequestTokenEndpoint = expected; actual = target.RequestTokenEndpoint; @@ -53,7 +53,7 @@ namespace DotNetOpenAuth.Test.OAuth { /// </summary> [Test, ExpectedException(typeof(ArgumentException))] public void RequestTokenUriWithOAuthParametersTest() { - var target = new ServiceProviderDescription(); + var target = new ServiceProviderHostDescription(); target.RequestTokenEndpoint = new MessageReceivingEndpoint("http://localhost/requesttoken?oauth_token=something", HttpDeliveryMethods.GetRequest); } @@ -62,7 +62,7 @@ namespace DotNetOpenAuth.Test.OAuth { /// </summary> [Test] public void AccessTokenUriTest() { - var target = new ServiceProviderDescription(); + var target = new ServiceProviderHostDescription(); MessageReceivingEndpoint expected = new MessageReceivingEndpoint("http://localhost/accesstoken", HttpDeliveryMethods.GetRequest); MessageReceivingEndpoint actual; target.AccessTokenEndpoint = expected; diff --git a/src/DotNetOpenAuth.Test/OAuth2/AuthorizationServerTests.cs b/src/DotNetOpenAuth.Test/OAuth2/AuthorizationServerTests.cs index e8f7172..8ec25b0 100644 --- a/src/DotNetOpenAuth.Test/OAuth2/AuthorizationServerTests.cs +++ b/src/DotNetOpenAuth.Test/OAuth2/AuthorizationServerTests.cs @@ -8,7 +8,10 @@ namespace DotNetOpenAuth.Test.OAuth2 { using System; using System.Collections.Generic; using System.Linq; + using System.Net; + using System.Net.Http; using System.Text; + using System.Threading; using System.Threading.Tasks; using DotNetOpenAuth.OAuth2; using DotNetOpenAuth.OAuth2.ChannelElements; @@ -25,59 +28,66 @@ namespace DotNetOpenAuth.Test.OAuth2 { /// Verifies that authorization server responds with an appropriate error response. /// </summary> [Test] - public void ErrorResponseTest() { - var coordinator = new OAuth2Coordinator<UserAgentClient>( - AuthorizationServerDescription, - AuthorizationServerMock, - new UserAgentClient(AuthorizationServerDescription), - client => { - var request = new AccessTokenAuthorizationCodeRequestC(AuthorizationServerDescription) { ClientIdentifier = ClientId, ClientSecret = ClientSecret, AuthorizationCode = "foo" }; - - var response = client.Channel.Request<AccessTokenFailedResponse>(request); - Assert.That(response.Error, Is.Not.Null.And.Not.Empty); - Assert.That(response.Error, Is.EqualTo(Protocol.AccessTokenRequestErrorCodes.InvalidRequest)); - }, - server => { - server.HandleTokenRequest().Respond(); + public async Task ErrorResponseTest() { + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(AuthorizationServerMock); + return await server.HandleTokenRequestAsync(req, ct); }); - coordinator.Run(); + var request = new AccessTokenAuthorizationCodeRequestC(AuthorizationServerDescription) { ClientIdentifier = ClientId, ClientSecret = ClientSecret, AuthorizationCode = "foo" }; + var client = new UserAgentClient(AuthorizationServerDescription, hostFactories: this.HostFactories); + var response = await client.Channel.RequestAsync<AccessTokenFailedResponse>(request, CancellationToken.None); + Assert.That(response.Error, Is.Not.Null.And.Not.Empty); + Assert.That(response.Error, Is.EqualTo(Protocol.AccessTokenRequestErrorCodes.InvalidRequest)); } [Test] - public void DecodeRefreshToken() { + public async Task DecodeRefreshToken() { var refreshTokenSource = new TaskCompletionSource<string>(); - var coordinator = new OAuth2Coordinator<WebServerClient>( - AuthorizationServerDescription, - AuthorizationServerMock, - new WebServerClient(AuthorizationServerDescription), - client => { - try { - var authState = new AuthorizationState(TestScopes) { - Callback = ClientCallback, - }; - client.PrepareRequestUserAuthorization(authState).Respond(); - var result = client.ProcessUserAuthorization(); - Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); - Assert.That(result.RefreshToken, Is.Not.Null.And.Not.Empty); - refreshTokenSource.SetResult(result.RefreshToken); - } catch { - refreshTokenSource.TrySetCanceled(); - } - }, - server => { - var request = server.ReadAuthorizationRequest(); + Handle(AuthorizationServerDescription.AuthorizationEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(AuthorizationServerMock); + var request = await server.ReadAuthorizationRequestAsync(req, ct); Assert.That(request, Is.Not.Null); - server.ApproveAuthorizationRequest(request, ResourceOwnerUsername); - server.HandleTokenRequest().Respond(); + var response = server.PrepareApproveAuthorizationRequest(request, ResourceOwnerUsername); + return await server.Channel.PrepareResponseAsync(response); + }); + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(AuthorizationServerMock); + var response = await server.HandleTokenRequestAsync(req, ct); var authorization = server.DecodeRefreshToken(refreshTokenSource.Task.Result); Assert.That(authorization, Is.Not.Null); Assert.That(authorization.User, Is.EqualTo(ResourceOwnerUsername)); + return response; }); - coordinator.Run(); + + var client = new WebServerClient(AuthorizationServerDescription); + try { + var authState = new AuthorizationState(TestScopes) { Callback = ClientCallback, }; + var authRedirectResponse = await client.PrepareRequestUserAuthorizationAsync(authState); + this.HostFactories.CookieContainer.SetCookies(authRedirectResponse, ClientCallback); + Uri authCompleteUri; + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(authRedirectResponse.Headers.Location)) { + response.EnsureSuccessStatusCode(); + authCompleteUri = response.Headers.Location; + } + } + + var authCompleteRequest = new HttpRequestMessage(HttpMethod.Get, authCompleteUri); + this.HostFactories.CookieContainer.ApplyCookies(authCompleteRequest); + var result = await client.ProcessUserAuthorizationAsync(authCompleteRequest); + Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); + Assert.That(result.RefreshToken, Is.Not.Null.And.Not.Empty); + refreshTokenSource.SetResult(result.RefreshToken); + } catch { + refreshTokenSource.TrySetCanceled(); + } } [Test] - public void ResourceOwnerScopeOverride() { + public async Task ResourceOwnerScopeOverride() { var clientRequestedScopes = new[] { "scope1", "scope2" }; var serverOverriddenScopes = new[] { "scope1", "differentScope" }; var authServerMock = CreateAuthorizationServerMock(); @@ -89,25 +99,20 @@ namespace DotNetOpenAuth.Test.OAuth2 { response.ApprovedScope.UnionWith(serverOverriddenScopes); return response; }); - var coordinator = new OAuth2Coordinator<WebServerClient>( - AuthorizationServerDescription, - authServerMock.Object, - new WebServerClient(AuthorizationServerDescription), - client => { - var authState = new AuthorizationState(TestScopes) { - Callback = ClientCallback, - }; - var result = client.ExchangeUserCredentialForToken(ResourceOwnerUsername, ResourceOwnerPassword, clientRequestedScopes); - Assert.That(result.Scope, Is.EquivalentTo(serverOverriddenScopes)); - }, - server => { - server.HandleTokenRequest().Respond(); + + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(authServerMock.Object); + return await server.HandleTokenRequestAsync(req, ct); }); - coordinator.Run(); + + var client = new WebServerClient(AuthorizationServerDescription, hostFactories: this.HostFactories); + var result = await client.ExchangeUserCredentialForTokenAsync(ResourceOwnerUsername, ResourceOwnerPassword, clientRequestedScopes); + Assert.That(result.Scope, Is.EquivalentTo(serverOverriddenScopes)); } [Test] - public void CreateAccessTokenSeesAuthorizingUserResourceOwnerGrant() { + public async Task CreateAccessTokenSeesAuthorizingUserResourceOwnerGrant() { var authServerMock = CreateAuthorizationServerMock(); authServerMock .Setup(a => a.CheckAuthorizeResourceOwnerCredentialGrant(ResourceOwnerUsername, ResourceOwnerPassword, It.IsAny<IAccessTokenRequest>())) @@ -116,25 +121,20 @@ namespace DotNetOpenAuth.Test.OAuth2 { Assert.That(req.UserName, Is.EqualTo(ResourceOwnerUsername)); return response; }); - var coordinator = new OAuth2Coordinator<WebServerClient>( - AuthorizationServerDescription, - authServerMock.Object, - new WebServerClient(AuthorizationServerDescription), - client => { - var authState = new AuthorizationState(TestScopes) { - Callback = ClientCallback, - }; - var result = client.ExchangeUserCredentialForToken(ResourceOwnerUsername, ResourceOwnerPassword, TestScopes); - Assert.That(result.AccessToken, Is.Not.Null); - }, - server => { - server.HandleTokenRequest().Respond(); + + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(authServerMock.Object); + return await server.HandleTokenRequestAsync(req, ct); }); - coordinator.Run(); + + var client = new WebServerClient(AuthorizationServerDescription, hostFactories: this.HostFactories); + var result = await client.ExchangeUserCredentialForTokenAsync(ResourceOwnerUsername, ResourceOwnerPassword, TestScopes); + Assert.That(result.AccessToken, Is.Not.Null); } [Test] - public void CreateAccessTokenSeesAuthorizingUserClientCredentialGrant() { + public async Task CreateAccessTokenSeesAuthorizingUserClientCredentialGrant() { var authServerMock = CreateAuthorizationServerMock(); authServerMock .Setup(a => a.CheckAuthorizeClientCredentialsGrant(It.IsAny<IAccessTokenRequest>())) @@ -142,25 +142,20 @@ namespace DotNetOpenAuth.Test.OAuth2 { Assert.That(req.UserName, Is.Null); return new AutomatedAuthorizationCheckResponse(req, true); }); - var coordinator = new OAuth2Coordinator<WebServerClient>( - AuthorizationServerDescription, - authServerMock.Object, - new WebServerClient(AuthorizationServerDescription), - client => { - var authState = new AuthorizationState(TestScopes) { - Callback = ClientCallback, - }; - var result = client.GetClientAccessToken(TestScopes); - Assert.That(result.AccessToken, Is.Not.Null); - }, - server => { - server.HandleTokenRequest().Respond(); + + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(authServerMock.Object); + return await server.HandleTokenRequestAsync(req, ct); }); - coordinator.Run(); + + var client = new WebServerClient(AuthorizationServerDescription, ClientId, ClientSecret, this.HostFactories); + var result = await client.GetClientAccessTokenAsync(TestScopes); + Assert.That(result.AccessToken, Is.Not.Null); } [Test] - public void CreateAccessTokenSeesAuthorizingUserAuthorizationCodeGrant() { + public async Task CreateAccessTokenSeesAuthorizingUserAuthorizationCodeGrant() { var authServerMock = CreateAuthorizationServerMock(); authServerMock .Setup(a => a.IsAuthorizationValid(It.IsAny<IAuthorizationDescription>())) @@ -168,30 +163,45 @@ namespace DotNetOpenAuth.Test.OAuth2 { Assert.That(req.User, Is.EqualTo(ResourceOwnerUsername)); return true; }); - var coordinator = new OAuth2Coordinator<WebServerClient>( - AuthorizationServerDescription, - authServerMock.Object, - new WebServerClient(AuthorizationServerDescription), - client => { - var authState = new AuthorizationState(TestScopes) { - Callback = ClientCallback, - }; - client.PrepareRequestUserAuthorization(authState).Respond(); - var result = client.ProcessUserAuthorization(); - Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); - Assert.That(result.RefreshToken, Is.Not.Null.And.Not.Empty); - }, - server => { - var request = server.ReadAuthorizationRequest(); + + Handle(AuthorizationServerDescription.AuthorizationEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(authServerMock.Object); + var request = await server.ReadAuthorizationRequestAsync(req, ct); Assert.That(request, Is.Not.Null); - server.ApproveAuthorizationRequest(request, ResourceOwnerUsername); - server.HandleTokenRequest().Respond(); + var response = server.PrepareApproveAuthorizationRequest(request, ResourceOwnerUsername); + return await server.Channel.PrepareResponseAsync(response); + }); + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(authServerMock.Object); + return await server.HandleTokenRequestAsync(req, ct); }); - coordinator.Run(); + + var client = new WebServerClient(AuthorizationServerDescription, ClientId, ClientSecret, this.HostFactories); + var authState = new AuthorizationState(TestScopes) { + Callback = ClientCallback, + }; + var authRedirectResponse = await client.PrepareRequestUserAuthorizationAsync(authState); + this.HostFactories.CookieContainer.SetCookies(authRedirectResponse, ClientCallback); + Uri authCompleteUri; + this.HostFactories.AllowAutoRedirects = false; + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(authRedirectResponse.Headers.Location)) { + Assert.That(response.StatusCode, Is.EqualTo(HttpStatusCode.Redirect)); + authCompleteUri = response.Headers.Location; + } + } + + var authCompleteRequest = new HttpRequestMessage(HttpMethod.Get, authCompleteUri); + this.HostFactories.CookieContainer.ApplyCookies(authCompleteRequest); + var result = await client.ProcessUserAuthorizationAsync(authCompleteRequest); + Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); + Assert.That(result.RefreshToken, Is.Not.Null.And.Not.Empty); } [Test] - public void ClientCredentialScopeOverride() { + public async Task ClientCredentialScopeOverride() { var clientRequestedScopes = new[] { "scope1", "scope2" }; var serverOverriddenScopes = new[] { "scope1", "differentScope" }; var authServerMock = CreateAuthorizationServerMock(); @@ -203,21 +213,17 @@ namespace DotNetOpenAuth.Test.OAuth2 { response.ApprovedScope.UnionWith(serverOverriddenScopes); return response; }); - var coordinator = new OAuth2Coordinator<WebServerClient>( - AuthorizationServerDescription, - authServerMock.Object, - new WebServerClient(AuthorizationServerDescription), - client => { - var authState = new AuthorizationState(TestScopes) { - Callback = ClientCallback, - }; - var result = client.GetClientAccessToken(clientRequestedScopes); - Assert.That(result.Scope, Is.EquivalentTo(serverOverriddenScopes)); - }, - server => { - server.HandleTokenRequest().Respond(); + + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(authServerMock.Object); + return await server.HandleTokenRequestAsync(req, ct); }); - coordinator.Run(); + + var client = new WebServerClient(AuthorizationServerDescription, ClientId, ClientSecret, this.HostFactories); + var result = await client.GetClientAccessTokenAsync(clientRequestedScopes); + Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); + Assert.That(result.Scope, Is.EquivalentTo(serverOverriddenScopes)); } } } diff --git a/src/DotNetOpenAuth.Test/OAuth2/MessageFactoryTests.cs b/src/DotNetOpenAuth.Test/OAuth2/MessageFactoryTests.cs index 52b5371..810d830 100644 --- a/src/DotNetOpenAuth.Test/OAuth2/MessageFactoryTests.cs +++ b/src/DotNetOpenAuth.Test/OAuth2/MessageFactoryTests.cs @@ -31,7 +31,7 @@ namespace DotNetOpenAuth.Test.OAuth2 { var authServerChannel = new OAuth2AuthorizationServerChannel(new Mock<IAuthorizationServerHost>().Object, new Mock<ClientAuthenticationModule>().Object); this.authServerMessageFactory = authServerChannel.MessageFactoryTestHook; - var clientChannel = new OAuth2ClientChannel(); + var clientChannel = new OAuth2ClientChannel(null); this.clientMessageFactory = clientChannel.MessageFactoryTestHook; } diff --git a/src/DotNetOpenAuth.Test/OAuth2/OAuth2Coordinator.cs b/src/DotNetOpenAuth.Test/OAuth2/OAuth2Coordinator.cs deleted file mode 100644 index eeda125..0000000 --- a/src/DotNetOpenAuth.Test/OAuth2/OAuth2Coordinator.cs +++ /dev/null @@ -1,74 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="OAuth2Coordinator.cs" company="Outercurve Foundation"> -// Copyright (c) Outercurve Foundation. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.Test.OAuth2 { - using System; - using System.Collections.Generic; - using System.Linq; - using System.Net; - using System.Text; - using DotNetOpenAuth.OAuth2; - using DotNetOpenAuth.Test.Mocks; - using Validation; - - internal class OAuth2Coordinator<TClient> : CoordinatorBase<TClient, AuthorizationServer> - where TClient : ClientBase { - private readonly AuthorizationServerDescription serverDescription; - private readonly IAuthorizationServerHost authServerHost; - private readonly TClient client; - - internal OAuth2Coordinator( - AuthorizationServerDescription serverDescription, - IAuthorizationServerHost authServerHost, - TClient client, - Action<TClient> clientAction, - Action<AuthorizationServer> authServerAction) - : base(clientAction, authServerAction) { - Requires.NotNull(serverDescription, "serverDescription"); - Requires.NotNull(authServerHost, "authServerHost"); - Requires.NotNull(client, "client"); - - this.serverDescription = serverDescription; - this.authServerHost = authServerHost; - this.client = client; - - this.client.ClientIdentifier = OAuth2TestBase.ClientId; - this.client.ClientCredentialApplicator = ClientCredentialApplicator.PostParameter(OAuth2TestBase.ClientSecret); - } - - internal override void Run() { - var authServer = new AuthorizationServer(this.authServerHost); - - var rpCoordinatingChannel = new CoordinatingOAuth2ClientChannel(this.client.Channel, this.IncomingMessageFilter, this.OutgoingMessageFilter); - var opCoordinatingChannel = new CoordinatingOAuth2AuthServerChannel(authServer.Channel, this.IncomingMessageFilter, this.OutgoingMessageFilter); - rpCoordinatingChannel.RemoteChannel = opCoordinatingChannel; - opCoordinatingChannel.RemoteChannel = rpCoordinatingChannel; - - this.client.Channel = rpCoordinatingChannel; - authServer.Channel = opCoordinatingChannel; - - this.RunCore(this.client, authServer); - } - - private static Action<WebServerClient> WrapAction(Action<WebServerClient> action) { - Requires.NotNull(action, "action"); - - return client => { - action(client); - ((CoordinatingChannel)client.Channel).Close(); - }; - } - - private static Action<AuthorizationServer> WrapAction(Action<AuthorizationServer> action) { - Requires.NotNull(action, "action"); - - return authServer => { - action(authServer); - ((CoordinatingChannel)authServer.Channel).Close(); - }; - } - } -} diff --git a/src/DotNetOpenAuth.Test/OAuth2/OAuth2TestBase.cs b/src/DotNetOpenAuth.Test/OAuth2/OAuth2TestBase.cs index 395b18c..f01b5b7 100644 --- a/src/DotNetOpenAuth.Test/OAuth2/OAuth2TestBase.cs +++ b/src/DotNetOpenAuth.Test/OAuth2/OAuth2TestBase.cs @@ -37,10 +37,7 @@ namespace DotNetOpenAuth.Test.OAuth2 { TokenEndpoint = new Uri("https://authserver/token"), }; - protected static readonly IClientDescription ClientDescription = new ClientDescription( - ClientSecret, - ClientCallback, - ClientType.Confidential); + protected static readonly IClientDescription ClientDescription = new ClientDescription(ClientSecret, ClientCallback); protected static readonly IAuthorizationServerHost AuthorizationServerMock = CreateAuthorizationServerMock().Object; diff --git a/src/DotNetOpenAuth.Test/OAuth2/ResourceServerTests.cs b/src/DotNetOpenAuth.Test/OAuth2/ResourceServerTests.cs index 80a8392..c232450 100644 --- a/src/DotNetOpenAuth.Test/OAuth2/ResourceServerTests.cs +++ b/src/DotNetOpenAuth.Test/OAuth2/ResourceServerTests.cs @@ -11,6 +11,8 @@ namespace DotNetOpenAuth.Test.OAuth2 { using System.Linq; using System.Security.Cryptography; using System.Text; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OAuth2; using DotNetOpenAuth.OAuth2.ChannelElements; @@ -28,7 +30,7 @@ namespace DotNetOpenAuth.Test.OAuth2 { { "Authorization", "Bearer " }, }; var request = new HttpRequestInfo("GET", new Uri("http://localhost/resource"), headers: requestHeaders); - Assert.That(() => resourceServer.GetAccessToken(request), Throws.InstanceOf<ProtocolException>()); + Assert.That(() => resourceServer.GetAccessTokenAsync(request).GetAwaiter().GetResult(), Throws.InstanceOf<ProtocolException>()); } [Test] @@ -39,7 +41,7 @@ namespace DotNetOpenAuth.Test.OAuth2 { { "Authorization", "Bearer " }, }; var request = new HttpRequestInfo("GET", new Uri("http://localhost/resource"), headers: requestHeaders); - Assert.That(() => resourceServer.GetPrincipal(request), Throws.InstanceOf<ProtocolException>()); + Assert.That(() => resourceServer.GetPrincipalAsync(request).GetAwaiter().GetResult(), Throws.InstanceOf<ProtocolException>()); } [Test] @@ -50,12 +52,12 @@ namespace DotNetOpenAuth.Test.OAuth2 { { "Authorization", "Bearer foobar" }, }; var request = new HttpRequestInfo("GET", new Uri("http://localhost/resource"), headers: requestHeaders); - Assert.That(() => resourceServer.GetAccessToken(request), Throws.InstanceOf<ProtocolException>()); + Assert.That(() => resourceServer.GetAccessTokenAsync(request).GetAwaiter().GetResult(), Throws.InstanceOf<ProtocolException>()); } [Test] - public void GetAccessTokenWithCorruptedToken() { - var accessToken = this.ObtainValidAccessToken(); + public async Task GetAccessTokenWithCorruptedToken() { + var accessToken = await this.ObtainValidAccessTokenAsync(); var resourceServer = new ResourceServer(new StandardAccessTokenAnalyzer(AsymmetricKey, null)); @@ -63,12 +65,12 @@ namespace DotNetOpenAuth.Test.OAuth2 { { "Authorization", "Bearer " + accessToken.Substring(0, accessToken.Length - 1) + "zzz" }, }; var request = new HttpRequestInfo("GET", new Uri("http://localhost/resource"), headers: requestHeaders); - Assert.That(() => resourceServer.GetAccessToken(request), Throws.InstanceOf<ProtocolException>()); + Assert.That(() => resourceServer.GetAccessTokenAsync(request).GetAwaiter().GetResult(), Throws.InstanceOf<ProtocolException>()); } [Test] - public void GetAccessTokenWithValidToken() { - var accessToken = this.ObtainValidAccessToken(); + public async Task GetAccessTokenWithValidToken() { + var accessToken = await this.ObtainValidAccessTokenAsync(); var resourceServer = new ResourceServer(new StandardAccessTokenAnalyzer(AsymmetricKey, null)); @@ -76,11 +78,11 @@ namespace DotNetOpenAuth.Test.OAuth2 { { "Authorization", "Bearer " + accessToken }, }; var request = new HttpRequestInfo("GET", new Uri("http://localhost/resource"), headers: requestHeaders); - var resourceServerDecodedToken = resourceServer.GetAccessToken(request); + var resourceServerDecodedToken = await resourceServer.GetAccessTokenAsync(request); Assert.That(resourceServerDecodedToken, Is.Not.Null); } - private string ObtainValidAccessToken() { + private async Task<string> ObtainValidAccessTokenAsync() { string accessToken = null; var authServer = CreateAuthorizationServerMock(); authServer.Setup( @@ -89,20 +91,18 @@ namespace DotNetOpenAuth.Test.OAuth2 { authServer.Setup( a => a.CheckAuthorizeClientCredentialsGrant(It.Is<IAccessTokenRequest>(d => d.ClientIdentifier == ClientId && MessagingUtilities.AreEquivalent(d.Scope, TestScopes)))) .Returns<IAccessTokenRequest>(req => new AutomatedAuthorizationCheckResponse(req, true)); - var coordinator = new OAuth2Coordinator<WebServerClient>( - AuthorizationServerDescription, - authServer.Object, - new WebServerClient(AuthorizationServerDescription), - client => { - var authState = client.GetClientAccessToken(TestScopes); - Assert.That(authState.AccessToken, Is.Not.Null.And.Not.Empty); - Assert.That(authState.RefreshToken, Is.Null); - accessToken = authState.AccessToken; - }, - server => { - server.HandleTokenRequest().Respond(); + + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(authServer.Object); + return await server.HandleTokenRequestAsync(req, ct); }); - coordinator.Run(); + + var client = new WebServerClient(AuthorizationServerDescription, ClientId, ClientSecret, this.HostFactories); + var authState = await client.GetClientAccessTokenAsync(TestScopes); + Assert.That(authState.AccessToken, Is.Not.Null.And.Not.Empty); + Assert.That(authState.RefreshToken, Is.Null); + accessToken = authState.AccessToken; return accessToken; } diff --git a/src/DotNetOpenAuth.Test/OAuth2/UserAgentClientAuthorizeTests.cs b/src/DotNetOpenAuth.Test/OAuth2/UserAgentClientAuthorizeTests.cs index ae03b0c..d0e9617 100644 --- a/src/DotNetOpenAuth.Test/OAuth2/UserAgentClientAuthorizeTests.cs +++ b/src/DotNetOpenAuth.Test/OAuth2/UserAgentClientAuthorizeTests.cs @@ -8,7 +8,10 @@ namespace DotNetOpenAuth.Test.OAuth2 { using System; using System.Collections.Generic; using System.Linq; + using System.Net.Http; using System.Text; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OAuth2; @@ -20,61 +23,77 @@ namespace DotNetOpenAuth.Test.OAuth2 { [TestFixture] public class UserAgentClientAuthorizeTests : OAuth2TestBase { [Test] - public void AuthorizationCodeGrant() { - var coordinator = new OAuth2Coordinator<UserAgentClient>( - AuthorizationServerDescription, - AuthorizationServerMock, - new UserAgentClient(AuthorizationServerDescription), - client => { - var authState = new AuthorizationState(TestScopes) { - Callback = ClientCallback, - }; - var request = client.PrepareRequestUserAuthorization(authState); - Assert.AreEqual(EndUserAuthorizationResponseType.AuthorizationCode, request.ResponseType); - client.Channel.Respond(request); - var incoming = client.Channel.ReadFromRequest(); - var result = client.ProcessUserAuthorization(authState, incoming); - Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); - Assert.That(result.RefreshToken, Is.Not.Null.And.Not.Empty); - }, - server => { - var request = server.ReadAuthorizationRequest(); + public async Task AuthorizationCodeGrant() { + Handle(AuthorizationServerDescription.AuthorizationEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(AuthorizationServerMock); + var request = await server.ReadAuthorizationRequestAsync(req, ct); Assert.That(request, Is.Not.Null); - server.ApproveAuthorizationRequest(request, ResourceOwnerUsername); - server.HandleTokenRequest().Respond(); + var response = server.PrepareApproveAuthorizationRequest(request, ResourceOwnerUsername); + return await server.Channel.PrepareResponseAsync(response, ct); }); - coordinator.Run(); + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(AuthorizationServerMock); + return await server.HandleTokenRequestAsync(req, ct); + }); + { + var client = new UserAgentClient(AuthorizationServerDescription, ClientId, ClientSecret, this.HostFactories); + var authState = new AuthorizationState(TestScopes) { Callback = ClientCallback, }; + var request = client.PrepareRequestUserAuthorization(authState); + Assert.AreEqual(EndUserAuthorizationResponseType.AuthorizationCode, request.ResponseType); + var authRequestRedirect = await client.Channel.PrepareResponseAsync(request); + Uri authRequestResponse; + this.HostFactories.AllowAutoRedirects = false; + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var httpResponse = await httpClient.GetAsync(authRequestRedirect.Headers.Location)) { + authRequestResponse = httpResponse.Headers.Location; + } + } + var incoming = + await + client.Channel.ReadFromRequestAsync( + new HttpRequestMessage(HttpMethod.Get, authRequestResponse), CancellationToken.None); + var result = await client.ProcessUserAuthorizationAsync(authState, incoming, CancellationToken.None); + Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); + Assert.That(result.RefreshToken, Is.Not.Null.And.Not.Empty); + } } [Test] - public void ImplicitGrant() { + public async Task ImplicitGrant() { var coordinatorClient = new UserAgentClient(AuthorizationServerDescription); - var coordinator = new OAuth2Coordinator<UserAgentClient>( - AuthorizationServerDescription, - AuthorizationServerMock, - coordinatorClient, - client => { - var authState = new AuthorizationState(TestScopes) { - Callback = ClientCallback, - }; - var request = client.PrepareRequestUserAuthorization(authState, implicitResponseType: true); - Assert.That(request.ResponseType, Is.EqualTo(EndUserAuthorizationResponseType.AccessToken)); - client.Channel.Respond(request); - var incoming = client.Channel.ReadFromRequest(); - var result = client.ProcessUserAuthorization(authState, incoming); - Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); - Assert.That(result.RefreshToken, Is.Null); - }, - server => { - var request = server.ReadAuthorizationRequest(); + coordinatorClient.ClientCredentialApplicator = null; // implicit grant clients don't need a secret. + Handle(AuthorizationServerDescription.AuthorizationEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(AuthorizationServerMock); + var request = await server.ReadAuthorizationRequestAsync(req, ct); Assert.That(request, Is.Not.Null); IAccessTokenRequest accessTokenRequest = (EndUserAuthorizationImplicitRequest)request; Assert.That(accessTokenRequest.ClientAuthenticated, Is.False); - server.ApproveAuthorizationRequest(request, ResourceOwnerUsername); + var response = server.PrepareApproveAuthorizationRequest(request, ResourceOwnerUsername); + return await server.Channel.PrepareResponseAsync(response, ct); }); + { + var client = new UserAgentClient(AuthorizationServerDescription, ClientId, ClientSecret, this.HostFactories); + var authState = new AuthorizationState(TestScopes) { Callback = ClientCallback, }; + var request = client.PrepareRequestUserAuthorization(authState, implicitResponseType: true); + Assert.That(request.ResponseType, Is.EqualTo(EndUserAuthorizationResponseType.AccessToken)); + var authRequestRedirect = await client.Channel.PrepareResponseAsync(request); + Uri authRequestResponse; + this.HostFactories.AllowAutoRedirects = false; + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var httpResponse = await httpClient.GetAsync(authRequestRedirect.Headers.Location)) { + authRequestResponse = httpResponse.Headers.Location; + } + } - coordinatorClient.ClientCredentialApplicator = null; // implicit grant clients don't need a secret. - coordinator.Run(); + var incoming = + await client.Channel.ReadFromRequestAsync(new HttpRequestMessage(HttpMethod.Get, authRequestResponse), CancellationToken.None); + var result = await client.ProcessUserAuthorizationAsync(authState, incoming, CancellationToken.None); + Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); + Assert.That(result.RefreshToken, Is.Null); + } } } } diff --git a/src/DotNetOpenAuth.Test/OAuth2/WebServerClientAuthorizeTests.cs b/src/DotNetOpenAuth.Test/OAuth2/WebServerClientAuthorizeTests.cs index 2a4241e..433cbf3 100644 --- a/src/DotNetOpenAuth.Test/OAuth2/WebServerClientAuthorizeTests.cs +++ b/src/DotNetOpenAuth.Test/OAuth2/WebServerClientAuthorizeTests.cs @@ -11,6 +11,7 @@ namespace DotNetOpenAuth.Test.OAuth2 { using System.Net; using System.Net.Http; using System.Text; + using System.Threading; using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OAuth2; @@ -22,31 +23,45 @@ namespace DotNetOpenAuth.Test.OAuth2 { [TestFixture] public class WebServerClientAuthorizeTests : OAuth2TestBase { [Test] - public void AuthorizationCodeGrant() { - var coordinator = new OAuth2Coordinator<WebServerClient>( - AuthorizationServerDescription, - AuthorizationServerMock, - new WebServerClient(AuthorizationServerDescription), - client => { - var authState = new AuthorizationState(TestScopes) { - Callback = ClientCallback, - }; - client.PrepareRequestUserAuthorization(authState).Respond(); - var result = client.ProcessUserAuthorization(); - Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); - Assert.That(result.RefreshToken, Is.Not.Null.And.Not.Empty); - }, - server => { - var request = server.ReadAuthorizationRequest(); + public async Task AuthorizationCodeGrant() { + Handle(AuthorizationServerDescription.AuthorizationEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(AuthorizationServerMock); + var request = await server.ReadAuthorizationRequestAsync(req, ct); Assert.That(request, Is.Not.Null); - server.ApproveAuthorizationRequest(request, ResourceOwnerUsername); - server.HandleTokenRequest().Respond(); + var response = server.PrepareApproveAuthorizationRequest(request, ResourceOwnerUsername); + return await server.Channel.PrepareResponseAsync(response, ct); }); - coordinator.Run(); + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(AuthorizationServerMock); + return await server.HandleTokenRequestAsync(req, ct); + }); + + var client = new WebServerClient(AuthorizationServerDescription, ClientId, ClientSecret, this.HostFactories); + var authState = new AuthorizationState(TestScopes) { + Callback = ClientCallback, + }; + var authRequestRedirect = await client.PrepareRequestUserAuthorizationAsync(authState); + this.HostFactories.CookieContainer.SetCookies(authRequestRedirect, ClientCallback); + Uri authRequestResponse; + this.HostFactories.AllowAutoRedirects = false; + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var httpResponse = await httpClient.GetAsync(authRequestRedirect.Headers.Location)) { + Assert.That(httpResponse.StatusCode, Is.EqualTo(HttpStatusCode.Redirect)); + authRequestResponse = httpResponse.Headers.Location; + } + } + + var authorizationResponse = new HttpRequestMessage(HttpMethod.Get, authRequestResponse); + this.HostFactories.CookieContainer.ApplyCookies(authorizationResponse); + var result = await client.ProcessUserAuthorizationAsync(authorizationResponse, CancellationToken.None); + Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); + Assert.That(result.RefreshToken, Is.Not.Null.And.Not.Empty); } [Theory] - public void ResourceOwnerPasswordCredentialGrant(bool anonymousClient) { + public async Task ResourceOwnerPasswordCredentialGrant(bool anonymousClient) { var authHostMock = CreateAuthorizationServerMock(); if (anonymousClient) { authHostMock.Setup( @@ -58,27 +73,23 @@ namespace DotNetOpenAuth.Test.OAuth2 { MessagingUtilities.AreEquivalent(d.Scope, TestScopes)))).Returns(true); } - var coordinator = new OAuth2Coordinator<WebServerClient>( - AuthorizationServerDescription, - authHostMock.Object, - new WebServerClient(AuthorizationServerDescription), - client => { - if (anonymousClient) { - client.ClientIdentifier = null; - } + Handle(AuthorizationServerDescription.TokenEndpoint).By(async (req, ct) => { + var server = new AuthorizationServer(authHostMock.Object); + return await server.HandleTokenRequestAsync(req, ct); + }); + + var client = new WebServerClient(AuthorizationServerDescription, ClientId, ClientSecret, this.HostFactories); + if (anonymousClient) { + client.ClientIdentifier = null; + } - var authState = client.ExchangeUserCredentialForToken(ResourceOwnerUsername, ResourceOwnerPassword, TestScopes); - Assert.That(authState.AccessToken, Is.Not.Null.And.Not.Empty); - Assert.That(authState.RefreshToken, Is.Not.Null.And.Not.Empty); - }, - server => { - server.HandleTokenRequest().Respond(); - }); - coordinator.Run(); + var authState = await client.ExchangeUserCredentialForTokenAsync(ResourceOwnerUsername, ResourceOwnerPassword, TestScopes); + Assert.That(authState.AccessToken, Is.Not.Null.And.Not.Empty); + Assert.That(authState.RefreshToken, Is.Not.Null.And.Not.Empty); } [Test] - public void ClientCredentialGrant() { + public async Task ClientCredentialGrant() { var authServer = CreateAuthorizationServerMock(); authServer.Setup( a => a.IsAuthorizationValid(It.Is<IAuthorizationDescription>(d => d.User == null && d.ClientIdentifier == ClientId && MessagingUtilities.AreEquivalent(d.Scope, TestScopes)))) @@ -86,23 +97,19 @@ namespace DotNetOpenAuth.Test.OAuth2 { authServer.Setup( a => a.CheckAuthorizeClientCredentialsGrant(It.Is<IAccessTokenRequest>(d => d.ClientIdentifier == ClientId && MessagingUtilities.AreEquivalent(d.Scope, TestScopes)))) .Returns<IAccessTokenRequest>(req => new AutomatedAuthorizationCheckResponse(req, true)); - var coordinator = new OAuth2Coordinator<WebServerClient>( - AuthorizationServerDescription, - authServer.Object, - new WebServerClient(AuthorizationServerDescription), - client => { - var authState = client.GetClientAccessToken(TestScopes); - Assert.That(authState.AccessToken, Is.Not.Null.And.Not.Empty); - Assert.That(authState.RefreshToken, Is.Null); - }, - server => { - server.HandleTokenRequest().Respond(); + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(authServer.Object); + return await server.HandleTokenRequestAsync(req, ct); }); - coordinator.Run(); + var client = new WebServerClient(AuthorizationServerDescription, ClientId, ClientSecret, this.HostFactories); + var authState = await client.GetClientAccessTokenAsync(TestScopes); + Assert.That(authState.AccessToken, Is.Not.Null.And.Not.Empty); + Assert.That(authState.RefreshToken, Is.Null); } [Test] - public void GetClientAccessTokenReturnsApprovedScope() { + public async Task GetClientAccessTokenReturnsApprovedScope() { string[] approvedScopes = new[] { "Scope2", "Scope3" }; var authServer = CreateAuthorizationServerMock(); authServer.Setup( @@ -111,22 +118,19 @@ namespace DotNetOpenAuth.Test.OAuth2 { authServer.Setup( a => a.CheckAuthorizeClientCredentialsGrant(It.Is<IAccessTokenRequest>(d => d.ClientIdentifier == ClientId && MessagingUtilities.AreEquivalent(d.Scope, TestScopes)))) .Returns<IAccessTokenRequest>(req => { - var response = new AutomatedAuthorizationCheckResponse(req, true); - response.ApprovedScope.ResetContents(approvedScopes); - return response; - }); - var coordinator = new OAuth2Coordinator<WebServerClient>( - AuthorizationServerDescription, - authServer.Object, - new WebServerClient(AuthorizationServerDescription), - client => { - var authState = client.GetClientAccessToken(TestScopes); - Assert.That(authState.Scope, Is.EquivalentTo(approvedScopes)); - }, - server => { - server.HandleTokenRequest().Respond(); + var response = new AutomatedAuthorizationCheckResponse(req, true); + response.ApprovedScope.ResetContents(approvedScopes); + return response; + }); + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(authServer.Object); + return await server.HandleTokenRequestAsync(req, ct); }); - coordinator.Run(); + + var client = new WebServerClient(AuthorizationServerDescription, ClientId, ClientSecret, this.HostFactories); + var authState = await client.GetClientAccessTokenAsync(TestScopes); + Assert.That(authState.Scope, Is.EquivalentTo(approvedScopes)); } [Test] diff --git a/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs b/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs index 029447d..6e3d7dc 100644 --- a/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs @@ -6,6 +6,8 @@ namespace DotNetOpenAuth.Test.OpenId { using System; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Messages; @@ -22,13 +24,13 @@ namespace DotNetOpenAuth.Test.OpenId { } [Test] - public void AssociateUnencrypted() { - this.ParameterizedAssociationTest(new Uri("https://host")); + public async Task AssociateUnencrypted() { + await this.ParameterizedAssociationTestAsync(OPUriSsl); } [Test] - public void AssociateDiffieHellmanOverHttp() { - this.ParameterizedAssociationTest(new Uri("http://host")); + public async Task AssociateDiffieHellmanOverHttp() { + await this.ParameterizedAssociationTestAsync(OPUri); } /// <summary> @@ -39,23 +41,22 @@ namespace DotNetOpenAuth.Test.OpenId { /// putting the two together, so we verify that DNOI can handle it. /// </remarks> [Test] - public void AssociateDiffieHellmanOverHttps() { + public async Task AssociateDiffieHellmanOverHttps() { Protocol protocol = Protocol.V20; - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - // We have to formulate the associate request manually, - // since the DNOI RP won't voluntarily use DH on HTTPS. - AssociateDiffieHellmanRequest request = new AssociateDiffieHellmanRequest(protocol.Version, new Uri("https://Provider")); - request.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256; - request.SessionType = protocol.Args.SessionType.DH_SHA256; - request.InitializeRequest(); - var response = rp.Channel.Request<AssociateSuccessfulResponse>(request); - Assert.IsNotNull(response); - Assert.AreEqual(request.AssociationType, response.AssociationType); - Assert.AreEqual(request.SessionType, response.SessionType); - }, - AutoProvider); - coordinator.Run(); + this.RegisterAutoProvider(); + var rp = this.CreateRelyingParty(); + + // We have to formulate the associate request manually, + // since the DNOI RP won't voluntarily use DH on HTTPS. + var request = new AssociateDiffieHellmanRequest(protocol.Version, OPUri) { + AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256, + SessionType = protocol.Args.SessionType.DH_SHA256 + }; + request.InitializeRequest(); + var response = await rp.Channel.RequestAsync<AssociateSuccessfulResponse>(request, CancellationToken.None); + Assert.IsNotNull(response); + Assert.AreEqual(request.AssociationType, response.AssociationType); + Assert.AreEqual(request.SessionType, response.SessionType); } /// <summary> @@ -63,44 +64,52 @@ namespace DotNetOpenAuth.Test.OpenId { /// initial request for an association is for a type the OP doesn't support. /// </summary> [Test] - public void AssociateRenegotiateBitLength() { + public async Task AssociateRenegotiateBitLength() { Protocol protocol = Protocol.V20; // The strategy is to make a simple request of the RP to establish an association, // and to more carefully observe the Provider-side of things to make sure that both // the OP and RP are behaving as expected. - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - var opDescription = new ProviderEndpointDescription(OPUri, protocol.Version); - Association association = rp.AssociationManager.GetOrCreateAssociation(opDescription); - Assert.IsNotNull(association, "Association failed to be created."); - Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA1, association.GetAssociationType(protocol)); - }, - op => { + int providerAttemptCount = 0; + HandleProvider( + async (op, request) => { op.SecuritySettings.MaximumHashBitLength = 160; // Force OP to reject HMAC-SHA256 - // Receive initial request for an HMAC-SHA256 association. - AutoResponsiveRequest req = (AutoResponsiveRequest)op.GetRequest(); - AssociateRequest associateRequest = (AssociateRequest)req.RequestMessage; - Assert.That(associateRequest, Is.Not.Null); - Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA256, associateRequest.AssociationType); - - // Ensure that the response is a suggestion that the RP try again with HMAC-SHA1 - AssociateUnsuccessfulResponse renegotiateResponse = (AssociateUnsuccessfulResponse)req.ResponseMessageTestHook; - Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA1, renegotiateResponse.AssociationType); - op.Respond(req); - - // Receive second attempt request for an HMAC-SHA1 association. - req = (AutoResponsiveRequest)op.GetRequest(); - associateRequest = (AssociateRequest)req.RequestMessage; - Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA1, associateRequest.AssociationType); - - // Ensure that the response is a success response. - AssociateSuccessfulResponse successResponse = (AssociateSuccessfulResponse)req.ResponseMessageTestHook; - Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA1, successResponse.AssociationType); - op.Respond(req); + switch (++providerAttemptCount) { + case 1: + // Receive initial request for an HMAC-SHA256 association. + var req = (AutoResponsiveRequest)await op.GetRequestAsync(request); + var associateRequest = (AssociateRequest)req.RequestMessage; + Assert.That(associateRequest, Is.Not.Null); + Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA256, associateRequest.AssociationType); + + // Ensure that the response is a suggestion that the RP try again with HMAC-SHA1 + var renegotiateResponse = + (AssociateUnsuccessfulResponse)await req.GetResponseMessageAsyncTestHook(CancellationToken.None); + Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA1, renegotiateResponse.AssociationType); + return await op.PrepareResponseAsync(req); + + case 2: + // Receive second attempt request for an HMAC-SHA1 association. + req = (AutoResponsiveRequest)await op.GetRequestAsync(request); + associateRequest = (AssociateRequest)req.RequestMessage; + Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA1, associateRequest.AssociationType); + + // Ensure that the response is a success response. + var successResponse = + (AssociateSuccessfulResponse)await req.GetResponseMessageAsyncTestHook(CancellationToken.None); + Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA1, successResponse.AssociationType); + return await op.PrepareResponseAsync(req); + + default: + throw Assumes.NotReachable(); + } }); - coordinator.Run(); + var rp = this.CreateRelyingParty(); + var opDescription = new ProviderEndpointDescription(OPUri, protocol.Version); + Association association = await rp.AssociationManager.GetOrCreateAssociationAsync(opDescription, CancellationToken.None); + Assert.IsNotNull(association, "Association failed to be created."); + Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA1, association.GetAssociationType(protocol)); } /// <summary> @@ -110,20 +119,18 @@ namespace DotNetOpenAuth.Test.OpenId { /// Verifies OP's compliance with OpenID 2.0 section 8.4.1. /// </remarks> [Test] - public void OPRejectsHttpNoEncryptionAssociateRequests() { + public async Task OPRejectsHttpNoEncryptionAssociateRequests() { Protocol protocol = Protocol.V20; - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - // We have to formulate the associate request manually, - // since the DNOI RP won't voluntarily suggest no encryption at all. - var request = new AssociateUnencryptedRequestNoSslCheck(protocol.Version, OPUri); - request.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256; - request.SessionType = protocol.Args.SessionType.NoEncryption; - var response = rp.Channel.Request<DirectErrorResponse>(request); - Assert.IsNotNull(response); - }, - AutoProvider); - coordinator.Run(); + this.RegisterAutoProvider(); + var rp = this.CreateRelyingParty(); + + // We have to formulate the associate request manually, + // since the DNOA RP won't voluntarily suggest no encryption at all. + var request = new AssociateUnencryptedRequestNoSslCheck(protocol.Version, OPUri); + request.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256; + request.SessionType = protocol.Args.SessionType.NoEncryption; + var response = await rp.Channel.RequestAsync<DirectErrorResponse>(request, CancellationToken.None); + Assert.IsNotNull(response); } /// <summary> @@ -131,47 +138,43 @@ namespace DotNetOpenAuth.Test.OpenId { /// when the HMAC and DH bit lengths do not match. /// </summary> [Test] - public void OPRejectsMismatchingAssociationAndSessionTypes() { + public async Task OPRejectsMismatchingAssociationAndSessionTypes() { Protocol protocol = Protocol.V20; - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - // We have to formulate the associate request manually, - // since the DNOI RP won't voluntarily mismatch the association and session types. - AssociateDiffieHellmanRequest request = new AssociateDiffieHellmanRequest(protocol.Version, new Uri("https://Provider")); - request.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256; - request.SessionType = protocol.Args.SessionType.DH_SHA1; - request.InitializeRequest(); - var response = rp.Channel.Request<AssociateUnsuccessfulResponse>(request); - Assert.IsNotNull(response); - Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA1, response.AssociationType); - Assert.AreEqual(protocol.Args.SessionType.DH_SHA1, response.SessionType); - }, - AutoProvider); - coordinator.Run(); + this.RegisterAutoProvider(); + var rp = this.CreateRelyingParty(); + + // We have to formulate the associate request manually, + // since the DNOI RP won't voluntarily mismatch the association and session types. + var request = new AssociateDiffieHellmanRequest(protocol.Version, OPUri); + request.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256; + request.SessionType = protocol.Args.SessionType.DH_SHA1; + request.InitializeRequest(); + var response = await rp.Channel.RequestAsync<AssociateUnsuccessfulResponse>(request, CancellationToken.None); + Assert.IsNotNull(response); + Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA1, response.AssociationType); + Assert.AreEqual(protocol.Args.SessionType.DH_SHA1, response.SessionType); } /// <summary> /// Verifies that the RP quietly rejects an OP that suggests an unknown association type. /// </summary> [Test] - public void RPRejectsUnrecognizedAssociationType() { + public async Task RPRejectsUnrecognizedAssociationType() { Protocol protocol = Protocol.V20; - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - var association = rp.AssociationManager.GetOrCreateAssociation(new ProviderEndpointDescription(OPUri, protocol.Version)); - Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); - }, - op => { + HandleProvider( + async (op, req) => { // Receive initial request. - var request = op.Channel.ReadFromRequest<AssociateRequest>(); + var request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, CancellationToken.None); // Send a response that suggests a foreign association type. - AssociateUnsuccessfulResponse renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); + var renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); renegotiateResponse.AssociationType = "HMAC-UNKNOWN"; renegotiateResponse.SessionType = "DH-UNKNOWN"; - op.Channel.Respond(renegotiateResponse); + return await op.Channel.PrepareResponseAsync(renegotiateResponse); }); - coordinator.Run(); + var rp = this.CreateRelyingParty(); + var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), CancellationToken.None); + Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); } /// <summary> @@ -181,24 +184,23 @@ namespace DotNetOpenAuth.Test.OpenId { /// Verifies RP's compliance with OpenID 2.0 section 8.4.1. /// </remarks> [Test] - public void RPRejectsUnencryptedSuggestion() { + public async Task RPRejectsUnencryptedSuggestion() { Protocol protocol = Protocol.V20; - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - var association = rp.AssociationManager.GetOrCreateAssociation(new ProviderEndpointDescription(OPUri, protocol.Version)); - Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); - }, - op => { + this.HandleProvider( + async (op, req) => { // Receive initial request. - var request = op.Channel.ReadFromRequest<AssociateRequest>(); + var request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, CancellationToken.None); // Send a response that suggests a no encryption. - AssociateUnsuccessfulResponse renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); + var renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); renegotiateResponse.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA1; renegotiateResponse.SessionType = protocol.Args.SessionType.NoEncryption; - op.Channel.Respond(renegotiateResponse); + return await op.Channel.PrepareResponseAsync(renegotiateResponse); }); - coordinator.Run(); + + var rp = this.CreateRelyingParty(); + var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), CancellationToken.None); + Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); } /// <summary> @@ -206,24 +208,22 @@ namespace DotNetOpenAuth.Test.OpenId { /// when the HMAC and DH bit lengths do not match. /// </summary> [Test] - public void RPRejectsMismatchingAssociationAndSessionBitLengths() { + public async Task RPRejectsMismatchingAssociationAndSessionBitLengths() { Protocol protocol = Protocol.V20; - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - var association = rp.AssociationManager.GetOrCreateAssociation(new ProviderEndpointDescription(OPUri, protocol.Version)); - Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); - }, - op => { + this.HandleProvider( + async (op, req) => { // Receive initial request. - var request = op.Channel.ReadFromRequest<AssociateRequest>(); + var request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, CancellationToken.None); // Send a mismatched response AssociateUnsuccessfulResponse renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); renegotiateResponse.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA1; renegotiateResponse.SessionType = protocol.Args.SessionType.DH_SHA256; - op.Channel.Respond(renegotiateResponse); + return await op.Channel.PrepareResponseAsync(renegotiateResponse); }); - coordinator.Run(); + var rp = this.CreateRelyingParty(); + var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), CancellationToken.None); + Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); } /// <summary> @@ -231,52 +231,56 @@ namespace DotNetOpenAuth.Test.OpenId { /// keeps sending it association retry messages. /// </summary> [Test] - public void RPOnlyRenegotiatesOnce() { + public async Task RPOnlyRenegotiatesOnce() { Protocol protocol = Protocol.V20; - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - var association = rp.AssociationManager.GetOrCreateAssociation(new ProviderEndpointDescription(OPUri, protocol.Version)); - Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); - }, - op => { - // Receive initial request. - var request = op.Channel.ReadFromRequest<AssociateRequest>(); + int opStep = 0; + HandleProvider( + async (op, req) => { + switch (++opStep) { + case 1: + // Receive initial request. + var request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, CancellationToken.None); - // Send a renegotiate response - AssociateUnsuccessfulResponse renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); - renegotiateResponse.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA1; - renegotiateResponse.SessionType = protocol.Args.SessionType.DH_SHA1; - op.Channel.Respond(renegotiateResponse); + // Send a renegotiate response + var renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); + renegotiateResponse.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA1; + renegotiateResponse.SessionType = protocol.Args.SessionType.DH_SHA1; + return await op.Channel.PrepareResponseAsync(renegotiateResponse, CancellationToken.None); - // Receive second-try - request = op.Channel.ReadFromRequest<AssociateRequest>(); + case 2: + // Receive second-try + request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, CancellationToken.None); - // Send ANOTHER renegotiate response, at which point the DNOI RP should give up. - renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); - renegotiateResponse.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256; - renegotiateResponse.SessionType = protocol.Args.SessionType.DH_SHA256; - op.Channel.Respond(renegotiateResponse); + // Send ANOTHER renegotiate response, at which point the DNOI RP should give up. + renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); + renegotiateResponse.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256; + renegotiateResponse.SessionType = protocol.Args.SessionType.DH_SHA256; + return await op.Channel.PrepareResponseAsync(renegotiateResponse, CancellationToken.None); + + default: + throw Assumes.NotReachable(); + } }); - coordinator.Run(); + var rp = this.CreateRelyingParty(); + var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), CancellationToken.None); + Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); } /// <summary> /// Verifies security settings limit RP's acceptance of OP's counter-suggestion /// </summary> [Test] - public void AssociateRenegotiateLimitedByRPSecuritySettings() { + public async Task AssociateRenegotiateLimitedByRPSecuritySettings() { Protocol protocol = Protocol.V20; - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - rp.SecuritySettings.MinimumHashBitLength = 256; - var association = rp.AssociationManager.GetOrCreateAssociation(new ProviderEndpointDescription(OPUri, protocol.Version)); - Assert.IsNull(association, "No association should have been created when RP and OP could not agree on association strength."); - }, - op => { + HandleProvider( + async (op, req) => { op.SecuritySettings.MaximumHashBitLength = 160; - AutoProvider(op); + return await AutoProviderActionAsync(op, req, CancellationToken.None); }); - coordinator.Run(); + var rp = this.CreateRelyingParty(); + rp.SecuritySettings.MinimumHashBitLength = 256; + var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), CancellationToken.None); + Assert.IsNull(association, "No association should have been created when RP and OP could not agree on association strength."); } /// <summary> @@ -284,10 +288,10 @@ namespace DotNetOpenAuth.Test.OpenId { /// response from the OP, for example in the HTTP timeout case. /// </summary> [Test] - public void AssociateQuietlyFailsAfterHttpError() { - this.MockResponder.RegisterMockNotFound(OPUri); + public async Task AssociateQuietlyFailsAfterHttpError() { + // Without wiring up a mock HTTP handler, the RP will get a 404 Not Found error. var rp = this.CreateRelyingParty(); - var association = rp.AssociationManager.GetOrCreateAssociation(new ProviderEndpointDescription(OPUri, Protocol.V20.Version)); + var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, Protocol.V20.Version), CancellationToken.None); Assert.IsNull(association); } @@ -295,11 +299,11 @@ namespace DotNetOpenAuth.Test.OpenId { /// Runs a parameterized association flow test using all supported OpenID versions. /// </summary> /// <param name="opEndpoint">The OP endpoint to simulate using.</param> - private void ParameterizedAssociationTest(Uri opEndpoint) { + private async Task ParameterizedAssociationTestAsync(Uri opEndpoint) { foreach (Protocol protocol in Protocol.AllPracticalVersions) { var endpoint = new ProviderEndpointDescription(opEndpoint, protocol.Version); var associationType = protocol.Version.Major < 2 ? protocol.Args.SignatureAlgorithm.HMAC_SHA1 : protocol.Args.SignatureAlgorithm.HMAC_SHA256; - this.ParameterizedAssociationTest(endpoint, associationType); + await this.ParameterizedAssociationTestAsync(endpoint, associationType); } } @@ -314,7 +318,7 @@ namespace DotNetOpenAuth.Test.OpenId { /// The value of the openid.assoc_type parameter expected, /// or null if a failure is anticipated. /// </param> - private void ParameterizedAssociationTest( + private async Task ParameterizedAssociationTestAsync( ProviderEndpointDescription opDescription, string expectedAssociationType) { Protocol protocol = Protocol.Lookup(Protocol.Lookup(opDescription.Version).ProtocolVersion); @@ -323,19 +327,18 @@ namespace DotNetOpenAuth.Test.OpenId { Association rpAssociation = null, opAssociation; AssociateSuccessfulResponse associateSuccessfulResponse = null; AssociateUnsuccessfulResponse associateUnsuccessfulResponse = null; - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - rp.SecuritySettings = this.RelyingPartySecuritySettings; - rpAssociation = rp.AssociationManager.GetOrCreateAssociation(opDescription); - }, - op => { - op.SecuritySettings = this.ProviderSecuritySettings; - IRequest req = op.GetRequest(); + var relyingParty = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore(), this.HostFactories); + var provider = new OpenIdProvider(new StandardProviderApplicationStore(), this.HostFactories) { + SecuritySettings = this.ProviderSecuritySettings + }; + Handle(opDescription.Uri).By( + async (request, ct) => { + IRequest req = await provider.GetRequestAsync(request, ct); Assert.IsNotNull(req, "Expected incoming request but did not receive it."); Assert.IsTrue(req.IsResponseReady); - op.Respond(req); + return await provider.PrepareResponseAsync(req, ct); }); - coordinator.IncomingMessageFilter = message => { + relyingParty.Channel.IncomingMessageFilter = message => { Assert.AreSame(opDescription.Version, message.Version, "The message was recognized as version {0} but was expected to be {1}.", message.Version, Protocol.Lookup(opDescription.Version).ProtocolVersion); var associateSuccess = message as AssociateSuccessfulResponse; var associateFailed = message as AssociateUnsuccessfulResponse; @@ -346,16 +349,18 @@ namespace DotNetOpenAuth.Test.OpenId { associateUnsuccessfulResponse = associateFailed; } }; - coordinator.OutgoingMessageFilter = message => { + relyingParty.Channel.OutgoingMessageFilter = message => { Assert.AreEqual(opDescription.Version, message.Version, "The message was for version {0} but was expected to be for {1}.", message.Version, opDescription.Version); }; - coordinator.Run(); + + relyingParty.SecuritySettings = this.RelyingPartySecuritySettings; + rpAssociation = await relyingParty.AssociationManager.GetOrCreateAssociationAsync(opDescription, CancellationToken.None); if (expectSuccess) { Assert.IsNotNull(rpAssociation); - Association actual = coordinator.RelyingParty.AssociationManager.AssociationStoreTestHook.GetAssociation(opDescription.Uri, rpAssociation.Handle); + Association actual = relyingParty.AssociationManager.AssociationStoreTestHook.GetAssociation(opDescription.Uri, rpAssociation.Handle); Assert.AreEqual(rpAssociation, actual); - opAssociation = coordinator.Provider.AssociationStore.Deserialize(new TestSignedDirectedMessage(), false, rpAssociation.Handle); + opAssociation = provider.AssociationStore.Deserialize(new TestSignedDirectedMessage(), false, rpAssociation.Handle); Assert.IsNotNull(opAssociation, "The Provider could not decode the association handle."); Assert.AreEqual(opAssociation.Handle, rpAssociation.Handle); @@ -373,7 +378,7 @@ namespace DotNetOpenAuth.Test.OpenId { var unencryptedResponse = (AssociateUnencryptedResponse)associateSuccessfulResponse; } } else { - Assert.IsNull(coordinator.RelyingParty.AssociationManager.AssociationStoreTestHook.GetAssociation(opDescription.Uri, new RelyingPartySecuritySettings())); + Assert.IsNull(relyingParty.AssociationManager.AssociationStoreTestHook.GetAssociation(opDescription.Uri, new RelyingPartySecuritySettings())); } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/AuthenticationTests.cs b/src/DotNetOpenAuth.Test/OpenId/AuthenticationTests.cs index 14bcaec..1bc65e5 100644 --- a/src/DotNetOpenAuth.Test/OpenId/AuthenticationTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/AuthenticationTests.cs @@ -6,6 +6,11 @@ namespace DotNetOpenAuth.Test.OpenId { using System; + using System.Net; + using System.Net.Http; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OpenId; @@ -24,76 +29,101 @@ namespace DotNetOpenAuth.Test.OpenId { } [Test] - public void SharedAssociationPositive() { - this.ParameterizedAuthenticationTest(true, true, false); + public async Task SharedAssociationPositive() { + await this.ParameterizedAuthenticationTestAsync(true, true, false); } /// <summary> /// Verifies that a shared association protects against tampering. /// </summary> [Test] - public void SharedAssociationTampered() { - this.ParameterizedAuthenticationTest(true, true, true); + public async Task SharedAssociationTampered() { + await this.ParameterizedAuthenticationTestAsync(true, true, true); } [Test] - public void SharedAssociationNegative() { - this.ParameterizedAuthenticationTest(true, false, false); + public async Task SharedAssociationNegative() { + await this.ParameterizedAuthenticationTestAsync(true, false, false); } [Test] - public void PrivateAssociationPositive() { - this.ParameterizedAuthenticationTest(false, true, false); + public async Task PrivateAssociationPositive() { + await this.ParameterizedAuthenticationTestAsync(false, true, false); } /// <summary> /// Verifies that a private association protects against tampering. /// </summary> [Test] - public void PrivateAssociationTampered() { - this.ParameterizedAuthenticationTest(false, true, true); + public async Task PrivateAssociationTampered() { + await this.ParameterizedAuthenticationTestAsync(false, true, true); } [Test] - public void NoAssociationNegative() { - this.ParameterizedAuthenticationTest(false, false, false); + public async Task NoAssociationNegative() { + await this.ParameterizedAuthenticationTestAsync(false, false, false); } [Test] - public void UnsolicitedAssertion() { - this.MockResponder.RegisterMockRPDiscovery(); - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - rp.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; - IAuthenticationResponse response = rp.GetResponse(); + public async Task UnsolicitedAssertion() { + var opStore = new StandardProviderApplicationStore(); + Handle(RPUri).By( + async req => { + var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore(), this.HostFactories); + IAuthenticationResponse response = await rp.GetResponseAsync(req); + Assert.That(response, Is.Not.Null); Assert.AreEqual(AuthenticationStatus.Authenticated, response.Status); - }, - op => { - op.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; - Identifier id = GetMockIdentifier(ProtocolVersion.V20); - op.SendUnsolicitedAssertion(OPUri, RPRealmUri, id, OPLocalIdentifiers[0]); - AutoProvider(op); // handle check_auth + return new HttpResponseMessage(); + }); + Handle(OPUri).By( + async (req, ct) => { + var op = new OpenIdProvider(opStore, this.HostFactories); + return await this.AutoProviderActionAsync(op, req, ct); }); - coordinator.Run(); + this.RegisterMockRPDiscovery(ssl: false); + + { + var op = new OpenIdProvider(opStore, this.HostFactories); + Identifier id = GetMockIdentifier(ProtocolVersion.V20); + var assertion = await op.PrepareUnsolicitedAssertionAsync(OPUri, RPRealmUri, id, OPLocalIdentifiers[0]); + + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(assertion.Headers.Location)) { + response.EnsureSuccessStatusCode(); + } + } + } } [Test] - public void UnsolicitedAssertionRejected() { - this.MockResponder.RegisterMockRPDiscovery(); - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - rp.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; + public async Task UnsolicitedAssertionRejected() { + var opStore = new StandardProviderApplicationStore(); + Handle(RPUri).By( + async req => { + var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore(), this.HostFactories); rp.SecuritySettings.RejectUnsolicitedAssertions = true; - IAuthenticationResponse response = rp.GetResponse(); + IAuthenticationResponse response = await rp.GetResponseAsync(req); + Assert.That(response, Is.Not.Null); Assert.AreEqual(AuthenticationStatus.Failed, response.Status); - }, - op => { - op.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; - Identifier id = GetMockIdentifier(ProtocolVersion.V20); - op.SendUnsolicitedAssertion(OPUri, RPRealmUri, id, OPLocalIdentifiers[0]); - AutoProvider(op); // handle check_auth + return new HttpResponseMessage(); + }); + Handle(OPUri).By( + async req => { + var op = new OpenIdProvider(opStore, this.HostFactories); + return await this.AutoProviderActionAsync(op, req, CancellationToken.None); }); - coordinator.Run(); + this.RegisterMockRPDiscovery(ssl: false); + + { + var op = new OpenIdProvider(opStore, this.HostFactories); + Identifier id = GetMockIdentifier(ProtocolVersion.V20); + var assertion = await op.PrepareUnsolicitedAssertionAsync(OPUri, RPRealmUri, id, OPLocalIdentifiers[0]); + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(assertion.Headers.Location)) { + response.EnsureSuccessStatusCode(); + } + } + } } /// <summary> @@ -101,25 +131,37 @@ namespace DotNetOpenAuth.Test.OpenId { /// when the appropriate security setting is set. /// </summary> [Test] - public void UnsolicitedDelegatingIdentifierRejection() { - this.MockResponder.RegisterMockRPDiscovery(); - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - rp.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; + public async Task UnsolicitedDelegatingIdentifierRejection() { + var opStore = new StandardProviderApplicationStore(); + Handle(RPUri).By( + async req => { + var rp = this.CreateRelyingParty(); rp.SecuritySettings.RejectDelegatingIdentifiers = true; - IAuthenticationResponse response = rp.GetResponse(); + IAuthenticationResponse response = await rp.GetResponseAsync(req); + Assert.That(response, Is.Not.Null); Assert.AreEqual(AuthenticationStatus.Failed, response.Status); - }, - op => { - op.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; - Identifier id = GetMockIdentifier(ProtocolVersion.V20, false, true); - op.SendUnsolicitedAssertion(OPUri, RPRealmUri, id, OPLocalIdentifiers[0]); - AutoProvider(op); // handle check_auth + return new HttpResponseMessage(); }); - coordinator.Run(); + Handle(OPUri).By( + async req => { + var op = new OpenIdProvider(opStore, this.HostFactories); + return await this.AutoProviderActionAsync(op, req, CancellationToken.None); + }); + this.RegisterMockRPDiscovery(ssl: false); + + { + var op = new OpenIdProvider(opStore, this.HostFactories); + Identifier id = GetMockIdentifier(ProtocolVersion.V20, false, true); + var assertion = await op.PrepareUnsolicitedAssertionAsync(OPUri, RPRealmUri, id, OPLocalIdentifiers[0]); + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(assertion.Headers.Location)) { + response.EnsureSuccessStatusCode(); + } + } + } } - private void ParameterizedAuthenticationTest(bool sharedAssociation, bool positive, bool tamper) { + private async Task ParameterizedAuthenticationTestAsync(bool sharedAssociation, bool positive, bool tamper) { foreach (Protocol protocol in Protocol.AllPracticalVersions) { foreach (bool statelessRP in new[] { false, true }) { if (sharedAssociation && statelessRP) { @@ -129,121 +171,154 @@ namespace DotNetOpenAuth.Test.OpenId { foreach (bool immediate in new[] { false, true }) { TestLogger.InfoFormat("Beginning authentication test scenario. OpenID: {0}, Shared: {1}, positive: {2}, tamper: {3}, stateless: {4}, immediate: {5}", protocol.Version, sharedAssociation, positive, tamper, statelessRP, immediate); - this.ParameterizedAuthenticationTest(protocol, statelessRP, sharedAssociation, positive, immediate, tamper); + await this.ParameterizedAuthenticationTestAsync(protocol, statelessRP, sharedAssociation, positive, immediate, tamper); } } } } - private void ParameterizedAuthenticationTest(Protocol protocol, bool statelessRP, bool sharedAssociation, bool positive, bool immediate, bool tamper) { + private async Task ParameterizedAuthenticationTestAsync(Protocol protocol, bool statelessRP, bool sharedAssociation, bool positive, bool immediate, bool tamper) { Requires.That(!statelessRP || !sharedAssociation, null, "The RP cannot be stateless while sharing an association with the OP."); Requires.That(positive || !tamper, null, "Cannot tamper with a negative response."); var securitySettings = new ProviderSecuritySettings(); var cryptoKeyStore = new MemoryCryptoKeyStore(); var associationStore = new ProviderAssociationHandleEncoder(cryptoKeyStore); Association association = sharedAssociation ? HmacShaAssociationProvider.Create(protocol, protocol.Args.SignatureAlgorithm.Best, AssociationRelyingPartyType.Smart, associationStore, securitySettings) : null; - var coordinator = new OpenIdCoordinator( - rp => { - var request = new CheckIdRequest(protocol.Version, OPUri, immediate ? AuthenticationRequestMode.Immediate : AuthenticationRequestMode.Setup); - + int opStep = 0; + HandleProvider( + async (op, req) => { if (association != null) { - StoreAssociation(rp, OPUri, association); - request.AssociationHandle = association.Handle; + var key = cryptoKeyStore.GetCurrentKey( + ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, TimeSpan.FromSeconds(1)); + op.CryptoKeyStore.StoreKey( + ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, key.Key, key.Value); } - request.ClaimedIdentifier = "http://claimedid"; - request.LocalIdentifier = "http://localid"; - request.ReturnTo = RPUri; - request.Realm = RPUri; - rp.Channel.Respond(request); - if (positive) { - if (tamper) { - try { - rp.Channel.ReadFromRequest<PositiveAssertionResponse>(); - Assert.Fail("Expected exception {0} not thrown.", typeof(InvalidSignatureException).Name); - } catch (InvalidSignatureException) { - TestLogger.InfoFormat("Caught expected {0} exception after tampering with signed data.", typeof(InvalidSignatureException).Name); + switch (++opStep) { + case 1: + var request = await op.Channel.ReadFromRequestAsync<CheckIdRequest>(req, CancellationToken.None); + Assert.IsNotNull(request); + IProtocolMessage response; + if (positive) { + response = new PositiveAssertionResponse(request); + } else { + response = await NegativeAssertionResponse.CreateAsync(request, CancellationToken.None, op.Channel); } - } else { - var response = rp.Channel.ReadFromRequest<PositiveAssertionResponse>(); - Assert.IsNotNull(response); - Assert.AreEqual(request.ClaimedIdentifier, response.ClaimedIdentifier); - Assert.AreEqual(request.LocalIdentifier, response.LocalIdentifier); - Assert.AreEqual(request.ReturnTo, response.ReturnTo); - - // Attempt to replay the message and verify that it fails. - // Because in various scenarios and protocol versions different components - // notice the replay, we can get one of two exceptions thrown. - // When the OP notices the replay we get a generic InvalidSignatureException. - // When the RP notices the replay we get a specific ReplayMessageException. - try { - CoordinatingChannel channel = (CoordinatingChannel)rp.Channel; - channel.Replay(response); - Assert.Fail("Expected ProtocolException was not thrown."); - } catch (ProtocolException ex) { - Assert.IsTrue(ex is ReplayedMessageException || ex is InvalidSignatureException, "A {0} exception was thrown instead of the expected {1} or {2}.", ex.GetType(), typeof(ReplayedMessageException).Name, typeof(InvalidSignatureException).Name); + + return await op.Channel.PrepareResponseAsync(response); + case 2: + if (positive && (statelessRP || !sharedAssociation)) { + var checkauthRequest = + await op.Channel.ReadFromRequestAsync<CheckAuthenticationRequest>(req, CancellationToken.None); + var checkauthResponse = new CheckAuthenticationResponse(checkauthRequest.Version, checkauthRequest); + checkauthResponse.IsValid = checkauthRequest.IsValid; + return await op.Channel.PrepareResponseAsync(checkauthResponse); } - } - } else { - var response = rp.Channel.ReadFromRequest<NegativeAssertionResponse>(); - Assert.IsNotNull(response); - if (immediate) { - // Only 1.1 was required to include user_setup_url - if (protocol.Version.Major < 2) { - Assert.IsNotNull(response.UserSetupUrl); + + throw Assumes.NotReachable(); + case 3: + if (positive && (statelessRP || !sharedAssociation)) { + if (!tamper) { + // Respond to the replay attack. + var checkauthRequest = + await op.Channel.ReadFromRequestAsync<CheckAuthenticationRequest>(req, CancellationToken.None); + var checkauthResponse = new CheckAuthenticationResponse(checkauthRequest.Version, checkauthRequest); + checkauthResponse.IsValid = checkauthRequest.IsValid; + return await op.Channel.PrepareResponseAsync(checkauthResponse); + } } - } else { - Assert.IsNull(response.UserSetupUrl); - } + + throw Assumes.NotReachable(); + default: + throw Assumes.NotReachable(); } - }, - op => { - if (association != null) { - var key = cryptoKeyStore.GetCurrentKey(ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, TimeSpan.FromSeconds(1)); - op.CryptoKeyStore.StoreKey(ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, key.Key, key.Value); + }); + + { + var rp = this.CreateRelyingParty(statelessRP); + if (tamper) { + rp.Channel.IncomingMessageFilter = message => { + var assertion = message as PositiveAssertionResponse; + if (assertion != null) { + // Alter the Local Identifier between the Provider and the Relying Party. + // If the signature binding element does its job, this should cause the RP + // to throw. + assertion.LocalIdentifier = "http://victim"; + } + }; + } + + var request = new CheckIdRequest( + protocol.Version, OPUri, immediate ? AuthenticationRequestMode.Immediate : AuthenticationRequestMode.Setup); + + if (association != null) { + StoreAssociation(rp, OPUri, association); + request.AssociationHandle = association.Handle; + } + + request.ClaimedIdentifier = "http://claimedid"; + request.LocalIdentifier = "http://localid"; + request.ReturnTo = RPUri; + request.Realm = RPUri; + var redirectRequest = await rp.Channel.PrepareResponseAsync(request); + Uri redirectResponse; + this.HostFactories.AllowAutoRedirects = false; + using (var httpClient = rp.Channel.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(redirectRequest.Headers.Location)) { + Assert.That(response.StatusCode, Is.EqualTo(HttpStatusCode.Redirect)); + redirectResponse = response.Headers.Location; } + } - var request = op.Channel.ReadFromRequest<CheckIdRequest>(); - Assert.IsNotNull(request); - IProtocolMessage response; - if (positive) { - response = new PositiveAssertionResponse(request); + var assertionMessage = new HttpRequestMessage(HttpMethod.Get, redirectResponse); + if (positive) { + if (tamper) { + try { + await rp.Channel.ReadFromRequestAsync<PositiveAssertionResponse>(assertionMessage, CancellationToken.None); + Assert.Fail("Expected exception {0} not thrown.", typeof(InvalidSignatureException).Name); + } catch (InvalidSignatureException) { + TestLogger.InfoFormat( + "Caught expected {0} exception after tampering with signed data.", typeof(InvalidSignatureException).Name); + } } else { - response = new NegativeAssertionResponse(request, op.Channel); - } - op.Channel.Respond(response); - - if (positive && (statelessRP || !sharedAssociation)) { - var checkauthRequest = op.Channel.ReadFromRequest<CheckAuthenticationRequest>(); - var checkauthResponse = new CheckAuthenticationResponse(checkauthRequest.Version, checkauthRequest); - checkauthResponse.IsValid = checkauthRequest.IsValid; - op.Channel.Respond(checkauthResponse); - - if (!tamper) { - // Respond to the replay attack. - checkauthRequest = op.Channel.ReadFromRequest<CheckAuthenticationRequest>(); - checkauthResponse = new CheckAuthenticationResponse(checkauthRequest.Version, checkauthRequest); - checkauthResponse.IsValid = checkauthRequest.IsValid; - op.Channel.Respond(checkauthResponse); + var response = + await rp.Channel.ReadFromRequestAsync<PositiveAssertionResponse>(assertionMessage, CancellationToken.None); + Assert.IsNotNull(response); + Assert.AreEqual(request.ClaimedIdentifier, response.ClaimedIdentifier); + Assert.AreEqual(request.LocalIdentifier, response.LocalIdentifier); + Assert.AreEqual(request.ReturnTo, response.ReturnTo); + + // Attempt to replay the message and verify that it fails. + // Because in various scenarios and protocol versions different components + // notice the replay, we can get one of two exceptions thrown. + // When the OP notices the replay we get a generic InvalidSignatureException. + // When the RP notices the replay we get a specific ReplayMessageException. + try { + await rp.Channel.ReadFromRequestAsync<PositiveAssertionResponse>(assertionMessage, CancellationToken.None); + Assert.Fail("Expected ProtocolException was not thrown."); + } catch (ProtocolException ex) { + Assert.IsTrue( + ex is ReplayedMessageException || ex is InvalidSignatureException, + "A {0} exception was thrown instead of the expected {1} or {2}.", + ex.GetType(), + typeof(ReplayedMessageException).Name, + typeof(InvalidSignatureException).Name); } } - }); - if (tamper) { - coordinator.IncomingMessageFilter = message => { - var assertion = message as PositiveAssertionResponse; - if (assertion != null) { - // Alter the Local Identifier between the Provider and the Relying Party. - // If the signature binding element does its job, this should cause the RP - // to throw. - assertion.LocalIdentifier = "http://victim"; + } else { + var response = + await rp.Channel.ReadFromRequestAsync<NegativeAssertionResponse>(assertionMessage, CancellationToken.None); + Assert.IsNotNull(response); + if (immediate) { + // Only 1.1 was required to include user_setup_url + if (protocol.Version.Major < 2) { + Assert.IsNotNull(response.UserSetupUrl); + } + } else { + Assert.IsNull(response.UserSetupUrl); } - }; - } - if (statelessRP) { - coordinator.RelyingParty = new OpenIdRelyingParty(null); + } } - - coordinator.Run(); } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/ChannelElements/ExtensionsBindingElementTests.cs b/src/DotNetOpenAuth.Test/OpenId/ChannelElements/ExtensionsBindingElementTests.cs index dd47782..1a1307d 100644 --- a/src/DotNetOpenAuth.Test/OpenId/ChannelElements/ExtensionsBindingElementTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/ChannelElements/ExtensionsBindingElementTests.cs @@ -8,12 +8,17 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { using System; using System.Collections.Generic; using System.Linq; + using System.Net.Http; using System.Text.RegularExpressions; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.ChannelElements; using DotNetOpenAuth.OpenId.Extensions; using DotNetOpenAuth.OpenId.Messages; + using DotNetOpenAuth.OpenId.Provider; using DotNetOpenAuth.OpenId.RelyingParty; using DotNetOpenAuth.Test.Mocks; using DotNetOpenAuth.Test.OpenId.Extensions; @@ -38,10 +43,10 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { } [Test] - public void RoundTripFullStackTest() { + public async Task RoundTripFullStackTest() { IOpenIdMessageExtension request = new MockOpenIdExtension("requestPart", "requestData"); IOpenIdMessageExtension response = new MockOpenIdExtension("responsePart", "responseData"); - ExtensionTestUtilities.Roundtrip( + await this.RoundtripAsync( Protocol.Default, new IOpenIdMessageExtension[] { request }, new IOpenIdMessageExtension[] { response }); @@ -53,23 +58,23 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { } [Test] - public void PrepareMessageForSendingNull() { - Assert.IsNull(this.rpElement.ProcessOutgoingMessage(null)); + public async Task PrepareMessageForSendingNull() { + Assert.IsNull(await this.rpElement.ProcessOutgoingMessageAsync(null, CancellationToken.None)); } /// <summary> /// Verifies that false is returned when a non-extendable message is sent. /// </summary> [Test] - public void PrepareMessageForSendingNonExtendableMessage() { + public async Task PrepareMessageForSendingNonExtendableMessage() { IProtocolMessage request = new AssociateDiffieHellmanRequest(Protocol.Default.Version, OpenIdTestBase.OPUri); - Assert.IsNull(this.rpElement.ProcessOutgoingMessage(request)); + Assert.IsNull(await this.rpElement.ProcessOutgoingMessageAsync(request, CancellationToken.None)); } [Test] - public void PrepareMessageForSending() { + public async Task PrepareMessageForSending() { this.request.Extensions.Add(new MockOpenIdExtension("part", "extra")); - Assert.IsNotNull(this.rpElement.ProcessOutgoingMessage(this.request)); + Assert.IsNotNull(await this.rpElement.ProcessOutgoingMessageAsync(this.request, CancellationToken.None)); string alias = GetAliases(this.request.ExtraData).Single(); Assert.AreEqual(MockOpenIdExtension.MockTypeUri, this.request.ExtraData["openid.ns." + alias]); @@ -78,11 +83,11 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { } [Test] - public void PrepareMessageForReceiving() { + public async Task PrepareMessageForReceiving() { this.request.ExtraData["openid.ns.mock"] = MockOpenIdExtension.MockTypeUri; this.request.ExtraData["openid.mock.Part"] = "part"; this.request.ExtraData["openid.mock.data"] = "extra"; - Assert.IsNotNull(this.rpElement.ProcessIncomingMessage(this.request)); + Assert.IsNotNull(await this.rpElement.ProcessIncomingMessageAsync(this.request, CancellationToken.None)); MockOpenIdExtension ext = this.request.Extensions.OfType<MockOpenIdExtension>().Single(); Assert.AreEqual("part", ext.Part); Assert.AreEqual("extra", ext.Data); @@ -92,12 +97,12 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { /// Verifies that extension responses are included in the OP's signature. /// </summary> [Test] - public void ExtensionResponsesAreSigned() { + public async Task ExtensionResponsesAreSigned() { Protocol protocol = Protocol.Default; var op = this.CreateProvider(); IndirectSignedResponse response = this.CreateResponseWithExtensions(protocol); - op.Channel.PrepareResponse(response); - ITamperResistantOpenIdMessage signedResponse = (ITamperResistantOpenIdMessage)response; + await op.Channel.PrepareResponseAsync(response); + ITamperResistantOpenIdMessage signedResponse = response; string extensionAliasKey = signedResponse.ExtraData.Single(kv => kv.Value == MockOpenIdExtension.MockTypeUri).Key; Assert.IsTrue(extensionAliasKey.StartsWith("openid.ns.")); string extensionAlias = extensionAliasKey.Substring("openid.ns.".Length); @@ -114,27 +119,59 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { /// Verifies that unsigned extension responses (where any or all fields are unsigned) are ignored. /// </summary> [Test] - public void ExtensionsAreIdentifiedAsSignedOrUnsigned() { + public async Task ExtensionsAreIdentifiedAsSignedOrUnsigned() { Protocol protocol = Protocol.Default; - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { + var opStore = new StandardProviderApplicationStore(); + int rpStep = 0; + + + Handle(RPUri).By( + async req => { + var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore(), this.HostFactories); RegisterMockExtension(rp.Channel); - var response = rp.Channel.ReadFromRequest<IndirectSignedResponse>(); - Assert.AreEqual(1, response.SignedExtensions.Count(), "Signed extension should have been received."); - Assert.AreEqual(0, response.UnsignedExtensions.Count(), "No unsigned extension should be present."); - response = rp.Channel.ReadFromRequest<IndirectSignedResponse>(); - Assert.AreEqual(0, response.SignedExtensions.Count(), "No signed extension should have been received."); - Assert.AreEqual(1, response.UnsignedExtensions.Count(), "Unsigned extension should have been received."); - }, - op => { - RegisterMockExtension(op.Channel); - op.Channel.Respond(CreateResponseWithExtensions(protocol)); - op.Respond(op.GetRequest()); // check_auth - op.SecuritySettings.SignOutgoingExtensions = false; - op.Channel.Respond(CreateResponseWithExtensions(protocol)); - op.Respond(op.GetRequest()); // check_auth + + switch (++rpStep) { + case 1: + var response = await rp.Channel.ReadFromRequestAsync<IndirectSignedResponse>(req, CancellationToken.None); + Assert.AreEqual(1, response.SignedExtensions.Count(), "Signed extension should have been received."); + Assert.AreEqual(0, response.UnsignedExtensions.Count(), "No unsigned extension should be present."); + break; + case 2: + response = await rp.Channel.ReadFromRequestAsync<IndirectSignedResponse>(req, CancellationToken.None); + Assert.AreEqual(0, response.SignedExtensions.Count(), "No signed extension should have been received."); + Assert.AreEqual(1, response.UnsignedExtensions.Count(), "Unsigned extension should have been received."); + break; + + default: + throw Assumes.NotReachable(); + } + + return new HttpResponseMessage(); }); - coordinator.Run(); + Handle(OPUri).By( + async req => { + var op = new OpenIdProvider(opStore, this.HostFactories); + return await AutoProviderActionAsync(op, req, CancellationToken.None); + }); + + { + var op = new OpenIdProvider(opStore, this.HostFactories); + RegisterMockExtension(op.Channel); + var redirectingResponse = await op.Channel.PrepareResponseAsync(CreateResponseWithExtensions(protocol)); + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(redirectingResponse.Headers.Location)) { + response.EnsureSuccessStatusCode(); + } + } + + op.SecuritySettings.SignOutgoingExtensions = false; + redirectingResponse = await op.Channel.PrepareResponseAsync(CreateResponseWithExtensions(protocol)); + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(redirectingResponse.Headers.Location)) { + response.EnsureSuccessStatusCode(); + } + } + } } /// <summary> @@ -145,17 +182,17 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { /// "A namespace MUST NOT be assigned more than one alias in the same message". /// </remarks> [Test] - public void TwoExtensionsSameTypeUri() { + public async Task TwoExtensionsSameTypeUri() { IOpenIdMessageExtension request1 = new MockOpenIdExtension("requestPart1", "requestData1"); IOpenIdMessageExtension request2 = new MockOpenIdExtension("requestPart2", "requestData2"); try { - ExtensionTestUtilities.Roundtrip( + await this.RoundtripAsync( Protocol.Default, new IOpenIdMessageExtension[] { request1, request2 }, new IOpenIdMessageExtension[0]); Assert.Fail("Expected ProtocolException not thrown."); - } catch (AssertionException ex) { - Assert.IsInstanceOf<ProtocolException>(ex.InnerException); + } catch (ProtocolException) { + // success } } @@ -181,7 +218,7 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { private IndirectSignedResponse CreateResponseWithExtensions(Protocol protocol) { Requires.NotNull(protocol, "protocol"); - IndirectSignedResponse response = new IndirectSignedResponse(protocol.Version, RPUri); + var response = new IndirectSignedResponse(protocol.Version, RPUri); response.ProviderEndpoint = OPUri; response.Extensions.Add(new MockOpenIdExtension("pv", "ev")); return response; diff --git a/src/DotNetOpenAuth.Test/OpenId/ChannelElements/KeyValueFormEncodingTests.cs b/src/DotNetOpenAuth.Test/OpenId/ChannelElements/KeyValueFormEncodingTests.cs index 93ad028..b64701d 100644 --- a/src/DotNetOpenAuth.Test/OpenId/ChannelElements/KeyValueFormEncodingTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/ChannelElements/KeyValueFormEncodingTests.cs @@ -11,6 +11,9 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { using System.Linq; using System.Net; using System.Text; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Reflection; using DotNetOpenAuth.OpenId.ChannelElements; @@ -56,9 +59,9 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { Assert.AreEqual(this.sampleData.Count, count); } - public void KVDictTest(byte[] kvform, IDictionary<string, string> dict, TestMode mode) { + public async Task KVDictTestAsync(byte[] kvform, IDictionary<string, string> dict, TestMode mode) { if ((mode & TestMode.Decoder) == TestMode.Decoder) { - var d = this.keyValueForm.GetDictionary(new MemoryStream(kvform)); + var d = await this.keyValueForm.GetDictionaryAsync(new MemoryStream(kvform), CancellationToken.None); foreach (string key in dict.Keys) { Assert.AreEqual(d[key], dict[key], "Decoder fault: " + d[key] + " and " + dict[key] + " do not match."); } @@ -70,91 +73,91 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { } [Test] - public void EncodeDecode() { - this.KVDictTest(UTF8Encoding.UTF8.GetBytes(string.Empty), new Dictionary<string, string>(), TestMode.Both); + public async Task EncodeDecode() { + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes(string.Empty), new Dictionary<string, string>(), TestMode.Both); Dictionary<string, string> d1 = new Dictionary<string, string>(); d1.Add("college", "harvey mudd"); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes("college:harvey mudd\n"), d1, TestMode.Both); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes("college:harvey mudd\n"), d1, TestMode.Both); Dictionary<string, string> d2 = new Dictionary<string, string>(); d2.Add("city", "claremont"); d2.Add("state", "CA"); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes("city:claremont\nstate:CA\n"), d2, TestMode.Both); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes("city:claremont\nstate:CA\n"), d2, TestMode.Both); Dictionary<string, string> d3 = new Dictionary<string, string>(); d3.Add("is_valid", "true"); d3.Add("invalidate_handle", "{HMAC-SHA1:2398410938412093}"); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes("is_valid:true\ninvalidate_handle:{HMAC-SHA1:2398410938412093}\n"), d3, TestMode.Both); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes("is_valid:true\ninvalidate_handle:{HMAC-SHA1:2398410938412093}\n"), d3, TestMode.Both); Dictionary<string, string> d4 = new Dictionary<string, string>(); d4.Add(string.Empty, string.Empty); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes(":\n"), d4, TestMode.Both); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes(":\n"), d4, TestMode.Both); Dictionary<string, string> d5 = new Dictionary<string, string>(); d5.Add(string.Empty, "missingkey"); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes(":missingkey\n"), d5, TestMode.Both); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes(":missingkey\n"), d5, TestMode.Both); Dictionary<string, string> d6 = new Dictionary<string, string>(); d6.Add("street", "foothill blvd"); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes("street:foothill blvd\n"), d6, TestMode.Both); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes("street:foothill blvd\n"), d6, TestMode.Both); Dictionary<string, string> d7 = new Dictionary<string, string>(); d7.Add("major", "computer science"); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes("major:computer science\n"), d7, TestMode.Both); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes("major:computer science\n"), d7, TestMode.Both); Dictionary<string, string> d8 = new Dictionary<string, string>(); d8.Add("dorm", "east"); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes(" dorm : east \n"), d8, TestMode.Decoder); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes(" dorm : east \n"), d8, TestMode.Decoder); Dictionary<string, string> d9 = new Dictionary<string, string>(); d9.Add("e^(i*pi)+1", "0"); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes("e^(i*pi)+1:0"), d9, TestMode.Decoder); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes("e^(i*pi)+1:0"), d9, TestMode.Decoder); Dictionary<string, string> d10 = new Dictionary<string, string>(); d10.Add("east", "west"); d10.Add("north", "south"); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes("east:west\nnorth:south"), d10, TestMode.Decoder); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes("east:west\nnorth:south"), d10, TestMode.Decoder); } [Test, ExpectedException(typeof(FormatException))] - public void NoValue() { - this.Illegal("x\n", KeyValueFormConformanceLevel.OpenId11); + public async Task NoValue() { + await this.IllegalAsync("x\n", KeyValueFormConformanceLevel.OpenId11); } [Test, ExpectedException(typeof(FormatException))] - public void NoValueLoose() { + public async Task NoValueLoose() { Dictionary<string, string> d = new Dictionary<string, string>(); - this.KVDictTest(Encoding.UTF8.GetBytes("x\n"), d, TestMode.Decoder); + await this.KVDictTestAsync(Encoding.UTF8.GetBytes("x\n"), d, TestMode.Decoder); } [Test, ExpectedException(typeof(FormatException))] - public void EmptyLine() { - this.Illegal("x:b\n\n", KeyValueFormConformanceLevel.OpenId20); + public async Task EmptyLine() { + await this.IllegalAsync("x:b\n\n", KeyValueFormConformanceLevel.OpenId20); } [Test] - public void EmptyLineLoose() { + public async Task EmptyLineLoose() { Dictionary<string, string> d = new Dictionary<string, string>(); d.Add("x", "b"); - this.KVDictTest(Encoding.UTF8.GetBytes("x:b\n\n"), d, TestMode.Decoder); + await this.KVDictTestAsync(Encoding.UTF8.GetBytes("x:b\n\n"), d, TestMode.Decoder); } [Test, ExpectedException(typeof(FormatException))] - public void LastLineNotTerminated() { - this.Illegal("x:y\na:b", KeyValueFormConformanceLevel.OpenId11); + public async Task LastLineNotTerminated() { + await this.IllegalAsync("x:y\na:b", KeyValueFormConformanceLevel.OpenId11); } [Test] - public void LastLineNotTerminatedLoose() { + public async Task LastLineNotTerminatedLoose() { Dictionary<string, string> d = new Dictionary<string, string>(); d.Add("x", "y"); d.Add("a", "b"); - this.KVDictTest(Encoding.UTF8.GetBytes("x:y\na:b"), d, TestMode.Decoder); + await this.KVDictTestAsync(Encoding.UTF8.GetBytes("x:y\na:b"), d, TestMode.Decoder); } - private void Illegal(string s, KeyValueFormConformanceLevel level) { - new KeyValueFormEncoding(level).GetDictionary(new MemoryStream(Encoding.UTF8.GetBytes(s))); + private async Task IllegalAsync(string s, KeyValueFormConformanceLevel level) { + await new KeyValueFormEncoding(level).GetDictionaryAsync(new MemoryStream(Encoding.UTF8.GetBytes(s)), CancellationToken.None); } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/ChannelElements/OpenIdChannelTests.cs b/src/DotNetOpenAuth.Test/OpenId/ChannelElements/OpenIdChannelTests.cs index f50137d..c9cd52c 100644 --- a/src/DotNetOpenAuth.Test/OpenId/ChannelElements/OpenIdChannelTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/ChannelElements/OpenIdChannelTests.cs @@ -10,7 +10,10 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { using System.IO; using System.Linq; using System.Net; + using System.Net.Http; using System.Text; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.Messaging.Reflection; @@ -21,16 +24,13 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { using NUnit.Framework; [TestFixture] - public class OpenIdChannelTests : TestBase { + public class OpenIdChannelTests : OpenIdTestBase { private static readonly TimeSpan maximumMessageAge = TimeSpan.FromHours(3); // good for tests, too long for production private OpenIdChannel channel; - private Mocks.TestWebRequestHandler webHandler; [SetUp] public void Setup() { - this.webHandler = new Mocks.TestWebRequestHandler(); - this.channel = new OpenIdRelyingPartyChannel(new MemoryCryptoKeyStore(), new NonceMemoryStore(maximumMessageAge), new RelyingPartySecuritySettings()); - this.channel.WebRequestHandler = this.webHandler; + this.channel = new OpenIdRelyingPartyChannel(new MemoryCryptoKeyStore(), new NonceMemoryStore(maximumMessageAge), new RelyingPartySecuritySettings(), this.HostFactories); } [Test] @@ -52,14 +52,14 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { /// Verifies that the channel sends direct message requests as HTTP POST requests. /// </summary> [Test] - public void DirectRequestsUsePost() { + public async Task DirectRequestsUsePost() { IDirectedProtocolMessage requestMessage = new Mocks.TestDirectedMessage(MessageTransport.Direct) { Recipient = new Uri("http://host"), Name = "Andrew", }; - HttpWebRequest httpRequest = this.channel.CreateHttpRequestTestHook(requestMessage); - Assert.AreEqual("POST", httpRequest.Method); - StringAssert.Contains("Name=Andrew", this.webHandler.RequestEntityAsString); + var httpRequest = this.channel.CreateHttpRequestTestHook(requestMessage); + Assert.AreEqual(HttpMethod.Post, httpRequest.Method); + StringAssert.Contains("Name=Andrew", await httpRequest.Content.ReadAsStringAsync()); } /// <summary> @@ -72,16 +72,15 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { /// <see cref="OpenIdChannel.SendDirectMessageResponse"/> method. /// </remarks> [Test] - public void DirectResponsesSentUsingKeyValueForm() { + public async Task DirectResponsesSentUsingKeyValueForm() { IProtocolMessage message = MessagingTestBase.GetStandardTestMessage(MessagingTestBase.FieldFill.AllRequired); MessageDictionary messageFields = this.MessageDescriptions.GetAccessor(message); byte[] expectedBytes = KeyValueFormEncoding.GetBytes(messageFields); string expectedContentType = OpenIdChannel.KeyValueFormContentType; - OutgoingWebResponse directResponse = this.channel.PrepareDirectResponseTestHook(message); - Assert.AreEqual(expectedContentType, directResponse.Headers[HttpResponseHeader.ContentType]); - byte[] actualBytes = new byte[directResponse.ResponseStream.Length]; - directResponse.ResponseStream.Read(actualBytes, 0, actualBytes.Length); + var directResponse = this.channel.PrepareDirectResponseTestHook(message); + Assert.AreEqual(expectedContentType, directResponse.Content.Headers.ContentType.MediaType); + byte[] actualBytes = await directResponse.Content.ReadAsByteArrayAsync(); Assert.IsTrue(MessagingUtilities.AreEquivalent(expectedBytes, actualBytes)); } @@ -89,15 +88,15 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { /// Verifies that direct message responses are read in using the Key Value Form decoder. /// </summary> [Test] - public void DirectResponsesReceivedAsKeyValueForm() { + public async Task DirectResponsesReceivedAsKeyValueForm() { var fields = new Dictionary<string, string> { { "var1", "value1" }, { "var2", "value2" }, }; - var response = new CachedDirectWebResponse { - CachedResponseStream = new MemoryStream(KeyValueFormEncoding.GetBytes(fields)), + var response = new HttpResponseMessage { + Content = new StreamContent(new MemoryStream(KeyValueFormEncoding.GetBytes(fields))), }; - Assert.IsTrue(MessagingUtilities.AreEquivalent(fields, this.channel.ReadFromResponseCoreTestHook(response))); + Assert.IsTrue(MessagingUtilities.AreEquivalent(fields, await this.channel.ReadFromResponseCoreAsyncTestHook(response, CancellationToken.None))); } /// <summary> @@ -106,14 +105,14 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { [Test] public void SendDirectMessageResponseHonorsHttpStatusCodes() { IProtocolMessage message = MessagingTestBase.GetStandardTestMessage(MessagingTestBase.FieldFill.AllRequired); - OutgoingWebResponse directResponse = this.channel.PrepareDirectResponseTestHook(message); - Assert.AreEqual(HttpStatusCode.OK, directResponse.Status); + var directResponse = this.channel.PrepareDirectResponseTestHook(message); + Assert.AreEqual(HttpStatusCode.OK, directResponse.StatusCode); var httpMessage = new TestDirectResponseMessageWithHttpStatus(); MessagingTestBase.GetStandardTestMessage(MessagingTestBase.FieldFill.AllRequired, httpMessage); httpMessage.HttpStatusCode = HttpStatusCode.NotAcceptable; directResponse = this.channel.PrepareDirectResponseTestHook(httpMessage); - Assert.AreEqual(HttpStatusCode.NotAcceptable, directResponse.Status); + Assert.AreEqual(HttpStatusCode.NotAcceptable, directResponse.StatusCode); } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/ChannelElements/SigningBindingElementTests.cs b/src/DotNetOpenAuth.Test/OpenId/ChannelElements/SigningBindingElementTests.cs index f7722e3..42b447e 100644 --- a/src/DotNetOpenAuth.Test/OpenId/ChannelElements/SigningBindingElementTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/ChannelElements/SigningBindingElementTests.cs @@ -8,6 +8,9 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { using System; using System.Collections.Generic; using System.Linq; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OpenId; @@ -24,7 +27,7 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { /// Verifies that the signatures generated match Known Good signatures. /// </summary> [Test] - public void SignaturesMatchKnownGood() { + public async Task SignaturesMatchKnownGood() { Protocol protocol = Protocol.V20; var settings = new ProviderSecuritySettings(); var cryptoStore = new MemoryCryptoKeyStore(); @@ -41,7 +44,7 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { message.ProviderEndpoint = new Uri("http://provider"); signedMessage.UtcCreationDate = DateTime.Parse("1/1/2009"); signedMessage.AssociationHandle = handle; - Assert.IsNotNull(signer.ProcessOutgoingMessage(message)); + Assert.IsNotNull(await signer.ProcessOutgoingMessageAsync(message, CancellationToken.None)); Assert.AreEqual("o9+uN7qTaUS9v0otbHTuNAtbkpBm14+es9QnNo6IHD4=", signedMessage.Signature); } @@ -49,7 +52,7 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { /// Verifies that all parameters in ExtraData in signed responses are signed. /// </summary> [Test] - public void SignedResponsesIncludeExtraDataInSignature() { + public async Task SignedResponsesIncludeExtraDataInSignature() { Protocol protocol = Protocol.Default; SigningBindingElement sbe = new ProviderSigningBindingElement(new ProviderAssociationHandleEncoder(new MemoryCryptoKeyStore()), new ProviderSecuritySettings()); sbe.Channel = new TestChannel(this.MessageDescriptions); @@ -60,7 +63,7 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { response.ExtraData["someunsigned"] = "value"; response.ExtraData["openid.somesigned"] = "value"; - Assert.IsNotNull(sbe.ProcessOutgoingMessage(response)); + Assert.IsNotNull(await sbe.ProcessOutgoingMessageAsync(response, CancellationToken.None)); ITamperResistantOpenIdMessage signedResponse = (ITamperResistantOpenIdMessage)response; // Make sure that the extra parameters are signed. @@ -76,14 +79,14 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { /// Regression test for bug #45 (https://github.com/AArnott/dotnetopenid/issues/45) /// </summary> [Test, ExpectedException(typeof(ProtocolException))] - public void MissingSignedParameter() { + public async Task MissingSignedParameter() { var cryptoStore = new MemoryCryptoKeyStore(); byte[] associationSecret = Convert.FromBase64String("rsSwv1zPWfjPRQU80hciu8FPDC+GONAMJQ/AvSo1a2M="); string handle = "{634477555066085461}{TTYcIg==}{32}"; cryptoStore.StoreKey(ProviderAssociationKeyStorage.PrivateAssociationBucket, handle, new CryptoKey(associationSecret, DateTime.UtcNow.AddDays(1))); var signer = new ProviderSigningBindingElement(new ProviderAssociationKeyStorage(cryptoStore), new ProviderSecuritySettings()); - var testChannel = new TestChannel(new OpenIdProviderMessageFactory()); + var testChannel = new TestChannel(new OpenIdProviderMessageFactory(), new IChannelBindingElement[0], this.HostFactories); signer.Channel = testChannel; var buggyRPMessage = new Dictionary<string, string>() { @@ -101,7 +104,7 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { }; var message = (CheckAuthenticationRequest)testChannel.Receive(buggyRPMessage, new MessageReceivingEndpoint(OPUri, HttpDeliveryMethods.PostRequest)); var originalResponse = new IndirectSignedResponse(message, signer.Channel); - signer.ProcessIncomingMessage(originalResponse); + await signer.ProcessIncomingMessageAsync(originalResponse, CancellationToken.None); } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/DiscoveryServices/UriDiscoveryServiceTests.cs b/src/DotNetOpenAuth.Test/OpenId/DiscoveryServices/UriDiscoveryServiceTests.cs index 88ad208..3c54d98 100644 --- a/src/DotNetOpenAuth.Test/OpenId/DiscoveryServices/UriDiscoveryServiceTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/DiscoveryServices/UriDiscoveryServiceTests.cs @@ -9,48 +9,55 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { using System.Collections.Generic; using System.Linq; using System.Net; + using System.Net.Http; using System.Text; + using System.Threading; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Extensions.SimpleRegistration; using DotNetOpenAuth.OpenId.RelyingParty; + using DotNetOpenAuth.Test.Mocks; + using NUnit.Framework; [TestFixture] public class UriDiscoveryServiceTests : OpenIdTestBase { [Test] - public void DiscoveryWithRedirects() { + public async Task DiscoveryWithRedirects() { Identifier claimedId = this.GetMockIdentifier(ProtocolVersion.V20, false); // Add a couple of chained redirect pages that lead to the claimedId. Uri userSuppliedUri = new Uri("https://localhost/someSecurePage"); Uri insecureMidpointUri = new Uri("http://localhost/insecureStop"); - this.MockResponder.RegisterMockRedirect(userSuppliedUri, insecureMidpointUri); - this.MockResponder.RegisterMockRedirect(insecureMidpointUri, new Uri(claimedId.ToString())); + this.RegisterMockRedirect(userSuppliedUri, insecureMidpointUri); + this.RegisterMockRedirect(insecureMidpointUri, new Uri(claimedId.ToString())); // don't require secure SSL discovery for this test. Identifier userSuppliedIdentifier = new UriIdentifier(userSuppliedUri, false); - Assert.AreEqual(1, this.Discover(userSuppliedIdentifier).Count()); + var discoveryResult = await this.DiscoverAsync(userSuppliedIdentifier); + Assert.AreEqual(1, discoveryResult.Count()); } [Test] - public void DiscoverRequireSslWithSecureRedirects() { + public async Task DiscoverRequireSslWithSecureRedirects() { Identifier claimedId = this.GetMockIdentifier(ProtocolVersion.V20, true); // Add a couple of chained redirect pages that lead to the claimedId. // All redirects should be secure. Uri userSuppliedUri = new Uri("https://localhost/someSecurePage"); Uri secureMidpointUri = new Uri("https://localhost/secureStop"); - this.MockResponder.RegisterMockRedirect(userSuppliedUri, secureMidpointUri); - this.MockResponder.RegisterMockRedirect(secureMidpointUri, new Uri(claimedId.ToString())); + this.RegisterMockRedirect(userSuppliedUri, secureMidpointUri); + this.RegisterMockRedirect(secureMidpointUri, new Uri(claimedId.ToString())); Identifier userSuppliedIdentifier = new UriIdentifier(userSuppliedUri, true); - Assert.AreEqual(1, this.Discover(userSuppliedIdentifier).Count()); + var discoveryResult = await this.DiscoverAsync(userSuppliedIdentifier); + Assert.AreEqual(1, discoveryResult.Count()); } [Test, ExpectedException(typeof(ProtocolException))] - public void DiscoverRequireSslWithInsecureRedirect() { + public async Task DiscoverRequireSslWithInsecureRedirect() { Identifier claimedId = this.GetMockIdentifier(ProtocolVersion.V20, true); // Add a couple of chained redirect pages that lead to the claimedId. @@ -58,41 +65,43 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { // the ultimate endpoint is never found as a result of high security profile. Uri userSuppliedUri = new Uri("https://localhost/someSecurePage"); Uri insecureMidpointUri = new Uri("http://localhost/insecureStop"); - this.MockResponder.RegisterMockRedirect(userSuppliedUri, insecureMidpointUri); - this.MockResponder.RegisterMockRedirect(insecureMidpointUri, new Uri(claimedId.ToString())); + this.RegisterMockRedirect(userSuppliedUri, insecureMidpointUri); + this.RegisterMockRedirect(insecureMidpointUri, new Uri(claimedId.ToString())); Identifier userSuppliedIdentifier = new UriIdentifier(userSuppliedUri, true); - this.Discover(userSuppliedIdentifier); + await this.DiscoverAsync(userSuppliedIdentifier); } [Test] - public void DiscoveryRequireSslWithInsecureXrdsInSecureHtmlHead() { + public async Task DiscoveryRequireSslWithInsecureXrdsInSecureHtmlHead() { var insecureXrdsSource = this.GetMockIdentifier(ProtocolVersion.V20, false); Uri secureClaimedUri = new Uri("https://localhost/secureId"); string html = string.Format("<html><head><meta http-equiv='X-XRDS-Location' content='{0}'/></head><body></body></html>", insecureXrdsSource); - this.MockResponder.RegisterMockResponse(secureClaimedUri, "text/html", html); + this.RegisterMockResponse(secureClaimedUri, "text/html", html); Identifier userSuppliedIdentifier = new UriIdentifier(secureClaimedUri, true); - Assert.AreEqual(0, this.Discover(userSuppliedIdentifier).Count()); + var discoveryResult = await this.DiscoverAsync(userSuppliedIdentifier); + Assert.AreEqual(0, discoveryResult.Count()); } [Test] - public void DiscoveryRequireSslWithInsecureXrdsInSecureHttpHeader() { + public async Task DiscoveryRequireSslWithInsecureXrdsInSecureHttpHeader() { var insecureXrdsSource = this.GetMockIdentifier(ProtocolVersion.V20, false); string html = "<html><head></head><body></body></html>"; WebHeaderCollection headers = new WebHeaderCollection { { "X-XRDS-Location", insecureXrdsSource } }; - this.MockResponder.RegisterMockResponse(VanityUriSsl, VanityUriSsl, "text/html", headers, html); + this.RegisterMockResponse(VanityUriSsl, VanityUriSsl, "text/html", headers, html); Identifier userSuppliedIdentifier = new UriIdentifier(VanityUriSsl, true); - Assert.AreEqual(0, this.Discover(userSuppliedIdentifier).Count()); + var discoveryResult = await this.DiscoverAsync(userSuppliedIdentifier); + Assert.AreEqual(0, discoveryResult.Count()); } [Test] - public void DiscoveryRequireSslWithInsecureXrdsButSecureLinkTags() { + public async Task DiscoveryRequireSslWithInsecureXrdsButSecureLinkTags() { var insecureXrdsSource = this.GetMockIdentifier(ProtocolVersion.V20, false); string html = string.Format( @" @@ -104,104 +113,109 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { HttpUtility.HtmlEncode(insecureXrdsSource), HttpUtility.HtmlEncode(OPUriSsl.AbsoluteUri), HttpUtility.HtmlEncode(OPLocalIdentifiersSsl[1].AbsoluteUri)); - this.MockResponder.RegisterMockResponse(VanityUriSsl, "text/html", html); + this.Handle(VanityUriSsl).By(html, "text/html"); Identifier userSuppliedIdentifier = new UriIdentifier(VanityUriSsl, true); // We verify that the XRDS was ignored and the LINK tags were used // because the XRDS OP-LocalIdentifier uses different local identifiers. - Assert.AreEqual(OPLocalIdentifiersSsl[1].AbsoluteUri, this.Discover(userSuppliedIdentifier).Single().ProviderLocalIdentifier.ToString()); + var discoveryResult = await this.DiscoverAsync(userSuppliedIdentifier); + Assert.AreEqual(OPLocalIdentifiersSsl[1].AbsoluteUri, discoveryResult.Single().ProviderLocalIdentifier.ToString()); } [Test] - public void DiscoveryRequiresSslIgnoresInsecureEndpointsInXrds() { + public async Task DiscoveryRequiresSslIgnoresInsecureEndpointsInXrds() { var insecureEndpoint = GetServiceEndpoint(0, ProtocolVersion.V20, 10, false); var secureEndpoint = GetServiceEndpoint(1, ProtocolVersion.V20, 20, true); UriIdentifier secureClaimedId = new UriIdentifier(VanityUriSsl, true); - this.MockResponder.RegisterMockXrdsResponse(secureClaimedId, new IdentifierDiscoveryResult[] { insecureEndpoint, secureEndpoint }); - Assert.AreEqual(secureEndpoint.ProviderLocalIdentifier, this.Discover(secureClaimedId).Single().ProviderLocalIdentifier); + this.RegisterMockXrdsResponse(secureClaimedId, new[] { insecureEndpoint, secureEndpoint }); + var discoverResult = await this.DiscoverAsync(secureClaimedId); + Assert.AreEqual(secureEndpoint.ProviderLocalIdentifier, discoverResult.Single().ProviderLocalIdentifier); } [Test] - public void XrdsDirectDiscovery_10() { - this.FailDiscoverXrds("xrds-irrelevant"); - this.DiscoverXrds("xrds10", ProtocolVersion.V10, null, "http://a/b"); - this.DiscoverXrds("xrds11", ProtocolVersion.V11, null, "http://a/b"); - this.DiscoverXrds("xrds1020", ProtocolVersion.V10, null, "http://a/b"); + public async Task XrdsDirectDiscovery_10() { + await this.FailDiscoverXrdsAsync("xrds-irrelevant"); + await this.DiscoverXrdsAsync("xrds10", ProtocolVersion.V10, null, "http://a/b"); + await this.DiscoverXrdsAsync("xrds11", ProtocolVersion.V11, null, "http://a/b"); + await this.DiscoverXrdsAsync("xrds1020", ProtocolVersion.V10, null, "http://a/b"); } [Test] - public void XrdsDirectDiscovery_20() { - this.DiscoverXrds("xrds20", ProtocolVersion.V20, null, "http://a/b"); - this.DiscoverXrds("xrds2010a", ProtocolVersion.V20, null, "http://a/b"); - this.DiscoverXrds("xrds2010b", ProtocolVersion.V20, null, "http://a/b"); + public async Task XrdsDirectDiscovery_20() { + await this.DiscoverXrdsAsync("xrds20", ProtocolVersion.V20, null, "http://a/b"); + await this.DiscoverXrdsAsync("xrds2010a", ProtocolVersion.V20, null, "http://a/b"); + await this.DiscoverXrdsAsync("xrds2010b", ProtocolVersion.V20, null, "http://a/b"); } [Test] - public void HtmlDiscover_11() { - this.DiscoverHtml("html10prov", ProtocolVersion.V11, null, "http://a/b"); - this.DiscoverHtml("html10both", ProtocolVersion.V11, "http://c/d", "http://a/b"); - this.FailDiscoverHtml("html10del"); + public async Task HtmlDiscover_11() { + await this.DiscoverHtmlAsync("html10prov", ProtocolVersion.V11, null, "http://a/b"); + await this.DiscoverHtmlAsync("html10both", ProtocolVersion.V11, "http://c/d", "http://a/b"); + await this.FailDiscoverHtmlAsync("html10del"); // Verify that HTML discovery generates the 1.x endpoints when appropriate - this.DiscoverHtml("html2010", ProtocolVersion.V11, "http://g/h", "http://e/f"); - this.DiscoverHtml("html1020", ProtocolVersion.V11, "http://g/h", "http://e/f"); - this.DiscoverHtml("html2010combinedA", ProtocolVersion.V11, "http://c/d", "http://a/b"); - this.DiscoverHtml("html2010combinedB", ProtocolVersion.V11, "http://c/d", "http://a/b"); - this.DiscoverHtml("html2010combinedC", ProtocolVersion.V11, "http://c/d", "http://a/b"); + await this.DiscoverHtmlAsync("html2010", ProtocolVersion.V11, "http://g/h", "http://e/f"); + await this.DiscoverHtmlAsync("html1020", ProtocolVersion.V11, "http://g/h", "http://e/f"); + await this.DiscoverHtmlAsync("html2010combinedA", ProtocolVersion.V11, "http://c/d", "http://a/b"); + await this.DiscoverHtmlAsync("html2010combinedB", ProtocolVersion.V11, "http://c/d", "http://a/b"); + await this.DiscoverHtmlAsync("html2010combinedC", ProtocolVersion.V11, "http://c/d", "http://a/b"); } [Test] - public void HtmlDiscover_20() { - this.DiscoverHtml("html20prov", ProtocolVersion.V20, null, "http://a/b"); - this.DiscoverHtml("html20both", ProtocolVersion.V20, "http://c/d", "http://a/b"); - this.FailDiscoverHtml("html20del"); - this.DiscoverHtml("html2010", ProtocolVersion.V20, "http://c/d", "http://a/b"); - this.DiscoverHtml("html1020", ProtocolVersion.V20, "http://c/d", "http://a/b"); - this.DiscoverHtml("html2010combinedA", ProtocolVersion.V20, "http://c/d", "http://a/b"); - this.DiscoverHtml("html2010combinedB", ProtocolVersion.V20, "http://c/d", "http://a/b"); - this.DiscoverHtml("html2010combinedC", ProtocolVersion.V20, "http://c/d", "http://a/b"); - this.FailDiscoverHtml("html20relative"); + public async Task HtmlDiscover_20() { + await this.DiscoverHtmlAsync("html20prov", ProtocolVersion.V20, null, "http://a/b"); + await this.DiscoverHtmlAsync("html20both", ProtocolVersion.V20, "http://c/d", "http://a/b"); + await this.FailDiscoverHtmlAsync("html20del"); + await this.DiscoverHtmlAsync("html2010", ProtocolVersion.V20, "http://c/d", "http://a/b"); + await this.DiscoverHtmlAsync("html1020", ProtocolVersion.V20, "http://c/d", "http://a/b"); + await this.DiscoverHtmlAsync("html2010combinedA", ProtocolVersion.V20, "http://c/d", "http://a/b"); + await this.DiscoverHtmlAsync("html2010combinedB", ProtocolVersion.V20, "http://c/d", "http://a/b"); + await this.DiscoverHtmlAsync("html2010combinedC", ProtocolVersion.V20, "http://c/d", "http://a/b"); + await this.FailDiscoverHtmlAsync("html20relative"); } [Test] - public void XrdsDiscoveryFromHead() { - this.MockResponder.RegisterMockResponse(new Uri("http://localhost/xrds1020.xml"), "application/xrds+xml", LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds1020.xml")); - this.DiscoverXrds("XrdsReferencedInHead.html", ProtocolVersion.V10, null, "http://a/b"); + public async Task XrdsDiscoveryFromHead() { + this.RegisterMockResponse( + new Uri("http://localhost/xrds1020.xml"), + "application/xrds+xml", + LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds1020.xml")); + await this.DiscoverXrdsAsync("XrdsReferencedInHead.html", ProtocolVersion.V10, null, "http://a/b"); } [Test] - public void XrdsDiscoveryFromHttpHeader() { + public async Task XrdsDiscoveryFromHttpHeader() { WebHeaderCollection headers = new WebHeaderCollection(); headers.Add("X-XRDS-Location", new Uri("http://localhost/xrds1020.xml").AbsoluteUri); - this.MockResponder.RegisterMockResponse(new Uri("http://localhost/xrds1020.xml"), "application/xrds+xml", LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds1020.xml")); - this.DiscoverXrds("XrdsReferencedInHttpHeader.html", ProtocolVersion.V10, null, "http://a/b", headers); + this.RegisterMockResponse(new Uri("http://localhost/xrds1020.xml"), "application/xrds+xml", LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds1020.xml")); + await this.DiscoverXrdsAsync("XrdsReferencedInHttpHeader.html", ProtocolVersion.V10, null, "http://a/b", headers); } /// <summary> /// Verifies HTML discovery proceeds if an XRDS document is referenced that doesn't contain OpenID endpoints. /// </summary> [Test] - public void HtmlDiscoveryProceedsIfXrdsIsEmpty() { - this.MockResponder.RegisterMockResponse(new Uri("http://localhost/xrds-irrelevant.xml"), "application/xrds+xml", LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds-irrelevant.xml")); - this.DiscoverHtml("html20provWithEmptyXrds", ProtocolVersion.V20, null, "http://a/b"); + public async Task HtmlDiscoveryProceedsIfXrdsIsEmpty() { + this.RegisterMockResponse(new Uri("http://localhost/xrds-irrelevant.xml"), "application/xrds+xml", LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds-irrelevant.xml")); + await this.DiscoverHtmlAsync("html20provWithEmptyXrds", ProtocolVersion.V20, null, "http://a/b"); } /// <summary> /// Verifies HTML discovery proceeds if the XRDS that is referenced cannot be found. /// </summary> [Test] - public void HtmlDiscoveryProceedsIfXrdsIsBadOrMissing() { - this.DiscoverHtml("html20provWithBadXrds", ProtocolVersion.V20, null, "http://a/b"); + public async Task HtmlDiscoveryProceedsIfXrdsIsBadOrMissing() { + await this.DiscoverHtmlAsync("html20provWithBadXrds", ProtocolVersion.V20, null, "http://a/b"); } /// <summary> /// Verifies that a dual identifier yields only one service endpoint by default. /// </summary> [Test] - public void DualIdentifierOffByDefault() { - this.MockResponder.RegisterMockResponse(VanityUri, "application/xrds+xml", LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds20dual.xml")); - var results = this.Discover(VanityUri).ToList(); + public async Task DualIdentifierOffByDefault() { + this.RegisterMockResponse(VanityUri, "application/xrds+xml", LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds20dual.xml")); + var results = (await this.DiscoverAsync(VanityUri)).ToList(); Assert.AreEqual(1, results.Count(r => r.ClaimedIdentifier == r.Protocol.ClaimedIdentifierForOPIdentifier), "OP Identifier missing from discovery results."); Assert.AreEqual(1, results.Count, "Unexpected additional services discovered."); } @@ -210,26 +224,39 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { /// Verifies that a dual identifier yields two service endpoints when that feature is turned on. /// </summary> [Test] - public void DualIdentifier() { - this.MockResponder.RegisterMockResponse(VanityUri, "application/xrds+xml", LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds20dual.xml")); + public async Task DualIdentifier() { + this.RegisterMockResponse(VanityUri, "application/xrds+xml", LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds20dual.xml")); var rp = this.CreateRelyingParty(true); - rp.Channel.WebRequestHandler = this.RequestHandler; rp.SecuritySettings.AllowDualPurposeIdentifiers = true; - var results = rp.Discover(VanityUri).ToList(); + var results = (await rp.DiscoverAsync(VanityUri, CancellationToken.None)).ToList(); Assert.AreEqual(1, results.Count(r => r.ClaimedIdentifier == r.Protocol.ClaimedIdentifierForOPIdentifier), "OP Identifier missing from discovery results."); Assert.AreEqual(1, results.Count(r => r.ClaimedIdentifier == VanityUri), "Claimed identifier missing from discovery results."); Assert.AreEqual(2, results.Count, "Unexpected additional services discovered."); } - private void Discover(string url, ProtocolVersion version, Identifier expectedLocalId, string providerEndpoint, bool expectSreg, bool useRedirect) { - this.Discover(url, version, expectedLocalId, providerEndpoint, expectSreg, useRedirect, null); + private async Task DiscoverAsync(string url, ProtocolVersion version, Identifier expectedLocalId, string providerEndpoint, bool expectSreg, bool useRedirect) { + await this.DiscoverAsync(url, version, expectedLocalId, providerEndpoint, expectSreg, useRedirect, null); + } + + private string RegisterDiscoveryRedirector(Uri baseUrl) { + var redirectorUrl = new Uri(baseUrl, "Discovery/htmldiscovery/redirect.aspx"); + this.Handle(redirectorUrl).By(req => { + string redirectTarget = HttpUtility.ParseQueryString(req.RequestUri.Query)["target"]; + var response = new HttpResponseMessage(HttpStatusCode.Redirect); + response.Headers.Location = new Uri(redirectTarget, UriKind.RelativeOrAbsolute); + response.RequestMessage = req; + return response; + }); + + return redirectorUrl.AbsoluteUri + "?target="; } - private void Discover(string url, ProtocolVersion version, Identifier expectedLocalId, string providerEndpoint, bool expectSreg, bool useRedirect, WebHeaderCollection headers) { + private async Task DiscoverAsync(string url, ProtocolVersion version, Identifier expectedLocalId, string providerEndpoint, bool expectSreg, bool useRedirect, WebHeaderCollection headers) { Protocol protocol = Protocol.Lookup(version); Uri baseUrl = new Uri("http://localhost/"); + string redirectBase = this.RegisterDiscoveryRedirector(baseUrl); UriIdentifier claimedId = new Uri(baseUrl, url); - UriIdentifier userSuppliedIdentifier = new Uri(baseUrl, "Discovery/htmldiscovery/redirect.aspx?target=" + url); + UriIdentifier userSuppliedIdentifier = new Uri(redirectBase + Uri.EscapeDataString(url)); if (expectedLocalId == null) { expectedLocalId = claimedId; } @@ -243,7 +270,7 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { } else { throw new InvalidOperationException(); } - this.MockResponder.RegisterMockResponse(new Uri(idToDiscover), claimedId, contentType, headers ?? new WebHeaderCollection(), LoadEmbeddedFile(url)); + this.RegisterMockResponse(claimedId, claimedId, contentType, headers ?? new WebHeaderCollection(), LoadEmbeddedFile(url)); IdentifierDiscoveryResult expected = IdentifierDiscoveryResult.CreateForClaimedIdentifier( claimedId, @@ -252,7 +279,8 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { null, null); - IdentifierDiscoveryResult se = this.Discover(idToDiscover).FirstOrDefault(ep => ep.Equals(expected)); + var discoveryResult = await this.DiscoverAsync(idToDiscover); + IdentifierDiscoveryResult se = discoveryResult.FirstOrDefault(ep => ep.Equals(expected)); Assert.IsNotNull(se, url + " failed to be discovered."); // Do extra checking of service type URIs, which aren't included in @@ -262,42 +290,43 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { Assert.AreEqual(expectSreg, se.IsExtensionSupported<ClaimsRequest>()); } - private void DiscoverXrds(string page, ProtocolVersion version, Identifier expectedLocalId, string providerEndpoint) { - this.DiscoverXrds(page, version, expectedLocalId, providerEndpoint, null); + private async Task DiscoverXrdsAsync(string page, ProtocolVersion version, Identifier expectedLocalId, string providerEndpoint) { + await this.DiscoverXrdsAsync(page, version, expectedLocalId, providerEndpoint, null); } - private void DiscoverXrds(string page, ProtocolVersion version, Identifier expectedLocalId, string providerEndpoint, WebHeaderCollection headers) { + private async Task DiscoverXrdsAsync(string page, ProtocolVersion version, Identifier expectedLocalId, string providerEndpoint, WebHeaderCollection headers) { if (!page.Contains(".")) { page += ".xml"; } - this.Discover("/Discovery/xrdsdiscovery/" + page, version, expectedLocalId, providerEndpoint, true, false, headers); - this.Discover("/Discovery/xrdsdiscovery/" + page, version, expectedLocalId, providerEndpoint, true, true, headers); + await this.DiscoverAsync("/Discovery/xrdsdiscovery/" + page, version, expectedLocalId, providerEndpoint, true, false, headers); + await this.DiscoverAsync("/Discovery/xrdsdiscovery/" + page, version, expectedLocalId, providerEndpoint, true, true, headers); } - private void DiscoverHtml(string page, ProtocolVersion version, Identifier expectedLocalId, string providerEndpoint, bool useRedirect) { - this.Discover("/Discovery/htmldiscovery/" + page, version, expectedLocalId, providerEndpoint, false, useRedirect); + private async Task DiscoverHtmlAsync(string page, ProtocolVersion version, Identifier expectedLocalId, string providerEndpoint, bool useRedirect) { + await this.DiscoverAsync("/Discovery/htmldiscovery/" + page, version, expectedLocalId, providerEndpoint, false, useRedirect); } - private void DiscoverHtml(string scenario, ProtocolVersion version, Identifier expectedLocalId, string providerEndpoint) { + private async Task DiscoverHtmlAsync(string scenario, ProtocolVersion version, Identifier expectedLocalId, string providerEndpoint) { string page = scenario + ".html"; - this.DiscoverHtml(page, version, expectedLocalId, providerEndpoint, false); - this.DiscoverHtml(page, version, expectedLocalId, providerEndpoint, true); + await this.DiscoverHtmlAsync(page, version, expectedLocalId, providerEndpoint, false); + await this.DiscoverHtmlAsync(page, version, expectedLocalId, providerEndpoint, true); } - private void FailDiscover(string url) { + private async Task FailDiscoverAsync(string url) { UriIdentifier userSuppliedId = new Uri(new Uri("http://localhost"), url); - this.MockResponder.RegisterMockResponse(new Uri(userSuppliedId), userSuppliedId, "text/html", LoadEmbeddedFile(url)); + this.RegisterMockResponse(new Uri(userSuppliedId), userSuppliedId, "text/html", LoadEmbeddedFile(url)); - Assert.AreEqual(0, this.Discover(userSuppliedId).Count()); // ... but that no endpoint info is discoverable + var discoveryResult = await this.DiscoverAsync(userSuppliedId); + Assert.AreEqual(0, discoveryResult.Count()); // ... but that no endpoint info is discoverable } - private void FailDiscoverHtml(string scenario) { - this.FailDiscover("/Discovery/htmldiscovery/" + scenario + ".html"); + private async Task FailDiscoverHtmlAsync(string scenario) { + await this.FailDiscoverAsync("/Discovery/htmldiscovery/" + scenario + ".html"); } - private void FailDiscoverXrds(string scenario) { - this.FailDiscover("/Discovery/xrdsdiscovery/" + scenario + ".xml"); + private async Task FailDiscoverXrdsAsync(string scenario) { + await this.FailDiscoverAsync("/Discovery/xrdsdiscovery/" + scenario + ".xml"); } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/DiscoveryServices/XriDiscoveryProxyServiceTests.cs b/src/DotNetOpenAuth.Test/OpenId/DiscoveryServices/XriDiscoveryProxyServiceTests.cs index fe767ea..23ddbfe 100644 --- a/src/DotNetOpenAuth.Test/OpenId/DiscoveryServices/XriDiscoveryProxyServiceTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/DiscoveryServices/XriDiscoveryProxyServiceTests.cs @@ -9,14 +9,16 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { using System.Collections.Generic; using System.Linq; using System.Text; + using System.Threading.Tasks; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.RelyingParty; + using DotNetOpenAuth.Test.Mocks; using NUnit.Framework; [TestFixture] public class XriDiscoveryProxyServiceTests : OpenIdTestBase { [Test] - public void Discover() { + public async Task Discover() { string xrds = @"<?xml version='1.0' encoding='UTF-8'?> <XRD version='2.0' xmlns='xri://$xrd*($v*2.0)'> <Query>*Arnott</Query> @@ -48,14 +50,14 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { <URI append='none' priority='10'>http://1id.com/sso</URI> </Service> </XRD>"; - Dictionary<string, string> mocks = new Dictionary<string, string> { + var mocks = new Dictionary<string, string> { { "https://xri.net/=Arnott?_xrd_r=application/xrd%2Bxml;sep=false", xrds }, { "https://xri.net/=!9B72.7DD1.50A9.5CCD?_xrd_r=application/xrd%2Bxml;sep=false", xrds }, }; - this.MockResponder.RegisterMockXrdsResponses(mocks); + this.RegisterMockXrdsResponses(mocks); string expectedCanonicalId = "=!9B72.7DD1.50A9.5CCD"; - IdentifierDiscoveryResult se = this.VerifyCanonicalId("=Arnott", expectedCanonicalId); + IdentifierDiscoveryResult se = await this.VerifyCanonicalIdAsync("=Arnott", expectedCanonicalId); Assert.AreEqual(Protocol.V10, Protocol.Lookup(se.Version)); Assert.AreEqual("http://1id.com/sso", se.ProviderEndpoint.ToString()); Assert.AreEqual(se.ClaimedIdentifier, se.ProviderLocalIdentifier); @@ -63,7 +65,7 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { } [Test] - public void DiscoverCommunityInameCanonicalIDs() { + public async Task DiscoverCommunityInameCanonicalIDs() { string llliResponse = @"<?xml version='1.0' encoding='UTF-8'?> <XRD version='2.0' xmlns='xri://$xrd*($v*2.0)'> <Query>*llli</Query> @@ -278,23 +280,23 @@ uEyb50RJ7DWmXctSC0b3eymZ2lSXxAWNOsNy </X509Data> </KeyInfo> </XRD>"; - this.MockResponder.RegisterMockXrdsResponses(new Dictionary<string, string> { + this.RegisterMockXrdsResponses(new Dictionary<string, string> { { "https://xri.net/@llli?_xrd_r=application/xrd%2Bxml;sep=false", llliResponse }, { "https://xri.net/@llli*area?_xrd_r=application/xrd%2Bxml;sep=false", llliAreaResponse }, { "https://xri.net/@llli*area*canada.unattached?_xrd_r=application/xrd%2Bxml;sep=false", llliAreaCanadaUnattachedResponse }, { "https://xri.net/@llli*area*canada.unattached*ada?_xrd_r=application/xrd%2Bxml;sep=false", llliAreaCanadaUnattachedAdaResponse }, { "https://xri.net/=Web?_xrd_r=application/xrd%2Bxml;sep=false", webResponse }, }); - this.VerifyCanonicalId("@llli", "@!72CD.A072.157E.A9C6"); - this.VerifyCanonicalId("@llli*area", "@!72CD.A072.157E.A9C6!0000.0000.3B9A.CA0C"); - this.VerifyCanonicalId("@llli*area*canada.unattached", "@!72CD.A072.157E.A9C6!0000.0000.3B9A.CA0C!0000.0000.3B9A.CA41"); - this.VerifyCanonicalId("@llli*area*canada.unattached*ada", "@!72CD.A072.157E.A9C6!0000.0000.3B9A.CA0C!0000.0000.3B9A.CA41!0000.0000.3B9A.CA01"); - this.VerifyCanonicalId("=Web", "=!91F2.8153.F600.AE24"); + await this.VerifyCanonicalIdAsync("@llli", "@!72CD.A072.157E.A9C6"); + await this.VerifyCanonicalIdAsync("@llli*area", "@!72CD.A072.157E.A9C6!0000.0000.3B9A.CA0C"); + await this.VerifyCanonicalIdAsync("@llli*area*canada.unattached", "@!72CD.A072.157E.A9C6!0000.0000.3B9A.CA0C!0000.0000.3B9A.CA41"); + await this.VerifyCanonicalIdAsync("@llli*area*canada.unattached*ada", "@!72CD.A072.157E.A9C6!0000.0000.3B9A.CA0C!0000.0000.3B9A.CA41!0000.0000.3B9A.CA01"); + await this.VerifyCanonicalIdAsync("=Web", "=!91F2.8153.F600.AE24"); } [Test] - public void DiscoveryCommunityInameDelegateWithoutCanonicalID() { - this.MockResponder.RegisterMockXrdsResponses(new Dictionary<string, string> { + public async Task DiscoveryCommunityInameDelegateWithoutCanonicalID() { + this.RegisterMockXrdsResponses(new Dictionary<string, string> { { "https://xri.net/=Web*andrew.arnott?_xrd_r=application/xrd%2Bxml;sep=false", @"<?xml version='1.0' encoding='UTF-8'?> <XRD xmlns='xri://$xrd*($v*2.0)'> <Query>*andrew.arnott</Query> @@ -375,12 +377,12 @@ uEyb50RJ7DWmXctSC0b3eymZ2lSXxAWNOsNy }); // Consistent with spec section 7.3.2.3, we do not permit // delegation on XRI discovery when there is no CanonicalID present. - this.VerifyCanonicalId("=Web*andrew.arnott", null); - this.VerifyCanonicalId("@id*andrewarnott", null); + await this.VerifyCanonicalIdAsync("=Web*andrew.arnott", null); + await this.VerifyCanonicalIdAsync("@id*andrewarnott", null); } - private IdentifierDiscoveryResult VerifyCanonicalId(Identifier iname, string expectedClaimedIdentifier) { - var se = this.Discover(iname).FirstOrDefault(); + private async Task<IdentifierDiscoveryResult> VerifyCanonicalIdAsync(Identifier iname, string expectedClaimedIdentifier) { + var se = (await this.DiscoverAsync(iname)).FirstOrDefault(); if (expectedClaimedIdentifier != null) { Assert.IsNotNull(se); Assert.AreEqual(expectedClaimedIdentifier, se.ClaimedIdentifier.ToString(), "i-name {0} discovery resulted in unexpected CanonicalId", iname); diff --git a/src/DotNetOpenAuth.Test/OpenId/Extensions/AttributeExchange/AttributeExchangeRoundtripTests.cs b/src/DotNetOpenAuth.Test/OpenId/Extensions/AttributeExchange/AttributeExchangeRoundtripTests.cs index ab0a10b..0d0d36c 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Extensions/AttributeExchange/AttributeExchangeRoundtripTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Extensions/AttributeExchange/AttributeExchangeRoundtripTests.cs @@ -5,6 +5,8 @@ //----------------------------------------------------------------------- namespace DotNetOpenAuth.Test.OpenId.Extensions { + using System.Threading.Tasks; + using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Extensions.AttributeExchange; using NUnit.Framework; @@ -17,7 +19,7 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { private int incrementingAttributeValue = 1; [Test] - public void Fetch() { + public async Task Fetch() { var request = new FetchRequest(); request.Attributes.Add(new AttributeRequest(NicknameTypeUri)); request.Attributes.Add(new AttributeRequest(EmailTypeUri, false, int.MaxValue)); @@ -26,11 +28,11 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { response.Attributes.Add(new AttributeValues(NicknameTypeUri, "Andrew")); response.Attributes.Add(new AttributeValues(EmailTypeUri, "a@a.com", "b@b.com")); - ExtensionTestUtilities.Roundtrip(Protocol.Default, new[] { request }, new[] { response }); + await this.RoundtripAsync(Protocol.Default, new[] { request }, new[] { response }); } [Test] - public void Store() { + public async Task Store() { var request = new StoreRequest(); var newAttribute = new AttributeValues( IncrementingAttribute, @@ -41,13 +43,13 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { var successResponse = new StoreResponse(); successResponse.Succeeded = true; - ExtensionTestUtilities.Roundtrip(Protocol.Default, new[] { request }, new[] { successResponse }); + await this.RoundtripAsync(Protocol.Default, new[] { request }, new[] { successResponse }); var failureResponse = new StoreResponse(); failureResponse.Succeeded = false; failureResponse.FailureReason = "Some error"; - ExtensionTestUtilities.Roundtrip(Protocol.Default, new[] { request }, new[] { failureResponse }); + await this.RoundtripAsync(Protocol.Default, new[] { request }, new[] { failureResponse }); } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/Extensions/ExtensionTestUtilities.cs b/src/DotNetOpenAuth.Test/OpenId/Extensions/ExtensionTestUtilities.cs index 8d0e6ff..fd53fd1 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Extensions/ExtensionTestUtilities.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Extensions/ExtensionTestUtilities.cs @@ -8,6 +8,10 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { using System; using System.Collections.Generic; using System.Linq; + using System.Net.Http; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OpenId; @@ -20,62 +24,6 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { using Validation; public static class ExtensionTestUtilities { - /// <summary> - /// Simulates an extension request and response. - /// </summary> - /// <param name="protocol">The protocol to use in the roundtripping.</param> - /// <param name="requests">The extensions to add to the request message.</param> - /// <param name="responses">The extensions to add to the response message.</param> - /// <remarks> - /// This method relies on the extension objects' Equals methods to verify - /// accurate transport. The Equals methods should be verified by separate tests. - /// </remarks> - internal static void Roundtrip( - Protocol protocol, - IEnumerable<IOpenIdMessageExtension> requests, - IEnumerable<IOpenIdMessageExtension> responses) { - var securitySettings = new ProviderSecuritySettings(); - var cryptoKeyStore = new MemoryCryptoKeyStore(); - var associationStore = new ProviderAssociationHandleEncoder(cryptoKeyStore); - Association association = HmacShaAssociationProvider.Create(protocol, protocol.Args.SignatureAlgorithm.Best, AssociationRelyingPartyType.Smart, associationStore, securitySettings); - var coordinator = new OpenIdCoordinator( - rp => { - RegisterExtension(rp.Channel, Mocks.MockOpenIdExtension.Factory); - var requestBase = new CheckIdRequest(protocol.Version, OpenIdTestBase.OPUri, AuthenticationRequestMode.Immediate); - OpenIdTestBase.StoreAssociation(rp, OpenIdTestBase.OPUri, association); - requestBase.AssociationHandle = association.Handle; - requestBase.ClaimedIdentifier = "http://claimedid"; - requestBase.LocalIdentifier = "http://localid"; - requestBase.ReturnTo = OpenIdTestBase.RPUri; - - foreach (IOpenIdMessageExtension extension in requests) { - requestBase.Extensions.Add(extension); - } - - rp.Channel.Respond(requestBase); - var response = rp.Channel.ReadFromRequest<PositiveAssertionResponse>(); - - var receivedResponses = response.Extensions.Cast<IOpenIdMessageExtension>(); - CollectionAssert<IOpenIdMessageExtension>.AreEquivalentByEquality(responses.ToArray(), receivedResponses.ToArray()); - }, - op => { - RegisterExtension(op.Channel, Mocks.MockOpenIdExtension.Factory); - var key = cryptoKeyStore.GetCurrentKey(ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, TimeSpan.FromSeconds(1)); - op.CryptoKeyStore.StoreKey(ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, key.Key, key.Value); - var request = op.Channel.ReadFromRequest<CheckIdRequest>(); - var response = new PositiveAssertionResponse(request); - var receivedRequests = request.Extensions.Cast<IOpenIdMessageExtension>(); - CollectionAssert<IOpenIdMessageExtension>.AreEquivalentByEquality(requests.ToArray(), receivedRequests.ToArray()); - - foreach (var extensionResponse in responses) { - response.Extensions.Add(extensionResponse); - } - - op.Channel.Respond(response); - }); - coordinator.Run(); - } - internal static void RegisterExtension(Channel channel, StandardOpenIdExtensionFactory.CreateDelegate extensionFactory) { Requires.NotNull(channel, "channel"); diff --git a/src/DotNetOpenAuth.Test/OpenId/Extensions/ExtensionsInteropHelperOPTests.cs b/src/DotNetOpenAuth.Test/OpenId/Extensions/ExtensionsInteropHelperOPTests.cs index e9ff7a4..60075f3 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Extensions/ExtensionsInteropHelperOPTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Extensions/ExtensionsInteropHelperOPTests.cs @@ -7,6 +7,8 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { using System.Collections.Generic; using System.Linq; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Extensions; @@ -38,7 +40,7 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { /// Verifies no extensions appear as no extensions /// </summary> [Test] - public void NoRequestedExtensions() { + public async Task NoRequestedExtensions() { var sreg = ExtensionsInteropHelper.UnifyExtensionsAsSreg(this.request); Assert.IsNull(sreg); @@ -47,22 +49,22 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { // to directly create a response without a request. var sregResponse = new ClaimsResponse(); this.request.AddResponseExtension(sregResponse); - ExtensionsInteropHelper.ConvertSregToMatchRequest(this.request); - var extensions = this.GetResponseExtensions(); + await ExtensionsInteropHelper.ConvertSregToMatchRequestAsync(this.request, CancellationToken.None); + var extensions = await this.GetResponseExtensionsAsync(); Assert.AreSame(sregResponse, extensions.Single()); } [Test] - public void NegativeResponse() { + public async Task NegativeResponse() { this.request.IsAuthenticated = false; - ExtensionsInteropHelper.ConvertSregToMatchRequest(this.request); + await ExtensionsInteropHelper.ConvertSregToMatchRequestAsync(this.request, CancellationToken.None); } /// <summary> /// Verifies sreg coming in is seen as sreg. /// </summary> [Test] - public void UnifyExtensionsAsSregWithSreg() { + public async Task UnifyExtensionsAsSregWithSreg() { var sregInjected = new ClaimsRequest(DotNetOpenAuth.OpenId.Extensions.SimpleRegistration.Constants.TypeUris.Standard) { Nickname = DemandLevel.Request, }; @@ -74,8 +76,8 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { var sregResponse = sreg.CreateResponse(); this.request.AddResponseExtension(sregResponse); - ExtensionsInteropHelper.ConvertSregToMatchRequest(this.request); - var extensions = this.GetResponseExtensions(); + await ExtensionsInteropHelper.ConvertSregToMatchRequestAsync(this.request, CancellationToken.None); + var extensions = await this.GetResponseExtensionsAsync(); Assert.AreSame(sregResponse, extensions.Single()); } @@ -83,23 +85,23 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { /// Verifies AX coming in looks like sreg. /// </summary> [Test] - public void UnifyExtensionsAsSregWithAX() { - this.ParameterizedAXTest(AXAttributeFormats.AXSchemaOrg); + public async Task UnifyExtensionsAsSregWithAX() { + await this.ParameterizedAXTestAsync(AXAttributeFormats.AXSchemaOrg); } /// <summary> /// Verifies AX coming in looks like sreg. /// </summary> [Test] - public void UnifyExtensionsAsSregWithAXSchemaOpenIdNet() { - this.ParameterizedAXTest(AXAttributeFormats.SchemaOpenIdNet); + public async Task UnifyExtensionsAsSregWithAXSchemaOpenIdNet() { + await this.ParameterizedAXTestAsync(AXAttributeFormats.SchemaOpenIdNet); } /// <summary> /// Verifies sreg and AX in one request has a preserved sreg request. /// </summary> [Test] - public void UnifyExtensionsAsSregWithBothSregAndAX() { + public async Task UnifyExtensionsAsSregWithBothSregAndAX() { var sregInjected = new ClaimsRequest(DotNetOpenAuth.OpenId.Extensions.SimpleRegistration.Constants.TypeUris.Standard) { Nickname = DemandLevel.Request, }; @@ -118,20 +120,20 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { var axResponseInjected = new FetchResponse(); axResponseInjected.Attributes.Add(WellKnownAttributes.Contact.Email, "a@b.com"); this.request.AddResponseExtension(axResponseInjected); - ExtensionsInteropHelper.ConvertSregToMatchRequest(this.request); - var extensions = this.GetResponseExtensions(); + await ExtensionsInteropHelper.ConvertSregToMatchRequestAsync(this.request, CancellationToken.None); + var extensions = await this.GetResponseExtensionsAsync(); var sregResponse = extensions.OfType<ClaimsResponse>().Single(); Assert.AreEqual("andy", sregResponse.Nickname); var axResponse = extensions.OfType<FetchResponse>().Single(); Assert.AreEqual("a@b.com", axResponse.GetAttributeValue(WellKnownAttributes.Contact.Email)); } - private IList<IExtensionMessage> GetResponseExtensions() { - IProtocolMessageWithExtensions response = (IProtocolMessageWithExtensions)this.request.Response; + private async Task<IList<IExtensionMessage>> GetResponseExtensionsAsync() { + var response = (IProtocolMessageWithExtensions)await this.request.GetResponseAsync(CancellationToken.None); return response.Extensions; } - private void ParameterizedAXTest(AXAttributeFormats format) { + private async Task ParameterizedAXTestAsync(AXAttributeFormats format) { var axInjected = new FetchRequest(); axInjected.Attributes.AddOptional(ExtensionsInteropHelper.TransformAXFormatTestHook(WellKnownAttributes.Name.Alias, format)); axInjected.Attributes.AddRequired(ExtensionsInteropHelper.TransformAXFormatTestHook(WellKnownAttributes.Name.FullName, format)); @@ -145,8 +147,8 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { var sregResponse = sreg.CreateResponse(); sregResponse.Nickname = "andy"; this.request.AddResponseExtension(sregResponse); - ExtensionsInteropHelper.ConvertSregToMatchRequest(this.request); - var extensions = this.GetResponseExtensions(); + await ExtensionsInteropHelper.ConvertSregToMatchRequestAsync(this.request, CancellationToken.None); + var extensions = await this.GetResponseExtensionsAsync(); var axResponse = extensions.OfType<FetchResponse>().Single(); Assert.AreEqual("andy", axResponse.GetAttributeValue(ExtensionsInteropHelper.TransformAXFormatTestHook(WellKnownAttributes.Name.Alias, format))); } diff --git a/src/DotNetOpenAuth.Test/OpenId/Extensions/ExtensionsInteropHelperRPRequestTests.cs b/src/DotNetOpenAuth.Test/OpenId/Extensions/ExtensionsInteropHelperRPRequestTests.cs index 05ba3ad..055cf8c 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Extensions/ExtensionsInteropHelperRPRequestTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Extensions/ExtensionsInteropHelperRPRequestTests.cs @@ -27,7 +27,7 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { var rp = CreateRelyingParty(true); Identifier identifier = this.GetMockIdentifier(ProtocolVersion.V20); - this.authReq = (AuthenticationRequest)rp.CreateRequest(identifier, RPRealmUri, RPUri); + this.authReq = (AuthenticationRequest)rp.CreateRequestAsync(identifier, RPRealmUri, RPUri).Result; this.sreg = new ClaimsRequest { Nickname = DemandLevel.Request, FullName = DemandLevel.Request, diff --git a/src/DotNetOpenAuth.Test/OpenId/Extensions/ProviderAuthenticationPolicy/PapeRoundTripTests.cs b/src/DotNetOpenAuth.Test/OpenId/Extensions/ProviderAuthenticationPolicy/PapeRoundTripTests.cs index cba54bf..3cb3028 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Extensions/ProviderAuthenticationPolicy/PapeRoundTripTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Extensions/ProviderAuthenticationPolicy/PapeRoundTripTests.cs @@ -6,6 +6,8 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions.ProviderAuthenticationPolicy { using System; + using System.Threading.Tasks; + using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Extensions.ProviderAuthenticationPolicy; using DotNetOpenAuth.Test.OpenId.Extensions; @@ -14,14 +16,14 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions.ProviderAuthenticationPolicy { [TestFixture] public class PapeRoundTripTests : OpenIdTestBase { [Test] - public void Trivial() { + public async Task Trivial() { var request = new PolicyRequest(); var response = new PolicyResponse(); - ExtensionTestUtilities.Roundtrip(Protocol.Default, new[] { request }, new[] { response }); + await this.RoundtripAsync(Protocol.Default, new[] { request }, new[] { response }); } [Test] - public void Full() { + public async Task Full() { var request = new PolicyRequest(); request.MaximumAuthenticationAge = TimeSpan.FromMinutes(10); request.PreferredAuthLevelTypes.Add(Constants.AssuranceLevels.NistTypeUri); @@ -37,7 +39,7 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions.ProviderAuthenticationPolicy { response.AssuranceLevels["customlevel"] = "ABC"; response.NistAssuranceLevel = NistAssuranceLevel.Level2; - ExtensionTestUtilities.Roundtrip(Protocol.Default, new[] { request }, new[] { response }); + await this.RoundtripAsync(Protocol.Default, new[] { request }, new[] { response }); } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/Extensions/SimpleRegistration/ClaimsResponseTests.cs b/src/DotNetOpenAuth.Test/OpenId/Extensions/SimpleRegistration/ClaimsResponseTests.cs index f898511..1aa6e33 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Extensions/SimpleRegistration/ClaimsResponseTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Extensions/SimpleRegistration/ClaimsResponseTests.cs @@ -10,13 +10,14 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { using System.IO; using System.Runtime.Serialization; using System.Runtime.Serialization.Formatters.Binary; + using System.Threading.Tasks; using System.Xml.Serialization; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Extensions.SimpleRegistration; using NUnit.Framework; [TestFixture] - public class ClaimsResponseTests { + public class ClaimsResponseTests : OpenIdTestBase { [Test] public void EmptyMailAddress() { ClaimsResponse response = new ClaimsResponse(Constants.TypeUris.Standard); @@ -132,14 +133,14 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { } [Test] - public void ResponseAlternateTypeUriTests() { + public async Task ResponseAlternateTypeUriTests() { var request = new ClaimsRequest(Constants.TypeUris.Variant10); request.Email = DemandLevel.Require; var response = new ClaimsResponse(Constants.TypeUris.Variant10); response.Email = "a@b.com"; - ExtensionTestUtilities.Roundtrip(Protocol.Default, new[] { request }, new[] { response }); + await this.RoundtripAsync(Protocol.Default, new[] { request }, new[] { response }); } private ClaimsResponse GetFilledData() { diff --git a/src/DotNetOpenAuth.Test/OpenId/NonIdentityTests.cs b/src/DotNetOpenAuth.Test/OpenId/NonIdentityTests.cs index 393239b..11463c7 100644 --- a/src/DotNetOpenAuth.Test/OpenId/NonIdentityTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/NonIdentityTests.cs @@ -5,53 +5,84 @@ //----------------------------------------------------------------------- namespace DotNetOpenAuth.Test.OpenId { + using System; + using System.Net.Http; + using System.Threading; + using System.Threading.Tasks; + + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Messages; using DotNetOpenAuth.OpenId.Provider; using DotNetOpenAuth.OpenId.RelyingParty; using NUnit.Framework; + using System.Net; [TestFixture] public class NonIdentityTests : OpenIdTestBase { [Test] - public void ExtensionOnlyChannelLevel() { + public async Task ExtensionOnlyChannelLevel() { Protocol protocol = Protocol.V20; - AuthenticationRequestMode mode = AuthenticationRequestMode.Setup; - - var coordinator = new OpenIdCoordinator( - rp => { - var request = new SignedResponseRequest(protocol.Version, OPUri, mode); - rp.Channel.Respond(request); - }, - op => { - var request = op.Channel.ReadFromRequest<SignedResponseRequest>(); + var mode = AuthenticationRequestMode.Setup; + + HandleProvider( + async (op, req) => { + var request = await op.Channel.ReadFromRequestAsync<SignedResponseRequest>(req, CancellationToken.None); Assert.IsNotInstanceOf<CheckIdRequest>(request); + return new HttpResponseMessage(); }); - coordinator.Run(); + + { + var rp = this.CreateRelyingParty(); + var request = new SignedResponseRequest(protocol.Version, OPUri, mode); + var authRequest = await rp.Channel.PrepareResponseAsync(request); + using (var httpClient = rp.Channel.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(authRequest.Headers.Location)) { + response.EnsureSuccessStatusCode(); + } + } + } } [Test] - public void ExtensionOnlyFacadeLevel() { + public async Task ExtensionOnlyFacadeLevel() { Protocol protocol = Protocol.V20; - var coordinator = new OpenIdCoordinator( - rp => { - var request = rp.CreateRequest(GetMockIdentifier(protocol.ProtocolVersion), RPRealmUri, RPUri); - - request.IsExtensionOnly = true; - rp.Channel.Respond(request.RedirectingResponse.OriginalMessage); - IAuthenticationResponse response = rp.GetResponse(); - Assert.AreEqual(AuthenticationStatus.ExtensionsOnly, response.Status); - }, - op => { - var assocRequest = op.GetRequest(); - op.Respond(assocRequest); - - var request = (IAnonymousRequest)op.GetRequest(); - request.IsApproved = true; - Assert.IsNotInstanceOf<CheckIdRequest>(request); - op.Respond(request); + int opStep = 0; + HandleProvider( + async (op, req) => { + switch (++opStep) { + case 1: + var assocRequest = await op.GetRequestAsync(req); + return await op.PrepareResponseAsync(assocRequest); + case 2: + var request = (IAnonymousRequest)await op.GetRequestAsync(req); + request.IsApproved = true; + Assert.IsNotInstanceOf<CheckIdRequest>(request); + return await op.PrepareResponseAsync(request); + default: + throw Assumes.NotReachable(); + } }); - coordinator.Run(); + + { + var rp = this.CreateRelyingParty(); + var request = await rp.CreateRequestAsync(GetMockIdentifier(protocol.ProtocolVersion), RPRealmUri, RPUri); + + request.IsExtensionOnly = true; + var redirectRequest = await request.GetRedirectingResponseAsync(); + Uri redirectResponseUrl; + this.HostFactories.AllowAutoRedirects = false; + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var redirectResponse = await httpClient.GetAsync(redirectRequest.Headers.Location)) { + Assert.That(redirectResponse.StatusCode, Is.EqualTo(HttpStatusCode.Redirect)); + redirectResponseUrl = redirectResponse.Headers.Location; + } + } + + IAuthenticationResponse response = + await rp.GetResponseAsync(new HttpRequestMessage(HttpMethod.Get, redirectResponseUrl)); + Assert.AreEqual(AuthenticationStatus.ExtensionsOnly, response.Status); + } } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/OpenIdCoordinator.cs b/src/DotNetOpenAuth.Test/OpenId/OpenIdCoordinator.cs deleted file mode 100644 index 5000833..0000000 --- a/src/DotNetOpenAuth.Test/OpenId/OpenIdCoordinator.cs +++ /dev/null @@ -1,69 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="OpenIdCoordinator.cs" company="Outercurve Foundation"> -// Copyright (c) Outercurve Foundation. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.Test.OpenId { - using System; - using DotNetOpenAuth.Messaging; - using DotNetOpenAuth.Messaging.Bindings; - using DotNetOpenAuth.OpenId; - using DotNetOpenAuth.OpenId.Provider; - using DotNetOpenAuth.OpenId.RelyingParty; - using DotNetOpenAuth.Test.Mocks; - using Validation; - - internal class OpenIdCoordinator : CoordinatorBase<OpenIdRelyingParty, OpenIdProvider> { - internal OpenIdCoordinator(Action<OpenIdRelyingParty> rpAction, Action<OpenIdProvider> opAction) - : base(WrapAction(rpAction), WrapAction(opAction)) { - } - - internal OpenIdProvider Provider { get; set; } - - internal OpenIdRelyingParty RelyingParty { get; set; } - - internal override void Run() { - this.EnsurePartiesAreInitialized(); - var rpCoordinatingChannel = new CoordinatingChannel(this.RelyingParty.Channel, this.IncomingMessageFilter, this.OutgoingMessageFilter); - var opCoordinatingChannel = new CoordinatingChannel(this.Provider.Channel, this.IncomingMessageFilter, this.OutgoingMessageFilter); - rpCoordinatingChannel.RemoteChannel = opCoordinatingChannel; - opCoordinatingChannel.RemoteChannel = rpCoordinatingChannel; - - this.RelyingParty.Channel = rpCoordinatingChannel; - this.Provider.Channel = opCoordinatingChannel; - - RunCore(this.RelyingParty, this.Provider); - } - - private static Action<OpenIdRelyingParty> WrapAction(Action<OpenIdRelyingParty> action) { - Requires.NotNull(action, "action"); - - return rp => { - action(rp); - ((CoordinatingChannel)rp.Channel).Close(); - }; - } - - private static Action<OpenIdProvider> WrapAction(Action<OpenIdProvider> action) { - Requires.NotNull(action, "action"); - - return op => { - action(op); - ((CoordinatingChannel)op.Channel).Close(); - }; - } - - private void EnsurePartiesAreInitialized() { - if (this.RelyingParty == null) { - this.RelyingParty = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore()); - this.RelyingParty.DiscoveryServices.Add(new MockIdentifierDiscoveryService()); - } - - if (this.Provider == null) { - this.Provider = new OpenIdProvider(new StandardProviderApplicationStore()); - this.Provider.DiscoveryServices.Add(new MockIdentifierDiscoveryService()); - } - } - } -} diff --git a/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs b/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs index 3a27e96..ea2867c 100644 --- a/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs +++ b/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs @@ -8,21 +8,29 @@ namespace DotNetOpenAuth.Test.OpenId { using System; using System.Collections.Generic; using System.IO; + using System.Linq; + using System.Net; + using System.Net.Http; using System.Reflection; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Configuration; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OpenId; + using DotNetOpenAuth.OpenId.Messages; using DotNetOpenAuth.OpenId.Provider; using DotNetOpenAuth.OpenId.RelyingParty; + using DotNetOpenAuth.Test.Messaging; using DotNetOpenAuth.Test.Mocks; - using NUnit.Framework; + using DotNetOpenAuth.Test.OpenId.Extensions; - public class OpenIdTestBase : TestBase { - internal IDirectWebRequestHandler RequestHandler; + using NUnit.Framework; - internal MockHttpRequest MockResponder; + using IAuthenticationRequest = DotNetOpenAuth.OpenId.Provider.IAuthenticationRequest; + public class OpenIdTestBase : TestBase { protected internal const string IdentifierSelect = "http://specs.openid.net/auth/2.0/identifier_select"; protected internal static readonly Uri BaseMockUri = new Uri("http://localhost/"); @@ -69,10 +77,9 @@ namespace DotNetOpenAuth.Test.OpenId { this.RelyingPartySecuritySettings = OpenIdElement.Configuration.RelyingParty.SecuritySettings.CreateSecuritySettings(); this.ProviderSecuritySettings = OpenIdElement.Configuration.Provider.SecuritySettings.CreateSecuritySettings(); - this.MockResponder = MockHttpRequest.CreateUntrustedMockHttpHandler(); - this.RequestHandler = this.MockResponder.MockWebRequestHandler; this.AutoProviderScenario = Scenarios.AutoApproval; Identifier.EqualityOnStrings = true; + this.HostFactories.InstallUntrustedWebReqestHandler = true; } [TearDown] @@ -121,7 +128,7 @@ namespace DotNetOpenAuth.Test.OpenId { internal static IdentifierDiscoveryResult GetServiceEndpoint(int user, ProtocolVersion providerVersion, int servicePriority, bool useSsl, bool delegating) { var providerEndpoint = new ProviderEndpointDescription( - useSsl ? OpenIdTestBase.OPUriSsl : OpenIdTestBase.OPUri, + useSsl ? OPUriSsl : OPUri, new string[] { Protocol.Lookup(providerVersion).ClaimedIdentifierServiceTypeURI }); var local_id = useSsl ? OPLocalIdentifiersSsl[user] : OPLocalIdentifiers[user]; var claimed_id = delegating ? (useSsl ? VanityUriSsl : VanityUri) : local_id; @@ -135,50 +142,59 @@ namespace DotNetOpenAuth.Test.OpenId { } /// <summary> - /// A default implementation of a simple provider that responds to authentication requests + /// Gets a default implementation of a simple provider that responds to authentication requests /// per the scenario that is being simulated. /// </summary> - /// <param name="provider">The OpenIdProvider on which the process messages.</param> /// <remarks> /// This is a very useful method to pass to the OpenIdCoordinator constructor for the Provider argument. /// </remarks> - internal void AutoProvider(OpenIdProvider provider) { - while (!((CoordinatingChannel)provider.Channel).RemoteChannel.IsDisposed) { - IRequest request = provider.GetRequest(); - if (request == null) { - continue; - } + internal void RegisterAutoProvider() { + this.Handle(OPUri).By( + async (req, ct) => { + var provider = new OpenIdProvider(new StandardProviderApplicationStore(), this.HostFactories); + return await this.AutoProviderActionAsync(provider, req, ct); + }); + } - if (!request.IsResponseReady) { - var authRequest = (DotNetOpenAuth.OpenId.Provider.IAuthenticationRequest)request; - switch (this.AutoProviderScenario) { - case Scenarios.AutoApproval: - authRequest.IsAuthenticated = true; - break; - case Scenarios.AutoApprovalAddFragment: - authRequest.SetClaimedIdentifierFragment("frag"); - authRequest.IsAuthenticated = true; - break; - case Scenarios.ApproveOnSetup: - authRequest.IsAuthenticated = !authRequest.Immediate; - break; - case Scenarios.AlwaysDeny: - authRequest.IsAuthenticated = false; - break; - default: - // All other scenarios are done programmatically only. - throw new InvalidOperationException("Unrecognized scenario"); - } + /// <summary> + /// Gets a default implementation of a simple provider that responds to authentication requests + /// per the scenario that is being simulated. + /// </summary> + /// <remarks> + /// This is a very useful method to pass to the OpenIdCoordinator constructor for the Provider argument. + /// </remarks> + internal async Task<HttpResponseMessage> AutoProviderActionAsync(OpenIdProvider provider, HttpRequestMessage req, CancellationToken ct) { + IRequest request = await provider.GetRequestAsync(req, ct); + Assert.That(request, Is.Not.Null); + + if (!request.IsResponseReady) { + var authRequest = (IAuthenticationRequest)request; + switch (this.AutoProviderScenario) { + case Scenarios.AutoApproval: + authRequest.IsAuthenticated = true; + break; + case Scenarios.AutoApprovalAddFragment: + authRequest.SetClaimedIdentifierFragment("frag"); + authRequest.IsAuthenticated = true; + break; + case Scenarios.ApproveOnSetup: + authRequest.IsAuthenticated = !authRequest.Immediate; + break; + case Scenarios.AlwaysDeny: + authRequest.IsAuthenticated = false; + break; + default: + // All other scenarios are done programmatically only. + throw new InvalidOperationException("Unrecognized scenario"); } - - provider.Respond(request); } + + return await provider.PrepareResponseAsync(request, ct); } - internal IEnumerable<IdentifierDiscoveryResult> Discover(Identifier identifier) { + internal Task<IEnumerable<IdentifierDiscoveryResult>> DiscoverAsync(Identifier identifier, CancellationToken cancellationToken = default(CancellationToken)) { var rp = this.CreateRelyingParty(true); - rp.Channel.WebRequestHandler = this.RequestHandler; - return rp.Discover(identifier); + return rp.DiscoverAsync(identifier, cancellationToken); } protected Realm GetMockRealm(bool useSsl) { @@ -196,8 +212,8 @@ namespace DotNetOpenAuth.Test.OpenId { protected Identifier GetMockIdentifier(ProtocolVersion providerVersion, bool useSsl, bool delegating) { var se = GetServiceEndpoint(0, providerVersion, 10, useSsl, delegating); - UriIdentifier identityUri = (UriIdentifier)se.ClaimedIdentifier; - return new MockIdentifier(identityUri, this.MockResponder, new IdentifierDiscoveryResult[] { se }); + this.RegisterMockXrdsResponse(se); + return se.ClaimedIdentifier; } protected Identifier GetMockDualIdentifier() { @@ -208,8 +224,8 @@ namespace DotNetOpenAuth.Test.OpenId { IdentifierDiscoveryResult.CreateForProviderIdentifier(protocol.ClaimedIdentifierForOPIdentifier, opDesc, 20, 20), }; - Identifier dualId = new MockIdentifier(VanityUri, this.MockResponder, dualResults); - return dualId; + this.RegisterMockXrdsResponse(VanityUri, dualResults); + return VanityUri; } /// <summary> @@ -226,9 +242,7 @@ namespace DotNetOpenAuth.Test.OpenId { /// <param name="stateless">if set to <c>true</c> a stateless RP is created.</param> /// <returns>The new instance.</returns> protected OpenIdRelyingParty CreateRelyingParty(bool stateless) { - var rp = new OpenIdRelyingParty(stateless ? null : new StandardRelyingPartyApplicationStore()); - rp.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; - rp.DiscoveryServices.Add(new MockIdentifierDiscoveryService()); + var rp = new OpenIdRelyingParty(stateless ? null : new StandardRelyingPartyApplicationStore(), this.HostFactories); return rp; } @@ -237,10 +251,89 @@ namespace DotNetOpenAuth.Test.OpenId { /// </summary> /// <returns>The new instance.</returns> protected OpenIdProvider CreateProvider() { - var op = new OpenIdProvider(new StandardProviderApplicationStore()); - op.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; - op.DiscoveryServices.Add(new MockIdentifierDiscoveryService()); + var op = new OpenIdProvider(new StandardProviderApplicationStore(), this.HostFactories); return op; } + + protected internal void HandleProvider(Func<OpenIdProvider, HttpRequestMessage, Task<HttpResponseMessage>> provider) { + var op = this.CreateProvider(); + this.Handle(OPUri).By(async req => { + return await provider(op, req); + }); + } + + /// <summary> + /// Simulates an extension request and response. + /// </summary> + /// <param name="protocol">The protocol to use in the roundtripping.</param> + /// <param name="requests">The extensions to add to the request message.</param> + /// <param name="responses">The extensions to add to the response message.</param> + /// <remarks> + /// This method relies on the extension objects' Equals methods to verify + /// accurate transport. The Equals methods should be verified by separate tests. + /// </remarks> + internal async Task RoundtripAsync( + Protocol protocol, IEnumerable<IOpenIdMessageExtension> requests, IEnumerable<IOpenIdMessageExtension> responses) { + var securitySettings = new ProviderSecuritySettings(); + var cryptoKeyStore = new MemoryCryptoKeyStore(); + var associationStore = new ProviderAssociationHandleEncoder(cryptoKeyStore); + Association association = HmacShaAssociationProvider.Create( + protocol, + protocol.Args.SignatureAlgorithm.Best, + AssociationRelyingPartyType.Smart, + associationStore, + securitySettings); + + this.HandleProvider( + async (op, req) => { + ExtensionTestUtilities.RegisterExtension(op.Channel, Mocks.MockOpenIdExtension.Factory); + var key = cryptoKeyStore.GetCurrentKey( + ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, TimeSpan.FromSeconds(1)); + op.CryptoKeyStore.StoreKey( + ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, key.Key, key.Value); + var request = await op.Channel.ReadFromRequestAsync<CheckIdRequest>(req, CancellationToken.None); + var response = new PositiveAssertionResponse(request); + var receivedRequests = request.Extensions.Cast<IOpenIdMessageExtension>(); + CollectionAssert<IOpenIdMessageExtension>.AreEquivalentByEquality(requests.ToArray(), receivedRequests.ToArray()); + + foreach (var extensionResponse in responses) { + response.Extensions.Add(extensionResponse); + } + + return await op.Channel.PrepareResponseAsync(response); + }); + + { + var rp = this.CreateRelyingParty(); + ExtensionTestUtilities.RegisterExtension(rp.Channel, Mocks.MockOpenIdExtension.Factory); + var requestBase = new CheckIdRequest(protocol.Version, OpenIdTestBase.OPUri, AuthenticationRequestMode.Immediate); + OpenIdTestBase.StoreAssociation(rp, OpenIdTestBase.OPUri, association); + requestBase.AssociationHandle = association.Handle; + requestBase.ClaimedIdentifier = "http://claimedid"; + requestBase.LocalIdentifier = "http://localid"; + requestBase.ReturnTo = OpenIdTestBase.RPUri; + + foreach (IOpenIdMessageExtension extension in requests) { + requestBase.Extensions.Add(extension); + } + + var redirectingRequest = await rp.Channel.PrepareResponseAsync(requestBase); + Uri redirectingResponseUri; + this.HostFactories.AllowAutoRedirects = false; + using (var httpClient = rp.Channel.HostFactories.CreateHttpClient()) { + using (var redirectingResponse = await httpClient.GetAsync(redirectingRequest.Headers.Location)) { + Assert.AreEqual(HttpStatusCode.Found, redirectingResponse.StatusCode); + redirectingResponseUri = redirectingResponse.Headers.Location; + } + } + + var response = + await + rp.Channel.ReadFromRequestAsync<PositiveAssertionResponse>( + new HttpRequestMessage(HttpMethod.Get, redirectingResponseUri), CancellationToken.None); + var receivedResponses = response.Extensions.Cast<IOpenIdMessageExtension>(); + CollectionAssert<IOpenIdMessageExtension>.AreEquivalentByEquality(responses.ToArray(), receivedResponses.ToArray()); + } + } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/Provider/AnonymousRequestTests.cs b/src/DotNetOpenAuth.Test/OpenId/Provider/AnonymousRequestTests.cs index 7310eb3..8657910 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Provider/AnonymousRequestTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Provider/AnonymousRequestTests.cs @@ -8,6 +8,9 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { using System.IO; using System.Runtime.Serialization; using System.Runtime.Serialization.Formatters.Binary; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Messages; using DotNetOpenAuth.OpenId.Provider; @@ -20,7 +23,7 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { /// Verifies that IsApproved controls which response message is returned. /// </summary> [Test] - public void IsApprovedDeterminesReturnedMessage() { + public async Task IsApprovedDeterminesReturnedMessage() { var op = CreateProvider(); Protocol protocol = Protocol.V20; var req = new SignedResponseRequest(protocol.Version, OPUri, AuthenticationRequestMode.Setup); @@ -30,15 +33,15 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { Assert.IsFalse(anonReq.IsApproved.HasValue); anonReq.IsApproved = false; - Assert.IsInstanceOf<NegativeAssertionResponse>(anonReq.Response); + Assert.IsInstanceOf<NegativeAssertionResponse>(await anonReq.GetResponseAsync(CancellationToken.None)); anonReq.IsApproved = true; - Assert.IsInstanceOf<IndirectSignedResponse>(anonReq.Response); - Assert.IsNotInstanceOf<PositiveAssertionResponse>(anonReq.Response); + Assert.IsInstanceOf<IndirectSignedResponse>(await anonReq.GetResponseAsync(CancellationToken.None)); + Assert.IsNotInstanceOf<PositiveAssertionResponse>(await anonReq.GetResponseAsync(CancellationToken.None)); } /// <summary> - /// Verifies that the AuthenticationRequest method is serializable. + /// Verifies that the AnonymousRequest type is serializable. /// </summary> [Test] public void Serializable() { diff --git a/src/DotNetOpenAuth.Test/OpenId/Provider/AuthenticationRequestTest.cs b/src/DotNetOpenAuth.Test/OpenId/Provider/AuthenticationRequestTest.cs index baf5377..e9c5465 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Provider/AuthenticationRequestTest.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Provider/AuthenticationRequestTest.cs @@ -7,8 +7,12 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { using System; using System.IO; + using System.Net.Http; using System.Runtime.Serialization; using System.Runtime.Serialization.Formatters.Binary; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Messages; @@ -21,24 +25,24 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { /// Verifies the user_setup_url is set properly for immediate negative responses. /// </summary> [Test] - public void UserSetupUrl() { + public async Task UserSetupUrl() { // Construct a V1 immediate request Protocol protocol = Protocol.V11; OpenIdProvider provider = this.CreateProvider(); - CheckIdRequest immediateRequest = new CheckIdRequest(protocol.Version, OPUri, DotNetOpenAuth.OpenId.AuthenticationRequestMode.Immediate); + var immediateRequest = new CheckIdRequest(protocol.Version, OPUri, DotNetOpenAuth.OpenId.AuthenticationRequestMode.Immediate); immediateRequest.Realm = RPRealmUri; immediateRequest.ReturnTo = RPUri; immediateRequest.LocalIdentifier = "http://somebody"; - AuthenticationRequest request = new AuthenticationRequest(provider, immediateRequest); + var request = new AuthenticationRequest(provider, immediateRequest); // Now simulate the request being rejected and extract the user_setup_url request.IsAuthenticated = false; - Uri userSetupUrl = ((NegativeAssertionResponse)request.Response).UserSetupUrl; + Uri userSetupUrl = ((NegativeAssertionResponse)await request.GetResponseAsync(CancellationToken.None)).UserSetupUrl; Assert.IsNotNull(userSetupUrl); // Now construct a new request as if it had just come in. - HttpRequestInfo httpRequest = new HttpRequestInfo("GET", userSetupUrl); - var setupRequest = (AuthenticationRequest)provider.GetRequest(httpRequest); + var httpRequest = new HttpRequestMessage(HttpMethod.Get, userSetupUrl); + var setupRequest = (AuthenticationRequest)await provider.GetRequestAsync(httpRequest); var setupRequestMessage = (CheckIdRequest)setupRequest.RequestMessage; // And make sure all the right properties are set. diff --git a/src/DotNetOpenAuth.Test/OpenId/Provider/HostProcessedRequestTests.cs b/src/DotNetOpenAuth.Test/OpenId/Provider/HostProcessedRequestTests.cs index 2e3e7ec..ce5f417 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Provider/HostProcessedRequestTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Provider/HostProcessedRequestTests.cs @@ -6,9 +6,13 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { using System; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Messages; using DotNetOpenAuth.OpenId.Provider; + using DotNetOpenAuth.Test.Mocks; + using NUnit.Framework; [TestFixture] @@ -24,22 +28,23 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { this.protocol = Protocol.Default; this.provider = this.CreateProvider(); - this.checkIdRequest = new CheckIdRequest(this.protocol.Version, OPUri, DotNetOpenAuth.OpenId.AuthenticationRequestMode.Setup); + this.checkIdRequest = new CheckIdRequest(this.protocol.Version, OPUri, AuthenticationRequestMode.Setup); this.checkIdRequest.Realm = RPRealmUri; this.checkIdRequest.ReturnTo = RPUri; this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); } [Test] - public void IsReturnUrlDiscoverableNoResponse() { - Assert.AreEqual(RelyingPartyDiscoveryResult.NoServiceDocument, this.request.IsReturnUrlDiscoverable(this.provider.Channel.WebRequestHandler)); + public async Task IsReturnUrlDiscoverableNoResponse() { + Assert.AreEqual(RelyingPartyDiscoveryResult.NoServiceDocument, await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); } [Test] - public void IsReturnUrlDiscoverableValidResponse() { - this.MockResponder.RegisterMockRPDiscovery(); + public async Task IsReturnUrlDiscoverableValidResponse() { + this.RegisterMockRPDiscovery(false); + this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); - Assert.AreEqual(RelyingPartyDiscoveryResult.Success, this.request.IsReturnUrlDiscoverable(this.provider.Channel.WebRequestHandler)); + Assert.AreEqual(RelyingPartyDiscoveryResult.Success, await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); } /// <summary> @@ -47,39 +52,42 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { /// is set, that discovery fails. /// </summary> [Test] - public void IsReturnUrlDiscoverableNotSsl() { + public async Task IsReturnUrlDiscoverableNotSsl() { + this.RegisterMockRPDiscovery(false); this.provider.SecuritySettings.RequireSsl = true; - this.MockResponder.RegisterMockRPDiscovery(); - Assert.AreEqual(RelyingPartyDiscoveryResult.NoServiceDocument, this.request.IsReturnUrlDiscoverable(this.provider.Channel.WebRequestHandler)); + Assert.AreEqual(RelyingPartyDiscoveryResult.NoServiceDocument, await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); } /// <summary> /// Verifies that when discovery would be performed over HTTPS that discovery succeeds. /// </summary> [Test] - public void IsReturnUrlDiscoverableRequireSsl() { - this.MockResponder.RegisterMockRPDiscovery(); + public async Task IsReturnUrlDiscoverableRequireSsl() { + this.RegisterMockRPDiscovery(ssl: false); + this.RegisterMockRPDiscovery(ssl: true); this.checkIdRequest.Realm = RPRealmUriSsl; this.checkIdRequest.ReturnTo = RPUriSsl; // Try once with RequireSsl this.provider.SecuritySettings.RequireSsl = true; this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); - Assert.AreEqual(RelyingPartyDiscoveryResult.Success, this.request.IsReturnUrlDiscoverable(this.provider.Channel.WebRequestHandler)); + Assert.AreEqual(RelyingPartyDiscoveryResult.Success, await this.request.IsReturnUrlDiscoverableAsync(this.HostFactories, CancellationToken.None)); // And again without RequireSsl this.provider.SecuritySettings.RequireSsl = false; this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); - Assert.AreEqual(RelyingPartyDiscoveryResult.Success, this.request.IsReturnUrlDiscoverable(this.provider.Channel.WebRequestHandler)); + Assert.AreEqual(RelyingPartyDiscoveryResult.Success, await this.request.IsReturnUrlDiscoverableAsync(this.HostFactories, CancellationToken.None)); } [Test] - public void IsReturnUrlDiscoverableValidButNoMatch() { - this.MockResponder.RegisterMockRPDiscovery(); + public async Task IsReturnUrlDiscoverableValidButNoMatch() { + this.RegisterMockRPDiscovery(false); this.provider.SecuritySettings.RequireSsl = false; // reset for another failure test case this.checkIdRequest.ReturnTo = new Uri("http://somerandom/host"); this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); - Assert.AreEqual(RelyingPartyDiscoveryResult.NoMatchingReturnTo, this.request.IsReturnUrlDiscoverable(this.provider.Channel.WebRequestHandler)); + Assert.AreEqual( + RelyingPartyDiscoveryResult.NoMatchingReturnTo, + await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/Provider/OpenIdProviderTests.cs b/src/DotNetOpenAuth.Test/OpenId/Provider/OpenIdProviderTests.cs index c15f5b8..4780e37 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Provider/OpenIdProviderTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Provider/OpenIdProviderTests.cs @@ -7,6 +7,9 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { using System; using System.IO; + using System.Net.Http; + using System.Threading; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; @@ -74,60 +77,55 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { /// Verifies the GetRequest method throws outside an HttpContext. /// </summary> [Test, ExpectedException(typeof(InvalidOperationException))] - public void GetRequestNoContext() { + public async Task GetRequestNoContext() { HttpContext.Current = null; - this.provider.GetRequest(); + await this.provider.GetRequestAsync(); } /// <summary> /// Verifies GetRequest throws on null input. /// </summary> [Test, ExpectedException(typeof(ArgumentNullException))] - public void GetRequestNull() { - this.provider.GetRequest(null); + public async Task GetRequestNull() { + await this.provider.GetRequestAsync((HttpRequestMessage)null); } /// <summary> /// Verifies that GetRequest correctly returns the right messages. /// </summary> [Test] - public void GetRequest() { - var httpInfo = new HttpRequestInfo("GET", new Uri("http://someUri")); - Assert.IsNull(this.provider.GetRequest(httpInfo), "An irrelevant request should return null."); + public async Task GetRequest() { + var httpInfo = new HttpRequestMessage(HttpMethod.Get, "http://someUri"); + Assert.IsNull(await this.provider.GetRequestAsync(httpInfo), "An irrelevant request should return null."); var providerDescription = new ProviderEndpointDescription(OPUri, Protocol.Default.Version); // Test some non-empty request scenario. - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - rp.Channel.Request(AssociateRequestRelyingParty.Create(rp.SecuritySettings, providerDescription)); - }, - op => { - IRequest request = op.GetRequest(); + HandleProvider( + async (op, req) => { + IRequest request = await op.GetRequestAsync(req); Assert.IsInstanceOf<AutoResponsiveRequest>(request); - op.Respond(request); + return await op.PrepareResponseAsync(request); }); - coordinator.Run(); + var rp = this.CreateRelyingParty(); + await rp.Channel.RequestAsync(AssociateRequestRelyingParty.Create(rp.SecuritySettings, providerDescription), CancellationToken.None); } [Test] - public void BadRequestsGenerateValidErrorResponses() { - var coordinator = new OpenIdCoordinator( - rp => { - var nonOpenIdMessage = new Mocks.TestDirectedMessage(); - nonOpenIdMessage.Recipient = OPUri; - nonOpenIdMessage.HttpMethods = HttpDeliveryMethods.PostRequest; - MessagingTestBase.GetStandardTestMessage(MessagingTestBase.FieldFill.AllRequired, nonOpenIdMessage); - var response = rp.Channel.Request<DirectErrorResponse>(nonOpenIdMessage); - Assert.IsNotNull(response.ErrorMessage); - Assert.AreEqual(Protocol.Default.Version, response.Version); - }, - AutoProvider); - - coordinator.Run(); + public async Task BadRequestsGenerateValidErrorResponses() { + this.RegisterAutoProvider(); + var rp = this.CreateRelyingParty(); + var nonOpenIdMessage = new Mocks.TestDirectedMessage { + Recipient = OPUri, + HttpMethods = HttpDeliveryMethods.PostRequest + }; + MessagingTestBase.GetStandardTestMessage(MessagingTestBase.FieldFill.AllRequired, nonOpenIdMessage); + var response = await rp.Channel.RequestAsync<DirectErrorResponse>(nonOpenIdMessage, CancellationToken.None); + Assert.IsNotNull(response.ErrorMessage); + Assert.AreEqual(Protocol.Default.Version, response.Version); } [Test, Category("HostASPNET")] - public void BadRequestsGenerateValidErrorResponsesHosted() { + public async Task BadRequestsGenerateValidErrorResponsesHosted() { try { using (AspNetHost host = AspNetHost.CreateHost(TestWebDirectory)) { Uri opEndpoint = new Uri(host.BaseUri, "/OpenIdProviderEndpoint.ashx"); @@ -136,7 +134,7 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { nonOpenIdMessage.Recipient = opEndpoint; nonOpenIdMessage.HttpMethods = HttpDeliveryMethods.PostRequest; MessagingTestBase.GetStandardTestMessage(MessagingTestBase.FieldFill.AllRequired, nonOpenIdMessage); - var response = rp.Channel.Request<DirectErrorResponse>(nonOpenIdMessage); + var response = await rp.Channel.RequestAsync<DirectErrorResponse>(nonOpenIdMessage, CancellationToken.None); Assert.IsNotNull(response.ErrorMessage); } } catch (FileNotFoundException ex) { diff --git a/src/DotNetOpenAuth.Test/OpenId/Provider/PerformanceTests.cs b/src/DotNetOpenAuth.Test/OpenId/Provider/PerformanceTests.cs index e2c719d..1501150 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Provider/PerformanceTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Provider/PerformanceTests.cs @@ -37,10 +37,10 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { public void AssociateDH() { var associateRequest = this.CreateAssociateRequest(OPUri); MeasurePerformance( - () => { - IRequest request = this.provider.GetRequest(associateRequest); - var response = this.provider.PrepareResponse(request); - Assert.IsInstanceOf<AssociateSuccessfulResponse>(response.OriginalMessage); + async delegate { + IRequest request = await this.provider.GetRequestAsync(associateRequest); + var response = await this.provider.PrepareResponseAsync(request); + Assert.IsInstanceOf<AssociateSuccessfulResponse>(((HttpResponseMessageWithOriginal)response).OriginalMessage); }, maximumAllowedUnitTime: 3.5e6f, iterations: 1); @@ -50,10 +50,10 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { public void AssociateClearText() { var associateRequest = this.CreateAssociateRequest(OPUriSsl); // SSL will cause a plaintext association MeasurePerformance( - () => { - IRequest request = this.provider.GetRequest(associateRequest); - var response = this.provider.PrepareResponse(request); - Assert.IsInstanceOf<AssociateSuccessfulResponse>(response.OriginalMessage); + async delegate { + IRequest request = await this.provider.GetRequestAsync(associateRequest); + var response = await this.provider.PrepareResponseAsync(request); + Assert.IsInstanceOf<AssociateSuccessfulResponse>(((HttpResponseMessageWithOriginal)response).OriginalMessage); }, maximumAllowedUnitTime: 1.5e4f, iterations: 1000); @@ -82,11 +82,11 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { this.provider.SecuritySettings); var checkidRequest = this.CreateCheckIdRequest(true); MeasurePerformance( - () => { - var request = (IAuthenticationRequest)this.provider.GetRequest(checkidRequest); + async delegate { + var request = (IAuthenticationRequest)await this.provider.GetRequestAsync(checkidRequest); request.IsAuthenticated = true; - var response = this.provider.PrepareResponse(request); - Assert.IsInstanceOf<PositiveAssertionResponse>(response.OriginalMessage); + var response = await this.provider.PrepareResponseAsync(request); + Assert.IsInstanceOf<PositiveAssertionResponse>(((HttpResponseMessageWithOriginal)response).OriginalMessage); }, maximumAllowedUnitTime: 6.8e4f); } diff --git a/src/DotNetOpenAuth.Test/OpenId/RelyingParty/AuthenticationRequestTests.cs b/src/DotNetOpenAuth.Test/OpenId/RelyingParty/AuthenticationRequestTests.cs index cd72fdb..333169f 100644 --- a/src/DotNetOpenAuth.Test/OpenId/RelyingParty/AuthenticationRequestTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/RelyingParty/AuthenticationRequestTests.cs @@ -10,6 +10,8 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { using System.Collections.Specialized; using System.Linq; using System.Text; + using System.Threading; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; @@ -70,54 +72,51 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// Verifies RedirectingResponse. /// </summary> [Test] - public void CreateRequestMessage() { - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - Identifier id = this.GetMockIdentifier(ProtocolVersion.V20); - IAuthenticationRequest authRequest = rp.CreateRequest(id, this.realm, this.returnTo); - - // Add some callback arguments - authRequest.AddCallbackArguments("a", "b"); - authRequest.AddCallbackArguments(new Dictionary<string, string> { { "c", "d" }, { "e", "f" } }); - - // Assembly an extension request. - ClaimsRequest sregRequest = new ClaimsRequest(); - sregRequest.Nickname = DemandLevel.Request; - authRequest.AddExtension(sregRequest); - - // Construct the actual authentication request message. - var authRequestAccessor = (AuthenticationRequest)authRequest; - var req = authRequestAccessor.CreateRequestMessageTestHook(); - Assert.IsNotNull(req); - - // Verify that callback arguments were included. - NameValueCollection callbackArguments = HttpUtility.ParseQueryString(req.ReturnTo.Query); - Assert.AreEqual("b", callbackArguments["a"]); - Assert.AreEqual("d", callbackArguments["c"]); - Assert.AreEqual("f", callbackArguments["e"]); - - // Verify that extensions were included. - Assert.AreEqual(1, req.Extensions.Count); - Assert.IsTrue(req.Extensions.Contains(sregRequest)); - }, - AutoProvider); - coordinator.Run(); + public async Task CreateRequestMessage() { + this.RegisterAutoProvider(); + var rp = this.CreateRelyingParty(); + Identifier id = this.GetMockIdentifier(ProtocolVersion.V20); + IAuthenticationRequest authRequest = await rp.CreateRequestAsync(id, this.realm, this.returnTo); + + // Add some callback arguments + authRequest.AddCallbackArguments("a", "b"); + authRequest.AddCallbackArguments(new Dictionary<string, string> { { "c", "d" }, { "e", "f" } }); + + // Assembly an extension request. + var sregRequest = new ClaimsRequest(); + sregRequest.Nickname = DemandLevel.Request; + authRequest.AddExtension(sregRequest); + + // Construct the actual authentication request message. + var authRequestAccessor = (AuthenticationRequest)authRequest; + var req = await authRequestAccessor.CreateRequestMessageTestHookAsync(CancellationToken.None); + Assert.IsNotNull(req); + + // Verify that callback arguments were included. + NameValueCollection callbackArguments = HttpUtility.ParseQueryString(req.ReturnTo.Query); + Assert.AreEqual("b", callbackArguments["a"]); + Assert.AreEqual("d", callbackArguments["c"]); + Assert.AreEqual("f", callbackArguments["e"]); + + // Verify that extensions were included. + Assert.AreEqual(1, req.Extensions.Count); + Assert.IsTrue(req.Extensions.Contains(sregRequest)); } /// <summary> /// Verifies that delegating authentication requests are filtered out when configured to do so. /// </summary> [Test] - public void CreateFiltersDelegatingIdentifiers() { + public async Task CreateFiltersDelegatingIdentifiers() { Identifier id = GetMockIdentifier(ProtocolVersion.V20, false, true); var rp = CreateRelyingParty(); // First verify that delegating identifiers work - Assert.IsTrue(AuthenticationRequest.Create(id, rp, this.realm, this.returnTo, false).Any(), "The delegating identifier should have not generated any results."); + Assert.IsTrue((await AuthenticationRequest.CreateAsync(id, rp, this.realm, this.returnTo, false, CancellationToken.None)).Any(), "The delegating identifier should have not generated any results."); // Now disable them and try again. rp.SecuritySettings.RejectDelegatingIdentifiers = true; - Assert.IsFalse(AuthenticationRequest.Create(id, rp, this.realm, this.returnTo, false).Any(), "The delegating identifier should have not generated any results."); + Assert.IsFalse((await AuthenticationRequest.CreateAsync(id, rp, this.realm, this.returnTo, false, CancellationToken.None)).Any(), "The delegating identifier should have not generated any results."); } /// <summary> @@ -135,11 +134,12 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// Verifies that AddCallbackArguments adds query arguments to the return_to URL of the message. /// </summary> [Test] - public void AddCallbackArgument() { + public async Task AddCallbackArgument() { var authRequest = this.CreateAuthenticationRequest(this.claimedId, this.claimedId); Assert.AreEqual(this.returnTo, authRequest.ReturnToUrl); authRequest.AddCallbackArguments("p1", "v1"); - var req = (SignedResponseRequest)authRequest.RedirectingResponse.OriginalMessage; + var response = (HttpResponseMessageWithOriginal)await authRequest.GetRedirectingResponseAsync(CancellationToken.None); + var req = (SignedResponseRequest)response.OriginalMessage; NameValueCollection query = HttpUtility.ParseQueryString(req.ReturnTo.Query); Assert.AreEqual("v1", query["p1"]); } @@ -149,13 +149,14 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// rather than appending them. /// </summary> [Test] - public void AddCallbackArgumentClearsPreviousArgument() { + public async Task AddCallbackArgumentClearsPreviousArgument() { UriBuilder returnToWithArgs = new UriBuilder(this.returnTo); returnToWithArgs.AppendQueryArgs(new Dictionary<string, string> { { "p1", "v1" } }); this.returnTo = returnToWithArgs.Uri; var authRequest = this.CreateAuthenticationRequest(this.claimedId, this.claimedId); authRequest.AddCallbackArguments("p1", "v2"); - var req = (SignedResponseRequest)authRequest.RedirectingResponse.OriginalMessage; + var response = (HttpResponseMessageWithOriginal)await authRequest.GetRedirectingResponseAsync(CancellationToken.None); + var req = (SignedResponseRequest)response.OriginalMessage; NameValueCollection query = HttpUtility.ParseQueryString(req.ReturnTo.Query); Assert.AreEqual("v2", query["p1"]); } @@ -164,11 +165,12 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// Verifies identity-less checkid_* request behavior. /// </summary> [Test] - public void NonIdentityRequest() { + public async Task NonIdentityRequest() { var authRequest = this.CreateAuthenticationRequest(this.claimedId, this.claimedId); authRequest.IsExtensionOnly = true; Assert.IsTrue(authRequest.IsExtensionOnly); - var req = (SignedResponseRequest)authRequest.RedirectingResponse.OriginalMessage; + var response = (HttpResponseMessageWithOriginal)await authRequest.GetRedirectingResponseAsync(CancellationToken.None); + var req = (SignedResponseRequest)response.OriginalMessage; Assert.IsNotInstanceOf<CheckIdRequest>(req, "An unexpected SignedResponseRequest derived type was generated."); } @@ -177,15 +179,15 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// only generate OP Identifier auth requests. /// </summary> [Test] - public void DualIdentifierUsedOnlyAsOPIdentifierForAuthRequest() { + public async Task DualIdentifierUsedOnlyAsOPIdentifierForAuthRequest() { var rp = this.CreateRelyingParty(true); - var results = AuthenticationRequest.Create(GetMockDualIdentifier(), rp, this.realm, this.returnTo, false).ToList(); + var results = (await AuthenticationRequest.CreateAsync(GetMockDualIdentifier(), rp, this.realm, this.returnTo, false, CancellationToken.None)).ToList(); Assert.AreEqual(1, results.Count); Assert.IsTrue(results[0].IsDirectedIdentity); // Also test when dual identiifer support is turned on. rp.SecuritySettings.AllowDualPurposeIdentifiers = true; - results = AuthenticationRequest.Create(GetMockDualIdentifier(), rp, this.realm, this.returnTo, false).ToList(); + results = (await AuthenticationRequest.CreateAsync(GetMockDualIdentifier(), rp, this.realm, this.returnTo, false, CancellationToken.None)).ToList(); Assert.AreEqual(1, results.Count); Assert.IsTrue(results[0].IsDirectedIdentity); } diff --git a/src/DotNetOpenAuth.Test/OpenId/RelyingParty/OpenIdRelyingPartyTests.cs b/src/DotNetOpenAuth.Test/OpenId/RelyingParty/OpenIdRelyingPartyTests.cs index a2a4efa..2d9413d 100644 --- a/src/DotNetOpenAuth.Test/OpenId/RelyingParty/OpenIdRelyingPartyTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/RelyingParty/OpenIdRelyingPartyTests.cs @@ -7,11 +7,18 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { using System; using System.Linq; + using System.Net.Http; + using System.Threading.Tasks; + using System.Web; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Extensions; using DotNetOpenAuth.OpenId.Messages; + using DotNetOpenAuth.OpenId.Provider; using DotNetOpenAuth.OpenId.RelyingParty; + using DotNetOpenAuth.Test.Mocks; + using NUnit.Framework; [TestFixture] @@ -22,12 +29,13 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { } [Test] - public void CreateRequestDumbMode() { + public async Task CreateRequestDumbMode() { var rp = this.CreateRelyingParty(true); Identifier id = this.GetMockIdentifier(ProtocolVersion.V20); - var authReq = rp.CreateRequest(id, RPRealmUri, RPUri); - CheckIdRequest requestMessage = (CheckIdRequest)authReq.RedirectingResponse.OriginalMessage; - Assert.IsNull(requestMessage.AssociationHandle); + var authReq = await rp.CreateRequestAsync(id, RPRealmUri, RPUri); + var httpMessage = await authReq.GetRedirectingResponseAsync(); + var data = HttpUtility.ParseQueryString(httpMessage.GetDirectUriRequest().Query); + Assert.IsNull(data[Protocol.Default.openid.assoc_handle]); } [Test, ExpectedException(typeof(ArgumentNullException))] @@ -46,52 +54,52 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { } [Test] - public void CreateRequest() { + public async Task CreateRequest() { var rp = this.CreateRelyingParty(); StoreAssociation(rp, OPUri, HmacShaAssociation.Create("somehandle", new byte[20], TimeSpan.FromDays(1))); Identifier id = Identifier.Parse(GetMockIdentifier(ProtocolVersion.V20)); - var req = rp.CreateRequest(id, RPRealmUri, RPUri); + var req = await rp.CreateRequestAsync(id, RPRealmUri, RPUri); Assert.IsNotNull(req); } [Test] - public void CreateRequests() { + public async Task CreateRequests() { var rp = this.CreateRelyingParty(); StoreAssociation(rp, OPUri, HmacShaAssociation.Create("somehandle", new byte[20], TimeSpan.FromDays(1))); Identifier id = Identifier.Parse(GetMockIdentifier(ProtocolVersion.V20)); - var requests = rp.CreateRequests(id, RPRealmUri, RPUri); + var requests = await rp.CreateRequestsAsync(id, RPRealmUri, RPUri); Assert.AreEqual(1, requests.Count()); } [Test] - public void CreateRequestsWithEndpointFilter() { + public async Task CreateRequestsWithEndpointFilter() { var rp = this.CreateRelyingParty(); StoreAssociation(rp, OPUri, HmacShaAssociation.Create("somehandle", new byte[20], TimeSpan.FromDays(1))); Identifier id = Identifier.Parse(GetMockIdentifier(ProtocolVersion.V20)); rp.EndpointFilter = opendpoint => true; - var requests = rp.CreateRequests(id, RPRealmUri, RPUri); + var requests = await rp.CreateRequestsAsync(id, RPRealmUri, RPUri); Assert.AreEqual(1, requests.Count()); rp.EndpointFilter = opendpoint => false; - requests = rp.CreateRequests(id, RPRealmUri, RPUri); + requests = await rp.CreateRequestsAsync(id, RPRealmUri, RPUri); Assert.AreEqual(0, requests.Count()); } [Test, ExpectedException(typeof(ProtocolException))] - public void CreateRequestOnNonOpenID() { - Uri nonOpenId = new Uri("http://www.microsoft.com/"); + public async Task CreateRequestOnNonOpenID() { + var nonOpenId = new Uri("http://www.microsoft.com/"); + Handle(nonOpenId).By("<html/>", "text/html"); var rp = this.CreateRelyingParty(); - this.MockResponder.RegisterMockResponse(nonOpenId, "text/html", "<html/>"); - rp.CreateRequest(nonOpenId, RPRealmUri, RPUri); + await rp.CreateRequestAsync(nonOpenId, RPRealmUri, RPUri); } [Test] - public void CreateRequestsOnNonOpenID() { - Uri nonOpenId = new Uri("http://www.microsoft.com/"); + public async Task CreateRequestsOnNonOpenID() { + var nonOpenId = new Uri("http://www.microsoft.com/"); + Handle(nonOpenId).By("<html/>", "text/html"); var rp = this.CreateRelyingParty(); - this.MockResponder.RegisterMockResponse(nonOpenId, "text/html", "<html/>"); - var requests = rp.CreateRequests(nonOpenId, RPRealmUri, RPUri); + var requests = await rp.CreateRequestsAsync(nonOpenId, RPRealmUri, RPUri); Assert.AreEqual(0, requests.Count()); } @@ -100,25 +108,32 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// OPs that are not approved by <see cref="OpenIdRelyingParty.EndpointFilter"/>. /// </summary> [Test] - public void AssertionWithEndpointFilter() { - var coordinator = new OpenIdCoordinator( - rp => { - // register with RP so that id discovery passes - rp.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; + public async Task AssertionWithEndpointFilter() { + var opStore = new StandardProviderApplicationStore(); + Handle(RPUri).By( + async req => { + var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore(), this.HostFactories); // Rig it to always deny the incoming OP rp.EndpointFilter = op => false; // Receive the unsolicited assertion - var response = rp.GetResponse(); + var response = await rp.GetResponseAsync(req); + Assert.That(response, Is.Not.Null); Assert.AreEqual(AuthenticationStatus.Failed, response.Status); - }, - op => { - Identifier id = GetMockIdentifier(ProtocolVersion.V20); - op.SendUnsolicitedAssertion(OPUri, GetMockRealm(false), id, id); - AutoProvider(op); + return new HttpResponseMessage(); }); - coordinator.Run(); + this.RegisterAutoProvider(); + { + var op = new OpenIdProvider(opStore, this.HostFactories); + Identifier id = GetMockIdentifier(ProtocolVersion.V20); + var assertion = await op.PrepareUnsolicitedAssertionAsync(OPUri, GetMockRealm(false), id, id); + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(assertion.Headers.Location)) { + response.EnsureSuccessStatusCode(); + } + } + } } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/RelyingParty/PositiveAuthenticationResponseTests.cs b/src/DotNetOpenAuth.Test/OpenId/RelyingParty/PositiveAuthenticationResponseTests.cs index 91318f5..6ce74fd 100644 --- a/src/DotNetOpenAuth.Test/OpenId/RelyingParty/PositiveAuthenticationResponseTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/RelyingParty/PositiveAuthenticationResponseTests.cs @@ -7,6 +7,9 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { using System; using System.Collections.Generic; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Extensions.SimpleRegistration; @@ -29,12 +32,12 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// Verifies good, positive assertions are accepted. /// </summary> [Test] - public void Valid() { + public async Task Valid() { PositiveAssertionResponse assertion = this.GetPositiveAssertion(); ClaimsResponse extension = new ClaimsResponse(); assertion.Extensions.Add(extension); var rp = CreateRelyingParty(); - var authResponse = new PositiveAuthenticationResponse(assertion, rp); + var authResponse = await PositiveAuthenticationResponse.CreateAsync(assertion, rp, CancellationToken.None); Assert.AreEqual(AuthenticationStatus.Authenticated, authResponse.Status); Assert.IsNull(authResponse.Exception); Assert.AreEqual((string)assertion.ClaimedIdentifier, (string)authResponse.ClaimedIdentifier); @@ -49,13 +52,13 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// Verifies that discovery verification of a positive assertion can match a dual identifier. /// </summary> [Test] - public void DualIdentifierMatchesInAssertionVerification() { + public async Task DualIdentifierMatchesInAssertionVerification() { PositiveAssertionResponse assertion = this.GetPositiveAssertion(true); ClaimsResponse extension = new ClaimsResponse(); assertion.Extensions.Add(extension); var rp = CreateRelyingParty(); rp.SecuritySettings.AllowDualPurposeIdentifiers = true; - new PositiveAuthenticationResponse(assertion, rp); // this will throw if it fails to find a match + await PositiveAuthenticationResponse.CreateAsync(assertion, rp, CancellationToken.None); // this will throw if it fails to find a match } /// <summary> @@ -63,12 +66,12 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// if the default settings are in place. /// </summary> [Test, ExpectedException(typeof(ProtocolException))] - public void DualIdentifierNoMatchInAssertionVerificationByDefault() { + public async Task DualIdentifierNoMatchInAssertionVerificationByDefault() { PositiveAssertionResponse assertion = this.GetPositiveAssertion(true); ClaimsResponse extension = new ClaimsResponse(); assertion.Extensions.Add(extension); var rp = CreateRelyingParty(); - new PositiveAuthenticationResponse(assertion, rp); // this will throw if it fails to find a match + await PositiveAuthenticationResponse.CreateAsync(assertion, rp, CancellationToken.None); // this will throw if it fails to find a match } /// <summary> @@ -77,11 +80,11 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// that the OP has no authority to assert positively regarding. /// </summary> [Test, ExpectedException(typeof(ProtocolException))] - public void SpoofedClaimedIdDetectionSolicited() { + public async Task SpoofedClaimedIdDetectionSolicited() { PositiveAssertionResponse assertion = this.GetPositiveAssertion(); assertion.ProviderEndpoint = new Uri("http://rogueOP"); var rp = CreateRelyingParty(); - var authResponse = new PositiveAuthenticationResponse(assertion, rp); + var authResponse = await PositiveAuthenticationResponse.CreateAsync(assertion, rp, CancellationToken.None); Assert.AreEqual(AuthenticationStatus.Failed, authResponse.Status); } @@ -90,22 +93,22 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// Cdentifiers when RequireSsl is set to true. /// </summary> [Test, ExpectedException(typeof(ProtocolException))] - public void InsecureIdentifiersRejectedWithRequireSsl() { + public async Task InsecureIdentifiersRejectedWithRequireSsl() { PositiveAssertionResponse assertion = this.GetPositiveAssertion(); var rp = CreateRelyingParty(); rp.SecuritySettings.RequireSsl = true; - var authResponse = new PositiveAuthenticationResponse(assertion, rp); + var authResponse = await PositiveAuthenticationResponse.CreateAsync(assertion, rp, CancellationToken.None); } [Test] - public void GetCallbackArguments() { + public async Task GetCallbackArguments() { PositiveAssertionResponse assertion = this.GetPositiveAssertion(); var rp = CreateRelyingParty(); UriBuilder returnToBuilder = new UriBuilder(assertion.ReturnTo); returnToBuilder.AppendQueryArgs(new Dictionary<string, string> { { "a", "b" } }); assertion.ReturnTo = returnToBuilder.Uri; - var authResponse = new PositiveAuthenticationResponse(assertion, rp); + var authResponse = await PositiveAuthenticationResponse.CreateAsync(assertion, rp, CancellationToken.None); // First pretend that the return_to args were signed. assertion.ReturnToParametersSignatureValidated = true; @@ -124,18 +127,18 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// Verifies that certain problematic claimed identifiers pass through to the RP response correctly. /// </summary> [Test] - public void ProblematicClaimedId() { + public async Task ProblematicClaimedId() { var providerEndpoint = new ProviderEndpointDescription(OpenIdTestBase.OPUri, Protocol.Default.Version); string claimed_id = BaseMockUri + "a./b."; var se = IdentifierDiscoveryResult.CreateForClaimedIdentifier(claimed_id, claimed_id, providerEndpoint, null, null); - UriIdentifier identityUri = (UriIdentifier)se.ClaimedIdentifier; - var mockId = new MockIdentifier(identityUri, this.MockResponder, new IdentifierDiscoveryResult[] { se }); + var identityUri = (UriIdentifier)se.ClaimedIdentifier; + this.RegisterMockXrdsResponse(se); + var rp = this.CreateRelyingParty(); var positiveAssertion = this.GetPositiveAssertion(); - positiveAssertion.ClaimedIdentifier = mockId; - positiveAssertion.LocalIdentifier = mockId; - var rp = CreateRelyingParty(); - var authResponse = new PositiveAuthenticationResponse(positiveAssertion, rp); + positiveAssertion.ClaimedIdentifier = claimed_id; + positiveAssertion.LocalIdentifier = claimed_id; + var authResponse = await PositiveAuthenticationResponse.CreateAsync(positiveAssertion, rp, CancellationToken.None); Assert.AreEqual(AuthenticationStatus.Authenticated, authResponse.Status); Assert.AreEqual(claimed_id, authResponse.ClaimedIdentifier.ToString()); } diff --git a/src/DotNetOpenAuth.Test/OpenId/UriIdentifierTests.cs b/src/DotNetOpenAuth.Test/OpenId/UriIdentifierTests.cs index b6a52a7..465fd23 100644 --- a/src/DotNetOpenAuth.Test/OpenId/UriIdentifierTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/UriIdentifierTests.cs @@ -8,6 +8,7 @@ namespace DotNetOpenAuth.Test.OpenId { using System; using System.Linq; using System.Net; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; @@ -243,7 +244,7 @@ namespace DotNetOpenAuth.Test.OpenId { } [Test] - public void TryRequireSslAdjustsIdentifier() { + public async Task TryRequireSslAdjustsIdentifier() { Identifier secureId; // Try Parse and ctor without explicit scheme var id = Identifier.Parse("www.yahoo.com"); @@ -263,13 +264,13 @@ namespace DotNetOpenAuth.Test.OpenId { Assert.IsFalse(id.TryRequireSsl(out secureId)); Assert.IsTrue(secureId.IsDiscoverySecureEndToEnd, "Although the TryRequireSsl failed, the created identifier should retain the Ssl status."); Assert.AreEqual("http://www.yahoo.com/", secureId.ToString()); - Assert.AreEqual(0, Discover(secureId).Count(), "Since TryRequireSsl failed, the created Identifier should never discover anything."); + Assert.AreEqual(0, (await DiscoverAsync(secureId)).Count(), "Since TryRequireSsl failed, the created Identifier should never discover anything."); id = new UriIdentifier("http://www.yahoo.com"); Assert.IsFalse(id.TryRequireSsl(out secureId)); Assert.IsTrue(secureId.IsDiscoverySecureEndToEnd); Assert.AreEqual("http://www.yahoo.com/", secureId.ToString()); - Assert.AreEqual(0, Discover(secureId).Count()); + Assert.AreEqual(0, (await DiscoverAsync(secureId)).Count()); } /// <summary> diff --git a/src/DotNetOpenAuth.Test/TestBase.cs b/src/DotNetOpenAuth.Test/TestBase.cs index 92adafa..4758b7d 100644 --- a/src/DotNetOpenAuth.Test/TestBase.cs +++ b/src/DotNetOpenAuth.Test/TestBase.cs @@ -7,7 +7,12 @@ namespace DotNetOpenAuth.Test { using System; using System.IO; + using System.Net; + using System.Net.Http; + using System.Net.Http.Headers; using System.Reflection; + using System.Threading; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Messaging.Reflection; using DotNetOpenAuth.OAuth.Messages; @@ -16,6 +21,8 @@ namespace DotNetOpenAuth.Test { using log4net; using NUnit.Framework; + using log4net.Config; + /// <summary> /// The base class that all test classes inherit from. /// </summary> @@ -48,14 +55,17 @@ namespace DotNetOpenAuth.Test { get { return this.messageDescriptions; } } + internal MockingHostFactories HostFactories; + /// <summary> /// The TestInitialize method for the test cases. /// </summary> [SetUp] public virtual void SetUp() { - log4net.Config.XmlConfigurator.Configure(Assembly.GetExecutingAssembly().GetManifestResourceStream("DotNetOpenAuth.Test.Logging.config")); + XmlConfigurator.Configure(Assembly.GetExecutingAssembly().GetManifestResourceStream("DotNetOpenAuth.Test.Logging.config")); MessageBase.LowSecurityMode = true; this.messageDescriptions = new MessageDescriptionCollection(); + this.HostFactories = new MockingHostFactories(); SetMockHttpContext(); } @@ -64,10 +74,10 @@ namespace DotNetOpenAuth.Test { /// </summary> [TearDown] public virtual void Cleanup() { - log4net.LogManager.Shutdown(); + LogManager.Shutdown(); } - internal static Stats MeasurePerformance(Action action, float maximumAllowedUnitTime, int samples = 10, int iterations = 100, string name = null) { + internal static Stats MeasurePerformance(Func<Task> action, float maximumAllowedUnitTime, int samples = 10, int iterations = 100, string name = null) { if (!PerformanceTestUtilities.IsOptimized(typeof(OpenIdRelyingParty).Assembly)) { Assert.Inconclusive("Unoptimized code."); } @@ -75,7 +85,7 @@ namespace DotNetOpenAuth.Test { var timer = new MultiSampleCodeTimer(samples, iterations); Stats stats; using (new HighPerformance()) { - stats = timer.Measure(name ?? TestContext.CurrentContext.Test.FullName, action); + stats = timer.Measure(name ?? TestContext.CurrentContext.Test.FullName, () => action().Wait()); } stats.AdjustForScale(PerformanceTestUtilities.Baseline.Median); @@ -103,5 +113,49 @@ namespace DotNetOpenAuth.Test { new HttpRequest("mock", "http://mock", "mock"), new HttpResponse(new StringWriter())); } + + protected internal Handler Handle(string uri) { + return new Handler(this, new Uri(uri)); + } + + protected internal Handler Handle(Uri uri) { + return new Handler(this, uri); + } + + protected internal struct Handler { + private TestBase test; + + internal Handler(TestBase test, Uri uri) + : this() { + this.test = test; + this.Uri = uri; + } + + internal Uri Uri { get; private set; } + + internal Func<HttpRequestMessage, Task<HttpResponseMessage>> MessageHandler { get; private set; } + + internal void By(Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> handler) { + this.test.HostFactories.Handlers[this.Uri] = req => handler(req, CancellationToken.None); + } + + internal void By(Func<HttpRequestMessage, Task<HttpResponseMessage>> handler) { + this.test.HostFactories.Handlers[this.Uri] = handler; + } + + internal void By(Func<HttpRequestMessage, HttpResponseMessage> handler) { + this.By(req => Task.FromResult(handler(req))); + } + + internal void By(string responseContent, string contentType, HttpStatusCode statusCode = HttpStatusCode.OK) { + this.By( + req => { + var response = new HttpResponseMessage(statusCode); + response.Content = new StringContent(responseContent); + response.Content.Headers.ContentType = new MediaTypeHeaderValue(contentType); + return response; + }); + } + } } } diff --git a/src/DotNetOpenAuth.Test/packages.config b/src/DotNetOpenAuth.Test/packages.config index 38e70d7..cab5101 100644 --- a/src/DotNetOpenAuth.Test/packages.config +++ b/src/DotNetOpenAuth.Test/packages.config @@ -3,5 +3,5 @@ <package id="Microsoft.Net.Http" version="2.0.20710.0" targetFramework="net45" /> <package id="Moq" version="4.0.10827" targetFramework="net45" /> <package id="NUnit" version="2.6.1" targetFramework="net45" /> - <package id="Validation" version="2.0.1.12362" targetFramework="net45" /> + <package id="Validation" version="2.0.2.13022" targetFramework="net45" /> </packages>
\ No newline at end of file |