github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/insert.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package eorm 16 17 import ( 18 "context" 19 "errors" 20 21 "github.com/ecodeclub/eorm/internal/errs" 22 "github.com/ecodeclub/eorm/internal/model" 23 "github.com/valyala/bytebufferpool" 24 ) 25 26 var _ QueryBuilder = &Inserter[any]{} 27 28 // Inserter is used to construct an insert query 29 // More details check Build function 30 type Inserter[T any] struct { 31 inserterBuilder 32 db Session 33 values []*T 34 } 35 36 // NewInserter 开始构建一个 INSERT 查询 37 func NewInserter[T any](sess Session) *Inserter[T] { 38 return &Inserter[T]{ 39 inserterBuilder: inserterBuilder{ 40 builder: builder{ 41 core: sess.getCore(), 42 buffer: bytebufferpool.Get(), 43 }, 44 }, 45 db: sess, 46 } 47 } 48 49 func (i *Inserter[T]) SkipPK() *Inserter[T] { 50 i.ignorePK = true 51 return i 52 } 53 54 // Build function build the query 55 // notes: 56 // - All the values from function Values should have the same type. 57 // - It will insert all columns including auto-increment primary key 58 func (i *Inserter[T]) Build() (Query, error) { 59 defer bytebufferpool.Put(i.buffer) 60 var err error 61 if len(i.values) == 0 { 62 return EmptyQuery, errors.New("插入0行") 63 } 64 i.writeString("INSERT INTO ") 65 i.meta, err = i.metaRegistry.Get(i.values[0]) 66 if err != nil { 67 return EmptyQuery, err 68 } 69 i.quote(i.meta.TableName) 70 i.writeString("(") 71 fields, err := i.buildColumns() 72 if err != nil { 73 return EmptyQuery, err 74 } 75 i.writeString(")") 76 i.writeString(" VALUES") 77 for index, val := range i.values { 78 if index > 0 { 79 i.comma() 80 } 81 i.writeString("(") 82 refVal := i.valCreator.NewPrimitiveValue(val, i.meta) 83 for j, v := range fields { 84 fdVal, err := refVal.Field(v.FieldName) 85 if err != nil { 86 return EmptyQuery, err 87 } 88 i.parameter(fdVal.Interface()) 89 if j != len(fields)-1 { 90 i.comma() 91 } 92 } 93 i.writeString(")") 94 } 95 i.end() 96 return Query{SQL: i.buffer.String(), Args: i.args}, nil 97 } 98 99 // Columns specifies the columns that need to be inserted 100 // if cs is empty, all columns will be inserted 101 // cs must be the same with the field name in model 102 func (i *Inserter[T]) Columns(cs ...string) *Inserter[T] { 103 i.columns = cs 104 return i 105 } 106 107 // Values specify the rows 108 // all the elements must be the same type 109 // and users are supposed to passing at least one element 110 func (i *Inserter[T]) Values(values ...*T) *Inserter[T] { 111 i.values = values 112 return i 113 } 114 115 // Exec 发起查询 116 func (i *Inserter[T]) Exec(ctx context.Context) Result { 117 query, err := i.Build() 118 if err != nil { 119 return Result{err: err} 120 } 121 return newQuerier[T](i.db, query, i.meta, INSERT).Exec(ctx) 122 } 123 124 func (i *Inserter[T]) buildColumns() ([]*model.ColumnMeta, error) { 125 cs := make([]*model.ColumnMeta, 0, len(i.columns)) 126 if len(i.columns) != 0 { 127 for index, c := range i.columns { 128 v, isOk := i.meta.FieldMap[c] 129 if !isOk { 130 return cs, errs.NewInvalidFieldError(c) 131 } 132 i.quote(v.ColumnName) 133 if index != len(i.columns)-1 { 134 i.comma() 135 } 136 cs = append(cs, v) 137 } 138 } else { 139 for index, val := range i.meta.Columns { 140 if i.ignorePK && val.IsPrimaryKey { 141 continue 142 } 143 i.quote(val.ColumnName) 144 if index != len(i.meta.Columns)-1 { 145 i.comma() 146 } 147 cs = append(cs, val) 148 } 149 } 150 return cs, nil 151 }