github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/multistmt/parse_test.go (about)

     1  package multistmt_test
     2  
     3  import (
     4  	"strings"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/assert"
     8  
     9  	"github.com/golang-migrate/migrate/v4/database/multistmt"
    10  )
    11  
    12  const maxMigrationSize = 1024
    13  
    14  func TestParse(t *testing.T) {
    15  	testCases := []struct {
    16  		name        string
    17  		multiStmt   string
    18  		delimiter   string
    19  		expected    []string
    20  		expectedErr error
    21  	}{
    22  		{name: "single statement, no delimiter", multiStmt: "single statement, no delimiter", delimiter: ";",
    23  			expected: []string{"single statement, no delimiter"}, expectedErr: nil},
    24  		{name: "single statement, one delimiter", multiStmt: "single statement, one delimiter;", delimiter: ";",
    25  			expected: []string{"single statement, one delimiter;"}, expectedErr: nil},
    26  		{name: "two statements, no trailing delimiter", multiStmt: "statement one; statement two", delimiter: ";",
    27  			expected: []string{"statement one;", " statement two"}, expectedErr: nil},
    28  		{name: "two statements, with trailing delimiter", multiStmt: "statement one; statement two;", delimiter: ";",
    29  			expected: []string{"statement one;", " statement two;"}, expectedErr: nil},
    30  	}
    31  
    32  	for _, tc := range testCases {
    33  		t.Run(tc.name, func(t *testing.T) {
    34  			stmts := make([]string, 0, len(tc.expected))
    35  			err := multistmt.Parse(strings.NewReader(tc.multiStmt), []byte(tc.delimiter), maxMigrationSize, func(b []byte) bool {
    36  				stmts = append(stmts, string(b))
    37  				return true
    38  			})
    39  			assert.Equal(t, tc.expectedErr, err)
    40  			assert.Equal(t, tc.expected, stmts)
    41  		})
    42  	}
    43  }
    44  
    45  func TestParseDiscontinue(t *testing.T) {
    46  	multiStmt := "statement one; statement two"
    47  	delimiter := ";"
    48  	expected := []string{"statement one;"}
    49  
    50  	stmts := make([]string, 0, len(expected))
    51  	err := multistmt.Parse(strings.NewReader(multiStmt), []byte(delimiter), maxMigrationSize, func(b []byte) bool {
    52  		stmts = append(stmts, string(b))
    53  		return false
    54  	})
    55  	assert.Nil(t, err)
    56  	assert.Equal(t, expected, stmts)
    57  }