github.com/unionj-cloud/go-doudou/v2@v2.3.5/toolkit/gormgen/internal/generate/interface.go (about)

     1  package generate
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  	"strings"
     7  
     8  	"github.com/unionj-cloud/go-doudou/v2/toolkit/gormgen/internal/model"
     9  	"github.com/unionj-cloud/go-doudou/v2/toolkit/gormgen/internal/parser"
    10  )
    11  
    12  // InterfaceMethod interface's method
    13  type InterfaceMethod struct { // feature will replace InterfaceMethod to parser.Method
    14  	Doc           string         // comment
    15  	S             string         // First letter of
    16  	OriginStruct  parser.Param   // origin struct name
    17  	TargetStruct  string         // generated query struct bane
    18  	MethodName    string         // generated function name
    19  	Params        []parser.Param // function input params
    20  	Result        []parser.Param // function output params
    21  	ResultData    parser.Param   // output data
    22  	Section       *Section       // Parse split SQL into sections
    23  	SQLParams     []parser.Param // variable in sql need function input
    24  	SQLString     string         // SQL
    25  	GormOption    string         // gorm execute method Find or Exec or Take
    26  	Table         string         // specified by user. if empty, generate it with gorm
    27  	InterfaceName string         // origin interface name
    28  	Package       string         // interface package name
    29  	HasForParams  bool           //
    30  }
    31  
    32  // FuncSign function signature
    33  func (m *InterfaceMethod) FuncSign() string {
    34  	return fmt.Sprintf("%s(%s) (%s)", m.MethodName, m.GetParamInTmpl(), m.GetResultParamInTmpl())
    35  }
    36  
    37  // HasSQLData has variable or for params will creat params map
    38  func (m *InterfaceMethod) HasSQLData() bool {
    39  	return len(m.SQLParams) > 0 || m.HasForParams
    40  }
    41  
    42  // HasGotPoint parameter has pointer or not
    43  func (m *InterfaceMethod) HasGotPoint() bool {
    44  	return !m.HasNeedNewResult()
    45  }
    46  
    47  // HasNeedNewResult need pointer or not
    48  func (m *InterfaceMethod) HasNeedNewResult() bool {
    49  	return !m.ResultData.IsArray && ((m.ResultData.IsNull() && m.ResultData.IsTime()) || m.ResultData.IsMap())
    50  }
    51  
    52  // GormRunMethodName return single data use Take() return array use Find
    53  func (m *InterfaceMethod) GormRunMethodName() string {
    54  	if m.ResultData.IsArray {
    55  		return "Find"
    56  	}
    57  	return "Take"
    58  }
    59  
    60  // ReturnSQLResult return sql result
    61  func (m *InterfaceMethod) ReturnSQLResult() bool {
    62  	for _, res := range m.Result {
    63  		if res.IsSQLResult() {
    64  			return true
    65  		}
    66  	}
    67  	return false
    68  }
    69  
    70  // ReturnSQLRow return sql result
    71  func (m *InterfaceMethod) ReturnSQLRow() bool {
    72  	for _, res := range m.Result {
    73  		if res.IsSQLRow() {
    74  			return true
    75  		}
    76  	}
    77  	return false
    78  }
    79  
    80  // ReturnSQLRows return sql result
    81  func (m *InterfaceMethod) ReturnSQLRows() bool {
    82  	for _, res := range m.Result {
    83  		if res.IsSQLRows() {
    84  			return true
    85  		}
    86  	}
    87  	return false
    88  }
    89  
    90  // ReturnNothing not return error and rowAffected
    91  func (m *InterfaceMethod) ReturnNothing() bool {
    92  	for _, res := range m.Result {
    93  		if res.IsError() || res.Name == "rowsAffected" {
    94  			return false
    95  		}
    96  	}
    97  	return true
    98  }
    99  
   100  // ReturnRowsAffected return rows affected
   101  func (m *InterfaceMethod) ReturnRowsAffected() bool {
   102  	for _, res := range m.Result {
   103  		if res.Name == "rowsAffected" {
   104  			return true
   105  		}
   106  	}
   107  	return false
   108  }
   109  
   110  // ReturnError return error
   111  func (m *InterfaceMethod) ReturnError() bool {
   112  	for _, res := range m.Result {
   113  		if res.IsError() {
   114  			return true
   115  		}
   116  	}
   117  	return false
   118  }
   119  
   120  // IsRepeatFromDifferentInterface check different interface has same mame method
   121  func (m *InterfaceMethod) IsRepeatFromDifferentInterface(newMethod *InterfaceMethod) bool {
   122  	return m.MethodName == newMethod.MethodName && m.InterfaceName != newMethod.InterfaceName && m.TargetStruct == newMethod.TargetStruct
   123  }
   124  
   125  // IsRepeatFromSameInterface check different interface has same mame method
   126  func (m *InterfaceMethod) IsRepeatFromSameInterface(newMethod *InterfaceMethod) bool {
   127  	return m.MethodName == newMethod.MethodName && m.InterfaceName == newMethod.InterfaceName && m.TargetStruct == newMethod.TargetStruct
   128  }
   129  
   130  //GetParamInTmpl return param list
   131  func (m *InterfaceMethod) GetParamInTmpl() string {
   132  	return paramToString(m.Params)
   133  }
   134  
   135  // GetResultParamInTmpl return result list
   136  func (m *InterfaceMethod) GetResultParamInTmpl() string {
   137  	return paramToString(m.Result)
   138  }
   139  
   140  // SQLParamName sql param map key,
   141  func (m *InterfaceMethod) SQLParamName(param string) string {
   142  	return strings.Replace(param, ".", "", -1)
   143  }
   144  
   145  // paramToString param list to string used in tmpl
   146  func paramToString(params []parser.Param) string {
   147  	var res []string
   148  	for _, param := range params {
   149  		res = append(res, param.TmplString())
   150  	}
   151  	return strings.Join(res, ",")
   152  }
   153  
   154  // DocComment return comment sql add "//" every line
   155  func (m *InterfaceMethod) DocComment() string {
   156  	return strings.Replace(strings.Replace(strings.TrimSpace(m.Doc), "\n", "\n// ", -1), "//  ", "// ", -1)
   157  }
   158  
   159  // checkParams check all parameters
   160  func (m *InterfaceMethod) checkMethod(methods []*InterfaceMethod, s *QueryStructMeta) (err error) {
   161  	if model.GormKeywords.FullMatch(m.MethodName) {
   162  		return fmt.Errorf("can not use keyword as method name:%s", m.MethodName)
   163  	}
   164  	// TODO check methods Always empty?
   165  	for _, method := range methods {
   166  		if m.IsRepeatFromDifferentInterface(method) {
   167  			return fmt.Errorf("can not generate method with the same name from different interface:[%s.%s] and [%s.%s]",
   168  				m.InterfaceName, m.MethodName, method.InterfaceName, method.MethodName)
   169  		}
   170  	}
   171  	for _, f := range s.Fields {
   172  		if f.Name == m.MethodName {
   173  			return fmt.Errorf("can not generate method same name with struct field:[%s.%s] and [%s.%s]",
   174  				m.InterfaceName, m.MethodName, s.ModelStructName, f.Name)
   175  		}
   176  	}
   177  
   178  	return nil
   179  }
   180  
   181  // checkParams check all parameters
   182  func (m *InterfaceMethod) checkParams(params []parser.Param) (err error) {
   183  	paramList := make([]parser.Param, len(params))
   184  	for i, param := range params {
   185  		switch {
   186  		case param.Package == "UNDEFINED":
   187  			param.Package = m.Package
   188  		case param.IsError() || param.IsNull():
   189  			return fmt.Errorf("type error on interface [%s] param: [%s]", m.InterfaceName, param.Name)
   190  		case param.IsGenM():
   191  			param.Type = "map[string]interface{}"
   192  			param.Package = ""
   193  		case param.IsGenT():
   194  			param.Type = m.OriginStruct.Type
   195  			param.Package = m.OriginStruct.Package
   196  		}
   197  		paramList[i] = param
   198  	}
   199  	m.Params = paramList
   200  	return
   201  }
   202  
   203  // checkResult check all parameters and replace gen.T by target structure. Parameters must be one of int/string/struct/map
   204  func (m *InterfaceMethod) checkResult(result []parser.Param) (err error) {
   205  	resList := make([]parser.Param, len(result))
   206  	var hasError bool
   207  	for i, param := range result {
   208  		if param.Package == "UNDEFINED" {
   209  			param.Package = m.Package
   210  		}
   211  		if param.IsGenM() {
   212  			param.Type = "map[string]interface{}"
   213  			param.Package = ""
   214  		}
   215  		switch {
   216  		case param.InMainPkg():
   217  			return fmt.Errorf("query method cannot return struct of main package in [%s.%s]", m.InterfaceName, m.MethodName)
   218  		case param.IsError():
   219  			if hasError {
   220  				return fmt.Errorf("query method cannot return more than 1 error value in [%s.%s]", m.InterfaceName, m.MethodName)
   221  			}
   222  			param.SetName("err")
   223  			hasError = true
   224  		case param.Eq(m.OriginStruct) || param.IsGenT():
   225  			if !m.ResultData.IsNull() {
   226  				return fmt.Errorf("query method cannot return more than 1 data value in [%s.%s]", m.InterfaceName, m.MethodName)
   227  			}
   228  			param.SetName("result")
   229  			param.Type = m.OriginStruct.Type
   230  			param.Package = m.OriginStruct.Package
   231  			m.ResultData = param
   232  		case param.IsInterface():
   233  			return fmt.Errorf("query method can not return interface in [%s.%s]", m.InterfaceName, m.MethodName)
   234  		case param.IsGenRowsAffected():
   235  			param.Type = "int64"
   236  			param.Package = ""
   237  			param.SetName("rowsAffected")
   238  			m.GormOption = "Exec"
   239  		case param.IsSQLResult():
   240  			param.Type = "Result"
   241  			param.Package = "sql"
   242  			param.SetName("result")
   243  			m.GormOption = "Statement.ConnPool.ExecContext"
   244  		case param.IsSQLRow():
   245  			param.Type = "Row"
   246  			param.Package = "sql"
   247  			param.SetName("row")
   248  			m.GormOption = "Raw"
   249  			param.IsPointer = true
   250  		case param.IsSQLRows():
   251  			param.Type = "Rows"
   252  			param.Package = "sql"
   253  			param.SetName("rows")
   254  			m.GormOption = "Raw"
   255  			param.IsPointer = true
   256  		default:
   257  			if !m.ResultData.IsNull() {
   258  				return fmt.Errorf("query method cannot return more than 1 data value in [%s.%s]", m.InterfaceName, m.MethodName)
   259  			}
   260  			if param.Package == "" && !(param.IsBaseType() || param.IsMap() || param.IsTime()) {
   261  				param.Package = m.Package
   262  			}
   263  			param.SetName("result")
   264  			m.ResultData = param
   265  		}
   266  		resList[i] = param
   267  	}
   268  	m.Result = resList
   269  	return
   270  }
   271  
   272  // checkSQL get sql from comment and check it
   273  func (m *InterfaceMethod) checkSQL() (err error) {
   274  	m.SQLString = m.parseDocString()
   275  	if err = m.sqlStateCheckAndSplit(); err != nil {
   276  		err = fmt.Errorf("interface %s member method %s check sql err:%w", m.InterfaceName, m.MethodName, err)
   277  	}
   278  	return
   279  }
   280  
   281  func (m *InterfaceMethod) parseDocString() string {
   282  	docString := strings.TrimSpace(m.getSQLDocString())
   283  	switch {
   284  	case strings.HasPrefix(strings.ToLower(docString), "sql("):
   285  		docString = docString[4 : len(docString)-1]
   286  		m.GormOption = "Raw"
   287  		if m.ResultData.IsNull() {
   288  			m.GormOption = "Exec"
   289  		}
   290  	case strings.HasPrefix(strings.ToLower(docString), "where("):
   291  		docString = docString[6 : len(docString)-1]
   292  		m.GormOption = "Where"
   293  	default:
   294  		m.GormOption = "Raw"
   295  		if m.ResultData.IsNull() {
   296  			m.GormOption = "Exec"
   297  		}
   298  	}
   299  
   300  	// if wrapped by ", trim it
   301  	if strings.HasPrefix(docString, `"`) && strings.HasSuffix(docString, `"`) {
   302  		docString = docString[1 : len(docString)-1]
   303  	}
   304  	return docString
   305  }
   306  
   307  func (m *InterfaceMethod) getSQLDocString() string {
   308  	docString := strings.TrimSpace(m.Doc)
   309  	/*
   310  		// methodName descriptive message
   311  		// (this blank line is needed)
   312  		// sql
   313  	*/
   314  	if index := strings.Index(docString, "\n\n"); index != -1 {
   315  		if strings.Contains(docString[index+2:], m.MethodName) {
   316  			docString = docString[:index]
   317  		} else {
   318  			docString = docString[index+2:]
   319  		}
   320  	}
   321  	/* //methodName sql */
   322  	docString = strings.TrimPrefix(docString, m.MethodName)
   323  	// TODO: using sql key word to split comment
   324  	return docString
   325  }
   326  
   327  // sqlStateCheckAndSplit check sql with an adeterministic finite automaton
   328  func (m *InterfaceMethod) sqlStateCheckAndSplit() error {
   329  	sqlString := m.SQLString
   330  	m.Section = NewSection()
   331  	var buf model.SQLBuffer
   332  	for i := 0; !strOutRange(i, sqlString); i++ {
   333  		b := sqlString[i]
   334  		switch b {
   335  		case '"':
   336  			_ = buf.WriteByte(sqlString[i])
   337  			for i++; ; i++ {
   338  				if strOutRange(i, sqlString) {
   339  					return fmt.Errorf("incomplete SQL:%s", sqlString)
   340  				}
   341  				_ = buf.WriteByte(sqlString[i])
   342  				if sqlString[i] == '"' && sqlString[i-1] != '\\' {
   343  					break
   344  				}
   345  			}
   346  		case '\'':
   347  			_ = buf.WriteByte(sqlString[i])
   348  			for i++; ; i++ {
   349  				if strOutRange(i, sqlString) {
   350  					return fmt.Errorf("incomplete SQL:%s", sqlString)
   351  				}
   352  				_ = buf.WriteByte(sqlString[i])
   353  				if sqlString[i] == '\'' && sqlString[i-1] != '\\' {
   354  					break
   355  				}
   356  			}
   357  		case '\\':
   358  			if sqlString[i+1] == '@' {
   359  				i++
   360  				buf.WriteSQL(sqlString[i])
   361  				continue
   362  			}
   363  			buf.WriteSQL(b)
   364  		case '{', '@':
   365  			if sqlClause := buf.Dump(); strings.TrimSpace(sqlClause) != "" {
   366  				m.Section.members = append(m.Section.members, section{
   367  					Type:  model.SQL,
   368  					Value: strconv.Quote(sqlClause),
   369  				})
   370  			}
   371  
   372  			if strOutRange(i+1, sqlString) {
   373  				return fmt.Errorf("incomplete SQL:%s", sqlString)
   374  			}
   375  			if b == '{' && sqlString[i+1] == '{' {
   376  				for i += 2; ; i++ {
   377  					if strOutRange(i, sqlString) {
   378  						return fmt.Errorf("incomplete SQL:%s", sqlString)
   379  					}
   380  					if sqlString[i] == '"' {
   381  						_ = buf.WriteByte(sqlString[i])
   382  						for i++; ; i++ {
   383  							if strOutRange(i, sqlString) {
   384  								return fmt.Errorf("incomplete SQL:%s", sqlString)
   385  							}
   386  							_ = buf.WriteByte(sqlString[i])
   387  							if sqlString[i] == '"' && sqlString[i-1] != '\\' {
   388  								break
   389  							}
   390  						}
   391  						i++
   392  					}
   393  
   394  					if strOutRange(i+1, sqlString) {
   395  						return fmt.Errorf("incomplete SQL:%s", sqlString)
   396  					}
   397  					if sqlString[i] == '}' && sqlString[i+1] == '}' {
   398  						i++
   399  						sqlClause := buf.Dump()
   400  						part, err := m.Section.checkTemplate(sqlClause)
   401  						if err != nil {
   402  							return fmt.Errorf("sql [%s] dynamic template %s err:%w", sqlString, sqlClause, err)
   403  						}
   404  						m.Section.members = append(m.Section.members, part)
   405  						break
   406  					}
   407  					buf.WriteSQL(sqlString[i])
   408  				}
   409  			}
   410  			if b == '@' {
   411  				i++
   412  				status := model.DATA
   413  				if sqlString[i] == '@' {
   414  					i++
   415  					status = model.VARIABLE
   416  				}
   417  				for ; ; i++ {
   418  					if strOutRange(i, sqlString) || isEnd(sqlString[i]) {
   419  						varString := buf.Dump()
   420  						params, err := m.Section.checkSQLVar(varString, status, m)
   421  						if err != nil {
   422  							return fmt.Errorf("sql [%s] varable %s err:%s", sqlString, varString, err)
   423  						}
   424  						m.Section.members = append(m.Section.members, params)
   425  						i--
   426  						break
   427  					}
   428  					buf.WriteSQL(sqlString[i])
   429  				}
   430  			}
   431  		default:
   432  			buf.WriteSQL(b)
   433  		}
   434  	}
   435  	if sqlClause := buf.Dump(); strings.TrimSpace(sqlClause) != "" {
   436  		m.Section.members = append(m.Section.members, section{
   437  			Type:  model.SQL,
   438  			Value: strconv.Quote(sqlClause),
   439  		})
   440  	}
   441  
   442  	return nil
   443  }
   444  
   445  // checkSQLVarByParams return external parameters, table name
   446  func (m *InterfaceMethod) checkSQLVarByParams(param string, status model.Status) (result section, err error) {
   447  	for _, p := range m.Params {
   448  		structName := strings.Split(param, ".")[0]
   449  		if p.Name == structName {
   450  			if p.Name != param {
   451  				p = parser.Param{
   452  					Name: param,
   453  					Type: "string",
   454  				}
   455  			}
   456  			switch status {
   457  			case model.DATA:
   458  				if !m.isParamExist(param) {
   459  					m.SQLParams = append(m.SQLParams, p)
   460  				}
   461  			case model.VARIABLE:
   462  				if p.Type != "string" || p.IsArray {
   463  					err = fmt.Errorf("variable name must be string :%s type is %s", param, p.TypeName())
   464  					return
   465  				}
   466  				param = fmt.Sprintf("%s.Quote(%s)", m.S, param)
   467  			}
   468  			result = section{
   469  				Type:  status,
   470  				Value: param,
   471  			}
   472  			return
   473  		}
   474  	}
   475  	if param == "table" {
   476  		result = section{
   477  			Type:  model.SQL,
   478  			Value: strconv.Quote(m.Table),
   479  		}
   480  		return
   481  	}
   482  
   483  	return result, fmt.Errorf("unknow variable param:%s", param)
   484  }
   485  
   486  // isParamExist check param duplicate
   487  func (m *InterfaceMethod) isParamExist(paramName string) bool {
   488  	for _, param := range m.SQLParams {
   489  		if param.Name == paramName {
   490  			return true
   491  		}
   492  	}
   493  	return false
   494  }