github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/kbfs/kbfscodec/codec_msgpack.go (about)

     1  // Copyright 2016 Keybase Inc. All rights reserved.
     2  // Use of this source code is governed by a BSD
     3  // license that can be found in the LICENSE file.
     4  
     5  package kbfscodec
     6  
     7  import (
     8  	"fmt"
     9  	"reflect"
    10  
    11  	"github.com/keybase/go-codec/codec"
    12  	"github.com/pkg/errors"
    13  )
    14  
    15  // ext is a no-op extension that's useful for tagging interfaces with
    16  // a type.  Note that it cannot be used for anything that has nested
    17  // extensions.
    18  type ext struct {
    19  	// codec should NOT encode extension types
    20  	codec Codec
    21  }
    22  
    23  // ConvertExt implements the codec.Ext interface for ext.
    24  func (e ext) ConvertExt(v interface{}) interface{} {
    25  	panic("ConvertExt not supported")
    26  }
    27  
    28  // UpdateExt implements the codec.Ext interface for ext.
    29  func (e ext) UpdateExt(dest interface{}, v interface{}) {
    30  	panic("UpdateExt not supported")
    31  }
    32  
    33  // WriteExt implements the codec.Ext interface for ext.
    34  func (e ext) WriteExt(v interface{}) (buf []byte) {
    35  	buf, err := e.codec.Encode(v)
    36  	if err != nil {
    37  		panic(fmt.Sprintf("Couldn't encode data in %v", v))
    38  	}
    39  	return buf
    40  }
    41  
    42  // ReadExt implements the codec.Ext interface for ext.
    43  func (e ext) ReadExt(v interface{}, buf []byte) {
    44  	err := e.codec.Decode(buf, v)
    45  	if err != nil {
    46  		panic(fmt.Sprintf("Couldn't decode data into %v", v))
    47  	}
    48  }
    49  
    50  // extSlice is an extension that's useful for slices that contain
    51  // extension types as elements.  The contained extension types cannot
    52  // themselves contain nested extension types.
    53  type extSlice struct {
    54  	// codec SHOULD encode extension types
    55  	codec Codec
    56  	typer func(interface{}) reflect.Value
    57  }
    58  
    59  // ConvertExt implements the codec.Ext interface for extSlice.
    60  func (es extSlice) ConvertExt(v interface{}) interface{} {
    61  	panic("ConvertExt not supported")
    62  }
    63  
    64  // UpdateExt implements the codec.Ext interface for extSlice.
    65  func (es extSlice) UpdateExt(dest interface{}, v interface{}) {
    66  	panic("UpdateExt not supported")
    67  }
    68  
    69  // WriteExt implements the codec.Ext interface for extSlice.
    70  func (es extSlice) WriteExt(v interface{}) (buf []byte) {
    71  	val := reflect.ValueOf(v)
    72  	if val.Kind() != reflect.Slice {
    73  		panic(fmt.Sprintf("Non-slice passed to extSlice.WriteExt %v",
    74  			val.Kind()))
    75  	}
    76  
    77  	ifaceArray := make([]interface{}, val.Len())
    78  	for i := 0; i < val.Len(); i++ {
    79  		ifaceArray[i] = val.Index(i).Interface()
    80  	}
    81  
    82  	buf, err := es.codec.Encode(ifaceArray)
    83  	if err != nil {
    84  		panic(fmt.Sprintf("Couldn't encode data in %v", v))
    85  	}
    86  	return buf
    87  }
    88  
    89  // ReadExt implements the codec.Ext interface for extSlice.
    90  func (es extSlice) ReadExt(v interface{}, buf []byte) {
    91  	// ReadExt actually receives a pointer to the list
    92  	val := reflect.ValueOf(v)
    93  	if val.Kind() != reflect.Ptr {
    94  		panic(fmt.Sprintf("Non-pointer passed to extSlice.ReadExt: %v",
    95  			val.Kind()))
    96  	}
    97  
    98  	val = val.Elem()
    99  	if val.Kind() != reflect.Slice {
   100  		panic(fmt.Sprintf("Non-slice passed to extSlice.ReadExt %v",
   101  			val.Kind()))
   102  	}
   103  
   104  	var ifaceArray []interface{}
   105  	err := es.codec.Decode(buf, &ifaceArray)
   106  	if err != nil {
   107  		panic(fmt.Sprintf("Couldn't decode data into %v", v))
   108  	}
   109  
   110  	if len(ifaceArray) > 0 {
   111  		val.Set(reflect.MakeSlice(val.Type(), len(ifaceArray),
   112  			len(ifaceArray)))
   113  	}
   114  
   115  	for i, v := range ifaceArray {
   116  		if es.typer != nil {
   117  			val.Index(i).Set(es.typer(v))
   118  		} else {
   119  			val.Index(i).Set(reflect.ValueOf(v))
   120  		}
   121  	}
   122  }
   123  
   124  // CodecMsgpack implements the Codec interface using msgpack
   125  // marshaling and unmarshaling.
   126  type CodecMsgpack struct {
   127  	h        codec.Handle
   128  	ExtCodec *CodecMsgpack
   129  }
   130  
   131  // newCodecMsgpackHelper constructs a new CodecMsgpack that may or may
   132  // not handle unknown fields.
   133  func newCodecMsgpackHelper(handleUnknownFields bool) *CodecMsgpack {
   134  	handle := codec.MsgpackHandle{}
   135  	handle.Canonical = true
   136  	handle.WriteExt = true
   137  	handle.DecodeUnknownFields = handleUnknownFields
   138  	handle.EncodeUnknownFields = handleUnknownFields
   139  
   140  	// save a codec that doesn't write extensions, so that we can just
   141  	// call Encode/Decode when we want to (de)serialize extension
   142  	// types.
   143  	handleNoExt := handle
   144  	handleNoExt.WriteExt = false
   145  	ExtCodec := &CodecMsgpack{&handleNoExt, nil}
   146  	return &CodecMsgpack{&handle, ExtCodec}
   147  }
   148  
   149  // NewMsgpack constructs a new CodecMsgpack.
   150  func NewMsgpack() *CodecMsgpack {
   151  	return newCodecMsgpackHelper(true)
   152  }
   153  
   154  // NewMsgpackNoUnknownFields constructs a new CodecMsgpack that
   155  // doesn't handle unknown fields.
   156  func NewMsgpackNoUnknownFields() *CodecMsgpack {
   157  	return newCodecMsgpackHelper(false)
   158  }
   159  
   160  // Decode implements the Codec interface for CodecMsgpack
   161  func (c *CodecMsgpack) Decode(buf []byte, obj interface{}) error {
   162  	err := codec.NewDecoderBytes(buf, c.h).Decode(obj)
   163  	if err != nil {
   164  		return errors.Wrap(err, "failed to decode")
   165  	}
   166  	return nil
   167  }
   168  
   169  // Encode implements the Codec interface for CodecMsgpack
   170  func (c *CodecMsgpack) Encode(obj interface{}) (buf []byte, err error) {
   171  	err = codec.NewEncoderBytes(&buf, c.h).Encode(obj)
   172  	if err != nil {
   173  		return nil, errors.Wrap(err, "failed to encode")
   174  	}
   175  	return buf, nil
   176  }
   177  
   178  // RegisterType implements the Codec interface for CodecMsgpack
   179  func (c *CodecMsgpack) RegisterType(rt reflect.Type, code ExtCode) {
   180  	err := c.h.(*codec.MsgpackHandle).SetBytesExt(
   181  		rt, uint64(code), ext{c.ExtCodec})
   182  	if err != nil {
   183  		panic(err)
   184  	}
   185  }
   186  
   187  // RegisterIfaceSliceType implements the Codec interface for CodecMsgpack
   188  func (c *CodecMsgpack) RegisterIfaceSliceType(
   189  	rt reflect.Type, code ExtCode, typer func(interface{}) reflect.Value) {
   190  	err := c.h.(*codec.MsgpackHandle).SetBytesExt(
   191  		rt, uint64(code), extSlice{c, typer})
   192  	if err != nil {
   193  		panic(err)
   194  	}
   195  }