Skip to content

Commit 59b34ba

Browse files
authored
improve type extraction (#399)
- [x] refactor getColumnInfo to separate responsibilities - [ ] make functions idempotent - [ ] decrease repetition - [ ] improve field analysis
1 parent 91bba82 commit 59b34ba

File tree

4 files changed

+231
-201
lines changed

4 files changed

+231
-201
lines changed

packages/typegen/src/index.ts

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,7 @@ import {psqlClient} from './pg'
1111
import {AnalyseQueryError, columnInfoGetter, isUntypeable, removeSimpleComments, simplifySql} from './query'
1212
import {parameterTypesGetter} from './query/parameters'
1313
import {AnalysedQuery, DescribedQuery, ExtractedQuery, Options, QueryField, QueryParameter} from './types'
14-
import {
15-
changedFiles,
16-
checkClean,
17-
containsIgnoreComment,
18-
globAsync,
19-
globList,
20-
maybeDo,
21-
truncateQuery,
22-
tryOrDefault,
23-
} from './util'
14+
import {changedFiles, checkClean, containsIgnoreComment, globAsync, globList, maybeDo, tryOrDefault} from './util'
2415
import * as write from './write'
2516

2617
import memoizee = require('memoizee')
@@ -81,15 +72,13 @@ export const generate = async (params: Partial<Options>) => {
8172

8273
const getFields = async (query: ExtractedQuery): Promise<QueryField[]> => {
8374
const rows = await gdesc(query.sql)
84-
const fields = await Promise.all(
75+
return await Promise.all(
8576
rows.map<Promise<QueryField>>(async row => ({
8677
name: row.Column,
8778
regtype: row.Type,
8879
typescript: await getTypeScriptType(row.Type, row.Column),
8980
})),
9081
)
91-
92-
return Promise.all(fields)
9382
}
9483

9584
const regTypeToTypeScript = async (regtype: string) => {

packages/typegen/src/query/column-info.ts

Lines changed: 101 additions & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -2,79 +2,24 @@ import * as assert from 'assert'
22
import {createHash} from 'crypto'
33

44
import * as lodash from 'lodash'
5+
import {SelectFromStatement, Statement} from 'pgsql-ast-parser'
56
import {singular} from 'pluralize'
67
import {DatabasePool, sql} from 'slonik'
78

89
import {AnalysedQuery, AnalysedQueryField, DescribedQuery, QueryField} from '../types'
910
import {tryOrDefault} from '../util'
11+
import {ViewResult, getViewResult} from './getViewResult'
1012
import * as parse from './index'
11-
import {getHopefullyViewableAST, getSuggestedTags, isCTE, suggestedTags} from './parse'
12-
import {getViewFriendlySql} from '.'
13-
14-
const _sql = sql
15-
16-
const getTypesSql = _sql`
17-
drop type if exists pg_temp.types_type cascade;
18-
19-
create type pg_temp.types_type as (
20-
schema_name text,
21-
view_name text,
22-
table_column_name text,
23-
query_column_name text,
24-
comment text,
25-
underlying_table_name text,
26-
is_underlying_nullable text,
27-
formatted_query text
28-
);
29-
30-
-- taken from https://dataedo.com/kb/query/postgresql/list-views-columns
31-
-- and https://www.cybertec-postgresql.com/en/abusing-postgresql-as-an-sql-beautifier
32-
-- nullable: https://stackoverflow.com/a/63980243
33-
34-
create or replace function pg_temp.gettypes(sql_query text)
35-
returns setof pg_temp.types_type as
36-
$$
37-
declare
38-
v_tmp_name text;
39-
returnrec types_type;
40-
begin
41-
v_tmp_name := 'temp_view_' || md5(sql_query);
42-
execute 'drop view if exists ' || v_tmp_name;
43-
execute 'create temporary view ' || v_tmp_name || ' as ' || sql_query;
44-
45-
FOR returnrec in
46-
select
47-
vcu.table_schema as schema_name,
48-
vcu.view_name as view_name,
49-
c.column_name,
50-
vcu.column_name,
51-
col_description(
52-
to_regclass(quote_ident(c.table_schema) || '.' || quote_ident(c.table_name)),
53-
c.ordinal_position
54-
),
55-
vcu.table_name as underlying_table_name,
56-
c.is_nullable as is_underlying_nullable,
57-
pg_get_viewdef(v_tmp_name) as formatted_query
58-
from
59-
information_schema.columns c
60-
join
61-
information_schema.view_column_usage vcu
62-
on c.table_name = vcu.table_name
63-
and c.column_name = vcu.column_name
64-
and c.table_schema = vcu.table_schema
65-
where
66-
c.table_name = v_tmp_name
67-
or vcu.view_name = v_tmp_name -- todo: this includes too much! columns which are part of table queried but not selected
68-
loop
69-
return next returnrec;
70-
end loop;
71-
72-
execute 'drop view if exists ' || v_tmp_name;
73-
74-
end;
75-
$$
76-
LANGUAGE 'plpgsql';
77-
`
13+
import {
14+
AliasMapping,
15+
aliasMappings,
16+
astToViewFriendlySql,
17+
getHopefullyViewableAST,
18+
getSuggestedTags,
19+
isCTE,
20+
suggestedTags,
21+
templateToHopefullyViewableAST,
22+
} from './parse'
7823

7924
export class AnalyseQueryError extends Error {
8025
public readonly [Symbol.toStringTag] = 'AnalyseQueryError'
@@ -91,112 +36,93 @@ export class AnalyseQueryError extends Error {
9136
// todo: get table description from obj_description(oid) (like column)
9237

9338
export const columnInfoGetter = (pool: DatabasePool) => {
94-
// const createViewAnalyser = lodash.once(() => pool.query(getTypesSql))
95-
9639
const addColumnInfo = async (query: DescribedQuery): Promise<AnalysedQuery> => {
97-
const cte = isCTE(query.template)
98-
const viewFriendlySql = getViewFriendlySql(query.template)
99-
const suggestedTags = tagsFromDescribedQuery(query)
40+
const viewFriendlyAst = templateToHopefullyViewableAST(query.template)
10041

101-
// await createViewAnalyser()
102-
103-
const viewResultQuery = _sql<GetTypes>`
104-
select
105-
schema_name,
106-
table_column_name,
107-
underlying_table_name,
108-
is_underlying_nullable,
109-
comment,
110-
formatted_query
111-
from
112-
pg_temp.gettypes(${viewFriendlySql})
113-
`
114-
115-
const ast = getHopefullyViewableAST(viewFriendlySql)
116-
if (ast.type !== 'select') {
117-
return {
118-
...query,
119-
suggestedTags,
120-
fields: query.fields.map(defaultAnalysedQueryField),
121-
}
42+
if (viewFriendlyAst.type !== 'select') {
43+
return getDefaultAnalysedQuery(query)
12244
}
12345

124-
const viewResult = cte
46+
const viewFriendlySql = astToViewFriendlySql(viewFriendlyAst)
47+
const viewResult = isCTE(query.template)
12548
? [] // not smart enough to figure out what types are referenced via a CTE
126-
: await pool.transaction(async t => {
127-
await t.query(getTypesSql)
128-
const results = await t.any(viewResultQuery)
129-
return lodash.uniqBy(results, JSON.stringify)
130-
})
131-
132-
const formattedSqlStatements = [...new Set(viewResult.map(r => r.formatted_query))]
133-
134-
assert.ok(formattedSqlStatements.length <= 1, `Expected exactly 1 formatted sql, got ${formattedSqlStatements}`)
135-
136-
const parseableSql = formattedSqlStatements[0] || viewFriendlySql
49+
: await getViewResult(pool, viewFriendlySql)
13750

138-
const parsed = parse.getAliasMappings(parseableSql)
51+
const getFieldInfo = buildGetFieldInfo(viewResult, viewFriendlyAst)
13952

14053
return {
14154
...query,
142-
suggestedTags,
143-
fields: query.fields.map(f => {
144-
const relatedResults = parsed.flatMap(c =>
145-
viewResult
146-
.map(v => ({
147-
...v,
148-
hasNullableJoin: c.hasNullableJoin,
149-
}))
150-
.filter(v => {
151-
assert.ok(v.underlying_table_name, `Table name for ${JSON.stringify(c)} not found`)
152-
return (
153-
c.queryColumn === f.name &&
154-
c.tablesColumnCouldBeFrom.includes(v.underlying_table_name) &&
155-
c.aliasFor === v.table_column_name
156-
)
157-
}),
158-
)
159-
160-
const res = relatedResults.length === 1 ? relatedResults[0] : undefined
161-
162-
// determine nullability
163-
let nullability: AnalysedQueryField['nullability'] = 'unknown'
164-
if (res?.is_underlying_nullable === 'YES') {
165-
nullability = 'nullable'
166-
} else if (res?.hasNullableJoin) {
167-
nullability = 'nullable_via_join'
168-
} else if (res?.is_underlying_nullable === 'NO' || isNonNullableField(parseableSql, f)) {
169-
nullability = 'not_null'
170-
} else {
171-
nullability = 'unknown'
172-
}
173-
174-
return {
175-
...f,
176-
nullability,
177-
column: res && {
178-
schema: res.schema_name!,
179-
table: res.underlying_table_name!,
180-
name: res.table_column_name!,
181-
},
182-
comment: res?.comment || undefined,
183-
}
184-
}),
55+
suggestedTags: generateTags(query),
56+
fields: query.fields.map(getFieldInfo),
18557
}
18658
}
18759

18860
return async (query: DescribedQuery): Promise<AnalysedQuery> =>
18961
addColumnInfo(query).catch(e => {
190-
const recover = {
191-
...query,
192-
suggestedTags: tagsFromDescribedQuery(query),
193-
fields: query.fields.map(defaultAnalysedQueryField),
194-
}
62+
const recover = getDefaultAnalysedQuery(query)
19563
throw new AnalyseQueryError(e, query, recover)
19664
})
19765
}
19866

199-
const tagOptions = (query: DescribedQuery) => {
67+
const buildGetFieldInfo = (viewResult: ViewResult[], ast: SelectFromStatement) => {
68+
const viewableAst =
69+
viewResult[0]?.formatted_query === undefined ? ast : getHopefullyViewableAST(viewResult[0].formatted_query!) // TODO: explore why this fallback might be needed - can't we always use the original ast?
70+
71+
const mappings = aliasMappings(viewableAst)
72+
73+
return function getFieldInfo(field: QueryField) {
74+
const relatedResults = mappings.flatMap(c =>
75+
viewResult
76+
.map(v => ({
77+
...v,
78+
hasNullableJoin: c.hasNullableJoin,
79+
}))
80+
.filter(v => {
81+
assert.ok(v.underlying_table_name, `Table name for ${JSON.stringify(c)} not found`)
82+
return (
83+
c.queryColumn === field.name &&
84+
c.tablesColumnCouldBeFrom.includes(v.underlying_table_name) &&
85+
c.aliasFor === v.table_column_name
86+
)
87+
}),
88+
)
89+
90+
const res = relatedResults.length === 1 ? relatedResults[0] : undefined
91+
92+
// determine nullability
93+
let nullability: AnalysedQueryField['nullability'] = 'unknown'
94+
if (res?.is_underlying_nullable === 'YES') {
95+
nullability = 'nullable'
96+
} else if (res?.hasNullableJoin) {
97+
nullability = 'nullable_via_join'
98+
// TODO: we're converting from sql to ast back and forth for `isNonNullableField`. this is probably unneded
99+
} else if (res?.is_underlying_nullable === 'NO' || isNonNullableField(astToViewFriendlySql(viewableAst), field)) {
100+
nullability = 'not_null'
101+
} else {
102+
nullability = 'unknown'
103+
}
104+
105+
return {
106+
...field,
107+
nullability,
108+
column: res && {
109+
schema: res.schema_name!,
110+
table: res.underlying_table_name!,
111+
name: res.table_column_name!,
112+
},
113+
comment: res?.comment || undefined,
114+
}
115+
}
116+
}
117+
118+
/**
119+
* Generate short hash
120+
*/
121+
const shortHexHash = (str: string) => createHash('md5').update(str).digest('hex').slice(0, 6)
122+
/**
123+
* Uses various strategies to come up with options for tags
124+
*/
125+
const generateTagOptions = (query: DescribedQuery) => {
200126
const sqlTags = tryOrDefault(() => getSuggestedTags(query.template), [])
201127

202128
const codeContextTags = query.context
@@ -224,10 +150,15 @@ const tagOptions = (query: DescribedQuery) => {
224150
return {sqlTags, codeContextTags, fieldTags, anonymousTags}
225151
}
226152

227-
const tagsFromDescribedQuery = (query: DescribedQuery) => {
228-
const options = tagOptions(query)
153+
/**
154+
* Generates a list of tag options based on a query
155+
* @param query DescribedQuery
156+
* @returns List of tag options sorted by quality
157+
*/
158+
const generateTags = (query: DescribedQuery) => {
159+
const options = generateTagOptions(query)
229160

230-
const tags = options.sqlTags.slice()
161+
const tags = [...options.sqlTags]
231162
tags.splice(tags[0]?.slice(1).includes('_') ? 0 : 1, 0, ...options.codeContextTags)
232163
tags.push(...options.fieldTags)
233164
tags.push(...options.codeContextTags)
@@ -236,13 +167,18 @@ const tagsFromDescribedQuery = (query: DescribedQuery) => {
236167
return tags
237168
}
238169

239-
const shortHexHash = (str: string) => createHash('md5').update(str).digest('hex').slice(0, 6)
240-
241-
export const defaultAnalysedQueryField = (f: QueryField): AnalysedQueryField => ({
242-
...f,
243-
nullability: 'unknown',
244-
comment: undefined,
245-
column: undefined,
170+
/**
171+
* Create a fallback, in case we fail to analyse the query
172+
*/
173+
const getDefaultAnalysedQuery = (query: DescribedQuery): AnalysedQuery => ({
174+
...query,
175+
suggestedTags: generateTags(query),
176+
fields: query.fields.map(f => ({
177+
...f,
178+
nullability: 'unknown',
179+
comment: undefined,
180+
column: undefined,
181+
})),
246182
})
247183

248184
const nonNullableExpressionTypes = new Set([
@@ -290,24 +226,3 @@ export const isNonNullableField = (sql: string, field: QueryField) => {
290226
// if there's exactly one column with the same name as the field and matching the conditions above, we can be confident it's not nullable.
291227
return nonNullableColumns.length === 1
292228
}
293-
294-
// this query is for a type in a temp schema so this tool doesn't work with it
295-
export interface GetTypes {
296-
/** postgres type: `text` */
297-
schema_name: string | null
298-
299-
/** postgres type: `text` */
300-
table_column_name: string | null
301-
302-
/** postgres type: `text` */
303-
underlying_table_name: string | null
304-
305-
/** postgres type: `text` */
306-
is_underlying_nullable: string | null
307-
308-
/** postgres type: `text` */
309-
comment: string | null
310-
311-
/** postgres type: `text` */
312-
formatted_query: string | null
313-
}

0 commit comments

Comments
 (0)