Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion alyx/alyx/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = __version__ = '3.4.2'
VERSION = __version__ = '3.5.0'
61 changes: 59 additions & 2 deletions alyx/alyx/test_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from datetime import date
import json

from django.test import TestCase
from django.test import Client
from django.test import Client, RequestFactory, TestCase, override_settings
from django.contrib.auth.models import AnonymousUser
from django.contrib.auth import get_user_model

from alyx.base import _custom_filter_parser
from alyx.throttling import AdaptiveScopedRateThrottle, IPRateThrottle


class TestDocView(TestCase):
Expand Down Expand Up @@ -77,3 +79,58 @@ def setup_admin_subject_user(obj):
obj.Lab = Lab.objects.create(name='cortexlab')
obj.subject = Subject.objects.create(
nickname='aQt', birth_date=date(2025, 1, 1), lab=obj.lab, actual_severity=2)


class DocsIPThrottle(IPRateThrottle):
scope = 'docs'
rate = '10/minute'


class TestThrottling(TestCase):
def setUp(self):
self.factory = RequestFactory()

def test_ip_rate_throttle_uses_ip_cache_key(self):
request = self.factory.get('/docs/', REMOTE_ADDR='203.0.113.10')
request.user = AnonymousUser()
throttle = DocsIPThrottle()

key = throttle.get_cache_key(request, view=None)

self.assertEqual(key, 'throttle_docs_203.0.113.10')

def test_ip_rate_throttle_returns_none_when_rate_is_disabled(self):
request = self.factory.get('/docs/', REMOTE_ADDR='203.0.113.10')
request.user = AnonymousUser()
throttle = DocsIPThrottle()
throttle.rate = None

key = throttle.get_cache_key(request, view=None)

self.assertIsNone(key)

@override_settings(THROTTLE_MODE='user-based')
def test_adaptive_scoped_rate_throttle_uses_user_ident_when_authenticated(self):
user = get_user_model().objects.create_user('throttle_user', 'throttle@example.com', 'pass')
request = self.factory.get('/docs/', REMOTE_ADDR='198.51.100.1')
request.user = user

throttle = AdaptiveScopedRateThrottle()
throttle.scope = 'docs'

key = throttle.get_cache_key(request, view=None)

self.assertEqual(key, f'throttle_docs_{user.pk}')

@override_settings(THROTTLE_MODE='anonymous')
def test_adaptive_scoped_rate_throttle_uses_ip_even_when_authenticated(self):
user = get_user_model().objects.create_user('throttle_user2', 'throttle2@example.com', 'pass')
request = self.factory.get('/docs/', REMOTE_ADDR='198.51.100.20')
request.user = user

throttle = AdaptiveScopedRateThrottle()
throttle.scope = 'docs'

key = throttle.get_cache_key(request, view=None)

self.assertEqual(key, 'throttle_docs_198.51.100.20')
46 changes: 46 additions & 0 deletions alyx/alyx/throttling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Throttling classes for Alyx REST API.

