github.com/lastbackend/toolkit@v0.0.0-20241020043710-cafa37b95aad/pkg/server/http/marshaler/jsonpb/jsonpb.go (about)

     1  /*
     2  Copyright [2014] - [2023] The Last.Backend authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package jsonpb
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/json"
    22  	"errors"
    23  	"fmt"
    24  	"github.com/lastbackend/toolkit/pkg/server/http/marshaler"
    25  	"github.com/lastbackend/toolkit/pkg/server/http/marshaler/util"
    26  	"google.golang.org/protobuf/encoding/protojson"
    27  	"google.golang.org/protobuf/proto"
    28  	"io"
    29  	"reflect"
    30  	"regexp"
    31  )
    32  
    33  var (
    34  	protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
    35  	typeProtoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem()
    36  	convFromType     = map[reflect.Kind]reflect.Value{
    37  		reflect.String:  reflect.ValueOf(util.String),
    38  		reflect.Bool:    reflect.ValueOf(util.Bool),
    39  		reflect.Float64: reflect.ValueOf(util.Float64),
    40  		reflect.Float32: reflect.ValueOf(util.Float32),
    41  		reflect.Int64:   reflect.ValueOf(util.Int64),
    42  		reflect.Int32:   reflect.ValueOf(util.Int32),
    43  		reflect.Uint64:  reflect.ValueOf(util.Uint64),
    44  		reflect.Uint32:  reflect.ValueOf(util.Uint32),
    45  		reflect.Slice:   reflect.ValueOf(util.Bytes),
    46  	}
    47  )
    48  
    49  type protoEnum interface {
    50  	fmt.Stringer
    51  	EnumDescriptor() ([]byte, []int)
    52  }
    53  
    54  type JSONPb struct {
    55  	protojson.MarshalOptions
    56  	protojson.UnmarshalOptions
    57  }
    58  
    59  func (*JSONPb) ContentType() string {
    60  	return "application/json"
    61  }
    62  
    63  func (j *JSONPb) Marshal(v interface{}) ([]byte, error) {
    64  	if _, ok := v.(proto.Message); !ok {
    65  		return j.marshalNonProtoField(v)
    66  	}
    67  
    68  	var buf bytes.Buffer
    69  
    70  	if err := j.marshalTo(&buf, v); err != nil {
    71  		return nil, err
    72  	}
    73  
    74  	return buf.Bytes(), nil
    75  }
    76  
    77  func (j *JSONPb) Unmarshal(data []byte, v interface{}) error {
    78  	return unmarshalJSONPb(data, j.UnmarshalOptions, v)
    79  }
    80  
    81  func (j *JSONPb) NewDecoder(r io.Reader) marshaler.Decoder {
    82  	d := json.NewDecoder(r)
    83  	return DecoderWrapper{
    84  		Decoder:          d,
    85  		UnmarshalOptions: j.UnmarshalOptions,
    86  	}
    87  }
    88  
    89  func (j *JSONPb) Delimiter() []byte {
    90  	return []byte("\n")
    91  }
    92  
    93  func (j *JSONPb) NewEncoder(w io.Writer) marshaler.Encoder {
    94  	return marshaler.EncoderFunc(func(v interface{}) error {
    95  		if err := j.marshalTo(w, v); err != nil {
    96  			return err
    97  		}
    98  		_, err := w.Write(j.Delimiter())
    99  		return err
   100  	})
   101  }
   102  
   103  func (j *JSONPb) marshalTo(w io.Writer, v interface{}) error {
   104  	p, ok := v.(proto.Message)
   105  	if !ok {
   106  		buf, err := j.marshalNonProtoField(v)
   107  		if err != nil {
   108  			return err
   109  		}
   110  		_, err = w.Write(buf)
   111  		return err
   112  	}
   113  
   114  	b, err := j.MarshalOptions.Marshal(p)
   115  	if err != nil {
   116  		return err
   117  	}
   118  
   119  	_, err = w.Write(b)
   120  
   121  	return err
   122  }
   123  
   124  func (j *JSONPb) marshalNonProtoField(v interface{}) ([]byte, error) {
   125  	if v == nil {
   126  		return []byte("null"), nil
   127  	}
   128  
   129  	rv := reflect.ValueOf(v)
   130  
   131  	for rv.Kind() == reflect.Ptr {
   132  		if rv.IsNil() {
   133  			return []byte("null"), nil
   134  		}
   135  		rv = rv.Elem()
   136  	}
   137  
   138  	if rv.Kind() == reflect.Slice {
   139  		if rv.IsNil() {
   140  			if j.EmitUnpopulated {
   141  				return []byte("[]"), nil
   142  			}
   143  			return []byte("null"), nil
   144  		}
   145  
   146  		if rv.Type().Elem().Implements(protoMessageType) {
   147  			var buf bytes.Buffer
   148  			err := buf.WriteByte('[')
   149  			if err != nil {
   150  				return nil, err
   151  			}
   152  			for i := 0; i < rv.Len(); i++ {
   153  				if i != 0 {
   154  					err = buf.WriteByte(',')
   155  					if err != nil {
   156  						return nil, err
   157  					}
   158  				}
   159  				if err = j.marshalTo(&buf, rv.Index(i).Interface().(proto.Message)); err != nil {
   160  					return nil, err
   161  				}
   162  			}
   163  			err = buf.WriteByte(']')
   164  			if err != nil {
   165  				return nil, err
   166  			}
   167  
   168  			return buf.Bytes(), nil
   169  		}
   170  	}
   171  
   172  	if rv.Kind() == reflect.Map {
   173  		m := make(map[string]*json.RawMessage)
   174  		for _, k := range rv.MapKeys() {
   175  			buf, err := j.Marshal(rv.MapIndex(k).Interface())
   176  			if err != nil {
   177  				return nil, err
   178  			}
   179  			m[fmt.Sprintf("%v", k.Interface())] = (*json.RawMessage)(&buf)
   180  		}
   181  		if j.Indent != "" {
   182  			return json.MarshalIndent(m, "", j.Indent)
   183  		}
   184  		return json.Marshal(m)
   185  	}
   186  
   187  	if enum, ok := rv.Interface().(protoEnum); ok && !j.UseEnumNumbers {
   188  		return json.Marshal(enum.String())
   189  	}
   190  
   191  	return json.Marshal(rv.Interface())
   192  }
   193  
   194  type DecoderWrapper struct {
   195  	*json.Decoder
   196  	protojson.UnmarshalOptions
   197  }
   198  
   199  func (d DecoderWrapper) Decode(v interface{}) error {
   200  	return decodeJSONPb(d.Decoder, d.UnmarshalOptions, v)
   201  }
   202  
   203  func unmarshalJSONPb(data []byte, unmarshaler protojson.UnmarshalOptions, v interface{}) error {
   204  	d := json.NewDecoder(bytes.NewReader(data))
   205  	return decodeJSONPb(d, unmarshaler, v)
   206  }
   207  
   208  func decodeJSONPb(d *json.Decoder, unmarshaler protojson.UnmarshalOptions, v interface{}) error {
   209  	p, ok := v.(proto.Message)
   210  	if !ok {
   211  		return decodeNonProtoField(d, unmarshaler, v)
   212  	}
   213  	var b json.RawMessage
   214  	err := d.Decode(&b)
   215  	if err != nil {
   216  		return err
   217  	}
   218  	return handleUnmarshalError(unmarshaler.Unmarshal(b, p))
   219  }
   220  
   221  func decodeNonProtoField(d *json.Decoder, unmarshaler protojson.UnmarshalOptions, v interface{}) error {
   222  	rv := reflect.ValueOf(v)
   223  
   224  	if rv.Kind() != reflect.Ptr {
   225  		return fmt.Errorf("%T is not a pointer", v)
   226  	}
   227  
   228  	for rv.Kind() == reflect.Ptr {
   229  		if rv.IsNil() {
   230  			rv.Set(reflect.New(rv.Type().Elem()))
   231  		}
   232  		if rv.Type().ConvertibleTo(typeProtoMessage) {
   233  			var b json.RawMessage
   234  			err := d.Decode(&b)
   235  			if err != nil {
   236  				return err
   237  			}
   238  
   239  			return unmarshaler.Unmarshal(b, rv.Interface().(proto.Message))
   240  		}
   241  		rv = rv.Elem()
   242  	}
   243  
   244  	if rv.Kind() == reflect.Map {
   245  		if rv.IsNil() {
   246  			rv.Set(reflect.MakeMap(rv.Type()))
   247  		}
   248  		conv, ok := convFromType[rv.Type().Key().Kind()]
   249  		if !ok {
   250  			return fmt.Errorf("unsupported type of map field key: %v", rv.Type().Key())
   251  		}
   252  
   253  		m := make(map[string]*json.RawMessage)
   254  		if err := d.Decode(&m); err != nil {
   255  			return err
   256  		}
   257  		for k, v := range m {
   258  			result := conv.Call([]reflect.Value{reflect.ValueOf(k)})
   259  			if err := result[1].Interface(); err != nil {
   260  				return err.(error)
   261  			}
   262  			bk := result[0]
   263  			bv := reflect.New(rv.Type().Elem())
   264  			if v == nil {
   265  				null := json.RawMessage("null")
   266  				v = &null
   267  			}
   268  			if err := unmarshalJSONPb(*v, unmarshaler, bv.Interface()); err != nil {
   269  				return err
   270  			}
   271  			rv.SetMapIndex(bk, bv.Elem())
   272  		}
   273  		return nil
   274  	}
   275  
   276  	if rv.Kind() == reflect.Slice {
   277  		var sl []json.RawMessage
   278  		if err := d.Decode(&sl); err != nil {
   279  			return err
   280  		}
   281  		if sl != nil {
   282  			rv.Set(reflect.MakeSlice(rv.Type(), 0, 0))
   283  		}
   284  		for _, item := range sl {
   285  			bv := reflect.New(rv.Type().Elem())
   286  			if err := unmarshalJSONPb(item, unmarshaler, bv.Interface()); err != nil {
   287  				return err
   288  			}
   289  			rv.Set(reflect.Append(rv, bv.Elem()))
   290  		}
   291  		return nil
   292  	}
   293  
   294  	if _, ok := rv.Interface().(protoEnum); ok {
   295  		var data interface{}
   296  		if err := d.Decode(&data); err != nil {
   297  			return err
   298  		}
   299  		switch v := data.(type) {
   300  		case string:
   301  			return fmt.Errorf("unmarshaling of symbolic enum %q not supported: %T", data, rv.Interface())
   302  		case float64:
   303  			rv.Set(reflect.ValueOf(int32(v)).Convert(rv.Type()))
   304  			return nil
   305  		default:
   306  			return fmt.Errorf("cannot assign %#v into Go type %T", data, rv.Interface())
   307  		}
   308  	}
   309  
   310  	return d.Decode(v)
   311  }
   312  
   313  func handleUnmarshalError(err error) error {
   314  	if err == nil {
   315  		return nil
   316  	}
   317  
   318  	message := err.Error()
   319  	re := regexp.MustCompile(`^proto:.*\(.*line.*\):\s*(.*)$`)
   320  	match := re.FindStringSubmatch(message)
   321  
   322  	if len(match) > 1 {
   323  		message = match[1]
   324  	}
   325  
   326  	return errors.New(message)
   327  }