github.com/unionj-cloud/go-doudou/v2@v2.3.5/toolkit/gormgen/internal/template/query.go (about)

     1  package template
     2  
     3  // DefaultQuery default query
     4  const DefaultQuery = `
     5  var (
     6  	Q =new(Query)
     7  	{{range $name,$d :=.Data -}}
     8  	{{$d.ModelStructName}} *{{$d.QueryStructName}}
     9  	{{end -}}
    10  )
    11  
    12  func SetDefault(db *gorm.DB, opts ...gormgen.DOOption) {
    13  	*Q = *Use(db,opts...)
    14  	{{range $name,$d :=.Data -}}
    15  	{{$d.ModelStructName}} = &Q.{{$d.ModelStructName}}
    16  	{{end -}}
    17  }
    18  
    19  `
    20  
    21  // QueryMethod query method template
    22  const QueryMethod = `
    23  func Use(db *gorm.DB, opts ...gormgen.DOOption) *Query {
    24  	return &Query{
    25  		db: db,
    26  		{{range $name,$d :=.Data -}}
    27  		{{$d.ModelStructName}}: new{{$d.ModelStructName}}(db,opts...),
    28  		{{end -}}
    29  	}
    30  }
    31  
    32  type Query struct{
    33  	db *gorm.DB
    34  
    35  	{{range $name,$d :=.Data -}}
    36  	{{$d.ModelStructName}} {{$d.QueryStructName}}
    37  	{{end}}
    38  }
    39  
    40  func (q *Query) Available() bool { return q.db != nil }
    41  
    42  func (q *Query) Clone(db *gorm.DB) *Query {
    43  	return &Query{
    44  		db: db,
    45  		{{range $name,$d :=.Data -}}
    46  		{{$d.ModelStructName}}: q.{{$d.ModelStructName}}.clone(db),
    47  		{{end}}
    48  	}
    49  }
    50  
    51  func (q *Query) Db() *gorm.DB {
    52  	return q.db
    53  }
    54  
    55  func (q *Query) ReadDB() *Query {
    56  	return q.ReplaceDB(q.db.Clauses(dbresolver.Read))
    57  }
    58  
    59  func (q *Query) WriteDB() *Query {
    60  	return q.ReplaceDB(q.db.Clauses(dbresolver.Write))
    61  }
    62  
    63  func (q *Query) ReplaceDB(db *gorm.DB) *Query {
    64  	return &Query{
    65  		db: db,
    66  		{{range $name,$d :=.Data -}}
    67  		{{$d.ModelStructName}}: q.{{$d.ModelStructName}}.replaceDB(db),
    68  		{{end}}
    69  	}
    70  }
    71  
    72  type queryCtx struct{ 
    73  	{{range $name,$d :=.Data -}}
    74  	{{$d.ModelStructName}} {{$d.ReturnObject}}
    75  	{{end}}
    76  }
    77  
    78  func (q *Query) WithContext(ctx context.Context) *queryCtx  {
    79  	return &queryCtx{
    80  		{{range $name,$d :=.Data -}}
    81  		{{$d.ModelStructName}}: q.{{$d.ModelStructName}}.WithContext(ctx),
    82  		{{end}}
    83  	}
    84  }
    85  
    86  func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error {
    87  	return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.Clone(tx)) }, opts...)
    88  }
    89  
    90  func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx {
    91  	tx := q.db.Begin(opts...)
    92  	return &QueryTx{Query: q.Clone(tx), Error: tx.Error}
    93  }
    94  
    95  type QueryTx struct {
    96  	*Query
    97  	Error error
    98  }
    99  
   100  func (q *QueryTx) Commit() error {
   101  	return q.db.Commit().Error
   102  }
   103  
   104  func (q *QueryTx) Rollback() error {
   105  	return q.db.Rollback().Error
   106  }
   107  
   108  func (q *QueryTx) SavePoint(name string) error {
   109  	return q.db.SavePoint(name).Error
   110  }
   111  
   112  func (q *QueryTx) RollbackTo(name string) error {
   113  	return q.db.RollbackTo(name).Error
   114  }
   115  
   116  `
   117  
   118  // QueryMethodTest query method test template
   119  const QueryMethodTest = `
   120  
   121  const dbName = "gen_test.db"
   122  
   123  var db *gorm.DB
   124  var once sync.Once
   125  
   126  func init() {
   127  	InitializeDB()
   128  	db.AutoMigrate(&_another{})
   129  }
   130  
   131  func InitializeDB() {
   132  	once.Do(func() {
   133  		var err error
   134  		db, err = gorm.Open(sqlite.Open(dbName), &gorm.Config{})
   135  		if err != nil {
   136  			panic(fmt.Errorf("open sqlite %q fail: %w", dbName, err))
   137  		}
   138  	})
   139  }
   140  
   141  func assert(t *testing.T, methodName string, res, exp interface{}) {
   142  	if !reflect.DeepEqual(res, exp) {
   143  		t.Errorf("%v() gotResult = %v, want %v", methodName, res, exp)
   144  	}
   145  }
   146  
   147  type _another struct {
   148  	ID uint64 ` + "`" + `gorm:"primaryKey"` + "`" + `
   149  }
   150  
   151  func (*_another) TableName() string { return "another_for_unit_test" }
   152  
   153  func Test_Available(t *testing.T) {
   154  	if !Use(db).Available() {
   155  		t.Errorf("query.Available() == false")
   156  	}
   157  }
   158  
   159  func Test_WithContext(t *testing.T) {
   160  	query := Use(db)
   161  	if !query.Available() {
   162  		t.Errorf("query Use(db) fail: query.Available() == false")
   163  	}
   164  
   165  	type Content string
   166  	var key, value Content = "gen_tag", "unit_test"
   167  	qCtx := query.WithContext(context.WithValue(context.Background(), key, value))
   168  
   169  	for _, ctx := range []context.Context{
   170  		{{range $name,$d :=.Data -}}
   171  		qCtx.{{$d.ModelStructName}}.UnderlyingDB().Statement.Context,
   172  		{{end}}
   173  	} {
   174  		if v := ctx.Value(key); v != value {
   175  			t.Errorf("get value from context fail, expect %q, got %q", value, v)
   176  		}
   177  	}
   178  }
   179  
   180  func Test_Transaction(t *testing.T) {
   181  	query := Use(db)
   182  	if !query.Available() {
   183  		t.Errorf("query Use(db) fail: query.Available() == false")
   184  	}
   185  
   186  	err := query.Transaction(func(tx *Query) error { return nil })
   187  	if err != nil {
   188  		t.Errorf("query.Transaction execute fail: %s", err)
   189  	}
   190  
   191  	tx := query.Begin()
   192  
   193  	err = tx.SavePoint("point")
   194  	if err != nil {
   195  		t.Errorf("query tx SavePoint fail: %s", err)
   196  	}
   197  	err = tx.RollbackTo("point")
   198  	if err != nil {
   199  		t.Errorf("query tx RollbackTo fail: %s", err)
   200  	}
   201  	err = tx.Commit()
   202  	if err != nil {
   203  		t.Errorf("query tx Commit fail: %s", err)
   204  	}
   205  
   206  	err = query.Begin().Rollback()
   207  	if err != nil {
   208  		t.Errorf("query tx Rollback fail: %s", err)
   209  	}
   210  }
   211  `