For Alyx instances where a single user login is shared across multiple users, we want to throttle based on IP address rather than user ID.
This module defines custom throttling classes that can be used in the Django REST Framework settings to achieve this behavior.
"""
from django.conf import settings
from rest_framework.throttling import SimpleRateThrottle, UserRateThrottle, ScopedRateThrottle


def _get_throttle_mode():
return getattr(settings, 'THROTTLE_MODE', 'user-based').lower()


class IPRateThrottle(SimpleRateThrottle):
"""Throttle all requests by client IP, regardless of authentication state."""

def get_cache_key(self, request, view):
if self.rate is None:
return None
return self.cache_format % {'scope': self.scope, 'ident': self.get_ident(request)}


if _get_throttle_mode() == 'anonymous':
class BurstRateThrottle(IPRateThrottle):
scope = 'burst'


class SustainedRateThrottle(IPRateThrottle):
scope = 'sustained'
else:
class BurstRateThrottle(UserRateThrottle):
scope = 'burst'


class SustainedRateThrottle(UserRateThrottle):
scope = 'sustained'


class AdaptiveScopedRateThrottle(ScopedRateThrottle):
def get_cache_key(self, request, view):
if getattr(settings, "THROTTLE_MODE", "user-based").lower() == "anonymous":
return self.cache_format % {
"scope": self.scope,
"ident": self.get_ident(request), # always IP
}
return super().get_cache_key(request, view) # default ScopedRateThrottle behavior
3 changes: 3 additions & 0 deletions alyx/alyx/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from drf_spectacular.views import SpectacularRedocView

import alyx
from alyx.throttling import AdaptiveScopedRateThrottle


class IgnoreClientContentNegotiation(BaseContentNegotiation):
Expand All @@ -29,6 +30,8 @@ class SpectacularRedocViewCoreAPIDeprecation(SpectacularRedocView):
are now served via the api/schema
"""
content_negotiation_class = IgnoreClientContentNegotiation
throttle_classes = [AdaptiveScopedRateThrottle]
throttle_scope = 'docs'

def get(self, request, *args, **kwargs):
if request.headers['Accept'].startswith('application/coreapi+json'):
Expand Down
2 changes: 1 addition & 1 deletion alyx/data/management/commands/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def handle(self, *args, **options):
dr.data_url = 'http://ibl.flatironinstitute.org/cortexlab/Subjects/'
dr.save()

qs = DatasetType.objects.filter(filename_pattern__isnull=False)
qs = DatasetType.objects.filter(filename_pattern__isnull=False).values('pk', 'name', 'filename_pattern')
dt = None
for d in FileRecord.objects.all().select_related('dataset'):
try:
Expand Down
71 changes: 70 additions & 1 deletion alyx/data/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_model_methods(self):
('some_file.ext', 'some_file')
)

dtypes = DatasetType.objects.all()
dtypes = DatasetType.objects.values('name', 'filename_pattern')
for filename, dataname in filename_typename:
with self.subTest(filename=filename):
self.assertEqual(get_dataset_type(filename, dtypes).name, dataname)
Expand Down Expand Up @@ -306,3 +306,72 @@ def test_get_name_collection_revision(self):
self.assertIsInstance(resp, Response)
self.assertEqual(resp.status_code, 400)
self.assertIn('Invalid ALF path', resp.data['detail'])

def test_get_aggregate_collection_revision(self):
# For example: Subjects/cortexlab/SP044/#2020-01-01#/obj.attr.ext
# Tags/2026_Q1_Wang_Yu_et_al/obj.attr.ext
f = transfers.get_aggregate_collection_revision
# Check processes revisions in filenames list
# Check parses identifier and relation parts
relative_path = 'Subjects/cortexlab/SP044'
filenames =[
'obj.attr.ext',
'#2020-01-01#/obj.attr.ext'
]
dataset_path_parsed, resp = f(filenames, relative_path)
self.assertIsNone(resp) # should be no error response
self.assertEqual(len(dataset_path_parsed), 2)
expected = {
'full_path': 'Subjects/cortexlab/SP044/obj.attr.ext',
'filename': filenames[0],
'rel_dir_path': relative_path,
'collection': relative_path,
'revision': None,
'relation': 'Subjects',
'identifier': 'cortexlab/SP044'
}
self.assertDictEqual(dataset_path_parsed[0], expected)
self.assertEqual(dataset_path_parsed[1]['revision'], '2020-01-01')

# Check handles no collection and no revision
dataset_path_parsed, resp = f(filenames, '')
expected = [
{
'full_path': filenames[0],
'filename': filenames[0],
'rel_dir_path': '',
'collection': None,
'revision': None,
'identifier': None,
'relation': None
},
{
'full_path': filenames[1],
'filename': filenames[0],
'rel_dir_path': '#2020-01-01#',
'collection': None,
'revision': '2020-01-01',
'identifier': None,
'relation': None
}
]
self.assertIsNone(resp)
self.assertDictEqual(dataset_path_parsed[0], expected[0])
self.assertDictEqual(dataset_path_parsed[1], expected[1])

