//-----------------------------------------------------------------------
//
// Copyright (c) Andrew Arnott. All rights reserved.
//
//-----------------------------------------------------------------------
namespace DotNetOAuth.ChannelElements {
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Text;
using DotNetOAuth.Messaging;
using DotNetOAuth.Messaging.Bindings;
///
/// A binding element that signs outgoing messages and verifies the signature on incoming messages.
///
public abstract class SigningBindingElementBase : ITamperProtectionChannelBindingElement {
///
/// The signature method this binding element uses.
///
private string signatureMethod;
///
/// Initializes a new instance of the class.
///
/// The OAuth signature method that the binding element uses.
internal SigningBindingElementBase(string signatureMethod) {
this.signatureMethod = signatureMethod;
}
#region ITamperProtectionChannelBindingElement members
///
/// Gets or sets the delegate that will initialize the non-serialized properties necessary on a signed
/// message so that its signature can be correctly calculated for verification.
///
public Action SignatureVerificationCallback { get; set; }
#endregion
#region IChannelBindingElement Members
///
/// Gets the message protection provided by this binding element.
///
public MessageProtection Protection {
get { return MessageProtection.TamperProtection; }
}
///
/// Signs the outgoing message.
///
/// The message to sign.
/// True if the message was signed. False otherwise.
public bool PrepareMessageForSending(IProtocolMessage message) {
var signedMessage = message as ITamperResistantOAuthMessage;
if (signedMessage != null && this.IsMessageApplicable(signedMessage)) {
signedMessage.SignatureMethod = this.signatureMethod;
signedMessage.Signature = this.GetSignature(signedMessage);
return true;
}
return false;
}
///
/// Verifies the signature on an incoming message.
///
/// The message whose signature should be verified.
/// True if the signature was verified. False if the message had no signature.
/// Thrown if the signature is invalid.
public bool PrepareMessageForReceiving(IProtocolMessage message) {
var signedMessage = message as ITamperResistantOAuthMessage;
if (signedMessage != null && this.IsMessageApplicable(signedMessage)) {
if (!string.Equals(signedMessage.SignatureMethod, this.signatureMethod, StringComparison.Ordinal)) {
Logger.WarnFormat("Expected signature method '{0}' but received message with a signature method of '{1}'.", this.signatureMethod, signedMessage.SignatureMethod);
return false;
}
if (this.SignatureVerificationCallback != null) {
this.SignatureVerificationCallback(signedMessage);
} else {
Logger.Warn("Signature verification required, but callback delegate was not provided to provide additional data for signing.");
}
string signature = this.GetSignature(signedMessage);
if (signedMessage.Signature != signature) {
Logger.Error("Signature verification failed.");
throw new InvalidSignatureException(message);
}
return true;
}
return false;
}
#endregion
///
/// Constructs the OAuth Signature Base String and returns the result.
///
/// The message to derive the signature base string from.
/// The signature base string.
///
/// This method implements OAuth 1.0 section 9.1.
///
protected static string ConstructSignatureBaseString(ITamperResistantOAuthMessage message) {
if (String.IsNullOrEmpty(message.HttpMethod)) {
throw new ArgumentException(
string.Format(
CultureInfo.CurrentCulture,
MessagingStrings.ArgumentPropertyMissing,
typeof(ITamperResistantOAuthMessage).Name,
"HttpMethod"),
"message");
}
List signatureBaseStringElements = new List(3);
signatureBaseStringElements.Add(message.HttpMethod.ToUpperInvariant());
UriBuilder endpoint = new UriBuilder(message.Recipient);
endpoint.Query = null;
endpoint.Fragment = null;
signatureBaseStringElements.Add(endpoint.Uri.AbsoluteUri);
var encodedDictionary = OAuthChannel.GetEncodedParameters(message);
encodedDictionary.Remove("oauth_signature");
var sortedKeyValueList = new List>(encodedDictionary);
sortedKeyValueList.Sort(SignatureBaseStringParameterComparer);
StringBuilder paramBuilder = new StringBuilder();
foreach (var pair in sortedKeyValueList) {
if (paramBuilder.Length > 0) {
paramBuilder.Append("&");
}
paramBuilder.Append(pair.Key);
paramBuilder.Append('=');
paramBuilder.Append(pair.Value);
}
signatureBaseStringElements.Add(paramBuilder.ToString());
StringBuilder signatureBaseString = new StringBuilder();
foreach (string element in signatureBaseStringElements) {
if (signatureBaseString.Length > 0) {
signatureBaseString.Append("&");
}
signatureBaseString.Append(Uri.EscapeDataString(element));
}
return signatureBaseString.ToString();
}
///
/// Calculates a signature for a given message.
///
/// The message to sign.
/// The signature for the message.
protected abstract string GetSignature(ITamperResistantOAuthMessage message);
///
/// Checks whether this binding element applies to this message.
///
/// The message that needs to be signed.
/// True if this binding element can be used to sign the message. False otherwise.
protected virtual bool IsMessageApplicable(ITamperResistantOAuthMessage message) {
return string.IsNullOrEmpty(message.SignatureMethod) || message.SignatureMethod == this.signatureMethod;
}
///
/// Gets the ConsumerSecret&TokenSecret" string, allowing either property to be empty or null.
///
/// The message to extract the secrets from.
/// The concatenated string.
protected string GetConsumerAndTokenSecretString(ITamperResistantOAuthMessage message) {
StringBuilder builder = new StringBuilder();
if (!string.IsNullOrEmpty(message.ConsumerSecret)) {
builder.Append(Uri.EscapeDataString(message.ConsumerSecret));
}
builder.Append("&");
if (!string.IsNullOrEmpty(message.TokenSecret)) {
builder.Append(Uri.EscapeDataString(message.TokenSecret));
}
return builder.ToString();
}
///
/// Sorts parameters according to OAuth signature base string rules.
///
/// The first parameter to compare.
/// The second parameter to compare.
/// Negative, zero or positive.
private static int SignatureBaseStringParameterComparer(KeyValuePair left, KeyValuePair right) {
int result = string.CompareOrdinal(left.Key, right.Key);
if (result != 0) {
return result;
}
return string.CompareOrdinal(left.Value, right.Value);
}
}
}