github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/json/json_set_test.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package json
    16  
    17  import (
    18  	json2 "encoding/json"
    19  	"strconv"
    20  	"strings"
    21  	"testing"
    22  
    23  	"github.com/stretchr/testify/require"
    24  	"gopkg.in/src-d/go-errors.v1"
    25  
    26  	"github.com/dolthub/go-mysql-server/sql"
    27  	"github.com/dolthub/go-mysql-server/sql/expression"
    28  	"github.com/dolthub/go-mysql-server/sql/types"
    29  )
    30  
    31  func TestJSONSet(t *testing.T) {
    32  	_, err := NewJSONSet()
    33  	require.True(t, errors.Is(err, sql.ErrInvalidArgumentNumber))
    34  
    35  	_, err = NewJSONSet(
    36  		expression.NewGetField(0, types.LongText, "arg1", false),
    37  	)
    38  	require.True(t, errors.Is(err, sql.ErrInvalidArgumentNumber))
    39  
    40  	_, err = NewJSONSet(
    41  		expression.NewGetField(0, types.LongText, "arg1", false),
    42  		expression.NewGetField(1, types.LongText, "arg2", false),
    43  	)
    44  	require.True(t, errors.Is(err, sql.ErrInvalidArgumentNumber))
    45  
    46  	f1 := buildGetFieldExpressions(t, NewJSONSet, 3)
    47  
    48  	f2 := buildGetFieldExpressions(t, NewJSONSet, 5)
    49  
    50  	json := `{"a": 1, "b": [2, 3], "c": {"d": "foo"}}`
    51  
    52  	testCases := []struct {
    53  		f        sql.Expression
    54  		row      sql.Row
    55  		expected interface{}
    56  		err      error
    57  	}{
    58  		{f1, sql.Row{json, "$.a", 10.1}, `{"a": 10.1, "b": [2, 3], "c": {"d": "foo"}}`, nil},                                                           // update existing
    59  		{f1, sql.Row{json, "$.e", "new"}, `{"a": 1, "b": [2, 3], "c": {"d": "foo"},"e":"new"}`, nil},                                                   // set new
    60  		{f1, sql.Row{json, "$.c.d", "test"}, `{"a": 1, "b": [2, 3], "c": {"d": "test"}}`, nil},                                                         // update existing nested
    61  		{f2, sql.Row{json, "$.a", 10.1, "$.e", "new"}, `{"a": 10.1, "b": [2, 3], "c": {"d": "foo"},"e":"new"}`, nil},                                   // update existing and set new
    62  		{f1, sql.Row{json, "$.a.e", "test"}, `{"a": 1, "b": [2, 3], "c": {"d": "foo"}}`, nil},                                                          // set new nested does nothing
    63  		{f1, sql.Row{json, "$.c.e", "test"}, `{"a": 1, "b": [2, 3], "c": {"d": "foo","e":"test"}}`, nil},                                               // set new nested in existing struct
    64  		{f1, sql.Row{json, "$.c[5]", 4.1}, `{"a": 1, "b": [2, 3], "c": [{"d": "foo"}, 4.1]}`, nil},                                                     // update struct with indexing out of range
    65  		{f1, sql.Row{json, "$.b[0]", 4.1}, `{"a": 1, "b": [4.1, 3], "c": {"d": "foo"}}`, nil},                                                          // update element in array
    66  		{f1, sql.Row{json, "$.b[5]", 4.1}, `{"a": 1, "b": [2, 3, 4.1], "c": {"d": "foo"}}`, nil},                                                       // update element in array out of range
    67  		{f1, sql.Row{json, "$.b.c", 4}, `{"a": 1, "b": [2, 3], "c": {"d": "foo"}}`, nil},                                                               // set nested in array does nothing
    68  		{f1, sql.Row{json, "$.a[0]", 4.1}, `{"a": 4.1, "b": [2, 3], "c": {"d": "foo"}}`, nil},                                                          // update single element with indexing
    69  		{f1, sql.Row{json, "$[0]", 4.1}, `4.1`, nil},                                                                                                   // struct indexing
    70  		{f1, sql.Row{json, "$.[0]", 4.1}, nil, ErrInvalidPath},                                                                                         // improper struct indexing
    71  		{f1, sql.Row{json, "foo", "test"}, nil, ErrInvalidPath},                                                                                        // invalid path
    72  		{f1, sql.Row{json, "$.c.*", "test"}, nil, ErrPathWildcard},                                                                                     // path contains * wildcard
    73  		{f1, sql.Row{json, "$.c.**", "test"}, nil, ErrPathWildcard},                                                                                    // path contains ** wildcard
    74  		{f1, sql.Row{json, "$", 10.1}, `10.1`, nil},                                                                                                    // whole document
    75  		{f1, sql.Row{nil, "$", 42.7}, nil, nil},                                                                                                        // null document
    76  		{f1, sql.Row{json, nil, 10}, nil, nil},                                                                                                         // if any path is null, return null
    77  		{f2, sql.Row{json, "$.z", map[string]interface{}{"zz": 1.1}, "$.z.zz", 42.1}, `{"a": 1, "b": [2, 3], "c": {"d": "foo"},"z":{"zz":42.1}}`, nil}, // accumulates L->R
    78  
    79  		// mysql> select JSON_SET(JSON_ARRAY(), "$[2]", 1 , "$[2]", 2 ,"$[2]", 3 ,"$[2]", 4);
    80  		// +---------------------------------------------------------------------+
    81  		// | JSON_SET(JSON_ARRAY(), "$[2]", 1 , "$[2]", 2 ,"$[2]", 3 ,"$[2]", 4) |
    82  		// +---------------------------------------------------------------------+
    83  		// | [1, 2, 4]                                                           |
    84  		// +---------------------------------------------------------------------+
    85  		{buildGetFieldExpressions(t, NewJSONSet, 9),
    86  			sql.Row{`[]`,
    87  				"$[2]", 1.1, // [] -> [1.1]
    88  				"$[2]", 2.2, // [1.1] -> [1.1,2.2]
    89  				"$[2]", 3.3, // [1.1, 2.2] -> [1.1, 2.2, 3.3]
    90  				"$[2]", 4.4}, // [1.1, 2.2, 3.3] -> [1.1, 2.2, 4.4]
    91  			`[1.1, 2.2, 4.4]`, nil},
    92  	}
    93  
    94  	for _, tt := range testCases {
    95  		var paths []string
    96  		for _, path := range tt.row[1:] {
    97  			if _, ok := path.(string); ok {
    98  				paths = append(paths, path.(string))
    99  			} else {
   100  				if path == nil {
   101  					paths = append(paths, "null")
   102  				} else if _, ok := path.(int); ok {
   103  					paths = append(paths, strconv.Itoa(path.(int)))
   104  				} else {
   105  					m, _ := json2.Marshal(path)
   106  					paths = append(paths, string(m))
   107  				}
   108  			}
   109  		}
   110  
   111  		t.Run(tt.f.String()+"."+strings.Join(paths, ","), func(t *testing.T) {
   112  			require := require.New(t)
   113  			result, err := tt.f.Eval(sql.NewEmptyContext(), tt.row)
   114  			if tt.err == nil {
   115  				require.NoError(err)
   116  
   117  				var expect interface{}
   118  				if tt.expected != nil {
   119  					expect, _, err = types.JSON.Convert(tt.expected)
   120  					if err != nil {
   121  						panic("Bad test string. Can't convert string to JSONDocument: " + tt.expected.(string))
   122  					}
   123  				}
   124  
   125  				require.Equal(expect, result)
   126  			} else {
   127  				require.Error(tt.err, err)
   128  			}
   129  		})
   130  	}
   131  }
   132  
   133  func buildGetFieldExpressions(t *testing.T, construct func(...sql.Expression) (sql.Expression, error), argCount int) sql.Expression {
   134  	expressions := make([]sql.Expression, 0, argCount)
   135  	for i := 0; i < argCount; i++ {
   136  		expressions = append(expressions, expression.NewGetField(i, types.LongText, "arg"+strconv.Itoa(i), false))
   137  	}
   138  
   139  	result, err := construct(expressions...)
   140  	require.NoError(t, err)
   141  
   142  	return result
   143  }