//----------------------------------------------------------------------- // // Copyright (c) Outercurve Foundation. All rights reserved. // //----------------------------------------------------------------------- namespace DotNetOpenAuth.Test.OpenId.Extensions { using System; using System.Collections.Generic; using System.Linq; using System.Net.Http; using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.ChannelElements; using DotNetOpenAuth.OpenId.Extensions; using DotNetOpenAuth.OpenId.Messages; using DotNetOpenAuth.OpenId.Provider; using DotNetOpenAuth.OpenId.RelyingParty; using DotNetOpenAuth.Test.Messaging; using Validation; public static class ExtensionTestUtilities { /// /// Simulates an extension request and response. /// /// The protocol to use in the roundtripping. /// The extensions to add to the request message. /// The extensions to add to the response message. /// /// This method relies on the extension objects' Equals methods to verify /// accurate transport. The Equals methods should be verified by separate tests. /// internal static async Task RoundtripAsync( Protocol protocol, IEnumerable requests, IEnumerable responses) { var securitySettings = new ProviderSecuritySettings(); var cryptoKeyStore = new MemoryCryptoKeyStore(); var associationStore = new ProviderAssociationHandleEncoder(cryptoKeyStore); Association association = HmacShaAssociationProvider.Create(protocol, protocol.Args.SignatureAlgorithm.Best, AssociationRelyingPartyType.Smart, associationStore, securitySettings); await CoordinatorBase.RunAsync( CoordinatorBase.RelyingPartyDriver(async (rp, ct) => { RegisterExtension(rp.Channel, Mocks.MockOpenIdExtension.Factory); var requestBase = new CheckIdRequest(protocol.Version, OpenIdTestBase.OPUri, AuthenticationRequestMode.Immediate); OpenIdTestBase.StoreAssociation(rp, OpenIdTestBase.OPUri, association); requestBase.AssociationHandle = association.Handle; requestBase.ClaimedIdentifier = "http://claimedid"; requestBase.LocalIdentifier = "http://localid"; requestBase.ReturnTo = OpenIdTestBase.RPUri; foreach (IOpenIdMessageExtension extension in requests) { requestBase.Extensions.Add(extension); } var redirectingRequest = await rp.Channel.PrepareResponseAsync(requestBase); Uri redirectingResponseUri; using (var httpClient = rp.Channel.HostFactories.CreateHttpClient()) { using (var redirectingResponse = await httpClient.GetAsync(redirectingRequest.Headers.Location, ct)) { redirectingResponse.EnsureSuccessStatusCode(); redirectingResponseUri = redirectingResponse.Headers.Location; } } var response = await rp.Channel.ReadFromRequestAsync(new HttpRequestMessage(HttpMethod.Get, redirectingResponseUri), ct); var receivedResponses = response.Extensions.Cast(); CollectionAssert.AreEquivalentByEquality(responses.ToArray(), receivedResponses.ToArray()); }), CoordinatorBase.HandleProvider(async (op, req, ct) => { RegisterExtension(op.Channel, Mocks.MockOpenIdExtension.Factory); var key = cryptoKeyStore.GetCurrentKey(ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, TimeSpan.FromSeconds(1)); op.CryptoKeyStore.StoreKey(ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, key.Key, key.Value); var request = await op.Channel.ReadFromRequestAsync(req, ct); var response = new PositiveAssertionResponse(request); var receivedRequests = request.Extensions.Cast(); CollectionAssert.AreEquivalentByEquality(requests.ToArray(), receivedRequests.ToArray()); foreach (var extensionResponse in responses) { response.Extensions.Add(extensionResponse); } return await op.Channel.PrepareResponseAsync(response, ct); })); } internal static void RegisterExtension(Channel channel, StandardOpenIdExtensionFactory.CreateDelegate extensionFactory) { Requires.NotNull(channel, "channel"); var factory = (OpenIdExtensionFactoryAggregator)channel.BindingElements.OfType().Single().ExtensionFactory; factory.Factories.OfType().Single().RegisterExtension(extensionFactory); } } }