diff options
Diffstat (limited to 'src/DotNetOpenAuth.Core/Messaging/Reflection')
8 files changed, 1522 insertions, 0 deletions
diff --git a/src/DotNetOpenAuth.Core/Messaging/Reflection/IMessagePartEncoder.cs b/src/DotNetOpenAuth.Core/Messaging/Reflection/IMessagePartEncoder.cs new file mode 100644 index 0000000..bbb3737 --- /dev/null +++ b/src/DotNetOpenAuth.Core/Messaging/Reflection/IMessagePartEncoder.cs @@ -0,0 +1,78 @@ +//----------------------------------------------------------------------- +// <copyright file="IMessagePartEncoder.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Messaging.Reflection { + using System; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.Linq; + using System.Text; + + /// <summary> + /// An interface describing how various objects can be serialized and deserialized between their object and string forms. + /// </summary> + /// <remarks> + /// Implementations of this interface must include a default constructor and must be thread-safe. + /// </remarks> + [ContractClass(typeof(IMessagePartEncoderContract))] + public interface IMessagePartEncoder { + /// <summary> + /// Encodes the specified value. + /// </summary> + /// <param name="value">The value. Guaranteed to never be null.</param> + /// <returns>The <paramref name="value"/> in string form, ready for message transport.</returns> + string Encode(object value); + + /// <summary> + /// Decodes the specified value. + /// </summary> + /// <param name="value">The string value carried by the transport. Guaranteed to never be null, although it may be empty.</param> + /// <returns>The deserialized form of the given string.</returns> + /// <exception cref="FormatException">Thrown when the string value given cannot be decoded into the required object type.</exception> + object Decode(string value); + } + + /// <summary> + /// Code contract for the <see cref="IMessagePartEncoder"/> type. + /// </summary> + [ContractClassFor(typeof(IMessagePartEncoder))] + internal abstract class IMessagePartEncoderContract : IMessagePartEncoder { + /// <summary> + /// Initializes a new instance of the <see cref="IMessagePartEncoderContract"/> class. + /// </summary> + protected IMessagePartEncoderContract() { + } + + #region IMessagePartEncoder Members + + /// <summary> + /// Encodes the specified value. + /// </summary> + /// <param name="value">The value. Guaranteed to never be null.</param> + /// <returns> + /// The <paramref name="value"/> in string form, ready for message transport. + /// </returns> + string IMessagePartEncoder.Encode(object value) { + Requires.NotNull(value, "value"); + throw new NotImplementedException(); + } + + /// <summary> + /// Decodes the specified value. + /// </summary> + /// <param name="value">The string value carried by the transport. Guaranteed to never be null, although it may be empty.</param> + /// <returns> + /// The deserialized form of the given string. + /// </returns> + /// <exception cref="FormatException">Thrown when the string value given cannot be decoded into the required object type.</exception> + object IMessagePartEncoder.Decode(string value) { + Requires.NotNull(value, "value"); + throw new NotImplementedException(); + } + + #endregion + } +} diff --git a/src/DotNetOpenAuth.Core/Messaging/Reflection/IMessagePartNullEncoder.cs b/src/DotNetOpenAuth.Core/Messaging/Reflection/IMessagePartNullEncoder.cs new file mode 100644 index 0000000..7581550 --- /dev/null +++ b/src/DotNetOpenAuth.Core/Messaging/Reflection/IMessagePartNullEncoder.cs @@ -0,0 +1,18 @@ +//----------------------------------------------------------------------- +// <copyright file="IMessagePartNullEncoder.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Messaging.Reflection { + /// <summary> + /// A message part encoder that has a special encoding for a null value. + /// </summary> + public interface IMessagePartNullEncoder : IMessagePartEncoder { + /// <summary> + /// Gets the string representation to include in a serialized message + /// when the message part has a <c>null</c> value. + /// </summary> + string EncodedNullValue { get; } + } +} diff --git a/src/DotNetOpenAuth.Core/Messaging/Reflection/IMessagePartOriginalEncoder.cs b/src/DotNetOpenAuth.Core/Messaging/Reflection/IMessagePartOriginalEncoder.cs new file mode 100644 index 0000000..9ad55c9 --- /dev/null +++ b/src/DotNetOpenAuth.Core/Messaging/Reflection/IMessagePartOriginalEncoder.cs @@ -0,0 +1,22 @@ +//----------------------------------------------------------------------- +// <copyright file="IMessagePartOriginalEncoder.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Messaging.Reflection { + /// <summary> + /// An interface describing how various objects can be serialized and deserialized between their object and string forms. + /// </summary> + /// <remarks> + /// Implementations of this interface must include a default constructor and must be thread-safe. + /// </remarks> + public interface IMessagePartOriginalEncoder : IMessagePartEncoder { + /// <summary> + /// Encodes the specified value as the original value that was formerly decoded. + /// </summary> + /// <param name="value">The value. Guaranteed to never be null.</param> + /// <returns>The <paramref name="value"/> in string form, ready for message transport.</returns> + string EncodeAsOriginalString(object value); + } +} diff --git a/src/DotNetOpenAuth.Core/Messaging/Reflection/MessageDescription.cs b/src/DotNetOpenAuth.Core/Messaging/Reflection/MessageDescription.cs new file mode 100644 index 0000000..9a8098b --- /dev/null +++ b/src/DotNetOpenAuth.Core/Messaging/Reflection/MessageDescription.cs @@ -0,0 +1,283 @@ +//----------------------------------------------------------------------- +// <copyright file="MessageDescription.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Messaging.Reflection { + using System; + using System.Collections.Generic; + using System.Diagnostics; + using System.Diagnostics.Contracts; + using System.Globalization; + using System.Linq; + using System.Reflection; + + /// <summary> + /// A mapping between serialized key names and <see cref="MessagePart"/> instances describing + /// those key/values pairs. + /// </summary> + internal class MessageDescription { + /// <summary> + /// A mapping between the serialized key names and their + /// describing <see cref="MessagePart"/> instances. + /// </summary> + private Dictionary<string, MessagePart> mapping; + + /// <summary> + /// Initializes a new instance of the <see cref="MessageDescription"/> class. + /// </summary> + /// <param name="messageType">Type of the message.</param> + /// <param name="messageVersion">The message version.</param> + internal MessageDescription(Type messageType, Version messageVersion) { + Requires.NotNullSubtype<IMessage>(messageType, "messageType"); + Requires.NotNull(messageVersion, "messageVersion"); + + this.MessageType = messageType; + this.MessageVersion = messageVersion; + this.ReflectMessageType(); + } + + /// <summary> + /// Gets the mapping between the serialized key names and their describing + /// <see cref="MessagePart"/> instances. + /// </summary> + internal IDictionary<string, MessagePart> Mapping { + get { return this.mapping; } + } + + /// <summary> + /// Gets the message version this instance was generated from. + /// </summary> + internal Version MessageVersion { get; private set; } + + /// <summary> + /// Gets the type of message this instance was generated from. + /// </summary> + /// <value>The type of the described message.</value> + internal Type MessageType { get; private set; } + + /// <summary> + /// Gets the constructors available on the message type. + /// </summary> + internal ConstructorInfo[] Constructors { get; private set; } + + /// <summary> + /// Returns a <see cref="System.String"/> that represents this instance. + /// </summary> + /// <returns> + /// A <see cref="System.String"/> that represents this instance. + /// </returns> + public override string ToString() { + return this.MessageType.Name + " (" + this.MessageVersion + ")"; + } + + /// <summary> + /// Gets a dictionary that provides read/write access to a message. + /// </summary> + /// <param name="message">The message the dictionary should provide access to.</param> + /// <returns>The dictionary accessor to the message</returns> + [Pure] + internal MessageDictionary GetDictionary(IMessage message) { + Requires.NotNull(message, "message"); + Contract.Ensures(Contract.Result<MessageDictionary>() != null); + return this.GetDictionary(message, false); + } + + /// <summary> + /// Gets a dictionary that provides read/write access to a message. + /// </summary> + /// <param name="message">The message the dictionary should provide access to.</param> + /// <param name="getOriginalValues">A value indicating whether this message dictionary will retrieve original values instead of normalized ones.</param> + /// <returns>The dictionary accessor to the message</returns> + [Pure] + internal MessageDictionary GetDictionary(IMessage message, bool getOriginalValues) { + Requires.NotNull(message, "message"); + Contract.Ensures(Contract.Result<MessageDictionary>() != null); + return new MessageDictionary(message, this, getOriginalValues); + } + + /// <summary> + /// Ensures the message parts pass basic validation. + /// </summary> + /// <param name="parts">The key/value pairs of the serialized message.</param> + internal void EnsureMessagePartsPassBasicValidation(IDictionary<string, string> parts) { + try { + this.CheckRequiredMessagePartsArePresent(parts.Keys, true); + this.CheckRequiredProtocolMessagePartsAreNotEmpty(parts, true); + this.CheckMessagePartsConstantValues(parts, true); + } catch (ProtocolException) { + Logger.Messaging.ErrorFormat( + "Error while performing basic validation of {0} with these message parts:{1}{2}", + this.MessageType.Name, + Environment.NewLine, + parts.ToStringDeferred()); + throw; + } + } + + /// <summary> + /// Tests whether all the required message parts pass basic validation for the given data. + /// </summary> + /// <param name="parts">The key/value pairs of the serialized message.</param> + /// <returns>A value indicating whether the provided data fits the message's basic requirements.</returns> + internal bool CheckMessagePartsPassBasicValidation(IDictionary<string, string> parts) { + Requires.NotNull(parts, "parts"); + + return this.CheckRequiredMessagePartsArePresent(parts.Keys, false) && + this.CheckRequiredProtocolMessagePartsAreNotEmpty(parts, false) && + this.CheckMessagePartsConstantValues(parts, false); + } + + /// <summary> + /// Verifies that a given set of keys include all the required parameters + /// for this message type or throws an exception. + /// </summary> + /// <param name="keys">The names of all parameters included in a message.</param> + /// <param name="throwOnFailure">if set to <c>true</c> an exception is thrown on failure with details.</param> + /// <returns>A value indicating whether the provided data fits the message's basic requirements.</returns> + /// <exception cref="ProtocolException"> + /// Thrown when required parts of a message are not in <paramref name="keys"/> + /// if <paramref name="throwOnFailure"/> is <c>true</c>. + /// </exception> + private bool CheckRequiredMessagePartsArePresent(IEnumerable<string> keys, bool throwOnFailure) { + Requires.NotNull(keys, "keys"); + + var missingKeys = (from part in this.Mapping.Values + where part.IsRequired && !keys.Contains(part.Name) + select part.Name).ToArray(); + if (missingKeys.Length > 0) { + if (throwOnFailure) { + ErrorUtilities.ThrowProtocol( + MessagingStrings.RequiredParametersMissing, + this.MessageType.FullName, + string.Join(", ", missingKeys)); + } else { + Logger.Messaging.DebugFormat( + MessagingStrings.RequiredParametersMissing, + this.MessageType.FullName, + missingKeys.ToStringDeferred()); + return false; + } + } + + return true; + } + + /// <summary> + /// Ensures the protocol message parts that must not be empty are in fact not empty. + /// </summary> + /// <param name="partValues">A dictionary of key/value pairs that make up the serialized message.</param> + /// <param name="throwOnFailure">if set to <c>true</c> an exception is thrown on failure with details.</param> + /// <returns>A value indicating whether the provided data fits the message's basic requirements.</returns> + /// <exception cref="ProtocolException"> + /// Thrown when required parts of a message are not in <paramref name="partValues"/> + /// if <paramref name="throwOnFailure"/> is <c>true</c>. + /// </exception> + private bool CheckRequiredProtocolMessagePartsAreNotEmpty(IDictionary<string, string> partValues, bool throwOnFailure) { + Requires.NotNull(partValues, "partValues"); + + string value; + var emptyValuedKeys = (from part in this.Mapping.Values + where !part.AllowEmpty && partValues.TryGetValue(part.Name, out value) && value != null && value.Length == 0 + select part.Name).ToArray(); + if (emptyValuedKeys.Length > 0) { + if (throwOnFailure) { + ErrorUtilities.ThrowProtocol( + MessagingStrings.RequiredNonEmptyParameterWasEmpty, + this.MessageType.FullName, + string.Join(", ", emptyValuedKeys)); + } else { + Logger.Messaging.DebugFormat( + MessagingStrings.RequiredNonEmptyParameterWasEmpty, + this.MessageType.FullName, + emptyValuedKeys.ToStringDeferred()); + return false; + } + } + + return true; + } + + /// <summary> + /// Checks that a bunch of message part values meet the constant value requirements of this message description. + /// </summary> + /// <param name="partValues">The part values.</param> + /// <param name="throwOnFailure">if set to <c>true</c>, this method will throw on failure.</param> + /// <returns>A value indicating whether all the requirements are met.</returns> + private bool CheckMessagePartsConstantValues(IDictionary<string, string> partValues, bool throwOnFailure) { + Requires.NotNull(partValues, "partValues"); + + var badConstantValues = (from part in this.Mapping.Values + where part.IsConstantValueAvailableStatically + where partValues.ContainsKey(part.Name) + where !string.Equals(partValues[part.Name], part.StaticConstantValue, StringComparison.Ordinal) + select part.Name).ToArray(); + if (badConstantValues.Length > 0) { + if (throwOnFailure) { + ErrorUtilities.ThrowProtocol( + MessagingStrings.RequiredMessagePartConstantIncorrect, + this.MessageType.FullName, + string.Join(", ", badConstantValues)); + } else { + Logger.Messaging.DebugFormat( + MessagingStrings.RequiredMessagePartConstantIncorrect, + this.MessageType.FullName, + badConstantValues.ToStringDeferred()); + return false; + } + } + + return true; + } + + /// <summary> + /// Reflects over some <see cref="IMessage"/>-implementing type + /// and prepares to serialize/deserialize instances of that type. + /// </summary> + private void ReflectMessageType() { + this.mapping = new Dictionary<string, MessagePart>(); + + Type currentType = this.MessageType; + do { + foreach (MemberInfo member in currentType.GetMembers(BindingFlags.Instance | BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly)) { + if (member is PropertyInfo || member is FieldInfo) { + MessagePartAttribute partAttribute = + (from a in member.GetCustomAttributes(typeof(MessagePartAttribute), true).OfType<MessagePartAttribute>() + orderby a.MinVersionValue descending + where a.MinVersionValue <= this.MessageVersion + where a.MaxVersionValue >= this.MessageVersion + select a).FirstOrDefault(); + if (partAttribute != null) { + MessagePart part = new MessagePart(member, partAttribute); + if (this.mapping.ContainsKey(part.Name)) { + Logger.Messaging.WarnFormat( + "Message type {0} has more than one message part named {1}. Inherited members will be hidden.", + this.MessageType.Name, + part.Name); + } else { + this.mapping.Add(part.Name, part); + } + } + } + } + currentType = currentType.BaseType; + } while (currentType != null); + + BindingFlags flags = BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public; + this.Constructors = this.MessageType.GetConstructors(flags); + } + +#if CONTRACTS_FULL + /// <summary> + /// Describes traits of this class that are always true. + /// </summary> + [ContractInvariantMethod] + private void Invariant() { + Contract.Invariant(this.MessageType != null); + Contract.Invariant(this.MessageVersion != null); + Contract.Invariant(this.Constructors != null); + } +#endif + } +} diff --git a/src/DotNetOpenAuth.Core/Messaging/Reflection/MessageDescriptionCollection.cs b/src/DotNetOpenAuth.Core/Messaging/Reflection/MessageDescriptionCollection.cs new file mode 100644 index 0000000..79ef172 --- /dev/null +++ b/src/DotNetOpenAuth.Core/Messaging/Reflection/MessageDescriptionCollection.cs @@ -0,0 +1,217 @@ +//----------------------------------------------------------------------- +// <copyright file="MessageDescriptionCollection.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Messaging.Reflection { + using System; + using System.Collections.Generic; + using System.Diagnostics.CodeAnalysis; + using System.Diagnostics.Contracts; + + /// <summary> + /// A cache of <see cref="MessageDescription"/> instances. + /// </summary> + [ContractVerification(true)] + internal class MessageDescriptionCollection : IEnumerable<MessageDescription> { + /// <summary> + /// A dictionary of reflected message types and the generated reflection information. + /// </summary> + private readonly Dictionary<MessageTypeAndVersion, MessageDescription> reflectedMessageTypes = new Dictionary<MessageTypeAndVersion, MessageDescription>(); + + /// <summary> + /// Initializes a new instance of the <see cref="MessageDescriptionCollection"/> class. + /// </summary> + internal MessageDescriptionCollection() { + } + + #region IEnumerable<MessageDescription> Members + + /// <summary> + /// Returns an enumerator that iterates through a collection. + /// </summary> + /// <returns> + /// An <see cref="T:System.Collections.IEnumerator"/> object that can be used to iterate through the collection. + /// </returns> + public IEnumerator<MessageDescription> GetEnumerator() { + return this.reflectedMessageTypes.Values.GetEnumerator(); + } + + #endregion + + #region IEnumerable Members + + /// <summary> + /// Returns an enumerator that iterates through a collection. + /// </summary> + /// <returns> + /// An <see cref="T:System.Collections.IEnumerator"/> object that can be used to iterate through the collection. + /// </returns> + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { + return this.reflectedMessageTypes.Values.GetEnumerator(); + } + + #endregion + + /// <summary> + /// Gets a <see cref="MessageDescription"/> instance prepared for the + /// given message type. + /// </summary> + /// <param name="messageType">A type that implements <see cref="IMessage"/>.</param> + /// <param name="messageVersion">The protocol version of the message.</param> + /// <returns>A <see cref="MessageDescription"/> instance.</returns> + [SuppressMessage("Microsoft.Globalization", "CA1303:Do not pass literals as localized parameters", MessageId = "System.Diagnostics.Contracts.__ContractsRuntime.Assume(System.Boolean,System.String,System.String)", Justification = "No localization required.")] + [Pure] + internal MessageDescription Get(Type messageType, Version messageVersion) { + Requires.NotNullSubtype<IMessage>(messageType, "messageType"); + Requires.NotNull(messageVersion, "messageVersion"); + Contract.Ensures(Contract.Result<MessageDescription>() != null); + + MessageTypeAndVersion key = new MessageTypeAndVersion(messageType, messageVersion); + + MessageDescription result; + if (!this.reflectedMessageTypes.TryGetValue(key, out result)) { + lock (this.reflectedMessageTypes) { + if (!this.reflectedMessageTypes.TryGetValue(key, out result)) { + this.reflectedMessageTypes[key] = result = new MessageDescription(messageType, messageVersion); + } + } + } + + Contract.Assume(result != null, "We should never assign null values to this dictionary."); + return result; + } + + /// <summary> + /// Gets a <see cref="MessageDescription"/> instance prepared for the + /// given message type. + /// </summary> + /// <param name="message">The message for which a <see cref="MessageDescription"/> should be obtained.</param> + /// <returns> + /// A <see cref="MessageDescription"/> instance. + /// </returns> + [Pure] + internal MessageDescription Get(IMessage message) { + Requires.NotNull(message, "message"); + Contract.Ensures(Contract.Result<MessageDescription>() != null); + return this.Get(message.GetType(), message.Version); + } + + /// <summary> + /// Gets the dictionary that provides read/write access to a message. + /// </summary> + /// <param name="message">The message.</param> + /// <returns>The dictionary.</returns> + [Pure] + internal MessageDictionary GetAccessor(IMessage message) { + Requires.NotNull(message, "message"); + return this.GetAccessor(message, false); + } + + /// <summary> + /// Gets the dictionary that provides read/write access to a message. + /// </summary> + /// <param name="message">The message.</param> + /// <param name="getOriginalValues">A value indicating whether this message dictionary will retrieve original values instead of normalized ones.</param> + /// <returns>The dictionary.</returns> + [Pure] + internal MessageDictionary GetAccessor(IMessage message, bool getOriginalValues) { + Requires.NotNull(message, "message"); + return this.Get(message).GetDictionary(message, getOriginalValues); + } + + /// <summary> + /// A struct used as the key to bundle message type and version. + /// </summary> + [ContractVerification(true)] + private struct MessageTypeAndVersion { + /// <summary> + /// Backing store for the <see cref="Type"/> property. + /// </summary> + private readonly Type type; + + /// <summary> + /// Backing store for the <see cref="Version"/> property. + /// </summary> + private readonly Version version; + + /// <summary> + /// Initializes a new instance of the <see cref="MessageTypeAndVersion"/> struct. + /// </summary> + /// <param name="messageType">Type of the message.</param> + /// <param name="messageVersion">The message version.</param> + internal MessageTypeAndVersion(Type messageType, Version messageVersion) { + Requires.NotNull(messageType, "messageType"); + Requires.NotNull(messageVersion, "messageVersion"); + + this.type = messageType; + this.version = messageVersion; + } + + /// <summary> + /// Gets the message type. + /// </summary> + [SuppressMessage("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode", Justification = "Exposes basic identity on the type.")] + internal Type Type { + get { return this.type; } + } + + /// <summary> + /// Gets the message version. + /// </summary> + [SuppressMessage("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode", Justification = "Exposes basic identity on the type.")] + internal Version Version { + get { return this.version; } + } + + /// <summary> + /// Implements the operator ==. + /// </summary> + /// <param name="first">The first object to compare.</param> + /// <param name="second">The second object to compare.</param> + /// <returns>The result of the operator.</returns> + public static bool operator ==(MessageTypeAndVersion first, MessageTypeAndVersion second) { + // structs cannot be null, so this is safe + return first.Equals(second); + } + + /// <summary> + /// Implements the operator !=. + /// </summary> + /// <param name="first">The first object to compare.</param> + /// <param name="second">The second object to compare.</param> + /// <returns>The result of the operator.</returns> + public static bool operator !=(MessageTypeAndVersion first, MessageTypeAndVersion second) { + // structs cannot be null, so this is safe + return !first.Equals(second); + } + + /// <summary> + /// Indicates whether this instance and a specified object are equal. + /// </summary> + /// <param name="obj">Another object to compare to.</param> + /// <returns> + /// true if <paramref name="obj"/> and this instance are the same type and represent the same value; otherwise, false. + /// </returns> + public override bool Equals(object obj) { + if (obj is MessageTypeAndVersion) { + MessageTypeAndVersion other = (MessageTypeAndVersion)obj; + return this.type == other.type && this.version == other.version; + } else { + return false; + } + } + + /// <summary> + /// Returns the hash code for this instance. + /// </summary> + /// <returns> + /// A 32-bit signed integer that is the hash code for this instance. + /// </returns> + public override int GetHashCode() { + return this.type.GetHashCode(); + } + } + } +} diff --git a/src/DotNetOpenAuth.Core/Messaging/Reflection/MessageDictionary.cs b/src/DotNetOpenAuth.Core/Messaging/Reflection/MessageDictionary.cs new file mode 100644 index 0000000..54e2dd5 --- /dev/null +++ b/src/DotNetOpenAuth.Core/Messaging/Reflection/MessageDictionary.cs @@ -0,0 +1,409 @@ +//----------------------------------------------------------------------- +// <copyright file="MessageDictionary.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Messaging.Reflection { + using System; + using System.Collections; + using System.Collections.Generic; + using System.Diagnostics; + using System.Diagnostics.CodeAnalysis; + using System.Diagnostics.Contracts; + + /// <summary> + /// Wraps an <see cref="IMessage"/> instance in a dictionary that + /// provides access to both well-defined message properties and "extra" + /// name/value pairs that have no properties associated with them. + /// </summary> + [ContractVerification(false)] + internal class MessageDictionary : IDictionary<string, string> { + /// <summary> + /// The <see cref="IMessage"/> instance manipulated by this dictionary. + /// </summary> + private readonly IMessage message; + + /// <summary> + /// The <see cref="MessageDescription"/> instance that describes the message type. + /// </summary> + private readonly MessageDescription description; + + /// <summary> + /// Whether original string values should be retrieved instead of normalized ones. + /// </summary> + private readonly bool getOriginalValues; + + /// <summary> + /// Initializes a new instance of the <see cref="MessageDictionary"/> class. + /// </summary> + /// <param name="message">The message instance whose values will be manipulated by this dictionary.</param> + /// <param name="description">The message description.</param> + /// <param name="getOriginalValues">A value indicating whether this message dictionary will retrieve original values instead of normalized ones.</param> + [Pure] + internal MessageDictionary(IMessage message, MessageDescription description, bool getOriginalValues) { + Requires.NotNull(message, "message"); + Requires.NotNull(description, "description"); + + this.message = message; + this.description = description; + this.getOriginalValues = getOriginalValues; + } + + /// <summary> + /// Gets the message this dictionary provides access to. + /// </summary> + public IMessage Message { + get { + Contract.Ensures(Contract.Result<IMessage>() != null); + return this.message; + } + } + + /// <summary> + /// Gets the description of the type of message this dictionary provides access to. + /// </summary> + public MessageDescription Description { + get { + Contract.Ensures(Contract.Result<MessageDescription>() != null); + return this.description; + } + } + + #region ICollection<KeyValuePair<string,string>> Properties + + /// <summary> + /// Gets the number of explicitly set values in the message. + /// </summary> + public int Count { + get { return this.Keys.Count; } + } + + /// <summary> + /// Gets a value indicating whether this message is read only. + /// </summary> + bool ICollection<KeyValuePair<string, string>>.IsReadOnly { + get { return false; } + } + + #endregion + + #region IDictionary<string,string> Properties + + /// <summary> + /// Gets all the keys that have values associated with them. + /// </summary> + public ICollection<string> Keys { + get { + List<string> keys = new List<string>(this.message.ExtraData.Count + this.description.Mapping.Count); + keys.AddRange(this.DeclaredKeys); + keys.AddRange(this.AdditionalKeys); + return keys.AsReadOnly(); + } + } + + /// <summary> + /// Gets the set of official message part names that have non-null values associated with them. + /// </summary> + public ICollection<string> DeclaredKeys { + get { + List<string> keys = new List<string>(this.description.Mapping.Count); + foreach (var pair in this.description.Mapping) { + // Don't include keys with null values, but default values for structs is ok + if (pair.Value.GetValue(this.message, this.getOriginalValues) != null) { + keys.Add(pair.Key); + } + } + + return keys.AsReadOnly(); + } + } + + /// <summary> + /// Gets the keys that are in the message but not declared as official OAuth properties. + /// </summary> + public ICollection<string> AdditionalKeys { + get { return this.message.ExtraData.Keys; } + } + + /// <summary> + /// Gets all the values. + /// </summary> + public ICollection<string> Values { + get { + List<string> values = new List<string>(this.message.ExtraData.Count + this.description.Mapping.Count); + foreach (MessagePart part in this.description.Mapping.Values) { + if (part.GetValue(this.message, this.getOriginalValues) != null) { + values.Add(part.GetValue(this.message, this.getOriginalValues)); + } + } + + foreach (string value in this.message.ExtraData.Values) { + Debug.Assert(value != null, "Null values should never be allowed in the extra data dictionary."); + values.Add(value); + } + + return values.AsReadOnly(); + } + } + + #endregion + + /// <summary> + /// Gets the serializer for the message this dictionary provides access to. + /// </summary> + private MessageSerializer Serializer { + get { return MessageSerializer.Get(this.Message.GetType()); } + } + + #region IDictionary<string,string> Indexers + + /// <summary> + /// Gets or sets a value for some named value. + /// </summary> + /// <param name="key">The serialized form of a name for the value to read or write.</param> + /// <returns>The named value.</returns> + /// <remarks> + /// If the key matches a declared property or field on the message type, + /// that type member is set. Otherwise the key/value is stored in a + /// dictionary for extra (weakly typed) strings. + /// </remarks> + /// <exception cref="ArgumentException">Thrown when setting a value that is not allowed for a given <paramref name="key"/>.</exception> + public string this[string key] { + get { + MessagePart part; + if (this.description.Mapping.TryGetValue(key, out part)) { + // Never throw KeyNotFoundException for declared properties. + return part.GetValue(this.message, this.getOriginalValues); + } else { + return this.message.ExtraData[key]; + } + } + + set { + MessagePart part; + if (this.description.Mapping.TryGetValue(key, out part)) { + part.SetValue(this.message, value); + } else { + if (value == null) { + this.message.ExtraData.Remove(key); + } else { + this.message.ExtraData[key] = value; + } + } + } + } + + #endregion + + #region IDictionary<string,string> Methods + + /// <summary> + /// Adds a named value to the message. + /// </summary> + /// <param name="key">The serialized form of the name whose value is being set.</param> + /// <param name="value">The serialized form of the value.</param> + /// <exception cref="ArgumentException"> + /// Thrown if <paramref name="key"/> already has a set value in this message. + /// </exception> + /// <exception cref="ArgumentNullException"> + /// Thrown if <paramref name="value"/> is null. + /// </exception> + public void Add(string key, string value) { + ErrorUtilities.VerifyArgumentNotNull(value, "value"); + + MessagePart part; + if (this.description.Mapping.TryGetValue(key, out part)) { + if (part.IsNondefaultValueSet(this.message)) { + throw new ArgumentException(MessagingStrings.KeyAlreadyExists); + } + part.SetValue(this.message, value); + } else { + this.message.ExtraData.Add(key, value); + } + } + + /// <summary> + /// Checks whether some named parameter has a value set in the message. + /// </summary> + /// <param name="key">The serialized form of the message part's name.</param> + /// <returns>True if the parameter by the given name has a set value. False otherwise.</returns> + public bool ContainsKey(string key) { + return this.message.ExtraData.ContainsKey(key) || + (this.description.Mapping.ContainsKey(key) && this.description.Mapping[key].GetValue(this.message, this.getOriginalValues) != null); + } + + /// <summary> + /// Removes a name and value from the message given its name. + /// </summary> + /// <param name="key">The serialized form of the name to remove.</param> + /// <returns>True if a message part by the given name was found and removed. False otherwise.</returns> + public bool Remove(string key) { + if (this.message.ExtraData.Remove(key)) { + return true; + } else { + MessagePart part; + if (this.description.Mapping.TryGetValue(key, out part)) { + if (part.GetValue(this.message, this.getOriginalValues) != null) { + part.SetValue(this.message, null); + return true; + } + } + return false; + } + } + + /// <summary> + /// Gets some named value if the key has a value. + /// </summary> + /// <param name="key">The name (in serialized form) of the value being sought.</param> + /// <param name="value">The variable where the value will be set.</param> + /// <returns>True if the key was found and <paramref name="value"/> was set. False otherwise.</returns> + public bool TryGetValue(string key, out string value) { + MessagePart part; + if (this.description.Mapping.TryGetValue(key, out part)) { + value = part.GetValue(this.message, this.getOriginalValues); + return value != null; + } + return this.message.ExtraData.TryGetValue(key, out value); + } + + #endregion + + #region ICollection<KeyValuePair<string,string>> Methods + + /// <summary> + /// Sets a named value in the message. + /// </summary> + /// <param name="item">The name-value pair to add. The name is the serialized form of the key.</param> + [SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "Code Contracts ccrewrite does this.")] + public void Add(KeyValuePair<string, string> item) { + this.Add(item.Key, item.Value); + } + + /// <summary> + /// Removes all values in the message. + /// </summary> + public void ClearValues() { + foreach (string key in this.Keys) { + this.Remove(key); + } + } + + /// <summary> + /// Removes all items from the <see cref="T:System.Collections.Generic.ICollection`1"/>. + /// </summary> + /// <exception cref="T:System.NotSupportedException"> + /// The <see cref="T:System.Collections.Generic.ICollection`1"/> is read-only. + /// </exception> + /// <remarks> + /// This method cannot be implemented because keys are not guaranteed to be removed + /// since some are inherent to the type of message that this dictionary provides + /// access to. + /// </remarks> + public void Clear() { + throw new NotSupportedException(); + } + + /// <summary> + /// Checks whether a named value has been set on the message. + /// </summary> + /// <param name="item">The name/value pair.</param> + /// <returns>True if the key exists and has the given value. False otherwise.</returns> + public bool Contains(KeyValuePair<string, string> item) { + MessagePart part; + if (this.description.Mapping.TryGetValue(item.Key, out part)) { + return string.Equals(part.GetValue(this.message, this.getOriginalValues), item.Value, StringComparison.Ordinal); + } else { + return this.message.ExtraData.Contains(item); + } + } + + /// <summary> + /// Copies all the serializable data from the message to a key/value array. + /// </summary> + /// <param name="array">The array to copy to.</param> + /// <param name="arrayIndex">The index in the <paramref name="array"/> to begin copying to.</param> + void ICollection<KeyValuePair<string, string>>.CopyTo(KeyValuePair<string, string>[] array, int arrayIndex) { + foreach (var pair in (IDictionary<string, string>)this) { + array[arrayIndex++] = pair; + } + } + + /// <summary> + /// Removes a named value from the message if it exists. + /// </summary> + /// <param name="item">The serialized form of the name and value to remove.</param> + /// <returns>True if the name/value was found and removed. False otherwise.</returns> + public bool Remove(KeyValuePair<string, string> item) { + // We use contains because that checks that the value is equal as well. + if (((ICollection<KeyValuePair<string, string>>)this).Contains(item)) { + ((IDictionary<string, string>)this).Remove(item.Key); + return true; + } + return false; + } + + #endregion + + #region IEnumerable<KeyValuePair<string,string>> Members + + /// <summary> + /// Gets an enumerator that generates KeyValuePair<string, string> instances + /// for all the key/value pairs that are set in the message. + /// </summary> + /// <returns>The enumerator that can generate the name/value pairs.</returns> + public IEnumerator<KeyValuePair<string, string>> GetEnumerator() { + foreach (string key in this.Keys) { + yield return new KeyValuePair<string, string>(key, this[key]); + } + } + + #endregion + + #region IEnumerable Members + + /// <summary> + /// Gets an enumerator that generates KeyValuePair<string, string> instances + /// for all the key/value pairs that are set in the message. + /// </summary> + /// <returns>The enumerator that can generate the name/value pairs.</returns> + IEnumerator System.Collections.IEnumerable.GetEnumerator() { + return ((IEnumerable<KeyValuePair<string, string>>)this).GetEnumerator(); + } + + #endregion + + /// <summary> + /// Saves the data in a message to a standard dictionary. + /// </summary> + /// <returns>The generated dictionary.</returns> + [Pure] + public IDictionary<string, string> Serialize() { + Contract.Ensures(Contract.Result<IDictionary<string, string>>() != null); + return this.Serializer.Serialize(this); + } + + /// <summary> + /// Loads data from a dictionary into the message. + /// </summary> + /// <param name="fields">The data to load into the message.</param> + public void Deserialize(IDictionary<string, string> fields) { + Requires.NotNull(fields, "fields"); + this.Serializer.Deserialize(fields, this); + } + +#if CONTRACTS_FULL + /// <summary> + /// Verifies conditions that should be true for any valid state of this object. + /// </summary> + [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification = "Called by code contracts.")] + [SuppressMessage("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode", Justification = "Called by code contracts.")] + [ContractInvariantMethod] + private void ObjectInvariant() { + Contract.Invariant(this.Message != null); + Contract.Invariant(this.Description != null); + } +#endif + } +} diff --git a/src/DotNetOpenAuth.Core/Messaging/Reflection/MessagePart.cs b/src/DotNetOpenAuth.Core/Messaging/Reflection/MessagePart.cs new file mode 100644 index 0000000..f439c4d --- /dev/null +++ b/src/DotNetOpenAuth.Core/Messaging/Reflection/MessagePart.cs @@ -0,0 +1,428 @@ +//----------------------------------------------------------------------- +// <copyright file="MessagePart.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Messaging.Reflection { + using System; + using System.Collections.Generic; + using System.Diagnostics; + using System.Diagnostics.CodeAnalysis; + using System.Diagnostics.Contracts; + using System.Globalization; + using System.Linq; + using System.Net.Security; + using System.Reflection; + using System.Xml; + using DotNetOpenAuth.Configuration; + + /// <summary> + /// Describes an individual member of a message and assists in its serialization. + /// </summary> + [ContractVerification(true)] + [DebuggerDisplay("MessagePart {Name}")] + internal class MessagePart { + /// <summary> + /// A map of converters that help serialize custom objects to string values and back again. + /// </summary> + private static readonly Dictionary<Type, ValueMapping> converters = new Dictionary<Type, ValueMapping>(); + + /// <summary> + /// A map of instantiated custom encoders used to encode/decode message parts. + /// </summary> + private static readonly Dictionary<Type, IMessagePartEncoder> encoders = new Dictionary<Type, IMessagePartEncoder>(); + + /// <summary> + /// The string-object conversion routines to use for this individual message part. + /// </summary> + private ValueMapping converter; + + /// <summary> + /// The property that this message part is associated with, if aplicable. + /// </summary> + private PropertyInfo property; + + /// <summary> + /// The field that this message part is associated with, if aplicable. + /// </summary> + private FieldInfo field; + + /// <summary> + /// The type of the message part. (Not the type of the message itself). + /// </summary> + private Type memberDeclaredType; + + /// <summary> + /// The default (uninitialized) value of the member inherent in its type. + /// </summary> + private object defaultMemberValue; + + /// <summary> + /// Initializes static members of the <see cref="MessagePart"/> class. + /// </summary> + [SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "This simplifies the rest of the code.")] + [SuppressMessage("Microsoft.Globalization", "CA1308:NormalizeStringsToUppercase", Justification = "By design.")] + [SuppressMessage("Microsoft.Performance", "CA1810:InitializeReferenceTypeStaticFieldsInline", Justification = "Much more efficient initialization when we can call methods.")] + static MessagePart() { + Func<string, Uri> safeUri = str => { + Contract.Assume(str != null); + return new Uri(str); + }; + Func<string, bool> safeBool = str => { + Contract.Assume(str != null); + return bool.Parse(str); + }; + + Func<byte[], string> safeFromByteArray = bytes => { + Contract.Assume(bytes != null); + return Convert.ToBase64String(bytes); + }; + Func<string, byte[]> safeToByteArray = str => { + Contract.Assume(str != null); + return Convert.FromBase64String(str); + }; + Map<Uri>(uri => uri.AbsoluteUri, uri => uri.OriginalString, safeUri); + Map<DateTime>(dt => XmlConvert.ToString(dt, XmlDateTimeSerializationMode.Utc), null, str => XmlConvert.ToDateTime(str, XmlDateTimeSerializationMode.Utc)); + Map<TimeSpan>(ts => ts.ToString(), null, str => TimeSpan.Parse(str)); + Map<byte[]>(safeFromByteArray, null, safeToByteArray); + Map<bool>(value => value.ToString().ToLowerInvariant(), null, safeBool); + Map<CultureInfo>(c => c.Name, null, str => new CultureInfo(str)); + Map<CultureInfo[]>(cs => string.Join(",", cs.Select(c => c.Name).ToArray()), null, str => str.Split(',').Select(s => new CultureInfo(s)).ToArray()); + Map<Type>(t => t.FullName, null, str => Type.GetType(str)); + } + + /// <summary> + /// Initializes a new instance of the <see cref="MessagePart"/> class. + /// </summary> + /// <param name="member"> + /// A property or field of an <see cref="IMessage"/> implementing type + /// that has a <see cref="MessagePartAttribute"/> attached to it. + /// </param> + /// <param name="attribute"> + /// The attribute discovered on <paramref name="member"/> that describes the + /// serialization requirements of the message part. + /// </param> + [SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "Unavoidable"), SuppressMessage("Microsoft.Performance", "CA1800:DoNotCastUnnecessarily", Justification = "Code contracts requires it.")] + internal MessagePart(MemberInfo member, MessagePartAttribute attribute) { + Requires.NotNull(member, "member"); + Requires.True(member is FieldInfo || member is PropertyInfo, "member"); + Requires.NotNull(attribute, "attribute"); + + this.field = member as FieldInfo; + this.property = member as PropertyInfo; + this.Name = attribute.Name ?? member.Name; + this.RequiredProtection = attribute.RequiredProtection; + this.IsRequired = attribute.IsRequired; + this.AllowEmpty = attribute.AllowEmpty; + this.memberDeclaredType = (this.field != null) ? this.field.FieldType : this.property.PropertyType; + this.defaultMemberValue = DeriveDefaultValue(this.memberDeclaredType); + + Contract.Assume(this.memberDeclaredType != null); // CC missing PropertyInfo.PropertyType ensures result != null + if (attribute.Encoder == null) { + if (!converters.TryGetValue(this.memberDeclaredType, out this.converter)) { + if (this.memberDeclaredType.IsGenericType && + this.memberDeclaredType.GetGenericTypeDefinition() == typeof(Nullable<>)) { + // It's a nullable type. Try again to look up an appropriate converter for the underlying type. + Type underlyingType = Nullable.GetUnderlyingType(this.memberDeclaredType); + ValueMapping underlyingMapping; + if (converters.TryGetValue(underlyingType, out underlyingMapping)) { + this.converter = new ValueMapping( + underlyingMapping.ValueToString, + null, + str => str != null ? underlyingMapping.StringToValue(str) : null); + } else { + this.converter = new ValueMapping( + obj => obj != null ? obj.ToString() : null, + null, + str => str != null ? Convert.ChangeType(str, underlyingType, CultureInfo.InvariantCulture) : null); + } + } else { + this.converter = new ValueMapping( + obj => obj != null ? obj.ToString() : null, + null, + str => str != null ? Convert.ChangeType(str, this.memberDeclaredType, CultureInfo.InvariantCulture) : null); + } + } + } else { + this.converter = new ValueMapping(GetEncoder(attribute.Encoder)); + } + + // readonly and const fields are considered legal, and "constants" for message transport. + FieldAttributes constAttributes = FieldAttributes.Static | FieldAttributes.Literal | FieldAttributes.HasDefault; + if (this.field != null && ( + (this.field.Attributes & FieldAttributes.InitOnly) == FieldAttributes.InitOnly || + (this.field.Attributes & constAttributes) == constAttributes)) { + this.IsConstantValue = true; + this.IsConstantValueAvailableStatically = this.field.IsStatic; + } else if (this.property != null && !this.property.CanWrite) { + this.IsConstantValue = true; + } + + // Validate a sane combination of settings + this.ValidateSettings(); + } + + /// <summary> + /// Gets or sets the name to use when serializing or deserializing this parameter in a message. + /// </summary> + internal string Name { get; set; } + + /// <summary> + /// Gets or sets whether this message part must be signed. + /// </summary> + internal ProtectionLevel RequiredProtection { get; set; } + + /// <summary> + /// Gets or sets a value indicating whether this message part is required for the + /// containing message to be valid. + /// </summary> + internal bool IsRequired { get; set; } + + /// <summary> + /// Gets or sets a value indicating whether the string value is allowed to be empty in the serialized message. + /// </summary> + internal bool AllowEmpty { get; set; } + + /// <summary> + /// Gets or sets a value indicating whether the field or property must remain its default value. + /// </summary> + internal bool IsConstantValue { get; set; } + + /// <summary> + /// Gets or sets a value indicating whether this part is defined as a constant field and can be read without a message instance. + /// </summary> + internal bool IsConstantValueAvailableStatically { get; set; } + + /// <summary> + /// Gets the static constant value for this message part without a message instance. + /// </summary> + internal string StaticConstantValue { + get { + Requires.ValidState(this.IsConstantValueAvailableStatically); + return this.ToString(this.field.GetValue(null), false); + } + } + + /// <summary> + /// Gets the type of the declared member. + /// </summary> + internal Type MemberDeclaredType { + get { return this.memberDeclaredType; } + } + + /// <summary> + /// Adds a pair of type conversion functions to the static conversion map. + /// </summary> + /// <typeparam name="T">The custom type to convert to and from strings.</typeparam> + /// <param name="toString">The function to convert the custom type to a string.</param> + /// <param name="toOriginalString">The mapping function that converts some custom value to its original (non-normalized) string. May be null if the same as the <paramref name="toString"/> function.</param> + /// <param name="toValue">The function to convert a string to the custom type.</param> + [SuppressMessage("Microsoft.Globalization", "CA1303:Do not pass literals as localized parameters", MessageId = "System.Diagnostics.Contracts.__ContractsRuntime.Requires<System.ArgumentNullException>(System.Boolean,System.String,System.String)", Justification = "Code contracts"), SuppressMessage("Microsoft.Naming", "CA2204:Literals should be spelled correctly", MessageId = "toString", Justification = "Code contracts"), SuppressMessage("Microsoft.Naming", "CA2204:Literals should be spelled correctly", MessageId = "toValue", Justification = "Code contracts")] + internal static void Map<T>(Func<T, string> toString, Func<T, string> toOriginalString, Func<string, T> toValue) { + Requires.NotNull(toString, "toString"); + Requires.NotNull(toValue, "toValue"); + + if (toOriginalString == null) { + toOriginalString = toString; + } + + Func<object, string> safeToString = obj => obj != null ? toString((T)obj) : null; + Func<object, string> safeToOriginalString = obj => obj != null ? toOriginalString((T)obj) : null; + Func<string, object> safeToT = str => str != null ? toValue(str) : default(T); + converters.Add(typeof(T), new ValueMapping(safeToString, safeToOriginalString, safeToT)); + } + + /// <summary> + /// Sets the member of a given message to some given value. + /// Used in deserialization. + /// </summary> + /// <param name="message">The message instance containing the member whose value should be set.</param> + /// <param name="value">The string representation of the value to set.</param> + internal void SetValue(IMessage message, string value) { + Requires.NotNull(message, "message"); + + try { + if (this.IsConstantValue) { + string constantValue = this.GetValue(message); + var caseSensitivity = DotNetOpenAuthSection.Messaging.Strict ? StringComparison.Ordinal : StringComparison.OrdinalIgnoreCase; + if (!string.Equals(constantValue, value, caseSensitivity)) { + throw new ArgumentException(string.Format( + CultureInfo.CurrentCulture, + MessagingStrings.UnexpectedMessagePartValueForConstant, + message.GetType().Name, + this.Name, + constantValue, + value)); + } + } else { + this.SetValueAsObject(message, this.ToValue(value)); + } + } catch (Exception ex) { + throw ErrorUtilities.Wrap(ex, MessagingStrings.MessagePartReadFailure, message.GetType(), this.Name, value); + } + } + + /// <summary> + /// Gets the normalized form of a value of a member of a given message. + /// Used in serialization. + /// </summary> + /// <param name="message">The message instance to read the value from.</param> + /// <returns>The string representation of the member's value.</returns> + internal string GetValue(IMessage message) { + try { + object value = this.GetValueAsObject(message); + return this.ToString(value, false); + } catch (FormatException ex) { + throw ErrorUtilities.Wrap(ex, MessagingStrings.MessagePartWriteFailure, message.GetType(), this.Name); + } + } + + /// <summary> + /// Gets the value of a member of a given message. + /// Used in serialization. + /// </summary> + /// <param name="message">The message instance to read the value from.</param> + /// <param name="originalValue">A value indicating whether the original value should be retrieved (as opposed to a normalized form of it).</param> + /// <returns>The string representation of the member's value.</returns> + internal string GetValue(IMessage message, bool originalValue) { + try { + object value = this.GetValueAsObject(message); + return this.ToString(value, originalValue); + } catch (FormatException ex) { + throw ErrorUtilities.Wrap(ex, MessagingStrings.MessagePartWriteFailure, message.GetType(), this.Name); + } + } + + /// <summary> + /// Gets whether the value has been set to something other than its CLR type default value. + /// </summary> + /// <param name="message">The message instance to check the value on.</param> + /// <returns>True if the value is not the CLR default value.</returns> + internal bool IsNondefaultValueSet(IMessage message) { + if (this.memberDeclaredType.IsValueType) { + return !this.GetValueAsObject(message).Equals(this.defaultMemberValue); + } else { + return this.defaultMemberValue != this.GetValueAsObject(message); + } + } + + /// <summary> + /// Figures out the CLR default value for a given type. + /// </summary> + /// <param name="type">The type whose default value is being sought.</param> + /// <returns>Either null, or some default value like 0 or 0.0.</returns> + private static object DeriveDefaultValue(Type type) { + if (type.IsValueType) { + return Activator.CreateInstance(type); + } else { + return null; + } + } + + /// <summary> + /// Checks whether a type is a nullable value type (i.e. int?) + /// </summary> + /// <param name="type">The type in question.</param> + /// <returns>True if this is a nullable value type.</returns> + private static bool IsNonNullableValueType(Type type) { + if (!type.IsValueType) { + return false; + } + + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) { + return false; + } + + return true; + } + + /// <summary> + /// Retrieves a previously instantiated encoder of a given type, or creates a new one and stores it for later retrieval as well. + /// </summary> + /// <param name="messagePartEncoder">The message part encoder type.</param> + /// <returns>An instance of the desired encoder.</returns> + private static IMessagePartEncoder GetEncoder(Type messagePartEncoder) { + Requires.NotNull(messagePartEncoder, "messagePartEncoder"); + Contract.Ensures(Contract.Result<IMessagePartEncoder>() != null); + + IMessagePartEncoder encoder; + if (!encoders.TryGetValue(messagePartEncoder, out encoder)) { + try { + encoder = encoders[messagePartEncoder] = (IMessagePartEncoder)Activator.CreateInstance(messagePartEncoder); + } catch (MissingMethodException ex) { + throw ErrorUtilities.Wrap(ex, MessagingStrings.EncoderInstantiationFailed, messagePartEncoder.FullName); + } + } + + return encoder; + } + + /// <summary> + /// Gets the value of the message part, without converting it to/from a string. + /// </summary> + /// <param name="message">The message instance to read from.</param> + /// <returns>The value of the member.</returns> + private object GetValueAsObject(IMessage message) { + if (this.property != null) { + return this.property.GetValue(message, null); + } else { + return this.field.GetValue(message); + } + } + + /// <summary> + /// Sets the value of a message part directly with a given value. + /// </summary> + /// <param name="message">The message instance to read from.</param> + /// <param name="value">The value to set on the this part.</param> + private void SetValueAsObject(IMessage message, object value) { + if (this.property != null) { + this.property.SetValue(message, value, null); + } else { + this.field.SetValue(message, value); + } + } + + /// <summary> + /// Converts a string representation of the member's value to the appropriate type. + /// </summary> + /// <param name="value">The string representation of the member's value.</param> + /// <returns> + /// An instance of the appropriate type for setting the member. + /// </returns> + private object ToValue(string value) { + return this.converter.StringToValue(value); + } + + /// <summary> + /// Converts the member's value to its string representation. + /// </summary> + /// <param name="value">The value of the member.</param> + /// <param name="originalString">A value indicating whether a string matching the originally decoded string should be returned (as opposed to a normalized string).</param> + /// <returns> + /// The string representation of the member's value. + /// </returns> + private string ToString(object value, bool originalString) { + return originalString ? this.converter.ValueToOriginalString(value) : this.converter.ValueToString(value); + } + + /// <summary> + /// Validates that the message part and its attribute have agreeable settings. + /// </summary> + /// <exception cref="ArgumentException"> + /// Thrown when a non-nullable value type is set as optional. + /// </exception> + private void ValidateSettings() { + if (!this.IsRequired && IsNonNullableValueType(this.memberDeclaredType)) { + MemberInfo member = (MemberInfo)this.field ?? this.property; + throw new ArgumentException( + string.Format( + CultureInfo.CurrentCulture, + "Invalid combination: {0} on message type {1} is a non-nullable value type but is marked as optional.", + member.Name, + member.DeclaringType)); + } + } + } +} diff --git a/src/DotNetOpenAuth.Core/Messaging/Reflection/ValueMapping.cs b/src/DotNetOpenAuth.Core/Messaging/Reflection/ValueMapping.cs new file mode 100644 index 0000000..9c0fa83 --- /dev/null +++ b/src/DotNetOpenAuth.Core/Messaging/Reflection/ValueMapping.cs @@ -0,0 +1,67 @@ +//----------------------------------------------------------------------- +// <copyright file="ValueMapping.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Messaging.Reflection { + using System; + using System.Diagnostics.Contracts; + + /// <summary> + /// A pair of conversion functions to map some type to a string and back again. + /// </summary> + [ContractVerification(true)] + internal struct ValueMapping { + /// <summary> + /// The mapping function that converts some custom type to a string. + /// </summary> + internal readonly Func<object, string> ValueToString; + + /// <summary> + /// The mapping function that converts some custom type to the original string + /// (possibly non-normalized) that represents it. + /// </summary> + internal readonly Func<object, string> ValueToOriginalString; + + /// <summary> + /// The mapping function that converts a string to some custom type. + /// </summary> + internal readonly Func<string, object> StringToValue; + + /// <summary> + /// Initializes a new instance of the <see cref="ValueMapping"/> struct. + /// </summary> + /// <param name="toString">The mapping function that converts some custom value to a string.</param> + /// <param name="toOriginalString">The mapping function that converts some custom value to its original (non-normalized) string. May be null if the same as the <paramref name="toString"/> function.</param> + /// <param name="toValue">The mapping function that converts a string to some custom value.</param> + internal ValueMapping(Func<object, string> toString, Func<object, string> toOriginalString, Func<string, object> toValue) { + Requires.NotNull(toString, "toString"); + Requires.NotNull(toValue, "toValue"); + + this.ValueToString = toString; + this.ValueToOriginalString = toOriginalString ?? toString; + this.StringToValue = toValue; + } + + /// <summary> + /// Initializes a new instance of the <see cref="ValueMapping"/> struct. + /// </summary> + /// <param name="encoder">The encoder.</param> + internal ValueMapping(IMessagePartEncoder encoder) { + Requires.NotNull(encoder, "encoder"); + var nullEncoder = encoder as IMessagePartNullEncoder; + string nullString = nullEncoder != null ? nullEncoder.EncodedNullValue : null; + + var originalStringEncoder = encoder as IMessagePartOriginalEncoder; + Func<object, string> originalStringEncode = encoder.Encode; + if (originalStringEncoder != null) { + originalStringEncode = originalStringEncoder.EncodeAsOriginalString; + } + + this.ValueToString = obj => (obj != null) ? encoder.Encode(obj) : nullString; + this.StringToValue = str => (str != null) ? encoder.Decode(str) : null; + this.ValueToOriginalString = obj => (obj != null) ? originalStringEncode(obj) : nullString; + } + } +} |