code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/cursor.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package sqlstore 17 18 import ( 19 "bytes" 20 "encoding/gob" 21 "fmt" 22 "strings" 23 24 "code.vegaprotocol.io/vega/datanode/entities" 25 ) 26 27 type ( 28 Sorting = string 29 Compare = string 30 ) 31 32 const ( 33 ASC Sorting = "ASC" 34 DESC Sorting = "DESC" 35 36 EQ Compare = "=" 37 NE Compare = "!=" 38 GT Compare = ">" 39 LT Compare = "<" 40 GE Compare = ">=" 41 LE Compare = "<=" 42 ) 43 44 type ColumnOrdering struct { 45 // Name of the column in the database table to match to the struct field 46 Name string 47 // Sorting is the sorting order to use for the column 48 Sorting Sorting 49 // Prefix is the prefix to add to the column name in order to resolve duplicate 50 // column names that might be in the query 51 Prefix string 52 // If the column originates from parsing a JSON field, how it should be referenced in the query. 53 Ref string 54 } 55 56 func NewColumnOrdering(name string, sorting Sorting) ColumnOrdering { 57 return ColumnOrdering{Name: name, Sorting: sorting} 58 } 59 60 type TableOrdering []ColumnOrdering 61 62 func (t *TableOrdering) OrderByClause() string { 63 if len(*t) == 0 { 64 return "" 65 } 66 67 fragments := make([]string, len(*t)) 68 for i, column := range *t { 69 prefix := column.Prefix 70 if column.Prefix != "" && !strings.HasSuffix(column.Prefix, ".") { 71 prefix += "." 72 } 73 fragments[i] = fmt.Sprintf("%s%s %s", prefix, column.Name, column.Sorting) 74 } 75 return fmt.Sprintf("ORDER BY %s", strings.Join(fragments, ",")) 76 } 77 78 func (t *TableOrdering) Reversed() TableOrdering { 79 reversed := make([]ColumnOrdering, len(*t)) 80 for i, column := range *t { 81 if column.Sorting == DESC { 82 reversed[i] = ColumnOrdering{Name: column.Name, Sorting: ASC, Ref: column.Ref} 83 } 84 if column.Sorting == ASC { 85 reversed[i] = ColumnOrdering{Name: column.Name, Sorting: DESC, Ref: column.Ref} 86 } 87 } 88 return reversed 89 } 90 91 // SetPrefixAll sets a prefix for all columns in the table ordering slice. 92 func (t *TableOrdering) SetPrefixAll(pf string) { 93 if len(*t) == 0 { 94 return 95 } 96 // need to cast to underlying slice type to be able to re-assign elements. 97 ts := []ColumnOrdering(*t) 98 for i, col := range *t { 99 col.Prefix = pf 100 ts[i] = col 101 } 102 // cast is needed here, if not the unit test fails. 103 *t = TableOrdering(ts) 104 } 105 106 // CursorPredicate generates an SQL predicate which excludes all rows before the supplied cursor, 107 // with regards to the supplied table ordering. The values used for comparison are added to 108 // the args list and bind variables used in the query fragment. 109 // 110 // For example, with if you had a query with columns sorted foo ASCENDING, bar DESCENDING and a 111 // cursor with {foo=1, bar=2}, it would yield a string predicate like this: 112 // 113 // (foo > $1) OR (foo = $1 AND bar <= $2) 114 // 115 // And 'args' would have 1 and 2 appended to it. 116 // 117 // Notes: 118 // - The predicate *includes* the value at the cursor 119 // - Only fields that are present in both the cursor and the ordering are considered 120 // - The union of those fields must have enough information to uniquely identify a row 121 // - The table ordering must be sufficient to ensure that a row identified by a cursor cannot 122 // change position in relation to the other rows 123 func CursorPredicate(args []interface{}, cursor interface{}, ordering TableOrdering) (string, []interface{}, error) { 124 cursorPredicates := []string{} 125 equalPredicates := []string{} 126 127 for i, column := range ordering { 128 // For the non-last columns, use LT/GT, so we don't include stuff before the cursor 129 var operator string 130 if column.Sorting == ASC { 131 operator = ">" 132 } else if column.Sorting == DESC { 133 operator = "<" 134 } else { 135 return "", nil, fmt.Errorf("unknown sort direction %s", column.Sorting) 136 } 137 138 // For the last column, we want to use GTE/LTE so we include the value at the cursor 139 isLast := i == (len(ordering) - 1) 140 if isLast { 141 operator = operator + "=" 142 } 143 144 value, err := StructValueForColumn(cursor, column.Name) 145 if err != nil { 146 return "", nil, err 147 } 148 149 prefix := column.Prefix 150 if column.Prefix != "" && !strings.HasSuffix(column.Prefix, ".") { 151 prefix += "." 152 } 153 154 bindVar := nextBindVar(&args, value) 155 ref := column.Name 156 if len(column.Ref) > 0 { 157 ref = column.Ref 158 } 159 inequalityPredicate := fmt.Sprintf("%s%s %s %s", prefix, ref, operator, bindVar) 160 161 colPredicates := append(equalPredicates, inequalityPredicate) 162 colPredicateString := strings.Join(colPredicates, " AND ") 163 colPredicateString = fmt.Sprintf("(%s)", colPredicateString) 164 cursorPredicates = append(cursorPredicates, colPredicateString) 165 166 equalityPredicate := fmt.Sprintf("%s%s = %s", prefix, ref, bindVar) 167 equalPredicates = append(equalPredicates, equalityPredicate) 168 } 169 170 predicateString := strings.Join(cursorPredicates, " OR ") 171 172 return predicateString, args, nil 173 } 174 175 type parser interface { 176 Parse(string) error 177 } 178 179 // This is a bit magical, it allows us to use the real cursor type for instantiation and the pointer 180 // type for calling methods with pointer receivers (e.g. Parse) for details see 181 // https://go.googlesource.com/proposal/+/refs/heads/master/design/43651-type-parameters.md#pointer-method-example 182 type parserPtr[T any] interface { 183 parser 184 *T 185 } 186 187 // We have to roll our own equals function here for comparing the cursors because some cursor parameters use 188 // types that do not implement `comparable`. 189 func equals[T any](actual, other T) (bool, error) { 190 var a, b bytes.Buffer 191 enc := gob.NewEncoder(&a) 192 err := enc.Encode(actual) 193 if err != nil { 194 return false, err 195 } 196 197 enc = gob.NewEncoder(&b) 198 err = enc.Encode(other) 199 if err != nil { 200 return false, err 201 } 202 203 return bytes.Equal(a.Bytes(), b.Bytes()), nil 204 } 205 206 // PaginateQuery takes a query string & bind arg list and returns the same with additional SQL to 207 // - exclude rows before the cursor (or after it if the cursor is a backwards looking one) 208 // - limit the number of rows to the pagination limit +1 (no cursor) or +2 (cursor) 209 // [for purposes of later figuring out whether there are next or previous pages] 210 // - order the query according to the TableOrdering supplied 211 // the order is reversed if pagination request is backwards 212 // 213 // For example with cursor to a row where foo=42, and a pagination saying get the next 3 then: 214 // PaginateQuery[MyCursor]("SELECT foo FROM my_table", args, ordering, pagination) 215 // 216 // Would append `42` to the arg list and return 217 // SELECT foo FROM my_table WHERE foo>=$1 ORDER BY foo ASC LIMIT 5 218 // 219 // See CursorPredicate() for more details about how the cursor filtering is done. 220 func PaginateQuery[T any, PT parserPtr[T]]( 221 query string, 222 args []interface{}, 223 ordering TableOrdering, 224 pagination entities.CursorPagination, 225 ) (string, []interface{}, error) { 226 return paginateQueryInternal[T, PT](query, args, ordering, pagination, false, false) 227 } 228 229 func PaginateQueryWithWhere[T any, PT parserPtr[T]]( 230 query string, 231 args []interface{}, 232 ordering TableOrdering, 233 pagination entities.CursorPagination, 234 ) (string, []interface{}, error) { 235 return paginateQueryInternal[T, PT](query, args, ordering, pagination, false, true) 236 } 237 238 func PaginateQueryWithoutOrderBy[T any, PT parserPtr[T]]( 239 query string, 240 args []interface{}, 241 ordering TableOrdering, 242 pagination entities.CursorPagination, 243 ) (string, []interface{}, error) { 244 return paginateQueryInternal[T, PT](query, args, ordering, pagination, true, false) 245 } 246 247 func paginateQueryInternal[T any, PT parserPtr[T]]( 248 query string, 249 args []interface{}, 250 ordering TableOrdering, 251 pagination entities.CursorPagination, 252 omitOrderBy bool, 253 forceWhere bool, 254 ) (string, []interface{}, error) { 255 // Extract a cursor struct from the pagination struct 256 cursor, err := parseCursor[T, PT](pagination) 257 if err != nil { 258 return "", nil, fmt.Errorf("parsing cursor: %w", err) 259 } 260 261 // If we're fetching rows before the cursor, reverse the ordering 262 if (pagination.HasBackward() && !pagination.NewestFirst) || // Navigating backwards in time order 263 (pagination.HasForward() && pagination.NewestFirst) || // Navigating forward in reverse time order 264 (!pagination.HasBackward() && !pagination.HasForward() && pagination.NewestFirst) { // No pagination provided, but in reverse time order 265 ordering = ordering.Reversed() 266 } 267 268 // If the cursor wasn't empty, exclude rows preceding the cursor's row 269 var emptyCursor T 270 isEmpty, err := equals[T](cursor, emptyCursor) 271 if err != nil { 272 return "", nil, fmt.Errorf("checking empty cursor: %w", err) 273 } 274 if !isEmpty { 275 whereOrAnd := "WHERE" 276 if !forceWhere && strings.Contains(strings.ToUpper(query), "WHERE") { 277 whereOrAnd = "AND" 278 } 279 280 var predicate string 281 predicate, args, err = CursorPredicate(args, cursor, ordering) 282 if err != nil { 283 return "", nil, fmt.Errorf("building cursor predicate: %w", err) 284 } 285 query = fmt.Sprintf("%s %s (%s)", query, whereOrAnd, predicate) 286 } 287 288 // Add an ORDER BY clause if requested 289 if !omitOrderBy { 290 query = fmt.Sprintf("%s %s", query, ordering.OrderByClause()) 291 } 292 293 // And a LIMIT clause 294 limit := calculateLimit(pagination) 295 if limit != 0 { 296 query = fmt.Sprintf("%s LIMIT %d", query, limit) 297 } 298 299 return query, args, nil 300 } 301 302 func parseCursor[T any, PT parserPtr[T]](pagination entities.CursorPagination) (T, error) { 303 cursor := PT(new(T)) 304 305 cursorStr := "" 306 if pagination.HasForward() && pagination.Forward.HasCursor() { 307 cursorStr = pagination.Forward.Cursor.Value() 308 } else if pagination.HasBackward() && pagination.Backward.HasCursor() { 309 cursorStr = pagination.Backward.Cursor.Value() 310 } 311 312 if cursorStr != "" { 313 err := cursor.Parse(cursorStr) 314 if err != nil { 315 return *cursor, fmt.Errorf("parsing cursor: %w", err) 316 } 317 } 318 return *cursor, nil 319 } 320 321 type CursorQueryParameter struct { 322 ColumnName string 323 Sort Sorting 324 Cmp Compare 325 Value any 326 } 327 328 func NewCursorQueryParameter(columnName string, sort Sorting, cmp Compare, value any) CursorQueryParameter { 329 return CursorQueryParameter{ 330 ColumnName: columnName, 331 Sort: sort, 332 Cmp: cmp, 333 Value: value, 334 } 335 } 336 337 func (c CursorQueryParameter) Where(args ...interface{}) (string, []interface{}) { 338 if c.Cmp == "" || c.Value == nil { 339 return "", args 340 } 341 342 where := fmt.Sprintf("%s %s %v", c.ColumnName, c.Cmp, nextBindVar(&args, c.Value)) 343 return where, args 344 } 345 346 func (c CursorQueryParameter) OrderBy() string { 347 return fmt.Sprintf("%s %s", c.ColumnName, c.Sort) 348 } 349 350 type CursorQueryParameters []CursorQueryParameter 351 352 func (c CursorQueryParameters) Where(args ...interface{}) (string, []interface{}) { 353 var where string 354 355 for i, cursor := range c { 356 var cursorCondition string 357 cursorCondition, args = cursor.Where(args...) 358 if i > 0 && strings.TrimSpace(where) != "" && strings.TrimSpace(cursorCondition) != "" { 359 where = fmt.Sprintf("%s AND", where) 360 } 361 where = fmt.Sprintf("%s %s", where, cursorCondition) 362 } 363 364 return strings.TrimSpace(where), args 365 } 366 367 func (c CursorQueryParameters) OrderBy() string { 368 var orderBy string 369 370 for i, cursor := range c { 371 if i > 0 { 372 orderBy = fmt.Sprintf("%s,", orderBy) 373 } 374 orderBy = fmt.Sprintf("%s %s", orderBy, cursor.OrderBy()) 375 } 376 377 return strings.TrimSpace(orderBy) 378 }