gitlab.com/picnic-app/backend/role-api@v0.0.0-20230614140944-06a76ff3696d/internal/repo/spanner/helpers/helpers.go (about) 1 package helpers 2 3 import ( 4 "context" 5 "fmt" 6 "reflect" 7 8 "cloud.google.com/go/spanner" 9 "github.com/Masterminds/squirrel" 10 "google.golang.org/api/iterator" 11 "google.golang.org/grpc/codes" 12 "google.golang.org/grpc/status" 13 14 "gitlab.com/picnic-app/backend/role-api/internal/repo/spanner/tables" 15 ) 16 17 type Reader interface { 18 Read(ctx context.Context, table string, k spanner.KeySet, cols []string) *spanner.RowIterator 19 } 20 21 type RowReader interface { 22 ReadRow(ctx context.Context, table string, k spanner.Key, cols []string) (*spanner.Row, error) 23 } 24 25 type Queryer interface { 26 Query(context.Context, spanner.Statement) *spanner.RowIterator 27 } 28 29 type Updater interface { 30 Update(context.Context, spanner.Statement) (int64, error) 31 } 32 33 type BufferWriter interface { 34 BufferWrite([]*spanner.Mutation) error 35 } 36 37 type DBWriter interface { 38 Updater 39 BufferWriter 40 } 41 42 type IndexReader interface { 43 ReadUsingIndex(ctx context.Context, table, index string, keys spanner.KeySet, columns []string) *spanner.RowIterator 44 } 45 46 func DeleteByBuilder(ctx context.Context, db Updater, b squirrel.DeleteBuilder) (int64, error) { 47 q, args, err := b.PlaceholderFormat(squirrel.AtP).ToSql() 48 if err != nil { 49 return 0, err 50 } 51 52 stmt := spanner.Statement{SQL: q, Params: ArgsToParams(args)} 53 return db.Update(ctx, stmt) 54 } 55 56 func UpdateByBuilder(ctx context.Context, db Updater, b squirrel.UpdateBuilder) (int64, error) { 57 q, args, err := b.PlaceholderFormat(squirrel.AtP).ToSql() 58 if err != nil { 59 return 0, err 60 } 61 62 stmt := spanner.Statement{SQL: q, Params: ArgsToParams(args)} 63 return db.Update(ctx, stmt) 64 } 65 66 func GetResultsByBuilder[M any](ctx context.Context, db Queryer, b squirrel.SelectBuilder) ([]M, error) { 67 q, args, err := b.PlaceholderFormat(squirrel.AtP).ToSql() 68 if err != nil { 69 return nil, err 70 } 71 72 stmt := spanner.Statement{SQL: q, Params: ArgsToParams(args)} 73 return GetResults[M](db.Query(ctx, stmt)) 74 } 75 76 func GetResultByBuilder[M any](ctx context.Context, db Queryer, b squirrel.SelectBuilder) (m M, err error) { 77 q, args, err := b.PlaceholderFormat(squirrel.AtP).ToSql() 78 if err != nil { 79 return m, err 80 } 81 82 stmt := spanner.Statement{SQL: q, Params: ArgsToParams(args)} 83 return GetResult[M](db.Query(ctx, stmt)) 84 } 85 86 func GetByKey[M any, T tables.Table, K Key](ctx context.Context, db RowReader, k K) (m M, err error) { 87 var t T 88 row, err := db.ReadRow(ctx, t.TableName(), key(k), t.Columns()) 89 if err != nil { 90 return m, err 91 } 92 return scanFunc[M]()(row) 93 } 94 95 func GetByKeys[M any, T tables.Table, K Key](ctx context.Context, db Reader, ids ...K) ([]M, error) { 96 var t T 97 return GetResults[M](db.Read(ctx, t.TableName(), keySet(ids...), t.Columns())) 98 } 99 100 func Delete[T tables.Table, K Key](db BufferWriter, keys ...K) error { 101 if len(keys) == 0 { 102 return nil 103 } 104 105 var t T 106 m := spanner.Delete(t.TableName(), keySet(keys...)) 107 return db.BufferWrite([]*spanner.Mutation{m}) 108 } 109 110 // GetResult returns the result from the iterator. Calls Stop after the iterator 111 // is finished. 112 func GetResult[M any](iter *spanner.RowIterator) (out M, err error) { 113 defer iter.Stop() 114 115 row, err := iter.Next() 116 if err != nil { 117 if err == iterator.Done { 118 return out, status.Error(codes.NotFound, reflect.TypeOf(out).Name()) 119 } 120 121 return out, err 122 } 123 124 return scanFunc[M]()(row) 125 } 126 127 // GetResults returns the results from the iterator. Calls Stop after the 128 // iterator is finished. 129 func GetResults[M any](iter *spanner.RowIterator) (out []M, err error) { 130 f := scanFunc[M]() 131 err = iter.Do(func(row *spanner.Row) error { 132 m, err := f(row) 133 if err != nil { 134 return err 135 } 136 out = append(out, m) 137 return nil 138 }) 139 return out, err 140 } 141 142 // GetPtrResults returns the results from the iterator. Calls Stop after the 143 // iterator is finished. 144 func GetPtrResults[M any](iter *spanner.RowIterator) (out []*M, err error) { 145 f := scanFunc[M]() 146 err = iter.Do( 147 func(row *spanner.Row) (err error) { 148 m, err := f(row) 149 if err == nil { 150 out = append(out, &m) 151 } 152 return err 153 }, 154 ) 155 return out, err 156 } 157 158 func structScanFunc[M any](row *spanner.Row) (m M, err error) { return m, row.ToStructLenient(&m) } 159 func primitiveScanFunc[M any](row *spanner.Row) (m M, err error) { return m, row.Columns(&m) } 160 func scanFunc[M any]() func(row *spanner.Row) (M, error) { 161 var m M 162 if reflect.TypeOf(m).Kind() != reflect.Struct { 163 return primitiveScanFunc[M] 164 } 165 return structScanFunc[M] 166 } 167 168 func ArgsToParams(args []interface{}) map[string]interface{} { 169 params := make(map[string]interface{}, len(args)) 170 for n := 0; n < len(args); n++ { 171 params[fmt.Sprintf("p%d", n+1)] = args[n] 172 } 173 174 return params 175 } 176 177 type Key interface { 178 string | spanner.Key | spanner.KeyRange 179 } 180 181 func key[K Key](key K) spanner.Key { 182 switch ids := (interface{})(key).(type) { 183 case spanner.Key: 184 return ids 185 case string: 186 return spanner.Key{ids} 187 } 188 return nil 189 } 190 191 func keySet[K Key](keys ...K) spanner.KeySet { 192 if len(keys) == 0 { 193 return spanner.AllKeys() 194 } 195 196 switch ids := (interface{})(keys).(type) { 197 case []spanner.KeyRange: 198 keys := make([]spanner.KeySet, len(ids)) 199 for i, id := range ids { 200 keys[i] = id 201 } 202 return spanner.KeySets(keys...) 203 case []spanner.Key: 204 return spanner.KeySetFromKeys(ids...) 205 case []string: 206 keys := make([]spanner.Key, len(ids)) 207 for i, id := range ids { 208 keys[i] = spanner.Key{id} 209 } 210 return spanner.KeySetFromKeys(keys...) 211 } 212 return nil 213 }