Skip to content

Commit 1002c67

Browse files
committed
fix: plugin threads synchronization
- use reentrant lock to synchronize access to Pool singletone. - make concurrent modifications of _cache structure thread-safe.
1 parent 3ab4aed commit 1002c67

File tree

1 file changed

+104
-92
lines changed
  • mamonsu/plugins/pgsql/driver

1 file changed

+104
-92
lines changed

mamonsu/plugins/pgsql/driver/pool.py

+104-92
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from .connection import Connection, ConnectionInfo
22

33
from mamonsu.lib.version import parse
4+
import threading
5+
46

57
class Pool(object):
68
ExcludeDBs = ["template0", "template1"]
@@ -107,10 +109,11 @@ def __init__(self, params=None):
107109
"bootstrap": {"storage": {}, "counter": 0, "cache": 10, "version": False},
108110
"recovery": {"storage": {}, "counter": 0, "cache": 10},
109111
"extension_schema": {"pg_buffercache": {}, "pg_stat_statements": {}, "pg_wait_sampling": {}, "pgpro_stats": {}},
110-
"extension_versions" : {},
112+
"extension_versions": {},
111113
"pgpro": {"storage": {}},
112114
"pgproee": {"storage": {}}
113115
}
116+
self._lock = threading.RLock()
114117

115118
def connection_string(self, db=None):
116119
db = self._normalize_db(db)
@@ -122,30 +125,32 @@ def query(self, query, db=None):
122125
return self._connections[db].query(query)
123126

124127
def server_version(self, db=None):
125-
db = self._normalize_db(db)
126-
if db in self._cache["server_version"]["storage"]:
128+
with self._lock:
129+
db = self._normalize_db(db)
130+
if db in self._cache["server_version"]["storage"]:
131+
return self._cache["server_version"]["storage"][db]
132+
133+
version_string = self.query("show server_version", db)[0][0]
134+
result = bytes(
135+
version_string.split(" ")[0], "utf-8")
136+
self._cache["server_version"]["storage"][db] = "{0}".format(
137+
result.decode("ascii"))
127138
return self._cache["server_version"]["storage"][db]
128139

129-
version_string = self.query("show server_version", db)[0][0]
130-
result = bytes(
131-
version_string.split(" ")[0], "utf-8")
132-
self._cache["server_version"]["storage"][db] = "{0}".format(
133-
result.decode("ascii"))
134-
return self._cache["server_version"]["storage"][db]
135-
136140
def extension_version(self, extension, db=None):
137-
db = self._normalize_db(db)
138-
if extension in self._cache["extension_versions"] and db in self._cache["extension_versions"][extension][db]:
141+
with self._lock:
142+
db = self._normalize_db(db)
143+
if extension in self._cache["extension_versions"] and db in self._cache["extension_versions"][extension][db]:
144+
return self._cache["extension_versions"][extension][db]
145+
146+
version_string = self.query("select extversion from pg_catalog.pg_extension where lower(extname) = lower('{0}');".format(extension), db)[0][0]
147+
result = bytes(
148+
version_string.split(" ")[0], "utf-8")
149+
self._cache["extension_versions"][extension] = {}
150+
self._cache["extension_versions"][extension][db] = "{0}".format(
151+
result.decode("ascii"))
139152
return self._cache["extension_versions"][extension][db]
140153

141-
version_string = self.query("select extversion from pg_catalog.pg_extension where lower(extname) = lower('{0}');".format(extension), db)[0][0]
142-
result = bytes(
143-
version_string.split(" ")[0], "utf-8")
144-
self._cache["extension_versions"][extension] = {}
145-
self._cache["extension_versions"][extension][db] = "{0}".format(
146-
result.decode("ascii"))
147-
return self._cache["extension_versions"][extension][db]
148-
149154
def server_version_greater(self, version, db=None):
150155
db = self._normalize_db(db)
151156
return parse(self.server_version(db)) >= parse(version)
@@ -155,49 +160,53 @@ def server_version_less(self, version, db=None):
155160
return parse(self.server_version(db)) <= parse(version)
156161

