//-----------------------------------------------------------------------
//
// Copyright (c) Outercurve Foundation. All rights reserved.
//
//-----------------------------------------------------------------------
namespace DotNetOpenAuth.OAuth.ChannelElements {
using System;
using System.Collections.Generic;
using System.Collections.Specialized;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Web;
using DotNetOpenAuth.Messaging;
using DotNetOpenAuth.Messaging.Bindings;
using DotNetOpenAuth.Messaging.Reflection;
using Validation;
///
/// A binding element that signs outgoing messages and verifies the signature on incoming messages.
///
public abstract class SigningBindingElementBase : ITamperProtectionChannelBindingElement {
///
/// The signature method this binding element uses.
///
private string signatureMethod;
///
/// Initializes a new instance of the class.
///
/// The OAuth signature method that the binding element uses.
internal SigningBindingElementBase(string signatureMethod) {
this.signatureMethod = signatureMethod;
}
#region IChannelBindingElement Properties
///
/// Gets the message protection provided by this binding element.
///
public MessageProtections Protection {
get { return MessageProtections.TamperProtection; }
}
///
/// Gets or sets the channel that this binding element belongs to.
///
public Channel Channel { get; set; }
#endregion
#region ITamperProtectionChannelBindingElement members
///
/// Gets or sets the delegate that will initialize the non-serialized properties necessary on a signed
/// message so that its signature can be correctly calculated for verification.
///
public Action SignatureCallback { get; set; }
///
/// Creates a new object that is a copy of the current instance.
///
///
/// A new object that is a copy of this instance.
///
ITamperProtectionChannelBindingElement ITamperProtectionChannelBindingElement.Clone() {
ITamperProtectionChannelBindingElement clone = this.Clone();
clone.SignatureCallback = this.SignatureCallback;
return clone;
}
#endregion
#region IChannelBindingElement Methods
///
/// Signs the outgoing message.
///
/// The message to sign.
/// The cancellation token.
///
/// The protections (if any) that this binding element applied to the message.
/// Null if this binding element did not even apply to this binding element.
///
public Task ProcessOutgoingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) {
var signedMessage = message as ITamperResistantOAuthMessage;
if (signedMessage != null && this.IsMessageApplicable(signedMessage)) {
if (this.SignatureCallback != null) {
this.SignatureCallback(signedMessage);
} else {
Logger.Bindings.Warn("Signing required, but callback delegate was not provided to provide additional data for signing.");
}
signedMessage.SignatureMethod = this.signatureMethod;
Logger.Bindings.DebugFormat("Signing {0} message using {1}.", message.GetType().Name, this.signatureMethod);
signedMessage.Signature = this.GetSignature(signedMessage);
return MessageProtectionTasks.TamperProtection;
}
return MessageProtectionTasks.Null;
}
///
/// Verifies the signature on an incoming message.
///
/// The message whose signature should be verified.
/// The cancellation token.
///
/// The protections (if any) that this binding element applied to the message.
/// Null if this binding element did not even apply to this binding element.
///
/// Thrown if the signature is invalid.
public Task ProcessIncomingMessageAsync(IProtocolMessage message, CancellationToken cancellationToken) {
var signedMessage = message as ITamperResistantOAuthMessage;
if (signedMessage != null && this.IsMessageApplicable(signedMessage)) {
Logger.Bindings.DebugFormat("Verifying incoming {0} message signature of: {1}", message.GetType().Name, signedMessage.Signature);
if (!string.Equals(signedMessage.SignatureMethod, this.signatureMethod, StringComparison.Ordinal)) {
Logger.Bindings.WarnFormat("Expected signature method '{0}' but received message with a signature method of '{1}'.", this.signatureMethod, signedMessage.SignatureMethod);
return MessageProtectionTasks.None;
}
if (this.SignatureCallback != null) {
this.SignatureCallback(signedMessage);
} else {
Logger.Bindings.Warn("Signature verification required, but callback delegate was not provided to provide additional data for signature verification.");
}
if (!this.IsSignatureValid(signedMessage)) {
Logger.Bindings.Error("Signature verification failed.");
throw new InvalidSignatureException(message);
}
return MessageProtectionTasks.TamperProtection;
}
return MessageProtectionTasks.Null;
}
#endregion
///
/// Constructs the OAuth Signature Base String and returns the result.
///
/// The message.
/// The message to derive the signature base string from.
/// The signature base string.
///
/// This method implements OAuth 1.0 section 9.1.
///
[SuppressMessage("Microsoft.Maintainability", "CA1506:AvoidExcessiveClassCoupling", Justification = "Unavoidable")]
internal static string ConstructSignatureBaseString(ITamperResistantOAuthMessage message, MessageDictionary messageDictionary) {
Requires.NotNull(message, "message");
Requires.NotNull(message.HttpMethod, "message.HttpMethod");
Requires.NotNull(messageDictionary, "messageDictionary");
ErrorUtilities.VerifyInternal(messageDictionary.Message == message, "Message references are not equal.");
List signatureBaseStringElements = new List(3);
signatureBaseStringElements.Add(message.HttpMethod.ToString().ToUpperInvariant());
// For multipart POST messages, only include the message parts that are NOT
// in the POST entity (those parts that may appear in an OAuth authorization header).
var encodedDictionary = new Dictionary();
IEnumerable> partsToInclude = Enumerable.Empty>();
var binaryMessage = message as IMessageWithBinaryData;
if (binaryMessage != null && binaryMessage.SendAsMultipart) {
HttpDeliveryMethods authHeaderInUseFlags = HttpDeliveryMethods.PostRequest | HttpDeliveryMethods.AuthorizationHeaderRequest;
ErrorUtilities.VerifyProtocol((binaryMessage.HttpMethods & authHeaderInUseFlags) == authHeaderInUseFlags, OAuthStrings.MultipartPostMustBeUsedWithAuthHeader);
// Include the declared keys in the signature as those will be signable.
// Cache in local variable to avoid recalculating DeclaredKeys in the delegate.
ICollection declaredKeys = messageDictionary.DeclaredKeys;
partsToInclude = messageDictionary.Where(pair => declaredKeys.Contains(pair.Key));
} else {
partsToInclude = messageDictionary;
}
// If this message was deserialized, include only those explicitly included message parts (excludes defaulted values)
// in the signature.
var originalPayloadMessage = (IMessageOriginalPayload)message;
if (originalPayloadMessage.OriginalPayload != null) {
partsToInclude = partsToInclude.Where(pair => originalPayloadMessage.OriginalPayload.ContainsKey(pair.Key));
}
foreach (var pair in OAuthChannel.GetUriEscapedParameters(partsToInclude)) {
encodedDictionary[pair.Key] = pair.Value;
}
// An incoming message will already have included the query and form parameters
// in the message dictionary, but an outgoing message COULD have SOME parameters
// in the query that are not in the message dictionary because they were included
// in the receiving endpoint (the original URL).
// In an outgoing message, the POST entity can only contain parameters if they were
// in the message dictionary, so no need to pull out any parameters from there.
if (message.Recipient.Query != null) {
NameValueCollection nvc = HttpUtility.ParseQueryString(message.Recipient.Query);
foreach (string key in nvc) {
string escapedKey = MessagingUtilities.EscapeUriDataStringRfc3986(key);
string escapedValue = MessagingUtilities.EscapeUriDataStringRfc3986(nvc[key]);
string existingValue;
if (!encodedDictionary.TryGetValue(escapedKey, out existingValue)) {
encodedDictionary.Add(escapedKey, escapedValue);
} else {
ErrorUtilities.VerifyInternal(escapedValue == existingValue, "Somehow we have conflicting values for the '{0}' parameter.", escapedKey);
}
}
}
encodedDictionary.Remove("oauth_signature");
UriBuilder endpoint = new UriBuilder(message.Recipient);
endpoint.Query = null;
endpoint.Fragment = null;
signatureBaseStringElements.Add(endpoint.Uri.AbsoluteUri);
var sortedKeyValueList = new List>(encodedDictionary);
sortedKeyValueList.Sort(SignatureBaseStringParameterComparer);
StringBuilder paramBuilder = new StringBuilder();
foreach (var pair in sortedKeyValueList) {
if (paramBuilder.Length > 0) {
paramBuilder.Append("&");
}
paramBuilder.Append(pair.Key);
paramBuilder.Append('=');
paramBuilder.Append(pair.Value);
}
signatureBaseStringElements.Add(paramBuilder.ToString());
StringBuilder signatureBaseString = new StringBuilder();
foreach (string element in signatureBaseStringElements) {
if (signatureBaseString.Length > 0) {
signatureBaseString.Append("&");
}
signatureBaseString.Append(MessagingUtilities.EscapeUriDataStringRfc3986(element));
}
Logger.Bindings.DebugFormat("Constructed signature base string: {0}", signatureBaseString);
return signatureBaseString.ToString();
}
///
/// Calculates a signature for a given message.
///
/// The message to sign.
/// The signature for the message.
///
/// This method signs the message per OAuth 1.0 section 9.2.
///
internal string GetSignatureTestHook(ITamperResistantOAuthMessage message) {
return this.GetSignature(message);
}
///
/// Gets the "ConsumerSecret&TokenSecret" string, allowing either property to be empty or null.
///
/// The message to extract the secrets from.
/// The concatenated string.
protected static string GetConsumerAndTokenSecretString(ITamperResistantOAuthMessage message) {
StringBuilder builder = new StringBuilder();
if (!string.IsNullOrEmpty(message.ConsumerSecret)) {
builder.Append(MessagingUtilities.EscapeUriDataStringRfc3986(message.ConsumerSecret));
}
builder.Append("&");
if (!string.IsNullOrEmpty(message.TokenSecret)) {
builder.Append(MessagingUtilities.EscapeUriDataStringRfc3986(message.TokenSecret));
}
return builder.ToString();
}
///
/// Determines whether the signature on some message is valid.
///
/// The message to check the signature on.
///
/// true if the signature on the message is valid; otherwise, false.
///
protected virtual bool IsSignatureValid(ITamperResistantOAuthMessage message) {
Requires.NotNull(message, "message");
string signature = this.GetSignature(message);
return MessagingUtilities.EqualsConstantTime(message.Signature, signature);
}
///
/// Clones this instance.
///
/// A new instance of the binding element.
///
/// Implementations of this method need not clone the SignatureVerificationCallback member, as the
/// class does this.
///
protected abstract ITamperProtectionChannelBindingElement Clone();
///
/// Calculates a signature for a given message.
///
/// The message to sign.
/// The signature for the message.
protected abstract string GetSignature(ITamperResistantOAuthMessage message);
///
/// Checks whether this binding element applies to this message.
///
/// The message that needs to be signed.
/// True if this binding element can be used to sign the message. False otherwise.
protected virtual bool IsMessageApplicable(ITamperResistantOAuthMessage message) {
return string.IsNullOrEmpty(message.SignatureMethod) || message.SignatureMethod == this.signatureMethod;
}
///
/// Sorts parameters according to OAuth signature base string rules.
///
/// The first parameter to compare.
/// The second parameter to compare.
/// Negative, zero or positive.
private static int SignatureBaseStringParameterComparer(KeyValuePair left, KeyValuePair right) {
int result = string.CompareOrdinal(left.Key, right.Key);
if (result != 0) {
return result;
}
return string.CompareOrdinal(left.Value, right.Value);
}
}
}