gitee.com/h79/goutils@v1.22.10/dao/wrapper/from.go (about)

     1  package wrapper
     2  
     3  import (
     4  	commonoption "gitee.com/h79/goutils/common/option"
     5  	"strings"
     6  )
     7  
     8  var _ IFrom = (*From)(nil)
     9  var gFromBuilder FromBuilder
    10  
    11  type FromBuilder func(table string, opts ...commonoption.Option) (IFrom, bool)
    12  
    13  // SetFromBuilder 表名检测,比如名称为 virtual表,返回真名表名是一个子查询
    14  func SetFromBuilder(fn FromBuilder) {
    15  	gFromBuilder = fn
    16  }
    17  
    18  type From struct {
    19  	_from IFrom  // for virtual table
    20  	From  string `json:"from" yaml:"from"`
    21  	As    string `json:"as" yaml:"as"`
    22  }
    23  
    24  func (f *From) Reset() {
    25  	f._from = nil
    26  	f.From = ""
    27  	f.As = ""
    28  }
    29  
    30  func (f *From) SetAs(as string) {
    31  	f.As = as
    32  }
    33  
    34  func (f *From) Is() bool {
    35  	return f.From != ""
    36  }
    37  
    38  func (f *From) Build(opts ...commonoption.Option) string {
    39  	if !f.Is() {
    40  		return ""
    41  	}
    42  	table := f.getTableName(opts...)
    43  	child := f.checkChild(table)
    44  	builder := strings.Builder{}
    45  	builder.WriteString(" FROM ")
    46  	if child {
    47  		builder.WriteByte('(')
    48  	}
    49  	f.from(&builder, table, child)
    50  	if child {
    51  		builder.WriteByte(')')
    52  	}
    53  	if f.As != table {
    54  		AddAlias(&builder, f.As)
    55  	}
    56  	return builder.String()
    57  }
    58  
    59  func (f *From) from(builder *strings.Builder, table string, child bool) bool {
    60  	quo := !child
    61  	first := table[0]
    62  	end := table[len(table)-1]
    63  	if first == '`' {
    64  		quo = false
    65  	}
    66  	if quo {
    67  		builder.WriteByte('`')
    68  		first = '`'
    69  	}
    70  	builder.WriteString(table)
    71  	if first == '`' && end != '`' {
    72  		builder.WriteByte('`')
    73  	}
    74  	return true
    75  }
    76  
    77  func (f *From) getTableName(opts ...commonoption.Option) string {
    78  	if gFromBuilder == nil {
    79  		return f.From
    80  	}
    81  	if f._from == nil {
    82  		from, ok := gFromBuilder(f.From, opts...)
    83  		if ok && from != nil {
    84  			f._from = from
    85  		}
    86  	}
    87  	if f._from != nil {
    88  		return f._from.Build(opts...)
    89  	}
    90  	return f.From
    91  }
    92  
    93  func (f *From) checkChild(table string) bool {
    94  	fm := strings.TrimSpace(table)
    95  	fm = strings.TrimLeft(fm, "(")
    96  	l := len(fm)
    97  	if l > 9 {
    98  		l = 9
    99  	}
   100  	up := strings.ToUpper(fm[0:l])
   101  	if strings.Contains(up, "DROP") ||
   102  		strings.Contains(up, "DELETE") ||
   103  		strings.Contains(up, "ALTER") ||
   104  		strings.Contains(up, "INSERT") {
   105  		strings.Contains(up, "UPDATE")
   106  		panic("may be illegal,lead to serious consequences")
   107  	}
   108  	return strings.Contains(up, "SELECT")
   109  }