code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/cursor.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore
    17  
    18  import (
    19  	"bytes"
    20  	"encoding/gob"
    21  	"fmt"
    22  	"strings"
    23  
    24  	"code.vegaprotocol.io/vega/datanode/entities"
    25  )
    26  
    27  type (
    28  	Sorting = string
    29  	Compare = string
    30  )
    31  
    32  const (
    33  	ASC  Sorting = "ASC"
    34  	DESC Sorting = "DESC"
    35  
    36  	EQ Compare = "="
    37  	NE Compare = "!="
    38  	GT Compare = ">"
    39  	LT Compare = "<"
    40  	GE Compare = ">="
    41  	LE Compare = "<="
    42  )
    43  
    44  type ColumnOrdering struct {
    45  	// Name of the column in the database table to match to the struct field
    46  	Name string
    47  	// Sorting is the sorting order to use for the column
    48  	Sorting Sorting
    49  	// Prefix is the prefix to add to the column name in order to resolve duplicate
    50  	// column names that might be in the query
    51  	Prefix string
    52  	// If the column originates from parsing a JSON field, how it should be referenced in the query.
    53  	Ref string
    54  }
    55  
    56  func NewColumnOrdering(name string, sorting Sorting) ColumnOrdering {
    57  	return ColumnOrdering{Name: name, Sorting: sorting}
    58  }
    59  
    60  type TableOrdering []ColumnOrdering
    61  
    62  func (t *TableOrdering) OrderByClause() string {
    63  	if len(*t) == 0 {
    64  		return ""
    65  	}
    66  
    67  	fragments := make([]string, len(*t))
    68  	for i, column := range *t {
    69  		prefix := column.Prefix
    70  		if column.Prefix != "" && !strings.HasSuffix(column.Prefix, ".") {
    71  			prefix += "."
    72  		}
    73  		fragments[i] = fmt.Sprintf("%s%s %s", prefix, column.Name, column.Sorting)
    74  	}
    75  	return fmt.Sprintf("ORDER BY %s", strings.Join(fragments, ","))
    76  }
    77  
    78  func (t *TableOrdering) Reversed() TableOrdering {
    79  	reversed := make([]ColumnOrdering, len(*t))
    80  	for i, column := range *t {
    81  		if column.Sorting == DESC {
    82  			reversed[i] = ColumnOrdering{Name: column.Name, Sorting: ASC, Ref: column.Ref}
    83  		}
    84  		if column.Sorting == ASC {
    85  			reversed[i] = ColumnOrdering{Name: column.Name, Sorting: DESC, Ref: column.Ref}
    86  		}
    87  	}
    88  	return reversed
    89  }
    90  
    91  // SetPrefixAll sets a prefix for all columns in the table ordering slice.
    92  func (t *TableOrdering) SetPrefixAll(pf string) {
    93  	if len(*t) == 0 {
    94  		return
    95  	}
    96  	// need to cast to underlying slice type to be able to re-assign elements.
    97  	ts := []ColumnOrdering(*t)
    98  	for i, col := range *t {
    99  		col.Prefix = pf
   100  		ts[i] = col
   101  	}
   102  	// cast is needed here, if not the unit test fails.
   103  	*t = TableOrdering(ts)
   104  }
   105  
   106  // CursorPredicate generates an SQL predicate which excludes all rows before the supplied cursor,
   107  // with regards to the supplied table ordering. The values used for comparison are added to
   108  // the args list and bind variables used in the query fragment.
   109  //
   110  // For example, with if you had a query with columns sorted foo ASCENDING, bar DESCENDING and a
   111  // cursor with {foo=1, bar=2}, it would yield a string predicate like this:
   112  //
   113  // (foo > $1) OR (foo = $1 AND bar <= $2)
   114  //
   115  // And 'args' would have 1 and 2 appended to it.
   116  //
   117  // Notes:
   118  //   - The predicate *includes* the value at the cursor
   119  //   - Only fields that are present in both the cursor and the ordering are considered
   120  //   - The union of those fields must have enough information to uniquely identify a row
   121  //   - The table ordering must be sufficient to ensure that a row identified by a cursor cannot
   122  //     change position in relation to the other rows
   123  func CursorPredicate(args []interface{}, cursor interface{}, ordering TableOrdering) (string, []interface{}, error) {
   124  	cursorPredicates := []string{}
   125  	equalPredicates := []string{}
   126  
   127  	for i, column := range ordering {
   128  		// For the non-last columns, use LT/GT, so we don't include stuff before the cursor
   129  		var operator string
   130  		if column.Sorting == ASC {
   131  			operator = ">"
   132  		} else if column.Sorting == DESC {
   133  			operator = "<"
   134  		} else {
   135  			return "", nil, fmt.Errorf("unknown sort direction %s", column.Sorting)
   136  		}
   137  
   138  		// For the last column, we want to use GTE/LTE so we include the value at the cursor
   139  		isLast := i == (len(ordering) - 1)
   140  		if isLast {
   141  			operator = operator + "="
   142  		}
   143  
   144  		value, err := StructValueForColumn(cursor, column.Name)
   145  		if err != nil {
   146  			return "", nil, err
   147  		}
   148  
   149  		prefix := column.Prefix
   150  		if column.Prefix != "" && !strings.HasSuffix(column.Prefix, ".") {
   151  			prefix += "."
   152  		}
   153  
   154  		bindVar := nextBindVar(&args, value)
   155  		ref := column.Name
   156  		if len(column.Ref) > 0 {
   157  			ref = column.Ref
   158  		}
   159  		inequalityPredicate := fmt.Sprintf("%s%s %s %s", prefix, ref, operator, bindVar)
   160  
   161  		colPredicates := append(equalPredicates, inequalityPredicate)
   162  		colPredicateString := strings.Join(colPredicates, " AND ")
   163  		colPredicateString = fmt.Sprintf("(%s)", colPredicateString)
   164  		cursorPredicates = append(cursorPredicates, colPredicateString)
   165  
   166  		equalityPredicate := fmt.Sprintf("%s%s = %s", prefix, ref, bindVar)
   167  		equalPredicates = append(equalPredicates, equalityPredicate)
   168  	}
   169  
   170  	predicateString := strings.Join(cursorPredicates, " OR ")
   171  
   172  	return predicateString, args, nil
   173  }
   174  
   175  type parser interface {
   176  	Parse(string) error
   177  }
   178  
   179  // This is a bit magical, it allows us to use the real cursor type for instantiation and the pointer
   180  // type for calling methods with pointer receivers (e.g. Parse) for details see
   181  // https://go.googlesource.com/proposal/+/refs/heads/master/design/43651-type-parameters.md#pointer-method-example
   182  type parserPtr[T any] interface {
   183  	parser
   184  	*T
   185  }
   186  
   187  // We have to roll our own equals function here for comparing the cursors because some cursor parameters use
   188  // types that do not implement `comparable`.
   189  func equals[T any](actual, other T) (bool, error) {
   190  	var a, b bytes.Buffer
   191  	enc := gob.NewEncoder(&a)
   192  	err := enc.Encode(actual)
   193  	if err != nil {
   194  		return false, err
   195  	}
   196  
   197  	enc = gob.NewEncoder(&b)
   198  	err = enc.Encode(other)
   199  	if err != nil {
   200  		return false, err
   201  	}
   202  
   203  	return bytes.Equal(a.Bytes(), b.Bytes()), nil
   204  }
   205  
   206  // PaginateQuery takes a query string & bind arg list and returns the same with additional SQL to
   207  //   - exclude rows before the cursor (or after it if the cursor is a backwards looking one)
   208  //   - limit the number of rows to the pagination limit +1 (no cursor) or +2 (cursor)
   209  //     [for purposes of later figuring out whether there are next or previous pages]
   210  //   - order the query according to the TableOrdering supplied
   211  //     the order is reversed if pagination request is backwards
   212  //
   213  // For example with cursor to a row where foo=42, and a pagination saying get the next 3 then:
   214  // PaginateQuery[MyCursor]("SELECT foo FROM my_table", args, ordering, pagination)
   215  //
   216  // Would append `42` to the arg list and return
   217  // SELECT foo FROM my_table WHERE foo>=$1 ORDER BY foo ASC LIMIT 5
   218  //
   219  // See CursorPredicate() for more details about how the cursor filtering is done.
   220  func PaginateQuery[T any, PT parserPtr[T]](
   221  	query string,
   222  	args []interface{},
   223  	ordering TableOrdering,
   224  	pagination entities.CursorPagination,
   225  ) (string, []interface{}, error) {
   226  	return paginateQueryInternal[T, PT](query, args, ordering, pagination, false, false)
   227  }
   228  
   229  func PaginateQueryWithWhere[T any, PT parserPtr[T]](
   230  	query string,
   231  	args []interface{},
   232  	ordering TableOrdering,
   233  	pagination entities.CursorPagination,
   234  ) (string, []interface{}, error) {
   235  	return paginateQueryInternal[T, PT](query, args, ordering, pagination, false, true)
   236  }
   237  
   238  func PaginateQueryWithoutOrderBy[T any, PT parserPtr[T]](
   239  	query string,
   240  	args []interface{},
   241  	ordering TableOrdering,
   242  	pagination entities.CursorPagination,
   243  ) (string, []interface{}, error) {
   244  	return paginateQueryInternal[T, PT](query, args, ordering, pagination, true, false)
   245  }
   246  
   247  func paginateQueryInternal[T any, PT parserPtr[T]](
   248  	query string,
   249  	args []interface{},
   250  	ordering TableOrdering,
   251  	pagination entities.CursorPagination,
   252  	omitOrderBy bool,
   253  	forceWhere bool,
   254  ) (string, []interface{}, error) {
   255  	// Extract a cursor struct from the pagination struct
   256  	cursor, err := parseCursor[T, PT](pagination)
   257  	if err != nil {
   258  		return "", nil, fmt.Errorf("parsing cursor: %w", err)
   259  	}
   260  
   261  	// If we're fetching rows before the cursor, reverse the ordering
   262  	if (pagination.HasBackward() && !pagination.NewestFirst) || // Navigating backwards in time order
   263  		(pagination.HasForward() && pagination.NewestFirst) || // Navigating forward in reverse time order
   264  		(!pagination.HasBackward() && !pagination.HasForward() && pagination.NewestFirst) { // No pagination provided, but in reverse time order
   265  		ordering = ordering.Reversed()
   266  	}
   267  
   268  	// If the cursor wasn't empty, exclude rows preceding the cursor's row
   269  	var emptyCursor T
   270  	isEmpty, err := equals[T](cursor, emptyCursor)
   271  	if err != nil {
   272  		return "", nil, fmt.Errorf("checking empty cursor: %w", err)
   273  	}
   274  	if !isEmpty {
   275  		whereOrAnd := "WHERE"
   276  		if !forceWhere && strings.Contains(strings.ToUpper(query), "WHERE") {
   277  			whereOrAnd = "AND"
   278  		}
   279  
   280  		var predicate string
   281  		predicate, args, err = CursorPredicate(args, cursor, ordering)
   282  		if err != nil {
   283  			return "", nil, fmt.Errorf("building cursor predicate: %w", err)
   284  		}
   285  		query = fmt.Sprintf("%s %s (%s)", query, whereOrAnd, predicate)
   286  	}
   287  
   288  	// Add an ORDER BY clause if requested
   289  	if !omitOrderBy {
   290  		query = fmt.Sprintf("%s %s", query, ordering.OrderByClause())
   291  	}
   292  
   293  	// And a LIMIT clause
   294  	limit := calculateLimit(pagination)
   295  	if limit != 0 {
   296  		query = fmt.Sprintf("%s LIMIT %d", query, limit)
   297  	}
   298  
   299  	return query, args, nil
   300  }
   301  
   302  func parseCursor[T any, PT parserPtr[T]](pagination entities.CursorPagination) (T, error) {
   303  	cursor := PT(new(T))
   304  
   305  	cursorStr := ""
   306  	if pagination.HasForward() && pagination.Forward.HasCursor() {
   307  		cursorStr = pagination.Forward.Cursor.Value()
   308  	} else if pagination.HasBackward() && pagination.Backward.HasCursor() {
   309  		cursorStr = pagination.Backward.Cursor.Value()
   310  	}
   311  
   312  	if cursorStr != "" {
   313  		err := cursor.Parse(cursorStr)
   314  		if err != nil {
   315  			return *cursor, fmt.Errorf("parsing cursor: %w", err)
   316  		}
   317  	}
   318  	return *cursor, nil
   319  }
   320  
   321  type CursorQueryParameter struct {
   322  	ColumnName string
   323  	Sort       Sorting
   324  	Cmp        Compare
   325  	Value      any
   326  }
   327  
   328  func NewCursorQueryParameter(columnName string, sort Sorting, cmp Compare, value any) CursorQueryParameter {
   329  	return CursorQueryParameter{
   330  		ColumnName: columnName,
   331  		Sort:       sort,
   332  		Cmp:        cmp,
   333  		Value:      value,
   334  	}
   335  }
   336  
   337  func (c CursorQueryParameter) Where(args ...interface{}) (string, []interface{}) {
   338  	if c.Cmp == "" || c.Value == nil {
   339  		return "", args
   340  	}
   341  
   342  	where := fmt.Sprintf("%s %s %v", c.ColumnName, c.Cmp, nextBindVar(&args, c.Value))
   343  	return where, args
   344  }
   345  
   346  func (c CursorQueryParameter) OrderBy() string {
   347  	return fmt.Sprintf("%s %s", c.ColumnName, c.Sort)
   348  }
   349  
   350  type CursorQueryParameters []CursorQueryParameter
   351  
   352  func (c CursorQueryParameters) Where(args ...interface{}) (string, []interface{}) {
   353  	var where string
   354  
   355  	for i, cursor := range c {
   356  		var cursorCondition string
   357  		cursorCondition, args = cursor.Where(args...)
   358  		if i > 0 && strings.TrimSpace(where) != "" && strings.TrimSpace(cursorCondition) != "" {
   359  			where = fmt.Sprintf("%s AND", where)
   360  		}
   361  		where = fmt.Sprintf("%s %s", where, cursorCondition)
   362  	}
   363  
   364  	return strings.TrimSpace(where), args
   365  }
   366  
   367  func (c CursorQueryParameters) OrderBy() string {
   368  	var orderBy string
   369  
   370  	for i, cursor := range c {
   371  		if i > 0 {
   372  			orderBy = fmt.Sprintf("%s,", orderBy)
   373  		}
   374  		orderBy = fmt.Sprintf("%s %s", orderBy, cursor.OrderBy())
   375  	}
   376  
   377  	return strings.TrimSpace(orderBy)
   378  }