//-----------------------------------------------------------------------
//
// Copyright (c) Andrew Arnott. All rights reserved.
//
//-----------------------------------------------------------------------
namespace DotNetOAuth.Messaging {
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Text;
using System.Web;
///
/// Manages sending direct messages to a remote party and receiving responses.
///
internal abstract class Channel {
///
/// The maximum allowable size for a 301 Redirect response before we send
/// a 200 OK response with a scripted form POST with the parameters instead
/// in order to ensure successfully sending a large payload to another server
/// that might have a maximum allowable size restriction on its GET request.
///
private static int indirectMessageGetToPostThreshold = 2 * 1024; // 2KB, recommended by OpenID group
///
/// The template for indirect messages that require form POST to forward through the user agent.
///
///
/// We are intentionally using " instead of the html single quote ' below because
/// the HtmlEncode'd values that we inject will only escape the double quote, so
/// only the double-quote used around these values is safe.
///
private static string indirectMessageFormPostFormat = @"
";
///
/// A tool that can figure out what kind of message is being received
/// so it can be deserialized.
///
private IMessageTypeProvider messageTypeProvider;
///
/// Gets or sets the HTTP response to send as a reply to the current incoming HTTP request.
///
private Response queuedIndirectOrResponseMessage;
///
/// Initializes a new instance of the class.
///
///
/// A class prepared to analyze incoming messages and indicate what concrete
/// message types can deserialize from it.
///
protected Channel(IMessageTypeProvider messageTypeProvider) {
if (messageTypeProvider == null) {
throw new ArgumentNullException("messageTypeProvider");
}
this.messageTypeProvider = messageTypeProvider;
}
///
/// Gets or sets the message that came in as a request, if any.
///
///
/// This message is used to help determine how to transmit the response.
///
internal IProtocolMessage RequestInProcess { get; set; }
///
/// Gets a tool that can figure out what kind of message is being received
/// so it can be deserialized.
///
protected IMessageTypeProvider MessageTypeProvider {
get { return this.messageTypeProvider; }
}
///
/// Retrieves the stored response for sending and clears it from the channel.
///
/// The response to send as the HTTP response.
internal Response DequeueIndirectOrResponseMessage() {
Response response = this.queuedIndirectOrResponseMessage;
this.queuedIndirectOrResponseMessage = null;
return response;
}
///
/// Queues an indirect message (either a request or response)
/// or direct message response for transmission to a remote party.
///
/// The one-way message to send
internal void Send(IProtocolMessage message) {
if (message == null) {
throw new ArgumentNullException("message");
}
var directedMessage = message as IDirectedProtocolMessage;
if (directedMessage == null) {
// This is a response to a direct message.
this.SendDirectMessageResponse(message);
} else {
if (directedMessage.Recipient != null) {
// This is an indirect message request or reply.
this.SendIndirectMessage(directedMessage);
} else {
ProtocolException exception = message as ProtocolException;
if (exception != null) {
if (this.RequestInProcess is IDirectedProtocolMessage) {
this.ReportErrorAsDirectResponse(exception);
} else {
this.ReportErrorToUser(exception);
}
} else {
throw new InvalidOperationException();
}
}
}
}
///
/// Takes a message and temporarily stores it for sending as the hosting site's
/// HTTP response to the current request.
///
/// The message to store for sending.
protected void QueueIndirectOrResponseMessage(Response response) {
if (response == null) {
throw new ArgumentNullException("response");
}
if (this.queuedIndirectOrResponseMessage != null) {
throw new InvalidOperationException(MessagingStrings.QueuedMessageResponseAlreadyExists);
}
this.queuedIndirectOrResponseMessage = response;
}
///
/// Sends a direct message to a remote party and waits for the response.
///
/// The message to send.
/// The remote party's response.
protected abstract IProtocolMessage Request(IDirectedProtocolMessage request);
///
/// Queues an indirect message for transmittal via the user agent.
///
/// The message to send.
protected virtual void SendIndirectMessage(IDirectedProtocolMessage message) {
if (message == null) {
throw new ArgumentNullException("message");
}
var serializer = MessageSerializer.Get(message.GetType());
var fields = serializer.Serialize(message);
Response response;
if (CalculateSizeOfPayload(fields) > indirectMessageGetToPostThreshold) {
response = this.CreateFormPostResponse(message, fields);
} else {
response = this.Create301RedirectResponse(message, fields);
}
this.QueueIndirectOrResponseMessage(response);
}
///
/// Encodes an HTTP response that will instruct the user agent to forward a message to
/// some remote third party using a 301 Redirect GET method.
///
/// The message to forward.
/// The pre-serialized fields from the message.
/// The encoded HTTP response.
protected virtual Response Create301RedirectResponse(IDirectedProtocolMessage message, IDictionary fields) {
if (message == null) {
throw new ArgumentNullException("message");
}
if (fields == null) {
throw new ArgumentNullException("fields");
}
WebHeaderCollection headers = new WebHeaderCollection();
UriBuilder builder = new UriBuilder(message.Recipient);
MessagingUtilities.AppendQueryArgs(builder, fields);
headers.Add(HttpResponseHeader.Location, builder.Uri.AbsoluteUri);
Logger.DebugFormat("Redirecting to {0}", builder.Uri.AbsoluteUri);
Response response = new Response {
Status = HttpStatusCode.Redirect,
Headers = headers,
Body = new byte[0],
OriginalMessage = message
};
return response;
}
///
/// Encodes an HTTP response that will instruct the user agent to forward a message to
/// some remote third party using a form POST method.
///
/// The message to forward.
/// The pre-serialized fields from the message.
/// The encoded HTTP response.
protected virtual Response CreateFormPostResponse(IDirectedProtocolMessage message, IDictionary fields) {
if (message == null) {
throw new ArgumentNullException("message");
}
if (fields == null) {
throw new ArgumentNullException("fields");
}
WebHeaderCollection headers = new WebHeaderCollection();
MemoryStream body = new MemoryStream();
StreamWriter bodyWriter = new StreamWriter(body);
StringBuilder hiddenFields = new StringBuilder();
foreach (var field in fields) {
hiddenFields.AppendFormat(
"\t\r\n",
HttpUtility.HtmlEncode(field.Key),
HttpUtility.HtmlEncode(field.Value));
}
bodyWriter.WriteLine(
indirectMessageFormPostFormat,
HttpUtility.HtmlEncode(message.Recipient.AbsoluteUri),
hiddenFields);
bodyWriter.Flush();
Response response = new Response {
Status = HttpStatusCode.Redirect,
Headers = headers,
Body = body.ToArray(),
OriginalMessage = message
};
return response;
}
///
/// Queues a message for sending in the response stream where the fields
/// are sent in the response stream in querystring style.
///
/// The message to send as a response.
///
/// This method implements spec V1.0 section 5.3.
///
protected abstract void SendDirectMessageResponse(IProtocolMessage response);
///
/// Reports an error to the user via the user agent.
///
/// The error information.
protected abstract void ReportErrorToUser(ProtocolException exception);
///
/// Sends an error result directly to the calling remote party according to the
/// rules of the protocol.
///
/// The error information.
protected abstract void ReportErrorAsDirectResponse(ProtocolException exception);
///
/// Calculates a fairly accurate estimation on the size of a message that contains
/// a given set of fields.
///
/// The fields that would be included in a message.
/// The size (in bytes) of the message payload.
private static int CalculateSizeOfPayload(IDictionary fields) {
if (fields == null) {
throw new ArgumentNullException("fields");
}
int size = 0;
foreach (var field in fields) {
size += field.Key.Length;
size += field.Value.Length;
size += 2; // & and =
}
return size;
}
}
}