# Check handles single collection (shouldn't error parsing identifier and relation)
dataset_path_parsed, resp = f(filenames, 'foo')
self.assertIsNone(resp)
self.assertEqual(dataset_path_parsed[0]['collection'], 'foo')
self.assertIsNone(dataset_path_parsed[0]['identifier'])
self.assertIsNone(dataset_path_parsed[0]['relation'])

# Check error on repeated revisions
dataset_path_parsed, resp = f(filenames, '#rev1a#')
self.assertIsInstance(resp, Response)
self.assertEqual(resp.status_code, 400)

# Check error on unsupported characters
dataset_path_parsed, resp = f(filenames[:1], 'foo/b+r/baz')
self.assertIsInstance(resp, Response)
self.assertEqual(resp.status_code, 400)
117 changes: 112 additions & 5 deletions alyx/data/tests_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from pathlib import PurePosixPath
import uuid

from django.contrib.contenttypes.models import ContentType
from django.contrib.auth import get_user_model
from django.urls import reverse

from alyx.base import BaseTests
from data.models import Dataset, FileRecord, Download, Tag
from data.models import Dataset, FileRecord, Download, Tag, DatasetType, DataFormat
from subjects.models import Subject


class APIDataTests(BaseTests):
Expand Down Expand Up @@ -746,7 +748,7 @@ def test_protected_view(self):
self.client.post(reverse('tag-list'), {'name': 'tag1', 'protected': True})

