github.com/inklabsfoundation/inkchain@v0.17.1-0.20181025012015-c3cef8062f19/common/tools/protolator/json.go (about)

     1  /*
     2  Copyright IBM Corp. 2017 All Rights Reserved.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8                   http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package protolator
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/json"
    22  	"fmt"
    23  	"io"
    24  	"io/ioutil"
    25  	"reflect"
    26  
    27  	"github.com/golang/protobuf/jsonpb"
    28  	"github.com/golang/protobuf/proto"
    29  )
    30  
    31  type protoFieldFactory interface {
    32  	// Handles should return whether or not this particular protoFieldFactory instance
    33  	// is responsible for the given proto's field
    34  	Handles(msg proto.Message, fieldName string, fieldType reflect.Type, fieldValue reflect.Value) bool
    35  
    36  	// NewProtoField should create a backing protoField implementor
    37  	// Note that the fieldValue may represent nil, so the fieldType is also
    38  	// included (as reflecting the type of a nil value causes a panic)
    39  	NewProtoField(msg proto.Message, fieldName string, fieldType reflect.Type, fieldValue reflect.Value) (protoField, error)
    40  }
    41  
    42  type protoField interface {
    43  	// Name returns the proto name of the field
    44  	Name() string
    45  
    46  	// PopulateFrom mutates the underlying object, by taking the intermediate JSON representation
    47  	// and converting it into the proto representation, then assigning it to the backing value
    48  	// via reflection
    49  	PopulateFrom(source interface{}) error
    50  
    51  	// PopulateTo does not mutate the underlying object, but instead converts it
    52  	// into the intermediate JSON representation (ie a struct -> map[string]interface{}
    53  	// or a slice of structs to []map[string]interface{}
    54  	PopulateTo() (interface{}, error)
    55  }
    56  
    57  var (
    58  	protoMsgType           = reflect.TypeOf((*proto.Message)(nil)).Elem()
    59  	mapStringInterfaceType = reflect.TypeOf(map[string]interface{}{})
    60  	bytesType              = reflect.TypeOf([]byte{})
    61  )
    62  
    63  type baseField struct {
    64  	msg   proto.Message
    65  	name  string
    66  	fType reflect.Type
    67  	vType reflect.Type
    68  	value reflect.Value
    69  }
    70  
    71  func (bf *baseField) Name() string {
    72  	return bf.name
    73  }
    74  
    75  type plainField struct {
    76  	baseField
    77  	populateFrom func(source interface{}, destType reflect.Type) (reflect.Value, error)
    78  	populateTo   func(source reflect.Value) (interface{}, error)
    79  }
    80  
    81  func (pf *plainField) PopulateFrom(source interface{}) error {
    82  	if !reflect.TypeOf(source).AssignableTo(pf.fType) {
    83  		return fmt.Errorf("expected field %s for message %T to be assignable from %v but was not.  Is %T", pf.name, pf.msg, pf.fType, source)
    84  	}
    85  	value, err := pf.populateFrom(source, pf.vType)
    86  	if err != nil {
    87  		return fmt.Errorf("error in PopulateFrom for field %s for message %T: %s", pf.name, pf.msg, err)
    88  	}
    89  	pf.value.Set(value)
    90  	return nil
    91  }
    92  
    93  func (pf *plainField) PopulateTo() (interface{}, error) {
    94  	if !pf.value.Type().AssignableTo(pf.vType) {
    95  		return nil, fmt.Errorf("expected field %s for message %T to be assignable to %v but was not. Got %T.", pf.name, pf.msg, pf.fType, pf.value)
    96  	}
    97  	value, err := pf.populateTo(pf.value)
    98  	if err != nil {
    99  		return nil, fmt.Errorf("error in PopulateTo for field %s for message %T: %s", pf.name, pf.msg, err)
   100  	}
   101  	return value, nil
   102  }
   103  
   104  type mapField struct {
   105  	baseField
   106  	populateFrom func(key string, value interface{}, destType reflect.Type) (reflect.Value, error)
   107  	populateTo   func(key string, value reflect.Value) (interface{}, error)
   108  }
   109  
   110  func (mf *mapField) PopulateFrom(source interface{}) error {
   111  	tree, ok := source.(map[string]interface{})
   112  	if !ok {
   113  		return fmt.Errorf("expected map field %s for message %T to be assignable from map[string]interface{} but was not. Got %T", mf.name, mf.msg, source)
   114  	}
   115  
   116  	result := reflect.MakeMap(mf.vType)
   117  
   118  	for k, v := range tree {
   119  		if !reflect.TypeOf(v).AssignableTo(mf.fType) {
   120  			return fmt.Errorf("expected map field %s value for %s for message %T to be assignable from %v but was not.  Is %T", mf.name, k, mf.msg, mf.fType, v)
   121  		}
   122  		newValue, err := mf.populateFrom(k, v, mf.vType.Elem())
   123  		if err != nil {
   124  			return fmt.Errorf("error in PopulateFrom for map field %s with key %s for message %T: %s", mf.name, k, mf.msg, err)
   125  		}
   126  		result.SetMapIndex(reflect.ValueOf(k), newValue)
   127  	}
   128  
   129  	mf.value.Set(result)
   130  	return nil
   131  }
   132  
   133  func (mf *mapField) PopulateTo() (interface{}, error) {
   134  	result := make(map[string]interface{})
   135  	keys := mf.value.MapKeys()
   136  	for _, key := range keys {
   137  		k, ok := key.Interface().(string)
   138  		if !ok {
   139  			return nil, fmt.Errorf("expected map field %s for message %T to have string keys, but did not.", mf.name, mf.msg)
   140  		}
   141  
   142  		subValue := mf.value.MapIndex(key)
   143  
   144  		if !subValue.Type().AssignableTo(mf.vType.Elem()) {
   145  			return nil, fmt.Errorf("expected map field %s with key %s for message %T to be assignable to %v but was not. Got %v.", mf.name, k, mf.msg, mf.vType.Elem(), subValue.Type())
   146  		}
   147  
   148  		value, err := mf.populateTo(k, subValue)
   149  		if err != nil {
   150  			return nil, fmt.Errorf("error in PopulateTo for map field %s and key %s for message %T: %s", mf.name, k, mf.msg, err)
   151  		}
   152  		result[k] = value
   153  	}
   154  
   155  	return result, nil
   156  }
   157  
   158  type sliceField struct {
   159  	baseField
   160  	populateTo   func(i int, source reflect.Value) (interface{}, error)
   161  	populateFrom func(i int, source interface{}, destType reflect.Type) (reflect.Value, error)
   162  }
   163  
   164  func (sf *sliceField) PopulateFrom(source interface{}) error {
   165  	slice, ok := source.([]interface{})
   166  	if !ok {
   167  		return fmt.Errorf("expected slice field %s for message %T to be assignable from []interface{} but was not. Got %T", sf.name, sf.msg, source)
   168  	}
   169  
   170  	result := reflect.MakeSlice(sf.vType, len(slice), len(slice))
   171  
   172  	for i, v := range slice {
   173  		if !reflect.TypeOf(v).AssignableTo(sf.fType) {
   174  			return fmt.Errorf("expected slice field %s value at index %d for message %T to be assignable from %v but was not.  Is %T", sf.name, i, sf.msg, sf.fType, v)
   175  		}
   176  		subValue, err := sf.populateFrom(i, v, sf.vType.Elem())
   177  		if err != nil {
   178  			return fmt.Errorf("error in PopulateFrom for slice field %s at index %d for message %T: %s", sf.name, i, sf.msg, err)
   179  		}
   180  		result.Index(i).Set(subValue)
   181  	}
   182  
   183  	sf.value.Set(result)
   184  	return nil
   185  }
   186  
   187  func (sf *sliceField) PopulateTo() (interface{}, error) {
   188  	result := make([]interface{}, sf.value.Len())
   189  	for i := range result {
   190  		subValue := sf.value.Index(i)
   191  		if !subValue.Type().AssignableTo(sf.vType.Elem()) {
   192  			return nil, fmt.Errorf("expected slice field %s at index %d for message %T to be assignable to %v but was not. Got %v.", sf.name, i, sf.msg, sf.vType.Elem(), subValue.Type())
   193  		}
   194  
   195  		value, err := sf.populateTo(i, subValue)
   196  		if err != nil {
   197  			return nil, fmt.Errorf("error in PopulateTo for slice field %s at index %d for message %T: %s", sf.name, i, sf.msg, err)
   198  		}
   199  		result[i] = value
   200  	}
   201  
   202  	return result, nil
   203  }
   204  
   205  func stringInSlice(target string, slice []string) bool {
   206  	for _, name := range slice {
   207  		if name == target {
   208  			return true
   209  		}
   210  	}
   211  	return false
   212  }
   213  
   214  // protoToJSON is a simple shortcut wrapper around the proto JSON marshaler
   215  func protoToJSON(msg proto.Message) ([]byte, error) {
   216  	var b bytes.Buffer
   217  	m := jsonpb.Marshaler{
   218  		EnumsAsInts:  false,
   219  		EmitDefaults: false,
   220  		Indent:       "    ",
   221  		OrigName:     true,
   222  	}
   223  	err := m.Marshal(&b, msg)
   224  	if err != nil {
   225  		return nil, err
   226  	}
   227  	return b.Bytes(), nil
   228  }
   229  
   230  func mapToProto(tree map[string]interface{}, msg proto.Message) error {
   231  	jsonOut, err := json.Marshal(tree)
   232  	if err != nil {
   233  		return err
   234  	}
   235  
   236  	return jsonpb.UnmarshalString(string(jsonOut), msg)
   237  }
   238  
   239  // jsonToMap allocates a map[string]interface{}, unmarshals a JSON document into it
   240  // and returns it, or error
   241  func jsonToMap(marshaled []byte) (map[string]interface{}, error) {
   242  	tree := make(map[string]interface{})
   243  	d := json.NewDecoder(bytes.NewReader(marshaled))
   244  	d.UseNumber()
   245  	err := d.Decode(&tree)
   246  	if err != nil {
   247  		return nil, fmt.Errorf("error unmarshaling intermediate JSON: %s", err)
   248  	}
   249  	return tree, nil
   250  }
   251  
   252  // The factory implementations, listed in order of most greedy to least.
   253  // Factories listed lower, may depend on factories listed higher being
   254  // evaluated first.
   255  var fieldFactories = []protoFieldFactory{
   256  	dynamicSliceFieldFactory{},
   257  	dynamicMapFieldFactory{},
   258  	dynamicFieldFactory{},
   259  	variablyOpaqueSliceFieldFactory{},
   260  	variablyOpaqueMapFieldFactory{},
   261  	variablyOpaqueFieldFactory{},
   262  	staticallyOpaqueSliceFieldFactory{},
   263  	staticallyOpaqueMapFieldFactory{},
   264  	staticallyOpaqueFieldFactory{},
   265  	nestedSliceFieldFactory{},
   266  	nestedMapFieldFactory{},
   267  	nestedFieldFactory{},
   268  }
   269  
   270  func protoFields(msg proto.Message, uMsg proto.Message) ([]protoField, error) {
   271  	var result []protoField
   272  
   273  	pmVal := reflect.ValueOf(uMsg)
   274  	if pmVal.Kind() != reflect.Ptr {
   275  		return nil, fmt.Errorf("expected proto.Message %T to be pointer kind", msg)
   276  	}
   277  
   278  	if pmVal.IsNil() {
   279  		return nil, nil
   280  	}
   281  
   282  	mVal := pmVal.Elem()
   283  	if mVal.Kind() != reflect.Struct {
   284  		return nil, fmt.Errorf("expected proto.Message %T ptr value to be struct, was %v", uMsg, mVal.Kind())
   285  	}
   286  
   287  	iResult := make([][]protoField, len(fieldFactories))
   288  
   289  	protoProps := proto.GetProperties(mVal.Type())
   290  	// TODO, this will skip oneof fields, this should be handled
   291  	// correctly at some point
   292  	for _, prop := range protoProps.Prop {
   293  		fieldName := prop.OrigName
   294  		fieldValue := mVal.FieldByName(prop.Name)
   295  		fieldTypeStruct, ok := mVal.Type().FieldByName(prop.Name)
   296  		if !ok {
   297  			return nil, fmt.Errorf("programming error: proto does not have field advertised by proto package")
   298  		}
   299  		fieldType := fieldTypeStruct.Type
   300  
   301  		for i, factory := range fieldFactories {
   302  			if !factory.Handles(msg, fieldName, fieldType, fieldValue) {
   303  				continue
   304  			}
   305  
   306  			field, err := factory.NewProtoField(msg, fieldName, fieldType, fieldValue)
   307  			if err != nil {
   308  				return nil, err
   309  			}
   310  			iResult[i] = append(iResult[i], field)
   311  			break
   312  		}
   313  	}
   314  
   315  	// Loop over the collected fields in reverse order to collect them in
   316  	// correct dependency order as specified in fieldFactories
   317  	for i := len(iResult) - 1; i >= 0; i-- {
   318  		result = append(result, iResult[i]...)
   319  	}
   320  
   321  	return result, nil
   322  }
   323  
   324  func recursivelyCreateTreeFromMessage(msg proto.Message) (tree map[string]interface{}, err error) {
   325  	defer func() {
   326  		// Because this function is recursive, it's difficult to determine which level
   327  		// of the proto the error originated from, this wrapper leaves breadcrumbs for debugging
   328  		if err != nil {
   329  			err = fmt.Errorf("%T: %s", msg, err)
   330  		}
   331  	}()
   332  
   333  	uMsg := msg
   334  	decorated, ok := msg.(DecoratedProto)
   335  	if ok {
   336  		uMsg = decorated.Underlying()
   337  	}
   338  
   339  	fields, err := protoFields(msg, uMsg)
   340  	if err != nil {
   341  		return nil, err
   342  	}
   343  
   344  	jsonBytes, err := protoToJSON(uMsg)
   345  	if err != nil {
   346  		return nil, err
   347  	}
   348  
   349  	tree, err = jsonToMap(jsonBytes)
   350  	if err != nil {
   351  		return nil, err
   352  	}
   353  
   354  	for _, field := range fields {
   355  		if _, ok := tree[field.Name()]; !ok {
   356  			continue
   357  		}
   358  		delete(tree, field.Name())
   359  		tree[field.Name()], err = field.PopulateTo()
   360  		if err != nil {
   361  			return nil, err
   362  		}
   363  	}
   364  
   365  	return tree, nil
   366  }
   367  
   368  // DeepMarshalJSON marshals msg to w as JSON, but instead of marshaling bytes fields which contain nested
   369  // marshaled messages as base64 (like the standard proto encoding), these nested messages are remarshaled
   370  // as the JSON representation of those messages.  This is done so that the JSON representation is as non-binary
   371  // and human readable as possible.
   372  func DeepMarshalJSON(w io.Writer, msg proto.Message) error {
   373  	root, err := recursivelyCreateTreeFromMessage(msg)
   374  	if err != nil {
   375  		return err
   376  	}
   377  
   378  	encoder := json.NewEncoder(w)
   379  	encoder.SetIndent("", "\t")
   380  	return encoder.Encode(root)
   381  }
   382  
   383  func recursivelyPopulateMessageFromTree(tree map[string]interface{}, msg proto.Message) (err error) {
   384  	defer func() {
   385  		// Because this function is recursive, it's difficult to determine which level
   386  		// of the proto the error orginated from, this wrapper leaves breadcrumbs for debugging
   387  		if err != nil {
   388  			err = fmt.Errorf("%T: %s", msg, err)
   389  		}
   390  	}()
   391  
   392  	uMsg := msg
   393  	decorated, ok := msg.(DecoratedProto)
   394  	if ok {
   395  		uMsg = decorated.Underlying()
   396  	}
   397  
   398  	fields, err := protoFields(msg, uMsg)
   399  	if err != nil {
   400  		return err
   401  	}
   402  
   403  	specialFieldsMap := make(map[string]interface{})
   404  
   405  	for _, field := range fields {
   406  		specialField, ok := tree[field.Name()]
   407  		if !ok {
   408  			continue
   409  		}
   410  		specialFieldsMap[field.Name()] = specialField
   411  		delete(tree, field.Name())
   412  	}
   413  
   414  	if err = mapToProto(tree, uMsg); err != nil {
   415  		return err
   416  	}
   417  
   418  	for _, field := range fields {
   419  		specialField, ok := specialFieldsMap[field.Name()]
   420  		if !ok {
   421  			continue
   422  		}
   423  		if err := field.PopulateFrom(specialField); err != nil {
   424  			return err
   425  		}
   426  	}
   427  
   428  	return nil
   429  }
   430  
   431  // DeepUnmarshalJSON takes JSON output as generated by DeepMarshalJSON and decodes it into msg
   432  // This includes re-marshaling the expanded nested elements to binary form
   433  func DeepUnmarshalJSON(r io.Reader, msg proto.Message) error {
   434  	b, err := ioutil.ReadAll(r)
   435  	if err != nil {
   436  		return err
   437  	}
   438  
   439  	root, err := jsonToMap(b)
   440  	if err != nil {
   441  		return err
   442  	}
   443  
   444  	return recursivelyPopulateMessageFromTree(root, msg)
   445  }