|
5 | 5 | # this file be licensed under the Apache-2.0 license or a
|
6 | 6 | # compatible open source license.
|
7 | 7 |
|
8 |
| -import sys |
9 |
| -from types import ModuleType |
10 |
| -from unittest.mock import Mock |
11 |
| - |
12 | 8 | import pytest
|
13 |
| - |
14 |
| - |
15 |
| -class DeletionTracker: |
16 |
| - """Helper class to track object deletions""" |
17 |
| - |
18 |
| - def __init__(self): |
19 |
| - self.deleted_objects = set() |
20 |
| - |
21 |
| - def mark_deleted(self, obj_id): |
22 |
| - self.deleted_objects.add(obj_id) |
23 |
| - |
24 |
| - def is_deleted(self, obj_id): |
25 |
| - return obj_id in self.deleted_objects |
26 |
| - |
27 |
| - def reset(self): |
28 |
| - self.deleted_objects.clear() |
29 |
| - |
30 |
| - |
31 |
| -# Create global deletion tracker |
32 |
| -_deletion_tracker = DeletionTracker() |
| 9 | +from core.common.models.index_build_parameters import ( |
| 10 | + AlgorithmParameters, |
| 11 | + IndexBuildParameters, |
| 12 | + IndexParameters, |
| 13 | + SpaceType, |
| 14 | +) |
33 | 15 |
|
34 | 16 |
|
35 | 17 | @pytest.fixture
|
36 |
| -def deletion_tracker(): |
37 |
| - """Fixture to provide access to deletion tracker""" |
38 |
| - _deletion_tracker.reset() # Reset before each test |
39 |
| - return _deletion_tracker |
40 |
| - |
41 |
| - |
42 |
| -@pytest.fixture(autouse=True) |
43 |
| -def reset_deletion_tracker(): |
44 |
| - """Reset deletion tracker before each test""" |
45 |
| - _deletion_tracker.reset() |
46 |
| - yield |
47 |
| - |
48 |
| - |
49 |
| -class MockGpuIndexCagra: |
50 |
| - """Mock for faiss.GpuIndexCagra with deletion tracking""" |
51 |
| - |
52 |
| - def __init__(self, *args, **kwargs): |
53 |
| - self.id = id(self) |
54 |
| - self.thisown = False |
55 |
| - self.args = args |
56 |
| - self.kwargs = kwargs |
57 |
| - |
58 |
| - def __del__(self): |
59 |
| - print("deleting MockGpuIndexCagra:", self.id) |
60 |
| - _deletion_tracker.mark_deleted(self.id) |
61 |
| - |
62 |
| - @property |
63 |
| - def is_deleted(self): |
64 |
| - return _deletion_tracker.is_deleted(self.id) |
65 |
| - |
66 |
| - def copyTo(self, cpu_index): |
67 |
| - """Mock implementation of copyTo method""" |
68 |
| - if not isinstance(cpu_index, MockIndexHNSWCagra): |
69 |
| - raise TypeError("Target must be IndexHNSWCagra") |
70 |
| - # Simulate copying data to CPU index |
71 |
| - return True |
72 |
| - |
73 |
| - |
74 |
| -class MockIndexIDMap: |
75 |
| - """Mock for faiss.IndexIDMap with deletion tracking""" |
76 |
| - |
77 |
| - def __init__(self, *args, **kwargs): |
78 |
| - self.id = id(self) |
79 |
| - self.own_fields = False |
80 |
| - self.index = None |
81 |
| - self.args = args |
82 |
| - self.kwargs = kwargs |
83 |
| - |
84 |
| - def __del__(self): |
85 |
| - print("deleting MockIndexIDMap:", self.id) |
86 |
| - _deletion_tracker.mark_deleted(self.id) |
87 |
| - |
88 |
| - @property |
89 |
| - def is_deleted(self): |
90 |
| - return _deletion_tracker.is_deleted(self.id) |
91 |
| - |
92 |
| - def add_with_ids(self, vectors, ids): |
93 |
| - pass |
94 |
| - |
95 |
| - |
96 |
| -class MockIndexHNSWCagra(Mock): |
97 |
| - """Mock for faiss.IndexHNSWCagra""" |
98 |
| - |
99 |
| - def __init__(self, *args, **kwargs): |
100 |
| - super().__init__(*args, **kwargs) |
101 |
| - self.hnsw = Mock() |
102 |
| - self.base_level_only = True |
103 |
| - |
104 |
| - def __del__(self): |
105 |
| - _deletion_tracker.mark_deleted(self.id) |
106 |
| - |
107 |
| - @property |
108 |
| - def is_deleted(self): |
109 |
| - return _deletion_tracker.is_deleted(self.id) |
110 |
| - |
111 |
| - |
112 |
| -class MockIVFPQBuildCagraConfig: |
113 |
| - """Mock class for faiss.IVFPQBuildCagraConfig""" |
114 |
| - |
115 |
| - def __init__(self): |
116 |
| - self.n_lists = 1024 |
117 |
| - self.kmeans_n_iters = 20 |
118 |
| - self.kmeans_trainset_fraction = 0.5 |
119 |
| - self.pq_bits = 8 |
120 |
| - self.pq_dim = 0 |
121 |
| - self.conservative_memory_allocation = True |
122 |
| - |
123 |
| - |
124 |
| -class MockIVFPQSearchCagraConfig: |
125 |
| - """Mock class for faiss.IVFPQSearchCagraConfig""" |
126 |
| - |
127 |
| - def __init__(self): |
128 |
| - self.n_probes = 20 |
129 |
| - |
130 |
| - |
131 |
| -class MockGpuIndexCagraConfig: |
132 |
| - """Mock class for faiss.GpuIndexCagraConfig""" |
133 |
| - |
134 |
| - def __init__(self): |
135 |
| - self.intermediate_graph_degree = 64 |
136 |
| - self.graph_degree = 32 |
137 |
| - self.store_dataset = False |
138 |
| - self.device = 0 |
139 |
| - self.refine_rate = 2.0 |
140 |
| - self.build_algo = None |
141 |
| - self.ivf_pq_build_config = None |
142 |
| - self.ivf_pq_search_config = None |
143 |
| - |
144 |
| - |
145 |
| -class FaissMock(ModuleType): |
146 |
| - """Complete mock for faiss module""" |
147 |
| - |
148 |
| - def __init__(self): |
149 |
| - super().__init__("faiss") |
150 |
| - # Classes |
151 |
| - self.StandardGpuResources = Mock() |
152 |
| - self.GpuIndexCagra = MockGpuIndexCagra |
153 |
| - self.IndexIDMap = MockIndexIDMap |
154 |
| - self.IndexHNSWCagra = MockIndexHNSWCagra |
155 |
| - self.IVFPQBuildCagraConfig = MockIVFPQBuildCagraConfig |
156 |
| - self.IVFPQSearchCagraConfig = MockIVFPQSearchCagraConfig |
157 |
| - self.GpuIndexCagraConfig = MockGpuIndexCagraConfig |
158 |
| - |
159 |
| - # Enums |
160 |
| - self.graph_build_algo_IVF_PQ = 1 |
161 |
| - |
162 |
| - self.METRIC_L2 = 0 |
163 |
| - self.METRIC_INNER_PRODUCT = 1 |
164 |
| - |
165 |
| - self._num_threads = None |
166 |
| - self.omp_set_num_threads = self._omp_set_num_threads |
167 |
| - self.omp_get_num_threads = self._omp_get_num_threads |
168 |
| - |
169 |
| - self.write_index = self._write_index |
170 |
| - |
171 |
| - def _omp_set_num_threads(self, num_threads: int) -> None: |
172 |
| - self._num_threads = num_threads |
173 |
| - |
174 |
| - def _omp_get_num_threads(self) -> int: |
175 |
| - return self._num_threads |
176 |
| - |
177 |
| - def _write_index(self, index, filepath): |
178 |
| - if not isinstance(filepath, str): |
179 |
| - raise TypeError("Filepath must be a string") |
180 |
| - if not index: |
181 |
| - raise ValueError("Index cannot be None") |
182 |
| - try: |
183 |
| - with open(filepath, "wb") as f: |
184 |
| - f.write(b"MOCK_INDEX") |
185 |
| - except IOError as e: |
186 |
| - raise IOError(f"Failed to write to {filepath}: {str(e)}") |
187 |
| - |
188 |
| - |
189 |
| -# Create the mock and patch faiss |
190 |
| -faiss_mock = FaissMock() |
191 |
| -sys.modules["faiss"] = faiss_mock |
| 18 | +def index_build_parameters(): |
| 19 | + """Create sample IndexBuildParameters for testing""" |
| 20 | + return IndexBuildParameters( |
| 21 | + container_name="testbucket", |
| 22 | + vector_path="vec.knnvec", |
| 23 | + doc_id_path="doc.knndid", |
| 24 | + dimension=3, |
| 25 | + doc_count=5, |
| 26 | + index_parameters=IndexParameters( |
| 27 | + space_type=SpaceType.INNERPRODUCT, |
| 28 | + algorithm_parameters=AlgorithmParameters( |
| 29 | + ef_construction=200, ef_search=200 |
| 30 | + ), |
| 31 | + ), |
| 32 | + repository_type="s3", |
| 33 | + ) |
0 commit comments