diff options
Diffstat (limited to 'src')
8 files changed, 93 insertions, 116 deletions
diff --git a/src/DotNetOpenAuth.Test/CoordinatorBase.cs b/src/DotNetOpenAuth.Test/CoordinatorBase.cs index d1c6f85..48067af 100644 --- a/src/DotNetOpenAuth.Test/CoordinatorBase.cs +++ b/src/DotNetOpenAuth.Test/CoordinatorBase.cs @@ -6,7 +6,9 @@ namespace DotNetOpenAuth.Test { using System; + using System.Collections.Generic; using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId.RelyingParty; using DotNetOpenAuth.Test.Mocks; @@ -14,10 +16,10 @@ namespace DotNetOpenAuth.Test { using Validation; internal abstract class CoordinatorBase<T1, T2> { - private Action<T1> party1Action; - private Action<T2> party2Action; + private Func<T1, CancellationToken, Task> party1Action; + private Func<T2, CancellationToken, Task> party2Action; - protected CoordinatorBase(Action<T1> party1Action, Action<T2> party2Action) { + protected CoordinatorBase(Func<T1, CancellationToken, Task> party1Action, Func<T2, CancellationToken, Task> party2Action) { Requires.NotNull(party1Action, "party1Action"); Requires.NotNull(party2Action, "party2Action"); @@ -29,62 +31,24 @@ namespace DotNetOpenAuth.Test { protected internal Action<IProtocolMessage> OutgoingMessageFilter { get; set; } - internal abstract void Run(); + internal abstract Task RunAsync(); - protected void RunCore(T1 party1Object, T2 party2Object) { - Thread party1Thread = null, party2Thread = null; - Exception failingException = null; + protected async Task RunCoreAsync(T1 party1Object, T2 party2Object) { + var cts = new CancellationTokenSource(); - // Each thread we create needs a surrounding exception catcher so that we can - // terminate the other thread and inform the test host that the test failed. - Action<Action> safeWrapper = (action) => { - try { - TestBase.SetMockHttpContext(); - action(); - } catch (Exception ex) { - // We may be the second thread in an ThreadAbortException, so check the "flag" - lock (this) { - if (failingException == null || (failingException is ThreadAbortException && !(ex is ThreadAbortException))) { - failingException = ex; - if (Thread.CurrentThread == party1Thread) { - party2Thread.Abort(); - } else { - party1Thread.Abort(); - } - } - } - } - }; - - // Run the threads, and wait for them to complete. - // If this main thread is aborted (test run aborted), go ahead and abort the other two threads. - party1Thread = new Thread(() => { safeWrapper(() => { this.party1Action(party1Object); }); }); - party2Thread = new Thread(() => { safeWrapper(() => { this.party2Action(party2Object); }); }); - party1Thread.Name = "P1"; - party2Thread.Name = "P2"; try { - party1Thread.Start(); - party2Thread.Start(); - party1Thread.Join(); - party2Thread.Join(); - } catch (ThreadAbortException) { - party1Thread.Abort(); - party2Thread.Abort(); + var parties = new List<Task> { + Task.Run(() => this.party1Action(party1Object, cts.Token)), + Task.Run(() => this.party2Action(party2Object, cts.Token)), + }; + var completingTask = await Task.WhenAny(parties); + await completingTask; // rethrow any exception from the first completing task. + + // if no exception, then block for the second task now. + await Task.WhenAll(parties); + } catch { + cts.Cancel(); // cause the second party to terminate, if necessary. throw; - } catch (ThreadStartException ex) { - if (ex.InnerException is ThreadAbortException) { - // if party1Thread threw an exception - // (which may even have been intentional for the test) - // before party2Thread even started, then this exception - // can be thrown, and should be ignored. - } else { - throw; - } - } - - // Use the failing reason of a failing sub-thread as our reason, if anything failed. - if (failingException != null) { - throw new AssertionException("Coordinator thread threw unhandled exception: " + failingException, failingException); } } } diff --git a/src/DotNetOpenAuth.Test/Messaging/Bindings/StandardExpirationBindingElementTests.cs b/src/DotNetOpenAuth.Test/Messaging/Bindings/StandardExpirationBindingElementTests.cs index 6aa9461..d525766 100644 --- a/src/DotNetOpenAuth.Test/Messaging/Bindings/StandardExpirationBindingElementTests.cs +++ b/src/DotNetOpenAuth.Test/Messaging/Bindings/StandardExpirationBindingElementTests.cs @@ -6,7 +6,7 @@ namespace DotNetOpenAuth.Test.Messaging.Bindings { using System; - + using System.Threading.Tasks; using DotNetOpenAuth.Configuration; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; @@ -16,13 +16,13 @@ namespace DotNetOpenAuth.Test.Messaging.Bindings { [TestFixture] public class StandardExpirationBindingElementTests : MessagingTestBase { [Test] - public void SendSetsTimestamp() { + public async Task SendSetsTimestamp() { TestExpiringMessage message = new TestExpiringMessage(MessageTransport.Indirect); message.Recipient = new Uri("http://localtest"); ((IExpiringProtocolMessage)message).UtcCreationDate = DateTime.Parse("1/1/1990"); Channel channel = CreateChannel(MessageProtections.Expiration); - channel.PrepareResponse(message); + await channel.PrepareResponseAsync(message); Assert.IsTrue(DateTime.UtcNow - ((IExpiringProtocolMessage)message).UtcCreationDate < TimeSpan.FromSeconds(3), "The timestamp on the message was not set on send."); } diff --git a/src/DotNetOpenAuth.Test/Messaging/Bindings/StandardReplayProtectionBindingElementTests.cs b/src/DotNetOpenAuth.Test/Messaging/Bindings/StandardReplayProtectionBindingElementTests.cs index 9a46e42..04c63ef 100644 --- a/src/DotNetOpenAuth.Test/Messaging/Bindings/StandardReplayProtectionBindingElementTests.cs +++ b/src/DotNetOpenAuth.Test/Messaging/Bindings/StandardReplayProtectionBindingElementTests.cs @@ -9,6 +9,8 @@ namespace DotNetOpenAuth.Test.Messaging.Bindings { using System.Collections.Generic; using System.Linq; using System.Text; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OpenId; @@ -40,14 +42,14 @@ namespace DotNetOpenAuth.Test.Messaging.Bindings { /// Verifies that the generated nonce includes random characters. /// </summary> [Test] - public void RandomCharactersTest() { - Assert.IsNotNull(this.nonceElement.ProcessOutgoingMessage(this.message)); + public async Task RandomCharactersTest() { + Assert.IsNotNull(await this.nonceElement.ProcessOutgoingMessageAsync(this.message, CancellationToken.None)); Assert.IsNotNull(this.message.Nonce, "No nonce was set on the message."); Assert.AreNotEqual(0, this.message.Nonce.Length, "The generated nonce was empty."); string firstNonce = this.message.Nonce; // Apply another nonce and verify that they are different than the first ones. - Assert.IsNotNull(this.nonceElement.ProcessOutgoingMessage(this.message)); + Assert.IsNotNull(await this.nonceElement.ProcessOutgoingMessageAsync(this.message, CancellationToken.None)); Assert.IsNotNull(this.message.Nonce, "No nonce was set on the message."); Assert.AreNotEqual(0, this.message.Nonce.Length, "The generated nonce was empty."); Assert.AreNotEqual(firstNonce, this.message.Nonce, "The two generated nonces are identical."); @@ -57,41 +59,41 @@ namespace DotNetOpenAuth.Test.Messaging.Bindings { /// Verifies that a message is received correctly. /// </summary> [Test] - public void ValidMessageReceivedTest() { + public async Task ValidMessageReceivedTest() { this.message.Nonce = "a"; - Assert.IsNotNull(this.nonceElement.ProcessIncomingMessage(this.message)); + Assert.IsNotNull(await this.nonceElement.ProcessIncomingMessageAsync(this.message, CancellationToken.None)); } /// <summary> /// Verifies that a message that doesn't have a string of random characters is received correctly. /// </summary> [Test] - public void ValidMessageNoNonceReceivedTest() { + public async Task ValidMessageNoNonceReceivedTest() { this.message.Nonce = string.Empty; this.nonceElement.AllowZeroLengthNonce = true; - Assert.IsNotNull(this.nonceElement.ProcessIncomingMessage(this.message)); + Assert.IsNotNull(await this.nonceElement.ProcessIncomingMessageAsync(this.message, CancellationToken.None)); } /// <summary> /// Verifies that a message that doesn't have a string of random characters is received correctly. /// </summary> [Test, ExpectedException(typeof(ProtocolException))] - public void InvalidMessageNoNonceReceivedTest() { + public async Task InvalidMessageNoNonceReceivedTest() { this.message.Nonce = string.Empty; this.nonceElement.AllowZeroLengthNonce = false; - Assert.IsNotNull(this.nonceElement.ProcessIncomingMessage(this.message)); + Assert.IsNotNull(await this.nonceElement.ProcessIncomingMessageAsync(this.message, CancellationToken.None)); } /// <summary> /// Verifies that a replayed message is rejected. /// </summary> [Test, ExpectedException(typeof(ReplayedMessageException))] - public void ReplayDetectionTest() { + public async Task ReplayDetectionTest() { this.message.Nonce = "a"; - Assert.IsNotNull(this.nonceElement.ProcessIncomingMessage(this.message)); + Assert.IsNotNull(await this.nonceElement.ProcessIncomingMessageAsync(this.message, CancellationToken.None)); // Now receive the same message again. This should throw because it's a message replay. - this.nonceElement.ProcessIncomingMessage(this.message); + await this.nonceElement.ProcessIncomingMessageAsync(this.message, CancellationToken.None); } } } diff --git a/src/DotNetOpenAuth.Test/Messaging/ChannelTests.cs b/src/DotNetOpenAuth.Test/Messaging/ChannelTests.cs index 5646a7e..4540373 100644 --- a/src/DotNetOpenAuth.Test/Messaging/ChannelTests.cs +++ b/src/DotNetOpenAuth.Test/Messaging/ChannelTests.cs @@ -9,6 +9,8 @@ namespace DotNetOpenAuth.Test.Messaging { using System.Collections.Generic; using System.IO; using System.Net; + using System.Threading; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; @@ -39,44 +41,44 @@ namespace DotNetOpenAuth.Test.Messaging { /// will reject messages that come with an unexpected HTTP verb. /// </summary> [Test, ExpectedException(typeof(ProtocolException))] - public void ReadFromRequestDisallowedHttpMethod() { + public async Task ReadFromRequestDisallowedHttpMethod() { var fields = GetStandardTestFields(FieldFill.CompleteBeforeBindings); fields["GetOnly"] = "true"; - this.Channel.ReadFromRequest(CreateHttpRequestInfo("POST", fields)); + await this.Channel.ReadFromRequestAsync(CreateHttpRequestInfo("POST", fields), CancellationToken.None); } [Test, ExpectedException(typeof(ArgumentNullException))] - public void SendNull() { - this.Channel.PrepareResponse(null); + public async Task SendNull() { + await this.Channel.PrepareResponseAsync(null); } [Test, ExpectedException(typeof(ArgumentException))] - public void SendIndirectedUndirectedMessage() { + public async Task SendIndirectedUndirectedMessage() { IProtocolMessage message = new TestDirectedMessage(MessageTransport.Indirect); - this.Channel.PrepareResponse(message); + await this.Channel.PrepareResponseAsync(message); } [Test, ExpectedException(typeof(ArgumentException))] - public void SendDirectedNoRecipientMessage() { + public async Task SendDirectedNoRecipientMessage() { IProtocolMessage message = new TestDirectedMessage(MessageTransport.Indirect); - this.Channel.PrepareResponse(message); + await this.Channel.PrepareResponseAsync(message); } [Test, ExpectedException(typeof(ArgumentException))] - public void SendInvalidMessageTransport() { + public async Task SendInvalidMessageTransport() { IProtocolMessage message = new TestDirectedMessage((MessageTransport)100); - this.Channel.PrepareResponse(message); + await this.Channel.PrepareResponseAsync(message); } [Test] - public void SendIndirectMessage301Get() { + public async Task SendIndirectMessage301Get() { TestDirectedMessage message = new TestDirectedMessage(MessageTransport.Indirect); GetStandardTestMessage(FieldFill.CompleteBeforeBindings, message); message.Recipient = new Uri("http://provider/path"); var expected = GetStandardTestFields(FieldFill.CompleteBeforeBindings); - OutgoingWebResponse response = this.Channel.PrepareResponse(message); - Assert.AreEqual(HttpStatusCode.Redirect, response.Status); + var response = await this.Channel.PrepareResponseAsync(message); + Assert.AreEqual(HttpStatusCode.Redirect, response.StatusCode); Assert.AreEqual("text/html; charset=utf-8", response.Headers[HttpResponseHeader.ContentType]); Assert.IsTrue(response.Body != null && response.Body.Length > 0); // a non-empty body helps get passed filters like WebSense StringAssert.StartsWith("http://provider/path", response.Headers[HttpResponseHeader.Location]); @@ -110,7 +112,7 @@ namespace DotNetOpenAuth.Test.Messaging { } [Test] - public void SendIndirectMessageFormPost() { + public async Task SendIndirectMessageFormPost() { // We craft a very large message to force fallback to form POST. // We'll also stick some HTML reserved characters in the string value // to test proper character escaping. @@ -120,8 +122,8 @@ namespace DotNetOpenAuth.Test.Messaging { Location = new Uri("http://host/path"), Recipient = new Uri("http://provider/path"), }; - OutgoingWebResponse response = this.Channel.PrepareResponse(message); - Assert.AreEqual(HttpStatusCode.OK, response.Status, "A form redirect should be an HTTP successful response."); + var response = await this.Channel.PrepareResponseAsync(message); + Assert.AreEqual(HttpStatusCode.OK, response.StatusCode, "A form redirect should be an HTTP successful response."); Assert.IsNull(response.Headers[HttpResponseHeader.Location], "There should not be a redirection header in the response."); string body = response.Body; StringAssert.Contains("<form ", body); @@ -162,13 +164,13 @@ namespace DotNetOpenAuth.Test.Messaging { /// we just check that the right method was called. /// </remarks> [Test, ExpectedException(typeof(NotImplementedException))] - public void SendDirectMessageResponse() { + public async Task SendDirectMessageResponse() { IProtocolMessage message = new TestDirectedMessage { Age = 15, Name = "Andrew", Location = new Uri("http://host/path"), }; - this.Channel.PrepareResponse(message); + await this.Channel.PrepareResponseAsync(message); } [Test, ExpectedException(typeof(ArgumentNullException))] @@ -190,12 +192,12 @@ namespace DotNetOpenAuth.Test.Messaging { } [Test] - public void ReadFromRequestWithContext() { + public async Task ReadFromRequestWithContext() { var fields = GetStandardTestFields(FieldFill.AllRequired); TestMessage expectedMessage = GetStandardTestMessage(FieldFill.AllRequired); HttpRequest request = new HttpRequest("somefile", "http://someurl", MessagingUtilities.CreateQueryString(fields)); HttpContext.Current = new HttpContext(request, new HttpResponse(new StringWriter())); - IProtocolMessage message = this.Channel.ReadFromRequest(); + IProtocolMessage message = await this.Channel.ReadFromRequestAsync(CancellationToken.None); Assert.IsNotNull(message); Assert.IsInstanceOf<TestMessage>(message); Assert.AreEqual(expectedMessage.Age, ((TestMessage)message).Age); @@ -215,12 +217,12 @@ namespace DotNetOpenAuth.Test.Messaging { } [Test] - public void SendReplayProtectedMessageSetsNonce() { + public async Task SendReplayProtectedMessageSetsNonce() { TestReplayProtectedMessage message = new TestReplayProtectedMessage(MessageTransport.Indirect); message.Recipient = new Uri("http://localtest"); this.Channel = CreateChannel(MessageProtections.ReplayProtection); - this.Channel.PrepareResponse(message); + await this.Channel.PrepareResponseAsync(message); Assert.IsNotNull(((IReplayProtectedProtocolMessage)message).Nonce); } @@ -251,12 +253,12 @@ namespace DotNetOpenAuth.Test.Messaging { } [Test, ExpectedException(typeof(ProtocolException))] - public void TooManyBindingElementsProvidingSameProtection() { + public async Task TooManyBindingElementsProvidingSameProtection() { Channel channel = new TestChannel( new TestMessageFactory(), new MockSigningBindingElement(), new MockSigningBindingElement()); - channel.ProcessOutgoingMessageTestHook(new TestSignedDirectedMessage()); + await channel.ProcessOutgoingMessageTestHookAsync(new TestSignedDirectedMessage()); } [Test] @@ -284,10 +286,10 @@ namespace DotNetOpenAuth.Test.Messaging { } [Test, ExpectedException(typeof(UnprotectedMessageException))] - public void InsufficientlyProtectedMessageSent() { + public async Task InsufficientlyProtectedMessageSent() { var message = new TestSignedDirectedMessage(MessageTransport.Direct); message.Recipient = new Uri("http://localtest"); - this.Channel.PrepareResponse(message); + await this.Channel.PrepareResponseAsync(message); } [Test, ExpectedException(typeof(UnprotectedMessageException))] @@ -297,9 +299,9 @@ namespace DotNetOpenAuth.Test.Messaging { } [Test, ExpectedException(typeof(ProtocolException))] - public void IncomingMessageMissingRequiredParameters() { + public async Task IncomingMessageMissingRequiredParameters() { var fields = GetStandardTestFields(FieldFill.IdentifiableButNotAllRequired); - this.Channel.ReadFromRequest(CreateHttpRequestInfo("GET", fields)); + await this.Channel.ReadFromRequestAsync(CreateHttpRequestInfo("GET", fields), CancellationToken.None); } } } diff --git a/src/DotNetOpenAuth.Test/Messaging/MessagingTestBase.cs b/src/DotNetOpenAuth.Test/Messaging/MessagingTestBase.cs index b7c0980..092b89c 100644 --- a/src/DotNetOpenAuth.Test/Messaging/MessagingTestBase.cs +++ b/src/DotNetOpenAuth.Test/Messaging/MessagingTestBase.cs @@ -10,6 +10,8 @@ namespace DotNetOpenAuth.Test { using System.Collections.Specialized; using System.IO; using System.Net; + using System.Threading; + using System.Threading.Tasks; using System.Xml; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; @@ -143,11 +145,11 @@ namespace DotNetOpenAuth.Test { } } - internal void ParameterizedReceiveTest(string method) { + internal async Task ParameterizedReceiveTestAsync(string method) { var fields = GetStandardTestFields(FieldFill.CompleteBeforeBindings); TestMessage expectedMessage = GetStandardTestMessage(FieldFill.CompleteBeforeBindings); - IDirectedProtocolMessage requestMessage = this.Channel.ReadFromRequest(CreateHttpRequestInfo(method, fields)); + IDirectedProtocolMessage requestMessage = await this.Channel.ReadFromRequestAsync(CreateHttpRequestInfo(method, fields), CancellationToken.None); Assert.IsNotNull(requestMessage); Assert.IsInstanceOf<TestMessage>(requestMessage); TestMessage actualMessage = (TestMessage)requestMessage; @@ -156,7 +158,7 @@ namespace DotNetOpenAuth.Test { Assert.AreEqual(expectedMessage.Location, actualMessage.Location); } - internal void ParameterizedReceiveProtectedTest(DateTime? utcCreatedDate, bool invalidSignature) { + internal async Task ParameterizedReceiveProtectedTestAsync(DateTime? utcCreatedDate, bool invalidSignature) { TestMessage expectedMessage = GetStandardTestMessage(FieldFill.CompleteBeforeBindings); var fields = GetStandardTestFields(FieldFill.CompleteBeforeBindings); fields.Add("Signature", invalidSignature ? "badsig" : MockSigningBindingElement.MessageSignature); @@ -165,7 +167,7 @@ namespace DotNetOpenAuth.Test { utcCreatedDate = DateTime.Parse(utcCreatedDate.Value.ToUniversalTime().ToString()); // round off the milliseconds so comparisons work later fields.Add("created_on", XmlConvert.ToString(utcCreatedDate.Value, XmlDateTimeSerializationMode.Utc)); } - IProtocolMessage requestMessage = this.Channel.ReadFromRequest(CreateHttpRequestInfo("GET", fields)); + IProtocolMessage requestMessage = await this.Channel.ReadFromRequestAsync(CreateHttpRequestInfo("GET", fields), CancellationToken.None); Assert.IsNotNull(requestMessage); Assert.IsInstanceOf<TestSignedDirectedMessage>(requestMessage); TestSignedDirectedMessage actualMessage = (TestSignedDirectedMessage)requestMessage; diff --git a/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs b/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs index 029447d..9e1bbff 100644 --- a/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs @@ -6,6 +6,8 @@ namespace DotNetOpenAuth.Test.OpenId { using System; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Messages; @@ -42,14 +44,14 @@ namespace DotNetOpenAuth.Test.OpenId { public void AssociateDiffieHellmanOverHttps() { Protocol protocol = Protocol.V20; OpenIdCoordinator coordinator = new OpenIdCoordinator( - rp => { + (rp, ct) => { // We have to formulate the associate request manually, // since the DNOI RP won't voluntarily use DH on HTTPS. AssociateDiffieHellmanRequest request = new AssociateDiffieHellmanRequest(protocol.Version, new Uri("https://Provider")); request.AssociationType = protocol.Args.SignatureAlgorithm.HMAC_SHA256; request.SessionType = protocol.Args.SessionType.DH_SHA256; request.InitializeRequest(); - var response = rp.Channel.Request<AssociateSuccessfulResponse>(request); + var response = await rp.Channel.RequestAsync<AssociateSuccessfulResponse>(request, CancellationToken.None); Assert.IsNotNull(response); Assert.AreEqual(request.AssociationType, response.AssociationType); Assert.AreEqual(request.SessionType, response.SessionType); diff --git a/src/DotNetOpenAuth.Test/OpenId/OpenIdCoordinator.cs b/src/DotNetOpenAuth.Test/OpenId/OpenIdCoordinator.cs index 5000833..aac69bd 100644 --- a/src/DotNetOpenAuth.Test/OpenId/OpenIdCoordinator.cs +++ b/src/DotNetOpenAuth.Test/OpenId/OpenIdCoordinator.cs @@ -6,6 +6,8 @@ namespace DotNetOpenAuth.Test.OpenId { using System; + using System.Threading; + using System.Threading.Tasks; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OpenId; @@ -15,7 +17,7 @@ namespace DotNetOpenAuth.Test.OpenId { using Validation; internal class OpenIdCoordinator : CoordinatorBase<OpenIdRelyingParty, OpenIdProvider> { - internal OpenIdCoordinator(Action<OpenIdRelyingParty> rpAction, Action<OpenIdProvider> opAction) + internal OpenIdCoordinator(Func<OpenIdRelyingParty, CancellationToken, Task> rpAction, Func<OpenIdProvider, CancellationToken, Task> opAction) : base(WrapAction(rpAction), WrapAction(opAction)) { } @@ -23,7 +25,7 @@ namespace DotNetOpenAuth.Test.OpenId { internal OpenIdRelyingParty RelyingParty { get; set; } - internal override void Run() { + internal override Task RunAsync() { this.EnsurePartiesAreInitialized(); var rpCoordinatingChannel = new CoordinatingChannel(this.RelyingParty.Channel, this.IncomingMessageFilter, this.OutgoingMessageFilter); var opCoordinatingChannel = new CoordinatingChannel(this.Provider.Channel, this.IncomingMessageFilter, this.OutgoingMessageFilter); @@ -33,23 +35,23 @@ namespace DotNetOpenAuth.Test.OpenId { this.RelyingParty.Channel = rpCoordinatingChannel; this.Provider.Channel = opCoordinatingChannel; - RunCore(this.RelyingParty, this.Provider); + return this.RunCoreAsync(this.RelyingParty, this.Provider); } - private static Action<OpenIdRelyingParty> WrapAction(Action<OpenIdRelyingParty> action) { + private static Func<OpenIdRelyingParty, Task> WrapAction(Func<OpenIdRelyingParty, Task> action) { Requires.NotNull(action, "action"); - return rp => { - action(rp); + return async rp => { + await action(rp); ((CoordinatingChannel)rp.Channel).Close(); }; } - private static Action<OpenIdProvider> WrapAction(Action<OpenIdProvider> action) { + private static Func<OpenIdProvider, Task> WrapAction(Func<OpenIdProvider, Task> action) { Requires.NotNull(action, "action"); - return op => { - action(op); + return async op => { + await action(op); ((CoordinatingChannel)op.Channel).Close(); }; } diff --git a/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs b/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs index 3a27e96..7e58047 100644 --- a/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs +++ b/src/DotNetOpenAuth.Test/OpenId/OpenIdTestBase.cs @@ -9,6 +9,9 @@ namespace DotNetOpenAuth.Test.OpenId { using System.Collections.Generic; using System.IO; using System.Reflection; + using System.Threading; + using System.Threading.Tasks; + using DotNetOpenAuth.Configuration; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; @@ -142,9 +145,9 @@ namespace DotNetOpenAuth.Test.OpenId { /// <remarks> /// This is a very useful method to pass to the OpenIdCoordinator constructor for the Provider argument. /// </remarks> - internal void AutoProvider(OpenIdProvider provider) { + internal async Task AutoProvider(OpenIdProvider provider, CancellationToken cancellationToken) { while (!((CoordinatingChannel)provider.Channel).RemoteChannel.IsDisposed) { - IRequest request = provider.GetRequest(); + IRequest request = await provider.GetRequestAsync(cancellationToken); if (request == null) { continue; } @@ -171,7 +174,7 @@ namespace DotNetOpenAuth.Test.OpenId { } } - provider.Respond(request); + await provider.Channel.PrepareResponseAsync(request, cancellationToken); } } |