github.com/btccom/go-micro/v2@v2.9.3/codec/protorpc/protorpc.go (about)

     1  // Protorpc provides a net/rpc proto-rpc codec. See envelope.proto for the format.
     2  package protorpc
     3  
     4  import (
     5  	"bytes"
     6  	"fmt"
     7  	"io"
     8  	"strconv"
     9  	"sync"
    10  
    11  	"github.com/golang/protobuf/proto"
    12  	"github.com/btccom/go-micro/v2/codec"
    13  )
    14  
    15  type flusher interface {
    16  	Flush() error
    17  }
    18  
    19  type protoCodec struct {
    20  	sync.Mutex
    21  	rwc io.ReadWriteCloser
    22  	mt  codec.MessageType
    23  	buf *bytes.Buffer
    24  }
    25  
    26  func (c *protoCodec) Close() error {
    27  	c.buf.Reset()
    28  	return c.rwc.Close()
    29  }
    30  
    31  func (c *protoCodec) String() string {
    32  	return "proto-rpc"
    33  }
    34  
    35  func id(id string) uint64 {
    36  	p, err := strconv.ParseInt(id, 10, 64)
    37  	if err != nil {
    38  		p = 0
    39  	}
    40  	i := uint64(p)
    41  	return i
    42  }
    43  
    44  func (c *protoCodec) Write(m *codec.Message, b interface{}) error {
    45  	switch m.Type {
    46  	case codec.Request:
    47  		c.Lock()
    48  		defer c.Unlock()
    49  		// This is protobuf, of course we copy it.
    50  		pbr := &Request{ServiceMethod: m.Method, Seq: id(m.Id)}
    51  		data, err := proto.Marshal(pbr)
    52  		if err != nil {
    53  			return err
    54  		}
    55  		_, err = WriteNetString(c.rwc, data)
    56  		if err != nil {
    57  			return err
    58  		}
    59  		// dont trust or incoming message
    60  		m, ok := b.(proto.Message)
    61  		if !ok {
    62  			return codec.ErrInvalidMessage
    63  		}
    64  		data, err = proto.Marshal(m)
    65  		if err != nil {
    66  			return err
    67  		}
    68  		_, err = WriteNetString(c.rwc, data)
    69  		if err != nil {
    70  			return err
    71  		}
    72  		if flusher, ok := c.rwc.(flusher); ok {
    73  			if err = flusher.Flush(); err != nil {
    74  				return err
    75  			}
    76  		}
    77  	case codec.Response, codec.Error:
    78  		c.Lock()
    79  		defer c.Unlock()
    80  		rtmp := &Response{ServiceMethod: m.Method, Seq: id(m.Id), Error: m.Error}
    81  		data, err := proto.Marshal(rtmp)
    82  		if err != nil {
    83  			return err
    84  		}
    85  		_, err = WriteNetString(c.rwc, data)
    86  		if err != nil {
    87  			return err
    88  		}
    89  		if pb, ok := b.(proto.Message); ok {
    90  			data, err = proto.Marshal(pb)
    91  			if err != nil {
    92  				return err
    93  			}
    94  		} else {
    95  			data = nil
    96  		}
    97  		_, err = WriteNetString(c.rwc, data)
    98  		if err != nil {
    99  			return err
   100  		}
   101  		if flusher, ok := c.rwc.(flusher); ok {
   102  			if err = flusher.Flush(); err != nil {
   103  				return err
   104  			}
   105  		}
   106  	case codec.Event:
   107  		m, ok := b.(proto.Message)
   108  		if !ok {
   109  			return codec.ErrInvalidMessage
   110  		}
   111  		data, err := proto.Marshal(m)
   112  		if err != nil {
   113  			return err
   114  		}
   115  		c.rwc.Write(data)
   116  	default:
   117  		return fmt.Errorf("Unrecognised message type: %v", m.Type)
   118  	}
   119  	return nil
   120  }
   121  
   122  func (c *protoCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error {
   123  	c.buf.Reset()
   124  	c.mt = mt
   125  
   126  	switch mt {
   127  	case codec.Request:
   128  		data, err := ReadNetString(c.rwc)
   129  		if err != nil {
   130  			return err
   131  		}
   132  		rtmp := new(Request)
   133  		err = proto.Unmarshal(data, rtmp)
   134  		if err != nil {
   135  			return err
   136  		}
   137  		m.Method = rtmp.GetServiceMethod()
   138  		m.Id = fmt.Sprintf("%d", rtmp.GetSeq())
   139  	case codec.Response:
   140  		data, err := ReadNetString(c.rwc)
   141  		if err != nil {
   142  			return err
   143  		}
   144  		rtmp := new(Response)
   145  		err = proto.Unmarshal(data, rtmp)
   146  		if err != nil {
   147  			return err
   148  		}
   149  		m.Method = rtmp.GetServiceMethod()
   150  		m.Id = fmt.Sprintf("%d", rtmp.GetSeq())
   151  		m.Error = rtmp.GetError()
   152  	case codec.Event:
   153  		_, err := io.Copy(c.buf, c.rwc)
   154  		return err
   155  	default:
   156  		return fmt.Errorf("Unrecognised message type: %v", mt)
   157  	}
   158  	return nil
   159  }
   160  
   161  func (c *protoCodec) ReadBody(b interface{}) error {
   162  	var data []byte
   163  	switch c.mt {
   164  	case codec.Request, codec.Response:
   165  		var err error
   166  		data, err = ReadNetString(c.rwc)
   167  		if err != nil {
   168  			return err
   169  		}
   170  	case codec.Event:
   171  		data = c.buf.Bytes()
   172  	default:
   173  		return fmt.Errorf("Unrecognised message type: %v", c.mt)
   174  	}
   175  	if b != nil {
   176  		return proto.Unmarshal(data, b.(proto.Message))
   177  	}
   178  	return nil
   179  }
   180  
   181  func NewCodec(rwc io.ReadWriteCloser) codec.Codec {
   182  	return &protoCodec{
   183  		buf: bytes.NewBuffer(nil),
   184  		rwc: rwc,
   185  	}
   186  }