github.com/acoshift/pgsql@v0.15.3/pgmodel/do.go (about)

     1  package pgmodel
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"reflect"
     7  
     8  	"github.com/acoshift/pgsql"
     9  	"github.com/acoshift/pgsql/pgstmt"
    10  )
    11  
    12  func Do(ctx context.Context, model any, filter ...Filter) error {
    13  	var err error
    14  	switch m := model.(type) {
    15  	case Selector:
    16  		stmt := pgstmt.Select(func(b pgstmt.SelectStatement) {
    17  			m.Select(b)
    18  			for _, f := range filter {
    19  				err = f.Apply(ctx, b)
    20  				if err != nil {
    21  					return
    22  				}
    23  			}
    24  		})
    25  		if err != nil {
    26  			return err
    27  		}
    28  		return m.Scan(stmt.QueryRowWith(ctx).Scan)
    29  	case Inserter:
    30  		stmt := pgstmt.Insert(func(b pgstmt.InsertStatement) {
    31  			m.Insert(b)
    32  		})
    33  
    34  		if scanner, ok := m.(Scanner); ok {
    35  			return scanner.Scan(stmt.QueryRowWith(ctx).Scan)
    36  		}
    37  		_, err := stmt.ExecWith(ctx)
    38  		return err
    39  	case Updater:
    40  		stmt := pgstmt.Update(func(b pgstmt.UpdateStatement) {
    41  			m.Update(b)
    42  			for _, f := range filter {
    43  				err = f.Apply(ctx, condUpdateWrapper{b})
    44  				if err != nil {
    45  					return
    46  				}
    47  			}
    48  		})
    49  		if err != nil {
    50  			return err
    51  		}
    52  
    53  		if scanner, ok := m.(Scanner); ok {
    54  			return scanner.Scan(stmt.QueryRowWith(ctx).Scan)
    55  		}
    56  		_, err := stmt.ExecWith(ctx)
    57  		return err
    58  	}
    59  
    60  	// *[]*model => []*model => *model => model
    61  	rf := reflect.ValueOf(model).Elem()
    62  	typeSlice := rf.Type()
    63  	typeElem := typeSlice.Elem().Elem()
    64  	rs := reflect.MakeSlice(typeSlice, 0, 0)
    65  	m := reflect.New(typeElem).Interface()
    66  
    67  	if m, ok := m.(Selector); ok {
    68  		stmt := pgstmt.Select(func(b pgstmt.SelectStatement) {
    69  			m.Select(b)
    70  			for _, f := range filter {
    71  				err = f.Apply(ctx, b)
    72  				if err != nil {
    73  					return
    74  				}
    75  			}
    76  		})
    77  		if err != nil {
    78  			return err
    79  		}
    80  
    81  		err = stmt.IterWith(ctx, func(scan pgsql.Scanner) error {
    82  			rx := reflect.New(typeElem)
    83  			err := rx.Interface().(Selector).Scan(scan)
    84  			if err != nil {
    85  				return err
    86  			}
    87  			rs = reflect.Append(rs, rx)
    88  			return nil
    89  		})
    90  		if err != nil {
    91  			return err
    92  		}
    93  		rf.Set(rs)
    94  		return nil
    95  	}
    96  
    97  	return fmt.Errorf("not implement")
    98  }