github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/json/json_set_test.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package json 16 17 import ( 18 json2 "encoding/json" 19 "strconv" 20 "strings" 21 "testing" 22 23 "github.com/stretchr/testify/require" 24 "gopkg.in/src-d/go-errors.v1" 25 26 "github.com/dolthub/go-mysql-server/sql" 27 "github.com/dolthub/go-mysql-server/sql/expression" 28 "github.com/dolthub/go-mysql-server/sql/types" 29 ) 30 31 func TestJSONSet(t *testing.T) { 32 _, err := NewJSONSet() 33 require.True(t, errors.Is(err, sql.ErrInvalidArgumentNumber)) 34 35 _, err = NewJSONSet( 36 expression.NewGetField(0, types.LongText, "arg1", false), 37 ) 38 require.True(t, errors.Is(err, sql.ErrInvalidArgumentNumber)) 39 40 _, err = NewJSONSet( 41 expression.NewGetField(0, types.LongText, "arg1", false), 42 expression.NewGetField(1, types.LongText, "arg2", false), 43 ) 44 require.True(t, errors.Is(err, sql.ErrInvalidArgumentNumber)) 45 46 f1 := buildGetFieldExpressions(t, NewJSONSet, 3) 47 48 f2 := buildGetFieldExpressions(t, NewJSONSet, 5) 49 50 json := `{"a": 1, "b": [2, 3], "c": {"d": "foo"}}` 51 52 testCases := []struct { 53 f sql.Expression 54 row sql.Row 55 expected interface{} 56 err error 57 }{ 58 {f1, sql.Row{json, "$.a", 10.1}, `{"a": 10.1, "b": [2, 3], "c": {"d": "foo"}}`, nil}, // update existing 59 {f1, sql.Row{json, "$.e", "new"}, `{"a": 1, "b": [2, 3], "c": {"d": "foo"},"e":"new"}`, nil}, // set new 60 {f1, sql.Row{json, "$.c.d", "test"}, `{"a": 1, "b": [2, 3], "c": {"d": "test"}}`, nil}, // update existing nested 61 {f2, sql.Row{json, "$.a", 10.1, "$.e", "new"}, `{"a": 10.1, "b": [2, 3], "c": {"d": "foo"},"e":"new"}`, nil}, // update existing and set new 62 {f1, sql.Row{json, "$.a.e", "test"}, `{"a": 1, "b": [2, 3], "c": {"d": "foo"}}`, nil}, // set new nested does nothing 63 {f1, sql.Row{json, "$.c.e", "test"}, `{"a": 1, "b": [2, 3], "c": {"d": "foo","e":"test"}}`, nil}, // set new nested in existing struct 64 {f1, sql.Row{json, "$.c[5]", 4.1}, `{"a": 1, "b": [2, 3], "c": [{"d": "foo"}, 4.1]}`, nil}, // update struct with indexing out of range 65 {f1, sql.Row{json, "$.b[0]", 4.1}, `{"a": 1, "b": [4.1, 3], "c": {"d": "foo"}}`, nil}, // update element in array 66 {f1, sql.Row{json, "$.b[5]", 4.1}, `{"a": 1, "b": [2, 3, 4.1], "c": {"d": "foo"}}`, nil}, // update element in array out of range 67 {f1, sql.Row{json, "$.b.c", 4}, `{"a": 1, "b": [2, 3], "c": {"d": "foo"}}`, nil}, // set nested in array does nothing 68 {f1, sql.Row{json, "$.a[0]", 4.1}, `{"a": 4.1, "b": [2, 3], "c": {"d": "foo"}}`, nil}, // update single element with indexing 69 {f1, sql.Row{json, "$[0]", 4.1}, `4.1`, nil}, // struct indexing 70 {f1, sql.Row{json, "$.[0]", 4.1}, nil, ErrInvalidPath}, // improper struct indexing 71 {f1, sql.Row{json, "foo", "test"}, nil, ErrInvalidPath}, // invalid path 72 {f1, sql.Row{json, "$.c.*", "test"}, nil, ErrPathWildcard}, // path contains * wildcard 73 {f1, sql.Row{json, "$.c.**", "test"}, nil, ErrPathWildcard}, // path contains ** wildcard 74 {f1, sql.Row{json, "$", 10.1}, `10.1`, nil}, // whole document 75 {f1, sql.Row{nil, "$", 42.7}, nil, nil}, // null document 76 {f1, sql.Row{json, nil, 10}, nil, nil}, // if any path is null, return null 77 {f2, sql.Row{json, "$.z", map[string]interface{}{"zz": 1.1}, "$.z.zz", 42.1}, `{"a": 1, "b": [2, 3], "c": {"d": "foo"},"z":{"zz":42.1}}`, nil}, // accumulates L->R 78 79 // mysql> select JSON_SET(JSON_ARRAY(), "$[2]", 1 , "$[2]", 2 ,"$[2]", 3 ,"$[2]", 4); 80 // +---------------------------------------------------------------------+ 81 // | JSON_SET(JSON_ARRAY(), "$[2]", 1 , "$[2]", 2 ,"$[2]", 3 ,"$[2]", 4) | 82 // +---------------------------------------------------------------------+ 83 // | [1, 2, 4] | 84 // +---------------------------------------------------------------------+ 85 {buildGetFieldExpressions(t, NewJSONSet, 9), 86 sql.Row{`[]`, 87 "$[2]", 1.1, // [] -> [1.1] 88 "$[2]", 2.2, // [1.1] -> [1.1,2.2] 89 "$[2]", 3.3, // [1.1, 2.2] -> [1.1, 2.2, 3.3] 90 "$[2]", 4.4}, // [1.1, 2.2, 3.3] -> [1.1, 2.2, 4.4] 91 `[1.1, 2.2, 4.4]`, nil}, 92 } 93 94 for _, tt := range testCases { 95 var paths []string 96 for _, path := range tt.row[1:] { 97 if _, ok := path.(string); ok { 98 paths = append(paths, path.(string)) 99 } else { 100 if path == nil { 101 paths = append(paths, "null") 102 } else if _, ok := path.(int); ok { 103 paths = append(paths, strconv.Itoa(path.(int))) 104 } else { 105 m, _ := json2.Marshal(path) 106 paths = append(paths, string(m)) 107 } 108 } 109 } 110 111 t.Run(tt.f.String()+"."+strings.Join(paths, ","), func(t *testing.T) { 112 require := require.New(t) 113 result, err := tt.f.Eval(sql.NewEmptyContext(), tt.row) 114 if tt.err == nil { 115 require.NoError(err) 116 117 var expect interface{} 118 if tt.expected != nil { 119 expect, _, err = types.JSON.Convert(tt.expected) 120 if err != nil { 121 panic("Bad test string. Can't convert string to JSONDocument: " + tt.expected.(string)) 122 } 123 } 124 125 require.Equal(expect, result) 126 } else { 127 require.Error(tt.err, err) 128 } 129 }) 130 } 131 } 132 133 func buildGetFieldExpressions(t *testing.T, construct func(...sql.Expression) (sql.Expression, error), argCount int) sql.Expression { 134 expressions := make([]sql.Expression, 0, argCount) 135 for i := 0; i < argCount; i++ { 136 expressions = append(expressions, expression.NewGetField(i, types.LongText, "arg"+strconv.Itoa(i), false)) 137 } 138 139 result, err := construct(expressions...) 140 require.NoError(t, err) 141 142 return result 143 }