github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/update.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package eorm 16 17 import ( 18 "context" 19 "reflect" 20 21 "github.com/ecodeclub/eorm/internal/errs" 22 "github.com/valyala/bytebufferpool" 23 ) 24 25 var _ QueryBuilder = &Updater[any]{} 26 27 // Updater is the builder responsible for building UPDATE query 28 type Updater[T any] struct { 29 Session 30 updaterBuilder 31 table interface{} 32 } 33 34 // NewUpdater 开始构建一个 UPDATE 查询 35 func NewUpdater[T any](sess Session) *Updater[T] { 36 return &Updater[T]{ 37 updaterBuilder: updaterBuilder{ 38 builder: builder{ 39 core: sess.getCore(), 40 buffer: bytebufferpool.Get(), 41 }, 42 }, 43 Session: sess, 44 } 45 } 46 47 func (u *Updater[T]) Update(val *T) *Updater[T] { 48 u.table = val 49 return u 50 } 51 52 // Build returns UPDATE query 53 func (u *Updater[T]) Build() (Query, error) { 54 defer bytebufferpool.Put(u.buffer) 55 var err error 56 t := new(T) 57 if u.table == nil { 58 u.table = t 59 } 60 u.meta, err = u.metaRegistry.Get(t) 61 if err != nil { 62 return EmptyQuery, err 63 } 64 65 u.val = u.valCreator.NewPrimitiveValue(u.table, u.meta) 66 u.args = make([]interface{}, 0, len(u.meta.Columns)) 67 68 u.writeString("UPDATE ") 69 u.quote(u.meta.TableName) 70 u.writeString(" SET ") 71 if len(u.assigns) == 0 { 72 err = u.buildDefaultColumns() 73 } else { 74 err = u.buildAssigns() 75 } 76 if err != nil { 77 return EmptyQuery, err 78 } 79 80 if len(u.where) > 0 { 81 u.writeString(" WHERE ") 82 err = u.buildPredicates(u.where) 83 if err != nil { 84 return EmptyQuery, err 85 } 86 } 87 88 u.end() 89 return Query{ 90 SQL: u.buffer.String(), 91 Args: u.args, 92 }, nil 93 } 94 95 func (u *Updater[T]) buildAssigns() error { 96 has := false 97 for _, assign := range u.assigns { 98 if has { 99 u.comma() 100 } 101 switch a := assign.(type) { 102 case Column: 103 c, ok := u.meta.FieldMap[a.name] 104 if !ok { 105 return errs.NewInvalidFieldError(a.name) 106 } 107 refVal, _ := u.val.Field(a.name) 108 u.quote(c.ColumnName) 109 _ = u.buffer.WriteByte('=') 110 u.parameter(refVal.Interface()) 111 has = true 112 case columns: 113 for _, name := range a.cs { 114 c, ok := u.meta.FieldMap[name] 115 if !ok { 116 return errs.NewInvalidFieldError(name) 117 } 118 refVal, _ := u.val.Field(name) 119 if has { 120 u.comma() 121 } 122 u.quote(c.ColumnName) 123 _ = u.buffer.WriteByte('=') 124 u.parameter(refVal.Interface()) 125 has = true 126 } 127 case Assignment: 128 if err := u.buildExpr(binaryExpr(a)); err != nil { 129 return err 130 } 131 has = true 132 default: 133 return errs.ErrUnsupportedAssignment 134 } 135 } 136 if !has { 137 return errs.NewValueNotSetError() 138 } 139 return nil 140 } 141 142 func (u *Updater[T]) buildDefaultColumns() error { 143 has := false 144 for _, c := range u.meta.Columns { 145 refVal, _ := u.val.Field(c.FieldName) 146 if u.ignoreZeroVal && isZeroValue(refVal) { 147 continue 148 } 149 if u.ignoreNilVal && isNilValue(refVal) { 150 continue 151 } 152 if has { 153 _ = u.buffer.WriteByte(',') 154 } 155 u.quote(c.ColumnName) 156 _ = u.buffer.WriteByte('=') 157 u.parameter(refVal.Interface()) 158 has = true 159 } 160 if !has { 161 return errs.NewValueNotSetError() 162 } 163 return nil 164 } 165 166 // Set represents SET clause 167 func (u *Updater[T]) Set(assigns ...Assignable) *Updater[T] { 168 u.assigns = assigns 169 return u 170 } 171 172 // Where represents WHERE clause 173 func (u *Updater[T]) Where(predicates ...Predicate) *Updater[T] { 174 u.where = predicates 175 return u 176 } 177 178 // SkipNilValue 忽略 nil 值 columns 179 func (u *Updater[T]) SkipNilValue() *Updater[T] { 180 u.ignoreNilVal = true 181 return u 182 } 183 184 func isNilValue(val reflect.Value) bool { 185 switch val.Kind() { 186 case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: 187 return val.IsNil() 188 } 189 return false 190 } 191 192 // SkipZeroValue 忽略零值 columns 193 func (u *Updater[T]) SkipZeroValue() *Updater[T] { 194 u.ignoreZeroVal = true 195 return u 196 } 197 198 func isZeroValue(val reflect.Value) bool { 199 return val.IsZero() 200 } 201 202 // Exec sql 203 func (u *Updater[T]) Exec(ctx context.Context) Result { 204 query, err := u.Build() 205 if err != nil { 206 return Result{err: err} 207 } 208 return newQuerier[T](u.Session, query, u.meta, UPDATE).Exec(ctx) 209 }