//----------------------------------------------------------------------- // // Copyright (c) Andrew Arnott. All rights reserved. // //----------------------------------------------------------------------- namespace DotNetOpenAuth.Test.OAuth.ChannelElements { using System; using System.Collections.Generic; using System.Collections.Specialized; using System.Diagnostics.Contracts; using System.IO; using System.Net; using System.Text; using System.Web; using System.Xml; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.Messaging.Reflection; using DotNetOpenAuth.OAuth.ChannelElements; using DotNetOpenAuth.Test.Mocks; using NUnit.Framework; [TestFixture] public class OAuthChannelTests : TestBase { private OAuthChannel channel; private TestWebRequestHandler webRequestHandler; private SigningBindingElementBase signingElement; private INonceStore nonceStore; [SetUp] public override void SetUp() { base.SetUp(); this.webRequestHandler = new TestWebRequestHandler(); this.signingElement = new RsaSha1SigningBindingElement(new InMemoryTokenManager()); this.nonceStore = new NonceMemoryStore(StandardExpirationBindingElement.MaximumMessageAge); this.channel = new OAuthChannel(this.signingElement, this.nonceStore, new InMemoryTokenManager(), new TestMessageFactory()); this.channel.WebRequestHandler = this.webRequestHandler; } [TestCase, ExpectedException(typeof(ArgumentNullException))] public void CtorNullSigner() { new OAuthChannel(null, this.nonceStore, new InMemoryTokenManager(), new TestMessageFactory()); } [TestCase, ExpectedException(typeof(ArgumentNullException))] public void CtorNullStore() { new OAuthChannel(new RsaSha1SigningBindingElement(new InMemoryTokenManager()), null, new InMemoryTokenManager(), new TestMessageFactory()); } [TestCase, ExpectedException(typeof(ArgumentNullException))] public void CtorNullTokenManager() { new OAuthChannel(new RsaSha1SigningBindingElement(new InMemoryTokenManager()), this.nonceStore, null, new TestMessageFactory()); } [TestCase] public void CtorSimpleConsumer() { new OAuthChannel(new RsaSha1SigningBindingElement(new InMemoryTokenManager()), this.nonceStore, (IConsumerTokenManager)new InMemoryTokenManager()); } [TestCase] public void CtorSimpleServiceProvider() { new OAuthChannel(new RsaSha1SigningBindingElement(new InMemoryTokenManager()), this.nonceStore, (IServiceProviderTokenManager)new InMemoryTokenManager()); } [TestCase] public void ReadFromRequestAuthorization() { this.ParameterizedReceiveTest(HttpDeliveryMethods.AuthorizationHeaderRequest); } /// /// Verifies that the OAuth ReadFromRequest method gathers parameters /// from the Authorization header, the query string and the entity form data. /// [TestCase] public void ReadFromRequestAuthorizationScattered() { // Start by creating a standard POST HTTP request. var fields = new Dictionary { { "age", "15" }, }; HttpRequestInfo requestInfo = CreateHttpRequestInfo(HttpDeliveryMethods.PostRequest, fields); // Now add another field to the request URL UriBuilder builder = new UriBuilder(requestInfo.UrlBeforeRewriting); builder.Query = "Name=Andrew"; requestInfo.UrlBeforeRewriting = builder.Uri; requestInfo.RawUrl = builder.Path + builder.Query + builder.Fragment; // Finally, add an Authorization header fields = new Dictionary { { "Location", "http://hostb/pathB" }, { "Timestamp", XmlConvert.ToString(DateTime.UtcNow, XmlDateTimeSerializationMode.Utc) }, }; requestInfo.Headers.Add(HttpRequestHeader.Authorization, CreateAuthorizationHeader(fields)); IDirectedProtocolMessage requestMessage = this.channel.ReadFromRequest(requestInfo); Assert.IsNotNull(requestMessage); Assert.IsInstanceOf(requestMessage); TestMessage testMessage = (TestMessage)requestMessage; Assert.AreEqual(15, testMessage.Age); Assert.AreEqual("Andrew", testMessage.Name); Assert.AreEqual("http://hostb/pathB", testMessage.Location.AbsoluteUri); } [TestCase] public void ReadFromRequestForm() { this.ParameterizedReceiveTest(HttpDeliveryMethods.PostRequest); } [TestCase] public void ReadFromRequestQueryString() { this.ParameterizedReceiveTest(HttpDeliveryMethods.GetRequest); } [TestCase] public void SendDirectMessageResponse() { IProtocolMessage message = new TestDirectedMessage { Age = 15, Name = "Andrew", Location = new Uri("http://hostb/pathB"), }; OutgoingWebResponse response = this.channel.PrepareResponse(message); Assert.AreSame(message, response.OriginalMessage); Assert.AreEqual(HttpStatusCode.OK, response.Status); Assert.AreEqual(0, response.Headers.Count); NameValueCollection body = HttpUtility.ParseQueryString(response.Body); Assert.AreEqual("15", body["age"]); Assert.AreEqual("Andrew", body["Name"]); Assert.AreEqual("http://hostb/pathB", body["Location"]); } [TestCase] public void ReadFromResponse() { var fields = new Dictionary { { "age", "15" }, { "Name", "Andrew" }, { "Location", "http://hostb/pathB" }, { "Timestamp", XmlConvert.ToString(DateTime.UtcNow, XmlDateTimeSerializationMode.Utc) }, }; MemoryStream ms = new MemoryStream(); StreamWriter writer = new StreamWriter(ms); writer.Write(MessagingUtilities.CreateQueryString(fields)); writer.Flush(); ms.Seek(0, SeekOrigin.Begin); IDictionary deserializedFields = this.channel.ReadFromResponseCoreTestHook(new CachedDirectWebResponse { CachedResponseStream = ms }); Assert.AreEqual(fields.Count, deserializedFields.Count); foreach (string key in fields.Keys) { Assert.AreEqual(fields[key], deserializedFields[key]); } } [TestCase, ExpectedException(typeof(ArgumentNullException))] public void RequestNull() { this.channel.Request(null); } [TestCase, ExpectedException(typeof(ArgumentException))] public void RequestNullRecipient() { IDirectedProtocolMessage message = new TestDirectedMessage(MessageTransport.Direct); this.channel.Request(message); } [TestCase, ExpectedException(typeof(NotSupportedException))] public void RequestBadPreferredScheme() { TestDirectedMessage message = new TestDirectedMessage(MessageTransport.Direct); message.Recipient = new Uri("http://localtest"); message.HttpMethods = HttpDeliveryMethods.None; this.channel.Request(message); } [TestCase] public void RequestUsingAuthorizationHeader() { this.ParameterizedRequestTest(HttpDeliveryMethods.AuthorizationHeaderRequest); } /// /// Verifies that message parts can be distributed to the query, form, and Authorization header. /// [TestCase] public void RequestUsingAuthorizationHeaderScattered() { TestDirectedMessage request = new TestDirectedMessage(MessageTransport.Direct) { Age = 15, Name = "Andrew", Location = new Uri("http://hostb/pathB"), Recipient = new Uri("http://localtest"), Timestamp = DateTime.UtcNow, HttpMethods = HttpDeliveryMethods.AuthorizationHeaderRequest, }; // ExtraData should appear in the form since this is a POST request, // and only official message parts get a place in the Authorization header. ((IProtocolMessage)request).ExtraData["appearinform"] = "formish"; request.Recipient = new Uri("http://localhost/?appearinquery=queryish"); request.HttpMethods = HttpDeliveryMethods.AuthorizationHeaderRequest | HttpDeliveryMethods.PostRequest; HttpWebRequest webRequest = this.channel.InitializeRequest(request); Assert.IsNotNull(webRequest); Assert.AreEqual("POST", webRequest.Method); Assert.AreEqual(request.Recipient, webRequest.RequestUri); var declaredParts = new Dictionary { { "age", request.Age.ToString() }, { "Name", request.Name }, { "Location", request.Location.AbsoluteUri }, { "Timestamp", XmlConvert.ToString(request.Timestamp, XmlDateTimeSerializationMode.Utc) }, }; Assert.AreEqual(CreateAuthorizationHeader(declaredParts), webRequest.Headers[HttpRequestHeader.Authorization]); Assert.AreEqual("appearinform=formish", this.webRequestHandler.RequestEntityAsString); } [TestCase] public void RequestUsingGet() { this.ParameterizedRequestTest(HttpDeliveryMethods.GetRequest); } [TestCase] public void RequestUsingPost() { this.ParameterizedRequestTest(HttpDeliveryMethods.PostRequest); } [TestCase] public void RequestUsingHead() { this.ParameterizedRequestTest(HttpDeliveryMethods.HeadRequest); } /// /// Verifies that messages asking for special HTTP status codes get them. /// [TestCase] public void SendDirectMessageResponseHonorsHttpStatusCodes() { IProtocolMessage message = MessagingTestBase.GetStandardTestMessage(MessagingTestBase.FieldFill.AllRequired); OutgoingWebResponse directResponse = this.channel.PrepareDirectResponseTestHook(message); Assert.AreEqual(HttpStatusCode.OK, directResponse.Status); var httpMessage = new TestDirectResponseMessageWithHttpStatus(); MessagingTestBase.GetStandardTestMessage(MessagingTestBase.FieldFill.AllRequired, httpMessage); httpMessage.HttpStatusCode = HttpStatusCode.NotAcceptable; directResponse = this.channel.PrepareDirectResponseTestHook(httpMessage); Assert.AreEqual(HttpStatusCode.NotAcceptable, directResponse.Status); } private static string CreateAuthorizationHeader(IDictionary fields) { Contract.Requires(fields != null); StringBuilder authorization = new StringBuilder(); authorization.Append("OAuth "); foreach (var pair in fields) { string key = Uri.EscapeDataString(pair.Key); string value = Uri.EscapeDataString(pair.Value); authorization.Append(key); authorization.Append("=\""); authorization.Append(value); authorization.Append("\","); } authorization.Length--; // remove trailing comma return authorization.ToString(); } private static HttpRequestInfo CreateHttpRequestInfo(HttpDeliveryMethods scheme, IDictionary fields) { string query = MessagingUtilities.CreateQueryString(fields); UriBuilder requestUri = new UriBuilder("http://localhost/path"); WebHeaderCollection headers = new WebHeaderCollection(); MemoryStream ms = new MemoryStream(); string method; switch (scheme) { case HttpDeliveryMethods.PostRequest: method = "POST"; headers.Add(HttpRequestHeader.ContentType, "application/x-www-form-urlencoded"); StreamWriter sw = new StreamWriter(ms); sw.Write(query); sw.Flush(); ms.Position = 0; break; case HttpDeliveryMethods.GetRequest: method = "GET"; requestUri.Query = query; break; case HttpDeliveryMethods.AuthorizationHeaderRequest: method = "GET"; headers.Add(HttpRequestHeader.Authorization, CreateAuthorizationHeader(fields)); break; default: throw new ArgumentOutOfRangeException("scheme", scheme, "Unexpected value"); } HttpRequestInfo request = new HttpRequestInfo { HttpMethod = method, UrlBeforeRewriting = requestUri.Uri, RawUrl = requestUri.Path + requestUri.Query + requestUri.Fragment, Headers = headers, InputStream = ms, }; return request; } private static HttpRequestInfo ConvertToRequestInfo(HttpWebRequest request, Stream postEntity) { HttpRequestInfo info = new HttpRequestInfo { HttpMethod = request.Method, UrlBeforeRewriting = request.RequestUri, RawUrl = request.RequestUri.AbsolutePath + request.RequestUri.Query + request.RequestUri.Fragment, Headers = request.Headers, InputStream = postEntity, }; return info; } private void ParameterizedRequestTest(HttpDeliveryMethods scheme) { TestDirectedMessage request = new TestDirectedMessage(MessageTransport.Direct) { Age = 15, Name = "Andrew", Location = new Uri("http://hostb/pathB"), Recipient = new Uri("http://localtest"), Timestamp = DateTime.UtcNow, HttpMethods = scheme, }; CachedDirectWebResponse rawResponse = null; this.webRequestHandler.Callback = (req) => { Assert.IsNotNull(req); HttpRequestInfo reqInfo = ConvertToRequestInfo(req, this.webRequestHandler.RequestEntityStream); Assert.AreEqual(MessagingUtilities.GetHttpVerb(scheme), reqInfo.HttpMethod); var incomingMessage = this.channel.ReadFromRequest(reqInfo) as TestMessage; Assert.IsNotNull(incomingMessage); Assert.AreEqual(request.Age, incomingMessage.Age); Assert.AreEqual(request.Name, incomingMessage.Name); Assert.AreEqual(request.Location, incomingMessage.Location); Assert.AreEqual(request.Timestamp, incomingMessage.Timestamp); var responseFields = new Dictionary { { "age", request.Age.ToString() }, { "Name", request.Name }, { "Location", request.Location.AbsoluteUri }, { "Timestamp", XmlConvert.ToString(request.Timestamp, XmlDateTimeSerializationMode.Utc) }, }; rawResponse = new CachedDirectWebResponse(); rawResponse.SetResponse(MessagingUtilities.CreateQueryString(responseFields)); return rawResponse; }; IProtocolMessage response = this.channel.Request(request); Assert.IsNotNull(response); Assert.IsInstanceOf(response); TestMessage responseMessage = (TestMessage)response; Assert.AreEqual(request.Age, responseMessage.Age); Assert.AreEqual(request.Name, responseMessage.Name); Assert.AreEqual(request.Location, responseMessage.Location); } private void ParameterizedReceiveTest(HttpDeliveryMethods scheme) { var fields = new Dictionary { { "age", "15" }, { "Name", "Andrew" }, { "Location", "http://hostb/pathB" }, { "Timestamp", XmlConvert.ToString(DateTime.UtcNow, XmlDateTimeSerializationMode.Utc) }, { "realm", "someValue" }, }; IProtocolMessage requestMessage = this.channel.ReadFromRequest(CreateHttpRequestInfo(scheme, fields)); Assert.IsNotNull(requestMessage); Assert.IsInstanceOf(requestMessage); TestMessage testMessage = (TestMessage)requestMessage; Assert.AreEqual(15, testMessage.Age); Assert.AreEqual("Andrew", testMessage.Name); Assert.AreEqual("http://hostb/pathB", testMessage.Location.AbsoluteUri); if (scheme == HttpDeliveryMethods.AuthorizationHeaderRequest) { // The realm value should be ignored in the authorization header Assert.IsFalse(((IMessage)testMessage).ExtraData.ContainsKey("realm")); } else { Assert.AreEqual("someValue", ((IMessage)testMessage).ExtraData["realm"]); } } } }