github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/parsers/sqlparse.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package parsers
    16  
    17  import (
    18  	"context"
    19  	gotrace "runtime/trace"
    20  	"strings"
    21  
    22  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    23  	"github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect"
    24  	"github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect/mysql"
    25  	"github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect/postgresql"
    26  	"github.com/matrixorigin/matrixone/pkg/sql/parsers/tree"
    27  )
    28  
    29  func Parse(ctx context.Context, dialectType dialect.DialectType, sql string, lower int64, useOrigin int64) ([]tree.Statement, error) {
    30  	_, task := gotrace.NewTask(context.TODO(), "parser.Parse")
    31  	defer task.End()
    32  	switch dialectType {
    33  	case dialect.MYSQL:
    34  		return mysql.Parse(ctx, sql, lower, useOrigin)
    35  	case dialect.POSTGRESQL:
    36  		return postgresql.Parse(ctx, sql)
    37  	default:
    38  		return nil, moerr.NewInternalError(ctx, "type of dialect error")
    39  	}
    40  }
    41  
    42  func ParseOne(ctx context.Context, dialectType dialect.DialectType, sql string, lower int64, useOrigin int64) (tree.Statement, error) {
    43  	switch dialectType {
    44  	case dialect.MYSQL:
    45  		return mysql.ParseOne(ctx, sql, lower, useOrigin)
    46  	case dialect.POSTGRESQL:
    47  		return postgresql.ParseOne(ctx, sql)
    48  	default:
    49  		return nil, moerr.NewInternalError(ctx, "type of dialect error")
    50  	}
    51  }
    52  
    53  const (
    54  	stripCloudUser           = "/* cloud_user */"
    55  	stripCloudUserContent    = "cloud_user"
    56  	stripCloudNonUser        = "/* cloud_nonuser */"
    57  	stripCloudNonUserContent = "cloud_nonuser"
    58  	stripSaveQuery           = "/* save_result */"
    59  	stripSaveQueryContent    = "save_result"
    60  )
    61  
    62  var stripContents = map[string]int8{
    63  	stripCloudUserContent:    0,
    64  	stripCloudNonUserContent: 0,
    65  	stripSaveQueryContent:    0,
    66  }
    67  
    68  var HandleSqlForRecord = func(sql string) []string {
    69  	split := SplitSqlBySemicolon(sql)
    70  	for i := range split {
    71  		stripScanner := mysql.NewScanner(dialect.MYSQL, split[i])
    72  		//strip needed comment "/*XXX*/"
    73  		var commentIdx [][]int
    74  		for stripScanner.Pos < len(split[i]) {
    75  			typ, comment := stripScanner.ScanComment()
    76  			if typ == mysql.COMMENT {
    77  				//only strip needed comment "/*XXX*/"
    78  				if strings.HasPrefix(comment, "/*") && strings.HasSuffix(comment, "*/") {
    79  					commentContent := strings.ToLower(strings.TrimSpace(comment[2 : len(comment)-2]))
    80  					if _, ok := stripContents[commentContent]; ok {
    81  						commentIdx = append(commentIdx, []int{stripScanner.Pos - len(comment), stripScanner.Pos})
    82  					}
    83  				}
    84  			} else if typ == mysql.EofChar() || typ == mysql.LEX_ERROR {
    85  				break
    86  			}
    87  		}
    88  
    89  		if len(commentIdx) > 0 {
    90  			var builder strings.Builder
    91  			for j := 0; j < len(commentIdx); j++ {
    92  				if j == 0 {
    93  					builder.WriteString(split[i][0:commentIdx[j][0]])
    94  				} else {
    95  					builder.WriteString(split[i][commentIdx[j-1][1]:commentIdx[j][0]])
    96  				}
    97  			}
    98  
    99  			builder.WriteString(split[i][commentIdx[len(commentIdx)-1][1]:len(split[i])])
   100  			split[i] = strings.TrimSpace(builder.String())
   101  		}
   102  
   103  		// Hide secret key for split[i],
   104  		// for example:
   105  		// before: create account nihao admin_name 'admin' identified with '123'
   106  		// after: create account nihao admin_name 'admin' identified with '******'
   107  
   108  		// Slice indexes helps to get the final ranges from split[i],
   109  		// for example:
   110  		// Secret keys' indexes ranges in split[i] are:
   111  		// 1, 2, 3, 3
   112  		// These mean [1, 2] and [3, 3] in split[i] are secret keys
   113  		// And if len(split[i]) is 10, then we get slice indexes:
   114  		// -1, 1, 2, 3, 3, 10
   115  		// These mean we need to get (-1, 1), (2, 3), (3, 10) from split[i]
   116  		scanner := mysql.NewScanner(dialect.MYSQL, split[i])
   117  		indexes := []int{-1}
   118  		eq := int('=')
   119  		for scanner.Pos < len(split[i]) {
   120  			typ, s := scanner.Scan()
   121  			if typ == mysql.IDENTIFIED {
   122  				typ, _ = scanner.Scan()
   123  				if typ == mysql.BY || typ == mysql.WITH {
   124  					typ, s = scanner.Scan()
   125  					if typ != mysql.RANDOM {
   126  						indexes = append(indexes, scanner.Pos-len(s)-1, scanner.Pos-2)
   127  					}
   128  				}
   129  			} else if strings.ToLower(s) == "access_key_id" || strings.ToLower(s) == "secret_access_key" {
   130  				typ, _ = scanner.Scan()
   131  				if typ == eq {
   132  					_, s = scanner.Scan()
   133  					indexes = append(indexes, scanner.Pos-len(s)-1, scanner.Pos-2)
   134  				}
   135  			}
   136  		}
   137  		indexes = append(indexes, len(split[i]))
   138  
   139  		if len(indexes) > 2 {
   140  			var builder strings.Builder
   141  			for j := 0; j < len(indexes); j += 2 {
   142  				builder.WriteString(split[i][indexes[j]+1 : indexes[j+1]])
   143  				if j < len(indexes)-2 {
   144  					builder.WriteString("******")
   145  				}
   146  			}
   147  			split[i] = builder.String()
   148  		}
   149  		split[i] = strings.TrimSpace(split[i])
   150  	}
   151  	return split
   152  }
   153  
   154  func SplitSqlBySemicolon(sql string) []string {
   155  	var ret []string
   156  	if len(sql) == 0 {
   157  		// case 1 : "" => [""]
   158  		return []string{sql}
   159  	}
   160  	scanner := mysql.NewScanner(dialect.MYSQL, sql)
   161  	lastEnd := 0
   162  	endWithSemicolon := false
   163  	for scanner.Pos < len(sql) {
   164  		typ, _ := scanner.Scan()
   165  		for scanner.Pos < len(sql) && typ != ';' {
   166  			typ, _ = scanner.Scan()
   167  		}
   168  		if typ == ';' {
   169  			ret = append(ret, sql[lastEnd:scanner.Pos-1])
   170  			lastEnd = scanner.Pos
   171  			endWithSemicolon = true
   172  		} else {
   173  			ret = append(ret, sql[lastEnd:scanner.Pos])
   174  			endWithSemicolon = false
   175  		}
   176  	}
   177  
   178  	if len(ret) == 0 {
   179  		//!!!NOTE there is at least one element in ret slice
   180  		panic("there is at least one element")
   181  	}
   182  	//handle whitespace characters in the front and end of the sql
   183  	for i := range ret {
   184  		ret[i] = strings.TrimSpace(ret[i])
   185  	}
   186  	// do nothing
   187  	//if len(ret) == 1 {
   188  	//	//case 1 : "   " => [""]
   189  	//	//case 2 : " abc " = > ["abc"]
   190  	//	//case 3 : " /* abc */  " = > ["/* abc */"]
   191  	//}
   192  	if len(ret) > 1 {
   193  		last := len(ret) - 1
   194  		if !endWithSemicolon && len(ret[last]) == 0 {
   195  			//case 3 : "abc;   " => ["abc"]
   196  			//if the last one is end empty, remove it
   197  			ret = ret[:last]
   198  		}
   199  		//case 4 : "abc; def; /* abc */  " => ["abc", "def", "/* abc */"]
   200  	}
   201  
   202  	return ret
   203  }