gitlab.com/picnic-app/backend/role-api@v0.0.0-20230614140944-06a76ff3696d/internal/repo/spanner/helpers/helpers.go (about)

     1  package helpers
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"reflect"
     7  
     8  	"cloud.google.com/go/spanner"
     9  	"github.com/Masterminds/squirrel"
    10  	"google.golang.org/api/iterator"
    11  	"google.golang.org/grpc/codes"
    12  	"google.golang.org/grpc/status"
    13  
    14  	"gitlab.com/picnic-app/backend/role-api/internal/repo/spanner/tables"
    15  )
    16  
    17  type Reader interface {
    18  	Read(ctx context.Context, table string, k spanner.KeySet, cols []string) *spanner.RowIterator
    19  }
    20  
    21  type RowReader interface {
    22  	ReadRow(ctx context.Context, table string, k spanner.Key, cols []string) (*spanner.Row, error)
    23  }
    24  
    25  type Queryer interface {
    26  	Query(context.Context, spanner.Statement) *spanner.RowIterator
    27  }
    28  
    29  type Updater interface {
    30  	Update(context.Context, spanner.Statement) (int64, error)
    31  }
    32  
    33  type BufferWriter interface {
    34  	BufferWrite([]*spanner.Mutation) error
    35  }
    36  
    37  type DBWriter interface {
    38  	Updater
    39  	BufferWriter
    40  }
    41  
    42  type IndexReader interface {
    43  	ReadUsingIndex(ctx context.Context, table, index string, keys spanner.KeySet, columns []string) *spanner.RowIterator
    44  }
    45  
    46  func DeleteByBuilder(ctx context.Context, db Updater, b squirrel.DeleteBuilder) (int64, error) {
    47  	q, args, err := b.PlaceholderFormat(squirrel.AtP).ToSql()
    48  	if err != nil {
    49  		return 0, err
    50  	}
    51  
    52  	stmt := spanner.Statement{SQL: q, Params: ArgsToParams(args)}
    53  	return db.Update(ctx, stmt)
    54  }
    55  
    56  func UpdateByBuilder(ctx context.Context, db Updater, b squirrel.UpdateBuilder) (int64, error) {
    57  	q, args, err := b.PlaceholderFormat(squirrel.AtP).ToSql()
    58  	if err != nil {
    59  		return 0, err
    60  	}
    61  
    62  	stmt := spanner.Statement{SQL: q, Params: ArgsToParams(args)}
    63  	return db.Update(ctx, stmt)
    64  }
    65  
    66  func GetResultsByBuilder[M any](ctx context.Context, db Queryer, b squirrel.SelectBuilder) ([]M, error) {
    67  	q, args, err := b.PlaceholderFormat(squirrel.AtP).ToSql()
    68  	if err != nil {
    69  		return nil, err
    70  	}
    71  
    72  	stmt := spanner.Statement{SQL: q, Params: ArgsToParams(args)}
    73  	return GetResults[M](db.Query(ctx, stmt))
    74  }
    75  
    76  func GetResultByBuilder[M any](ctx context.Context, db Queryer, b squirrel.SelectBuilder) (m M, err error) {
    77  	q, args, err := b.PlaceholderFormat(squirrel.AtP).ToSql()
    78  	if err != nil {
    79  		return m, err
    80  	}
    81  
    82  	stmt := spanner.Statement{SQL: q, Params: ArgsToParams(args)}
    83  	return GetResult[M](db.Query(ctx, stmt))
    84  }
    85  
    86  func GetByKey[M any, T tables.Table, K Key](ctx context.Context, db RowReader, k K) (m M, err error) {
    87  	var t T
    88  	row, err := db.ReadRow(ctx, t.TableName(), key(k), t.Columns())
    89  	if err != nil {
    90  		return m, err
    91  	}
    92  	return scanFunc[M]()(row)
    93  }
    94  
    95  func GetByKeys[M any, T tables.Table, K Key](ctx context.Context, db Reader, ids ...K) ([]M, error) {
    96  	var t T
    97  	return GetResults[M](db.Read(ctx, t.TableName(), keySet(ids...), t.Columns()))
    98  }
    99  
   100  func Delete[T tables.Table, K Key](db BufferWriter, keys ...K) error {
   101  	if len(keys) == 0 {
   102  		return nil
   103  	}
   104  
   105  	var t T
   106  	m := spanner.Delete(t.TableName(), keySet(keys...))
   107  	return db.BufferWrite([]*spanner.Mutation{m})
   108  }
   109  
   110  // GetResult returns the result from the iterator. Calls Stop after the iterator
   111  // is finished.
   112  func GetResult[M any](iter *spanner.RowIterator) (out M, err error) {
   113  	defer iter.Stop()
   114  
   115  	row, err := iter.Next()
   116  	if err != nil {
   117  		if err == iterator.Done {
   118  			return out, status.Error(codes.NotFound, reflect.TypeOf(out).Name())
   119  		}
   120  
   121  		return out, err
   122  	}
   123  
   124  	return scanFunc[M]()(row)
   125  }
   126  
   127  // GetResults returns the results from the iterator. Calls Stop after the
   128  // iterator is finished.
   129  func GetResults[M any](iter *spanner.RowIterator) (out []M, err error) {
   130  	f := scanFunc[M]()
   131  	err = iter.Do(func(row *spanner.Row) error {
   132  		m, err := f(row)
   133  		if err != nil {
   134  			return err
   135  		}
   136  		out = append(out, m)
   137  		return nil
   138  	})
   139  	return out, err
   140  }
   141  
   142  // GetPtrResults returns the results from the iterator. Calls Stop after the
   143  // iterator is finished.
   144  func GetPtrResults[M any](iter *spanner.RowIterator) (out []*M, err error) {
   145  	f := scanFunc[M]()
   146  	err = iter.Do(
   147  		func(row *spanner.Row) (err error) {
   148  			m, err := f(row)
   149  			if err == nil {
   150  				out = append(out, &m)
   151  			}
   152  			return err
   153  		},
   154  	)
   155  	return out, err
   156  }
   157  
   158  func structScanFunc[M any](row *spanner.Row) (m M, err error)    { return m, row.ToStructLenient(&m) }
   159  func primitiveScanFunc[M any](row *spanner.Row) (m M, err error) { return m, row.Columns(&m) }
   160  func scanFunc[M any]() func(row *spanner.Row) (M, error) {
   161  	var m M
   162  	if reflect.TypeOf(m).Kind() != reflect.Struct {
   163  		return primitiveScanFunc[M]
   164  	}
   165  	return structScanFunc[M]
   166  }
   167  
   168  func ArgsToParams(args []interface{}) map[string]interface{} {
   169  	params := make(map[string]interface{}, len(args))
   170  	for n := 0; n < len(args); n++ {
   171  		params[fmt.Sprintf("p%d", n+1)] = args[n]
   172  	}
   173  
   174  	return params
   175  }
   176  
   177  type Key interface {
   178  	string | spanner.Key | spanner.KeyRange
   179  }
   180  
   181  func key[K Key](key K) spanner.Key {
   182  	switch ids := (interface{})(key).(type) {
   183  	case spanner.Key:
   184  		return ids
   185  	case string:
   186  		return spanner.Key{ids}
   187  	}
   188  	return nil
   189  }
   190  
   191  func keySet[K Key](keys ...K) spanner.KeySet {
   192  	if len(keys) == 0 {
   193  		return spanner.AllKeys()
   194  	}
   195  
   196  	switch ids := (interface{})(keys).(type) {
   197  	case []spanner.KeyRange:
   198  		keys := make([]spanner.KeySet, len(ids))
   199  		for i, id := range ids {
   200  			keys[i] = id
   201  		}
   202  		return spanner.KeySets(keys...)
   203  	case []spanner.Key:
   204  		return spanner.KeySetFromKeys(ids...)
   205  	case []string:
   206  		keys := make([]spanner.Key, len(ids))
   207  		for i, id := range ids {
   208  			keys[i] = spanner.Key{id}
   209  		}
   210  		return spanner.KeySetFromKeys(keys...)
   211  	}
   212  	return nil
   213  }