@@ -39,45 +39,54 @@ def allow_request(self, request, view):
3939 if self .rate is None :
4040 return True
4141
42- # Amend scope with optional bucket
43- bucket = getattr (view , self .scope_attr + "_bucket" , None )
44- if bucket is not None :
45- self .scope += ":" + sha1 (bucket .encode ()).hexdigest ()
42+ buckets = getattr (view , self .scope_attr + "_bucket" , None )
43+ if not isinstance (buckets , set ):
44+ buckets = {buckets }
4645
4746 self .now = self .timer ()
4847 self .num_requests , self .duration = zip (* self .parse_rate (self .rate ))
49- self .key = self .get_cache_key (request , view )
50- self .history = {key : [] for key in self .key }
51- self .history .update (self .cache .get_many (self .key ))
52-
53- for num_requests , duration , key in zip (
54- self .num_requests , self .duration , self .key
55- ):
56- history = self .history [key ]
57- # Drop any requests from the history which have now passed the
58- # throttle duration
59- while history and history [- 1 ] <= self .now - duration :
60- history .pop ()
61- if len (history ) >= num_requests :
62- # Prepare variables used by the Throttle's wait() method that gets called by APIView.check_throttles()
63- self .num_requests , self .duration , self .key , self .history = (
64- num_requests ,
65- duration ,
66- key ,
67- history ,
68- )
69- response = self .throttle_failure ()
70- metrics .get ("desecapi_throttle_failure" ).labels (
71- request .method , scope , request .user .pk , bucket
72- ).inc ()
73- return response
74- self .history [key ] = history
48+ self .histories = {}
49+ for bucket in buckets :
50+ # Amend scope with optional bucket
51+ if bucket is not None :
52+ self .scope = scope + ":" + sha1 (bucket .encode ()).hexdigest ()
53+ else :
54+ self .scope = scope
55+
56+ self .key = self .get_cache_key (request , view )
57+ bucket_history = {key : [] for key in self .key }
58+ bucket_history .update (self .cache .get_many (self .key ))
59+
60+ for num_requests , duration , key in zip (
61+ self .num_requests , self .duration , self .key
62+ ):
63+ history = bucket_history [key ]
64+ # Drop any requests from the history which have now passed the
65+ # throttle duration
66+ while history and history [- 1 ] <= self .now - duration :
67+ history .pop ()
68+ if len (history ) >= num_requests :
69+ # Prepare variables used by the Throttle's wait() method that gets called by APIView.check_throttles()
70+ self .num_requests , self .duration , self .key , self .history = (
71+ num_requests ,
72+ duration ,
73+ key ,
74+ history ,
75+ )
76+ response = self .throttle_failure ()
77+ metrics .get ("desecapi_throttle_failure" ).labels (
78+ request .method , scope , request .user .pk , bucket
79+ ).inc ()
80+ return response
81+ bucket_history [key ] = history
82+ self .histories [bucket ] = bucket_history
7583 return self .throttle_success ()
7684
7785 def throttle_success (self ):
78- for key in self .history :
79- self .history [key ].insert (0 , self .now )
80- self .cache .set_many (self .history , max (self .duration ))
86+ for bucket_history in self .histories .values ():
87+ for history in bucket_history .values ():
88+ history .insert (0 , self .now )
89+ self .cache .set_many (bucket_history , max (self .duration ))
8190 return True
8291
8392 # Override the static attribute of the parent class so that we can dynamically apply override settings for testing
0 commit comments