Skip to content

Commit dea33bd

Browse files
committed
Move make_identifier function to collect.py
in aggregate_issues, we have both the Issue object and the deserialized data, so we can compute the identifier without relying on a fragile iteration on list of key_lists Also add some more precise typing
1 parent ad42432 commit dea33bd

5 files changed

Lines changed: 107 additions & 71 deletions

File tree

bugwarrior/collect.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from collections.abc import Iterator
1+
from collections.abc import Iterable, Iterator
22
import copy
33
from functools import cache
44
from importlib.metadata import entry_points
5+
import json
56
import logging
67
import multiprocessing
78
import time
@@ -10,6 +11,8 @@
1011
from jinja2 import Template
1112
from taskw.task import Task
1213

14+
from bugwarrior.types import CollectedIssue, CollectionErrorData, TaskwarriorData
15+
1316
if TYPE_CHECKING:
1417
from bugwarrior.config.validation import Config
1518
from bugwarrior.services import Issue, Service
@@ -81,7 +84,9 @@ def _aggregate_issues(service: "Service", queue: multiprocessing.Queue) -> None:
8184
log.info(f"Done with [{target}] in {duration}.")
8285

8386

84-
def aggregate_issues(conf: "Config", debug: bool) -> Iterator[dict | tuple[str, str]]:
87+
def aggregate_issues(
88+
conf: "Config", debug: bool
89+
) -> Iterator[CollectedIssue | CollectionErrorData]:
8590
"""Return all issues from every target."""
8691
log.info("Starting to aggregate remote issues.")
8792

