summaryrefslogtreecommitdiffstats
path: root/src/DotNetOAuth/Messaging/Channel.cs
diff options
context:
space:
mode:
Diffstat (limited to 'src/DotNetOAuth/Messaging/Channel.cs')
-rw-r--r--src/DotNetOAuth/Messaging/Channel.cs135
1 files changed, 134 insertions, 1 deletions
diff --git a/src/DotNetOAuth/Messaging/Channel.cs b/src/DotNetOAuth/Messaging/Channel.cs
index b9a7629..6951046 100644
--- a/src/DotNetOAuth/Messaging/Channel.cs
+++ b/src/DotNetOAuth/Messaging/Channel.cs
@@ -7,6 +7,7 @@
namespace DotNetOAuth.Messaging {
using System;
using System.Collections.Generic;
+ using System.Diagnostics;
using System.IO;
using System.Net;
using System.Text;
@@ -17,6 +18,33 @@ namespace DotNetOAuth.Messaging {
/// </summary>
internal abstract class Channel {
/// <summary>
+ /// The maximum allowable size for a 301 Redirect response before we send
+ /// a 200 OK response with a scripted form POST with the parameters instead
+ /// in order to ensure successfully sending a large payload to another server
+ /// that might have a maximum allowable size restriction on its GET request.
+ /// </summary>
+ private static int indirectMessageGetToPostThreshold = 2 * 1024; // 2KB, recommended by OpenID group
+
+ /// <summary>
+ /// The template for indirect messages that require form POST to forward through the user agent.
+ /// </summary>
+ /// <remarks>
+ /// We are intentionally using " instead of the html single quote ' below because
+ /// the HtmlEncode'd values that we inject will only escape the double quote, so
+ /// only the double-quote used around these values is safe.
+ /// </remarks>
+ private static string indirectMessageFormPostFormat = @"
+<html>
+<body onload=""var btn = document.getElementById('submit_button'); btn.disabled = true; btn.value = 'Login in progress'; document.getElementById('openid_message').submit()"">
+<form id=""openid_message"" action=""{0}"" method=""post"" accept-charset=""UTF-8"" enctype=""application/x-www-form-urlencoded"" onSubmit=""var btn = document.getElementById('submit_button'); btn.disabled = true; btn.value = 'Login in progress'; return true;"">
+{1}
+ <input id=""submit_button"" type=""submit"" value=""Continue"" />
+</form>
+</body>
+</html>
+";
+
+ /// <summary>
/// A tool that can figure out what kind of message is being received
/// so it can be deserialized.
/// </summary>
@@ -128,7 +156,92 @@ namespace DotNetOAuth.Messaging {
/// Queues an indirect message for transmittal via the user agent.
/// </summary>
/// <param name="message">The message to send.</param>
- protected abstract void SendIndirectMessage(IDirectedProtocolMessage message);
+ protected virtual void SendIndirectMessage(IDirectedProtocolMessage message) {
+ if (message == null) {
+ throw new ArgumentNullException("message");
+ }
+
+ var serializer = MessageSerializer.Get(message.GetType());
+ var fields = serializer.Serialize(message);
+ Response response;
+ if (CalculateSizeOfPayload(fields) > indirectMessageGetToPostThreshold) {
+ response = this.CreateFormPostResponse(message, fields);
+ } else {
+ response = this.Create301RedirectResponse(message, fields);
+ }
+
+ this.QueueIndirectOrResponseMessage(response);
+ }
+
+ /// <summary>
+ /// Encodes an HTTP response that will instruct the user agent to forward a message to
+ /// some remote third party using a 301 Redirect GET method.
+ /// </summary>
+ /// <param name="message">The message to forward.</param>
+ /// <param name="fields">The pre-serialized fields from the message.</param>
+ /// <returns>The encoded HTTP response.</returns>
+ protected virtual Response Create301RedirectResponse(IDirectedProtocolMessage message, IDictionary<string, string> fields) {
+ if (message == null) {
+ throw new ArgumentNullException("message");
+ }
+ if (fields == null) {
+ throw new ArgumentNullException("fields");
+ }
+
+ WebHeaderCollection headers = new WebHeaderCollection();
+ UriBuilder builder = new UriBuilder(message.Recipient);
+ MessagingUtilities.AppendQueryArgs(builder, fields);
+ headers.Add(HttpResponseHeader.Location, builder.Uri.AbsoluteUri);
+ Logger.DebugFormat("Redirecting to {0}", builder.Uri.AbsoluteUri);
+ Response response = new Response {
+ Status = HttpStatusCode.Redirect,
+ Headers = headers,
+ Body = new byte[0],
+ OriginalMessage = message
+ };
+
+ return response;
+ }
+
+ /// <summary>
+ /// Encodes an HTTP response that will instruct the user agent to forward a message to
+ /// some remote third party using a form POST method.
+ /// </summary>
+ /// <param name="message">The message to forward.</param>
+ /// <param name="fields">The pre-serialized fields from the message.</param>
+ /// <returns>The encoded HTTP response.</returns>
+ protected virtual Response CreateFormPostResponse(IDirectedProtocolMessage message, IDictionary<string, string> fields) {
+ if (message == null) {
+ throw new ArgumentNullException("message");
+ }
+ if (fields == null) {
+ throw new ArgumentNullException("fields");
+ }
+
+ WebHeaderCollection headers = new WebHeaderCollection();
+ MemoryStream body = new MemoryStream();
+ StreamWriter bodyWriter = new StreamWriter(body);
+ StringBuilder hiddenFields = new StringBuilder();
+ foreach (var field in fields) {
+ hiddenFields.AppendFormat(
+ "\t<input type=\"hidden\" name=\"{0}\" value=\"{1}\" />\r\n",
+ HttpUtility.HtmlEncode(field.Key),
+ HttpUtility.HtmlEncode(field.Value));
+ }
+ bodyWriter.WriteLine(
+ indirectMessageFormPostFormat,
+ HttpUtility.HtmlEncode(message.Recipient.AbsoluteUri),
+ hiddenFields);
+ bodyWriter.Flush();
+ Response response = new Response {
+ Status = HttpStatusCode.Redirect,
+ Headers = headers,
+ Body = body.ToArray(),
+ OriginalMessage = message
+ };
+
+ return response;
+ }
/// <summary>
/// Queues a message for sending in the response stream where the fields
@@ -152,5 +265,25 @@ namespace DotNetOAuth.Messaging {
/// </summary>
/// <param name="exception">The error information.</param>
protected abstract void ReportErrorAsDirectResponse(ProtocolException exception);
+
+ /// <summary>
+ /// Calculates a fairly accurate estimation on the size of a message that contains
+ /// a given set of fields.
+ /// </summary>
+ /// <param name="fields">The fields that would be included in a message.</param>
+ /// <returns>The size (in bytes) of the message payload.</returns>
+ private static int CalculateSizeOfPayload(IDictionary<string, string> fields) {
+ if (fields == null) {
+ throw new ArgumentNullException("fields");
+ }
+
+ int size = 0;
+ foreach (var field in fields) {
+ size += field.Key.Length;
+ size += field.Value.Length;
+ size += 2; // & and =
+ }
+ return size;
+ }
}
}