//----------------------------------------------------------------------- // // Copyright (c) Outercurve Foundation. All rights reserved. // //----------------------------------------------------------------------- namespace DotNetOpenAuth.OpenId.ChannelElements { using System; using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.Contracts; using System.Globalization; using System.Linq; using System.Net.Security; using System.Web; using DotNetOpenAuth.Loggers; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.Messaging.Reflection; using DotNetOpenAuth.OpenId.Messages; /// /// Signs and verifies authentication assertions. /// [ContractClass(typeof(SigningBindingElementContract))] internal abstract class SigningBindingElement : IChannelBindingElement { #region IChannelBindingElement Properties /// /// Gets the protection offered (if any) by this binding element. /// /// public MessageProtections Protection { get { return MessageProtections.TamperProtection; } } /// /// Gets or sets the channel that this binding element belongs to. /// public Channel Channel { get; set; } #endregion /// /// Gets a value indicating whether this binding element is on a Provider channel. /// protected virtual bool IsOnProvider { get { return false; } } #region IChannelBindingElement Methods /// /// Prepares a message for sending based on the rules of this channel binding element. /// /// The message to prepare for sending. /// /// The protections (if any) that this binding element applied to the message. /// Null if this binding element did not even apply to this binding element. /// public virtual MessageProtections? ProcessOutgoingMessage(IProtocolMessage message) { return null; } /// /// Performs any transformation on an incoming message that may be necessary and/or /// validates an incoming message based on the rules of this channel binding element. /// /// The incoming message to process. /// /// The protections (if any) that this binding element applied to the message. /// Null if this binding element did not even apply to this binding element. /// /// /// Thrown when the binding element rules indicate that this message is invalid and should /// NOT be processed. /// public MessageProtections? ProcessIncomingMessage(IProtocolMessage message) { var signedMessage = message as ITamperResistantOpenIdMessage; if (signedMessage != null) { Logger.Bindings.DebugFormat("Verifying incoming {0} message signature of: {1}", message.GetType().Name, signedMessage.Signature); MessageProtections protectionsApplied = MessageProtections.TamperProtection; this.EnsureParametersRequiringSignatureAreSigned(signedMessage); Association association = this.GetSpecificAssociation(signedMessage); if (association != null) { string signature = this.GetSignature(signedMessage, association); if (!MessagingUtilities.EqualsConstantTime(signedMessage.Signature, signature)) { Logger.Bindings.Error("Signature verification failed."); throw new InvalidSignatureException(message); } } else { ErrorUtilities.VerifyInternal(this.Channel != null, "Cannot verify private association signature because we don't have a channel."); protectionsApplied = this.VerifySignatureByUnrecognizedHandle(message, signedMessage, protectionsApplied); } return protectionsApplied; } return null; } /// /// Verifies the signature by unrecognized handle. /// /// The message. /// The signed message. /// The protections applied. /// The applied protections. protected abstract MessageProtections VerifySignatureByUnrecognizedHandle(IProtocolMessage message, ITamperResistantOpenIdMessage signedMessage, MessageProtections protectionsApplied); #endregion /// /// Calculates the signature for a given message. /// /// The message to sign or verify. /// The association to use to sign the message. /// The calculated signature of the method. protected string GetSignature(ITamperResistantOpenIdMessage signedMessage, Association association) { Requires.NotNull(signedMessage, "signedMessage"); Requires.True(!string.IsNullOrEmpty(signedMessage.SignedParameterOrder), "signedMessage"); Requires.NotNull(association, "association"); // Prepare the parts to sign, taking care to replace an openid.mode value // of check_authentication with its original id_res so the signature matches. MessageDictionary dictionary = this.Channel.MessageDescriptions.GetAccessor(signedMessage); var parametersToSign = from name in signedMessage.SignedParameterOrder.Split(',') let prefixedName = Protocol.V20.openid.Prefix + name select new KeyValuePair(name, dictionary.GetValueOrThrow(prefixedName, signedMessage)); byte[] dataToSign = KeyValueFormEncoding.GetBytes(parametersToSign); string signature = Convert.ToBase64String(association.Sign(dataToSign)); if (Logger.Signatures.IsDebugEnabled) { Logger.Signatures.DebugFormat( "Signing these message parts: {0}{1}{0}Base64 representation of signed data: {2}{0}Signature: {3}", Environment.NewLine, parametersToSign.ToStringDeferred(), Convert.ToBase64String(dataToSign), signature); } return signature; } /// /// Gets the association to use to sign or verify a message. /// /// The message to sign or verify. /// The association to use to sign or verify the message. protected abstract Association GetAssociation(ITamperResistantOpenIdMessage signedMessage); /// /// Gets a specific association referenced in a given message's association handle. /// /// The signed message whose association handle should be used to lookup the association to return. /// The referenced association; or null if such an association cannot be found. /// /// If the association handle set in the message does not match any valid association, /// the association handle property is cleared, and the /// property is set to the /// handle that could not be found. /// protected abstract Association GetSpecificAssociation(ITamperResistantOpenIdMessage signedMessage); /// /// Gets a private Provider association used for signing messages in "dumb" mode. /// /// An existing or newly created association. protected virtual Association GetDumbAssociationForSigning() { throw new NotImplementedException(); } /// /// Ensures that all message parameters that must be signed are in fact included /// in the signature. /// /// The signed message. private void EnsureParametersRequiringSignatureAreSigned(ITamperResistantOpenIdMessage signedMessage) { // Verify that the signed parameter order includes the mandated fields. // We do this in such a way that derived classes that add mandated fields automatically // get included in the list of checked parameters. Protocol protocol = Protocol.Lookup(signedMessage.Version); var partsRequiringProtection = from part in this.Channel.MessageDescriptions.Get(signedMessage).Mapping.Values where part.RequiredProtection != ProtectionLevel.None where part.IsRequired || part.IsNondefaultValueSet(signedMessage) select part.Name; ErrorUtilities.VerifyInternal(partsRequiringProtection.All(name => name.StartsWith(protocol.openid.Prefix, StringComparison.Ordinal)), "Signing only works when the parameters start with the 'openid.' prefix."); string[] signedParts = signedMessage.SignedParameterOrder.Split(','); var unsignedParts = from partName in partsRequiringProtection where !signedParts.Contains(partName.Substring(protocol.openid.Prefix.Length)) select partName; ErrorUtilities.VerifyProtocol(!unsignedParts.Any(), OpenIdStrings.SignatureDoesNotIncludeMandatoryParts, string.Join(", ", unsignedParts.ToArray())); } } }