summaryrefslogtreecommitdiffstats
path: root/src/DotNetOpenAuth.Core/Messaging/StandardMessageFactory.cs
blob: 762b54b5f22a477d6ab785d386bc5346d12c26c8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
//-----------------------------------------------------------------------
// <copyright file="StandardMessageFactory.cs" company="Outercurve Foundation">
//     Copyright (c) Outercurve Foundation. All rights reserved.
// </copyright>
//-----------------------------------------------------------------------

namespace DotNetOpenAuth.Messaging {
	using System;
	using System.Collections.Generic;
	using System.Diagnostics.Contracts;
	using System.Linq;
	using System.Reflection;
	using System.Text;
	using DotNetOpenAuth.Messaging.Reflection;

	/// <summary>
	/// A message factory that automatically selects the message type based on the incoming data.
	/// </summary>
	internal class StandardMessageFactory : IMessageFactory {
		/// <summary>
		/// The request message types and their constructors to use for instantiating the messages.
		/// </summary>
		private readonly Dictionary<MessageDescription, ConstructorInfo> requestMessageTypes = new Dictionary<MessageDescription, ConstructorInfo>();

		/// <summary>
		/// The response message types and their constructors to use for instantiating the messages.
		/// </summary>
		/// <value>
		/// The value is a dictionary, whose key is the type of the constructor's lone parameter.
		/// </value>
		private readonly Dictionary<MessageDescription, Dictionary<Type, ConstructorInfo>> responseMessageTypes = new Dictionary<MessageDescription, Dictionary<Type, ConstructorInfo>>();

		/// <summary>
		/// Initializes a new instance of the <see cref="StandardMessageFactory"/> class.
		/// </summary>
		internal StandardMessageFactory() {
		}

		/// <summary>
		/// Adds message types to the set that this factory can create.
		/// </summary>
		/// <param name="messageTypes">The message types that this factory may instantiate.</param>
		public virtual void AddMessageTypes(IEnumerable<MessageDescription> messageTypes) {
			Requires.NotNull(messageTypes, "messageTypes");
			Requires.True(messageTypes.All(msg => msg != null), "messageTypes");

			var unsupportedMessageTypes = new List<MessageDescription>(0);
			foreach (MessageDescription messageDescription in messageTypes) {
				bool supportedMessageType = false;

				// First see whether this message fits the recognized pattern for request messages.
				if (typeof(IDirectedProtocolMessage).IsAssignableFrom(messageDescription.MessageType)) {
					foreach (ConstructorInfo ctor in messageDescription.Constructors) {
						ParameterInfo[] parameters = ctor.GetParameters();
						if (parameters.Length == 2 && parameters[0].ParameterType == typeof(Uri) && parameters[1].ParameterType == typeof(Version)) {
							supportedMessageType = true;
							this.requestMessageTypes.Add(messageDescription, ctor);
							break;
						}
					}
				}

				// Also see if this message fits the recognized pattern for response messages.
				if (typeof(IDirectResponseProtocolMessage).IsAssignableFrom(messageDescription.MessageType)) {
					var responseCtors = new Dictionary<Type, ConstructorInfo>(messageDescription.Constructors.Length);
					foreach (ConstructorInfo ctor in messageDescription.Constructors) {
						ParameterInfo[] parameters = ctor.GetParameters();
						if (parameters.Length == 1 && typeof(IDirectedProtocolMessage).IsAssignableFrom(parameters[0].ParameterType)) {
							responseCtors.Add(parameters[0].ParameterType, ctor);
						}
					}

					if (responseCtors.Count > 0) {
						supportedMessageType = true;
						this.responseMessageTypes.Add(messageDescription, responseCtors);
					}
				}

				if (!supportedMessageType) {
					unsupportedMessageTypes.Add(messageDescription);
				}
			}

			ErrorUtilities.VerifySupported(
				!unsupportedMessageTypes.Any(),
				MessagingStrings.StandardMessageFactoryUnsupportedMessageType,
				unsupportedMessageTypes.ToStringDeferred());
		}

		#region IMessageFactory Members