157162
def bootstrap_version_greater(self, version):
158-
return parse(str(self._cache["bootstrap"]["version"])) >= parse(version)
163+
with self._lock:
164+
return parse(str(self._cache["bootstrap"]["version"])) >= parse(version)
159165

160166
def bootstrap_version_less(self, version):
161-
return parse(str(self._cache["bootstrap"]["version"])) <= parse(version)
167+
with self._lock:
168+
return parse(str(self._cache["bootstrap"]["version"])) <= parse(version)
162169

163170
def in_recovery(self, db=None):
164-
db = self._normalize_db(db)
165-
if db in self._cache["recovery"]["storage"]:
166-
if self._cache["recovery"]["counter"] < self._cache["recovery"]["cache"]:
167-
self._cache["recovery"]["counter"] += 1
168-
return self._cache["recovery"]["storage"][db]
169-
self._cache["recovery"]["counter"] = 0
170-
self._cache["recovery"]["storage"][db] = self.query(
171-
"select pg_catalog.pg_is_in_recovery()", db)[0][0]
172-
return self._cache["recovery"]["storage"][db]
171+
with self._lock:
172+
db = self._normalize_db(db)
173+
if db in self._cache["recovery"]["storage"]:
174+
if self._cache["recovery"]["counter"] < self._cache["recovery"]["cache"]:
175+
self._cache["recovery"]["counter"] += 1
176+
return self._cache["recovery"]["storage"][db]
177+
self._cache["recovery"]["counter"] = 0
178+
self._cache["recovery"]["storage"][db] = self.query(
179+
"select pg_catalog.pg_is_in_recovery()", db)[0][0]
180+
return self._cache["recovery"]["storage"][db]
173181

174182
def is_bootstraped(self, db=None):
175-
db = self._normalize_db(db)
176-
if db in self._cache["bootstrap"]["storage"]:
177-
if self._cache["bootstrap"]["counter"] < self._cache["bootstrap"]["cache"]:
178-
self._cache["bootstrap"]["counter"] += 1
179-
return self._cache["bootstrap"]["storage"][db]
180-
self._cache["bootstrap"]["counter"] = 0
181-
# TODO: изменить на нормальное название, 'config' слишком общее
182-
sql = """
183-
SELECT count(*)
184-
FROM pg_catalog.pg_class
185-
WHERE relname = 'config';
186-
"""
187-
result = int(self.query(sql, db)[0][0])
188-
self._cache["bootstrap"]["storage"][db] = (result == 1)
189-
if self._cache["bootstrap"]["storage"][db]:
190-
self._connections[db].log.info("Found mamonsu bootstrap")
183+
with self._lock:
184+
db = self._normalize_db(db)
185+
if db in self._cache["bootstrap"]["storage"]:
186+
if self._cache["bootstrap"]["counter"] < self._cache["bootstrap"]["cache"]:
187+
self._cache["bootstrap"]["counter"] += 1
188+
return self._cache["bootstrap"]["storage"][db]
189+
self._cache["bootstrap"]["counter"] = 0
190+
# TODO: изменить на нормальное название, 'config' слишком общее
191191
sql = """
192-
SELECT max(version)
193-
FROM mamonsu.config;
192+
SELECT count(*)
193+
FROM pg_catalog.pg_class
194+
WHERE relname = 'config';
194195
"""
195-
self._cache["bootstrap"]["version"] = self.query(sql, db)[0][0]
196-
else:
197-
self._connections[db].log.info("Mamonsu bootstrap is not found")
198-
self._connections[db].log.info(
199-
"hint: run `mamonsu bootstrap` if you want to run without superuser rights")
200-
return self._cache["bootstrap"]["storage"][db]
196+
result = int(self.query(sql, db)[0][0])
197+
self._cache["bootstrap"]["storage"][db] = (result == 1)
198+
if self._cache["bootstrap"]["storage"][db]:
199+
self._connections[db].log.info("Found mamonsu bootstrap")
200+
sql = """
201+
SELECT max(version)
202+
FROM mamonsu.config;
203+
"""
204+
self._cache["bootstrap"]["version"] = self.query(sql, db)[0][0]
205+
else:
206+
self._connections[db].log.info("Mamonsu bootstrap is not found")
207+
self._connections[db].log.info(
208+
"hint: run `mamonsu bootstrap` if you want to run without superuser rights")
209+
return self._cache["bootstrap"]["storage"][db]
201210

