diff --git a/alyx/alyx/__init__.py b/alyx/alyx/__init__.py index 823f79397..427ff2af4 100644 --- a/alyx/alyx/__init__.py +++ b/alyx/alyx/__init__.py @@ -1 +1 @@ -VERSION = __version__ = '3.4.2' +VERSION = __version__ = '3.5.0' diff --git a/alyx/alyx/test_base.py b/alyx/alyx/test_base.py index dc01c04fd..39f532f65 100644 --- a/alyx/alyx/test_base.py +++ b/alyx/alyx/test_base.py @@ -1,10 +1,12 @@ from datetime import date import json -from django.test import TestCase -from django.test import Client +from django.test import Client, RequestFactory, TestCase, override_settings +from django.contrib.auth.models import AnonymousUser +from django.contrib.auth import get_user_model from alyx.base import _custom_filter_parser +from alyx.throttling import AdaptiveScopedRateThrottle, IPRateThrottle class TestDocView(TestCase): @@ -77,3 +79,58 @@ def setup_admin_subject_user(obj): obj.Lab = Lab.objects.create(name='cortexlab') obj.subject = Subject.objects.create( nickname='aQt', birth_date=date(2025, 1, 1), lab=obj.lab, actual_severity=2) + + +class DocsIPThrottle(IPRateThrottle): + scope = 'docs' + rate = '10/minute' + + +class TestThrottling(TestCase): + def setUp(self): + self.factory = RequestFactory() + + def test_ip_rate_throttle_uses_ip_cache_key(self): + request = self.factory.get('/docs/', REMOTE_ADDR='203.0.113.10') + request.user = AnonymousUser() + throttle = DocsIPThrottle() + + key = throttle.get_cache_key(request, view=None) + + self.assertEqual(key, 'throttle_docs_203.0.113.10') + + def test_ip_rate_throttle_returns_none_when_rate_is_disabled(self): + request = self.factory.get('/docs/', REMOTE_ADDR='203.0.113.10') + request.user = AnonymousUser() + throttle = DocsIPThrottle() + throttle.rate = None + + key = throttle.get_cache_key(request, view=None) + + self.assertIsNone(key) + + @override_settings(THROTTLE_MODE='user-based') + def test_adaptive_scoped_rate_throttle_uses_user_ident_when_authenticated(self): + user = get_user_model().objects.create_user('throttle_user', 'throttle@example.com', 'pass') + request = self.factory.get('/docs/', REMOTE_ADDR='198.51.100.1') + request.user = user + + throttle = AdaptiveScopedRateThrottle() + throttle.scope = 'docs' + + key = throttle.get_cache_key(request, view=None) + + self.assertEqual(key, f'throttle_docs_{user.pk}') + + @override_settings(THROTTLE_MODE='anonymous') + def test_adaptive_scoped_rate_throttle_uses_ip_even_when_authenticated(self): + user = get_user_model().objects.create_user('throttle_user2', 'throttle2@example.com', 'pass') + request = self.factory.get('/docs/', REMOTE_ADDR='198.51.100.20') + request.user = user + + throttle = AdaptiveScopedRateThrottle() + throttle.scope = 'docs' + + key = throttle.get_cache_key(request, view=None) + + self.assertEqual(key, 'throttle_docs_198.51.100.20') diff --git a/alyx/alyx/throttling.py b/alyx/alyx/throttling.py new file mode 100644 index 000000000..22bc0a7f3 --- /dev/null +++ b/alyx/alyx/throttling.py @@ -0,0 +1,46 @@ +"""Throttling classes for Alyx REST API. + +For Alyx instances where a single user login is shared across multiple users, we want to throttle based on IP address rather than user ID. +This module defines custom throttling classes that can be used in the Django REST Framework settings to achieve this behavior. +""" +from django.conf import settings +from rest_framework.throttling import SimpleRateThrottle, UserRateThrottle, ScopedRateThrottle + + +def _get_throttle_mode(): + return getattr(settings, 'THROTTLE_MODE', 'user-based').lower() + + +class IPRateThrottle(SimpleRateThrottle): + """Throttle all requests by client IP, regardless of authentication state.""" + + def get_cache_key(self, request, view): + if self.rate is None: + return None + return self.cache_format % {'scope': self.scope, 'ident': self.get_ident(request)} + + +if _get_throttle_mode() == 'anonymous': + class BurstRateThrottle(IPRateThrottle): + scope = 'burst' + + + class SustainedRateThrottle(IPRateThrottle): + scope = 'sustained' +else: + class BurstRateThrottle(UserRateThrottle): + scope = 'burst' + + + class SustainedRateThrottle(UserRateThrottle): + scope = 'sustained' + + +class AdaptiveScopedRateThrottle(ScopedRateThrottle): + def get_cache_key(self, request, view): + if getattr(settings, "THROTTLE_MODE", "user-based").lower() == "anonymous": + return self.cache_format % { + "scope": self.scope, + "ident": self.get_ident(request), # always IP + } + return super().get_cache_key(request, view) # default ScopedRateThrottle behavior \ No newline at end of file diff --git a/alyx/alyx/views.py b/alyx/alyx/views.py index f8f8f3380..c346c5b99 100644 --- a/alyx/alyx/views.py +++ b/alyx/alyx/views.py @@ -6,6 +6,7 @@ from drf_spectacular.views import SpectacularRedocView import alyx +from alyx.throttling import AdaptiveScopedRateThrottle class IgnoreClientContentNegotiation(BaseContentNegotiation): @@ -29,6 +30,8 @@ class SpectacularRedocViewCoreAPIDeprecation(SpectacularRedocView): are now served via the api/schema """ content_negotiation_class = IgnoreClientContentNegotiation + throttle_classes = [AdaptiveScopedRateThrottle] + throttle_scope = 'docs' def get(self, request, *args, **kwargs): if request.headers['Accept'].startswith('application/coreapi+json'): diff --git a/alyx/data/management/commands/files.py b/alyx/data/management/commands/files.py index f0a5aec29..a7e5efc0b 100644 --- a/alyx/data/management/commands/files.py +++ b/alyx/data/management/commands/files.py @@ -199,7 +199,7 @@ def handle(self, *args, **options): dr.data_url = 'http://ibl.flatironinstitute.org/cortexlab/Subjects/' dr.save() - qs = DatasetType.objects.filter(filename_pattern__isnull=False) + qs = DatasetType.objects.filter(filename_pattern__isnull=False).values('pk', 'name', 'filename_pattern') dt = None for d in FileRecord.objects.all().select_related('dataset'): try: diff --git a/alyx/data/tests.py b/alyx/data/tests.py index 27e583886..e3bc4ae02 100644 --- a/alyx/data/tests.py +++ b/alyx/data/tests.py @@ -81,7 +81,7 @@ def test_model_methods(self): ('some_file.ext', 'some_file') ) - dtypes = DatasetType.objects.all() + dtypes = DatasetType.objects.values('name', 'filename_pattern') for filename, dataname in filename_typename: with self.subTest(filename=filename): self.assertEqual(get_dataset_type(filename, dtypes).name, dataname) @@ -306,3 +306,72 @@ def test_get_name_collection_revision(self): self.assertIsInstance(resp, Response) self.assertEqual(resp.status_code, 400) self.assertIn('Invalid ALF path', resp.data['detail']) + + def test_get_aggregate_collection_revision(self): + # For example: Subjects/cortexlab/SP044/#2020-01-01#/obj.attr.ext + # Tags/2026_Q1_Wang_Yu_et_al/obj.attr.ext + f = transfers.get_aggregate_collection_revision + # Check processes revisions in filenames list + # Check parses identifier and relation parts + relative_path = 'Subjects/cortexlab/SP044' + filenames =[ + 'obj.attr.ext', + '#2020-01-01#/obj.attr.ext' + ] + dataset_path_parsed, resp = f(filenames, relative_path) + self.assertIsNone(resp) # should be no error response + self.assertEqual(len(dataset_path_parsed), 2) + expected = { + 'full_path': 'Subjects/cortexlab/SP044/obj.attr.ext', + 'filename': filenames[0], + 'rel_dir_path': relative_path, + 'collection': relative_path, + 'revision': None, + 'relation': 'Subjects', + 'identifier': 'cortexlab/SP044' + } + self.assertDictEqual(dataset_path_parsed[0], expected) + self.assertEqual(dataset_path_parsed[1]['revision'], '2020-01-01') + + # Check handles no collection and no revision + dataset_path_parsed, resp = f(filenames, '') + expected = [ + { + 'full_path': filenames[0], + 'filename': filenames[0], + 'rel_dir_path': '', + 'collection': None, + 'revision': None, + 'identifier': None, + 'relation': None + }, + { + 'full_path': filenames[1], + 'filename': filenames[0], + 'rel_dir_path': '#2020-01-01#', + 'collection': None, + 'revision': '2020-01-01', + 'identifier': None, + 'relation': None + } + ] + self.assertIsNone(resp) + self.assertDictEqual(dataset_path_parsed[0], expected[0]) + self.assertDictEqual(dataset_path_parsed[1], expected[1]) + + # Check handles single collection (shouldn't error parsing identifier and relation) + dataset_path_parsed, resp = f(filenames, 'foo') + self.assertIsNone(resp) + self.assertEqual(dataset_path_parsed[0]['collection'], 'foo') + self.assertIsNone(dataset_path_parsed[0]['identifier']) + self.assertIsNone(dataset_path_parsed[0]['relation']) + + # Check error on repeated revisions + dataset_path_parsed, resp = f(filenames, '#rev1a#') + self.assertIsInstance(resp, Response) + self.assertEqual(resp.status_code, 400) + + # Check error on unsupported characters + dataset_path_parsed, resp = f(filenames[:1], 'foo/b+r/baz') + self.assertIsInstance(resp, Response) + self.assertEqual(resp.status_code, 400) diff --git a/alyx/data/tests_rest.py b/alyx/data/tests_rest.py index 71fb614e8..75bd9fd50 100644 --- a/alyx/data/tests_rest.py +++ b/alyx/data/tests_rest.py @@ -2,11 +2,13 @@ from pathlib import PurePosixPath import uuid +from django.contrib.contenttypes.models import ContentType from django.contrib.auth import get_user_model from django.urls import reverse from alyx.base import BaseTests -from data.models import Dataset, FileRecord, Download, Tag +from data.models import Dataset, FileRecord, Download, Tag, DatasetType, DataFormat +from subjects.models import Subject class APIDataTests(BaseTests): @@ -746,7 +748,7 @@ def test_protected_view(self): self.client.post(reverse('tag-list'), {'name': 'tag1', 'protected': True}) # Create some datasets and register - data = {'path': '%s/2018-01-01/002/' % self.subject, + data = {'path': f'{self.subject}/2018-01-01/002/', 'filenames': 'test_prot/a.d.e2,test_prot/a.d.e1,', 'name': 'drb1', # this is the repository name } @@ -763,7 +765,7 @@ def test_protected_view(self): # 1. already created + protected --> expect protected=True # 2. already created --> expect protected=False # 3. not yet created --> expect protected=False - data = {'path': '%s/2018-01-01/002/' % self.subject, + data = {'path': f'{self.subject}/2018-01-01/002/', 'filenames': 'test_prot/a.d.e2,test_prot/a.d.e1,test_prot/a.b.e1', 'name': 'drb1', 'check_protected': True @@ -792,7 +794,7 @@ def test_check_protected(self): self.client.post(reverse('tag-list'), {'name': 'tag1', 'protected': True}) # Create some datasets and register - data = {'path': '%s/2018-01-01/002/' % self.subject, + data = {'path': f'{self.subject}/2018-01-01/002/', 'filenames': 'test_prot/a.c.e2', 'name': 'drb1', # this is the repository name } @@ -800,7 +802,7 @@ def test_check_protected(self): d = self.ar(self.client.post(reverse('register-file'), data), 201) # Check the same dataset to see if it is protected, should be unprotected - # and get a status 200 respons + # and get a status 200 response _ = data.pop('name') r = self.ar(self.client.get(reverse('check-protected'), data=data, @@ -818,6 +820,40 @@ def test_check_protected(self): self.assertEqual(r['status_code'], 403) self.assertEqual(r['error'], 'One or more datasets is protected') + # Test with aggregate dataset + for i in range(2): + dataset, is_new = Dataset.objects.get_or_create( + collection=f'subjects/laba/{self.subject}', name=f'a.d.e{i}', + dataset_type=DatasetType.objects.first(), data_format=DataFormat.objects.first(), + content_type=ContentType.objects.get(app_label='subjects', model='subject'), + object_id=Subject.objects.get(nickname=self.subject).pk + ) + + # add protected tag to the second dataset + dataset.tags.add(tag1) + + # Check the protected status of three files + # 1. already created + protected --> expect protected=True + # 2. already created --> expect protected=False + # 3. not yet created --> expect protected=False + data = { + 'path': dataset.collection, + 'filenames': 'a.d.e0,a.d.e1,a.d.e2', + 'object_id': str(Subject.objects.get(nickname=self.subject).pk), + 'content_type': 'subjects.subject'} + + r = self.client.get(reverse('check-protected'), data) + r = self.ar(r, 200) + self.assertEqual(r['status_code'], 403) + self.assertEqual(r['error'], 'One or more datasets is protected') + expected = [{'a.d.e0': [{'': False}]}, {'a.d.e1': [{'': True}]}, {'a.d.e2': []}] + self.assertEqual(r['details'], expected) + # Untag dataset and check that the status is now 200 + dataset.tags.remove(tag1) + r = self.client.get(reverse('check-protected'), data) + r = self.ar(r, 200) + self.assertEqual(r['details'], 'None of the datasets are protected') + def test_revisions(self): # Check revision lookup with name self.post(reverse('revision-list'), {'name': 'v2'}) @@ -907,3 +943,74 @@ def test_auto_datetime_field(self): # check that all modified_datetime fields are set to the value we chose for iurl, url in enumerate(dset_urls): self.assertEqual(self.client.get(url).data['auto_datetime'], mod_dates[0]) + + def test_register_aggregate(self): + """Test endpoint for registering a dataset aggregated accross sessions. + + Such datasets are associated to a model via a generic foreign key, and the relative + path does not include the session pattern. + """ + # Check basic registration of aggregate dataset + subject_uuid = str(Subject.objects.get(nickname=self.subject).pk) + data = { + 'path': f'Subjects/laba/{self.subject}', + 'filenames': 'a.a.e1,a.b.e1', + 'hostname': 'hostname', + 'object_id': subject_uuid, + 'content_type': 'subject', + 'check_protected': True + } + r = self.client.post(reverse('register-file'), data) + self.ar(r, 201) + fr = FileRecord.objects.filter(dataset=Dataset.objects.get(name='a.a.e1')) + self.assertTrue(fr.count() == 1) + # Should support app label in content type + data['content_type'] = 'subjects.subject' + r = self.client.post(reverse('register-file'), data) + self.ar(r, 201) + fr = FileRecord.objects.filter(dataset=Dataset.objects.get(name='a.a.e1')) + self.assertTrue(fr.count() == 1) + # Response should include content type and object id + self.assertEqual(r.data[0]['content_type'], 'subjects.subject') + self.assertEqual(str(r.data[0]['object_id']), subject_uuid) + + # Check behaviour when dataset protected + tag = Tag.objects.create(name='protected_tag', protected=True) + dataset = Dataset.objects.get(name='a.a.e1') + dataset.tags.add(tag) + r = self.client.post(reverse('register-file'), data) + self.ar(r, 403) + self.assertEqual(r.data['error'], 'One or more datasets is protected') + details = r.data['details'] + expected = [{'a.a.e1': [{'': True}]}, {'a.b.e1': [{'': False}]}] + self.assertEqual(details, expected) + + # And with check-protected False + data['check_protected'] = False + r = self.client.post(reverse('register-file'), data) + self.ar(r, 403) + self.assertRegex(r.data['detail'], rf'Dataset {str(dataset.pk)} is protected, cannot patch') + + # Test some validation + del data['content_type'] + r = self.client.post(reverse('register-file'), data) + self.ar(r, 400) + data['content_type'] = 'subject' + del data['hostname'] + r = self.client.post(reverse('register-file'), data) + self.ar(r, 400) + # Incorrect content type + data['hostname'] = 'hostname' + data['content_type'] = 'incorrect' + r = self.client.post(reverse('register-file'), data) + self.ar(r, 400) + # Incorrect object id + data['content_type'] = 'subject' + data['object_id'] = 'incorrect' + r = self.client.post(reverse('register-file'), data) + self.ar(r, 400) + # Invalid path + data['object_id'] = subject_uuid + data['path'] = f'Subjects/l+ba/{self.subject}' + r = self.client.post(reverse('register-file'), data) + self.ar(r, 400) diff --git a/alyx/data/transfers.py b/alyx/data/transfers.py index c17ebd3ad..859c0f344 100644 --- a/alyx/data/transfers.py +++ b/alyx/data/transfers.py @@ -4,19 +4,20 @@ import os.path as op import re import time -from pathlib import Path, PurePosixPath +from pathlib import Path, PurePosixPath, PurePath from django.db.models import Case, When, Count, Q, F import globus_sdk import numpy as np from one.alf.path import add_uuid_string, folder_parts from one.registration import get_dataset_type -from one.alf.spec import QC +from one.alf.spec import QC, regex, COLLECTION_SPEC from alyx import settings from data.models import FileRecord, Dataset, DatasetType, DataFormat, DataRepository from rest_framework.response import Response from actions.models import Session +from subjects.models import Subject logger = logging.getLogger(__name__) @@ -31,6 +32,7 @@ def get_config_path(path=''): def create_globus_client(): + # FIXME use ONE Globus client instead client = globus_sdk.NativeAppAuthClient(settings.GLOBUS_CLIENT_ID) client.oauth2_start_flow(refresh_tokens=True) return client @@ -188,52 +190,115 @@ def _get_repositories_for_labs(labs, server_only=False): return list(repositories) +def _parse_path(path): + pattern = regex(spec='{subject}/{date}/{number}').pattern + '.*' + m = re.match(pattern, path) + if not m: + raise ValueError(r"The path %s should be `nickname/YYYY-MM-DD/n/..." % path) + date = m.group('date') + nickname = m.group('subject') + session_number = int(m.group('number')) + # An error is raised if the subject or data repository do not exist. + subject = Subject.objects.get(nickname=nickname) + return subject, date, session_number + + def _get_name_collection_revision(file, rel_dir_path): """ Extract collection, revision and session parts from the full file path. :param file: The filename :param rel_dir_path: The relative path (subject/date/number/collection/revision) - :return: dict of path parts + :return: dict of path parts (or list of dicts) :return: a REST Response object if ALF path is invalid, otherwise None """ - # Get collections/revisions for each file - fullpath = Path(rel_dir_path).joinpath(file) - try: - info = folder_parts(fullpath.parent, as_dict=True) - if info['revision']: - path_parts = fullpath.parent.parts - assert path_parts.index(f"#{info['revision']}#") == len(path_parts) - 1 - except AssertionError: - data = {'status_code': 400, - 'detail': 'Invalid ALF path. There must be only 1 revision and it cannot contain' - 'sub folders. A revision folder must be surrounded by pound signs (#).'} - return None, Response(data=data, status=400) - except ValueError: - data = {'status_code': 400, - 'detail': 'Invalid ALF path. Only letters, numbers, hyphen and underscores ' - 'allowed. A revision folder must be surrounded by pound signs (#).'} - return None, Response(data=data, status=400) - - info['full_path'] = fullpath.as_posix() - info['filename'] = fullpath.name - info['rel_dir_path'] = '{subject}/{date}/{number}'.format(**info) - info = {k: v or '' for k, v in info.items()} - return info, None - - -def _change_default_dataset(session, collection, filename): + if return_single := isinstance(file, (str, PurePath)): + file = [file] + + parsed_paths = [] + for f in file: + # Get collections/revisions for each file + fullpath = Path(rel_dir_path).joinpath(f) + try: + info = folder_parts(fullpath.parent, as_dict=True) + if info['revision']: + path_parts = fullpath.parent.parts + assert path_parts.index(f"#{info['revision']}#") == len(path_parts) - 1 + except AssertionError: + data = {'status_code': 400, + 'detail': 'Invalid ALF path. There must be only 1 revision and it cannot contain' + 'sub folders. A revision folder must be surrounded by pound signs (#).'} + return None, Response(data=data, status=400) + except ValueError: + data = {'status_code': 400, + 'detail': 'Invalid ALF path. Only letters, numbers, hyphen and underscores ' + 'allowed. A revision folder must be surrounded by pound signs (#).'} + return None, Response(data=data, status=400) + + info['full_path'] = fullpath.as_posix() + info['filename'] = fullpath.name + info['rel_dir_path'] = '{subject}/{date}/{number}'.format(**info) + info = {k: v or '' for k, v in info.items()} + parsed_paths.append(info) + + if return_single: + return parsed_paths[0], None + return parsed_paths, None + + +def get_aggregate_collection_revision(files, rel_dir_path): + """Parse the path of a dataset not associated to a session. + + This is less strict than `_get_name_collection_revision`. + + By convention, the pattern is this: relation/identifier/(revision)/dataset + For example: Subjects/cortexlab/SP044/#2020-01-01#/obj.attr.ext + Tags/2026_Q1_Wang_Yu_et_al/obj.attr.ext + """ + dataset_path_parsed = [] + for f in files: + assert f + fullpath = Path(rel_dir_path, f) + _rel_dir_path = '' if len(fullpath.parts) == 1 else fullpath.parent.as_posix() + info = regex(spec=COLLECTION_SPEC).match(_rel_dir_path + '/').groupdict() + info['full_path'] = fullpath.as_posix() + info['filename'] = fullpath.name + info['rel_dir_path'] = _rel_dir_path + info['relation'] = info['identifier'] = None + if info['collection']: + if '/' in info['collection']: + # This is just convention and is not enforced + info['relation'], info['identifier'] = info['collection'].split('/', 1) + # Check that collection (if present) was captured correctly + expected_len = 0 if not info['collection'] else len(info['collection'].split('/')) + expected_len += int(bool(info['revision'])) + if len(fullpath.parent.parts) != expected_len: + data = {'status_code': 400, + 'detail': 'Invalid ALF path. Only letters, numbers, hyphen and underscores ' + 'allowed. A revision folder must be surrounded by pound signs (#).'} + return None, Response(data=data, status=400) + dataset_path_parsed.append(info) + return dataset_path_parsed, None + + +def _change_default_dataset(session, collection, filename, **kwargs): dataset = Dataset.objects.filter( - session=session, collection=collection or '', name=filename, default_dataset=True) + session=session, collection=collection or '', + name=filename, default_dataset=True, **kwargs) if dataset.count() > 0: dataset.update(default_dataset=False) -def _check_dataset_protected(session, collection, filename): +def _check_dataset_protected(session, collection, filename, **kwargs): # Order datasets by the latest revision with the original one last - dataset = Dataset.objects.filter( - session=session, collection=collection or '', name=filename).order_by( - F('revision__created_datetime').desc(nulls_last=True)) + dataset = ( + Dataset + .objects + .filter( + session=session, collection=collection or '', name=filename, **kwargs) + .order_by( + F('revision__created_datetime') + .desc(nulls_last=True))) if dataset.count() == 0: return False, [] else: @@ -246,9 +311,9 @@ def _check_dataset_protected(session, collection, filename): def _create_dataset_file_records( rel_dir_path=None, filename=None, session=None, user=None, repositories=None, exists_in=None, collection=None, hash=None, - file_size=None, version=None, revision=None, default=None, qc=None): - - assert session is not None + file_size=None, version=None, revision=None, default=None, + qc=None, content_type=None, object_id=None): + assert session or all([content_type, object_id]) collection = collection or '' revision_name = f'#{revision.name}#' if revision else '' relative_path = PurePosixPath(rel_dir_path, collection, revision_name, filename) @@ -260,12 +325,14 @@ def _create_dataset_file_records( # If we are going to set this one as the default we need to change any previous ones with # same session, collection and name to have default flag to false if default: - _change_default_dataset(session, collection, filename) + _change_default_dataset(session, collection, filename, + content_type=content_type, object_id=object_id) # Get or create the dataset. dataset, is_new = Dataset.objects.get_or_create( - collection=collection, name=filename, session=session, # content_object=session, - dataset_type=dataset_type, data_format=data_format, revision=revision + collection=collection, name=filename, session=session, + dataset_type=dataset_type, data_format=data_format, revision=revision, + content_type=content_type, object_id=object_id ) dataset.default_dataset = default is True try: @@ -273,7 +340,7 @@ def _create_dataset_file_records( except ValueError: data = {'status_code': 400, 'detail': f'Invalid QC value "{qc}" for dataset "{relative_path}"'} - return None, Response(data=data, status=403) + return None, Response(data=data, status=400) dataset.save() # If the dataset already existed see if it is protected (i.e can't be overwritten) diff --git a/alyx/data/views.py b/alyx/data/views.py index aa0551707..dc67d32f6 100644 --- a/alyx/data/views.py +++ b/alyx/data/views.py @@ -1,16 +1,16 @@ import logging -import re from django.contrib.auth import get_user_model +from django.contrib.contenttypes.models import ContentType +from django.core.exceptions import ValidationError from django.db import models from rest_framework import generics, viewsets, mixins, serializers from rest_framework.response import Response import django_filters -from one.alf.spec import regex from alyx.base import BaseFilterSet, rest_permission_classes -from subjects.models import Subject, Project from experiments.models import ProbeInsertion +from subjects.models import Subject, Project from misc.models import Lab from .models import (DataRepositoryType, DataRepository, @@ -33,9 +33,9 @@ RevisionSerializer, TagSerializer ) -from .transfers import (_get_session, _get_repositories_for_labs, - _create_dataset_file_records, bulk_sync, _check_dataset_protected, - _get_name_collection_revision) +from .transfers import (_get_session, _parse_path, _get_repositories_for_labs, bulk_sync, + _check_dataset_protected, _get_name_collection_revision, + get_aggregate_collection_revision, _create_dataset_file_records) logger = logging.getLogger(__name__) @@ -299,35 +299,39 @@ def _make_dataset_response(dataset): 'id': dataset.pk, 'name': dataset.name, 'file_size': dataset.file_size, - 'subject': dataset.session.subject.nickname, + 'subject': dataset.session.subject.nickname if dataset.session else None, 'created_by': dataset.created_by.username, 'created_datetime': dataset.created_datetime, 'dataset_type': getattr(dataset.dataset_type, 'name', ''), - 'data_format': getattr(dataset.data_format, 'name', ''), - 'session': getattr(dataset.session, 'pk', ''), - 'session_number': dataset.session.number, - 'session_users': ','.join(_.username for _ in dataset.session.users.all()), - 'session_start_time': dataset.session.start_time, + 'data_format': getattr(dataset.data_format, 'name', '')} + if dataset.session: + out.update({ + 'session': getattr(dataset.session, 'pk', ''), + 'session_number': dataset.session.number, + 'session_users': ','.join(_.username for _ in dataset.session.users.all()), + 'session_start_time': dataset.session.start_time}) + elif dataset.content_object: + out.update({ + 'content_type': f'{dataset.content_type.app_label}.{dataset.content_type.model}', + 'object_id': dataset.object_id}) + # NB: Keeping this order for backward compatibility + out.update({ 'collection': dataset.collection, 'revision': getattr(dataset.revision, 'name', ''), 'default': dataset.default_dataset, 'qc': dataset.qc - } + }) out['file_records'] = file_records return out -def _parse_path(path): - pattern = regex(spec='{subject}/{date}/{number}').pattern + '.*' - m = re.match(pattern, path) - if not m: - raise ValueError(r"The path %s should be `nickname/YYYY-MM-DD/n/..." % path) - date = m.group('date') - nickname = m.group('subject') - session_number = int(m.group('number')) - # An error is raised if the subject or data repository do not exist. - subject = Subject.objects.get(nickname=nickname) - return subject, date, session_number +def _get_content_type(content_type_str): + """Helper function to get content type object from a string that can be either 'model' or 'app.model'.""" + if '.' in content_type_str: + app_label, model_name = content_type_str.split('.') + return ContentType.objects.get(app_label=app_label, model=model_name) + else: + return ContentType.objects.get(model=content_type_str) class ProtectedFileViewSet(mixins.ListModelMixin, @@ -354,6 +358,7 @@ def list(self, request): Returns a response indicating if any of the datasets are protected or not - Status 403 if a dataset is protected, details contains a list of protected datasets - Status 200 is none of the datasets are protected + """ req = request.GET.dict() if len(request.data) == 0 else request.data @@ -371,25 +376,53 @@ def list(self, request): # Extract the data repository from the hostname, the subject, the directory path. rel_dir_path = rel_dir_path.replace('\\', '/') rel_dir_path = rel_dir_path.replace('//', '/') - subject, date, session_number = _parse_path(rel_dir_path) filenames = req.get('filenames', ()) if isinstance(filenames, str): filenames = filenames.split(',') - - session = _get_session( - subject=subject, date=date, number=session_number, user=user) - assert session - + filenames = list(filter(None, filenames)) # Remove empty strings + + content_type = req.get('content_type', None) + object_id = req.get('object_id', None) + + if content_type and object_id: + # Aggregate dataset + session = None + try: + content_type = _get_content_type(content_type) + model = content_type.model_class() + except ContentType.DoesNotExist: + data = {'status_code': 400, + 'detail': f'Invalid content type: {content_type}'} + return Response(data=data, status=400) + try: + model.objects.get(pk=object_id) + except (model.DoesNotExist, ValidationError): + data = {'status_code': 400, + 'detail': f'Invalid object ID: {object_id}'} + return Response(data=data, status=400) + dataset_path_parsed, resp = get_aggregate_collection_revision(filenames, rel_dir_path) + if resp: + return resp + elif any([content_type, object_id]): + data = {'status_code': 400, + 'error': 'Both content_type and object_id should be provided together.'} + return Response(data=data, status=400) + else: + subject, date, session_number = _parse_path(rel_dir_path) + session = _get_session( + subject=subject, date=date, number=session_number, user=user) + assert session + dataset_path_parsed, resp = _get_name_collection_revision(filenames, rel_dir_path) + if resp: + return resp # Loop through the files to see if any are protected prot_response = [] protected = [] - for file in filenames: - info, resp = _get_name_collection_revision(file, rel_dir_path) - if resp: - return resp + for file, info in zip(filenames, dataset_path_parsed): prot, prot_info = _check_dataset_protected( - session, info['collection'], info['filename']) + session, info['collection'], info['filename'], + content_type=content_type, object_id=object_id) protected.append(prot) prot_response.append({file: prot_info}) if any(protected): @@ -410,10 +443,10 @@ class RegisterFileViewSet(mixins.CreateModelMixin, def create(self, request): """ - Endpoint to create a register a dataset record through the REST API. + Endpoint to create and register a dataset and file record through the REST API. The session is retrieved by the ALF convention in the relative path, so this field has to - match the format Subject/Date/Number as shown below. + match the format Subject/Date/Number as shown below, unless The set of repositories are given through the labs. The lab is by default the subject lab, but if it is specified, it overrides the subject lab entirely. @@ -451,6 +484,14 @@ def create(self, request): 'projects': 'alyx_lab_name', # optional, alias of lab field above } ``` + + For registering data that is not associated with a session, the following fields should be + provided: + ```python + r_ = {'content_type': 'model_name', # e.g. 'probeinsertion', 'experiment.probeinsertion' + 'object_id': 'object_pk', # UUID of the object the dataset is associated + ``` + NB: The repository hostname or name should be provided for non-session datasets. If the dataset already exists, it will use the file hash to deduce if the file has been patched or not (i.e. the filerecords will be created as not existing) @@ -475,20 +516,20 @@ def create(self, request): rel_dir_path = request.data.get('path', '') if not rel_dir_path: raise ValueError("The path argument is required.") - - # Extract the data repository from the hostname, the subject, the directory path. rel_dir_path = rel_dir_path.replace('\\', '/') - rel_dir_path = rel_dir_path.replace('//', '/') - subject, date, session_number = _parse_path(rel_dir_path) + rel_dir_path = rel_dir_path.replace('//', '/').strip('/') filenames = request.data.get('filenames', ()) if isinstance(filenames, str): filenames = filenames.split(',') + filenames = list(filter(None, filenames)) # Remove empty strings # versions if provided versions = request.data.get('versions', [None] * len(filenames)) if isinstance(versions, str): versions = versions.split(',') + if len(versions) == 1: + versions = versions * len(filenames) # file hashes if provided hashes = request.data.get('hashes', [None] * len(filenames)) @@ -522,15 +563,73 @@ def create(self, request): if isinstance(check_protected, str): check_protected = check_protected == 'True' - # Multiple labs (NB: projects is an alias of labs) - labs = request.data.get('labs', []) - if isinstance(labs, str): - labs = labs.split(',') - projects = request.data.get('projects', []) - if isinstance(projects, str): - projects = projects.split(',') - labs = [Lab.objects.get(name=lab) for lab in labs + projects if lab] - repositories = _get_repositories_for_labs(labs or [subject.lab], server_only=server_only) + # If the content type and object id are provided, we skip the session retrieval; + # The dataset is associated to a different model than actions.Session + content_type = request.data.get('content_type', None) + object_id = request.data.get('object_id', None) + if content_type and object_id: + # Aggregate dataset + try: + content_type = _get_content_type(content_type) + model = content_type.model_class() + except ContentType.DoesNotExist: + data = {'status_code': 400, + 'detail': f'Invalid content type: {content_type}'} + return Response(data=data, status=400) + try: + model.objects.get(pk=object_id) + except (model.DoesNotExist, ValidationError): + data = {'status_code': 400, + 'detail': f'Invalid object ID: {object_id}'} + return Response(data=data, status=400) + # For aggregate datasets the repository must be specified, + # as we cannot retrieve it from the session subject + if not repo: + data = {'status_code': 400, + 'detail': 'A valid repository name or hostname must be provided for non-session datasets.'} + return Response(data=data, status=400) + repositories = [repo] + session = subject = None + dataset_path_parsed, resp = get_aggregate_collection_revision(filenames, rel_dir_path) + if resp: + return resp + elif any([content_type, object_id]): + data = {'status_code': 400, + 'error': 'Both content_type and object_id should be provided together.'} + return Response(data=data, status=400) + else: + # Extract the session from the directory path. + try: + subject, date, session_number = _parse_path(rel_dir_path) + except ValueError as e: + data = {'status_code': 400, 'error': str(e)} + return Response(data=data, status=400) + except Subject.DoesNotExist: + err = f'A subject with nickname "{rel_dir_path.split("/")[0]}" does not exist.' + data = {'status_code': 400, 'error': err} + return Response(data=data, status=400) + try: + session = _get_session( + subject=subject, date=date, number=session_number, user=user) + except ValueError as e: + data = {'status_code': 400, 'error': str(e)} + return Response(data=data, status=400) + # Parse paths + dataset_path_parsed, resp = _get_name_collection_revision(filenames, rel_dir_path) + if resp: + return resp + + # Multiple labs (NB: projects is an alias of labs) + labs = request.data.get('labs', []) + if isinstance(labs, str): + labs = labs.split(',') + projects = request.data.get('projects', []) + if isinstance(projects, str): + projects = projects.split(',') + labs = [Lab.objects.get(name=lab) for lab in labs + projects if lab] + + repositories = _get_repositories_for_labs(labs or [subject.lab], server_only=server_only) + if repo and repo not in repositories: repositories += [repo] if server_only: @@ -544,23 +643,14 @@ def create(self, request): if not exists: exists_in = (None,) - session = _get_session( - subject=subject, date=date, number=session_number, user=user) - assert session - # If the check protected flag is True, loop through the files to see if any are protected if check_protected: prot_response = [] protected = [] - for file in filenames: - - info, resp = _get_name_collection_revision(file, rel_dir_path) - - if resp: - return resp - + for file, info in zip(filenames, dataset_path_parsed): prot, prot_info = _check_dataset_protected( - session, info['collection'], info['filename']) + session, info['collection'], info['filename'], + content_type=content_type, object_id=object_id) protected.append(prot) prot_response.append({file: prot_info}) @@ -571,14 +661,8 @@ def create(self, request): return Response(data=data, status=403) response = [] - for filename, hash, fsize, version, qc in zip(filenames, hashes, filesizes, versions, qcs): - if not filename: - continue - info, resp = _get_name_collection_revision(filename, rel_dir_path) - - if resp: - return resp - + all_info = zip(dataset_path_parsed, hashes, filesizes, versions, qcs) + for info, hash, fsize, version, qc in all_info: if info['revision']: revision, _ = Revision.objects.get_or_create(name=info['revision']) else: @@ -588,7 +672,7 @@ def create(self, request): collection=info['collection'], rel_dir_path=info['rel_dir_path'], filename=info['filename'], session=session, user=user, repositories=repositories, exists_in=exists_in, hash=hash or '', file_size=fsize, version=version or '', - revision=revision, default=default, qc=qc) + revision=revision, default=default, qc=qc, content_type=content_type, object_id=object_id) if resp: return resp out = _make_dataset_response(dataset) diff --git a/alyx/experiments/serializers.py b/alyx/experiments/serializers.py index cfc21b6f9..550c93ce3 100644 --- a/alyx/experiments/serializers.py +++ b/alyx/experiments/serializers.py @@ -109,7 +109,7 @@ def setup_eager_loading(queryset): """ queryset = queryset.select_related('model', 'session', 'session__subject', 'session__lab') queryset = queryset.prefetch_related('session__projects') - return queryset.order_by('-session__start_time') + return queryset.order_by('-session__start_time', 'name') model = serializers.SlugRelatedField(read_only=True, slug_field='name') session_info = SessionListSerializer(read_only=True, source='session') @@ -126,7 +126,7 @@ def setup_eager_loading(queryset): """ Perform necessary eager loading of data to avoid horrible performance.""" queryset = queryset.select_related('model', 'session', 'session__subject', 'session__lab') queryset = queryset.prefetch_related('session__projects', 'datasets') - return queryset.order_by('-session__start_time') + return queryset.order_by('-session__start_time', 'name') session = serializers.SlugRelatedField( read_only=False, required=False, slug_field='id', @@ -161,7 +161,7 @@ def setup_eager_loading(queryset): """ Perform necessary eager loading of data to avoid horrible performance.""" queryset = queryset.select_related('model', 'session', 'session__subject', 'session__lab') queryset = queryset.prefetch_related('session__projects') - return queryset.order_by('-session__start_time') + return queryset.order_by('-session__start_time', 'name') session = serializers.SlugRelatedField( read_only=False, required=False, slug_field='id', @@ -311,7 +311,7 @@ def setup_eager_loading(queryset): queryset = queryset.prefetch_related( 'datasets', Prefetch('location', queryset=location_qs) ) - return queryset.order_by('-session__start_time') + return queryset.order_by('-session__start_time', 'name') class Meta: model = FOV diff --git a/alyx/experiments/tests_rest.py b/alyx/experiments/tests_rest.py index d9a572477..326cf9e9b 100644 --- a/alyx/experiments/tests_rest.py +++ b/alyx/experiments/tests_rest.py @@ -120,6 +120,24 @@ def test_probe_insertion_rest(self): probe_ins = self.ar(self.client.get(urlf)) self.assertTrue(len(probe_ins) == 0) + # Test that the serializer returns probes sorted by session start time and then name + # Create a more recent session + session = Session.objects.create(subject=self.session.subject, number=2) + session.task_protocol = 'ephys' + session.save() + insertions = [] + for name in ['probe01', 'probe02', 'probe00']: + insertion = {'session': str(session.id), + 'name': name, + 'model': '3A' + } + url = reverse('probeinsertion-list') + insertions.append(self.ar(self.post(url, insertion), 201)) + + probe_ins = self.ar(self.client.get(reverse('probeinsertion-list')), 200) + self.assertEqual(probe_ins[0]['session_info']['id'], str(session.id)) + self.assertEqual([x['name'] for x in probe_ins], ['probe00', 'probe01', 'probe02', 'probe00', 'probe01']) + def test_probe_insertion_dataset_interaction(self): # First create two insertions and attach to session probe_names = ['probe00', 'probe01'] diff --git a/deploy/docker/settings-deploy.py b/deploy/docker/settings-deploy.py index a1ea9850b..251bfc84f 100644 --- a/deploy/docker/settings-deploy.py +++ b/deploy/docker/settings-deploy.py @@ -185,11 +185,24 @@ WSGI_APPLICATION = 'alyx.wsgi.application' +THROTTLE_MODE = os.getenv('THROTTLE_MODE', 'user-based').strip().lower() +if THROTTLE_MODE not in ('user-based', 'anonymous'): + raise ValueError('THROTTLE_MODE must be one of: user-based, anonymous') + REST_FRAMEWORK = { 'DEFAULT_AUTHENTICATION_CLASSES': ( 'rest_framework.authentication.SessionAuthentication', 'rest_framework.authentication.TokenAuthentication', ), + 'DEFAULT_THROTTLE_CLASSES': [ + 'alyx.throttling.BurstRateThrottle', + 'alyx.throttling.SustainedRateThrottle', + ], + 'DEFAULT_THROTTLE_RATES': { + 'burst': os.getenv('THROTTLE_BURST_RATE', '550/minute'), + 'sustained': os.getenv('THROTTLE_SUSTAINED_RATE', None), + 'docs': os.getenv('THROTTLE_DOCS_RATE', '20/minute'), + }, 'DEFAULT_FILTER_BACKENDS': ('django_filters.rest_framework.DjangoFilterBackend',), 'STRICT_JSON': False, 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.LimitOffsetPagination', diff --git a/requirements.txt b/requirements.txt index c7793382b..8ff649d3c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,7 @@ gunicorn ipython markdown matplotlib -ONE-api>=3.0 +ONE-api>=3.5.1 pillow>=11.3.0 psycopg2-binary python-dateutil diff --git a/requirements_frozen.txt b/requirements_frozen.txt index bc19880a2..8c3daa6c0 100644 --- a/requirements_frozen.txt +++ b/requirements_frozen.txt @@ -11,9 +11,9 @@ coreapi==2.3.3 coreschema==0.0.4 coverage==7.8.2 coveralls==4.0.1 -cryptography==46.0.5 +cryptography==46.0.6 cycler==0.12.1 -Django==5.2.11 +Django==5.2.12 django-admin-list-filter-dropdown==1.0.3 django-admin-rangefilter==0.13.2 django-autocomplete-light==3.12.1 @@ -43,7 +43,7 @@ jmespath==1.0.1 kiwisolver==1.4.8 llvmlite==0.44.0 lxml==5.4.0 -Markdown==3.8 +Markdown==3.8.1 MarkupSafe==3.0.2 matplotlib==3.10.3 mccabe==0.7.0 @@ -58,14 +58,14 @@ pyarrow==20.0.0 pycodestyle==2.13.0 pycparser==2.22 pyflakes==3.3.2 -PyJWT==2.10.1 +PyJWT==2.12.0 pyparsing==3.2.3 python-dateutil==2.9.0.post0 python-ipware==3.0.0 python-magic==0.4.27 pytz==2025.2 PyYAML==6.0.2 -requests==2.32.3 +requests==2.33.0 ruff==0.11.11 s3transfer==0.13.0 setuptools==80.9.0