@@ -111,22 +116,33 @@ def aggregate_issues(conf: "Config", debug: bool) -> Iterator[dict | tuple[str,
111116
while currently_running > 0:
112117
issue = queue.get(True)
113118
try:
114-
record = TaskConstructor(issue).get_taskwarrior_record()
115-
record['target'] = issue.config.target
119+
record = TaskConstructor(issue).get_data_to_sync()
116120
yield record
117121
except AttributeError:
118122
if isinstance(issue, tuple):
119123
currently_running -= 1
120124
completion_type, target = issue
121125
if completion_type == SERVICE_FINISHED_ERROR:
122126
log.error(f"Aborted [{target}] due to critical error.")
123-
yield ('SERVICE FAILED', target)
127+
yield CollectionErrorData('SERVICE FAILED', target)
124128
continue
125129
raise
126130

127131
log.info("Done aggregating remote issues.")
128132

129133

134+
def make_unique_identifier(
135+
unique_keys: Iterable[str], taskwarrior_data: TaskwarriorData
136+
) -> str:
137+
"""For a given issue, make an identifier from its unique keys.
138+
139+
This is not the same as the taskwarrior uuid, which is assigned
140+
only once the task is created.
141+
"""
142+
subset = {key: taskwarrior_data[key] for key in unique_keys}
143+
return json.dumps(subset, sort_keys=True)
144+
145+
130146
class TaskConstructor:
131147
"""Construct a taskwarrior task from a foreign record."""
132148

@@ -152,6 +168,10 @@ def get_taskwarrior_record(self, refined: bool = True) -> dict[str, Any]:
152168
record['tags'] = []
153169
if refined:
154170
record['tags'].extend(self.get_added_tags())
171+
172+
# Blank priority should mean *no* priority
173+
if record['priority'] == '':
174+
record['priority'] = None
155175
return record
156176

157177
def get_template_context(self) -> dict[str, Any]:
@@ -168,3 +188,11 @@ def refine_record(self, record: dict[str, Any]) -> dict[str, Any]:
168188
elif field == 'description':
169189
record['description'] = self.issue.get_default_description()
170190
return record
191+
192+
def get_data_to_sync(self) -> CollectedIssue:
193+
taskwarrior_data = self.get_taskwarrior_record()
194+
return CollectedIssue(
195+
taskwarrior_data=taskwarrior_data,
196+
identifier=make_unique_identifier(self.issue.UNIQUE_KEY, taskwarrior_data),
197+
target=self.issue.config.target,
198+
)

bugwarrior/db.py

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from collections.abc import Collection, Iterable, Iterator
22
import itertools
3-
import json
43
import logging
54
import re
65
import subprocess
@@ -11,6 +10,7 @@
1110

1211
from bugwarrior.collect import get_service
1312
from bugwarrior.notifications import send_notification
13+
from bugwarrior.types import CollectedIssue, CollectionErrorData
1414

1515
if TYPE_CHECKING:
1616
from bugwarrior.config.schema import MainSectionConfig
@@ -43,21 +43,6 @@ def get_managed_task_uuids(
4343
return expected_task_ids
4444

4545

46-
def make_unique_identifier(
47-
unique_key_sets: Iterable[Collection[str]], issue: dict[str, Any]
48-
) -> str:
49-
"""For a given issue, make an identifier from its unique keys.
50-
51-
This is not the same as the taskwarrior uuid, which is assigned
52-
only once the task is created.
53-
"""
54-
for unique_keys in unique_key_sets:
55-
if all(key in issue for key in unique_keys):
56-
subset = {key: issue[key] for key in unique_keys}
57-
return json.dumps(subset, sort_keys=True)
58-
raise RuntimeError("Could not determine unique identifier for %s" % issue)
59-
60-
6146
def find_taskwarrior_uuid(
6247
tw: TaskWarriorShellout,
6348
unique_key_sets: Iterable[Collection[str]],
@@ -175,7 +160,7 @@ def run_hooks(pre_import: list[str]) -> None:
175160

176161

177162
def synchronize(
178-
issue_generator: Iterable[dict | tuple[str, str]],
163+
issue_generator: Iterator[CollectedIssue | CollectionErrorData],
179164
conf: "Config",
180165
dry_run: bool = False,
181166
) -> None:
@@ -207,27 +192,25 @@ def synchronize(
207192
}
208193

209194
for issue in issue_generator:
210-
if isinstance(issue, tuple):
211-
assert issue[0] == 'SERVICE FAILED', (
212-
"'issue' should only be a tuple in case of a failure"
213-
)
214-
successful_config_map.pop(issue[1])
195+
if isinstance(issue, CollectionErrorData):
196+
successful_config_map.pop(issue.target)
215197
continue
216198

217199
# De-duplicate issues coming in
218-
unique_identifier = make_unique_identifier(unique_key_sets, issue)
219-
if unique_identifier in issue_map:
220-
log.debug(f"Merging tags and skipping. Seen {unique_identifier} of {issue}")
200+
if issue.identifier in issue_map:
201+
log.debug(f"Merging tags and skipping. Seen {issue.identifier} of {issue}")
221202
# Merge and deduplicate tags.
222-
issue_map[unique_identifier]['tags'] += issue['tags']
223-
issue_map[unique_identifier]['tags'] = list(
224-
set(issue_map[unique_identifier]['tags'])
203+
new_tags = sorted(
204+
set(issue_map[issue.identifier].taskwarrior_data['tags'])
205+
| set(issue.taskwarrior_data['tags'])
225206
)
207+
issue_map[issue.identifier].taskwarrior_data['tags'] = new_tags
208+
226209
else:
227-
issue_map[unique_identifier] = issue
210+
issue_map[issue.identifier] = issue
228211

229212
seen_uuids = set()
230-
for issue in issue_map.values():
213+
for issue, target, _ in issue_map.values():
231214
# We received this issue from The Internet, but we're not sure what
232215
# kind of encoding the service providers may have handed us. Let's try
233216
# and decode all byte strings from UTF8 off the bat. If we encounter
@@ -240,12 +223,7 @@ def synchronize(
240223
except UnicodeDecodeError:
241224
log.warning("Failed to interpret %r as utf-8" % key)
242225

243-
# Blank priority should mean *no* priority
244-
if issue['priority'] == '':
245-
issue['priority'] = None
246-
247-
# Target was only tacked on to pass configuration to this function.
248-
service_config = successful_config_map[issue.pop('target')]
226+
service_config = successful_config_map[target]
249227

250228
try:
251229
existing_taskwarrior_uuid = find_taskwarrior_uuid(

bugwarrior/services/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import requests
2020

2121
from bugwarrior.config import schema, secrets
22+
from bugwarrior.types import TaskwarriorData
2223

2324
log = logging.getLogger(__name__)
2425

@@ -121,7 +122,7 @@ def __init__(
121122
self.extra = extra
122123

123124
@abc.abstractmethod
124-
def to_taskwarrior(self) -> dict[str, Any]:
125+
def to_taskwarrior(self) -> TaskwarriorData:
125126
"""Transform a foreign record into a taskwarrior dictionary."""
126127
raise NotImplementedError()
127128

bugwarrior/types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import Any, NamedTuple
2+
3+
TaskwarriorData = dict[str, Any]
4+
5+
6+
class CollectedIssue(NamedTuple):
7+
taskwarrior_data: TaskwarriorData
8+
target: str
9+
identifier: str
10+
11+
12+
class CollectionErrorData(NamedTuple):
13+
error_message: str
14+
target: str

tests/test_db.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import taskw.task
55

66
from bugwarrior import db
7+
from bugwarrior.types import CollectedIssue
78

89
from .base import ConfigTest
910

@@ -58,19 +59,8 @@ def test_handles_missing_tags(self):
5859

5960

6061
class TestSynchronize(ConfigTest):
61-
def test_synchronize(self):
62-
def remove_non_deterministic_keys(tasks):
63-
for status in ['pending', 'completed']:
64-
for task in tasks[status]:
65-
del task['modified']
66-
del task['entry']
67-
del task['uuid']
68-
task['tags'] = sorted(task['tags'])
69-
return tasks
70-
71-
def get_tasks(tw):
72-
return remove_non_deterministic_keys(tw.load_tasks())
73-
62+
def setUp(self):
63+
super().setUp()
7464
self.config = {
7565
'general': {
7666
'targets': ['my_service'],
@@ -84,10 +74,38 @@ def get_tasks(tw):
8474
'token': 'abc123',
8575
},
8676
}
87-
bwconfig = self.validate()
77+
self.bwconfig = self.validate()
78+
self.tw = taskw.TaskWarrior(self.taskrc)
79+
80+
def synchronize(self, issues_data):
81+
82+
issue_generator = [
83+
CollectedIssue(
84+
taskwarrior_data=copy.deepcopy(issue_data),
85+
target="my_service",
86+
identifier="abcd",
87+
)
88+
for issue_data in issues_data
89+
]
90+
db.synchronize(iter(issue_generator), self.bwconfig)
91+
92+
def remove_non_deterministic_keys(self, tasks):
93+
for status in ['pending', 'completed']:
94+
for task in tasks[status]:
95+
del task['modified']
96+
del task['entry']
97+
del task['uuid']
98+
task['tags'] = sorted(task['tags'])
8899

89-
tw = taskw.TaskWarrior(self.taskrc)
90-
self.assertEqual(tw.load_tasks(), {'completed': [], 'pending': []})
100+
return tasks
101+
102+
def get_tasks(self):
103+
104+
return self.remove_non_deterministic_keys(self.tw.load_tasks())
105+
106+
def test_synchronize(self):
107+
108+
self.assertEqual(self.tw.load_tasks(), {'completed': [], 'pending': []})
91109

92110
issue = {
93111
'description': 'Blah blah blah. ☃',
@@ -96,7 +114,6 @@ def get_tasks(tw):
96114
'githuburl': 'https://example.com',
97115
'priority': 'M',
98116
'tags': ['foo'],
99-
'target': 'my_service',
100117
}
101118
duplicate_issue = copy.deepcopy(issue)
102119
duplicate_issue['tags'] = ['bar']
@@ -107,11 +124,10 @@ def get_tasks(tw):
107124
# These should be de-duplicated in db.synchronize before
108125
# writing out to taskwarrior.
109126
# https://github.com/ralphbean/bugwarrior/issues/601
110-
issue_generator = iter((copy.deepcopy(issue), duplicate_issue))
111-
db.synchronize(issue_generator, bwconfig)
127+
self.synchronize([issue, duplicate_issue])
112128

113129
self.assertEqual(
114-
get_tasks(tw),
130+
self.get_tasks(),
115131
{
116132
'completed': [],
117133
'pending': [
@@ -135,11 +151,10 @@ def get_tasks(tw):
135151

136152
# Change static field
137153
issue['project'] = 'other_project'
138-
139-
db.synchronize(iter((copy.deepcopy(issue),)), bwconfig)
154+
self.synchronize([issue])
140155

141156
self.assertEqual(
142-
get_tasks(tw),
157+
self.get_tasks(),
143158
{
144159
'completed': [],
145160
'pending': [
@@ -159,11 +174,11 @@ def get_tasks(tw):
159174
)
160175

161176
# TEST CLOSED ISSUE.
162-
db.synchronize(iter(()), bwconfig)
177+
self.synchronize([])
163178

164-
completed_tasks = tw.load_tasks()
179+
completed_tasks = self.tw.load_tasks()
165180

166-
tasks = remove_non_deterministic_keys(copy.deepcopy(completed_tasks))
181+
tasks = self.remove_non_deterministic_keys(copy.deepcopy(completed_tasks))
167182
del tasks['completed'][0]['end']
168183
self.assertEqual(
169184
tasks,
@@ -186,14 +201,14 @@ def get_tasks(tw):
186201
)
187202

188203
# TEST REOPENED ISSUE
189-
db.synchronize(iter((copy.deepcopy(issue),)), bwconfig)
204+
self.synchronize([issue])
190205

191-
tasks = tw.load_tasks()
206+
tasks = self.tw.load_tasks()
192207
self.assertEqual(
193208
completed_tasks['completed'][0]['uuid'], tasks['pending'][0]['uuid']
194209
)
195210

196-
tasks = remove_non_deterministic_keys(tasks)
211+
tasks = self.remove_non_deterministic_keys(tasks)
197212
self.assertEqual(
198213
tasks,
199214
{

0 commit comments

Comments
 (0)