# Create some datasets and register
data = {'path': '%s/2018-01-01/002/' % self.subject,
data = {'path': f'{self.subject}/2018-01-01/002/',
'filenames': 'test_prot/a.d.e2,test_prot/a.d.e1,',
'name': 'drb1', # this is the repository name
}
Expand All @@ -763,7 +765,7 @@ def test_protected_view(self):
# 1. already created + protected --> expect protected=True
# 2. already created --> expect protected=False
# 3. not yet created --> expect protected=False
data = {'path': '%s/2018-01-01/002/' % self.subject,
data = {'path': f'{self.subject}/2018-01-01/002/',
'filenames': 'test_prot/a.d.e2,test_prot/a.d.e1,test_prot/a.b.e1',
'name': 'drb1',
'check_protected': True
Expand Down Expand Up @@ -792,15 +794,15 @@ def test_check_protected(self):
self.client.post(reverse('tag-list'), {'name': 'tag1', 'protected': True})

# Create some datasets and register
data = {'path': '%s/2018-01-01/002/' % self.subject,
data = {'path': f'{self.subject}/2018-01-01/002/',
'filenames': 'test_prot/a.c.e2',
'name': 'drb1', # this is the repository name
}

d = self.ar(self.client.post(reverse('register-file'), data), 201)

# Check the same dataset to see if it is protected, should be unprotected
# and get a status 200 respons
# and get a status 200 response
_ = data.pop('name')

r = self.ar(self.client.get(reverse('check-protected'), data=data,
Expand All @@ -818,6 +820,40 @@ def test_check_protected(self):
self.assertEqual(r['status_code'], 403)
self.assertEqual(r['error'], 'One or more datasets is protected')

# Test with aggregate dataset
for i in range(2):
dataset, is_new = Dataset.objects.get_or_create(
collection=f'subjects/laba/{self.subject}', name=f'a.d.e{i}',
dataset_type=DatasetType.objects.first(), data_format=DataFormat.objects.first(),
content_type=ContentType.objects.get(app_label='subjects', model='subject'),
object_id=Subject.objects.get(nickname=self.subject).pk
)

# add protected tag to the second dataset
dataset.tags.add(tag1)

# Check the protected status of three files
# 1. already created + protected --> expect protected=True
# 2. already created --> expect protected=False
# 3. not yet created --> expect protected=False
data = {
'path': dataset.collection,
'filenames': 'a.d.e0,a.d.e1,a.d.e2',
'object_id': str(Subject.objects.get(nickname=self.subject).pk),
'content_type': 'subjects.subject'}

r = self.client.get(reverse('check-protected'), data)
r = self.ar(r, 200)
self.assertEqual(r['status_code'], 403)
self.assertEqual(r['error'], 'One or more datasets is protected')
expected = [{'a.d.e0': [{'': False}]}, {'a.d.e1': [{'': True}]}, {'a.d.e2': []}]
self.assertEqual(r['details'], expected)
# Untag dataset and check that the status is now 200
dataset.tags.remove(tag1)
r = self.client.get(reverse('check-protected'), data)
r = self.ar(r, 200)
self.assertEqual(r['details'], 'None of the datasets are protected')

def test_revisions(self):
# Check revision lookup with name
self.post(reverse('revision-list'), {'name': 'v2'})
Expand Down Expand Up @@ -907,3 +943,74 @@ def test_auto_datetime_field(self):
# check that all modified_datetime fields are set to the value we chose
for iurl, url in enumerate(dset_urls):
self.assertEqual(self.client.get(url).data['auto_datetime'], mod_dates[0])

def test_register_aggregate(self):
"""Test endpoint for registering a dataset aggregated accross sessions.

Such datasets are associated to a model via a generic foreign key, and the relative
path does not include the session pattern.
"""
# Check basic registration of aggregate dataset
subject_uuid = str(Subject.objects.get(nickname=self.subject).pk)
data = {
'path': f'Subjects/laba/{self.subject}',
'filenames': 'a.a.e1,a.b.e1',
'hostname': 'hostname',
'object_id': subject_uuid,
'content_type': 'subject',
'check_protected': True
}
r = self.client.post(reverse('register-file'), data)
self.ar(r, 201)
fr = FileRecord.objects.filter(dataset=Dataset.objects.get(name='a.a.e1'))
self.assertTrue(fr.count() == 1)
# Should support app label in content type
data['content_type'] = 'subjects.subject'
r = self.client.post(reverse('register-file'), data)
self.ar(r, 201)
fr = FileRecord.objects.filter(dataset=Dataset.objects.get(name='a.a.e1'))
self.assertTrue(fr.count() == 1)
# Response should include content type and object id
self.assertEqual(r.data[0]['content_type'], 'subjects.subject')
self.assertEqual(str(r.data[0]['object_id']), subject_uuid)

# Check behaviour when dataset protected
tag = Tag.objects.create(name='protected_tag', protected=True)
dataset = Dataset.objects.get(name='a.a.e1')
dataset.tags.add(tag)
r = self.client.post(reverse('register-file'), data)
self.ar(r, 403)
self.assertEqual(r.data['error'], 'One or more datasets is protected')
details = r.data['details']
expected = [{'a.a.e1': [{'': True}]}, {'a.b.e1': [{'': False}]}]
self.assertEqual(details, expected)

# And with check-protected False
data['check_protected'] = False
r = self.client.post(reverse('register-file'), data)
self.ar(r, 403)
self.assertRegex(r.data['detail'], rf'Dataset {str(dataset.pk)} is protected, cannot patch')

# Test some validation
del data['content_type']
r = self.client.post(reverse('register-file'), data)
self.ar(r, 400)
data['content_type'] = 'subject'
del data['hostname']
r = self.client.post(reverse('register-file'), data)
self.ar(r, 400)
# Incorrect content type
data['hostname'] = 'hostname'
data['content_type'] = 'incorrect'
r = self.client.post(reverse('register-file'), data)
self.ar(r, 400)
# Incorrect object id
data['content_type'] = 'subject'
data['object_id'] = 'incorrect'
r = self.client.post(reverse('register-file'), data)
self.ar(r, 400)
# Invalid path
data['object_id'] = subject_uuid
data['path'] = f'Subjects/l+ba/{self.subject}'
r = self.client.post(reverse('register-file'), data)
self.ar(r, 400)
Loading
Loading