Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 139 additions & 45 deletions src/hooks/hooks.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ extern "C" {
#include "utils/lsyscache.h"
#include "utils/timestamp.h"

#include "parser/analyze.h"

#if PG_VERSION_NUM >= 140000
#include "nodes/queryjumble.h"
#endif
Expand All @@ -32,10 +34,13 @@ extern "C" {

#include "config/guc.h"
#include "hooks/hooks.h"
#include "hooks/query_normalize.h"
#include "hooks/query_normalize_state.h"
#include "queue/event.h"
#include "queue/shmem.h"

// Previous hook values for chaining
static post_parse_analyze_hook_type prev_post_parse_analyze = nullptr;
static ExecutorStart_hook_type prev_executor_start = nullptr;
static ExecutorRun_hook_type prev_executor_run = nullptr;
static ExecutorFinish_hook_type prev_executor_finish = nullptr;
Expand Down Expand Up @@ -72,8 +77,10 @@ static uint8 CopyName(char* dst, size_t dst_size, const char* src) {
// Cache for session-stable values to avoid repeated catalog lookups on every query.
// Following pg_stat_monitor's pattern of caching client IP (pg_stat_monitor.c:73-96).
// Database name and client address never change within a session. Username is
// re-resolved when userid changes (handles SET ROLE).
struct BackendCache {
// re-resolved when userid changes (handles SET ROLE). This also carries the
// session-local registry of normalized statements waiting to be consumed by
// ExecutorEnd, ProcessUtility, or emit_log_hook.
struct PschBackendState {
bool initialized;
char datname[NAMEDATALEN];
uint8 datname_len;
Expand All @@ -82,8 +89,9 @@ struct BackendCache {
uint8 username_len;
char client_addr[46]; // INET6_ADDRSTRLEN
uint8 client_addr_len;
PschNormalizedQueryState normalized_queries;
};
static BackendCache backend_cache = {};
static PschBackendState backend_state = {};

// Resolve and cache the current username. On initial resolve, falls back to
// "<unknown>" if resolution fails. On SET ROLE re-resolve, keeps the existing
Expand All @@ -92,13 +100,13 @@ static BackendCache backend_cache = {};
static void CacheUsername(Oid userid, bool fallback_on_null) {
const char* username = GetUserNameFromId(userid, true);
if (username != nullptr) {
backend_cache.username_len =
CopyName(backend_cache.username, sizeof(backend_cache.username), username);
backend_cache.cached_userid = userid;
backend_state.username_len =
CopyName(backend_state.username, sizeof(backend_state.username), username);
backend_state.cached_userid = userid;
} else if (fallback_on_null) {
backend_cache.username_len =
CopyName(backend_cache.username, sizeof(backend_cache.username), "<unknown>");
backend_cache.cached_userid = userid;
backend_state.username_len =
CopyName(backend_state.username, sizeof(backend_state.username), "<unknown>");
backend_state.cached_userid = userid;
}
}

Expand All @@ -108,30 +116,30 @@ static void CacheUsername(Oid userid, bool fallback_on_null) {
static void EnsureBackendCache(void) {
Oid userid = GetUserId();

if (!backend_cache.initialized) {
if (!backend_state.initialized) {
// Can't resolve catalog names outside a transaction
if (!IsTransactionState()) {
return;
}

// Database name (session-stable)
const char* datname = get_database_name(MyDatabaseId);
backend_cache.datname_len = CopyName(backend_cache.datname, sizeof(backend_cache.datname),
backend_state.datname_len = CopyName(backend_state.datname, sizeof(backend_state.datname),
datname != nullptr ? datname : "<unknown>");

// Client address (session-stable)
backend_cache.client_addr_len = static_cast<uint8>(
GetClientAddress(backend_cache.client_addr, sizeof(backend_cache.client_addr)));
backend_state.client_addr_len = static_cast<uint8>(
GetClientAddress(backend_state.client_addr, sizeof(backend_state.client_addr)));

// Username (may change via SET ROLE)
CacheUsername(userid, true);

backend_cache.initialized = true;
backend_state.initialized = true;
return;
}

// Re-resolve username if userid changed (SET ROLE)
if (backend_cache.cached_userid != userid) {
if (backend_state.cached_userid != userid) {
if (IsTransactionState()) {
CacheUsername(userid, false);
}
Expand Down Expand Up @@ -335,20 +343,52 @@ static void CopyClientContext(PschEvent* event) {
GetApplicationName(event->application_name, sizeof(event->application_name)));

EnsureBackendCache();
if (backend_cache.initialized) {
memcpy(event->client_addr, backend_cache.client_addr, backend_cache.client_addr_len + 1);
event->client_addr_len = backend_cache.client_addr_len;
if (backend_state.initialized) {
memcpy(event->client_addr, backend_state.client_addr, backend_state.client_addr_len + 1);
event->client_addr_len = backend_state.client_addr_len;
} else {
event->client_addr_len =
static_cast<uint8>(GetClientAddress(event->client_addr, sizeof(event->client_addr)));
}
}

static void CopyQueryText(PschEvent* event, const char* query_text) {
if (query_text != nullptr) {
event->query_len =
static_cast<uint16>(CopyTrimmed(event->query, PSCH_MAX_QUERY_LEN, query_text));
// Copy the original SQL text for one statement into the event buffer.
//
// PostgreSQL can hand us a multi-statement source string plus stmt_location /
// stmt_len. CleanQuerytext trims that down to just the current statement, but
// it does not parameterize literals. This helper is the raw-text fallback when
// we have no normalized entry for the statement.
static void CopyRawStatementText(PschEvent* event, const PschStatementKey& statement_key) {
if (statement_key.source_text == nullptr) {
return;
}

const char* query_text = statement_key.source_text;
if (statement_key.stmt_location >= 0) {
int query_loc = statement_key.stmt_location;
int query_len = statement_key.stmt_len;
query_text = CleanQuerytext(query_text, &query_loc, &query_len);
}

event->query_len = static_cast<uint16>(CopyTrimmed(event->query, PSCH_MAX_QUERY_LEN, query_text));
}

// Copy query text into the event buffer, preferring a previously normalized
// form from post_parse_analyze_hook.
//
// The normalized registry is keyed by statement identity and reused across
// repeated executions of cached plans, so this helper first looks up the
// normalized entry stashed at parse time. If no match exists, it falls back to
// CopyRawStatementText, which preserves the literal SQL text for the current
// statement only.
static void CopyQueryText(PschEvent* event, const PschStatementKey& statement_key) {
if (PschCopyNormalizedQueryForStatement(&backend_state.normalized_queries, event->query,
sizeof(event->query), &event->query_len, statement_key,
false)) {
return;
}

CopyRawStatementText(event, statement_key);
}

// Resolve database and user names, using the session cache when available.
Expand All @@ -357,11 +397,11 @@ static void CopyQueryText(PschEvent* event, const char* query_text) {
static void ResolveNames(PschEvent* event) {
EnsureBackendCache();

if (backend_cache.initialized) {
memcpy(event->datname, backend_cache.datname, backend_cache.datname_len + 1);
event->datname_len = backend_cache.datname_len;
memcpy(event->username, backend_cache.username, backend_cache.username_len + 1);
event->username_len = backend_cache.username_len;
if (backend_state.initialized) {
memcpy(event->datname, backend_state.datname, backend_state.datname_len + 1);
event->datname_len = backend_state.datname_len;
memcpy(event->username, backend_state.username, backend_state.username_len + 1);
event->username_len = backend_state.username_len;
return;
}

Expand Down Expand Up @@ -440,11 +480,58 @@ static void BuildEventFromQueryDesc(QueryDesc* query_desc, PschEvent* event, int
CopyJitInstrumentation(event, query_desc);
CopyParallelWorkerInfo(event, query_desc);
CopyClientContext(event);
CopyQueryText(event, query_desc->sourceText);
const PschStatementKey statement_key =
PschMakeStatementKey(query_desc->sourceText, query_desc->plannedstmt->stmt_location,
query_desc->plannedstmt->stmt_len);
CopyQueryText(event, statement_key);
}

extern "C" {

// Remove a pending normalized entry for one statement when execution exits
// without building a normal executor/utility event from it.
static void ForgetNormalizedStatement(const char* source_text, int stmt_location, int stmt_len) {
const PschStatementKey statement_key = PschMakeStatementKey(source_text, stmt_location, stmt_len);
PschForgetNormalizedQueryForStatement(&backend_state.normalized_queries, statement_key);
}

// post_parse_analyze_hook — normalize query text at parse time.
// The JumbleState (with constant locations) is only available here, so we
// must generate the normalized text now and stash it for ExecutorEnd.
static void PschPostParseAnalyze(ParseState* pstate, Query* query, JumbleState* jstate) {
if (prev_post_parse_analyze != nullptr) {
prev_post_parse_analyze(pstate, query, jstate);
}

// Only normalize if enabled and the query has constants to replace.
if (!psch_enabled || IsParallelWorker() || jstate == nullptr || jstate->clocations_count <= 0) {
return;
}

const char* source_text = pstate->p_sourcetext;
const int stmt_location = query->stmt_location;
const int stmt_len = query->stmt_len;
const char* query_text = source_text;
int query_loc = stmt_location;
int query_len = stmt_len;

// CleanQuerytext slices a multi-statement source string down to the current
// statement before we replace literal constants with placeholders.
query_text = CleanQuerytext(query_text, &query_loc, &query_len);

// Allocate in TopMemoryContext so the normalized text survives until
// ExecutorEnd or ProcessUtility copies it into the exported event.
MemoryContext oldcxt = MemoryContextSwitchTo(TopMemoryContext);
char* normalized_query = PschNormalizeQuery(query_text, query_loc, &query_len, jstate);
if (normalized_query != nullptr) {
const PschStatementKey statement_key =
PschMakeStatementKey(source_text, stmt_location, stmt_len);
PschRememberNormalizedQuery(&backend_state.normalized_queries, statement_key, normalized_query,
query_len);
}
MemoryContextSwitchTo(oldcxt);
}
Comment on lines +498 to +533
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normalized_query is a single per-backend buffer that can become stale or be overwritten before it is consumed. In particular: (1) PschPostParseAnalyze() can run for statements that later hit the early-return path in PschExecutorEnd() (e.g. queryId == 0), leaving normalized_query set; the next captured query will then incorrectly use this stale normalized text. (2) nested statements (SPI / functions) can overwrite normalized_query before the outer statement reaches ExecutorEnd, causing the outer event to lose normalization or pick up the inner query’s text. Consider storing normalized text in a stack keyed by nesting_level, or in a small per-backend map keyed by queryId/statement identity, and ensure the entry is removed on all completion paths (including the PschExecutorEnd() early return and utility paths).

Copilot uses AI. Check for mistakes.

static void PschExecutorStart(QueryDesc* query_desc, int eflags) {
if (IsParallelWorker()) {
if (prev_executor_start != nullptr) {
Expand Down Expand Up @@ -557,6 +644,8 @@ static void PschExecutorFinish(QueryDesc* query_desc) {

static void PschExecutorEnd(QueryDesc* query_desc) {
if (!psch_enabled || IsParallelWorker() || query_desc->plannedstmt->queryId == UINT64CONST(0)) {
ForgetNormalizedStatement(query_desc->sourceText, query_desc->plannedstmt->stmt_location,
query_desc->plannedstmt->stmt_len);
if (prev_executor_end != nullptr) {
prev_executor_end(query_desc);
} else {
Expand Down Expand Up @@ -591,9 +680,9 @@ static void PschExecutorEnd(QueryDesc* query_desc) {

// Build a PschEvent for utility statements (no QueryDesc available)
static void BuildEventForUtility(PschEvent* event, const char* queryString, TimestampTz start_ts,
uint64 duration_us, bool is_top_level, uint64 rows,
BufferUsage* bufusage, WalUsage* walusage, int64 cpu_user_us,
int64 cpu_sys_us) {
int stmt_location, int stmt_len, uint64 duration_us,
bool is_top_level, uint64 rows, BufferUsage* bufusage,
WalUsage* walusage, int64 cpu_user_us, int64 cpu_sys_us) {
InitBaseEvent(event, start_ts, is_top_level, PSCH_CMD_UTILITY);
event->duration_us = duration_us;
event->rows = rows;
Expand All @@ -604,7 +693,8 @@ static void BuildEventForUtility(PschEvent* event, const char* queryString, Time
CopyIoTiming(event, bufusage);
CopyWalUsage(event, walusage);
CopyClientContext(event);
CopyQueryText(event, queryString);
const PschStatementKey statement_key = PschMakeStatementKey(queryString, stmt_location, stmt_len);
CopyQueryText(event, statement_key);
}

// Helper macro to call ProcessUtility (previous hook or standard)
Expand Down Expand Up @@ -684,13 +774,18 @@ static void PschProcessUtility(PlannedStmt* pstmt, const char* queryString,
QueryCompletion* qc) {
#endif
if (!ShouldTrackUtility(pstmt->utilityStmt)) {
int stmt_location = pstmt->stmt_location;
int stmt_len = pstmt->stmt_len;
CALL_PROCESS_UTILITY();
ForgetNormalizedStatement(queryString, stmt_location, stmt_len);
return;
}

// Capture state before execution
bool is_top_level = (nesting_level == 0);
TimestampTz start_ts = GetCurrentTimestamp();
int stmt_location = pstmt->stmt_location;
int stmt_len = pstmt->stmt_len;
BufferUsage bufusage_start = pgBufferUsage;
WalUsage walusage_start = pgWalUsage;
struct rusage rusage_util_start;
Expand Down Expand Up @@ -724,9 +819,9 @@ static void PschProcessUtility(PlannedStmt* pstmt, const char* queryString,
}

PschEvent event;
BuildEventForUtility(&event, queryString, start_ts, INSTR_TIME_GET_MICROSEC(duration),
is_top_level, GetUtilityRowCount(qc), &bufusage_delta, &walusage_delta,
cpu_user_us, cpu_sys_us);
BuildEventForUtility(&event, queryString, start_ts, stmt_location, stmt_len,
INSTR_TIME_GET_MICROSEC(duration), is_top_level, GetUtilityRowCount(qc),
&bufusage_delta, &walusage_delta, cpu_user_us, cpu_sys_us);
PschEnqueueEvent(&event);
}

Expand Down Expand Up @@ -763,14 +858,14 @@ static bool ShouldCaptureLog(ErrorData* edata) {
return true;
}

// Build and enqueue an error event from ErrorData
// Build and enqueue an error event from ErrorData.
//
// NOTE: Query text captured here may contain sensitive data (passwords in CREATE USER,
// connection strings, etc.). Consider this PII concern when configuring ClickHouse
// retention policies. Future enhancement: GUC to disable query capture in error events.
// We intentionally leave event.query empty here. emit_log_hook only exposes
// debug_query_string and cursor position, not the exact statement identity used
// by ExecutorEnd/ProcessUtility, so reconstructing normalized SQL required
// fuzzy matching and extra backend-local state. Error events still carry the
// message, SQLSTATE, and client/session metadata.
static void CaptureLogEvent(ErrorData* edata) {
const char* query = (debug_query_string != nullptr) ? debug_query_string : "";

PschEvent event;
InitBaseEvent(&event, GetCurrentTimestamp(), (nesting_level == 0), PSCH_CMD_UNKNOWN);

Expand All @@ -782,10 +877,6 @@ static void CaptureLogEvent(ErrorData* edata) {
static_cast<uint16>(CopyTrimmed(event.err_message, PSCH_MAX_ERR_MSG_LEN, edata->message));
}

if (query[0] != '\0') {
event.query_len = static_cast<uint16>(CopyTrimmed(event.query, PSCH_MAX_QUERY_LEN, query));
}

CopyClientContext(&event);

// Enqueue with recursion guard
Expand Down Expand Up @@ -825,6 +916,9 @@ void PschInstallHooks(void) {
EnableQueryId();
#endif

prev_post_parse_analyze = post_parse_analyze_hook;
post_parse_analyze_hook = PschPostParseAnalyze;

prev_executor_start = ExecutorStart_hook;
ExecutorStart_hook = PschExecutorStart;

Expand Down
Loading
Loading