202211
def is_superuser(self, db=None):
203212
_ = self._normalize_db(db)
@@ -209,34 +218,36 @@ def is_superuser(self, db=None):
209218
return False
210219

211220
def is_pgpro(self, db=None):
212-
db = self._normalize_db(db)
213-
if db in self._cache["pgpro"]:
221+
with self._lock:
222+
db = self._normalize_db(db)
223+
if db in self._cache["pgpro"]:
224+
return self._cache["pgpro"][db]
225+
try:
226+
self.query("""
227+
SELECT pgpro_version();
228+
""")
229+
self._cache["pgpro"][db] = True
230+
except:
231+
self._cache["pgpro"][db] = False
214232
return self._cache["pgpro"][db]
215-
try:
216-
self.query("""
217-
SELECT pgpro_version();
218-
""")
219-
self._cache["pgpro"][db] = True
220-
except:
221-
self._cache["pgpro"][db] = False
222-
return self._cache["pgpro"][db]
223233

224234
def is_pgpro_ee(self, db=None):
225-
db = self._normalize_db(db)
226-
if not self.is_pgpro(db):
227-
return False
228-
if db in self._cache["pgproee"]:
235+
with self._lock:
236+
db = self._normalize_db(db)
237+
if not self.is_pgpro(db):
238+
return False
239+
if db in self._cache["pgproee"]:
240+
return self._cache["pgproee"][db]
241+
try:
242+
ed = self.query("""
243+
SELECT pgpro_edition();
244+
""")[0][0]
245+
self._connections[db].log.info("pgpro_edition is {}".format(ed))
246+
self._cache["pgproee"][db] = (ed.lower() == "enterprise")
247+
except:
248+
self._connections[db].log.info("pgpro_edition() is not defined")
249+
self._cache["pgproee"][db] = False
229250
return self._cache["pgproee"][db]
230-
try:
231-
ed = self.query("""
232-
SELECT pgpro_edition();
233-
""")[0][0]
234-
self._connections[db].log.info("pgpro_edition is {}".format(ed))
235-
self._cache["pgproee"][db] = (ed.lower() == "enterprise")
236-
except:
237-
self._connections[db].log.info("pgpro_edition() is not defined")
238-
self._cache["pgproee"][db] = False
239-
return self._cache["pgproee"][db]
240251

241252
def extension_version_greater(self, extension, version, db=None):
242253
db = self._normalize_db(db)
@@ -256,19 +267,20 @@ def extension_installed(self, ext, db=None):
256267
return (int(result[0][0])) == 1
257268

258269
def extension_schema(self, extension, db=None):
259-
db = self._normalize_db(db)
260-
if db in self._cache["extension_schema"][extension]:
261-
return self._cache["extension_schema"][extension][db]
262-
try:
263-
self._cache["extension_schema"][extension][db] = self.query("""
264-
SELECT n.nspname
265-
FROM pg_extension e
266-
JOIN pg_namespace n ON e.extnamespace = n.oid
267-
WHERE e.extname = '{0}'
268-
""".format(extension), db)[0][0]
269-
return self._cache["extension_schema"][extension][db]
270-
except:
271-
self._connections[db].log.info("{0} is not installed".format(extension))
270+
with self._lock:
271+
db = self._normalize_db(db)
272+
if db in self._cache["extension_schema"][extension]:
273+
return self._cache["extension_schema"][extension][db]
274+
try:
275+
self._cache["extension_schema"][extension][db] = self.query("""
276+
SELECT n.nspname
277+
FROM pg_extension e
278+
JOIN pg_namespace n ON e.extnamespace = n.oid
279+
WHERE e.extname = '{0}'
280+
""".format(extension), db)[0][0]
281+
return self._cache["extension_schema"][extension][db]
282+
except:
283+
self._connections[db].log.info("{0} is not installed".format(extension))
272284

273285
def databases(self):
274286
result, databases = self.query("""

0 commit comments

Comments
 (0)