github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/pkg/dal/session.go (about) 1 package dal 2 3 import ( 4 "context" 5 "fmt" 6 "sync" 7 8 "github.com/pkg/errors" 9 10 "github.com/octohelm/storage/internal/sql/adapter" 11 "github.com/octohelm/storage/pkg/sqlbuilder" 12 contextx "github.com/octohelm/x/context" 13 ) 14 15 func Tx(ctx context.Context, m sqlbuilder.Model, action func(ctx context.Context) error) error { 16 return SessionFor(ctx, m).Tx(ctx, action) 17 } 18 19 var catalogs = sync.Map{} 20 21 func registerSessionCatalog(name string, tables *sqlbuilder.Tables) { 22 tables.Range(func(tab sqlbuilder.Table, idx int) bool { 23 catalogs.Store(tab.TableName(), name) 24 return true 25 }) 26 } 27 28 type TableWrapper interface { 29 Unwrap() sqlbuilder.Model 30 } 31 32 func SessionFor(ctx context.Context, nameOrTable any) Session { 33 if u, ok := nameOrTable.(TableWrapper); ok { 34 return SessionFor(ctx, u.Unwrap()) 35 } 36 37 switch x := nameOrTable.(type) { 38 case string: 39 return FromContext(ctx, x) 40 case sqlbuilder.Model: 41 if t, ok := catalogs.Load(x.TableName()); ok { 42 return FromContext(ctx, t.(string)) 43 } 44 } 45 46 panic(errors.Errorf("invalid section target %#v", nameOrTable)) 47 } 48 49 type contextSession struct { 50 name string 51 } 52 53 func InjectContext(ctx context.Context, repo Session) context.Context { 54 return contextx.WithValue(ctx, contextSession{name: repo.Name()}, repo) 55 } 56 57 func FromContext(ctx context.Context, name string) Session { 58 r, ok := ctx.Value(contextSession{name: name}).(Session) 59 if ok { 60 return r 61 } 62 panic(fmt.Sprintf("missing session of %s", name)) 63 } 64 65 type Session interface { 66 // Name of database 67 Name() string 68 T(m any) sqlbuilder.Table 69 Tx(ctx context.Context, fn func(ctx context.Context) error) error 70 71 Adapter() adapter.Adapter 72 } 73 74 func New(a adapter.Adapter, name string) Session { 75 return &session{ 76 name: name, 77 adapter: a, 78 } 79 } 80 81 type session struct { 82 name string 83 adapter adapter.Adapter 84 } 85 86 func (s *session) Adapter() adapter.Adapter { 87 return s.adapter 88 } 89 90 func (s *session) Name() string { 91 return s.name 92 } 93 94 func (s *session) Tx(ctx context.Context, fn func(ctx context.Context) error) error { 95 return s.adapter.Transaction(ctx, fn) 96 } 97 98 func (s *session) T(m any) sqlbuilder.Table { 99 if td, ok := m.(sqlbuilder.TableDefinition); ok { 100 return td.T() 101 } 102 if td, ok := m.(sqlbuilder.Table); ok { 103 return td 104 } 105 return sqlbuilder.TableFromModel(m) 106 }