github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/nphttp2/stream.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package nphttp2
    18  
    19  import (
    20  	"context"
    21  	"net"
    22  
    23  	"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata"
    24  
    25  	"github.com/cloudwego/kitex/pkg/remote"
    26  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    27  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    28  	"github.com/cloudwego/kitex/pkg/streaming"
    29  )
    30  
    31  // Streamer Stream creator
    32  type Streamer func(ctx context.Context, svcInfo serviceinfo.ServiceInfo, conn net.Conn,
    33  	handler remote.TransReadWriter) streaming.Stream
    34  
    35  type stream struct {
    36  	ctx     context.Context
    37  	svcInfo *serviceinfo.ServiceInfo
    38  	conn    net.Conn // clientConn or serverConn
    39  	handler remote.TransReadWriter
    40  }
    41  
    42  // NewStream ...
    43  func NewStream(ctx context.Context, svcInfo *serviceinfo.ServiceInfo, conn net.Conn,
    44  	handler remote.TransReadWriter,
    45  ) streaming.Stream {
    46  	return &stream{
    47  		ctx:     ctx,
    48  		svcInfo: svcInfo,
    49  		conn:    conn,
    50  		handler: handler,
    51  	}
    52  }
    53  
    54  func (s *stream) Context() context.Context {
    55  	return s.ctx
    56  }
    57  
    58  // Trailer is used for client side stream
    59  func (s *stream) Trailer() metadata.MD {
    60  	sc, ok := s.conn.(*clientConn)
    61  	if !ok {
    62  		panic("this method should only be used in client side stream!")
    63  	}
    64  	return sc.s.Trailer()
    65  }
    66  
    67  // Header is used for client side stream
    68  func (s *stream) Header() (metadata.MD, error) {
    69  	sc, ok := s.conn.(*clientConn)
    70  	if !ok {
    71  		panic("this method should only be used in client side stream!")
    72  	}
    73  	return sc.s.Header()
    74  }
    75  
    76  // SendHeader is used for server side stream
    77  func (s *stream) SendHeader(md metadata.MD) error {
    78  	sc := s.conn.(*serverConn)
    79  	return sc.s.SendHeader(md)
    80  }
    81  
    82  // SetHeader is used for server side stream
    83  func (s *stream) SetHeader(md metadata.MD) error {
    84  	sc := s.conn.(*serverConn)
    85  	return sc.s.SetHeader(md)
    86  }
    87  
    88  // SetTrailer is used for server side stream
    89  func (s *stream) SetTrailer(md metadata.MD) {
    90  	sc := s.conn.(*serverConn)
    91  	sc.s.SetTrailer(md)
    92  }
    93  
    94  func (s *stream) RecvMsg(m interface{}) error {
    95  	ri := rpcinfo.GetRPCInfo(s.ctx)
    96  
    97  	msg := remote.NewMessage(m, s.svcInfo, ri, remote.Stream, remote.Client)
    98  	payloadCodec, err := s.getPayloadCodecFromContentType()
    99  	if err != nil {
   100  		return err
   101  	}
   102  	msg.SetProtocolInfo(remote.NewProtocolInfo(ri.Config().TransportProtocol(), payloadCodec))
   103  	defer msg.Recycle()
   104  
   105  	_, err = s.handler.Read(s.ctx, s.conn, msg)
   106  	return err
   107  }
   108  
   109  func (s *stream) SendMsg(m interface{}) error {
   110  	ri := rpcinfo.GetRPCInfo(s.ctx)
   111  
   112  	msg := remote.NewMessage(m, s.svcInfo, ri, remote.Stream, remote.Client)
   113  	payloadCodec, err := s.getPayloadCodecFromContentType()
   114  	if err != nil {
   115  		return err
   116  	}
   117  	msg.SetProtocolInfo(remote.NewProtocolInfo(ri.Config().TransportProtocol(), payloadCodec))
   118  	defer msg.Recycle()
   119  
   120  	_, err = s.handler.Write(s.ctx, s.conn, msg)
   121  	return err
   122  }
   123  
   124  func (s *stream) Close() error {
   125  	return s.conn.Close()
   126  }
   127  
   128  func (s *stream) getPayloadCodecFromContentType() (serviceinfo.PayloadCodec, error) {
   129  	// TODO: handle other protocols in the future. currently only supports grpc
   130  	var subType string
   131  	switch sc := s.conn.(type) {
   132  	case *clientConn:
   133  		subType = sc.s.ContentSubtype()
   134  	case *serverConn:
   135  		subType = sc.s.ContentSubtype()
   136  	}
   137  	switch subType {
   138  	case contentSubTypeThrift:
   139  		return serviceinfo.Thrift, nil
   140  	default:
   141  		return serviceinfo.Protobuf, nil
   142  	}
   143  }