diff options
22 files changed, 193 insertions, 724 deletions
diff --git a/src/DotNetOAuth.Test/DotNetOAuth.Test.csproj b/src/DotNetOAuth.Test/DotNetOAuth.Test.csproj index 028e1b9..b65b5c6 100644 --- a/src/DotNetOAuth.Test/DotNetOAuth.Test.csproj +++ b/src/DotNetOAuth.Test/DotNetOAuth.Test.csproj @@ -59,14 +59,15 @@ </ItemGroup>
<ItemGroup>
<Compile Include="Messaging\CollectionAssert.cs" />
+ <Compile Include="Messaging\MessageSerializerTests.cs" />
<Compile Include="Messaging\Reflection\MessageDictionaryTest.cs" />
<Compile Include="Messaging\MessagingTestBase.cs" />
<Compile Include="Messaging\MessagingUtilitiesTests.cs" />
<Compile Include="Messaging\ChannelTests.cs" />
- <Compile Include="Messaging\DictionaryXmlReaderTests.cs" />
<Compile Include="Messaging\HttpRequestInfoTests.cs" />
<Compile Include="Messaging\ProtocolExceptionTests.cs" />
<Compile Include="Messaging\Bindings\StandardExpirationBindingElementTests.cs" />
+ <Compile Include="Messaging\Reflection\MessagePartTests.cs" />
<Compile Include="Mocks\MockTransformationBindingElement.cs" />
<Compile Include="Mocks\MockReplayProtectionBindingElement.cs" />
<Compile Include="Mocks\TestBaseMessage.cs" />
@@ -79,7 +80,6 @@ <Compile Include="Mocks\MockSigningBindingElement.cs" />
<Compile Include="Mocks\TestWebRequestHandler.cs" />
<Compile Include="OAuthChannelTests.cs" />
- <Compile Include="Messaging\MessageSerializerTests.cs" />
<Compile Include="Mocks\TestChannel.cs" />
<Compile Include="Mocks\TestMessage.cs" />
<Compile Include="Mocks\TestMessageTypeProvider.cs" />
diff --git a/src/DotNetOAuth.Test/Messaging/DictionaryXmlReaderTests.cs b/src/DotNetOAuth.Test/Messaging/DictionaryXmlReaderTests.cs deleted file mode 100644 index 54c47f0..0000000 --- a/src/DotNetOAuth.Test/Messaging/DictionaryXmlReaderTests.cs +++ /dev/null @@ -1,33 +0,0 @@ -//-----------------------------------------------------------------------
-// <copyright file="DictionaryXmlReaderTests.cs" company="Andrew Arnott">
-// Copyright (c) Andrew Arnott. All rights reserved.
-// </copyright>
-//-----------------------------------------------------------------------
-
-namespace DotNetOAuth.Test.Messaging {
- using System;
- using System.Collections.Generic;
- using System.Xml.Linq;
- using DotNetOAuth.Messaging;
- using Microsoft.VisualStudio.TestTools.UnitTesting;
-
- [TestClass]
- public class DictionaryXmlReaderTests : TestBase {
- [TestMethod, ExpectedException(typeof(ArgumentNullException))]
- public void CreateWithNullRootElement() {
- IComparer<string> fieldSorter = new DataContractMemberComparer(typeof(Mocks.TestMessage));
- DictionaryXmlReader.Create(null, fieldSorter, new Dictionary<string, string>());
- }
-
- [TestMethod, ExpectedException(typeof(ArgumentNullException))]
- public void CreateWithNullDataContractType() {
- DictionaryXmlReader.Create(XName.Get("name", "ns"), null, new Dictionary<string, string>());
- }
-
- [TestMethod, ExpectedException(typeof(ArgumentNullException))]
- public void CreateWithNullFields() {
- IComparer<string> fieldSorter = new DataContractMemberComparer(typeof(Mocks.TestMessage));
- DictionaryXmlReader.Create(XName.Get("name", "ns"), fieldSorter, null);
- }
- }
-}
diff --git a/src/DotNetOAuth.Test/Messaging/MessageSerializerTests.cs b/src/DotNetOAuth.Test/Messaging/MessageSerializerTests.cs index 162d456..2eb1dbf 100644 --- a/src/DotNetOAuth.Test/Messaging/MessageSerializerTests.cs +++ b/src/DotNetOAuth.Test/Messaging/MessageSerializerTests.cs @@ -21,18 +21,6 @@ namespace DotNetOAuth.Test.Messaging { serializer.Serialize(null);
}
- [TestMethod, ExpectedException(typeof(ArgumentNullException))]
- public void SerializeNullFields() {
- var serializer = MessageSerializer.Get(typeof(Mocks.TestMessage));
- serializer.Serialize(null, new Mocks.TestMessage());
- }
-
- [TestMethod, ExpectedException(typeof(ArgumentNullException))]
- public void SerializeNullMessage() {
- var serializer = MessageSerializer.Get(typeof(Mocks.TestMessage));
- serializer.Serialize(new Dictionary<string, string>(), null);
- }
-
[TestMethod, ExpectedException(typeof(ArgumentException))]
public void GetInvalidMessageType() {
MessageSerializer.Get(typeof(string));
@@ -43,11 +31,6 @@ namespace DotNetOAuth.Test.Messaging { MessageSerializer.Get(null);
}
- [TestMethod]
- public void GetReturnsSameSerializerTwice() {
- Assert.AreSame(MessageSerializer.Get(typeof(Mocks.TestMessage)), MessageSerializer.Get(typeof(Mocks.TestMessage)));
- }
-
[TestMethod, ExpectedException(typeof(ProtocolException))]
public void SerializeInvalidMessage() {
var serializer = MessageSerializer.Get(typeof(Mocks.TestMessage));
@@ -81,19 +64,6 @@ namespace DotNetOAuth.Test.Messaging { Assert.IsFalse(actual.ContainsKey("EmptyMember"));
}
- [TestMethod]
- public void SerializeToExistingDictionary() {
- var serializer = MessageSerializer.Get(typeof(Mocks.TestMessage));
- var message = new Mocks.TestMessage { Age = 15, Name = "Andrew" };
- var fields = new Dictionary<string, string>();
- fields["someExtraField"] = "someValue";
- serializer.Serialize(fields, message);
- Assert.AreEqual(4, fields.Count);
- Assert.AreEqual("15", fields["age"]);
- Assert.AreEqual("Andrew", fields["Name"]);
- Assert.AreEqual("someValue", fields["someExtraField"]);
- }
-
[TestMethod, ExpectedException(typeof(ArgumentNullException))]
public void DeserializeNull() {
var serializer = MessageSerializer.Get(typeof(Mocks.TestMessage));
@@ -115,7 +85,7 @@ namespace DotNetOAuth.Test.Messaging { }
/// <summary>
- /// This tests deserialization of a message that is comprised of [DataMember]'s
+ /// This tests deserialization of a message that is comprised of [MessagePart]'s
/// that are defined in multiple places in the inheritance tree.
/// </summary>
/// <remarks>
diff --git a/src/DotNetOAuth.Test/Messaging/Reflection/MessageDictionaryTest.cs b/src/DotNetOAuth.Test/Messaging/Reflection/MessageDictionaryTest.cs index 6bbd849..f798bde 100644 --- a/src/DotNetOAuth.Test/Messaging/Reflection/MessageDictionaryTest.cs +++ b/src/DotNetOAuth.Test/Messaging/Reflection/MessageDictionaryTest.cs @@ -37,9 +37,6 @@ namespace DotNetOAuth.Test.Messaging.Reflection { IDictionary<string, string> target = new MessageDictionary(this.message);
Collection<string> expected = new Collection<string> {
this.message.Age.ToString(),
- this.message.EmptyMember,
- null, // this.message.Location.AbsoluteUri, (Location is null)
- this.message.Name,
this.message.Timestamp.ToString(),
};
CollectionAssert<string>.AreEquivalent(expected, target.Values);
@@ -50,7 +47,6 @@ namespace DotNetOAuth.Test.Messaging.Reflection { target["extra"] = "a";
expected = new Collection<string> {
this.message.Age.ToString(),
- this.message.EmptyMember,
this.message.Location.AbsoluteUri,
this.message.Name,
this.message.Timestamp.ToString(),
@@ -67,13 +63,12 @@ namespace DotNetOAuth.Test.Messaging.Reflection { IDictionary<string, string> target = new MessageDictionary(this.message);
Collection<string> expected = new Collection<string> {
"age",
- "EmptyMember",
- "Location",
- "Name",
"Timestamp",
};
CollectionAssert<string>.AreEquivalent(expected, target.Keys);
+ this.message.Name = "Andrew";
+ expected.Add("Name");
target["extraField"] = string.Empty;
expected.Add("extraField");
CollectionAssert<string>.AreEquivalent(expected, target.Keys);
@@ -213,9 +208,7 @@ namespace DotNetOAuth.Test.Messaging.Reflection { public void ContainsKeyTest() {
IDictionary<string, string> target = new MessageDictionary(this.message);
Assert.IsTrue(target.ContainsKey("age"), "Value type declared element should have a key.");
- Assert.IsTrue(target.ContainsKey("Name"), "Null declared element should have a key.");
- target.Remove("Name");
- Assert.IsTrue(target.ContainsKey("Name"), "Removed declared element should still have a key.");
+ Assert.IsFalse(target.ContainsKey("Name"), "Null declared element should NOT have a key.");
Assert.IsFalse(target.ContainsKey("extra"));
target["extra"] = "value";
@@ -258,6 +251,23 @@ namespace DotNetOAuth.Test.Messaging.Reflection { target.Add("Name", "andrew");
}
+ [TestMethod]
+ public void DefaultReferenceTypeDeclaredPropertyHasNoKey() {
+ IDictionary<string, string> target = new MessageDictionary(this.message);
+ Assert.IsFalse(target.ContainsKey("Name"), "A null value should result in no key.");
+ Assert.IsFalse(target.Keys.Contains("Name"), "A null value should result in no key.");
+ }
+
+ [TestMethod]
+ public void RemoveStructDeclaredProperty() {
+ IDictionary<string, string> target = new MessageDictionary(this.message);
+ this.message.Age = 5;
+ Assert.IsTrue(target.ContainsKey("age"));
+ target.Remove("age");
+ Assert.IsTrue(target.ContainsKey("age"));
+ Assert.AreEqual(0, this.message.Age);
+ }
+
/// <summary>
/// A test for System.Collections.Generic.ICollection<System.Collections.Generic.KeyValuePair<System.String,System.String<<.Remove
/// </summary>
@@ -314,9 +324,8 @@ namespace DotNetOAuth.Test.Messaging.Reflection { this.message.Name = "Andrew";
this.message.Age = 15;
targetAsDictionary["extra"] = "value";
- int countBeforeClear = targetAsDictionary.Count;
target.Clear();
- Assert.AreEqual(countBeforeClear - 1, target.Count, "Clearing with one extra parameter should reduce count by 1.");
+ Assert.AreEqual(2, target.Count, "Clearing should remove all keys except for declared non-nullable structs.");
Assert.IsFalse(targetAsDictionary.ContainsKey("extra"));
Assert.IsNull(this.message.Name);
Assert.AreEqual(0, this.message.Age);
diff --git a/src/DotNetOAuth.Test/Messaging/Reflection/MessagePartTests.cs b/src/DotNetOAuth.Test/Messaging/Reflection/MessagePartTests.cs new file mode 100644 index 0000000..646599c --- /dev/null +++ b/src/DotNetOAuth.Test/Messaging/Reflection/MessagePartTests.cs @@ -0,0 +1,27 @@ +using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using DotNetOAuth.Messaging.Reflection;
+using System.Reflection;
+
+namespace DotNetOAuth.Test.Messaging.Reflection {
+ [TestClass]
+ public class MessagePartTests :MessagingTestBase {
+ class MessageWithNonNullableOptionalStruct {
+ /// <summary>
+ /// Optional structs like int must be nullable for Optional to make sense.
+ /// </summary>
+ [MessagePart(IsRequired = false)]
+ internal int optionalInt;
+ }
+
+ [TestMethod, ExpectedException(typeof(ArgumentException))]
+ public void OptionalNonNullableStruct() {
+ FieldInfo field = typeof(MessageWithNonNullableOptionalStruct).GetField("optionalInt", BindingFlags.NonPublic | BindingFlags.Instance);
+ MessagePartAttribute attribute = field.GetCustomAttributes(typeof(MessagePartAttribute), true).OfType<MessagePartAttribute>().Single();
+ new MessagePart(field, attribute); // should recognize invalid optional non-nullable struct
+ }
+ }
+}
diff --git a/src/DotNetOAuth.Test/Mocks/TestBaseMessage.cs b/src/DotNetOAuth.Test/Mocks/TestBaseMessage.cs index 8f500cf..2a8cb30 100644 --- a/src/DotNetOAuth.Test/Mocks/TestBaseMessage.cs +++ b/src/DotNetOAuth.Test/Mocks/TestBaseMessage.cs @@ -9,6 +9,7 @@ namespace DotNetOAuth.Test.Mocks { using System.Collections.Generic;
using System.Runtime.Serialization;
using DotNetOAuth.Messaging;
+ using DotNetOAuth.Messaging.Reflection;
internal interface IBaseMessageExplicitMembers {
string ExplicitProperty { get; set; }
@@ -18,13 +19,13 @@ namespace DotNetOAuth.Test.Mocks { internal class TestBaseMessage : IProtocolMessage, IBaseMessageExplicitMembers {
private Dictionary<string, string> extraData = new Dictionary<string, string>();
- [DataMember(Name = "age", IsRequired = true)]
+ [MessagePart(Name = "age", IsRequired = true)]
public int Age { get; set; }
- [DataMember]
+ [MessagePart]
public string Name { get; set; }
- [DataMember(Name = "explicit")]
+ [MessagePart(Name = "explicit")]
string IBaseMessageExplicitMembers.ExplicitProperty { get; set; }
Version IProtocolMessage.ProtocolVersion {
@@ -48,7 +49,7 @@ namespace DotNetOAuth.Test.Mocks { set { this.PrivateProperty = value; }
}
- [DataMember(Name = "private")]
+ [MessagePart(Name = "private")]
private string PrivateProperty { get; set; }
void IProtocolMessage.EnsureValidMessage() { }
diff --git a/src/DotNetOAuth.Test/Mocks/TestDerivedMessage.cs b/src/DotNetOAuth.Test/Mocks/TestDerivedMessage.cs index afd67f6..69e58aa 100644 --- a/src/DotNetOAuth.Test/Mocks/TestDerivedMessage.cs +++ b/src/DotNetOAuth.Test/Mocks/TestDerivedMessage.cs @@ -6,6 +6,7 @@ namespace DotNetOAuth.Test.Mocks {
using System.Runtime.Serialization;
+ using DotNetOAuth.Messaging.Reflection;
[DataContract(Namespace = Protocol.DataContractNamespaceV10)]
internal class TestDerivedMessage : TestBaseMessage {
@@ -17,7 +18,7 @@ namespace DotNetOAuth.Test.Mocks { /// due to alphabetical ordering rules, but after all the elements in the
/// base class due to inheritance rules.
/// </remarks>
- [DataMember]
+ [MessagePart]
public string TheFirstDerivedElement { get; set; }
/// <summary>
@@ -27,7 +28,7 @@ namespace DotNetOAuth.Test.Mocks { /// This element should appear BEFORE <see cref="TheFirstDerivedElement"/>,
/// but after all the elements in the base class.
/// </remarks>
- [DataMember]
+ [MessagePart]
public string SecondDerivedElement { get; set; }
}
}
diff --git a/src/DotNetOAuth.Test/Mocks/TestDirectedMessage.cs b/src/DotNetOAuth.Test/Mocks/TestDirectedMessage.cs index cdd130d..8066adf 100644 --- a/src/DotNetOAuth.Test/Mocks/TestDirectedMessage.cs +++ b/src/DotNetOAuth.Test/Mocks/TestDirectedMessage.cs @@ -9,6 +9,7 @@ namespace DotNetOAuth.Test.Mocks { using System.Collections.Generic;
using System.Runtime.Serialization;
using DotNetOAuth.Messaging;
+ using DotNetOAuth.Messaging.Reflection;
[DataContract(Namespace = Protocol.DataContractNamespaceV10)]
internal class TestDirectedMessage : IDirectedProtocolMessage {
@@ -16,17 +17,20 @@ namespace DotNetOAuth.Test.Mocks { private Dictionary<string, string> extraData = new Dictionary<string, string>();
+ internal TestDirectedMessage() {
+ }
+
internal TestDirectedMessage(MessageTransport transport) {
this.transport = transport;
}
- [DataMember(Name = "age", IsRequired = true)]
+ [MessagePart(Name = "age", IsRequired = true)]
public int Age { get; set; }
- [DataMember]
+ [MessagePart]
public string Name { get; set; }
- [DataMember]
+ [MessagePart]
public string EmptyMember { get; set; }
- [DataMember]
+ [MessagePart]
public Uri Location { get; set; }
#region IDirectedProtocolMessage Members
diff --git a/src/DotNetOAuth.Test/Mocks/TestExpiringMessage.cs b/src/DotNetOAuth.Test/Mocks/TestExpiringMessage.cs index d51e8ee..fbe0d9a 100644 --- a/src/DotNetOAuth.Test/Mocks/TestExpiringMessage.cs +++ b/src/DotNetOAuth.Test/Mocks/TestExpiringMessage.cs @@ -10,19 +10,22 @@ namespace DotNetOAuth.Test.Mocks { using System.Runtime.Serialization;
using DotNetOAuth.Messaging;
using DotNetOAuth.Messaging.Bindings;
+ using DotNetOAuth.Messaging.Reflection;
[DataContract(Namespace = Protocol.DataContractNamespaceV10)]
internal class TestExpiringMessage : TestSignedDirectedMessage, IExpiringProtocolMessage {
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private DateTime utcCreationDate;
+ internal TestExpiringMessage() { }
+
internal TestExpiringMessage(MessageTransport transport)
: base(transport) {
}
#region IExpiringProtocolMessage Members
- [DataMember(Name = "created_on")]
+ [MessagePart(Name = "created_on")]
DateTime IExpiringProtocolMessage.UtcCreationDate {
get { return this.utcCreationDate; }
set { this.utcCreationDate = value.ToUniversalTime(); }
diff --git a/src/DotNetOAuth.Test/Mocks/TestMessage.cs b/src/DotNetOAuth.Test/Mocks/TestMessage.cs index 1a9840d..ceb6dbd 100644 --- a/src/DotNetOAuth.Test/Mocks/TestMessage.cs +++ b/src/DotNetOAuth.Test/Mocks/TestMessage.cs @@ -9,6 +9,7 @@ namespace DotNetOAuth.Test.Mocks { using System.Collections.Generic;
using System.Runtime.Serialization;
using DotNetOAuth.Messaging;
+ using DotNetOAuth.Messaging.Reflection;
[DataContract(Namespace = Protocol.DataContractNamespaceV10)]
internal class TestMessage : IProtocolMessage {
@@ -23,20 +24,15 @@ namespace DotNetOAuth.Test.Mocks { this.transport = transport;
}
- [DataMember(Name = "age", IsRequired = true)]
- [MessagePart("age")]
+ [MessagePart(Name = "age", IsRequired = true)]
public int Age { get; set; }
- [DataMember]
- [MessagePart(Optional = true)]
+ [MessagePart]
public string Name { get; set; }
- [DataMember]
- [MessagePart(Optional = true)]
+ [MessagePart]
public string EmptyMember { get; set; }
- [DataMember]
- [MessagePart(Optional = true)]
+ [MessagePart]
public Uri Location { get; set; }
- [DataMember]
- [MessagePart(Optional = true)]
+ [MessagePart]
public DateTime Timestamp { get; set; }
#region IProtocolMessage Members
diff --git a/src/DotNetOAuth.Test/Mocks/TestReplayProtectedMessage.cs b/src/DotNetOAuth.Test/Mocks/TestReplayProtectedMessage.cs index b62957b..f6c3b39 100644 --- a/src/DotNetOAuth.Test/Mocks/TestReplayProtectedMessage.cs +++ b/src/DotNetOAuth.Test/Mocks/TestReplayProtectedMessage.cs @@ -8,16 +8,19 @@ namespace DotNetOAuth.Test.Mocks { using System.Runtime.Serialization;
using DotNetOAuth.Messaging;
using DotNetOAuth.Messaging.Bindings;
+ using DotNetOAuth.Messaging.Reflection;
[DataContract(Namespace = Protocol.DataContractNamespaceV10)]
internal class TestReplayProtectedMessage : TestExpiringMessage, IReplayProtectedProtocolMessage {
+ internal TestReplayProtectedMessage() { }
+
internal TestReplayProtectedMessage(MessageTransport transport)
: base(transport) {
}
#region IReplayProtectedProtocolMessage Members
- [DataMember(Name = "Nonce")]
+ [MessagePart(Name = "Nonce")]
string IReplayProtectedProtocolMessage.Nonce {
get;
set;
diff --git a/src/DotNetOAuth.Test/Mocks/TestSignedDirectedMessage.cs b/src/DotNetOAuth.Test/Mocks/TestSignedDirectedMessage.cs index d4d2536..bda4255 100644 --- a/src/DotNetOAuth.Test/Mocks/TestSignedDirectedMessage.cs +++ b/src/DotNetOAuth.Test/Mocks/TestSignedDirectedMessage.cs @@ -7,17 +7,20 @@ namespace DotNetOAuth.Test.Mocks {
using System.Runtime.Serialization;
using DotNetOAuth.Messaging;
+ using DotNetOAuth.Messaging.Reflection;
using DotNetOAuth.Messaging.Bindings;
[DataContract(Namespace = Protocol.DataContractNamespaceV10)]
internal class TestSignedDirectedMessage : TestDirectedMessage, ITamperResistantProtocolMessage {
+ internal TestSignedDirectedMessage() { }
+
internal TestSignedDirectedMessage(MessageTransport transport)
: base(transport) {
}
#region ISignedProtocolMessage Members
- [DataMember]
+ [MessagePart]
public string Signature {
get;
set;
diff --git a/src/DotNetOAuth.Test/OAuthChannelTests.cs b/src/DotNetOAuth.Test/OAuthChannelTests.cs index 2d00800..8011496 100644 --- a/src/DotNetOAuth.Test/OAuthChannelTests.cs +++ b/src/DotNetOAuth.Test/OAuthChannelTests.cs @@ -216,6 +216,7 @@ namespace DotNetOAuth.Test { HttpRequestInfo reqInfo = ConvertToRequestInfo(req, this.webRequestHandler.RequestEntityStream);
Assert.AreEqual(scheme == MessageScheme.PostRequest ? "POST" : "GET", reqInfo.HttpMethod);
var incomingMessage = this.channel.ReadFromRequest(reqInfo) as TestMessage;
+ Assert.IsNotNull(incomingMessage);
Assert.AreEqual(request.Age, incomingMessage.Age);
Assert.AreEqual(request.Name, incomingMessage.Name);
Assert.AreEqual(request.Location, incomingMessage.Location);
diff --git a/src/DotNetOAuth/DotNetOAuth.csproj b/src/DotNetOAuth/DotNetOAuth.csproj index c53f1e0..f5882ef 100644 --- a/src/DotNetOAuth/DotNetOAuth.csproj +++ b/src/DotNetOAuth/DotNetOAuth.csproj @@ -68,22 +68,20 @@ <ItemGroup>
<Compile Include="Consumer.cs" />
<Compile Include="IWebRequestHandler.cs" />
- <Compile Include="Messaging\MessagePartAttribute.cs" />
+ <Compile Include="Messaging\Reflection\MessagePartAttribute.cs" />
<Compile Include="Messaging\MessageProtection.cs" />
<Compile Include="Messaging\IChannelBindingElement.cs" />
<Compile Include="Messaging\Bindings\ReplayedMessageException.cs" />
<Compile Include="Messaging\Bindings\ExpiredMessageException.cs" />
- <Compile Include="Messaging\DataContractMemberComparer.cs" />
<Compile Include="Messaging\Bindings\InvalidSignatureException.cs" />
<Compile Include="Messaging\Bindings\IReplayProtectedProtocolMessage.cs" />
<Compile Include="Messaging\Bindings\IExpiringProtocolMessage.cs" />
- <Compile Include="Messaging\DictionaryXmlReader.cs" />
- <Compile Include="Messaging\DictionaryXmlWriter.cs" />
<Compile Include="Messaging\Channel.cs" />
<Compile Include="Messaging\HttpRequestInfo.cs" />
<Compile Include="Messaging\IDirectedProtocolMessage.cs" />
<Compile Include="Messaging\IMessageTypeProvider.cs" />
<Compile Include="Messaging\Bindings\ITamperResistantProtocolMessage.cs" />
+ <Compile Include="Messaging\MessageSerializer.cs" />
<Compile Include="Messaging\MessagingStrings.Designer.cs">
<AutoGen>True</AutoGen>
<DesignTime>True</DesignTime>
@@ -108,7 +106,6 @@ <Compile Include="Messaging\MessageTransport.cs" />
<Compile Include="OAuthMessageTypeProvider.cs" />
<Compile Include="Messaging\ProtocolException.cs" />
- <Compile Include="Messaging\MessageSerializer.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="StandardWebRequestHandler.cs" />
<Compile Include="Util.cs" />
diff --git a/src/DotNetOAuth/Messaging/Channel.cs b/src/DotNetOAuth/Messaging/Channel.cs index aebf6e4..e3b7a73 100644 --- a/src/DotNetOAuth/Messaging/Channel.cs +++ b/src/DotNetOAuth/Messaging/Channel.cs @@ -179,7 +179,10 @@ namespace DotNetOAuth.Messaging { /// <returns>The deserialized message, if one is found. Null otherwise.</returns>
protected internal IProtocolMessage ReadFromRequest(HttpRequestInfo httpRequest) {
IProtocolMessage requestMessage = this.ReadFromRequestInternal(httpRequest);
- this.VerifyMessageAfterReceiving(requestMessage);
+ if (requestMessage != null) {
+ this.VerifyMessageAfterReceiving(requestMessage);
+ }
+
return requestMessage;
}
@@ -521,6 +524,11 @@ namespace DotNetOAuth.Messaging { if ((message.RequiredProtection & appliedProtection) != message.RequiredProtection) {
throw new UnprotectedMessageException(message, appliedProtection);
}
+
+ // TODO: call MessagePart.IsValidValue()
+
+
+ message.EnsureValidMessage();
}
}
}
diff --git a/src/DotNetOAuth/Messaging/DataContractMemberComparer.cs b/src/DotNetOAuth/Messaging/DataContractMemberComparer.cs deleted file mode 100644 index 4061b57..0000000 --- a/src/DotNetOAuth/Messaging/DataContractMemberComparer.cs +++ /dev/null @@ -1,103 +0,0 @@ -//-----------------------------------------------------------------------
-// <copyright file="DataContractMemberComparer.cs" company="Andrew Arnott">
-// Copyright (c) Andrew Arnott. All rights reserved.
-// </copyright>
-//-----------------------------------------------------------------------
-
-namespace DotNetOAuth.Messaging {
- using System;
- using System.Collections.Generic;
- using System.Diagnostics;
- using System.IO;
- using System.Linq;
- using System.Reflection;
- using System.Runtime.Serialization;
- using System.Xml;
- using System.Xml.Linq;
-
- /// <summary>
- /// A sorting tool to arrange fields in an order expected by the <see cref="DataContractSerializer"/>.
- /// </summary>
- internal class DataContractMemberComparer : IComparer<string> {
- /// <summary>
- /// The cached calculated inheritance ranking of every [DataMember] member of a type.
- /// </summary>
- private Dictionary<string, int> ranking;
-
- /// <summary>
- /// Initializes a new instance of the <see cref="DataContractMemberComparer"/> class.
- /// </summary>
- /// <param name="dataContractType">The data contract type that will be deserialized to.</param>
- internal DataContractMemberComparer(Type dataContractType) {
- // The elements must be serialized in inheritance rank and alphabetical order
- // so the DataContractSerializer will see them.
- this.ranking = GetDataMemberInheritanceRanking(dataContractType);
- }
-
- #region IComparer<string> Members
-
- /// <summary>
- /// Compares to fields and decides what order they should appear in.
- /// </summary>
- /// <param name="field1">The first field.</param>
- /// <param name="field2">The second field.</param>
- /// <returns>-1 if the first field should appear first, 0 if it doesn't matter, 1 if it should appear last.</returns>
- public int Compare(string field1, string field2) {
- int rank1, rank2;
- bool field1Valid = this.ranking.TryGetValue(field1, out rank1);
- bool field2Valid = this.ranking.TryGetValue(field2, out rank2);
-
- // If both fields are invalid, we don't care about the order.
- if (!field1Valid && !field2Valid) {
- return 0;
- }
-
- // If exactly one is valid, put that one first.
- if (field1Valid ^ field2Valid) {
- return field1Valid ? -1 : 1;
- }
-
- // First compare their inheritance ranking.
- if (rank1 != rank2) {
- // We want DESCENDING rank order, putting the members defined in the most
- // base class first.
- return -rank1.CompareTo(rank2);
- }
-
- // Finally sort alphabetically with case sensitivity.
- return string.CompareOrdinal(field1, field2);
- }
-
- #endregion
-
- /// <summary>
- /// Generates a dictionary of field name and inheritance rankings for a given DataContract type.
- /// </summary>
- /// <param name="type">The type to generate member rankings for.</param>
- /// <returns>The generated dictionary.</returns>
- private static Dictionary<string, int> GetDataMemberInheritanceRanking(Type type) {
- Debug.Assert(type != null, "type == null");
- var ranking = new Dictionary<string, int>();
-
- // TODO: review partial trust scenarios and this NonPublic flag.
- Type currentType = type;
- int rank = 0;
- do {
- foreach (MemberInfo member in currentType.GetMembers(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly)) {
- if (member is PropertyInfo || member is FieldInfo) {
- DataMemberAttribute dataMemberAttribute = member.GetCustomAttributes(typeof(DataMemberAttribute), true).OfType<DataMemberAttribute>().FirstOrDefault();
- if (dataMemberAttribute != null) {
- string name = dataMemberAttribute.Name ?? member.Name;
- ranking.Add(name, rank);
- }
- }
- }
-
- rank++;
- currentType = currentType.BaseType;
- } while (currentType != null);
-
- return ranking;
- }
- }
-}
diff --git a/src/DotNetOAuth/Messaging/DictionaryXmlReader.cs b/src/DotNetOAuth/Messaging/DictionaryXmlReader.cs deleted file mode 100644 index 11d553b..0000000 --- a/src/DotNetOAuth/Messaging/DictionaryXmlReader.cs +++ /dev/null @@ -1,92 +0,0 @@ -//-----------------------------------------------------------------------
-// <copyright file="DictionaryXmlReader.cs" company="Andrew Arnott">
-// Copyright (c) Andrew Arnott. All rights reserved.
-// </copyright>
-//-----------------------------------------------------------------------
-
-namespace DotNetOAuth.Messaging {
- using System;
- using System.Collections.Generic;
- using System.Diagnostics;
- using System.IO;
- using System.Linq;
- using System.Reflection;
- using System.Runtime.Serialization;
- using System.Xml;
- using System.Xml.Linq;
-
- /// <summary>
- /// An XmlReader-looking object that actually reads from a dictionary.
- /// </summary>
- internal class DictionaryXmlReader {
- /// <summary>
- /// Creates an XmlReader that reads data out of a dictionary instead of XML.
- /// </summary>
- /// <param name="rootElement">The name of the root XML element.</param>
- /// <param name="fieldSorter">The field sorter so that the XmlReader generates xml elements in the order required by the <see cref="DataContractSerializer"/>.</param>
- /// <param name="fields">The dictionary to read data from.</param>
- /// <returns>The XmlReader that will read the data out of the given dictionary.</returns>
- internal static XmlReader Create(XName rootElement, IComparer<string> fieldSorter, IDictionary<string, string> fields) {
- if (rootElement == null) {
- throw new ArgumentNullException("rootElement");
- }
- if (fieldSorter == null) {
- throw new ArgumentNullException("fieldSorter");
- }
- if (fields == null) {
- throw new ArgumentNullException("fields");
- }
-
- return CreateRoundtripReader(rootElement, fieldSorter, fields);
- }
-
- /// <summary>
- /// Creates an <see cref="XmlReader"/> that will read values out of a dictionary.
- /// </summary>
- /// <param name="rootElement">The surrounding root XML element to generate.</param>
- /// <param name="fieldSorter">The field sorter so that the XmlReader generates xml elements in the order required by the <see cref="DataContractSerializer"/>.</param>
- /// <param name="fields">The dictionary to list values from.</param>
- /// <returns>The generated <see cref="XmlReader"/>.</returns>
- private static XmlReader CreateRoundtripReader(XName rootElement, IComparer<string> fieldSorter, IDictionary<string, string> fields) {
- Debug.Assert(rootElement != null, "rootElement == null");
- Debug.Assert(fieldSorter != null, "fieldSorter == null");
- Debug.Assert(fields != null, "fields == null");
-
- MemoryStream stream = new MemoryStream();
- XmlWriter writer = XmlWriter.Create(stream);
- SerializeDictionaryToXml(writer, rootElement, fieldSorter, fields);
- writer.Flush();
- stream.Seek(0, SeekOrigin.Begin);
-
- // For debugging purposes.
- StreamReader sr = new StreamReader(stream);
- Trace.WriteLine(sr.ReadToEnd());
- stream.Seek(0, SeekOrigin.Begin);
-
- return XmlReader.Create(stream);
- }
-
- /// <summary>
- /// Writes out the values in a dictionary as XML.
- /// </summary>
- /// <param name="writer">The <see cref="XmlWriter"/> to write out the XML to.</param>
- /// <param name="rootElement">The name of the root element to use to surround the dictionary values.</param>
- /// <param name="fieldSorter">The field sorter so that the XmlReader generates xml elements in the order required by the <see cref="DataContractSerializer"/>.</param>
- /// <param name="fields">The dictionary with values to serialize.</param>
- private static void SerializeDictionaryToXml(XmlWriter writer, XName rootElement, IComparer<string> fieldSorter, IDictionary<string, string> fields) {
- Debug.Assert(writer != null, "writer == null");
- Debug.Assert(rootElement != null, "rootElement == null");
- Debug.Assert(fields != null, "fields == null");
-
- writer.WriteStartElement(rootElement.LocalName, rootElement.NamespaceName);
-
- foreach (var pair in fields.OrderBy(pair => pair.Key, fieldSorter)) {
- writer.WriteStartElement(pair.Key, rootElement.NamespaceName);
- writer.WriteValue(pair.Value);
- writer.WriteEndElement();
- }
-
- writer.WriteEndElement();
- }
- }
-}
diff --git a/src/DotNetOAuth/Messaging/DictionaryXmlWriter.cs b/src/DotNetOAuth/Messaging/DictionaryXmlWriter.cs deleted file mode 100644 index 3be043f..0000000 --- a/src/DotNetOAuth/Messaging/DictionaryXmlWriter.cs +++ /dev/null @@ -1,273 +0,0 @@ -//-----------------------------------------------------------------------
-// <copyright file="DictionaryXmlWriter.cs" company="Andrew Arnott">
-// Copyright (c) Andrew Arnott. All rights reserved.
-// </copyright>
-//-----------------------------------------------------------------------
-
-namespace DotNetOAuth.Messaging {
- using System;
- using System.Collections.Generic;
- using System.Text;
- using System.Xml;
-
- /// <summary>
- /// An XmlWriter-looking object that actually saves data to a dictionary.
- /// </summary>
- internal class DictionaryXmlWriter {
- /// <summary>
- /// Creates an <see cref="XmlWriter"/> that actually writes to an IDictionary<string, string> instance.
- /// </summary>
- /// <param name="dictionary">The dictionary to save the written XML to.</param>
- /// <returns>The XmlWriter that will save data to the given dictionary.</returns>
- internal static XmlWriter Create(IDictionary<string, string> dictionary) {
- return new PseudoXmlWriter(dictionary);
- }
-
- /// <summary>
- /// Writes out a dictionary as if it were XML.
- /// </summary>
- private class PseudoXmlWriter : XmlWriter {
- /// <summary>
- /// The dictionary to write values to.
- /// </summary>
- private IDictionary<string, string> dictionary;
-
- /// <summary>
- /// The key being written at the moment.
- /// </summary>
- private string key;
-
- /// <summary>
- /// The value being written out at the moment.
- /// </summary>
- private StringBuilder value = new StringBuilder();
-
- /// <summary>
- /// Initializes a new instance of the <see cref="PseudoXmlWriter"/> class.
- /// </summary>
- /// <param name="dictionary">The dictionary that will be written to.</param>
- internal PseudoXmlWriter(IDictionary<string, string> dictionary) {
- if (dictionary == null) {
- throw new ArgumentNullException("dictionary");
- }
-
- this.dictionary = dictionary;
- }
-
- /// <summary>
- /// Gets the spoofed state of the <see cref="XmlWriter"/>.
- /// </summary>
- public override WriteState WriteState {
- get { return WriteState.Element; }
- }
-
- /// <summary>
- /// Prepares to write out a new key/value pair with the given key name to the dictionary.
- /// </summary>
- /// <param name="prefix">This parameter is ignored.</param>
- /// <param name="localName">The key to store in the dictionary.</param>
- /// <param name="ns">This parameter is ignored.</param>
- public override void WriteStartElement(string prefix, string localName, string ns) {
- this.key = localName;
- this.value.Length = 0;
- }
-
- /// <summary>
- /// Appends some text to the value that is to be stored in the dictionary.
- /// </summary>
- /// <param name="text">The text to append to the value.</param>
- public override void WriteString(string text) {
- if (!string.IsNullOrEmpty(this.key)) {
- this.value.Append(text);
- }
- }
-
- /// <summary>
- /// Writes out a completed key/value to the dictionary.
- /// </summary>
- public override void WriteEndElement() {
- if (this.key != null) {
- this.dictionary[this.key] = this.value.ToString();
- this.key = null;
- this.value.Length = 0;
- }
- }
-
- /// <summary>
- /// Clears the internal key/value building state.
- /// </summary>
- /// <param name="prefix">This parameter is ignored.</param>
- /// <param name="localName">This parameter is ignored.</param>
- /// <param name="ns">This parameter is ignored.</param>
- public override void WriteStartAttribute(string prefix, string localName, string ns) {
- this.key = null;
- }
-
- /// <summary>
- /// This method does not do anything.
- /// </summary>
- public override void WriteEndAttribute() { }
-
- /// <summary>
- /// This method does not do anything.
- /// </summary>
- public override void Close() { }
-
- #region Unimplemented methods
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- public override void Flush() {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="ns">This parameter is ignored.</param>
- /// <returns>None, since an exception is always thrown.</returns>
- public override string LookupPrefix(string ns) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="buffer">This parameter is ignored.</param>
- /// <param name="index">This parameter is ignored.</param>
- /// <param name="count">This parameter is ignored.</param>
- public override void WriteBase64(byte[] buffer, int index, int count) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="text">This parameter is ignored.</param>
- public override void WriteCData(string text) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="ch">This parameter is ignored.</param>
- public override void WriteCharEntity(char ch) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="buffer">This parameter is ignored.</param>
- /// <param name="index">This parameter is ignored.</param>
- /// <param name="count">This parameter is ignored.</param>
- public override void WriteChars(char[] buffer, int index, int count) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="text">This parameter is ignored.</param>
- public override void WriteComment(string text) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="name">This parameter is ignored.</param>
- /// <param name="pubid">This parameter is ignored.</param>
- /// <param name="sysid">This parameter is ignored.</param>
- /// <param name="subset">This parameter is ignored.</param>
- public override void WriteDocType(string name, string pubid, string sysid, string subset) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- public override void WriteEndDocument() {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="name">This parameter is ignored.</param>
- public override void WriteEntityRef(string name) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- public override void WriteFullEndElement() {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="name">This parameter is ignored.</param>
- /// <param name="text">This parameter is ignored.</param>
- public override void WriteProcessingInstruction(string name, string text) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="data">This parameter is ignored.</param>
- public override void WriteRaw(string data) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="buffer">This parameter is ignored.</param>
- /// <param name="index">This parameter is ignored.</param>
- /// <param name="count">This parameter is ignored.</param>
- public override void WriteRaw(char[] buffer, int index, int count) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="standalone">This parameter is ignored.</param>
- public override void WriteStartDocument(bool standalone) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- public override void WriteStartDocument() {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="lowChar">This parameter is ignored.</param>
- /// <param name="highChar">This parameter is ignored.</param>
- public override void WriteSurrogateCharEntity(char lowChar, char highChar) {
- throw new NotImplementedException();
- }
-
- /// <summary>
- /// Throws <see cref="NotImplementedException"/>.
- /// </summary>
- /// <param name="ws">This parameter is ignored.</param>
- public override void WriteWhitespace(string ws) {
- throw new NotImplementedException();
- }
-
- #endregion
- }
- }
-}
diff --git a/src/DotNetOAuth/Messaging/MessageSerializer.cs b/src/DotNetOAuth/Messaging/MessageSerializer.cs index ad3a12d..4a06076 100644 --- a/src/DotNetOAuth/Messaging/MessageSerializer.cs +++ b/src/DotNetOAuth/Messaging/MessageSerializer.cs @@ -19,34 +19,12 @@ namespace DotNetOAuth.Messaging { /// </summary>
internal class MessageSerializer {
/// <summary>
- /// The serializer that will be used as a reflection engine to extract
- /// the OAuth message properties out of their containing <see cref="IProtocolMessage"/>
- /// objects.
- /// </summary>
- private readonly DataContractSerializer serializer;
-
- /// <summary>
/// The specific <see cref="IProtocolMessage"/>-derived type
/// that will be serialized and deserialized using this class.
/// </summary>
private readonly Type messageType;
/// <summary>
- /// An AppDomain-wide cache of shared serializers for optimization purposes.
- /// </summary>
- private static Dictionary<Type, MessageSerializer> prebuiltSerializers = new Dictionary<Type, MessageSerializer>();
-
- /// <summary>
- /// Backing field for the <see cref="RootElement"/> property
- /// </summary>
- private XName rootElement;
-
- /// <summary>
- /// A field sorter that puts fields in the right order for the <see cref="DataContractSerializer"/>.
- /// </summary>
- private IComparer<string> fieldSorter;
-
- /// <summary>
/// Initializes a new instance of the MessageSerializer class.
/// </summary>
/// <param name="messageType">The specific <see cref="IProtocolMessage"/>-derived type
@@ -65,71 +43,19 @@ namespace DotNetOAuth.Messaging { }
this.messageType = messageType;
- this.serializer = new DataContractSerializer(
- messageType, this.RootElement.LocalName, this.RootElement.NamespaceName);
- this.fieldSorter = new DataContractMemberComparer(messageType);
}
/// <summary>
- /// Gets the XML element that is used to surround all the XML values from the dictionary.
- /// </summary>
- private XName RootElement {
- get {
- if (this.rootElement == null) {
- DataContractAttribute attribute;
- try {
- attribute = this.messageType.GetCustomAttributes(typeof(DataContractAttribute), false).OfType<DataContractAttribute>().Single();
- } catch (InvalidOperationException ex) {
- throw new ProtocolException(
- string.Format(
- CultureInfo.CurrentCulture,
- MessagingStrings.DataContractMissingFromMessageType,
- this.messageType.FullName),
- ex);
- }
-
- if (attribute.Namespace == null) {
- throw new ProtocolException(string.Format(
- CultureInfo.CurrentCulture,
- MessagingStrings.DataContractMissingNamespace,
- this.messageType.FullName));
- }
-
- this.rootElement = XName.Get("root", attribute.Namespace);
- }
-
- return this.rootElement;
- }
- }
-
- /// <summary>
- /// Returns a message serializer from a reusable collection of serializers.
+ /// Creates or reuses a message serializer for a given message type.
/// </summary>
/// <param name="messageType">The type of message that will be serialized/deserialized.</param>
- /// <returns>A previously created serializer if one exists, or a newly created one.</returns>
+ /// <returns>A message serializer for the given message type.</returns>
internal static MessageSerializer Get(Type messageType) {
if (messageType == null) {
throw new ArgumentNullException("messageType");
}
- // We do this as efficiently as possible by first trying to fetch the
- // serializer out of the dictionary without taking a lock.
- MessageSerializer serializer;
- if (prebuiltSerializers.TryGetValue(messageType, out serializer)) {
- return serializer;
- }
-
- // Since it wasn't there, we'll be trying to write to the dictionary so
- // we take a lock and try reading again first, then creating the serializer
- // and storing it when we're sure it absolutely necessary.
- lock (prebuiltSerializers) {
- if (prebuiltSerializers.TryGetValue(messageType, out serializer)) {
- return serializer;
- }
- serializer = new MessageSerializer(messageType);
- prebuiltSerializers.Add(messageType, serializer);
- }
- return serializer;
+ return new MessageSerializer(messageType);
}
/// <summary>
@@ -142,28 +68,7 @@ namespace DotNetOAuth.Messaging { throw new ArgumentNullException("message");
}
- var fields = new Dictionary<string, string>(StringComparer.Ordinal);
- this.Serialize(fields, message);
- return fields;
- }
-
- /// <summary>
- /// Saves the [DataMember] properties of a message to an existing dictionary.
- /// </summary>
- /// <param name="fields">The dictionary to save values to.</param>
- /// <param name="message">The message to pull values from.</param>
- internal void Serialize(IDictionary<string, string> fields, IProtocolMessage message) {
- if (fields == null) {
- throw new ArgumentNullException("fields");
- }
- if (message == null) {
- throw new ArgumentNullException("message");
- }
-
- message.EnsureValidMessage();
- using (XmlWriter writer = DictionaryXmlWriter.Create(fields)) {
- this.serializer.WriteObjectContent(writer, message);
- }
+ return new Reflection.MessageDictionary(message);
}
/// <summary>
@@ -176,13 +81,15 @@ namespace DotNetOAuth.Messaging { throw new ArgumentNullException("fields");
}
- var reader = DictionaryXmlReader.Create(this.RootElement, this.fieldSorter, fields);
- IProtocolMessage result;
+ IProtocolMessage result ;
try {
- result = (IProtocolMessage)this.serializer.ReadObject(reader, false);
- } catch (SerializationException ex) {
- // Missing required fields is one cause of this exception.
- throw new ProtocolException(Strings.InvalidIncomingMessage, ex);
+ result = (IProtocolMessage)Activator.CreateInstance(this.messageType, true);
+ } catch (MissingMethodException ex) {
+ throw new ProtocolException("Failed to instantiate type " + this.messageType.FullName, ex);
+ }
+ foreach (var pair in fields) {
+ IDictionary<string, string> dictionary = new Reflection.MessageDictionary(result);
+ dictionary.Add(pair);
}
result.EnsureValidMessage();
return result;
diff --git a/src/DotNetOAuth/Messaging/Reflection/MessageDictionary.cs b/src/DotNetOAuth/Messaging/Reflection/MessageDictionary.cs index eca7a9b..9490894 100644 --- a/src/DotNetOAuth/Messaging/Reflection/MessageDictionary.cs +++ b/src/DotNetOAuth/Messaging/Reflection/MessageDictionary.cs @@ -8,6 +8,7 @@ namespace DotNetOAuth.Messaging.Reflection { using System;
using System.Collections;
using System.Collections.Generic;
+ using System.Diagnostics;
/// <summary>
/// Wraps an <see cref="IProtocolMessage"/> instance in a dictionary that
@@ -30,10 +31,14 @@ namespace DotNetOAuth.Messaging.Reflection { #region IDictionary<string,string> Members
- void IDictionary<string, string>.Add(string key, string value) {
+ public void Add(string key, string value) {
+ if (value == null) {
+ throw new ArgumentNullException("value");
+ }
+
MessagePart part;
if (this.description.Mapping.TryGetValue(key, out part)) {
- if (part.GetValue(this.message) != null) {
+ if (part.IsNondefaultValueSet(this.message)) {
throw new ArgumentException(MessagingStrings.KeyAlreadyExists);
}
part.SetValue(this.message, value);
@@ -42,27 +47,30 @@ namespace DotNetOAuth.Messaging.Reflection { }
}
- bool IDictionary<string, string>.ContainsKey(string key) {
- return this.message.ExtraData.ContainsKey(key) || this.description.Mapping.ContainsKey(key);
+ public bool ContainsKey(string key) {
+ return this.message.ExtraData.ContainsKey(key) ||
+ (this.description.Mapping.ContainsKey(key) && this.description.Mapping[key].GetValue(this.message) != null);
}
- ICollection<string> IDictionary<string, string>.Keys {
+ public ICollection<string> Keys {
get {
- string[] keys = new string[this.message.ExtraData.Count + this.description.Mapping.Count];
- int i = 0;
- foreach (string key in this.description.Mapping.Keys) {
- keys[i++] = key;
+ List<string> keys = new List<string>(this.message.ExtraData.Count + this.description.Mapping.Count);
+ foreach (var pair in this.description.Mapping) {
+ // Don't include keys with null values, but default values for structs is ok
+ if (pair.Value.GetValue(this.message) != null) {
+ keys.Add(pair.Key);
+ }
}
foreach (string key in this.message.ExtraData.Keys) {
- keys[i++] = key;
+ keys.Add(key);
}
- return keys;
+ return keys.AsReadOnly();
}
}
- bool IDictionary<string, string>.Remove(string key) {
+ public bool Remove(string key) {
if (this.message.ExtraData.Remove(key)) {
return true;
} else {
@@ -77,7 +85,7 @@ namespace DotNetOAuth.Messaging.Reflection { }
}
- bool IDictionary<string, string>.TryGetValue(string key, out string value) {
+ public bool TryGetValue(string key, out string value) {
MessagePart part;
if (this.description.Mapping.TryGetValue(key, out part)) {
value = part.GetValue(this.message);
@@ -86,26 +94,29 @@ namespace DotNetOAuth.Messaging.Reflection { return this.message.ExtraData.TryGetValue(key, out value);
}
- ICollection<string> IDictionary<string, string>.Values {
+ public ICollection<string> Values {
get {
- string[] values = new string[this.message.ExtraData.Count + this.description.Mapping.Count];
- int i = 0;
+ List<string> values = new List<string>(this.message.ExtraData.Count + this.description.Mapping.Count);
foreach (MessagePart part in this.description.Mapping.Values) {
- values[i++] = part.GetValue(this.message);
+ if (part.GetValue(this.message) != null) {
+ values.Add(part.GetValue(this.message));
+ }
}
foreach (string value in this.message.ExtraData.Values) {
- values[i++] = value;
+ Debug.Assert(value != null, "Null values should never be allowed in the extra data dictionary.");
+ values.Add(value);
}
- return values;
+ return values.AsReadOnly();
}
}
- string IDictionary<string, string>.this[string key] {
+ public string this[string key] {
get {
MessagePart part;
if (this.description.Mapping.TryGetValue(key, out part)) {
+ // Never throw KeyNotFoundException for declared properties.
return part.GetValue(this.message);
} else {
return this.message.ExtraData[key];
@@ -117,7 +128,11 @@ namespace DotNetOAuth.Messaging.Reflection { if (this.description.Mapping.TryGetValue(key, out part)) {
part.SetValue(this.message, value);
} else {
- this.message.ExtraData[key] = value;
+ if (value == null) {
+ this.message.ExtraData.Remove(key);
+ } else {
+ this.message.ExtraData[key] = value;
+ }
}
}
}
@@ -126,17 +141,17 @@ namespace DotNetOAuth.Messaging.Reflection { #region ICollection<KeyValuePair<string,string>> Members
- void ICollection<KeyValuePair<string, string>>.Add(KeyValuePair<string, string> item) {
- ((IDictionary<string, string>)this).Add(item.Key, item.Value);
+ public void Add(KeyValuePair<string, string> item) {
+ this.Add(item.Key, item.Value);
}
- void ICollection<KeyValuePair<string, string>>.Clear() {
- foreach (string key in ((IDictionary<string, string>)this).Keys) {
- ((IDictionary<string, string>)this).Remove(key);
+ public void Clear() {
+ foreach (string key in this.Keys) {
+ this.Remove(key);
}
}
- bool ICollection<KeyValuePair<string, string>>.Contains(KeyValuePair<string, string> item) {
+ public bool Contains(KeyValuePair<string, string> item) {
MessagePart part;
if (this.description.Mapping.TryGetValue(item.Key, out part)) {
return string.Equals(part.GetValue(this.message), item.Value, StringComparison.Ordinal);
@@ -151,15 +166,15 @@ namespace DotNetOAuth.Messaging.Reflection { }
}
- int ICollection<KeyValuePair<string, string>>.Count {
- get { return this.description.Mapping.Count + this.message.ExtraData.Count; }
+ public int Count {
+ get { return this.Keys.Count; }
}
bool ICollection<KeyValuePair<string, string>>.IsReadOnly {
get { return false; }
}
- bool ICollection<KeyValuePair<string, string>>.Remove(KeyValuePair<string, string> item) {
+ public bool Remove(KeyValuePair<string, string> item) {
// We use contains because that checks that the value is equal as well.
if (((ICollection<KeyValuePair<string, string>>)this).Contains(item)) {
((IDictionary<string, string>)this).Remove(item.Key);
@@ -172,13 +187,9 @@ namespace DotNetOAuth.Messaging.Reflection { #region IEnumerable<KeyValuePair<string,string>> Members
- IEnumerator<KeyValuePair<string, string>> IEnumerable<KeyValuePair<string, string>>.GetEnumerator() {
- foreach (MessagePart part in this.description.Mapping.Values) {
- yield return new KeyValuePair<string, string>(part.Name, part.GetValue(this.message));
- }
-
- foreach (var pair in this.message.ExtraData) {
- yield return pair;
+ public IEnumerator<KeyValuePair<string, string>> GetEnumerator() {
+ foreach (string key in Keys) {
+ yield return new KeyValuePair<string, string>(key, this[key]);
}
}
diff --git a/src/DotNetOAuth/Messaging/Reflection/MessagePart.cs b/src/DotNetOAuth/Messaging/Reflection/MessagePart.cs index 09f959b..2005370 100644 --- a/src/DotNetOAuth/Messaging/Reflection/MessagePart.cs +++ b/src/DotNetOAuth/Messaging/Reflection/MessagePart.cs @@ -19,6 +19,10 @@ namespace DotNetOAuth.Messaging.Reflection { private FieldInfo field;
+ private Type memberDeclaredType;
+
+ private object defaultMemberValue;
+
static MessagePart() {
Map<Uri>(uri => uri.AbsoluteUri, str => new Uri(str));
}
@@ -40,10 +44,11 @@ namespace DotNetOAuth.Messaging.Reflection { this.Name = attribute.Name ?? member.Name;
this.Signed = attribute.Signed;
- this.IsRequired = !attribute.Optional;
+ this.IsRequired = attribute.IsRequired;
+ this.memberDeclaredType = (this.field != null) ? this.field.FieldType : this.property.PropertyType;
+ this.defaultMemberValue = deriveDefaultValue(this.memberDeclaredType);
- if (!converters.TryGetValue(member.DeclaringType, out this.converter)) {
- Type memberDeclaredType = (this.field != null) ? this.field.FieldType : this.property.PropertyType;
+ if (!converters.TryGetValue(this.memberDeclaredType, out this.converter)) {
this.converter = new ValueMapping(
obj => obj != null ? obj.ToString() : null,
str => str != null ? Convert.ChangeType(str, memberDeclaredType) : null);
@@ -73,10 +78,34 @@ namespace DotNetOAuth.Messaging.Reflection { }
internal string GetValue(IProtocolMessage message) {
+ return this.ToString(this.GetValueAsObject(message));
+ }
+
+ internal bool IsNondefaultValueSet(IProtocolMessage message) {
+ if (this.memberDeclaredType.IsValueType) {
+ return !GetValueAsObject(message).Equals(this.defaultMemberValue);
+ } else {
+ return this.defaultMemberValue != GetValueAsObject(message);
+ }
+ }
+
+ internal bool IsValidValue(IProtocolMessage message) {
+ return true;
+ }
+
+ private static object deriveDefaultValue(Type type) {
+ if (type.IsValueType) {
+ return Activator.CreateInstance(type);
+ } else {
+ return null;
+ }
+ }
+
+ private object GetValueAsObject(IProtocolMessage message) {
if (this.property != null) {
- return this.ToString(this.property.GetValue(message, null));
+ return this.property.GetValue(message, null);
} else {
- return this.ToString(this.field.GetValue(message));
+ return this.field.GetValue(message);
}
}
diff --git a/src/DotNetOAuth/Messaging/MessagePartAttribute.cs b/src/DotNetOAuth/Messaging/Reflection/MessagePartAttribute.cs index 6620f7f..e5c429e 100644 --- a/src/DotNetOAuth/Messaging/MessagePartAttribute.cs +++ b/src/DotNetOAuth/Messaging/Reflection/MessagePartAttribute.cs @@ -4,7 +4,7 @@ // </copyright>
//-----------------------------------------------------------------------
-namespace DotNetOAuth.Messaging {
+namespace DotNetOAuth.Messaging.Reflection {
using System;
using System.Net.Security;
using System.Reflection;
@@ -28,7 +28,7 @@ namespace DotNetOAuth.Messaging { public ProtectionLevel Signed { get; set; }
- public bool Optional { get; set; }
+ public bool IsRequired { get; set; }
internal void Initialize(MemberInfo member) {
if (member == null) {
|