nanomsg.org/go/mangos/v2@v2.0.9-0.20200203084354-8a092611e461/transport/tlstcp/tlstcp.go (about)

     1  // Copyright 2019 The Mangos Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use file except in compliance with the License.
     5  // You may obtain a copy of the license at
     6  //
     7  //    http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package tlstcp implements the TLS over TCP transport for mangos.
    16  // To enable it simply import it.
    17  package tlstcp
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"net"
    23  	"sync"
    24  	"time"
    25  
    26  	"nanomsg.org/go/mangos/v2"
    27  	"nanomsg.org/go/mangos/v2/transport"
    28  )
    29  
    30  // Transport is a transport.Transport for TLS over TCP.
    31  const Transport = tlsTran(0)
    32  
    33  func init() {
    34  	transport.RegisterTransport(Transport)
    35  }
    36  
    37  type dialer struct {
    38  	addr        string
    39  	proto       transport.ProtocolInfo
    40  	hs          transport.Handshaker
    41  	d           *net.Dialer
    42  	config      *tls.Config
    43  	maxRecvSize int
    44  	lock        sync.Mutex
    45  }
    46  
    47  func (d *dialer) Dial() (transport.Pipe, error) {
    48  
    49  	d.lock.Lock()
    50  	config := d.config
    51  	maxRecvSize := d.maxRecvSize
    52  	d.lock.Unlock()
    53  
    54  	conn, err := tls.DialWithDialer(d.d, "tcp", d.addr, config)
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  	opts := make(map[string]interface{})
    59  	opts[mangos.OptionTLSConnState] = conn.ConnectionState()
    60  	p := transport.NewConnPipe(conn, d.proto, opts)
    61  	p.SetMaxRecvSize(maxRecvSize)
    62  	d.hs.Start(p)
    63  	return d.hs.Wait()
    64  }
    65  
    66  func (d *dialer) SetOption(n string, v interface{}) error {
    67  	switch n {
    68  	case mangos.OptionMaxRecvSize:
    69  		if b, ok := v.(int); ok {
    70  			d.maxRecvSize = b
    71  			return nil
    72  		}
    73  		return mangos.ErrBadValue
    74  	case mangos.OptionTLSConfig:
    75  		if b, ok := v.(*tls.Config); ok {
    76  			d.config = b
    77  			return nil
    78  		}
    79  		return mangos.ErrBadValue
    80  	case mangos.OptionKeepAliveTime:
    81  		if b, ok := v.(time.Duration); ok {
    82  			d.d.KeepAlive = b
    83  			return nil
    84  		}
    85  		return mangos.ErrBadValue
    86  
    87  	// The following options exist *only* for compatibility reasons.
    88  	// Remove them from new code.
    89  
    90  	// We don't support disabling Nagle anymore.
    91  	case mangos.OptionNoDelay:
    92  		if _, ok := v.(bool); ok {
    93  			return nil
    94  		}
    95  		return mangos.ErrBadValue
    96  	case mangos.OptionKeepAlive:
    97  		if b, ok := v.(bool); ok {
    98  			if b {
    99  				d.d.KeepAlive = 0 // Enable (default time)
   100  			} else {
   101  				d.d.KeepAlive = -1 // Disable
   102  			}
   103  			return nil
   104  		}
   105  		return mangos.ErrBadValue
   106  	}
   107  	return mangos.ErrBadOption
   108  }
   109  
   110  func (d *dialer) GetOption(n string) (interface{}, error) {
   111  	d.lock.Lock()
   112  	defer d.lock.Unlock()
   113  	switch n {
   114  	case mangos.OptionMaxRecvSize:
   115  		return d.maxRecvSize, nil
   116  	case mangos.OptionNoDelay:
   117  		return true, nil // Compatibility only, always true
   118  	case mangos.OptionTLSConfig:
   119  		return d.config, nil
   120  	case mangos.OptionKeepAlive:
   121  		if d.d.KeepAlive >= 0 {
   122  			return true, nil
   123  		}
   124  		return false, nil
   125  	case mangos.OptionKeepAliveTime:
   126  		return d.d.KeepAlive, nil
   127  	}
   128  	return nil, mangos.ErrBadOption
   129  }
   130  
   131  type listener struct {
   132  	addr        string
   133  	bound       net.Addr
   134  	lc          net.ListenConfig
   135  	l           net.Listener
   136  	maxRecvSize int
   137  	proto       transport.ProtocolInfo
   138  	config      *tls.Config
   139  	hs          transport.Handshaker
   140  	closeQ      chan struct{}
   141  	once        sync.Once
   142  	lock        sync.Mutex
   143  }
   144  
   145  func (l *listener) Listen() error {
   146  	var err error
   147  	select {
   148  	case <-l.closeQ:
   149  		return mangos.ErrClosed
   150  	default:
   151  	}
   152  	l.lock.Lock()
   153  	config := l.config
   154  	if config == nil {
   155  		return mangos.ErrTLSNoConfig
   156  	}
   157  	if config.Certificates == nil || len(config.Certificates) == 0 {
   158  		l.lock.Unlock()
   159  		return mangos.ErrTLSNoCert
   160  	}
   161  
   162  	inner, err := l.lc.Listen(context.Background(), "tcp", l.addr)
   163  	if err != nil {
   164  		l.lock.Unlock()
   165  		return err
   166  	}
   167  	l.l = tls.NewListener(inner, config)
   168  	l.bound = l.l.Addr()
   169  	l.lock.Unlock()
   170  
   171  	go func() {
   172  		for {
   173  			conn, err := l.l.Accept()
   174  			if err != nil {
   175  				select {
   176  				case <-l.closeQ:
   177  					return
   178  				default:
   179  					time.Sleep(time.Millisecond)
   180  					continue
   181  				}
   182  			}
   183  
   184  			tc := conn.(*tls.Conn)
   185  			opts := make(map[string]interface{})
   186  			l.lock.Lock()
   187  			maxRecvSize := l.maxRecvSize
   188  			l.lock.Unlock()
   189  			opts[mangos.OptionTLSConnState] = tc.ConnectionState()
   190  			p := transport.NewConnPipe(conn, l.proto, opts)
   191  			p.SetMaxRecvSize(maxRecvSize)
   192  
   193  			l.hs.Start(p)
   194  		}
   195  	}()
   196  
   197  	return nil
   198  }
   199  
   200  func (l *listener) Address() string {
   201  	if b := l.bound; b != nil {
   202  		return "tls+tcp://" + b.String()
   203  	}
   204  	return "tls+tcp://" + l.addr
   205  }
   206  
   207  func (l *listener) Accept() (transport.Pipe, error) {
   208  	if l.l == nil {
   209  		return nil, mangos.ErrClosed
   210  	}
   211  	return l.hs.Wait()
   212  }
   213  
   214  func (l *listener) Close() error {
   215  	l.once.Do(func() {
   216  		if l.l != nil {
   217  			_ = l.l.Close()
   218  		}
   219  		l.hs.Close()
   220  		close(l.closeQ)
   221  	})
   222  	return nil
   223  }
   224  
   225  func (l *listener) SetOption(n string, v interface{}) error {
   226  	l.lock.Lock()
   227  	defer l.lock.Unlock()
   228  	switch n {
   229  	case mangos.OptionMaxRecvSize:
   230  		if b, ok := v.(int); ok {
   231  			l.maxRecvSize = b
   232  			return nil
   233  		}
   234  		return mangos.ErrBadValue
   235  	case mangos.OptionTLSConfig:
   236  		if b, ok := v.(*tls.Config); ok {
   237  			l.config = b
   238  			return nil
   239  		}
   240  		return mangos.ErrBadValue
   241  	case mangos.OptionKeepAliveTime:
   242  		if b, ok := v.(time.Duration); ok {
   243  			l.lc.KeepAlive = b
   244  			return nil
   245  		}
   246  		return mangos.ErrBadValue
   247  
   248  		// Legacy stuff follows
   249  	case mangos.OptionNoDelay:
   250  		if _, ok := v.(bool); ok {
   251  			return nil
   252  		}
   253  		return mangos.ErrBadValue
   254  	case mangos.OptionKeepAlive:
   255  		if b, ok := v.(bool); ok {
   256  			if b {
   257  				l.lc.KeepAlive = 0
   258  			} else {
   259  				l.lc.KeepAlive = -1
   260  			}
   261  			return nil
   262  		}
   263  		return mangos.ErrBadValue
   264  
   265  	}
   266  	return mangos.ErrBadOption
   267  }
   268  
   269  func (l *listener) GetOption(n string) (interface{}, error) {
   270  	l.lock.Lock()
   271  	defer l.lock.Unlock()
   272  	switch n {
   273  	case mangos.OptionMaxRecvSize:
   274  		return l.maxRecvSize, nil
   275  	case mangos.OptionTLSConfig:
   276  		return l.config, nil
   277  	case mangos.OptionKeepAliveTime:
   278  		return l.lc.KeepAlive, nil
   279  	case mangos.OptionNoDelay:
   280  		return true, nil
   281  	case mangos.OptionKeepAlive:
   282  		if l.lc.KeepAlive >= 0 {
   283  			return true, nil
   284  		}
   285  		return false, nil
   286  	}
   287  	return nil, mangos.ErrBadOption
   288  }
   289  
   290  type tlsTran int
   291  
   292  func (t tlsTran) Scheme() string {
   293  	return "tls+tcp"
   294  }
   295  
   296  func (t tlsTran) NewDialer(addr string, sock mangos.Socket) (transport.Dialer, error) {
   297  	var err error
   298  
   299  	if addr, err = transport.StripScheme(t, addr); err != nil {
   300  		return nil, err
   301  	}
   302  
   303  	// check to ensure the provided addr resolves correctly.
   304  	if _, err = transport.ResolveTCPAddr(addr); err != nil {
   305  		return nil, err
   306  	}
   307  
   308  	d := &dialer{
   309  		proto: sock.Info(),
   310  		addr:  addr,
   311  		hs:    transport.NewConnHandshaker(),
   312  		d:     &net.Dialer{},
   313  	}
   314  	return d, nil
   315  }
   316  
   317  // NewListener implements the Transport NewListener method.
   318  func (t tlsTran) NewListener(addr string, sock mangos.Socket) (transport.Listener, error) {
   319  	l := &listener{
   320  		proto:  sock.Info(),
   321  		closeQ: make(chan struct{}),
   322  	}
   323  
   324  	var err error
   325  	if addr, err = transport.StripScheme(t, addr); err != nil {
   326  		return nil, err
   327  	}
   328  	l.addr = addr
   329  	l.hs = transport.NewConnHandshaker()
   330  
   331  	return l, nil
   332  }