1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
|
//-----------------------------------------------------------------------
// <copyright file="CoordinatingOAuthChannel.cs" company="Andrew Arnott">
// Copyright (c) Andrew Arnott. All rights reserved.
// </copyright>
//-----------------------------------------------------------------------
namespace DotNetOpenAuth.Test.Mocks {
using System;
using System.Diagnostics.Contracts;
using System.Threading;
using DotNetOpenAuth.Messaging;
using DotNetOpenAuth.Messaging.Bindings;
using DotNetOpenAuth.OAuth.ChannelElements;
using DotNetOpenAuth.OAuth.Messages;
/// <summary>
/// A special channel used in test simulations to pass messages directly between two parties.
/// </summary>
internal class CoordinatingOAuthChannel : OAuthChannel {
private EventWaitHandle incomingMessageSignal = new AutoResetEvent(false);
private IProtocolMessage incomingMessage;
private OutgoingWebResponse incomingRawResponse;
/// <summary>
/// Initializes a new instance of the <see cref="CoordinatingOAuthChannel"/> class for Consumers.
/// </summary>
/// <param name="signingBindingElement">
/// The signing element for the Consumer to use. Null for the Service Provider.
/// </param>
/// <param name="tokenManager">The token manager to use.</param>
internal CoordinatingOAuthChannel(ITamperProtectionChannelBindingElement signingBindingElement, IConsumerTokenManager tokenManager)
: base(
signingBindingElement,
new NonceMemoryStore(StandardExpirationBindingElement.MaximumMessageAge),
tokenManager) {
}
/// <summary>
/// Initializes a new instance of the <see cref="CoordinatingOAuthChannel"/> class for Consumers.
/// </summary>
/// <param name="signingBindingElement">
/// The signing element for the Consumer to use. Null for the Service Provider.
/// </param>
/// <param name="tokenManager">The token manager to use.</param>
internal CoordinatingOAuthChannel(ITamperProtectionChannelBindingElement signingBindingElement, IServiceProviderTokenManager tokenManager)
: base(
signingBindingElement,
new NonceMemoryStore(StandardExpirationBindingElement.MaximumMessageAge),
tokenManager) {
}
/// <summary>
/// Gets or sets the coordinating channel used by the other party.
/// </summary>
internal CoordinatingOAuthChannel RemoteChannel { get; set; }
internal OutgoingWebResponse RequestProtectedResource(AccessProtectedResourceRequest request) {
((ITamperResistantOAuthMessage)request).HttpMethod = this.GetHttpMethod(((ITamperResistantOAuthMessage)request).HttpMethods);
this.ProcessOutgoingMessage(request);
HttpRequestInfo requestInfo = this.SpoofHttpMethod(request);
TestBase.TestLogger.InfoFormat("Sending protected resource request: {0}", requestInfo.Message);
// Drop the outgoing message in the other channel's in-slot and let them know it's there.
this.RemoteChannel.incomingMessage = requestInfo.Message;
this.RemoteChannel.incomingMessageSignal.Set();
return this.AwaitIncomingRawResponse();
}
internal void SendDirectRawResponse(OutgoingWebResponse response) {
this.RemoteChannel.incomingRawResponse = response;
this.RemoteChannel.incomingMessageSignal.Set();
}
protected internal override HttpRequestInfo GetRequestFromContext() {
var directedMessage = (IDirectedProtocolMessage)this.AwaitIncomingMessage();
return new HttpRequestInfo(directedMessage, directedMessage.HttpMethods);
}
protected override IProtocolMessage RequestCore(IDirectedProtocolMessage request) {
HttpRequestInfo requestInfo = this.SpoofHttpMethod(request);
// Drop the outgoing message in the other channel's in-slot and let them know it's there.
this.RemoteChannel.incomingMessage = requestInfo.Message;
this.RemoteChannel.incomingMessageSignal.Set();
// Now wait for a response...
return this.AwaitIncomingMessage();
}
protected override OutgoingWebResponse PrepareDirectResponse(IProtocolMessage response) {
this.RemoteChannel.incomingMessage = CloneSerializedParts(response, null);
this.RemoteChannel.incomingMessageSignal.Set();
return new OutgoingWebResponse(); // not used, but returning null is not allowed
}
protected override OutgoingWebResponse PrepareIndirectResponse(IDirectedProtocolMessage message) {
// In this mock transport, direct and indirect messages are the same.
return this.PrepareDirectResponse(message);
}
protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestInfo request) {
return request.Message;
}
/// <summary>
/// Spoof HTTP request information for signing/verification purposes.
/// </summary>
/// <param name="message">The message to add a pretend HTTP method to.</param>
/// <returns>A spoofed HttpRequestInfo that wraps the new message.</returns>
private HttpRequestInfo SpoofHttpMethod(IDirectedProtocolMessage message) {
HttpRequestInfo requestInfo = new HttpRequestInfo(message, message.HttpMethods);
var signedMessage = message as ITamperResistantOAuthMessage;
if (signedMessage != null) {
string httpMethod = this.GetHttpMethod(signedMessage.HttpMethods);
requestInfo.HttpMethod = httpMethod;
requestInfo.UrlBeforeRewriting = message.Recipient;
signedMessage.HttpMethod = httpMethod;
}
requestInfo.Message = this.CloneSerializedParts(message, requestInfo);
return requestInfo;
}
private IProtocolMessage AwaitIncomingMessage() {
this.incomingMessageSignal.WaitOne();
IProtocolMessage response = this.incomingMessage;
this.incomingMessage = null;
return response;
}
private OutgoingWebResponse AwaitIncomingRawResponse() {
this.incomingMessageSignal.WaitOne();
OutgoingWebResponse response = this.incomingRawResponse;
this.incomingRawResponse = null;
return response;
}
private T CloneSerializedParts<T>(T message, HttpRequestInfo requestInfo) where T : class, IProtocolMessage {
Contract.Requires<ArgumentNullException>(message != null);
IProtocolMessage clonedMessage;
var messageAccessor = this.MessageDescriptions.GetAccessor(message);
var fields = messageAccessor.Serialize();
MessageReceivingEndpoint recipient = null;
var directedMessage = message as IDirectedProtocolMessage;
var directResponse = message as IDirectResponseProtocolMessage;
if (directedMessage != null && directedMessage.IsRequest()) {
if (directedMessage.Recipient != null) {
recipient = new MessageReceivingEndpoint(directedMessage.Recipient, directedMessage.HttpMethods);
}
clonedMessage = this.RemoteChannel.MessageFactory.GetNewRequestMessage(recipient, fields);
} else if (directResponse != null && directResponse.IsDirectResponse()) {
clonedMessage = this.RemoteChannel.MessageFactory.GetNewResponseMessage(directResponse.OriginatingRequest, fields);
} else {
throw new InvalidOperationException("Totally expected a message to implement one of the two derived interface types.");
}
// Fill the cloned message with data.
var clonedMessageAccessor = this.MessageDescriptions.GetAccessor(clonedMessage);
clonedMessageAccessor.Deserialize(fields);
return (T)clonedMessage;
}
private string GetHttpMethod(HttpDeliveryMethods methods) {
return (methods & HttpDeliveryMethods.PostRequest) != 0 ? "POST" : "GET";
}
}
}
|