github.com/unionj-cloud/go-doudou/v2@v2.3.5/toolkit/gormgen/internal/generate/section.go (about)

     1  package generate
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  	"strings"
     7  
     8  	"github.com/unionj-cloud/go-doudou/v2/toolkit/gormgen/internal/model"
     9  )
    10  
    11  // NewSection create and initialize Sections
    12  func NewSection() *Section {
    13  	return &Section{
    14  		ClauseTotal: map[model.Status]int{
    15  			model.WHERE: 0,
    16  			model.SET:   0,
    17  		},
    18  	}
    19  }
    20  
    21  // Section split sql into chunks
    22  type Section struct {
    23  	members      []section
    24  	Tmpls        []string
    25  	currentIndex int
    26  	ClauseTotal  map[model.Status]int
    27  	forValue     []ForRange
    28  }
    29  
    30  // next return next section and increase index by 1
    31  func (s *Section) next() section {
    32  	if s.currentIndex < len(s.members)-1 {
    33  		s.currentIndex++
    34  		return s.members[s.currentIndex]
    35  	}
    36  	return section{Type: model.END}
    37  }
    38  
    39  // SubIndex take index one step back
    40  func (s *Section) SubIndex() {
    41  	s.currentIndex--
    42  }
    43  
    44  // HasMore is has more section
    45  func (s *Section) HasMore() bool {
    46  	return s.currentIndex < len(s.members)-1
    47  }
    48  
    49  // IsNull whether section is empty
    50  func (s *Section) IsNull() bool {
    51  	return len(s.members) == 0
    52  }
    53  
    54  // current return current section
    55  func (s *Section) current() section {
    56  	return s.members[s.currentIndex]
    57  }
    58  
    59  func (s *Section) appendTmpl(value string) {
    60  	s.Tmpls = append(s.Tmpls, value)
    61  }
    62  
    63  func (s *Section) hasSameName(value string) bool {
    64  	for _, p := range s.members {
    65  		if p.Type == model.FOR && p.ForRange.value == value {
    66  			return true
    67  		}
    68  	}
    69  	return false
    70  }
    71  
    72  // BuildSQL sql sections and append to tmpl, return a Clause array
    73  func (s *Section) BuildSQL() ([]Clause, error) {
    74  	if s.IsNull() {
    75  		return nil, fmt.Errorf("sql is null")
    76  	}
    77  	name := "generateSQL"
    78  	res := make([]Clause, 0, len(s.members))
    79  	for {
    80  		c := s.current()
    81  		switch c.Type {
    82  		case model.SQL, model.DATA, model.VARIABLE:
    83  			sqlClause := s.parseSQL(name)
    84  			res = append(res, sqlClause)
    85  			s.appendTmpl(sqlClause.Finish())
    86  		case model.IF:
    87  			ifClause, err := s.parseIF(name)
    88  			if err != nil {
    89  				return nil, err
    90  			}
    91  			res = append(res, ifClause)
    92  			s.appendTmpl(ifClause.Finish())
    93  		case model.WHERE:
    94  			whereClause, err := s.parseWhere()
    95  			if err != nil {
    96  				return nil, err
    97  			}
    98  			res = append(res, whereClause)
    99  			s.appendTmpl(whereClause.Finish(name))
   100  		case model.SET:
   101  			setClause, err := s.parseSet()
   102  			if err != nil {
   103  				return nil, err
   104  			}
   105  			res = append(res, setClause)
   106  			s.appendTmpl(setClause.Finish(name))
   107  		case model.TRIM:
   108  			trimClause, err := s.parseTrim()
   109  			if err != nil {
   110  				return nil, err
   111  			}
   112  			res = append(res, trimClause)
   113  			s.appendTmpl(trimClause.Finish(name))
   114  		case model.FOR:
   115  			forClause, err := s.parseFor(name)
   116  			_, _ = forClause, err
   117  			if err != nil {
   118  				return nil, err
   119  			}
   120  			res = append(res, forClause)
   121  			s.appendTmpl(forClause.Finish())
   122  		case model.END:
   123  		default:
   124  			return nil, fmt.Errorf("unknow clause:%s", c.Value)
   125  		}
   126  		if !s.HasMore() {
   127  			break
   128  		}
   129  		c = s.next()
   130  	}
   131  	return res, nil
   132  }
   133  
   134  // parseIF parse if clause
   135  func (s *Section) parseIF(name string) (res IfClause, err error) {
   136  	c := s.current()
   137  	res.slice = c
   138  
   139  	s.appendTmpl(res.Create())
   140  	if !s.HasMore() {
   141  		return
   142  	}
   143  	c = s.next()
   144  	for {
   145  		switch c.Type {
   146  		case model.SQL, model.DATA, model.VARIABLE:
   147  			sqlClause := s.parseSQL(name)
   148  			res.Value = append(res.Value, sqlClause)
   149  			s.appendTmpl(sqlClause.Finish())
   150  		case model.IF:
   151  			var ifClause IfClause
   152  			ifClause, err = s.parseIF(name)
   153  			if err != nil {
   154  				return
   155  			}
   156  			res.Value = append(res.Value, ifClause)
   157  			s.appendTmpl(ifClause.Finish())
   158  		case model.WHERE:
   159  			var whereClause WhereClause
   160  			whereClause, err = s.parseWhere()
   161  			if err != nil {
   162  				return
   163  			}
   164  			res.Value = append(res.Value, whereClause)
   165  			s.appendTmpl(whereClause.Finish(name))
   166  		case model.SET:
   167  			var setClause SetClause
   168  			setClause, err = s.parseSet()
   169  			if err != nil {
   170  				return
   171  			}
   172  			res.Value = append(res.Value, setClause)
   173  			s.appendTmpl(setClause.Finish(name))
   174  		case model.ELSE:
   175  			var elseClause ElseClause
   176  			elseClause, err = s.parseElSE(name)
   177  			if err != nil {
   178  				return
   179  			}
   180  			res.Value = append(res.Value, elseClause)
   181  		case model.FOR:
   182  			var forClause ForClause
   183  			forClause, err = s.parseFor(name)
   184  			if err != nil {
   185  				return
   186  			}
   187  			res.Value = append(res.Value, forClause)
   188  			s.appendTmpl(res.Finish())
   189  		case model.TRIM:
   190  			var trimClause TrimClause
   191  			trimClause, err = s.parseTrim()
   192  			if err != nil {
   193  				return
   194  			}
   195  			res.Value = append(res.Value, trimClause)
   196  			s.appendTmpl(trimClause.Finish(name))
   197  		case model.END:
   198  			return
   199  		default:
   200  			err = fmt.Errorf("unknow clause : %s", c.Value)
   201  			return
   202  		}
   203  		if !s.HasMore() {
   204  			break
   205  		}
   206  		c = s.next()
   207  	}
   208  	if c.isEnd() {
   209  		err = fmt.Errorf("incomplete SQL,if not end")
   210  	}
   211  	return
   212  }
   213  
   214  // parseElSE parse else clause, the clause' type must be one of if, where, set, SQL condition
   215  func (s *Section) parseElSE(name string) (res ElseClause, err error) {
   216  	res.slice = s.current()
   217  	s.appendTmpl(res.Create())
   218  
   219  	if !s.HasMore() {
   220  		return
   221  	}
   222  	c := s.next()
   223  	for {
   224  		switch c.Type {
   225  		case model.SQL, model.DATA, model.VARIABLE:
   226  			sqlClause := s.parseSQL(name)
   227  			res.Value = append(res.Value, sqlClause)
   228  			s.appendTmpl(sqlClause.Create())
   229  		case model.IF:
   230  			var ifClause IfClause
   231  			ifClause, err = s.parseIF(name)
   232  			if err != nil {
   233  				return
   234  			}
   235  			res.Value = append(res.Value, ifClause)
   236  			s.appendTmpl(ifClause.Finish())
   237  		case model.WHERE:
   238  			var whereClause WhereClause
   239  			whereClause, err = s.parseWhere()
   240  			if err != nil {
   241  				return
   242  			}
   243  			res.Value = append(res.Value, whereClause)
   244  			s.appendTmpl(whereClause.Finish(name))
   245  		case model.SET:
   246  			var setClause SetClause
   247  			setClause, err = s.parseSet()
   248  			if err != nil {
   249  				return
   250  			}
   251  			res.Value = append(res.Value, setClause)
   252  			s.appendTmpl(setClause.Finish(name))
   253  		case model.ELSE:
   254  			var elseClause ElseClause
   255  			elseClause, err = s.parseElSE(name)
   256  			if err != nil {
   257  				return
   258  			}
   259  			res.Value = append(res.Value, elseClause)
   260  		case model.FOR:
   261  			var forClause ForClause
   262  			forClause, err = s.parseFor(name)
   263  			if err != nil {
   264  				return
   265  			}
   266  			res.Value = append(res.Value, forClause)
   267  			s.appendTmpl(forClause.Finish())
   268  		case model.TRIM:
   269  			var trimClause TrimClause
   270  			trimClause, err = s.parseTrim()
   271  			if err != nil {
   272  				return
   273  			}
   274  			res.Value = append(res.Value, trimClause)
   275  			s.appendTmpl(trimClause.Finish(name))
   276  		default:
   277  			s.SubIndex()
   278  			return
   279  		}
   280  		if !s.HasMore() {
   281  			break
   282  		}
   283  		c = s.next()
   284  	}
   285  	return
   286  }
   287  
   288  // parseWhere parse where clause, the clause' type must be one of if, SQL condition
   289  func (s *Section) parseWhere() (res WhereClause, err error) {
   290  	c := s.current()
   291  	res.VarName = s.GetName(c.Type)
   292  	s.appendTmpl(res.Create())
   293  	res.Type = c.Type
   294  
   295  	if !s.HasMore() {
   296  		return
   297  	}
   298  	c = s.next()
   299  	for {
   300  		switch c.Type {
   301  		case model.SQL, model.DATA, model.VARIABLE:
   302  			sqlClause := s.parseSQL(res.VarName)
   303  			res.Value = append(res.Value, sqlClause)
   304  			s.appendTmpl(sqlClause.Finish())
   305  		case model.IF:
   306  			var ifClause IfClause
   307  			ifClause, err = s.parseIF(res.VarName)
   308  			if err != nil {
   309  				return
   310  			}
   311  			res.Value = append(res.Value, ifClause)
   312  			s.appendTmpl(ifClause.Finish())
   313  		case model.FOR:
   314  			var forClause ForClause
   315  			forClause, err = s.parseFor(res.VarName)
   316  			if err != nil {
   317  				return
   318  			}
   319  			res.Value = append(res.Value, forClause)
   320  			s.appendTmpl(forClause.Finish())
   321  		case model.WHERE:
   322  			var whereClause WhereClause
   323  			whereClause, err = s.parseWhere()
   324  			if err != nil {
   325  				return
   326  			}
   327  			res.Value = append(res.Value, whereClause)
   328  			s.appendTmpl(whereClause.Finish(res.VarName))
   329  		case model.TRIM:
   330  			var trimClause TrimClause
   331  			trimClause, err = s.parseTrim()
   332  			if err != nil {
   333  				return
   334  			}
   335  			res.Value = append(res.Value, trimClause)
   336  			s.appendTmpl(trimClause.Finish(res.VarName))
   337  		case model.END:
   338  			return
   339  		default:
   340  			err = fmt.Errorf("unknow clause : %s", c.Value)
   341  			return
   342  		}
   343  		if !s.HasMore() {
   344  			break
   345  		}
   346  		c = s.next()
   347  	}
   348  	if c.isEnd() {
   349  		return
   350  	}
   351  	err = fmt.Errorf("incomplete SQL,where not end")
   352  	return
   353  }
   354  
   355  // parseSet parse set clause, the clause' type must be one of if, SQL condition
   356  func (s *Section) parseSet() (res SetClause, err error) {
   357  	c := s.current()
   358  	res.VarName = s.GetName(c.Type)
   359  	s.appendTmpl(res.Create())
   360  	if !s.HasMore() {
   361  		return
   362  	}
   363  	c = s.next()
   364  
   365  	res.Type = c.Type
   366  	for {
   367  		switch c.Type {
   368  		case model.SQL, model.DATA, model.VARIABLE:
   369  			sqlClause := s.parseSQL(res.VarName)
   370  			res.Value = append(res.Value, sqlClause)
   371  			s.appendTmpl(sqlClause.Finish())
   372  		case model.IF:
   373  			var ifClause IfClause
   374  			ifClause, err = s.parseIF(res.VarName)
   375  			if err != nil {
   376  				return
   377  			}
   378  			res.Value = append(res.Value, ifClause)
   379  			s.appendTmpl(ifClause.Finish())
   380  		case model.FOR:
   381  			var forClause ForClause
   382  			forClause, err = s.parseFor(res.VarName)
   383  			if err != nil {
   384  				return
   385  			}
   386  			res.Value = append(res.Value, forClause)
   387  			s.appendTmpl(forClause.Finish())
   388  		case model.WHERE:
   389  			var whereClause WhereClause
   390  			whereClause, err = s.parseWhere()
   391  			if err != nil {
   392  				return
   393  			}
   394  			res.Value = append(res.Value, whereClause)
   395  			s.appendTmpl(whereClause.Finish(res.VarName))
   396  		case model.TRIM:
   397  			var trimClause TrimClause
   398  			trimClause, err = s.parseTrim()
   399  			if err != nil {
   400  				return
   401  			}
   402  			res.Value = append(res.Value, trimClause)
   403  			s.appendTmpl(trimClause.Finish(res.VarName))
   404  		case model.END:
   405  			return
   406  		default:
   407  			err = fmt.Errorf("unknow clause : %s", c.Value)
   408  			return
   409  		}
   410  		if !s.HasMore() {
   411  			break
   412  		}
   413  		c = s.next()
   414  	}
   415  	if c.isEnd() {
   416  		err = fmt.Errorf("incomplete SQL,set not end")
   417  	}
   418  	return
   419  }
   420  
   421  // parseTrim parse set clause, the clause' type must be one of if, SQL condition
   422  func (s *Section) parseTrim() (res TrimClause, err error) {
   423  	c := s.current()
   424  	res.VarName = s.GetName(c.Type)
   425  	s.appendTmpl(res.Create())
   426  	if !s.HasMore() {
   427  		return
   428  	}
   429  	c = s.next()
   430  
   431  	res.Type = c.Type
   432  	for {
   433  		switch c.Type {
   434  		case model.SQL, model.DATA, model.VARIABLE:
   435  			sqlClause := s.parseSQL(res.VarName)
   436  			res.Value = append(res.Value, sqlClause)
   437  			s.appendTmpl(sqlClause.Finish())
   438  		case model.IF:
   439  			var ifClause IfClause
   440  			ifClause, err = s.parseIF(res.VarName)
   441  			if err != nil {
   442  				return
   443  			}
   444  			res.Value = append(res.Value, ifClause)
   445  			s.appendTmpl(ifClause.Finish())
   446  		case model.FOR:
   447  			var forClause ForClause
   448  			forClause, err = s.parseFor(res.VarName)
   449  			if err != nil {
   450  				return
   451  			}
   452  			res.Value = append(res.Value, forClause)
   453  			s.appendTmpl(forClause.Finish())
   454  		case model.WHERE:
   455  			var whereClause WhereClause
   456  			whereClause, err = s.parseWhere()
   457  			if err != nil {
   458  				return
   459  			}
   460  			res.Value = append(res.Value, whereClause)
   461  			s.appendTmpl(whereClause.Finish(res.VarName))
   462  		case model.END:
   463  			return
   464  		default:
   465  			err = fmt.Errorf("unknow clause : %s", c.Value)
   466  			return
   467  		}
   468  		if !s.HasMore() {
   469  			break
   470  		}
   471  		c = s.next()
   472  	}
   473  	if c.isEnd() {
   474  		err = fmt.Errorf("incomplete SQL,set not end")
   475  	}
   476  	return
   477  }
   478  
   479  func (s *Section) parseFor(name string) (res ForClause, err error) {
   480  	c := s.current()
   481  	res.forSlice = c
   482  	s.appendTmpl(res.Create())
   483  	s.forValue = append(s.forValue, res.forSlice.ForRange)
   484  
   485  	if !s.HasMore() {
   486  		return
   487  	}
   488  	c = s.next()
   489  	for {
   490  		switch c.Type {
   491  		case model.SQL, model.DATA, model.VARIABLE:
   492  			strClause := s.parseSQL(name)
   493  			res.Value = append(res.Value, strClause)
   494  			s.appendTmpl(fmt.Sprintf("%s.WriteString(%s)", name, strClause.String()))
   495  		case model.IF:
   496  			var ifClause IfClause
   497  			ifClause, err = s.parseIF(name)
   498  			if err != nil {
   499  				return
   500  			}
   501  			res.Value = append(res.Value, ifClause)
   502  			s.appendTmpl(ifClause.Finish())
   503  		case model.FOR:
   504  			var forClause ForClause
   505  			forClause, err = s.parseFor(name)
   506  			if err != nil {
   507  				return
   508  			}
   509  			res.Value = append(res.Value, forClause)
   510  			s.appendTmpl(forClause.Finish())
   511  		case model.TRIM:
   512  			var trimClause TrimClause
   513  			trimClause, err = s.parseTrim()
   514  			if err != nil {
   515  				return
   516  			}
   517  			res.Value = append(res.Value, trimClause)
   518  			s.appendTmpl(trimClause.Finish(name))
   519  		case model.END:
   520  			s.forValue = s.forValue[:len(s.forValue)-1]
   521  			return
   522  		default:
   523  			err = fmt.Errorf("unknow clause : %s", c.Value)
   524  			return
   525  		}
   526  		if !s.HasMore() {
   527  			break
   528  		}
   529  		c = s.next()
   530  	}
   531  	if c.isEnd() {
   532  		err = fmt.Errorf("incomplete SQL,set not end")
   533  	}
   534  	return
   535  }
   536  
   537  // parseSQL parse sql condition, the clause' type must be one of SQL condition, VARIABLE, Data
   538  func (s *Section) parseSQL(name string) (res SQLClause) {
   539  	res.VarName = name
   540  	res.Type = model.SQL
   541  	for {
   542  		c := s.current()
   543  		switch c.Type {
   544  		case model.SQL:
   545  			res.Value = append(res.Value, c.Value)
   546  		case model.VARIABLE:
   547  			res.Value = append(res.Value, c.Value)
   548  		case model.DATA:
   549  			s.appendTmpl(fmt.Sprintf("params = append(params,%s)", c.Value))
   550  			res.Value = append(res.Value, "\"?\"")
   551  		default:
   552  			s.SubIndex()
   553  			return
   554  		}
   555  		if !s.HasMore() {
   556  			return
   557  		}
   558  		c = s.next()
   559  	}
   560  }
   561  
   562  // checkSQLVar check sql variable by for loops value and external params
   563  func (s *Section) checkSQLVar(param string, status model.Status, method *InterfaceMethod) (result section, err error) {
   564  	if status == model.VARIABLE && param == "table" {
   565  		result = section{
   566  			Type:  model.SQL,
   567  			Value: strconv.Quote(method.Table),
   568  		}
   569  		return
   570  	}
   571  	if status == model.DATA {
   572  		method.HasForParams = true
   573  	}
   574  	if status == model.VARIABLE {
   575  		param = fmt.Sprintf("%s.Quote(%s)", method.S, param)
   576  	}
   577  	result = section{
   578  		Type:  status,
   579  		Value: param,
   580  	}
   581  	return
   582  }
   583  
   584  // GetName ...
   585  func (s *Section) GetName(status model.Status) string {
   586  	switch status {
   587  	case model.WHERE:
   588  		defer func() { s.ClauseTotal[model.WHERE]++ }()
   589  		return fmt.Sprintf("whereSQL%d", s.ClauseTotal[model.WHERE])
   590  	case model.SET:
   591  		defer func() { s.ClauseTotal[model.SET]++ }()
   592  		return fmt.Sprintf("setSQL%d", s.ClauseTotal[model.SET])
   593  	case model.TRIM:
   594  		defer func() { s.ClauseTotal[model.TRIM]++ }()
   595  		return fmt.Sprintf("trimSQL%d", s.ClauseTotal[model.TRIM])
   596  	default:
   597  		return "generateSQL"
   598  	}
   599  }
   600  
   601  // checkTemplate check sql template's syntax (if/else/where/set/for)
   602  func (s *Section) checkTemplate(tmpl string) (part section, err error) {
   603  	part.Value = tmpl
   604  	part.SQLSlice = s
   605  	part.splitTemplate()
   606  
   607  	err = part.checkTemplate()
   608  
   609  	return
   610  }
   611  
   612  type section struct {
   613  	Type      model.Status
   614  	Value     string
   615  	ForRange  ForRange
   616  	SQLSlice  *Section
   617  	splitList []string
   618  }
   619  
   620  func (s *section) isEnd() bool {
   621  	return s.Type == model.END
   622  }
   623  
   624  func (s *section) String() string {
   625  	if s.Type == model.FOR {
   626  		return s.ForRange.String()
   627  	}
   628  	return s.Value
   629  }
   630  
   631  func (s *section) splitTemplate() {
   632  	s.splitList = strings.FieldsFunc(strings.TrimSpace(s.Value), func(r rune) bool {
   633  		return r == ':' || r == ' ' || r == '=' || r == ','
   634  	})
   635  }
   636  
   637  func (s *section) checkTemplate() error {
   638  	if len(s.splitList) == 0 {
   639  		return fmt.Errorf("template is null")
   640  	}
   641  	if model.GenKeywords.Contain(s.Value) {
   642  		return fmt.Errorf("template can not use gen keywords")
   643  	}
   644  
   645  	err := s.sectionType(s.splitList[0])
   646  	if err != nil {
   647  		return err
   648  	}
   649  
   650  	if s.Type == model.FOR {
   651  		if len(s.splitList) != 5 {
   652  			return fmt.Errorf("for range syntax error: %s", s.Value)
   653  		}
   654  		if s.SQLSlice.hasSameName(s.splitList[2]) {
   655  			return fmt.Errorf("cannot use the same value name in different for loops")
   656  		}
   657  		s.ForRange.index = s.splitList[1]
   658  		s.ForRange.value = s.splitList[2]
   659  		s.ForRange.rangeList = s.splitList[4]
   660  	}
   661  	return nil
   662  }
   663  
   664  func (s *section) sectionType(str string) error {
   665  	switch str {
   666  	case "if":
   667  		s.Type = model.IF
   668  	case "else":
   669  		s.Type = model.ELSE
   670  	case "for":
   671  		s.Type = model.FOR
   672  	case "where":
   673  		s.Type = model.WHERE
   674  	case "set":
   675  		s.Type = model.SET
   676  	case "end":
   677  		s.Type = model.END
   678  	case "trim":
   679  		s.Type = model.TRIM
   680  	default:
   681  		return fmt.Errorf("unknown syntax: %s", str)
   682  	}
   683  	return nil
   684  }
   685  
   686  func (s *section) SQLParamName() string {
   687  	return strings.Replace(s.Value, ".", "", -1)
   688  }
   689  
   690  // ForRange for range clause for diy method
   691  type ForRange struct {
   692  	index     string
   693  	value     string
   694  	suffix    string
   695  	rangeList string
   696  }
   697  
   698  func (f *ForRange) String() string {
   699  	return fmt.Sprintf("for %s, %s := range %s", f.index, f.value, f.rangeList)
   700  }