Skip to content

Commit db01a53

Browse files
committed
tools: Generate functions for all of contrib
1 parent f420281 commit db01a53

File tree

1 file changed

+125
-15
lines changed

1 file changed

+125
-15
lines changed

internal/tools/sqlc-pg-gen/main.go

Lines changed: 125 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,16 @@ ORDER BY 1;
5959
`
6060

6161
const catalogTmpl = `
62-
package postgresql
62+
// Code generated by sqlc-pg-gen. DO NOT EDIT.
63+
64+
package {{.Pkg}}
6365
6466
import (
6567
"github.com/kyleconroy/sqlc/internal/sql/ast"
6668
"github.com/kyleconroy/sqlc/internal/sql/catalog"
6769
)
6870
69-
func gen{{.Name}}() *catalog.Schema {
71+
func {{.Name}}() *catalog.Schema {
7072
s := &catalog.Schema{Name: "pg_catalog"}
7173
s.Funcs = []*catalog.Function{
7274
{{- range .Funcs}}
@@ -92,7 +94,29 @@ func gen{{.Name}}() *catalog.Schema {
9294
}
9395
`
9496

97+
const loaderFuncTmpl = `
98+
// Code generated by sqlc-pg-gen. DO NOT EDIT.
99+
100+
package postgresql
101+
102+
import (
103+
"github.com/kyleconroy/sqlc/internal/engine/postgresql/contrib"
104+
"github.com/kyleconroy/sqlc/internal/sql/catalog"
105+
)
106+
107+
func loadExtension(name string) *catalog.Schema {
108+
switch name {
109+
{{- range .}}
110+
case "{{.Name}}":
111+
return contrib.{{.Func}}()
112+
{{- end}}
113+
}
114+
return nil
115+
}
116+
`
117+
95118
type tmplCtx struct {
119+
Pkg string
96120
Name string
97121
Funcs []catalog.Function
98122
}
@@ -220,48 +244,134 @@ func run(ctx context.Context) error {
220244
return err
221245
}
222246
out := bytes.NewBuffer([]byte{})
223-
if err := tmpl.Execute(out, tmplCtx{Name: "PGCatalog", Funcs: funcs}); err != nil {
247+
if err := tmpl.Execute(out, tmplCtx{Pkg: "postgresql", Name: "genPGCatalog", Funcs: funcs}); err != nil {
224248
return err
225249
}
226250
code, err := format.Source(out.Bytes())
227251
if err != nil {
228252
return err
229253
}
230-
err = ioutil.WriteFile(filepath.Join("internal", "engine", "postgresql", "pg_catalog.gen.go"), code, 0644)
254+
err = ioutil.WriteFile(filepath.Join("internal", "engine", "postgresql", "pg_catalog.go"), code, 0644)
231255
if err != nil {
232256
return err
233257
}
234258

235-
// https://www.postgresql.org/docs/current/contrib.html
236-
extensions := map[string]string{
237-
"citext": "CIText",
238-
"pg_trgm": "PGTrigram",
239-
"pgcrypto": "PGCrypto",
240-
"uuid-ossp": "UUIDOSSP",
241-
}
259+
loaded := []extensionPair{}
260+
261+
for _, extension := range extensions {
262+
name := strings.Replace(extension, "-", "_", -1)
263+
264+
var funcName string
265+
for _, part := range strings.Split(name, "_") {
266+
funcName += strings.Title(part)
267+
}
242268

243-
for extension, name := range extensions {
244269
_, err := conn.Exec(ctx, fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS \"%s\"", extension))
245270
if err != nil {
246-
return err
271+
log.Printf("error creating %s: %s", extension, err)
272+
continue
247273
}
274+
248275
rows, err := conn.Query(ctx, extensionFuncs, extension)
276+
if err != nil {
277+
return err
278+
}
249279
funcs, err := scanFuncs(rows)
250280
if err != nil {
251281
return err
252282
}
283+
if len(funcs) == 0 {
284+
log.Printf("no functions in %s, skipping", extension)
285+
continue
286+
}
253287
out := bytes.NewBuffer([]byte{})
254-
if err := tmpl.Execute(out, tmplCtx{Name: name, Funcs: funcs}); err != nil {
288+
if err := tmpl.Execute(out, tmplCtx{Pkg: "contrib", Name: funcName, Funcs: funcs}); err != nil {
255289
return err
256290
}
257291
code, err := format.Source(out.Bytes())
258292
if err != nil {
259293
return err
260294
}
261-
err = ioutil.WriteFile(filepath.Join("internal", "engine", "postgresql", "extension_"+strings.Replace(extension, "-", "_", -1)+".gen.go"), code, 0644)
295+
err = ioutil.WriteFile(filepath.Join("internal", "engine", "postgresql", "contrib", name+".go"), code, 0644)
262296
if err != nil {
263297
return err
264298
}
299+
300+
loaded = append(loaded, extensionPair{Name: extension, Func: funcName})
265301
}
302+
303+
{
304+
tmpl, err := template.New("").Parse(loaderFuncTmpl)
305+
if err != nil {
306+
return err
307+
}
308+
out := bytes.NewBuffer([]byte{})
309+
if err := tmpl.Execute(out, loaded); err != nil {
310+
return err
311+
}
312+
code, err := format.Source(out.Bytes())
313+
if err != nil {
314+
return err
315+
}
316+
err = ioutil.WriteFile(filepath.Join("internal", "engine", "postgresql", "extension.go"), code, 0644)
317+
if err != nil {
318+
return err
319+
}
320+
}
321+
266322
return nil
267323
}
324+
325+
type extensionPair struct {
326+
Name string
327+
Func string
328+
}
329+
330+
// https://www.postgresql.org/docs/current/contrib.html
331+
var extensions = []string{
332+
"adminpack",
333+
"amcheck",
334+
"auth_delay",
335+
"auto_explain",
336+
"bloom",
337+
"btree_gin",
338+
"btree_gist",
339+
"citext",
340+
"cube",
341+
"dblink",
342+
"dict_int",
343+
"dict_xsyn",
344+
"earthdistance",
345+
"file_fdw",
346+
"fuzzystrmatch",
347+
"hstore",
348+
"intagg",
349+
"intarray",
350+
"isn",
351+
"lo",
352+
"ltree",
353+
"pageinspect",
354+
"passwordcheck",
355+
"pg_buffercache",
356+
"pgcrypto",
357+
"pg_freespacemap",
358+
"pg_prewarm",
359+
"pgrowlocks",
360+
"pg_stat_statements",
361+
"pgstattuple",
362+
"pg_trgm",
363+
"pg_visibility",
364+
"postgres_fdw",
365+
"seg",
366+
"sepgsql",
367+
"spi",
368+
"sslinfo",
369+
"tablefunc",
370+
"tcn",
371+
"test_decoding",
372+
"tsm_system_rows",
373+
"tsm_system_time",
374+
"unaccent",
375+
"uuid-ossp",
376+
"xml2",
377+
}

0 commit comments

Comments
 (0)