diff options
Diffstat (limited to 'src')
29 files changed, 936 insertions, 1042 deletions
diff --git a/src/DotNetOpenAuth.OAuth.Consumer/OAuth/Consumer.cs b/src/DotNetOpenAuth.OAuth.Consumer/OAuth/Consumer.cs index 560e536..93fbaac 100644 --- a/src/DotNetOpenAuth.OAuth.Consumer/OAuth/Consumer.cs +++ b/src/DotNetOpenAuth.OAuth.Consumer/OAuth/Consumer.cs @@ -39,21 +39,24 @@ namespace DotNetOpenAuth.OAuth { } /// <summary> - /// Initializes a new instance of the <see cref="Consumer"/> class. + /// Initializes a new instance of the <see cref="Consumer" /> class. /// </summary> /// <param name="consumerKey">The consumer key.</param> /// <param name="consumerSecret">The consumer secret.</param> /// <param name="serviceProvider">The service provider.</param> /// <param name="temporaryCredentialStorage">The temporary credential storage.</param> + /// <param name="hostFactories">The host factories.</param> public Consumer( string consumerKey, string consumerSecret, ServiceProviderDescription serviceProvider, - ITemporaryCredentialStorage temporaryCredentialStorage) { + ITemporaryCredentialStorage temporaryCredentialStorage, + IHostFactories hostFactories = null) { this.ConsumerKey = consumerKey; this.ConsumerSecret = consumerSecret; this.ServiceProvider = serviceProvider; this.TemporaryCredentialStorage = temporaryCredentialStorage; + this.HostFactories = hostFactories ?? new DefaultOAuthHostFactories(); } /// <summary> diff --git a/src/DotNetOpenAuth.OAuth2.Client/OAuth2/WebServerClient.cs b/src/DotNetOpenAuth.OAuth2.Client/OAuth2/WebServerClient.cs index 5560fd5..2b5a80a 100644 --- a/src/DotNetOpenAuth.OAuth2.Client/OAuth2/WebServerClient.cs +++ b/src/DotNetOpenAuth.OAuth2.Client/OAuth2/WebServerClient.cs @@ -84,7 +84,7 @@ namespace DotNetOpenAuth.OAuth2 { /// <returns> /// The authorization request. /// </returns> - public async Task<HttpResponseMessage> PrepareRequestUserAuthorizationAsync(IAuthorizationState authorization, CancellationToken cancellationToken) { + public async Task<HttpResponseMessage> PrepareRequestUserAuthorizationAsync(IAuthorizationState authorization, CancellationToken cancellationToken = default(CancellationToken)) { Requires.NotNull(authorization, "authorization"); RequiresEx.ValidState(authorization.Callback != null || (HttpContext.Current != null && HttpContext.Current.Request != null), MessagingStrings.HttpContextRequired); RequiresEx.ValidState(!string.IsNullOrEmpty(this.ClientIdentifier), Strings.RequiredPropertyNotYetPreset, "ClientIdentifier"); diff --git a/src/DotNetOpenAuth.Test/CoordinatorBase.cs b/src/DotNetOpenAuth.Test/CoordinatorBase.cs deleted file mode 100644 index 12b03f8..0000000 --- a/src/DotNetOpenAuth.Test/CoordinatorBase.cs +++ /dev/null @@ -1,43 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="CoordinatorBase.cs" company="Outercurve Foundation"> -// Copyright (c) Outercurve Foundation. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.Test { - using System; - using System.Collections.Generic; - using System.Net; - using System.Net.Http; - using System.Net.Http.Headers; - using System.Threading; - using System.Threading.Tasks; - using DotNetOpenAuth.Messaging; - using DotNetOpenAuth.OpenId.Provider; - using DotNetOpenAuth.OpenId.RelyingParty; - using DotNetOpenAuth.Test.Mocks; - using DotNetOpenAuth.Test.OpenId; - - using NUnit.Framework; - using Validation; - - using System.Linq; - - internal class CoordinatorBase { - private Func<IHostFactories, CancellationToken, Task> driver; - - internal CoordinatorBase(Func<IHostFactories, CancellationToken, Task> driver, params TestBase.Handler[] handlers) { - Requires.NotNull(driver, "driver"); - Requires.NotNull(handlers, "handlers"); - - this.driver = driver; - this.HostFactories = new MockingHostFactories(handlers.ToList()); - } - - internal MockingHostFactories HostFactories { get; set; } - - protected internal virtual async Task RunAsync(CancellationToken cancellationToken = default(CancellationToken)) { - await this.driver(this.HostFactories, cancellationToken); - } - } -} diff --git a/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj b/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj index d777f50..07ee8a9 100644 --- a/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj +++ b/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj @@ -104,7 +104,6 @@ </ItemGroup> <ItemGroup> <Compile Include="Configuration\SectionTests.cs" /> - <Compile Include="CoordinatorBase.cs" /> <Compile Include="Hosting\AspNetHost.cs" /> <Compile Include="Hosting\HostingTests.cs" /> <Compile Include="Hosting\HttpHost.cs" /> diff --git a/src/DotNetOpenAuth.Test/MockingHostFactories.cs b/src/DotNetOpenAuth.Test/MockingHostFactories.cs index a7d24e4..16a746a 100644 --- a/src/DotNetOpenAuth.Test/MockingHostFactories.cs +++ b/src/DotNetOpenAuth.Test/MockingHostFactories.cs @@ -49,7 +49,7 @@ namespace DotNetOpenAuth.Test { protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { foreach (var handler in this.handlers) { if (handler.Uri.IsBaseOf(request.RequestUri) && handler.Uri.AbsolutePath == request.RequestUri.AbsolutePath) { - var response = await handler.MessageHandler(this.hostFactories, request, cancellationToken); + var response = await handler.MessageHandler(request); if (response != null) { return response; } diff --git a/src/DotNetOpenAuth.Test/Mocks/MockHttpRequest.cs b/src/DotNetOpenAuth.Test/Mocks/MockHttpRequest.cs index dbf57bc..87faac2 100644 --- a/src/DotNetOpenAuth.Test/Mocks/MockHttpRequest.cs +++ b/src/DotNetOpenAuth.Test/Mocks/MockHttpRequest.cs @@ -22,7 +22,7 @@ namespace DotNetOpenAuth.Test.Mocks { using Validation; internal static class MockHttpRequest { - internal static TestBase.Handler RegisterMockXrdsResponse(IdentifierDiscoveryResult endpoint) { + internal static void RegisterMockXrdsResponse(this TestBase test, IdentifierDiscoveryResult endpoint) { Requires.NotNull(endpoint, "endpoint"); string identityUri; @@ -32,10 +32,10 @@ namespace DotNetOpenAuth.Test.Mocks { identityUri = endpoint.UserSuppliedIdentifier ?? endpoint.ClaimedIdentifier; } - return RegisterMockXrdsResponse(new Uri(identityUri), new IdentifierDiscoveryResult[] { endpoint }); + RegisterMockXrdsResponse(test, new Uri(identityUri), new IdentifierDiscoveryResult[] { endpoint }); } - internal static TestBase.Handler RegisterMockXrdsResponse(Uri respondingUri, IEnumerable<IdentifierDiscoveryResult> endpoints) { + internal static void RegisterMockXrdsResponse(this TestBase test, Uri respondingUri, IEnumerable<IdentifierDiscoveryResult> endpoints) { Requires.NotNull(endpoints, "endpoints"); var xrds = new StringBuilder(); @@ -67,10 +67,10 @@ namespace DotNetOpenAuth.Test.Mocks { </XRD> </xrds:XRDS>"); - return TestBase.Handle(respondingUri).By(xrds.ToString(), ContentTypes.Xrds); + test.Handle(respondingUri).By(xrds.ToString(), ContentTypes.Xrds); } - internal static TestBase.Handler RegisterMockXrdsResponse(UriIdentifier directedIdentityAssignedIdentifier, IdentifierDiscoveryResult providerEndpoint) { + internal static void RegisterMockXrdsResponse(this TestBase test, UriIdentifier directedIdentityAssignedIdentifier, IdentifierDiscoveryResult providerEndpoint) { IdentifierDiscoveryResult identityEndpoint = IdentifierDiscoveryResult.CreateForClaimedIdentifier( directedIdentityAssignedIdentifier, directedIdentityAssignedIdentifier, @@ -78,16 +78,16 @@ namespace DotNetOpenAuth.Test.Mocks { new ProviderEndpointDescription(providerEndpoint.ProviderEndpoint, providerEndpoint.Capabilities), 10, 10); - return RegisterMockXrdsResponse(identityEndpoint); + RegisterMockXrdsResponse(test, identityEndpoint); } - internal static TestBase.Handler RegisterMockXrdsResponse(string embeddedResourcePath, out Identifier id) { + internal static void RegisterMockXrdsResponse(this TestBase test, string embeddedResourcePath, out Identifier id) { id = new Uri(new Uri("http://localhost/"), embeddedResourcePath); - return TestBase.Handle(new Uri(id)) - .By(OpenIdTestBase.LoadEmbeddedFile(embeddedResourcePath), "application/xrds+xml"); + test.Handle(new Uri(id)) + .By(OpenIdTestBase.LoadEmbeddedFile(embeddedResourcePath), "application/xrds+xml"); } - internal static TestBase.Handler RegisterMockRPDiscovery(bool ssl) { + internal static void RegisterMockRPDiscovery(this TestBase test, bool ssl) { string template = @"<xrds:XRDS xmlns:xrds='xri://$xrds' xmlns:openid='http://openid.net/xmlns/1.0' xmlns='xri://$xrd*($v*2.0)'> <XRD> <Service priority='10'> @@ -104,38 +104,34 @@ namespace DotNetOpenAuth.Test.Mocks { HttpUtility.HtmlEncode(OpenIdTestBase.RPRealmUri.AbsoluteUri), HttpUtility.HtmlEncode(OpenIdTestBase.RPRealmUriSsl.AbsoluteUri)); - return new TestBase.Handler(ssl ? OpenIdTestBase.RPRealmUriSsl : OpenIdTestBase.RPRealmUri) + test.Handle(ssl ? OpenIdTestBase.RPRealmUriSsl : OpenIdTestBase.RPRealmUri) .By(xrds, ContentTypes.Xrds); } - internal static TestBase.Handler RegisterMockRedirect(Uri origin, Uri redirectLocation) { + internal static void RegisterMockRedirect(this TestBase test, Uri origin, Uri redirectLocation) { var response = new HttpResponseMessage(HttpStatusCode.Redirect); response.Headers.Location = redirectLocation; - return new TestBase.Handler(origin).By(req => response); + test.Handle(origin).By(req => response); } - internal static TestBase.Handler[] RegisterMockXrdsResponses( - IEnumerable<KeyValuePair<string, string>> urlXrdsPairs) { + internal static void RegisterMockXrdsResponses(this TestBase test, IEnumerable<KeyValuePair<string, string>> urlXrdsPairs) { Requires.NotNull(urlXrdsPairs, "urlXrdsPairs"); - var results = new List<TestBase.Handler>(); foreach (var keyValuePair in urlXrdsPairs) { - results.Add(TestBase.Handle(new Uri(keyValuePair.Key)).By(keyValuePair.Value, ContentTypes.Xrds)); + test.Handle(new Uri(keyValuePair.Key)).By(keyValuePair.Value, ContentTypes.Xrds); } - - return results.ToArray(); } - internal static TestBase.Handler RegisterMockResponse(Uri url, string contentType, string content) { - return TestBase.Handle(url).By(content, contentType); + internal static void RegisterMockResponse(this TestBase test, Uri url, string contentType, string content) { + test.Handle(url).By(content, contentType); } - internal static TestBase.Handler RegisterMockResponse(Uri requestUri, Uri responseUri, string contentType, string content) { - return RegisterMockResponse(requestUri, responseUri, contentType, null, content); + internal static void RegisterMockResponse(this TestBase test, Uri requestUri, Uri responseUri, string contentType, string content) { + RegisterMockResponse(test, requestUri, responseUri, contentType, null, content); } - internal static TestBase.Handler RegisterMockResponse(Uri requestUri, Uri responseUri, string contentType, WebHeaderCollection headers, string content) { - return TestBase.Handle(requestUri).By(req => { + internal static void RegisterMockResponse(this TestBase test, Uri requestUri, Uri responseUri, string contentType, WebHeaderCollection headers, string content) { + test.Handle(requestUri).By(req => { var response = new HttpResponseMessage(); response.CopyHeadersFrom(headers); response.Content = new StringContent(content, Encoding.Default, contentType); diff --git a/src/DotNetOpenAuth.Test/OAuth/AppendixScenarios.cs b/src/DotNetOpenAuth.Test/OAuth/AppendixScenarios.cs index 2d58b1d..0d05c5e 100644 --- a/src/DotNetOpenAuth.Test/OAuth/AppendixScenarios.cs +++ b/src/DotNetOpenAuth.Test/OAuth/AppendixScenarios.cs @@ -38,61 +38,57 @@ namespace DotNetOpenAuth.Test.OAuth { tokenManager.AddConsumer(consumerDescription); var sp = new ServiceProvider(serviceHostDescription, tokenManager); - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var consumer = new Consumer( - consumerDescription.ConsumerKey, - consumerDescription.ConsumerSecret, - serviceDescription, - new MemoryTemporaryCredentialStorage()); - consumer.HostFactories = hostFactories; - var authorizeUrl = await consumer.RequestUserAuthorizationAsync(new Uri("http://printer.example.com/request_token_ready")); - Uri authorizeResponseUri; - using (var httpClient = hostFactories.CreateHttpClient()) { - using (var response = await httpClient.GetAsync(authorizeUrl, ct)) { - Assert.That(response.StatusCode, Is.EqualTo(HttpStatusCode.Redirect)); - authorizeResponseUri = response.Headers.Location; - } - } + Handle(serviceDescription.TemporaryCredentialsRequestEndpoint).By( + async (request, ct) => { + var requestTokenMessage = await sp.ReadTokenRequestAsync(request, ct); + return await sp.Channel.PrepareResponseAsync(sp.PrepareUnauthorizedTokenMessage(requestTokenMessage)); + }); + Handle(serviceDescription.ResourceOwnerAuthorizationEndpoint).By( + async (request, ct) => { + var authRequest = await sp.ReadAuthorizationRequestAsync(request, ct); + ((InMemoryTokenManager)sp.TokenManager).AuthorizeRequestToken(authRequest.RequestToken); + return await sp.Channel.PrepareResponseAsync(sp.PrepareAuthorizationResponse(authRequest)); + }); + Handle(serviceDescription.TokenRequestEndpoint).By( + async (request, ct) => { + var accessRequest = await sp.ReadAccessTokenRequestAsync(request, ct); + return await sp.Channel.PrepareResponseAsync(sp.PrepareAccessTokenMessage(accessRequest), ct); + }); + Handle(accessPhotoEndpoint).By( + async (request, ct) => { + string accessToken = (await sp.ReadProtectedResourceAuthorizationAsync(request)).AccessToken; + Assert.That(accessToken, Is.Not.Null.And.Not.Empty); + var responseMessage = new HttpResponseMessage { Content = new ByteArrayContent(new byte[] { 0x33, 0x66 }), }; + responseMessage.Content.Headers.ContentType = new MediaTypeHeaderValue("image/jpeg"); + return responseMessage; + }); - var accessTokenResponse = await consumer.ProcessUserAuthorizationAsync(authorizeResponseUri, ct); - Assert.That(accessTokenResponse, Is.Not.Null); + var consumer = new Consumer( + consumerDescription.ConsumerKey, + consumerDescription.ConsumerSecret, + serviceDescription, + new MemoryTemporaryCredentialStorage()); + consumer.HostFactories = this.HostFactories; + var authorizeUrl = await consumer.RequestUserAuthorizationAsync(new Uri("http://printer.example.com/request_token_ready")); + Uri authorizeResponseUri; + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(authorizeUrl)) { + Assert.That(response.StatusCode, Is.EqualTo(HttpStatusCode.Redirect)); + authorizeResponseUri = response.Headers.Location; + } + } - using (var authorizingClient = consumer.CreateHttpClient(accessTokenResponse.AccessToken)) { - using (var protectedPhoto = await authorizingClient.GetAsync(accessPhotoEndpoint, ct)) { - Assert.That(protectedPhoto, Is.Not.Null); - protectedPhoto.EnsureSuccessStatusCode(); - Assert.That("image/jpeg", Is.EqualTo(protectedPhoto.Content.Headers.ContentType.MediaType)); - Assert.That(protectedPhoto.Content.Headers.ContentLength, Is.Not.EqualTo(0)); - } - } - }, - Handle(serviceDescription.TemporaryCredentialsRequestEndpoint).By( - async (request, ct) => { - var requestTokenMessage = await sp.ReadTokenRequestAsync(request, ct); - return await sp.Channel.PrepareResponseAsync(sp.PrepareUnauthorizedTokenMessage(requestTokenMessage)); - }), - Handle(serviceDescription.ResourceOwnerAuthorizationEndpoint).By( - async (request, ct) => { - var authRequest = await sp.ReadAuthorizationRequestAsync(request, ct); - ((InMemoryTokenManager)sp.TokenManager).AuthorizeRequestToken(authRequest.RequestToken); - return await sp.Channel.PrepareResponseAsync(sp.PrepareAuthorizationResponse(authRequest)); - }), - Handle(serviceDescription.TokenRequestEndpoint).By( - async (request, ct) => { - var accessRequest = await sp.ReadAccessTokenRequestAsync(request, ct); - return await sp.Channel.PrepareResponseAsync(sp.PrepareAccessTokenMessage(accessRequest), ct); - }), - Handle(accessPhotoEndpoint).By( - async (request, ct) => { - string accessToken = (await sp.ReadProtectedResourceAuthorizationAsync(request)).AccessToken; - Assert.That(accessToken, Is.Not.Null.And.Not.Empty); - var responseMessage = new HttpResponseMessage { Content = new ByteArrayContent(new byte[] { 0x33, 0x66 }), }; - responseMessage.Content.Headers.ContentType = new MediaTypeHeaderValue("image/jpeg"); - return responseMessage; - })); + var accessTokenResponse = await consumer.ProcessUserAuthorizationAsync(authorizeResponseUri); + Assert.That(accessTokenResponse, Is.Not.Null); - await coordinator.RunAsync(); + using (var authorizingClient = consumer.CreateHttpClient(accessTokenResponse.AccessToken)) { + using (var protectedPhoto = await authorizingClient.GetAsync(accessPhotoEndpoint)) { + Assert.That(protectedPhoto, Is.Not.Null); + protectedPhoto.EnsureSuccessStatusCode(); + Assert.That("image/jpeg", Is.EqualTo(protectedPhoto.Content.Headers.ContentType.MediaType)); + Assert.That(protectedPhoto.Content.Headers.ContentLength, Is.Not.EqualTo(0)); + } + } } } } diff --git a/src/DotNetOpenAuth.Test/OAuth/ChannelElements/OAuthChannelTests.cs b/src/DotNetOpenAuth.Test/OAuth/ChannelElements/OAuthChannelTests.cs index a1db784..834aba2 100644 --- a/src/DotNetOpenAuth.Test/OAuth/ChannelElements/OAuthChannelTests.cs +++ b/src/DotNetOpenAuth.Test/OAuth/ChannelElements/OAuthChannelTests.cs @@ -301,17 +301,8 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { HttpMethods = scheme, }; - await RunAsync( - async (hostFactories, CancellationToken) => { - IProtocolMessage response = await this.channel.RequestAsync(request, CancellationToken.None); - Assert.IsNotNull(response); - Assert.IsInstanceOf<TestMessage>(response); - TestMessage responseMessage = (TestMessage)response; - Assert.AreEqual(request.Age, responseMessage.Age); - Assert.AreEqual(request.Name, responseMessage.Name); - Assert.AreEqual(request.Location, responseMessage.Location); - }, - Handle(request.Location).By(async (req, ct) => { + Handle(request.Location).By( + async (req, ct) => { Assert.IsNotNull(req); Assert.AreEqual(MessagingUtilities.GetHttpVerb(scheme), req.Method); var incomingMessage = (await this.channel.ReadFromRequestAsync(req, CancellationToken.None)) as TestMessage; @@ -330,7 +321,15 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { var rawResponse = new HttpResponseMessage(); rawResponse.Content = new StringContent(MessagingUtilities.CreateQueryString(responseFields)); return rawResponse; - })); + }); + + IProtocolMessage response = await this.channel.RequestAsync(request, CancellationToken.None); + Assert.IsNotNull(response); + Assert.IsInstanceOf<TestMessage>(response); + TestMessage responseMessage = (TestMessage)response; + Assert.AreEqual(request.Age, responseMessage.Age); + Assert.AreEqual(request.Name, responseMessage.Name); + Assert.AreEqual(request.Location, responseMessage.Location); } private async Task ParameterizedReceiveTestAsync(HttpDeliveryMethods scheme) { diff --git a/src/DotNetOpenAuth.Test/OAuth2/AuthorizationServerTests.cs b/src/DotNetOpenAuth.Test/OAuth2/AuthorizationServerTests.cs index b2f2666..3302db7 100644 --- a/src/DotNetOpenAuth.Test/OAuth2/AuthorizationServerTests.cs +++ b/src/DotNetOpenAuth.Test/OAuth2/AuthorizationServerTests.cs @@ -28,66 +28,60 @@ namespace DotNetOpenAuth.Test.OAuth2 { /// </summary> [Test] public async Task ErrorResponseTest() { - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var request = new AccessTokenAuthorizationCodeRequestC(AuthorizationServerDescription) { ClientIdentifier = ClientId, ClientSecret = ClientSecret, AuthorizationCode = "foo" }; - var client = new UserAgentClient(AuthorizationServerDescription, hostFactories: hostFactories); - var response = await client.Channel.RequestAsync<AccessTokenFailedResponse>(request, CancellationToken.None); - Assert.That(response.Error, Is.Not.Null.And.Not.Empty); - Assert.That(response.Error, Is.EqualTo(Protocol.AccessTokenRequestErrorCodes.InvalidRequest)); - }, - Handle(AuthorizationServerDescription.TokenEndpoint).By(async (req, ct) => { + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { var server = new AuthorizationServer(AuthorizationServerMock); return await server.HandleTokenRequestAsync(req, ct); - })); - await coordinator.RunAsync(); + }); + var request = new AccessTokenAuthorizationCodeRequestC(AuthorizationServerDescription) { ClientIdentifier = ClientId, ClientSecret = ClientSecret, AuthorizationCode = "foo" }; + var client = new UserAgentClient(AuthorizationServerDescription, hostFactories: this.HostFactories); + var response = await client.Channel.RequestAsync<AccessTokenFailedResponse>(request, CancellationToken.None); + Assert.That(response.Error, Is.Not.Null.And.Not.Empty); + Assert.That(response.Error, Is.EqualTo(Protocol.AccessTokenRequestErrorCodes.InvalidRequest)); } [Test] public async Task DecodeRefreshToken() { var refreshTokenSource = new TaskCompletionSource<string>(); - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var client = new WebServerClient(AuthorizationServerDescription); - try { - var authState = new AuthorizationState(TestScopes) { Callback = ClientCallback, }; - var authRedirectResponse = await client.PrepareRequestUserAuthorizationAsync(authState, ct); - Uri authCompleteUri; - using (var httpClient = hostFactories.CreateHttpClient()) { - using (var response = await httpClient.GetAsync(authRedirectResponse.Headers.Location)) { - response.EnsureSuccessStatusCode(); - authCompleteUri = response.Headers.Location; - } - } - - var authCompleteRequest = new HttpRequestMessage(HttpMethod.Get, authCompleteUri); - authCompleteRequest.Headers.Add("Cookie", string.Join("; ", authRedirectResponse.Headers.GetValues("Set-Cookie"))); - var result = await client.ProcessUserAuthorizationAsync(authCompleteRequest, ct); - Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); - Assert.That(result.RefreshToken, Is.Not.Null.And.Not.Empty); - refreshTokenSource.SetResult(result.RefreshToken); - } catch { - refreshTokenSource.TrySetCanceled(); + Handle(AuthorizationServerDescription.AuthorizationEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(AuthorizationServerMock); + var request = await server.ReadAuthorizationRequestAsync(req, ct); + Assert.That(request, Is.Not.Null); + var response = server.PrepareApproveAuthorizationRequest(request, ResourceOwnerUsername); + return await server.Channel.PrepareResponseAsync(response); + }); + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(AuthorizationServerMock); + var response = await server.HandleTokenRequestAsync(req, ct); + var authorization = server.DecodeRefreshToken(refreshTokenSource.Task.Result); + Assert.That(authorization, Is.Not.Null); + Assert.That(authorization.User, Is.EqualTo(ResourceOwnerUsername)); + return response; + }); + + var client = new WebServerClient(AuthorizationServerDescription); + try { + var authState = new AuthorizationState(TestScopes) { Callback = ClientCallback, }; + var authRedirectResponse = await client.PrepareRequestUserAuthorizationAsync(authState); + Uri authCompleteUri; + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(authRedirectResponse.Headers.Location)) { + response.EnsureSuccessStatusCode(); + authCompleteUri = response.Headers.Location; } - }, - Handle(AuthorizationServerDescription.AuthorizationEndpoint).By( - async (req, ct) => { - var server = new AuthorizationServer(AuthorizationServerMock); - var request = await server.ReadAuthorizationRequestAsync(req, ct); - Assert.That(request, Is.Not.Null); - var response = server.PrepareApproveAuthorizationRequest(request, ResourceOwnerUsername); - return await server.Channel.PrepareResponseAsync(response); - }), - Handle(AuthorizationServerDescription.TokenEndpoint).By( - async (req, ct) => { - var server = new AuthorizationServer(AuthorizationServerMock); - var response = await server.HandleTokenRequestAsync(req, ct); - var authorization = server.DecodeRefreshToken(refreshTokenSource.Task.Result); - Assert.That(authorization, Is.Not.Null); - Assert.That(authorization.User, Is.EqualTo(ResourceOwnerUsername)); - return response; - })); - await coordinator.RunAsync(); + } + + var authCompleteRequest = new HttpRequestMessage(HttpMethod.Get, authCompleteUri); + authCompleteRequest.Headers.Add("Cookie", string.Join("; ", authRedirectResponse.Headers.GetValues("Set-Cookie"))); + var result = await client.ProcessUserAuthorizationAsync(authCompleteRequest); + Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); + Assert.That(result.RefreshToken, Is.Not.Null.And.Not.Empty); + refreshTokenSource.SetResult(result.RefreshToken); + } catch { + refreshTokenSource.TrySetCanceled(); + } } [Test] @@ -104,21 +98,15 @@ namespace DotNetOpenAuth.Test.OAuth2 { return response; }); - // AuthorizationServerDescription, - //authServerMock.Object, - //new WebServerClient(AuthorizationServerDescription), - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var client = new WebServerClient(AuthorizationServerDescription, hostFactories: hostFactories); - var result = await client.ExchangeUserCredentialForTokenAsync(ResourceOwnerUsername, ResourceOwnerPassword, clientRequestedScopes, ct); - Assert.That(result.Scope, Is.EquivalentTo(serverOverriddenScopes)); - }, - Handle(AuthorizationServerDescription.TokenEndpoint).By( - async (req, ct) => { - var server = new AuthorizationServer(authServerMock.Object); - return await server.HandleTokenRequestAsync(req, ct); - })); - await coordinator.RunAsync(); + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(authServerMock.Object); + return await server.HandleTokenRequestAsync(req, ct); + }); + + var client = new WebServerClient(AuthorizationServerDescription, hostFactories: this.HostFactories); + var result = await client.ExchangeUserCredentialForTokenAsync(ResourceOwnerUsername, ResourceOwnerPassword, clientRequestedScopes); + Assert.That(result.Scope, Is.EquivalentTo(serverOverriddenScopes)); } [Test] @@ -131,18 +119,16 @@ namespace DotNetOpenAuth.Test.OAuth2 { Assert.That(req.UserName, Is.EqualTo(ResourceOwnerUsername)); return response; }); - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var client = new WebServerClient(AuthorizationServerDescription, hostFactories: hostFactories); - var result = await client.ExchangeUserCredentialForTokenAsync(ResourceOwnerUsername, ResourceOwnerPassword, TestScopes, ct); - Assert.That(result.AccessToken, Is.Not.Null); - }, - Handle(AuthorizationServerDescription.TokenEndpoint).By( - async (req, ct) => { - var server = new AuthorizationServer(authServerMock.Object); - return await server.HandleTokenRequestAsync(req, ct); - })); - await coordinator.RunAsync(); + + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(authServerMock.Object); + return await server.HandleTokenRequestAsync(req, ct); + }); + + var client = new WebServerClient(AuthorizationServerDescription, hostFactories: this.HostFactories); + var result = await client.ExchangeUserCredentialForTokenAsync(ResourceOwnerUsername, ResourceOwnerPassword, TestScopes); + Assert.That(result.AccessToken, Is.Not.Null); } [Test] @@ -154,18 +140,16 @@ namespace DotNetOpenAuth.Test.OAuth2 { Assert.That(req.UserName, Is.Null); return new AutomatedAuthorizationCheckResponse(req, true); }); - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var client = new WebServerClient(AuthorizationServerDescription, hostFactories: hostFactories); - var result = await client.GetClientAccessTokenAsync(TestScopes, ct); - Assert.That(result.AccessToken, Is.Not.Null); - }, - Handle(AuthorizationServerDescription.TokenEndpoint).By( - async (req, ct) => { - var server = new AuthorizationServer(authServerMock.Object); - return await server.HandleTokenRequestAsync(req, ct); - })); - await coordinator.RunAsync(); + + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(authServerMock.Object); + return await server.HandleTokenRequestAsync(req, ct); + }); + + var client = new WebServerClient(AuthorizationServerDescription, hostFactories: this.HostFactories); + var result = await client.GetClientAccessTokenAsync(TestScopes); + Assert.That(result.AccessToken, Is.Not.Null); } [Test] @@ -177,40 +161,38 @@ namespace DotNetOpenAuth.Test.OAuth2 { Assert.That(req.User, Is.EqualTo(ResourceOwnerUsername)); return true; }); - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var client = new WebServerClient(AuthorizationServerDescription, hostFactories: hostFactories); - var authState = new AuthorizationState(TestScopes) { - Callback = ClientCallback, - }; - var authRedirectResponse = await client.PrepareRequestUserAuthorizationAsync(authState, ct); - Uri authCompleteUri; - using (var httpClient = hostFactories.CreateHttpClient()) { - using (var response = await httpClient.GetAsync(authRedirectResponse.Headers.Location)) { - response.EnsureSuccessStatusCode(); - authCompleteUri = response.Headers.Location; - } - } - var authCompleteRequest = new HttpRequestMessage(HttpMethod.Get, authCompleteUri); - var result = await client.ProcessUserAuthorizationAsync(authCompleteRequest, ct); - Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); - Assert.That(result.RefreshToken, Is.Not.Null.And.Not.Empty); - }, - Handle(AuthorizationServerDescription.TokenEndpoint).By( - async (req, ct) => { - var server = new AuthorizationServer(authServerMock.Object); - var request = await server.ReadAuthorizationRequestAsync(req, ct); - Assert.That(request, Is.Not.Null); - var response = server.PrepareApproveAuthorizationRequest(request, ResourceOwnerUsername); - return await server.Channel.PrepareResponseAsync(response); - }), - Handle(AuthorizationServerDescription.TokenEndpoint).By( - async (req, ct) => { - var server = new AuthorizationServer(authServerMock.Object); - return await server.HandleTokenRequestAsync(req, ct); - })); - await coordinator.RunAsync(); + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(authServerMock.Object); + var request = await server.ReadAuthorizationRequestAsync(req, ct); + Assert.That(request, Is.Not.Null); + var response = server.PrepareApproveAuthorizationRequest(request, ResourceOwnerUsername); + return await server.Channel.PrepareResponseAsync(response); + }); + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(authServerMock.Object); + return await server.HandleTokenRequestAsync(req, ct); + }); + + var client = new WebServerClient(AuthorizationServerDescription, hostFactories: this.HostFactories); + var authState = new AuthorizationState(TestScopes) { + Callback = ClientCallback, + }; + var authRedirectResponse = await client.PrepareRequestUserAuthorizationAsync(authState); + Uri authCompleteUri; + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(authRedirectResponse.Headers.Location)) { + response.EnsureSuccessStatusCode(); + authCompleteUri = response.Headers.Location; + } + } + + var authCompleteRequest = new HttpRequestMessage(HttpMethod.Get, authCompleteUri); + var result = await client.ProcessUserAuthorizationAsync(authCompleteRequest); + Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); + Assert.That(result.RefreshToken, Is.Not.Null.And.Not.Empty); } [Test] @@ -226,21 +208,19 @@ namespace DotNetOpenAuth.Test.OAuth2 { response.ApprovedScope.UnionWith(serverOverriddenScopes); return response; }); - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var client = new WebServerClient(AuthorizationServerDescription, hostFactories: hostFactories); - var authState = new AuthorizationState(TestScopes) { - Callback = ClientCallback, - }; - var result = await client.GetClientAccessTokenAsync(clientRequestedScopes, ct); - Assert.That(result.Scope, Is.EquivalentTo(serverOverriddenScopes)); - }, - Handle(AuthorizationServerDescription.TokenEndpoint).By( - async (req, ct) => { - var server = new AuthorizationServer(authServerMock.Object); - return await server.HandleTokenRequestAsync(req, ct); - })); - await coordinator.RunAsync(); + + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(authServerMock.Object); + return await server.HandleTokenRequestAsync(req, ct); + }); + + var client = new WebServerClient(AuthorizationServerDescription, hostFactories: this.HostFactories); + ////var authState = new AuthorizationState(TestScopes) { + //// Callback = ClientCallback, + ////}; + var result = await client.GetClientAccessTokenAsync(clientRequestedScopes); + Assert.That(result.Scope, Is.EquivalentTo(serverOverriddenScopes)); } } } diff --git a/src/DotNetOpenAuth.Test/OAuth2/ResourceServerTests.cs b/src/DotNetOpenAuth.Test/OAuth2/ResourceServerTests.cs index d4618cf..55b333d 100644 --- a/src/DotNetOpenAuth.Test/OAuth2/ResourceServerTests.cs +++ b/src/DotNetOpenAuth.Test/OAuth2/ResourceServerTests.cs @@ -91,19 +91,18 @@ namespace DotNetOpenAuth.Test.OAuth2 { authServer.Setup( a => a.CheckAuthorizeClientCredentialsGrant(It.Is<IAccessTokenRequest>(d => d.ClientIdentifier == ClientId && MessagingUtilities.AreEquivalent(d.Scope, TestScopes)))) .Returns<IAccessTokenRequest>(req => new AutomatedAuthorizationCheckResponse(req, true)); - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var client = new WebServerClient(AuthorizationServerDescription); - var authState = await client.GetClientAccessTokenAsync(TestScopes, ct); - Assert.That(authState.AccessToken, Is.Not.Null.And.Not.Empty); - Assert.That(authState.RefreshToken, Is.Null); - accessToken = authState.AccessToken; - }, - Handle(AuthorizationServerDescription.TokenEndpoint).By(async (req, ct) => { + + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { var server = new AuthorizationServer(authServer.Object); return await server.HandleTokenRequestAsync(req, ct); - })); - await coordinator.RunAsync(); + }); + + var client = new WebServerClient(AuthorizationServerDescription, hostFactories: this.HostFactories); + var authState = await client.GetClientAccessTokenAsync(TestScopes); + Assert.That(authState.AccessToken, Is.Not.Null.And.Not.Empty); + Assert.That(authState.RefreshToken, Is.Null); + accessToken = authState.AccessToken; return accessToken; } diff --git a/src/DotNetOpenAuth.Test/OAuth2/UserAgentClientAuthorizeTests.cs b/src/DotNetOpenAuth.Test/OAuth2/UserAgentClientAuthorizeTests.cs index eaa0d44..c911416 100644 --- a/src/DotNetOpenAuth.Test/OAuth2/UserAgentClientAuthorizeTests.cs +++ b/src/DotNetOpenAuth.Test/OAuth2/UserAgentClientAuthorizeTests.cs @@ -10,6 +10,7 @@ namespace DotNetOpenAuth.Test.OAuth2 { using System.Linq; using System.Net.Http; using System.Text; + using System.Threading; using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; @@ -23,79 +24,74 @@ namespace DotNetOpenAuth.Test.OAuth2 { public class UserAgentClientAuthorizeTests : OAuth2TestBase { [Test] public async Task AuthorizationCodeGrant() { - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var client = new UserAgentClient(AuthorizationServerDescription); - var authState = new AuthorizationState(TestScopes) { - Callback = ClientCallback, - }; - var request = client.PrepareRequestUserAuthorization(authState); - Assert.AreEqual(EndUserAuthorizationResponseType.AuthorizationCode, request.ResponseType); - var authRequestRedirect = await client.Channel.PrepareResponseAsync(request, ct); - Uri authRequestResponse; - using (var httpClient = hostFactories.CreateHttpClient()) { - using (var httpResponse = await httpClient.GetAsync(authRequestRedirect.Headers.Location, ct)) { - authRequestResponse = httpResponse.Headers.Location; - } + Handle(AuthorizationServerDescription.AuthorizationEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(AuthorizationServerMock); + var request = await server.ReadAuthorizationRequestAsync(req, ct); + Assert.That(request, Is.Not.Null); + var response = server.PrepareApproveAuthorizationRequest(request, ResourceOwnerUsername); + return await server.Channel.PrepareResponseAsync(response, ct); + }); + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(AuthorizationServerMock); + return await server.HandleTokenRequestAsync(req, ct); + }); + { + var client = new UserAgentClient(AuthorizationServerDescription); + var authState = new AuthorizationState(TestScopes) { Callback = ClientCallback, }; + var request = client.PrepareRequestUserAuthorization(authState); + Assert.AreEqual(EndUserAuthorizationResponseType.AuthorizationCode, request.ResponseType); + var authRequestRedirect = await client.Channel.PrepareResponseAsync(request); + Uri authRequestResponse; + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var httpResponse = await httpClient.GetAsync(authRequestRedirect.Headers.Location)) { + authRequestResponse = httpResponse.Headers.Location; } - var incoming = await client.Channel.ReadFromRequestAsync(new HttpRequestMessage(HttpMethod.Get, authRequestResponse), ct); - var result = await client.ProcessUserAuthorizationAsync(authState, incoming, ct); - Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); - Assert.That(result.RefreshToken, Is.Not.Null.And.Not.Empty); - }, - Handle(AuthorizationServerDescription.AuthorizationEndpoint).By( - async (req, ct) => { - var server = new AuthorizationServer(AuthorizationServerMock); - var request = await server.ReadAuthorizationRequestAsync(req, ct); - Assert.That(request, Is.Not.Null); - var response = server.PrepareApproveAuthorizationRequest(request, ResourceOwnerUsername); - return await server.Channel.PrepareResponseAsync(response, ct); - }), - Handle(AuthorizationServerDescription.TokenEndpoint).By( - async (req, ct) => { - var server = new AuthorizationServer(AuthorizationServerMock); - return await server.HandleTokenRequestAsync(req, ct); - })); - await coordinator.RunAsync(); + } + var incoming = + await + client.Channel.ReadFromRequestAsync( + new HttpRequestMessage(HttpMethod.Get, authRequestResponse), CancellationToken.None); + var result = await client.ProcessUserAuthorizationAsync(authState, incoming, CancellationToken.None); + Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); + Assert.That(result.RefreshToken, Is.Not.Null.And.Not.Empty); + } } [Test] public async Task ImplicitGrant() { var coordinatorClient = new UserAgentClient(AuthorizationServerDescription); - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var client = new UserAgentClient(AuthorizationServerDescription); - var authState = new AuthorizationState(TestScopes) { - Callback = ClientCallback, - }; - var request = client.PrepareRequestUserAuthorization(authState, implicitResponseType: true); - Assert.That(request.ResponseType, Is.EqualTo(EndUserAuthorizationResponseType.AccessToken)); - var authRequestRedirect = await client.Channel.PrepareResponseAsync(request, ct); - Uri authRequestResponse; - using (var httpClient = hostFactories.CreateHttpClient()) { - using (var httpResponse = await httpClient.GetAsync(authRequestRedirect.Headers.Location, ct)) { - authRequestResponse = httpResponse.Headers.Location; - } + coordinatorClient.ClientCredentialApplicator = null; // implicit grant clients don't need a secret. + Handle(AuthorizationServerDescription.AuthorizationEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(AuthorizationServerMock); + var request = await server.ReadAuthorizationRequestAsync(req, ct); + Assert.That(request, Is.Not.Null); + IAccessTokenRequest accessTokenRequest = (EndUserAuthorizationImplicitRequest)request; + Assert.That(accessTokenRequest.ClientAuthenticated, Is.False); + var response = server.PrepareApproveAuthorizationRequest(request, ResourceOwnerUsername); + return await server.Channel.PrepareResponseAsync(response, ct); + }); + { + var client = new UserAgentClient(AuthorizationServerDescription); + var authState = new AuthorizationState(TestScopes) { Callback = ClientCallback, }; + var request = client.PrepareRequestUserAuthorization(authState, implicitResponseType: true); + Assert.That(request.ResponseType, Is.EqualTo(EndUserAuthorizationResponseType.AccessToken)); + var authRequestRedirect = await client.Channel.PrepareResponseAsync(request); + Uri authRequestResponse; + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var httpResponse = await httpClient.GetAsync(authRequestRedirect.Headers.Location)) { + authRequestResponse = httpResponse.Headers.Location; } + } - var incoming = await client.Channel.ReadFromRequestAsync(new HttpRequestMessage(HttpMethod.Get, authRequestResponse), ct); - var result = await client.ProcessUserAuthorizationAsync(authState, incoming, ct); - Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); - Assert.That(result.RefreshToken, Is.Null); - }, - Handle(AuthorizationServerDescription.AuthorizationEndpoint).By( - async (req, ct) => { - var server = new AuthorizationServer(AuthorizationServerMock); - var request = await server.ReadAuthorizationRequestAsync(req, ct); - Assert.That(request, Is.Not.Null); - IAccessTokenRequest accessTokenRequest = (EndUserAuthorizationImplicitRequest)request; - Assert.That(accessTokenRequest.ClientAuthenticated, Is.False); - var response = server.PrepareApproveAuthorizationRequest(request, ResourceOwnerUsername); - return await server.Channel.PrepareResponseAsync(response, ct); - })); - - coordinatorClient.ClientCredentialApplicator = null; // implicit grant clients don't need a secret. - await coordinator.RunAsync(); + var incoming = + await client.Channel.ReadFromRequestAsync(new HttpRequestMessage(HttpMethod.Get, authRequestResponse), CancellationToken.None); + var result = await client.ProcessUserAuthorizationAsync(authState, incoming, CancellationToken.None); + Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); + Assert.That(result.RefreshToken, Is.Null); + } } } } diff --git a/src/DotNetOpenAuth.Test/OAuth2/WebServerClientAuthorizeTests.cs b/src/DotNetOpenAuth.Test/OAuth2/WebServerClientAuthorizeTests.cs index 7b5c32e..2befd35 100644 --- a/src/DotNetOpenAuth.Test/OAuth2/WebServerClientAuthorizeTests.cs +++ b/src/DotNetOpenAuth.Test/OAuth2/WebServerClientAuthorizeTests.cs @@ -11,6 +11,7 @@ namespace DotNetOpenAuth.Test.OAuth2 { using System.Net; using System.Net.Http; using System.Text; + using System.Threading; using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OAuth2; @@ -23,37 +24,35 @@ namespace DotNetOpenAuth.Test.OAuth2 { public class WebServerClientAuthorizeTests : OAuth2TestBase { [Test] public async Task AuthorizationCodeGrant() { - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var client = new WebServerClient(AuthorizationServerDescription); - var authState = new AuthorizationState(TestScopes) { - Callback = ClientCallback, - }; - var authRequestRedirect = await client.PrepareRequestUserAuthorizationAsync(authState, ct); - Uri authRequestResponse; - using (var httpClient = hostFactories.CreateHttpClient()) { - using (var httpResponse = await httpClient.GetAsync(authRequestRedirect.Headers.Location, ct)) { - authRequestResponse = httpResponse.Headers.Location; - } - } + Handle(AuthorizationServerDescription.AuthorizationEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(AuthorizationServerMock); + var request = await server.ReadAuthorizationRequestAsync(req, ct); + Assert.That(request, Is.Not.Null); + var response = server.PrepareApproveAuthorizationRequest(request, ResourceOwnerUsername); + return await server.Channel.PrepareResponseAsync(response, ct); + }); + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { + var server = new AuthorizationServer(AuthorizationServerMock); + return await server.HandleTokenRequestAsync(req, ct); + }); + + var client = new WebServerClient(AuthorizationServerDescription); + var authState = new AuthorizationState(TestScopes) { + Callback = ClientCallback, + }; + var authRequestRedirect = await client.PrepareRequestUserAuthorizationAsync(authState); + Uri authRequestResponse; + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var httpResponse = await httpClient.GetAsync(authRequestRedirect.Headers.Location)) { + authRequestResponse = httpResponse.Headers.Location; + } + } - var result = await client.ProcessUserAuthorizationAsync(new HttpRequestMessage(HttpMethod.Get, authRequestResponse), ct); - Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); - Assert.That(result.RefreshToken, Is.Not.Null.And.Not.Empty); - }, - Handle(AuthorizationServerDescription.AuthorizationEndpoint).By( - async (req, ct) => { - var server = new AuthorizationServer(AuthorizationServerMock); - var request = await server.ReadAuthorizationRequestAsync(req, ct); - Assert.That(request, Is.Not.Null); - var response = server.PrepareApproveAuthorizationRequest(request, ResourceOwnerUsername); - return await server.Channel.PrepareResponseAsync(response, ct); - }), - Handle(AuthorizationServerDescription.TokenEndpoint).By(async (req, ct) => { - var server = new AuthorizationServer(AuthorizationServerMock); - return await server.HandleTokenRequestAsync(req, ct); - })); - await coordinator.RunAsync(); + var result = await client.ProcessUserAuthorizationAsync(new HttpRequestMessage(HttpMethod.Get, authRequestResponse), CancellationToken.None); + Assert.That(result.AccessToken, Is.Not.Null.And.Not.Empty); + Assert.That(result.RefreshToken, Is.Not.Null.And.Not.Empty); } [Theory] @@ -69,22 +68,19 @@ namespace DotNetOpenAuth.Test.OAuth2 { MessagingUtilities.AreEquivalent(d.Scope, TestScopes)))).Returns(true); } - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var client = new WebServerClient(AuthorizationServerDescription); - if (anonymousClient) { - client.ClientIdentifier = null; - } + Handle(AuthorizationServerDescription.TokenEndpoint).By(async (req, ct) => { + var server = new AuthorizationServer(authHostMock.Object); + return await server.HandleTokenRequestAsync(req, ct); + }); - var authState = await client.ExchangeUserCredentialForTokenAsync(ResourceOwnerUsername, ResourceOwnerPassword, TestScopes, ct); - Assert.That(authState.AccessToken, Is.Not.Null.And.Not.Empty); - Assert.That(authState.RefreshToken, Is.Not.Null.And.Not.Empty); - }, - Handle(AuthorizationServerDescription.TokenEndpoint).By(async (req, ct) => { - var server = new AuthorizationServer(authHostMock.Object); - return await server.HandleTokenRequestAsync(req, ct); - })); - await coordinator.RunAsync(); + var client = new WebServerClient(AuthorizationServerDescription); + if (anonymousClient) { + client.ClientIdentifier = null; + } + + var authState = await client.ExchangeUserCredentialForTokenAsync(ResourceOwnerUsername, ResourceOwnerPassword, TestScopes); + Assert.That(authState.AccessToken, Is.Not.Null.And.Not.Empty); + Assert.That(authState.RefreshToken, Is.Not.Null.And.Not.Empty); } [Test] @@ -96,18 +92,15 @@ namespace DotNetOpenAuth.Test.OAuth2 { authServer.Setup( a => a.CheckAuthorizeClientCredentialsGrant(It.Is<IAccessTokenRequest>(d => d.ClientIdentifier == ClientId && MessagingUtilities.AreEquivalent(d.Scope, TestScopes)))) .Returns<IAccessTokenRequest>(req => new AutomatedAuthorizationCheckResponse(req, true)); - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var client = new WebServerClient(AuthorizationServerDescription); - var authState = await client.GetClientAccessTokenAsync(TestScopes, ct); - Assert.That(authState.AccessToken, Is.Not.Null.And.Not.Empty); - Assert.That(authState.RefreshToken, Is.Null); - }, - Handle(AuthorizationServerDescription.TokenEndpoint).By(async (req, ct) => { + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { var server = new AuthorizationServer(authServer.Object); return await server.HandleTokenRequestAsync(req, ct); - })); - await coordinator.RunAsync(); + }); + var client = new WebServerClient(AuthorizationServerDescription); + var authState = await client.GetClientAccessTokenAsync(TestScopes); + Assert.That(authState.AccessToken, Is.Not.Null.And.Not.Empty); + Assert.That(authState.RefreshToken, Is.Null); } [Test] @@ -124,17 +117,15 @@ namespace DotNetOpenAuth.Test.OAuth2 { response.ApprovedScope.ResetContents(approvedScopes); return response; }); - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var client = new WebServerClient(AuthorizationServerDescription); - var authState = await client.GetClientAccessTokenAsync(TestScopes, ct); - Assert.That(authState.Scope, Is.EquivalentTo(approvedScopes)); - }, - Handle(AuthorizationServerDescription.TokenEndpoint).By(async (req, ct) => { + Handle(AuthorizationServerDescription.TokenEndpoint).By( + async (req, ct) => { var server = new AuthorizationServer(authServer.Object); return await server.HandleTokenRequestAsync(req, ct); - })); - await coordinator.RunAsync(); + }); + + var client = new WebServerClient(AuthorizationServerDescription); + var authState = await client.GetClientAccessTokenAsync(TestScopes); + Assert.That(authState.Scope, Is.EquivalentTo(approvedScopes)); } [Test] diff --git a/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs b/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs index 99fa56c..eb60866 100644 --- a/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs @@ -43,22 +43,20 @@ namespace DotNetOpenAuth.Test.OpenId { [Test] public async Task AssociateDiffieHellmanOverHttps() { Protocol protocol = Protocol.V20; - var coordinator = new CoordinatorBase( - RelyingPartyDriver(async (rp, ct) => { - // We have to formulate the associate request manually, - // since the DNOI RP won't voluntarily use DH on HTTPS. - var request = new AssociateDiffieHellmanRequest(protocol.Version, OPUri) { - AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256, - SessionType = protocol.Args.SessionType.DH_SHA256 - }; - request.InitializeRequest(); - var response = await rp.Channel.RequestAsync<AssociateSuccessfulResponse>(request, CancellationToken.None); - Assert.IsNotNull(response); - Assert.AreEqual(request.AssociationType, response.AssociationType); - Assert.AreEqual(request.SessionType, response.SessionType); - }), - AutoProvider); - await coordinator.RunAsync(); + this.RegisterAutoProvider(); + var rp = this.CreateRelyingParty(); + + // We have to formulate the associate request manually, + // since the DNOI RP won't voluntarily use DH on HTTPS. + var request = new AssociateDiffieHellmanRequest(protocol.Version, OPUri) { + AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256, + SessionType = protocol.Args.SessionType.DH_SHA256 + }; + request.InitializeRequest(); + var response = await rp.Channel.RequestAsync<AssociateSuccessfulResponse>(request, CancellationToken.None); + Assert.IsNotNull(response); + Assert.AreEqual(request.AssociationType, response.AssociationType); + Assert.AreEqual(request.SessionType, response.SessionType); } /// <summary> @@ -73,28 +71,23 @@ namespace DotNetOpenAuth.Test.OpenId { // and to more carefully observe the Provider-side of things to make sure that both // the OP and RP are behaving as expected. int providerAttemptCount = 0; - var coordinator = new CoordinatorBase( - RelyingPartyDriver(async (rp, ct) => { - var opDescription = new ProviderEndpointDescription(OPUri, protocol.Version); - Association association = await rp.AssociationManager.GetOrCreateAssociationAsync(opDescription, ct); - Assert.IsNotNull(association, "Association failed to be created."); - Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA1, association.GetAssociationType(protocol)); - }), - HandleProvider(async (op, request, ct) => { + HandleProvider( + async (op, request) => { op.SecuritySettings.MaximumHashBitLength = 160; // Force OP to reject HMAC-SHA256 switch (++providerAttemptCount) { case 1: // Receive initial request for an HMAC-SHA256 association. - var req = (AutoResponsiveRequest)await op.GetRequestAsync(request, ct); + var req = (AutoResponsiveRequest)await op.GetRequestAsync(request); var associateRequest = (AssociateRequest)req.RequestMessage; Assert.That(associateRequest, Is.Not.Null); Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA256, associateRequest.AssociationType); // Ensure that the response is a suggestion that the RP try again with HMAC-SHA1 - var renegotiateResponse = (AssociateUnsuccessfulResponse)await req.GetResponseMessageAsyncTestHook(ct); + var renegotiateResponse = + (AssociateUnsuccessfulResponse)await req.GetResponseMessageAsyncTestHook(CancellationToken.None); Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA1, renegotiateResponse.AssociationType); - return await op.PrepareResponseAsync(req, ct); + return await op.PrepareResponseAsync(req); case 2: // Receive second attempt request for an HMAC-SHA1 association. @@ -103,15 +96,20 @@ namespace DotNetOpenAuth.Test.OpenId { Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA1, associateRequest.AssociationType); // Ensure that the response is a success response. - var successResponse = (AssociateSuccessfulResponse)await req.GetResponseMessageAsyncTestHook(ct); + var successResponse = + (AssociateSuccessfulResponse)await req.GetResponseMessageAsyncTestHook(CancellationToken.None); Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA1, successResponse.AssociationType); - return await op.PrepareResponseAsync(req, ct); + return await op.PrepareResponseAsync(req); default: throw Assumes.NotReachable(); } - })); - await coordinator.RunAsync(); + }); + var rp = this.CreateRelyingParty(); + var opDescription = new ProviderEndpointDescription(OPUri, protocol.Version); + Association association = await rp.AssociationManager.GetOrCreateAssociationAsync(opDescription, CancellationToken.None); + Assert.IsNotNull(association, "Association failed to be created."); + Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA1, association.GetAssociationType(protocol)); } /// <summary> @@ -123,18 +121,16 @@ namespace DotNetOpenAuth.Test.OpenId { [Test] public async Task OPRejectsHttpNoEncryptionAssociateRequests() { Protocol protocol = Protocol.V20; - var coordinator = new CoordinatorBase( - RelyingPartyDriver(async (rp, ct) => { - // We have to formulate the associate request manually, - // since the DNOA RP won't voluntarily suggest no encryption at all. - var request = new AssociateUnencryptedRequestNoSslCheck(protocol.Version, OPUri); - request.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256; - request.SessionType = protocol.Args.SessionType.NoEncryption; - var response = await rp.Channel.RequestAsync<DirectErrorResponse>(request, ct); - Assert.IsNotNull(response); - }), - AutoProvider); - await coordinator.RunAsync(); + this.RegisterAutoProvider(); + var rp = this.CreateRelyingParty(); + + // We have to formulate the associate request manually, + // since the DNOA RP won't voluntarily suggest no encryption at all. + var request = new AssociateUnencryptedRequestNoSslCheck(protocol.Version, OPUri); + request.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256; + request.SessionType = protocol.Args.SessionType.NoEncryption; + var response = await rp.Channel.RequestAsync<DirectErrorResponse>(request, CancellationToken.None); + Assert.IsNotNull(response); } /// <summary> @@ -144,21 +140,19 @@ namespace DotNetOpenAuth.Test.OpenId { [Test] public async Task OPRejectsMismatchingAssociationAndSessionTypes() { Protocol protocol = Protocol.V20; - var coordinator = new CoordinatorBase( - RelyingPartyDriver(async (rp, ct) => { - // We have to formulate the associate request manually, - // since the DNOI RP won't voluntarily mismatch the association and session types. - AssociateDiffieHellmanRequest request = new AssociateDiffieHellmanRequest(protocol.Version, new Uri("https://Provider")); - request.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256; - request.SessionType = protocol.Args.SessionType.DH_SHA1; - request.InitializeRequest(); - var response = await rp.Channel.RequestAsync<AssociateUnsuccessfulResponse>(request, ct); - Assert.IsNotNull(response); - Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA1, response.AssociationType); - Assert.AreEqual(protocol.Args.SessionType.DH_SHA1, response.SessionType); - }), - AutoProvider); - await coordinator.RunAsync(); + this.RegisterAutoProvider(); + var rp = this.CreateRelyingParty(); + + // We have to formulate the associate request manually, + // since the DNOI RP won't voluntarily mismatch the association and session types. + var request = new AssociateDiffieHellmanRequest(protocol.Version, new Uri("https://Provider")); + request.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256; + request.SessionType = protocol.Args.SessionType.DH_SHA1; + request.InitializeRequest(); + var response = await rp.Channel.RequestAsync<AssociateUnsuccessfulResponse>(request, CancellationToken.None); + Assert.IsNotNull(response); + Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA1, response.AssociationType); + Assert.AreEqual(protocol.Args.SessionType.DH_SHA1, response.SessionType); } /// <summary> @@ -167,22 +161,20 @@ namespace DotNetOpenAuth.Test.OpenId { [Test] public async Task RPRejectsUnrecognizedAssociationType() { Protocol protocol = Protocol.V20; - var coordinator = new CoordinatorBase( - RelyingPartyDriver(async (rp, ct) => { - var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), ct); - Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); - }), - HandleProvider(async (op, req, ct) => { + HandleProvider( + async (op, req) => { // Receive initial request. - var request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, ct); + var request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, CancellationToken.None); // Send a response that suggests a foreign association type. var renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); renegotiateResponse.AssociationType = "HMAC-UNKNOWN"; renegotiateResponse.SessionType = "DH-UNKNOWN"; - return await op.Channel.PrepareResponseAsync(renegotiateResponse, ct); - })); - await coordinator.RunAsync(); + return await op.Channel.PrepareResponseAsync(renegotiateResponse); + }); + var rp = this.CreateRelyingParty(); + var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), CancellationToken.None); + Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); } /// <summary> @@ -194,22 +186,21 @@ namespace DotNetOpenAuth.Test.OpenId { [Test] public async Task RPRejectsUnencryptedSuggestion() { Protocol protocol = Protocol.V20; - var coordinator = new CoordinatorBase( - RelyingPartyDriver(async (rp, ct) => { - var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), ct); - Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); - }), - HandleProvider(async (op, req, ct) => { + this.HandleProvider( + async (op, req) => { // Receive initial request. - var request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, ct); + var request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, CancellationToken.None); // Send a response that suggests a no encryption. var renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); renegotiateResponse.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA1; renegotiateResponse.SessionType = protocol.Args.SessionType.NoEncryption; - return await op.Channel.PrepareResponseAsync(renegotiateResponse, ct); - })); - await coordinator.RunAsync(); + return await op.Channel.PrepareResponseAsync(renegotiateResponse); + }); + + var rp = this.CreateRelyingParty(); + var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), CancellationToken.None); + Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); } /// <summary> @@ -219,22 +210,20 @@ namespace DotNetOpenAuth.Test.OpenId { [Test] public async Task RPRejectsMismatchingAssociationAndSessionBitLengths() { Protocol protocol = Protocol.V20; - var coordinator = new CoordinatorBase( - RelyingPartyDriver(async (rp, ct) => { - var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), ct); - Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); - }), - HandleProvider(async (op, req, ct) => { + this.HandleProvider( + async (op, req) => { // Receive initial request. - var request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, ct); + var request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, CancellationToken.None); // Send a mismatched response AssociateUnsuccessfulResponse renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); renegotiateResponse.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA1; renegotiateResponse.SessionType = protocol.Args.SessionType.DH_SHA256; - return await op.Channel.PrepareResponseAsync(renegotiateResponse, ct); - })); - await coordinator.RunAsync(); + return await op.Channel.PrepareResponseAsync(renegotiateResponse); + }); + var rp = this.CreateRelyingParty(); + var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), CancellationToken.None); + Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); } /// <summary> @@ -245,37 +234,36 @@ namespace DotNetOpenAuth.Test.OpenId { public async Task RPOnlyRenegotiatesOnce() { Protocol protocol = Protocol.V20; int opStep = 0; - await RunAsync( - RelyingPartyDriver(async (rp, ct) => { - var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), ct); - Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); - }), - HandleProvider(async (op, req, ct) => { + HandleProvider( + async (op, req) => { switch (++opStep) { case 1: // Receive initial request. - var request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, ct); + var request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, CancellationToken.None); // Send a renegotiate response var renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); renegotiateResponse.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA1; renegotiateResponse.SessionType = protocol.Args.SessionType.DH_SHA1; - return await op.Channel.PrepareResponseAsync(renegotiateResponse, ct); + return await op.Channel.PrepareResponseAsync(renegotiateResponse, CancellationToken.None); case 2: // Receive second-try - request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, ct); + request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, CancellationToken.None); // Send ANOTHER renegotiate response, at which point the DNOI RP should give up. renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); renegotiateResponse.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256; renegotiateResponse.SessionType = protocol.Args.SessionType.DH_SHA256; - return await op.Channel.PrepareResponseAsync(renegotiateResponse, ct); + return await op.Channel.PrepareResponseAsync(renegotiateResponse, CancellationToken.None); default: throw Assumes.NotReachable(); } - })); + }); + var rp = this.CreateRelyingParty(); + var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), CancellationToken.None); + Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); } /// <summary> @@ -284,17 +272,15 @@ namespace DotNetOpenAuth.Test.OpenId { [Test] public async Task AssociateRenegotiateLimitedByRPSecuritySettings() { Protocol protocol = Protocol.V20; - var coordinator = new CoordinatorBase( - RelyingPartyDriver(async (rp, ct) => { - rp.SecuritySettings.MinimumHashBitLength = 256; - var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), ct); - Assert.IsNull(association, "No association should have been created when RP and OP could not agree on association strength."); - }), - HandleProvider(async (op, req, ct) => { + HandleProvider( + async (op, req) => { op.SecuritySettings.MaximumHashBitLength = 160; - return await AutoProviderActionAsync(op, req, ct); - })); - await coordinator.RunAsync(); + return await AutoProviderActionAsync(op, req, CancellationToken.None); + }); + var rp = this.CreateRelyingParty(); + rp.SecuritySettings.MinimumHashBitLength = 256; + var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), CancellationToken.None); + Assert.IsNull(association, "No association should have been created when RP and OP could not agree on association strength."); } /// <summary> @@ -345,19 +331,13 @@ namespace DotNetOpenAuth.Test.OpenId { var provider = new OpenIdProvider(new StandardProviderApplicationStore(), this.HostFactories) { SecuritySettings = this.ProviderSecuritySettings }; - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - relyingParty.SecuritySettings = this.RelyingPartySecuritySettings; - rpAssociation = await relyingParty.AssociationManager.GetOrCreateAssociationAsync(opDescription, ct); - }, - Handle(opDescription.Uri).By(async (request, ct) => { + Handle(opDescription.Uri).By( + async (request, ct) => { IRequest req = await provider.GetRequestAsync(request, ct); Assert.IsNotNull(req, "Expected incoming request but did not receive it."); Assert.IsTrue(req.IsResponseReady); return await provider.PrepareResponseAsync(req, ct); - })); - this.HostFactories.Handlers.AddRange(coordinator.HostFactories.Handlers); - coordinator.HostFactories = this.HostFactories; + }); relyingParty.Channel.IncomingMessageFilter = message => { Assert.AreSame(opDescription.Version, message.Version, "The message was recognized as version {0} but was expected to be {1}.", message.Version, Protocol.Lookup(opDescription.Version).ProtocolVersion); var associateSuccess = message as AssociateSuccessfulResponse; @@ -372,7 +352,9 @@ namespace DotNetOpenAuth.Test.OpenId { relyingParty.Channel.OutgoingMessageFilter = message => { Assert.AreEqual(opDescription.Version, message.Version, "The message was for version {0} but was expected to be for {1}.", message.Version, opDescription.Version); }; - await coordinator.RunAsync(); + + relyingParty.SecuritySettings = this.RelyingPartySecuritySettings; + rpAssociation = await relyingParty.AssociationManager.GetOrCreateAssociationAsync(opDescription, CancellationToken.None); if (expectSuccess) { Assert.IsNotNull(rpAssociation); diff --git a/src/DotNetOpenAuth.Test/OpenId/AuthenticationTests.cs b/src/DotNetOpenAuth.Test/OpenId/AuthenticationTests.cs index 93f74e6..01a1b7f 100644 --- a/src/DotNetOpenAuth.Test/OpenId/AuthenticationTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/AuthenticationTests.cs @@ -7,6 +7,7 @@ namespace DotNetOpenAuth.Test.OpenId { using System; using System.Net.Http; + using System.Threading; using System.Threading.Tasks; using DotNetOpenAuth.Messaging; @@ -65,60 +66,61 @@ namespace DotNetOpenAuth.Test.OpenId { [Test] public async Task UnsolicitedAssertion() { var opStore = new StandardProviderApplicationStore(); - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var op = new OpenIdProvider(opStore); - Identifier id = GetMockIdentifier(ProtocolVersion.V20); - var assertion = await op.PrepareUnsolicitedAssertionAsync(OPUri, RPRealmUri, id, OPLocalIdentifiers[0], ct); - - using (var httpClient = hostFactories.CreateHttpClient()) { - using (var response = await httpClient.GetAsync(assertion.Headers.Location)) { - response.EnsureSuccessStatusCode(); - } - } - }, - Handle(RPRealmUri).By(async (hostFactories, req, ct) => { - var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore(), hostFactories); + Handle(RPRealmUri).By( + async req => { + var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore(), this.HostFactories); IAuthenticationResponse response = await rp.GetResponseAsync(); Assert.AreEqual(AuthenticationStatus.Authenticated, response.Status); return new HttpResponseMessage(); - }), - Handle(OPUri).By( - async (req, ct) => { - var op = new OpenIdProvider(opStore); - return await this.AutoProviderActionAsync(op, req, ct); - }), - MockHttpRequest.RegisterMockRPDiscovery(ssl: false)); - await coordinator.RunAsync(); + }); + Handle(OPUri).By( + async (req, ct) => { + var op = new OpenIdProvider(opStore); + return await this.AutoProviderActionAsync(op, req, ct); + }); + this.RegisterMockRPDiscovery(ssl: false); + + { + var op = new OpenIdProvider(opStore); + Identifier id = GetMockIdentifier(ProtocolVersion.V20); + var assertion = await op.PrepareUnsolicitedAssertionAsync(OPUri, RPRealmUri, id, OPLocalIdentifiers[0]); + + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(assertion.Headers.Location)) { + response.EnsureSuccessStatusCode(); + } + } + } } [Test] public async Task UnsolicitedAssertionRejected() { var opStore = new StandardProviderApplicationStore(); - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var op = new OpenIdProvider(opStore); - Identifier id = GetMockIdentifier(ProtocolVersion.V20); - var assertion = await op.PrepareUnsolicitedAssertionAsync(OPUri, RPRealmUri, id, OPLocalIdentifiers[0], ct); - using (var httpClient = hostFactories.CreateHttpClient()) { - using (var response = await httpClient.GetAsync(assertion.Headers.Location, ct)) { - response.EnsureSuccessStatusCode(); - } - } - }, - Handle(RPRealmUri).By(async (hostFactories, req, ct) => { - var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore(), hostFactories); + Handle(RPRealmUri).By( + async req => { + var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore(), this.HostFactories); rp.SecuritySettings.RejectUnsolicitedAssertions = true; - IAuthenticationResponse response = await rp.GetResponseAsync(req, ct); + IAuthenticationResponse response = await rp.GetResponseAsync(req); Assert.AreEqual(AuthenticationStatus.Failed, response.Status); return new HttpResponseMessage(); - }), - Handle(OPUri).By(async (hostFactories, req, ct) => { + }); + Handle(OPUri).By( + async req => { var op = new OpenIdProvider(opStore); - return await this.AutoProviderActionAsync(op, req, ct); - }), - MockHttpRequest.RegisterMockRPDiscovery(false)); - await coordinator.RunAsync(); + return await this.AutoProviderActionAsync(op, req, CancellationToken.None); + }); + this.RegisterMockRPDiscovery(ssl: false); + + { + var op = new OpenIdProvider(opStore); + Identifier id = GetMockIdentifier(ProtocolVersion.V20); + var assertion = await op.PrepareUnsolicitedAssertionAsync(OPUri, RPRealmUri, id, OPLocalIdentifiers[0]); + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(assertion.Headers.Location)) { + response.EnsureSuccessStatusCode(); + } + } + } } /// <summary> @@ -128,30 +130,31 @@ namespace DotNetOpenAuth.Test.OpenId { [Test] public async Task UnsolicitedDelegatingIdentifierRejection() { var opStore = new StandardProviderApplicationStore(); - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var op = new OpenIdProvider(opStore); - Identifier id = GetMockIdentifier(ProtocolVersion.V20, false, true); - var assertion = await op.PrepareUnsolicitedAssertionAsync(OPUri, RPRealmUri, id, OPLocalIdentifiers[0], ct); - using (var httpClient = hostFactories.CreateHttpClient()) { - using (var response = await httpClient.GetAsync(assertion.Headers.Location, ct)) { - response.EnsureSuccessStatusCode(); - } - } - }, - Handle(RPRealmUri).By(async (hostFactories, req, ct) => { - var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore(), hostFactories); + Handle(RPRealmUri).By( + async req => { + var rp = this.CreateRelyingParty(); rp.SecuritySettings.RejectDelegatingIdentifiers = true; - IAuthenticationResponse response = await rp.GetResponseAsync(req, ct); + IAuthenticationResponse response = await rp.GetResponseAsync(req); Assert.AreEqual(AuthenticationStatus.Failed, response.Status); return new HttpResponseMessage(); - }), - Handle(OPUri).By(async (hostFactories, req, ct) => { - var op = new OpenIdProvider(opStore); - return await this.AutoProviderActionAsync(op, req, ct); - }), - MockHttpRequest.RegisterMockRPDiscovery(false)); - await coordinator.RunAsync(); + }); + Handle(OPUri).By( + async req => { + var op = new OpenIdProvider(opStore, this.HostFactories); + return await this.AutoProviderActionAsync(op, req, CancellationToken.None); + }); + this.RegisterMockRPDiscovery(ssl: false); + + { + var op = new OpenIdProvider(opStore); + Identifier id = GetMockIdentifier(ProtocolVersion.V20, false, true); + var assertion = await op.PrepareUnsolicitedAssertionAsync(OPUri, RPRealmUri, id, OPLocalIdentifiers[0]); + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(assertion.Headers.Location)) { + response.EnsureSuccessStatusCode(); + } + } + } } private async Task ParameterizedAuthenticationTestAsync(bool sharedAssociation, bool positive, bool tamper) { @@ -178,83 +181,18 @@ namespace DotNetOpenAuth.Test.OpenId { var associationStore = new ProviderAssociationHandleEncoder(cryptoKeyStore); Association association = sharedAssociation ? HmacShaAssociationProvider.Create(protocol, protocol.Args.SignatureAlgorithm.Best, AssociationRelyingPartyType.Smart, associationStore, securitySettings) : null; int opStep = 0; - var coordinator = new CoordinatorBase( - RelyingPartyDriver(async (rp, ct) => { - if (statelessRP) { - rp = new OpenIdRelyingParty(null, rp.Channel.HostFactories); - } - - var request = new CheckIdRequest(protocol.Version, OPUri, immediate ? AuthenticationRequestMode.Immediate : AuthenticationRequestMode.Setup); - + HandleProvider( + async (op, req) => { if (association != null) { - StoreAssociation(rp, OPUri, association); - request.AssociationHandle = association.Handle; - } - - request.ClaimedIdentifier = "http://claimedid"; - request.LocalIdentifier = "http://localid"; - request.ReturnTo = RPUri; - request.Realm = RPUri; - var redirectRequest = await rp.Channel.PrepareResponseAsync(request, ct); - Uri redirectResponse; - using (var httpClient = rp.Channel.HostFactories.CreateHttpClient()) { - using (var response = await httpClient.GetAsync(redirectRequest.Headers.Location)) { - redirectResponse = response.Headers.Location; - } - } - - var assertionMessage = new HttpRequestMessage(HttpMethod.Get, redirectResponse.AbsoluteUri); - if (positive) { - if (tamper) { - try { - await rp.Channel.ReadFromRequestAsync<PositiveAssertionResponse>(assertionMessage, ct); - Assert.Fail("Expected exception {0} not thrown.", typeof(InvalidSignatureException).Name); - } catch (InvalidSignatureException) { - TestLogger.InfoFormat("Caught expected {0} exception after tampering with signed data.", typeof(InvalidSignatureException).Name); - } - } else { - var response = await rp.Channel.ReadFromRequestAsync<PositiveAssertionResponse>(assertionMessage, ct); - Assert.IsNotNull(response); - Assert.AreEqual(request.ClaimedIdentifier, response.ClaimedIdentifier); - Assert.AreEqual(request.LocalIdentifier, response.LocalIdentifier); - Assert.AreEqual(request.ReturnTo, response.ReturnTo); - - // Attempt to replay the message and verify that it fails. - // Because in various scenarios and protocol versions different components - // notice the replay, we can get one of two exceptions thrown. - // When the OP notices the replay we get a generic InvalidSignatureException. - // When the RP notices the replay we get a specific ReplayMessageException. - try { - // TODO: fix this. - ////CoordinatingChannel channel = (CoordinatingChannel)rp.Channel; - ////await channel.ReplayAsync(response); - Assert.Fail("Expected ProtocolException was not thrown."); - } catch (ProtocolException ex) { - Assert.IsTrue(ex is ReplayedMessageException || ex is InvalidSignatureException, "A {0} exception was thrown instead of the expected {1} or {2}.", ex.GetType(), typeof(ReplayedMessageException).Name, typeof(InvalidSignatureException).Name); - } - } - } else { - var response = await rp.Channel.ReadFromRequestAsync<NegativeAssertionResponse>(assertionMessage, ct); - Assert.IsNotNull(response); - if (immediate) { - // Only 1.1 was required to include user_setup_url - if (protocol.Version.Major < 2) { - Assert.IsNotNull(response.UserSetupUrl); - } - } else { - Assert.IsNull(response.UserSetupUrl); - } - } - }), - HandleProvider(async (op, req, ct) => { - if (association != null) { - var key = cryptoKeyStore.GetCurrentKey(ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, TimeSpan.FromSeconds(1)); - op.CryptoKeyStore.StoreKey(ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, key.Key, key.Value); + var key = cryptoKeyStore.GetCurrentKey( + ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, TimeSpan.FromSeconds(1)); + op.CryptoKeyStore.StoreKey( + ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, key.Key, key.Value); } switch (++opStep) { case 1: - var request = await op.Channel.ReadFromRequestAsync<CheckIdRequest>(req, ct); + var request = await op.Channel.ReadFromRequestAsync<CheckIdRequest>(req, CancellationToken.None); Assert.IsNotNull(request); IProtocolMessage response; if (positive) { @@ -263,13 +201,14 @@ namespace DotNetOpenAuth.Test.OpenId { response = new NegativeAssertionResponse(request.Version, request.ReturnTo, request.Mode); } - return await op.Channel.PrepareResponseAsync(response, ct); + return await op.Channel.PrepareResponseAsync(response); case 2: if (positive && (statelessRP || !sharedAssociation)) { - var checkauthRequest = await op.Channel.ReadFromRequestAsync<CheckAuthenticationRequest>(req, ct); + var checkauthRequest = + await op.Channel.ReadFromRequestAsync<CheckAuthenticationRequest>(req, CancellationToken.None); var checkauthResponse = new CheckAuthenticationResponse(checkauthRequest.Version, checkauthRequest); checkauthResponse.IsValid = checkauthRequest.IsValid; - return await op.Channel.PrepareResponseAsync(checkauthResponse, ct); + return await op.Channel.PrepareResponseAsync(checkauthResponse); } throw Assumes.NotReachable(); @@ -277,10 +216,11 @@ namespace DotNetOpenAuth.Test.OpenId { if (positive && (statelessRP || !sharedAssociation)) { if (!tamper) { // Respond to the replay attack. - var checkauthRequest = await op.Channel.ReadFromRequestAsync<CheckAuthenticationRequest>(req, ct); + var checkauthRequest = + await op.Channel.ReadFromRequestAsync<CheckAuthenticationRequest>(req, CancellationToken.None); var checkauthResponse = new CheckAuthenticationResponse(checkauthRequest.Version, checkauthRequest); checkauthResponse.IsValid = checkauthRequest.IsValid; - return await op.Channel.PrepareResponseAsync(checkauthResponse, ct); + return await op.Channel.PrepareResponseAsync(checkauthResponse); } } @@ -288,21 +228,93 @@ namespace DotNetOpenAuth.Test.OpenId { default: throw Assumes.NotReachable(); } - })); - if (tamper) { - // TODO: fix this. - ////coordinator.IncomingMessageFilter = message => { - //// var assertion = message as PositiveAssertionResponse; - //// if (assertion != null) { - //// // Alter the Local Identifier between the Provider and the Relying Party. - //// // If the signature binding element does its job, this should cause the RP - //// // to throw. - //// assertion.LocalIdentifier = "http://victim"; - //// } - ////}; - } + }); + + { + var rp = this.CreateRelyingParty(statelessRP); + if (tamper) { + rp.Channel.IncomingMessageFilter = message => { + var assertion = message as PositiveAssertionResponse; + if (assertion != null) { + // Alter the Local Identifier between the Provider and the Relying Party. + // If the signature binding element does its job, this should cause the RP + // to throw. + assertion.LocalIdentifier = "http://victim"; + } + }; + } + + var request = new CheckIdRequest( + protocol.Version, OPUri, immediate ? AuthenticationRequestMode.Immediate : AuthenticationRequestMode.Setup); - await coordinator.RunAsync(); + if (association != null) { + StoreAssociation(rp, OPUri, association); + request.AssociationHandle = association.Handle; + } + + request.ClaimedIdentifier = "http://claimedid"; + request.LocalIdentifier = "http://localid"; + request.ReturnTo = RPUri; + request.Realm = RPUri; + var redirectRequest = await rp.Channel.PrepareResponseAsync(request); + Uri redirectResponse; + using (var httpClient = rp.Channel.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(redirectRequest.Headers.Location)) { + redirectResponse = response.Headers.Location; + } + } + + var assertionMessage = new HttpRequestMessage(HttpMethod.Get, redirectResponse.AbsoluteUri); + if (positive) { + if (tamper) { + try { + await rp.Channel.ReadFromRequestAsync<PositiveAssertionResponse>(assertionMessage, CancellationToken.None); + Assert.Fail("Expected exception {0} not thrown.", typeof(InvalidSignatureException).Name); + } catch (InvalidSignatureException) { + TestLogger.InfoFormat( + "Caught expected {0} exception after tampering with signed data.", typeof(InvalidSignatureException).Name); + } + } else { + var response = + await rp.Channel.ReadFromRequestAsync<PositiveAssertionResponse>(assertionMessage, CancellationToken.None); + Assert.IsNotNull(response); + Assert.AreEqual(request.ClaimedIdentifier, response.ClaimedIdentifier); + Assert.AreEqual(request.LocalIdentifier, response.LocalIdentifier); + Assert.AreEqual(request.ReturnTo, response.ReturnTo); + + // Attempt to replay the message and verify that it fails. + // Because in various scenarios and protocol versions different components + // notice the replay, we can get one of two exceptions thrown. + // When the OP notices the replay we get a generic InvalidSignatureException. + // When the RP notices the replay we get a specific ReplayMessageException. + try { + // TODO: fix this. + ////CoordinatingChannel channel = (CoordinatingChannel)rp.Channel; + ////await channel.ReplayAsync(response); + Assert.Fail("Expected ProtocolException was not thrown."); + } catch (ProtocolException ex) { + Assert.IsTrue( + ex is ReplayedMessageException || ex is InvalidSignatureException, + "A {0} exception was thrown instead of the expected {1} or {2}.", + ex.GetType(), + typeof(ReplayedMessageException).Name, + typeof(InvalidSignatureException).Name); + } + } + } else { + var response = + await rp.Channel.ReadFromRequestAsync<NegativeAssertionResponse>(assertionMessage, CancellationToken.None); + Assert.IsNotNull(response); + if (immediate) { + // Only 1.1 was required to include user_setup_url + if (protocol.Version.Major < 2) { + Assert.IsNotNull(response.UserSetupUrl); + } + } else { + Assert.IsNull(response.UserSetupUrl); + } + } + } } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/ChannelElements/ExtensionsBindingElementTests.cs b/src/DotNetOpenAuth.Test/OpenId/ChannelElements/ExtensionsBindingElementTests.cs index 7a1add2..7dfae7b 100644 --- a/src/DotNetOpenAuth.Test/OpenId/ChannelElements/ExtensionsBindingElementTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/ChannelElements/ExtensionsBindingElementTests.cs @@ -46,7 +46,7 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { public async Task RoundTripFullStackTest() { IOpenIdMessageExtension request = new MockOpenIdExtension("requestPart", "requestData"); IOpenIdMessageExtension response = new MockOpenIdExtension("responsePart", "responseData"); - await ExtensionTestUtilities.RoundtripAsync( + await this.RoundtripAsync( Protocol.Default, new IOpenIdMessageExtension[] { request }, new IOpenIdMessageExtension[] { response }); @@ -123,37 +123,21 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { Protocol protocol = Protocol.Default; var opStore = new StandardProviderApplicationStore(); int rpStep = 0; - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var op = new OpenIdProvider(opStore); - RegisterMockExtension(op.Channel); - var redirectingResponse = await op.Channel.PrepareResponseAsync(CreateResponseWithExtensions(protocol)); - using (var httpClient = hostFactories.CreateHttpClient()) { - using (var response = await httpClient.GetAsync(redirectingResponse.Headers.Location)) { - response.EnsureSuccessStatusCode(); - } - } - op.SecuritySettings.SignOutgoingExtensions = false; - redirectingResponse = await op.Channel.PrepareResponseAsync(CreateResponseWithExtensions(protocol)); - using (var httpClient = hostFactories.CreateHttpClient()) { - using (var response = await httpClient.GetAsync(redirectingResponse.Headers.Location)) { - response.EnsureSuccessStatusCode(); - } - } - }, - Handle(RPRealmUri).By(async (hostFactories, req, ct) => { - var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore(), hostFactories); + + Handle(RPRealmUri).By( + async req => { + var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore(), this.HostFactories); RegisterMockExtension(rp.Channel); switch (++rpStep) { case 1: - var response = await rp.Channel.ReadFromRequestAsync<IndirectSignedResponse>(req, ct); + var response = await rp.Channel.ReadFromRequestAsync<IndirectSignedResponse>(req, CancellationToken.None); Assert.AreEqual(1, response.SignedExtensions.Count(), "Signed extension should have been received."); Assert.AreEqual(0, response.UnsignedExtensions.Count(), "No unsigned extension should be present."); break; case 2: - response = await rp.Channel.ReadFromRequestAsync<IndirectSignedResponse>(req, ct); + response = await rp.Channel.ReadFromRequestAsync<IndirectSignedResponse>(req, CancellationToken.None); Assert.AreEqual(0, response.SignedExtensions.Count(), "No signed extension should have been received."); Assert.AreEqual(1, response.UnsignedExtensions.Count(), "Unsigned extension should have been received."); break; @@ -163,12 +147,31 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { } return new HttpResponseMessage(); - }), - Handle(OPUri).By(async (hostFactories, req, ct) => { + }); + Handle(OPUri).By( + async req => { var op = new OpenIdProvider(opStore); - return await AutoProviderActionAsync(op, req, ct); - })); - await coordinator.RunAsync(); + return await AutoProviderActionAsync(op, req, CancellationToken.None); + }); + + { + var op = new OpenIdProvider(opStore); + RegisterMockExtension(op.Channel); + var redirectingResponse = await op.Channel.PrepareResponseAsync(CreateResponseWithExtensions(protocol)); + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(redirectingResponse.Headers.Location)) { + response.EnsureSuccessStatusCode(); + } + } + + op.SecuritySettings.SignOutgoingExtensions = false; + redirectingResponse = await op.Channel.PrepareResponseAsync(CreateResponseWithExtensions(protocol)); + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(redirectingResponse.Headers.Location)) { + response.EnsureSuccessStatusCode(); + } + } + } } /// <summary> @@ -183,7 +186,7 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { IOpenIdMessageExtension request1 = new MockOpenIdExtension("requestPart1", "requestData1"); IOpenIdMessageExtension request2 = new MockOpenIdExtension("requestPart2", "requestData2"); try { - await ExtensionTestUtilities.RoundtripAsync( + await this.RoundtripAsync( Protocol.Default, new IOpenIdMessageExtension[] { request1, request2 }, new IOpenIdMessageExtension[0]); diff --git a/src/DotNetOpenAuth.Test/OpenId/DiscoveryServices/UriDiscoveryServiceTests.cs b/src/DotNetOpenAuth.Test/OpenId/DiscoveryServices/UriDiscoveryServiceTests.cs index 06cb745..0dfb880 100644 --- a/src/DotNetOpenAuth.Test/OpenId/DiscoveryServices/UriDiscoveryServiceTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/DiscoveryServices/UriDiscoveryServiceTests.cs @@ -30,8 +30,8 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { // Add a couple of chained redirect pages that lead to the claimedId. Uri userSuppliedUri = new Uri("https://localhost/someSecurePage"); Uri insecureMidpointUri = new Uri("http://localhost/insecureStop"); - this.HostFactories.Handlers.Add(MockHttpRequest.RegisterMockRedirect(userSuppliedUri, insecureMidpointUri)); - this.HostFactories.Handlers.Add(MockHttpRequest.RegisterMockRedirect(insecureMidpointUri, new Uri(claimedId.ToString()))); + this.RegisterMockRedirect(userSuppliedUri, insecureMidpointUri); + this.RegisterMockRedirect(insecureMidpointUri, new Uri(claimedId.ToString())); // don't require secure SSL discovery for this test. Identifier userSuppliedIdentifier = new UriIdentifier(userSuppliedUri, false); @@ -47,8 +47,8 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { // All redirects should be secure. Uri userSuppliedUri = new Uri("https://localhost/someSecurePage"); Uri secureMidpointUri = new Uri("https://localhost/secureStop"); - this.HostFactories.Handlers.Add(MockHttpRequest.RegisterMockRedirect(userSuppliedUri, secureMidpointUri)); - this.HostFactories.Handlers.Add(MockHttpRequest.RegisterMockRedirect(secureMidpointUri, new Uri(claimedId.ToString()))); + this.RegisterMockRedirect(userSuppliedUri, secureMidpointUri); + this.RegisterMockRedirect(secureMidpointUri, new Uri(claimedId.ToString())); Identifier userSuppliedIdentifier = new UriIdentifier(userSuppliedUri, true); var discoveryResult = await this.DiscoverAsync(userSuppliedIdentifier); @@ -64,8 +64,8 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { // the ultimate endpoint is never found as a result of high security profile. Uri userSuppliedUri = new Uri("https://localhost/someSecurePage"); Uri insecureMidpointUri = new Uri("http://localhost/insecureStop"); - this.HostFactories.Handlers.Add(MockHttpRequest.RegisterMockRedirect(userSuppliedUri, insecureMidpointUri)); - this.HostFactories.Handlers.Add(MockHttpRequest.RegisterMockRedirect(insecureMidpointUri, new Uri(claimedId.ToString()))); + this.RegisterMockRedirect(userSuppliedUri, insecureMidpointUri); + this.RegisterMockRedirect(insecureMidpointUri, new Uri(claimedId.ToString())); Identifier userSuppliedIdentifier = new UriIdentifier(userSuppliedUri, true); await this.DiscoverAsync(userSuppliedIdentifier); @@ -77,7 +77,7 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { Uri secureClaimedUri = new Uri("https://localhost/secureId"); string html = string.Format("<html><head><meta http-equiv='X-XRDS-Location' content='{0}'/></head><body></body></html>", insecureXrdsSource); - this.HostFactories.Handlers.Add(MockHttpRequest.RegisterMockResponse(secureClaimedUri, "text/html", html)); + this.RegisterMockResponse(secureClaimedUri, "text/html", html); Identifier userSuppliedIdentifier = new UriIdentifier(secureClaimedUri, true); var discoveryResult = await this.DiscoverAsync(userSuppliedIdentifier); @@ -92,7 +92,7 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { WebHeaderCollection headers = new WebHeaderCollection { { "X-XRDS-Location", insecureXrdsSource } }; - this.HostFactories.Handlers.Add(MockHttpRequest.RegisterMockResponse(VanityUriSsl, VanityUriSsl, "text/html", headers, html)); + this.RegisterMockResponse(VanityUriSsl, VanityUriSsl, "text/html", headers, html); Identifier userSuppliedIdentifier = new UriIdentifier(VanityUriSsl, true); var discoveryResult = await this.DiscoverAsync(userSuppliedIdentifier); @@ -112,7 +112,7 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { HttpUtility.HtmlEncode(insecureXrdsSource), HttpUtility.HtmlEncode(OPUriSsl.AbsoluteUri), HttpUtility.HtmlEncode(OPLocalIdentifiersSsl[1].AbsoluteUri)); - this.HostFactories.Handlers.Add(Handle(VanityUriSsl).By(html, "text/html")); + this.Handle(VanityUriSsl).By(html, "text/html"); Identifier userSuppliedIdentifier = new UriIdentifier(VanityUriSsl, true); @@ -127,7 +127,7 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { var insecureEndpoint = GetServiceEndpoint(0, ProtocolVersion.V20, 10, false); var secureEndpoint = GetServiceEndpoint(1, ProtocolVersion.V20, 20, true); UriIdentifier secureClaimedId = new UriIdentifier(VanityUriSsl, true); - this.HostFactories.Handlers.Add(MockHttpRequest.RegisterMockXrdsResponse(secureClaimedId, new IdentifierDiscoveryResult[] { insecureEndpoint, secureEndpoint })); + this.RegisterMockXrdsResponse(secureClaimedId, new[] { insecureEndpoint, secureEndpoint }); var discoverResult = await this.DiscoverAsync(secureClaimedId); Assert.AreEqual(secureEndpoint.ProviderLocalIdentifier, discoverResult.Single().ProviderLocalIdentifier); } @@ -176,7 +176,10 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { [Test] public async Task XrdsDiscoveryFromHead() { - this.HostFactories.Handlers.Add(MockHttpRequest.RegisterMockResponse(new Uri("http://localhost/xrds1020.xml"), "application/xrds+xml", LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds1020.xml"))); + this.RegisterMockResponse( + new Uri("http://localhost/xrds1020.xml"), + "application/xrds+xml", + LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds1020.xml")); await this.DiscoverXrdsAsync("XrdsReferencedInHead.html", ProtocolVersion.V10, null, "http://a/b"); } @@ -184,7 +187,7 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { public async Task XrdsDiscoveryFromHttpHeader() { WebHeaderCollection headers = new WebHeaderCollection(); headers.Add("X-XRDS-Location", new Uri("http://localhost/xrds1020.xml").AbsoluteUri); - this.HostFactories.Handlers.Add(MockHttpRequest.RegisterMockResponse(new Uri("http://localhost/xrds1020.xml"), "application/xrds+xml", LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds1020.xml"))); + this.RegisterMockResponse(new Uri("http://localhost/xrds1020.xml"), "application/xrds+xml", LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds1020.xml")); await this.DiscoverXrdsAsync("XrdsReferencedInHttpHeader.html", ProtocolVersion.V10, null, "http://a/b", headers); } @@ -193,7 +196,7 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { /// </summary> [Test] public async Task HtmlDiscoveryProceedsIfXrdsIsEmpty() { - this.HostFactories.Handlers.Add(MockHttpRequest.RegisterMockResponse(new Uri("http://localhost/xrds-irrelevant.xml"), "application/xrds+xml", LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds-irrelevant.xml"))); + this.RegisterMockResponse(new Uri("http://localhost/xrds-irrelevant.xml"), "application/xrds+xml", LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds-irrelevant.xml")); await this.DiscoverHtmlAsync("html20provWithEmptyXrds", ProtocolVersion.V20, null, "http://a/b"); } @@ -210,7 +213,7 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { /// </summary> [Test] public async Task DualIdentifierOffByDefault() { - this.HostFactories.Handlers.Add(MockHttpRequest.RegisterMockResponse(VanityUri, "application/xrds+xml", LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds20dual.xml"))); + this.RegisterMockResponse(VanityUri, "application/xrds+xml", LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds20dual.xml")); var results = (await this.DiscoverAsync(VanityUri)).ToList(); Assert.AreEqual(1, results.Count(r => r.ClaimedIdentifier == r.Protocol.ClaimedIdentifierForOPIdentifier), "OP Identifier missing from discovery results."); Assert.AreEqual(1, results.Count, "Unexpected additional services discovered."); @@ -221,7 +224,7 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { /// </summary> [Test] public async Task DualIdentifier() { - this.HostFactories.Handlers.Add(MockHttpRequest.RegisterMockResponse(VanityUri, "application/xrds+xml", LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds20dual.xml"))); + this.RegisterMockResponse(VanityUri, "application/xrds+xml", LoadEmbeddedFile("/Discovery/xrdsdiscovery/xrds20dual.xml")); var rp = this.CreateRelyingParty(true); rp.SecuritySettings.AllowDualPurposeIdentifiers = true; var results = (await rp.DiscoverAsync(VanityUri, CancellationToken.None)).ToList(); @@ -252,7 +255,7 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { } else { throw new InvalidOperationException(); } - this.HostFactories.Handlers.Add(MockHttpRequest.RegisterMockResponse(new Uri(idToDiscover), claimedId, contentType, headers ?? new WebHeaderCollection(), LoadEmbeddedFile(url))); + this.RegisterMockResponse(new Uri(idToDiscover), claimedId, contentType, headers ?? new WebHeaderCollection(), LoadEmbeddedFile(url)); IdentifierDiscoveryResult expected = IdentifierDiscoveryResult.CreateForClaimedIdentifier( claimedId, @@ -297,7 +300,7 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { private async Task FailDiscoverAsync(string url) { UriIdentifier userSuppliedId = new Uri(new Uri("http://localhost"), url); - this.HostFactories.Handlers.Add(MockHttpRequest.RegisterMockResponse(new Uri(userSuppliedId), userSuppliedId, "text/html", LoadEmbeddedFile(url))); + this.RegisterMockResponse(new Uri(userSuppliedId), userSuppliedId, "text/html", LoadEmbeddedFile(url)); var discoveryResult = await this.DiscoverAsync(userSuppliedId); Assert.AreEqual(0, discoveryResult.Count()); // ... but that no endpoint info is discoverable diff --git a/src/DotNetOpenAuth.Test/OpenId/DiscoveryServices/XriDiscoveryProxyServiceTests.cs b/src/DotNetOpenAuth.Test/OpenId/DiscoveryServices/XriDiscoveryProxyServiceTests.cs index 03f9349..23ddbfe 100644 --- a/src/DotNetOpenAuth.Test/OpenId/DiscoveryServices/XriDiscoveryProxyServiceTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/DiscoveryServices/XriDiscoveryProxyServiceTests.cs @@ -54,7 +54,7 @@ namespace DotNetOpenAuth.Test.OpenId.DiscoveryServices { { "https://xri.net/=Arnott?_xrd_r=application/xrd%2Bxml;sep=false", xrds }, { "https://xri.net/=!9B72.7DD1.50A9.5CCD?_xrd_r=application/xrd%2Bxml;sep=false", xrds }, }; - this.HostFactories.Handlers.AddRange(MockHttpRequest.RegisterMockXrdsResponses(mocks)); + this.RegisterMockXrdsResponses(mocks); string expectedCanonicalId = "=!9B72.7DD1.50A9.5CCD"; IdentifierDiscoveryResult se = await this.VerifyCanonicalIdAsync("=Arnott", expectedCanonicalId); @@ -280,14 +280,13 @@ uEyb50RJ7DWmXctSC0b3eymZ2lSXxAWNOsNy </X509Data> </KeyInfo> </XRD>"; - this.HostFactories.Handlers.AddRange( - MockHttpRequest.RegisterMockXrdsResponses(new Dictionary<string, string> { - { "https://xri.net/@llli?_xrd_r=application/xrd%2Bxml;sep=false", llliResponse }, - { "https://xri.net/@llli*area?_xrd_r=application/xrd%2Bxml;sep=false", llliAreaResponse }, - { "https://xri.net/@llli*area*canada.unattached?_xrd_r=application/xrd%2Bxml;sep=false", llliAreaCanadaUnattachedResponse }, - { "https://xri.net/@llli*area*canada.unattached*ada?_xrd_r=application/xrd%2Bxml;sep=false", llliAreaCanadaUnattachedAdaResponse }, - { "https://xri.net/=Web?_xrd_r=application/xrd%2Bxml;sep=false", webResponse }, - })); + this.RegisterMockXrdsResponses(new Dictionary<string, string> { + { "https://xri.net/@llli?_xrd_r=application/xrd%2Bxml;sep=false", llliResponse }, + { "https://xri.net/@llli*area?_xrd_r=application/xrd%2Bxml;sep=false", llliAreaResponse }, + { "https://xri.net/@llli*area*canada.unattached?_xrd_r=application/xrd%2Bxml;sep=false", llliAreaCanadaUnattachedResponse }, + { "https://xri.net/@llli*area*canada.unattached*ada?_xrd_r=application/xrd%2Bxml;sep=false", llliAreaCanadaUnattachedAdaResponse }, + { "https://xri.net/=Web?_xrd_r=application/xrd%2Bxml;sep=false", webResponse }, + }); await this.VerifyCanonicalIdAsync("@llli", "@!72CD.A072.157E.A9C6"); await this.VerifyCanonicalIdAsync("@llli*area", "@!72CD.A072.157E.A9C6!0000.0000.3B9A.CA0C"); await this.VerifyCanonicalIdAsync("@llli*area*canada.unattached", "@!72CD.A072.157E.A9C6!0000.0000.3B9A.CA0C!0000.0000.3B9A.CA41"); @@ -297,8 +296,7 @@ uEyb50RJ7DWmXctSC0b3eymZ2lSXxAWNOsNy [Test] public async Task DiscoveryCommunityInameDelegateWithoutCanonicalID() { - this.HostFactories.Handlers.AddRange( - MockHttpRequest.RegisterMockXrdsResponses(new Dictionary<string, string> { + this.RegisterMockXrdsResponses(new Dictionary<string, string> { { "https://xri.net/=Web*andrew.arnott?_xrd_r=application/xrd%2Bxml;sep=false", @"<?xml version='1.0' encoding='UTF-8'?> <XRD xmlns='xri://$xrd*($v*2.0)'> <Query>*andrew.arnott</Query> @@ -376,7 +374,7 @@ uEyb50RJ7DWmXctSC0b3eymZ2lSXxAWNOsNy </Service> <ServedBy>OpenXRI</ServedBy> </XRD>" }, - })); + }); // Consistent with spec section 7.3.2.3, we do not permit // delegation on XRI discovery when there is no CanonicalID present. await this.VerifyCanonicalIdAsync("=Web*andrew.arnott", null); diff --git a/src/DotNetOpenAuth.Test/OpenId/Extensions/AttributeExchange/AttributeExchangeRoundtripTests.cs b/src/DotNetOpenAuth.Test/OpenId/Extensions/AttributeExchange/AttributeExchangeRoundtripTests.cs index 6f46daa..0d0d36c 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Extensions/AttributeExchange/AttributeExchangeRoundtripTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Extensions/AttributeExchange/AttributeExchangeRoundtripTests.cs @@ -28,7 +28,7 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { response.Attributes.Add(new AttributeValues(NicknameTypeUri, "Andrew")); response.Attributes.Add(new AttributeValues(EmailTypeUri, "a@a.com", "b@b.com")); - await ExtensionTestUtilities.RoundtripAsync(Protocol.Default, new[] { request }, new[] { response }); + await this.RoundtripAsync(Protocol.Default, new[] { request }, new[] { response }); } [Test] @@ -43,13 +43,13 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { var successResponse = new StoreResponse(); successResponse.Succeeded = true; - await ExtensionTestUtilities.RoundtripAsync(Protocol.Default, new[] { request }, new[] { successResponse }); + await this.RoundtripAsync(Protocol.Default, new[] { request }, new[] { successResponse }); var failureResponse = new StoreResponse(); failureResponse.Succeeded = false; failureResponse.FailureReason = "Some error"; - await ExtensionTestUtilities.RoundtripAsync(Protocol.Default, new[] { request }, new[] { failureResponse }); + await this.RoundtripAsync(Protocol.Default, new[] { request }, new[] { failureResponse }); } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/Extensions/ExtensionTestUtilities.cs b/src/DotNetOpenAuth.Test/OpenId/Extensions/ExtensionTestUtilities.cs index 7c7f945..fd53fd1 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Extensions/ExtensionTestUtilities.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Extensions/ExtensionTestUtilities.cs @@ -9,6 +9,7 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { using System.Collections.Generic; using System.Linq; using System.Net.Http; + using System.Threading; using System.Threading.Tasks; using DotNetOpenAuth.Messaging; @@ -23,68 +24,6 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { using Validation; public static class ExtensionTestUtilities { - /// <summary> - /// Simulates an extension request and response. - /// </summary> - /// <param name="protocol">The protocol to use in the roundtripping.</param> - /// <param name="requests">The extensions to add to the request message.</param> - /// <param name="responses">The extensions to add to the response message.</param> - /// <remarks> - /// This method relies on the extension objects' Equals methods to verify - /// accurate transport. The Equals methods should be verified by separate tests. - /// </remarks> - internal static async Task RoundtripAsync( - Protocol protocol, - IEnumerable<IOpenIdMessageExtension> requests, - IEnumerable<IOpenIdMessageExtension> responses) { - var securitySettings = new ProviderSecuritySettings(); - var cryptoKeyStore = new MemoryCryptoKeyStore(); - var associationStore = new ProviderAssociationHandleEncoder(cryptoKeyStore); - Association association = HmacShaAssociationProvider.Create(protocol, protocol.Args.SignatureAlgorithm.Best, AssociationRelyingPartyType.Smart, associationStore, securitySettings); - await TestBase.RunAsync( - OpenIdTestBase.RelyingPartyDriver(async (rp, ct) => { - RegisterExtension(rp.Channel, Mocks.MockOpenIdExtension.Factory); - var requestBase = new CheckIdRequest(protocol.Version, OpenIdTestBase.OPUri, AuthenticationRequestMode.Immediate); - OpenIdTestBase.StoreAssociation(rp, OpenIdTestBase.OPUri, association); - requestBase.AssociationHandle = association.Handle; - requestBase.ClaimedIdentifier = "http://claimedid"; - requestBase.LocalIdentifier = "http://localid"; - requestBase.ReturnTo = OpenIdTestBase.RPUri; - - foreach (IOpenIdMessageExtension extension in requests) { - requestBase.Extensions.Add(extension); - } - - var redirectingRequest = await rp.Channel.PrepareResponseAsync(requestBase); - Uri redirectingResponseUri; - using (var httpClient = rp.Channel.HostFactories.CreateHttpClient()) { - using (var redirectingResponse = await httpClient.GetAsync(redirectingRequest.Headers.Location, ct)) { - redirectingResponse.EnsureSuccessStatusCode(); - redirectingResponseUri = redirectingResponse.Headers.Location; - } - } - - var response = await rp.Channel.ReadFromRequestAsync<PositiveAssertionResponse>(new HttpRequestMessage(HttpMethod.Get, redirectingResponseUri), ct); - var receivedResponses = response.Extensions.Cast<IOpenIdMessageExtension>(); - CollectionAssert<IOpenIdMessageExtension>.AreEquivalentByEquality(responses.ToArray(), receivedResponses.ToArray()); - }), - OpenIdTestBase.HandleProvider(async (op, req, ct) => { - RegisterExtension(op.Channel, Mocks.MockOpenIdExtension.Factory); - var key = cryptoKeyStore.GetCurrentKey(ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, TimeSpan.FromSeconds(1)); - op.CryptoKeyStore.StoreKey(ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, key.Key, key.Value); - var request = await op.Channel.ReadFromRequestAsync<CheckIdRequest>(req, ct); - var response = new PositiveAssertionResponse(request); - var receivedRequests = request.Extensions.Cast<IOpenIdMessageExtension>(); - CollectionAssert<IOpenIdMessageExtension>.AreEquivalentByEquality(requests.ToArray(), receivedRequests.ToArray()); - - foreach (var extensionResponse in responses) { - response.Extensions.Add(extensionResponse); - } - - return await op.Channel.PrepareResponseAsync(response, ct); - })); - } - internal static void RegisterExtension(Channel channel, StandardOpenIdExtensionFactory.CreateDelegate extensionFactory) { Requires.NotNull(channel, "channel"); diff --git a/src/DotNetOpenAuth.Test/OpenId/Extensions/ProviderAuthenticationPolicy/PapeRoundTripTests.cs b/src/DotNetOpenAuth.Test/OpenId/Extensions/ProviderAuthenticationPolicy/PapeRoundTripTests.cs index 2969511..3cb3028 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Extensions/ProviderAuthenticationPolicy/PapeRoundTripTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Extensions/ProviderAuthenticationPolicy/PapeRoundTripTests.cs @@ -19,7 +19,7 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions.ProviderAuthenticationPolicy { public async Task Trivial() { var request = new PolicyRequest(); var response = new PolicyResponse(); - await ExtensionTestUtilities.RoundtripAsync(Protocol.Default, new[] { request }, new[] { response }); + await this.RoundtripAsync(Protocol.Default, new[] { request }, new[] { response }); } [Test] @@ -39,7 +39,7 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions.ProviderAuthenticationPolicy { response.AssuranceLevels["customlevel"] = "ABC"; response.NistAssuranceLevel = NistAssuranceLevel.Level2; - await ExtensionTestUtilities.RoundtripAsync(Protocol.Default, new[] { request }, new[] { response }); + await this.RoundtripAsync(Protocol.Default, new[] { request }, new[] { response }); } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/Extensions/SimpleRegistration/ClaimsResponseTests.cs b/src/DotNetOpenAuth.Test/OpenId/Extensions/SimpleRegistration/ClaimsResponseTests.cs index 40f9d76..1aa6e33 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Extensions/SimpleRegistration/ClaimsResponseTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Extensions/SimpleRegistration/ClaimsResponseTests.cs @@ -17,7 +17,7 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { using NUnit.Framework; [TestFixture] - public class ClaimsResponseTests { + public class ClaimsResponseTests : OpenIdTestBase { [Test] public void EmptyMailAddress() { ClaimsResponse response = new ClaimsResponse(Constants.TypeUris.Standard); @@ -140,7 +140,7 @@ namespace DotNetOpenAuth.Test.OpenId.Extensions { var response = new ClaimsResponse(Constants.TypeUris.Variant10); response.Email = "a@b.com"; - await ExtensionTestUtilities.RoundtripAsync(Protocol.Default, new[] { request }, new[] { response }); + await this.RoundtripAsync(Protocol.Default, new[] { request }, new[] { response }); } private ClaimsResponse GetFilledData() { diff --git a/src/DotNetOpenAuth.Test/OpenId/NonIdentityTests.cs b/src/DotNetOpenAuth.Test/OpenId/NonIdentityTests.cs index 16c096f..b42fb46 100644 --- a/src/DotNetOpenAuth.Test/OpenId/NonIdentityTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/NonIdentityTests.cs @@ -7,6 +7,7 @@ namespace DotNetOpenAuth.Test.OpenId { using System; using System.Net.Http; + using System.Threading; using System.Threading.Tasks; using DotNetOpenAuth.Messaging; @@ -23,58 +24,63 @@ namespace DotNetOpenAuth.Test.OpenId { Protocol protocol = Protocol.V20; var mode = AuthenticationRequestMode.Setup; - await RunAsync( - RelyingPartyDriver(async (rp, ct) => { - var request = new SignedResponseRequest(protocol.Version, OPUri, mode); - var authRequest = await rp.Channel.PrepareResponseAsync(request); - using (var httpClient = rp.Channel.HostFactories.CreateHttpClient()) { - using (var response = await httpClient.GetAsync(authRequest.Headers.Location, ct)) { - response.EnsureSuccessStatusCode(); - } - } - }), - HandleProvider(async (op, req, ct) => { - var request = await op.Channel.ReadFromRequestAsync<SignedResponseRequest>(req, ct); + HandleProvider( + async (op, req) => { + var request = await op.Channel.ReadFromRequestAsync<SignedResponseRequest>(req, CancellationToken.None); Assert.IsNotInstanceOf<CheckIdRequest>(request); return new HttpResponseMessage(); - })); + }); + + { + var rp = this.CreateRelyingParty(); + var request = new SignedResponseRequest(protocol.Version, OPUri, mode); + var authRequest = await rp.Channel.PrepareResponseAsync(request); + using (var httpClient = rp.Channel.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(authRequest.Headers.Location)) { + response.EnsureSuccessStatusCode(); + } + } + } } [Test] public async Task ExtensionOnlyFacadeLevel() { Protocol protocol = Protocol.V20; int opStep = 0; - await RunAsync( - RelyingPartyDriver(async (rp, ct) => { - var request = await rp.CreateRequestAsync(GetMockIdentifier(protocol.ProtocolVersion), RPRealmUri, RPUri, ct); - - request.IsExtensionOnly = true; - var redirectRequest = await request.GetRedirectingResponseAsync(ct); - Uri redirectResponseUrl; - using (var httpClient = rp.Channel.HostFactories.CreateHttpClient()) { - using (var redirectResponse = await httpClient.GetAsync(redirectRequest.Headers.Location, ct)) { - redirectResponse.EnsureSuccessStatusCode(); - redirectResponseUrl = redirectRequest.Headers.Location; - } - } - - IAuthenticationResponse response = await rp.GetResponseAsync(new HttpRequestMessage(HttpMethod.Get, redirectResponseUrl)); - Assert.AreEqual(AuthenticationStatus.ExtensionsOnly, response.Status); - }), - HandleProvider(async (op, req, ct) => { + HandleProvider( + async (op, req) => { switch (++opStep) { case 1: - var assocRequest = await op.GetRequestAsync(req, ct); - return await op.PrepareResponseAsync(assocRequest, ct); + var assocRequest = await op.GetRequestAsync(req); + return await op.PrepareResponseAsync(assocRequest); case 2: - var request = (IAnonymousRequest)await op.GetRequestAsync(req, ct); + var request = (IAnonymousRequest)await op.GetRequestAsync(req); request.IsApproved = true; Assert.IsNotInstanceOf<CheckIdRequest>(request); - return await op.PrepareResponseAsync(request, ct); + return await op.PrepareResponseAsync(request); default: throw Assumes.NotReachable(); } - })); + }); + + { + var rp = this.CreateRelyingParty(); + var request = await rp.CreateRequestAsync(GetMockIdentifier(protocol.ProtocolVersion), RPRealmUri, RPUri); + + request.IsExtensionOnly = true; + var redirectRequest = await request.GetRedirectingResponseAsync(); + Uri redirectResponseUrl; + using (var httpClient = rp.Channel.HostFactories.CreateHttpClient()) { + using (var redirectResponse = await httpClient.GetAsync(redirectRequest.Headers.Location)) { + redirectResponse.EnsureSuccessStatusCode(); + redirectResponseUrl = redirectRequest.Headers.Location; + } + } + + IAuthenticationResponse response = + await rp.GetResponseAsync(new HttpRequestMessage(HttpMethod.Get, redirectResponseUrl)); + Assert.AreEqual(AuthenticationStatus.ExtensionsOnly, response.Status); + } } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs b/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs index f4a9684..6fd0240 100644 --- a/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs +++ b/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs @@ -8,6 +8,7 @@ namespace DotNetOpenAuth.Test.OpenId { using System; using System.Collections.Generic; using System.IO; + using System.Linq; using System.Net.Http; using System.Reflection; using System.Threading; @@ -17,9 +18,13 @@ namespace DotNetOpenAuth.Test.OpenId { using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OpenId; + using DotNetOpenAuth.OpenId.Messages; using DotNetOpenAuth.OpenId.Provider; using DotNetOpenAuth.OpenId.RelyingParty; + using DotNetOpenAuth.Test.Messaging; using DotNetOpenAuth.Test.Mocks; + using DotNetOpenAuth.Test.OpenId.Extensions; + using NUnit.Framework; using IAuthenticationRequest = DotNetOpenAuth.OpenId.Provider.IAuthenticationRequest; @@ -141,14 +146,12 @@ namespace DotNetOpenAuth.Test.OpenId { /// <remarks> /// This is a very useful method to pass to the OpenIdCoordinator constructor for the Provider argument. /// </remarks> - internal TestBase.Handler AutoProvider { - get { - return Handle(OPUri).By( - async (req, ct) => { - var provider = new OpenIdProvider(new StandardProviderApplicationStore()); - return await this.AutoProviderActionAsync(provider, req, ct); - }); - } + internal void RegisterAutoProvider() { + this.Handle(OPUri).By( + async (req, ct) => { + var provider = new OpenIdProvider(new StandardProviderApplicationStore(), this.HostFactories); + return await this.AutoProviderActionAsync(provider, req, ct); + }); } /// <summary> @@ -207,7 +210,7 @@ namespace DotNetOpenAuth.Test.OpenId { protected Identifier GetMockIdentifier(ProtocolVersion providerVersion, bool useSsl, bool delegating) { var se = GetServiceEndpoint(0, providerVersion, 10, useSsl, delegating); - this.HostFactories.Handlers.Add(MockHttpRequest.RegisterMockXrdsResponse(se)); + this.RegisterMockXrdsResponse(se); return se.ClaimedIdentifier; } @@ -219,7 +222,7 @@ namespace DotNetOpenAuth.Test.OpenId { IdentifierDiscoveryResult.CreateForProviderIdentifier(protocol.ClaimedIdentifierForOPIdentifier, opDesc, 20, 20), }; - this.HostFactories.Handlers.Add(MockHttpRequest.RegisterMockXrdsResponse(VanityUri, dualResults)); + this.RegisterMockXrdsResponse(VanityUri, dualResults); return VanityUri; } @@ -250,25 +253,84 @@ namespace DotNetOpenAuth.Test.OpenId { return op; } - protected internal static Func<IHostFactories, CancellationToken, Task> RelyingPartyDriver(Func<OpenIdRelyingParty, CancellationToken, Task> relyingPartyDriver) { - return async (hostFactories, ct) => { - var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore(), hostFactories); - await relyingPartyDriver(rp, ct); - }; - } - - protected internal static Func<IHostFactories, CancellationToken, Task> ProviderDriver(Func<OpenIdProvider, CancellationToken, Task> providerDriver) { - return async (hostFactories, ct) => { - var op = new OpenIdProvider(new StandardProviderApplicationStore(), hostFactories); - await providerDriver(op, ct); - }; - } - - protected internal static Handler HandleProvider(Func<OpenIdProvider, HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> provider) { - return Handle(OPUri).By(async (req, ct) => { + protected internal void HandleProvider(Func<OpenIdProvider, HttpRequestMessage, Task<HttpResponseMessage>> provider) { + this.Handle(OPUri).By(async req => { var op = new OpenIdProvider(new StandardProviderApplicationStore()); - return await provider(op, req, ct); + return await provider(op, req); }); } + + /// <summary> + /// Simulates an extension request and response. + /// </summary> + /// <param name="protocol">The protocol to use in the roundtripping.</param> + /// <param name="requests">The extensions to add to the request message.</param> + /// <param name="responses">The extensions to add to the response message.</param> + /// <remarks> + /// This method relies on the extension objects' Equals methods to verify + /// accurate transport. The Equals methods should be verified by separate tests. + /// </remarks> + internal async Task RoundtripAsync( + Protocol protocol, IEnumerable<IOpenIdMessageExtension> requests, IEnumerable<IOpenIdMessageExtension> responses) { + var securitySettings = new ProviderSecuritySettings(); + var cryptoKeyStore = new MemoryCryptoKeyStore(); + var associationStore = new ProviderAssociationHandleEncoder(cryptoKeyStore); + Association association = HmacShaAssociationProvider.Create( + protocol, + protocol.Args.SignatureAlgorithm.Best, + AssociationRelyingPartyType.Smart, + associationStore, + securitySettings); + + this.HandleProvider( + async (op, req) => { + ExtensionTestUtilities.RegisterExtension(op.Channel, Mocks.MockOpenIdExtension.Factory); + var key = cryptoKeyStore.GetCurrentKey( + ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, TimeSpan.FromSeconds(1)); + op.CryptoKeyStore.StoreKey( + ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, key.Key, key.Value); + var request = await op.Channel.ReadFromRequestAsync<CheckIdRequest>(req, CancellationToken.None); + var response = new PositiveAssertionResponse(request); + var receivedRequests = request.Extensions.Cast<IOpenIdMessageExtension>(); + CollectionAssert<IOpenIdMessageExtension>.AreEquivalentByEquality(requests.ToArray(), receivedRequests.ToArray()); + + foreach (var extensionResponse in responses) { + response.Extensions.Add(extensionResponse); + } + + return await op.Channel.PrepareResponseAsync(response); + }); + + { + var rp = this.CreateRelyingParty(); + ExtensionTestUtilities.RegisterExtension(rp.Channel, Mocks.MockOpenIdExtension.Factory); + var requestBase = new CheckIdRequest(protocol.Version, OpenIdTestBase.OPUri, AuthenticationRequestMode.Immediate); + OpenIdTestBase.StoreAssociation(rp, OpenIdTestBase.OPUri, association); + requestBase.AssociationHandle = association.Handle; + requestBase.ClaimedIdentifier = "http://claimedid"; + requestBase.LocalIdentifier = "http://localid"; + requestBase.ReturnTo = OpenIdTestBase.RPUri; + + foreach (IOpenIdMessageExtension extension in requests) { + requestBase.Extensions.Add(extension); + } + + var redirectingRequest = await rp.Channel.PrepareResponseAsync(requestBase); + Uri redirectingResponseUri; + using (var httpClient = rp.Channel.HostFactories.CreateHttpClient()) { + using (var redirectingResponse = await httpClient.GetAsync(redirectingRequest.Headers.Location)) { + redirectingResponse.EnsureSuccessStatusCode(); + redirectingResponseUri = redirectingResponse.Headers.Location; + } + } + + var response = + await + rp.Channel.ReadFromRequestAsync<PositiveAssertionResponse>( + new HttpRequestMessage(HttpMethod.Get, redirectingResponseUri), CancellationToken.None); + var receivedResponses = response.Extensions.Cast<IOpenIdMessageExtension>(); + CollectionAssert<IOpenIdMessageExtension>.AreEquivalentByEquality(responses.ToArray(), receivedResponses.ToArray()); + } + } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/Provider/HostProcessedRequestTests.cs b/src/DotNetOpenAuth.Test/OpenId/Provider/HostProcessedRequestTests.cs index dc692fc..da883f5 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Provider/HostProcessedRequestTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Provider/HostProcessedRequestTests.cs @@ -41,12 +41,10 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { [Test] public async Task IsReturnUrlDiscoverableValidResponse() { - await RunAsync( - async (hostFactories, ct) => { - this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); - Assert.AreEqual(RelyingPartyDiscoveryResult.Success, await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); - }, - MockHttpRequest.RegisterMockRPDiscovery(false)); + this.RegisterMockRPDiscovery(false); + + this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); + Assert.AreEqual(RelyingPartyDiscoveryResult.Success, await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); } /// <summary> @@ -55,12 +53,9 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { /// </summary> [Test] public async Task IsReturnUrlDiscoverableNotSsl() { - await RunAsync( - async (hostFactories, ct) => { - this.provider.SecuritySettings.RequireSsl = true; - Assert.AreEqual(RelyingPartyDiscoveryResult.NoServiceDocument, await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); - }, - MockHttpRequest.RegisterMockRPDiscovery(false)); + this.RegisterMockRPDiscovery(false); + this.provider.SecuritySettings.RequireSsl = true; + Assert.AreEqual(RelyingPartyDiscoveryResult.NoServiceDocument, await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); } /// <summary> @@ -68,34 +63,30 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { /// </summary> [Test] public async Task IsReturnUrlDiscoverableRequireSsl() { - await RunAsync( - async (hostFactories, ct) => { - this.checkIdRequest.Realm = RPRealmUriSsl; - this.checkIdRequest.ReturnTo = RPUriSsl; + this.RegisterMockRPDiscovery(false); + this.checkIdRequest.Realm = RPRealmUriSsl; + this.checkIdRequest.ReturnTo = RPUriSsl; - // Try once with RequireSsl - this.provider.SecuritySettings.RequireSsl = true; - this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); - Assert.AreEqual(RelyingPartyDiscoveryResult.Success, await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); + // Try once with RequireSsl + this.provider.SecuritySettings.RequireSsl = true; + this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); + Assert.AreEqual(RelyingPartyDiscoveryResult.Success, await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); - // And again without RequireSsl - this.provider.SecuritySettings.RequireSsl = false; - this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); - Assert.AreEqual(RelyingPartyDiscoveryResult.Success, await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); - }, - MockHttpRequest.RegisterMockRPDiscovery(false)); + // And again without RequireSsl + this.provider.SecuritySettings.RequireSsl = false; + this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); + Assert.AreEqual(RelyingPartyDiscoveryResult.Success, await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); } [Test] public async Task IsReturnUrlDiscoverableValidButNoMatch() { - await RunAsync( - async (hostFactories, ct) => { - this.provider.SecuritySettings.RequireSsl = false; // reset for another failure test case - this.checkIdRequest.ReturnTo = new Uri("http://somerandom/host"); - this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); - Assert.AreEqual(RelyingPartyDiscoveryResult.NoMatchingReturnTo, await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); - }, - MockHttpRequest.RegisterMockRPDiscovery(false)); + this.RegisterMockRPDiscovery(false); + this.provider.SecuritySettings.RequireSsl = false; // reset for another failure test case + this.checkIdRequest.ReturnTo = new Uri("http://somerandom/host"); + this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); + Assert.AreEqual( + RelyingPartyDiscoveryResult.NoMatchingReturnTo, + await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/Provider/OpenIdProviderTests.cs b/src/DotNetOpenAuth.Test/OpenId/Provider/OpenIdProviderTests.cs index 1033198..4780e37 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Provider/OpenIdProviderTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Provider/OpenIdProviderTests.cs @@ -100,34 +100,28 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { var providerDescription = new ProviderEndpointDescription(OPUri, Protocol.Default.Version); // Test some non-empty request scenario. - var coordinator = new CoordinatorBase( - RelyingPartyDriver(async (rp, ct) => { - await rp.Channel.RequestAsync(AssociateRequestRelyingParty.Create(rp.SecuritySettings, providerDescription), ct); - }), - HandleProvider(async (op, req, ct) => { + HandleProvider( + async (op, req) => { IRequest request = await op.GetRequestAsync(req); Assert.IsInstanceOf<AutoResponsiveRequest>(request); - return await op.PrepareResponseAsync(request, ct); - })); - await coordinator.RunAsync(); + return await op.PrepareResponseAsync(request); + }); + var rp = this.CreateRelyingParty(); + await rp.Channel.RequestAsync(AssociateRequestRelyingParty.Create(rp.SecuritySettings, providerDescription), CancellationToken.None); } [Test] public async Task BadRequestsGenerateValidErrorResponses() { - var coordinator = new CoordinatorBase( - RelyingPartyDriver(async (rp, ct) => { - var nonOpenIdMessage = new Mocks.TestDirectedMessage { - Recipient = OPUri, - HttpMethods = HttpDeliveryMethods.PostRequest - }; - MessagingTestBase.GetStandardTestMessage(MessagingTestBase.FieldFill.AllRequired, nonOpenIdMessage); - var response = await rp.Channel.RequestAsync<DirectErrorResponse>(nonOpenIdMessage, ct); - Assert.IsNotNull(response.ErrorMessage); - Assert.AreEqual(Protocol.Default.Version, response.Version); - }), - AutoProvider); - - await coordinator.RunAsync(); + this.RegisterAutoProvider(); + var rp = this.CreateRelyingParty(); + var nonOpenIdMessage = new Mocks.TestDirectedMessage { + Recipient = OPUri, + HttpMethods = HttpDeliveryMethods.PostRequest + }; + MessagingTestBase.GetStandardTestMessage(MessagingTestBase.FieldFill.AllRequired, nonOpenIdMessage); + var response = await rp.Channel.RequestAsync<DirectErrorResponse>(nonOpenIdMessage, CancellationToken.None); + Assert.IsNotNull(response.ErrorMessage); + Assert.AreEqual(Protocol.Default.Version, response.Version); } [Test, Category("HostASPNET")] diff --git a/src/DotNetOpenAuth.Test/OpenId/RelyingParty/AuthenticationRequestTests.cs b/src/DotNetOpenAuth.Test/OpenId/RelyingParty/AuthenticationRequestTests.cs index af66bed..333169f 100644 --- a/src/DotNetOpenAuth.Test/OpenId/RelyingParty/AuthenticationRequestTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/RelyingParty/AuthenticationRequestTests.cs @@ -73,37 +73,34 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// </summary> [Test] public async Task CreateRequestMessage() { - var coordinator = new CoordinatorBase( - RelyingPartyDriver(async (rp, ct) => { - Identifier id = this.GetMockIdentifier(ProtocolVersion.V20); - IAuthenticationRequest authRequest = await rp.CreateRequestAsync(id, this.realm, this.returnTo); - - // Add some callback arguments - authRequest.AddCallbackArguments("a", "b"); - authRequest.AddCallbackArguments(new Dictionary<string, string> { { "c", "d" }, { "e", "f" } }); - - // Assembly an extension request. - var sregRequest = new ClaimsRequest(); - sregRequest.Nickname = DemandLevel.Request; - authRequest.AddExtension(sregRequest); - - // Construct the actual authentication request message. - var authRequestAccessor = (AuthenticationRequest)authRequest; - var req = await authRequestAccessor.CreateRequestMessageTestHookAsync(ct); - Assert.IsNotNull(req); - - // Verify that callback arguments were included. - NameValueCollection callbackArguments = HttpUtility.ParseQueryString(req.ReturnTo.Query); - Assert.AreEqual("b", callbackArguments["a"]); - Assert.AreEqual("d", callbackArguments["c"]); - Assert.AreEqual("f", callbackArguments["e"]); - - // Verify that extensions were included. - Assert.AreEqual(1, req.Extensions.Count); - Assert.IsTrue(req.Extensions.Contains(sregRequest)); - }), - AutoProvider); - await coordinator.RunAsync(); + this.RegisterAutoProvider(); + var rp = this.CreateRelyingParty(); + Identifier id = this.GetMockIdentifier(ProtocolVersion.V20); + IAuthenticationRequest authRequest = await rp.CreateRequestAsync(id, this.realm, this.returnTo); + + // Add some callback arguments + authRequest.AddCallbackArguments("a", "b"); + authRequest.AddCallbackArguments(new Dictionary<string, string> { { "c", "d" }, { "e", "f" } }); + + // Assembly an extension request. + var sregRequest = new ClaimsRequest(); + sregRequest.Nickname = DemandLevel.Request; + authRequest.AddExtension(sregRequest); + + // Construct the actual authentication request message. + var authRequestAccessor = (AuthenticationRequest)authRequest; + var req = await authRequestAccessor.CreateRequestMessageTestHookAsync(CancellationToken.None); + Assert.IsNotNull(req); + + // Verify that callback arguments were included. + NameValueCollection callbackArguments = HttpUtility.ParseQueryString(req.ReturnTo.Query); + Assert.AreEqual("b", callbackArguments["a"]); + Assert.AreEqual("d", callbackArguments["c"]); + Assert.AreEqual("f", callbackArguments["e"]); + + // Verify that extensions were included. + Assert.AreEqual(1, req.Extensions.Count); + Assert.IsTrue(req.Extensions.Contains(sregRequest)); } /// <summary> diff --git a/src/DotNetOpenAuth.Test/OpenId/RelyingParty/OpenIdRelyingPartyTests.cs b/src/DotNetOpenAuth.Test/OpenId/RelyingParty/OpenIdRelyingPartyTests.cs index c2c5db5..13049db 100644 --- a/src/DotNetOpenAuth.Test/OpenId/RelyingParty/OpenIdRelyingPartyTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/RelyingParty/OpenIdRelyingPartyTests.cs @@ -89,26 +89,18 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { [Test, ExpectedException(typeof(ProtocolException))] public async Task CreateRequestOnNonOpenID() { var nonOpenId = new Uri("http://www.microsoft.com/"); - var coordinator = new CoordinatorBase( - RelyingPartyDriver( - async (rp, ct) => { - await rp.CreateRequestAsync(nonOpenId, RPRealmUri, RPUri); - }), - Handle(nonOpenId).By("<html/>", "text/html")); - await coordinator.RunAsync(); + Handle(nonOpenId).By("<html/>", "text/html"); + var rp = this.CreateRelyingParty(); + await rp.CreateRequestAsync(nonOpenId, RPRealmUri, RPUri); } [Test] public async Task CreateRequestsOnNonOpenID() { var nonOpenId = new Uri("http://www.microsoft.com/"); - var coordinator = new CoordinatorBase( - RelyingPartyDriver( - async (rp, ct) => { - var requests = await rp.CreateRequestsAsync(nonOpenId, RPRealmUri, RPUri); - Assert.AreEqual(0, requests.Count()); - }), - Handle(nonOpenId).By("<html/>", "text/html")); - await coordinator.RunAsync(); + Handle(nonOpenId).By("<html/>", "text/html"); + var rp = this.CreateRelyingParty(); + var requests = await rp.CreateRequestsAsync(nonOpenId, RPRealmUri, RPUri); + Assert.AreEqual(0, requests.Count()); } /// <summary> @@ -118,35 +110,29 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { [Test] public async Task AssertionWithEndpointFilter() { var opStore = new StandardProviderApplicationStore(); - var coordinator = new CoordinatorBase( - async (hostFactories, ct) => { - var op = new OpenIdProvider(opStore); - Identifier id = GetMockIdentifier(ProtocolVersion.V20); - var assertion = await op.PrepareUnsolicitedAssertionAsync(OPUri, GetMockRealm(false), id, id, ct); - using (var httpClient = hostFactories.CreateHttpClient()) { - using (var response = await httpClient.GetAsync(assertion.Headers.Location, ct)) { - response.EnsureSuccessStatusCode(); - } + Handle(RPRealmUri).By( + async req => { + var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore()); + + // Rig it to always deny the incoming OP + rp.EndpointFilter = op => false; + + // Receive the unsolicited assertion + var response = await rp.GetResponseAsync(req); + Assert.AreEqual(AuthenticationStatus.Failed, response.Status); + return new HttpResponseMessage(); + }); + this.RegisterAutoProvider(); + { + var op = new OpenIdProvider(opStore); + Identifier id = GetMockIdentifier(ProtocolVersion.V20); + var assertion = await op.PrepareUnsolicitedAssertionAsync(OPUri, GetMockRealm(false), id, id); + using (var httpClient = this.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(assertion.Headers.Location)) { + response.EnsureSuccessStatusCode(); } - }, - Handle(RPRealmUri).By( - async (hostFactories, req, ct) => { - var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore()); - - // register with RP so that id discovery passes - // TODO: Fix this - ////rp.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; - - // Rig it to always deny the incoming OP - rp.EndpointFilter = op => false; - - // Receive the unsolicited assertion - var response = await rp.GetResponseAsync(req, ct); - Assert.AreEqual(AuthenticationStatus.Failed, response.Status); - return new HttpResponseMessage(); - }), - AutoProvider); - await coordinator.RunAsync(); + } + } } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/RelyingParty/PositiveAuthenticationResponseTests.cs b/src/DotNetOpenAuth.Test/OpenId/RelyingParty/PositiveAuthenticationResponseTests.cs index a998ef2..a39c425 100644 --- a/src/DotNetOpenAuth.Test/OpenId/RelyingParty/PositiveAuthenticationResponseTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/RelyingParty/PositiveAuthenticationResponseTests.cs @@ -131,17 +131,15 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { string claimed_id = BaseMockUri + "a./b."; var se = IdentifierDiscoveryResult.CreateForClaimedIdentifier(claimed_id, claimed_id, providerEndpoint, null, null); var identityUri = (UriIdentifier)se.ClaimedIdentifier; - var coordinator = new CoordinatorBase( - RelyingPartyDriver(async (rp, ct) => { - var positiveAssertion = this.GetPositiveAssertion(); - positiveAssertion.ClaimedIdentifier = claimed_id; - positiveAssertion.LocalIdentifier = claimed_id; - var authResponse = new PositiveAuthenticationResponse(positiveAssertion, rp); - Assert.AreEqual(AuthenticationStatus.Authenticated, authResponse.Status); - Assert.AreEqual(claimed_id, authResponse.ClaimedIdentifier.ToString()); - }), - MockHttpRequest.RegisterMockXrdsResponse(se)); - await coordinator.RunAsync(); + this.RegisterMockXrdsResponse(se); + + var rp = this.CreateRelyingParty(); + var positiveAssertion = this.GetPositiveAssertion(); + positiveAssertion.ClaimedIdentifier = claimed_id; + positiveAssertion.LocalIdentifier = claimed_id; + var authResponse = new PositiveAuthenticationResponse(positiveAssertion, rp); + Assert.AreEqual(AuthenticationStatus.Authenticated, authResponse.Status); + Assert.AreEqual(claimed_id, authResponse.ClaimedIdentifier.ToString()); } private PositiveAssertionResponse GetPositiveAssertion() { diff --git a/src/DotNetOpenAuth.Test/TestBase.cs b/src/DotNetOpenAuth.Test/TestBase.cs index 875af7a..96ce98d 100644 --- a/src/DotNetOpenAuth.Test/TestBase.cs +++ b/src/DotNetOpenAuth.Test/TestBase.cs @@ -114,45 +114,52 @@ namespace DotNetOpenAuth.Test { new HttpResponse(new StringWriter())); } - protected internal static Task RunAsync(Func<IHostFactories, CancellationToken, Task> driver, params Handler[] handlers) { - var coordinator = new CoordinatorBase(driver, handlers); - return coordinator.RunAsync(); + protected internal Handler Handle(string uri) { + return new Handler(this, new Uri(uri)); } - protected internal static Handler Handle(Uri uri) { - return new Handler(uri); + protected internal Handler Handle(Uri uri) { + return new Handler(this, uri); } protected internal struct Handler { - internal Handler(Uri uri) + private TestBase test; + + internal Handler(TestBase test, Uri uri) : this() { + this.test = test; this.Uri = uri; } - public Uri Uri { get; private set; } + private Handler(Handler previous, Func<HttpRequestMessage, Task<HttpResponseMessage>> handler) + : this(previous.test, previous.Uri) { + this.MessageHandler = handler; + } + + internal Uri Uri { get; private set; } - public Func<IHostFactories, HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> MessageHandler { get; private set; } + internal Func<HttpRequestMessage, Task<HttpResponseMessage>> MessageHandler { get; private set; } - internal Handler By(Func<IHostFactories, HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> handler) { - return new Handler(this.Uri) { MessageHandler = handler }; + internal void By(Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> handler) { + this.test.HostFactories.Handlers.Add(new Handler(this, req => handler(req, CancellationToken.None))); } - internal Handler By(Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> handler) { - return this.By((hf, req, ct) => handler(req, ct)); + internal void By(Func<HttpRequestMessage, Task<HttpResponseMessage>> handler) { + this.test.HostFactories.Handlers.Add(new Handler(this, handler)); } - internal Handler By(Func<HttpRequestMessage, HttpResponseMessage> handler) { - return this.By((req, ct) => Task.FromResult(handler(req))); + internal void By(Func<HttpRequestMessage, HttpResponseMessage> handler) { + this.By(req => Task.FromResult(handler(req))); } - internal Handler By(string responseContent, string contentType, HttpStatusCode statusCode = HttpStatusCode.OK) { - return this.By( - req => { - var response = new HttpResponseMessage(statusCode); - response.Content = new StringContent(responseContent); - response.Content.Headers.ContentType = new MediaTypeHeaderValue(contentType); - return response; - }); + internal void By(string responseContent, string contentType, HttpStatusCode statusCode = HttpStatusCode.OK) { + this.By( + req => { + var response = new HttpResponseMessage(statusCode); + response.Content = new StringContent(responseContent); + response.Content.Headers.ContentType = new MediaTypeHeaderValue(contentType); + return response; + }); } } } |