33import json
44import logging
55import sys
6- import uuid
76from datetime import datetime
87from typing import Any , Mapping , Optional , Sequence , Type , TypeVar , Union
98
109from marshmallow import fields
10+ from ...utils .uuid_utils import uuid4
1111
1212from ...cache .base import BaseCache
1313from ...config .settings import BaseSettings
1414from ...core .profile import ProfileSession
15- from ...storage .base import BaseStorage , StorageDuplicateError , StorageNotFoundError
15+ from ...storage .base import (
16+ DEFAULT_PAGE_SIZE ,
17+ BaseStorage ,
18+ StorageDuplicateError ,
19+ StorageNotFoundError ,
20+ )
1621from ...storage .record import StorageRecord
1722from ..util import datetime_to_str , time_now
18- from ..valid import INDY_ISO8601_DATETIME_EXAMPLE , INDY_ISO8601_DATETIME_VALIDATE
23+ from ..valid import (
24+ INDY_ISO8601_DATETIME_EXAMPLE ,
25+ INDY_ISO8601_DATETIME_VALIDATE ,
26+ )
1927from .base import BaseModel , BaseModelError , BaseModelSchema
2028
2129LOGGER = logging .getLogger (__name__ )
@@ -46,8 +54,7 @@ def match_post_filter(
4654 return (
4755 positive
4856 and all (
49- record .get (k ) and record .get (k ) in alts
50- for k , alts in post_filter .items ()
57+ record .get (k ) and record .get (k ) in alts for k , alts in post_filter .items ()
5158 )
5259 ) or (
5360 (not positive )
@@ -224,11 +231,12 @@ async def retrieve_by_id(
224231 Args:
225232 session: The profile session to use
226233 record_id: The ID of the record to find
234+ for_update: Whether to lock the record for update
227235 """
228236
229237 storage = session .inject (BaseStorage )
230238 result = await storage .get_record (
231- cls .RECORD_TYPE , record_id , {"forUpdate" : for_update , "retrieveTags" : False }
239+ cls .RECORD_TYPE , record_id , options = {"forUpdate" : for_update }
232240 )
233241 vals = json .loads (result .value )
234242 return cls .from_storage (record_id , vals )
@@ -238,24 +246,26 @@ async def retrieve_by_tag_filter(
238246 cls : Type [RecordType ],
239247 session : ProfileSession ,
240248 tag_filter : dict ,
241- post_filter : dict = None ,
249+ post_filter : Optional [ dict ] = None ,
242250 * ,
243251 for_update = False ,
244252 ) -> RecordType :
245253 """Retrieve a record by tag filter.
246254
247255 Args:
256+ cls: The record class
248257 session: The profile session to use
249258 tag_filter: The filter dictionary to apply
250259 post_filter: Additional value filters to apply matching positively,
251260 with sequence values specifying alternatives to match (hit any)
261+ for_update: Whether to lock the record for update
252262 """
253263
254264 storage = session .inject (BaseStorage )
255265 rows = await storage .find_all_records (
256266 cls .RECORD_TYPE ,
257267 cls .prefix_tag_filter (tag_filter ),
258- options = {"forUpdate" : for_update , "retrieveTags" : False },
268+ options = {"forUpdate" : for_update },
259269 )
260270 found = None
261271 for record in rows :
@@ -282,65 +292,107 @@ async def retrieve_by_tag_filter(
282292 async def query (
283293 cls : Type [RecordType ],
284294 session : ProfileSession ,
285- tag_filter : dict = None ,
295+ tag_filter : Optional [ dict ] = None ,
286296 * ,
287- post_filter_positive : dict = None ,
288- post_filter_negative : dict = None ,
297+ limit : Optional [int ] = None ,
298+ offset : Optional [int ] = None ,
299+ order_by : Optional [str ] = None ,
300+ descending : bool = False ,
301+ post_filter_positive : Optional [dict ] = None ,
302+ post_filter_negative : Optional [dict ] = None ,
289303 alt : bool = False ,
290304 ) -> Sequence [RecordType ]:
291305 """Query stored records.
292306
293307 Args:
294308 session: The profile session to use
295309 tag_filter: An optional dictionary of tag filter clauses
310+ limit: The maximum number of records to retrieve
311+ offset: The offset to start retrieving records from
312+ order_by: An optional field by which to order the records.
313+ descending: Whether to order the records in descending order.
296314 post_filter_positive: Additional value filters to apply matching positively
297315 post_filter_negative: Additional value filters to apply matching negatively
298316 alt: set to match any (positive=True) value or miss all (positive=False)
299317 values in post_filter
300318 """
301-
302319 storage = session .inject (BaseStorage )
303- rows = await storage .find_all_records (
304- cls .RECORD_TYPE ,
305- cls .prefix_tag_filter (tag_filter ),
306- options = {"retrieveTags" : False },
307- )
320+
321+ tag_query = cls .prefix_tag_filter (tag_filter )
322+ post_filter = post_filter_positive or post_filter_negative
323+
324+ # set flag to indicate if pagination is requested or not, then set defaults
325+ paginated = limit is not None or offset is not None
326+ limit = limit or DEFAULT_PAGE_SIZE
327+ offset = offset or 0
328+
329+ if not post_filter and paginated :
330+ # Only fetch paginated records if post-filter is not being applied
331+ rows = await storage .find_paginated_records (
332+ type_filter = cls .RECORD_TYPE ,
333+ tag_query = tag_query ,
334+ limit = limit ,
335+ offset = offset ,
336+ order_by = order_by ,
337+ descending = descending ,
338+ )
339+ else :
340+ rows = await storage .find_all_records (
341+ type_filter = cls .RECORD_TYPE ,
342+ tag_query = tag_query ,
343+ order_by = order_by ,
344+ descending = descending ,
345+ )
346+
347+ num_results_post_filter = 0 # used if applying pagination post-filter
348+ num_records_to_match = limit + offset # ignored if not paginated
349+
308350 result = []
309351 for record in rows :
310- vals = json .loads (record .value )
311- if match_post_filter (
312- vals ,
313- post_filter_positive ,
314- positive = True ,
315- alt = alt ,
316- ) and match_post_filter (
317- vals ,
318- post_filter_negative ,
319- positive = False ,
320- alt = alt ,
321- ):
322- try :
352+ try :
353+ vals = json .loads (record .value )
354+ if not post_filter : # pagination would already be applied if requested
323355 result .append (cls .from_storage (record .id , vals ))
324- except BaseModelError as err :
325- raise BaseModelError (f"{ err } , for record id { record .id } " )
356+ else :
357+ continue_processing = (
358+ not paginated or num_results_post_filter < num_records_to_match
359+ )
360+ if not continue_processing :
361+ break
362+
363+ post_filter_match = match_post_filter (
364+ vals , post_filter_positive , positive = True , alt = alt
365+ ) and match_post_filter (
366+ vals , post_filter_negative , positive = False , alt = alt
367+ )
368+
369+ if not post_filter_match :
370+ continue
371+
372+ if num_results_post_filter >= offset : # append only after offset
373+ result .append (cls .from_storage (record .id , vals ))
374+
375+ num_results_post_filter += 1
376+ except (BaseModelError , json .JSONDecodeError , TypeError ) as err :
377+ raise BaseModelError (f"{ err } , for record id { record .id } " )
326378 return result
327379
328380 async def save (
329381 self ,
330382 session : ProfileSession ,
331383 * ,
332- reason : str = None ,
384+ reason : Optional [ str ] = None ,
333385 log_params : Mapping [str , Any ] = None ,
334386 log_override : bool = False ,
335- event : bool = None ,
387+ event : Optional [ bool ] = None ,
336388 ) -> str :
337389 """Persist the record to storage.
338390
339391 Args:
340392 session: The profile session to use
341393 reason: A reason to add to the log
342394 log_params: Additional parameters to log
343- override : Override configured logging regimen, print to stderr instead
395+ log_override : Override configured logging regimen, print to stderr instead
344396 event: Flag to override whether the event is sent
345397 """
346398
@@ -355,7 +407,7 @@ async def save(
355407 new_record = False
356408 else :
357409 if not self ._id :
358- self ._id = str (uuid . uuid4 ())
410+ self ._id = str (uuid4 ())
359411 self .created_at = self .updated_at
360412 await storage .add_record (self .storage_record )
361413 new_record = True
@@ -380,7 +432,7 @@ async def post_save(
380432 session : ProfileSession ,
381433 new_record : bool ,
382434 last_state : Optional [str ],
383- event : bool = None ,
435+ event : Optional [ bool ] = None ,
384436 ):
385437 """Perform post-save actions.
386438
@@ -411,7 +463,7 @@ async def delete_record(self, session: ProfileSession):
411463 await self .emit_event (session , self .serialize ())
412464 await storage .delete_record (self .storage_record )
413465
414- async def emit_event (self , session : ProfileSession , payload : Any = None ):
466+ async def emit_event (self , session : ProfileSession , payload : Optional [ Any ] = None ):
415467 """Emit an event.
416468
417469 Args:
@@ -436,12 +488,11 @@ async def emit_event(self, session: ProfileSession, payload: Any = None):
436488 def log_state (
437489 cls ,
438490 msg : str ,
439- params : dict = None ,
440- settings : BaseSettings = None ,
491+ params : Optional [ dict ] = None ,
492+ settings : Optional [ BaseSettings ] = None ,
441493 override : bool = False ,
442494 ):
443495 """Print a message with increased visibility (for testing)."""
444-
445496 if override or (
446497 cls .LOG_STATE_FLAG and settings and settings .get (cls .LOG_STATE_FLAG )
447498 ):
@@ -454,10 +505,7 @@ def log_state(
454505 @classmethod
455506 def strip_tag_prefix (cls , tags : dict ):
456507 """Strip tilde from unencrypted tag names."""
457-
458- return (
459- {(k [1 :] if "~" in k else k ): v for (k , v ) in tags .items ()} if tags else {}
460- )
508+ return {(k [1 :] if "~" in k else k ): v for (k , v ) in tags .items ()} if tags else {}
461509
462510 @classmethod
463511 def prefix_tag_filter (cls , tag_filter : dict ):
0 commit comments