github.com/tickoalcantara12/micro/v3@v3.0.0-20221007104245-9d75b9bcbab9/util/codec/protorpc/protorpc.go (about)

     1  // Licensed under the Apache License, Version 2.0 (the "License");
     2  // you may not use this file except in compliance with the License.
     3  // You may obtain a copy of the License at
     4  //
     5  //     https://www.apache.org/licenses/LICENSE-2.0
     6  //
     7  // Unless required by applicable law or agreed to in writing, software
     8  // distributed under the License is distributed on an "AS IS" BASIS,
     9  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    10  // See the License for the specific language governing permissions and
    11  // limitations under the License.
    12  //
    13  // Original source: github.com/micro/go-micro/v3/codec/protorpc/protorpc.go
    14  
    15  // Protorpc provides a net/rpc proto-rpc codec. See envelope.proto for the format.
    16  package protorpc
    17  
    18  import (
    19  	"bytes"
    20  	"fmt"
    21  	"io"
    22  	"strconv"
    23  	"sync"
    24  
    25  	"github.com/golang/protobuf/proto"
    26  	"github.com/tickoalcantara12/micro/v3/util/codec"
    27  )
    28  
    29  type flusher interface {
    30  	Flush() error
    31  }
    32  
    33  type protoCodec struct {
    34  	sync.Mutex
    35  	rwc io.ReadWriteCloser
    36  	mt  codec.MessageType
    37  	buf *bytes.Buffer
    38  }
    39  
    40  func (c *protoCodec) Close() error {
    41  	c.buf.Reset()
    42  	return c.rwc.Close()
    43  }
    44  
    45  func (c *protoCodec) String() string {
    46  	return "proto-rpc"
    47  }
    48  
    49  func id(id string) uint64 {
    50  	p, err := strconv.ParseInt(id, 10, 64)
    51  	if err != nil {
    52  		p = 0
    53  	}
    54  	i := uint64(p)
    55  	return i
    56  }
    57  
    58  func (c *protoCodec) Write(m *codec.Message, b interface{}) error {
    59  	switch m.Type {
    60  	case codec.Request:
    61  		c.Lock()
    62  		defer c.Unlock()
    63  		// This is protobuf, of course we copy it.
    64  		pbr := &Request{ServiceMethod: m.Method, Seq: id(m.Id)}
    65  		data, err := proto.Marshal(pbr)
    66  		if err != nil {
    67  			return err
    68  		}
    69  		_, err = WriteNetString(c.rwc, data)
    70  		if err != nil {
    71  			return err
    72  		}
    73  		// dont trust or incoming message
    74  		m, ok := b.(proto.Message)
    75  		if !ok {
    76  			return codec.ErrInvalidMessage
    77  		}
    78  		data, err = proto.Marshal(m)
    79  		if err != nil {
    80  			return err
    81  		}
    82  		_, err = WriteNetString(c.rwc, data)
    83  		if err != nil {
    84  			return err
    85  		}
    86  		if flusher, ok := c.rwc.(flusher); ok {
    87  			if err = flusher.Flush(); err != nil {
    88  				return err
    89  			}
    90  		}
    91  	case codec.Response, codec.Error:
    92  		c.Lock()
    93  		defer c.Unlock()
    94  		rtmp := &Response{ServiceMethod: m.Method, Seq: id(m.Id), Error: m.Error}
    95  		data, err := proto.Marshal(rtmp)
    96  		if err != nil {
    97  			return err
    98  		}
    99  		_, err = WriteNetString(c.rwc, data)
   100  		if err != nil {
   101  			return err
   102  		}
   103  		if pb, ok := b.(proto.Message); ok {
   104  			data, err = proto.Marshal(pb)
   105  			if err != nil {
   106  				return err
   107  			}
   108  		} else {
   109  			data = nil
   110  		}
   111  		_, err = WriteNetString(c.rwc, data)
   112  		if err != nil {
   113  			return err
   114  		}
   115  		if flusher, ok := c.rwc.(flusher); ok {
   116  			if err = flusher.Flush(); err != nil {
   117  				return err
   118  			}
   119  		}
   120  	case codec.Event:
   121  		m, ok := b.(proto.Message)
   122  		if !ok {
   123  			return codec.ErrInvalidMessage
   124  		}
   125  		data, err := proto.Marshal(m)
   126  		if err != nil {
   127  			return err
   128  		}
   129  		c.rwc.Write(data)
   130  	default:
   131  		return fmt.Errorf("Unrecognised message type: %v", m.Type)
   132  	}
   133  	return nil
   134  }
   135  
   136  func (c *protoCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error {
   137  	c.buf.Reset()
   138  	c.mt = mt
   139  
   140  	switch mt {
   141  	case codec.Request:
   142  		data, err := ReadNetString(c.rwc)
   143  		if err != nil {
   144  			return err
   145  		}
   146  		rtmp := new(Request)
   147  		err = proto.Unmarshal(data, rtmp)
   148  		if err != nil {
   149  			return err
   150  		}
   151  		m.Method = rtmp.GetServiceMethod()
   152  		m.Id = fmt.Sprintf("%d", rtmp.GetSeq())
   153  	case codec.Response:
   154  		data, err := ReadNetString(c.rwc)
   155  		if err != nil {
   156  			return err
   157  		}
   158  		rtmp := new(Response)
   159  		err = proto.Unmarshal(data, rtmp)
   160  		if err != nil {
   161  			return err
   162  		}
   163  		m.Method = rtmp.GetServiceMethod()
   164  		m.Id = fmt.Sprintf("%d", rtmp.GetSeq())
   165  		m.Error = rtmp.GetError()
   166  	case codec.Event:
   167  		_, err := io.Copy(c.buf, c.rwc)
   168  		return err
   169  	default:
   170  		return fmt.Errorf("Unrecognised message type: %v", mt)
   171  	}
   172  	return nil
   173  }
   174  
   175  func (c *protoCodec) ReadBody(b interface{}) error {
   176  	var data []byte
   177  	switch c.mt {
   178  	case codec.Request, codec.Response:
   179  		var err error
   180  		data, err = ReadNetString(c.rwc)
   181  		if err != nil {
   182  			return err
   183  		}
   184  	case codec.Event:
   185  		data = c.buf.Bytes()
   186  	default:
   187  		return fmt.Errorf("Unrecognised message type: %v", c.mt)
   188  	}
   189  	if b != nil {
   190  		return proto.Unmarshal(data, b.(proto.Message))
   191  	}
   192  	return nil
   193  }
   194  
   195  func NewCodec(rwc io.ReadWriteCloser) codec.Codec {
   196  	return &protoCodec{
   197  		buf: bytes.NewBuffer(nil),
   198  		rwc: rwc,
   199  	}
   200  }