//-----------------------------------------------------------------------
//
// Copyright (c) Outercurve Foundation. All rights reserved.
//
//-----------------------------------------------------------------------
namespace DotNetOpenAuth.OpenId.ChannelElements {
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using DotNetOpenAuth.Messaging;
using DotNetOpenAuth.Messaging.Reflection;
using DotNetOpenAuth.OpenId.Extensions;
using DotNetOpenAuth.OpenId.Messages;
using Validation;
///
/// The binding element that serializes/deserializes OpenID extensions to/from
/// their carrying OpenID messages.
///
internal class ExtensionsBindingElement : IChannelBindingElement {
private static readonly Task NullTask = Task.FromResult(null);
private static readonly Task NoneTask =
Task.FromResult(MessageProtections.None);
///
/// False if unsigned extensions should be dropped. Must always be true on Providers, since RPs never sign extensions.
///
private readonly bool receiveUnsignedExtensions;
///
/// Initializes a new instance of the class.
///
/// The extension factory.
/// The security settings.
/// Security setting for relying parties. Should be true for Providers.
internal ExtensionsBindingElement(IOpenIdExtensionFactory extensionFactory, SecuritySettings securitySettings, bool receiveUnsignedExtensions) {
Requires.NotNull(extensionFactory, "extensionFactory");
Requires.NotNull(securitySettings, "securitySettings");
this.ExtensionFactory = extensionFactory;
this.receiveUnsignedExtensions = receiveUnsignedExtensions;
}
#region IChannelBindingElement Members
///
/// Gets or sets the channel that this binding element belongs to.
///
///
///
/// This property is set by the channel when it is first constructed.
///
public Channel Channel { get; set; }
///
/// Gets the extension factory.
///
public IOpenIdExtensionFactory ExtensionFactory { get; private set; }
///
/// Gets the protection offered (if any) by this binding element.
///
///
public MessageProtections Protection {
get { return MessageProtections.None; }
}
///
/// Prepares a message for sending based on the rules of this channel binding element.
///
/// The message to prepare for sending.
/// The cancellation token.
///
/// The protections (if any) that this binding element applied to the message.
/// Null if this binding element did not even apply to this binding element.
///
///
/// Implementations that provide message protection must honor the
/// properties where applicable.
///
[SuppressMessage("Microsoft.Maintainability", "CA1506:AvoidExcessiveClassCoupling", Justification = "It doesn't look too bad to me. :)")]
public Task ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) {
var extendableMessage = message as IProtocolMessageWithExtensions;
if (extendableMessage != null) {
Protocol protocol = Protocol.Lookup(message.Version);
MessageDictionary baseMessageDictionary = this.Channel.MessageDescriptions.GetAccessor(message);
// We have a helper class that will do all the heavy-lifting of organizing
// all the extensions, their aliases, and their parameters.
var extensionManager = ExtensionArgumentsManager.CreateOutgoingExtensions(protocol);
foreach (IExtensionMessage protocolExtension in extendableMessage.Extensions) {
var extension = protocolExtension as IOpenIdMessageExtension;
if (extension != null) {
Reporting.RecordFeatureUse(protocolExtension);
// Give extensions that require custom serialization a chance to do their work.
var customSerializingExtension = extension as IMessageWithEvents;
if (customSerializingExtension != null) {
customSerializingExtension.OnSending();
}
// OpenID 2.0 Section 12 forbids two extensions with the same TypeURI in the same message.
ErrorUtilities.VerifyProtocol(!extensionManager.ContainsExtension(extension.TypeUri), OpenIdStrings.ExtensionAlreadyAddedWithSameTypeURI, extension.TypeUri);
// Ensure that we're sending out a valid extension.
var extensionDescription = this.Channel.MessageDescriptions.Get(extension);
var extensionDictionary = extensionDescription.GetDictionary(extension).Serialize();
extensionDescription.EnsureMessagePartsPassBasicValidation(extensionDictionary);
// Add the extension to the outgoing message payload.
extensionManager.AddExtensionArguments(extension.TypeUri, extensionDictionary);
} else {
Logger.OpenId.WarnFormat("Unexpected extension type {0} did not implement {1}.", protocolExtension.GetType(), typeof(IOpenIdMessageExtension).Name);
}
}
// We use a cheap trick (for now at least) to determine whether the 'openid.' prefix
// belongs on the parameters by just looking at what other parameters do.
// Technically, direct message responses from Provider to Relying Party are the only
// messages that leave off the 'openid.' prefix.
bool includeOpenIdPrefix = baseMessageDictionary.Keys.Any(key => key.StartsWith(protocol.openid.Prefix, StringComparison.Ordinal));
// Add the extension parameters to the base message for transmission.
baseMessageDictionary.AddExtraParameters(extensionManager.GetArgumentsToSend(includeOpenIdPrefix));
return NoneTask;
}
return NullTask;
}
///
/// Performs any transformation on an incoming message that may be necessary and/or
/// validates an incoming message based on the rules of this channel binding element.
///
/// The incoming message to process.
/// The cancellation token.
///
/// The protections (if any) that this binding element applied to the message.
/// Null if this binding element did not even apply to this binding element.
///
///
/// Thrown when the binding element rules indicate that this message is invalid and should
/// NOT be processed.
///
///
/// Implementations that provide message protection must honor the
/// properties where applicable.
///
public Task ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) {
var extendableMessage = message as IProtocolMessageWithExtensions;
if (extendableMessage != null) {
// First add the extensions that are signed by the Provider.
foreach (IOpenIdMessageExtension signedExtension in this.GetExtensions(extendableMessage, true, null)) {
Reporting.RecordFeatureUse(signedExtension);
signedExtension.IsSignedByRemoteParty = true;
extendableMessage.Extensions.Add(signedExtension);
}
// Now search again, considering ALL extensions whether they are signed or not,
// skipping the signed ones and adding the new ones as unsigned extensions.
if (this.receiveUnsignedExtensions) {
Func isNotSigned = typeUri => !extendableMessage.Extensions.Cast().Any(ext => ext.TypeUri == typeUri);
foreach (IOpenIdMessageExtension unsignedExtension in this.GetExtensions(extendableMessage, false, isNotSigned)) {
Reporting.RecordFeatureUse(unsignedExtension);
unsignedExtension.IsSignedByRemoteParty = false;
extendableMessage.Extensions.Add(unsignedExtension);
}
}
return NoneTask;
}
return NullTask;
}
#endregion
///
/// Gets the extensions on a message.
///
/// The carrier of the extensions.
/// If set to true only signed extensions will be available.
/// A optional filter that takes an extension type URI and
/// returns a value indicating whether that extension should be deserialized and
/// returned in the sequence. May be null.
/// A sequence of extensions in the message.
private IEnumerable GetExtensions(IProtocolMessageWithExtensions message, bool ignoreUnsigned, Func extensionFilter) {
bool isAtProvider = message is SignedResponseRequest;
// We have a helper class that will do all the heavy-lifting of organizing
// all the extensions, their aliases, and their parameters.
var extensionManager = ExtensionArgumentsManager.CreateIncomingExtensions(this.GetExtensionsDictionary(message, ignoreUnsigned));
foreach (string typeUri in extensionManager.GetExtensionTypeUris()) {
// Our caller may have already obtained a signed version of this extension,
// so skip it if they don't want this one.
if (extensionFilter != null && !extensionFilter(typeUri)) {
continue;
}
var extensionData = extensionManager.GetExtensionArguments(typeUri);
// Initialize this particular extension.
IOpenIdMessageExtension extension = this.ExtensionFactory.Create(typeUri, extensionData, message, isAtProvider);
if (extension != null) {
try {
// Make sure the extension fulfills spec requirements before deserializing it.
MessageDescription messageDescription = this.Channel.MessageDescriptions.Get(extension);
messageDescription.EnsureMessagePartsPassBasicValidation(extensionData);
// Deserialize the extension.
MessageDictionary extensionDictionary = messageDescription.GetDictionary(extension);
foreach (var pair in extensionData) {
extensionDictionary[pair.Key] = pair.Value;
}
// Give extensions that require custom serialization a chance to do their work.
var customSerializingExtension = extension as IMessageWithEvents;
if (customSerializingExtension != null) {
customSerializingExtension.OnReceiving();
}
} catch (ProtocolException ex) {
Logger.OpenId.ErrorFormat(OpenIdStrings.BadExtension, extension.GetType(), ex);
extension = null;
}
if (extension != null) {
yield return extension;
}
} else {
Logger.OpenId.DebugFormat("Extension with type URI '{0}' ignored because it is not a recognized extension.", typeUri);
}
}
}
///
/// Gets the dictionary of message parts that should be deserialized into extensions.
///
/// The message.
/// If set to true only signed extensions will be available.
///
/// A dictionary of message parts, including only signed parts when appropriate.
///
private IDictionary GetExtensionsDictionary(IProtocolMessage message, bool ignoreUnsigned) {
RequiresEx.ValidState(this.Channel != null);
IndirectSignedResponse signedResponse = message as IndirectSignedResponse;
if (signedResponse != null && ignoreUnsigned) {
return signedResponse.GetSignedMessageParts(this.Channel);
} else {
return this.Channel.MessageDescriptions.GetAccessor(message);
}
}
}
}