github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/update.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package eorm
    16  
    17  import (
    18  	"context"
    19  	"reflect"
    20  
    21  	"github.com/ecodeclub/eorm/internal/errs"
    22  	"github.com/valyala/bytebufferpool"
    23  )
    24  
    25  var _ QueryBuilder = &Updater[any]{}
    26  
    27  // Updater is the builder responsible for building UPDATE query
    28  type Updater[T any] struct {
    29  	Session
    30  	updaterBuilder
    31  	table interface{}
    32  }
    33  
    34  // NewUpdater 开始构建一个 UPDATE 查询
    35  func NewUpdater[T any](sess Session) *Updater[T] {
    36  	return &Updater[T]{
    37  		updaterBuilder: updaterBuilder{
    38  			builder: builder{
    39  				core:   sess.getCore(),
    40  				buffer: bytebufferpool.Get(),
    41  			},
    42  		},
    43  		Session: sess,
    44  	}
    45  }
    46  
    47  func (u *Updater[T]) Update(val *T) *Updater[T] {
    48  	u.table = val
    49  	return u
    50  }
    51  
    52  // Build returns UPDATE query
    53  func (u *Updater[T]) Build() (Query, error) {
    54  	defer bytebufferpool.Put(u.buffer)
    55  	var err error
    56  	t := new(T)
    57  	if u.table == nil {
    58  		u.table = t
    59  	}
    60  	u.meta, err = u.metaRegistry.Get(t)
    61  	if err != nil {
    62  		return EmptyQuery, err
    63  	}
    64  
    65  	u.val = u.valCreator.NewPrimitiveValue(u.table, u.meta)
    66  	u.args = make([]interface{}, 0, len(u.meta.Columns))
    67  
    68  	u.writeString("UPDATE ")
    69  	u.quote(u.meta.TableName)
    70  	u.writeString(" SET ")
    71  	if len(u.assigns) == 0 {
    72  		err = u.buildDefaultColumns()
    73  	} else {
    74  		err = u.buildAssigns()
    75  	}
    76  	if err != nil {
    77  		return EmptyQuery, err
    78  	}
    79  
    80  	if len(u.where) > 0 {
    81  		u.writeString(" WHERE ")
    82  		err = u.buildPredicates(u.where)
    83  		if err != nil {
    84  			return EmptyQuery, err
    85  		}
    86  	}
    87  
    88  	u.end()
    89  	return Query{
    90  		SQL:  u.buffer.String(),
    91  		Args: u.args,
    92  	}, nil
    93  }
    94  
    95  func (u *Updater[T]) buildAssigns() error {
    96  	has := false
    97  	for _, assign := range u.assigns {
    98  		if has {
    99  			u.comma()
   100  		}
   101  		switch a := assign.(type) {
   102  		case Column:
   103  			c, ok := u.meta.FieldMap[a.name]
   104  			if !ok {
   105  				return errs.NewInvalidFieldError(a.name)
   106  			}
   107  			refVal, _ := u.val.Field(a.name)
   108  			u.quote(c.ColumnName)
   109  			_ = u.buffer.WriteByte('=')
   110  			u.parameter(refVal.Interface())
   111  			has = true
   112  		case columns:
   113  			for _, name := range a.cs {
   114  				c, ok := u.meta.FieldMap[name]
   115  				if !ok {
   116  					return errs.NewInvalidFieldError(name)
   117  				}
   118  				refVal, _ := u.val.Field(name)
   119  				if has {
   120  					u.comma()
   121  				}
   122  				u.quote(c.ColumnName)
   123  				_ = u.buffer.WriteByte('=')
   124  				u.parameter(refVal.Interface())
   125  				has = true
   126  			}
   127  		case Assignment:
   128  			if err := u.buildExpr(binaryExpr(a)); err != nil {
   129  				return err
   130  			}
   131  			has = true
   132  		default:
   133  			return errs.ErrUnsupportedAssignment
   134  		}
   135  	}
   136  	if !has {
   137  		return errs.NewValueNotSetError()
   138  	}
   139  	return nil
   140  }
   141  
   142  func (u *Updater[T]) buildDefaultColumns() error {
   143  	has := false
   144  	for _, c := range u.meta.Columns {
   145  		refVal, _ := u.val.Field(c.FieldName)
   146  		if u.ignoreZeroVal && isZeroValue(refVal) {
   147  			continue
   148  		}
   149  		if u.ignoreNilVal && isNilValue(refVal) {
   150  			continue
   151  		}
   152  		if has {
   153  			_ = u.buffer.WriteByte(',')
   154  		}
   155  		u.quote(c.ColumnName)
   156  		_ = u.buffer.WriteByte('=')
   157  		u.parameter(refVal.Interface())
   158  		has = true
   159  	}
   160  	if !has {
   161  		return errs.NewValueNotSetError()
   162  	}
   163  	return nil
   164  }
   165  
   166  // Set represents SET clause
   167  func (u *Updater[T]) Set(assigns ...Assignable) *Updater[T] {
   168  	u.assigns = assigns
   169  	return u
   170  }
   171  
   172  // Where represents WHERE clause
   173  func (u *Updater[T]) Where(predicates ...Predicate) *Updater[T] {
   174  	u.where = predicates
   175  	return u
   176  }
   177  
   178  // SkipNilValue 忽略 nil 值 columns
   179  func (u *Updater[T]) SkipNilValue() *Updater[T] {
   180  	u.ignoreNilVal = true
   181  	return u
   182  }
   183  
   184  func isNilValue(val reflect.Value) bool {
   185  	switch val.Kind() {
   186  	case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
   187  		return val.IsNil()
   188  	}
   189  	return false
   190  }
   191  
   192  // SkipZeroValue 忽略零值 columns
   193  func (u *Updater[T]) SkipZeroValue() *Updater[T] {
   194  	u.ignoreZeroVal = true
   195  	return u
   196  }
   197  
   198  func isZeroValue(val reflect.Value) bool {
   199  	return val.IsZero()
   200  }
   201  
   202  // Exec sql
   203  func (u *Updater[T]) Exec(ctx context.Context) Result {
   204  	query, err := u.Build()
   205  	if err != nil {
   206  		return Result{err: err}
   207  	}
   208  	return newQuerier[T](u.Session, query, u.meta, UPDATE).Exec(ctx)
   209  }