github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/databases/orm/orm_object.go (about)

     1  // The original package is migrated from beego and modified, you can find orignal from following link:
     2  //    "github.com/beego/beego/"
     3  //
     4  // Copyright 2023 IAC. All Rights Reserved.
     5  //
     6  // Licensed under the Apache License, Version 2.0 (the "License");
     7  // you may not use this file except in compliance with the License.
     8  // You may obtain a copy of the License at
     9  //
    10  //      http://www.apache.org/licenses/LICENSE-2.0
    11  //
    12  // Unless required by applicable law or agreed to in writing, software
    13  // distributed under the License is distributed on an "AS IS" BASIS,
    14  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  // See the License for the specific language governing permissions and
    16  // limitations under the License.
    17  
    18  package orm
    19  
    20  import (
    21  	"context"
    22  	"fmt"
    23  	"reflect"
    24  )
    25  
    26  // an insert queryer struct
    27  type insertSet struct {
    28  	mi     *modelInfo
    29  	orm    *ormBase
    30  	stmt   stmtQuerier
    31  	closed bool
    32  }
    33  
    34  var _ Inserter = new(insertSet)
    35  
    36  // insert model ignore it's registered or not.
    37  func (o *insertSet) Insert(md interface{}) (int64, error) {
    38  	return o.InsertWithCtx(context.Background(), md)
    39  }
    40  
    41  func (o *insertSet) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
    42  	if o.closed {
    43  		return 0, ErrStmtClosed
    44  	}
    45  	val := reflect.ValueOf(md)
    46  	ind := reflect.Indirect(val)
    47  	typ := ind.Type()
    48  	name := getFullName(typ)
    49  	if val.Kind() != reflect.Ptr {
    50  		panic(fmt.Errorf("<Inserter.Insert> cannot use non-ptr model struct `%s`", name))
    51  	}
    52  	if name != o.mi.fullName {
    53  		panic(fmt.Errorf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name))
    54  	}
    55  	id, err := o.orm.alias.DbBaser.InsertStmt(ctx, o.stmt, o.mi, ind, o.orm.alias.TZ)
    56  	if err != nil {
    57  		return id, err
    58  	}
    59  	if id > 0 {
    60  		if o.mi.fields.pk.auto {
    61  			if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
    62  				ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id))
    63  			} else {
    64  				ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id)
    65  			}
    66  		}
    67  	}
    68  	return id, nil
    69  }
    70  
    71  // close insert queryer statement
    72  func (o *insertSet) Close() error {
    73  	if o.closed {
    74  		return ErrStmtClosed
    75  	}
    76  	o.closed = true
    77  	return o.stmt.Close()
    78  }
    79  
    80  // create new insert queryer.
    81  func newInsertSet(ctx context.Context, orm *ormBase, mi *modelInfo) (Inserter, error) {
    82  	bi := new(insertSet)
    83  	bi.orm = orm
    84  	bi.mi = mi
    85  	st, query, err := orm.alias.DbBaser.PrepareInsert(ctx, orm.db, mi)
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  	if Debug {
    90  		bi.stmt = newStmtQueryLog(orm.alias, st, query)
    91  	} else {
    92  		bi.stmt = st
    93  	}
    94  	return bi, nil
    95  }