summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/DotNetOAuth.Test/DotNetOAuth.Test.csproj8
-rw-r--r--src/DotNetOAuth.Test/Messaging/ChannelTests.cs33
-rw-r--r--src/DotNetOAuth.Test/Messaging/CollectionAssert.cs19
-rw-r--r--src/DotNetOAuth.Test/Messaging/DictionaryXmlReaderTests.cs33
-rw-r--r--src/DotNetOAuth.Test/Messaging/MessageSerializerTests.cs73
-rw-r--r--src/DotNetOAuth.Test/Messaging/MessagingTestBase.cs100
-rw-r--r--src/DotNetOAuth.Test/Messaging/Reflection/MessageDescriptionTests.cs21
-rw-r--r--src/DotNetOAuth.Test/Messaging/Reflection/MessageDictionaryTests.cs347
-rw-r--r--src/DotNetOAuth.Test/Messaging/Reflection/MessagePartTests.cs99
-rw-r--r--src/DotNetOAuth.Test/Messaging/Reflection/ValueMappingTests.cs21
-rw-r--r--src/DotNetOAuth.Test/Mocks/TestBaseMessage.cs16
-rw-r--r--src/DotNetOAuth.Test/Mocks/TestDerivedMessage.cs5
-rw-r--r--src/DotNetOAuth.Test/Mocks/TestDirectedMessage.cs37
-rw-r--r--src/DotNetOAuth.Test/Mocks/TestExpiringMessage.cs5
-rw-r--r--src/DotNetOAuth.Test/Mocks/TestMessage.cs20
-rw-r--r--src/DotNetOAuth.Test/Mocks/TestReplayProtectedMessage.cs5
-rw-r--r--src/DotNetOAuth.Test/Mocks/TestSignedDirectedMessage.cs5
-rw-r--r--src/DotNetOAuth.Test/OAuthChannelTests.cs7
-rw-r--r--src/DotNetOAuth/DotNetOAuth.csproj10
-rw-r--r--src/DotNetOAuth/Messaging/Channel.cs21
-rw-r--r--src/DotNetOAuth/Messaging/DataContractMemberComparer.cs103
-rw-r--r--src/DotNetOAuth/Messaging/DictionaryXmlReader.cs92
-rw-r--r--src/DotNetOAuth/Messaging/DictionaryXmlWriter.cs273
-rw-r--r--src/DotNetOAuth/Messaging/IProtocolMessage.cs9
-rw-r--r--src/DotNetOAuth/Messaging/MessageSerializer.cs121
-rw-r--r--src/DotNetOAuth/Messaging/MessagingStrings.Designer.cs27
-rw-r--r--src/DotNetOAuth/Messaging/MessagingStrings.resx9
-rw-r--r--src/DotNetOAuth/Messaging/ProtocolException.cs13
-rw-r--r--src/DotNetOAuth/Messaging/Reflection/MessageDescription.cs91
-rw-r--r--src/DotNetOAuth/Messaging/Reflection/MessageDictionary.cs206
-rw-r--r--src/DotNetOAuth/Messaging/Reflection/MessagePart.cs148
-rw-r--r--src/DotNetOAuth/Messaging/Reflection/MessagePartAttribute.cs33
-rw-r--r--src/DotNetOAuth/Messaging/Reflection/ValueMapping.cs27
-rw-r--r--src/DotNetOAuth/StandardWebRequestHandler.cs3
34 files changed, 1289 insertions, 751 deletions
diff --git a/src/DotNetOAuth.Test/DotNetOAuth.Test.csproj b/src/DotNetOAuth.Test/DotNetOAuth.Test.csproj
index 09b5974..dfa7fef 100644
--- a/src/DotNetOAuth.Test/DotNetOAuth.Test.csproj
+++ b/src/DotNetOAuth.Test/DotNetOAuth.Test.csproj
@@ -58,13 +58,18 @@
</Reference>
</ItemGroup>
<ItemGroup>
+ <Compile Include="Messaging\CollectionAssert.cs" />
+ <Compile Include="Messaging\MessageSerializerTests.cs" />
+ <Compile Include="Messaging\Reflection\MessageDescriptionTests.cs" />
+ <Compile Include="Messaging\Reflection\MessageDictionaryTests.cs" />
<Compile Include="Messaging\MessagingTestBase.cs" />
<Compile Include="Messaging\MessagingUtilitiesTests.cs" />
<Compile Include="Messaging\ChannelTests.cs" />
- <Compile Include="Messaging\DictionaryXmlReaderTests.cs" />
<Compile Include="Messaging\HttpRequestInfoTests.cs" />
<Compile Include="Messaging\ProtocolExceptionTests.cs" />
<Compile Include="Messaging\Bindings\StandardExpirationBindingElementTests.cs" />
+ <Compile Include="Messaging\Reflection\MessagePartTests.cs" />
+ <Compile Include="Messaging\Reflection\ValueMappingTests.cs" />
<Compile Include="Mocks\MockTransformationBindingElement.cs" />
<Compile Include="Mocks\MockReplayProtectionBindingElement.cs" />
<Compile Include="Mocks\TestBaseMessage.cs" />
@@ -77,7 +82,6 @@
<Compile Include="Mocks\MockSigningBindingElement.cs" />
<Compile Include="Mocks\TestWebRequestHandler.cs" />
<Compile Include="OAuthChannelTests.cs" />
- <Compile Include="Messaging\MessageSerializerTests.cs" />
<Compile Include="Mocks\TestChannel.cs" />
<Compile Include="Mocks\TestMessage.cs" />
<Compile Include="Mocks\TestMessageTypeProvider.cs" />
diff --git a/src/DotNetOAuth.Test/Messaging/ChannelTests.cs b/src/DotNetOAuth.Test/Messaging/ChannelTests.cs
index 5ffd1d6..e8c4514 100644
--- a/src/DotNetOAuth.Test/Messaging/ChannelTests.cs
+++ b/src/DotNetOAuth.Test/Messaging/ChannelTests.cs
@@ -14,6 +14,7 @@ namespace DotNetOAuth.Test.Messaging {
using DotNetOAuth.Messaging.Bindings;
using DotNetOAuth.Test.Mocks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
+ using System.Xml;
[TestClass]
public class ChannelTests : MessagingTestBase {
@@ -64,19 +65,21 @@ namespace DotNetOAuth.Test.Messaging {
[TestMethod]
public void SendIndirectMessage301Get() {
- IProtocolMessage message = new TestDirectedMessage(MessageTransport.Indirect) {
- Age = 15,
- Name = "Andrew",
- Location = new Uri("http://host/path"),
- Recipient = new Uri("http://provider/path"),
- };
+ TestDirectedMessage message = new TestDirectedMessage(MessageTransport.Indirect);
+ GetStandardTestMessage(FieldFill.CompleteBeforeBindings, message);
+ message.Recipient = new Uri("http://provider/path");
+ var expected = GetStandardTestFields(FieldFill.CompleteBeforeBindings);
+
this.Channel.Send(message);
Response response = this.Channel.DequeueIndirectOrResponseMessage();
Assert.AreEqual(HttpStatusCode.Redirect, response.Status);
StringAssert.StartsWith(response.Headers[HttpResponseHeader.Location], "http://provider/path");
- StringAssert.Contains(response.Headers[HttpResponseHeader.Location], "age=15");
- StringAssert.Contains(response.Headers[HttpResponseHeader.Location], "Name=Andrew");
- StringAssert.Contains(response.Headers[HttpResponseHeader.Location], "Location=http%3a%2f%2fhost%2fpath");
+ foreach (var pair in expected) {
+ string key = HttpUtility.UrlEncode(pair.Key);
+ string value = HttpUtility.UrlEncode(pair.Value);
+ string substring = string.Format("{0}={1}", key, value);
+ StringAssert.Contains(response.Headers[HttpResponseHeader.Location], substring);
+ }
}
[TestMethod, ExpectedException(typeof(ArgumentNullException))]
@@ -198,12 +201,14 @@ namespace DotNetOAuth.Test.Messaging {
[TestMethod]
public void ReadFromRequestWithContext() {
// TODO: make this a request with a message in it.
- HttpRequest request = new HttpRequest("somefile", "http://someurl", "age=15");
+ var fields = GetStandardTestFields(FieldFill.AllRequired);
+ TestMessage expectedMessage = GetStandardTestMessage(FieldFill.AllRequired);
+ HttpRequest request = new HttpRequest("somefile", "http://someurl", MessagingUtilities.CreateQueryString(fields));
HttpContext.Current = new HttpContext(request, new HttpResponse(new StringWriter()));
IProtocolMessage message = this.Channel.ReadFromRequest();
Assert.IsNotNull(message);
Assert.IsInstanceOfType(message, typeof(TestMessage));
- Assert.AreEqual(15, ((TestMessage)message).Age);
+ Assert.AreEqual(expectedMessage.Age, ((TestMessage)message).Age);
}
[TestMethod, ExpectedException(typeof(InvalidOperationException))]
@@ -298,5 +303,11 @@ namespace DotNetOAuth.Test.Messaging {
this.Channel = CreateChannel(MessageProtection.None, MessageProtection.TamperProtection);
this.ParameterizedReceiveProtectedTest(DateTime.Now, false);
}
+
+ [TestMethod, ExpectedException(typeof(ProtocolException))]
+ public void IncomingMessageMissingRequiredParameters() {
+ var fields = GetStandardTestFields(FieldFill.IdentifiableButNotAllRequired);
+ this.Channel.ReadFromRequest(CreateHttpRequestInfo("GET", fields));
+ }
}
}
diff --git a/src/DotNetOAuth.Test/Messaging/CollectionAssert.cs b/src/DotNetOAuth.Test/Messaging/CollectionAssert.cs
new file mode 100644
index 0000000..b9f3da5
--- /dev/null
+++ b/src/DotNetOAuth.Test/Messaging/CollectionAssert.cs
@@ -0,0 +1,19 @@
+//-----------------------------------------------------------------------
+// <copyright file="CollectionAssert.cs" company="Andrew Arnott">
+// Copyright (c) Andrew Arnott. All rights reserved.
+// </copyright>
+//-----------------------------------------------------------------------
+
+namespace DotNetOAuth.Test.Messaging {
+ using System.Collections;
+ using System.Collections.Generic;
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ internal class CollectionAssert<T> {
+ internal static void AreEquivalent(ICollection<T> expected, ICollection<T> actual) {
+ ICollection expectedNonGeneric = new List<T>(expected);
+ ICollection actualNonGeneric = new List<T>(actual);
+ CollectionAssert.AreEquivalent(expectedNonGeneric, actualNonGeneric);
+ }
+ }
+}
diff --git a/src/DotNetOAuth.Test/Messaging/DictionaryXmlReaderTests.cs b/src/DotNetOAuth.Test/Messaging/DictionaryXmlReaderTests.cs
deleted file mode 100644
index 54c47f0..0000000
--- a/src/DotNetOAuth.Test/Messaging/DictionaryXmlReaderTests.cs
+++ /dev/null
@@ -1,33 +0,0 @@
-//-----------------------------------------------------------------------
-// <copyright file="DictionaryXmlReaderTests.cs" company="Andrew Arnott">
-// Copyright (c) Andrew Arnott. All rights reserved.
-// </copyright>
-//-----------------------------------------------------------------------
-
-namespace DotNetOAuth.Test.Messaging {
- using System;
- using System.Collections.Generic;
- using System.Xml.Linq;
- using DotNetOAuth.Messaging;
- using Microsoft.VisualStudio.TestTools.UnitTesting;
-
- [TestClass]
- public class DictionaryXmlReaderTests : TestBase {
- [TestMethod, ExpectedException(typeof(ArgumentNullException))]
- public void CreateWithNullRootElement() {
- IComparer<string> fieldSorter = new DataContractMemberComparer(typeof(Mocks.TestMessage));
- DictionaryXmlReader.Create(null, fieldSorter, new Dictionary<string, string>());
- }
-
- [TestMethod, ExpectedException(typeof(ArgumentNullException))]
- public void CreateWithNullDataContractType() {
- DictionaryXmlReader.Create(XName.Get("name", "ns"), null, new Dictionary<string, string>());
- }
-
- [TestMethod, ExpectedException(typeof(ArgumentNullException))]
- public void CreateWithNullFields() {
- IComparer<string> fieldSorter = new DataContractMemberComparer(typeof(Mocks.TestMessage));
- DictionaryXmlReader.Create(XName.Get("name", "ns"), fieldSorter, null);
- }
- }
-}
diff --git a/src/DotNetOAuth.Test/Messaging/MessageSerializerTests.cs b/src/DotNetOAuth.Test/Messaging/MessageSerializerTests.cs
index 162d456..a41753b 100644
--- a/src/DotNetOAuth.Test/Messaging/MessageSerializerTests.cs
+++ b/src/DotNetOAuth.Test/Messaging/MessageSerializerTests.cs
@@ -9,30 +9,19 @@ namespace DotNetOAuth.Test.Messaging {
using System.Collections.Generic;
using DotNetOAuth.Messaging;
using Microsoft.VisualStudio.TestTools.UnitTesting;
+ using System.Xml;
/// <summary>
/// Tests for the <see cref="MessageSerializer"/> class.
/// </summary>
[TestClass()]
- public class MessageSerializerTests : TestBase {
+ public class MessageSerializerTests : MessagingTestBase {
[TestMethod, ExpectedException(typeof(ArgumentNullException))]
public void SerializeNull() {
var serializer = MessageSerializer.Get(typeof(Mocks.TestMessage));
serializer.Serialize(null);
}
- [TestMethod, ExpectedException(typeof(ArgumentNullException))]
- public void SerializeNullFields() {
- var serializer = MessageSerializer.Get(typeof(Mocks.TestMessage));
- serializer.Serialize(null, new Mocks.TestMessage());
- }
-
- [TestMethod, ExpectedException(typeof(ArgumentNullException))]
- public void SerializeNullMessage() {
- var serializer = MessageSerializer.Get(typeof(Mocks.TestMessage));
- serializer.Serialize(new Dictionary<string, string>(), null);
- }
-
[TestMethod, ExpectedException(typeof(ArgumentException))]
public void GetInvalidMessageType() {
MessageSerializer.Get(typeof(string));
@@ -43,29 +32,11 @@ namespace DotNetOAuth.Test.Messaging {
MessageSerializer.Get(null);
}
- [TestMethod]
- public void GetReturnsSameSerializerTwice() {
- Assert.AreSame(MessageSerializer.Get(typeof(Mocks.TestMessage)), MessageSerializer.Get(typeof(Mocks.TestMessage)));
- }
-
- [TestMethod, ExpectedException(typeof(ProtocolException))]
- public void SerializeInvalidMessage() {
- var serializer = MessageSerializer.Get(typeof(Mocks.TestMessage));
- Dictionary<string, string> fields = new Dictionary<string, string>(StringComparer.Ordinal);
- Mocks.TestMessage message = new Mocks.TestMessage();
- message.EmptyMember = "invalidvalue";
- serializer.Serialize(message);
- }
-
[TestMethod()]
public void SerializeTest() {
var serializer = MessageSerializer.Get(typeof(Mocks.TestMessage));
- var message = new Mocks.TestMessage {
- Age = 15,
- Name = "Andrew",
- Location = new Uri("http://localhost"),
- Timestamp = DateTime.Parse("1/1/1990"),
- };
+ var message = GetStandardTestMessage(FieldFill.CompleteBeforeBindings);
+ var expected = GetStandardTestFields(FieldFill.CompleteBeforeBindings);
IDictionary<string, string> actual = serializer.Serialize(message);
Assert.AreEqual(4, actual.Count);
@@ -74,26 +45,13 @@ namespace DotNetOAuth.Test.Messaging {
Assert.IsTrue(actual.ContainsKey("age"));
// Test contents of dictionary
- Assert.AreEqual("15", actual["age"]);
- Assert.AreEqual("Andrew", actual["Name"]);
- Assert.AreEqual("http://localhost/", actual["Location"]);
- Assert.AreEqual("1990-01-01T00:00:00", actual["Timestamp"]);
+ Assert.AreEqual(expected["age"], actual["age"]);
+ Assert.AreEqual(expected["Name"], actual["Name"]);
+ Assert.AreEqual(expected["Location"], actual["Location"]);
+ Assert.AreEqual(expected["Timestamp"], actual["Timestamp"]);
Assert.IsFalse(actual.ContainsKey("EmptyMember"));
}
- [TestMethod]
- public void SerializeToExistingDictionary() {
- var serializer = MessageSerializer.Get(typeof(Mocks.TestMessage));
- var message = new Mocks.TestMessage { Age = 15, Name = "Andrew" };
- var fields = new Dictionary<string, string>();
- fields["someExtraField"] = "someValue";
- serializer.Serialize(fields, message);
- Assert.AreEqual(4, fields.Count);
- Assert.AreEqual("15", fields["age"]);
- Assert.AreEqual("Andrew", fields["Name"]);
- Assert.AreEqual("someValue", fields["someExtraField"]);
- }
-
[TestMethod, ExpectedException(typeof(ArgumentNullException))]
public void DeserializeNull() {
var serializer = MessageSerializer.Get(typeof(Mocks.TestMessage));
@@ -115,7 +73,7 @@ namespace DotNetOAuth.Test.Messaging {
}
/// <summary>
- /// This tests deserialization of a message that is comprised of [DataMember]'s
+ /// This tests deserialization of a message that is comprised of [MessagePart]'s
/// that are defined in multiple places in the inheritance tree.
/// </summary>
/// <remarks>
@@ -151,6 +109,7 @@ namespace DotNetOAuth.Test.Messaging {
Dictionary<string, string> fields = new Dictionary<string, string>(StringComparer.Ordinal);
fields["age"] = "15";
fields["Name"] = "Andrew";
+ fields["Timestamp"] = XmlConvert.ToString(DateTime.UtcNow, XmlDateTimeSerializationMode.Utc);
// Add some field that is not recognized by the class. This simulates a querystring with
// more parameters than are actually interesting to the protocol message.
fields["someExtraField"] = "asdf";
@@ -161,18 +120,10 @@ namespace DotNetOAuth.Test.Messaging {
}
[TestMethod, ExpectedException(typeof(ProtocolException))]
- public void DeserializeEmpty() {
- var serializer = MessageSerializer.Get(typeof(Mocks.TestMessage));
- var fields = new Dictionary<string, string>(StringComparer.Ordinal);
- serializer.Deserialize(fields);
- }
-
- [TestMethod, ExpectedException(typeof(ProtocolException))]
public void DeserializeInvalidMessage() {
var serializer = MessageSerializer.Get(typeof(Mocks.TestMessage));
- Dictionary<string, string> fields = new Dictionary<string, string>(StringComparer.Ordinal);
- // Set an disallowed value.
- fields["age"] = "-1";
+ var fields = GetStandardTestFields(FieldFill.AllRequired);
+ fields["age"] = "-1"; // Set an disallowed value.
serializer.Deserialize(fields);
}
}
diff --git a/src/DotNetOAuth.Test/Messaging/MessagingTestBase.cs b/src/DotNetOAuth.Test/Messaging/MessagingTestBase.cs
index a21b000..34cf1df 100644
--- a/src/DotNetOAuth.Test/Messaging/MessagingTestBase.cs
+++ b/src/DotNetOAuth.Test/Messaging/MessagingTestBase.cs
@@ -85,29 +85,89 @@ namespace DotNetOAuth.Test {
return new TestChannel(typeProvider, bindingElements.ToArray());
}
+ internal enum FieldFill {
+ /// <summary>
+ /// An empty dictionary is returned.
+ /// </summary>
+ None,
+
+ /// <summary>
+ /// Only enough fields for the <see cref="TestMessageTypeProvider"/>
+ /// to identify the message are included.
+ /// </summary>
+ IdentifiableButNotAllRequired,
+
+ /// <summary>
+ /// All fields marked as required are included.
+ /// </summary>
+ AllRequired,
+
+ /// <summary>
+ /// All user-fillable fields in the message, leaving out those whose
+ /// values are to be set by channel binding elements.
+ /// </summary>
+ CompleteBeforeBindings,
+ }
+
+ internal static IDictionary<string, string> GetStandardTestFields(FieldFill fill) {
+ TestMessage expectedMessage = GetStandardTestMessage(fill);
+
+ var fields = new Dictionary<string, string>();
+ if (fill >= FieldFill.IdentifiableButNotAllRequired) {
+ fields.Add("age", expectedMessage.Age.ToString());
+ }
+ if (fill >= FieldFill.AllRequired) {
+ fields.Add("Timestamp", XmlConvert.ToString(expectedMessage.Timestamp, XmlDateTimeSerializationMode.Utc));
+ }
+ if (fill >= FieldFill.CompleteBeforeBindings) {
+ fields.Add("Name", expectedMessage.Name);
+ fields.Add("Location", expectedMessage.Location.AbsoluteUri);
+ }
+
+ return fields;
+ }
+
+ internal static TestMessage GetStandardTestMessage(FieldFill fill) {
+ TestMessage message = new TestMessage();
+ GetStandardTestMessage(fill, message);
+ return message;
+ }
+
+ internal static void GetStandardTestMessage(FieldFill fill, TestMessage message) {
+ if (message == null) {
+ throw new ArgumentNullException("message");
+ }
+
+ if (fill >= FieldFill.IdentifiableButNotAllRequired) {
+ message.Age = 15;
+ }
+ if (fill >= FieldFill.AllRequired) {
+ message.Timestamp = DateTime.SpecifyKind(DateTime.Parse("9/19/2008 8 AM"), DateTimeKind.Utc);
+ }
+ if (fill >= FieldFill.CompleteBeforeBindings) {
+ message.Name = "Andrew";
+ message.Location = new Uri("http://localtest/path");
+ }
+ }
+
internal void ParameterizedReceiveTest(string method) {
- var fields = new Dictionary<string, string> {
- { "age", "15" },
- { "Name", "Andrew" },
- { "Location", "http://hostb/pathB" },
- };
+ var fields = GetStandardTestFields(FieldFill.CompleteBeforeBindings);
+ TestMessage expectedMessage = GetStandardTestMessage(FieldFill.CompleteBeforeBindings); ;
+
IProtocolMessage requestMessage = this.Channel.ReadFromRequest(CreateHttpRequestInfo(method, fields));
Assert.IsNotNull(requestMessage);
Assert.IsInstanceOfType(requestMessage, typeof(TestMessage));
- TestMessage testMessage = (TestMessage)requestMessage;
- Assert.AreEqual(15, testMessage.Age);
- Assert.AreEqual("Andrew", testMessage.Name);
- Assert.AreEqual("http://hostb/pathB", testMessage.Location.AbsoluteUri);
+ TestMessage actualMessage = (TestMessage)requestMessage;
+ Assert.AreEqual(expectedMessage.Age, actualMessage.Age);
+ Assert.AreEqual(expectedMessage.Name, actualMessage.Name);
+ Assert.AreEqual(expectedMessage.Location, actualMessage.Location);
}
internal void ParameterizedReceiveProtectedTest(DateTime? utcCreatedDate, bool invalidSignature) {
- var fields = new Dictionary<string, string> {
- { "age", "15" },
- { "Name", "Andrew" },
- { "Location", "http://hostb/pathB" },
- { "Signature", invalidSignature ? "badsig" : MockSigningBindingElement.MessageSignature },
- { "Nonce", "someNonce" },
- };
+ TestMessage expectedMessage = GetStandardTestMessage(FieldFill.CompleteBeforeBindings); ;
+ var fields = GetStandardTestFields(FieldFill.CompleteBeforeBindings);
+ fields.Add("Signature", invalidSignature ? "badsig" : MockSigningBindingElement.MessageSignature);
+ fields.Add("Nonce", "someNonce");
if (utcCreatedDate.HasValue) {
utcCreatedDate = DateTime.Parse(utcCreatedDate.Value.ToUniversalTime().ToString()); // round off the milliseconds so comparisons work later
fields.Add("created_on", XmlConvert.ToString(utcCreatedDate.Value, XmlDateTimeSerializationMode.Utc));
@@ -115,10 +175,10 @@ namespace DotNetOAuth.Test {
IProtocolMessage requestMessage = this.Channel.ReadFromRequest(CreateHttpRequestInfo("GET", fields));
Assert.IsNotNull(requestMessage);
Assert.IsInstanceOfType(requestMessage, typeof(TestSignedDirectedMessage));
- TestSignedDirectedMessage testMessage = (TestSignedDirectedMessage)requestMessage;
- Assert.AreEqual(15, testMessage.Age);
- Assert.AreEqual("Andrew", testMessage.Name);
- Assert.AreEqual("http://hostb/pathB", testMessage.Location.AbsoluteUri);
+ TestSignedDirectedMessage actualMessage = (TestSignedDirectedMessage)requestMessage;
+ Assert.AreEqual(expectedMessage.Age, actualMessage.Age);
+ Assert.AreEqual(expectedMessage.Name, actualMessage.Name);
+ Assert.AreEqual(expectedMessage.Location, actualMessage.Location);
if (utcCreatedDate.HasValue) {
IExpiringProtocolMessage expiringMessage = (IExpiringProtocolMessage)requestMessage;
Assert.AreEqual(utcCreatedDate.Value, expiringMessage.UtcCreationDate);
diff --git a/src/DotNetOAuth.Test/Messaging/Reflection/MessageDescriptionTests.cs b/src/DotNetOAuth.Test/Messaging/Reflection/MessageDescriptionTests.cs
new file mode 100644
index 0000000..04c1df8
--- /dev/null
+++ b/src/DotNetOAuth.Test/Messaging/Reflection/MessageDescriptionTests.cs
@@ -0,0 +1,21 @@
+namespace DotNetOAuth.Test.Messaging.Reflection {
+ using System;
+ using System.Collections.Generic;
+ using System.Linq;
+ using System.Text;
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+ using DotNetOAuth.Messaging.Reflection;
+
+ [TestClass]
+ public class MessageDescriptionTests : MessagingTestBase {
+ [TestMethod, ExpectedException(typeof(ArgumentNullException))]
+ public void GetNull() {
+ MessageDescription.Get(null);
+ }
+
+ [TestMethod, ExpectedException(typeof(ArgumentException))]
+ public void GetNonMessageType() {
+ MessageDescription.Get(typeof(string));
+ }
+ }
+}
diff --git a/src/DotNetOAuth.Test/Messaging/Reflection/MessageDictionaryTests.cs b/src/DotNetOAuth.Test/Messaging/Reflection/MessageDictionaryTests.cs
new file mode 100644
index 0000000..b1ae0b6
--- /dev/null
+++ b/src/DotNetOAuth.Test/Messaging/Reflection/MessageDictionaryTests.cs
@@ -0,0 +1,347 @@
+//-----------------------------------------------------------------------
+// <copyright file="MessageDictionaryTest.cs" company="Andrew Arnott">
+// Copyright (c) Andrew Arnott. All rights reserved.
+// </copyright>
+//-----------------------------------------------------------------------
+
+namespace DotNetOAuth.Test.Messaging.Reflection {
+ using System;
+ using System.Collections;
+ using System.Collections.Generic;
+ using System.Collections.ObjectModel;
+ using DotNetOAuth.Messaging;
+ using DotNetOAuth.Messaging.Reflection;
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+ using System.Xml;
+
+ [TestClass]
+ public class MessageDictionaryTests : MessagingTestBase {
+ private Mocks.TestMessage message;
+
+ [TestInitialize]
+ public override void SetUp() {
+ base.SetUp();
+
+ this.message = new Mocks.TestMessage();
+ }
+
+ [TestMethod, ExpectedException(typeof(ArgumentNullException))]
+ public void CtorNull() {
+ new MessageDictionary(null);
+ }
+
+ /// <summary>
+ /// A test for System.Collections.Generic.IDictionary&lt;System.String,System.String>.Values
+ /// </summary>
+ [TestMethod]
+ public void Values() {
+ IDictionary<string, string> target = new MessageDictionary(this.message);
+ Collection<string> expected = new Collection<string> {
+ this.message.Age.ToString(),
+ XmlConvert.ToString(DateTime.SpecifyKind(this.message.Timestamp, DateTimeKind.Utc), XmlDateTimeSerializationMode.Utc),
+ };
+ CollectionAssert<string>.AreEquivalent(expected, target.Values);
+
+ this.message.Age = 15;
+ this.message.Location = new Uri("http://localtest");
+ this.message.Name = "Andrew";
+ target["extra"] = "a";
+ expected = new Collection<string> {
+ this.message.Age.ToString(),
+ this.message.Location.AbsoluteUri,
+ this.message.Name,
+ XmlConvert.ToString(DateTime.SpecifyKind(this.message.Timestamp, DateTimeKind.Utc), XmlDateTimeSerializationMode.Utc),
+ "a",
+ };
+ CollectionAssert<string>.AreEquivalent(expected, target.Values);
+ }
+
+ /// <summary>
+ /// A test for System.Collections.Generic.IDictionary&lt;System.String,System.String>.Keys
+ /// </summary>
+ [TestMethod]
+ public void Keys() {
+ // We expect that non-nullable value type fields will automatically have keys
+ // in the dictionary for them.
+ IDictionary<string, string> target = new MessageDictionary(this.message);
+ Collection<string> expected = new Collection<string> {
+ "age",
+ "Timestamp",
+ };
+ CollectionAssert<string>.AreEquivalent(expected, target.Keys);
+
+ this.message.Name = "Andrew";
+ expected.Add("Name");
+ target["extraField"] = string.Empty;
+ expected.Add("extraField");
+ CollectionAssert<string>.AreEquivalent(expected, target.Keys);
+ }
+
+ /// <summary>
+ /// A test for System.Collections.Generic.IDictionary&lt;System.String,System.String>.Item
+ /// </summary>
+ [TestMethod]
+ public void Item() {
+ IDictionary<string, string> target = new MessageDictionary(this.message);
+
+ // Test setting of declared message properties.
+ this.message.Age = 15;
+ Assert.AreEqual("15", target["age"]);
+ target["age"] = "13";
+ Assert.AreEqual(13, this.message.Age);
+
+ // Test setting extra fields
+ target["extra"] = "fun";
+ Assert.AreEqual("fun", target["extra"]);
+ Assert.AreEqual("fun", ((IProtocolMessage)this.message).ExtraData["extra"]);
+
+ // Test clearing extra fields
+ target["extra"] = null;
+ Assert.IsFalse(target.ContainsKey("extra"));
+ }
+
+ /// <summary>
+ /// A test for System.Collections.Generic.ICollection&lt;System.Collections.Generic.KeyValuePair&lt;System.String,System.String&lt;&lt;.IsReadOnly
+ /// </summary>
+ [TestMethod]
+ public void IsReadOnly() {
+ ICollection<KeyValuePair<string, string>> target = new MessageDictionary(this.message);
+ Assert.IsFalse(target.IsReadOnly);
+ }
+
+ /// <summary>
+ /// A test for System.Collections.Generic.ICollection&lt;System.Collections.Generic.KeyValuePair&lt;System.String,System.String&lt;&lt;.Count
+ /// </summary>
+ [TestMethod]
+ public void Count() {
+ ICollection<KeyValuePair<string, string>> target = new MessageDictionary(this.message);
+ IDictionary<string, string> targetDictionary = (IDictionary<string, string>)target;
+ Assert.AreEqual(targetDictionary.Keys.Count, target.Count);
+ targetDictionary["extraField"] = "hi";
+ Assert.AreEqual(targetDictionary.Keys.Count, target.Count);
+ }
+
+ /// <summary>
+ /// A test for System.Collections.Generic.IEnumerable&lt;System.Collections.Generic.KeyValuePair&lt;System.String,System.String&lt;&lt;.GetEnumerator
+ /// </summary>
+ [TestMethod]
+ public void GetEnumerator() {
+ IEnumerable<KeyValuePair<string, string>> target = new MessageDictionary(this.message);
+ IDictionary<string, string> targetDictionary = (IDictionary<string, string>)target;
+ var keys = targetDictionary.Keys.GetEnumerator();
+ var values = targetDictionary.Values.GetEnumerator();
+ IEnumerator<KeyValuePair<string, string>> actual = target.GetEnumerator();
+
+ bool keysLast = true, valuesLast = true, actualLast = true;
+ while (true) {
+ keysLast = keys.MoveNext();
+ valuesLast = values.MoveNext();
+ actualLast = actual.MoveNext();
+ if (!keysLast || !valuesLast || !actualLast) {
+ break;
+ }
+
+ Assert.AreEqual(keys.Current, actual.Current.Key);
+ Assert.AreEqual(values.Current, actual.Current.Value);
+ }
+ Assert.IsTrue(keysLast == valuesLast && keysLast == actualLast);
+ }
+
+ [TestMethod]
+ public void GetEnumeratorUntyped() {
+ IEnumerable target = new MessageDictionary(this.message);
+ IDictionary<string, string> targetDictionary = (IDictionary<string, string>)target;
+ var keys = targetDictionary.Keys.GetEnumerator();
+ var values = targetDictionary.Values.GetEnumerator();
+ IEnumerator actual = target.GetEnumerator();
+
+ bool keysLast = true, valuesLast = true, actualLast = true;
+ while (true) {
+ keysLast = keys.MoveNext();
+ valuesLast = values.MoveNext();
+ actualLast = actual.MoveNext();
+ if (!keysLast || !valuesLast || !actualLast) {
+ break;
+ }
+
+ KeyValuePair<string, string> current = (KeyValuePair<string, string>)actual.Current;
+ Assert.AreEqual(keys.Current, current.Key);
+ Assert.AreEqual(values.Current, current.Value);
+ }
+ Assert.IsTrue(keysLast == valuesLast && keysLast == actualLast);
+ }
+
+ /// <summary>
+ /// A test for System.Collections.Generic.IDictionary&lt;System.String,System.String>.TryGetValue
+ /// </summary>
+ [TestMethod]
+ public void TryGetValue() {
+ IDictionary<string, string> target = new MessageDictionary(this.message);
+ this.message.Name = "andrew";
+ string name;
+ Assert.IsTrue(target.TryGetValue("Name", out name));
+ Assert.AreEqual(this.message.Name, name);
+
+ Assert.IsFalse(target.TryGetValue("name", out name));
+ Assert.IsNull(name);
+
+ target["extra"] = "value";
+ string extra;
+ Assert.IsTrue(target.TryGetValue("extra", out extra));
+ Assert.AreEqual("value", extra);
+ }
+
+ /// <summary>
+ /// A test for System.Collections.Generic.IDictionary&lt;System.String,System.String>.Remove
+ /// </summary>
+ [TestMethod]
+ public void RemoveTest1() {
+ IDictionary<string, string> target = new MessageDictionary(this.message);
+ this.message.Name = "andrew";
+ Assert.IsTrue(target.Remove("Name"));
+ Assert.IsNull(this.message.Name);
+ Assert.IsFalse(target.Remove("Name"));
+
+ Assert.IsFalse(target.Remove("extra"));
+ target["extra"] = "value";
+ Assert.IsTrue(target.Remove("extra"));
+ Assert.IsFalse(target.ContainsKey("extra"));
+ }
+
+ /// <summary>
+ /// A test for System.Collections.Generic.IDictionary&lt;System.String,System.String>.ContainsKey
+ /// </summary>
+ [TestMethod]
+ public void ContainsKey() {
+ IDictionary<string, string> target = new MessageDictionary(this.message);
+ Assert.IsTrue(target.ContainsKey("age"), "Value type declared element should have a key.");
+ Assert.IsFalse(target.ContainsKey("Name"), "Null declared element should NOT have a key.");
+
+ Assert.IsFalse(target.ContainsKey("extra"));
+ target["extra"] = "value";
+ Assert.IsTrue(target.ContainsKey("extra"));
+ }
+
+ /// <summary>
+ /// A test for System.Collections.Generic.IDictionary&lt;System.String,System.String&gt;.Add
+ /// </summary>
+ [TestMethod]
+ public void AddByKeyAndValue() {
+ IDictionary<string, string> target = new MessageDictionary(this.message);
+ target.Add("extra", "value");
+ Assert.IsTrue(target.Contains(new KeyValuePair<string, string>("extra", "value")));
+ target.Add("Name", "Andrew");
+ Assert.AreEqual("Andrew", this.message.Name);
+ }
+
+ [TestMethod, ExpectedException(typeof(ArgumentNullException))]
+ public void AddNullValue() {
+ IDictionary<string, string> target = new MessageDictionary(this.message);
+ target.Add("extra", null);
+ }
+
+ /// <summary>
+ /// A test for System.Collections.Generic.ICollection&lt;System.Collections.Generic.KeyValuePair&lt;System.String,System.String&lt;&lt;.Add
+ /// </summary>
+ [TestMethod]
+ public void AddByKeyValuePair() {
+ IDictionary<string, string> target = new MessageDictionary(this.message);
+ target.Add(new KeyValuePair<string, string>("extra", "value"));
+ Assert.IsTrue(target.Contains(new KeyValuePair<string, string>("extra", "value")));
+ }
+
+ [TestMethod, ExpectedException(typeof(ArgumentException))]
+ public void AddExtraFieldThatAlreadyExists() {
+ IDictionary<string, string> target = new MessageDictionary(this.message);
+ target.Add("extra", "value");
+ target.Add("extra", "value");
+ }
+
+ [TestMethod, ExpectedException(typeof(ArgumentException))]
+ public void AddDeclaredValueThatAlreadyExists() {
+ IDictionary<string, string> target = new MessageDictionary(this.message);
+ target.Add("Name", "andrew");
+ target.Add("Name", "andrew");
+ }
+
+ [TestMethod]
+ public void DefaultReferenceTypeDeclaredPropertyHasNoKey() {
+ IDictionary<string, string> target = new MessageDictionary(this.message);
+ Assert.IsFalse(target.ContainsKey("Name"), "A null value should result in no key.");
+ Assert.IsFalse(target.Keys.Contains("Name"), "A null value should result in no key.");
+ }
+
+ [TestMethod]
+ public void RemoveStructDeclaredProperty() {
+ IDictionary<string, string> target = new MessageDictionary(this.message);
+ this.message.Age = 5;
+ Assert.IsTrue(target.ContainsKey("age"));
+ target.Remove("age");
+ Assert.IsTrue(target.ContainsKey("age"));
+ Assert.AreEqual(0, this.message.Age);
+ }
+
+ /// <summary>
+ /// A test for System.Collections.Generic.ICollection&lt;System.Collections.Generic.KeyValuePair&lt;System.String,System.String&lt;&lt;.Remove
+ /// </summary>
+ [TestMethod]
+ public void RemoveByKeyValuePair() {
+ ICollection<KeyValuePair<string, string>> target = new MessageDictionary(this.message);
+ this.message.Name = "Andrew";
+ Assert.IsFalse(target.Remove(new KeyValuePair<string, string>("Name", "andrew")));
+ Assert.AreEqual("Andrew", this.message.Name);
+ Assert.IsTrue(target.Remove(new KeyValuePair<string, string>("Name", "Andrew")));
+ Assert.IsNull(this.message.Name);
+ }
+
+ /// <summary>
+ /// A test for System.Collections.Generic.ICollection&lt;System.Collections.Generic.KeyValuePair&lt;System.String,System.String&lt;&lt;.CopyTo
+ /// </summary>
+ [TestMethod]
+ public void CopyTo() {
+ ICollection<KeyValuePair<string, string>> target = new MessageDictionary(this.message);
+ IDictionary<string, string> targetAsDictionary = ((IDictionary<string, string>)target);
+ KeyValuePair<string, string>[] array = new KeyValuePair<string, string>[target.Count + 1];
+ int arrayIndex = 1;
+ target.CopyTo(array, arrayIndex);
+ Assert.AreEqual(new KeyValuePair<string, string>(), array[0]);
+ for (int i = 1; i < array.Length; i++) {
+ Assert.IsNotNull(array[i].Key);
+ Assert.AreEqual(targetAsDictionary[array[i].Key], array[i].Value);
+ }
+ }
+
+ /// <summary>
+ /// A test for System.Collections.Generic.ICollection&lt;System.Collections.Generic.KeyValuePair&lt;System.String,System.String&lt;&lt;.Contains
+ /// </summary>
+ [TestMethod]
+ public void ContainsKeyValuePair() {
+ ICollection<KeyValuePair<string, string>> target = new MessageDictionary(this.message);
+ IDictionary<string, string> targetAsDictionary = ((IDictionary<string, string>)target);
+ Assert.IsFalse(target.Contains(new KeyValuePair<string, string>("age", "1")));
+ Assert.IsTrue(target.Contains(new KeyValuePair<string, string>("age", "0")));
+
+ targetAsDictionary["extra"] = "value";
+ Assert.IsFalse(target.Contains(new KeyValuePair<string, string>("extra", "Value")));
+ Assert.IsTrue(target.Contains(new KeyValuePair<string, string>("extra", "value")));
+ Assert.IsFalse(target.Contains(new KeyValuePair<string, string>("wayoff", "value")));
+ }
+
+ /// <summary>
+ /// A test for System.Collections.Generic.ICollection&lt;System.Collections.Generic.KeyValuePair&lt;System.String,System.String&lt;&lt;.Clear
+ /// </summary>
+ [TestMethod]
+ public void Clear() {
+ ICollection<KeyValuePair<string, string>> target = new MessageDictionary(this.message);
+ IDictionary<string, string> targetAsDictionary = ((IDictionary<string, string>)target);
+ this.message.Name = "Andrew";
+ this.message.Age = 15;
+ targetAsDictionary["extra"] = "value";
+ target.Clear();
+ Assert.AreEqual(2, target.Count, "Clearing should remove all keys except for declared non-nullable structs.");
+ Assert.IsFalse(targetAsDictionary.ContainsKey("extra"));
+ Assert.IsNull(this.message.Name);
+ Assert.AreEqual(0, this.message.Age);
+ }
+ }
+}
diff --git a/src/DotNetOAuth.Test/Messaging/Reflection/MessagePartTests.cs b/src/DotNetOAuth.Test/Messaging/Reflection/MessagePartTests.cs
new file mode 100644
index 0000000..b6c3b9d
--- /dev/null
+++ b/src/DotNetOAuth.Test/Messaging/Reflection/MessagePartTests.cs
@@ -0,0 +1,99 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using DotNetOAuth.Messaging.Reflection;
+using System.Reflection;
+using DotNetOAuth.Test.Mocks;
+
+namespace DotNetOAuth.Test.Messaging.Reflection {
+ [TestClass]
+ public class MessagePartTests :MessagingTestBase {
+ class MessageWithNonNullableOptionalStruct : TestMessage {
+ /// <summary>
+ /// Optional structs like int must be nullable for Optional to make sense.
+ /// </summary>
+ [MessagePart(IsRequired = false)]
+ internal int optionalInt = 0;
+ }
+ class MessageWithNonNullableRequiredStruct : TestMessage {
+ /// <summary>
+ /// This should work because a required field will always have a value so it
+ /// need not be nullable.
+ /// </summary>
+ [MessagePart(IsRequired = true)]
+ internal int optionalInt = 0;
+ }
+ class MessageWithNullableOptionalStruct : TestMessage {
+ /// <summary>
+ /// Optional structs like int must be nullable for Optional to make sense.
+ /// </summary>
+ [MessagePart(IsRequired = false)]
+ internal int? optionalInt = 0;
+ }
+ class MessageWithNullableRequiredStruct : TestMessage {
+ [MessagePart(IsRequired = true)]
+ internal int? optionalInt;
+ }
+
+ [TestMethod, ExpectedException(typeof(ArgumentException))]
+ public void OptionalNonNullableStruct() {
+ ParameterizedMessageTypeTest(typeof(MessageWithNonNullableOptionalStruct));
+ }
+
+ [TestMethod]
+ public void RequiredNonNullableStruct() {
+ ParameterizedMessageTypeTest(typeof(MessageWithNonNullableRequiredStruct));
+ }
+
+ [TestMethod]
+ public void OptionalNullableStruct() {
+ ParameterizedMessageTypeTest(typeof(MessageWithNullableOptionalStruct));
+ }
+
+ [TestMethod]
+ public void RequiredNullableStruct() {
+ ParameterizedMessageTypeTest(typeof(MessageWithNullableRequiredStruct));
+ }
+
+ [TestMethod, ExpectedException(typeof(ArgumentNullException))]
+ public void CtorNullMember() {
+ new MessagePart(null, new MessagePartAttribute());
+ }
+
+ [TestMethod, ExpectedException(typeof(ArgumentNullException))]
+ public void CtorNullAttribute() {
+ FieldInfo field = typeof(MessageWithNullableOptionalStruct).GetField("optionalInt", BindingFlags.NonPublic | BindingFlags.Instance);
+ new MessagePart(field, null);
+ }
+
+ [TestMethod]
+ public void SetValue() {
+ var message = new MessageWithNonNullableRequiredStruct();
+ MessagePart part = ParameterizedMessageTypeTest(message.GetType());
+ part.SetValue(message, "5");
+ Assert.AreEqual(5, message.optionalInt);
+ }
+
+ [TestMethod]
+ public void GetValue() {
+ var message = new MessageWithNonNullableRequiredStruct();
+ message.optionalInt = 8;
+ MessagePart part = ParameterizedMessageTypeTest(message.GetType());
+ Assert.AreEqual("8", part.GetValue(message));
+ }
+
+ [TestMethod, ExpectedException(typeof(ArgumentException))]
+ public void NonFieldOrPropertyMember() {
+ MemberInfo method = typeof(MessageWithNullableOptionalStruct).GetMethod("Equals", BindingFlags.Public | BindingFlags.Instance);
+ new MessagePart(method, new MessagePartAttribute());
+ }
+
+ private MessagePart ParameterizedMessageTypeTest(Type messageType) {
+ FieldInfo field = messageType.GetField("optionalInt", BindingFlags.NonPublic | BindingFlags.Instance);
+ MessagePartAttribute attribute = field.GetCustomAttributes(typeof(MessagePartAttribute), true).OfType<MessagePartAttribute>().Single();
+ return new MessagePart(field, attribute);
+ }
+ }
+}
diff --git a/src/DotNetOAuth.Test/Messaging/Reflection/ValueMappingTests.cs b/src/DotNetOAuth.Test/Messaging/Reflection/ValueMappingTests.cs
new file mode 100644
index 0000000..9142a0c
--- /dev/null
+++ b/src/DotNetOAuth.Test/Messaging/Reflection/ValueMappingTests.cs
@@ -0,0 +1,21 @@
+namespace DotNetOAuth.Test.Messaging.Reflection {
+ using System;
+ using System.Collections.Generic;
+ using System.Linq;
+ using System.Text;
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+ using DotNetOAuth.Messaging.Reflection;
+
+ [TestClass]
+ public class ValueMappingTests {
+ [TestMethod, ExpectedException(typeof(ArgumentNullException))]
+ public void CtorNullToString() {
+ new ValueMapping(null, str => new object());
+ }
+
+ [TestMethod, ExpectedException(typeof(ArgumentNullException))]
+ public void CtorNullToObject() {
+ new ValueMapping(obj => obj.ToString(), null);
+ }
+ }
+}
diff --git a/src/DotNetOAuth.Test/Mocks/TestBaseMessage.cs b/src/DotNetOAuth.Test/Mocks/TestBaseMessage.cs
index 29e2809..2a8cb30 100644
--- a/src/DotNetOAuth.Test/Mocks/TestBaseMessage.cs
+++ b/src/DotNetOAuth.Test/Mocks/TestBaseMessage.cs
@@ -6,8 +6,10 @@
namespace DotNetOAuth.Test.Mocks {
using System;
+ using System.Collections.Generic;
using System.Runtime.Serialization;
using DotNetOAuth.Messaging;
+ using DotNetOAuth.Messaging.Reflection;
internal interface IBaseMessageExplicitMembers {
string ExplicitProperty { get; set; }
@@ -15,13 +17,15 @@ namespace DotNetOAuth.Test.Mocks {
[DataContract(Namespace = Protocol.DataContractNamespaceV10)]
internal class TestBaseMessage : IProtocolMessage, IBaseMessageExplicitMembers {
- [DataMember(Name = "age", IsRequired = true)]
+ private Dictionary<string, string> extraData = new Dictionary<string, string>();
+
+ [MessagePart(Name = "age", IsRequired = true)]
public int Age { get; set; }
- [DataMember]
+ [MessagePart]
public string Name { get; set; }
- [DataMember(Name = "explicit")]
+ [MessagePart(Name = "explicit")]
string IBaseMessageExplicitMembers.ExplicitProperty { get; set; }
Version IProtocolMessage.ProtocolVersion {
@@ -36,12 +40,16 @@ namespace DotNetOAuth.Test.Mocks {
get { return MessageTransport.Indirect; }
}
+ IDictionary<string, string> IProtocolMessage.ExtraData {
+ get { return this.extraData; }
+ }
+
internal string PrivatePropertyAccessor {
get { return this.PrivateProperty; }
set { this.PrivateProperty = value; }
}
- [DataMember(Name = "private")]
+ [MessagePart(Name = "private")]
private string PrivateProperty { get; set; }
void IProtocolMessage.EnsureValidMessage() { }
diff --git a/src/DotNetOAuth.Test/Mocks/TestDerivedMessage.cs b/src/DotNetOAuth.Test/Mocks/TestDerivedMessage.cs
index afd67f6..69e58aa 100644
--- a/src/DotNetOAuth.Test/Mocks/TestDerivedMessage.cs
+++ b/src/DotNetOAuth.Test/Mocks/TestDerivedMessage.cs
@@ -6,6 +6,7 @@
namespace DotNetOAuth.Test.Mocks {
using System.Runtime.Serialization;
+ using DotNetOAuth.Messaging.Reflection;
[DataContract(Namespace = Protocol.DataContractNamespaceV10)]
internal class TestDerivedMessage : TestBaseMessage {
@@ -17,7 +18,7 @@ namespace DotNetOAuth.Test.Mocks {
/// due to alphabetical ordering rules, but after all the elements in the
/// base class due to inheritance rules.
/// </remarks>
- [DataMember]
+ [MessagePart]
public string TheFirstDerivedElement { get; set; }
/// <summary>
@@ -27,7 +28,7 @@ namespace DotNetOAuth.Test.Mocks {
/// This element should appear BEFORE <see cref="TheFirstDerivedElement"/>,
/// but after all the elements in the base class.
/// </remarks>
- [DataMember]
+ [MessagePart]
public string SecondDerivedElement { get; set; }
}
}
diff --git a/src/DotNetOAuth.Test/Mocks/TestDirectedMessage.cs b/src/DotNetOAuth.Test/Mocks/TestDirectedMessage.cs
index 7add28b..17317f5 100644
--- a/src/DotNetOAuth.Test/Mocks/TestDirectedMessage.cs
+++ b/src/DotNetOAuth.Test/Mocks/TestDirectedMessage.cs
@@ -6,25 +6,18 @@
namespace DotNetOAuth.Test.Mocks {
using System;
+ using System.Collections.Generic;
using System.Runtime.Serialization;
using DotNetOAuth.Messaging;
+ using DotNetOAuth.Messaging.Reflection;
[DataContract(Namespace = Protocol.DataContractNamespaceV10)]
- internal class TestDirectedMessage : IDirectedProtocolMessage {
- private MessageTransport transport;
-
- internal TestDirectedMessage(MessageTransport transport) {
- this.transport = transport;
+ internal class TestDirectedMessage : TestMessage, IDirectedProtocolMessage {
+ internal TestDirectedMessage() {
}
- [DataMember(Name = "age", IsRequired = true)]
- public int Age { get; set; }
- [DataMember]
- public string Name { get; set; }
- [DataMember]
- public string EmptyMember { get; set; }
- [DataMember]
- public Uri Location { get; set; }
+ internal TestDirectedMessage(MessageTransport transport) : base(transport) {
+ }
#region IDirectedProtocolMessage Members
@@ -34,32 +27,14 @@ namespace DotNetOAuth.Test.Mocks {
#region IProtocolMessage Properties
- Version IProtocolMessage.ProtocolVersion {
- get { return new Version(1, 0); }
- }
-
MessageProtection IProtocolMessage.RequiredProtection {
get { return this.RequiredProtection; }
}
- MessageTransport IProtocolMessage.Transport {
- get { return this.transport; }
- }
-
#endregion
protected virtual MessageProtection RequiredProtection {
get { return MessageProtection.None; }
}
-
- #region IProtocolMessage Methods
-
- void IProtocolMessage.EnsureValidMessage() {
- if (this.EmptyMember != null || this.Age < 0) {
- throw new ProtocolException();
- }
- }
-
- #endregion
}
}
diff --git a/src/DotNetOAuth.Test/Mocks/TestExpiringMessage.cs b/src/DotNetOAuth.Test/Mocks/TestExpiringMessage.cs
index d51e8ee..1b06969 100644
--- a/src/DotNetOAuth.Test/Mocks/TestExpiringMessage.cs
+++ b/src/DotNetOAuth.Test/Mocks/TestExpiringMessage.cs
@@ -10,19 +10,22 @@ namespace DotNetOAuth.Test.Mocks {
using System.Runtime.Serialization;
using DotNetOAuth.Messaging;
using DotNetOAuth.Messaging.Bindings;
+ using DotNetOAuth.Messaging.Reflection;
[DataContract(Namespace = Protocol.DataContractNamespaceV10)]
internal class TestExpiringMessage : TestSignedDirectedMessage, IExpiringProtocolMessage {
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private DateTime utcCreationDate;
+ internal TestExpiringMessage() { }
+
internal TestExpiringMessage(MessageTransport transport)
: base(transport) {
}
#region IExpiringProtocolMessage Members
- [DataMember(Name = "created_on")]
+ [MessagePart(Name = "created_on", IsRequired = true)]
DateTime IExpiringProtocolMessage.UtcCreationDate {
get { return this.utcCreationDate; }
set { this.utcCreationDate = value.ToUniversalTime(); }
diff --git a/src/DotNetOAuth.Test/Mocks/TestMessage.cs b/src/DotNetOAuth.Test/Mocks/TestMessage.cs
index e67b582..aede676 100644
--- a/src/DotNetOAuth.Test/Mocks/TestMessage.cs
+++ b/src/DotNetOAuth.Test/Mocks/TestMessage.cs
@@ -6,29 +6,33 @@
namespace DotNetOAuth.Test.Mocks {
using System;
+ using System.Collections.Generic;
using System.Runtime.Serialization;
using DotNetOAuth.Messaging;
+ using DotNetOAuth.Messaging.Reflection;
[DataContract(Namespace = Protocol.DataContractNamespaceV10)]
internal class TestMessage : IProtocolMessage {
private MessageTransport transport;
+ private Dictionary<string, string> extraData = new Dictionary<string, string>();
- internal TestMessage() : this(MessageTransport.Direct) {
+ internal TestMessage()
+ : this(MessageTransport.Direct) {
}
internal TestMessage(MessageTransport transport) {
this.transport = transport;
}
- [DataMember(Name = "age", IsRequired = true)]
+ [MessagePart(Name = "age", IsRequired = true)]
public int Age { get; set; }
- [DataMember]
+ [MessagePart("Name")]
public string Name { get; set; }
- [DataMember]
+ [MessagePart]
public string EmptyMember { get; set; }
- [DataMember]
+ [MessagePart(Name = null)] // null name tests that Location is still the name.
public Uri Location { get; set; }
- [DataMember]
+ [MessagePart(IsRequired = true)]
public DateTime Timestamp { get; set; }
#region IProtocolMessage Members
@@ -45,6 +49,10 @@ namespace DotNetOAuth.Test.Mocks {
get { return this.transport; }
}
+ IDictionary<string, string> IProtocolMessage.ExtraData {
+ get { return this.extraData; }
+ }
+
void IProtocolMessage.EnsureValidMessage() {
if (this.EmptyMember != null || this.Age < 0) {
throw new ProtocolException();
diff --git a/src/DotNetOAuth.Test/Mocks/TestReplayProtectedMessage.cs b/src/DotNetOAuth.Test/Mocks/TestReplayProtectedMessage.cs
index b62957b..f6c3b39 100644
--- a/src/DotNetOAuth.Test/Mocks/TestReplayProtectedMessage.cs
+++ b/src/DotNetOAuth.Test/Mocks/TestReplayProtectedMessage.cs
@@ -8,16 +8,19 @@ namespace DotNetOAuth.Test.Mocks {
using System.Runtime.Serialization;
using DotNetOAuth.Messaging;
using DotNetOAuth.Messaging.Bindings;
+ using DotNetOAuth.Messaging.Reflection;
[DataContract(Namespace = Protocol.DataContractNamespaceV10)]
internal class TestReplayProtectedMessage : TestExpiringMessage, IReplayProtectedProtocolMessage {
+ internal TestReplayProtectedMessage() { }
+
internal TestReplayProtectedMessage(MessageTransport transport)
: base(transport) {
}
#region IReplayProtectedProtocolMessage Members
- [DataMember(Name = "Nonce")]
+ [MessagePart(Name = "Nonce")]
string IReplayProtectedProtocolMessage.Nonce {
get;
set;
diff --git a/src/DotNetOAuth.Test/Mocks/TestSignedDirectedMessage.cs b/src/DotNetOAuth.Test/Mocks/TestSignedDirectedMessage.cs
index d4d2536..bda4255 100644
--- a/src/DotNetOAuth.Test/Mocks/TestSignedDirectedMessage.cs
+++ b/src/DotNetOAuth.Test/Mocks/TestSignedDirectedMessage.cs
@@ -7,17 +7,20 @@
namespace DotNetOAuth.Test.Mocks {
using System.Runtime.Serialization;
using DotNetOAuth.Messaging;
+ using DotNetOAuth.Messaging.Reflection;
using DotNetOAuth.Messaging.Bindings;
[DataContract(Namespace = Protocol.DataContractNamespaceV10)]
internal class TestSignedDirectedMessage : TestDirectedMessage, ITamperResistantProtocolMessage {
+ internal TestSignedDirectedMessage() { }
+
internal TestSignedDirectedMessage(MessageTransport transport)
: base(transport) {
}
#region ISignedProtocolMessage Members
- [DataMember]
+ [MessagePart]
public string Signature {
get;
set;
diff --git a/src/DotNetOAuth.Test/OAuthChannelTests.cs b/src/DotNetOAuth.Test/OAuthChannelTests.cs
index 2d00800..8b7b24b 100644
--- a/src/DotNetOAuth.Test/OAuthChannelTests.cs
+++ b/src/DotNetOAuth.Test/OAuthChannelTests.cs
@@ -15,6 +15,7 @@ namespace DotNetOAuth.Test {
using DotNetOAuth.Messaging;
using DotNetOAuth.Test.Mocks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
+ using System.Xml;
[TestClass]
public class OAuthChannelTests : TestBase {
@@ -85,6 +86,7 @@ namespace DotNetOAuth.Test {
{ "age", "15" },
{ "Name", "Andrew" },
{ "Location", "http://hostb/pathB" },
+ { "Timestamp", XmlConvert.ToString(DateTime.UtcNow, XmlDateTimeSerializationMode.Utc) },
};
MemoryStream ms = new MemoryStream();
@@ -208,6 +210,7 @@ namespace DotNetOAuth.Test {
Name = "Andrew",
Location = new Uri("http://hostb/pathB"),
Recipient = new Uri("http://localtest"),
+ Timestamp = DateTime.UtcNow,
};
Response rawResponse = null;
@@ -216,14 +219,17 @@ namespace DotNetOAuth.Test {
HttpRequestInfo reqInfo = ConvertToRequestInfo(req, this.webRequestHandler.RequestEntityStream);
Assert.AreEqual(scheme == MessageScheme.PostRequest ? "POST" : "GET", reqInfo.HttpMethod);
var incomingMessage = this.channel.ReadFromRequest(reqInfo) as TestMessage;
+ Assert.IsNotNull(incomingMessage);
Assert.AreEqual(request.Age, incomingMessage.Age);
Assert.AreEqual(request.Name, incomingMessage.Name);
Assert.AreEqual(request.Location, incomingMessage.Location);
+ Assert.AreEqual(request.Timestamp, incomingMessage.Timestamp);
var responseFields = new Dictionary<string, string> {
{ "age", request.Age.ToString() },
{ "Name", request.Name },
{ "Location", request.Location.AbsoluteUri },
+ { "Timestamp", XmlConvert.ToString(request.Timestamp, XmlDateTimeSerializationMode.Utc) },
};
rawResponse = new Response {
Body = MessagingUtilities.CreateQueryString(responseFields),
@@ -246,6 +252,7 @@ namespace DotNetOAuth.Test {
{ "age", "15" },
{ "Name", "Andrew" },
{ "Location", "http://hostb/pathB" },
+ { "Timestamp", XmlConvert.ToString(DateTime.UtcNow, XmlDateTimeSerializationMode.Utc) },
};
IProtocolMessage requestMessage = this.channel.ReadFromRequest(CreateHttpRequestInfo(scheme, fields));
Assert.IsNotNull(requestMessage);
diff --git a/src/DotNetOAuth/DotNetOAuth.csproj b/src/DotNetOAuth/DotNetOAuth.csproj
index 16f0dac..f5882ef 100644
--- a/src/DotNetOAuth/DotNetOAuth.csproj
+++ b/src/DotNetOAuth/DotNetOAuth.csproj
@@ -68,21 +68,20 @@
<ItemGroup>
<Compile Include="Consumer.cs" />
<Compile Include="IWebRequestHandler.cs" />
+ <Compile Include="Messaging\Reflection\MessagePartAttribute.cs" />
<Compile Include="Messaging\MessageProtection.cs" />
<Compile Include="Messaging\IChannelBindingElement.cs" />
<Compile Include="Messaging\Bindings\ReplayedMessageException.cs" />
<Compile Include="Messaging\Bindings\ExpiredMessageException.cs" />
- <Compile Include="Messaging\DataContractMemberComparer.cs" />
<Compile Include="Messaging\Bindings\InvalidSignatureException.cs" />
<Compile Include="Messaging\Bindings\IReplayProtectedProtocolMessage.cs" />
<Compile Include="Messaging\Bindings\IExpiringProtocolMessage.cs" />
- <Compile Include="Messaging\DictionaryXmlReader.cs" />
- <Compile Include="Messaging\DictionaryXmlWriter.cs" />
<Compile Include="Messaging\Channel.cs" />
<Compile Include="Messaging\HttpRequestInfo.cs" />
<Compile Include="Messaging\IDirectedProtocolMessage.cs" />
<Compile Include="Messaging\IMessageTypeProvider.cs" />
<Compile Include="Messaging\Bindings\ITamperResistantProtocolMessage.cs" />
+ <Compile Include="Messaging\MessageSerializer.cs" />
<Compile Include="Messaging\MessagingStrings.Designer.cs">
<AutoGen>True</AutoGen>
<DesignTime>True</DesignTime>
@@ -90,6 +89,10 @@
</Compile>
<Compile Include="Messaging\MessagingUtilities.cs" />
<Compile Include="Messaging\Bindings\StandardExpirationBindingElement.cs" />
+ <Compile Include="Messaging\Reflection\ValueMapping.cs" />
+ <Compile Include="Messaging\Reflection\MessageDescription.cs" />
+ <Compile Include="Messaging\Reflection\MessageDictionary.cs" />
+ <Compile Include="Messaging\Reflection\MessagePart.cs" />
<Compile Include="Messaging\UnprotectedMessageException.cs" />
<Compile Include="OAuthChannel.cs" />
<Compile Include="Messaging\Response.cs" />
@@ -103,7 +106,6 @@
<Compile Include="Messaging\MessageTransport.cs" />
<Compile Include="OAuthMessageTypeProvider.cs" />
<Compile Include="Messaging\ProtocolException.cs" />
- <Compile Include="Messaging\MessageSerializer.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="StandardWebRequestHandler.cs" />
<Compile Include="Util.cs" />
diff --git a/src/DotNetOAuth/Messaging/Channel.cs b/src/DotNetOAuth/Messaging/Channel.cs
index aebf6e4..db6cc10 100644
--- a/src/DotNetOAuth/Messaging/Channel.cs
+++ b/src/DotNetOAuth/Messaging/Channel.cs
@@ -15,6 +15,8 @@ namespace DotNetOAuth.Messaging {
using System.Net;
using System.Text;
using System.Web;
+ using DotNetOAuth.Messaging.Reflection;
+ using DotNetOAuth.Messaging.Bindings;
/// <summary>
/// Manages sending direct messages to a remote party and receiving responses.
@@ -179,7 +181,10 @@ namespace DotNetOAuth.Messaging {
/// <returns>The deserialized message, if one is found. Null otherwise.</returns>
protected internal IProtocolMessage ReadFromRequest(HttpRequestInfo httpRequest) {
IProtocolMessage requestMessage = this.ReadFromRequestInternal(httpRequest);
- this.VerifyMessageAfterReceiving(requestMessage);
+ if (requestMessage != null) {
+ this.VerifyMessageAfterReceiving(requestMessage);
+ }
+
return requestMessage;
}
@@ -497,6 +502,9 @@ namespace DotNetOAuth.Messaging {
if ((message.RequiredProtection & appliedProtection) != message.RequiredProtection) {
throw new UnprotectedMessageException(message, appliedProtection);
}
+
+ EnsureValidMessageParts(message);
+ message.EnsureValidMessage();
}
/// <summary>
@@ -521,6 +529,17 @@ namespace DotNetOAuth.Messaging {
if ((message.RequiredProtection & appliedProtection) != message.RequiredProtection) {
throw new UnprotectedMessageException(message, appliedProtection);
}
+
+ EnsureValidMessageParts(message);
+ message.EnsureValidMessage();
+ }
+
+ private void EnsureValidMessageParts(IProtocolMessage message) {
+ Debug.Assert(message != null, "message == null");
+
+ MessageDictionary dictionary = new MessageDictionary(message);
+ MessageDescription description = MessageDescription.Get(message.GetType());
+ description.EnsureRequiredMessagePartsArePresent(dictionary.Keys);
}
}
}
diff --git a/src/DotNetOAuth/Messaging/DataContractMemberComparer.cs b/src/DotNetOAuth/Messaging/DataContractMemberComparer.cs
deleted file mode 100644
index 4061b57..0000000
--- a/src/DotNetOAuth/Messaging/DataContractMemberComparer.cs
+++ /dev/null
@@ -1,103 +0,0 @@
-//-----------------------------------------------------------------------
-// <copyright file="DataContractMemberComparer.cs" company="Andrew Arnott">
-// Copyright (c) Andrew Arnott. All rights reserved.
-// </copyright>
-//-----------------------------------------------------------------------
-
-namespace DotNetOAuth.Messaging {
- using System;
- using System.Collections.Generic;
- using System.Diagnostics;
- using System.IO;
- using System.Linq;
- using System.Reflection;
- using System.Runtime.Serialization;
- using System.Xml;
- using System.Xml.Linq;
-
- /// <summary>
- /// A sorting tool to arrange fields in an order expected by the <see cref="DataContractSerializer"/>.
- /// </summary>
- internal class DataContractMemberComparer : IComparer<string> {
- /// <summary>
- /// The cached calculated inheritance ranking of every [DataMember] member of a type.
- /// </summary>
- private Dictionary<string, int> ranking;
-
- /// <summary>
- /// Initializes a new instance of the <see cref="DataContractMemberComparer"/> class.
- /// </summary>
- /// <param name="dataContractType">The data contract type that will be deserialized to.</param>
- internal DataContractMemberComparer(Type dataContractType) {
- // The elements must be serialized in inheritance rank and alphabetical order
- // so the DataContractSerializer will see them.
- this.ranking = GetDataMemberInheritanceRanking(dataContractType);
- }
-
- #region IComparer<string> Members
-
- /// <summary>
- /// Compares to fields and decides what order they should appear in.
- /// </summary>
- /// <param name="field1">The first field.</param>
- /// <param name="field2">The second field.</param>
- /// <returns>-1 if the first field should appear first, 0 if it doesn't matter, 1 if it should appear last.</returns>
- public int Compare(string field1, string field2) {
- int rank1, rank2;
- bool field1Valid = this.ranking.TryGetValue(field1, out rank1);
- bool field2Valid = this.ranking.TryGetValue(field2, out rank2);
-
- // If both fields are invalid, we don't care about the order.
- if (!field1Valid && !field2Valid) {
- return 0;
- }
-
- // If exactly one is valid, put that one first.
- if (field1Valid ^ field2Valid) {
- return field1Valid ? -1 : 1;
- }
-
- // First compare their inheritance ranking.
- if (rank1 != rank2) {
- // We want DESCENDING rank order, putting the members defined in the most
- // base class first.
- return -rank1.CompareTo(rank2);
- }
-
- // Finally sort alphabetically with case sensitivity.
- return string.CompareOrdinal(field1, field2);
- }
-
- #endregion
-
- /// <summary>
- /// Generates a dictionary of field name and inheritance rankings for a given DataContract type.
- /// </summary>
- /// <param name="type">The type to generate member rankings for.</param>
- /// <returns>The generated dictionary.</returns>
- private static Dictionary<string, int> GetDataMemberInheritanceRanking(Type type) {
- Debug.Assert(type != null, "type == null");
- var ranking = new Dictionary<string, int>();
-
- // TODO: review partial trust scenarios and this NonPublic flag.
- Type currentType = type;
- int rank = 0;
- do {
- foreach (MemberInfo member in currentType.GetMembers(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly)) {
- if (member is PropertyInfo || member is FieldInfo) {
- DataMemberAttribute dataMemberAttribute = member.GetCustomAttributes(typeof(DataMemberAttribute), true).OfType<DataMemberAttribute>().FirstOrDefault();
- if (dataMemberAttribute != null) {
- string name = dataMemberAttribute.Name ?? member.Name;
- ranking.Add(name, rank);
- }
- }
- }
-
- rank++;
- currentType = currentType.BaseType;
- } while (currentType != null);
-
- return ranking;
- }
- }
-}
diff --git a/src/DotNetOAuth/Messaging/DictionaryXmlReader.cs b/src/DotNetOAuth/Messaging/DictionaryXmlReader.cs
deleted file mode 100644
index 11d553b..0000000
--- a/src/DotNetOAuth/Messaging/DictionaryXmlReader.cs
+++ /dev/null
@@ -1,92 +0,0 @@
-//-----------------------------------------------------------------------
-// <copyright file="DictionaryXmlReader.cs" company="Andrew Arnott">
-// Copyright (c) Andrew Arnott. All rights reserved.
-// </copyright>
-//-----------------------------------------------------------------------
-
-namespace DotNetOAuth.Messaging {
- using System;
- using System.Collections.Generic;
- using System.Diagnostics;
- using System.IO;
- using System.Linq;
- using System.Reflection;
- using System.Runtime.Serialization;
- using System.Xml;
- using System.Xml.Linq;
-
- /// <summary>
- /// An XmlReader-looking object that actually reads from a dictionary.
- /// </summary>
- internal class DictionaryXmlReader {
- /// <summary>
- /// Creates an XmlReader that reads data out of a dictionary instead of XML.
- /// </summary>
- /// <param name="rootElement">The name of the root XML element.</param>
- /// <param name="fieldSorter">The field sorter so that the XmlReader generates xml elements in the order required by the <see cref="DataContractSerializer"/>.</param>
- /// <param name="fields">The dictionary to read data from.</param>
- /// <returns>The XmlReader that will read the data out of the given dictionary.</returns>
- internal static XmlReader Create(XName rootElement, IComparer<string> fieldSorter, IDictionary<string, string> fields) {
- if (rootElement == null) {
- throw new ArgumentNullException("rootElement");
- }
- if (fieldSorter == null) {
- throw new ArgumentNullException("fieldSorter");
- }
- if (fields == null) {
- throw new ArgumentNullException("fields");
- }
-
- return CreateRoundtripReader(rootElement, fieldSorter, fields);
- }
-
- /// <summary>
- /// Creates an <see cref="XmlReader"/> that will read values out of a dictionary.
- /// </summary>
- /// <param name="rootElement">The surrounding root XML element to generate.</param>
- /// <param name="fieldSorter">The field sorter so that the XmlReader generates xml elements in the order required by the <see cref="DataContractSerializer"/>.</param>
- /// <param name="fields">The dictionary to list values from.</param>
- /// <returns>The generated <see cref="XmlReader"/>.</returns>
- private static XmlReader CreateRoundtripReader(XName rootElement, IComparer<string> fieldSorter, IDictionary<string, string> fields) {
- Debug.Assert(rootElement != null, "rootElement == null");
- Debug.Assert(fieldSorter != null, "fieldSorter == null");
- Debug.Assert(fields != null, "fields == null");
-
- MemoryStream stream = new MemoryStream();
- XmlWriter writer = XmlWriter.Create(stream);
- SerializeDictionaryToXml(writer, rootElement, fieldSorter, fields);
- writer.Flush();
- stream.Seek(0, SeekOrigin.Begin);
-
- // For debugging purposes.
- StreamReader sr = new StreamReader(stream);
- Trace.WriteLine(sr.ReadToEnd());
- stream.Seek(0, SeekOrigin.Begin);
-
- return XmlReader.Create(stream);
- }
-
- /// <summary>
- /// Writes out the values in a dictionary as XML.
- /// </summary>
- /// <param name="writer">The <see cref="XmlWriter"/> to write out the XML to.</param>
- /// <param name="rootElement">The name of the root element to use to surround the dictionary values.</param>
- /// <param name="fieldSorter">The field sorter so that the XmlReader generates xml elements in the order required by the <see cref="DataContractSerializer"/>.</param>
- /// <param name="fields">The dictionary with values to serialize.</param>
- private static void SerializeDictionaryToXml(XmlWriter writer, XName rootElement, IComparer<string> fieldSorter, IDictionary<string, string> fields) {
- Debug.Assert(writer != null, "writer == null");
- Debug.Assert(rootElement != null, "rootElement == null");
- Debug.Assert(fields != null, "fields == null");
-
- writer.WriteStartElement(rootElement.LocalName, rootElement.NamespaceName);
-
- foreach (var pair in fields.OrderBy(pair => pair.Key, fieldSorter)) {
- writer.WriteStartElement(pair.Key, rootElement.NamespaceName);
- writer.WriteValue(pair.Value);
- writer.WriteEndElement();
- }
-
- writer.WriteEndElement();
- }
- }
-}
diff --git a/src/DotNetOAuth/Messaging/DictionaryXmlWriter.cs b/src/DotNetOAuth/Messaging/DictionaryXmlWriter.cs
deleted file mode 100644
index 3be043f..0000000
--- a/src/DotNetOAuth/Messaging/DictionaryXmlWriter.cs
+++ /dev/null
@@ -1,273 +0,0 @@
-//-----------------------------------------------------------------------
-// <copyright file="DictionaryXmlWriter.cs" company="Andrew Arnott">
-// Copyright (c) Andrew Arnott. All rights reserved.
-// </copyright>
-//-----------------------------------------------------------------------
-
-namespace DotNetOAuth.Messaging {
- using System;
- using System.Collections.Generic;
- using System.Text;
- using System.Xml;
-
- /// <summary>
- /// An XmlWriter-looking object that actually saves data to a dictionary.
- /// </summary>
- internal class DictionaryXmlWriter {
- /// <summary>
- /// Creates an <see cref="XmlWriter"/> that actually writes to an IDictionary&lt;string, string&gt; instance.
- /// </summary>
- /// <param name="dictionary">The dictionary to save the written XML to.</param>
- /// <returns>The XmlWriter that will save data to the given dictionary.</returns>
- internal static XmlWriter Create(IDictionary<string, string> dictionary) {
- return new PseudoXmlWriter(dictionary);
- }
-
- /// <summary>
- /// Writes out a dictionary as if it were XML.
- /// </summary>
- private class PseudoXmlWriter : XmlWriter {
- /// <summary>
- /// The dictionary to write values to.
- /// </summary>
- private IDictionary<string, string> dictionary;
-
- /// <summary>
- /// The key being written at the moment.
- /// </summary>
- private string key;
-
- /// <summary>
- /// The value being written out at the moment.
- /// </summary>
- private StringBuilder value = new StringBuilder();
-
- /// <summary>
- /// Initializes a new instance of the <see cref="PseudoXmlWriter"/> class.
- /// </summary>
- /// <param name="dictionary">The dictionary that will be written to.</param>
- internal PseudoXmlWriter(IDictionary<string, string> dictionary) {
- if (dictionary == null) {
- throw new ArgumentNullException("dictionary");
- }
-
- this.dictionary = dictionary;
- }
-
- /// <summary>
- /// Gets the spoofed state of the <see cref="XmlWriter"/>.
- /// </summary>
- public override WriteState WriteState {
- get { return WriteState.Element; }
- }
-
- /// <summary>
- /// Prepares to write out a new key/value pair with the given key name to the dictionary.
- /// </summary>
- /// <param name="prefix">This parameter is ignored.</param>
- /// <param name="localName">The key to store in the dictionary.</param>
- /// <param name="ns">This parameter is ignored.</param>
- public override void WriteStartElement(string prefix, string localName, string ns) {
- this.key = localName;
- this.value.Length = 0;
- }
-
- /// <summary>
- /// Appends some text to the value that is to be stored in the dictionary.
- /// </summary>
- /// <param name="text">The text to append to the value.</param>
- public override void WriteString(string text) {
- if (!string.IsNullOrEmpty(this.key)) {
- this.value.Append(text);
- }
- }
-
- /// <summary>
- /// Writes out a completed key/value to the dictionary.
- /// </summary>
- public override void WriteEndElement() {
- if (this.key != null) {
- this.dictionary[this.key] = this.value.ToString();
- this.key = null;
- this.value.Length = 0;
- }
- }
-
- /// <summary>
- /// Clears the internal key/value building state.
- /// </summary>
- /// <param name="prefix">This parameter is ignored.</param>
- /// <param name="localName">This parameter is ignored.</param>
- /// <param name="ns">This parameter is ignored.</param>
- public override void WriteStartAttribute(string prefix, string localName, string ns) {
- this.key = null;
- }
-
- /// <summary>
- /// This method does not do anything.
- /// </summary>
- public override void WriteEndAttribute() { }
-
- /// <summary>
- /// This method does not do anything.
- /// </summary>
- public override void Close() { }
-
- #region Unimplemented methods
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- public override void Flush() {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="ns">This parameter is ignored.</param>
- /// <returns>None, since an exception is always thrown.</returns>
- public override string LookupPrefix(string ns) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="buffer">This parameter is ignored.</param>
- /// <param name="index">This parameter is ignored.</param>
- /// <param name="count">This parameter is ignored.</param>
- public override void WriteBase64(byte[] buffer, int index, int count) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="text">This parameter is ignored.</param>
- public override void WriteCData(string text) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="ch">This parameter is ignored.</param>
- public override void WriteCharEntity(char ch) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="buffer">This parameter is ignored.</param>
- /// <param name="index">This parameter is ignored.</param>
- /// <param name="count">This parameter is ignored.</param>
- public override void WriteChars(char[] buffer, int index, int count) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="text">This parameter is ignored.</param>
- public override void WriteComment(string text) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="name">This parameter is ignored.</param>
- /// <param name="pubid">This parameter is ignored.</param>
- /// <param name="sysid">This parameter is ignored.</param>
- /// <param name="subset">This parameter is ignored.</param>
- public override void WriteDocType(string name, string pubid, string sysid, string subset) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- public override void WriteEndDocument() {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="name">This parameter is ignored.</param>
- public override void WriteEntityRef(string name) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- public override void WriteFullEndElement() {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="name">This parameter is ignored.</param>
- /// <param name="text">This parameter is ignored.</param>
- public override void WriteProcessingInstruction(string name, string text) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="data">This parameter is ignored.</param>
- public override void WriteRaw(string data) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="buffer">This parameter is ignored.</param>
- /// <param name="index">This parameter is ignored.</param>
- /// <param name="count">This parameter is ignored.</param>
- public override void WriteRaw(char[] buffer, int index, int count) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="standalone">This parameter is ignored.</param>
- public override void WriteStartDocument(bool standalone) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- public override void WriteStartDocument() {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="lowChar">This parameter is ignored.</param>
- /// <param name="highChar">This parameter is ignored.</param>
- public override void WriteSurrogateCharEntity(char lowChar, char highChar) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="ws">This parameter is ignored.</param>
- public override void WriteWhitespace(string ws) {
- throw new NotImplementedException();
- }
-
- #endregion
- }
- }
-}
diff --git a/src/DotNetOAuth/Messaging/IProtocolMessage.cs b/src/DotNetOAuth/Messaging/IProtocolMessage.cs
index 61df813..a95822c 100644
--- a/src/DotNetOAuth/Messaging/IProtocolMessage.cs
+++ b/src/DotNetOAuth/Messaging/IProtocolMessage.cs
@@ -30,6 +30,15 @@ namespace DotNetOAuth.Messaging {
MessageTransport Transport { get; }
/// <summary>
+ /// Gets the dictionary of additional name/value fields tacked on to this message.
+ /// </summary>
+ /// <remarks>
+ /// Implementations of <see cref="IProtocolMessage"/> should ensure that this property
+ /// never returns null.
+ /// </remarks>
+ IDictionary<string, string> ExtraData { get; }
+
+ /// <summary>
/// Checks the message state for conformity to the protocol specification
/// and throws an exception if the message is invalid.
/// </summary>
diff --git a/src/DotNetOAuth/Messaging/MessageSerializer.cs b/src/DotNetOAuth/Messaging/MessageSerializer.cs
index ad3a12d..6edb0b6 100644
--- a/src/DotNetOAuth/Messaging/MessageSerializer.cs
+++ b/src/DotNetOAuth/Messaging/MessageSerializer.cs
@@ -13,40 +13,19 @@ namespace DotNetOAuth.Messaging {
using System.Runtime.Serialization;
using System.Xml;
using System.Xml.Linq;
+ using DotNetOAuth.Messaging.Reflection;
/// <summary>
/// Serializes/deserializes OAuth messages for/from transit.
/// </summary>
internal class MessageSerializer {
/// <summary>
- /// The serializer that will be used as a reflection engine to extract
- /// the OAuth message properties out of their containing <see cref="IProtocolMessage"/>
- /// objects.
- /// </summary>
- private readonly DataContractSerializer serializer;
-
- /// <summary>
/// The specific <see cref="IProtocolMessage"/>-derived type
/// that will be serialized and deserialized using this class.
/// </summary>
private readonly Type messageType;
/// <summary>
- /// An AppDomain-wide cache of shared serializers for optimization purposes.
- /// </summary>
- private static Dictionary<Type, MessageSerializer> prebuiltSerializers = new Dictionary<Type, MessageSerializer>();
-
- /// <summary>
- /// Backing field for the <see cref="RootElement"/> property
- /// </summary>
- private XName rootElement;
-
- /// <summary>
- /// A field sorter that puts fields in the right order for the <see cref="DataContractSerializer"/>.
- /// </summary>
- private IComparer<string> fieldSorter;
-
- /// <summary>
/// Initializes a new instance of the MessageSerializer class.
/// </summary>
/// <param name="messageType">The specific <see cref="IProtocolMessage"/>-derived type
@@ -65,71 +44,19 @@ namespace DotNetOAuth.Messaging {
}
this.messageType = messageType;
- this.serializer = new DataContractSerializer(
- messageType, this.RootElement.LocalName, this.RootElement.NamespaceName);
- this.fieldSorter = new DataContractMemberComparer(messageType);
- }
-
- /// <summary>
- /// Gets the XML element that is used to surround all the XML values from the dictionary.
- /// </summary>
- private XName RootElement {
- get {
- if (this.rootElement == null) {
- DataContractAttribute attribute;
- try {
- attribute = this.messageType.GetCustomAttributes(typeof(DataContractAttribute), false).OfType<DataContractAttribute>().Single();
- } catch (InvalidOperationException ex) {
- throw new ProtocolException(
- string.Format(
- CultureInfo.CurrentCulture,
- MessagingStrings.DataContractMissingFromMessageType,
- this.messageType.FullName),
- ex);
- }
-
- if (attribute.Namespace == null) {
- throw new ProtocolException(string.Format(
- CultureInfo.CurrentCulture,
- MessagingStrings.DataContractMissingNamespace,
- this.messageType.FullName));
- }
-
- this.rootElement = XName.Get("root", attribute.Namespace);
- }
-
- return this.rootElement;
- }
}
/// <summary>
- /// Returns a message serializer from a reusable collection of serializers.
+ /// Creates or reuses a message serializer for a given message type.
/// </summary>
/// <param name="messageType">The type of message that will be serialized/deserialized.</param>
- /// <returns>A previously created serializer if one exists, or a newly created one.</returns>
+ /// <returns>A message serializer for the given message type.</returns>
internal static MessageSerializer Get(Type messageType) {
if (messageType == null) {
throw new ArgumentNullException("messageType");
}
- // We do this as efficiently as possible by first trying to fetch the
- // serializer out of the dictionary without taking a lock.
- MessageSerializer serializer;
- if (prebuiltSerializers.TryGetValue(messageType, out serializer)) {
- return serializer;
- }
-
- // Since it wasn't there, we'll be trying to write to the dictionary so
- // we take a lock and try reading again first, then creating the serializer
- // and storing it when we're sure it absolutely necessary.
- lock (prebuiltSerializers) {
- if (prebuiltSerializers.TryGetValue(messageType, out serializer)) {
- return serializer;
- }
- serializer = new MessageSerializer(messageType);
- prebuiltSerializers.Add(messageType, serializer);
- }
- return serializer;
+ return new MessageSerializer(messageType);
}
/// <summary>
@@ -142,28 +69,9 @@ namespace DotNetOAuth.Messaging {
throw new ArgumentNullException("message");
}
- var fields = new Dictionary<string, string>(StringComparer.Ordinal);
- this.Serialize(fields, message);
- return fields;
- }
-
- /// <summary>
- /// Saves the [DataMember] properties of a message to an existing dictionary.
- /// </summary>
- /// <param name="fields">The dictionary to save values to.</param>
- /// <param name="message">The message to pull values from.</param>
- internal void Serialize(IDictionary<string, string> fields, IProtocolMessage message) {
- if (fields == null) {
- throw new ArgumentNullException("fields");
- }
- if (message == null) {
- throw new ArgumentNullException("message");
- }
+ var result = new Reflection.MessageDictionary(message);
- message.EnsureValidMessage();
- using (XmlWriter writer = DictionaryXmlWriter.Create(fields)) {
- this.serializer.WriteObjectContent(writer, message);
- }
+ return result;
}
/// <summary>
@@ -176,13 +84,18 @@ namespace DotNetOAuth.Messaging {
throw new ArgumentNullException("fields");
}
- var reader = DictionaryXmlReader.Create(this.RootElement, this.fieldSorter, fields);
- IProtocolMessage result;
+ // Before we deserialize the message, make sure all the required parts are present.
+ MessageDescription.Get(this.messageType).EnsureRequiredMessagePartsArePresent(fields.Keys);
+
+ IProtocolMessage result ;
try {
- result = (IProtocolMessage)this.serializer.ReadObject(reader, false);
- } catch (SerializationException ex) {
- // Missing required fields is one cause of this exception.
- throw new ProtocolException(Strings.InvalidIncomingMessage, ex);
+ result = (IProtocolMessage)Activator.CreateInstance(this.messageType, true);
+ } catch (MissingMethodException ex) {
+ throw new ProtocolException("Failed to instantiate type " + this.messageType.FullName, ex);
+ }
+ foreach (var pair in fields) {
+ IDictionary<string, string> dictionary = new MessageDictionary(result);
+ dictionary.Add(pair);
}
result.EnsureValidMessage();
return result;
diff --git a/src/DotNetOAuth/Messaging/MessagingStrings.Designer.cs b/src/DotNetOAuth/Messaging/MessagingStrings.Designer.cs
index 49e726d..9386f53 100644
--- a/src/DotNetOAuth/Messaging/MessagingStrings.Designer.cs
+++ b/src/DotNetOAuth/Messaging/MessagingStrings.Designer.cs
@@ -151,6 +151,24 @@ namespace DotNetOAuth.Messaging {
}
/// <summary>
+ /// Looks up a localized string similar to Some part(s) of the message have invalid values: {0}.
+ /// </summary>
+ internal static string InvalidMessageParts {
+ get {
+ return ResourceManager.GetString("InvalidMessageParts", resourceCulture);
+ }
+ }
+
+ /// <summary>
+ /// Looks up a localized string similar to An item with the same key has already been added..
+ /// </summary>
+ internal static string KeyAlreadyExists {
+ get {
+ return ResourceManager.GetString("KeyAlreadyExists", resourceCulture);
+ }
+ }
+
+ /// <summary>
/// Looks up a localized string similar to A message response is already queued for sending in the response stream..
/// </summary>
internal static string QueuedMessageResponseAlreadyExists {
@@ -178,6 +196,15 @@ namespace DotNetOAuth.Messaging {
}
/// <summary>
+ /// Looks up a localized string similar to The following required parameters were missing from the {0} message: {1}.
+ /// </summary>
+ internal static string RequiredParametersMissing {
+ get {
+ return ResourceManager.GetString("RequiredParametersMissing", resourceCulture);
+ }
+ }
+
+ /// <summary>
/// Looks up a localized string similar to The binding element offering the {0} protection requires other protection that is not provided..
/// </summary>
internal static string RequiredProtectionMissing {
diff --git a/src/DotNetOAuth/Messaging/MessagingStrings.resx b/src/DotNetOAuth/Messaging/MessagingStrings.resx
index 2133c44..bdbd212 100644
--- a/src/DotNetOAuth/Messaging/MessagingStrings.resx
+++ b/src/DotNetOAuth/Messaging/MessagingStrings.resx
@@ -147,6 +147,12 @@
<data name="InsufficentMessageProtection" xml:space="preserve">
<value>The message required protections {0} but the channel could only apply {1}.</value>
</data>
+ <data name="InvalidMessageParts" xml:space="preserve">
+ <value>Some part(s) of the message have invalid values: {0}</value>
+ </data>
+ <data name="KeyAlreadyExists" xml:space="preserve">
+ <value>An item with the same key has already been added.</value>
+ </data>
<data name="QueuedMessageResponseAlreadyExists" xml:space="preserve">
<value>A message response is already queued for sending in the response stream.</value>
</data>
@@ -156,6 +162,9 @@
<data name="ReplayProtectionNotSupported" xml:space="preserve">
<value>This channel does not support replay protection.</value>
</data>
+ <data name="RequiredParametersMissing" xml:space="preserve">
+ <value>The following required parameters were missing from the {0} message: {1}</value>
+ </data>
<data name="RequiredProtectionMissing" xml:space="preserve">
<value>The binding element offering the {0} protection requires other protection that is not provided.</value>
</data>
diff --git a/src/DotNetOAuth/Messaging/ProtocolException.cs b/src/DotNetOAuth/Messaging/ProtocolException.cs
index 05ff340..d8b3d33 100644
--- a/src/DotNetOAuth/Messaging/ProtocolException.cs
+++ b/src/DotNetOAuth/Messaging/ProtocolException.cs
@@ -6,6 +6,7 @@
namespace DotNetOAuth.Messaging {
using System;
+ using System.Collections.Generic;
/// <summary>
/// An exception to represent errors in the local or remote implementation of the protocol.
@@ -23,6 +24,11 @@ namespace DotNetOAuth.Messaging {
private Uri recipient;
/// <summary>
+ /// A cache for extra name/value pairs tacked on as data when this exception is sent as a message.
+ /// </summary>
+ private Dictionary<string, string> extraData = new Dictionary<string, string>();
+
+ /// <summary>
/// Initializes a new instance of the <see cref="ProtocolException"/> class.
/// </summary>
public ProtocolException() { }
@@ -148,6 +154,13 @@ namespace DotNetOAuth.Messaging {
}
}
+ /// <summary>
+ /// Gets the dictionary of additional name/value fields tacked on to this message.
+ /// </summary>
+ IDictionary<string, string> IProtocolMessage.ExtraData {
+ get { return this.extraData; }
+ }
+
#endregion
/// <summary>
diff --git a/src/DotNetOAuth/Messaging/Reflection/MessageDescription.cs b/src/DotNetOAuth/Messaging/Reflection/MessageDescription.cs
new file mode 100644
index 0000000..9442f15
--- /dev/null
+++ b/src/DotNetOAuth/Messaging/Reflection/MessageDescription.cs
@@ -0,0 +1,91 @@
+//-----------------------------------------------------------------------
+// <copyright file="MessageDescription.cs" company="Andrew Arnott">
+// Copyright (c) Andrew Arnott. All rights reserved.
+// </copyright>
+//-----------------------------------------------------------------------
+
+namespace DotNetOAuth.Messaging.Reflection {
+ using System;
+ using System.Collections.Generic;
+ using System.Linq;
+ using System.Reflection;
+ using System.Globalization;
+ using System.Diagnostics;
+
+ internal class MessageDescription {
+ private static Dictionary<Type, MessageDescription> reflectedMessageTypes = new Dictionary<Type,MessageDescription>();
+ private Type messageType;
+ private Dictionary<string, MessagePart> mapping;
+
+ private MessageDescription(Type messageType) {
+ Debug.Assert(messageType != null, "messageType == null");
+
+ if (!typeof(IProtocolMessage).IsAssignableFrom(messageType)) {
+ throw new ArgumentException(string.Format(
+ CultureInfo.CurrentCulture,
+ MessagingStrings.UnexpectedType,
+ typeof(IProtocolMessage),
+ messageType));
+ }
+
+ this.messageType = messageType;
+ this.ReflectMessageType();
+ }
+
+ internal static MessageDescription Get(Type messageType) {
+ if (messageType == null) {
+ throw new ArgumentNullException("messageType");
+ }
+
+ MessageDescription result;
+ if (!reflectedMessageTypes.TryGetValue(messageType, out result)) {
+ lock (reflectedMessageTypes) {
+ if (!reflectedMessageTypes.TryGetValue(messageType, out result)) {
+ reflectedMessageTypes[messageType] = result = new MessageDescription(messageType);
+ }
+ }
+ }
+
+ return result;
+ }
+
+ internal IDictionary<string, MessagePart> Mapping {
+ get { return this.mapping; }
+ }
+
+ internal void ReflectMessageType() {
+ this.mapping = new Dictionary<string, MessagePart>();
+
+ Type currentType = this.messageType;
+ do {
+ foreach (MemberInfo member in currentType.GetMembers(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly)) {
+ if (member is PropertyInfo || member is FieldInfo) {
+ MessagePartAttribute partAttribute = member.GetCustomAttributes(typeof(MessagePartAttribute), true).OfType<MessagePartAttribute>().FirstOrDefault();
+ if (partAttribute != null) {
+ MessagePart part = new MessagePart(member, partAttribute);
+ this.mapping.Add(part.Name, part);
+ }
+ }
+ }
+ currentType = currentType.BaseType;
+ } while (currentType != null);
+ }
+
+ /// <summary>
+ /// Verifies that a given set of keys include all the required parameters
+ /// for this message type or throws an exception.
+ /// </summary>
+ /// <exception cref="ProtocolException">Thrown when required parts of a message are not in <paramref name="keys"/></exception>
+ internal void EnsureRequiredMessagePartsArePresent(IEnumerable<string> keys) {
+ var missingKeys = (from part in Mapping.Values
+ where part.IsRequired && !keys.Contains(part.Name)
+ select part.Name).ToArray();
+ if (missingKeys.Length > 0) {
+ throw new ProtocolException(string.Format(CultureInfo.CurrentCulture,
+ MessagingStrings.RequiredParametersMissing,
+ this.messageType.FullName,
+ string.Join(", ", missingKeys)));
+ }
+ }
+ }
+}
diff --git a/src/DotNetOAuth/Messaging/Reflection/MessageDictionary.cs b/src/DotNetOAuth/Messaging/Reflection/MessageDictionary.cs
new file mode 100644
index 0000000..196af54
--- /dev/null
+++ b/src/DotNetOAuth/Messaging/Reflection/MessageDictionary.cs
@@ -0,0 +1,206 @@
+//-----------------------------------------------------------------------
+// <copyright file="MessageDictionary.cs" company="Andrew Arnott">
+// Copyright (c) Andrew Arnott. All rights reserved.
+// </copyright>
+//-----------------------------------------------------------------------
+
+namespace DotNetOAuth.Messaging.Reflection {
+ using System;
+ using System.Collections;
+ using System.Collections.Generic;
+ using System.Diagnostics;
+
+ /// <summary>
+ /// Wraps an <see cref="IProtocolMessage"/> instance in a dictionary that
+ /// provides access to both well-defined message properties and "extra"
+ /// name/value pairs that have no properties associated with them.
+ /// </summary>
+ internal class MessageDictionary : IDictionary<string, string> {
+ private IProtocolMessage message;
+
+ private MessageDescription description;
+
+ internal MessageDictionary(IProtocolMessage message) {
+ if (message == null) {
+ throw new ArgumentNullException("message");
+ }
+
+ this.message = message;
+ this.description = MessageDescription.Get(message.GetType());
+ }
+
+ #region IDictionary<string,string> Members
+
+ public void Add(string key, string value) {
+ if (value == null) {
+ throw new ArgumentNullException("value");
+ }
+
+ MessagePart part;
+ if (this.description.Mapping.TryGetValue(key, out part)) {
+ if (part.IsNondefaultValueSet(this.message)) {
+ throw new ArgumentException(MessagingStrings.KeyAlreadyExists);
+ }
+ part.SetValue(this.message, value);
+ } else {
+ this.message.ExtraData.Add(key, value);
+ }
+ }
+
+ public bool ContainsKey(string key) {
+ return this.message.ExtraData.ContainsKey(key) ||
+ (this.description.Mapping.ContainsKey(key) && this.description.Mapping[key].GetValue(this.message) != null);
+ }
+
+ public ICollection<string> Keys {
+ get {
+ List<string> keys = new List<string>(this.message.ExtraData.Count + this.description.Mapping.Count);
+ foreach (var pair in this.description.Mapping) {
+ // Don't include keys with null values, but default values for structs is ok
+ if (pair.Value.GetValue(this.message) != null) {
+ keys.Add(pair.Key);
+ }
+ }
+
+ foreach (string key in this.message.ExtraData.Keys) {
+ keys.Add(key);
+ }
+
+ return keys.AsReadOnly();
+ }
+ }
+
+ public bool Remove(string key) {
+ if (this.message.ExtraData.Remove(key)) {
+ return true;
+ } else {
+ MessagePart part;
+ if (this.description.Mapping.TryGetValue(key, out part)) {
+ if (part.GetValue(this.message) != null) {
+ part.SetValue(this.message, null);
+ return true;
+ }
+ }
+ return false;
+ }
+ }
+
+ public bool TryGetValue(string key, out string value) {
+ MessagePart part;
+ if (this.description.Mapping.TryGetValue(key, out part)) {
+ value = part.GetValue(this.message);
+ return true;
+ }
+ return this.message.ExtraData.TryGetValue(key, out value);
+ }
+
+ public ICollection<string> Values {
+ get {
+ List<string> values = new List<string>(this.message.ExtraData.Count + this.description.Mapping.Count);
+ foreach (MessagePart part in this.description.Mapping.Values) {
+ if (part.GetValue(this.message) != null) {
+ values.Add(part.GetValue(this.message));
+ }
+ }
+
+ foreach (string value in this.message.ExtraData.Values) {
+ Debug.Assert(value != null, "Null values should never be allowed in the extra data dictionary.");
+ values.Add(value);
+ }
+
+ return values.AsReadOnly();
+ }
+ }
+
+ public string this[string key] {
+ get {
+ MessagePart part;
+ if (this.description.Mapping.TryGetValue(key, out part)) {
+ // Never throw KeyNotFoundException for declared properties.
+ return part.GetValue(this.message);
+ } else {
+ return this.message.ExtraData[key];
+ }
+ }
+
+ set {
+ MessagePart part;
+ if (this.description.Mapping.TryGetValue(key, out part)) {
+ part.SetValue(this.message, value);
+ } else {
+ if (value == null) {
+ this.message.ExtraData.Remove(key);
+ } else {
+ this.message.ExtraData[key] = value;
+ }
+ }
+ }
+ }
+
+ #endregion
+
+ #region ICollection<KeyValuePair<string,string>> Members
+
+ public void Add(KeyValuePair<string, string> item) {
+ this.Add(item.Key, item.Value);
+ }
+
+ public void Clear() {
+ foreach (string key in this.Keys) {
+ this.Remove(key);
+ }
+ }
+
+ public bool Contains(KeyValuePair<string, string> item) {
+ MessagePart part;
+ if (this.description.Mapping.TryGetValue(item.Key, out part)) {
+ return string.Equals(part.GetValue(this.message), item.Value, StringComparison.Ordinal);
+ } else {
+ return this.message.ExtraData.Contains(item);
+ }
+ }
+
+ void ICollection<KeyValuePair<string, string>>.CopyTo(KeyValuePair<string, string>[] array, int arrayIndex) {
+ foreach (var pair in (IDictionary<string, string>)this) {
+ array[arrayIndex++] = pair;
+ }
+ }
+
+ public int Count {
+ get { return this.Keys.Count; }
+ }
+
+ bool ICollection<KeyValuePair<string, string>>.IsReadOnly {
+ get { return false; }
+ }
+
+ public bool Remove(KeyValuePair<string, string> item) {
+ // We use contains because that checks that the value is equal as well.
+ if (((ICollection<KeyValuePair<string, string>>)this).Contains(item)) {
+ ((IDictionary<string, string>)this).Remove(item.Key);
+ return true;
+ }
+ return false;
+ }
+
+ #endregion
+
+ #region IEnumerable<KeyValuePair<string,string>> Members
+
+ public IEnumerator<KeyValuePair<string, string>> GetEnumerator() {
+ foreach (string key in Keys) {
+ yield return new KeyValuePair<string, string>(key, this[key]);
+ }
+ }
+
+ #endregion
+
+ #region IEnumerable Members
+
+ IEnumerator System.Collections.IEnumerable.GetEnumerator() {
+ return ((IEnumerable<KeyValuePair<string, string>>)this).GetEnumerator();
+ }
+
+ #endregion
+ }
+}
diff --git a/src/DotNetOAuth/Messaging/Reflection/MessagePart.cs b/src/DotNetOAuth/Messaging/Reflection/MessagePart.cs
new file mode 100644
index 0000000..0cf7cd4
--- /dev/null
+++ b/src/DotNetOAuth/Messaging/Reflection/MessagePart.cs
@@ -0,0 +1,148 @@
+//-----------------------------------------------------------------------
+// <copyright file="MessagePart.cs" company="Andrew Arnott">
+// Copyright (c) Andrew Arnott. All rights reserved.
+// </copyright>
+//-----------------------------------------------------------------------
+
+namespace DotNetOAuth.Messaging.Reflection {
+ using System;
+ using System.Collections.Generic;
+ using System.Net.Security;
+ using System.Reflection;
+ using System.Xml;
+ using System.Globalization;
+
+ internal class MessagePart {
+ private static readonly Dictionary<Type, ValueMapping> converters = new Dictionary<Type, ValueMapping>();
+
+ private ValueMapping converter;
+
+ private PropertyInfo property;
+
+ private FieldInfo field;
+
+ private Type memberDeclaredType;
+
+ private object defaultMemberValue;
+
+ static MessagePart() {
+ Map<Uri>(uri => uri.AbsoluteUri, str => new Uri(str));
+ Map<DateTime>(dt => XmlConvert.ToString(dt, XmlDateTimeSerializationMode.Utc), str => XmlConvert.ToDateTime(str, XmlDateTimeSerializationMode.Utc));
+ }
+
+ internal MessagePart(MemberInfo member, MessagePartAttribute attribute) {
+ if (member == null) {
+ throw new ArgumentNullException("member");
+ }
+
+ this.field = member as FieldInfo;
+ this.property = member as PropertyInfo;
+ if (this.field == null && this.property == null) {
+ throw new ArgumentException(string.Format(
+ CultureInfo.CurrentCulture,
+ MessagingStrings.UnexpectedType,
+ typeof(FieldInfo).Name + ", " + typeof(PropertyInfo).Name,
+ member.GetType().Name), "member");
+ }
+
+ if (attribute == null) {
+ throw new ArgumentNullException("attribute");
+ }
+
+ this.Name = attribute.Name ?? member.Name;
+ this.RequiredProtection = attribute.RequiredProtection;
+ this.IsRequired = attribute.IsRequired;
+ this.memberDeclaredType = (this.field != null) ? this.field.FieldType : this.property.PropertyType;
+ this.defaultMemberValue = deriveDefaultValue(this.memberDeclaredType);
+
+ if (!converters.TryGetValue(this.memberDeclaredType, out this.converter)) {
+ this.converter = new ValueMapping(
+ obj => obj != null ? obj.ToString() : null,
+ str => str != null ? Convert.ChangeType(str, memberDeclaredType) : null);
+ }
+
+ // Validate a sane combination of settings
+ ValidateSettings();
+ }
+
+ internal string Name { get; set; }
+
+ internal ProtectionLevel RequiredProtection { get; set; }
+
+ internal bool IsRequired { get; set; }
+
+ internal object ToValue(string value) {
+ return this.converter.StringToValue(value);
+ }
+
+ internal string ToString(object value) {
+ return this.converter.ValueToString(value);
+ }
+
+ internal void SetValue(IProtocolMessage message, string value) {
+ if (this.property != null) {
+ this.property.SetValue(message, this.ToValue(value), null);
+ } else {
+ this.field.SetValue(message, this.ToValue(value));
+ }
+ }
+
+ internal string GetValue(IProtocolMessage message) {
+ return this.ToString(this.GetValueAsObject(message));
+ }
+
+ internal bool IsNondefaultValueSet(IProtocolMessage message) {
+ if (this.memberDeclaredType.IsValueType) {
+ return !GetValueAsObject(message).Equals(this.defaultMemberValue);
+ } else {
+ return this.defaultMemberValue != GetValueAsObject(message);
+ }
+ }
+
+ private static object deriveDefaultValue(Type type) {
+ if (type.IsValueType) {
+ return Activator.CreateInstance(type);
+ } else {
+ return null;
+ }
+ }
+
+ private object GetValueAsObject(IProtocolMessage message) {
+ if (this.property != null) {
+ return this.property.GetValue(message, null);
+ } else {
+ return this.field.GetValue(message);
+ }
+ }
+
+ private static void Map<T>(Func<T, string> toString, Func<string, T> toValue) {
+ converters.Add(
+ typeof(T),
+ new ValueMapping(
+ obj => obj != null ? toString((T)obj) : null,
+ str => str != null ? toValue(str) : default(T)));
+ }
+
+ private void ValidateSettings() {
+ // An optional tag on a non-nullable value type is a contradiction.
+ if (!this.IsRequired && IsNonNullableValueType(this.memberDeclaredType)) {
+ MemberInfo member = (MemberInfo)this.field ?? this.property;
+ throw new ArgumentException(string.Format(CultureInfo.CurrentCulture,
+ "Invalid combination: {0} on message type {1} is a non-nullable value type but is marked as optional.",
+ member.Name, member.DeclaringType));
+ }
+ }
+
+ private static bool IsNonNullableValueType(Type type) {
+ if (!type.IsValueType) {
+ return false;
+ }
+
+ if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) {
+ return false;
+ }
+
+ return true;
+ }
+ }
+}
diff --git a/src/DotNetOAuth/Messaging/Reflection/MessagePartAttribute.cs b/src/DotNetOAuth/Messaging/Reflection/MessagePartAttribute.cs
new file mode 100644
index 0000000..c8bb8f5
--- /dev/null
+++ b/src/DotNetOAuth/Messaging/Reflection/MessagePartAttribute.cs
@@ -0,0 +1,33 @@
+//-----------------------------------------------------------------------
+// <copyright file="MessagePartAttribute.cs" company="Andrew Arnott">
+// Copyright (c) Andrew Arnott. All rights reserved.
+// </copyright>
+//-----------------------------------------------------------------------
+
+namespace DotNetOAuth.Messaging.Reflection {
+ using System;
+ using System.Net.Security;
+ using System.Reflection;
+
+ [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field, Inherited = true, AllowMultiple = false)]
+ internal sealed class MessagePartAttribute : Attribute {
+ private bool initialized;
+ private string name;
+
+ internal MessagePartAttribute() {
+ }
+
+ internal MessagePartAttribute(string name) {
+ this.Name = name;
+ }
+
+ public string Name {
+ get { return this.name; }
+ set { this.name = string.IsNullOrEmpty(value) ? null : value; }
+ }
+
+ public ProtectionLevel RequiredProtection { get; set; }
+
+ public bool IsRequired { get; set; }
+ }
+}
diff --git a/src/DotNetOAuth/Messaging/Reflection/ValueMapping.cs b/src/DotNetOAuth/Messaging/Reflection/ValueMapping.cs
new file mode 100644
index 0000000..2371b49
--- /dev/null
+++ b/src/DotNetOAuth/Messaging/Reflection/ValueMapping.cs
@@ -0,0 +1,27 @@
+//-----------------------------------------------------------------------
+// <copyright file="ValueMapping.cs" company="Andrew Arnott">
+// Copyright (c) Andrew Arnott. All rights reserved.
+// </copyright>
+//-----------------------------------------------------------------------
+
+namespace DotNetOAuth.Messaging.Reflection {
+ using System;
+
+ internal struct ValueMapping {
+ internal Func<object, string> ValueToString;
+ internal Func<string, object> StringToValue;
+
+ internal ValueMapping(Func<object, string> toString, Func<string, object> toValue) {
+ if (toString == null) {
+ throw new ArgumentNullException("toString");
+ }
+
+ if (toValue == null) {
+ throw new ArgumentNullException("toValue");
+ }
+
+ this.ValueToString = toString;
+ this.StringToValue = toValue;
+ }
+ }
+}
diff --git a/src/DotNetOAuth/StandardWebRequestHandler.cs b/src/DotNetOAuth/StandardWebRequestHandler.cs
index 715da72..d56562b 100644
--- a/src/DotNetOAuth/StandardWebRequestHandler.cs
+++ b/src/DotNetOAuth/StandardWebRequestHandler.cs
@@ -18,7 +18,8 @@ namespace DotNetOAuth {
#region IWebRequestHandler Members
/// <summary>
- /// Prepares an <see cref="HttpWebRequest"/> that contains an POST entity for sending the entity.
+ /// Prepares a POST <see cref="HttpWebRequest"/> and returns the request stream
+ /// for writing out the POST entity data.
/// </summary>
/// <param name="request">The <see cref="HttpWebRequest"/> that should contain the entity.</param>
/// <returns>The stream the caller should write out the entity data to.</returns>