diff options
author | Andrew Arnott <andrewarnott@gmail.com> | 2010-02-22 08:39:49 -0800 |
---|---|---|
committer | Andrew Arnott <andrewarnott@gmail.com> | 2010-02-22 08:39:49 -0800 |
commit | 1b3fac76d3f63830d2a5fbe07b90fbcfb2bb8b8b (patch) | |
tree | 1615bc256cbe5c9536ca4492c8296072d57b70c6 /src | |
parent | 79ce3eee130d5bc58e44b78acd9b4b9693b07ecc (diff) | |
parent | ba1511f42007de1b6439d32a563b0f4a58dcbb53 (diff) | |
download | DotNetOpenAuth-1b3fac76d3f63830d2a5fbe07b90fbcfb2bb8b8b.zip DotNetOpenAuth-1b3fac76d3f63830d2a5fbe07b90fbcfb2bb8b8b.tar.gz DotNetOpenAuth-1b3fac76d3f63830d2a5fbe07b90fbcfb2bb8b8b.tar.bz2 |
Merge branch 'standardmessagefactory' into oauthWRAP
Conflicts:
src/DotNetOpenAuth/DotNetOpenAuth.csproj
Diffstat (limited to 'src')
9 files changed, 597 insertions, 65 deletions
diff --git a/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj b/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj index 6540714..a0c849a 100644 --- a/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj +++ b/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj @@ -196,6 +196,7 @@ <Compile Include="Messaging\Bindings\StandardExpirationBindingElementTests.cs" /> <Compile Include="Messaging\Reflection\MessagePartTests.cs" /> <Compile Include="Messaging\Reflection\ValueMappingTests.cs" /> + <Compile Include="Messaging\StandardMessageFactoryTests.cs" /> <Compile Include="Mocks\AssociateUnencryptedRequestNoSslCheck.cs" /> <Compile Include="Mocks\CoordinatingChannel.cs" /> <Compile Include="Mocks\CoordinatingHttpRequestInfo.cs" /> diff --git a/src/DotNetOpenAuth.Test/Messaging/Reflection/MessageDescriptionTests.cs b/src/DotNetOpenAuth.Test/Messaging/Reflection/MessageDescriptionTests.cs index e57df65..92f39cc 100644 --- a/src/DotNetOpenAuth.Test/Messaging/Reflection/MessageDescriptionTests.cs +++ b/src/DotNetOpenAuth.Test/Messaging/Reflection/MessageDescriptionTests.cs @@ -73,6 +73,16 @@ namespace DotNetOpenAuth.Test.Messaging.Reflection { Assert.IsTrue(v30.Mapping["OptionalIn10RequiredIn25AndLater"].IsRequired); } + /// <summary> + /// Verifies that the constructors cache is properly initialized. + /// </summary> + [TestCase] + public void CtorsCache() { + var message = new MessageDescription(typeof(MultiVersionMessage), new Version(1, 0)); + Assert.IsNotNull(message.Constructors); + Assert.AreEqual(1, message.Constructors.Length); + } + private class MultiVersionMessage : Mocks.TestBaseMessage { #pragma warning disable 0649 // these fields are never written to, but part of the test [MessagePart] diff --git a/src/DotNetOpenAuth.Test/Messaging/StandardMessageFactoryTests.cs b/src/DotNetOpenAuth.Test/Messaging/StandardMessageFactoryTests.cs new file mode 100644 index 0000000..2b0b4e7 --- /dev/null +++ b/src/DotNetOpenAuth.Test/Messaging/StandardMessageFactoryTests.cs @@ -0,0 +1,178 @@ +//----------------------------------------------------------------------- +// <copyright file="StandardMessageFactoryTests.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Test.Messaging { + using System; + using System.Collections.Generic; + using System.Linq; + using System.Text; + using DotNetOpenAuth.Messaging; + using DotNetOpenAuth.Messaging.Reflection; + using DotNetOpenAuth.Test.Mocks; + using NUnit.Framework; + + [TestFixture] + public class StandardMessageFactoryTests : MessagingTestBase { + private static readonly Version V1 = new Version(1, 0); + private static readonly MessageReceivingEndpoint receiver = new MessageReceivingEndpoint("http://receiver", HttpDeliveryMethods.PostRequest); + + private StandardMessageFactory factory; + + public override void SetUp() { + base.SetUp(); + + this.factory = new StandardMessageFactory(); + } + + /// <summary> + /// Verifies that AddMessageTypes throws the appropriate exception on null input. + /// </summary> + [TestCase, ExpectedException(typeof(ArgumentNullException))] + public void AddMessageTypesNull() { + this.factory.AddMessageTypes(null); + } + + /// <summary> + /// Verifies that AddMessageTypes throws the appropriate exception on null input. + /// </summary> + [TestCase, ExpectedException(typeof(ArgumentException))] + public void AddMessageTypesNullMessageDescription() { + this.factory.AddMessageTypes(new MessageDescription[] { null }); + } + + /// <summary> + /// Verifies very simple recognition of a single message type + /// </summary> + [TestCase] + public void SingleRequestMessageType() { + this.factory.AddMessageTypes(new MessageDescription[] { MessageDescriptions.Get(typeof(RequestMessageMock), V1) }); + var fields = new Dictionary<string, string> { + { "random", "bits" }, + }; + Assert.IsNull(this.factory.GetNewRequestMessage(receiver, fields)); + fields["Age"] = "18"; + Assert.IsInstanceOf(typeof(RequestMessageMock), this.factory.GetNewRequestMessage(receiver, fields)); + } + + /// <summary> + /// Verifies very simple recognition of a single message type + /// </summary> + [TestCase] + public void SingleResponseMessageType() { + this.factory.AddMessageTypes(new MessageDescription[] { MessageDescriptions.Get(typeof(DirectResponseMessageMock), V1) }); + var fields = new Dictionary<string, string> { + { "random", "bits" }, + }; + IDirectedProtocolMessage request = new RequestMessageMock(receiver.Location, V1); + Assert.IsNull(this.factory.GetNewResponseMessage(request, fields)); + fields["Age"] = "18"; + IDirectResponseProtocolMessage response = this.factory.GetNewResponseMessage(request, fields); + Assert.IsInstanceOf<DirectResponseMessageMock>(response); + Assert.AreSame(request, response.OriginatingRequest); + + // Verify that we can instantiate a response with a derived-type of an expected request message. + request = new TestSignedDirectedMessage(); + response = this.factory.GetNewResponseMessage(request, fields); + Assert.IsInstanceOf<DirectResponseMessageMock>(response); + Assert.AreSame(request, response.OriginatingRequest); + } + + private class DirectResponseMessageMock : IDirectResponseProtocolMessage { + internal DirectResponseMessageMock(RequestMessageMock request) { + this.OriginatingRequest = request; + } + + internal DirectResponseMessageMock(TestDirectedMessage request) { + this.OriginatingRequest = request; + } + + [MessagePart(IsRequired = true)] + public int Age { get; set; } + + #region IDirectResponseProtocolMessage Members + + public IDirectedProtocolMessage OriginatingRequest { get; private set; } + + #endregion + + #region IProtocolMessage Members + + public MessageProtections RequiredProtection { + get { throw new NotImplementedException(); } + } + + public MessageTransport Transport { + get { throw new NotImplementedException(); } + } + + #endregion + + #region IMessage Members + + public Version Version { + get { throw new NotImplementedException(); } + } + + public System.Collections.Generic.IDictionary<string, string> ExtraData { + get { throw new NotImplementedException(); } + } + + public void EnsureValidMessage() { + throw new NotImplementedException(); + } + + #endregion + } + + private class RequestMessageMock : IDirectedProtocolMessage { + internal RequestMessageMock(Uri recipient, Version version) { + } + + [MessagePart(IsRequired = true)] + public int Age { get; set; } + + #region IDirectedProtocolMessage Members + + public HttpDeliveryMethods HttpMethods { + get { throw new NotImplementedException(); } + } + + public Uri Recipient { + get { throw new NotImplementedException(); } + } + + #endregion + + #region IProtocolMessage Members + + public MessageProtections RequiredProtection { + get { throw new NotImplementedException(); } + } + + public MessageTransport Transport { + get { throw new NotImplementedException(); } + } + + #endregion + + #region IMessage Members + + public Version Version { + get { throw new NotImplementedException(); } + } + + public System.Collections.Generic.IDictionary<string, string> ExtraData { + get { throw new NotImplementedException(); } + } + + public void EnsureValidMessage() { + throw new NotImplementedException(); + } + + #endregion + } + } +} diff --git a/src/DotNetOpenAuth/DotNetOpenAuth.csproj b/src/DotNetOpenAuth/DotNetOpenAuth.csproj index 7f5b298..669863c 100644 --- a/src/DotNetOpenAuth/DotNetOpenAuth.csproj +++ b/src/DotNetOpenAuth/DotNetOpenAuth.csproj @@ -303,6 +303,7 @@ http://opensource.org/licenses/ms-pl.html <Compile Include="Messaging\Reflection\IMessagePartEncoder.cs" /> <Compile Include="Messaging\Reflection\IMessagePartNullEncoder.cs" /> <Compile Include="Messaging\Reflection\MessageDescriptionCollection.cs" /> + <Compile Include="Messaging\StandardMessageFactory.cs" /> <Compile Include="OAuthWrap\IClientTokenManager.cs" /> <Compile Include="OAuthWrap\Messages\Assertion\AssertionRequest.cs" /> <Compile Include="OAuthWrap\Messages\Assertion\AssertionFailedResponse.cs" /> diff --git a/src/DotNetOpenAuth/Messaging/MessagingStrings.Designer.cs b/src/DotNetOpenAuth/Messaging/MessagingStrings.Designer.cs index 0bbac42..ea3bf6b 100644 --- a/src/DotNetOpenAuth/Messaging/MessagingStrings.Designer.cs +++ b/src/DotNetOpenAuth/Messaging/MessagingStrings.Designer.cs @@ -1,7 +1,7 @@ //------------------------------------------------------------------------------ // <auto-generated> // This code was generated by a tool. -// Runtime Version:4.0.30104.0 +// Runtime Version:4.0.30128.0 // // Changes to this file may cause incorrect behavior and will be lost if // the code is regenerated. @@ -412,6 +412,15 @@ namespace DotNetOpenAuth.Messaging { } /// <summary> + /// Looks up a localized string similar to This message factory does not support message type(s): {0}. + /// </summary> + internal static string StandardMessageFactoryUnsupportedMessageType { + get { + return ResourceManager.GetString("StandardMessageFactoryUnsupportedMessageType", resourceCulture); + } + } + + /// <summary> /// Looks up a localized string similar to The stream must have a known length.. /// </summary> internal static string StreamMustHaveKnownLength { diff --git a/src/DotNetOpenAuth/Messaging/MessagingStrings.resx b/src/DotNetOpenAuth/Messaging/MessagingStrings.resx index 34385d4..8b50767 100644 --- a/src/DotNetOpenAuth/Messaging/MessagingStrings.resx +++ b/src/DotNetOpenAuth/Messaging/MessagingStrings.resx @@ -300,4 +300,7 @@ <data name="BinaryDataRequiresMultipart" xml:space="preserve"> <value>Unable to send all message data because some of it requires multi-part POST, but IMessageWithBinaryData.SendAsMultipart was false.</value> </data> + <data name="StandardMessageFactoryUnsupportedMessageType" xml:space="preserve"> + <value>This message factory does not support message type(s): {0}</value> + </data> </root>
\ No newline at end of file diff --git a/src/DotNetOpenAuth/Messaging/Reflection/MessageDescription.cs b/src/DotNetOpenAuth/Messaging/Reflection/MessageDescription.cs index 5493ba6..17b5304 100644 --- a/src/DotNetOpenAuth/Messaging/Reflection/MessageDescription.cs +++ b/src/DotNetOpenAuth/Messaging/Reflection/MessageDescription.cs @@ -19,16 +19,6 @@ namespace DotNetOpenAuth.Messaging.Reflection { /// </summary> internal class MessageDescription { /// <summary> - /// The type of message this instance was generated from. - /// </summary> - private Type messageType; - - /// <summary> - /// The message version this instance was generated from. - /// </summary> - private Version messageVersion; - - /// <summary> /// A mapping between the serialized key names and their /// describing <see cref="MessagePart"/> instances. /// </summary> @@ -44,8 +34,8 @@ namespace DotNetOpenAuth.Messaging.Reflection { Contract.Requires<ArgumentException>(typeof(IMessage).IsAssignableFrom(messageType)); Contract.Requires<ArgumentNullException>(messageVersion != null); - this.messageType = messageType; - this.messageVersion = messageVersion; + this.MessageType = messageType; + this.MessageVersion = messageVersion; this.ReflectMessageType(); } @@ -58,6 +48,32 @@ namespace DotNetOpenAuth.Messaging.Reflection { } /// <summary> + /// Gets the message version this instance was generated from. + /// </summary> + internal Version MessageVersion { get; private set; } + + /// <summary> + /// Gets the type of message this instance was generated from. + /// </summary> + /// <value>The type of the described message.</value> + internal Type MessageType { get; private set; } + + /// <summary> + /// Gets the constructors available on the message type. + /// </summary> + internal ConstructorInfo[] Constructors { get; private set; } + + /// <summary> + /// Returns a <see cref="System.String"/> that represents this instance. + /// </summary> + /// <returns> + /// A <see cref="System.String"/> that represents this instance. + /// </returns> + public override string ToString() { + return this.MessageType.Name + " (" + this.MessageVersion + ")"; + } + + /// <summary> /// Gets a dictionary that provides read/write access to a message. /// </summary> /// <param name="message">The message the dictionary should provide access to.</param> @@ -70,51 +86,17 @@ namespace DotNetOpenAuth.Messaging.Reflection { } /// <summary> - /// Reflects over some <see cref="IMessage"/>-implementing type - /// and prepares to serialize/deserialize instances of that type. - /// </summary> - internal void ReflectMessageType() { - this.mapping = new Dictionary<string, MessagePart>(); - - Type currentType = this.messageType; - do { - foreach (MemberInfo member in currentType.GetMembers(BindingFlags.Instance | BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly)) { - if (member is PropertyInfo || member is FieldInfo) { - MessagePartAttribute partAttribute = - (from a in member.GetCustomAttributes(typeof(MessagePartAttribute), true).OfType<MessagePartAttribute>() - orderby a.MinVersionValue descending - where a.MinVersionValue <= this.messageVersion - where a.MaxVersionValue >= this.messageVersion - select a).FirstOrDefault(); - if (partAttribute != null) { - MessagePart part = new MessagePart(member, partAttribute); - if (this.mapping.ContainsKey(part.Name)) { - Logger.Messaging.WarnFormat( - "Message type {0} has more than one message part named {1}. Inherited members will be hidden.", - this.messageType.Name, - part.Name); - } else { - this.mapping.Add(part.Name, part); - } - } - } - } - currentType = currentType.BaseType; - } while (currentType != null); - } - - /// <summary> /// Ensures the message parts pass basic validation. /// </summary> /// <param name="parts">The key/value pairs of the serialized message.</param> internal void EnsureMessagePartsPassBasicValidation(IDictionary<string, string> parts) { try { - this.EnsureRequiredMessagePartsArePresent(parts.Keys); - this.EnsureRequiredProtocolMessagePartsAreNotEmpty(parts); + this.CheckRequiredMessagePartsArePresent(parts.Keys, true); + this.CheckRequiredProtocolMessagePartsAreNotEmpty(parts, true); } catch (ProtocolException) { Logger.Messaging.ErrorFormat( "Error while performing basic validation of {0} with these message parts:{1}{2}", - this.messageType.Name, + this.MessageType.Name, Environment.NewLine, parts.ToStringDeferred()); throw; @@ -122,42 +104,134 @@ namespace DotNetOpenAuth.Messaging.Reflection { } /// <summary> + /// Tests whether all the required message parts pass basic validation for the given data. + /// </summary> + /// <param name="parts">The key/value pairs of the serialized message.</param> + /// <returns>A value indicating whether the provided data fits the message's basic requirements.</returns> + internal bool CheckMessagePartsPassBasicValidation(IDictionary<string, string> parts) { + Contract.Requires<ArgumentNullException>(parts != null); + + return this.CheckRequiredMessagePartsArePresent(parts.Keys, false) && + this.CheckRequiredProtocolMessagePartsAreNotEmpty(parts, false); + } + + /// <summary> /// Verifies that a given set of keys include all the required parameters /// for this message type or throws an exception. /// </summary> /// <param name="keys">The names of all parameters included in a message.</param> - /// <exception cref="ProtocolException">Thrown when required parts of a message are not in <paramref name="keys"/></exception> - private void EnsureRequiredMessagePartsArePresent(IEnumerable<string> keys) { + /// <param name="throwOnFailure">if set to <c>true</c> an exception is thrown on failure with details.</param> + /// <returns>A value indicating whether the provided data fits the message's basic requirements.</returns> + /// <exception cref="ProtocolException"> + /// Thrown when required parts of a message are not in <paramref name="keys"/> + /// if <paramref name="throwOnFailure"/> is <c>true</c>. + /// </exception> + private bool CheckRequiredMessagePartsArePresent(IEnumerable<string> keys, bool throwOnFailure) { + Contract.Requires<ArgumentNullException>(keys != null); + var missingKeys = (from part in this.Mapping.Values where part.IsRequired && !keys.Contains(part.Name) select part.Name).ToArray(); if (missingKeys.Length > 0) { - throw new ProtocolException( - string.Format( - CultureInfo.CurrentCulture, + if (throwOnFailure) { + ErrorUtilities.ThrowProtocol( + MessagingStrings.RequiredParametersMissing, + this.MessageType.FullName, + string.Join(", ", missingKeys)); + } else { + Logger.Messaging.DebugFormat( MessagingStrings.RequiredParametersMissing, - this.messageType.FullName, - string.Join(", ", missingKeys))); + this.MessageType.FullName, + missingKeys.ToStringDeferred()); + return false; + } } + + return true; } /// <summary> /// Ensures the protocol message parts that must not be empty are in fact not empty. /// </summary> /// <param name="partValues">A dictionary of key/value pairs that make up the serialized message.</param> - private void EnsureRequiredProtocolMessagePartsAreNotEmpty(IDictionary<string, string> partValues) { + /// <param name="throwOnFailure">if set to <c>true</c> an exception is thrown on failure with details.</param> + /// <returns>A value indicating whether the provided data fits the message's basic requirements.</returns> + /// <exception cref="ProtocolException"> + /// Thrown when required parts of a message are not in <paramref name="partValues"/> + /// if <paramref name="throwOnFailure"/> is <c>true</c>. + /// </exception> + private bool CheckRequiredProtocolMessagePartsAreNotEmpty(IDictionary<string, string> partValues, bool throwOnFailure) { + Contract.Requires<ArgumentNullException>(partValues != null); + string value; var emptyValuedKeys = (from part in this.Mapping.Values where !part.AllowEmpty && partValues.TryGetValue(part.Name, out value) && value != null && value.Length == 0 select part.Name).ToArray(); if (emptyValuedKeys.Length > 0) { - throw new ProtocolException( - string.Format( - CultureInfo.CurrentCulture, + if (throwOnFailure) { + ErrorUtilities.ThrowProtocol( + MessagingStrings.RequiredNonEmptyParameterWasEmpty, + this.MessageType.FullName, + string.Join(", ", emptyValuedKeys)); + } else { + Logger.Messaging.DebugFormat( MessagingStrings.RequiredNonEmptyParameterWasEmpty, - this.messageType.FullName, - string.Join(", ", emptyValuedKeys))); + this.MessageType.FullName, + emptyValuedKeys.ToStringDeferred()); + return false; + } } + + return true; + } + + /// <summary> + /// Reflects over some <see cref="IMessage"/>-implementing type + /// and prepares to serialize/deserialize instances of that type. + /// </summary> + private void ReflectMessageType() { + this.mapping = new Dictionary<string, MessagePart>(); + + Type currentType = this.MessageType; + do { + foreach (MemberInfo member in currentType.GetMembers(BindingFlags.Instance | BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly)) { + if (member is PropertyInfo || member is FieldInfo) { + MessagePartAttribute partAttribute = + (from a in member.GetCustomAttributes(typeof(MessagePartAttribute), true).OfType<MessagePartAttribute>() + orderby a.MinVersionValue descending + where a.MinVersionValue <= this.MessageVersion + where a.MaxVersionValue >= this.MessageVersion + select a).FirstOrDefault(); + if (partAttribute != null) { + MessagePart part = new MessagePart(member, partAttribute); + if (this.mapping.ContainsKey(part.Name)) { + Logger.Messaging.WarnFormat( + "Message type {0} has more than one message part named {1}. Inherited members will be hidden.", + this.MessageType.Name, + part.Name); + } else { + this.mapping.Add(part.Name, part); + } + } + } + } + currentType = currentType.BaseType; + } while (currentType != null); + + BindingFlags flags = BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public; + this.Constructors = this.MessageType.GetConstructors(flags); + } + +#if CONTRACTS_FULL + /// <summary> + /// Describes traits of this class that are always true. + /// </summary> + [ContractInvariantMethod] + private void Invariant() { + Contract.Invariant(this.MessageType != null); + Contract.Invariant(this.MessageVersion != null); + Contract.Invariant(this.Constructors != null); } +#endif } } diff --git a/src/DotNetOpenAuth/Messaging/Reflection/MessageDescriptionCollection.cs b/src/DotNetOpenAuth/Messaging/Reflection/MessageDescriptionCollection.cs index ff8b74b..8911960 100644 --- a/src/DotNetOpenAuth/Messaging/Reflection/MessageDescriptionCollection.cs +++ b/src/DotNetOpenAuth/Messaging/Reflection/MessageDescriptionCollection.cs @@ -14,7 +14,7 @@ namespace DotNetOpenAuth.Messaging.Reflection { /// A cache of <see cref="MessageDescription"/> instances. /// </summary> [ContractVerification(true)] - internal class MessageDescriptionCollection { + internal class MessageDescriptionCollection : IEnumerable<MessageDescription> { /// <summary> /// A dictionary of reflected message types and the generated reflection information. /// </summary> @@ -23,9 +23,37 @@ namespace DotNetOpenAuth.Messaging.Reflection { /// <summary> /// Initializes a new instance of the <see cref="MessageDescriptionCollection"/> class. /// </summary> - public MessageDescriptionCollection() { + internal MessageDescriptionCollection() { } + #region IEnumerable<MessageDescription> Members + + /// <summary> + /// Returns an enumerator that iterates through a collection. + /// </summary> + /// <returns> + /// An <see cref="T:System.Collections.IEnumerator"/> object that can be used to iterate through the collection. + /// </returns> + public IEnumerator<MessageDescription> GetEnumerator() { + return this.reflectedMessageTypes.Values.GetEnumerator(); + } + + #endregion + + #region IEnumerable Members + + /// <summary> + /// Returns an enumerator that iterates through a collection. + /// </summary> + /// <returns> + /// An <see cref="T:System.Collections.IEnumerator"/> object that can be used to iterate through the collection. + /// </returns> + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { + return this.reflectedMessageTypes.Values.GetEnumerator(); + } + + #endregion + /// <summary> /// Gets a <see cref="MessageDescription"/> instance prepared for the /// given message type. diff --git a/src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs b/src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs new file mode 100644 index 0000000..670d750 --- /dev/null +++ b/src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs @@ -0,0 +1,228 @@ +//----------------------------------------------------------------------- +// <copyright file="StandardMessageFactory.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Messaging { + using System; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.Linq; + using System.Reflection; + using System.Text; + using DotNetOpenAuth.Messaging.Reflection; + + /// <summary> + /// A message factory that automatically selects the message type based on the incoming data. + /// </summary> + internal class StandardMessageFactory : IMessageFactory { + /// <summary> + /// The request message types and their constructors to use for instantiating the messages. + /// </summary> + private readonly Dictionary<MessageDescription, ConstructorInfo> requestMessageTypes = new Dictionary<MessageDescription, ConstructorInfo>(); + + /// <summary> + /// The response message types and their constructors to use for instantiating the messages. + /// </summary> + /// <value> + /// The value is a dictionary, whose key is the type of the constructor's lone parameter. + /// </value> + private readonly Dictionary<MessageDescription, Dictionary<Type, ConstructorInfo>> responseMessageTypes = new Dictionary<MessageDescription, Dictionary<Type, ConstructorInfo>>(); + + /// <summary> + /// Initializes a new instance of the <see cref="StandardMessageFactory"/> class. + /// </summary> + internal StandardMessageFactory() { + } + + /// <summary> + /// Adds message types to the set that this factory can create. + /// </summary> + /// <param name="messageTypes">The message types that this factory may instantiate.</param> + public virtual void AddMessageTypes(IEnumerable<MessageDescription> messageTypes) { + Contract.Requires<ArgumentNullException>(messageTypes != null); + Contract.Requires<ArgumentException>(messageTypes.All(msg => msg != null)); + + var unsupportedMessageTypes = new List<MessageDescription>(0); + foreach (MessageDescription messageDescription in messageTypes) { + bool supportedMessageType = false; + + // First see whether this message fits the recognized pattern for request messages. + if (typeof(IDirectedProtocolMessage).IsAssignableFrom(messageDescription.MessageType)) { + foreach (ConstructorInfo ctor in messageDescription.Constructors) { + ParameterInfo[] parameters = ctor.GetParameters(); + if (parameters.Length == 2 && parameters[0].ParameterType == typeof(Uri) && parameters[1].ParameterType == typeof(Version)) { + supportedMessageType = true; + this.requestMessageTypes.Add(messageDescription, ctor); + break; + } + } + } + + // Also see if this message fits the recognized pattern for response messages. + if (typeof(IDirectResponseProtocolMessage).IsAssignableFrom(messageDescription.MessageType)) { + var responseCtors = new Dictionary<Type, ConstructorInfo>(messageDescription.Constructors.Length); + foreach (ConstructorInfo ctor in messageDescription.Constructors) { + ParameterInfo[] parameters = ctor.GetParameters(); + if (parameters.Length == 1 && typeof(IDirectedProtocolMessage).IsAssignableFrom(parameters[0].ParameterType)) { + responseCtors.Add(parameters[0].ParameterType, ctor); + } + } + + if (responseCtors.Count > 0) { + supportedMessageType = true; + this.responseMessageTypes.Add(messageDescription, responseCtors); + } + } + + if (!supportedMessageType) { + unsupportedMessageTypes.Add(messageDescription); + } + } + + ErrorUtilities.VerifySupported( + !unsupportedMessageTypes.Any(), + MessagingStrings.StandardMessageFactoryUnsupportedMessageType, + unsupportedMessageTypes.ToStringDeferred()); + } + + #region IMessageFactory Members + + /// <summary> + /// Analyzes an incoming request message payload to discover what kind of + /// message is embedded in it and returns the type, or null if no match is found. + /// </summary> + /// <param name="recipient">The intended or actual recipient of the request message.</param> + /// <param name="fields">The name/value pairs that make up the message payload.</param> + /// <returns> + /// A newly instantiated <see cref="IProtocolMessage"/>-derived object that this message can + /// deserialize to. Null if the request isn't recognized as a valid protocol message. + /// </returns> + public virtual IDirectedProtocolMessage GetNewRequestMessage(MessageReceivingEndpoint recipient, IDictionary<string, string> fields) { + MessageDescription matchingType = this.GetMessageDescription(recipient, fields); + if (matchingType != null) { + return this.InstantiateAsRequest(matchingType, recipient); + } else { + return null; + } + } + + /// <summary> + /// Analyzes an incoming request message payload to discover what kind of + /// message is embedded in it and returns the type, or null if no match is found. + /// </summary> + /// <param name="request">The message that was sent as a request that resulted in the response.</param> + /// <param name="fields">The name/value pairs that make up the message payload.</param> + /// <returns> + /// A newly instantiated <see cref="IProtocolMessage"/>-derived object that this message can + /// deserialize to. Null if the request isn't recognized as a valid protocol message. + /// </returns> + public virtual IDirectResponseProtocolMessage GetNewResponseMessage(IDirectedProtocolMessage request, IDictionary<string, string> fields) { + MessageDescription matchingType = this.GetMessageDescription(request, fields); + if (matchingType != null) { + return this.InstantiateAsResponse(matchingType, request); + } else { + return null; + } + } + + #endregion + + /// <summary> + /// Gets the message type that best fits the given incoming request data. + /// </summary> + /// <param name="recipient">The recipient of the incoming data. Typically not used, but included just in case.</param> + /// <param name="fields">The data of the incoming message.</param> + /// <returns> + /// The message type that matches the incoming data; or <c>null</c> if no match. + /// </returns> + /// <exception cref="ProtocolException">May be thrown if the incoming data is ambiguous.</exception> + protected virtual MessageDescription GetMessageDescription(MessageReceivingEndpoint recipient, IDictionary<string, string> fields) { + Contract.Requires<ArgumentNullException>(recipient != null); + Contract.Requires<ArgumentNullException>(fields != null); + + var basicMatches = this.requestMessageTypes.Keys.Where(message => message.CheckMessagePartsPassBasicValidation(fields)); + var match = basicMatches.FirstOrDefault(); + if (match != null) { + if (Logger.Messaging.IsDebugEnabled && basicMatches.Count() > 1) { + Logger.Messaging.DebugFormat( + "Multiple message types seemed to fit the incoming data: {0}", + basicMatches.ToStringDeferred()); + } + + return match; + } else { + // No message type matches the incoming data. + return null; + } + } + + /// <summary> + /// Gets the message type that best fits the given incoming direct response data. + /// </summary> + /// <param name="request">The request message that prompted the response data.</param> + /// <param name="fields">The data of the incoming message.</param> + /// <returns> + /// The message type that matches the incoming data; or <c>null</c> if no match. + /// </returns> + /// <exception cref="ProtocolException">May be thrown if the incoming data is ambiguous.</exception> + protected virtual MessageDescription GetMessageDescription(IDirectedProtocolMessage request, IDictionary<string, string> fields) { + var basicMatches = this.responseMessageTypes.Keys.Where(message => message.CheckMessagePartsPassBasicValidation(fields)).CacheGeneratedResults(); + var match = basicMatches.FirstOrDefault(); + if (match != null) { + if (Logger.Messaging.IsDebugEnabled && basicMatches.Count() > 1) { + Logger.Messaging.DebugFormat( + "Multiple message types seemed to fit the incoming data: {0}", + basicMatches.ToStringDeferred()); + } + + return match; + } else { + // No message type matches the incoming data. + return null; + } + } + + /// <summary> + /// Instantiates the given request message type. + /// </summary> + /// <param name="messageDescription">The message description.</param> + /// <param name="recipient">The recipient.</param> + /// <returns>The instantiated message. Never null.</returns> + protected virtual IDirectedProtocolMessage InstantiateAsRequest(MessageDescription messageDescription, MessageReceivingEndpoint recipient) { + Contract.Requires<ArgumentNullException>(messageDescription != null); + Contract.Requires<ArgumentNullException>(recipient != null); + Contract.Ensures(Contract.Result<IDirectedProtocolMessage>() != null); + + ConstructorInfo ctor = this.requestMessageTypes[messageDescription]; + return (IDirectedProtocolMessage)ctor.Invoke(new object[] { recipient.Location, messageDescription.MessageVersion }); + } + + /// <summary> + /// Instantiates the given request message type. + /// </summary> + /// <param name="messageDescription">The message description.</param> + /// <param name="request">The request that resulted in this response.</param> + /// <returns>The instantiated message. Never null.</returns> + protected virtual IDirectResponseProtocolMessage InstantiateAsResponse(MessageDescription messageDescription, IDirectedProtocolMessage request) { + Contract.Requires<ArgumentNullException>(messageDescription != null); + Contract.Requires<ArgumentNullException>(request != null); + Contract.Ensures(Contract.Result<IDirectResponseProtocolMessage>() != null); + + Type requestType = request.GetType(); + var ctors = this.responseMessageTypes[messageDescription].Where(pair => pair.Key.IsAssignableFrom(requestType)); + ConstructorInfo ctor = null; + try { + ctor = ctors.Single().Value; + } catch (InvalidOperationException) { + if (ctors.Any()) { + ErrorUtilities.ThrowInternal("More than one matching constructor for request type " + requestType.Name + " and response type " + messageDescription.MessageType.Name); + } else { + ErrorUtilities.ThrowInternal("Unexpected request message type " + requestType.FullName + " for response type " + messageDescription.MessageType.Name); + } + } + return (IDirectResponseProtocolMessage)ctor.Invoke(new object[] { request }); + } + } +} |