trpc.group/trpc-go/trpc-go@v1.0.3/transport/client_transport_udp.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package transport
    15  
    16  import (
    17  	"context"
    18  	"fmt"
    19  	"net"
    20  
    21  	"trpc.group/trpc-go/trpc-go/codec"
    22  	"trpc.group/trpc-go/trpc-go/errs"
    23  	"trpc.group/trpc-go/trpc-go/internal/packetbuffer"
    24  	"trpc.group/trpc-go/trpc-go/internal/report"
    25  )
    26  
    27  const defaultUDPRecvBufSize = 64 * 1024
    28  
    29  // udpRoundTrip sends UDP requests.
    30  func (c *clientTransport) udpRoundTrip(ctx context.Context, reqData []byte,
    31  	opts *RoundTripOptions) ([]byte, error) {
    32  	if opts.FramerBuilder == nil {
    33  		return nil, errs.NewFrameError(errs.RetClientConnectFail,
    34  			"udp client transport: framer builder empty")
    35  	}
    36  
    37  	conn, addr, err := c.dialUDP(ctx, opts)
    38  	if err != nil {
    39  		return nil, err
    40  	}
    41  	defer conn.Close()
    42  	msg := codec.Message(ctx)
    43  	msg.WithRemoteAddr(addr)
    44  	msg.WithLocalAddr(conn.LocalAddr())
    45  
    46  	if ctx.Err() == context.Canceled {
    47  		return nil, errs.NewFrameError(errs.RetClientCanceled,
    48  			"udp client transport canceled before Write: "+ctx.Err().Error())
    49  	}
    50  	if ctx.Err() == context.DeadlineExceeded {
    51  		return nil, errs.NewFrameError(errs.RetClientTimeout,
    52  			"udp client transport timeout before Write: "+ctx.Err().Error())
    53  	}
    54  
    55  	report.UDPClientTransportSendSize.Set(float64(len(reqData)))
    56  	if err := c.udpWriteFrame(conn, reqData, addr, opts); err != nil {
    57  		return nil, err
    58  	}
    59  	return c.udpReadFrame(ctx, conn, opts)
    60  }
    61  
    62  // udpReadFrame reads UDP frame.
    63  func (c *clientTransport) udpReadFrame(
    64  	ctx context.Context, conn net.PacketConn, opts *RoundTripOptions) ([]byte, error) {
    65  	// If it is SendOnly, returns directly without waiting for the server's response.
    66  	if opts.ReqType == SendOnly {
    67  		return nil, errs.ErrClientNoResponse
    68  	}
    69  
    70  	select {
    71  	case <-ctx.Done():
    72  		return nil, errs.NewFrameError(errs.RetClientTimeout, "udp client transport select after Write: "+ctx.Err().Error())
    73  	default:
    74  	}
    75  
    76  	buf := packetbuffer.New(conn, defaultUDPRecvBufSize)
    77  	defer buf.Close()
    78  	fr := opts.FramerBuilder.New(buf)
    79  	req, err := fr.ReadFrame()
    80  	if err != nil {
    81  		report.UDPClientTransportReadFail.Incr()
    82  		if e, ok := err.(net.Error); ok {
    83  			if e.Timeout() {
    84  				return nil, errs.NewFrameError(errs.RetClientTimeout,
    85  					"udp client transport ReadFrame: "+err.Error())
    86  			}
    87  			return nil, errs.NewFrameError(errs.RetClientNetErr,
    88  				"udp client transport ReadFrom: "+err.Error())
    89  		}
    90  		return nil, errs.NewFrameError(errs.RetClientReadFrameErr,
    91  			"udp client transport ReadFrame: "+err.Error())
    92  	}
    93  	// One packet of udp corresponds to one trpc packet,
    94  	// and after parsing, there should not be any remaining data
    95  	if err := buf.Next(); err != nil {
    96  		report.UDPClientTransportUnRead.Incr()
    97  		return nil, errs.NewFrameError(errs.RetClientReadFrameErr,
    98  			fmt.Sprintf("udp client transport ReadFrame: %s", err))
    99  	}
   100  	report.UDPClientTransportReceiveSize.Set(float64(len(req)))
   101  	// Framer is used for every request so there is no need to copy memory.
   102  	return req, nil
   103  }
   104  
   105  // udpWriteReqData write UDP frame.
   106  func (c *clientTransport) udpWriteFrame(conn net.PacketConn,
   107  	reqData []byte, addr *net.UDPAddr, opts *RoundTripOptions) error {
   108  	// Sending udp request packets
   109  	var num int
   110  	var err error
   111  	if opts.ConnectionMode == Connected {
   112  		udpconn := conn.(*net.UDPConn)
   113  		num, err = udpconn.Write(reqData)
   114  	} else {
   115  		num, err = conn.WriteTo(reqData, addr)
   116  	}
   117  	if err != nil {
   118  		if e, ok := err.(net.Error); ok && e.Timeout() {
   119  			return errs.NewFrameError(errs.RetClientTimeout, "udp client transport WriteTo: "+err.Error())
   120  		}
   121  		return errs.NewFrameError(errs.RetClientNetErr, "udp client transport WriteTo: "+err.Error())
   122  	}
   123  	if num != len(reqData) {
   124  		return errs.NewFrameError(errs.RetClientNetErr, "udp client transport WriteTo: num mismatch")
   125  	}
   126  	return nil
   127  }
   128  
   129  // dialUDP establishes an UDP connection.
   130  func (c *clientTransport) dialUDP(ctx context.Context, opts *RoundTripOptions) (net.PacketConn, *net.UDPAddr, error) {
   131  	addr, err := net.ResolveUDPAddr(opts.Network, opts.Address)
   132  	if err != nil {
   133  		return nil, nil, errs.NewFrameError(errs.RetClientNetErr,
   134  			"udp client transport ResolveUDPAddr: "+err.Error())
   135  	}
   136  
   137  	var conn net.PacketConn
   138  	if opts.ConnectionMode == Connected {
   139  		var localAddr net.Addr
   140  		if opts.LocalAddr != "" {
   141  			localAddr, err = net.ResolveUDPAddr(opts.Network, opts.LocalAddr)
   142  			if err != nil {
   143  				return nil, nil, errs.NewFrameError(errs.RetClientNetErr,
   144  					"udp client transport LocalAddr ResolveUDPAddr: "+err.Error())
   145  			}
   146  		}
   147  		dialer := net.Dialer{
   148  			LocalAddr: localAddr,
   149  		}
   150  		var udpConn net.Conn
   151  		udpConn, err = dialer.Dial(opts.Network, opts.Address)
   152  		if err != nil {
   153  			return nil, nil, errs.NewFrameError(errs.RetClientConnectFail,
   154  				fmt.Sprintf("dial udp fail: %s", err.Error()))
   155  		}
   156  
   157  		var ok bool
   158  		conn, ok = udpConn.(net.PacketConn)
   159  		if !ok {
   160  			return nil, nil, errs.NewFrameError(errs.RetClientConnectFail,
   161  				"udp conn not implement net.PacketConn")
   162  		}
   163  	} else {
   164  		// Listen on all available IP addresses of the local system by default,
   165  		// and a port number is automatically chosen.
   166  		const defaultLocalAddr = ":"
   167  		localAddr := defaultLocalAddr
   168  		if opts.LocalAddr != "" {
   169  			localAddr = opts.LocalAddr
   170  		}
   171  		conn, err = net.ListenPacket(opts.Network, localAddr)
   172  	}
   173  	if err != nil {
   174  		return nil, nil, errs.NewFrameError(errs.RetClientNetErr, "udp client transport Dial: "+err.Error())
   175  	}
   176  	d, ok := ctx.Deadline()
   177  	if ok {
   178  		conn.SetDeadline(d)
   179  	}
   180  	return conn, addr, nil
   181  }