diff options
Diffstat (limited to 'src/DotNetOpenAuth.Test/OpenId')
13 files changed, 622 insertions, 422 deletions
diff --git a/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs b/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs index 4a43142..e293011 100644 --- a/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs @@ -24,13 +24,13 @@ namespace DotNetOpenAuth.Test.OpenId { } [Test] - public void AssociateUnencrypted() { - this.ParameterizedAssociationTest(new Uri("https://host")); + public async Task AssociateUnencrypted() { + await this.ParameterizedAssociationTestAsync(new Uri("https://host")); } [Test] - public void AssociateDiffieHellmanOverHttp() { - this.ParameterizedAssociationTest(new Uri("http://host")); + public async Task AssociateDiffieHellmanOverHttp() { + await this.ParameterizedAssociationTestAsync(new Uri("http://host")); } /// <summary> @@ -124,15 +124,15 @@ namespace DotNetOpenAuth.Test.OpenId { public async Task OPRejectsHttpNoEncryptionAssociateRequests() { Protocol protocol = Protocol.V20; var coordinator = new CoordinatorBase( - rp => { + CoordinatorBase.RelyingPartyDriver(async (rp, ct) => { // We have to formulate the associate request manually, - // since the DNOI RP won't voluntarily suggest no encryption at all. + // since the DNOA RP won't voluntarily suggest no encryption at all. var request = new AssociateUnencryptedRequestNoSslCheck(protocol.Version, OPUri); request.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256; request.SessionType = protocol.Args.SessionType.NoEncryption; - var response = rp.Channel.Request<DirectErrorResponse>(request); + var response = await rp.Channel.RequestAsync<DirectErrorResponse>(request, ct); Assert.IsNotNull(response); - }, + }), AutoProvider); await coordinator.RunAsync(); } @@ -145,18 +145,18 @@ namespace DotNetOpenAuth.Test.OpenId { public async Task OPRejectsMismatchingAssociationAndSessionTypes() { Protocol protocol = Protocol.V20; var coordinator = new CoordinatorBase( - rp => { + CoordinatorBase.RelyingPartyDriver(async (rp, ct) => { // We have to formulate the associate request manually, // since the DNOI RP won't voluntarily mismatch the association and session types. AssociateDiffieHellmanRequest request = new AssociateDiffieHellmanRequest(protocol.Version, new Uri("https://Provider")); request.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256; request.SessionType = protocol.Args.SessionType.DH_SHA1; request.InitializeRequest(); - var response = rp.Channel.Request<AssociateUnsuccessfulResponse>(request); + var response = await rp.Channel.RequestAsync<AssociateUnsuccessfulResponse>(request, ct); Assert.IsNotNull(response); Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA1, response.AssociationType); Assert.AreEqual(protocol.Args.SessionType.DH_SHA1, response.SessionType); - }, + }), AutoProvider); await coordinator.RunAsync(); } @@ -168,20 +168,20 @@ namespace DotNetOpenAuth.Test.OpenId { public async Task RPRejectsUnrecognizedAssociationType() { Protocol protocol = Protocol.V20; var coordinator = new CoordinatorBase( - rp => { - var association = rp.AssociationManager.GetOrCreateAssociation(new ProviderEndpointDescription(OPUri, protocol.Version)); + CoordinatorBase.RelyingPartyDriver(async (rp, ct) => { + var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), ct); Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); - }, - op => { + }), + CoordinatorBase.HandleProvider(async (op, req, ct) => { // Receive initial request. - var request = op.Channel.ReadFromRequest<AssociateRequest>(); + var request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, ct); // Send a response that suggests a foreign association type. - AssociateUnsuccessfulResponse renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); + var renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); renegotiateResponse.AssociationType = "HMAC-UNKNOWN"; renegotiateResponse.SessionType = "DH-UNKNOWN"; - op.Channel.Respond(renegotiateResponse); - }); + return await op.Channel.PrepareResponseAsync(renegotiateResponse, ct); + })); await coordinator.RunAsync(); } @@ -195,20 +195,20 @@ namespace DotNetOpenAuth.Test.OpenId { public async Task RPRejectsUnencryptedSuggestion() { Protocol protocol = Protocol.V20; var coordinator = new CoordinatorBase( - rp => { - var association = rp.AssociationManager.GetOrCreateAssociation(new ProviderEndpointDescription(OPUri, protocol.Version)); + CoordinatorBase.RelyingPartyDriver(async (rp, ct) => { + var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), ct); Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); - }, - op => { + }), + CoordinatorBase.HandleProvider(async (op, req, ct) => { // Receive initial request. - var request = op.Channel.ReadFromRequest<AssociateRequest>(); + var request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, ct); // Send a response that suggests a no encryption. - AssociateUnsuccessfulResponse renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); + var renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); renegotiateResponse.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA1; renegotiateResponse.SessionType = protocol.Args.SessionType.NoEncryption; - op.Channel.Respond(renegotiateResponse); - }); + return await op.Channel.PrepareResponseAsync(renegotiateResponse, ct); + })); await coordinator.RunAsync(); } @@ -220,20 +220,20 @@ namespace DotNetOpenAuth.Test.OpenId { public async Task RPRejectsMismatchingAssociationAndSessionBitLengths() { Protocol protocol = Protocol.V20; var coordinator = new CoordinatorBase( - rp => { - var association = rp.AssociationManager.GetOrCreateAssociation(new ProviderEndpointDescription(OPUri, protocol.Version)); + CoordinatorBase.RelyingPartyDriver(async (rp, ct) => { + var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), ct); Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); - }, - op => { + }), + CoordinatorBase.HandleProvider(async (op, req, ct) => { // Receive initial request. - var request = op.Channel.ReadFromRequest<AssociateRequest>(); + var request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, ct); // Send a mismatched response AssociateUnsuccessfulResponse renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); renegotiateResponse.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA1; renegotiateResponse.SessionType = protocol.Args.SessionType.DH_SHA256; - op.Channel.Respond(renegotiateResponse); - }); + return await op.Channel.PrepareResponseAsync(renegotiateResponse, ct); + })); await coordinator.RunAsync(); } @@ -244,30 +244,35 @@ namespace DotNetOpenAuth.Test.OpenId { [Test] public async Task RPOnlyRenegotiatesOnce() { Protocol protocol = Protocol.V20; + int opStep = 0; var coordinator = new CoordinatorBase( - rp => { - var association = rp.AssociationManager.GetOrCreateAssociation(new ProviderEndpointDescription(OPUri, protocol.Version)); + CoordinatorBase.RelyingPartyDriver(async (rp, ct) => { + var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), ct); Assert.IsNull(association, "The RP should quietly give up when the OP misbehaves."); - }, - op => { - // Receive initial request. - var request = op.Channel.ReadFromRequest<AssociateRequest>(); + }), + CoordinatorBase.HandleProvider(async (op, req, ct) => { + switch (++opStep) { + case 1: + // Receive initial request. + var request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, ct); - // Send a renegotiate response - AssociateUnsuccessfulResponse renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); - renegotiateResponse.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA1; - renegotiateResponse.SessionType = protocol.Args.SessionType.DH_SHA1; - op.Channel.Respond(renegotiateResponse); + // Send a renegotiate response + var renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); + renegotiateResponse.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA1; + renegotiateResponse.SessionType = protocol.Args.SessionType.DH_SHA1; + return await op.Channel.PrepareResponseAsync(renegotiateResponse, ct); - // Receive second-try - request = op.Channel.ReadFromRequest<AssociateRequest>(); + case 2: + // Receive second-try + request = await op.Channel.ReadFromRequestAsync<AssociateRequest>(req, ct); - // Send ANOTHER renegotiate response, at which point the DNOI RP should give up. - renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); - renegotiateResponse.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256; - renegotiateResponse.SessionType = protocol.Args.SessionType.DH_SHA256; - op.Channel.Respond(renegotiateResponse); - }); + // Send ANOTHER renegotiate response, at which point the DNOI RP should give up. + renegotiateResponse = new AssociateUnsuccessfulResponse(request.Version, request); + renegotiateResponse.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256; + renegotiateResponse.SessionType = protocol.Args.SessionType.DH_SHA256; + return await op.Channel.PrepareResponseAsync(renegotiateResponse, ct); + } + })); await coordinator.RunAsync(); } @@ -278,15 +283,15 @@ namespace DotNetOpenAuth.Test.OpenId { public async Task AssociateRenegotiateLimitedByRPSecuritySettings() { Protocol protocol = Protocol.V20; var coordinator = new CoordinatorBase( - rp => { + CoordinatorBase.RelyingPartyDriver(async (rp, ct) => { rp.SecuritySettings.MinimumHashBitLength = 256; - var association = rp.AssociationManager.GetOrCreateAssociation(new ProviderEndpointDescription(OPUri, protocol.Version)); + var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, protocol.Version), ct); Assert.IsNull(association, "No association should have been created when RP and OP could not agree on association strength."); - }, - op => { + }), + CoordinatorBase.HandleProvider(async (op, req, ct) => { op.SecuritySettings.MaximumHashBitLength = 160; - AutoProvider(op); - }); + return await AutoProviderActionAsync(op, req, ct); + })); await coordinator.RunAsync(); } @@ -298,7 +303,7 @@ namespace DotNetOpenAuth.Test.OpenId { public async Task AssociateQuietlyFailsAfterHttpError() { this.MockResponder.RegisterMockNotFound(OPUri); var rp = this.CreateRelyingParty(); - var association = rp.AssociationManager.GetOrCreateAssociation(new ProviderEndpointDescription(OPUri, Protocol.V20.Version)); + var association = await rp.AssociationManager.GetOrCreateAssociationAsync(new ProviderEndpointDescription(OPUri, Protocol.V20.Version)); Assert.IsNull(association); } @@ -306,11 +311,11 @@ namespace DotNetOpenAuth.Test.OpenId { /// Runs a parameterized association flow test using all supported OpenID versions. /// </summary> /// <param name="opEndpoint">The OP endpoint to simulate using.</param> - private void ParameterizedAssociationTest(Uri opEndpoint) { + private async Task ParameterizedAssociationTestAsync(Uri opEndpoint) { foreach (Protocol protocol in Protocol.AllPracticalVersions) { var endpoint = new ProviderEndpointDescription(opEndpoint, protocol.Version); var associationType = protocol.Version.Major < 2 ? protocol.Args.SignatureAlgorithm.HMAC_SHA1 : protocol.Args.SignatureAlgorithm.HMAC_SHA256; - this.ParameterizedAssociationTest(endpoint, associationType); + await this.ParameterizedAssociationTestAsync(endpoint, associationType); } } @@ -325,7 +330,7 @@ namespace DotNetOpenAuth.Test.OpenId { /// The value of the openid.assoc_type parameter expected, /// or null if a failure is anticipated. /// </param> - private void ParameterizedAssociationTest( + private async Task ParameterizedAssociationTestAsync( ProviderEndpointDescription opDescription, string expectedAssociationType) { Protocol protocol = Protocol.Lookup(Protocol.Lookup(opDescription.Version).ProtocolVersion); @@ -335,17 +340,17 @@ namespace DotNetOpenAuth.Test.OpenId { AssociateSuccessfulResponse associateSuccessfulResponse = null; AssociateUnsuccessfulResponse associateUnsuccessfulResponse = null; var coordinator = new CoordinatorBase( - rp => { + CoordinatorBase.RelyingPartyDriver(async (rp, ct) => { rp.SecuritySettings = this.RelyingPartySecuritySettings; - rpAssociation = rp.AssociationManager.GetOrCreateAssociation(opDescription); - }, - op => { + rpAssociation = await rp.AssociationManager.GetOrCreateAssociationAsync(opDescription, ct); + }), + CoordinatorBase.HandleProvider(async (op, request, ct) => { op.SecuritySettings = this.ProviderSecuritySettings; - IRequest req = op.GetRequest(); + IRequest req = await op.GetRequestAsync(request, ct); Assert.IsNotNull(req, "Expected incoming request but did not receive it."); Assert.IsTrue(req.IsResponseReady); - op.Respond(req); - }); + return await op.PrepareResponseAsync(req, ct); + })); coordinator.IncomingMessageFilter = message => { Assert.AreSame(opDescription.Version, message.Version, "The message was recognized as version {0} but was expected to be {1}.", message.Version, Protocol.Lookup(opDescription.Version).ProtocolVersion); var associateSuccess = message as AssociateSuccessfulResponse; diff --git a/src/DotNetOpenAuth.Test/OpenId/AuthenticationTests.cs b/src/DotNetOpenAuth.Test/OpenId/AuthenticationTests.cs index acb37fc..6fb0d7a 100644 --- a/src/DotNetOpenAuth.Test/OpenId/AuthenticationTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/AuthenticationTests.cs @@ -6,6 +6,9 @@ namespace DotNetOpenAuth.Test.OpenId { using System; + using System.Net.Http; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OpenId; @@ -24,76 +27,98 @@ namespace DotNetOpenAuth.Test.OpenId { } [Test] - public void SharedAssociationPositive() { - this.ParameterizedAuthenticationTest(true, true, false); + public async Task SharedAssociationPositive() { + await this.ParameterizedAuthenticationTestAsync(true, true, false); } /// <summary> /// Verifies that a shared association protects against tampering. /// </summary> [Test] - public void SharedAssociationTampered() { - this.ParameterizedAuthenticationTest(true, true, true); + public async Task SharedAssociationTampered() { + await this.ParameterizedAuthenticationTestAsync(true, true, true); } [Test] - public void SharedAssociationNegative() { - this.ParameterizedAuthenticationTest(true, false, false); + public async Task SharedAssociationNegative() { + await this.ParameterizedAuthenticationTestAsync(true, false, false); } [Test] - public void PrivateAssociationPositive() { - this.ParameterizedAuthenticationTest(false, true, false); + public async Task PrivateAssociationPositive() { + await this.ParameterizedAuthenticationTestAsync(false, true, false); } /// <summary> /// Verifies that a private association protects against tampering. /// </summary> [Test] - public void PrivateAssociationTampered() { - this.ParameterizedAuthenticationTest(false, true, true); + public async Task PrivateAssociationTampered() { + await this.ParameterizedAuthenticationTestAsync(false, true, true); } [Test] - public void NoAssociationNegative() { - this.ParameterizedAuthenticationTest(false, false, false); + public async Task NoAssociationNegative() { + await this.ParameterizedAuthenticationTestAsync(false, false, false); } [Test] - public void UnsolicitedAssertion() { - this.MockResponder.RegisterMockRPDiscovery(); - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - rp.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; - IAuthenticationResponse response = rp.GetResponse(); - Assert.AreEqual(AuthenticationStatus.Authenticated, response.Status); - }, - op => { - op.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; + public async Task UnsolicitedAssertion() { + var opStore = new StandardProviderApplicationStore(); + var coordinator = new CoordinatorBase( + async (hostFactories, ct) => { + var op = new OpenIdProvider(opStore); Identifier id = GetMockIdentifier(ProtocolVersion.V20); - op.SendUnsolicitedAssertion(OPUri, RPRealmUri, id, OPLocalIdentifiers[0]); - AutoProvider(op); // handle check_auth - }); - coordinator.Run(); + var assertion = await op.PrepareUnsolicitedAssertionAsync(OPUri, RPRealmUri, id, OPLocalIdentifiers[0], ct); + + using (var httpClient = hostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(assertion.Headers.Location)) { + response.EnsureSuccessStatusCode(); + } + } + }, + CoordinatorBase.Handle(RPRealmUri).By(async (hostFactories, req, ct) => { + var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore(), hostFactories); + IAuthenticationResponse response = await rp.GetResponseAsync(); + Assert.AreEqual(AuthenticationStatus.Authenticated, response.Status); + return new HttpResponseMessage(); + }), + CoordinatorBase.Handle(OPUri).By( + async (req, ct) => { + var op = new OpenIdProvider(opStore); + return await this.AutoProviderActionAsync(op, req, ct); + }), + MockHttpRequest.RegisterMockRPDiscovery(ssl: false)); + await coordinator.RunAsync(); } [Test] - public void UnsolicitedAssertionRejected() { - this.MockResponder.RegisterMockRPDiscovery(); - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - rp.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; + public async Task UnsolicitedAssertionRejected() { + var opStore = new StandardProviderApplicationStore(); + var coordinator = new CoordinatorBase( + async (hostFactories, ct) => { + var op = new OpenIdProvider(opStore); + Identifier id = GetMockIdentifier(ProtocolVersion.V20); + var assertion = await op.PrepareUnsolicitedAssertionAsync(OPUri, RPRealmUri, id, OPLocalIdentifiers[0], ct); + using (var httpClient = hostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(assertion.Headers.Location, ct)) { + response.EnsureSuccessStatusCode(); + } + } + }, + CoordinatorBase.Handle(RPRealmUri).By(async (hostFactories, req, ct) => { + var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore(), hostFactories); rp.SecuritySettings.RejectUnsolicitedAssertions = true; - IAuthenticationResponse response = rp.GetResponse(); + IAuthenticationResponse response = await rp.GetResponseAsync(req, ct); Assert.AreEqual(AuthenticationStatus.Failed, response.Status); - }, - op => { - op.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; - Identifier id = GetMockIdentifier(ProtocolVersion.V20); - op.SendUnsolicitedAssertion(OPUri, RPRealmUri, id, OPLocalIdentifiers[0]); - AutoProvider(op); // handle check_auth - }); - coordinator.Run(); + return new HttpResponseMessage(); + }), + CoordinatorBase.Handle(OPUri).By(async (hostFactories, req, ct) => { + var op = new OpenIdProvider(opStore); + return await this.AutoProviderActionAsync(op, req, ct); + }), + MockHttpRequest.RegisterMockRPDiscovery(false)); + await coordinator.RunAsync(); } /// <summary> @@ -101,25 +126,35 @@ namespace DotNetOpenAuth.Test.OpenId { /// when the appropriate security setting is set. /// </summary> [Test] - public void UnsolicitedDelegatingIdentifierRejection() { - this.MockResponder.RegisterMockRPDiscovery(); - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - rp.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; + public async Task UnsolicitedDelegatingIdentifierRejection() { + var opStore = new StandardProviderApplicationStore(); + var coordinator = new CoordinatorBase( + async (hostFactories, ct) => { + var op = new OpenIdProvider(opStore); + Identifier id = GetMockIdentifier(ProtocolVersion.V20, false, true); + var assertion = await op.PrepareUnsolicitedAssertionAsync(OPUri, RPRealmUri, id, OPLocalIdentifiers[0], ct); + using (var httpClient = hostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(assertion.Headers.Location, ct)) { + response.EnsureSuccessStatusCode(); + } + } + }, + CoordinatorBase.Handle(RPRealmUri).By(async (hostFactories, req, ct) => { + var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore(), hostFactories); rp.SecuritySettings.RejectDelegatingIdentifiers = true; - IAuthenticationResponse response = rp.GetResponse(); + IAuthenticationResponse response = await rp.GetResponseAsync(req, ct); Assert.AreEqual(AuthenticationStatus.Failed, response.Status); - }, - op => { - op.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; - Identifier id = GetMockIdentifier(ProtocolVersion.V20, false, true); - op.SendUnsolicitedAssertion(OPUri, RPRealmUri, id, OPLocalIdentifiers[0]); - AutoProvider(op); // handle check_auth - }); - coordinator.Run(); + return new HttpResponseMessage(); + }), + CoordinatorBase.Handle(OPUri).By(async (hostFactories, req, ct) => { + var op = new OpenIdProvider(opStore); + return await this.AutoProviderActionAsync(op, req, ct); + }), + MockHttpRequest.RegisterMockRPDiscovery(false)); + await coordinator.RunAsync(); } - private void ParameterizedAuthenticationTest(bool sharedAssociation, bool positive, bool tamper) { + private async Task ParameterizedAuthenticationTestAsync(bool sharedAssociation, bool positive, bool tamper) { foreach (Protocol protocol in Protocol.AllPracticalVersions) { foreach (bool statelessRP in new[] { false, true }) { if (sharedAssociation && statelessRP) { @@ -129,21 +164,26 @@ namespace DotNetOpenAuth.Test.OpenId { foreach (bool immediate in new[] { false, true }) { TestLogger.InfoFormat("Beginning authentication test scenario. OpenID: {0}, Shared: {1}, positive: {2}, tamper: {3}, stateless: {4}, immediate: {5}", protocol.Version, sharedAssociation, positive, tamper, statelessRP, immediate); - this.ParameterizedAuthenticationTest(protocol, statelessRP, sharedAssociation, positive, immediate, tamper); + await this.ParameterizedAuthenticationTestAsync(protocol, statelessRP, sharedAssociation, positive, immediate, tamper); } } } } - private void ParameterizedAuthenticationTest(Protocol protocol, bool statelessRP, bool sharedAssociation, bool positive, bool immediate, bool tamper) { + private async Task ParameterizedAuthenticationTestAsync(Protocol protocol, bool statelessRP, bool sharedAssociation, bool positive, bool immediate, bool tamper) { Requires.That(!statelessRP || !sharedAssociation, null, "The RP cannot be stateless while sharing an association with the OP."); Requires.That(positive || !tamper, null, "Cannot tamper with a negative response."); var securitySettings = new ProviderSecuritySettings(); var cryptoKeyStore = new MemoryCryptoKeyStore(); var associationStore = new ProviderAssociationHandleEncoder(cryptoKeyStore); Association association = sharedAssociation ? HmacShaAssociationProvider.Create(protocol, protocol.Args.SignatureAlgorithm.Best, AssociationRelyingPartyType.Smart, associationStore, securitySettings) : null; - var coordinator = new OpenIdCoordinator( - rp => { + int opStep = 0; + var coordinator = new CoordinatorBase( + CoordinatorBase.RelyingPartyDriver(async (rp, ct) => { + if (statelessRP) { + rp = new OpenIdRelyingParty(null, rp.Channel.HostFactories); + } + var request = new CheckIdRequest(protocol.Version, OPUri, immediate ? AuthenticationRequestMode.Immediate : AuthenticationRequestMode.Setup); if (association != null) { @@ -155,17 +195,25 @@ namespace DotNetOpenAuth.Test.OpenId { request.LocalIdentifier = "http://localid"; request.ReturnTo = RPUri; request.Realm = RPUri; - rp.Channel.Respond(request); + var redirectRequest = await rp.Channel.PrepareResponseAsync(request, ct); + Uri redirectResponse; + using (var httpClient = rp.Channel.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(redirectRequest.Headers.Location)) { + redirectResponse = response.Headers.Location; + } + } + + var assertionMessage = new HttpRequestMessage(HttpMethod.Get, redirectResponse.AbsoluteUri); if (positive) { if (tamper) { try { - rp.Channel.ReadFromRequest<PositiveAssertionResponse>(); + await rp.Channel.ReadFromRequestAsync<PositiveAssertionResponse>(assertionMessage, ct); Assert.Fail("Expected exception {0} not thrown.", typeof(InvalidSignatureException).Name); } catch (InvalidSignatureException) { TestLogger.InfoFormat("Caught expected {0} exception after tampering with signed data.", typeof(InvalidSignatureException).Name); } } else { - var response = rp.Channel.ReadFromRequest<PositiveAssertionResponse>(); + var response = await rp.Channel.ReadFromRequestAsync<PositiveAssertionResponse>(assertionMessage, ct); Assert.IsNotNull(response); Assert.AreEqual(request.ClaimedIdentifier, response.ClaimedIdentifier); Assert.AreEqual(request.LocalIdentifier, response.LocalIdentifier); @@ -177,15 +225,16 @@ namespace DotNetOpenAuth.Test.OpenId { // When the OP notices the replay we get a generic InvalidSignatureException. // When the RP notices the replay we get a specific ReplayMessageException. try { - CoordinatingChannel channel = (CoordinatingChannel)rp.Channel; - await channel.ReplayAsync(response); + // TODO: fix this. + ////CoordinatingChannel channel = (CoordinatingChannel)rp.Channel; + ////await channel.ReplayAsync(response); Assert.Fail("Expected ProtocolException was not thrown."); } catch (ProtocolException ex) { Assert.IsTrue(ex is ReplayedMessageException || ex is InvalidSignatureException, "A {0} exception was thrown instead of the expected {1} or {2}.", ex.GetType(), typeof(ReplayedMessageException).Name, typeof(InvalidSignatureException).Name); } } } else { - var response = rp.Channel.ReadFromRequest<NegativeAssertionResponse>(); + var response = await rp.Channel.ReadFromRequestAsync<NegativeAssertionResponse>(assertionMessage, ct); Assert.IsNotNull(response); if (immediate) { // Only 1.1 was required to include user_setup_url @@ -196,54 +245,64 @@ namespace DotNetOpenAuth.Test.OpenId { Assert.IsNull(response.UserSetupUrl); } } - }, - op => { + }), + CoordinatorBase.HandleProvider(async (op, req, ct) => { if (association != null) { var key = cryptoKeyStore.GetCurrentKey(ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, TimeSpan.FromSeconds(1)); op.CryptoKeyStore.StoreKey(ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, key.Key, key.Value); } - var request = op.Channel.ReadFromRequest<CheckIdRequest>(); - Assert.IsNotNull(request); - IProtocolMessage response; - if (positive) { - response = new PositiveAssertionResponse(request); - } else { - response = new NegativeAssertionResponse(request, op.Channel); - } - op.Channel.Respond(response); - - if (positive && (statelessRP || !sharedAssociation)) { - var checkauthRequest = op.Channel.ReadFromRequest<CheckAuthenticationRequest>(); - var checkauthResponse = new CheckAuthenticationResponse(checkauthRequest.Version, checkauthRequest); - checkauthResponse.IsValid = checkauthRequest.IsValid; - op.Channel.Respond(checkauthResponse); - - if (!tamper) { - // Respond to the replay attack. - checkauthRequest = op.Channel.ReadFromRequest<CheckAuthenticationRequest>(); - checkauthResponse = new CheckAuthenticationResponse(checkauthRequest.Version, checkauthRequest); - checkauthResponse.IsValid = checkauthRequest.IsValid; - op.Channel.Respond(checkauthResponse); - } + switch (++opStep) { + case 1: + var request = await op.Channel.ReadFromRequestAsync<CheckIdRequest>(req, ct); + Assert.IsNotNull(request); + IProtocolMessage response; + if (positive) { + response = new PositiveAssertionResponse(request); + } else { + response = new NegativeAssertionResponse(request.Version, request.ReturnTo, request.Mode); + } + + return await op.Channel.PrepareResponseAsync(response, ct); + case 2: + if (positive && (statelessRP || !sharedAssociation)) { + var checkauthRequest = await op.Channel.ReadFromRequestAsync<CheckAuthenticationRequest>(req, ct); + var checkauthResponse = new CheckAuthenticationResponse(checkauthRequest.Version, checkauthRequest); + checkauthResponse.IsValid = checkauthRequest.IsValid; + return await op.Channel.PrepareResponseAsync(checkauthResponse, ct); + } + + throw Assumes.NotReachable(); + case 3: + if (positive && (statelessRP || !sharedAssociation)) { + if (!tamper) { + // Respond to the replay attack. + var checkauthRequest = await op.Channel.ReadFromRequestAsync<CheckAuthenticationRequest>(req, ct); + var checkauthResponse = new CheckAuthenticationResponse(checkauthRequest.Version, checkauthRequest); + checkauthResponse.IsValid = checkauthRequest.IsValid; + return await op.Channel.PrepareResponseAsync(checkauthResponse, ct); + } + } + + throw Assumes.NotReachable(); + default: + throw Assumes.NotReachable(); } - }); + })); if (tamper) { - coordinator.IncomingMessageFilter = message => { - var assertion = message as PositiveAssertionResponse; - if (assertion != null) { - // Alter the Local Identifier between the Provider and the Relying Party. - // If the signature binding element does its job, this should cause the RP - // to throw. - assertion.LocalIdentifier = "http://victim"; - } - }; - } - if (statelessRP) { - coordinator.RelyingParty = new OpenIdRelyingParty(null); + // TODO: fix this. + ////coordinator.IncomingMessageFilter = message => { + //// var assertion = message as PositiveAssertionResponse; + //// if (assertion != null) { + //// // Alter the Local Identifier between the Provider and the Relying Party. + //// // If the signature binding element does its job, this should cause the RP + //// // to throw. + //// assertion.LocalIdentifier = "http://victim"; + //// } + ////}; } - coordinator.Run(); + await coordinator.RunAsync(); } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/ChannelElements/ExtensionsBindingElementTests.cs b/src/DotNetOpenAuth.Test/OpenId/ChannelElements/ExtensionsBindingElementTests.cs index dd47782..e60ac3c 100644 --- a/src/DotNetOpenAuth.Test/OpenId/ChannelElements/ExtensionsBindingElementTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/ChannelElements/ExtensionsBindingElementTests.cs @@ -8,12 +8,17 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { using System; using System.Collections.Generic; using System.Linq; + using System.Net.Http; using System.Text.RegularExpressions; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.ChannelElements; using DotNetOpenAuth.OpenId.Extensions; using DotNetOpenAuth.OpenId.Messages; + using DotNetOpenAuth.OpenId.Provider; using DotNetOpenAuth.OpenId.RelyingParty; using DotNetOpenAuth.Test.Mocks; using DotNetOpenAuth.Test.OpenId.Extensions; @@ -53,23 +58,23 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { } [Test] - public void PrepareMessageForSendingNull() { - Assert.IsNull(this.rpElement.ProcessOutgoingMessage(null)); + public async Task PrepareMessageForSendingNull() { + Assert.IsNull(await this.rpElement.ProcessOutgoingMessageAsync(null, CancellationToken.None)); } /// <summary> /// Verifies that false is returned when a non-extendable message is sent. /// </summary> [Test] - public void PrepareMessageForSendingNonExtendableMessage() { + public async Task PrepareMessageForSendingNonExtendableMessage() { IProtocolMessage request = new AssociateDiffieHellmanRequest(Protocol.Default.Version, OpenIdTestBase.OPUri); - Assert.IsNull(this.rpElement.ProcessOutgoingMessage(request)); + Assert.IsNull(await this.rpElement.ProcessOutgoingMessageAsync(request, CancellationToken.None)); } [Test] - public void PrepareMessageForSending() { + public async Task PrepareMessageForSending() { this.request.Extensions.Add(new MockOpenIdExtension("part", "extra")); - Assert.IsNotNull(this.rpElement.ProcessOutgoingMessage(this.request)); + Assert.IsNotNull(await this.rpElement.ProcessOutgoingMessageAsync(this.request, CancellationToken.None)); string alias = GetAliases(this.request.ExtraData).Single(); Assert.AreEqual(MockOpenIdExtension.MockTypeUri, this.request.ExtraData["openid.ns." + alias]); @@ -78,11 +83,11 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { } [Test] - public void PrepareMessageForReceiving() { + public async Task PrepareMessageForReceiving() { this.request.ExtraData["openid.ns.mock"] = MockOpenIdExtension.MockTypeUri; this.request.ExtraData["openid.mock.Part"] = "part"; this.request.ExtraData["openid.mock.data"] = "extra"; - Assert.IsNotNull(this.rpElement.ProcessIncomingMessage(this.request)); + Assert.IsNotNull(await this.rpElement.ProcessIncomingMessageAsync(this.request, CancellationToken.None)); MockOpenIdExtension ext = this.request.Extensions.OfType<MockOpenIdExtension>().Single(); Assert.AreEqual("part", ext.Part); Assert.AreEqual("extra", ext.Data); @@ -92,12 +97,12 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { /// Verifies that extension responses are included in the OP's signature. /// </summary> [Test] - public void ExtensionResponsesAreSigned() { + public async Task ExtensionResponsesAreSigned() { Protocol protocol = Protocol.Default; var op = this.CreateProvider(); IndirectSignedResponse response = this.CreateResponseWithExtensions(protocol); - op.Channel.PrepareResponse(response); - ITamperResistantOpenIdMessage signedResponse = (ITamperResistantOpenIdMessage)response; + await op.Channel.PrepareResponseAsync(response); + ITamperResistantOpenIdMessage signedResponse = response; string extensionAliasKey = signedResponse.ExtraData.Single(kv => kv.Value == MockOpenIdExtension.MockTypeUri).Key; Assert.IsTrue(extensionAliasKey.StartsWith("openid.ns.")); string extensionAlias = extensionAliasKey.Substring("openid.ns.".Length); @@ -114,27 +119,56 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { /// Verifies that unsigned extension responses (where any or all fields are unsigned) are ignored. /// </summary> [Test] - public void ExtensionsAreIdentifiedAsSignedOrUnsigned() { + public async Task ExtensionsAreIdentifiedAsSignedOrUnsigned() { Protocol protocol = Protocol.Default; - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - RegisterMockExtension(rp.Channel); - var response = rp.Channel.ReadFromRequest<IndirectSignedResponse>(); - Assert.AreEqual(1, response.SignedExtensions.Count(), "Signed extension should have been received."); - Assert.AreEqual(0, response.UnsignedExtensions.Count(), "No unsigned extension should be present."); - response = rp.Channel.ReadFromRequest<IndirectSignedResponse>(); - Assert.AreEqual(0, response.SignedExtensions.Count(), "No signed extension should have been received."); - Assert.AreEqual(1, response.UnsignedExtensions.Count(), "Unsigned extension should have been received."); - }, - op => { + var opStore = new StandardProviderApplicationStore(); + int rpStep = 0; + var coordinator = new CoordinatorBase( + async (hostFactories, ct) => { + var op = new OpenIdProvider(opStore); RegisterMockExtension(op.Channel); - op.Channel.Respond(CreateResponseWithExtensions(protocol)); - op.Respond(op.GetRequest()); // check_auth + var redirectingResponse = await op.Channel.PrepareResponseAsync(CreateResponseWithExtensions(protocol)); + using (var httpClient = hostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(redirectingResponse.Headers.Location)) { + response.EnsureSuccessStatusCode(); + } + } + op.SecuritySettings.SignOutgoingExtensions = false; - op.Channel.Respond(CreateResponseWithExtensions(protocol)); - op.Respond(op.GetRequest()); // check_auth - }); - coordinator.Run(); + redirectingResponse = await op.Channel.PrepareResponseAsync(CreateResponseWithExtensions(protocol)); + using (var httpClient = hostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(redirectingResponse.Headers.Location)) { + response.EnsureSuccessStatusCode(); + } + } + }, + CoordinatorBase.Handle(RPRealmUri).By(async (hostFactories, req, ct) => { + var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore(), hostFactories); + RegisterMockExtension(rp.Channel); + + switch (++rpStep) { + case 1: + var response = await rp.Channel.ReadFromRequestAsync<IndirectSignedResponse>(req, ct); + Assert.AreEqual(1, response.SignedExtensions.Count(), "Signed extension should have been received."); + Assert.AreEqual(0, response.UnsignedExtensions.Count(), "No unsigned extension should be present."); + break; + case 2: + response = await rp.Channel.ReadFromRequestAsync<IndirectSignedResponse>(req, ct); + Assert.AreEqual(0, response.SignedExtensions.Count(), "No signed extension should have been received."); + Assert.AreEqual(1, response.UnsignedExtensions.Count(), "Unsigned extension should have been received."); + break; + + default: + throw Assumes.NotReachable(); + } + + return new HttpResponseMessage(); + }), + CoordinatorBase.Handle(OPUri).By(async (hostFactories, req, ct) => { + var op = new OpenIdProvider(opStore); + return await AutoProviderActionAsync(op, req, ct); + })); + await coordinator.RunAsync(); } /// <summary> @@ -181,7 +215,7 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { private IndirectSignedResponse CreateResponseWithExtensions(Protocol protocol) { Requires.NotNull(protocol, "protocol"); - IndirectSignedResponse response = new IndirectSignedResponse(protocol.Version, RPUri); + var response = new IndirectSignedResponse(protocol.Version, RPUri); response.ProviderEndpoint = OPUri; response.Extensions.Add(new MockOpenIdExtension("pv", "ev")); return response; diff --git a/src/DotNetOpenAuth.Test/OpenId/ChannelElements/KeyValueFormEncodingTests.cs b/src/DotNetOpenAuth.Test/OpenId/ChannelElements/KeyValueFormEncodingTests.cs index 93ad028..b64701d 100644 --- a/src/DotNetOpenAuth.Test/OpenId/ChannelElements/KeyValueFormEncodingTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/ChannelElements/KeyValueFormEncodingTests.cs @@ -11,6 +11,9 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { using System.Linq; using System.Net; using System.Text; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Reflection; using DotNetOpenAuth.OpenId.ChannelElements; @@ -56,9 +59,9 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { Assert.AreEqual(this.sampleData.Count, count); } - public void KVDictTest(byte[] kvform, IDictionary<string, string> dict, TestMode mode) { + public async Task KVDictTestAsync(byte[] kvform, IDictionary<string, string> dict, TestMode mode) { if ((mode & TestMode.Decoder) == TestMode.Decoder) { - var d = this.keyValueForm.GetDictionary(new MemoryStream(kvform)); + var d = await this.keyValueForm.GetDictionaryAsync(new MemoryStream(kvform), CancellationToken.None); foreach (string key in dict.Keys) { Assert.AreEqual(d[key], dict[key], "Decoder fault: " + d[key] + " and " + dict[key] + " do not match."); } @@ -70,91 +73,91 @@ namespace DotNetOpenAuth.Test.OpenId.ChannelElements { } [Test] - public void EncodeDecode() { - this.KVDictTest(UTF8Encoding.UTF8.GetBytes(string.Empty), new Dictionary<string, string>(), TestMode.Both); + public async Task EncodeDecode() { + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes(string.Empty), new Dictionary<string, string>(), TestMode.Both); Dictionary<string, string> d1 = new Dictionary<string, string>(); d1.Add("college", "harvey mudd"); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes("college:harvey mudd\n"), d1, TestMode.Both); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes("college:harvey mudd\n"), d1, TestMode.Both); Dictionary<string, string> d2 = new Dictionary<string, string>(); d2.Add("city", "claremont"); d2.Add("state", "CA"); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes("city:claremont\nstate:CA\n"), d2, TestMode.Both); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes("city:claremont\nstate:CA\n"), d2, TestMode.Both); Dictionary<string, string> d3 = new Dictionary<string, string>(); d3.Add("is_valid", "true"); d3.Add("invalidate_handle", "{HMAC-SHA1:2398410938412093}"); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes("is_valid:true\ninvalidate_handle:{HMAC-SHA1:2398410938412093}\n"), d3, TestMode.Both); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes("is_valid:true\ninvalidate_handle:{HMAC-SHA1:2398410938412093}\n"), d3, TestMode.Both); Dictionary<string, string> d4 = new Dictionary<string, string>(); d4.Add(string.Empty, string.Empty); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes(":\n"), d4, TestMode.Both); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes(":\n"), d4, TestMode.Both); Dictionary<string, string> d5 = new Dictionary<string, string>(); d5.Add(string.Empty, "missingkey"); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes(":missingkey\n"), d5, TestMode.Both); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes(":missingkey\n"), d5, TestMode.Both); Dictionary<string, string> d6 = new Dictionary<string, string>(); d6.Add("street", "foothill blvd"); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes("street:foothill blvd\n"), d6, TestMode.Both); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes("street:foothill blvd\n"), d6, TestMode.Both); Dictionary<string, string> d7 = new Dictionary<string, string>(); d7.Add("major", "computer science"); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes("major:computer science\n"), d7, TestMode.Both); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes("major:computer science\n"), d7, TestMode.Both); Dictionary<string, string> d8 = new Dictionary<string, string>(); d8.Add("dorm", "east"); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes(" dorm : east \n"), d8, TestMode.Decoder); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes(" dorm : east \n"), d8, TestMode.Decoder); Dictionary<string, string> d9 = new Dictionary<string, string>(); d9.Add("e^(i*pi)+1", "0"); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes("e^(i*pi)+1:0"), d9, TestMode.Decoder); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes("e^(i*pi)+1:0"), d9, TestMode.Decoder); Dictionary<string, string> d10 = new Dictionary<string, string>(); d10.Add("east", "west"); d10.Add("north", "south"); - this.KVDictTest(UTF8Encoding.UTF8.GetBytes("east:west\nnorth:south"), d10, TestMode.Decoder); + await this.KVDictTestAsync(UTF8Encoding.UTF8.GetBytes("east:west\nnorth:south"), d10, TestMode.Decoder); } [Test, ExpectedException(typeof(FormatException))] - public void NoValue() { - this.Illegal("x\n", KeyValueFormConformanceLevel.OpenId11); + public async Task NoValue() { + await this.IllegalAsync("x\n", KeyValueFormConformanceLevel.OpenId11); } [Test, ExpectedException(typeof(FormatException))] - public void NoValueLoose() { + public async Task NoValueLoose() { Dictionary<string, string> d = new Dictionary<string, string>(); - this.KVDictTest(Encoding.UTF8.GetBytes("x\n"), d, TestMode.Decoder); + await this.KVDictTestAsync(Encoding.UTF8.GetBytes("x\n"), d, TestMode.Decoder); } [Test, ExpectedException(typeof(FormatException))] - public void EmptyLine() { - this.Illegal("x:b\n\n", KeyValueFormConformanceLevel.OpenId20); + public async Task EmptyLine() { + await this.IllegalAsync("x:b\n\n", KeyValueFormConformanceLevel.OpenId20); } [Test] - public void EmptyLineLoose() { + public async Task EmptyLineLoose() { Dictionary<string, string> d = new Dictionary<string, string>(); d.Add("x", "b"); - this.KVDictTest(Encoding.UTF8.GetBytes("x:b\n\n"), d, TestMode.Decoder); + await this.KVDictTestAsync(Encoding.UTF8.GetBytes("x:b\n\n"), d, TestMode.Decoder); } [Test, ExpectedException(typeof(FormatException))] - public void LastLineNotTerminated() { - this.Illegal("x:y\na:b", KeyValueFormConformanceLevel.OpenId11); + public async Task LastLineNotTerminated() { + await this.IllegalAsync("x:y\na:b", KeyValueFormConformanceLevel.OpenId11); } [Test] - public void LastLineNotTerminatedLoose() { + public async Task LastLineNotTerminatedLoose() { Dictionary<string, string> d = new Dictionary<string, string>(); d.Add("x", "y"); d.Add("a", "b"); - this.KVDictTest(Encoding.UTF8.GetBytes("x:y\na:b"), d, TestMode.Decoder); + await this.KVDictTestAsync(Encoding.UTF8.GetBytes("x:y\na:b"), d, TestMode.Decoder); } - private void Illegal(string s, KeyValueFormConformanceLevel level) { - new KeyValueFormEncoding(level).GetDictionary(new MemoryStream(Encoding.UTF8.GetBytes(s))); + private async Task IllegalAsync(string s, KeyValueFormConformanceLevel level) { + await new KeyValueFormEncoding(level).GetDictionaryAsync(new MemoryStream(Encoding.UTF8.GetBytes(s)), CancellationToken.None); } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/NonIdentityTests.cs b/src/DotNetOpenAuth.Test/OpenId/NonIdentityTests.cs index 393239b..21b1d0b 100644 --- a/src/DotNetOpenAuth.Test/OpenId/NonIdentityTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/NonIdentityTests.cs @@ -5,6 +5,11 @@ //----------------------------------------------------------------------- namespace DotNetOpenAuth.Test.OpenId { + using System; + using System.Net.Http; + using System.Threading.Tasks; + + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Messages; using DotNetOpenAuth.OpenId.Provider; @@ -14,44 +19,62 @@ namespace DotNetOpenAuth.Test.OpenId { [TestFixture] public class NonIdentityTests : OpenIdTestBase { [Test] - public void ExtensionOnlyChannelLevel() { + public async Task ExtensionOnlyChannelLevel() { Protocol protocol = Protocol.V20; - AuthenticationRequestMode mode = AuthenticationRequestMode.Setup; + var mode = AuthenticationRequestMode.Setup; - var coordinator = new OpenIdCoordinator( - rp => { + await CoordinatorBase.RunAsync( + CoordinatorBase.RelyingPartyDriver(async (rp, ct) => { var request = new SignedResponseRequest(protocol.Version, OPUri, mode); - rp.Channel.Respond(request); - }, - op => { - var request = op.Channel.ReadFromRequest<SignedResponseRequest>(); + var authRequest = await rp.Channel.PrepareResponseAsync(request); + using (var httpClient = rp.Channel.HostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(authRequest.Headers.Location, ct)) { + response.EnsureSuccessStatusCode(); + } + } + }), + CoordinatorBase.HandleProvider(async (op, req, ct) => { + var request = await op.Channel.ReadFromRequestAsync<SignedResponseRequest>(req, ct); Assert.IsNotInstanceOf<CheckIdRequest>(request); - }); - coordinator.Run(); + return new HttpResponseMessage(); + })); } [Test] - public void ExtensionOnlyFacadeLevel() { + public async Task ExtensionOnlyFacadeLevel() { Protocol protocol = Protocol.V20; - var coordinator = new OpenIdCoordinator( - rp => { - var request = rp.CreateRequest(GetMockIdentifier(protocol.ProtocolVersion), RPRealmUri, RPUri); + int opStep = 0; + await CoordinatorBase.RunAsync( + CoordinatorBase.RelyingPartyDriver(async (rp, ct) => { + var request = await rp.CreateRequestAsync(GetMockIdentifier(protocol.ProtocolVersion), RPRealmUri, RPUri, ct); request.IsExtensionOnly = true; - rp.Channel.Respond(request.RedirectingResponse.OriginalMessage); - IAuthenticationResponse response = rp.GetResponse(); - Assert.AreEqual(AuthenticationStatus.ExtensionsOnly, response.Status); - }, - op => { - var assocRequest = op.GetRequest(); - op.Respond(assocRequest); + var redirectRequest = await request.GetRedirectingResponseAsync(ct); + Uri redirectResponseUrl; + using (var httpClient = rp.Channel.HostFactories.CreateHttpClient()) { + using (var redirectResponse = await httpClient.GetAsync(redirectRequest.Headers.Location, ct)) { + redirectResponse.EnsureSuccessStatusCode(); + redirectResponseUrl = redirectRequest.Headers.Location; + } + } - var request = (IAnonymousRequest)op.GetRequest(); - request.IsApproved = true; - Assert.IsNotInstanceOf<CheckIdRequest>(request); - op.Respond(request); - }); - coordinator.Run(); + IAuthenticationResponse response = await rp.GetResponseAsync(new HttpRequestMessage(HttpMethod.Get, redirectResponseUrl)); + Assert.AreEqual(AuthenticationStatus.ExtensionsOnly, response.Status); + }), + CoordinatorBase.HandleProvider(async (op, req, ct) => { + switch (++opStep) { + case 1: + var assocRequest = await op.GetRequestAsync(req, ct); + return await op.PrepareResponseAsync(assocRequest, ct); + case 2: + var request = (IAnonymousRequest)await op.GetRequestAsync(req, ct); + request.IsApproved = true; + Assert.IsNotInstanceOf<CheckIdRequest>(request); + return await op.PrepareResponseAsync(request, ct); + default: + throw Assumes.NotReachable(); + } + })); } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs b/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs index 75217fd..c23f042 100644 --- a/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs +++ b/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs @@ -8,6 +8,7 @@ namespace DotNetOpenAuth.Test.OpenId { using System; using System.Collections.Generic; using System.IO; + using System.Net.Http; using System.Reflection; using System.Threading; using System.Threading.Tasks; @@ -149,36 +150,47 @@ namespace DotNetOpenAuth.Test.OpenId { return CoordinatorBase.Handle(OPUri).By( async (req, ct) => { var provider = new OpenIdProvider(new StandardProviderApplicationStore()); - IRequest request = await provider.GetRequestAsync(req, ct); - Assert.That(request, Is.Not.Null); - - if (!request.IsResponseReady) { - var authRequest = (DotNetOpenAuth.OpenId.Provider.IAuthenticationRequest)request; - switch (this.AutoProviderScenario) { - case Scenarios.AutoApproval: - authRequest.IsAuthenticated = true; - break; - case Scenarios.AutoApprovalAddFragment: - authRequest.SetClaimedIdentifierFragment("frag"); - authRequest.IsAuthenticated = true; - break; - case Scenarios.ApproveOnSetup: - authRequest.IsAuthenticated = !authRequest.Immediate; - break; - case Scenarios.AlwaysDeny: - authRequest.IsAuthenticated = false; - break; - default: - // All other scenarios are done programmatically only. - throw new InvalidOperationException("Unrecognized scenario"); - } - } - - return await provider.PrepareResponseAsync(request, ct); + return await this.AutoProviderActionAsync(provider, req, ct); }); } } + /// <summary> + /// Gets a default implementation of a simple provider that responds to authentication requests + /// per the scenario that is being simulated. + /// </summary> + /// <remarks> + /// This is a very useful method to pass to the OpenIdCoordinator constructor for the Provider argument. + /// </remarks> + internal async Task<HttpResponseMessage> AutoProviderActionAsync(OpenIdProvider provider, HttpRequestMessage req, CancellationToken ct) { + IRequest request = await provider.GetRequestAsync(req, ct); + Assert.That(request, Is.Not.Null); + + if (!request.IsResponseReady) { + var authRequest = (DotNetOpenAuth.OpenId.Provider.IAuthenticationRequest)request; + switch (this.AutoProviderScenario) { + case Scenarios.AutoApproval: + authRequest.IsAuthenticated = true; + break; + case Scenarios.AutoApprovalAddFragment: + authRequest.SetClaimedIdentifierFragment("frag"); + authRequest.IsAuthenticated = true; + break; + case Scenarios.ApproveOnSetup: + authRequest.IsAuthenticated = !authRequest.Immediate; + break; + case Scenarios.AlwaysDeny: + authRequest.IsAuthenticated = false; + break; + default: + // All other scenarios are done programmatically only. + throw new InvalidOperationException("Unrecognized scenario"); + } + } + + return await provider.PrepareResponseAsync(request, ct); + } + internal IEnumerable<IdentifierDiscoveryResult> Discover(Identifier identifier) { var rp = this.CreateRelyingParty(true); rp.Channel.WebRequestHandler = this.RequestHandler; diff --git a/src/DotNetOpenAuth.Test/OpenId/Provider/AnonymousRequestTests.cs b/src/DotNetOpenAuth.Test/OpenId/Provider/AnonymousRequestTests.cs index 7310eb3..ad619f7 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Provider/AnonymousRequestTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Provider/AnonymousRequestTests.cs @@ -8,6 +8,9 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { using System.IO; using System.Runtime.Serialization; using System.Runtime.Serialization.Formatters.Binary; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Messages; using DotNetOpenAuth.OpenId.Provider; @@ -20,7 +23,7 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { /// Verifies that IsApproved controls which response message is returned. /// </summary> [Test] - public void IsApprovedDeterminesReturnedMessage() { + public async Task IsApprovedDeterminesReturnedMessage() { var op = CreateProvider(); Protocol protocol = Protocol.V20; var req = new SignedResponseRequest(protocol.Version, OPUri, AuthenticationRequestMode.Setup); @@ -30,11 +33,11 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { Assert.IsFalse(anonReq.IsApproved.HasValue); anonReq.IsApproved = false; - Assert.IsInstanceOf<NegativeAssertionResponse>(anonReq.Response); + Assert.IsInstanceOf<NegativeAssertionResponse>(await anonReq.GetResponseAsync(CancellationToken.None)); anonReq.IsApproved = true; - Assert.IsInstanceOf<IndirectSignedResponse>(anonReq.Response); - Assert.IsNotInstanceOf<PositiveAssertionResponse>(anonReq.Response); + Assert.IsInstanceOf<IndirectSignedResponse>(await anonReq.GetResponseAsync(CancellationToken.None)); + Assert.IsNotInstanceOf<PositiveAssertionResponse>(await anonReq.GetResponseAsync(CancellationToken.None)); } /// <summary> diff --git a/src/DotNetOpenAuth.Test/OpenId/Provider/AuthenticationRequestTest.cs b/src/DotNetOpenAuth.Test/OpenId/Provider/AuthenticationRequestTest.cs index baf5377..e9c5465 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Provider/AuthenticationRequestTest.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Provider/AuthenticationRequestTest.cs @@ -7,8 +7,12 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { using System; using System.IO; + using System.Net.Http; using System.Runtime.Serialization; using System.Runtime.Serialization.Formatters.Binary; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Messages; @@ -21,24 +25,24 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { /// Verifies the user_setup_url is set properly for immediate negative responses. /// </summary> [Test] - public void UserSetupUrl() { + public async Task UserSetupUrl() { // Construct a V1 immediate request Protocol protocol = Protocol.V11; OpenIdProvider provider = this.CreateProvider(); - CheckIdRequest immediateRequest = new CheckIdRequest(protocol.Version, OPUri, DotNetOpenAuth.OpenId.AuthenticationRequestMode.Immediate); + var immediateRequest = new CheckIdRequest(protocol.Version, OPUri, DotNetOpenAuth.OpenId.AuthenticationRequestMode.Immediate); immediateRequest.Realm = RPRealmUri; immediateRequest.ReturnTo = RPUri; immediateRequest.LocalIdentifier = "http://somebody"; - AuthenticationRequest request = new AuthenticationRequest(provider, immediateRequest); + var request = new AuthenticationRequest(provider, immediateRequest); // Now simulate the request being rejected and extract the user_setup_url request.IsAuthenticated = false; - Uri userSetupUrl = ((NegativeAssertionResponse)request.Response).UserSetupUrl; + Uri userSetupUrl = ((NegativeAssertionResponse)await request.GetResponseAsync(CancellationToken.None)).UserSetupUrl; Assert.IsNotNull(userSetupUrl); // Now construct a new request as if it had just come in. - HttpRequestInfo httpRequest = new HttpRequestInfo("GET", userSetupUrl); - var setupRequest = (AuthenticationRequest)provider.GetRequest(httpRequest); + var httpRequest = new HttpRequestMessage(HttpMethod.Get, userSetupUrl); + var setupRequest = (AuthenticationRequest)await provider.GetRequestAsync(httpRequest); var setupRequestMessage = (CheckIdRequest)setupRequest.RequestMessage; // And make sure all the right properties are set. diff --git a/src/DotNetOpenAuth.Test/OpenId/Provider/HostProcessedRequestTests.cs b/src/DotNetOpenAuth.Test/OpenId/Provider/HostProcessedRequestTests.cs index 2e3e7ec..966e712 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Provider/HostProcessedRequestTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Provider/HostProcessedRequestTests.cs @@ -6,9 +6,13 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { using System; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Messages; using DotNetOpenAuth.OpenId.Provider; + using DotNetOpenAuth.Test.Mocks; + using NUnit.Framework; [TestFixture] @@ -24,22 +28,25 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { this.protocol = Protocol.Default; this.provider = this.CreateProvider(); - this.checkIdRequest = new CheckIdRequest(this.protocol.Version, OPUri, DotNetOpenAuth.OpenId.AuthenticationRequestMode.Setup); + this.checkIdRequest = new CheckIdRequest(this.protocol.Version, OPUri, AuthenticationRequestMode.Setup); this.checkIdRequest.Realm = RPRealmUri; this.checkIdRequest.ReturnTo = RPUri; this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); } [Test] - public void IsReturnUrlDiscoverableNoResponse() { - Assert.AreEqual(RelyingPartyDiscoveryResult.NoServiceDocument, this.request.IsReturnUrlDiscoverable(this.provider.Channel.WebRequestHandler)); + public async Task IsReturnUrlDiscoverableNoResponse() { + Assert.AreEqual(RelyingPartyDiscoveryResult.NoServiceDocument, await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); } [Test] - public void IsReturnUrlDiscoverableValidResponse() { - this.MockResponder.RegisterMockRPDiscovery(); - this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); - Assert.AreEqual(RelyingPartyDiscoveryResult.Success, this.request.IsReturnUrlDiscoverable(this.provider.Channel.WebRequestHandler)); + public async Task IsReturnUrlDiscoverableValidResponse() { + await CoordinatorBase.RunAsync( + async (hostFactories, ct) => { + this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); + Assert.AreEqual(RelyingPartyDiscoveryResult.Success, await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); + }, + MockHttpRequest.RegisterMockRPDiscovery(false)); } /// <summary> @@ -47,39 +54,48 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { /// is set, that discovery fails. /// </summary> [Test] - public void IsReturnUrlDiscoverableNotSsl() { - this.provider.SecuritySettings.RequireSsl = true; - this.MockResponder.RegisterMockRPDiscovery(); - Assert.AreEqual(RelyingPartyDiscoveryResult.NoServiceDocument, this.request.IsReturnUrlDiscoverable(this.provider.Channel.WebRequestHandler)); + public async Task IsReturnUrlDiscoverableNotSsl() { + await CoordinatorBase.RunAsync( + async (hostFactories, ct) => { + this.provider.SecuritySettings.RequireSsl = true; + Assert.AreEqual(RelyingPartyDiscoveryResult.NoServiceDocument, await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); + }, + MockHttpRequest.RegisterMockRPDiscovery(false)); } /// <summary> /// Verifies that when discovery would be performed over HTTPS that discovery succeeds. /// </summary> [Test] - public void IsReturnUrlDiscoverableRequireSsl() { - this.MockResponder.RegisterMockRPDiscovery(); - this.checkIdRequest.Realm = RPRealmUriSsl; - this.checkIdRequest.ReturnTo = RPUriSsl; + public async Task IsReturnUrlDiscoverableRequireSsl() { + await CoordinatorBase.RunAsync( + async (hostFactories, ct) => { + this.checkIdRequest.Realm = RPRealmUriSsl; + this.checkIdRequest.ReturnTo = RPUriSsl; - // Try once with RequireSsl - this.provider.SecuritySettings.RequireSsl = true; - this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); - Assert.AreEqual(RelyingPartyDiscoveryResult.Success, this.request.IsReturnUrlDiscoverable(this.provider.Channel.WebRequestHandler)); + // Try once with RequireSsl + this.provider.SecuritySettings.RequireSsl = true; + this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); + Assert.AreEqual(RelyingPartyDiscoveryResult.Success, await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); - // And again without RequireSsl - this.provider.SecuritySettings.RequireSsl = false; - this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); - Assert.AreEqual(RelyingPartyDiscoveryResult.Success, this.request.IsReturnUrlDiscoverable(this.provider.Channel.WebRequestHandler)); + // And again without RequireSsl + this.provider.SecuritySettings.RequireSsl = false; + this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); + Assert.AreEqual(RelyingPartyDiscoveryResult.Success, await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); + }, + MockHttpRequest.RegisterMockRPDiscovery(false)); } [Test] - public void IsReturnUrlDiscoverableValidButNoMatch() { - this.MockResponder.RegisterMockRPDiscovery(); - this.provider.SecuritySettings.RequireSsl = false; // reset for another failure test case - this.checkIdRequest.ReturnTo = new Uri("http://somerandom/host"); - this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); - Assert.AreEqual(RelyingPartyDiscoveryResult.NoMatchingReturnTo, this.request.IsReturnUrlDiscoverable(this.provider.Channel.WebRequestHandler)); + public async Task IsReturnUrlDiscoverableValidButNoMatch() { + await CoordinatorBase.RunAsync( + async (hostFactories, ct) => { + this.provider.SecuritySettings.RequireSsl = false; // reset for another failure test case + this.checkIdRequest.ReturnTo = new Uri("http://somerandom/host"); + this.request = new AuthenticationRequest(this.provider, this.checkIdRequest); + Assert.AreEqual(RelyingPartyDiscoveryResult.NoMatchingReturnTo, await this.request.IsReturnUrlDiscoverableAsync(this.provider.Channel.HostFactories, CancellationToken.None)); + }, + MockHttpRequest.RegisterMockRPDiscovery(false)); } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/Provider/OpenIdProviderTests.cs b/src/DotNetOpenAuth.Test/OpenId/Provider/OpenIdProviderTests.cs index c15f5b8..984337c 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Provider/OpenIdProviderTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Provider/OpenIdProviderTests.cs @@ -7,6 +7,9 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { using System; using System.IO; + using System.Net.Http; + using System.Threading; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; @@ -74,60 +77,61 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { /// Verifies the GetRequest method throws outside an HttpContext. /// </summary> [Test, ExpectedException(typeof(InvalidOperationException))] - public void GetRequestNoContext() { + public async Task GetRequestNoContext() { HttpContext.Current = null; - this.provider.GetRequest(); + await this.provider.GetRequestAsync(); } /// <summary> /// Verifies GetRequest throws on null input. /// </summary> [Test, ExpectedException(typeof(ArgumentNullException))] - public void GetRequestNull() { - this.provider.GetRequest(null); + public async Task GetRequestNull() { + await this.provider.GetRequestAsync((HttpRequestMessage)null); } /// <summary> /// Verifies that GetRequest correctly returns the right messages. /// </summary> [Test] - public void GetRequest() { - var httpInfo = new HttpRequestInfo("GET", new Uri("http://someUri")); - Assert.IsNull(this.provider.GetRequest(httpInfo), "An irrelevant request should return null."); + public async Task GetRequest() { + var httpInfo = new HttpRequestMessage(HttpMethod.Get, "http://someUri"); + Assert.IsNull(await this.provider.GetRequestAsync(httpInfo), "An irrelevant request should return null."); var providerDescription = new ProviderEndpointDescription(OPUri, Protocol.Default.Version); // Test some non-empty request scenario. - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { - rp.Channel.Request(AssociateRequestRelyingParty.Create(rp.SecuritySettings, providerDescription)); - }, - op => { - IRequest request = op.GetRequest(); + var coordinator = new CoordinatorBase( + CoordinatorBase.RelyingPartyDriver(async (rp, ct) => { + await rp.Channel.RequestAsync(AssociateRequestRelyingParty.Create(rp.SecuritySettings, providerDescription), ct); + }), + CoordinatorBase.HandleProvider(async (op, req, ct) => { + IRequest request = await op.GetRequestAsync(req); Assert.IsInstanceOf<AutoResponsiveRequest>(request); - op.Respond(request); - }); - coordinator.Run(); + return await op.PrepareResponseAsync(request, ct); + })); + await coordinator.RunAsync(); } [Test] - public void BadRequestsGenerateValidErrorResponses() { - var coordinator = new OpenIdCoordinator( - rp => { - var nonOpenIdMessage = new Mocks.TestDirectedMessage(); - nonOpenIdMessage.Recipient = OPUri; - nonOpenIdMessage.HttpMethods = HttpDeliveryMethods.PostRequest; + public async Task BadRequestsGenerateValidErrorResponses() { + var coordinator = new CoordinatorBase( + CoordinatorBase.RelyingPartyDriver(async (rp, ct) => { + var nonOpenIdMessage = new Mocks.TestDirectedMessage { + Recipient = OPUri, + HttpMethods = HttpDeliveryMethods.PostRequest + }; MessagingTestBase.GetStandardTestMessage(MessagingTestBase.FieldFill.AllRequired, nonOpenIdMessage); - var response = rp.Channel.Request<DirectErrorResponse>(nonOpenIdMessage); + var response = await rp.Channel.RequestAsync<DirectErrorResponse>(nonOpenIdMessage, ct); Assert.IsNotNull(response.ErrorMessage); Assert.AreEqual(Protocol.Default.Version, response.Version); - }, + }), AutoProvider); - coordinator.Run(); + await coordinator.RunAsync(); } [Test, Category("HostASPNET")] - public void BadRequestsGenerateValidErrorResponsesHosted() { + public async Task BadRequestsGenerateValidErrorResponsesHosted() { try { using (AspNetHost host = AspNetHost.CreateHost(TestWebDirectory)) { Uri opEndpoint = new Uri(host.BaseUri, "/OpenIdProviderEndpoint.ashx"); @@ -136,7 +140,7 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { nonOpenIdMessage.Recipient = opEndpoint; nonOpenIdMessage.HttpMethods = HttpDeliveryMethods.PostRequest; MessagingTestBase.GetStandardTestMessage(MessagingTestBase.FieldFill.AllRequired, nonOpenIdMessage); - var response = rp.Channel.Request<DirectErrorResponse>(nonOpenIdMessage); + var response = await rp.Channel.RequestAsync<DirectErrorResponse>(nonOpenIdMessage, CancellationToken.None); Assert.IsNotNull(response.ErrorMessage); } } catch (FileNotFoundException ex) { diff --git a/src/DotNetOpenAuth.Test/OpenId/RelyingParty/AuthenticationRequestTests.cs b/src/DotNetOpenAuth.Test/OpenId/RelyingParty/AuthenticationRequestTests.cs index cd72fdb..f8eef61 100644 --- a/src/DotNetOpenAuth.Test/OpenId/RelyingParty/AuthenticationRequestTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/RelyingParty/AuthenticationRequestTests.cs @@ -10,6 +10,8 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { using System.Collections.Specialized; using System.Linq; using System.Text; + using System.Threading; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; @@ -70,24 +72,24 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// Verifies RedirectingResponse. /// </summary> [Test] - public void CreateRequestMessage() { - OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { + public async Task CreateRequestMessage() { + var coordinator = new CoordinatorBase( + CoordinatorBase.RelyingPartyDriver(async (rp, ct) => { Identifier id = this.GetMockIdentifier(ProtocolVersion.V20); - IAuthenticationRequest authRequest = rp.CreateRequest(id, this.realm, this.returnTo); + IAuthenticationRequest authRequest = await rp.CreateRequestAsync(id, this.realm, this.returnTo); // Add some callback arguments authRequest.AddCallbackArguments("a", "b"); authRequest.AddCallbackArguments(new Dictionary<string, string> { { "c", "d" }, { "e", "f" } }); // Assembly an extension request. - ClaimsRequest sregRequest = new ClaimsRequest(); + var sregRequest = new ClaimsRequest(); sregRequest.Nickname = DemandLevel.Request; authRequest.AddExtension(sregRequest); // Construct the actual authentication request message. var authRequestAccessor = (AuthenticationRequest)authRequest; - var req = authRequestAccessor.CreateRequestMessageTestHook(); + var req = await authRequestAccessor.CreateRequestMessageTestHookAsync(ct); Assert.IsNotNull(req); // Verify that callback arguments were included. @@ -99,25 +101,25 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { // Verify that extensions were included. Assert.AreEqual(1, req.Extensions.Count); Assert.IsTrue(req.Extensions.Contains(sregRequest)); - }, + }), AutoProvider); - coordinator.Run(); + await coordinator.RunAsync(); } /// <summary> /// Verifies that delegating authentication requests are filtered out when configured to do so. /// </summary> [Test] - public void CreateFiltersDelegatingIdentifiers() { + public async Task CreateFiltersDelegatingIdentifiers() { Identifier id = GetMockIdentifier(ProtocolVersion.V20, false, true); var rp = CreateRelyingParty(); // First verify that delegating identifiers work - Assert.IsTrue(AuthenticationRequest.Create(id, rp, this.realm, this.returnTo, false).Any(), "The delegating identifier should have not generated any results."); + Assert.IsTrue((await AuthenticationRequest.CreateAsync(id, rp, this.realm, this.returnTo, false, CancellationToken.None)).Any(), "The delegating identifier should have not generated any results."); // Now disable them and try again. rp.SecuritySettings.RejectDelegatingIdentifiers = true; - Assert.IsFalse(AuthenticationRequest.Create(id, rp, this.realm, this.returnTo, false).Any(), "The delegating identifier should have not generated any results."); + Assert.IsFalse((await AuthenticationRequest.CreateAsync(id, rp, this.realm, this.returnTo, false, CancellationToken.None)).Any(), "The delegating identifier should have not generated any results."); } /// <summary> @@ -135,11 +137,12 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// Verifies that AddCallbackArguments adds query arguments to the return_to URL of the message. /// </summary> [Test] - public void AddCallbackArgument() { + public async Task AddCallbackArgument() { var authRequest = this.CreateAuthenticationRequest(this.claimedId, this.claimedId); Assert.AreEqual(this.returnTo, authRequest.ReturnToUrl); authRequest.AddCallbackArguments("p1", "v1"); - var req = (SignedResponseRequest)authRequest.RedirectingResponse.OriginalMessage; + var response = (HttpResponseMessageWithOriginal)await authRequest.GetRedirectingResponseAsync(CancellationToken.None); + var req = (SignedResponseRequest)response.OriginalMessage; NameValueCollection query = HttpUtility.ParseQueryString(req.ReturnTo.Query); Assert.AreEqual("v1", query["p1"]); } @@ -149,13 +152,14 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// rather than appending them. /// </summary> [Test] - public void AddCallbackArgumentClearsPreviousArgument() { + public async Task AddCallbackArgumentClearsPreviousArgument() { UriBuilder returnToWithArgs = new UriBuilder(this.returnTo); returnToWithArgs.AppendQueryArgs(new Dictionary<string, string> { { "p1", "v1" } }); this.returnTo = returnToWithArgs.Uri; var authRequest = this.CreateAuthenticationRequest(this.claimedId, this.claimedId); authRequest.AddCallbackArguments("p1", "v2"); - var req = (SignedResponseRequest)authRequest.RedirectingResponse.OriginalMessage; + var response = (HttpResponseMessageWithOriginal)await authRequest.GetRedirectingResponseAsync(CancellationToken.None); + var req = (SignedResponseRequest)response.OriginalMessage; NameValueCollection query = HttpUtility.ParseQueryString(req.ReturnTo.Query); Assert.AreEqual("v2", query["p1"]); } @@ -164,11 +168,12 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// Verifies identity-less checkid_* request behavior. /// </summary> [Test] - public void NonIdentityRequest() { + public async Task NonIdentityRequest() { var authRequest = this.CreateAuthenticationRequest(this.claimedId, this.claimedId); authRequest.IsExtensionOnly = true; Assert.IsTrue(authRequest.IsExtensionOnly); - var req = (SignedResponseRequest)authRequest.RedirectingResponse.OriginalMessage; + var response = (HttpResponseMessageWithOriginal)await authRequest.GetRedirectingResponseAsync(CancellationToken.None); + var req = (SignedResponseRequest)response.OriginalMessage; Assert.IsNotInstanceOf<CheckIdRequest>(req, "An unexpected SignedResponseRequest derived type was generated."); } @@ -177,15 +182,15 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// only generate OP Identifier auth requests. /// </summary> [Test] - public void DualIdentifierUsedOnlyAsOPIdentifierForAuthRequest() { + public async Task DualIdentifierUsedOnlyAsOPIdentifierForAuthRequest() { var rp = this.CreateRelyingParty(true); - var results = AuthenticationRequest.Create(GetMockDualIdentifier(), rp, this.realm, this.returnTo, false).ToList(); + var results = (await AuthenticationRequest.CreateAsync(GetMockDualIdentifier(), rp, this.realm, this.returnTo, false, CancellationToken.None)).ToList(); Assert.AreEqual(1, results.Count); Assert.IsTrue(results[0].IsDirectedIdentity); // Also test when dual identiifer support is turned on. rp.SecuritySettings.AllowDualPurposeIdentifiers = true; - results = AuthenticationRequest.Create(GetMockDualIdentifier(), rp, this.realm, this.returnTo, false).ToList(); + results = (await AuthenticationRequest.CreateAsync(GetMockDualIdentifier(), rp, this.realm, this.returnTo, false, CancellationToken.None)).ToList(); Assert.AreEqual(1, results.Count); Assert.IsTrue(results[0].IsDirectedIdentity); } diff --git a/src/DotNetOpenAuth.Test/OpenId/RelyingParty/OpenIdRelyingPartyTests.cs b/src/DotNetOpenAuth.Test/OpenId/RelyingParty/OpenIdRelyingPartyTests.cs index a2a4efa..986a64e 100644 --- a/src/DotNetOpenAuth.Test/OpenId/RelyingParty/OpenIdRelyingPartyTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/RelyingParty/OpenIdRelyingPartyTests.cs @@ -7,11 +7,18 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { using System; using System.Linq; + using System.Net.Http; + using System.Threading.Tasks; + using System.Web; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Extensions; using DotNetOpenAuth.OpenId.Messages; + using DotNetOpenAuth.OpenId.Provider; using DotNetOpenAuth.OpenId.RelyingParty; + using DotNetOpenAuth.Test.Mocks; + using NUnit.Framework; [TestFixture] @@ -22,12 +29,13 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { } [Test] - public void CreateRequestDumbMode() { + public async Task CreateRequestDumbMode() { var rp = this.CreateRelyingParty(true); Identifier id = this.GetMockIdentifier(ProtocolVersion.V20); - var authReq = rp.CreateRequest(id, RPRealmUri, RPUri); - CheckIdRequest requestMessage = (CheckIdRequest)authReq.RedirectingResponse.OriginalMessage; - Assert.IsNull(requestMessage.AssociationHandle); + var authReq = await rp.CreateRequestAsync(id, RPRealmUri, RPUri); + var httpMessage = await authReq.GetRedirectingResponseAsync(); + var data = HttpUtility.ParseQueryString(httpMessage.GetDirectUriRequest().Query); + Assert.IsNull(data[Protocol.Default.openid.assoc_handle]); } [Test, ExpectedException(typeof(ArgumentNullException))] @@ -46,53 +54,61 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { } [Test] - public void CreateRequest() { + public async Task CreateRequest() { var rp = this.CreateRelyingParty(); StoreAssociation(rp, OPUri, HmacShaAssociation.Create("somehandle", new byte[20], TimeSpan.FromDays(1))); Identifier id = Identifier.Parse(GetMockIdentifier(ProtocolVersion.V20)); - var req = rp.CreateRequest(id, RPRealmUri, RPUri); + var req = await rp.CreateRequestAsync(id, RPRealmUri, RPUri); Assert.IsNotNull(req); } [Test] - public void CreateRequests() { + public async Task CreateRequests() { var rp = this.CreateRelyingParty(); StoreAssociation(rp, OPUri, HmacShaAssociation.Create("somehandle", new byte[20], TimeSpan.FromDays(1))); Identifier id = Identifier.Parse(GetMockIdentifier(ProtocolVersion.V20)); - var requests = rp.CreateRequests(id, RPRealmUri, RPUri); + var requests = await rp.CreateRequestsAsync(id, RPRealmUri, RPUri); Assert.AreEqual(1, requests.Count()); } [Test] - public void CreateRequestsWithEndpointFilter() { + public async Task CreateRequestsWithEndpointFilter() { var rp = this.CreateRelyingParty(); StoreAssociation(rp, OPUri, HmacShaAssociation.Create("somehandle", new byte[20], TimeSpan.FromDays(1))); Identifier id = Identifier.Parse(GetMockIdentifier(ProtocolVersion.V20)); rp.EndpointFilter = opendpoint => true; - var requests = rp.CreateRequests(id, RPRealmUri, RPUri); + var requests = await rp.CreateRequestsAsync(id, RPRealmUri, RPUri); Assert.AreEqual(1, requests.Count()); rp.EndpointFilter = opendpoint => false; - requests = rp.CreateRequests(id, RPRealmUri, RPUri); + requests = await rp.CreateRequestsAsync(id, RPRealmUri, RPUri); Assert.AreEqual(0, requests.Count()); } [Test, ExpectedException(typeof(ProtocolException))] - public void CreateRequestOnNonOpenID() { - Uri nonOpenId = new Uri("http://www.microsoft.com/"); - var rp = this.CreateRelyingParty(); - this.MockResponder.RegisterMockResponse(nonOpenId, "text/html", "<html/>"); - rp.CreateRequest(nonOpenId, RPRealmUri, RPUri); + public async Task CreateRequestOnNonOpenID() { + var nonOpenId = new Uri("http://www.microsoft.com/"); + var coordinator = new CoordinatorBase( + CoordinatorBase.RelyingPartyDriver( + async (rp, ct) => { + await rp.CreateRequestAsync(nonOpenId, RPRealmUri, RPUri); + }), + CoordinatorBase.Handle(nonOpenId).By("<html/>", "text/html")); + await coordinator.RunAsync(); } [Test] - public void CreateRequestsOnNonOpenID() { - Uri nonOpenId = new Uri("http://www.microsoft.com/"); - var rp = this.CreateRelyingParty(); - this.MockResponder.RegisterMockResponse(nonOpenId, "text/html", "<html/>"); - var requests = rp.CreateRequests(nonOpenId, RPRealmUri, RPUri); - Assert.AreEqual(0, requests.Count()); + public async Task CreateRequestsOnNonOpenID() { + var nonOpenId = new Uri("http://www.microsoft.com/"); + var coordinator = new CoordinatorBase( + CoordinatorBase.RelyingPartyDriver( + async (rp, ct) => { + var requests = await rp.CreateRequestsAsync(nonOpenId, RPRealmUri, RPUri); + Assert.AreEqual(0, requests.Count()); + }), + CoordinatorBase.Handle(nonOpenId).By("<html/>", "text/html")); + await coordinator.RunAsync(); } /// <summary> @@ -100,25 +116,37 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// OPs that are not approved by <see cref="OpenIdRelyingParty.EndpointFilter"/>. /// </summary> [Test] - public void AssertionWithEndpointFilter() { - var coordinator = new OpenIdCoordinator( - rp => { - // register with RP so that id discovery passes - rp.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; - - // Rig it to always deny the incoming OP - rp.EndpointFilter = op => false; - - // Receive the unsolicited assertion - var response = rp.GetResponse(); - Assert.AreEqual(AuthenticationStatus.Failed, response.Status); - }, - op => { + public async Task AssertionWithEndpointFilter() { + var opStore = new StandardProviderApplicationStore(); + var coordinator = new CoordinatorBase( + async (hostFactories, ct) => { + var op = new OpenIdProvider(opStore); Identifier id = GetMockIdentifier(ProtocolVersion.V20); - op.SendUnsolicitedAssertion(OPUri, GetMockRealm(false), id, id); - AutoProvider(op); - }); - coordinator.Run(); + var assertion = await op.PrepareUnsolicitedAssertionAsync(OPUri, GetMockRealm(false), id, id, ct); + using (var httpClient = hostFactories.CreateHttpClient()) { + using (var response = await httpClient.GetAsync(assertion.Headers.Location, ct)) { + response.EnsureSuccessStatusCode(); + } + } + }, + CoordinatorBase.Handle(RPRealmUri).By( + async (hostFactories, req, ct) => { + var rp = new OpenIdRelyingParty(new StandardRelyingPartyApplicationStore()); + + // register with RP so that id discovery passes + // TODO: Fix this + ////rp.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; + + // Rig it to always deny the incoming OP + rp.EndpointFilter = op => false; + + // Receive the unsolicited assertion + var response = await rp.GetResponseAsync(req, ct); + Assert.AreEqual(AuthenticationStatus.Failed, response.Status); + return new HttpResponseMessage(); + }), + AutoProvider); + await coordinator.RunAsync(); } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/RelyingParty/PositiveAuthenticationResponseTests.cs b/src/DotNetOpenAuth.Test/OpenId/RelyingParty/PositiveAuthenticationResponseTests.cs index 91318f5..5642bc5 100644 --- a/src/DotNetOpenAuth.Test/OpenId/RelyingParty/PositiveAuthenticationResponseTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/RelyingParty/PositiveAuthenticationResponseTests.cs @@ -7,6 +7,8 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { using System; using System.Collections.Generic; + using System.Threading.Tasks; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Extensions.SimpleRegistration; @@ -124,20 +126,22 @@ namespace DotNetOpenAuth.Test.OpenId.RelyingParty { /// Verifies that certain problematic claimed identifiers pass through to the RP response correctly. /// </summary> [Test] - public void ProblematicClaimedId() { + public async Task ProblematicClaimedId() { var providerEndpoint = new ProviderEndpointDescription(OpenIdTestBase.OPUri, Protocol.Default.Version); string claimed_id = BaseMockUri + "a./b."; var se = IdentifierDiscoveryResult.CreateForClaimedIdentifier(claimed_id, claimed_id, providerEndpoint, null, null); - UriIdentifier identityUri = (UriIdentifier)se.ClaimedIdentifier; - var mockId = new MockIdentifier(identityUri, this.MockResponder, new IdentifierDiscoveryResult[] { se }); - - var positiveAssertion = this.GetPositiveAssertion(); - positiveAssertion.ClaimedIdentifier = mockId; - positiveAssertion.LocalIdentifier = mockId; - var rp = CreateRelyingParty(); - var authResponse = new PositiveAuthenticationResponse(positiveAssertion, rp); - Assert.AreEqual(AuthenticationStatus.Authenticated, authResponse.Status); - Assert.AreEqual(claimed_id, authResponse.ClaimedIdentifier.ToString()); + var identityUri = (UriIdentifier)se.ClaimedIdentifier; + var coordinator = new CoordinatorBase( + CoordinatorBase.RelyingPartyDriver(async (rp, ct) => { + var positiveAssertion = this.GetPositiveAssertion(); + positiveAssertion.ClaimedIdentifier = claimed_id; + positiveAssertion.LocalIdentifier = claimed_id; + var authResponse = new PositiveAuthenticationResponse(positiveAssertion, rp); + Assert.AreEqual(AuthenticationStatus.Authenticated, authResponse.Status); + Assert.AreEqual(claimed_id, authResponse.ClaimedIdentifier.ToString()); + }), + MockHttpRequest.RegisterMockXrdsResponse(se)); + await coordinator.RunAsync(); } private PositiveAssertionResponse GetPositiveAssertion() { |