//----------------------------------------------------------------------- // // Copyright (c) Andrew Arnott. All rights reserved. // //----------------------------------------------------------------------- namespace DotNetOAuth.Messaging { using System; using System.Collections.Generic; using System.IO; using System.Net; using System.Text; using System.Web; /// /// Manages sending direct messages to a remote party and receiving responses. /// internal abstract class Channel { /// /// The maximum allowable size for a 301 Redirect response before we send /// a 200 OK response with a scripted form POST with the parameters instead /// in order to ensure successfully sending a large payload to another server /// that might have a maximum allowable size restriction on its GET request. /// private static int indirectMessageGetToPostThreshold = 2 * 1024; // 2KB, recommended by OpenID group /// /// The template for indirect messages that require form POST to forward through the user agent. /// /// /// We are intentionally using " instead of the html single quote ' below because /// the HtmlEncode'd values that we inject will only escape the double quote, so /// only the double-quote used around these values is safe. /// private static string indirectMessageFormPostFormat = @"
{1}
"; /// /// A tool that can figure out what kind of message is being received /// so it can be deserialized. /// private IMessageTypeProvider messageTypeProvider; /// /// Gets or sets the HTTP response to send as a reply to the current incoming HTTP request. /// private Response queuedIndirectOrResponseMessage; /// /// Initializes a new instance of the class. /// /// /// A class prepared to analyze incoming messages and indicate what concrete /// message types can deserialize from it. /// protected Channel(IMessageTypeProvider messageTypeProvider) { if (messageTypeProvider == null) { throw new ArgumentNullException("messageTypeProvider"); } this.messageTypeProvider = messageTypeProvider; } /// /// Gets a tool that can figure out what kind of message is being received /// so it can be deserialized. /// protected IMessageTypeProvider MessageTypeProvider { get { return this.messageTypeProvider; } } /// /// Retrieves the stored response for sending and clears it from the channel. /// /// The response to send as the HTTP response. internal Response DequeueIndirectOrResponseMessage() { Response response = this.queuedIndirectOrResponseMessage; this.queuedIndirectOrResponseMessage = null; return response; } /// /// Queues an indirect message (either a request or response) /// or direct message response for transmission to a remote party. /// /// The one-way message to send internal void Send(IProtocolMessage message) { this.Send(message, null); } /// /// Queues an indirect message (either a request or response) /// or direct message response for transmission to a remote party. /// /// The one-way message to send /// /// If is a response to an incoming message, this is the incoming message. /// This is useful for error scenarios in deciding just how to send the response message. /// May be null. /// internal void Send(IProtocolMessage message, IProtocolMessage inResponseTo) { if (message == null) { throw new ArgumentNullException("message"); } var directedMessage = message as IDirectedProtocolMessage; if (directedMessage == null) { // This is a response to a direct message. this.SendDirectMessageResponse(message); } else { if (directedMessage.Recipient != null) { // This is an indirect message request or reply. this.SendIndirectMessage(directedMessage); } else { ProtocolException exception = message as ProtocolException; if (exception != null) { if (inResponseTo is IDirectedProtocolMessage) { this.ReportErrorAsDirectResponse(exception); } else { this.ReportErrorToUser(exception); } } else { throw new InvalidOperationException(); } } } } /// /// Gets the protocol message embedded in the given HTTP request, if present. /// /// The deserialized message, if one is found. Null otherwise. /// /// Requires an HttpContext.Current context. /// internal IProtocolMessage ReadFromRequest() { return this.ReadFromRequest(new HttpRequestInfo(HttpContext.Current.Request)); } /// /// Gets the protocol message that may be embedded in the given HTTP request. /// /// The request to search for an embedded message. /// The deserialized message, if one is found. Null otherwise. protected internal virtual IProtocolMessage ReadFromRequest(HttpRequestInfo request) { if (request == null) { throw new ArgumentNullException("request"); } // Search Form data first, and if nothing is there search the QueryString var fields = request.Form.ToDictionary(); if (fields.Count == 0) { fields = request.QueryString.ToDictionary(); } return this.Receive(fields); } /// /// Gets the protocol message that may be in the given HTTP response stream. /// /// The response that is anticipated to contain an OAuth message. /// The deserialized message, if one is found. Null otherwise. protected internal abstract IProtocolMessage ReadFromResponse(Stream responseStream); /// /// Sends a direct message to a remote party and waits for the response. /// /// The message to send. /// The remote party's response. protected internal abstract IProtocolMessage Request(IDirectedProtocolMessage request); /// /// Deserializes a dictionary of values into a message. /// /// The dictionary of values that were read from an HTTP request or response. /// The deserialized message. protected virtual IProtocolMessage Receive(Dictionary fields) { Type messageType = null; if (fields != null) { messageType = this.MessageTypeProvider.GetRequestMessageType(fields); } // If there was no data, or we couldn't recognize it as a message, abort. if (messageType == null) { return null; } // We have a message! Assemble it. var serializer = MessageSerializer.Get(messageType); IProtocolMessage message = serializer.Deserialize(fields); return message; } /// /// Takes a message and temporarily stores it for sending as the hosting site's /// HTTP response to the current request. /// /// The message to store for sending. protected void QueueIndirectOrResponseMessage(Response response) { if (response == null) { throw new ArgumentNullException("response"); } if (this.queuedIndirectOrResponseMessage != null) { throw new InvalidOperationException(MessagingStrings.QueuedMessageResponseAlreadyExists); } this.queuedIndirectOrResponseMessage = response; } /// /// Queues an indirect message for transmittal via the user agent. /// /// The message to send. protected virtual void SendIndirectMessage(IDirectedProtocolMessage message) { if (message == null) { throw new ArgumentNullException("message"); } var serializer = MessageSerializer.Get(message.GetType()); var fields = serializer.Serialize(message); Response response; if (CalculateSizeOfPayload(fields) > indirectMessageGetToPostThreshold) { response = this.CreateFormPostResponse(message, fields); } else { response = this.Create301RedirectResponse(message, fields); } this.QueueIndirectOrResponseMessage(response); } /// /// Encodes an HTTP response that will instruct the user agent to forward a message to /// some remote third party using a 301 Redirect GET method. /// /// The message to forward. /// The pre-serialized fields from the message. /// The encoded HTTP response. protected virtual Response Create301RedirectResponse(IDirectedProtocolMessage message, IDictionary fields) { if (message == null) { throw new ArgumentNullException("message"); } if (fields == null) { throw new ArgumentNullException("fields"); } WebHeaderCollection headers = new WebHeaderCollection(); UriBuilder builder = new UriBuilder(message.Recipient); MessagingUtilities.AppendQueryArgs(builder, fields); headers.Add(HttpResponseHeader.Location, builder.Uri.AbsoluteUri); Logger.DebugFormat("Redirecting to {0}", builder.Uri.AbsoluteUri); Response response = new Response { Status = HttpStatusCode.Redirect, Headers = headers, Body = new byte[0], OriginalMessage = message }; return response; } /// /// Encodes an HTTP response that will instruct the user agent to forward a message to /// some remote third party using a form POST method. /// /// The message to forward. /// The pre-serialized fields from the message. /// The encoded HTTP response. protected virtual Response CreateFormPostResponse(IDirectedProtocolMessage message, IDictionary fields) { if (message == null) { throw new ArgumentNullException("message"); } if (fields == null) { throw new ArgumentNullException("fields"); } WebHeaderCollection headers = new WebHeaderCollection(); MemoryStream body = new MemoryStream(); StreamWriter bodyWriter = new StreamWriter(body); StringBuilder hiddenFields = new StringBuilder(); foreach (var field in fields) { hiddenFields.AppendFormat( "\t\r\n", HttpUtility.HtmlEncode(field.Key), HttpUtility.HtmlEncode(field.Value)); } bodyWriter.WriteLine( indirectMessageFormPostFormat, HttpUtility.HtmlEncode(message.Recipient.AbsoluteUri), hiddenFields); bodyWriter.Flush(); Response response = new Response { Status = HttpStatusCode.Redirect, Headers = headers, Body = body.ToArray(), OriginalMessage = message }; return response; } /// /// Queues a message for sending in the response stream where the fields /// are sent in the response stream in querystring style. /// /// The message to send as a response. /// /// This method implements spec V1.0 section 5.3. /// protected abstract void SendDirectMessageResponse(IProtocolMessage response); /// /// Reports an error to the user via the user agent. /// /// The error information. protected abstract void ReportErrorToUser(ProtocolException exception); /// /// Sends an error result directly to the calling remote party according to the /// rules of the protocol. /// /// The error information. protected abstract void ReportErrorAsDirectResponse(ProtocolException exception); /// /// Calculates a fairly accurate estimation on the size of a message that contains /// a given set of fields. /// /// The fields that would be included in a message. /// The size (in bytes) of the message payload. private static int CalculateSizeOfPayload(IDictionary fields) { if (fields == null) { throw new ArgumentNullException("fields"); } int size = 0; foreach (var field in fields) { size += field.Key.Length; size += field.Value.Length; size += 2; // & and = } return size; } } }