		/// <summary>
		/// Analyzes an incoming request message payload to discover what kind of
		/// message is embedded in it and returns the type, or null if no match is found.
		/// </summary>
		/// <param name="recipient">The intended or actual recipient of the request message.</param>
		/// <param name="fields">The name/value pairs that make up the message payload.</param>
		/// <returns>
		/// A newly instantiated <see cref="IProtocolMessage"/>-derived object that this message can
		/// deserialize to.  Null if the request isn't recognized as a valid protocol message.
		/// </returns>
		public virtual IDirectedProtocolMessage GetNewRequestMessage(MessageReceivingEndpoint recipient, IDictionary<string, string> fields) {
			MessageDescription matchingType = this.GetMessageDescription(recipient, fields);
			if (matchingType != null) {
				return this.InstantiateAsRequest(matchingType, recipient);
			} else {
				return null;
			}
		}

		/// <summary>
		/// Analyzes an incoming request message payload to discover what kind of
		/// message is embedded in it and returns the type, or null if no match is found.
		/// </summary>
		/// <param name="request">The message that was sent as a request that resulted in the response.</param>
		/// <param name="fields">The name/value pairs that make up the message payload.</param>
		/// <returns>
		/// A newly instantiated <see cref="IProtocolMessage"/>-derived object that this message can
		/// deserialize to.  Null if the request isn't recognized as a valid protocol message.
		/// </returns>
		public virtual IDirectResponseProtocolMessage GetNewResponseMessage(IDirectedProtocolMessage request, IDictionary<string, string> fields) {
			MessageDescription matchingType = this.GetMessageDescription(request, fields);
			if (matchingType != null) {
				return this.InstantiateAsResponse(matchingType, request);
			} else {
				return null;
			}
		}

		#endregion

		/// <summary>
		/// Gets the message type that best fits the given incoming request data.
		/// </summary>
		/// <param name="recipient">The recipient of the incoming data.  Typically not used, but included just in case.</param>
		/// <param name="fields">The data of the incoming message.</param>
		/// <returns>
		/// The message type that matches the incoming data; or <c>null</c> if no match.
		/// </returns>
		/// <exception cref="ProtocolException">May be thrown if the incoming data is ambiguous.</exception>
		protected virtual MessageDescription GetMessageDescription(MessageReceivingEndpoint recipient, IDictionary<string, string> fields) {
			Requires.NotNull(recipient, "recipient");
			Requires.NotNull(fields, "fields");

			var matches = this.requestMessageTypes.Keys
				.Where(message => message.CheckMessagePartsPassBasicValidation(fields))
				.OrderByDescending(message => CountInCommon(message.Mapping.Keys, fields.Keys))
				.ThenByDescending(message => message.Mapping.Count)
				.CacheGeneratedResults();
			var match = matches.FirstOrDefault();
			if (match != null) {
				if (Logger.Messaging.IsWarnEnabled && matches.Count() > 1) {
					Logger.Messaging.WarnFormat(
						"Multiple message types seemed to fit the incoming data: {0}",
						matches.ToStringDeferred());
				}

				return match;
			} else {
				// No message type matches the incoming data.
				return null;
			}
		}

		/// <summary>
		/// Gets the message type that best fits the given incoming direct response data.
		/// </summary>
		/// <param name="request">The request message that prompted the response data.</param>
		/// <param name="fields">The data of the incoming message.</param>
		/// <returns>
		/// The message type that matches the incoming data; or <c>null</c> if no match.
		/// </returns>
		/// <exception cref="ProtocolException">May be thrown if the incoming data is ambiguous.</exception>
		protected virtual MessageDescription GetMessageDescription(IDirectedProtocolMessage request, IDictionary<string, string> fields) {
			Requires.NotNull(request, "request");
			Requires.NotNull(fields, "fields");

			var matches = (from responseMessageType in this.responseMessageTypes
			               let message = responseMessageType.Key
			               where message.CheckMessagePartsPassBasicValidation(fields)
			               let ctors = this.FindMatchingResponseConstructors(message, request.GetType())
			               where ctors.Any()
			               orderby GetDerivationDistance(ctors.First().GetParameters()[0].ParameterType, request.GetType()),
			                 CountInCommon(message.Mapping.Keys, fields.Keys) descending,
			                 message.Mapping.Count descending
			               select message).CacheGeneratedResults();
			var match = matches.FirstOrDefault();
			if (match != null) {
				if (Logger.Messaging.IsWarnEnabled && matches.Count() > 1) {
					Logger.Messaging.WarnFormat(
						"Multiple message types seemed to fit the incoming data: {0}",
						matches.ToStringDeferred());
				}

				return match;
			} else {
				// No message type matches the incoming data.
				return null;
			}
		}

