github.com/acoshift/pgsql@v0.15.3/pgstmt/update_test.go (about)

     1  package pgstmt_test
     2  
     3  import (
     4  	"testing"
     5  	"time"
     6  
     7  	"github.com/stretchr/testify/assert"
     8  
     9  	"github.com/acoshift/pgsql/pgstmt"
    10  )
    11  
    12  func TestUpdate(t *testing.T) {
    13  	t.Parallel()
    14  
    15  	t.Run("update", func(t *testing.T) {
    16  		q, args := pgstmt.Update(func(b pgstmt.UpdateStatement) {
    17  			b.Table("users")
    18  			b.Set("name").To("test")
    19  			b.Set("email", "address", "updated_at").To("test@localhost", "123", pgstmt.Raw("now()"))
    20  			b.Set("age").ToRaw(1)
    21  			b.Where(func(b pgstmt.Cond) {
    22  				b.Eq("id", 5)
    23  			})
    24  			b.Returning("id", "name")
    25  		}).SQL()
    26  
    27  		assert.Equal(t,
    28  			stripSpace(`
    29  				update users
    30  				set name = $1,
    31  					(email, address, updated_at) = row($2, $3, now()),
    32  					age = 1
    33  				where (id = $4)
    34  				returning id, name
    35  			`),
    36  			q,
    37  		)
    38  		assert.EqualValues(t,
    39  			[]any{
    40  				"test",
    41  				"test@localhost", "123",
    42  				5,
    43  			},
    44  			args,
    45  		)
    46  	})
    47  
    48  	t.Run("update set select", func(t *testing.T) {
    49  		q, args := pgstmt.Update(func(b pgstmt.UpdateStatement) {
    50  			b.Table("users")
    51  			b.Set("name", "age", "updated_at").Select(func(b pgstmt.SelectStatement) {
    52  				b.Columns("name", "age", "now()")
    53  				b.From("users")
    54  				b.Where(func(b pgstmt.Cond) {
    55  					b.Eq("id", 6)
    56  				})
    57  			})
    58  			b.Set("updated_count").ToRaw("updated_count + 1")
    59  			b.Set("email", "address").To("test@localhost", "123")
    60  			b.Where(func(b pgstmt.Cond) {
    61  				b.Eq("id", 5)
    62  			})
    63  		}).SQL()
    64  
    65  		assert.Equal(t,
    66  			stripSpace(`
    67  				update users
    68  				set (name, age, updated_at) = (select name, age, now()
    69  											   from users
    70  											   where (id = $1)),
    71  					updated_count = updated_count + 1,
    72  					(email, address) = row($2, $3)
    73  				where (id = $4)
    74  			`),
    75  			q,
    76  		)
    77  		assert.EqualValues(t,
    78  			[]any{
    79  				6,
    80  				"test@localhost", "123",
    81  				5,
    82  			},
    83  			args,
    84  		)
    85  	})
    86  
    87  	t.Run("update from join", func(t *testing.T) {
    88  		q, args := pgstmt.Update(func(b pgstmt.UpdateStatement) {
    89  			b.Table("users")
    90  			b.Set("name").ToRaw("p.name")
    91  			b.Set("address").ToRaw("p.address")
    92  			b.Set("updated_at").ToRaw("now()")
    93  			b.Set("date").To(pgstmt.NotArg(time.Date(2022, 1, 2, 3, 4, 5, 6, time.UTC)))
    94  			b.From("users")
    95  			b.InnerJoin("profiles p").Using("email")
    96  			b.Where(func(b pgstmt.Cond) {
    97  				b.Eq("users.id", 2)
    98  			})
    99  		}).SQL()
   100  
   101  		assert.Equal(t,
   102  			stripSpace(`
   103  				update users
   104  				set name = p.name,
   105  					address = p.address,
   106  					updated_at = now(),
   107  					date = '2022-01-02 03:04:05.000000006Z'
   108  				from users
   109  				inner join profiles p using (email)
   110  				where (users.id = $1)
   111  			`),
   112  			q,
   113  		)
   114  		assert.EqualValues(t,
   115  			[]any{
   116  				2,
   117  			},
   118  			args,
   119  		)
   120  	})
   121  }