github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/elasticsql/elasticsql.go (about)

     1  package elasticsql
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"strings"
     7  
     8  	"github.com/bingoohuang/gg/pkg/sqlparse/sqlparser"
     9  	"github.com/bingoohuang/gg/pkg/ss"
    10  )
    11  
    12  // Convert will transform sql to elasticsearch dsl string
    13  func Convert(sql string) (dsl string, err error) {
    14  	switch firstWord := strings.ToLower(ss.FirstWord(sql)); firstWord {
    15  	case "update", "delete", "insert":
    16  		return "", errors.New("unsupported")
    17  	case "limit", "order", "where":
    18  		sql = "select * from t " + sql
    19  	default:
    20  		sql = "select * from t where " + sql
    21  	}
    22  
    23  	stmt, err := sqlparser.Parse(sql)
    24  	if err != nil {
    25  		return "", err
    26  	}
    27  
    28  	// sql valid, start to handle
    29  	switch t := stmt.(type) {
    30  	case *sqlparser.Select:
    31  		return handleSelect(t)
    32  	default:
    33  		return "", errors.New("unsupported")
    34  	}
    35  }
    36  
    37  func handleSelect(sel *sqlparser.Select) (dsl string, err error) {
    38  	// Handle where
    39  	// top level node pass in an empty interface
    40  	// to tell the children this is root
    41  	// is there any better way?
    42  	var rootParent sqlparser.Expr
    43  	var queryMapStr string
    44  
    45  	// use may not pass where clauses
    46  	if sel.Where != nil {
    47  		queryMapStr, err = handleSelectWhere(&sel.Where.Expr, true, &rootParent)
    48  		if err != nil {
    49  			return "", err
    50  		}
    51  	}
    52  	if queryMapStr == "" {
    53  		queryMapStr = `{"bool" : {"must": [{"match_all" : {}}]}}`
    54  	}
    55  
    56  	queryFrom, querySize := "", ""
    57  
    58  	// if the request is to aggregation
    59  	// then set aggFlag to true, and querySize to 0
    60  	// to not return any query result
    61  
    62  	sel.GroupBy = nil
    63  	sel.SelectExprs = nil
    64  
    65  	// Handle limit
    66  	if sel.Limit != nil {
    67  		if sel.Limit.Offset != nil {
    68  			queryFrom = sqlparser.String(sel.Limit.Offset)
    69  		}
    70  		querySize = sqlparser.String(sel.Limit.Rowcount)
    71  	}
    72  
    73  	// Handle order by
    74  	// when executing aggregations, order by is useless
    75  	var orderByArr []string
    76  	for _, orderByExpr := range sel.OrderBy {
    77  		s := strings.Replace(sqlparser.String(orderByExpr.Expr), "`", "", -1)
    78  		orderByStr := fmt.Sprintf(`{"%v": "%v"}`, s, orderByExpr.Direction)
    79  		orderByArr = append(orderByArr, orderByStr)
    80  	}
    81  
    82  	resultMap := map[string]interface{}{"query": queryMapStr}
    83  
    84  	if querySize != "" {
    85  		resultMap["size"] = ss.ParseInt(querySize)
    86  	}
    87  	if queryFrom != "" {
    88  		resultMap["from"] = ss.ParseInt(queryFrom)
    89  	}
    90  
    91  	if len(orderByArr) > 0 {
    92  		resultMap["sort"] = fmt.Sprintf("[%v]", strings.Join(orderByArr, ","))
    93  	}
    94  
    95  	// keep the traversal in order, avoid unpredicted json
    96  	var resultArr []string
    97  	for _, mapKey := range []string{"query", "from", "size", "sort"} {
    98  		if val, ok := resultMap[mapKey]; ok {
    99  			resultArr = append(resultArr, fmt.Sprintf(`"%v" : %v`, mapKey, val))
   100  		}
   101  	}
   102  
   103  	dsl = "{" + strings.Join(resultArr, ",") + "}"
   104  	return dsl, nil
   105  }
   106  
   107  func buildNestedFuncStrValue(nestedFunc *sqlparser.FuncExpr) (string, error) {
   108  	return "", errors.New("elasticsql: unsupported function" + nestedFunc.Name.String())
   109  }
   110  
   111  func handleSelectWhereAndExpr(expr *sqlparser.Expr, parent *sqlparser.Expr) (string, error) {
   112  	andExpr := (*expr).(*sqlparser.AndExpr)
   113  	leftExpr := andExpr.Left
   114  	rightExpr := andExpr.Right
   115  	leftStr, err := handleSelectWhere(&leftExpr, false, expr)
   116  	if err != nil {
   117  		return "", err
   118  	}
   119  	rightStr, err := handleSelectWhere(&rightExpr, false, expr)
   120  	if err != nil {
   121  		return "", err
   122  	}
   123  
   124  	// not toplevel
   125  	// if the parent node is also and, then the result can be merged
   126  
   127  	var resultStr string
   128  	if leftStr == "" || rightStr == "" {
   129  		resultStr = leftStr + rightStr
   130  	} else {
   131  		resultStr = leftStr + `,` + rightStr
   132  	}
   133  
   134  	if _, ok := (*parent).(*sqlparser.AndExpr); ok {
   135  		return resultStr, nil
   136  	}
   137  	return fmt.Sprintf(`{"bool" : {"must" : [%v]}}`, resultStr), nil
   138  }
   139  
   140  func handleSelectWhereOrExpr(expr *sqlparser.Expr, parent *sqlparser.Expr) (string, error) {
   141  	orExpr := (*expr).(*sqlparser.OrExpr)
   142  	leftExpr := orExpr.Left
   143  	rightExpr := orExpr.Right
   144  
   145  	leftStr, err := handleSelectWhere(&leftExpr, false, expr)
   146  	if err != nil {
   147  		return "", err
   148  	}
   149  
   150  	rightStr, err := handleSelectWhere(&rightExpr, false, expr)
   151  	if err != nil {
   152  		return "", err
   153  	}
   154  
   155  	var resultStr string
   156  	if leftStr == "" || rightStr == "" {
   157  		resultStr = leftStr + rightStr
   158  	} else {
   159  		resultStr = leftStr + `,` + rightStr
   160  	}
   161  
   162  	// not toplevel
   163  	// if the parent node is also or node, then merge the query param
   164  	if _, ok := (*parent).(*sqlparser.OrExpr); ok {
   165  		return resultStr, nil
   166  	}
   167  
   168  	return fmt.Sprintf(`{"bool" : {"should" : [%v]}}`, resultStr), nil
   169  }
   170  
   171  func buildComparisonExprRightStr(expr sqlparser.Expr) (string, bool, error) {
   172  	var rightStr string
   173  	var err error
   174  	switch expr.(type) {
   175  	case *sqlparser.SQLVal:
   176  		rightStr = sqlparser.String(expr)
   177  		rightStr = strings.Trim(rightStr, `'`)
   178  	case *sqlparser.GroupConcatExpr:
   179  		return "", false, errors.New("elasticsql: group_concat not supported")
   180  	case *sqlparser.FuncExpr:
   181  		// parse nested
   182  		funcExpr := expr.(*sqlparser.FuncExpr)
   183  		rightStr, err = buildNestedFuncStrValue(funcExpr)
   184  		if err != nil {
   185  			return "", false, err
   186  		}
   187  	case *sqlparser.ColName:
   188  		if sqlparser.String(expr) == "missing" {
   189  			return "", true, nil
   190  		}
   191  
   192  		return "", true, errors.New("elasticsql: column name on the right side of compare operator is not supported")
   193  	case sqlparser.ValTuple:
   194  		rightStr = sqlparser.String(expr)
   195  	default:
   196  		// cannot reach here
   197  	}
   198  	return rightStr, false, err
   199  }
   200  
   201  func handleSelectWhereComparisonExpr(expr *sqlparser.Expr, topLevel bool, parent *sqlparser.Expr) (string, error) {
   202  	comparisonExpr := (*expr).(*sqlparser.ComparisonExpr)
   203  	colName, ok := comparisonExpr.Left.(*sqlparser.ColName)
   204  
   205  	if !ok {
   206  		return "", errors.New("elasticsql: invalid comparison expression, the left must be a column name")
   207  	}
   208  
   209  	colNameStr := sqlparser.String(colName)
   210  	colNameStr = strings.Replace(colNameStr, "`", "", -1)
   211  	colNameStr = strings.ToLower(colNameStr)
   212  	rightStr, missingCheck, err := buildComparisonExprRightStr(comparisonExpr.Right)
   213  	if err != nil {
   214  		return "", err
   215  	}
   216  
   217  	resultStr := ""
   218  
   219  	switch comparisonExpr.Operator {
   220  	case ">=":
   221  		resultStr = fmt.Sprintf(`{"range" : {"%v" : {"from" : "%v"}}}`, colNameStr, rightStr)
   222  	case "<=":
   223  		resultStr = fmt.Sprintf(`{"range" : {"%v" : {"to" : "%v"}}}`, colNameStr, rightStr)
   224  	case "=":
   225  		// field is missing
   226  		if missingCheck {
   227  			resultStr = fmt.Sprintf(`{"missing":{"field":"%v"}}`, colNameStr)
   228  		} else {
   229  			resultStr = fmt.Sprintf(`{"match" : {"%v" : {"query" : "%v"}}}`, colNameStr, rightStr)
   230  		}
   231  	case ">":
   232  		resultStr = fmt.Sprintf(`{"range" : {"%v" : {"gt" : "%v"}}}`, colNameStr, rightStr)
   233  	case "<":
   234  		resultStr = fmt.Sprintf(`{"range" : {"%v" : {"lt" : "%v"}}}`, colNameStr, rightStr)
   235  	case "!=":
   236  		if missingCheck {
   237  			resultStr = fmt.Sprintf(`{"bool" : {"must_not" : [{"missing":{"field":"%v"}}]}}`, colNameStr)
   238  		} else {
   239  			resultStr = fmt.Sprintf(`{"bool" : {"must_not" : [{"match" : {"%v" : {"query" : "%v"}}}]}}`, colNameStr, rightStr)
   240  		}
   241  	case "in":
   242  		// the default valTuple is ('1', '2', '3') like
   243  		// so need to drop the () and replace ' to "
   244  		rightStr = strings.Replace(rightStr, `'`, `"`, -1)
   245  		rightStr = strings.Trim(rightStr, "(")
   246  		rightStr = strings.Trim(rightStr, ")")
   247  		resultStr = fmt.Sprintf(`{"terms" : {"%v" : [%v]}}`, colNameStr, rightStr)
   248  	case "like":
   249  		rightStr = strings.Replace(rightStr, `%`, ``, -1)
   250  		resultStr = fmt.Sprintf(`{"match" : {"%v" : {"query" : "%v"}}}`, colNameStr, rightStr)
   251  	case "not like":
   252  		rightStr = strings.Replace(rightStr, `%`, ``, -1)
   253  		resultStr = fmt.Sprintf(`{"bool" : {"must_not" : {"match" : {"%v" : {"query" : "%v"}}}}}`, colNameStr, rightStr)
   254  	case "not in":
   255  		// the default valTuple is ('1', '2', '3') like
   256  		// so need to drop the () and replace ' to "
   257  		rightStr = strings.Replace(rightStr, `'`, `"`, -1)
   258  		rightStr = strings.Trim(rightStr, "(")
   259  		rightStr = strings.Trim(rightStr, ")")
   260  		resultStr = fmt.Sprintf(`{"bool" : {"must_not" : {"terms" : {"%v" : [%v]}}}}`, colNameStr, rightStr)
   261  	}
   262  
   263  	// the root node need to have bool and must
   264  	if topLevel {
   265  		resultStr = fmt.Sprintf(`{"bool" : {"must" : [%v]}}`, resultStr)
   266  	}
   267  
   268  	return resultStr, nil
   269  }
   270  
   271  func handleSelectWhere(expr *sqlparser.Expr, topLevel bool, parent *sqlparser.Expr) (string, error) {
   272  	if expr == nil {
   273  		return "", errors.New("elasticsql: error expression cannot be nil here")
   274  	}
   275  
   276  	switch e := (*expr).(type) {
   277  	case *sqlparser.AndExpr:
   278  		return handleSelectWhereAndExpr(expr, parent)
   279  
   280  	case *sqlparser.OrExpr:
   281  		return handleSelectWhereOrExpr(expr, parent)
   282  	case *sqlparser.ComparisonExpr:
   283  		return handleSelectWhereComparisonExpr(expr, topLevel, parent)
   284  
   285  	case *sqlparser.IsExpr:
   286  		return "", errors.New("elasticsql: is expression currently not supported")
   287  	case *sqlparser.RangeCond:
   288  		// between a and b
   289  		// the meaning is equal to range query
   290  		rangeCond := (*expr).(*sqlparser.RangeCond)
   291  		colName, ok := rangeCond.Left.(*sqlparser.ColName)
   292  
   293  		if !ok {
   294  			return "", errors.New("elasticsql: range column name missing")
   295  		}
   296  
   297  		colNameStr := sqlparser.String(colName)
   298  		colNameStr = strings.Replace(colNameStr, "`", "", -1)
   299  		fromStr := strings.Trim(sqlparser.String(rangeCond.From), `'`)
   300  		toStr := strings.Trim(sqlparser.String(rangeCond.To), `'`)
   301  
   302  		resultStr := fmt.Sprintf(`{"range" : {"%v" : {"from" : "%v", "to" : "%v"}}}`, colNameStr, fromStr, toStr)
   303  		if topLevel {
   304  			resultStr = fmt.Sprintf(`{"bool" : {"must" : [%v]}}`, resultStr)
   305  		}
   306  
   307  		return resultStr, nil
   308  
   309  	case *sqlparser.ParenExpr:
   310  		parentBoolExpr := (*expr).(*sqlparser.ParenExpr)
   311  		boolExpr := parentBoolExpr.Expr
   312  
   313  		// if paren is the top level, bool must is needed
   314  		isThisTopLevel := false
   315  		if topLevel {
   316  			isThisTopLevel = true
   317  		}
   318  		return handleSelectWhere(&boolExpr, isThisTopLevel, parent)
   319  	case *sqlparser.NotExpr:
   320  		return "", errors.New("elasticsql: not expression currently not supported")
   321  	case *sqlparser.FuncExpr:
   322  		switch e.Name.Lowered() {
   323  		case "multi_match":
   324  			params := e.Exprs
   325  			if len(params) > 3 || len(params) < 2 {
   326  				return "", errors.New("elasticsql: the multi_match must have 2 or 3 params, (query, fields and type) or (query, fields)")
   327  			}
   328  
   329  			var typ, query, fields string
   330  			for i := 0; i < len(params); i++ {
   331  				elem := strings.Replace(sqlparser.String(params[i]), "`", "", -1) // a = b
   332  				kv := strings.Split(elem, "=")
   333  				if len(kv) != 2 {
   334  					return "", errors.New("elasticsql: the param should be query = xxx, field = yyy, type = zzz")
   335  				}
   336  				k, v := strings.TrimSpace(kv[0]), strings.TrimSpace(kv[1])
   337  				switch k {
   338  				case "type":
   339  					typ = strings.Replace(v, "'", "", -1)
   340  				case "query":
   341  					query = strings.Replace(v, "`", "", -1)
   342  					query = strings.Replace(query, "'", "", -1)
   343  				case "fields":
   344  					fieldList := strings.Split(strings.TrimRight(strings.TrimLeft(v, "("), ")"), ",")
   345  					for idx, field := range fieldList {
   346  						fieldList[idx] = fmt.Sprintf(`"%v"`, strings.TrimSpace(field))
   347  					}
   348  					fields = strings.Join(fieldList, ",")
   349  				default:
   350  					return "", errors.New("elaticsql: unknow param for multi_match")
   351  				}
   352  			}
   353  			if typ == "" {
   354  				return fmt.Sprintf(`{"multi_match" : {"query" : "%v", "fields" : [%v]}}`, query, fields), nil
   355  			}
   356  			return fmt.Sprintf(`{"multi_match" : {"query" : "%v", "type" : "%v", "fields" : [%v]}}`, query, typ, fields), nil
   357  		default:
   358  			return "", errors.New("elaticsql: function in where not supported" + e.Name.Lowered())
   359  		}
   360  	}
   361  
   362  	return "", errors.New("elaticsql: logically cannot reached here")
   363  }