summaryrefslogtreecommitdiffstats
path: root/src/DotNetOpenId/RelyingParty/AuthenticationRequest.cs
diff options
context:
space:
mode:
Diffstat (limited to 'src/DotNetOpenId/RelyingParty/AuthenticationRequest.cs')
-rw-r--r--src/DotNetOpenId/RelyingParty/AuthenticationRequest.cs223
1 files changed, 175 insertions, 48 deletions
diff --git a/src/DotNetOpenId/RelyingParty/AuthenticationRequest.cs b/src/DotNetOpenId/RelyingParty/AuthenticationRequest.cs
index cc9c3ae..a3c71ca 100644
--- a/src/DotNetOpenId/RelyingParty/AuthenticationRequest.cs
+++ b/src/DotNetOpenId/RelyingParty/AuthenticationRequest.cs
@@ -1,11 +1,10 @@
using System;
using System.Collections.Generic;
-using System.Text;
-using DotNetOpenId;
+using System.Collections.ObjectModel;
using System.Collections.Specialized;
+using System.Diagnostics;
using System.Globalization;
using System.Web;
-using System.Diagnostics;
namespace DotNetOpenId.RelyingParty {
/// <summary>
@@ -30,18 +29,19 @@ namespace DotNetOpenId.RelyingParty {
class AuthenticationRequest : IAuthenticationRequest {
Association assoc;
ServiceEndpoint endpoint;
- MessageEncoder encoder;
Protocol protocol { get { return endpoint.Protocol; } }
+ internal OpenIdRelyingParty RelyingParty;
AuthenticationRequest(string token, Association assoc, ServiceEndpoint endpoint,
- Realm realm, Uri returnToUrl, MessageEncoder encoder) {
+ Realm realm, Uri returnToUrl, OpenIdRelyingParty relyingParty) {
if (endpoint == null) throw new ArgumentNullException("endpoint");
if (realm == null) throw new ArgumentNullException("realm");
if (returnToUrl == null) throw new ArgumentNullException("returnToUrl");
- if (encoder == null) throw new ArgumentNullException("encoder");
+ if (relyingParty == null) throw new ArgumentNullException("relyingParty");
+
this.assoc = assoc;
this.endpoint = endpoint;
- this.encoder = encoder;
+ RelyingParty = relyingParty;
Realm = realm;
ReturnToUrl = returnToUrl;
@@ -52,38 +52,37 @@ namespace DotNetOpenId.RelyingParty {
AddCallbackArguments(DotNetOpenId.RelyingParty.Token.TokenKey, token);
}
internal static AuthenticationRequest Create(Identifier userSuppliedIdentifier,
- Realm realm, Uri returnToUrl, IRelyingPartyApplicationStore store, MessageEncoder encoder) {
+ OpenIdRelyingParty relyingParty, Realm realm, Uri returnToUrl) {
if (userSuppliedIdentifier == null) throw new ArgumentNullException("userSuppliedIdentifier");
+ if (relyingParty == null) throw new ArgumentNullException("relyingParty");
if (realm == null) throw new ArgumentNullException("realm");
- if (TraceUtil.Switch.TraceInfo) {
- Trace.TraceInformation("Creating authentication request for user supplied Identifier: {0}",
- userSuppliedIdentifier);
+ userSuppliedIdentifier = userSuppliedIdentifier.TrimFragment();
+ if (relyingParty.Settings.RequireSsl) {
+ // Rather than check for successful SSL conversion at this stage,
+ // We'll wait for secure discovery to fail on the new identifier.
+ userSuppliedIdentifier.TryRequireSsl(out userSuppliedIdentifier);
}
- if (TraceUtil.Switch.TraceVerbose) {
- Trace.Indent();
- Trace.TraceInformation("Realm: {0}", realm);
- Trace.TraceInformation("Return To: {0}", returnToUrl);
- Trace.Unindent();
- }
- if (TraceUtil.Switch.TraceWarning && returnToUrl.Query != null) {
+ Logger.InfoFormat("Creating authentication request for user supplied Identifier: {0}",
+ userSuppliedIdentifier);
+ Logger.DebugFormat("Realm: {0}", realm);
+ Logger.DebugFormat("Return To: {0}", returnToUrl);
+ Logger.DebugFormat("RequireSsl: {0}", userSuppliedIdentifier.IsDiscoverySecureEndToEnd);
+
+ if (Logger.IsWarnEnabled && returnToUrl.Query != null) {
NameValueCollection returnToArgs = HttpUtility.ParseQueryString(returnToUrl.Query);
foreach (string key in returnToArgs) {
if (OpenIdRelyingParty.ShouldParameterBeStrippedFromReturnToUrl(key)) {
- Trace.TraceWarning("OpenId argument \"{0}\" found in return_to URL. This can corrupt an OpenID response.", key);
+ Logger.WarnFormat("OpenId argument \"{0}\" found in return_to URL. This can corrupt an OpenID response.", key);
break;
}
}
}
- var endpoint = userSuppliedIdentifier.Discover();
+ var endpoints = new List<ServiceEndpoint>(userSuppliedIdentifier.Discover());
+ ServiceEndpoint endpoint = selectEndpoint(endpoints.AsReadOnly(), relyingParty);
if (endpoint == null)
throw new OpenIdException(Strings.OpenIdEndpointNotFound);
- if (TraceUtil.Switch.TraceVerbose) {
- Trace.Indent();
- Trace.TraceInformation("Discovered provider endpoint: {0}", endpoint);
- Trace.Unindent();
- }
// Throw an exception now if the realm and the return_to URLs don't match
// as required by the provider. We could wait for the provider to test this and
@@ -92,35 +91,151 @@ namespace DotNetOpenId.RelyingParty {
throw new OpenIdException(string.Format(CultureInfo.CurrentCulture,
Strings.ReturnToNotUnderRealm, returnToUrl, realm));
+ string token = new Token(endpoint).Serialize(relyingParty.Store);
+ // Retrieve the association, but don't create one, as a creation was already
+ // attempted by the selectEndpoint method.
+ Association association = relyingParty.Store != null ? getAssociation(relyingParty, endpoint, false) : null;
+
return new AuthenticationRequest(
- new Token(endpoint).Serialize(store),
- store != null ? getAssociation(endpoint, store) : null,
- endpoint, realm, returnToUrl, encoder);
+ token, association, endpoint, realm, returnToUrl, relyingParty);
+ }
+
+ /// <summary>
+ /// Returns a filtered and sorted list of the available OP endpoints for a discovered Identifier.
+ /// </summary>
+ private static List<ServiceEndpoint> filterAndSortEndpoints(ReadOnlyCollection<ServiceEndpoint> endpoints,
+ OpenIdRelyingParty relyingParty) {
+ if (endpoints == null) throw new ArgumentNullException("endpoints");
+ if (relyingParty == null) throw new ArgumentNullException("relyingParty");
+
+ // Construct the endpoints filters based on criteria given by the host web site.
+ EndpointSelector versionFilter = ep => ((ServiceEndpoint)ep).Protocol.Version >= Protocol.Lookup(relyingParty.Settings.MinimumRequiredOpenIdVersion).Version;
+ EndpointSelector hostingSiteFilter = relyingParty.EndpointFilter ?? (ep => true);
+
+ var filteredEndpoints = new List<IXrdsProviderEndpoint>(endpoints.Count);
+ foreach (ServiceEndpoint endpoint in endpoints) {
+ if (versionFilter(endpoint) && hostingSiteFilter(endpoint)) {
+ filteredEndpoints.Add(endpoint);
+ }
+ }
+
+ // Sort endpoints so that the first one in the list is the most preferred one.
+ filteredEndpoints.Sort(relyingParty.EndpointOrder);
+
+ List<ServiceEndpoint> endpointList = new List<ServiceEndpoint>(filteredEndpoints.Count);
+ foreach (ServiceEndpoint endpoint in filteredEndpoints) {
+ endpointList.Add(endpoint);
+ }
+ return endpointList;
}
- static Association getAssociation(ServiceEndpoint provider, IRelyingPartyApplicationStore store) {
+
+ /// <summary>
+ /// Chooses which provider endpoint is the best one to use.
+ /// </summary>
+ /// <returns>The best endpoint, or null if no acceptable endpoints were found.</returns>
+ private static ServiceEndpoint selectEndpoint(ReadOnlyCollection<ServiceEndpoint> endpoints,
+ OpenIdRelyingParty relyingParty) {
+
+ List<ServiceEndpoint> filteredEndpoints = filterAndSortEndpoints(endpoints, relyingParty);
+ if (filteredEndpoints.Count != endpoints.Count) {
+ Logger.DebugFormat("Some endpoints were filtered out. Total endpoints remaining: {0}", filteredEndpoints.Count);
+ }
+ if (Logger.IsDebugEnabled) {
+ if (Util.AreSequencesEquivalent(endpoints, filteredEndpoints)) {
+ Logger.Debug("Filtering and sorting of endpoints did not affect the list.");
+ } else {
+ Logger.Debug("After filtering and sorting service endpoints, this is the new prioritized list:");
+ Logger.Debug(Util.ToString(filteredEndpoints, true));
+ }
+ }
+
+ // If there are no endpoint candidates...
+ if (filteredEndpoints.Count == 0) {
+ return null;
+ }
+
+ // If we don't have an application store, we have no place to record an association to
+ // and therefore can only take our best shot at one of the endpoints.
+ if (relyingParty.Store == null) {
+ Logger.Debug("No state store, so the first endpoint available is selected.");
+ return filteredEndpoints[0];
+ }
+
+ // Go through each endpoint until we find one that we can successfully create
+ // an association with. This is our only hint about whether an OP is up and running.
+ // The idea here is that we don't want to redirect the user to a dead OP for authentication.
+ // If the user has multiple OPs listed in his/her XRDS document, then we'll go down the list
+ // and try each one until we find one that's good.
+ int winningEndpointIndex = 0;
+ foreach (ServiceEndpoint endpointCandidate in filteredEndpoints) {
+ winningEndpointIndex++;
+ // One weakness of this method is that an OP that's down, but with whom we already
+ // created an association in the past will still pass this "are you alive?" test.
+ Association association = getAssociation(relyingParty, endpointCandidate, true);
+ if (association != null) {
+ Logger.DebugFormat("Endpoint #{0} (1-based index) responded to an association request. Selecting that endpoint.", winningEndpointIndex);
+ // We have a winner!
+ return endpointCandidate;
+ }
+ }
+
+ // Since all OPs failed to form an association with us, just return the first endpoint
+ // and hope for the best.
+ Logger.Debug("All endpoints failed to respond to an association request. Selecting first endpoint to try to authenticate to.");
+ return endpoints[0];
+ }
+ static Association getAssociation(OpenIdRelyingParty relyingParty, ServiceEndpoint provider, bool createNewAssociationIfNeeded) {
+ if (relyingParty == null) throw new ArgumentNullException("relyingParty");
if (provider == null) throw new ArgumentNullException("provider");
- if (store == null) throw new ArgumentNullException("store");
- Association assoc = store.GetAssociation(provider.ProviderEndpoint);
+ // TODO: we need a way to lookup an association that fulfills a given set of security
+ // requirements. We may have a SHA-1 association and a SHA-256 association that need
+ // to be called for specifically. (a bizzare scenario, admittedly, making this low priority).
+ Association assoc = relyingParty.Store.GetAssociation(provider.ProviderEndpoint);
+
+ // If the returned association does not fulfill security requirements, ignore it.
+ if (assoc != null && !relyingParty.Settings.IsAssociationInPermittedRange(provider.Protocol, assoc.GetAssociationType(provider.Protocol))) {
+ assoc = null;
+ }
- if (assoc == null || !assoc.HasUsefulLifeRemaining) {
- var req = AssociateRequest.Create(provider);
+ if ((assoc == null || !assoc.HasUsefulLifeRemaining) && createNewAssociationIfNeeded) {
+ var req = AssociateRequest.Create(relyingParty, provider);
+ if (req == null) {
+ // this can happen if security requirements and protocol conflict
+ // to where there are no association types to choose from.
+ return null;
+ }
if (req.Response != null) {
// try again if we failed the first time and have a worthy second-try.
if (req.Response.Association == null && req.Response.SecondAttempt != null) {
- if (TraceUtil.Switch.TraceWarning) {
- Trace.TraceWarning("Initial association attempt failed, but will retry with Provider-suggested parameters.");
- }
+ Logger.Warn("Initial association attempt failed, but will retry with Provider-suggested parameters.");
req = req.Response.SecondAttempt;
}
assoc = req.Response.Association;
+ // Confirm that the association matches the type we requested (section 8.2.1)
+ // if this is a 2.0 OP (1.x OPs had freedom to differ from the requested type).
+ if (assoc != null && provider.Protocol.Version.Major >= 2) {
+ if (!string.Equals(
+ req.Args[provider.Protocol.openid.assoc_type],
+ Util.GetRequiredArg(req.Response.Args, provider.Protocol.openidnp.assoc_type),
+ StringComparison.Ordinal) ||
+ !string.Equals(
+ req.Args[provider.Protocol.openid.session_type],
+ Util.GetRequiredArg(req.Response.Args, provider.Protocol.openidnp.session_type),
+ StringComparison.Ordinal)) {
+ Logger.ErrorFormat("Provider responded with contradicting association parameters. Requested [{0}, {1}] but got [{2}, {3}] back.",
+ req.Args[provider.Protocol.openid.assoc_type],
+ req.Args[provider.Protocol.openid.session_type],
+ Util.GetRequiredArg(req.Response.Args, provider.Protocol.openidnp.assoc_type),
+ Util.GetRequiredArg(req.Response.Args, provider.Protocol.openidnp.session_type));
+
+ assoc = null;
+ }
+ }
if (assoc != null) {
- if (TraceUtil.Switch.TraceInfo)
- Trace.TraceInformation("Association with {0} established.", provider.ProviderEndpoint);
- store.StoreAssociation(provider.ProviderEndpoint, assoc);
+ Logger.InfoFormat("Association with {0} established.", provider.ProviderEndpoint);
+ relyingParty.Store.StoreAssociation(provider.ProviderEndpoint, assoc);
} else {
- if (TraceUtil.Switch.TraceError) {
- Trace.TraceError("Association attempt with {0} provider failed.", provider);
- }
+ Logger.ErrorFormat("Association attempt with {0} provider failed.", provider.ProviderEndpoint);
}
}
}
@@ -141,19 +256,31 @@ namespace DotNetOpenId.RelyingParty {
public AuthenticationRequestMode Mode { get; set; }
public Realm Realm { get; private set; }
public Uri ReturnToUrl { get; private set; }
- public Identifier ClaimedIdentifier { get { return endpoint.ClaimedIdentifier; } }
+ public Identifier ClaimedIdentifier {
+ get { return IsDirectedIdentity ? null : endpoint.ClaimedIdentifier; }
+ }
+ public bool IsDirectedIdentity {
+ get { return endpoint.ClaimedIdentifier == endpoint.Protocol.ClaimedIdentifierForOPIdentifier; }
+ }
/// <summary>
/// The detected version of OpenID implemented by the Provider.
/// </summary>
public Version ProviderVersion { get { return protocol.Version; } }
/// <summary>
- /// Gets the URL the user agent should be redirected to to begin the
+ /// Gets information about the OpenId Provider, as advertised by the
+ /// OpenId discovery documents found at the <see cref="ClaimedIdentifier"/>
+ /// location.
+ /// </summary>
+ IProviderEndpoint IAuthenticationRequest.Provider { get { return endpoint; } }
+
+ /// <summary>
+ /// Gets the response to send to the user agent to begin the
/// OpenID authentication process.
/// </summary>
public IResponse RedirectingResponse {
get {
UriBuilder returnToBuilder = new UriBuilder(ReturnToUrl);
- UriUtil.AppendQueryArgs(returnToBuilder, this.ReturnToArgs);
+ UriUtil.AppendAndReplaceQueryArgs(returnToBuilder, this.ReturnToArgs);
var qsArgs = new Dictionary<string, string>();
@@ -172,11 +299,11 @@ namespace DotNetOpenId.RelyingParty {
qsArgs.Add(protocol.openid.assoc_handle, this.assoc.Handle);
// Add on extension arguments
- foreach(var pair in OutgoingExtensions.GetArgumentsToSend(true))
+ foreach (var pair in OutgoingExtensions.GetArgumentsToSend(true))
qsArgs.Add(pair.Key, pair.Value);
var request = new IndirectMessageRequest(this.endpoint.ProviderEndpoint, qsArgs);
- return this.encoder.Encode(request);
+ return RelyingParty.Encoder.Encode(request);
}
}
@@ -214,7 +341,7 @@ namespace DotNetOpenId.RelyingParty {
/// This method requires an ASP.NET HttpContext.
/// </remarks>
public void RedirectToProvider() {
- if (HttpContext.Current == null || HttpContext.Current.Response == null)
+ if (HttpContext.Current == null || HttpContext.Current.Response == null)
throw new InvalidOperationException(Strings.CurrentHttpContextRequired);
RedirectingResponse.Send();
}