//-----------------------------------------------------------------------
//
// Copyright (c) Outercurve Foundation. All rights reserved.
//
//-----------------------------------------------------------------------
namespace DotNetOpenAuth.OpenId.ChannelElements {
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.Contracts;
using System.Globalization;
using System.Linq;
using System.Net.Security;
using System.Web;
using DotNetOpenAuth.Loggers;
using DotNetOpenAuth.Messaging;
using DotNetOpenAuth.Messaging.Bindings;
using DotNetOpenAuth.Messaging.Reflection;
using DotNetOpenAuth.OpenId.Messages;
///
/// Signs and verifies authentication assertions.
///
[ContractClass(typeof(SigningBindingElementContract))]
internal abstract class SigningBindingElement : IChannelBindingElement {
#region IChannelBindingElement Properties
///
/// Gets the protection offered (if any) by this binding element.
///
///
public MessageProtections Protection {
get { return MessageProtections.TamperProtection; }
}
///
/// Gets or sets the channel that this binding element belongs to.
///
public Channel Channel { get; set; }
#endregion
///
/// Gets a value indicating whether this binding element is on a Provider channel.
///
protected virtual bool IsOnProvider {
get { return false; }
}
#region IChannelBindingElement Methods
///
/// Prepares a message for sending based on the rules of this channel binding element.
///
/// The message to prepare for sending.
///
/// The protections (if any) that this binding element applied to the message.
/// Null if this binding element did not even apply to this binding element.
///
public virtual MessageProtections? ProcessOutgoingMessage(IProtocolMessage message) {
return null;
}
///
/// Performs any transformation on an incoming message that may be necessary and/or
/// validates an incoming message based on the rules of this channel binding element.
///
/// The incoming message to process.
///
/// The protections (if any) that this binding element applied to the message.
/// Null if this binding element did not even apply to this binding element.
///
///
/// Thrown when the binding element rules indicate that this message is invalid and should
/// NOT be processed.
///
public MessageProtections? ProcessIncomingMessage(IProtocolMessage message) {
var signedMessage = message as ITamperResistantOpenIdMessage;
if (signedMessage != null) {
Logger.Bindings.DebugFormat("Verifying incoming {0} message signature of: {1}", message.GetType().Name, signedMessage.Signature);
MessageProtections protectionsApplied = MessageProtections.TamperProtection;
this.EnsureParametersRequiringSignatureAreSigned(signedMessage);
Association association = this.GetSpecificAssociation(signedMessage);
if (association != null) {
string signature = this.GetSignature(signedMessage, association);
if (!MessagingUtilities.EqualsConstantTime(signedMessage.Signature, signature)) {
Logger.Bindings.Error("Signature verification failed.");
throw new InvalidSignatureException(message);
}
} else {
ErrorUtilities.VerifyInternal(this.Channel != null, "Cannot verify private association signature because we don't have a channel.");
protectionsApplied = this.VerifySignatureByUnrecognizedHandle(message, signedMessage, protectionsApplied);
}
return protectionsApplied;
}
return null;
}
///
/// Verifies the signature by unrecognized handle.
///
/// The message.
/// The signed message.
/// The protections applied.
/// The applied protections.
protected abstract MessageProtections VerifySignatureByUnrecognizedHandle(IProtocolMessage message, ITamperResistantOpenIdMessage signedMessage, MessageProtections protectionsApplied);
#endregion
///
/// Calculates the signature for a given message.
///
/// The message to sign or verify.
/// The association to use to sign the message.
/// The calculated signature of the method.
protected string GetSignature(ITamperResistantOpenIdMessage signedMessage, Association association) {
Requires.NotNull(signedMessage, "signedMessage");
Requires.True(!string.IsNullOrEmpty(signedMessage.SignedParameterOrder), "signedMessage");
Requires.NotNull(association, "association");
// Prepare the parts to sign, taking care to replace an openid.mode value
// of check_authentication with its original id_res so the signature matches.
MessageDictionary dictionary = this.Channel.MessageDescriptions.GetAccessor(signedMessage);
var parametersToSign = from name in signedMessage.SignedParameterOrder.Split(',')
let prefixedName = Protocol.V20.openid.Prefix + name
select new KeyValuePair(name, dictionary.GetValueOrThrow(prefixedName, signedMessage));
byte[] dataToSign = KeyValueFormEncoding.GetBytes(parametersToSign);
string signature = Convert.ToBase64String(association.Sign(dataToSign));
if (Logger.Signatures.IsDebugEnabled) {
Logger.Signatures.DebugFormat(
"Signing these message parts: {0}{1}{0}Base64 representation of signed data: {2}{0}Signature: {3}",
Environment.NewLine,
parametersToSign.ToStringDeferred(),
Convert.ToBase64String(dataToSign),
signature);
}
return signature;
}
///
/// Gets the association to use to sign or verify a message.
///
/// The message to sign or verify.
/// The association to use to sign or verify the message.
protected abstract Association GetAssociation(ITamperResistantOpenIdMessage signedMessage);
///
/// Gets a specific association referenced in a given message's association handle.
///
/// The signed message whose association handle should be used to lookup the association to return.
/// The referenced association; or null if such an association cannot be found.
///
/// If the association handle set in the message does not match any valid association,
/// the association handle property is cleared, and the
/// property is set to the
/// handle that could not be found.
///
protected abstract Association GetSpecificAssociation(ITamperResistantOpenIdMessage signedMessage);
///
/// Gets a private Provider association used for signing messages in "dumb" mode.
///
/// An existing or newly created association.
protected virtual Association GetDumbAssociationForSigning() {
throw new NotImplementedException();
}
///
/// Ensures that all message parameters that must be signed are in fact included
/// in the signature.
///
/// The signed message.
private void EnsureParametersRequiringSignatureAreSigned(ITamperResistantOpenIdMessage signedMessage) {
// Verify that the signed parameter order includes the mandated fields.
// We do this in such a way that derived classes that add mandated fields automatically
// get included in the list of checked parameters.
Protocol protocol = Protocol.Lookup(signedMessage.Version);
var partsRequiringProtection = from part in this.Channel.MessageDescriptions.Get(signedMessage).Mapping.Values
where part.RequiredProtection != ProtectionLevel.None
where part.IsRequired || part.IsNondefaultValueSet(signedMessage)
select part.Name;
ErrorUtilities.VerifyInternal(partsRequiringProtection.All(name => name.StartsWith(protocol.openid.Prefix, StringComparison.Ordinal)), "Signing only works when the parameters start with the 'openid.' prefix.");
string[] signedParts = signedMessage.SignedParameterOrder.Split(',');
var unsignedParts = from partName in partsRequiringProtection
where !signedParts.Contains(partName.Substring(protocol.openid.Prefix.Length))
select partName;
ErrorUtilities.VerifyProtocol(!unsignedParts.Any(), OpenIdStrings.SignatureDoesNotIncludeMandatoryParts, string.Join(", ", unsignedParts.ToArray()));
}
}
}