github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/pkg/dal/session.go (about)

     1  package dal
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  
     8  	"github.com/pkg/errors"
     9  
    10  	"github.com/octohelm/storage/internal/sql/adapter"
    11  	"github.com/octohelm/storage/pkg/sqlbuilder"
    12  	contextx "github.com/octohelm/x/context"
    13  )
    14  
    15  func Tx(ctx context.Context, m sqlbuilder.Model, action func(ctx context.Context) error) error {
    16  	return SessionFor(ctx, m).Tx(ctx, action)
    17  }
    18  
    19  var catalogs = sync.Map{}
    20  
    21  func registerSessionCatalog(name string, tables *sqlbuilder.Tables) {
    22  	tables.Range(func(tab sqlbuilder.Table, idx int) bool {
    23  		catalogs.Store(tab.TableName(), name)
    24  		return true
    25  	})
    26  }
    27  
    28  type TableWrapper interface {
    29  	Unwrap() sqlbuilder.Model
    30  }
    31  
    32  func SessionFor(ctx context.Context, nameOrTable any) Session {
    33  	if u, ok := nameOrTable.(TableWrapper); ok {
    34  		return SessionFor(ctx, u.Unwrap())
    35  	}
    36  
    37  	switch x := nameOrTable.(type) {
    38  	case string:
    39  		return FromContext(ctx, x)
    40  	case sqlbuilder.Model:
    41  		if t, ok := catalogs.Load(x.TableName()); ok {
    42  			return FromContext(ctx, t.(string))
    43  		}
    44  	}
    45  
    46  	panic(errors.Errorf("invalid section target %#v", nameOrTable))
    47  }
    48  
    49  type contextSession struct {
    50  	name string
    51  }
    52  
    53  func InjectContext(ctx context.Context, repo Session) context.Context {
    54  	return contextx.WithValue(ctx, contextSession{name: repo.Name()}, repo)
    55  }
    56  
    57  func FromContext(ctx context.Context, name string) Session {
    58  	r, ok := ctx.Value(contextSession{name: name}).(Session)
    59  	if ok {
    60  		return r
    61  	}
    62  	panic(fmt.Sprintf("missing session of %s", name))
    63  }
    64  
    65  type Session interface {
    66  	// Name of database
    67  	Name() string
    68  	T(m any) sqlbuilder.Table
    69  	Tx(ctx context.Context, fn func(ctx context.Context) error) error
    70  
    71  	Adapter() adapter.Adapter
    72  }
    73  
    74  func New(a adapter.Adapter, name string) Session {
    75  	return &session{
    76  		name:    name,
    77  		adapter: a,
    78  	}
    79  }
    80  
    81  type session struct {
    82  	name    string
    83  	adapter adapter.Adapter
    84  }
    85  
    86  func (s *session) Adapter() adapter.Adapter {
    87  	return s.adapter
    88  }
    89  
    90  func (s *session) Name() string {
    91  	return s.name
    92  }
    93  
    94  func (s *session) Tx(ctx context.Context, fn func(ctx context.Context) error) error {
    95  	return s.adapter.Transaction(ctx, fn)
    96  }
    97  
    98  func (s *session) T(m any) sqlbuilder.Table {
    99  	if td, ok := m.(sqlbuilder.TableDefinition); ok {
   100  		return td.T()
   101  	}
   102  	if td, ok := m.(sqlbuilder.Table); ok {
   103  		return td
   104  	}
   105  	return sqlbuilder.TableFromModel(m)
   106  }