github.com/lastbackend/toolkit@v0.0.0-20241020043710-cafa37b95aad/pkg/server/http/marshaler/jsonpb/jsonpb.go (about) 1 /* 2 Copyright [2014] - [2023] The Last.Backend authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package jsonpb 18 19 import ( 20 "bytes" 21 "encoding/json" 22 "errors" 23 "fmt" 24 "github.com/lastbackend/toolkit/pkg/server/http/marshaler" 25 "github.com/lastbackend/toolkit/pkg/server/http/marshaler/util" 26 "google.golang.org/protobuf/encoding/protojson" 27 "google.golang.org/protobuf/proto" 28 "io" 29 "reflect" 30 "regexp" 31 ) 32 33 var ( 34 protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem() 35 typeProtoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem() 36 convFromType = map[reflect.Kind]reflect.Value{ 37 reflect.String: reflect.ValueOf(util.String), 38 reflect.Bool: reflect.ValueOf(util.Bool), 39 reflect.Float64: reflect.ValueOf(util.Float64), 40 reflect.Float32: reflect.ValueOf(util.Float32), 41 reflect.Int64: reflect.ValueOf(util.Int64), 42 reflect.Int32: reflect.ValueOf(util.Int32), 43 reflect.Uint64: reflect.ValueOf(util.Uint64), 44 reflect.Uint32: reflect.ValueOf(util.Uint32), 45 reflect.Slice: reflect.ValueOf(util.Bytes), 46 } 47 ) 48 49 type protoEnum interface { 50 fmt.Stringer 51 EnumDescriptor() ([]byte, []int) 52 } 53 54 type JSONPb struct { 55 protojson.MarshalOptions 56 protojson.UnmarshalOptions 57 } 58 59 func (*JSONPb) ContentType() string { 60 return "application/json" 61 } 62 63 func (j *JSONPb) Marshal(v interface{}) ([]byte, error) { 64 if _, ok := v.(proto.Message); !ok { 65 return j.marshalNonProtoField(v) 66 } 67 68 var buf bytes.Buffer 69 70 if err := j.marshalTo(&buf, v); err != nil { 71 return nil, err 72 } 73 74 return buf.Bytes(), nil 75 } 76 77 func (j *JSONPb) Unmarshal(data []byte, v interface{}) error { 78 return unmarshalJSONPb(data, j.UnmarshalOptions, v) 79 } 80 81 func (j *JSONPb) NewDecoder(r io.Reader) marshaler.Decoder { 82 d := json.NewDecoder(r) 83 return DecoderWrapper{ 84 Decoder: d, 85 UnmarshalOptions: j.UnmarshalOptions, 86 } 87 } 88 89 func (j *JSONPb) Delimiter() []byte { 90 return []byte("\n") 91 } 92 93 func (j *JSONPb) NewEncoder(w io.Writer) marshaler.Encoder { 94 return marshaler.EncoderFunc(func(v interface{}) error { 95 if err := j.marshalTo(w, v); err != nil { 96 return err 97 } 98 _, err := w.Write(j.Delimiter()) 99 return err 100 }) 101 } 102 103 func (j *JSONPb) marshalTo(w io.Writer, v interface{}) error { 104 p, ok := v.(proto.Message) 105 if !ok { 106 buf, err := j.marshalNonProtoField(v) 107 if err != nil { 108 return err 109 } 110 _, err = w.Write(buf) 111 return err 112 } 113 114 b, err := j.MarshalOptions.Marshal(p) 115 if err != nil { 116 return err 117 } 118 119 _, err = w.Write(b) 120 121 return err 122 } 123 124 func (j *JSONPb) marshalNonProtoField(v interface{}) ([]byte, error) { 125 if v == nil { 126 return []byte("null"), nil 127 } 128 129 rv := reflect.ValueOf(v) 130 131 for rv.Kind() == reflect.Ptr { 132 if rv.IsNil() { 133 return []byte("null"), nil 134 } 135 rv = rv.Elem() 136 } 137 138 if rv.Kind() == reflect.Slice { 139 if rv.IsNil() { 140 if j.EmitUnpopulated { 141 return []byte("[]"), nil 142 } 143 return []byte("null"), nil 144 } 145 146 if rv.Type().Elem().Implements(protoMessageType) { 147 var buf bytes.Buffer 148 err := buf.WriteByte('[') 149 if err != nil { 150 return nil, err 151 } 152 for i := 0; i < rv.Len(); i++ { 153 if i != 0 { 154 err = buf.WriteByte(',') 155 if err != nil { 156 return nil, err 157 } 158 } 159 if err = j.marshalTo(&buf, rv.Index(i).Interface().(proto.Message)); err != nil { 160 return nil, err 161 } 162 } 163 err = buf.WriteByte(']') 164 if err != nil { 165 return nil, err 166 } 167 168 return buf.Bytes(), nil 169 } 170 } 171 172 if rv.Kind() == reflect.Map { 173 m := make(map[string]*json.RawMessage) 174 for _, k := range rv.MapKeys() { 175 buf, err := j.Marshal(rv.MapIndex(k).Interface()) 176 if err != nil { 177 return nil, err 178 } 179 m[fmt.Sprintf("%v", k.Interface())] = (*json.RawMessage)(&buf) 180 } 181 if j.Indent != "" { 182 return json.MarshalIndent(m, "", j.Indent) 183 } 184 return json.Marshal(m) 185 } 186 187 if enum, ok := rv.Interface().(protoEnum); ok && !j.UseEnumNumbers { 188 return json.Marshal(enum.String()) 189 } 190 191 return json.Marshal(rv.Interface()) 192 } 193 194 type DecoderWrapper struct { 195 *json.Decoder 196 protojson.UnmarshalOptions 197 } 198 199 func (d DecoderWrapper) Decode(v interface{}) error { 200 return decodeJSONPb(d.Decoder, d.UnmarshalOptions, v) 201 } 202 203 func unmarshalJSONPb(data []byte, unmarshaler protojson.UnmarshalOptions, v interface{}) error { 204 d := json.NewDecoder(bytes.NewReader(data)) 205 return decodeJSONPb(d, unmarshaler, v) 206 } 207 208 func decodeJSONPb(d *json.Decoder, unmarshaler protojson.UnmarshalOptions, v interface{}) error { 209 p, ok := v.(proto.Message) 210 if !ok { 211 return decodeNonProtoField(d, unmarshaler, v) 212 } 213 var b json.RawMessage 214 err := d.Decode(&b) 215 if err != nil { 216 return err 217 } 218 return handleUnmarshalError(unmarshaler.Unmarshal(b, p)) 219 } 220 221 func decodeNonProtoField(d *json.Decoder, unmarshaler protojson.UnmarshalOptions, v interface{}) error { 222 rv := reflect.ValueOf(v) 223 224 if rv.Kind() != reflect.Ptr { 225 return fmt.Errorf("%T is not a pointer", v) 226 } 227 228 for rv.Kind() == reflect.Ptr { 229 if rv.IsNil() { 230 rv.Set(reflect.New(rv.Type().Elem())) 231 } 232 if rv.Type().ConvertibleTo(typeProtoMessage) { 233 var b json.RawMessage 234 err := d.Decode(&b) 235 if err != nil { 236 return err 237 } 238 239 return unmarshaler.Unmarshal(b, rv.Interface().(proto.Message)) 240 } 241 rv = rv.Elem() 242 } 243 244 if rv.Kind() == reflect.Map { 245 if rv.IsNil() { 246 rv.Set(reflect.MakeMap(rv.Type())) 247 } 248 conv, ok := convFromType[rv.Type().Key().Kind()] 249 if !ok { 250 return fmt.Errorf("unsupported type of map field key: %v", rv.Type().Key()) 251 } 252 253 m := make(map[string]*json.RawMessage) 254 if err := d.Decode(&m); err != nil { 255 return err 256 } 257 for k, v := range m { 258 result := conv.Call([]reflect.Value{reflect.ValueOf(k)}) 259 if err := result[1].Interface(); err != nil { 260 return err.(error) 261 } 262 bk := result[0] 263 bv := reflect.New(rv.Type().Elem()) 264 if v == nil { 265 null := json.RawMessage("null") 266 v = &null 267 } 268 if err := unmarshalJSONPb(*v, unmarshaler, bv.Interface()); err != nil { 269 return err 270 } 271 rv.SetMapIndex(bk, bv.Elem()) 272 } 273 return nil 274 } 275 276 if rv.Kind() == reflect.Slice { 277 var sl []json.RawMessage 278 if err := d.Decode(&sl); err != nil { 279 return err 280 } 281 if sl != nil { 282 rv.Set(reflect.MakeSlice(rv.Type(), 0, 0)) 283 } 284 for _, item := range sl { 285 bv := reflect.New(rv.Type().Elem()) 286 if err := unmarshalJSONPb(item, unmarshaler, bv.Interface()); err != nil { 287 return err 288 } 289 rv.Set(reflect.Append(rv, bv.Elem())) 290 } 291 return nil 292 } 293 294 if _, ok := rv.Interface().(protoEnum); ok { 295 var data interface{} 296 if err := d.Decode(&data); err != nil { 297 return err 298 } 299 switch v := data.(type) { 300 case string: 301 return fmt.Errorf("unmarshaling of symbolic enum %q not supported: %T", data, rv.Interface()) 302 case float64: 303 rv.Set(reflect.ValueOf(int32(v)).Convert(rv.Type())) 304 return nil 305 default: 306 return fmt.Errorf("cannot assign %#v into Go type %T", data, rv.Interface()) 307 } 308 } 309 310 return d.Decode(v) 311 } 312 313 func handleUnmarshalError(err error) error { 314 if err == nil { 315 return nil 316 } 317 318 message := err.Error() 319 re := regexp.MustCompile(`^proto:.*\(.*line.*\):\s*(.*)$`) 320 match := re.FindStringSubmatch(message) 321 322 if len(match) > 1 { 323 message = match[1] 324 } 325 326 return errors.New(message) 327 }