github.com/machinefi/w3bstream@v1.6.5-rc9.0.20240426031326-b8c7c4876e72/pkg/depends/kit/httptransport/transformer/tsfm_protobuf.go (about)

     1  package transformer
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net/textproto"
     7  	"reflect"
     8  
     9  	"github.com/golang/protobuf/proto"
    10  	"github.com/pkg/errors"
    11  
    12  	"github.com/machinefi/w3bstream/pkg/depends/kit/httptransport/httpx"
    13  	"github.com/machinefi/w3bstream/pkg/depends/kit/validator"
    14  	"github.com/machinefi/w3bstream/pkg/depends/x/typesx"
    15  )
    16  
    17  var (
    18  	ErrEncodeDataNotProtobuf   = errors.New("encode data must be `proto.Message`")
    19  	ErrDecodeTargetNotProtobuf = errors.New("decode target must be `proto.Message`")
    20  )
    21  
    22  func init() { DefaultFactory.Register(&Protobuf{}) }
    23  
    24  type Protobuf struct{}
    25  
    26  func (Protobuf) Names() []string { return []string{httpx.MIME_PROTOBUF, "protobuf", "x-protobuf"} }
    27  
    28  func (Protobuf) NamedByTag() string { return "protobuf" }
    29  
    30  func (t *Protobuf) String() string { return httpx.MIME_PROTOBUF }
    31  
    32  func (Protobuf) New(context.Context, typesx.Type) (Transformer, error) { return &Protobuf{}, nil }
    33  
    34  func (t *Protobuf) EncodeTo(ctx context.Context, w io.Writer, v interface{}) error {
    35  	if rv, ok := v.(reflect.Value); ok {
    36  		v = rv.Interface()
    37  	}
    38  	httpx.MaybeWriteHeader(ctx, w, t.String(), map[string]string{})
    39  	pv, ok := v.(proto.Message)
    40  	if !ok {
    41  		return ErrEncodeDataNotProtobuf
    42  	}
    43  	data, err := proto.Marshal(pv)
    44  	if err != nil {
    45  		return err
    46  	}
    47  	_, err = w.Write(data)
    48  	return err
    49  }
    50  
    51  func (t *Protobuf) DecodeFrom(ctx context.Context, r io.Reader, v interface{}, _ ...textproto.MIMEHeader) error {
    52  	if rv, ok := v.(reflect.Value); ok {
    53  		if rv.Kind() != reflect.Ptr && rv.CanAddr() {
    54  			rv = rv.Addr()
    55  		}
    56  		v = rv.Interface()
    57  	}
    58  
    59  	pv, ok := v.(proto.Message)
    60  	if !ok {
    61  		return ErrDecodeTargetNotProtobuf
    62  	}
    63  	data, err := io.ReadAll(r)
    64  	if err != nil {
    65  		return err
    66  	}
    67  
    68  	return proto.Unmarshal(data, pv)
    69  }
    70  
    71  // NewValidator returns empty validator to implements interface `MayValidate` to skip protobuf struct validation
    72  func (t *Protobuf) NewValidator(_ context.Context, _ typesx.Type) (validator.Validator, error) {
    73  	return nil, nil
    74  }