//----------------------------------------------------------------------- // // Copyright (c) Andrew Arnott. All rights reserved. // //----------------------------------------------------------------------- namespace DotNetOAuth.ChannelElements { using System; using System.Collections.Generic; using System.Globalization; using System.Text; using DotNetOAuth.Messaging; using DotNetOAuth.Messaging.Bindings; /// /// A binding element that signs outgoing messages and verifies the signature on incoming messages. /// public abstract class SigningBindingElementBase : ITamperProtectionChannelBindingElement { /// /// The signature method this binding element uses. /// private string signatureMethod; /// /// Initializes a new instance of the class. /// /// The OAuth signature method that the binding element uses. internal SigningBindingElementBase(string signatureMethod) { this.signatureMethod = signatureMethod; } #region ITamperProtectionChannelBindingElement members /// /// Gets or sets the delegate that will initialize the non-serialized properties necessary on a signed /// message so that its signature can be correctly calculated for verification. /// public Action SignatureVerificationCallback { get; set; } #endregion #region IChannelBindingElement Members /// /// Gets the message protection provided by this binding element. /// public MessageProtection Protection { get { return MessageProtection.TamperProtection; } } /// /// Signs the outgoing message. /// /// The message to sign. /// True if the message was signed. False otherwise. public bool PrepareMessageForSending(IProtocolMessage message) { var signedMessage = message as ITamperResistantOAuthMessage; if (signedMessage != null && this.IsMessageApplicable(signedMessage)) { signedMessage.SignatureMethod = this.signatureMethod; signedMessage.Signature = this.GetSignature(signedMessage); return true; } return false; } /// /// Verifies the signature on an incoming message. /// /// The message whose signature should be verified. /// True if the signature was verified. False if the message had no signature. /// Thrown if the signature is invalid. public bool PrepareMessageForReceiving(IProtocolMessage message) { var signedMessage = message as ITamperResistantOAuthMessage; if (signedMessage != null && this.IsMessageApplicable(signedMessage)) { if (!string.Equals(signedMessage.SignatureMethod, this.signatureMethod, StringComparison.Ordinal)) { Logger.WarnFormat("Expected signature method '{0}' but received message with a signature method of '{1}'.", this.signatureMethod, signedMessage.SignatureMethod); return false; } if (this.SignatureVerificationCallback != null) { this.SignatureVerificationCallback(signedMessage); } else { Logger.Warn("Signature verification required, but callback delegate was not provided to provide additional data for signing."); } string signature = this.GetSignature(signedMessage); if (signedMessage.Signature != signature) { Logger.Error("Signature verification failed."); throw new InvalidSignatureException(message); } return true; } return false; } #endregion /// /// Constructs the OAuth Signature Base String and returns the result. /// /// The message to derive the signature base string from. /// The signature base string. /// /// This method implements OAuth 1.0 section 9.1. /// protected static string ConstructSignatureBaseString(ITamperResistantOAuthMessage message) { if (String.IsNullOrEmpty(message.HttpMethod)) { throw new ArgumentException( string.Format( CultureInfo.CurrentCulture, MessagingStrings.ArgumentPropertyMissing, typeof(ITamperResistantOAuthMessage).Name, "HttpMethod"), "message"); } List signatureBaseStringElements = new List(3); signatureBaseStringElements.Add(message.HttpMethod.ToUpperInvariant()); UriBuilder endpoint = new UriBuilder(message.Recipient); endpoint.Query = null; endpoint.Fragment = null; signatureBaseStringElements.Add(endpoint.Uri.AbsoluteUri); var encodedDictionary = OAuthChannel.GetEncodedParameters(message); encodedDictionary.Remove("oauth_signature"); var sortedKeyValueList = new List>(encodedDictionary); sortedKeyValueList.Sort(SignatureBaseStringParameterComparer); StringBuilder paramBuilder = new StringBuilder(); foreach (var pair in sortedKeyValueList) { if (paramBuilder.Length > 0) { paramBuilder.Append("&"); } paramBuilder.Append(pair.Key); paramBuilder.Append('='); paramBuilder.Append(pair.Value); } signatureBaseStringElements.Add(paramBuilder.ToString()); StringBuilder signatureBaseString = new StringBuilder(); foreach (string element in signatureBaseStringElements) { if (signatureBaseString.Length > 0) { signatureBaseString.Append("&"); } signatureBaseString.Append(Uri.EscapeDataString(element)); } return signatureBaseString.ToString(); } /// /// Calculates a signature for a given message. /// /// The message to sign. /// The signature for the message. protected abstract string GetSignature(ITamperResistantOAuthMessage message); /// /// Checks whether this binding element applies to this message. /// /// The message that needs to be signed. /// True if this binding element can be used to sign the message. False otherwise. protected virtual bool IsMessageApplicable(ITamperResistantOAuthMessage message) { return string.IsNullOrEmpty(message.SignatureMethod) || message.SignatureMethod == this.signatureMethod; } /// /// Gets the ConsumerSecret&TokenSecret" string, allowing either property to be empty or null. /// /// The message to extract the secrets from. /// The concatenated string. protected string GetConsumerAndTokenSecretString(ITamperResistantOAuthMessage message) { StringBuilder builder = new StringBuilder(); if (!string.IsNullOrEmpty(message.ConsumerSecret)) { builder.Append(Uri.EscapeDataString(message.ConsumerSecret)); } builder.Append("&"); if (!string.IsNullOrEmpty(message.TokenSecret)) { builder.Append(Uri.EscapeDataString(message.TokenSecret)); } return builder.ToString(); } /// /// Sorts parameters according to OAuth signature base string rules. /// /// The first parameter to compare. /// The second parameter to compare. /// Negative, zero or positive. private static int SignatureBaseStringParameterComparer(KeyValuePair left, KeyValuePair right) { int result = string.CompareOrdinal(left.Key, right.Key); if (result != 0) { return result; } return string.CompareOrdinal(left.Value, right.Value); } } }