github.com/team-ide/go-dialect@v1.9.20/vitess/sqlparser/parsed_query.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sqlparser
    18  
    19  import (
    20  	"encoding/json"
    21  	"fmt"
    22  	"strings"
    23  
    24  	"github.com/team-ide/go-dialect/vitess/vterrors"
    25  	vtrpcpb "github.com/team-ide/go-dialect/vitess/vtrpc"
    26  
    27  	"github.com/team-ide/go-dialect/vitess/bytes2"
    28  
    29  	"github.com/team-ide/go-dialect/vitess/sqltypes"
    30  
    31  	querypb "github.com/team-ide/go-dialect/vitess/query"
    32  )
    33  
    34  // ParsedQuery represents a parsed query where
    35  // bind locations are precomputed for fast substitutions.
    36  type ParsedQuery struct {
    37  	Query         string
    38  	bindLocations []bindLocation
    39  }
    40  
    41  type bindLocation struct {
    42  	offset, length int
    43  }
    44  
    45  // NewParsedQuery returns a ParsedQuery of the ast.
    46  func NewParsedQuery(node SQLNode) *ParsedQuery {
    47  	buf := NewTrackedBuffer(nil)
    48  	buf.Myprintf("%v", node)
    49  	return buf.ParsedQuery()
    50  }
    51  
    52  // GenerateQuery generates a query by substituting the specified
    53  // bindVariables. The extras parameter specifies special parameters
    54  // that can perform custom encoding.
    55  func (pq *ParsedQuery) GenerateQuery(bindVariables map[string]*querypb.BindVariable, extras map[string]Encodable) (string, error) {
    56  	if len(pq.bindLocations) == 0 {
    57  		return pq.Query, nil
    58  	}
    59  	var buf strings.Builder
    60  	buf.Grow(len(pq.Query))
    61  	if err := pq.Append(&buf, bindVariables, extras); err != nil {
    62  		return "", err
    63  	}
    64  	return buf.String(), nil
    65  }
    66  
    67  // Append appends the generated query to the provided buffer.
    68  func (pq *ParsedQuery) Append(buf *strings.Builder, bindVariables map[string]*querypb.BindVariable, extras map[string]Encodable) error {
    69  	current := 0
    70  	for _, loc := range pq.bindLocations {
    71  		buf.WriteString(pq.Query[current:loc.offset])
    72  		name := pq.Query[loc.offset : loc.offset+loc.length]
    73  		if encodable, ok := extras[name[1:]]; ok {
    74  			encodable.EncodeSQL(buf)
    75  		} else {
    76  			supplied, _, err := FetchBindVar(name, bindVariables)
    77  			if err != nil {
    78  				return err
    79  			}
    80  			EncodeValue(buf, supplied)
    81  		}
    82  		current = loc.offset + loc.length
    83  	}
    84  	buf.WriteString(pq.Query[current:])
    85  	return nil
    86  }
    87  
    88  // AppendFromRow behaves like Append but takes a querypb.Row directly, assuming that
    89  // the fields in the row are in the same order as the placeholders in this query. The fields might include generated
    90  // columns which are dropped, by checking against skipFields, before binding the variables
    91  // note: there can be more fields than bind locations since extra columns might be requested from the source if not all
    92  // primary keys columns are present in the target table, for example. Also some values in the row may not correspond for
    93  // values from the database on the source: sum/count for aggregation queries, for example
    94  func (pq *ParsedQuery) AppendFromRow(buf *bytes2.Buffer, fields []*querypb.Field, row *querypb.Row, skipFields map[string]bool) error {
    95  	if len(fields) < len(pq.bindLocations) {
    96  		return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "wrong number of fields: got %d fields for %d bind locations ",
    97  			len(fields), len(pq.bindLocations))
    98  	}
    99  
   100  	type colInfo struct {
   101  		typ    querypb.Type
   102  		length int64
   103  		offset int64
   104  	}
   105  	rowInfo := make([]*colInfo, 0)
   106  
   107  	offset := int64(0)
   108  	for i, field := range fields { // collect info required for fields to be bound
   109  		length := row.Lengths[i]
   110  		if !skipFields[strings.ToLower(field.Name)] {
   111  			rowInfo = append(rowInfo, &colInfo{
   112  				typ:    field.Type,
   113  				length: length,
   114  				offset: offset,
   115  			})
   116  		}
   117  		if length > 0 {
   118  			offset += row.Lengths[i]
   119  		}
   120  	}
   121  
   122  	// bind field values to locations
   123  	var offsetQuery int
   124  	for i, loc := range pq.bindLocations {
   125  		col := rowInfo[i]
   126  		buf.WriteString(pq.Query[offsetQuery:loc.offset])
   127  
   128  		typ := col.typ
   129  		if typ == querypb.Type_TUPLE {
   130  			return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unexpected Type_TUPLE for value %d", i)
   131  		}
   132  
   133  		length := col.length
   134  		if length < 0 {
   135  			// -1 means a null variable; serialize it directly
   136  			buf.WriteString("null")
   137  		} else {
   138  			vv := sqltypes.MakeTrusted(typ, row.Values[col.offset:col.offset+col.length])
   139  			vv.EncodeSQLBytes2(buf)
   140  		}
   141  
   142  		offsetQuery = loc.offset + loc.length
   143  	}
   144  	buf.WriteString(pq.Query[offsetQuery:])
   145  	return nil
   146  }
   147  
   148  // MarshalJSON is a custom JSON marshaler for ParsedQuery.
   149  // Note that any queries longer that 512 bytes will be truncated.
   150  func (pq *ParsedQuery) MarshalJSON() ([]byte, error) {
   151  	return json.Marshal(TruncateForUI(pq.Query))
   152  }
   153  
   154  // EncodeValue encodes one bind variable value into the query.
   155  func EncodeValue(buf *strings.Builder, value *querypb.BindVariable) {
   156  	if value.Type != querypb.Type_TUPLE {
   157  		// Since we already check for TUPLE, we don't expect an error.
   158  		v, _ := sqltypes.BindVariableToValue(value)
   159  		v.EncodeSQLStringBuilder(buf)
   160  		return
   161  	}
   162  
   163  	// It's a TUPLE.
   164  	buf.WriteByte('(')
   165  	for i, bv := range value.Values {
   166  		if i != 0 {
   167  			buf.WriteString(", ")
   168  		}
   169  		sqltypes.ProtoToValue(bv).EncodeSQLStringBuilder(buf)
   170  	}
   171  	buf.WriteByte(')')
   172  }
   173  
   174  // FetchBindVar resolves the bind variable by fetching it from bindVariables.
   175  func FetchBindVar(name string, bindVariables map[string]*querypb.BindVariable) (val *querypb.BindVariable, isList bool, err error) {
   176  	name = name[1:]
   177  	if name[0] == ':' {
   178  		name = name[1:]
   179  		isList = true
   180  	}
   181  	supplied, ok := bindVariables[name]
   182  	if !ok {
   183  		return nil, false, fmt.Errorf("missing bind var %s", name)
   184  	}
   185  
   186  	if isList {
   187  		if supplied.Type != querypb.Type_TUPLE {
   188  			return nil, false, fmt.Errorf("unexpected list arg type (%v) for key %s", supplied.Type, name)
   189  		}
   190  		if len(supplied.Values) == 0 {
   191  			return nil, false, fmt.Errorf("empty list supplied for %s", name)
   192  		}
   193  		return supplied, true, nil
   194  	}
   195  
   196  	if supplied.Type == querypb.Type_TUPLE {
   197  		return nil, false, fmt.Errorf("unexpected arg type (TUPLE) for non-list key %s", name)
   198  	}
   199  
   200  	return supplied, false, nil
   201  }
   202  
   203  // ParseAndBind is a one step sweep that binds variables to an input query, in order of placeholders.
   204  // It is useful when one doesn't have any parser-variables, just bind variables.
   205  // Example:
   206  //   query, err := ParseAndBind("select * from tbl where name=%a", sqltypes.StringBindVariable("it's me"))
   207  func ParseAndBind(in string, binds ...*querypb.BindVariable) (query string, err error) {
   208  	vars := make([]interface{}, len(binds))
   209  	for i := range binds {
   210  		vars[i] = fmt.Sprintf(":var%d", i)
   211  	}
   212  	parsed := BuildParsedQuery(in, vars...)
   213  
   214  	bindVars := map[string]*querypb.BindVariable{}
   215  	for i := range binds {
   216  		bindVars[fmt.Sprintf("var%d", i)] = binds[i]
   217  	}
   218  	return parsed.GenerateQuery(bindVars, nil)
   219  }