summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--omaha_server/omaha/api.py10
-rw-r--r--omaha_server/omaha/tests/test_api.py73
-rw-r--r--omaha_server/sparkle/api.py4
-rw-r--r--omaha_server/sparkle/tests/test_api.py20
4 files changed, 99 insertions, 8 deletions
diff --git a/omaha_server/omaha/api.py b/omaha_server/omaha/api.py
index 4fda594..92d7a79 100644
--- a/omaha_server/omaha/api.py
+++ b/omaha_server/omaha/api.py
@@ -74,7 +74,7 @@ class StandardResultsSetPagination(pagination.PageNumberPagination):
max_page_size = 100
-class AppViewSet(BaseView):
+class AppViewSet(viewsets.ModelViewSet):
"""
API endpoint that allows applications to be viewed.
@@ -152,12 +152,12 @@ class AppViewSet(BaseView):
serializer_class = AppSerializer
-class DataViewSet(BaseView):
+class DataViewSet(viewsets.ModelViewSet):
queryset = Data.objects.all().order_by('-id')
serializer_class = DataSerializer
-class PlatformViewSet(BaseView):
+class PlatformViewSet(viewsets.ModelViewSet):
queryset = Platform.objects.all().order_by('-id')
serializer_class = PlatformSerializer
@@ -167,12 +167,12 @@ class ChannelViewSet(viewsets.ModelViewSet):
serializer_class = ChannelSerializer
-class VersionViewSet(BaseView, mixins.UpdateModelMixin):
+class VersionViewSet(viewsets.ModelViewSet):
queryset = Version.objects.all().order_by('-id')
serializer_class = VersionSerializer
-class ActionViewSet(BaseView):
+class ActionViewSet(viewsets.ModelViewSet):
queryset = Action.objects.all().order_by('-id')
serializer_class = ActionSerializer
diff --git a/omaha_server/omaha/tests/test_api.py b/omaha_server/omaha/tests/test_api.py
index 09f0cb0..35a33b3 100644
--- a/omaha_server/omaha/tests/test_api.py
+++ b/omaha_server/omaha/tests/test_api.py
@@ -122,6 +122,20 @@ class AppTest(BaseTest, APITestCase):
obj = Application.objects.get(id=response.data['id'])
self.assertEqual(response.data, self.serializer(obj).data)
+ @is_private()
+ def test_update(self):
+ data = dict(id='test_id', name='test_name', data_set=[])
+ response = self.client.post(reverse(self.url), data, format='json')
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ obj_id = response.data['id']
+ obj = Application.objects.get(id=obj_id)
+ self.assertEqual(obj.name, 'test_name')
+ url = reverse(self.url_detail, kwargs=dict(pk=obj_id))
+ response = self.client.patch(url, dict(name='test_other_name'))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ obj = Application.objects.get(id=obj_id)
+ self.assertEqual(obj.name, 'test_other_name')
+
class DataTest(BaseTest, APITestCase):
url = 'data-list'
@@ -138,6 +152,21 @@ class DataTest(BaseTest, APITestCase):
obj = Data.objects.get(id=response.data['id'])
self.assertEqual(response.data, self.serializer(obj).data)
+ @is_private()
+ def test_update(self):
+ app = ApplicationFactory.create()
+ data = dict(name=0, app=app.pk)
+ response = self.client.post(reverse(self.url), data, format='json')
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ obj_id = response.data['id']
+ obj = Data.objects.get(id=obj_id)
+ self.assertEqual(obj.name, 0)
+ url = reverse(self.url_detail, kwargs=dict(pk=obj_id))
+ response = self.client.patch(url, dict(name=1))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ obj = Data.objects.get(id=obj_id)
+ self.assertEqual(obj.name, 1)
+
class PlatformTest(BaseTest, APITestCase):
url = 'platform-list'
@@ -157,6 +186,20 @@ class PlatformTest(BaseTest, APITestCase):
obj = Platform.objects.get(id=response.data['id'])
self.assertEqual(response.data, self.serializer(obj).data)
+ @is_private()
+ def test_create(self):
+ data = dict(name='test_name')
+ response = self.client.post(reverse(self.url), data, format='json')
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ obj_id = response.data['id']
+ obj = Platform.objects.get(id=obj_id)
+ self.assertEqual(obj.name, 'test_name')
+ url = reverse(self.url_detail, kwargs=dict(pk=obj_id))
+ response = self.client.patch(url, dict(name='test_name2'))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ obj = Platform.objects.get(id=obj_id)
+ self.assertEqual(obj.name, 'test_name2')
+
class ChannelTest(BaseTest, APITestCase):
url = 'channel-list'
@@ -172,6 +215,19 @@ class ChannelTest(BaseTest, APITestCase):
obj = Channel.objects.get(id=response.data['id'])
self.assertEqual(response.data, self.serializer(obj).data)
+ @is_private()
+ def test_update(self):
+ data = dict(name='test_name')
+ response = self.client.post(reverse(self.url), data, format='json')
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ obj_id = response.data['id']
+ obj = Channel.objects.get(id=obj_id)
+ self.assertEqual(response.data, self.serializer(obj).data)
+ url = reverse(self.url_detail, kwargs=dict(pk=obj_id))
+ response = self.client.patch(url, dict(name='test_name2'))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ obj = Channel.objects.get(id=obj_id)
+ self.assertEqual(obj.name, 'test_name2')
class VersionTest(BaseTest, APITestCase):
url = 'version-list'
@@ -221,7 +277,6 @@ class VersionTest(BaseTest, APITestCase):
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
obj_id = response.data['id']
version = Version.objects.get(id=obj_id)
- self.assertEqual(response.data, self.serializer(version).data)
self.assertEqual(version.file_size, len(b'content'))
self.assertFalse(version.is_enabled)
url = reverse(self.url_detail, kwargs=dict(pk=obj_id))
@@ -247,6 +302,22 @@ class ActionTest(BaseTest, APITestCase):
obj = Action.objects.get(id=response.data['id'])
self.assertEqual(response.data, self.serializer(obj).data)
+ @is_private()
+ def test_update(self):
+ version = VersionFactory.create()
+ data = dict(event=1, version=version.pk)
+ response = self.client.post(reverse(self.url), data, format='json')
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ obj_id = response.data['id']
+ obj = Action.objects.get(id=obj_id)
+ self.assertEqual(response.data, self.serializer(obj).data)
+ url = reverse(self.url_detail, kwargs=dict(pk=obj_id))
+ response = self.client.patch(url, dict(event=2))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ obj = Action.objects.get(id=obj_id)
+ self.assertEqual(obj.event, 2)
+
+
class LiveStatistics(APITestCase):
maxDiff = None
diff --git a/omaha_server/sparkle/api.py b/omaha_server/sparkle/api.py
index 594ad1f..005db7f 100644
--- a/omaha_server/sparkle/api.py
+++ b/omaha_server/sparkle/api.py
@@ -18,12 +18,12 @@ License for the specific language governing permissions and limitations under
the License.
"""
-from omaha.api import BaseView
+from rest_framework import viewsets
from sparkle.serializers import SparkleVersionSerializer
from sparkle.models import SparkleVersion
-class SparkleVersionViewSet(BaseView):
+class SparkleVersionViewSet(viewsets.ModelViewSet):
queryset = SparkleVersion.objects.all().order_by('-id')
serializer_class = SparkleVersionSerializer
diff --git a/omaha_server/sparkle/tests/test_api.py b/omaha_server/sparkle/tests/test_api.py
index d70f981..d05247f 100644
--- a/omaha_server/sparkle/tests/test_api.py
+++ b/omaha_server/sparkle/tests/test_api.py
@@ -66,3 +66,23 @@ class VersionTest(BaseTest, APITestCase):
self.assertEqual(response.data, self.serializer(version).data)
self.assertEqual(version.file_size, len(b'content'))
self.assertTrue(version.is_enabled)
+
+ @is_private()
+ @temporary_media_root(MEDIA_URL='http://cache.pack.google.com/edgedl/chrome/install/782.112/')
+ def test_update(self):
+ data = dict(
+ app=ApplicationFactory.create().id,
+ channel=ChannelFactory.create().id,
+ version='1.2.3.4',
+ file=SimpleUploadedFile("chrome.exe", b'content'),
+ )
+ response = self.client.post(reverse(self.url), data)
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ obj_id = response.data['id']
+ version = SparkleVersion.objects.get(id=obj_id)
+ self.assertEqual(version.version, '1.2.3.4')
+ url = reverse(self.url_detail, kwargs=dict(pk=obj_id))
+ response = self.client.patch(url, dict(version='1.2.3.5'))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ version = SparkleVersion.objects.get(id=obj_id)
+ self.assertEqual(version.version, '1.2.3.5')