Skip to content

Commit 17808a8

Browse files
authored
Merge pull request #19 from bitechdev/feature-keystore
Feature keystore
2 parents aa095d6 + 134ff85 commit 17808a8

File tree

18 files changed

+1485
-168
lines changed

18 files changed

+1485
-168
lines changed

pkg/common/adapters/database/bun.go

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"reflect"
88
"strings"
9+
"sync"
910
"time"
1011

1112
"github.com/uptrace/bun"
@@ -95,6 +96,8 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
9596
// This demonstrates how the abstraction works with different ORMs
9697
type BunAdapter struct {
9798
db *bun.DB
99+
dbMu sync.RWMutex
100+
dbFactory func() (*bun.DB, error)
98101
driverName string
99102
}
100103

@@ -106,10 +109,36 @@ func NewBunAdapter(db *bun.DB) *BunAdapter {
106109
return adapter
107110
}
108111

112+
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
113+
func (b *BunAdapter) WithDBFactory(factory func() (*bun.DB, error)) *BunAdapter {
114+
b.dbFactory = factory
115+
return b
116+
}
117+
118+
func (b *BunAdapter) getDB() *bun.DB {
119+
b.dbMu.RLock()
120+
defer b.dbMu.RUnlock()
121+
return b.db
122+
}
123+
124+
func (b *BunAdapter) reconnectDB() error {
125+
if b.dbFactory == nil {
126+
return fmt.Errorf("no db factory configured for reconnect")
127+
}
128+
newDB, err := b.dbFactory()
129+
if err != nil {
130+
return err
131+
}
132+
b.dbMu.Lock()
133+
b.db = newDB
134+
b.dbMu.Unlock()
135+
return nil
136+
}
137+
109138
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
110139
// This is useful for debugging preload queries that may be failing
111140
func (b *BunAdapter) EnableQueryDebug() {
112-
b.db.AddQueryHook(&QueryDebugHook{})
141+
b.getDB().AddQueryHook(&QueryDebugHook{})
113142
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
114143
}
115144

@@ -130,22 +159,22 @@ func (b *BunAdapter) DisableQueryDebug() {
130159

131160
func (b *BunAdapter) NewSelect() common.SelectQuery {
132161
return &BunSelectQuery{
133-
query: b.db.NewSelect(),
162+
query: b.getDB().NewSelect(),
134163
db: b.db,
135164
driverName: b.driverName,
136165
}
137166
}
138167

139168
func (b *BunAdapter) NewInsert() common.InsertQuery {
140-
return &BunInsertQuery{query: b.db.NewInsert()}
169+
return &BunInsertQuery{query: b.getDB().NewInsert()}
141170
}
142171

143172
func (b *BunAdapter) NewUpdate() common.UpdateQuery {
144-
return &BunUpdateQuery{query: b.db.NewUpdate()}
173+
return &BunUpdateQuery{query: b.getDB().NewUpdate()}
145174
}
146175

147176
func (b *BunAdapter) NewDelete() common.DeleteQuery {
148-
return &BunDeleteQuery{query: b.db.NewDelete()}
177+
return &BunDeleteQuery{query: b.getDB().NewDelete()}
149178
}
150179

151180
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
@@ -154,7 +183,14 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
154183
err = logger.HandlePanic("BunAdapter.Exec", r)
155184
}
156185
}()
157-
result, err := b.db.ExecContext(ctx, query, args...)
186+
var result sql.Result
187+
run := func() error { var e error; result, e = b.getDB().ExecContext(ctx, query, args...); return e }
188+
err = run()
189+
if isDBClosed(err) {
190+
if reconnErr := b.reconnectDB(); reconnErr == nil {
191+
err = run()
192+
}
193+
}
158194
return &BunResult{result: result}, err
159195
}
160196

@@ -164,11 +200,17 @@ func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string,
164200
err = logger.HandlePanic("BunAdapter.Query", r)
165201
}
166202
}()
167-
return b.db.NewRaw(query, args...).Scan(ctx, dest)
203+
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
204+
if isDBClosed(err) {
205+
if reconnErr := b.reconnectDB(); reconnErr == nil {
206+
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
207+
}
208+
}
209+
return err
168210
}
169211

170212
func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) {
171-
tx, err := b.db.BeginTx(ctx, &sql.TxOptions{})
213+
tx, err := b.getDB().BeginTx(ctx, &sql.TxOptions{})
172214
if err != nil {
173215
return nil, err
174216
}
@@ -194,15 +236,15 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
194236
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
195237
}
196238
}()
197-
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
239+
return b.getDB().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
198240
// Create adapter with transaction
199241
adapter := &BunTxAdapter{tx: tx, driverName: b.driverName}
200242
return fn(adapter)
201243
})
202244
}
203245

204246
func (b *BunAdapter) GetUnderlyingDB() interface{} {
205-
return b.db
247+
return b.getDB()
206248
}
207249

