//----------------------------------------------------------------------- // // Copyright (c) Outercurve Foundation. All rights reserved. // //----------------------------------------------------------------------- namespace DotNetOpenAuth.Test.OpenId { using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Net; using System.Net.Http; using System.Reflection; using System.Threading; using System.Threading.Tasks; using DotNetOpenAuth.Configuration; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Messages; using DotNetOpenAuth.OpenId.Provider; using DotNetOpenAuth.OpenId.RelyingParty; using DotNetOpenAuth.Test.Messaging; using DotNetOpenAuth.Test.Mocks; using DotNetOpenAuth.Test.OpenId.Extensions; using NUnit.Framework; using IAuthenticationRequest = DotNetOpenAuth.OpenId.Provider.IAuthenticationRequest; public class OpenIdTestBase : TestBase { protected internal const string IdentifierSelect = "http://specs.openid.net/auth/2.0/identifier_select"; protected internal static readonly Uri BaseMockUri = new Uri("http://localhost/"); protected internal static readonly Uri BaseMockUriSsl = new Uri("https://localhost/"); protected internal static readonly Uri OPUri = new Uri(BaseMockUri, "/provider/endpoint"); protected internal static readonly Uri OPUriSsl = new Uri(BaseMockUriSsl, "/provider/endpoint"); protected internal static readonly Uri[] OPLocalIdentifiers = new[] { new Uri(OPUri, "/provider/someUser0"), new Uri(OPUri, "/provider/someUser1") }; protected internal static readonly Uri[] OPLocalIdentifiersSsl = new[] { new Uri(OPUriSsl, "/provider/someUser0"), new Uri(OPUriSsl, "/provider/someUser1") }; // Vanity URLs are Claimed Identifiers that delegate to some OP and its local identifier. protected internal static readonly Uri VanityUri = new Uri(BaseMockUri, "/userControlled/identity"); protected internal static readonly Uri VanityUriSsl = new Uri(BaseMockUriSsl, "/userControlled/identity"); protected internal static readonly Uri RPUri = new Uri(BaseMockUri, "/relyingparty/login"); protected internal static readonly Uri RPUriSsl = new Uri(BaseMockUriSsl, "/relyingparty/login"); protected internal static readonly Uri RPRealmUri = new Uri(BaseMockUri, "/relyingparty/"); protected internal static readonly Uri RPRealmUriSsl = new Uri(BaseMockUriSsl, "/relyingparty/"); /// /// Initializes a new instance of the class. /// internal OpenIdTestBase() { this.AutoProviderScenario = Scenarios.AutoApproval; } public enum Scenarios { AutoApproval, AutoApprovalAddFragment, ApproveOnSetup, AlwaysDeny, } internal Scenarios AutoProviderScenario { get; set; } protected RelyingPartySecuritySettings RelyingPartySecuritySettings { get; private set; } protected ProviderSecuritySettings ProviderSecuritySettings { get; private set; } [SetUp] public override void SetUp() { base.SetUp(); this.RelyingPartySecuritySettings = OpenIdElement.Configuration.RelyingParty.SecuritySettings.CreateSecuritySettings(); this.ProviderSecuritySettings = OpenIdElement.Configuration.Provider.SecuritySettings.CreateSecuritySettings(); this.AutoProviderScenario = Scenarios.AutoApproval; Identifier.EqualityOnStrings = true; } [TearDown] public override void Cleanup() { base.Cleanup(); Identifier.EqualityOnStrings = false; } /// /// Forces storage of an association in an RP's association store. /// /// The relying party. /// The provider endpoint. /// The association. internal static void StoreAssociation(OpenIdRelyingParty relyingParty, Uri providerEndpoint, Association association) { // Only store the association if the RP is not in stateless mode. if (relyingParty.AssociationManager.AssociationStoreTestHook != null) { relyingParty.AssociationManager.AssociationStoreTestHook.StoreAssociation(providerEndpoint, association); } } /// /// Returns the content of a given embedded resource. /// /// The path of the file as it appears within the project, /// where the leading / marks the root directory of the project. /// The content of the requested resource. internal static string LoadEmbeddedFile(string path) { if (!path.StartsWith("/")) { path = "/" + path; } path = "DotNetOpenAuth.Test.OpenId" + path.Replace('/', '.'); Stream resource = Assembly.GetExecutingAssembly().GetManifestResourceStream(path); if (resource == null) { throw new ArgumentException(); } using (StreamReader sr = new StreamReader(resource)) { return sr.ReadToEnd(); } } internal static IdentifierDiscoveryResult GetServiceEndpoint(int user, ProtocolVersion providerVersion, int servicePriority, bool useSsl) { return GetServiceEndpoint(user, providerVersion, servicePriority, useSsl, false); } internal static IdentifierDiscoveryResult GetServiceEndpoint(int user, ProtocolVersion providerVersion, int servicePriority, bool useSsl, bool delegating) { var providerEndpoint = new ProviderEndpointDescription( useSsl ? OPUriSsl : OPUri, new string[] { Protocol.Lookup(providerVersion).ClaimedIdentifierServiceTypeURI }); var local_id = useSsl ? OPLocalIdentifiersSsl[user] : OPLocalIdentifiers[user]; var claimed_id = delegating ? (useSsl ? VanityUriSsl : VanityUri) : local_id; return IdentifierDiscoveryResult.CreateForClaimedIdentifier( claimed_id, claimed_id, local_id, providerEndpoint, servicePriority, 10); } /// /// Gets a default implementation of a simple provider that responds to authentication requests /// per the scenario that is being simulated. /// /// /// This is a very useful method to pass to the OpenIdCoordinator constructor for the Provider argument. /// internal void RegisterAutoProvider() { this.Handle(OPUri).By( async (req, ct) => { var provider = new OpenIdProvider(new StandardProviderApplicationStore(), this.HostFactories); return await this.AutoProviderActionAsync(provider, req, ct); }); } /// /// Gets a default implementation of a simple provider that responds to authentication requests /// per the scenario that is being simulated. /// /// /// This is a very useful method to pass to the OpenIdCoordinator constructor for the Provider argument. /// internal async Task AutoProviderActionAsync(OpenIdProvider provider, HttpRequestMessage req, CancellationToken ct) { IRequest request = await provider.GetRequestAsync(req, ct); Assert.That(request, Is.Not.Null); if (!request.IsResponseReady) { var authRequest = (IAuthenticationRequest)request; switch (this.AutoProviderScenario) { case Scenarios.AutoApproval: authRequest.IsAuthenticated = true; break; case Scenarios.AutoApprovalAddFragment: authRequest.SetClaimedIdentifierFragment("frag"); authRequest.IsAuthenticated = true; break; case Scenarios.ApproveOnSetup: authRequest.IsAuthenticated = !authRequest.Immediate; break; case Scenarios.AlwaysDeny: authRequest.IsAuthenticated = false; break; default: // All other scenarios are done programmatically only. throw new InvalidOperationException("Unrecognized scenario"); } } return await provider.PrepareResponseAsync(request, ct); } internal Task> DiscoverAsync(Identifier identifier, CancellationToken cancellationToken = default(CancellationToken)) { var rp = this.CreateRelyingParty(true); return rp.DiscoverAsync(identifier, cancellationToken); } protected Realm GetMockRealm(bool useSsl) { var rpDescription = new RelyingPartyEndpointDescription(useSsl ? RPUriSsl : RPUri, new string[] { Protocol.V20.RPReturnToTypeURI }); return new MockRealm(useSsl ? RPRealmUriSsl : RPRealmUri, rpDescription); } protected Identifier GetMockIdentifier(ProtocolVersion providerVersion) { return this.GetMockIdentifier(providerVersion, false); } protected Identifier GetMockIdentifier(ProtocolVersion providerVersion, bool useSsl) { return this.GetMockIdentifier(providerVersion, useSsl, false); } protected Identifier GetMockIdentifier(ProtocolVersion providerVersion, bool useSsl, bool delegating) { var se = GetServiceEndpoint(0, providerVersion, 10, useSsl, delegating); this.RegisterMockXrdsResponse(se); return se.ClaimedIdentifier; } protected Identifier GetMockDualIdentifier() { Protocol protocol = Protocol.Default; var opDesc = new ProviderEndpointDescription(OPUri, protocol.Version); var dualResults = new IdentifierDiscoveryResult[] { IdentifierDiscoveryResult.CreateForClaimedIdentifier(VanityUri.AbsoluteUri, OPLocalIdentifiers[0], opDesc, 10, 10), IdentifierDiscoveryResult.CreateForProviderIdentifier(protocol.ClaimedIdentifierForOPIdentifier, opDesc, 20, 20), }; this.RegisterMockXrdsResponse(VanityUri, dualResults); return VanityUri; } /// /// Creates a standard instance for general testing. /// /// The new instance. protected OpenIdRelyingParty CreateRelyingParty() { return this.CreateRelyingParty(false); } /// /// Creates a standard instance for general testing. /// /// if set to true a stateless RP is created. /// The new instance. protected OpenIdRelyingParty CreateRelyingParty(bool stateless) { var rp = new OpenIdRelyingParty(stateless ? null : new StandardRelyingPartyApplicationStore(), this.HostFactories); return rp; } /// /// Creates a standard instance for general testing. /// /// The new instance. protected OpenIdProvider CreateProvider() { var op = new OpenIdProvider(new StandardProviderApplicationStore(), this.HostFactories); return op; } protected internal void HandleProvider(Func> provider) { this.Handle(OPUri).By(async req => { var op = new OpenIdProvider(new StandardProviderApplicationStore()); return await provider(op, req); }); } /// /// Simulates an extension request and response. /// /// The protocol to use in the roundtripping. /// The extensions to add to the request message. /// The extensions to add to the response message. /// /// This method relies on the extension objects' Equals methods to verify /// accurate transport. The Equals methods should be verified by separate tests. /// internal async Task RoundtripAsync( Protocol protocol, IEnumerable requests, IEnumerable responses) { var securitySettings = new ProviderSecuritySettings(); var cryptoKeyStore = new MemoryCryptoKeyStore(); var associationStore = new ProviderAssociationHandleEncoder(cryptoKeyStore); Association association = HmacShaAssociationProvider.Create( protocol, protocol.Args.SignatureAlgorithm.Best, AssociationRelyingPartyType.Smart, associationStore, securitySettings); this.HandleProvider( async (op, req) => { ExtensionTestUtilities.RegisterExtension(op.Channel, Mocks.MockOpenIdExtension.Factory); var key = cryptoKeyStore.GetCurrentKey( ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, TimeSpan.FromSeconds(1)); op.CryptoKeyStore.StoreKey( ProviderAssociationHandleEncoder.AssociationHandleEncodingSecretBucket, key.Key, key.Value); var request = await op.Channel.ReadFromRequestAsync(req, CancellationToken.None); var response = new PositiveAssertionResponse(request); var receivedRequests = request.Extensions.Cast(); CollectionAssert.AreEquivalentByEquality(requests.ToArray(), receivedRequests.ToArray()); foreach (var extensionResponse in responses) { response.Extensions.Add(extensionResponse); } return await op.Channel.PrepareResponseAsync(response); }); { var rp = this.CreateRelyingParty(); ExtensionTestUtilities.RegisterExtension(rp.Channel, Mocks.MockOpenIdExtension.Factory); var requestBase = new CheckIdRequest(protocol.Version, OpenIdTestBase.OPUri, AuthenticationRequestMode.Immediate); OpenIdTestBase.StoreAssociation(rp, OpenIdTestBase.OPUri, association); requestBase.AssociationHandle = association.Handle; requestBase.ClaimedIdentifier = "http://claimedid"; requestBase.LocalIdentifier = "http://localid"; requestBase.ReturnTo = OpenIdTestBase.RPUri; foreach (IOpenIdMessageExtension extension in requests) { requestBase.Extensions.Add(extension); } var redirectingRequest = await rp.Channel.PrepareResponseAsync(requestBase); Uri redirectingResponseUri; this.HostFactories.AllowAutoRedirects = false; using (var httpClient = rp.Channel.HostFactories.CreateHttpClient()) { using (var redirectingResponse = await httpClient.GetAsync(redirectingRequest.Headers.Location)) { Assert.AreEqual(HttpStatusCode.Found, redirectingResponse.StatusCode); redirectingResponseUri = redirectingResponse.Headers.Location; } } var response = await rp.Channel.ReadFromRequestAsync( new HttpRequestMessage(HttpMethod.Get, redirectingResponseUri), CancellationToken.None); var receivedResponses = response.Extensions.Cast(); CollectionAssert.AreEquivalentByEquality(responses.ToArray(), receivedResponses.ToArray()); } } } }