github.com/gramework/gramework@v1.8.1-0.20231027140105-82555c9057f5/x/sqlgen/insert.go (about)

     1  package sqlgen
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  	"sync"
     7  )
     8  
     9  // Insert statement builder generates
    10  // an insert statement using `?` placeholders
    11  // for values
    12  func Insert(table string) *InsertBuilder {
    13  	return &InsertBuilder{
    14  		tableName: table,
    15  		query:     fmt.Sprintln(`INSERT INTO`, table),
    16  		lock:      new(sync.Mutex),
    17  	}
    18  }
    19  
    20  // PreparedInsert statement builder generates
    21  // the insert SQL statement with values built
    22  // in the statement without using placeholders
    23  func PreparedInsert(table string) *InsertBuilder {
    24  	i := Insert(table)
    25  	i.prepared = true
    26  	return i
    27  }
    28  
    29  // Columns defines column list
    30  func (b *InsertBuilder) Columns(columns ...string) *InsertBuilder {
    31  	b.lock.Lock()
    32  	b.columns = columns
    33  	b.query = fmt.Sprintf(`%s(`, b.query)
    34  	for i, column := range columns {
    35  		b.query = fmt.Sprintf(`%s%s`, b.query, column)
    36  		if i < len(columns)-1 {
    37  			b.query = fmt.Sprintf(`%s,`, b.query)
    38  		}
    39  	}
    40  	b.query = fmt.Sprintf(`%s)`, b.query)
    41  	b.lock.Unlock()
    42  	return b
    43  }
    44  
    45  // Values appends column values to the query
    46  func (b *InsertBuilder) Values(columnValues ...interface{}) *InsertBuilder {
    47  	b.lock.Lock()
    48  	sqlValue := "("
    49  	if b.prepared {
    50  		for k, columnValue := range columnValues {
    51  			switch v := columnValue.(type) {
    52  			case string:
    53  				sqlValue = fmt.Sprintf("%s'%s'", sqlValue, strings.Replace(v, "'", "''", -1))
    54  			default:
    55  				sqlValue = fmt.Sprintf("%s%v", sqlValue, v)
    56  			}
    57  			if k < len(columnValues)-1 {
    58  				sqlValue = fmt.Sprintf("%s, ", sqlValue)
    59  			}
    60  		}
    61  	} else {
    62  		for k := range columnValues {
    63  			sqlValue = fmt.Sprintf("%s?", sqlValue)
    64  			if k < len(columnValues)-1 {
    65  				sqlValue = fmt.Sprintf("%s, ", sqlValue)
    66  			}
    67  		}
    68  	}
    69  	b.sqlValues = append(b.sqlValues, fmt.Sprintf(`%s)`, sqlValue))
    70  
    71  	b.lock.Unlock()
    72  
    73  	return b
    74  }
    75  
    76  // Build the query
    77  func (b *InsertBuilder) Build() string {
    78  	b.lock.Lock()
    79  	defer b.lock.Unlock()
    80  	return fmt.Sprintf("%s\n    VALUES %s;", b.query, strings.Join(b.sqlValues, ", \n        "))
    81  }