go.nanomsg.org/mangos/v3@v3.4.3-0.20240217232803-46464076f1f5/transport/tlstcp/tlstcp.go (about)

     1  // Copyright 2020 The Mangos Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use file except in compliance with the License.
     5  // You may obtain a copy of the license at
     6  //
     7  //    http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package tlstcp implements the TLS over TCP transport for mangos.
    16  // To enable it simply import it.
    17  package tlstcp
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"net"
    23  	"sync"
    24  	"time"
    25  
    26  	"go.nanomsg.org/mangos/v3"
    27  	"go.nanomsg.org/mangos/v3/transport"
    28  )
    29  
    30  // Transport is a transport.Transport for TLS over TCP.
    31  const Transport = tlsTran(0)
    32  
    33  func init() {
    34  	transport.RegisterTransport(Transport)
    35  }
    36  
    37  type dialer struct {
    38  	addr        string
    39  	proto       transport.ProtocolInfo
    40  	hs          transport.Handshaker
    41  	d           *net.Dialer
    42  	config      *tls.Config
    43  	maxRecvSize int
    44  	lock        sync.Mutex
    45  }
    46  
    47  func (d *dialer) Dial() (transport.Pipe, error) {
    48  
    49  	d.lock.Lock()
    50  	config := d.config
    51  	maxRecvSize := d.maxRecvSize
    52  	d.lock.Unlock()
    53  
    54  	conn, err := tls.DialWithDialer(d.d, "tcp", d.addr, config)
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  	p := transport.NewConnPipe(conn, d.proto)
    59  	p.SetOption(mangos.OptionMaxRecvSize, maxRecvSize)
    60  	p.SetOption(mangos.OptionTLSConnState, conn.ConnectionState())
    61  	d.hs.Start(p)
    62  	return d.hs.Wait()
    63  }
    64  
    65  func (d *dialer) SetOption(n string, v interface{}) error {
    66  	switch n {
    67  	case mangos.OptionMaxRecvSize:
    68  		if b, ok := v.(int); ok {
    69  			d.maxRecvSize = b
    70  			return nil
    71  		}
    72  		return mangos.ErrBadValue
    73  	case mangos.OptionTLSConfig:
    74  		if b, ok := v.(*tls.Config); ok {
    75  			d.config = b
    76  			return nil
    77  		}
    78  		return mangos.ErrBadValue
    79  	case mangos.OptionKeepAliveTime:
    80  		if b, ok := v.(time.Duration); ok {
    81  			d.d.KeepAlive = b
    82  			return nil
    83  		}
    84  		return mangos.ErrBadValue
    85  
    86  	// The following options exist *only* for compatibility reasons.
    87  	// Remove them from new code.
    88  
    89  	// We don't support disabling Nagle anymore.
    90  	case mangos.OptionNoDelay:
    91  		if _, ok := v.(bool); ok {
    92  			return nil
    93  		}
    94  		return mangos.ErrBadValue
    95  	case mangos.OptionKeepAlive:
    96  		if b, ok := v.(bool); ok {
    97  			if b {
    98  				d.d.KeepAlive = 0 // Enable (default time)
    99  			} else {
   100  				d.d.KeepAlive = -1 // Disable
   101  			}
   102  			return nil
   103  		}
   104  		return mangos.ErrBadValue
   105  	}
   106  	return mangos.ErrBadOption
   107  }
   108  
   109  func (d *dialer) GetOption(n string) (interface{}, error) {
   110  	d.lock.Lock()
   111  	defer d.lock.Unlock()
   112  	switch n {
   113  	case mangos.OptionMaxRecvSize:
   114  		return d.maxRecvSize, nil
   115  	case mangos.OptionNoDelay:
   116  		return true, nil // Compatibility only, always true
   117  	case mangos.OptionTLSConfig:
   118  		return d.config, nil
   119  	case mangos.OptionKeepAlive:
   120  		if d.d.KeepAlive >= 0 {
   121  			return true, nil
   122  		}
   123  		return false, nil
   124  	case mangos.OptionKeepAliveTime:
   125  		return d.d.KeepAlive, nil
   126  	}
   127  	return nil, mangos.ErrBadOption
   128  }
   129  
   130  type listener struct {
   131  	addr        string
   132  	bound       net.Addr
   133  	lc          net.ListenConfig
   134  	l           net.Listener
   135  	maxRecvSize int
   136  	proto       transport.ProtocolInfo
   137  	config      *tls.Config
   138  	hs          transport.Handshaker
   139  	closeQ      chan struct{}
   140  	once        sync.Once
   141  	lock        sync.Mutex
   142  }
   143  
   144  func (l *listener) Listen() error {
   145  	var err error
   146  	select {
   147  	case <-l.closeQ:
   148  		return mangos.ErrClosed
   149  	default:
   150  	}
   151  	l.lock.Lock()
   152  	config := l.config
   153  	if config == nil {
   154  		return mangos.ErrTLSNoConfig
   155  	}
   156  	if config.Certificates == nil || len(config.Certificates) == 0 {
   157  		l.lock.Unlock()
   158  		return mangos.ErrTLSNoCert
   159  	}
   160  
   161  	inner, err := l.lc.Listen(context.Background(), "tcp", l.addr)
   162  	if err != nil {
   163  		l.lock.Unlock()
   164  		return err
   165  	}
   166  	l.l = tls.NewListener(inner, config)
   167  	l.bound = l.l.Addr()
   168  	l.lock.Unlock()
   169  
   170  	go func() {
   171  		for {
   172  			conn, err := l.l.Accept()
   173  			if err != nil {
   174  				select {
   175  				case <-l.closeQ:
   176  					return
   177  				default:
   178  					time.Sleep(time.Millisecond)
   179  					continue
   180  				}
   181  			}
   182  
   183  			tc := conn.(*tls.Conn)
   184  			p := transport.NewConnPipe(conn, l.proto)
   185  			l.lock.Lock()
   186  			p.SetOption(mangos.OptionMaxRecvSize, l.maxRecvSize)
   187  			p.SetOption(mangos.OptionTLSConnState, tc.ConnectionState())
   188  			l.lock.Unlock()
   189  
   190  			l.hs.Start(p)
   191  		}
   192  	}()
   193  
   194  	return nil
   195  }
   196  
   197  func (l *listener) Address() string {
   198  	if b := l.bound; b != nil {
   199  		return "tls+tcp://" + b.String()
   200  	}
   201  	return "tls+tcp://" + l.addr
   202  }
   203  
   204  func (l *listener) Accept() (transport.Pipe, error) {
   205  	if l.l == nil {
   206  		return nil, mangos.ErrClosed
   207  	}
   208  	return l.hs.Wait()
   209  }
   210  
   211  func (l *listener) Close() error {
   212  	l.once.Do(func() {
   213  		if l.l != nil {
   214  			_ = l.l.Close()
   215  		}
   216  		l.hs.Close()
   217  		close(l.closeQ)
   218  	})
   219  	return nil
   220  }
   221  
   222  func (l *listener) SetOption(n string, v interface{}) error {
   223  	l.lock.Lock()
   224  	defer l.lock.Unlock()
   225  	switch n {
   226  	case mangos.OptionMaxRecvSize:
   227  		if b, ok := v.(int); ok {
   228  			l.maxRecvSize = b
   229  			return nil
   230  		}
   231  		return mangos.ErrBadValue
   232  	case mangos.OptionTLSConfig:
   233  		if b, ok := v.(*tls.Config); ok {
   234  			l.config = b
   235  			return nil
   236  		}
   237  		return mangos.ErrBadValue
   238  	case mangos.OptionKeepAliveTime:
   239  		if b, ok := v.(time.Duration); ok {
   240  			l.lc.KeepAlive = b
   241  			return nil
   242  		}
   243  		return mangos.ErrBadValue
   244  
   245  		// Legacy stuff follows
   246  	case mangos.OptionNoDelay:
   247  		if _, ok := v.(bool); ok {
   248  			return nil
   249  		}
   250  		return mangos.ErrBadValue
   251  	case mangos.OptionKeepAlive:
   252  		if b, ok := v.(bool); ok {
   253  			if b {
   254  				l.lc.KeepAlive = 0
   255  			} else {
   256  				l.lc.KeepAlive = -1
   257  			}
   258  			return nil
   259  		}
   260  		return mangos.ErrBadValue
   261  
   262  	}
   263  	return mangos.ErrBadOption
   264  }
   265  
   266  func (l *listener) GetOption(n string) (interface{}, error) {
   267  	l.lock.Lock()
   268  	defer l.lock.Unlock()
   269  	switch n {
   270  	case mangos.OptionMaxRecvSize:
   271  		return l.maxRecvSize, nil
   272  	case mangos.OptionTLSConfig:
   273  		return l.config, nil
   274  	case mangos.OptionKeepAliveTime:
   275  		return l.lc.KeepAlive, nil
   276  	case mangos.OptionNoDelay:
   277  		return true, nil
   278  	case mangos.OptionKeepAlive:
   279  		if l.lc.KeepAlive >= 0 {
   280  			return true, nil
   281  		}
   282  		return false, nil
   283  	}
   284  	return nil, mangos.ErrBadOption
   285  }
   286  
   287  type tlsTran int
   288  
   289  func (t tlsTran) Scheme() string {
   290  	return "tls+tcp"
   291  }
   292  
   293  func (t tlsTran) NewDialer(addr string, sock mangos.Socket) (transport.Dialer, error) {
   294  	var err error
   295  
   296  	if addr, err = transport.StripScheme(t, addr); err != nil {
   297  		return nil, err
   298  	}
   299  
   300  	// check to ensure the provided addr resolves correctly.
   301  	if _, err = transport.ResolveTCPAddr(addr); err != nil {
   302  		return nil, err
   303  	}
   304  
   305  	d := &dialer{
   306  		proto: sock.Info(),
   307  		addr:  addr,
   308  		hs:    transport.NewConnHandshaker(),
   309  		d:     &net.Dialer{},
   310  	}
   311  	return d, nil
   312  }
   313  
   314  // NewListener implements the Transport NewListener method.
   315  func (t tlsTran) NewListener(addr string, sock mangos.Socket) (transport.Listener, error) {
   316  	l := &listener{
   317  		proto:  sock.Info(),
   318  		closeQ: make(chan struct{}),
   319  	}
   320  
   321  	var err error
   322  	if addr, err = transport.StripScheme(t, addr); err != nil {
   323  		return nil, err
   324  	}
   325  	l.addr = addr
   326  	l.hs = transport.NewConnHandshaker()
   327  
   328  	return l, nil
   329  }