1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
|
//-----------------------------------------------------------------------
// <copyright file="CoordinatingOAuthChannel.cs" company="Andrew Arnott">
// Copyright (c) Andrew Arnott. All rights reserved.
// </copyright>
//-----------------------------------------------------------------------
namespace DotNetOAuth.Test.Scenarios {
using System;
using System.Reflection;
using System.Threading;
using DotNetOAuth.ChannelElements;
using DotNetOAuth.Messages;
using DotNetOAuth.Messaging;
using DotNetOAuth.Messaging.Bindings;
using DotNetOAuth.Messaging.Reflection;
using DotNetOAuth.Test.Mocks;
/// <summary>
/// A special channel used in test simulations to pass messages directly between two parties.
/// </summary>
internal class CoordinatingOAuthChannel : OAuthChannel {
private EventWaitHandle incomingMessageSignal = new AutoResetEvent(false);
private IProtocolMessage incomingMessage;
private Response incomingRawResponse;
/// <summary>
/// Initializes a new instance of the <see cref="CoordinatingOAuthChannel"/> class for Consumers.
/// </summary>
/// <param name="signingBindingElement">
/// The signing element for the Consumer to use. Null for the Service Provider.
/// </param>
internal CoordinatingOAuthChannel(ITamperProtectionChannelBindingElement signingBindingElement)
: base(
signingBindingElement,
new NonceMemoryStore(StandardExpirationBindingElement.DefaultMaximumMessageAge),
new OAuthMessageTypeProvider(new InMemoryTokenManager()),
new TestWebRequestHandler()) {
}
/// <summary>
/// Gets or sets the coordinating channel used by the other party.
/// </summary>
internal CoordinatingOAuthChannel RemoteChannel { get; set; }
internal Response RequestProtectedResource(AccessProtectedResourcesMessage request) {
((ITamperResistantOAuthMessage)request).HttpMethod = this.GetHttpMethod(((ITamperResistantOAuthMessage)request).HttpMethods);
this.PrepareMessageForSending(request);
HttpRequestInfo requestInfo = this.SpoofHttpMethod(request);
TestBase.TestLogger.InfoFormat("Sending protected resource request: {0}", requestInfo.Message);
// Drop the outgoing message in the other channel's in-slot and let them know it's there.
this.RemoteChannel.incomingMessage = requestInfo.Message;
this.RemoteChannel.incomingMessageSignal.Set();
return this.AwaitIncomingRawResponse();
}
internal void SendDirectRawResponse(Response response) {
this.RemoteChannel.incomingRawResponse = response;
this.RemoteChannel.incomingMessageSignal.Set();
}
protected override IProtocolMessage RequestInternal(IDirectedProtocolMessage request) {
HttpRequestInfo requestInfo = this.SpoofHttpMethod(request);
TestBase.TestLogger.InfoFormat("Sending request: {0}", requestInfo.Message);
// Drop the outgoing message in the other channel's in-slot and let them know it's there.
this.RemoteChannel.incomingMessage = requestInfo.Message;
this.RemoteChannel.incomingMessageSignal.Set();
// Now wait for a response...
return this.AwaitIncomingMessage();
}
protected override void SendDirectMessageResponse(IProtocolMessage response) {
TestBase.TestLogger.InfoFormat("Sending response: {0}", response);
this.RemoteChannel.incomingMessage = CloneSerializedParts(response, null);
this.CopyDirectionalParts(response, this.RemoteChannel.incomingMessage);
this.RemoteChannel.incomingMessageSignal.Set();
}
protected override void SendIndirectMessage(IDirectedProtocolMessage message) {
TestBase.TestLogger.Info("Next response is an indirect message...");
// In this mock transport, direct and indirect messages are the same.
this.SendDirectMessageResponse(message);
}
protected override HttpRequestInfo GetRequestFromContext() {
return new HttpRequestInfo(this.AwaitIncomingMessage());
}
protected override IProtocolMessage ReadFromRequestInternal(HttpRequestInfo request) {
return request.Message;
}
/// <summary>
/// Spoof HTTP request information for signing/verification purposes.
/// </summary>
/// <param name="message">The message to add a pretend HTTP method to.</param>
/// <returns>A spoofed HttpRequestInfo that wraps the new message.</returns>
private HttpRequestInfo SpoofHttpMethod(IDirectedProtocolMessage message) {
HttpRequestInfo requestInfo = new HttpRequestInfo(message);
var signedMessage = message as ITamperResistantOAuthMessage;
if (signedMessage != null) {
string httpMethod = this.GetHttpMethod(signedMessage.HttpMethods);
requestInfo.HttpMethod = httpMethod;
requestInfo.Url = message.Recipient;
signedMessage.HttpMethod = httpMethod;
}
requestInfo.Message = this.CloneSerializedParts(message, requestInfo);
this.CopyDirectionalParts(message, requestInfo.Message); // Remove since its body is empty.
return requestInfo;
}
private IProtocolMessage AwaitIncomingMessage() {
this.incomingMessageSignal.WaitOne();
IProtocolMessage response = this.incomingMessage;
this.incomingMessage = null;
return response;
}
private Response AwaitIncomingRawResponse() {
this.incomingMessageSignal.WaitOne();
Response response = this.incomingRawResponse;
this.incomingRawResponse = null;
return response;
}
private T CloneSerializedParts<T>(T message, HttpRequestInfo requestInfo) where T : class, IProtocolMessage {
if (message == null) {
throw new ArgumentNullException("message");
}
MessageReceivingEndpoint recipient = null;
IOAuthDirectedMessage directedMessage = message as IOAuthDirectedMessage;
if (directedMessage != null && directedMessage.Recipient != null) {
recipient = new MessageReceivingEndpoint(directedMessage.Recipient, directedMessage.HttpMethods);
}
MessageSerializer serializer = MessageSerializer.Get(message.GetType());
return (T)serializer.Deserialize(serializer.Serialize(message), recipient);
}
private void CopyDirectionalParts(IProtocolMessage original, IProtocolMessage copy) {
var signedOriginal = original as ITamperResistantOAuthMessage;
var signedCopy = copy as ITamperResistantOAuthMessage;
if (signedOriginal != null && signedCopy != null) {
signedCopy.HttpMethod = signedOriginal.HttpMethod;
}
}
private string GetHttpMethod(HttpDeliveryMethod methods) {
return (methods & HttpDeliveryMethod.PostRequest) != 0 ? "POST" : "GET";
}
}
}
|