github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/databases/orm/orm_object.go (about) 1 // The original package is migrated from beego and modified, you can find orignal from following link: 2 // "github.com/beego/beego/" 3 // 4 // Copyright 2023 IAC. All Rights Reserved. 5 // 6 // Licensed under the Apache License, Version 2.0 (the "License"); 7 // you may not use this file except in compliance with the License. 8 // You may obtain a copy of the License at 9 // 10 // http://www.apache.org/licenses/LICENSE-2.0 11 // 12 // Unless required by applicable law or agreed to in writing, software 13 // distributed under the License is distributed on an "AS IS" BASIS, 14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 // See the License for the specific language governing permissions and 16 // limitations under the License. 17 18 package orm 19 20 import ( 21 "context" 22 "fmt" 23 "reflect" 24 ) 25 26 // an insert queryer struct 27 type insertSet struct { 28 mi *modelInfo 29 orm *ormBase 30 stmt stmtQuerier 31 closed bool 32 } 33 34 var _ Inserter = new(insertSet) 35 36 // insert model ignore it's registered or not. 37 func (o *insertSet) Insert(md interface{}) (int64, error) { 38 return o.InsertWithCtx(context.Background(), md) 39 } 40 41 func (o *insertSet) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { 42 if o.closed { 43 return 0, ErrStmtClosed 44 } 45 val := reflect.ValueOf(md) 46 ind := reflect.Indirect(val) 47 typ := ind.Type() 48 name := getFullName(typ) 49 if val.Kind() != reflect.Ptr { 50 panic(fmt.Errorf("<Inserter.Insert> cannot use non-ptr model struct `%s`", name)) 51 } 52 if name != o.mi.fullName { 53 panic(fmt.Errorf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name)) 54 } 55 id, err := o.orm.alias.DbBaser.InsertStmt(ctx, o.stmt, o.mi, ind, o.orm.alias.TZ) 56 if err != nil { 57 return id, err 58 } 59 if id > 0 { 60 if o.mi.fields.pk.auto { 61 if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { 62 ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id)) 63 } else { 64 ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id) 65 } 66 } 67 } 68 return id, nil 69 } 70 71 // close insert queryer statement 72 func (o *insertSet) Close() error { 73 if o.closed { 74 return ErrStmtClosed 75 } 76 o.closed = true 77 return o.stmt.Close() 78 } 79 80 // create new insert queryer. 81 func newInsertSet(ctx context.Context, orm *ormBase, mi *modelInfo) (Inserter, error) { 82 bi := new(insertSet) 83 bi.orm = orm 84 bi.mi = mi 85 st, query, err := orm.alias.DbBaser.PrepareInsert(ctx, orm.db, mi) 86 if err != nil { 87 return nil, err 88 } 89 if Debug { 90 bi.stmt = newStmtQueryLog(orm.alias, st, query) 91 } else { 92 bi.stmt = st 93 } 94 return bi, nil 95 }