//-----------------------------------------------------------------------
//
// Copyright (c) Andrew Arnott. All rights reserved.
//
//-----------------------------------------------------------------------
namespace DotNetOAuth.Messaging {
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Net;
using System.Text;
using System.Web;
using DotNetOAuth.Messaging.Reflection;
///
/// Manages sending direct messages to a remote party and receiving responses.
///
public abstract class Channel {
///
/// The maximum allowable size for a 301 Redirect response before we send
/// a 200 OK response with a scripted form POST with the parameters instead
/// in order to ensure successfully sending a large payload to another server
/// that might have a maximum allowable size restriction on its GET request.
///
private static int indirectMessageGetToPostThreshold = 2 * 1024; // 2KB, recommended by OpenID group
///
/// The template for indirect messages that require form POST to forward through the user agent.
///
///
/// We are intentionally using " instead of the html single quote ' below because
/// the HtmlEncode'd values that we inject will only escape the double quote, so
/// only the double-quote used around these values is safe.
///
private static string indirectMessageFormPostFormat = @"
";
///
/// A tool that can figure out what kind of message is being received
/// so it can be deserialized.
///
private IMessageTypeProvider messageTypeProvider;
///
/// A list of binding elements in the order they must be applied to outgoing messages.
///
///
/// Incoming messages should have the binding elements applied in reverse order.
///
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private List bindingElements = new List();
///
/// Initializes a new instance of the class.
///
///
/// A class prepared to analyze incoming messages and indicate what concrete
/// message types can deserialize from it.
///
/// The binding elements to use in sending and receiving messages.
protected Channel(IMessageTypeProvider messageTypeProvider, params IChannelBindingElement[] bindingElements) {
if (messageTypeProvider == null) {
throw new ArgumentNullException("messageTypeProvider");
}
this.messageTypeProvider = messageTypeProvider;
this.bindingElements = new List(ValidateAndPrepareBindingElements(bindingElements));
}
///
/// Gets the binding elements used by this channel, in the order they are applied to outgoing messages.
///
///
/// Incoming messages are processed by this binding elements in the reverse order.
///
protected internal ReadOnlyCollection BindingElements {
get {
return this.bindingElements.AsReadOnly();
}
}
///
/// Gets a tool that can figure out what kind of message is being received
/// so it can be deserialized.
///
protected IMessageTypeProvider MessageTypeProvider {
get { return this.messageTypeProvider; }
}
///
/// Queues an indirect message (either a request or response)
/// or direct message response for transmission to a remote party.
///
/// The one-way message to send
/// The pending user agent redirect based message to be sent as an HttpResponse.
public Response Send(IProtocolMessage message) {
if (message == null) {
throw new ArgumentNullException("message");
}
this.PrepareMessageForSending(message);
Logger.DebugFormat("Sending message: {0}", message);
switch (message.Transport) {
case MessageTransport.Direct:
// This is a response to a direct message.
return this.SendDirectMessageResponse(message);
case MessageTransport.Indirect:
var directedMessage = message as IDirectedProtocolMessage;
if (directedMessage == null) {
throw new ArgumentException(
string.Format(
CultureInfo.CurrentCulture,
MessagingStrings.IndirectMessagesMustImplementIDirectedProtocolMessage,
typeof(IDirectedProtocolMessage).FullName),
"message");
}
if (directedMessage.Recipient == null) {
throw new ArgumentException(MessagingStrings.DirectedMessageMissingRecipient, "message");
}
return this.SendIndirectMessage(directedMessage);
default:
throw new ArgumentException(
string.Format(
CultureInfo.CurrentCulture,
MessagingStrings.UnrecognizedEnumValue,
"Transport",
message.Transport),
"message");
}
}
///
/// Gets the protocol message embedded in the given HTTP request, if present.
///
/// The deserialized message, if one is found. Null otherwise.
///
/// Requires an HttpContext.Current context.
///
/// Thrown when is null.
public IProtocolMessage ReadFromRequest() {
return this.ReadFromRequest(this.GetRequestFromContext());
}
///
/// Gets the protocol message embedded in the given HTTP request, if present.
///
/// The expected type of the message to be received.
/// The deserialized message, if one is found. Null otherwise.
/// True if the expected message was recognized and deserialized. False otherwise.
///
/// Requires an HttpContext.Current context.
///
/// Thrown when is null.
/// Thrown when a request message of an unexpected type is received.
public bool TryReadFromRequest(out TREQUEST request)
where TREQUEST : class, IProtocolMessage {
return TryReadFromRequest(this.GetRequestFromContext(), out request);
}
///
/// Gets the protocol message embedded in the given HTTP request, if present.
///
/// The expected type of the message to be received.
/// The request to search for an embedded message.
/// The deserialized message, if one is found. Null otherwise.
/// True if the expected message was recognized and deserialized. False otherwise.
/// Thrown when is null.
/// Thrown when a request message of an unexpected type is received.
public bool TryReadFromRequest(HttpRequestInfo httpRequest, out TREQUEST request)
where TREQUEST : class, IProtocolMessage {
IProtocolMessage untypedRequest = this.ReadFromRequest(httpRequest);
if (untypedRequest == null) {
request = null;
return false;
}
request = untypedRequest as TREQUEST;
if (request == null) {
throw new ProtocolException(
string.Format(
CultureInfo.CurrentCulture,
MessagingStrings.UnexpectedMessageReceived,
typeof(TREQUEST),
untypedRequest.GetType()));
}
return true;
}
///
/// Gets the protocol message embedded in the given HTTP request, if present.
///
/// The expected type of the message to be received.
/// The deserialized message.
///
/// Requires an HttpContext.Current context.
///
/// Thrown when is null.
/// Thrown if the expected message was not recognized in the response.
public TREQUEST ReadFromRequest()
where TREQUEST : class, IProtocolMessage {
return this.ReadFromRequest(this.GetRequestFromContext());
}
///
/// Gets the protocol message that may be embedded in the given HTTP request.
///
/// The expected type of the message to be received.
/// The request to search for an embedded message.
/// The deserialized message, if one is found. Null otherwise.
/// Thrown if the expected message was not recognized in the response.
public TREQUEST ReadFromRequest(HttpRequestInfo httpRequest)
where TREQUEST : class, IProtocolMessage {
TREQUEST request;
if (this.TryReadFromRequest(httpRequest, out request)) {
return request;
} else {
throw new ProtocolException(
string.Format(
CultureInfo.CurrentCulture,
MessagingStrings.ExpectedMessageNotReceived,
typeof(TREQUEST)));
}
}
///
/// Gets the protocol message that may be embedded in the given HTTP request.
///
/// The request to search for an embedded message.
/// The deserialized message, if one is found. Null otherwise.
public IProtocolMessage ReadFromRequest(HttpRequestInfo httpRequest) {
IProtocolMessage requestMessage = this.ReadFromRequestInternal(httpRequest);
if (requestMessage != null) {
Logger.DebugFormat("Incoming request received: {0}", requestMessage);
this.VerifyMessageAfterReceiving(requestMessage);
}
return requestMessage;
}
///
/// Sends a direct message to a remote party and waits for the response.
///
/// The expected type of the message to be received.
/// The message to send.
/// The remote party's response.
///
/// Thrown if no message is recognized in the response
/// or an unexpected type of message is received.
///
public TRESPONSE Request(IDirectedProtocolMessage request)
where TRESPONSE : class, IProtocolMessage {
IProtocolMessage response = this.Request(request);
if (response == null) {
throw new ProtocolException(
string.Format(
CultureInfo.CurrentCulture,
MessagingStrings.ExpectedMessageNotReceived,
typeof(TRESPONSE)));
}
var expectedResponse = response as TRESPONSE;
if (expectedResponse == null) {
throw new ProtocolException(
string.Format(
CultureInfo.CurrentCulture,
MessagingStrings.UnexpectedMessageReceived,
typeof(TRESPONSE),
response.GetType()));
}
return expectedResponse;
}
///
/// Sends a direct message to a remote party and waits for the response.
///
/// The message to send.
/// The remote party's response.
public IProtocolMessage Request(IDirectedProtocolMessage request) {
if (request == null) {
throw new ArgumentNullException("request");
}
this.PrepareMessageForSending(request);
Logger.DebugFormat("Sending request: {0}", request);
IProtocolMessage response = this.RequestInternal(request);
if (response != null) {
Logger.DebugFormat("Received response: {0}", response);
this.VerifyMessageAfterReceiving(response);
}
return response;
}
///
/// Gets the protocol message that may be in the given HTTP response stream.
///
/// The response that is anticipated to contain an OAuth message.
/// The deserialized message, if one is found. Null otherwise.
private IProtocolMessage ReadFromResponse(Stream responseStream) {
IProtocolMessage message = this.ReadFromResponseInternal(responseStream);
Logger.DebugFormat("Received message response: {0}", message);
this.VerifyMessageAfterReceiving(message);
return message;
}
///
/// Gets the current HTTP request being processed.
///
/// The HttpRequestInfo for the current request.
///
/// Requires an HttpContext.Current context.
///
/// Thrown when is null.
protected internal virtual HttpRequestInfo GetRequestFromContext() {
if (HttpContext.Current == null) {
throw new InvalidOperationException(MessagingStrings.HttpContextRequired);
}
return new HttpRequestInfo(HttpContext.Current.Request);
}
///
/// Gets the protocol message that may be embedded in the given HTTP request.
///
/// The request to search for an embedded message.
/// The deserialized message, if one is found. Null otherwise.
protected virtual IProtocolMessage ReadFromRequestInternal(HttpRequestInfo request) {
if (request == null) {
throw new ArgumentNullException("request");
}
// Search Form data first, and if nothing is there search the QueryString
var fields = request.Form.ToDictionary();
if (fields.Count == 0) {
fields = request.QueryString.ToDictionary();
}
return this.Receive(fields, request.GetRecipient());
}
///
/// Deserializes a dictionary of values into a message.
///
/// The dictionary of values that were read from an HTTP request or response.
/// Information about where the message was been directed. Null for direct response messages.
/// The deserialized message, or null if no message could be recognized in the provided data.
protected virtual IProtocolMessage Receive(Dictionary fields, MessageReceivingEndpoint recipient) {
if (fields == null) {
throw new ArgumentNullException("fields");
}
Type messageType = this.MessageTypeProvider.GetRequestMessageType(fields);
// If there was no data, or we couldn't recognize it as a message, abort.
if (messageType == null) {
return null;
}
// We have a message! Assemble it.
var serializer = MessageSerializer.Get(messageType);
IProtocolMessage message = serializer.Deserialize(fields, recipient);
return message;
}
///
/// Queues an indirect message for transmittal via the user agent.
///
/// The message to send.
/// The pending user agent redirect based message to be sent as an HttpResponse.
protected virtual Response SendIndirectMessage(IDirectedProtocolMessage message) {
if (message == null) {
throw new ArgumentNullException("message");
}
var serializer = MessageSerializer.Get(message.GetType());
var fields = serializer.Serialize(message);
Response response;
if (CalculateSizeOfPayload(fields) > indirectMessageGetToPostThreshold) {
response = this.CreateFormPostResponse(message, fields);
} else {
response = this.Create301RedirectResponse(message, fields);
}
return response;
}
///
/// Encodes an HTTP response that will instruct the user agent to forward a message to
/// some remote third party using a 301 Redirect GET method.
///
/// The message to forward.
/// The pre-serialized fields from the message.
/// The encoded HTTP response.
protected virtual Response Create301RedirectResponse(IDirectedProtocolMessage message, IDictionary fields) {
if (message == null) {
throw new ArgumentNullException("message");
}
if (message.Recipient == null) {
throw new ArgumentException(MessagingStrings.DirectedMessageMissingRecipient, "message");
}
if (fields == null) {
throw new ArgumentNullException("fields");
}
WebHeaderCollection headers = new WebHeaderCollection();
UriBuilder builder = new UriBuilder(message.Recipient);
MessagingUtilities.AppendQueryArgs(builder, fields);
headers.Add(HttpResponseHeader.Location, builder.Uri.AbsoluteUri);
Logger.DebugFormat("Redirecting to {0}", builder.Uri.AbsoluteUri);
Response response = new Response {
Status = HttpStatusCode.Redirect,
Headers = headers,
Body = null,
OriginalMessage = message
};
return response;
}
///
/// Encodes an HTTP response that will instruct the user agent to forward a message to
/// some remote third party using a form POST method.
///
/// The message to forward.
/// The pre-serialized fields from the message.
/// The encoded HTTP response.
protected virtual Response CreateFormPostResponse(IDirectedProtocolMessage message, IDictionary fields) {
if (message == null) {
throw new ArgumentNullException("message");
}
if (message.Recipient == null) {
throw new ArgumentException(MessagingStrings.DirectedMessageMissingRecipient, "message");
}
if (fields == null) {
throw new ArgumentNullException("fields");
}
WebHeaderCollection headers = new WebHeaderCollection();
StringWriter bodyWriter = new StringWriter(CultureInfo.InvariantCulture);
StringBuilder hiddenFields = new StringBuilder();
foreach (var field in fields) {
hiddenFields.AppendFormat(
"\t\r\n",
HttpUtility.HtmlEncode(field.Key),
HttpUtility.HtmlEncode(field.Value));
}
bodyWriter.WriteLine(
indirectMessageFormPostFormat,
HttpUtility.HtmlEncode(message.Recipient.AbsoluteUri),
hiddenFields);
bodyWriter.Flush();
Response response = new Response {
Status = HttpStatusCode.OK,
Headers = headers,
Body = bodyWriter.ToString(),
OriginalMessage = message
};
return response;
}
///
/// Gets the protocol message that may be in the given HTTP response stream.
///
/// The response that is anticipated to contain an OAuth message.
/// The deserialized message, if one is found. Null otherwise.
protected abstract IProtocolMessage ReadFromResponseInternal(Stream responseStream);
///
/// Sends a direct message to a remote party and waits for the response.
///
/// The message to send.
/// The remote party's response.
protected abstract IProtocolMessage RequestInternal(IDirectedProtocolMessage request);
///
/// Queues a message for sending in the response stream where the fields
/// are sent in the response stream in querystring style.
///
/// The message to send as a response.
/// The pending user agent redirect based message to be sent as an HttpResponse.
///
/// This method implements spec V1.0 section 5.3.
///
protected abstract Response SendDirectMessageResponse(IProtocolMessage response);
///
/// Prepares a message for transmit by applying signatures, nonces, etc.
///
/// The message to prepare for sending.
///
/// This method should NOT be called by derived types
/// except when sending ONE WAY request messages.
///
protected void PrepareMessageForSending(IProtocolMessage message) {
if (message == null) {
throw new ArgumentNullException("message");
}
MessageProtections appliedProtection = MessageProtections.None;
foreach (IChannelBindingElement bindingElement in this.bindingElements) {
if (bindingElement.PrepareMessageForSending(message)) {
appliedProtection |= bindingElement.Protection;
}
}
// Ensure that the message's protection requirements have been satisfied.
if ((message.RequiredProtection & appliedProtection) != message.RequiredProtection) {
throw new UnprotectedMessageException(message, appliedProtection);
}
EnsureValidMessageParts(message);
message.EnsureValidMessage();
}
///
/// Calculates a fairly accurate estimation on the size of a message that contains
/// a given set of fields.
///
/// The fields that would be included in a message.
/// The size (in bytes) of the message payload.
private static int CalculateSizeOfPayload(IDictionary fields) {
Debug.Assert(fields != null, "fields == null");
int size = 0;
foreach (var field in fields) {
size += field.Key.Length;
size += field.Value.Length;
size += 2; // & and =
}
return size;
}
///
/// Ensures a consistent and secure set of binding elements and
/// sorts them as necessary for a valid sequence of operations.
///
/// The binding elements provided to the channel.
/// The properly ordered list of elements.
/// Thrown when the binding elements are incomplete or inconsistent with each other.
private static IEnumerable ValidateAndPrepareBindingElements(IEnumerable elements) {
if (elements == null) {
return new IChannelBindingElement[0];
}
if (elements.Contains(null)) {
throw new ArgumentException(MessagingStrings.SequenceContainsNullElement, "elements");
}
// Filter the elements between the mere transforming ones and the protection ones.
var transformationElements = new List(
elements.Where(element => element.Protection == MessageProtections.None));
var protectionElements = new List(
elements.Where(element => element.Protection != MessageProtections.None));
bool wasLastProtectionPresent = true;
foreach (MessageProtections protectionKind in Enum.GetValues(typeof(MessageProtections))) {
if (protectionKind == MessageProtections.None) {
continue;
}
int countProtectionsOfThisKind = protectionElements.Count(element => (element.Protection & protectionKind) == protectionKind);
// Each protection binding element is backed by the presence of its dependent protection(s).
if (countProtectionsOfThisKind > 0 && !wasLastProtectionPresent) {
throw new ProtocolException(
string.Format(
CultureInfo.CurrentCulture,
MessagingStrings.RequiredProtectionMissing,
protectionKind));
}
// At most one binding element for each protection type.
if (countProtectionsOfThisKind > 1) {
throw new ProtocolException(
string.Format(
CultureInfo.CurrentCulture,
MessagingStrings.TooManyBindingsOfferingSameProtection,
protectionKind,
countProtectionsOfThisKind));
}
wasLastProtectionPresent = countProtectionsOfThisKind > 0;
}
// Put the binding elements in order so they are correctly applied to outgoing messages.
// Start with the transforming (non-protecting) binding elements first and preserve their original order.
var orderedList = new List(transformationElements);
// Now sort the protection binding elements among themselves and add them to the list.
orderedList.AddRange(protectionElements.OrderBy(element => element.Protection, BindingElementOutgoingMessageApplicationOrder));
return orderedList;
}
///
/// Puts binding elements in their correct outgoing message processing order.
///
/// The first protection type to compare.
/// The second protection type to compare.
///
/// -1 if should be applied to an outgoing message before .
/// 1 if should be applied to an outgoing message before .
/// 0 if it doesn't matter.
///
private static int BindingElementOutgoingMessageApplicationOrder(MessageProtections protection1, MessageProtections protection2) {
Debug.Assert(protection1 != MessageProtections.None || protection2 != MessageProtections.None, "This comparison function should only be used to compare protection binding elements. Otherwise we change the order of user-defined message transformations.");
// Now put the protection ones in the right order.
return -((int)protection1).CompareTo((int)protection2); // descending flag ordinal order
}
///
/// Verifies that all required message parts are initialized to values
/// prior to sending the message to a remote party.
///
/// The message to verify.
///
/// Thrown when any required message part does not have a value.
///
private static void EnsureValidMessageParts(IProtocolMessage message) {
Debug.Assert(message != null, "message == null");
MessageDictionary dictionary = new MessageDictionary(message);
MessageDescription description = MessageDescription.Get(message.GetType());
description.EnsureRequiredMessagePartsArePresent(dictionary.Keys);
}
///
/// Verifies the integrity and applicability of an incoming message.
///
/// The message just received.
///
/// Thrown when the message is somehow invalid.
/// This can be due to tampering, replay attack or expiration, among other things.
///
private void VerifyMessageAfterReceiving(IProtocolMessage message) {
Debug.Assert(message != null, "message == null");
MessageProtections appliedProtection = MessageProtections.None;
foreach (IChannelBindingElement bindingElement in this.bindingElements.Reverse()) {
if (bindingElement.PrepareMessageForReceiving(message)) {
appliedProtection |= bindingElement.Protection;
}
}
// Ensure that the message's protection requirements have been satisfied.
if ((message.RequiredProtection & appliedProtection) != message.RequiredProtection) {
throw new UnprotectedMessageException(message, appliedProtection);
}
// We do NOT verify that all required message parts are present here... the
// message deserializer did for us. It would be too late to do it here since
// they might look initialized by the time we have an IProtocolMessage instance.
message.EnsureValidMessage();
}
}
}