//----------------------------------------------------------------------- // // Copyright (c) Andrew Arnott. All rights reserved. // //----------------------------------------------------------------------- namespace DotNetOAuth.Messaging { using System; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Diagnostics; using System.Globalization; using System.IO; using System.Linq; using System.Net; using System.Text; using System.Web; using DotNetOAuth.Messaging.Reflection; /// /// Manages sending direct messages to a remote party and receiving responses. /// public abstract class Channel { /// /// The maximum allowable size for a 301 Redirect response before we send /// a 200 OK response with a scripted form POST with the parameters instead /// in order to ensure successfully sending a large payload to another server /// that might have a maximum allowable size restriction on its GET request. /// private static int indirectMessageGetToPostThreshold = 2 * 1024; // 2KB, recommended by OpenID group /// /// The template for indirect messages that require form POST to forward through the user agent. /// /// /// We are intentionally using " instead of the html single quote ' below because /// the HtmlEncode'd values that we inject will only escape the double quote, so /// only the double-quote used around these values is safe. /// private static string indirectMessageFormPostFormat = @"
{1}
"; /// /// A tool that can figure out what kind of message is being received /// so it can be deserialized. /// private IMessageTypeProvider messageTypeProvider; /// /// A list of binding elements in the order they must be applied to outgoing messages. /// /// /// Incoming messages should have the binding elements applied in reverse order. /// [DebuggerBrowsable(DebuggerBrowsableState.Never)] private List bindingElements = new List(); /// /// Initializes a new instance of the class. /// /// /// A class prepared to analyze incoming messages and indicate what concrete /// message types can deserialize from it. /// /// The binding elements to use in sending and receiving messages. protected Channel(IMessageTypeProvider messageTypeProvider, params IChannelBindingElement[] bindingElements) { if (messageTypeProvider == null) { throw new ArgumentNullException("messageTypeProvider"); } this.messageTypeProvider = messageTypeProvider; this.bindingElements = new List(ValidateAndPrepareBindingElements(bindingElements)); } /// /// Gets the binding elements used by this channel, in the order they are applied to outgoing messages. /// /// /// Incoming messages are processed by this binding elements in the reverse order. /// protected internal ReadOnlyCollection BindingElements { get { return this.bindingElements.AsReadOnly(); } } /// /// Gets a tool that can figure out what kind of message is being received /// so it can be deserialized. /// protected IMessageTypeProvider MessageTypeProvider { get { return this.messageTypeProvider; } } /// /// Queues an indirect message (either a request or response) /// or direct message response for transmission to a remote party. /// /// The one-way message to send /// The pending user agent redirect based message to be sent as an HttpResponse. public Response Send(IProtocolMessage message) { if (message == null) { throw new ArgumentNullException("message"); } this.PrepareMessageForSending(message); Logger.DebugFormat("Sending message: {0}", message); switch (message.Transport) { case MessageTransport.Direct: // This is a response to a direct message. return this.SendDirectMessageResponse(message); case MessageTransport.Indirect: var directedMessage = message as IDirectedProtocolMessage; if (directedMessage == null) { throw new ArgumentException( string.Format( CultureInfo.CurrentCulture, MessagingStrings.IndirectMessagesMustImplementIDirectedProtocolMessage, typeof(IDirectedProtocolMessage).FullName), "message"); } if (directedMessage.Recipient == null) { throw new ArgumentException(MessagingStrings.DirectedMessageMissingRecipient, "message"); } return this.SendIndirectMessage(directedMessage); default: throw new ArgumentException( string.Format( CultureInfo.CurrentCulture, MessagingStrings.UnrecognizedEnumValue, "Transport", message.Transport), "message"); } } /// /// Gets the protocol message embedded in the given HTTP request, if present. /// /// The deserialized message, if one is found. Null otherwise. /// /// Requires an HttpContext.Current context. /// /// Thrown when is null. public IProtocolMessage ReadFromRequest() { return this.ReadFromRequest(this.GetRequestFromContext()); } /// /// Gets the protocol message embedded in the given HTTP request, if present. /// /// The expected type of the message to be received. /// The deserialized message, if one is found. Null otherwise. /// True if the expected message was recognized and deserialized. False otherwise. /// /// Requires an HttpContext.Current context. /// /// Thrown when is null. /// Thrown when a request message of an unexpected type is received. public bool TryReadFromRequest(out TREQUEST request) where TREQUEST : class, IProtocolMessage { return TryReadFromRequest(this.GetRequestFromContext(), out request); } /// /// Gets the protocol message embedded in the given HTTP request, if present. /// /// The expected type of the message to be received. /// The request to search for an embedded message. /// The deserialized message, if one is found. Null otherwise. /// True if the expected message was recognized and deserialized. False otherwise. /// Thrown when is null. /// Thrown when a request message of an unexpected type is received. public bool TryReadFromRequest(HttpRequestInfo httpRequest, out TREQUEST request) where TREQUEST : class, IProtocolMessage { IProtocolMessage untypedRequest = this.ReadFromRequest(httpRequest); if (untypedRequest == null) { request = null; return false; } request = untypedRequest as TREQUEST; if (request == null) { throw new ProtocolException( string.Format( CultureInfo.CurrentCulture, MessagingStrings.UnexpectedMessageReceived, typeof(TREQUEST), untypedRequest.GetType())); } return true; } /// /// Gets the protocol message embedded in the given HTTP request, if present. /// /// The expected type of the message to be received. /// The deserialized message. /// /// Requires an HttpContext.Current context. /// /// Thrown when is null. /// Thrown if the expected message was not recognized in the response. public TREQUEST ReadFromRequest() where TREQUEST : class, IProtocolMessage { return this.ReadFromRequest(this.GetRequestFromContext()); } /// /// Gets the protocol message that may be embedded in the given HTTP request. /// /// The expected type of the message to be received. /// The request to search for an embedded message. /// The deserialized message, if one is found. Null otherwise. /// Thrown if the expected message was not recognized in the response. public TREQUEST ReadFromRequest(HttpRequestInfo httpRequest) where TREQUEST : class, IProtocolMessage { TREQUEST request; if (this.TryReadFromRequest(httpRequest, out request)) { return request; } else { throw new ProtocolException( string.Format( CultureInfo.CurrentCulture, MessagingStrings.ExpectedMessageNotReceived, typeof(TREQUEST))); } } /// /// Gets the protocol message that may be embedded in the given HTTP request. /// /// The request to search for an embedded message. /// The deserialized message, if one is found. Null otherwise. public IProtocolMessage ReadFromRequest(HttpRequestInfo httpRequest) { IProtocolMessage requestMessage = this.ReadFromRequestInternal(httpRequest); if (requestMessage != null) { Logger.DebugFormat("Incoming request received: {0}", requestMessage); this.VerifyMessageAfterReceiving(requestMessage); } return requestMessage; } /// /// Sends a direct message to a remote party and waits for the response. /// /// The expected type of the message to be received. /// The message to send. /// The remote party's response. /// /// Thrown if no message is recognized in the response /// or an unexpected type of message is received. /// public TRESPONSE Request(IDirectedProtocolMessage request) where TRESPONSE : class, IProtocolMessage { IProtocolMessage response = this.Request(request); if (response == null) { throw new ProtocolException( string.Format( CultureInfo.CurrentCulture, MessagingStrings.ExpectedMessageNotReceived, typeof(TRESPONSE))); } var expectedResponse = response as TRESPONSE; if (expectedResponse == null) { throw new ProtocolException( string.Format( CultureInfo.CurrentCulture, MessagingStrings.UnexpectedMessageReceived, typeof(TRESPONSE), response.GetType())); } return expectedResponse; } /// /// Sends a direct message to a remote party and waits for the response. /// /// The message to send. /// The remote party's response. public IProtocolMessage Request(IDirectedProtocolMessage request) { if (request == null) { throw new ArgumentNullException("request"); } this.PrepareMessageForSending(request); Logger.DebugFormat("Sending request: {0}", request); IProtocolMessage response = this.RequestInternal(request); if (response != null) { Logger.DebugFormat("Received response: {0}", response); this.VerifyMessageAfterReceiving(response); } return response; } /// /// Gets the protocol message that may be in the given HTTP response stream. /// /// The response that is anticipated to contain an OAuth message. /// The deserialized message, if one is found. Null otherwise. private IProtocolMessage ReadFromResponse(Stream responseStream) { IProtocolMessage message = this.ReadFromResponseInternal(responseStream); Logger.DebugFormat("Received message response: {0}", message); this.VerifyMessageAfterReceiving(message); return message; } /// /// Gets the current HTTP request being processed. /// /// The HttpRequestInfo for the current request. /// /// Requires an HttpContext.Current context. /// /// Thrown when is null. protected internal virtual HttpRequestInfo GetRequestFromContext() { if (HttpContext.Current == null) { throw new InvalidOperationException(MessagingStrings.HttpContextRequired); } return new HttpRequestInfo(HttpContext.Current.Request); } /// /// Gets the protocol message that may be embedded in the given HTTP request. /// /// The request to search for an embedded message. /// The deserialized message, if one is found. Null otherwise. protected virtual IProtocolMessage ReadFromRequestInternal(HttpRequestInfo request) { if (request == null) { throw new ArgumentNullException("request"); } // Search Form data first, and if nothing is there search the QueryString var fields = request.Form.ToDictionary(); if (fields.Count == 0) { fields = request.QueryString.ToDictionary(); } return this.Receive(fields, request.GetRecipient()); } /// /// Deserializes a dictionary of values into a message. /// /// The dictionary of values that were read from an HTTP request or response. /// Information about where the message was been directed. Null for direct response messages. /// The deserialized message, or null if no message could be recognized in the provided data. protected virtual IProtocolMessage Receive(Dictionary fields, MessageReceivingEndpoint recipient) { if (fields == null) { throw new ArgumentNullException("fields"); } Type messageType = this.MessageTypeProvider.GetRequestMessageType(fields); // If there was no data, or we couldn't recognize it as a message, abort. if (messageType == null) { return null; } // We have a message! Assemble it. var serializer = MessageSerializer.Get(messageType); IProtocolMessage message = serializer.Deserialize(fields, recipient); return message; } /// /// Queues an indirect message for transmittal via the user agent. /// /// The message to send. /// The pending user agent redirect based message to be sent as an HttpResponse. protected virtual Response SendIndirectMessage(IDirectedProtocolMessage message) { if (message == null) { throw new ArgumentNullException("message"); } var serializer = MessageSerializer.Get(message.GetType()); var fields = serializer.Serialize(message); Response response; if (CalculateSizeOfPayload(fields) > indirectMessageGetToPostThreshold) { response = this.CreateFormPostResponse(message, fields); } else { response = this.Create301RedirectResponse(message, fields); } return response; } /// /// Encodes an HTTP response that will instruct the user agent to forward a message to /// some remote third party using a 301 Redirect GET method. /// /// The message to forward. /// The pre-serialized fields from the message. /// The encoded HTTP response. protected virtual Response Create301RedirectResponse(IDirectedProtocolMessage message, IDictionary fields) { if (message == null) { throw new ArgumentNullException("message"); } if (message.Recipient == null) { throw new ArgumentException(MessagingStrings.DirectedMessageMissingRecipient, "message"); } if (fields == null) { throw new ArgumentNullException("fields"); } WebHeaderCollection headers = new WebHeaderCollection(); UriBuilder builder = new UriBuilder(message.Recipient); MessagingUtilities.AppendQueryArgs(builder, fields); headers.Add(HttpResponseHeader.Location, builder.Uri.AbsoluteUri); Logger.DebugFormat("Redirecting to {0}", builder.Uri.AbsoluteUri); Response response = new Response { Status = HttpStatusCode.Redirect, Headers = headers, Body = null, OriginalMessage = message }; return response; } /// /// Encodes an HTTP response that will instruct the user agent to forward a message to /// some remote third party using a form POST method. /// /// The message to forward. /// The pre-serialized fields from the message. /// The encoded HTTP response. protected virtual Response CreateFormPostResponse(IDirectedProtocolMessage message, IDictionary fields) { if (message == null) { throw new ArgumentNullException("message"); } if (message.Recipient == null) { throw new ArgumentException(MessagingStrings.DirectedMessageMissingRecipient, "message"); } if (fields == null) { throw new ArgumentNullException("fields"); } WebHeaderCollection headers = new WebHeaderCollection(); StringWriter bodyWriter = new StringWriter(CultureInfo.InvariantCulture); StringBuilder hiddenFields = new StringBuilder(); foreach (var field in fields) { hiddenFields.AppendFormat( "\t\r\n", HttpUtility.HtmlEncode(field.Key), HttpUtility.HtmlEncode(field.Value)); } bodyWriter.WriteLine( indirectMessageFormPostFormat, HttpUtility.HtmlEncode(message.Recipient.AbsoluteUri), hiddenFields); bodyWriter.Flush(); Response response = new Response { Status = HttpStatusCode.OK, Headers = headers, Body = bodyWriter.ToString(), OriginalMessage = message }; return response; } /// /// Gets the protocol message that may be in the given HTTP response stream. /// /// The response that is anticipated to contain an OAuth message. /// The deserialized message, if one is found. Null otherwise. protected abstract IProtocolMessage ReadFromResponseInternal(Stream responseStream); /// /// Sends a direct message to a remote party and waits for the response. /// /// The message to send. /// The remote party's response. protected abstract IProtocolMessage RequestInternal(IDirectedProtocolMessage request); /// /// Queues a message for sending in the response stream where the fields /// are sent in the response stream in querystring style. /// /// The message to send as a response. /// The pending user agent redirect based message to be sent as an HttpResponse. /// /// This method implements spec V1.0 section 5.3. /// protected abstract Response SendDirectMessageResponse(IProtocolMessage response); /// /// Prepares a message for transmit by applying signatures, nonces, etc. /// /// The message to prepare for sending. /// /// This method should NOT be called by derived types /// except when sending ONE WAY request messages. /// protected void PrepareMessageForSending(IProtocolMessage message) { if (message == null) { throw new ArgumentNullException("message"); } MessageProtections appliedProtection = MessageProtections.None; foreach (IChannelBindingElement bindingElement in this.bindingElements) { if (bindingElement.PrepareMessageForSending(message)) { appliedProtection |= bindingElement.Protection; } } // Ensure that the message's protection requirements have been satisfied. if ((message.RequiredProtection & appliedProtection) != message.RequiredProtection) { throw new UnprotectedMessageException(message, appliedProtection); } EnsureValidMessageParts(message); message.EnsureValidMessage(); } /// /// Calculates a fairly accurate estimation on the size of a message that contains /// a given set of fields. /// /// The fields that would be included in a message. /// The size (in bytes) of the message payload. private static int CalculateSizeOfPayload(IDictionary fields) { Debug.Assert(fields != null, "fields == null"); int size = 0; foreach (var field in fields) { size += field.Key.Length; size += field.Value.Length; size += 2; // & and = } return size; } /// /// Ensures a consistent and secure set of binding elements and /// sorts them as necessary for a valid sequence of operations. /// /// The binding elements provided to the channel. /// The properly ordered list of elements. /// Thrown when the binding elements are incomplete or inconsistent with each other. private static IEnumerable ValidateAndPrepareBindingElements(IEnumerable elements) { if (elements == null) { return new IChannelBindingElement[0]; } if (elements.Contains(null)) { throw new ArgumentException(MessagingStrings.SequenceContainsNullElement, "elements"); } // Filter the elements between the mere transforming ones and the protection ones. var transformationElements = new List( elements.Where(element => element.Protection == MessageProtections.None)); var protectionElements = new List( elements.Where(element => element.Protection != MessageProtections.None)); bool wasLastProtectionPresent = true; foreach (MessageProtections protectionKind in Enum.GetValues(typeof(MessageProtections))) { if (protectionKind == MessageProtections.None) { continue; } int countProtectionsOfThisKind = protectionElements.Count(element => (element.Protection & protectionKind) == protectionKind); // Each protection binding element is backed by the presence of its dependent protection(s). if (countProtectionsOfThisKind > 0 && !wasLastProtectionPresent) { throw new ProtocolException( string.Format( CultureInfo.CurrentCulture, MessagingStrings.RequiredProtectionMissing, protectionKind)); } // At most one binding element for each protection type. if (countProtectionsOfThisKind > 1) { throw new ProtocolException( string.Format( CultureInfo.CurrentCulture, MessagingStrings.TooManyBindingsOfferingSameProtection, protectionKind, countProtectionsOfThisKind)); } wasLastProtectionPresent = countProtectionsOfThisKind > 0; } // Put the binding elements in order so they are correctly applied to outgoing messages. // Start with the transforming (non-protecting) binding elements first and preserve their original order. var orderedList = new List(transformationElements); // Now sort the protection binding elements among themselves and add them to the list. orderedList.AddRange(protectionElements.OrderBy(element => element.Protection, BindingElementOutgoingMessageApplicationOrder)); return orderedList; } /// /// Puts binding elements in their correct outgoing message processing order. /// /// The first protection type to compare. /// The second protection type to compare. /// /// -1 if should be applied to an outgoing message before . /// 1 if should be applied to an outgoing message before . /// 0 if it doesn't matter. /// private static int BindingElementOutgoingMessageApplicationOrder(MessageProtections protection1, MessageProtections protection2) { Debug.Assert(protection1 != MessageProtections.None || protection2 != MessageProtections.None, "This comparison function should only be used to compare protection binding elements. Otherwise we change the order of user-defined message transformations."); // Now put the protection ones in the right order. return -((int)protection1).CompareTo((int)protection2); // descending flag ordinal order } /// /// Verifies that all required message parts are initialized to values /// prior to sending the message to a remote party. /// /// The message to verify. /// /// Thrown when any required message part does not have a value. /// private static void EnsureValidMessageParts(IProtocolMessage message) { Debug.Assert(message != null, "message == null"); MessageDictionary dictionary = new MessageDictionary(message); MessageDescription description = MessageDescription.Get(message.GetType()); description.EnsureRequiredMessagePartsArePresent(dictionary.Keys); } /// /// Verifies the integrity and applicability of an incoming message. /// /// The message just received. /// /// Thrown when the message is somehow invalid. /// This can be due to tampering, replay attack or expiration, among other things. /// private void VerifyMessageAfterReceiving(IProtocolMessage message) { Debug.Assert(message != null, "message == null"); MessageProtections appliedProtection = MessageProtections.None; foreach (IChannelBindingElement bindingElement in this.bindingElements.Reverse()) { if (bindingElement.PrepareMessageForReceiving(message)) { appliedProtection |= bindingElement.Protection; } } // Ensure that the message's protection requirements have been satisfied. if ((message.RequiredProtection & appliedProtection) != message.RequiredProtection) { throw new UnprotectedMessageException(message, appliedProtection); } // We do NOT verify that all required message parts are present here... the // message deserializer did for us. It would be too late to do it here since // they might look initialized by the time we have an IProtocolMessage instance. message.EnsureValidMessage(); } } }