208250
func (b *BunAdapter) DriverName() string {

pkg/common/adapters/database/pgsql.go

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"reflect"
88
"strings"
9+
"sync"
910

1011
"github.com/bitechdev/ResolveSpec/pkg/common"
1112
"github.com/bitechdev/ResolveSpec/pkg/logger"
@@ -17,6 +18,8 @@ import (
1718
// This provides a lightweight PostgreSQL adapter without ORM overhead
1819
type PgSQLAdapter struct {
1920
db *sql.DB
21+
dbMu sync.RWMutex
22+
dbFactory func() (*sql.DB, error)
2023
driverName string
2124
}
2225

@@ -31,14 +34,44 @@ func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter {
3134
return &PgSQLAdapter{db: db, driverName: name}
3235
}
3336

37+
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
38+
func (p *PgSQLAdapter) WithDBFactory(factory func() (*sql.DB, error)) *PgSQLAdapter {
39+
p.dbFactory = factory
40+
return p
41+
}
42+
43+
func (p *PgSQLAdapter) getDB() *sql.DB {
44+
p.dbMu.RLock()
45+
defer p.dbMu.RUnlock()
46+
return p.db
47+
}
48+
49+
func (p *PgSQLAdapter) reconnectDB() error {
50+
if p.dbFactory == nil {
51+
return fmt.Errorf("no db factory configured for reconnect")
52+
}
53+
newDB, err := p.dbFactory()
54+
if err != nil {
55+
return err
56+
}
57+
p.dbMu.Lock()
58+
p.db = newDB
59+
p.dbMu.Unlock()
60+
return nil
61+
}
62+
63+
func isDBClosed(err error) bool {
64+
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
65+
}
66+
3467
// EnableQueryDebug enables query debugging for development
3568
func (p *PgSQLAdapter) EnableQueryDebug() {
3669
logger.Info("PgSQL query debug mode - logging enabled via logger")
3770
}
3871

3972
func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
4073
return &PgSQLSelectQuery{
41-
db: p.db,
74+
db: p.getDB(),
4275
driverName: p.driverName,
4376
columns: []string{"*"},
4477
args: make([]interface{}, 0),
@@ -47,15 +80,15 @@ func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
4780

4881
func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
4982
return &PgSQLInsertQuery{
50-
db: p.db,
83+
db: p.getDB(),
5184
driverName: p.driverName,
5285
values: make(map[string]interface{}),
5386
}
5487
}
5588

5689
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
5790
return &PgSQLUpdateQuery{
58-
db: p.db,
91+
db: p.getDB(),
5992
driverName: p.driverName,
6093
sets: make(map[string]interface{}),
6194
args: make([]interface{}, 0),
@@ -65,7 +98,7 @@ func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
6598

6699
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
67100
return &PgSQLDeleteQuery{
68-
db: p.db,
101+
db: p.getDB(),
69102
driverName: p.driverName,
70103
args: make([]interface{}, 0),
71104
whereClauses: make([]string, 0),
@@ -79,7 +112,14 @@ func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface
79112
}
80113
}()
81114
logger.Debug("PgSQL Exec: %s [args: %v]", query, args)
82-
result, err := p.db.ExecContext(ctx, query, args...)
115+
var result sql.Result
116+
run := func() error { var e error; result, e = p.getDB().ExecContext(ctx, query, args...); return e }
117+
err = run()
118+
if isDBClosed(err) {
119+
if reconnErr := p.reconnectDB(); reconnErr == nil {
120+
err = run()
121+
}
122+
}
83123
if err != nil {
84124
logger.Error("PgSQL Exec failed: %v", err)
85125
return nil, err
@@ -94,7 +134,14 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
94134
}
95135
}()
96136
logger.Debug("PgSQL Query: %s [args: %v]", query, args)
97-
rows, err := p.db.QueryContext(ctx, query, args...)
137+
var rows *sql.Rows
138+
run := func() error { var e error; rows, e = p.getDB().QueryContext(ctx, query, args...); return e }
139+
err = run()
140+
if isDBClosed(err) {
141+
if reconnErr := p.reconnectDB(); reconnErr == nil {
142+
err = run()
143+
}
144+
}
98145
if err != nil {
99146
logger.Error("PgSQL Query failed: %v", err)
100147
return err
@@ -105,7 +152,7 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
105152
}
106153

107154
func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
108-
tx, err := p.db.BeginTx(ctx, nil)
155+
tx, err := p.getDB().BeginTx(ctx, nil)
109156
if err != nil {
110157
return nil, err
111158
}
@@ -127,7 +174,7 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
127174
}
128175
}()
129176

130-
tx, err := p.db.BeginTx(ctx, nil)
177+
tx, err := p.getDB().BeginTx(ctx, nil)
131178
if err != nil {
132179
return err
133180
}

0 commit comments

Comments
 (0)