github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqlparse/sqlparser/tracked_buffer.go (about)

     1  /*
     2  Copyright 2017 Google Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sqlparser
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  )
    23  
    24  // TrackedBuffer is used to rebuild a query from the ast.
    25  // bindLocations keeps track of locations in the buffer that
    26  // use bind variables for efficient future substitutions.
    27  // nodeFormatter is the formatting function the buffer will
    28  // use to format a node. By default(nil), it's FormatNode.
    29  // But you can supply a different formatting function if you
    30  // want to generate a query that's different from the default.
    31  type TrackedBuffer struct {
    32  	IdQuoter
    33  	PlaceholderFormatter
    34  
    35  	*bytes.Buffer
    36  	bindLocations []bindLocation
    37  	nodeFormatter func(buf *TrackedBuffer, node SQLNode)
    38  }
    39  
    40  // NewTrackedBuffer creates a new TrackedBuffer.
    41  func NewTrackedBuffer(nodeFormatter func(buf *TrackedBuffer, node SQLNode)) *TrackedBuffer {
    42  	return &TrackedBuffer{
    43  		Buffer:        new(bytes.Buffer),
    44  		nodeFormatter: nodeFormatter,
    45  	}
    46  }
    47  
    48  // WriteNode function, initiates the writing of a single SQLNode tree by passing
    49  // through to Myprintf with a default format string
    50  func (buf *TrackedBuffer) WriteNode(node SQLNode) *TrackedBuffer {
    51  	buf.Myprintf("%v", node)
    52  	return buf
    53  }
    54  
    55  // Myprintf mimics fmt.Fprintf(buf, ...), but limited to Node(%v),
    56  // Node.Value(%s) and string(%s). It also allows a %a for a value argument, in
    57  // which case it adds tracking info for future substitutions.
    58  //
    59  // The name must be something other than the usual Printf() to avoid "go vet"
    60  // warnings due to our custom format specifiers.
    61  func (buf *TrackedBuffer) Myprintf(format string, values ...interface{}) {
    62  	end := len(format)
    63  	fieldnum := 0
    64  	for i := 0; i < end; {
    65  		lasti := i
    66  		for i < end && format[i] != '%' {
    67  			i++
    68  		}
    69  		if i > lasti {
    70  			buf.WriteString(format[lasti:i])
    71  		}
    72  		if i >= end {
    73  			break
    74  		}
    75  		i++ // '%'
    76  		switch format[i] {
    77  		case 'c':
    78  			switch v := values[fieldnum].(type) {
    79  			case byte:
    80  				buf.WriteByte(v)
    81  			case rune:
    82  				buf.WriteRune(v)
    83  			default:
    84  				panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
    85  			}
    86  		case 's':
    87  			switch v := values[fieldnum].(type) {
    88  			case []byte:
    89  				buf.Write(v)
    90  			case string:
    91  				buf.WriteString(v)
    92  			default:
    93  				panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
    94  			}
    95  		case 'v':
    96  			node := values[fieldnum].(SQLNode)
    97  			if buf.nodeFormatter == nil {
    98  				node.Format(buf)
    99  			} else {
   100  				buf.nodeFormatter(buf, node)
   101  			}
   102  		case 'a':
   103  			buf.WriteArg(values[fieldnum].(string))
   104  		default:
   105  			panic("unexpected")
   106  		}
   107  		fieldnum++
   108  		i++
   109  	}
   110  }
   111  
   112  // WriteArg writes a value argument into the buffer along with
   113  // tracking information for future substitutions. arg must contain
   114  // the ":" or "::" prefix.
   115  func (buf *TrackedBuffer) WriteArg(arg string) {
   116  	buf.bindLocations = append(buf.bindLocations, bindLocation{
   117  		offset: buf.Len(),
   118  		length: len(arg),
   119  	})
   120  	buf.WriteString(arg)
   121  }
   122  
   123  // ParsedQuery returns a ParsedQuery that contains bind
   124  // locations for easy substitution.
   125  func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery {
   126  	return &ParsedQuery{Query: buf.String(), bindLocations: buf.bindLocations}
   127  }
   128  
   129  // HasBindVars returns true if the parsed query uses bind vars.
   130  func (buf *TrackedBuffer) HasBindVars() bool {
   131  	return len(buf.bindLocations) != 0
   132  }
   133  
   134  // BuildParsedQuery builds a ParsedQuery from the input.
   135  func BuildParsedQuery(in string, vars ...interface{}) *ParsedQuery {
   136  	buf := NewTrackedBuffer(nil)
   137  	buf.Myprintf(in, vars...)
   138  	return buf.ParsedQuery()
   139  }