github.com/Azareal/Gosora@v0.0.0-20210729070923-553e66b59003/query_gen/utils.go (about)

     1  /*
     2  *
     3  *	Query Generator Library
     4  *	WIP Under Construction
     5  *	Copyright Azareal 2017 - 2020
     6  *
     7   */
     8  package qgen
     9  
    10  import (
    11  	"os"
    12  	"strings"
    13  	//"fmt"
    14  )
    15  
    16  // TODO: Add support for numbers and strings?
    17  func processColumns(colStr string) (columns []DBColumn) {
    18  	if colStr == "" {
    19  		return columns
    20  	}
    21  	colStr = strings.Replace(colStr, " as ", " AS ", -1)
    22  	for _, segment := range strings.Split(colStr, ",") {
    23  		var outCol DBColumn
    24  		dotHalves := strings.Split(strings.TrimSpace(segment), ".")
    25  
    26  		var halves []string
    27  		if len(dotHalves) == 2 {
    28  			outCol.Table = dotHalves[0]
    29  			halves = strings.Split(dotHalves[1], " AS ")
    30  		} else {
    31  			halves = strings.Split(dotHalves[0], " AS ")
    32  		}
    33  
    34  		halves[0] = strings.TrimSpace(halves[0])
    35  		if len(halves) == 2 {
    36  			outCol.Alias = strings.TrimSpace(halves[1])
    37  		}
    38  		//fmt.Printf("halves: %+v\n", halves)
    39  		//fmt.Printf("halves[0]: %+v\n", halves[0])
    40  		switch {
    41  		case halves[0][0] == '(':
    42  			outCol.Type = TokenScope
    43  			outCol.Table = ""
    44  		case halves[0][len(halves[0])-1] == ')':
    45  			outCol.Type = TokenFunc
    46  		case halves[0] == "?":
    47  			outCol.Type = TokenSub
    48  		default:
    49  			outCol.Type = TokenColumn
    50  		}
    51  
    52  		outCol.Left = halves[0]
    53  		columns = append(columns, outCol)
    54  	}
    55  	return columns
    56  }
    57  
    58  // TODO: Allow order by statements without a direction
    59  func processOrderby(orderStr string) (order []DBOrder) {
    60  	if orderStr == "" {
    61  		return order
    62  	}
    63  	for _, segment := range strings.Split(orderStr, ",") {
    64  		var outOrder DBOrder
    65  		halves := strings.Split(strings.TrimSpace(segment), " ")
    66  		if len(halves) != 2 {
    67  			continue
    68  		}
    69  		outOrder.Column = halves[0]
    70  		outOrder.Order = strings.ToLower(halves[1])
    71  		order = append(order, outOrder)
    72  	}
    73  	return order
    74  }
    75  
    76  func processJoiner(joinStr string) (joiner []DBJoiner) {
    77  	if joinStr == "" {
    78  		return joiner
    79  	}
    80  	joinStr = strings.Replace(joinStr, " on ", " ON ", -1)
    81  	joinStr = strings.Replace(joinStr, " and ", " AND ", -1)
    82  	for _, segment := range strings.Split(joinStr, " AND ") {
    83  		var outJoin DBJoiner
    84  		var parseOffset int
    85  		var left, right string
    86  
    87  		left, parseOffset = getIdentifier(segment, parseOffset)
    88  		outJoin.Operator, parseOffset = getOperator(segment, parseOffset+1)
    89  		right, parseOffset = getIdentifier(segment, parseOffset+1)
    90  
    91  		leftColumn := strings.Split(left, ".")
    92  		rightColumn := strings.Split(right, ".")
    93  		outJoin.LeftTable = strings.TrimSpace(leftColumn[0])
    94  		outJoin.RightTable = strings.TrimSpace(rightColumn[0])
    95  		outJoin.LeftColumn = strings.TrimSpace(leftColumn[1])
    96  		outJoin.RightColumn = strings.TrimSpace(rightColumn[1])
    97  
    98  		joiner = append(joiner, outJoin)
    99  	}
   100  	return joiner
   101  }
   102  
   103  func (wh *DBWhere) parseNumber(seg string, i int) int {
   104  	//var buffer string
   105  	si := i
   106  	l := 0
   107  	for ; i < len(seg); i++ {
   108  		ch := seg[i]
   109  		if '0' <= ch && ch <= '9' {
   110  			//buffer += string(ch)
   111  			l++
   112  		} else {
   113  			i--
   114  			var str string
   115  			if l != 0 {
   116  				str = seg[si : si+l]
   117  			}
   118  			wh.Expr = append(wh.Expr, DBToken{str, TokenNumber})
   119  			return i
   120  		}
   121  	}
   122  	return i
   123  }
   124  
   125  func (wh *DBWhere) parseColumn(seg string, i int) int {
   126  	//var buffer string
   127  	si := i
   128  	l := 0
   129  	for ; i < len(seg); i++ {
   130  		ch := seg[i]
   131  		switch {
   132  		case ('a' <= ch && ch <= 'z') || ('A' <= ch && ch <= 'Z') || ch == '.' || ch == '_':
   133  			//buffer += string(ch)
   134  			l++
   135  		case ch == '(':
   136  			var str string
   137  			if l != 0 {
   138  				str = seg[si : si+l]
   139  			}
   140  			return wh.parseFunction(seg, str, i)
   141  		default:
   142  			i--
   143  			var str string
   144  			if l != 0 {
   145  				str = seg[si : si+l]
   146  			}
   147  			wh.Expr = append(wh.Expr, DBToken{str, TokenColumn})
   148  			return i
   149  		}
   150  	}
   151  	return i
   152  }
   153  
   154  func (wh *DBWhere) parseFunction(seg, buffer string, i int) int {
   155  	preI := i
   156  	i = skipFunctionCall(seg, i-1)
   157  	buffer += seg[preI:i] + string(seg[i])
   158  	wh.Expr = append(wh.Expr, DBToken{buffer, TokenFunc})
   159  	return i
   160  }
   161  
   162  func (wh *DBWhere) parseString(seg string, i int) int {
   163  	//var buffer string
   164  	i++
   165  	si := i
   166  	l := 0
   167  	for ; i < len(seg); i++ {
   168  		ch := seg[i]
   169  		if ch != '\'' {
   170  			//buffer += string(ch)
   171  			l++
   172  		} else {
   173  			var str string
   174  			if l != 0 {
   175  				str = seg[si : si+l]
   176  			}
   177  			wh.Expr = append(wh.Expr, DBToken{str, TokenString})
   178  			return i
   179  		}
   180  	}
   181  	return i
   182  }
   183  
   184  func (wh *DBWhere) parseOperator(seg string, i int) int {
   185  	//var buffer string
   186  	si := i
   187  	l := 0
   188  	for ; i < len(seg); i++ {
   189  		ch := seg[i]
   190  		if isOpByte(ch) {
   191  			//buffer += string(ch)
   192  			l++
   193  		} else {
   194  			i--
   195  			var str string
   196  			if l != 0 {
   197  				str = seg[si : si+l]
   198  			}
   199  			wh.Expr = append(wh.Expr, DBToken{str, TokenOp})
   200  			return i
   201  		}
   202  	}
   203  	return i
   204  }
   205  
   206  // TODO: Make this case insensitive
   207  func normalizeAnd(in string) string {
   208  	in = strings.Replace(in, " and ", " AND ", -1)
   209  	return strings.Replace(in, " && ", " AND ", -1)
   210  }
   211  func normalizeOr(in string) string {
   212  	in = strings.Replace(in, " or ", " OR ", -1)
   213  	return strings.Replace(in, " || ", " OR ", -1)
   214  }
   215  
   216  // TODO: Write tests for this
   217  func processWhere(whereStr string) (where []DBWhere) {
   218  	if whereStr == "" {
   219  		return where
   220  	}
   221  	whereStr = normalizeAnd(whereStr)
   222  	whereStr = normalizeOr(whereStr)
   223  
   224  	for _, seg := range strings.Split(whereStr, " AND ") {
   225  		tmpWhere := &DBWhere{[]DBToken{}}
   226  		seg += ")"
   227  		for i := 0; i < len(seg); i++ {
   228  			ch := seg[i]
   229  			switch {
   230  			case '0' <= ch && ch <= '9':
   231  				i = tmpWhere.parseNumber(seg, i)
   232  			// TODO: Sniff the third byte offset from char or it's non-existent to avoid matching uppercase strings which start with OR
   233  			case ch == 'O' && (i+1) < len(seg) && seg[i+1] == 'R':
   234  				tmpWhere.Expr = append(tmpWhere.Expr, DBToken{"OR", TokenOr})
   235  				i += 1
   236  			case ch == 'N' && (i+2) < len(seg) && seg[i+1] == 'O' && seg[i+2] == 'T':
   237  				tmpWhere.Expr = append(tmpWhere.Expr, DBToken{"NOT", TokenNot})
   238  				i += 2
   239  			case ch == 'L' && (i+3) < len(seg) && seg[i+1] == 'I' && seg[i+2] == 'K' && seg[i+3] == 'E':
   240  				tmpWhere.Expr = append(tmpWhere.Expr, DBToken{"LIKE", TokenLike})
   241  				i += 3
   242  			case ('a' <= ch && ch <= 'z') || ('A' <= ch && ch <= 'Z') || ch == '_':
   243  				i = tmpWhere.parseColumn(seg, i)
   244  			case ch == '\'':
   245  				i = tmpWhere.parseString(seg, i)
   246  			case ch == ')' && i < (len(seg)-1):
   247  				tmpWhere.Expr = append(tmpWhere.Expr, DBToken{")", TokenOp})
   248  			case isOpByte(ch):
   249  				i = tmpWhere.parseOperator(seg, i)
   250  			case ch == '?':
   251  				tmpWhere.Expr = append(tmpWhere.Expr, DBToken{"?", TokenSub})
   252  			}
   253  		}
   254  		where = append(where, *tmpWhere)
   255  	}
   256  	return where
   257  }
   258  
   259  func (set *DBSetter) parseNumber(seg string, i int) int {
   260  	//var buffer string
   261  	si := i
   262  	l := 0
   263  	for ; i < len(seg); i++ {
   264  		ch := seg[i]
   265  		if '0' <= ch && ch <= '9' {
   266  			//buffer += string(ch)
   267  			l++
   268  		} else {
   269  			var str string
   270  			if l != 0 {
   271  				str = seg[si : si+l]
   272  			}
   273  			set.Expr = append(set.Expr, DBToken{str, TokenNumber})
   274  			return i
   275  		}
   276  	}
   277  	return i
   278  }
   279  
   280  func (set *DBSetter) parseColumn(seg string, i int) int {
   281  	//var buffer string
   282  	si := i
   283  	l := 0
   284  	for ; i < len(seg); i++ {
   285  		ch := seg[i]
   286  		switch {
   287  		case ('a' <= ch && ch <= 'z') || ('A' <= ch && ch <= 'Z') || ch == '_':
   288  			//buffer += string(ch)
   289  			l++
   290  		case ch == '(':
   291  			var str string
   292  			if l != 0 {
   293  				str = seg[si : si+l]
   294  			}
   295  			return set.parseFunction(seg, str, i)
   296  		default:
   297  			i--
   298  			var str string
   299  			if l != 0 {
   300  				str = seg[si : si+l]
   301  			}
   302  			set.Expr = append(set.Expr, DBToken{str, TokenColumn})
   303  			return i
   304  		}
   305  	}
   306  	return i
   307  }
   308  
   309  func (set *DBSetter) parseFunction(segment, buffer string, i int) int {
   310  	preI := i
   311  	i = skipFunctionCall(segment, i-1)
   312  	buffer += segment[preI:i] + string(segment[i])
   313  	set.Expr = append(set.Expr, DBToken{buffer, TokenFunc})
   314  	return i
   315  }
   316  
   317  func (set *DBSetter) parseString(seg string, i int) int {
   318  	//var buffer string
   319  	i++
   320  	si := i
   321  	l := 0
   322  	for ; i < len(seg); i++ {
   323  		ch := seg[i]
   324  		if ch != '\'' {
   325  			//buffer += string(ch)
   326  			l++
   327  		} else {
   328  			var str string
   329  			if l != 0 {
   330  				str = seg[si : si+l]
   331  			}
   332  			set.Expr = append(set.Expr, DBToken{str, TokenString})
   333  			return i
   334  		}
   335  	}
   336  	return i
   337  }
   338  
   339  func (set *DBSetter) parseOperator(seg string, i int) int {
   340  	//var buffer string
   341  	si := i
   342  	l := 0
   343  	for ; i < len(seg); i++ {
   344  		ch := seg[i]
   345  		if isOpByte(ch) {
   346  			//buffer += string(ch)
   347  			l++
   348  		} else {
   349  			i--
   350  			var str string
   351  			if l != 0 {
   352  				str = seg[si : si+l]
   353  			}
   354  			set.Expr = append(set.Expr, DBToken{str, TokenOp})
   355  			return i
   356  		}
   357  	}
   358  	return i
   359  }
   360  
   361  func processSet(setstr string) (setter []DBSetter) {
   362  	if setstr == "" {
   363  		return setter
   364  	}
   365  
   366  	// First pass, splitting the string by commas while ignoring the innards of functions
   367  	var setset []string
   368  	var buffer string
   369  	var lastItem int
   370  	setstr += ","
   371  	for i := 0; i < len(setstr); i++ {
   372  		if setstr[i] == '(' {
   373  			i = skipFunctionCall(setstr, i-1)
   374  			setset = append(setset, setstr[lastItem:i+1])
   375  			buffer = ""
   376  			lastItem = i + 2
   377  		} else if setstr[i] == ',' && buffer != "" {
   378  			setset = append(setset, buffer)
   379  			buffer = ""
   380  			lastItem = i + 1
   381  		} else if (setstr[i] > 32) && setstr[i] != ',' && setstr[i] != ')' {
   382  			buffer += string(setstr[i])
   383  		}
   384  	}
   385  
   386  	// Second pass. Break this setitem into manageable chunks
   387  	for _, setitem := range setset {
   388  		halves := strings.Split(setitem, "=")
   389  		if len(halves) != 2 {
   390  			continue
   391  		}
   392  		tmpSetter := &DBSetter{Column: strings.TrimSpace(halves[0])}
   393  		segment := halves[1] + ")"
   394  
   395  		for i := 0; i < len(segment); i++ {
   396  			ch := segment[i]
   397  			switch {
   398  			case '0' <= ch && ch <= '9':
   399  				i = tmpSetter.parseNumber(segment, i)
   400  			case ('a' <= ch && ch <= 'z') || ('A' <= ch && ch <= 'Z') || ch == '_':
   401  				i = tmpSetter.parseColumn(segment, i)
   402  			case ch == '\'':
   403  				i = tmpSetter.parseString(segment, i)
   404  			case isOpByte(ch):
   405  				i = tmpSetter.parseOperator(segment, i)
   406  			case ch == '?':
   407  				tmpSetter.Expr = append(tmpSetter.Expr, DBToken{"?", TokenSub})
   408  			}
   409  		}
   410  		setter = append(setter, *tmpSetter)
   411  	}
   412  	return setter
   413  }
   414  
   415  func processLimit(limitStr string) (limit DBLimit) {
   416  	halves := strings.Split(limitStr, ",")
   417  	if len(halves) == 2 {
   418  		limit.Offset = halves[0]
   419  		limit.MaxCount = halves[1]
   420  	} else {
   421  		limit.MaxCount = halves[0]
   422  	}
   423  	return limit
   424  }
   425  
   426  func isOpByte(ch byte) bool {
   427  	return ch == '<' || ch == '>' || ch == '=' || ch == '!' || ch == '*' || ch == '%' || ch == '+' || ch == '-' || ch == '/' || ch == '(' || ch == ')'
   428  }
   429  
   430  func isOpRune(ch rune) bool {
   431  	return ch == '<' || ch == '>' || ch == '=' || ch == '!' || ch == '*' || ch == '%' || ch == '+' || ch == '-' || ch == '/' || ch == '(' || ch == ')'
   432  }
   433  
   434  func processFields(fieldStr string) (fields []DBField) {
   435  	if fieldStr == "" {
   436  		return fields
   437  	}
   438  	var buffer string
   439  	var lastItem int
   440  	fieldStr += ","
   441  	for i := 0; i < len(fieldStr); i++ {
   442  		ch := fieldStr[i]
   443  		if ch == '(' {
   444  			i = skipFunctionCall(fieldStr, i-1)
   445  			fields = append(fields, DBField{Name: fieldStr[lastItem : i+1], Type: getIdentifierType(fieldStr[lastItem : i+1])})
   446  			buffer = ""
   447  			lastItem = i + 2
   448  		} else if ch == ',' && buffer != "" {
   449  			fields = append(fields, DBField{Name: buffer, Type: getIdentifierType(buffer)})
   450  			buffer = ""
   451  			lastItem = i + 1
   452  		} else if (ch >= 32) && ch != ',' && ch != ')' {
   453  			buffer += string(ch)
   454  		}
   455  	}
   456  	return fields
   457  }
   458  
   459  func getIdentifierType(iden string) int {
   460  	if ('a' <= iden[0] && iden[0] <= 'z') || ('A' <= iden[0] && iden[0] <= 'Z') {
   461  		if iden[len(iden)-1] == ')' {
   462  			return IdenFunc
   463  		}
   464  		return IdenColumn
   465  	}
   466  	if iden[0] == '\'' || iden[0] == '"' {
   467  		return IdenString
   468  	}
   469  	return IdenLiteral
   470  }
   471  
   472  func getIdentifier(seg string, startOffset int) (out string, i int) {
   473  	seg = strings.TrimSpace(seg)
   474  	seg += " " // Avoid overflow bugs with slicing
   475  	for i = startOffset; i < len(seg); i++ {
   476  		ch := seg[i]
   477  		if ch == '(' {
   478  			i = skipFunctionCall(seg, i)
   479  			return strings.TrimSpace(seg[startOffset:i]), (i - 1)
   480  		}
   481  		if (ch == ' ' || isOpByte(ch)) && i != startOffset {
   482  			return strings.TrimSpace(seg[startOffset:i]), (i - 1)
   483  		}
   484  	}
   485  	return strings.TrimSpace(seg[startOffset:]), (i - 1)
   486  }
   487  
   488  func getOperator(seg string, startOffset int) (out string, i int) {
   489  	seg = strings.TrimSpace(seg)
   490  	seg += " " // Avoid overflow bugs with slicing
   491  	for i = startOffset; i < len(seg); i++ {
   492  		if !isOpByte(seg[i]) && i != startOffset {
   493  			return strings.TrimSpace(seg[startOffset:i]), (i - 1)
   494  		}
   495  	}
   496  	return strings.TrimSpace(seg[startOffset:]), (i - 1)
   497  }
   498  
   499  func skipFunctionCall(data string, index int) int {
   500  	var braceCount int
   501  	for ; index < len(data); index++ {
   502  		char := data[index]
   503  		if char == '(' {
   504  			braceCount++
   505  		} else if char == ')' {
   506  			braceCount--
   507  			if braceCount == 0 {
   508  				return index
   509  			}
   510  		}
   511  	}
   512  	return index
   513  }
   514  
   515  func writeFile(name, content string) (err error) {
   516  	f, err := os.Create(name)
   517  	if err != nil {
   518  		return err
   519  	}
   520  	_, err = f.WriteString(content)
   521  	if err != nil {
   522  		return err
   523  	}
   524  	err = f.Sync()
   525  	if err != nil {
   526  		return err
   527  	}
   528  	return f.Close()
   529  }