//-----------------------------------------------------------------------
//
// Copyright (c) Andrew Arnott. All rights reserved.
//
//-----------------------------------------------------------------------
namespace DotNetOpenAuth.Messaging {
using System;
using System.Collections.Generic;
using System.Diagnostics.Contracts;
using System.Linq;
using System.Reflection;
using System.Text;
using DotNetOpenAuth.Messaging.Reflection;
///
/// A message factory that automatically selects the message type based on the incoming data.
///
internal class StandardMessageFactory : IMessageFactory {
///
/// The request message types and their constructors to use for instantiating the messages.
///
private readonly Dictionary requestMessageTypes = new Dictionary();
///
/// The response message types and their constructors to use for instantiating the messages.
///
///
/// The value is a dictionary, whose key is the type of the constructor's lone parameter.
///
private readonly Dictionary> responseMessageTypes = new Dictionary>();
///
/// Initializes a new instance of the class.
///
internal StandardMessageFactory() {
}
///
/// Adds message types to the set that this factory can create.
///
/// The message types that this factory may instantiate.
public virtual void AddMessageTypes(IEnumerable messageTypes) {
Requires.NotNull(messageTypes, "messageTypes");
Requires.True(messageTypes.All(msg => msg != null), "messageTypes");
var unsupportedMessageTypes = new List(0);
foreach (MessageDescription messageDescription in messageTypes) {
bool supportedMessageType = false;
// First see whether this message fits the recognized pattern for request messages.
if (typeof(IDirectedProtocolMessage).IsAssignableFrom(messageDescription.MessageType)) {
foreach (ConstructorInfo ctor in messageDescription.Constructors) {
ParameterInfo[] parameters = ctor.GetParameters();
if (parameters.Length == 2 && parameters[0].ParameterType == typeof(Uri) && parameters[1].ParameterType == typeof(Version)) {
supportedMessageType = true;
this.requestMessageTypes.Add(messageDescription, ctor);
break;
}
}
}
// Also see if this message fits the recognized pattern for response messages.
if (typeof(IDirectResponseProtocolMessage).IsAssignableFrom(messageDescription.MessageType)) {
var responseCtors = new Dictionary(messageDescription.Constructors.Length);
foreach (ConstructorInfo ctor in messageDescription.Constructors) {
ParameterInfo[] parameters = ctor.GetParameters();
if (parameters.Length == 1 && typeof(IDirectedProtocolMessage).IsAssignableFrom(parameters[0].ParameterType)) {
responseCtors.Add(parameters[0].ParameterType, ctor);
}
}
if (responseCtors.Count > 0) {
supportedMessageType = true;
this.responseMessageTypes.Add(messageDescription, responseCtors);
}
}
if (!supportedMessageType) {
unsupportedMessageTypes.Add(messageDescription);
}
}
ErrorUtilities.VerifySupported(
!unsupportedMessageTypes.Any(),
MessagingStrings.StandardMessageFactoryUnsupportedMessageType,
unsupportedMessageTypes.ToStringDeferred());
}
#region IMessageFactory Members
///
/// Analyzes an incoming request message payload to discover what kind of
/// message is embedded in it and returns the type, or null if no match is found.
///
/// The intended or actual recipient of the request message.
/// The name/value pairs that make up the message payload.
///
/// A newly instantiated -derived object that this message can
/// deserialize to. Null if the request isn't recognized as a valid protocol message.
///
public virtual IDirectedProtocolMessage GetNewRequestMessage(MessageReceivingEndpoint recipient, IDictionary fields) {
MessageDescription matchingType = this.GetMessageDescription(recipient, fields);
if (matchingType != null) {
return this.InstantiateAsRequest(matchingType, recipient);
} else {
return null;
}
}
///
/// Analyzes an incoming request message payload to discover what kind of
/// message is embedded in it and returns the type, or null if no match is found.
///
/// The message that was sent as a request that resulted in the response.
/// The name/value pairs that make up the message payload.
///
/// A newly instantiated -derived object that this message can
/// deserialize to. Null if the request isn't recognized as a valid protocol message.
///
public virtual IDirectResponseProtocolMessage GetNewResponseMessage(IDirectedProtocolMessage request, IDictionary fields) {
MessageDescription matchingType = this.GetMessageDescription(request, fields);
if (matchingType != null) {
return this.InstantiateAsResponse(matchingType, request);
} else {
return null;
}
}
#endregion
///
/// Gets the message type that best fits the given incoming request data.
///
/// The recipient of the incoming data. Typically not used, but included just in case.
/// The data of the incoming message.
///
/// The message type that matches the incoming data; or null if no match.
///
/// May be thrown if the incoming data is ambiguous.
protected virtual MessageDescription GetMessageDescription(MessageReceivingEndpoint recipient, IDictionary fields) {
Requires.NotNull(recipient, "recipient");
Requires.NotNull(fields, "fields");
var matches = this.requestMessageTypes.Keys
.Where(message => message.CheckMessagePartsPassBasicValidation(fields))
.OrderByDescending(message => CountInCommon(message.Mapping.Keys, fields.Keys))
.ThenByDescending(message => message.Mapping.Count)
.CacheGeneratedResults();
var match = matches.FirstOrDefault();
if (match != null) {
if (Logger.Messaging.IsWarnEnabled && matches.Count() > 1) {
Logger.Messaging.WarnFormat(
"Multiple message types seemed to fit the incoming data: {0}",
matches.ToStringDeferred());
}
return match;
} else {
// No message type matches the incoming data.
return null;
}
}
///
/// Gets the message type that best fits the given incoming direct response data.
///
/// The request message that prompted the response data.
/// The data of the incoming message.
///
/// The message type that matches the incoming data; or null if no match.
///
/// May be thrown if the incoming data is ambiguous.
protected virtual MessageDescription GetMessageDescription(IDirectedProtocolMessage request, IDictionary fields) {
Requires.NotNull(request, "request");
Requires.NotNull(fields, "fields");
var matches = (from responseMessageType in this.responseMessageTypes
let message = responseMessageType.Key
where message.CheckMessagePartsPassBasicValidation(fields)
let ctors = this.FindMatchingResponseConstructors(message, request.GetType())
where ctors.Any()
orderby GetDerivationDistance(ctors.First().GetParameters()[0].ParameterType, request.GetType()),
CountInCommon(message.Mapping.Keys, fields.Keys) descending,
message.Mapping.Count descending
select message).CacheGeneratedResults();
var match = matches.FirstOrDefault();
if (match != null) {
if (Logger.Messaging.IsWarnEnabled && matches.Count() > 1) {
Logger.Messaging.WarnFormat(
"Multiple message types seemed to fit the incoming data: {0}",
matches.ToStringDeferred());
}
return match;
} else {
// No message type matches the incoming data.
return null;
}
}
///
/// Instantiates the given request message type.
///
/// The message description.
/// The recipient.
/// The instantiated message. Never null.
protected virtual IDirectedProtocolMessage InstantiateAsRequest(MessageDescription messageDescription, MessageReceivingEndpoint recipient) {
Requires.NotNull(messageDescription, "messageDescription");
Requires.NotNull(recipient, "recipient");
Contract.Ensures(Contract.Result() != null);
ConstructorInfo ctor = this.requestMessageTypes[messageDescription];
return (IDirectedProtocolMessage)ctor.Invoke(new object[] { recipient.Location, messageDescription.MessageVersion });
}
///
/// Instantiates the given request message type.
///
/// The message description.
/// The request that resulted in this response.
/// The instantiated message. Never null.
protected virtual IDirectResponseProtocolMessage InstantiateAsResponse(MessageDescription messageDescription, IDirectedProtocolMessage request) {
Requires.NotNull(messageDescription, "messageDescription");
Requires.NotNull(request, "request");
Contract.Ensures(Contract.Result() != null);
Type requestType = request.GetType();
var ctors = this.FindMatchingResponseConstructors(messageDescription, requestType);
ConstructorInfo ctor = null;
try {
ctor = ctors.Single();
} catch (InvalidOperationException) {
if (ctors.Any()) {
ErrorUtilities.ThrowInternal("More than one matching constructor for request type " + requestType.Name + " and response type " + messageDescription.MessageType.Name);
} else {
ErrorUtilities.ThrowInternal("Unexpected request message type " + requestType.FullName + " for response type " + messageDescription.MessageType.Name);
}
}
return (IDirectResponseProtocolMessage)ctor.Invoke(new object[] { request });
}
///
/// Gets the hierarchical distance between a type and a type it derives from or implements.
///
/// The base type or interface.
/// The concrete class that implements the .
/// The distance between the two types. 0 if the types are equivalent, 1 if the type immediately derives from or implements the base type, or progressively higher integers.
private static int GetDerivationDistance(Type assignableType, Type derivedType) {
Requires.NotNull(assignableType, "assignableType");
Requires.NotNull(derivedType, "derivedType");
Requires.True(assignableType.IsAssignableFrom(derivedType), "assignableType");
// If this is the two types are equivalent...
if (derivedType.IsAssignableFrom(assignableType))
{
return 0;
}
int steps;
derivedType = derivedType.BaseType;
for (steps = 1; assignableType.IsAssignableFrom(derivedType); steps++)
{
derivedType = derivedType.BaseType;
}
return steps;
}
///
/// Counts how many strings are in the intersection of two collections.
///
/// The first collection.
/// The second collection.
/// The string comparison method to use.
/// A non-negative integer no greater than the count of elements in the smallest collection.
private static int CountInCommon(ICollection collection1, ICollection collection2, StringComparison comparison = StringComparison.Ordinal) {
Requires.NotNull(collection1, "collection1");
Requires.NotNull(collection2, "collection2");
Contract.Ensures(Contract.Result() >= 0 && Contract.Result() <= Math.Min(collection1.Count, collection2.Count));
return collection1.Count(value1 => collection2.Any(value2 => string.Equals(value1, value2, comparison)));
}
///
/// Finds constructors for response messages that take a given request message type.
///
/// The message description.
/// Type of the request message.
/// A sequence of matching constructors.
private IEnumerable FindMatchingResponseConstructors(MessageDescription messageDescription, Type requestType) {
Requires.NotNull(messageDescription, "messageDescription");
Requires.NotNull(requestType, "requestType");
return this.responseMessageTypes[messageDescription].Where(pair => pair.Key.IsAssignableFrom(requestType)).Select(pair => pair.Value);
}
}
}