diff options
Diffstat (limited to 'src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs')
-rw-r--r-- | src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs | 193 |
1 files changed, 143 insertions, 50 deletions
diff --git a/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs b/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs index 3a27e96..ea2867c 100644 --- a/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs +++ b/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs @@ -8,21 +8,29 @@ namespace DotNetOpenAuth.Test.OpenId { using System; using System.Collections.Generic; using System.IO; + using System.Linq; + using System.Net; + using System.Net.Http; using System.Reflection; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Configuration; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OpenId; + using DotNetOpenAuth.OpenId.Messages; using DotNetOpenAuth.OpenId.Provider; using DotNetOpenAuth.OpenId.RelyingParty; + using DotNetOpenAuth.Test.Messaging; using DotNetOpenAuth.Test.Mocks; - using NUnit.Framework; + using DotNetOpenAuth.Test.OpenId.Extensions; - public class OpenIdTestBase : TestBase { - internal IDirectWebRequestHandler RequestHandler; + using NUnit.Framework; - internal MockHttpRequest MockResponder; + using IAuthenticationRequest = DotNetOpenAuth.OpenId.Provider.IAuthenticationRequest; + public class OpenIdTestBase : TestBase { protected internal const string IdentifierSelect = "http://specs.openid.net/auth/2.0/identifier_select"; protected internal static readonly Uri BaseMockUri = new Uri("http://localhost/"); @@ -69,10 +77,9 @@ namespace DotNetOpenAuth.Test.OpenId { this.RelyingPartySecuritySettings = OpenIdElement.Configuration.RelyingParty.SecuritySettings.CreateSecuritySettings(); this.ProviderSecuritySettings = OpenIdElement.Configuration.Provider.SecuritySettings.CreateSecuritySettings(); - this.MockResponder = MockHttpRequest.CreateUntrustedMockHttpHandler(); - this.RequestHandler = this.MockResponder.MockWebRequestHandler; this.AutoProviderScenario = Scenarios.AutoApproval; Identifier.EqualityOnStrings = true; + this.HostFactories.InstallUntrustedWebReqestHandler = true; } [TearDown] @@ -121,7 +128,7 @@ namespace DotNetOpenAuth.Test.OpenId { internal static IdentifierDiscoveryResult GetServiceEndpoint(int user, ProtocolVersion providerVersion, int servicePriority, bool useSsl, bool delegating) { var providerEndpoint = new ProviderEndpointDescription( - useSsl ? OpenIdTestBase.OPUriSsl : OpenIdTestBase.OPUri, + useSsl ? OPUriSsl : OPUri, new string[] { Protocol.Lookup(providerVersion).ClaimedIdentifierServiceTypeURI }); var local_id = useSsl ? OPLocalIdentifiersSsl[user] : OPLocalIdentifiers[user]; var claimed_id = delegating ? (useSsl ? VanityUriSsl : VanityUri) : local_id; @@ -135,50 +142,59 @@ namespace DotNetOpenAuth.Test.OpenId { } /// <summary> - /// A default implementation of a simple provider that responds to authentication requests + /// Gets a default implementation of a simple provider that responds to authentication requests /// per the scenario that is being simulated. /// </summary> - /// <param name="provider">The OpenIdProvider on which the process messages.</param> /// <remarks> /// This is a very useful method to pass to the OpenIdCoordinator constructor for the Provider argument. /// </remarks> - internal void AutoProvider(OpenIdProvider provider) { - while (!((CoordinatingChannel)provider.Channel).RemoteChannel.IsDisposed) { - IRequest request = provider.GetRequest(); - if (request == null) { - continue; - } + internal void RegisterAutoProvider() { + this.Handle(OPUri).By( + async (req, ct) => { + var provider = new OpenIdProvider(new StandardProviderApplicationStore(), this.HostFactories); + return await this.AutoProviderActionAsync(provider, req, ct); + }); + } - if (!request.IsResponseReady) { - var authRequest = (DotNetOpenAuth.OpenId.Provider.IAuthenticationRequest)request; - switch (this.AutoProviderScenario) { - case Scenarios.AutoApproval: - authRequest.IsAuthenticated = true; - break; - case Scenarios.AutoApprovalAddFragment: - authRequest.SetClaimedIdentifierFragment("frag"); - authRequest.IsAuthenticated = true; - break; - case Scenarios.ApproveOnSetup: - authRequest.IsAuthenticated = !authRequest.Immediate; - break; - case Scenarios.AlwaysDeny: - authRequest.IsAuthenticated = false; - break; - default: - // All other scenarios are done programmatically only. - throw new InvalidOperationException("Unrecognized scenario"); - } + /// <summary> + /// Gets a default implementation of a simple provider that responds to authentication requests + /// per the scenario that is being simulated. + /// </summary> + /// <remarks> + /// This is a very useful method to pass to the OpenIdCoordinator constructor for the Provider argument. + /// </remarks> + internal async Task<HttpResponseMessage> AutoProviderActionAsync(OpenIdProvider provider, HttpRequestMessage req, CancellationToken ct) { + IRequest request = await provider.GetRequestAsync(req, ct); + Assert.That(request, Is.Not.Null); + + if (!request.IsResponseReady) { + var authRequest = (IAuthenticationRequest)request; + switch (this.AutoProviderScenario) { + case Scenarios.AutoApproval: + authRequest.IsAuthenticated = true; + break; + case Scenarios.AutoApprovalAddFragment: + authRequest.SetClaimedIdentifierFragment("frag"); + authRequest.IsAuthenticated = true; + break; + case Scenarios.ApproveOnSetup: + authRequest.IsAuthenticated = !authRequest.Immediate; + break; + case Scenarios.AlwaysDeny: + authRequest.IsAuthenticated = false; + break; + default: + // All other scenarios are done programmatically only. + throw new InvalidOperationException("Unrecognized scenario"); } - - provider.Respond(request); } + + return await provider.PrepareResponseAsync(request, ct); } - internal IEnumerable<IdentifierDiscoveryResult> Discover(Identifier identifier) { + internal Task<IEnumerable<IdentifierDiscoveryResult>> DiscoverAsync(Identifier identifier, CancellationToken cancellationToken = default(CancellationToken)) { var rp = this.CreateRelyingParty(true); - rp.Channel.WebRequestHandler = this.RequestHandler; - return rp.Discover(identifier); + return rp.DiscoverAsync(identifier, cancellationToken); } protected Realm GetMockRealm(bool useSsl) { @@ -196,8 +212,8 @@ namespace DotNetOpenAuth.Test.OpenId { protected Identifier GetMockIdentifier(ProtocolVersion providerVersion, bool useSsl, bool delegating) { var se = GetServiceEndpoint(0, providerVersion, 10, useSsl, delegating); - UriIdentifier identityUri = (UriIdentifier)se.ClaimedIdentifier; - return new MockIdentifier(identityUri, this.MockResponder, new IdentifierDiscoveryResult[] { se }); + this.RegisterMockXrdsResponse(se); + return se.ClaimedIdentifier; } protected Identifier GetMockDualIdentifier() { @@ -208,8 +224,8 @@ namespace DotNetOpenAuth.Test.OpenId { IdentifierDiscoveryResult.CreateForProviderIdentifier(protocol.ClaimedIdentifierForOPIdentifier, opDesc, 20, 20), }; - Identifier dualId = new MockIdentifier(VanityUri, this.MockResponder, dualResults); - return dualId; + this.RegisterMockXrdsResponse(VanityUri, dualResults); + return VanityUri; } /// <summary> @@ -226,9 +242,7 @@ namespace DotNetOpenAuth.Test.OpenId { /// <param name="stateless">if set to <c>true</c> a stateless RP is created.</param> /// <returns>The new instance.</returns> protected OpenIdRelyingParty CreateRelyingParty(bool stateless) { - var rp = new OpenIdRelyingParty(stateless ? null : new StandardRelyingPartyApplicationStore()); - rp.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; - rp.DiscoveryServices.Add(new MockIdentifierDiscoveryService()); + var rp = new OpenIdRelyingParty(stateless ? null : new StandardRelyingPartyApplicationStore(), this.HostFactories); return rp; } @@ -237,10 +251,89 @@ namespace DotNetOpenAuth.Test.OpenId { /// </summary> /// <returns>The new instance.</returns> protected OpenIdProvider CreateProvider() { - var op = new OpenIdProvider(new StandardProviderApplicationStore()); - op.Channel.WebRequestHandler = this.MockResponder.MockWebRequestHandler; - op.DiscoveryServices.Add(new MockIdentifierDiscoveryService()); + var op = new OpenIdProvider(new StandardProviderApplicationStore(), this.HostFactories); return op; } + + protected internal void HandleProvider(Func<OpenIdProvider, HttpRequestMessage, Task<HttpResponseMessage>> provider) { + var op = this.CreateProvider(); + this.Handle(OPUri).By(async req => { + return await provider(op, req); + }); + } + + /// <summary> + /// Simulates an extension request and response. + /// </summary> + /// <param name="protocol">The protocol to use in the roundtripping.</param> + /// <param name="requests">The extensions to add to the request message.</param> + /// <param name="responses">The extensions to add to the response message.</param> + /// <remarks> + /// This method relies on the extension objects' Equals methods to verify + /// accurate transport. The Equals methods should be verified by separate tests. + /// </remarks> + internal async Task RoundtripAsync( + Protocol protocol, IEnumerable<IOpenIdMessageExtension> requests, IEnumerable<IOpenIdMessageExtension> responses) { + var securitySettings = new ProviderSecuritySettings(); + var cryptoKeyStore = new MemoryCryptoKeyStore(); + var associationStore = new ProviderAssociationHandleEncoder(cryptoKeyStore); + Association association = HmacShaAssociationProvider.Create( + protocol, + protocol.Args.SignatureAlgorithm.Best, + AssociationRelyingPartyType.Smart, + associationStore, + securitySettings); + + this.HandleProvider( + async (op, req) => { + ExtensionTestUtilities.RegisterExtension(op.Channel, Mocks.MockOpenIdExtension.Factory); + var key = cryptoKeyStore.GetCurrentKey( + ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, TimeSpan.FromSeconds(1)); + op.CryptoKeyStore.StoreKey( + ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, key.Key, key.Value); + var request = await op.Channel.ReadFromRequestAsync<CheckIdRequest>(req, CancellationToken.None); + var response = new PositiveAssertionResponse(request); + var receivedRequests = request.Extensions.Cast<IOpenIdMessageExtension>(); + CollectionAssert<IOpenIdMessageExtension>.AreEquivalentByEquality(requests.ToArray(), receivedRequests.ToArray()); + + foreach (var extensionResponse in responses) { + response.Extensions.Add(extensionResponse); + } + + return await op.Channel.PrepareResponseAsync(response); + }); + + { + var rp = this.CreateRelyingParty(); + ExtensionTestUtilities.RegisterExtension(rp.Channel, Mocks.MockOpenIdExtension.Factory); + var requestBase = new CheckIdRequest(protocol.Version, OpenIdTestBase.OPUri, AuthenticationRequestMode.Immediate); + OpenIdTestBase.StoreAssociation(rp, OpenIdTestBase.OPUri, association); + requestBase.AssociationHandle = association.Handle; + requestBase.ClaimedIdentifier = "http://claimedid"; + requestBase.LocalIdentifier = "http://localid"; + requestBase.ReturnTo = OpenIdTestBase.RPUri; + + foreach (IOpenIdMessageExtension extension in requests) { + requestBase.Extensions.Add(extension); + } + + var redirectingRequest = await rp.Channel.PrepareResponseAsync(requestBase); + Uri redirectingResponseUri; + this.HostFactories.AllowAutoRedirects = false; + using (var httpClient = rp.Channel.HostFactories.CreateHttpClient()) { + using (var redirectingResponse = await httpClient.GetAsync(redirectingRequest.Headers.Location)) { + Assert.AreEqual(HttpStatusCode.Found, redirectingResponse.StatusCode); + redirectingResponseUri = redirectingResponse.Headers.Location; + } + } + + var response = + await + rp.Channel.ReadFromRequestAsync<PositiveAssertionResponse>( + new HttpRequestMessage(HttpMethod.Get, redirectingResponseUri), CancellationToken.None); + var receivedResponses = response.Extensions.Cast<IOpenIdMessageExtension>(); + CollectionAssert<IOpenIdMessageExtension>.AreEquivalentByEquality(responses.ToArray(), receivedResponses.ToArray()); + } + } } } |