summaryrefslogtreecommitdiffstats
path: root/src/DotNetOAuth/Messaging/Reflection
diff options
context:
space:
mode:
Diffstat (limited to 'src/DotNetOAuth/Messaging/Reflection')
-rw-r--r--src/DotNetOAuth/Messaging/Reflection/MessageDictionary.cs85
-rw-r--r--src/DotNetOAuth/Messaging/Reflection/MessagePart.cs39
-rw-r--r--src/DotNetOAuth/Messaging/Reflection/MessagePartAttribute.cs47
3 files changed, 129 insertions, 42 deletions
diff --git a/src/DotNetOAuth/Messaging/Reflection/MessageDictionary.cs b/src/DotNetOAuth/Messaging/Reflection/MessageDictionary.cs
index eca7a9b..9490894 100644
--- a/src/DotNetOAuth/Messaging/Reflection/MessageDictionary.cs
+++ b/src/DotNetOAuth/Messaging/Reflection/MessageDictionary.cs
@@ -8,6 +8,7 @@ namespace DotNetOAuth.Messaging.Reflection {
using System;
using System.Collections;
using System.Collections.Generic;
+ using System.Diagnostics;
/// <summary>
/// Wraps an <see cref="IProtocolMessage"/> instance in a dictionary that
@@ -30,10 +31,14 @@ namespace DotNetOAuth.Messaging.Reflection {
#region IDictionary<string,string> Members
- void IDictionary<string, string>.Add(string key, string value) {
+ public void Add(string key, string value) {
+ if (value == null) {
+ throw new ArgumentNullException("value");
+ }
+
MessagePart part;
if (this.description.Mapping.TryGetValue(key, out part)) {
- if (part.GetValue(this.message) != null) {
+ if (part.IsNondefaultValueSet(this.message)) {
throw new ArgumentException(MessagingStrings.KeyAlreadyExists);
}
part.SetValue(this.message, value);
@@ -42,27 +47,30 @@ namespace DotNetOAuth.Messaging.Reflection {
}
}
- bool IDictionary<string, string>.ContainsKey(string key) {
- return this.message.ExtraData.ContainsKey(key) || this.description.Mapping.ContainsKey(key);
+ public bool ContainsKey(string key) {
+ return this.message.ExtraData.ContainsKey(key) ||
+ (this.description.Mapping.ContainsKey(key) && this.description.Mapping[key].GetValue(this.message) != null);
}
- ICollection<string> IDictionary<string, string>.Keys {
+ public ICollection<string> Keys {
get {
- string[] keys = new string[this.message.ExtraData.Count + this.description.Mapping.Count];
- int i = 0;
- foreach (string key in this.description.Mapping.Keys) {
- keys[i++] = key;
+ List<string> keys = new List<string>(this.message.ExtraData.Count + this.description.Mapping.Count);
+ foreach (var pair in this.description.Mapping) {
+ // Don't include keys with null values, but default values for structs is ok
+ if (pair.Value.GetValue(this.message) != null) {
+ keys.Add(pair.Key);
+ }
}
foreach (string key in this.message.ExtraData.Keys) {
- keys[i++] = key;
+ keys.Add(key);
}
- return keys;
+ return keys.AsReadOnly();
}
}
- bool IDictionary<string, string>.Remove(string key) {
+ public bool Remove(string key) {
if (this.message.ExtraData.Remove(key)) {
return true;
} else {
@@ -77,7 +85,7 @@ namespace DotNetOAuth.Messaging.Reflection {
}
}
- bool IDictionary<string, string>.TryGetValue(string key, out string value) {
+ public bool TryGetValue(string key, out string value) {
MessagePart part;
if (this.description.Mapping.TryGetValue(key, out part)) {
value = part.GetValue(this.message);
@@ -86,26 +94,29 @@ namespace DotNetOAuth.Messaging.Reflection {
return this.message.ExtraData.TryGetValue(key, out value);
}
- ICollection<string> IDictionary<string, string>.Values {
+ public ICollection<string> Values {
get {
- string[] values = new string[this.message.ExtraData.Count + this.description.Mapping.Count];
- int i = 0;
+ List<string> values = new List<string>(this.message.ExtraData.Count + this.description.Mapping.Count);
foreach (MessagePart part in this.description.Mapping.Values) {
- values[i++] = part.GetValue(this.message);
+ if (part.GetValue(this.message) != null) {
+ values.Add(part.GetValue(this.message));
+ }
}
foreach (string value in this.message.ExtraData.Values) {
- values[i++] = value;
+ Debug.Assert(value != null, "Null values should never be allowed in the extra data dictionary.");
+ values.Add(value);
}
- return values;
+ return values.AsReadOnly();
}
}
- string IDictionary<string, string>.this[string key] {
+ public string this[string key] {
get {
MessagePart part;
if (this.description.Mapping.TryGetValue(key, out part)) {
+ // Never throw KeyNotFoundException for declared properties.
return part.GetValue(this.message);
} else {
return this.message.ExtraData[key];
@@ -117,7 +128,11 @@ namespace DotNetOAuth.Messaging.Reflection {
if (this.description.Mapping.TryGetValue(key, out part)) {
part.SetValue(this.message, value);
} else {
- this.message.ExtraData[key] = value;
+ if (value == null) {
+ this.message.ExtraData.Remove(key);
+ } else {
+ this.message.ExtraData[key] = value;
+ }
}
}
}
@@ -126,17 +141,17 @@ namespace DotNetOAuth.Messaging.Reflection {
#region ICollection<KeyValuePair<string,string>> Members
- void ICollection<KeyValuePair<string, string>>.Add(KeyValuePair<string, string> item) {
- ((IDictionary<string, string>)this).Add(item.Key, item.Value);
+ public void Add(KeyValuePair<string, string> item) {
+ this.Add(item.Key, item.Value);
}
- void ICollection<KeyValuePair<string, string>>.Clear() {
- foreach (string key in ((IDictionary<string, string>)this).Keys) {
- ((IDictionary<string, string>)this).Remove(key);
+ public void Clear() {
+ foreach (string key in this.Keys) {
+ this.Remove(key);
}
}
- bool ICollection<KeyValuePair<string, string>>.Contains(KeyValuePair<string, string> item) {
+ public bool Contains(KeyValuePair<string, string> item) {
MessagePart part;
if (this.description.Mapping.TryGetValue(item.Key, out part)) {
return string.Equals(part.GetValue(this.message), item.Value, StringComparison.Ordinal);
@@ -151,15 +166,15 @@ namespace DotNetOAuth.Messaging.Reflection {
}
}
- int ICollection<KeyValuePair<string, string>>.Count {
- get { return this.description.Mapping.Count + this.message.ExtraData.Count; }
+ public int Count {
+ get { return this.Keys.Count; }
}
bool ICollection<KeyValuePair<string, string>>.IsReadOnly {
get { return false; }
}
- bool ICollection<KeyValuePair<string, string>>.Remove(KeyValuePair<string, string> item) {
+ public bool Remove(KeyValuePair<string, string> item) {
// We use contains because that checks that the value is equal as well.
if (((ICollection<KeyValuePair<string, string>>)this).Contains(item)) {
((IDictionary<string, string>)this).Remove(item.Key);
@@ -172,13 +187,9 @@ namespace DotNetOAuth.Messaging.Reflection {
#region IEnumerable<KeyValuePair<string,string>> Members
- IEnumerator<KeyValuePair<string, string>> IEnumerable<KeyValuePair<string, string>>.GetEnumerator() {
- foreach (MessagePart part in this.description.Mapping.Values) {
- yield return new KeyValuePair<string, string>(part.Name, part.GetValue(this.message));
- }
-
- foreach (var pair in this.message.ExtraData) {
- yield return pair;
+ public IEnumerator<KeyValuePair<string, string>> GetEnumerator() {
+ foreach (string key in Keys) {
+ yield return new KeyValuePair<string, string>(key, this[key]);
}
}
diff --git a/src/DotNetOAuth/Messaging/Reflection/MessagePart.cs b/src/DotNetOAuth/Messaging/Reflection/MessagePart.cs
index 09f959b..2005370 100644
--- a/src/DotNetOAuth/Messaging/Reflection/MessagePart.cs
+++ b/src/DotNetOAuth/Messaging/Reflection/MessagePart.cs
@@ -19,6 +19,10 @@ namespace DotNetOAuth.Messaging.Reflection {
private FieldInfo field;
+ private Type memberDeclaredType;
+
+ private object defaultMemberValue;
+
static MessagePart() {
Map<Uri>(uri => uri.AbsoluteUri, str => new Uri(str));
}
@@ -40,10 +44,11 @@ namespace DotNetOAuth.Messaging.Reflection {
this.Name = attribute.Name ?? member.Name;
this.Signed = attribute.Signed;
- this.IsRequired = !attribute.Optional;
+ this.IsRequired = attribute.IsRequired;
+ this.memberDeclaredType = (this.field != null) ? this.field.FieldType : this.property.PropertyType;
+ this.defaultMemberValue = deriveDefaultValue(this.memberDeclaredType);
- if (!converters.TryGetValue(member.DeclaringType, out this.converter)) {
- Type memberDeclaredType = (this.field != null) ? this.field.FieldType : this.property.PropertyType;
+ if (!converters.TryGetValue(this.memberDeclaredType, out this.converter)) {
this.converter = new ValueMapping(
obj => obj != null ? obj.ToString() : null,
str => str != null ? Convert.ChangeType(str, memberDeclaredType) : null);
@@ -73,10 +78,34 @@ namespace DotNetOAuth.Messaging.Reflection {
}
internal string GetValue(IProtocolMessage message) {
+ return this.ToString(this.GetValueAsObject(message));
+ }
+
+ internal bool IsNondefaultValueSet(IProtocolMessage message) {
+ if (this.memberDeclaredType.IsValueType) {
+ return !GetValueAsObject(message).Equals(this.defaultMemberValue);
+ } else {
+ return this.defaultMemberValue != GetValueAsObject(message);
+ }
+ }
+
+ internal bool IsValidValue(IProtocolMessage message) {
+ return true;
+ }
+
+ private static object deriveDefaultValue(Type type) {
+ if (type.IsValueType) {
+ return Activator.CreateInstance(type);
+ } else {
+ return null;
+ }
+ }
+
+ private object GetValueAsObject(IProtocolMessage message) {
if (this.property != null) {
- return this.ToString(this.property.GetValue(message, null));
+ return this.property.GetValue(message, null);
} else {
- return this.ToString(this.field.GetValue(message));
+ return this.field.GetValue(message);
}
}
diff --git a/src/DotNetOAuth/Messaging/Reflection/MessagePartAttribute.cs b/src/DotNetOAuth/Messaging/Reflection/MessagePartAttribute.cs
new file mode 100644
index 0000000..e5c429e
--- /dev/null
+++ b/src/DotNetOAuth/Messaging/Reflection/MessagePartAttribute.cs
@@ -0,0 +1,47 @@
+//-----------------------------------------------------------------------
+// <copyright file="MessagePartAttribute.cs" company="Andrew Arnott">
+// Copyright (c) Andrew Arnott. All rights reserved.
+// </copyright>
+//-----------------------------------------------------------------------
+
+namespace DotNetOAuth.Messaging.Reflection {
+ using System;
+ using System.Net.Security;
+ using System.Reflection;
+
+ [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field, Inherited = true, AllowMultiple = false)]
+ internal sealed class MessagePartAttribute : Attribute {
+ private bool initialized;
+ private string name;
+
+ internal MessagePartAttribute() {
+ }
+
+ internal MessagePartAttribute(string name) {
+ this.Name = name;
+ }
+
+ public string Name {
+ get { return this.name; }
+ set { this.name = string.IsNullOrEmpty(value) ? null : value; }
+ }
+
+ public ProtectionLevel Signed { get; set; }
+
+ public bool IsRequired { get; set; }
+
+ internal void Initialize(MemberInfo member) {
+ if (member == null) {
+ throw new ArgumentNullException("member");
+ }
+
+ if (!this.initialized) {
+ if (String.IsNullOrEmpty(this.Name)) {
+ this.Name = member.Name;
+ }
+
+ this.initialized = true;
+ }
+ }
+ }
+}