github.com/cloudwego/hertz@v0.9.3/pkg/network/standard/transport.go (about)

     1  /*
     2   * Copyright 2022 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package standard
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"net"
    23  	"sync"
    24  	"time"
    25  
    26  	"github.com/cloudwego/hertz/pkg/common/config"
    27  	"github.com/cloudwego/hertz/pkg/common/hlog"
    28  	"github.com/cloudwego/hertz/pkg/network"
    29  )
    30  
    31  type transport struct {
    32  	// Per-connection buffer size for requests' reading.
    33  	// This also limits the maximum header size.
    34  	//
    35  	// Increase this buffer if your clients send multi-KB RequestURIs
    36  	// and/or multi-KB headers (for example, BIG cookies).
    37  	//
    38  	// Default buffer size is used if not set.
    39  	readBufferSize   int
    40  	network          string
    41  	addr             string
    42  	keepAliveTimeout time.Duration
    43  	readTimeout      time.Duration
    44  	handler          network.OnData
    45  	ln               net.Listener
    46  	tls              *tls.Config
    47  	listenConfig     *net.ListenConfig
    48  	lock             sync.Mutex
    49  	OnAccept         func(conn net.Conn) context.Context
    50  	OnConnect        func(ctx context.Context, conn network.Conn) context.Context
    51  }
    52  
    53  func (t *transport) serve() (err error) {
    54  	network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck
    55  	t.lock.Lock()
    56  	if t.listenConfig != nil {
    57  		t.ln, err = t.listenConfig.Listen(context.Background(), t.network, t.addr)
    58  	} else {
    59  		t.ln, err = net.Listen(t.network, t.addr)
    60  	}
    61  	t.lock.Unlock()
    62  	if err != nil {
    63  		return err
    64  	}
    65  	hlog.SystemLogger().Infof("HTTP server listening on address=%s", t.ln.Addr().String())
    66  	for {
    67  		ctx := context.Background()
    68  		conn, err := t.ln.Accept()
    69  		var c network.Conn
    70  		if err != nil {
    71  			hlog.SystemLogger().Errorf("Error=%s", err.Error())
    72  			return err
    73  		}
    74  
    75  		if t.OnAccept != nil {
    76  			ctx = t.OnAccept(conn)
    77  		}
    78  
    79  		if t.tls != nil {
    80  			c = newTLSConn(tls.Server(conn, t.tls), t.readBufferSize)
    81  		} else {
    82  			c = newConn(conn, t.readBufferSize)
    83  		}
    84  
    85  		if t.OnConnect != nil {
    86  			ctx = t.OnConnect(ctx, c)
    87  		}
    88  		go t.handler(ctx, c)
    89  	}
    90  }
    91  
    92  func (t *transport) ListenAndServe(onData network.OnData) (err error) {
    93  	t.handler = onData
    94  	return t.serve()
    95  }
    96  
    97  func (t *transport) Close() error {
    98  	ctx, cancel := context.WithTimeout(context.Background(), 0)
    99  	defer cancel()
   100  	return t.Shutdown(ctx)
   101  }
   102  
   103  func (t *transport) Shutdown(ctx context.Context) error {
   104  	defer func() {
   105  		network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck
   106  	}()
   107  	t.lock.Lock()
   108  	if t.ln != nil {
   109  		_ = t.ln.Close()
   110  	}
   111  	t.lock.Unlock()
   112  	<-ctx.Done()
   113  	return nil
   114  }
   115  
   116  // For transporter switch
   117  func NewTransporter(options *config.Options) network.Transporter {
   118  	return &transport{
   119  		readBufferSize:   options.ReadBufferSize,
   120  		network:          options.Network,
   121  		addr:             options.Addr,
   122  		keepAliveTimeout: options.KeepAliveTimeout,
   123  		readTimeout:      options.ReadTimeout,
   124  		tls:              options.TLS,
   125  		listenConfig:     options.ListenConfig,
   126  		OnAccept:         options.OnAccept,
   127  		OnConnect:        options.OnConnect,
   128  	}
   129  }