github.com/goravel/framework@v1.13.9/database/orm.go (about)

     1  package database
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  
     8  	"github.com/gookit/color"
     9  	"github.com/pkg/errors"
    10  
    11  	"github.com/goravel/framework/contracts/config"
    12  	ormcontract "github.com/goravel/framework/contracts/database/orm"
    13  	databasegorm "github.com/goravel/framework/database/gorm"
    14  	"github.com/goravel/framework/database/orm"
    15  )
    16  
    17  type OrmImpl struct {
    18  	ctx        context.Context
    19  	config     config.Config
    20  	connection string
    21  	query      ormcontract.Query
    22  	queries    map[string]ormcontract.Query
    23  }
    24  
    25  func NewOrmImpl(ctx context.Context, config config.Config, connection string, query ormcontract.Query) (*OrmImpl, error) {
    26  	return &OrmImpl{
    27  		ctx:        ctx,
    28  		config:     config,
    29  		connection: connection,
    30  		query:      query,
    31  		queries: map[string]ormcontract.Query{
    32  			connection: query,
    33  		},
    34  	}, nil
    35  }
    36  
    37  func (r *OrmImpl) Connection(name string) ormcontract.Orm {
    38  	if name == "" {
    39  		name = r.config.GetString("database.default")
    40  	}
    41  	if instance, exist := r.queries[name]; exist {
    42  		return &OrmImpl{
    43  			ctx:        r.ctx,
    44  			config:     r.config,
    45  			connection: name,
    46  			query:      instance,
    47  			queries:    r.queries,
    48  		}
    49  	}
    50  
    51  	queue, err := databasegorm.InitializeQuery(r.ctx, r.config, name)
    52  	if err != nil || queue == nil {
    53  		color.Redln(fmt.Sprintf("[Orm] Init %s connection error: %v", name, err))
    54  
    55  		return nil
    56  	}
    57  
    58  	r.queries[name] = queue
    59  
    60  	return &OrmImpl{
    61  		ctx:        r.ctx,
    62  		config:     r.config,
    63  		connection: name,
    64  		query:      queue,
    65  		queries:    r.queries,
    66  	}
    67  }
    68  
    69  func (r *OrmImpl) DB() (*sql.DB, error) {
    70  	query := r.Query().(*databasegorm.QueryImpl)
    71  
    72  	return query.Instance().DB()
    73  }
    74  
    75  func (r *OrmImpl) Query() ormcontract.Query {
    76  	return r.query
    77  }
    78  
    79  func (r *OrmImpl) Factory() ormcontract.Factory {
    80  	return NewFactoryImpl(r.Query())
    81  }
    82  
    83  func (r *OrmImpl) Observe(model any, observer ormcontract.Observer) {
    84  	orm.Observers = append(orm.Observers, orm.Observer{
    85  		Model:    model,
    86  		Observer: observer,
    87  	})
    88  }
    89  
    90  func (r *OrmImpl) Transaction(txFunc func(tx ormcontract.Transaction) error) error {
    91  	tx, err := r.Query().Begin()
    92  	if err != nil {
    93  		return err
    94  	}
    95  
    96  	if err := txFunc(tx); err != nil {
    97  		if err := tx.Rollback(); err != nil {
    98  			return errors.Wrapf(err, "rollback error: %v", err)
    99  		}
   100  
   101  		return err
   102  	} else {
   103  		return tx.Commit()
   104  	}
   105  }
   106  
   107  func (r *OrmImpl) WithContext(ctx context.Context) ormcontract.Orm {
   108  	for _, query := range r.queries {
   109  		query := query.(*databasegorm.QueryImpl)
   110  		query.SetContext(ctx)
   111  	}
   112  
   113  	query := r.query.(*databasegorm.QueryImpl)
   114  	query.SetContext(ctx)
   115  
   116  	return &OrmImpl{
   117  		ctx:        ctx,
   118  		config:     r.config,
   119  		connection: r.connection,
   120  		query:      query,
   121  		queries:    r.queries,
   122  	}
   123  }