88from cohort .serializers import WSJobStatus , JobName
99from cohort .services .base_service import CommonService
1010from cohort .services .dated_measure import dm_service
11- from cohort .services .utils import get_authorization_header , ServerError
11+ from cohort .services .utils import ServerError
1212from admin_cohort .services .ws_event_manager import WebsocketManager , WebSocketMessageType
1313from cohort .tasks import create_cohort
1414
@@ -36,7 +36,12 @@ def build_query(cohort_source_id: str, fhir_filter_id: str = None) -> str:
3636 }
3737 return json .dumps (query )
3838
39- def create_cohort_subset (self , request , owner_id : str , table_name : str , source_cohort : CohortResult , fhir_filter_id : str ) -> CohortResult :
39+ def create_cohort_subset (self ,
40+ auth_headers : dict ,
41+ owner_id : str ,
42+ table_name : str ,
43+ source_cohort : CohortResult ,
44+ fhir_filter_id : str ) -> CohortResult :
4045 def copy_query_snapshot (snapshot : RequestQuerySnapshot ) -> RequestQuerySnapshot :
4146 return RequestQuerySnapshot .objects .create (owner = snapshot .owner ,
4247 request = snapshot .request ,
@@ -62,7 +67,7 @@ def copy_dated_measure(dm: DatedMeasure) -> DatedMeasure:
6267 dated_measure = new_dm ,
6368 request_query_snapshot = new_rqs )
6469 with transaction .atomic ():
65- self .handle_cohort_creation (cohort_subset , request , False )
70+ self .handle_cohort_creation (cohort_subset , auth_headers )
6671 return cohort_subset
6772
6873 @staticmethod
@@ -74,9 +79,9 @@ def count_active_jobs():
7479 return CohortResult .objects .filter (request_job_status__in = active_statuses ) \
7580 .count ()
7681
77- def handle_cohort_creation (self , cohort : CohortResult , request , global_estimate : bool ) -> None :
82+ def handle_cohort_creation (self , cohort : CohortResult , auth_headers : dict , global_estimate : bool = False ) -> None :
7883 if global_estimate :
79- dm_service .handle_global_count (cohort , request )
84+ dm_service .handle_global_count (cohort , auth_headers )
8085 try :
8186 if cohort .parent_cohort and cohort .sampling_ratio :
8287 json_query = self .build_query (cohort_source_id = cohort .parent_cohort .group_id )
@@ -91,7 +96,7 @@ def handle_cohort_creation(self, cohort: CohortResult, request, global_estimate:
9196
9297 create_cohort .s (cohort_id = cohort .pk ,
9398 json_query = json_query ,
94- auth_headers = get_authorization_header ( request ) ,
99+ auth_headers = auth_headers ,
95100 cohort_creator_cls = self .operator_cls ,
96101 sampling_ratio = cohort .sampling_ratio ) \
97102 .apply_async ()
0 commit comments