github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/gonet/bytebuffer.go (about)

     1  /*
     2   * Copyright 2022 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package gonet
    18  
    19  import (
    20  	"errors"
    21  	"io"
    22  	"sync"
    23  
    24  	"github.com/cloudwego/netpoll"
    25  
    26  	"github.com/cloudwego/kitex/pkg/remote"
    27  )
    28  
    29  var rwPool sync.Pool
    30  
    31  func init() {
    32  	rwPool.New = newBufferReadWriter
    33  }
    34  
    35  var _ remote.ByteBuffer = &bufferReadWriter{}
    36  
    37  type bufferReadWriter struct {
    38  	reader netpoll.Reader
    39  	writer netpoll.Writer
    40  
    41  	ioReader io.Reader
    42  	ioWriter io.Writer
    43  
    44  	readSize int
    45  	status   int
    46  }
    47  
    48  func newBufferReadWriter() interface{} {
    49  	return &bufferReadWriter{}
    50  }
    51  
    52  // NewBufferReader creates a new remote.ByteBuffer using the given netpoll.ZeroCopyReader.
    53  func NewBufferReader(ir io.Reader) remote.ByteBuffer {
    54  	rw := rwPool.Get().(*bufferReadWriter)
    55  	if npReader, ok := ir.(interface{ Reader() netpoll.Reader }); ok {
    56  		rw.reader = npReader.Reader()
    57  	} else {
    58  		rw.reader = netpoll.NewReader(ir)
    59  	}
    60  	rw.ioReader = ir
    61  	rw.status = remote.BitReadable
    62  	rw.readSize = 0
    63  	return rw
    64  }
    65  
    66  // NewBufferWriter creates a new remote.ByteBuffer using the given netpoll.ZeroCopyWriter.
    67  func NewBufferWriter(iw io.Writer) remote.ByteBuffer {
    68  	rw := rwPool.Get().(*bufferReadWriter)
    69  	rw.writer = netpoll.NewWriter(iw)
    70  	rw.ioWriter = iw
    71  	rw.status = remote.BitWritable
    72  	return rw
    73  }
    74  
    75  // NewBufferReadWriter creates a new remote.ByteBuffer using the given netpoll.ZeroCopyReadWriter.
    76  func NewBufferReadWriter(irw io.ReadWriter) remote.ByteBuffer {
    77  	rw := rwPool.Get().(*bufferReadWriter)
    78  	rw.writer = netpoll.NewWriter(irw)
    79  	rw.reader = netpoll.NewReader(irw)
    80  	rw.ioWriter = irw
    81  	rw.ioReader = irw
    82  	rw.status = remote.BitWritable | remote.BitReadable
    83  	return rw
    84  }
    85  
    86  func (rw *bufferReadWriter) readable() bool {
    87  	return rw.status&remote.BitReadable != 0
    88  }
    89  
    90  func (rw *bufferReadWriter) writable() bool {
    91  	return rw.status&remote.BitWritable != 0
    92  }
    93  
    94  func (rw *bufferReadWriter) Next(n int) (p []byte, err error) {
    95  	if !rw.readable() {
    96  		return nil, errors.New("unreadable buffer, cannot support Next")
    97  	}
    98  	if p, err = rw.reader.Next(n); err == nil {
    99  		rw.readSize += n
   100  	}
   101  	return
   102  }
   103  
   104  func (rw *bufferReadWriter) Peek(n int) (buf []byte, err error) {
   105  	if !rw.readable() {
   106  		return nil, errors.New("unreadable buffer, cannot support Peek")
   107  	}
   108  	return rw.reader.Peek(n)
   109  }
   110  
   111  func (rw *bufferReadWriter) Skip(n int) (err error) {
   112  	if !rw.readable() {
   113  		return errors.New("unreadable buffer, cannot support Skip")
   114  	}
   115  	return rw.reader.Skip(n)
   116  }
   117  
   118  func (rw *bufferReadWriter) ReadableLen() (n int) {
   119  	if !rw.readable() {
   120  		return -1
   121  	}
   122  	return rw.reader.Len()
   123  }
   124  
   125  func (rw *bufferReadWriter) ReadString(n int) (s string, err error) {
   126  	if !rw.readable() {
   127  		return "", errors.New("unreadable buffer, cannot support ReadString")
   128  	}
   129  	if s, err = rw.reader.ReadString(n); err == nil {
   130  		rw.readSize += n
   131  	}
   132  	return
   133  }
   134  
   135  func (rw *bufferReadWriter) ReadBinary(n int) (p []byte, err error) {
   136  	if !rw.readable() {
   137  		return p, errors.New("unreadable buffer, cannot support ReadBinary")
   138  	}
   139  	if p, err = rw.reader.ReadBinary(n); err == nil {
   140  		rw.readSize += n
   141  	}
   142  	return
   143  }
   144  
   145  func (rw *bufferReadWriter) Read(p []byte) (n int, err error) {
   146  	if !rw.readable() {
   147  		return -1, errors.New("unreadable buffer, cannot support Read")
   148  	}
   149  	if rw.ioReader != nil {
   150  		return rw.ioReader.Read(p)
   151  	}
   152  	return -1, errors.New("ioReader is nil")
   153  }
   154  
   155  func (rw *bufferReadWriter) ReadLen() (n int) {
   156  	return rw.readSize
   157  }
   158  
   159  func (rw *bufferReadWriter) Malloc(n int) (buf []byte, err error) {
   160  	if !rw.writable() {
   161  		return nil, errors.New("unwritable buffer, cannot support Malloc")
   162  	}
   163  	return rw.writer.Malloc(n)
   164  }
   165  
   166  func (rw *bufferReadWriter) MallocLen() (length int) {
   167  	if !rw.writable() {
   168  		return -1
   169  	}
   170  	return rw.writer.MallocLen()
   171  }
   172  
   173  func (rw *bufferReadWriter) WriteString(s string) (n int, err error) {
   174  	if !rw.writable() {
   175  		return -1, errors.New("unwritable buffer, cannot support WriteString")
   176  	}
   177  	return rw.writer.WriteString(s)
   178  }
   179  
   180  func (rw *bufferReadWriter) WriteBinary(b []byte) (n int, err error) {
   181  	if !rw.writable() {
   182  		return -1, errors.New("unwritable buffer, cannot support WriteBinary")
   183  	}
   184  	return rw.writer.WriteBinary(b)
   185  }
   186  
   187  func (rw *bufferReadWriter) Flush() (err error) {
   188  	if !rw.writable() {
   189  		return errors.New("unwritable buffer, cannot support Flush")
   190  	}
   191  	return rw.writer.Flush()
   192  }
   193  
   194  func (rw *bufferReadWriter) Write(p []byte) (n int, err error) {
   195  	if !rw.writable() {
   196  		return -1, errors.New("unwritable buffer, cannot support Write")
   197  	}
   198  	if rw.ioWriter != nil {
   199  		return rw.ioWriter.Write(p)
   200  	}
   201  	return -1, errors.New("ioWriter is nil")
   202  }
   203  
   204  func (rw *bufferReadWriter) Release(e error) (err error) {
   205  	if rw.reader != nil {
   206  		err = rw.reader.Release()
   207  	}
   208  	rw.zero()
   209  	rwPool.Put(rw)
   210  	return
   211  }
   212  
   213  // WriteDirect is a way to write []byte without copying, and splits the original buffer.
   214  func (rw *bufferReadWriter) WriteDirect(p []byte, remainCap int) error {
   215  	if !rw.writable() {
   216  		return errors.New("unwritable buffer, cannot support WriteBinary")
   217  	}
   218  	return rw.writer.WriteDirect(p, remainCap)
   219  }
   220  
   221  func (rw *bufferReadWriter) AppendBuffer(buf remote.ByteBuffer) (err error) {
   222  	subBuf, ok := buf.(*bufferReadWriter)
   223  	if !ok {
   224  		return errors.New("AppendBuffer failed, Buffer is not bufferReadWriter")
   225  	}
   226  	if err = rw.writer.Append(subBuf.writer); err != nil {
   227  		return
   228  	}
   229  	return buf.Release(nil)
   230  }
   231  
   232  // NewBuffer returns a new writable remote.ByteBuffer.
   233  func (rw *bufferReadWriter) NewBuffer() remote.ByteBuffer {
   234  	panic("unimplemented")
   235  }
   236  
   237  func (rw *bufferReadWriter) Bytes() (buf []byte, err error) {
   238  	panic("unimplemented")
   239  }
   240  
   241  func (rw *bufferReadWriter) zero() {
   242  	rw.reader = nil
   243  	rw.writer = nil
   244  	rw.ioReader = nil
   245  	rw.ioWriter = nil
   246  	rw.readSize = 0
   247  	rw.status = 0
   248  }