gitee.com/liuxuezhan/go-micro-v1.18.0@v1.0.0/codec/protorpc/protorpc.go (about)

     1  // Protorpc provides a net/rpc proto-rpc codec. See envelope.proto for the format.
     2  package protorpc
     3  
     4  import (
     5  	"bytes"
     6  	"fmt"
     7  	"io"
     8  	"strconv"
     9  	"sync"
    10  
    11  	"github.com/golang/protobuf/proto"
    12  	"gitee.com/liuxuezhan/go-micro-v1.18.0/codec"
    13  )
    14  
    15  type flusher interface {
    16  	Flush() error
    17  }
    18  
    19  type protoCodec struct {
    20  	sync.Mutex
    21  	rwc io.ReadWriteCloser
    22  	mt  codec.MessageType
    23  	buf *bytes.Buffer
    24  }
    25  
    26  func (c *protoCodec) Close() error {
    27  	c.buf.Reset()
    28  	return c.rwc.Close()
    29  }
    30  
    31  func (c *protoCodec) String() string {
    32  	return "proto-rpc"
    33  }
    34  
    35  func id(id string) *uint64 {
    36  	p, err := strconv.ParseInt(id, 10, 64)
    37  	if err != nil {
    38  		p = 0
    39  	}
    40  	i := uint64(p)
    41  	return &i
    42  }
    43  
    44  func (c *protoCodec) Write(m *codec.Message, b interface{}) error {
    45  	switch m.Type {
    46  	case codec.Request:
    47  		c.Lock()
    48  		defer c.Unlock()
    49  		// This is protobuf, of course we copy it.
    50  		pbr := &Request{ServiceMethod: &m.Method, Seq: id(m.Id)}
    51  		data, err := proto.Marshal(pbr)
    52  		if err != nil {
    53  			return err
    54  		}
    55  		_, err = WriteNetString(c.rwc, data)
    56  		if err != nil {
    57  			return err
    58  		}
    59  		// Of course this is a protobuf! Trust me or detonate the program.
    60  		data, err = proto.Marshal(b.(proto.Message))
    61  		if err != nil {
    62  			return err
    63  		}
    64  		_, err = WriteNetString(c.rwc, data)
    65  		if err != nil {
    66  			return err
    67  		}
    68  		if flusher, ok := c.rwc.(flusher); ok {
    69  			if err = flusher.Flush(); err != nil {
    70  				return err
    71  			}
    72  		}
    73  	case codec.Response, codec.Error:
    74  		c.Lock()
    75  		defer c.Unlock()
    76  		rtmp := &Response{ServiceMethod: &m.Method, Seq: id(m.Id), Error: &m.Error}
    77  		data, err := proto.Marshal(rtmp)
    78  		if err != nil {
    79  			return err
    80  		}
    81  		_, err = WriteNetString(c.rwc, data)
    82  		if err != nil {
    83  			return err
    84  		}
    85  		if pb, ok := b.(proto.Message); ok {
    86  			data, err = proto.Marshal(pb)
    87  			if err != nil {
    88  				return err
    89  			}
    90  		} else {
    91  			data = nil
    92  		}
    93  		_, err = WriteNetString(c.rwc, data)
    94  		if err != nil {
    95  			return err
    96  		}
    97  		if flusher, ok := c.rwc.(flusher); ok {
    98  			if err = flusher.Flush(); err != nil {
    99  				return err
   100  			}
   101  		}
   102  	case codec.Event:
   103  		data, err := proto.Marshal(b.(proto.Message))
   104  		if err != nil {
   105  			return err
   106  		}
   107  		c.rwc.Write(data)
   108  	default:
   109  		return fmt.Errorf("Unrecognised message type: %v", m.Type)
   110  	}
   111  	return nil
   112  }
   113  
   114  func (c *protoCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error {
   115  	c.buf.Reset()
   116  	c.mt = mt
   117  
   118  	switch mt {
   119  	case codec.Request:
   120  		data, err := ReadNetString(c.rwc)
   121  		if err != nil {
   122  			return err
   123  		}
   124  		rtmp := new(Request)
   125  		err = proto.Unmarshal(data, rtmp)
   126  		if err != nil {
   127  			return err
   128  		}
   129  		m.Method = rtmp.GetServiceMethod()
   130  		m.Id = fmt.Sprintf("%d", rtmp.GetSeq())
   131  	case codec.Response:
   132  		data, err := ReadNetString(c.rwc)
   133  		if err != nil {
   134  			return err
   135  		}
   136  		rtmp := new(Response)
   137  		err = proto.Unmarshal(data, rtmp)
   138  		if err != nil {
   139  			return err
   140  		}
   141  		m.Method = rtmp.GetServiceMethod()
   142  		m.Id = fmt.Sprintf("%d", rtmp.GetSeq())
   143  		m.Error = rtmp.GetError()
   144  	case codec.Event:
   145  		_, err := io.Copy(c.buf, c.rwc)
   146  		return err
   147  	default:
   148  		return fmt.Errorf("Unrecognised message type: %v", mt)
   149  	}
   150  	return nil
   151  }
   152  
   153  func (c *protoCodec) ReadBody(b interface{}) error {
   154  	var data []byte
   155  	switch c.mt {
   156  	case codec.Request, codec.Response:
   157  		var err error
   158  		data, err = ReadNetString(c.rwc)
   159  		if err != nil {
   160  			return err
   161  		}
   162  	case codec.Event:
   163  		data = c.buf.Bytes()
   164  	default:
   165  		return fmt.Errorf("Unrecognised message type: %v", c.mt)
   166  	}
   167  	if b != nil {
   168  		return proto.Unmarshal(data, b.(proto.Message))
   169  	}
   170  	return nil
   171  }
   172  
   173  func NewCodec(rwc io.ReadWriteCloser) codec.Codec {
   174  	return &protoCodec{
   175  		buf: bytes.NewBuffer(nil),
   176  		rwc: rwc,
   177  	}
   178  }