summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorAndrew Arnott <andrewarnott@gmail.com>2010-02-22 08:39:49 -0800
committerAndrew Arnott <andrewarnott@gmail.com>2010-02-22 08:39:49 -0800
commit1b3fac76d3f63830d2a5fbe07b90fbcfb2bb8b8b (patch)
tree1615bc256cbe5c9536ca4492c8296072d57b70c6 /src
parent79ce3eee130d5bc58e44b78acd9b4b9693b07ecc (diff)
parentba1511f42007de1b6439d32a563b0f4a58dcbb53 (diff)
downloadDotNetOpenAuth-1b3fac76d3f63830d2a5fbe07b90fbcfb2bb8b8b.zip
DotNetOpenAuth-1b3fac76d3f63830d2a5fbe07b90fbcfb2bb8b8b.tar.gz
DotNetOpenAuth-1b3fac76d3f63830d2a5fbe07b90fbcfb2bb8b8b.tar.bz2
Merge branch 'standardmessagefactory' into oauthWRAP
Conflicts: src/DotNetOpenAuth/DotNetOpenAuth.csproj
Diffstat (limited to 'src')
-rw-r--r--src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj1
-rw-r--r--src/DotNetOpenAuth.Test/Messaging/Reflection/MessageDescriptionTests.cs10
-rw-r--r--src/DotNetOpenAuth.Test/Messaging/StandardMessageFactoryTests.cs178
-rw-r--r--src/DotNetOpenAuth/DotNetOpenAuth.csproj1
-rw-r--r--src/DotNetOpenAuth/Messaging/MessagingStrings.Designer.cs11
-rw-r--r--src/DotNetOpenAuth/Messaging/MessagingStrings.resx3
-rw-r--r--src/DotNetOpenAuth/Messaging/Reflection/MessageDescription.cs198
-rw-r--r--src/DotNetOpenAuth/Messaging/Reflection/MessageDescriptionCollection.cs32
-rw-r--r--src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs228
9 files changed, 597 insertions, 65 deletions
diff --git a/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj b/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj
index 6540714..a0c849a 100644
--- a/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj
+++ b/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj
@@ -196,6 +196,7 @@
<Compile Include="Messaging\Bindings\StandardExpirationBindingElementTests.cs" />
<Compile Include="Messaging\Reflection\MessagePartTests.cs" />
<Compile Include="Messaging\Reflection\ValueMappingTests.cs" />
+ <Compile Include="Messaging\StandardMessageFactoryTests.cs" />
<Compile Include="Mocks\AssociateUnencryptedRequestNoSslCheck.cs" />
<Compile Include="Mocks\CoordinatingChannel.cs" />
<Compile Include="Mocks\CoordinatingHttpRequestInfo.cs" />
diff --git a/src/DotNetOpenAuth.Test/Messaging/Reflection/MessageDescriptionTests.cs b/src/DotNetOpenAuth.Test/Messaging/Reflection/MessageDescriptionTests.cs
index e57df65..92f39cc 100644
--- a/src/DotNetOpenAuth.Test/Messaging/Reflection/MessageDescriptionTests.cs
+++ b/src/DotNetOpenAuth.Test/Messaging/Reflection/MessageDescriptionTests.cs
@@ -73,6 +73,16 @@ namespace DotNetOpenAuth.Test.Messaging.Reflection {
Assert.IsTrue(v30.Mapping["OptionalIn10RequiredIn25AndLater"].IsRequired);
}
+ /// <summary>
+ /// Verifies that the constructors cache is properly initialized.
+ /// </summary>
+ [TestCase]
+ public void CtorsCache() {
+ var message = new MessageDescription(typeof(MultiVersionMessage), new Version(1, 0));
+ Assert.IsNotNull(message.Constructors);
+ Assert.AreEqual(1, message.Constructors.Length);
+ }
+
private class MultiVersionMessage : Mocks.TestBaseMessage {
#pragma warning disable 0649 // these fields are never written to, but part of the test
[MessagePart]
diff --git a/src/DotNetOpenAuth.Test/Messaging/StandardMessageFactoryTests.cs b/src/DotNetOpenAuth.Test/Messaging/StandardMessageFactoryTests.cs
new file mode 100644
index 0000000..2b0b4e7
--- /dev/null
+++ b/src/DotNetOpenAuth.Test/Messaging/StandardMessageFactoryTests.cs
@@ -0,0 +1,178 @@
+//-----------------------------------------------------------------------
+// <copyright file="StandardMessageFactoryTests.cs" company="Andrew Arnott">
+// Copyright (c) Andrew Arnott. All rights reserved.
+// </copyright>
+//-----------------------------------------------------------------------
+
+namespace DotNetOpenAuth.Test.Messaging {
+ using System;
+ using System.Collections.Generic;
+ using System.Linq;
+ using System.Text;
+ using DotNetOpenAuth.Messaging;
+ using DotNetOpenAuth.Messaging.Reflection;
+ using DotNetOpenAuth.Test.Mocks;
+ using NUnit.Framework;
+
+ [TestFixture]
+ public class StandardMessageFactoryTests : MessagingTestBase {
+ private static readonly Version V1 = new Version(1, 0);
+ private static readonly MessageReceivingEndpoint receiver = new MessageReceivingEndpoint("http://receiver", HttpDeliveryMethods.PostRequest);
+
+ private StandardMessageFactory factory;
+
+ public override void SetUp() {
+ base.SetUp();
+
+ this.factory = new StandardMessageFactory();
+ }
+
+ /// <summary>
+ /// Verifies that AddMessageTypes throws the appropriate exception on null input.
+ /// </summary>
+ [TestCase, ExpectedException(typeof(ArgumentNullException))]
+ public void AddMessageTypesNull() {
+ this.factory.AddMessageTypes(null);
+ }
+
+ /// <summary>
+ /// Verifies that AddMessageTypes throws the appropriate exception on null input.
+ /// </summary>
+ [TestCase, ExpectedException(typeof(ArgumentException))]
+ public void AddMessageTypesNullMessageDescription() {
+ this.factory.AddMessageTypes(new MessageDescription[] { null });
+ }
+
+ /// <summary>
+ /// Verifies very simple recognition of a single message type
+ /// </summary>
+ [TestCase]
+ public void SingleRequestMessageType() {
+ this.factory.AddMessageTypes(new MessageDescription[] { MessageDescriptions.Get(typeof(RequestMessageMock), V1) });
+ var fields = new Dictionary<string, string> {
+ { "random", "bits" },
+ };
+ Assert.IsNull(this.factory.GetNewRequestMessage(receiver, fields));
+ fields["Age"] = "18";
+ Assert.IsInstanceOf(typeof(RequestMessageMock), this.factory.GetNewRequestMessage(receiver, fields));
+ }
+
+ /// <summary>
+ /// Verifies very simple recognition of a single message type
+ /// </summary>
+ [TestCase]
+ public void SingleResponseMessageType() {
+ this.factory.AddMessageTypes(new MessageDescription[] { MessageDescriptions.Get(typeof(DirectResponseMessageMock), V1) });
+ var fields = new Dictionary<string, string> {
+ { "random", "bits" },
+ };
+ IDirectedProtocolMessage request = new RequestMessageMock(receiver.Location, V1);
+ Assert.IsNull(this.factory.GetNewResponseMessage(request, fields));
+ fields["Age"] = "18";
+ IDirectResponseProtocolMessage response = this.factory.GetNewResponseMessage(request, fields);
+ Assert.IsInstanceOf<DirectResponseMessageMock>(response);
+ Assert.AreSame(request, response.OriginatingRequest);
+
+ // Verify that we can instantiate a response with a derived-type of an expected request message.
+ request = new TestSignedDirectedMessage();
+ response = this.factory.GetNewResponseMessage(request, fields);
+ Assert.IsInstanceOf<DirectResponseMessageMock>(response);
+ Assert.AreSame(request, response.OriginatingRequest);
+ }
+
+ private class DirectResponseMessageMock : IDirectResponseProtocolMessage {
+ internal DirectResponseMessageMock(RequestMessageMock request) {
+ this.OriginatingRequest = request;
+ }
+
+ internal DirectResponseMessageMock(TestDirectedMessage request) {
+ this.OriginatingRequest = request;
+ }
+
+ [MessagePart(IsRequired = true)]
+ public int Age { get; set; }
+
+ #region IDirectResponseProtocolMessage Members
+
+ public IDirectedProtocolMessage OriginatingRequest { get; private set; }
+
+ #endregion
+
+ #region IProtocolMessage Members
+
+ public MessageProtections RequiredProtection {
+ get { throw new NotImplementedException(); }
+ }
+
+ public MessageTransport Transport {
+ get { throw new NotImplementedException(); }
+ }
+
+ #endregion
+
+ #region IMessage Members
+
+ public Version Version {
+ get { throw new NotImplementedException(); }
+ }
+
+ public System.Collections.Generic.IDictionary<string, string> ExtraData {
+ get { throw new NotImplementedException(); }
+ }
+
+ public void EnsureValidMessage() {
+ throw new NotImplementedException();
+ }
+
+ #endregion
+ }
+
+ private class RequestMessageMock : IDirectedProtocolMessage {
+ internal RequestMessageMock(Uri recipient, Version version) {
+ }
+
+ [MessagePart(IsRequired = true)]
+ public int Age { get; set; }
+
+ #region IDirectedProtocolMessage Members
+
+ public HttpDeliveryMethods HttpMethods {
+ get { throw new NotImplementedException(); }
+ }
+
+ public Uri Recipient {
+ get { throw new NotImplementedException(); }
+ }
+
+ #endregion
+
+ #region IProtocolMessage Members
+
+ public MessageProtections RequiredProtection {
+ get { throw new NotImplementedException(); }
+ }
+
+ public MessageTransport Transport {
+ get { throw new NotImplementedException(); }
+ }
+
+ #endregion
+
+ #region IMessage Members
+
+ public Version Version {
+ get { throw new NotImplementedException(); }
+ }
+
+ public System.Collections.Generic.IDictionary<string, string> ExtraData {
+ get { throw new NotImplementedException(); }
+ }
+
+ public void EnsureValidMessage() {
+ throw new NotImplementedException();
+ }
+
+ #endregion
+ }
+ }
+}
diff --git a/src/DotNetOpenAuth/DotNetOpenAuth.csproj b/src/DotNetOpenAuth/DotNetOpenAuth.csproj
index 7f5b298..669863c 100644
--- a/src/DotNetOpenAuth/DotNetOpenAuth.csproj
+++ b/src/DotNetOpenAuth/DotNetOpenAuth.csproj
@@ -303,6 +303,7 @@ http://opensource.org/licenses/ms-pl.html
<Compile Include="Messaging\Reflection\IMessagePartEncoder.cs" />
<Compile Include="Messaging\Reflection\IMessagePartNullEncoder.cs" />
<Compile Include="Messaging\Reflection\MessageDescriptionCollection.cs" />
+ <Compile Include="Messaging\StandardMessageFactory.cs" />
<Compile Include="OAuthWrap\IClientTokenManager.cs" />
<Compile Include="OAuthWrap\Messages\Assertion\AssertionRequest.cs" />
<Compile Include="OAuthWrap\Messages\Assertion\AssertionFailedResponse.cs" />
diff --git a/src/DotNetOpenAuth/Messaging/MessagingStrings.Designer.cs b/src/DotNetOpenAuth/Messaging/MessagingStrings.Designer.cs
index 0bbac42..ea3bf6b 100644
--- a/src/DotNetOpenAuth/Messaging/MessagingStrings.Designer.cs
+++ b/src/DotNetOpenAuth/Messaging/MessagingStrings.Designer.cs
@@ -1,7 +1,7 @@
//------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by a tool.
-// Runtime Version:4.0.30104.0
+// Runtime Version:4.0.30128.0
//
// Changes to this file may cause incorrect behavior and will be lost if
// the code is regenerated.
@@ -412,6 +412,15 @@ namespace DotNetOpenAuth.Messaging {
}
/// <summary>
+ /// Looks up a localized string similar to This message factory does not support message type(s): {0}.
+ /// </summary>
+ internal static string StandardMessageFactoryUnsupportedMessageType {
+ get {
+ return ResourceManager.GetString("StandardMessageFactoryUnsupportedMessageType", resourceCulture);
+ }
+ }
+
+ /// <summary>
/// Looks up a localized string similar to The stream must have a known length..
/// </summary>
internal static string StreamMustHaveKnownLength {
diff --git a/src/DotNetOpenAuth/Messaging/MessagingStrings.resx b/src/DotNetOpenAuth/Messaging/MessagingStrings.resx
index 34385d4..8b50767 100644
--- a/src/DotNetOpenAuth/Messaging/MessagingStrings.resx
+++ b/src/DotNetOpenAuth/Messaging/MessagingStrings.resx
@@ -300,4 +300,7 @@
<data name="BinaryDataRequiresMultipart" xml:space="preserve">
<value>Unable to send all message data because some of it requires multi-part POST, but IMessageWithBinaryData.SendAsMultipart was false.</value>
</data>
+ <data name="StandardMessageFactoryUnsupportedMessageType" xml:space="preserve">
+ <value>This message factory does not support message type(s): {0}</value>
+ </data>
</root> \ No newline at end of file
diff --git a/src/DotNetOpenAuth/Messaging/Reflection/MessageDescription.cs b/src/DotNetOpenAuth/Messaging/Reflection/MessageDescription.cs
index 5493ba6..17b5304 100644
--- a/src/DotNetOpenAuth/Messaging/Reflection/MessageDescription.cs
+++ b/src/DotNetOpenAuth/Messaging/Reflection/MessageDescription.cs
@@ -19,16 +19,6 @@ namespace DotNetOpenAuth.Messaging.Reflection {
/// </summary>
internal class MessageDescription {
/// <summary>
- /// The type of message this instance was generated from.
- /// </summary>
- private Type messageType;
-
- /// <summary>
- /// The message version this instance was generated from.
- /// </summary>
- private Version messageVersion;
-
- /// <summary>
/// A mapping between the serialized key names and their
/// describing <see cref="MessagePart"/> instances.
/// </summary>
@@ -44,8 +34,8 @@ namespace DotNetOpenAuth.Messaging.Reflection {
Contract.Requires<ArgumentException>(typeof(IMessage).IsAssignableFrom(messageType));
Contract.Requires<ArgumentNullException>(messageVersion != null);
- this.messageType = messageType;
- this.messageVersion = messageVersion;
+ this.MessageType = messageType;
+ this.MessageVersion = messageVersion;
this.ReflectMessageType();
}
@@ -58,6 +48,32 @@ namespace DotNetOpenAuth.Messaging.Reflection {
}
/// <summary>
+ /// Gets the message version this instance was generated from.
+ /// </summary>
+ internal Version MessageVersion { get; private set; }
+
+ /// <summary>
+ /// Gets the type of message this instance was generated from.
+ /// </summary>
+ /// <value>The type of the described message.</value>
+ internal Type MessageType { get; private set; }
+
+ /// <summary>
+ /// Gets the constructors available on the message type.
+ /// </summary>
+ internal ConstructorInfo[] Constructors { get; private set; }
+
+ /// <summary>
+ /// Returns a <see cref="System.String"/> that represents this instance.
+ /// </summary>
+ /// <returns>
+ /// A <see cref="System.String"/> that represents this instance.
+ /// </returns>
+ public override string ToString() {
+ return this.MessageType.Name + " (" + this.MessageVersion + ")";
+ }
+
+ /// <summary>
/// Gets a dictionary that provides read/write access to a message.
/// </summary>
/// <param name="message">The message the dictionary should provide access to.</param>
@@ -70,51 +86,17 @@ namespace DotNetOpenAuth.Messaging.Reflection {
}
/// <summary>
- /// Reflects over some <see cref="IMessage"/>-implementing type
- /// and prepares to serialize/deserialize instances of that type.
- /// </summary>
- internal void ReflectMessageType() {
- this.mapping = new Dictionary<string, MessagePart>();
-
- Type currentType = this.messageType;
- do {
- foreach (MemberInfo member in currentType.GetMembers(BindingFlags.Instance | BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly)) {
- if (member is PropertyInfo || member is FieldInfo) {
- MessagePartAttribute partAttribute =
- (from a in member.GetCustomAttributes(typeof(MessagePartAttribute), true).OfType<MessagePartAttribute>()
- orderby a.MinVersionValue descending
- where a.MinVersionValue <= this.messageVersion
- where a.MaxVersionValue >= this.messageVersion
- select a).FirstOrDefault();
- if (partAttribute != null) {
- MessagePart part = new MessagePart(member, partAttribute);
- if (this.mapping.ContainsKey(part.Name)) {
- Logger.Messaging.WarnFormat(
- "Message type {0} has more than one message part named {1}. Inherited members will be hidden.",
- this.messageType.Name,
- part.Name);
- } else {
- this.mapping.Add(part.Name, part);
- }
- }
- }
- }
- currentType = currentType.BaseType;
- } while (currentType != null);
- }
-
- /// <summary>
/// Ensures the message parts pass basic validation.
/// </summary>
/// <param name="parts">The key/value pairs of the serialized message.</param>
internal void EnsureMessagePartsPassBasicValidation(IDictionary<string, string> parts) {
try {
- this.EnsureRequiredMessagePartsArePresent(parts.Keys);
- this.EnsureRequiredProtocolMessagePartsAreNotEmpty(parts);
+ this.CheckRequiredMessagePartsArePresent(parts.Keys, true);
+ this.CheckRequiredProtocolMessagePartsAreNotEmpty(parts, true);
} catch (ProtocolException) {
Logger.Messaging.ErrorFormat(
"Error while performing basic validation of {0} with these message parts:{1}{2}",
- this.messageType.Name,
+ this.MessageType.Name,
Environment.NewLine,
parts.ToStringDeferred());
throw;
@@ -122,42 +104,134 @@ namespace DotNetOpenAuth.Messaging.Reflection {
}
/// <summary>
+ /// Tests whether all the required message parts pass basic validation for the given data.
+ /// </summary>
+ /// <param name="parts">The key/value pairs of the serialized message.</param>
+ /// <returns>A value indicating whether the provided data fits the message's basic requirements.</returns>
+ internal bool CheckMessagePartsPassBasicValidation(IDictionary<string, string> parts) {
+ Contract.Requires<ArgumentNullException>(parts != null);
+
+ return this.CheckRequiredMessagePartsArePresent(parts.Keys, false) &&
+ this.CheckRequiredProtocolMessagePartsAreNotEmpty(parts, false);
+ }
+
+ /// <summary>
/// Verifies that a given set of keys include all the required parameters
/// for this message type or throws an exception.
/// </summary>
/// <param name="keys">The names of all parameters included in a message.</param>
- /// <exception cref="ProtocolException">Thrown when required parts of a message are not in <paramref name="keys"/></exception>
- private void EnsureRequiredMessagePartsArePresent(IEnumerable<string> keys) {
+ /// <param name="throwOnFailure">if set to <c>true</c> an exception is thrown on failure with details.</param>
+ /// <returns>A value indicating whether the provided data fits the message's basic requirements.</returns>
+ /// <exception cref="ProtocolException">
+ /// Thrown when required parts of a message are not in <paramref name="keys"/>
+ /// if <paramref name="throwOnFailure"/> is <c>true</c>.
+ /// </exception>
+ private bool CheckRequiredMessagePartsArePresent(IEnumerable<string> keys, bool throwOnFailure) {
+ Contract.Requires<ArgumentNullException>(keys != null);
+
var missingKeys = (from part in this.Mapping.Values
where part.IsRequired && !keys.Contains(part.Name)
select part.Name).ToArray();
if (missingKeys.Length > 0) {
- throw new ProtocolException(
- string.Format(
- CultureInfo.CurrentCulture,
+ if (throwOnFailure) {
+ ErrorUtilities.ThrowProtocol(
+ MessagingStrings.RequiredParametersMissing,
+ this.MessageType.FullName,
+ string.Join(", ", missingKeys));
+ } else {
+ Logger.Messaging.DebugFormat(
MessagingStrings.RequiredParametersMissing,
- this.messageType.FullName,
- string.Join(", ", missingKeys)));
+ this.MessageType.FullName,
+ missingKeys.ToStringDeferred());
+ return false;
+ }
}
+
+ return true;
}
/// <summary>
/// Ensures the protocol message parts that must not be empty are in fact not empty.
/// </summary>
/// <param name="partValues">A dictionary of key/value pairs that make up the serialized message.</param>
- private void EnsureRequiredProtocolMessagePartsAreNotEmpty(IDictionary<string, string> partValues) {
+ /// <param name="throwOnFailure">if set to <c>true</c> an exception is thrown on failure with details.</param>
+ /// <returns>A value indicating whether the provided data fits the message's basic requirements.</returns>
+ /// <exception cref="ProtocolException">
+ /// Thrown when required parts of a message are not in <paramref name="partValues"/>
+ /// if <paramref name="throwOnFailure"/> is <c>true</c>.
+ /// </exception>
+ private bool CheckRequiredProtocolMessagePartsAreNotEmpty(IDictionary<string, string> partValues, bool throwOnFailure) {
+ Contract.Requires<ArgumentNullException>(partValues != null);
+
string value;
var emptyValuedKeys = (from part in this.Mapping.Values
where !part.AllowEmpty && partValues.TryGetValue(part.Name, out value) && value != null && value.Length == 0
select part.Name).ToArray();
if (emptyValuedKeys.Length > 0) {
- throw new ProtocolException(
- string.Format(
- CultureInfo.CurrentCulture,
+ if (throwOnFailure) {
+ ErrorUtilities.ThrowProtocol(
+ MessagingStrings.RequiredNonEmptyParameterWasEmpty,
+ this.MessageType.FullName,
+ string.Join(", ", emptyValuedKeys));
+ } else {
+ Logger.Messaging.DebugFormat(
MessagingStrings.RequiredNonEmptyParameterWasEmpty,
- this.messageType.FullName,
- string.Join(", ", emptyValuedKeys)));
+ this.MessageType.FullName,
+ emptyValuedKeys.ToStringDeferred());
+ return false;
+ }
}
+
+ return true;
+ }
+
+ /// <summary>
+ /// Reflects over some <see cref="IMessage"/>-implementing type
+ /// and prepares to serialize/deserialize instances of that type.
+ /// </summary>
+ private void ReflectMessageType() {
+ this.mapping = new Dictionary<string, MessagePart>();
+
+ Type currentType = this.MessageType;
+ do {
+ foreach (MemberInfo member in currentType.GetMembers(BindingFlags.Instance | BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly)) {
+ if (member is PropertyInfo || member is FieldInfo) {
+ MessagePartAttribute partAttribute =
+ (from a in member.GetCustomAttributes(typeof(MessagePartAttribute), true).OfType<MessagePartAttribute>()
+ orderby a.MinVersionValue descending
+ where a.MinVersionValue <= this.MessageVersion
+ where a.MaxVersionValue >= this.MessageVersion
+ select a).FirstOrDefault();
+ if (partAttribute != null) {
+ MessagePart part = new MessagePart(member, partAttribute);
+ if (this.mapping.ContainsKey(part.Name)) {
+ Logger.Messaging.WarnFormat(
+ "Message type {0} has more than one message part named {1}. Inherited members will be hidden.",
+ this.MessageType.Name,
+ part.Name);
+ } else {
+ this.mapping.Add(part.Name, part);
+ }
+ }
+ }
+ }
+ currentType = currentType.BaseType;
+ } while (currentType != null);
+
+ BindingFlags flags = BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public;
+ this.Constructors = this.MessageType.GetConstructors(flags);
+ }
+
+#if CONTRACTS_FULL
+ /// <summary>
+ /// Describes traits of this class that are always true.
+ /// </summary>
+ [ContractInvariantMethod]
+ private void Invariant() {
+ Contract.Invariant(this.MessageType != null);
+ Contract.Invariant(this.MessageVersion != null);
+ Contract.Invariant(this.Constructors != null);
}
+#endif
}
}
diff --git a/src/DotNetOpenAuth/Messaging/Reflection/MessageDescriptionCollection.cs b/src/DotNetOpenAuth/Messaging/Reflection/MessageDescriptionCollection.cs
index ff8b74b..8911960 100644
--- a/src/DotNetOpenAuth/Messaging/Reflection/MessageDescriptionCollection.cs
+++ b/src/DotNetOpenAuth/Messaging/Reflection/MessageDescriptionCollection.cs
@@ -14,7 +14,7 @@ namespace DotNetOpenAuth.Messaging.Reflection {
/// A cache of <see cref="MessageDescription"/> instances.
/// </summary>
[ContractVerification(true)]
- internal class MessageDescriptionCollection {
+ internal class MessageDescriptionCollection : IEnumerable<MessageDescription> {
/// <summary>
/// A dictionary of reflected message types and the generated reflection information.
/// </summary>
@@ -23,9 +23,37 @@ namespace DotNetOpenAuth.Messaging.Reflection {
/// <summary>
/// Initializes a new instance of the <see cref="MessageDescriptionCollection"/> class.
/// </summary>
- public MessageDescriptionCollection() {
+ internal MessageDescriptionCollection() {
}
+ #region IEnumerable<MessageDescription> Members
+
+ /// <summary>
+ /// Returns an enumerator that iterates through a collection.
+ /// </summary>
+ /// <returns>
+ /// An <see cref="T:System.Collections.IEnumerator"/> object that can be used to iterate through the collection.
+ /// </returns>
+ public IEnumerator<MessageDescription> GetEnumerator() {
+ return this.reflectedMessageTypes.Values.GetEnumerator();
+ }
+
+ #endregion
+
+ #region IEnumerable Members
+
+ /// <summary>
+ /// Returns an enumerator that iterates through a collection.
+ /// </summary>
+ /// <returns>
+ /// An <see cref="T:System.Collections.IEnumerator"/> object that can be used to iterate through the collection.
+ /// </returns>
+ System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() {
+ return this.reflectedMessageTypes.Values.GetEnumerator();
+ }
+
+ #endregion
+
/// <summary>
/// Gets a <see cref="MessageDescription"/> instance prepared for the
/// given message type.
diff --git a/src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs b/src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs
new file mode 100644
index 0000000..670d750
--- /dev/null
+++ b/src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs
@@ -0,0 +1,228 @@
+//-----------------------------------------------------------------------
+// <copyright file="StandardMessageFactory.cs" company="Andrew Arnott">
+// Copyright (c) Andrew Arnott. All rights reserved.
+// </copyright>
+//-----------------------------------------------------------------------
+
+namespace DotNetOpenAuth.Messaging {
+ using System;
+ using System.Collections.Generic;
+ using System.Diagnostics.Contracts;
+ using System.Linq;
+ using System.Reflection;
+ using System.Text;
+ using DotNetOpenAuth.Messaging.Reflection;
+
+ /// <summary>
+ /// A message factory that automatically selects the message type based on the incoming data.
+ /// </summary>
+ internal class StandardMessageFactory : IMessageFactory {
+ /// <summary>
+ /// The request message types and their constructors to use for instantiating the messages.
+ /// </summary>
+ private readonly Dictionary<MessageDescription, ConstructorInfo> requestMessageTypes = new Dictionary<MessageDescription, ConstructorInfo>();
+
+ /// <summary>
+ /// The response message types and their constructors to use for instantiating the messages.
+ /// </summary>
+ /// <value>
+ /// The value is a dictionary, whose key is the type of the constructor's lone parameter.
+ /// </value>
+ private readonly Dictionary<MessageDescription, Dictionary<Type, ConstructorInfo>> responseMessageTypes = new Dictionary<MessageDescription, Dictionary<Type, ConstructorInfo>>();
+
+ /// <summary>
+ /// Initializes a new instance of the <see cref="StandardMessageFactory"/> class.
+ /// </summary>
+ internal StandardMessageFactory() {
+ }
+
+ /// <summary>
+ /// Adds message types to the set that this factory can create.
+ /// </summary>
+ /// <param name="messageTypes">The message types that this factory may instantiate.</param>
+ public virtual void AddMessageTypes(IEnumerable<MessageDescription> messageTypes) {
+ Contract.Requires<ArgumentNullException>(messageTypes != null);
+ Contract.Requires<ArgumentException>(messageTypes.All(msg => msg != null));
+
+ var unsupportedMessageTypes = new List<MessageDescription>(0);
+ foreach (MessageDescription messageDescription in messageTypes) {
+ bool supportedMessageType = false;
+
+ // First see whether this message fits the recognized pattern for request messages.
+ if (typeof(IDirectedProtocolMessage).IsAssignableFrom(messageDescription.MessageType)) {
+ foreach (ConstructorInfo ctor in messageDescription.Constructors) {
+ ParameterInfo[] parameters = ctor.GetParameters();
+ if (parameters.Length == 2 && parameters[0].ParameterType == typeof(Uri) && parameters[1].ParameterType == typeof(Version)) {
+ supportedMessageType = true;
+ this.requestMessageTypes.Add(messageDescription, ctor);
+ break;
+ }
+ }
+ }
+
+ // Also see if this message fits the recognized pattern for response messages.
+ if (typeof(IDirectResponseProtocolMessage).IsAssignableFrom(messageDescription.MessageType)) {
+ var responseCtors = new Dictionary<Type, ConstructorInfo>(messageDescription.Constructors.Length);
+ foreach (ConstructorInfo ctor in messageDescription.Constructors) {
+ ParameterInfo[] parameters = ctor.GetParameters();
+ if (parameters.Length == 1 && typeof(IDirectedProtocolMessage).IsAssignableFrom(parameters[0].ParameterType)) {
+ responseCtors.Add(parameters[0].ParameterType, ctor);
+ }
+ }
+
+ if (responseCtors.Count > 0) {
+ supportedMessageType = true;
+ this.responseMessageTypes.Add(messageDescription, responseCtors);
+ }
+ }
+
+ if (!supportedMessageType) {
+ unsupportedMessageTypes.Add(messageDescription);
+ }
+ }
+
+ ErrorUtilities.VerifySupported(
+ !unsupportedMessageTypes.Any(),
+ MessagingStrings.StandardMessageFactoryUnsupportedMessageType,
+ unsupportedMessageTypes.ToStringDeferred());
+ }
+
+ #region IMessageFactory Members
+
+ /// <summary>
+ /// Analyzes an incoming request message payload to discover what kind of
+ /// message is embedded in it and returns the type, or null if no match is found.
+ /// </summary>
+ /// <param name="recipient">The intended or actual recipient of the request message.</param>
+ /// <param name="fields">The name/value pairs that make up the message payload.</param>
+ /// <returns>
+ /// A newly instantiated <see cref="IProtocolMessage"/>-derived object that this message can
+ /// deserialize to. Null if the request isn't recognized as a valid protocol message.
+ /// </returns>
+ public virtual IDirectedProtocolMessage GetNewRequestMessage(MessageReceivingEndpoint recipient, IDictionary<string, string> fields) {
+ MessageDescription matchingType = this.GetMessageDescription(recipient, fields);
+ if (matchingType != null) {
+ return this.InstantiateAsRequest(matchingType, recipient);
+ } else {
+ return null;
+ }
+ }
+
+ /// <summary>
+ /// Analyzes an incoming request message payload to discover what kind of
+ /// message is embedded in it and returns the type, or null if no match is found.
+ /// </summary>
+ /// <param name="request">The message that was sent as a request that resulted in the response.</param>
+ /// <param name="fields">The name/value pairs that make up the message payload.</param>
+ /// <returns>
+ /// A newly instantiated <see cref="IProtocolMessage"/>-derived object that this message can
+ /// deserialize to. Null if the request isn't recognized as a valid protocol message.
+ /// </returns>
+ public virtual IDirectResponseProtocolMessage GetNewResponseMessage(IDirectedProtocolMessage request, IDictionary<string, string> fields) {
+ MessageDescription matchingType = this.GetMessageDescription(request, fields);
+ if (matchingType != null) {
+ return this.InstantiateAsResponse(matchingType, request);
+ } else {
+ return null;
+ }
+ }
+
+ #endregion
+
+ /// <summary>
+ /// Gets the message type that best fits the given incoming request data.
+ /// </summary>
+ /// <param name="recipient">The recipient of the incoming data. Typically not used, but included just in case.</param>
+ /// <param name="fields">The data of the incoming message.</param>
+ /// <returns>
+ /// The message type that matches the incoming data; or <c>null</c> if no match.
+ /// </returns>
+ /// <exception cref="ProtocolException">May be thrown if the incoming data is ambiguous.</exception>
+ protected virtual MessageDescription GetMessageDescription(MessageReceivingEndpoint recipient, IDictionary<string, string> fields) {
+ Contract.Requires<ArgumentNullException>(recipient != null);
+ Contract.Requires<ArgumentNullException>(fields != null);
+
+ var basicMatches = this.requestMessageTypes.Keys.Where(message => message.CheckMessagePartsPassBasicValidation(fields));
+ var match = basicMatches.FirstOrDefault();
+ if (match != null) {
+ if (Logger.Messaging.IsDebugEnabled && basicMatches.Count() > 1) {
+ Logger.Messaging.DebugFormat(
+ "Multiple message types seemed to fit the incoming data: {0}",
+ basicMatches.ToStringDeferred());
+ }
+
+ return match;
+ } else {
+ // No message type matches the incoming data.
+ return null;
+ }
+ }
+
+ /// <summary>
+ /// Gets the message type that best fits the given incoming direct response data.
+ /// </summary>
+ /// <param name="request">The request message that prompted the response data.</param>
+ /// <param name="fields">The data of the incoming message.</param>
+ /// <returns>
+ /// The message type that matches the incoming data; or <c>null</c> if no match.
+ /// </returns>
+ /// <exception cref="ProtocolException">May be thrown if the incoming data is ambiguous.</exception>
+ protected virtual MessageDescription GetMessageDescription(IDirectedProtocolMessage request, IDictionary<string, string> fields) {
+ var basicMatches = this.responseMessageTypes.Keys.Where(message => message.CheckMessagePartsPassBasicValidation(fields)).CacheGeneratedResults();
+ var match = basicMatches.FirstOrDefault();
+ if (match != null) {
+ if (Logger.Messaging.IsDebugEnabled && basicMatches.Count() > 1) {
+ Logger.Messaging.DebugFormat(
+ "Multiple message types seemed to fit the incoming data: {0}",
+ basicMatches.ToStringDeferred());
+ }
+
+ return match;
+ } else {
+ // No message type matches the incoming data.
+ return null;
+ }
+ }
+
+ /// <summary>
+ /// Instantiates the given request message type.
+ /// </summary>
+ /// <param name="messageDescription">The message description.</param>
+ /// <param name="recipient">The recipient.</param>
+ /// <returns>The instantiated message. Never null.</returns>
+ protected virtual IDirectedProtocolMessage InstantiateAsRequest(MessageDescription messageDescription, MessageReceivingEndpoint recipient) {
+ Contract.Requires<ArgumentNullException>(messageDescription != null);
+ Contract.Requires<ArgumentNullException>(recipient != null);
+ Contract.Ensures(Contract.Result<IDirectedProtocolMessage>() != null);
+
+ ConstructorInfo ctor = this.requestMessageTypes[messageDescription];
+ return (IDirectedProtocolMessage)ctor.Invoke(new object[] { recipient.Location, messageDescription.MessageVersion });
+ }
+
+ /// <summary>
+ /// Instantiates the given request message type.
+ /// </summary>
+ /// <param name="messageDescription">The message description.</param>
+ /// <param name="request">The request that resulted in this response.</param>
+ /// <returns>The instantiated message. Never null.</returns>
+ protected virtual IDirectResponseProtocolMessage InstantiateAsResponse(MessageDescription messageDescription, IDirectedProtocolMessage request) {
+ Contract.Requires<ArgumentNullException>(messageDescription != null);
+ Contract.Requires<ArgumentNullException>(request != null);
+ Contract.Ensures(Contract.Result<IDirectResponseProtocolMessage>() != null);
+
+ Type requestType = request.GetType();
+ var ctors = this.responseMessageTypes[messageDescription].Where(pair => pair.Key.IsAssignableFrom(requestType));
+ ConstructorInfo ctor = null;
+ try {
+ ctor = ctors.Single().Value;
+ } catch (InvalidOperationException) {
+ if (ctors.Any()) {
+ ErrorUtilities.ThrowInternal("More than one matching constructor for request type " + requestType.Name + " and response type " + messageDescription.MessageType.Name);
+ } else {
+ ErrorUtilities.ThrowInternal("Unexpected request message type " + requestType.FullName + " for response type " + messageDescription.MessageType.Name);
+ }
+ }
+ return (IDirectResponseProtocolMessage)ctor.Invoke(new object[] { request });
+ }
+ }
+}