//----------------------------------------------------------------------- // // Copyright (c) Outercurve Foundation. All rights reserved. // //----------------------------------------------------------------------- namespace DotNetOpenAuth.Messaging { using System; using System.Collections.Generic; using System.Diagnostics.Contracts; using System.Linq; using System.Reflection; using System.Text; using DotNetOpenAuth.Messaging.Reflection; /// /// A message factory that automatically selects the message type based on the incoming data. /// internal class StandardMessageFactory : IMessageFactory { /// /// The request message types and their constructors to use for instantiating the messages. /// private readonly Dictionary requestMessageTypes = new Dictionary(); /// /// The response message types and their constructors to use for instantiating the messages. /// /// /// The value is a dictionary, whose key is the type of the constructor's lone parameter. /// private readonly Dictionary> responseMessageTypes = new Dictionary>(); /// /// Initializes a new instance of the class. /// internal StandardMessageFactory() { } /// /// Adds message types to the set that this factory can create. /// /// The message types that this factory may instantiate. public virtual void AddMessageTypes(IEnumerable messageTypes) { Requires.NotNull(messageTypes, "messageTypes"); Requires.True(messageTypes.All(msg => msg != null), "messageTypes"); var unsupportedMessageTypes = new List(0); foreach (MessageDescription messageDescription in messageTypes) { bool supportedMessageType = false; // First see whether this message fits the recognized pattern for request messages. if (typeof(IDirectedProtocolMessage).IsAssignableFrom(messageDescription.MessageType)) { foreach (ConstructorInfo ctor in messageDescription.Constructors) { ParameterInfo[] parameters = ctor.GetParameters(); if (parameters.Length == 2 && parameters[0].ParameterType == typeof(Uri) && parameters[1].ParameterType == typeof(Version)) { supportedMessageType = true; this.requestMessageTypes.Add(messageDescription, ctor); break; } } } // Also see if this message fits the recognized pattern for response messages. if (typeof(IDirectResponseProtocolMessage).IsAssignableFrom(messageDescription.MessageType)) { var responseCtors = new Dictionary(messageDescription.Constructors.Length); foreach (ConstructorInfo ctor in messageDescription.Constructors) { ParameterInfo[] parameters = ctor.GetParameters(); if (parameters.Length == 1 && typeof(IDirectedProtocolMessage).IsAssignableFrom(parameters[0].ParameterType)) { responseCtors.Add(parameters[0].ParameterType, ctor); } } if (responseCtors.Count > 0) { supportedMessageType = true; this.responseMessageTypes.Add(messageDescription, responseCtors); } } if (!supportedMessageType) { unsupportedMessageTypes.Add(messageDescription); } } ErrorUtilities.VerifySupported( !unsupportedMessageTypes.Any(), MessagingStrings.StandardMessageFactoryUnsupportedMessageType, unsupportedMessageTypes.ToStringDeferred()); } #region IMessageFactory Members /// /// Analyzes an incoming request message payload to discover what kind of /// message is embedded in it and returns the type, or null if no match is found. /// /// The intended or actual recipient of the request message. /// The name/value pairs that make up the message payload. /// /// A newly instantiated -derived object that this message can /// deserialize to. Null if the request isn't recognized as a valid protocol message. /// public virtual IDirectedProtocolMessage GetNewRequestMessage(MessageReceivingEndpoint recipient, IDictionary fields) { MessageDescription matchingType = this.GetMessageDescription(recipient, fields); if (matchingType != null) { return this.InstantiateAsRequest(matchingType, recipient); } else { return null; } } /// /// Analyzes an incoming request message payload to discover what kind of /// message is embedded in it and returns the type, or null if no match is found. /// /// The message that was sent as a request that resulted in the response. /// The name/value pairs that make up the message payload. /// /// A newly instantiated -derived object that this message can /// deserialize to. Null if the request isn't recognized as a valid protocol message. /// public virtual IDirectResponseProtocolMessage GetNewResponseMessage(IDirectedProtocolMessage request, IDictionary fields) { MessageDescription matchingType = this.GetMessageDescription(request, fields); if (matchingType != null) { return this.InstantiateAsResponse(matchingType, request); } else { return null; } } #endregion /// /// Gets the message type that best fits the given incoming request data. /// /// The recipient of the incoming data. Typically not used, but included just in case. /// The data of the incoming message. /// /// The message type that matches the incoming data; or null if no match. /// /// May be thrown if the incoming data is ambiguous. protected virtual MessageDescription GetMessageDescription(MessageReceivingEndpoint recipient, IDictionary fields) { Requires.NotNull(recipient, "recipient"); Requires.NotNull(fields, "fields"); var matches = this.requestMessageTypes.Keys .Where(message => message.CheckMessagePartsPassBasicValidation(fields)) .OrderByDescending(message => CountInCommon(message.Mapping.Keys, fields.Keys)) .ThenByDescending(message => message.Mapping.Count) .CacheGeneratedResults(); var match = matches.FirstOrDefault(); if (match != null) { if (Logger.Messaging.IsWarnEnabled && matches.Count() > 1) { Logger.Messaging.WarnFormat( "Multiple message types seemed to fit the incoming data: {0}", matches.ToStringDeferred()); } return match; } else { // No message type matches the incoming data. return null; } } /// /// Gets the message type that best fits the given incoming direct response data. /// /// The request message that prompted the response data. /// The data of the incoming message. /// /// The message type that matches the incoming data; or null if no match. /// /// May be thrown if the incoming data is ambiguous. protected virtual MessageDescription GetMessageDescription(IDirectedProtocolMessage request, IDictionary fields) { Requires.NotNull(request, "request"); Requires.NotNull(fields, "fields"); var matches = (from responseMessageType in this.responseMessageTypes let message = responseMessageType.Key where message.CheckMessagePartsPassBasicValidation(fields) let ctors = this.FindMatchingResponseConstructors(message, request.GetType()) where ctors.Any() orderby GetDerivationDistance(ctors.First().GetParameters()[0].ParameterType, request.GetType()), CountInCommon(message.Mapping.Keys, fields.Keys) descending, message.Mapping.Count descending select message).CacheGeneratedResults(); var match = matches.FirstOrDefault(); if (match != null) { if (Logger.Messaging.IsWarnEnabled && matches.Count() > 1) { Logger.Messaging.WarnFormat( "Multiple message types seemed to fit the incoming data: {0}", matches.ToStringDeferred()); } return match; } else { // No message type matches the incoming data. return null; } } /// /// Instantiates the given request message type. /// /// The message description. /// The recipient. /// The instantiated message. Never null. protected virtual IDirectedProtocolMessage InstantiateAsRequest(MessageDescription messageDescription, MessageReceivingEndpoint recipient) { Requires.NotNull(messageDescription, "messageDescription"); Requires.NotNull(recipient, "recipient"); Contract.Ensures(Contract.Result() != null); ConstructorInfo ctor = this.requestMessageTypes[messageDescription]; return (IDirectedProtocolMessage)ctor.Invoke(new object[] { recipient.Location, messageDescription.MessageVersion }); } /// /// Instantiates the given request message type. /// /// The message description. /// The request that resulted in this response. /// The instantiated message. Never null. protected virtual IDirectResponseProtocolMessage InstantiateAsResponse(MessageDescription messageDescription, IDirectedProtocolMessage request) { Requires.NotNull(messageDescription, "messageDescription"); Requires.NotNull(request, "request"); Contract.Ensures(Contract.Result() != null); Type requestType = request.GetType(); var ctors = this.FindMatchingResponseConstructors(messageDescription, requestType); ConstructorInfo ctor = null; try { ctor = ctors.Single(); } catch (InvalidOperationException) { if (ctors.Any()) { ErrorUtilities.ThrowInternal("More than one matching constructor for request type " + requestType.Name + " and response type " + messageDescription.MessageType.Name); } else { ErrorUtilities.ThrowInternal("Unexpected request message type " + requestType.FullName + " for response type " + messageDescription.MessageType.Name); } } return (IDirectResponseProtocolMessage)ctor.Invoke(new object[] { request }); } /// /// Gets the hierarchical distance between a type and a type it derives from or implements. /// /// The base type or interface. /// The concrete class that implements the . /// The distance between the two types. 0 if the types are equivalent, 1 if the type immediately derives from or implements the base type, or progressively higher integers. private static int GetDerivationDistance(Type assignableType, Type derivedType) { Requires.NotNull(assignableType, "assignableType"); Requires.NotNull(derivedType, "derivedType"); Requires.True(assignableType.IsAssignableFrom(derivedType), "assignableType"); // If this is the two types are equivalent... if (derivedType.IsAssignableFrom(assignableType)) { return 0; } int steps; derivedType = derivedType.BaseType; for (steps = 1; assignableType.IsAssignableFrom(derivedType); steps++) { derivedType = derivedType.BaseType; } return steps; } /// /// Counts how many strings are in the intersection of two collections. /// /// The first collection. /// The second collection. /// The string comparison method to use. /// A non-negative integer no greater than the count of elements in the smallest collection. private static int CountInCommon(ICollection collection1, ICollection collection2, StringComparison comparison = StringComparison.Ordinal) { Requires.NotNull(collection1, "collection1"); Requires.NotNull(collection2, "collection2"); Contract.Ensures(Contract.Result() >= 0 && Contract.Result() <= Math.Min(collection1.Count, collection2.Count)); return collection1.Count(value1 => collection2.Any(value2 => string.Equals(value1, value2, comparison))); } /// /// Finds constructors for response messages that take a given request message type. /// /// The message description. /// Type of the request message. /// A sequence of matching constructors. private IEnumerable FindMatchingResponseConstructors(MessageDescription messageDescription, Type requestType) { Requires.NotNull(messageDescription, "messageDescription"); Requires.NotNull(requestType, "requestType"); return this.responseMessageTypes[messageDescription].Where(pair => pair.Key.IsAssignableFrom(requestType)).Select(pair => pair.Value); } } }