github.com/cloudwego/hertz@v0.9.3/pkg/network/netpoll/transport.go (about)

     1  // Copyright 2022 CloudWeGo Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  
    16  //go:build !windows
    17  // +build !windows
    18  
    19  package netpoll
    20  
    21  import (
    22  	"context"
    23  	"io"
    24  	"net"
    25  	"sync"
    26  	"time"
    27  
    28  	"github.com/cloudwego/hertz/pkg/common/config"
    29  	"github.com/cloudwego/hertz/pkg/common/hlog"
    30  	"github.com/cloudwego/hertz/pkg/network"
    31  	"github.com/cloudwego/netpoll"
    32  )
    33  
    34  func init() {
    35  	// disable netpoll's log
    36  	netpoll.SetLoggerOutput(io.Discard)
    37  }
    38  
    39  type ctxCancelKeyStruct struct{}
    40  
    41  var ctxCancelKey = ctxCancelKeyStruct{}
    42  
    43  func cancelContext(ctx context.Context) context.Context {
    44  	ctx, cancel := context.WithCancel(ctx)
    45  	ctx = context.WithValue(ctx, ctxCancelKey, cancel)
    46  	return ctx
    47  }
    48  
    49  type transporter struct {
    50  	sync.RWMutex
    51  	senseClientDisconnection bool
    52  	network                  string
    53  	addr                     string
    54  	keepAliveTimeout         time.Duration
    55  	readTimeout              time.Duration
    56  	writeTimeout             time.Duration
    57  	listener                 net.Listener
    58  	eventLoop                netpoll.EventLoop
    59  	listenConfig             *net.ListenConfig
    60  	OnAccept                 func(conn net.Conn) context.Context
    61  	OnConnect                func(ctx context.Context, conn network.Conn) context.Context
    62  }
    63  
    64  // For transporter switch
    65  func NewTransporter(options *config.Options) network.Transporter {
    66  	return &transporter{
    67  		senseClientDisconnection: options.SenseClientDisconnection,
    68  		network:                  options.Network,
    69  		addr:                     options.Addr,
    70  		keepAliveTimeout:         options.KeepAliveTimeout,
    71  		readTimeout:              options.ReadTimeout,
    72  		writeTimeout:             options.WriteTimeout,
    73  		listener:                 nil,
    74  		eventLoop:                nil,
    75  		listenConfig:             options.ListenConfig,
    76  		OnAccept:                 options.OnAccept,
    77  		OnConnect:                options.OnConnect,
    78  	}
    79  }
    80  
    81  // ListenAndServe binds listen address and keep serving, until an error occurs
    82  // or the transport shutdowns
    83  func (t *transporter) ListenAndServe(onReq network.OnData) (err error) {
    84  	network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck
    85  	if t.listenConfig != nil {
    86  		t.listener, err = t.listenConfig.Listen(context.Background(), t.network, t.addr)
    87  	} else {
    88  		t.listener, err = net.Listen(t.network, t.addr)
    89  	}
    90  
    91  	if err != nil {
    92  		panic("create netpoll listener fail: " + err.Error())
    93  	}
    94  
    95  	// Initialize custom option for EventLoop
    96  	opts := []netpoll.Option{
    97  		netpoll.WithIdleTimeout(t.keepAliveTimeout),
    98  		netpoll.WithOnPrepare(func(conn netpoll.Connection) context.Context {
    99  			conn.SetReadTimeout(t.readTimeout) // nolint:errcheck
   100  			if t.writeTimeout > 0 {
   101  				conn.SetWriteTimeout(t.writeTimeout)
   102  			}
   103  			ctx := context.Background()
   104  			if t.OnAccept != nil {
   105  				ctx = t.OnAccept(newConn(conn))
   106  			}
   107  			if t.senseClientDisconnection {
   108  				ctx = cancelContext(ctx)
   109  			}
   110  			return ctx
   111  		}),
   112  	}
   113  
   114  	if t.OnConnect != nil {
   115  		opts = append(opts, netpoll.WithOnConnect(func(ctx context.Context, conn netpoll.Connection) context.Context {
   116  			return t.OnConnect(ctx, newConn(conn))
   117  		}))
   118  	}
   119  
   120  	if t.senseClientDisconnection {
   121  		opts = append(opts, netpoll.WithOnDisconnect(func(ctx context.Context, connection netpoll.Connection) {
   122  			cancelFunc, ok := ctx.Value(ctxCancelKey).(context.CancelFunc)
   123  			if cancelFunc != nil && ok {
   124  				cancelFunc()
   125  			}
   126  		}))
   127  	}
   128  
   129  	// Create EventLoop
   130  	t.Lock()
   131  	t.eventLoop, err = netpoll.NewEventLoop(func(ctx context.Context, connection netpoll.Connection) error {
   132  		return onReq(ctx, newConn(connection))
   133  	}, opts...)
   134  	t.Unlock()
   135  	if err != nil {
   136  		panic("create netpoll event-loop fail")
   137  	}
   138  
   139  	// Start Server
   140  	hlog.SystemLogger().Infof("HTTP server listening on address=%s", t.listener.Addr().String())
   141  	t.RLock()
   142  	err = t.eventLoop.Serve(t.listener)
   143  	t.RUnlock()
   144  	if err != nil {
   145  		panic("netpoll server exit")
   146  	}
   147  
   148  	return nil
   149  }
   150  
   151  // Close forces transport to close immediately (no wait timeout)
   152  func (t *transporter) Close() error {
   153  	ctx, cancel := context.WithTimeout(context.Background(), 0)
   154  	defer cancel()
   155  	return t.Shutdown(ctx)
   156  }
   157  
   158  // Shutdown will trigger listener stop and graceful shutdown
   159  // It will wait all connections close until reaching ctx.Deadline()
   160  func (t *transporter) Shutdown(ctx context.Context) error {
   161  	defer func() {
   162  		network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck
   163  		t.RUnlock()
   164  	}()
   165  	t.RLock()
   166  	if t.eventLoop == nil {
   167  		return nil
   168  	}
   169  	return t.eventLoop.Shutdown(ctx)
   170  }