		/// <summary>
		/// Instantiates the given request message type.
		/// </summary>
		/// <param name="messageDescription">The message description.</param>
		/// <param name="recipient">The recipient.</param>
		/// <returns>The instantiated message.  Never null.</returns>
		protected virtual IDirectedProtocolMessage InstantiateAsRequest(MessageDescription messageDescription, MessageReceivingEndpoint recipient) {
			Requires.NotNull(messageDescription, "messageDescription");
			Requires.NotNull(recipient, "recipient");
			Contract.Ensures(Contract.Result<IDirectedProtocolMessage>() != null);

			ConstructorInfo ctor = this.requestMessageTypes[messageDescription];
			return (IDirectedProtocolMessage)ctor.Invoke(new object[] { recipient.Location, messageDescription.MessageVersion });
		}

		/// <summary>
		/// Instantiates the given request message type.
		/// </summary>
		/// <param name="messageDescription">The message description.</param>
		/// <param name="request">The request that resulted in this response.</param>
		/// <returns>The instantiated message.  Never null.</returns>
		protected virtual IDirectResponseProtocolMessage InstantiateAsResponse(MessageDescription messageDescription, IDirectedProtocolMessage request) {
			Requires.NotNull(messageDescription, "messageDescription");
			Requires.NotNull(request, "request");
			Contract.Ensures(Contract.Result<IDirectResponseProtocolMessage>() != null);

			Type requestType = request.GetType();
			var ctors = this.FindMatchingResponseConstructors(messageDescription, requestType);
			ConstructorInfo ctor = null;
			try {
				ctor = ctors.Single();
			} catch (InvalidOperationException) {
				if (ctors.Any()) {
					ErrorUtilities.ThrowInternal("More than one matching constructor for request type " + requestType.Name + " and response type " + messageDescription.MessageType.Name);
				} else {
					ErrorUtilities.ThrowInternal("Unexpected request message type " + requestType.FullName + " for response type " + messageDescription.MessageType.Name);
				}
			}
			return (IDirectResponseProtocolMessage)ctor.Invoke(new object[] { request });
		}

		/// <summary>
		/// Gets the hierarchical distance between a type and a type it derives from or implements.
		/// </summary>
		/// <param name="assignableType">The base type or interface.</param>
		/// <param name="derivedType">The concrete class that implements the <paramref name="assignableType"/>.</param>
		/// <returns>The distance between the two types.  0 if the types are equivalent, 1 if the type immediately derives from or implements the base type, or progressively higher integers.</returns>
		private static int GetDerivationDistance(Type assignableType, Type derivedType) {
			Requires.NotNull(assignableType, "assignableType");
			Requires.NotNull(derivedType, "derivedType");
			Requires.True(assignableType.IsAssignableFrom(derivedType), "assignableType");

			// If this is the two types are equivalent...
			if (derivedType.IsAssignableFrom(assignableType))
			{
				return 0;
			}

			int steps;
			derivedType = derivedType.BaseType;
			for (steps = 1; assignableType.IsAssignableFrom(derivedType); steps++)
			{
				derivedType = derivedType.BaseType;
			}

			return steps;
		}

		/// <summary>
		/// Counts how many strings are in the intersection of two collections.
		/// </summary>
		/// <param name="collection1">The first collection.</param>
		/// <param name="collection2">The second collection.</param>
		/// <param name="comparison">The string comparison method to use.</param>
		/// <returns>A non-negative integer no greater than the count of elements in the smallest collection.</returns>
		private static int CountInCommon(ICollection<string> collection1, ICollection<string> collection2, StringComparison comparison = StringComparison.Ordinal) {
			Requires.NotNull(collection1, "collection1");
			Requires.NotNull(collection2, "collection2");
			Contract.Ensures(Contract.Result<int>() >= 0 && Contract.Result<int>() <= Math.Min(collection1.Count, collection2.Count));

			return collection1.Count(value1 => collection2.Any(value2 => string.Equals(value1, value2, comparison)));
		}

		/// <summary>
		/// Finds constructors for response messages that take a given request message type.
		/// </summary>
		/// <param name="messageDescription">The message description.</param>
		/// <param name="requestType">Type of the request message.</param>
		/// <returns>A sequence of matching constructors.</returns>
		private IEnumerable<ConstructorInfo> FindMatchingResponseConstructors(MessageDescription messageDescription, Type requestType) {
			Requires.NotNull(messageDescription, "messageDescription");
			Requires.NotNull(requestType, "requestType");

			return this.responseMessageTypes[messageDescription].Where(pair => pair.Key.IsAssignableFrom(requestType)).Select(pair => pair.Value);
		}
	}
}