github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/tun/tun.go (about)

     1  /*
     2   * Copyright (c) 2017, Psiphon Inc.
     3   * All rights reserved.
     4   *
     5   * This program is free software: you can redistribute it and/or modify
     6   * it under the terms of the GNU General Public License as published by
     7   * the Free Software Foundation, either version 3 of the License, or
     8   * (at your option) any later version.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  // Copyright 2009 The Go Authors. All rights reserved.
    21  // Use of this source code is governed by a BSD-style
    22  // license that can be found in the LICENSE file.
    23  
    24  /*
    25  Package tun is an IP packet tunnel server and client. It supports tunneling
    26  both IPv4 and IPv6.
    27  
    28  	.........................................................       .-,(  ),-.
    29  	. [server]                                     .-----.  .    .-(          )-.
    30  	.                                              | NIC |<---->(    Internet    )
    31  	. .......................................      '-----'  .    '-(          ).-'
    32  	. . [packet tunnel daemon]              .         ^     .        '-.( ).-'
    33  	. .                                     .         |     .
    34  	. . ...........................         .         |     .
    35  	. . . [session]               .         .        NAT    .
    36  	. . .                         .         .         |     .
    37  	. . .                         .         .         v     .
    38  	. . .                         .         .       .---.   .
    39  	. . .                         .         .       | t |   .
    40  	. . .                         .         .       | u |   .
    41  	. . .                 .---.   .  .---.  .       | n |   .
    42  	. . .                 | q |   .  | d |  .       |   |   .
    43  	. . .                 | u |   .  | e |  .       | d |   .
    44  	. . .          .------| e |<-----| m |<---------| e |   .
    45  	. . .          |      | u |   .  | u |  .       | v |   .
    46  	. . .          |      | e |   .  | x |  .       | i |   .
    47  	. . .       rewrite   '---'   .  '---'  .       | c |   .
    48  	. . .          |              .         .       | e |   .
    49  	. . .          v              .         .       '---'   .
    50  	. . .     .---------.         .         .         ^     .
    51  	. . .     | channel |--rewrite--------------------'     .
    52  	. . .     '---------'         .         .               .
    53  	. . ...........^...............         .               .
    54  	. .............|.........................               .
    55  	...............|.........................................
    56  	               |
    57  	               | (typically via Internet)
    58  	               |
    59  	...............|.................
    60  	. [client]     |                .
    61  	.              |                .
    62  	. .............|............... .
    63  	. .            v              . .
    64  	. .       .---------.         . .
    65  	. .       | channel |         . .
    66  	. .       '---------'         . .
    67  	. .            ^              . .
    68  	. .............|............... .
    69  	.              v                .
    70  	.        .------------.         .
    71  	.        | tun device |         .
    72  	.        '------------'         .
    73  	.................................
    74  
    75  The client relays IP packets between a local tun device and a channel, which
    76  is a transport to the server. In Psiphon, the channel will be an SSH channel
    77  within an SSH connection to a Psiphon server.
    78  
    79  The server relays packets between each client and its own tun device. The
    80  server tun device is NATed to the Internet via an external network interface.
    81  In this way, client traffic is tunneled and will egress from the server host.
    82  
    83  Similar to a typical VPN, IP addresses are assigned to each client. Unlike
    84  a typical VPN, the assignment is not transmitted to the client. Instead, the
    85  server transparently rewrites the source addresses of client packets to
    86  the assigned IP address. The server also rewrites the destination address of
    87  certain DNS packets. The purpose of this is to allow clients to reconnect
    88  to different servers without having to tear down or change their local
    89  network configuration. Clients may configure their local tun device with an
    90  arbitrary IP address and a static DNS resolver address.
    91  
    92  The server uses the 24-bit 10.0.0.0/8 IPv4 private address space to maximize
    93  the number of addresses available, due to Psiphon client churn and minimum
    94  address lease time constraints. For IPv6, a 24-bit unique local space is used.
    95  When a client is allocated addresses, a unique, unused 24-bit "index" is
    96  reserved/leased. This index maps to and from IPv4 and IPv6 private addresses.
    97  The server multiplexes all client packets into a single tun device. When a
    98  packet is read, the destination address is used to map the packet back to the
    99  correct index, which maps back to the client.
   100  
   101  The server maintains client "sessions". A session maintains client IP
   102  address state and effectively holds the lease on assigned addresses. If a
   103  client is disconnected and quickly reconnects, it will resume its previous
   104  session, retaining its IP address and network connection states. Idle
   105  sessions with no client connection will eventually expire.
   106  
   107  Packet count and bytes transferred metrics are logged for each client session.
   108  
   109  The server integrates with and enforces Psiphon traffic rules and logging
   110  facilities. The server parses and validates packets. Client-to-client packets
   111  are not permitted. Only global unicast packets are permitted. Only TCP and UDP
   112  packets are permitted. The client also filters out, before sending, packets
   113  that the server won't route.
   114  
   115  Certain aspects of packet tunneling are outside the scope of this package;
   116  e.g, the Psiphon client and server are responsible for establishing an SSH
   117  channel and negotiating the correct MTU and DNS settings. The Psiphon
   118  server will call Server.ClientConnected when a client connects and establishes
   119  a packet tunnel channel; and Server.ClientDisconnected when the client closes
   120  the channel and/or disconnects.
   121  */
   122  package tun
   123  
   124  import (
   125  	"context"
   126  	"encoding/binary"
   127  	"fmt"
   128  	"io"
   129  	"math/rand"
   130  	"net"
   131  	"sync"
   132  	"sync/atomic"
   133  	"time"
   134  	"unsafe"
   135  
   136  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
   137  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
   138  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/monotime"
   139  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
   140  )
   141  
   142  const (
   143  	DEFAULT_MTU                          = 1500
   144  	DEFAULT_DOWNSTREAM_PACKET_QUEUE_SIZE = 32768 * 16
   145  	DEFAULT_UPSTREAM_PACKET_QUEUE_SIZE   = 32768
   146  	DEFAULT_IDLE_SESSION_EXPIRY_SECONDS  = 300
   147  	ORPHAN_METRICS_CHECKPOINTER_PERIOD   = 30 * time.Minute
   148  	FLOW_IDLE_EXPIRY                     = 60 * time.Second
   149  )
   150  
   151  // ServerConfig specifies the configuration of a packet tunnel server.
   152  type ServerConfig struct {
   153  
   154  	// Logger is used for logging events and metrics.
   155  	Logger common.Logger
   156  
   157  	// SudoNetworkConfigCommands specifies whether to use "sudo"
   158  	// when executing network configuration commands. This is required
   159  	// when the packet tunnel server is not run as root and when
   160  	// process capabilities are not available (only Linux kernel 4.3+
   161  	// has the required capabilities support). The host sudoers file
   162  	// must be configured to allow the tunnel server process user to
   163  	// execute the commands invoked in configureServerInterface; see
   164  	// the implementation for the appropriate platform.
   165  	SudoNetworkConfigCommands bool
   166  
   167  	// AllowNoIPv6NetworkConfiguration indicates that failures while
   168  	// configuring tun interfaces and routing for IPv6 are to be
   169  	// logged as warnings only. This option is intended to support
   170  	// test cases on hosts without IPv6 and is not for production use;
   171  	// the packet tunnel server will still accept IPv6 packets and
   172  	// relay them to the tun device.
   173  	// AllowNoIPv6NetworkConfiguration may not be supported on all
   174  	// platforms.
   175  	AllowNoIPv6NetworkConfiguration bool
   176  
   177  	// EgressInterface is the interface to which client traffic is
   178  	// masqueraded/NATed. For example, "eth0". If blank, a platform-
   179  	// appropriate default is used.
   180  	EgressInterface string
   181  
   182  	// GetDNSResolverIPv4Addresses is a function which returns the
   183  	// DNS resolvers to use as transparent DNS rewrite targets for
   184  	// IPv4 DNS traffic.
   185  	//
   186  	// GetDNSResolverIPv4Addresses is invoked for each new client
   187  	// session and the list of resolvers is stored with the session.
   188  	// This is a compromise between checking current resolvers for
   189  	// each packet (too expensive) and simply passing in a static
   190  	// list (won't pick up resolver changes). As implemented, only
   191  	// new client sessions will pick up resolver changes.
   192  	//
   193  	// Transparent DNS rewriting occurs when the client uses the
   194  	// specific, target transparent DNS addresses specified by
   195  	// GetTransparentDNSResolverIPv4/6Address.
   196  	//
   197  	// For outbound DNS packets with a target resolver IP address,
   198  	// a random resolver is selected and used for the rewrite.
   199  	// For inbound packets, _any_ resolver in the list is rewritten
   200  	// back to the target resolver IP address. As a side-effect,
   201  	// responses to client DNS packets originally destined for a
   202  	// resolver in GetDNSResolverIPv4Addresses will be lost.
   203  	GetDNSResolverIPv4Addresses func() []net.IP
   204  
   205  	// GetDNSResolverIPv6Addresses is a function which returns the
   206  	// DNS resolvers to use as transparent DNS rewrite targets for
   207  	// IPv6 DNS traffic. It functions like GetDNSResolverIPv4Addresses.
   208  	GetDNSResolverIPv6Addresses func() []net.IP
   209  
   210  	// EnableDNSFlowTracking specifies whether to apply flow tracking to DNS
   211  	// flows, as required for DNS quality metrics. Typically there are many
   212  	// short-lived DNS flows to track and each tracked flow adds some overhead,
   213  	// so this defaults to off.
   214  	EnableDNSFlowTracking bool
   215  
   216  	// DownstreamPacketQueueSize specifies the size of the downstream
   217  	// packet queue. The packet tunnel server multiplexes all client
   218  	// packets through a single tun device, so when a packet is read,
   219  	// it must be queued or dropped if it cannot be immediately routed
   220  	// to the appropriate client. Note that the TCP and SSH windows
   221  	// for the underlying channel transport will impact transfer rate
   222  	// and queuing.
   223  	// When DownstreamPacketQueueSize is 0, a default value tuned for
   224  	// Psiphon is used.
   225  	DownstreamPacketQueueSize int
   226  
   227  	// MTU specifies the maximum transmission unit for the packet
   228  	// tunnel. Clients must be configured with the same MTU. The
   229  	// server's tun device will be set to this MTU value and is
   230  	// assumed not to change for the duration of the server.
   231  	// When MTU is 0, a default value is used.
   232  	MTU int
   233  
   234  	// SessionIdleExpirySeconds specifies how long to retain client
   235  	// sessions which have no client attached. Sessions are retained
   236  	// across client connections so reconnecting clients can resume
   237  	// a previous session. Resuming avoids leasing new IP addresses
   238  	// for reconnection, and also retains NAT state for active
   239  	// tunneled connections.
   240  	//
   241  	// SessionIdleExpirySeconds is also, effectively, the lease
   242  	// time for assigned IP addresses.
   243  	SessionIdleExpirySeconds int
   244  
   245  	// AllowBogons disables bogon checks. This should be used only
   246  	// for testing.
   247  	AllowBogons bool
   248  }
   249  
   250  // Server is a packet tunnel server. A packet tunnel server
   251  // maintains client sessions, relays packets through client
   252  // channels, and multiplexes packets through a single tun
   253  // device. The server assigns IP addresses to clients, performs
   254  // IP address and transparent DNS rewriting, and enforces
   255  // traffic rules.
   256  type Server struct {
   257  	config              *ServerConfig
   258  	device              *Device
   259  	indexToSession      sync.Map
   260  	sessionIDToIndex    sync.Map
   261  	connectedInProgress *sync.WaitGroup
   262  	workers             *sync.WaitGroup
   263  	runContext          context.Context
   264  	stopRunning         context.CancelFunc
   265  	orphanMetrics       *packetMetrics
   266  }
   267  
   268  // NewServer initializes a server.
   269  func NewServer(config *ServerConfig) (*Server, error) {
   270  
   271  	device, err := NewServerDevice(config)
   272  	if err != nil {
   273  		return nil, errors.Trace(err)
   274  	}
   275  
   276  	runContext, stopRunning := context.WithCancel(context.Background())
   277  
   278  	return &Server{
   279  		config:              config,
   280  		device:              device,
   281  		connectedInProgress: new(sync.WaitGroup),
   282  		workers:             new(sync.WaitGroup),
   283  		runContext:          runContext,
   284  		stopRunning:         stopRunning,
   285  		orphanMetrics:       new(packetMetrics),
   286  	}, nil
   287  }
   288  
   289  // Start starts a server and returns with it running.
   290  func (server *Server) Start() {
   291  
   292  	server.config.Logger.WithTrace().Info("starting")
   293  
   294  	server.workers.Add(1)
   295  	go server.runSessionReaper()
   296  
   297  	server.workers.Add(1)
   298  	go server.runOrphanMetricsCheckpointer()
   299  
   300  	server.workers.Add(1)
   301  	go server.runDeviceDownstream()
   302  }
   303  
   304  // Stop halts a running server.
   305  func (server *Server) Stop() {
   306  
   307  	server.config.Logger.WithTrace().Info("stopping")
   308  
   309  	server.stopRunning()
   310  
   311  	// Interrupt blocked device read/writes.
   312  	server.device.Close()
   313  
   314  	// Wait for any in-progress ClientConnected calls to complete.
   315  	server.connectedInProgress.Wait()
   316  
   317  	// After this point, no further clients will be added: all
   318  	// in-progress ClientConnected calls have finished; and any
   319  	// later ClientConnected calls won't get past their
   320  	// server.runContext.Done() checks.
   321  
   322  	// Close all clients. Client workers will be joined
   323  	// by the following server.workers.Wait().
   324  	server.indexToSession.Range(func(_, value interface{}) bool {
   325  		session := value.(*session)
   326  		server.interruptSession(session)
   327  		return true
   328  	})
   329  
   330  	server.workers.Wait()
   331  
   332  	server.config.Logger.WithTrace().Info("stopped")
   333  }
   334  
   335  // AllowedPortChecker is a function which returns true when it is
   336  // permitted to relay packets to the specified upstream IP address
   337  // and/or port.
   338  type AllowedPortChecker func(upstreamIPAddress net.IP, port int) bool
   339  
   340  // AllowedDomainChecker is a function which returns true when it is
   341  // permitted to resolve the specified domain name.
   342  type AllowedDomainChecker func(string) bool
   343  
   344  // FlowActivityUpdater defines an interface for receiving updates for
   345  // flow activity. Values passed to UpdateProgress are bytes transferred
   346  // and flow duration since the previous UpdateProgress.
   347  type FlowActivityUpdater interface {
   348  	UpdateProgress(downstreamBytes, upstreamBytes, durationNanoseconds int64)
   349  }
   350  
   351  // FlowActivityUpdaterMaker is a function which returns a list of
   352  // appropriate updaters for a new flow to the specified upstream
   353  // hostname (if known -- may be ""), and IP address.
   354  // The flow is TCP when isTCP is true, and UDP otherwise.
   355  type FlowActivityUpdaterMaker func(
   356  	isTCP bool, upstreamHostname string, upstreamIPAddress net.IP) []FlowActivityUpdater
   357  
   358  // MetricsUpdater is a function which receives a checkpoint summary
   359  // of application bytes transferred through a packet tunnel.
   360  type MetricsUpdater func(
   361  	TCPApplicationBytesDown, TCPApplicationBytesUp,
   362  	UDPApplicationBytesDown, UDPApplicationBytesUp int64)
   363  
   364  // DNSQualityReporter is a function which receives a DNS quality report:
   365  // whether a DNS request received a reponse, the elapsed time, and the
   366  // resolver used.
   367  type DNSQualityReporter func(
   368  	receivedResponse bool, requestDuration time.Duration, resolverIP net.IP)
   369  
   370  // ClientConnected handles new client connections, creating or resuming
   371  // a session and returns with client packet handlers running.
   372  //
   373  // sessionID is used to identify sessions for resumption.
   374  //
   375  // transport provides the channel for relaying packets to and from
   376  // the client.
   377  //
   378  // checkAllowedTCPPortFunc/checkAllowedUDPPortFunc/checkAllowedDomainFunc
   379  // are callbacks used to enforce traffic rules. For each TCP/UDP flow, the
   380  // corresponding AllowedPort function is called to check if traffic to the
   381  // packet's port is permitted. For upstream DNS query packets,
   382  // checkAllowedDomainFunc is called to check if domain resolution is
   383  // permitted. These callbacks must be efficient and safe for concurrent
   384  // calls.
   385  //
   386  // flowActivityUpdaterMaker is a callback invoked for each new packet
   387  // flow; it may create updaters to track flow activity.
   388  //
   389  // metricsUpdater is a callback invoked at metrics checkpoints (usually
   390  // when the client disconnects) with a summary of application bytes
   391  // transferred.
   392  //
   393  // It is safe to make concurrent calls to ClientConnected for distinct
   394  // session IDs. The caller is responsible for serializing calls with the
   395  // same session ID. Further, the caller must ensure, in the case of a client
   396  // transport reconnect when an existing transport has not yet disconnected,
   397  // that ClientDisconnected is called first -- so it doesn't undo the new
   398  // ClientConnected. (psiphond meets these constraints by closing any
   399  // existing SSH client with duplicate session ID early in the lifecycle of
   400  // a new SSH client connection.)
   401  func (server *Server) ClientConnected(
   402  	sessionID string,
   403  	transport io.ReadWriteCloser,
   404  	checkAllowedTCPPortFunc, checkAllowedUDPPortFunc AllowedPortChecker,
   405  	checkAllowedDomainFunc AllowedDomainChecker,
   406  	flowActivityUpdaterMaker FlowActivityUpdaterMaker,
   407  	metricsUpdater MetricsUpdater,
   408  	dnsQualityReporter DNSQualityReporter) error {
   409  
   410  	// It's unusual to call both sync.WaitGroup.Add() _and_ Done() in the same
   411  	// goroutine. There's no other place to call Add() since ClientConnected is
   412  	// an API entrypoint. And Done() works because the invariant enforced by
   413  	// connectedInProgress.Wait() is not that no ClientConnected calls are in
   414  	// progress, but that no such calls are in progress past the
   415  	// server.runContext.Done() check.
   416  
   417  	// TODO: will this violate https://golang.org/pkg/sync/#WaitGroup.Add:
   418  	// "calls with a positive delta that occur when the counter is zero must happen before a Wait"?
   419  
   420  	server.connectedInProgress.Add(1)
   421  	defer server.connectedInProgress.Done()
   422  
   423  	select {
   424  	case <-server.runContext.Done():
   425  		return errors.TraceNew("server stopping")
   426  	default:
   427  	}
   428  
   429  	server.config.Logger.WithTraceFields(
   430  		common.LogFields{"sessionID": sessionID}).Debug("client connected")
   431  
   432  	MTU := getMTU(server.config.MTU)
   433  
   434  	clientSession := server.getSession(sessionID)
   435  
   436  	if clientSession != nil {
   437  
   438  		// Call interruptSession to ensure session is in the
   439  		// expected idle state.
   440  
   441  		server.interruptSession(clientSession)
   442  
   443  		// Note: we don't check the session expiry; whether it has
   444  		// already expired and not yet been reaped; or is about
   445  		// to expire very shortly. It could happen that the reaper
   446  		// will kill this session between now and when the expiry
   447  		// is reset in the following resumeSession call. In this
   448  		// unlikely case, the packet tunnel client should reconnect.
   449  
   450  	} else {
   451  
   452  		// Store IPv4 resolver addresses in 4-byte representation
   453  		// for use in rewritting.
   454  		resolvers := server.config.GetDNSResolverIPv4Addresses()
   455  		DNSResolverIPv4Addresses := make([]net.IP, len(resolvers))
   456  		for i, resolver := range resolvers {
   457  			// Assumes To4 is non-nil
   458  			DNSResolverIPv4Addresses[i] = resolver.To4()
   459  		}
   460  
   461  		clientSession = &session{
   462  			allowBogons:              server.config.AllowBogons,
   463  			lastActivity:             int64(monotime.Now()),
   464  			sessionID:                sessionID,
   465  			metrics:                  new(packetMetrics),
   466  			enableDNSFlowTracking:    server.config.EnableDNSFlowTracking,
   467  			DNSResolverIPv4Addresses: append([]net.IP(nil), DNSResolverIPv4Addresses...),
   468  			DNSResolverIPv6Addresses: append([]net.IP(nil), server.config.GetDNSResolverIPv6Addresses()...),
   469  			workers:                  new(sync.WaitGroup),
   470  		}
   471  
   472  		// One-time, for this session, random resolver selection for TCP transparent
   473  		// DNS forwarding. See comment in processPacket.
   474  		if len(clientSession.DNSResolverIPv4Addresses) > 0 {
   475  			clientSession.TCPDNSResolverIPv4Index = prng.Intn(len(clientSession.DNSResolverIPv4Addresses))
   476  		}
   477  		if len(clientSession.DNSResolverIPv6Addresses) > 0 {
   478  			clientSession.TCPDNSResolverIPv6Index = prng.Intn(len(clientSession.DNSResolverIPv6Addresses))
   479  		}
   480  
   481  		// allocateIndex initializes session.index, session.assignedIPv4Address,
   482  		// and session.assignedIPv6Address; and updates server.indexToSession and
   483  		// server.sessionIDToIndex.
   484  
   485  		err := server.allocateIndex(clientSession)
   486  		if err != nil {
   487  			return errors.Trace(err)
   488  		}
   489  	}
   490  
   491  	// Note: it's possible that a client disconnects (or reconnects before a
   492  	// disconnect is detected) and interruptSession is called between
   493  	// allocateIndex and resumeSession calls here, so interruptSession and
   494  	// related code must not assume resumeSession has been called.
   495  
   496  	server.resumeSession(
   497  		clientSession,
   498  		NewChannel(transport, MTU),
   499  		checkAllowedTCPPortFunc,
   500  		checkAllowedUDPPortFunc,
   501  		checkAllowedDomainFunc,
   502  		flowActivityUpdaterMaker,
   503  		metricsUpdater,
   504  		dnsQualityReporter)
   505  
   506  	return nil
   507  }
   508  
   509  // ClientDisconnected handles clients disconnecting. Packet handlers
   510  // are halted, but the client session is left intact to reserve the
   511  // assigned IP addresses and retain network state in case the client
   512  // soon reconnects.
   513  func (server *Server) ClientDisconnected(sessionID string) {
   514  
   515  	session := server.getSession(sessionID)
   516  	if session != nil {
   517  
   518  		server.config.Logger.WithTraceFields(
   519  			common.LogFields{"sessionID": sessionID}).Debug("client disconnected")
   520  
   521  		server.interruptSession(session)
   522  	}
   523  }
   524  
   525  func (server *Server) getSession(sessionID string) *session {
   526  
   527  	if index, ok := server.sessionIDToIndex.Load(sessionID); ok {
   528  		s, ok := server.indexToSession.Load(index.(int32))
   529  		if ok {
   530  			return s.(*session)
   531  		}
   532  		server.config.Logger.WithTrace().Warning("unexpected missing session")
   533  	}
   534  	return nil
   535  }
   536  
   537  func (server *Server) resumeSession(
   538  	session *session,
   539  	channel *Channel,
   540  	checkAllowedTCPPortFunc, checkAllowedUDPPortFunc AllowedPortChecker,
   541  	checkAllowedDomainFunc AllowedDomainChecker,
   542  	flowActivityUpdaterMaker FlowActivityUpdaterMaker,
   543  	metricsUpdater MetricsUpdater,
   544  	dnsQualityReporter DNSQualityReporter) {
   545  
   546  	session.mutex.Lock()
   547  	defer session.mutex.Unlock()
   548  
   549  	// Performance/concurrency note: the downstream packet queue
   550  	// and various packet event callbacks may be accessed while
   551  	// the session is idle, via the runDeviceDownstream goroutine,
   552  	// which runs concurrent to resumeSession/interruptSession calls.
   553  	// Consequently, all accesses to these fields must be
   554  	// synchronized.
   555  	//
   556  	// Benchmarking indicates the atomic.LoadPointer mechanism
   557  	// outperforms a mutex; approx. 2 ns/op vs. 20 ns/op in the case
   558  	// of getCheckAllowedTCPPortFunc. Since these accesses occur
   559  	// multiple times per packet, atomic.LoadPointer is used and so
   560  	// each of these fields is an unsafe.Pointer in the session
   561  	// struct.
   562  
   563  	// Begin buffering downstream packets.
   564  
   565  	downstreamPacketQueueSize := DEFAULT_DOWNSTREAM_PACKET_QUEUE_SIZE
   566  	if server.config.DownstreamPacketQueueSize > 0 {
   567  		downstreamPacketQueueSize = server.config.DownstreamPacketQueueSize
   568  	}
   569  	downstreamPackets := NewPacketQueue(downstreamPacketQueueSize)
   570  
   571  	session.setDownstreamPackets(downstreamPackets)
   572  
   573  	// Set new access control, flow monitoring, and metrics
   574  	// callbacks; all associated with the new client connection.
   575  
   576  	// IMPORTANT: any new callbacks or references to the outer client added
   577  	// here must be cleared in interruptSession to ensure that a paused
   578  	// session does not retain references to old client connection objects
   579  	// after the client disconnects.
   580  
   581  	session.setCheckAllowedTCPPortFunc(&checkAllowedTCPPortFunc)
   582  
   583  	session.setCheckAllowedUDPPortFunc(&checkAllowedUDPPortFunc)
   584  
   585  	session.setCheckAllowedDomainFunc(&checkAllowedDomainFunc)
   586  
   587  	session.setFlowActivityUpdaterMaker(&flowActivityUpdaterMaker)
   588  
   589  	session.setMetricsUpdater(&metricsUpdater)
   590  
   591  	session.setDNSQualityReporter(&dnsQualityReporter)
   592  
   593  	session.channel = channel
   594  
   595  	// Parent context is not server.runContext so that session workers
   596  	// need only check session.stopRunning to act on shutdown events.
   597  	session.runContext, session.stopRunning = context.WithCancel(context.Background())
   598  
   599  	// When a session is interrupted, all goroutines in session.workers
   600  	// are joined. When the server is stopped, all goroutines in
   601  	// server.workers are joined. So, in both cases we synchronously
   602  	// stop all workers associated with this session.
   603  
   604  	session.workers.Add(1)
   605  	go server.runClientUpstream(session)
   606  
   607  	session.workers.Add(1)
   608  	go server.runClientDownstream(session)
   609  
   610  	session.touch()
   611  }
   612  
   613  func (server *Server) interruptSession(session *session) {
   614  
   615  	session.mutex.Lock()
   616  	defer session.mutex.Unlock()
   617  
   618  	wasRunning := (session.channel != nil)
   619  
   620  	if session.stopRunning != nil {
   621  		session.stopRunning()
   622  	}
   623  
   624  	if session.channel != nil {
   625  		// Interrupt blocked channel read/writes.
   626  		session.channel.Close()
   627  	}
   628  
   629  	session.workers.Wait()
   630  
   631  	if session.channel != nil {
   632  		// Don't hold a reference to channel, allowing both it and
   633  		// its conn to be garbage collected.
   634  		// Setting channel to nil must happen after workers.Wait()
   635  		// to ensure no goroutine remains which may access
   636  		// session.channel.
   637  		session.channel = nil
   638  	}
   639  
   640  	metricsUpdater := session.getMetricsUpdater()
   641  
   642  	// interruptSession may be called for idle sessions, to ensure
   643  	// the session is in an expected state: in ClientConnected,
   644  	// and in server.Stop(); don't log in those cases.
   645  	if wasRunning {
   646  		session.metrics.checkpoint(
   647  			server.config.Logger,
   648  			metricsUpdater,
   649  			"server_packet_metrics",
   650  			packetMetricsAll)
   651  	}
   652  
   653  	// Release the downstream packet buffer, so the associated
   654  	// memory is not consumed while no client is connected.
   655  	//
   656  	// Since runDeviceDownstream continues to run and will access
   657  	// session.downstreamPackets, an atomic pointer is used to
   658  	// synchronize access.
   659  	session.setDownstreamPackets(nil)
   660  
   661  	session.setCheckAllowedTCPPortFunc(nil)
   662  
   663  	session.setCheckAllowedUDPPortFunc(nil)
   664  
   665  	session.setCheckAllowedDomainFunc(nil)
   666  
   667  	session.setFlowActivityUpdaterMaker(nil)
   668  
   669  	session.setMetricsUpdater(nil)
   670  
   671  	session.setDNSQualityReporter(nil)
   672  }
   673  
   674  func (server *Server) runSessionReaper() {
   675  
   676  	defer server.workers.Done()
   677  
   678  	// Periodically iterate over all sessions and discard expired
   679  	// sessions. This action, removing the index from server.indexToSession,
   680  	// releases the IP addresses assigned  to the session.
   681  
   682  	// TODO: As-is, this will discard sessions for live SSH tunnels,
   683  	// as long as the SSH channel for such a session has been idle for
   684  	// a sufficient period. Should the session be retained as long as
   685  	// the SSH tunnel is alive (e.g., expose and call session.touch()
   686  	// on keepalive events)? Or is it better to free up resources held
   687  	// by idle sessions?
   688  
   689  	idleExpiry := server.sessionIdleExpiry()
   690  
   691  	ticker := time.NewTicker(idleExpiry / 2)
   692  	defer ticker.Stop()
   693  
   694  	for {
   695  		select {
   696  		case <-ticker.C:
   697  			server.indexToSession.Range(func(_, value interface{}) bool {
   698  				session := value.(*session)
   699  				if session.expired(idleExpiry) {
   700  					server.removeSession(session)
   701  				}
   702  				return true
   703  			})
   704  		case <-server.runContext.Done():
   705  			return
   706  		}
   707  	}
   708  }
   709  
   710  func (server *Server) sessionIdleExpiry() time.Duration {
   711  	sessionIdleExpirySeconds := DEFAULT_IDLE_SESSION_EXPIRY_SECONDS
   712  	if server.config.SessionIdleExpirySeconds > 2 {
   713  		sessionIdleExpirySeconds = server.config.SessionIdleExpirySeconds
   714  	}
   715  	return time.Duration(sessionIdleExpirySeconds) * time.Second
   716  }
   717  
   718  func (server *Server) removeSession(session *session) {
   719  	server.sessionIDToIndex.Delete(session.sessionID)
   720  	server.indexToSession.Delete(session.index)
   721  	server.interruptSession(session)
   722  
   723  	// Delete flows to ensure any pending flow metrics are reported.
   724  	session.deleteFlows()
   725  }
   726  
   727  func (server *Server) runOrphanMetricsCheckpointer() {
   728  
   729  	defer server.workers.Done()
   730  
   731  	// Periodically log orphan packet metrics. Orphan metrics
   732  	// are not associated with any session. This includes
   733  	// packets that are rejected before they can be associated
   734  	// with a session.
   735  
   736  	ticker := time.NewTicker(ORPHAN_METRICS_CHECKPOINTER_PERIOD)
   737  	defer ticker.Stop()
   738  
   739  	for {
   740  		done := false
   741  		select {
   742  		case <-ticker.C:
   743  		case <-server.runContext.Done():
   744  			done = true
   745  		}
   746  
   747  		// TODO: skip log if all zeros?
   748  		server.orphanMetrics.checkpoint(
   749  			server.config.Logger,
   750  			nil,
   751  			"server_orphan_packet_metrics",
   752  			packetMetricsRejected)
   753  		if done {
   754  			return
   755  		}
   756  	}
   757  }
   758  
   759  func (server *Server) runDeviceDownstream() {
   760  
   761  	defer server.workers.Done()
   762  
   763  	// Read incoming packets from the tun device, parse and validate the
   764  	// packets, map them to a session/client, perform rewriting, and relay
   765  	// the packets to the client.
   766  
   767  	for {
   768  		readPacket, err := server.device.ReadPacket()
   769  
   770  		select {
   771  		case <-server.runContext.Done():
   772  			// No error is logged as shutdown may have interrupted read.
   773  			return
   774  		default:
   775  		}
   776  
   777  		if err != nil {
   778  			server.config.Logger.WithTraceFields(
   779  				common.LogFields{"error": err}).Warning("read device packet failed")
   780  			// May be temporary error condition, keep reading.
   781  			continue
   782  		}
   783  
   784  		// destinationIPAddress determines which client receives this packet.
   785  		// At this point, only enough of the packet is inspected to determine
   786  		// this routing info; further validation happens in subsequent
   787  		// processPacket in runClientDownstream.
   788  
   789  		// Note that masquerading/NAT stands between the Internet and the tun
   790  		// device, so arbitrary packets cannot be sent through to this point.
   791  
   792  		// TODO: getPacketDestinationIPAddress and processPacket perform redundant
   793  		// packet parsing; refactor to avoid extra work?
   794  
   795  		destinationIPAddress, ok := getPacketDestinationIPAddress(
   796  			server.orphanMetrics, packetDirectionServerDownstream, readPacket)
   797  
   798  		if !ok {
   799  			// Packet is dropped. Reason will be counted in orphan metrics.
   800  			continue
   801  		}
   802  
   803  		// Map destination IP address to client session.
   804  
   805  		index := server.convertIPAddressToIndex(destinationIPAddress)
   806  		s, ok := server.indexToSession.Load(index)
   807  
   808  		if !ok {
   809  			server.orphanMetrics.rejectedPacket(
   810  				packetDirectionServerDownstream, packetRejectNoSession)
   811  			continue
   812  		}
   813  
   814  		session := s.(*session)
   815  
   816  		downstreamPackets := session.getDownstreamPackets()
   817  
   818  		// No downstreamPackets buffer is maintained when no client is
   819  		// connected, so the packet is dropped.
   820  
   821  		if downstreamPackets == nil {
   822  			server.orphanMetrics.rejectedPacket(
   823  				packetDirectionServerDownstream, packetRejectNoClient)
   824  			continue
   825  		}
   826  
   827  		// Simply enqueue the packet for client handling, and move on to
   828  		// read the next packet. The packet tunnel server multiplexes all
   829  		// client packets through a single tun device, so we must not block
   830  		// on client channel I/O here.
   831  		//
   832  		// When the queue is full, the packet is dropped. This is standard
   833  		// behavior for routers, VPN servers, etc.
   834  		//
   835  		// TODO: processPacket is performed here, instead of runClientDownstream,
   836  		// since packets are packed contiguously into the packet queue and if
   837  		// the packet it to be omitted, that should be done before enqueuing.
   838  		// The potential downside is that all packet processing is done in this
   839  		// single thread of execution, blocking the next packet for the next
   840  		// client. Try handing off the packet to another worker which will
   841  		// call processPacket and Enqueue?
   842  
   843  		// In downstream mode, processPacket rewrites the destination address
   844  		// to the original client source IP address, and also rewrites DNS
   845  		// packets. As documented in runClientUpstream, the original address
   846  		// should already be populated via an upstream packet; if not, the
   847  		// packet will be rejected.
   848  
   849  		if !processPacket(
   850  			session.metrics,
   851  			session,
   852  			packetDirectionServerDownstream,
   853  			readPacket) {
   854  			// Packet is rejected and dropped. Reason will be counted in metrics.
   855  			continue
   856  		}
   857  
   858  		downstreamPackets.Enqueue(readPacket)
   859  	}
   860  }
   861  
   862  func (server *Server) runClientUpstream(session *session) {
   863  
   864  	defer session.workers.Done()
   865  
   866  	// Read incoming packets from the client channel, validate the packets,
   867  	// perform rewriting, and send them through to the tun device.
   868  
   869  	for {
   870  		readPacket, err := session.channel.ReadPacket()
   871  
   872  		select {
   873  		case <-session.runContext.Done():
   874  			// No error is logged as shutdown may have interrupted read.
   875  			return
   876  		default:
   877  		}
   878  
   879  		if err != nil {
   880  
   881  			// Debug since channel I/O errors occur during normal operation.
   882  			server.config.Logger.WithTraceFields(
   883  				common.LogFields{"error": err}).Debug("read channel packet failed")
   884  
   885  			// Tear down the session. Must be invoked asynchronously.
   886  			go server.interruptSession(session)
   887  
   888  			return
   889  		}
   890  
   891  		session.touch()
   892  
   893  		// processPacket transparently rewrites the source address to the
   894  		// session's assigned address and rewrites the destination of any
   895  		// DNS packets destined to the target DNS resolver.
   896  		//
   897  		// The first time the source address is rewritten, the original
   898  		// value is recorded so inbound packets can have the reverse
   899  		// rewrite applied. This assumes that the client will send a
   900  		// packet before receiving any packet, which is the case since
   901  		// only clients can initiate TCP or UDP connections or flows.
   902  
   903  		if !processPacket(
   904  			session.metrics,
   905  			session,
   906  			packetDirectionServerUpstream,
   907  			readPacket) {
   908  
   909  			// Packet is rejected and dropped. Reason will be counted in metrics.
   910  			continue
   911  		}
   912  
   913  		err = server.device.WritePacket(readPacket)
   914  
   915  		if err != nil {
   916  			server.config.Logger.WithTraceFields(
   917  				common.LogFields{"error": err}).Warning("write device packet failed")
   918  			// May be temporary error condition, keep working. The packet is
   919  			// most likely dropped.
   920  			continue
   921  		}
   922  	}
   923  }
   924  
   925  func (server *Server) runClientDownstream(session *session) {
   926  
   927  	defer session.workers.Done()
   928  
   929  	// Dequeue, process, and relay packets to be sent to the client channel.
   930  
   931  	for {
   932  
   933  		downstreamPackets := session.getDownstreamPackets()
   934  		// Note: downstreamPackets will not be nil, since this goroutine only
   935  		// runs while the session has a connected client.
   936  
   937  		packetBuffer, ok := downstreamPackets.DequeueFramedPackets(session.runContext)
   938  		if !ok {
   939  			// Dequeue aborted due to session.runContext.Done()
   940  			return
   941  		}
   942  
   943  		err := session.channel.WriteFramedPackets(packetBuffer)
   944  		if err != nil {
   945  
   946  			// Debug since channel I/O errors occur during normal operation.
   947  			server.config.Logger.WithTraceFields(
   948  				common.LogFields{"error": err}).Debug("write channel packets failed")
   949  
   950  			downstreamPackets.Replace(packetBuffer)
   951  
   952  			// Tear down the session. Must be invoked asynchronously.
   953  			go server.interruptSession(session)
   954  
   955  			return
   956  		}
   957  
   958  		session.touch()
   959  
   960  		downstreamPackets.Replace(packetBuffer)
   961  	}
   962  }
   963  
   964  var (
   965  	serverIPv4AddressCIDR             = "10.0.0.1/8"
   966  	transparentDNSResolverIPv4Address = net.ParseIP("10.0.0.2").To4() // 4-byte for rewriting
   967  	_, privateSubnetIPv4, _           = net.ParseCIDR("10.0.0.0/8")
   968  	assignedIPv4AddressTemplate       = "10.%d.%d.%d"
   969  
   970  	serverIPv6AddressCIDR             = "fd19:ca83:e6d5:1c44:0000:0000:0000:0001/64"
   971  	transparentDNSResolverIPv6Address = net.ParseIP("fd19:ca83:e6d5:1c44:0000:0000:0000:0002")
   972  	_, privateSubnetIPv6, _           = net.ParseCIDR("fd19:ca83:e6d5:1c44::/64")
   973  	assignedIPv6AddressTemplate       = "fd19:ca83:e6d5:1c44:8c57:4434:ee%02x:%02x%02x"
   974  )
   975  
   976  func (server *Server) allocateIndex(newSession *session) error {
   977  
   978  	// Find and assign an available index in the 24-bit index space.
   979  	// The index directly maps to and so determines the assigned
   980  	// IPv4 and IPv6 addresses.
   981  
   982  	// Search is a random index selection followed by a linear probe.
   983  	// TODO: is this the most effective (fast on average, simple) algorithm?
   984  
   985  	max := 0x00FFFFFF
   986  
   987  	randomInt := prng.Intn(max + 1)
   988  
   989  	index := int32(randomInt)
   990  	index &= int32(max)
   991  
   992  	idleExpiry := server.sessionIdleExpiry()
   993  
   994  	for tries := 0; tries < 100000; index++ {
   995  
   996  		tries++
   997  
   998  		// The index/address space isn't exactly 24-bits:
   999  		// - 0 and 0x00FFFFFF are reserved since they map to
  1000  		//   the network identifier (10.0.0.0) and broadcast
  1001  		//   address (10.255.255.255) respectively
  1002  		// - 1 is reserved as the server tun device address,
  1003  		//   (10.0.0.1, and IPv6 equivalent)
  1004  		// - 2 is reserved as the transparent DNS target
  1005  		//   address (10.0.0.2, and IPv6 equivalent)
  1006  
  1007  		if index <= 2 {
  1008  			continue
  1009  		}
  1010  		if index == 0x00FFFFFF {
  1011  			index = 0
  1012  			continue
  1013  		}
  1014  
  1015  		IPv4Address := server.convertIndexToIPv4Address(index).To4()
  1016  		IPv6Address := server.convertIndexToIPv6Address(index)
  1017  
  1018  		// Ensure that the index converts to valid IPs. This is not expected
  1019  		// to fail, but continuing with nil IPs will silently misroute
  1020  		// packets with rewritten source IPs.
  1021  		if IPv4Address == nil || IPv6Address == nil {
  1022  			server.config.Logger.WithTraceFields(
  1023  				common.LogFields{"index": index}).Warning("convert index to IP address failed")
  1024  			continue
  1025  		}
  1026  
  1027  		if s, ok := server.indexToSession.LoadOrStore(index, newSession); ok {
  1028  			// Index is already in use or acquired concurrently.
  1029  			// If the existing session is expired, reap it and try again
  1030  			// to acquire it.
  1031  			existingSession := s.(*session)
  1032  			if existingSession.expired(idleExpiry) {
  1033  				server.removeSession(existingSession)
  1034  				// Try to acquire this index again. We can't fall through and
  1035  				// use this index as removeSession has cleared indexToSession.
  1036  				index--
  1037  			}
  1038  			continue
  1039  		}
  1040  
  1041  		// Note: the To4() for assignedIPv4Address is essential since
  1042  		// that address value is assumed to be 4 bytes when rewriting.
  1043  
  1044  		newSession.index = index
  1045  		newSession.assignedIPv4Address = IPv4Address
  1046  		newSession.assignedIPv6Address = IPv6Address
  1047  		server.sessionIDToIndex.Store(newSession.sessionID, index)
  1048  
  1049  		server.resetRouting(newSession.assignedIPv4Address, newSession.assignedIPv6Address)
  1050  
  1051  		return nil
  1052  	}
  1053  
  1054  	return errors.TraceNew("unallocated index not found")
  1055  }
  1056  
  1057  func (server *Server) resetRouting(IPv4Address, IPv6Address net.IP) {
  1058  
  1059  	// Attempt to clear the NAT table of any existing connection
  1060  	// states. This will prevent the (already unlikely) delivery
  1061  	// of packets to the wrong client when an assigned IP address is
  1062  	// recycled. Silently has no effect on some platforms, see
  1063  	// resetNATTables implementations.
  1064  
  1065  	err := resetNATTables(server.config, IPv4Address)
  1066  	if err != nil {
  1067  		server.config.Logger.WithTraceFields(
  1068  			common.LogFields{"error": err}).Warning("reset IPv4 routing failed")
  1069  
  1070  	}
  1071  
  1072  	err = resetNATTables(server.config, IPv6Address)
  1073  	if err != nil {
  1074  		server.config.Logger.WithTraceFields(
  1075  			common.LogFields{"error": err}).Warning("reset IPv6 routing failed")
  1076  
  1077  	}
  1078  }
  1079  
  1080  func (server *Server) convertIPAddressToIndex(IP net.IP) int32 {
  1081  	// Assumes IP is at least 3 bytes.
  1082  	size := len(IP)
  1083  	return int32(IP[size-3])<<16 | int32(IP[size-2])<<8 | int32(IP[size-1])
  1084  }
  1085  
  1086  func (server *Server) convertIndexToIPv4Address(index int32) net.IP {
  1087  	return net.ParseIP(
  1088  		fmt.Sprintf(
  1089  			assignedIPv4AddressTemplate,
  1090  			(index>>16)&0xFF,
  1091  			(index>>8)&0xFF,
  1092  			index&0xFF))
  1093  }
  1094  
  1095  func (server *Server) convertIndexToIPv6Address(index int32) net.IP {
  1096  	return net.ParseIP(
  1097  		fmt.Sprintf(
  1098  			assignedIPv6AddressTemplate,
  1099  			(index>>16)&0xFF,
  1100  			(index>>8)&0xFF,
  1101  			index&0xFF))
  1102  }
  1103  
  1104  // GetTransparentDNSResolverIPv4Address returns the static IPv4 address
  1105  // to use as a DNS resolver when transparent DNS rewriting is desired.
  1106  func GetTransparentDNSResolverIPv4Address() net.IP {
  1107  	return transparentDNSResolverIPv4Address
  1108  }
  1109  
  1110  // GetTransparentDNSResolverIPv6Address returns the static IPv6 address
  1111  // to use as a DNS resolver when transparent DNS rewriting is desired.
  1112  func GetTransparentDNSResolverIPv6Address() net.IP {
  1113  	return transparentDNSResolverIPv6Address
  1114  }
  1115  
  1116  type session struct {
  1117  	// Note: 64-bit ints used with atomic operations are placed
  1118  	// at the start of struct to ensure 64-bit alignment.
  1119  	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
  1120  	lastActivity             int64
  1121  	lastFlowReapIndex        int64
  1122  	downstreamPackets        unsafe.Pointer
  1123  	checkAllowedTCPPortFunc  unsafe.Pointer
  1124  	checkAllowedUDPPortFunc  unsafe.Pointer
  1125  	checkAllowedDomainFunc   unsafe.Pointer
  1126  	flowActivityUpdaterMaker unsafe.Pointer
  1127  	metricsUpdater           unsafe.Pointer
  1128  	dnsQualityReporter       unsafe.Pointer
  1129  
  1130  	allowBogons              bool
  1131  	metrics                  *packetMetrics
  1132  	sessionID                string
  1133  	index                    int32
  1134  	enableDNSFlowTracking    bool
  1135  	DNSResolverIPv4Addresses []net.IP
  1136  	TCPDNSResolverIPv4Index  int
  1137  	assignedIPv4Address      net.IP
  1138  	setOriginalIPv4Address   int32
  1139  	originalIPv4Address      net.IP
  1140  	DNSResolverIPv6Addresses []net.IP
  1141  	TCPDNSResolverIPv6Index  int
  1142  	assignedIPv6Address      net.IP
  1143  	setOriginalIPv6Address   int32
  1144  	originalIPv6Address      net.IP
  1145  	flows                    sync.Map
  1146  	workers                  *sync.WaitGroup
  1147  	mutex                    sync.Mutex
  1148  	channel                  *Channel
  1149  	runContext               context.Context
  1150  	stopRunning              context.CancelFunc
  1151  }
  1152  
  1153  func (session *session) touch() {
  1154  	atomic.StoreInt64(&session.lastActivity, int64(monotime.Now()))
  1155  }
  1156  
  1157  func (session *session) expired(idleExpiry time.Duration) bool {
  1158  	lastActivity := monotime.Time(atomic.LoadInt64(&session.lastActivity))
  1159  	return monotime.Since(lastActivity) > idleExpiry
  1160  }
  1161  
  1162  func (session *session) setOriginalIPv4AddressIfNotSet(IPAddress net.IP) {
  1163  	if !atomic.CompareAndSwapInt32(&session.setOriginalIPv4Address, 0, 1) {
  1164  		return
  1165  	}
  1166  	// Make a copy of IPAddress; don't reference a slice of a reusable
  1167  	// packet buffer, which will be overwritten.
  1168  	session.originalIPv4Address = net.IP(append([]byte(nil), []byte(IPAddress)...))
  1169  }
  1170  
  1171  func (session *session) getOriginalIPv4Address() net.IP {
  1172  	if atomic.LoadInt32(&session.setOriginalIPv4Address) == 0 {
  1173  		return nil
  1174  	}
  1175  	return session.originalIPv4Address
  1176  }
  1177  
  1178  func (session *session) setOriginalIPv6AddressIfNotSet(IPAddress net.IP) {
  1179  	if !atomic.CompareAndSwapInt32(&session.setOriginalIPv6Address, 0, 1) {
  1180  		return
  1181  	}
  1182  	// Make a copy of IPAddress.
  1183  	session.originalIPv6Address = net.IP(append([]byte(nil), []byte(IPAddress)...))
  1184  }
  1185  
  1186  func (session *session) getOriginalIPv6Address() net.IP {
  1187  	if atomic.LoadInt32(&session.setOriginalIPv6Address) == 0 {
  1188  		return nil
  1189  	}
  1190  	return session.originalIPv6Address
  1191  }
  1192  
  1193  func (session *session) setDownstreamPackets(p *PacketQueue) {
  1194  	atomic.StorePointer(&session.downstreamPackets, unsafe.Pointer(p))
  1195  }
  1196  
  1197  func (session *session) getDownstreamPackets() *PacketQueue {
  1198  	return (*PacketQueue)(atomic.LoadPointer(&session.downstreamPackets))
  1199  }
  1200  
  1201  func (session *session) setCheckAllowedTCPPortFunc(p *AllowedPortChecker) {
  1202  	atomic.StorePointer(&session.checkAllowedTCPPortFunc, unsafe.Pointer(p))
  1203  }
  1204  
  1205  func (session *session) getCheckAllowedTCPPortFunc() AllowedPortChecker {
  1206  	p := (*AllowedPortChecker)(atomic.LoadPointer(&session.checkAllowedTCPPortFunc))
  1207  	if p == nil {
  1208  		return nil
  1209  	}
  1210  	return *p
  1211  }
  1212  
  1213  func (session *session) setCheckAllowedUDPPortFunc(p *AllowedPortChecker) {
  1214  	atomic.StorePointer(&session.checkAllowedUDPPortFunc, unsafe.Pointer(p))
  1215  }
  1216  
  1217  func (session *session) getCheckAllowedUDPPortFunc() AllowedPortChecker {
  1218  	p := (*AllowedPortChecker)(atomic.LoadPointer(&session.checkAllowedUDPPortFunc))
  1219  	if p == nil {
  1220  		return nil
  1221  	}
  1222  	return *p
  1223  }
  1224  
  1225  func (session *session) setCheckAllowedDomainFunc(p *AllowedDomainChecker) {
  1226  	atomic.StorePointer(&session.checkAllowedDomainFunc, unsafe.Pointer(p))
  1227  }
  1228  
  1229  func (session *session) getCheckAllowedDomainFunc() AllowedDomainChecker {
  1230  	p := (*AllowedDomainChecker)(atomic.LoadPointer(&session.checkAllowedDomainFunc))
  1231  	if p == nil {
  1232  		return nil
  1233  	}
  1234  	return *p
  1235  }
  1236  
  1237  func (session *session) setFlowActivityUpdaterMaker(p *FlowActivityUpdaterMaker) {
  1238  	atomic.StorePointer(&session.flowActivityUpdaterMaker, unsafe.Pointer(p))
  1239  }
  1240  
  1241  func (session *session) getFlowActivityUpdaterMaker() FlowActivityUpdaterMaker {
  1242  	p := (*FlowActivityUpdaterMaker)(atomic.LoadPointer(&session.flowActivityUpdaterMaker))
  1243  	if p == nil {
  1244  		return nil
  1245  	}
  1246  	return *p
  1247  }
  1248  
  1249  func (session *session) setMetricsUpdater(p *MetricsUpdater) {
  1250  	atomic.StorePointer(&session.metricsUpdater, unsafe.Pointer(p))
  1251  }
  1252  
  1253  func (session *session) getMetricsUpdater() MetricsUpdater {
  1254  	p := (*MetricsUpdater)(atomic.LoadPointer(&session.metricsUpdater))
  1255  	if p == nil {
  1256  		return nil
  1257  	}
  1258  	return *p
  1259  }
  1260  
  1261  func (session *session) setDNSQualityReporter(p *DNSQualityReporter) {
  1262  	atomic.StorePointer(&session.dnsQualityReporter, unsafe.Pointer(p))
  1263  }
  1264  
  1265  func (session *session) getDNSQualityReporter() DNSQualityReporter {
  1266  	p := (*DNSQualityReporter)(atomic.LoadPointer(&session.dnsQualityReporter))
  1267  	if p == nil {
  1268  		return nil
  1269  	}
  1270  	return *p
  1271  }
  1272  
  1273  // flowID identifies an IP traffic flow using the conventional
  1274  // network 5-tuple. flowIDs track bidirectional flows.
  1275  type flowID struct {
  1276  	downstreamIPAddress [net.IPv6len]byte
  1277  	downstreamPort      uint16
  1278  	upstreamIPAddress   [net.IPv6len]byte
  1279  	upstreamPort        uint16
  1280  	protocol            internetProtocol
  1281  }
  1282  
  1283  // From: https://github.com/golang/go/blob/b88efc7e7ac15f9e0b5d8d9c82f870294f6a3839/src/net/ip.go#L55
  1284  var v4InV6Prefix = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}
  1285  
  1286  func (f *flowID) set(
  1287  	downstreamIPAddress net.IP,
  1288  	downstreamPort uint16,
  1289  	upstreamIPAddress net.IP,
  1290  	upstreamPort uint16,
  1291  	protocol internetProtocol) {
  1292  
  1293  	if len(downstreamIPAddress) == net.IPv4len {
  1294  		copy(f.downstreamIPAddress[:], v4InV6Prefix)
  1295  		copy(f.downstreamIPAddress[len(v4InV6Prefix):], downstreamIPAddress)
  1296  	} else { // net.IPv6len
  1297  		copy(f.downstreamIPAddress[:], downstreamIPAddress)
  1298  	}
  1299  	f.downstreamPort = downstreamPort
  1300  
  1301  	if len(upstreamIPAddress) == net.IPv4len {
  1302  		copy(f.upstreamIPAddress[:], v4InV6Prefix)
  1303  		copy(f.upstreamIPAddress[len(v4InV6Prefix):], upstreamIPAddress)
  1304  	} else { // net.IPv6len
  1305  		copy(f.upstreamIPAddress[:], upstreamIPAddress)
  1306  	}
  1307  	f.upstreamPort = upstreamPort
  1308  
  1309  	f.protocol = protocol
  1310  }
  1311  
  1312  type flowState struct {
  1313  	// Note: 64-bit ints used with atomic operations are placed
  1314  	// at the start of struct to ensure 64-bit alignment.
  1315  	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
  1316  	firstUpstreamPacketTime   int64
  1317  	lastUpstreamPacketTime    int64
  1318  	firstDownstreamPacketTime int64
  1319  	lastDownstreamPacketTime  int64
  1320  	isDNS                     bool
  1321  	dnsQualityReporter        DNSQualityReporter
  1322  	activityUpdaters          []FlowActivityUpdater
  1323  }
  1324  
  1325  func (flowState *flowState) expired(idleExpiry time.Duration) bool {
  1326  	now := monotime.Now()
  1327  
  1328  	// Traffic in either direction keeps the flow alive. Initially, only one of
  1329  	// lastUpstreamPacketTime or lastDownstreamPacketTime will be set by
  1330  	// startTrackingFlow, and the other value will be 0 and evaluate as expired.
  1331  
  1332  	return (now.Sub(monotime.Time(atomic.LoadInt64(&flowState.lastUpstreamPacketTime))) > idleExpiry) &&
  1333  		(now.Sub(monotime.Time(atomic.LoadInt64(&flowState.lastDownstreamPacketTime))) > idleExpiry)
  1334  }
  1335  
  1336  // isTrackingFlow checks if a flow is being tracked.
  1337  func (session *session) isTrackingFlow(ID flowID) bool {
  1338  
  1339  	f, ok := session.flows.Load(ID)
  1340  	if !ok {
  1341  		return false
  1342  	}
  1343  	flowState := f.(*flowState)
  1344  
  1345  	// Check if flow is expired but not yet reaped.
  1346  	if flowState.expired(FLOW_IDLE_EXPIRY) {
  1347  		session.deleteFlow(ID, flowState)
  1348  		return false
  1349  	}
  1350  
  1351  	return true
  1352  }
  1353  
  1354  // startTrackingFlow starts flow tracking for the flow identified
  1355  // by ID.
  1356  //
  1357  // Flow tracking is used to implement:
  1358  // - one-time permissions checks for a flow
  1359  // - OSLs
  1360  // - domain bytes transferred [TODO]
  1361  // - DNS quality metrics
  1362  //
  1363  // The applicationData from the first packet in the flow is
  1364  // inspected to determine any associated hostname, using HTTP or
  1365  // TLS payload. The session's FlowActivityUpdaterMaker is invoked
  1366  // to determine a list of updaters to track flow activity.
  1367  //
  1368  // Updaters receive reports with the number of application data
  1369  // bytes in each flow packet. This number, totalled for all packets
  1370  // in a flow, may exceed the total bytes transferred at the
  1371  // application level due to TCP retransmission. Currently, the flow
  1372  // tracking logic doesn't exclude retransmitted packets from update
  1373  // reporting.
  1374  //
  1375  // Flows are untracked after an idle expiry period. Transport
  1376  // protocol indicators of end of flow, such as FIN or RST for TCP,
  1377  // which may or may not appear in a flow, are not currently used.
  1378  //
  1379  // startTrackingFlow may be called from concurrent goroutines; if
  1380  // the flow is already tracked, it is simply updated.
  1381  func (session *session) startTrackingFlow(
  1382  	ID flowID,
  1383  	direction packetDirection,
  1384  	applicationData []byte,
  1385  	isDNS bool) {
  1386  
  1387  	now := int64(monotime.Now())
  1388  
  1389  	// Once every period, iterate over flows and reap expired entries.
  1390  	reapIndex := now / int64(monotime.Time(FLOW_IDLE_EXPIRY/2))
  1391  	previousReapIndex := atomic.LoadInt64(&session.lastFlowReapIndex)
  1392  	if reapIndex != previousReapIndex &&
  1393  		atomic.CompareAndSwapInt64(&session.lastFlowReapIndex, previousReapIndex, reapIndex) {
  1394  		session.reapFlows()
  1395  	}
  1396  
  1397  	var isTCP bool
  1398  	var hostname string
  1399  	if ID.protocol == internetProtocolTCP {
  1400  		// TODO: implement
  1401  		// hostname = common.ExtractHostnameFromTCPFlow(applicationData)
  1402  		isTCP = true
  1403  	}
  1404  
  1405  	var activityUpdaters []FlowActivityUpdater
  1406  
  1407  	// Don't incur activity monitor overhead for DNS requests
  1408  	if !isDNS {
  1409  		flowActivityUpdaterMaker := session.getFlowActivityUpdaterMaker()
  1410  		if flowActivityUpdaterMaker != nil {
  1411  			activityUpdaters = flowActivityUpdaterMaker(
  1412  				isTCP,
  1413  				hostname,
  1414  				net.IP(ID.upstreamIPAddress[:]))
  1415  		}
  1416  	}
  1417  
  1418  	flowState := &flowState{
  1419  		isDNS:              isDNS,
  1420  		activityUpdaters:   activityUpdaters,
  1421  		dnsQualityReporter: session.getDNSQualityReporter(),
  1422  	}
  1423  
  1424  	if direction == packetDirectionServerUpstream {
  1425  		flowState.firstUpstreamPacketTime = now
  1426  		flowState.lastUpstreamPacketTime = now
  1427  	} else {
  1428  		flowState.firstDownstreamPacketTime = now
  1429  		flowState.lastDownstreamPacketTime = now
  1430  	}
  1431  
  1432  	// LoadOrStore will retain any existing entry
  1433  	session.flows.LoadOrStore(ID, flowState)
  1434  
  1435  	session.updateFlow(ID, direction, applicationData)
  1436  }
  1437  
  1438  func (session *session) updateFlow(
  1439  	ID flowID,
  1440  	direction packetDirection,
  1441  	applicationData []byte) {
  1442  
  1443  	f, ok := session.flows.Load(ID)
  1444  	if !ok {
  1445  		return
  1446  	}
  1447  	flowState := f.(*flowState)
  1448  
  1449  	// Note: no expired check here, since caller is assumed to
  1450  	// have just called isTrackingFlow.
  1451  
  1452  	now := int64(monotime.Now())
  1453  	var upstreamBytes, downstreamBytes, durationNanoseconds int64
  1454  
  1455  	if direction == packetDirectionServerUpstream {
  1456  		upstreamBytes = int64(len(applicationData))
  1457  
  1458  		atomic.CompareAndSwapInt64(&flowState.firstUpstreamPacketTime, 0, now)
  1459  
  1460  		atomic.StoreInt64(&flowState.lastUpstreamPacketTime, now)
  1461  
  1462  	} else {
  1463  		downstreamBytes = int64(len(applicationData))
  1464  
  1465  		atomic.CompareAndSwapInt64(&flowState.firstDownstreamPacketTime, 0, now)
  1466  
  1467  		// Follows common.ActivityMonitoredConn semantics, where
  1468  		// duration is updated only for downstream activity. This
  1469  		// is intened to produce equivalent behaviour for port
  1470  		// forward clients (tracked with ActivityUpdaters) and
  1471  		// packet tunnel clients (tracked with FlowActivityUpdaters).
  1472  
  1473  		durationNanoseconds = now - atomic.SwapInt64(&flowState.lastDownstreamPacketTime, now)
  1474  	}
  1475  
  1476  	for _, updater := range flowState.activityUpdaters {
  1477  		updater.UpdateProgress(downstreamBytes, upstreamBytes, durationNanoseconds)
  1478  	}
  1479  }
  1480  
  1481  // deleteFlow stops tracking a flow and logs any outstanding metrics.
  1482  // flowState is passed in to avoid duplicating the lookup that all callers
  1483  // have already performed.
  1484  func (session *session) deleteFlow(ID flowID, flowState *flowState) {
  1485  
  1486  	if flowState.isDNS {
  1487  
  1488  		dnsStartTime := monotime.Time(
  1489  			atomic.LoadInt64(&flowState.firstUpstreamPacketTime))
  1490  
  1491  		if dnsStartTime > 0 {
  1492  
  1493  			// Record DNS quality metrics using a heuristic: if a packet was sent and
  1494  			// then a packet was received, assume the DNS request successfully received
  1495  			// a valid response; failure occurs when the resolver fails to provide a
  1496  			// response; a "no such host" response is still a success. Limitations: we
  1497  			// assume a resolver will not respond when, e.g., rate limiting; we ignore
  1498  			// subsequent requests made via the same UDP/TCP flow; deleteFlow may be
  1499  			// called only after the flow has expired, which adds some delay to the
  1500  			// recording of the DNS metric.
  1501  
  1502  			dnsEndTime := monotime.Time(
  1503  				atomic.LoadInt64(&flowState.firstDownstreamPacketTime))
  1504  
  1505  			dnsSuccess := true
  1506  			if dnsEndTime == 0 {
  1507  				dnsSuccess = false
  1508  				dnsEndTime = monotime.Now()
  1509  			}
  1510  
  1511  			resolveElapsedTime := dnsEndTime.Sub(dnsStartTime)
  1512  
  1513  			if flowState.dnsQualityReporter != nil {
  1514  				flowState.dnsQualityReporter(
  1515  					dnsSuccess,
  1516  					resolveElapsedTime,
  1517  					net.IP(ID.upstreamIPAddress[:]))
  1518  			}
  1519  		}
  1520  	}
  1521  
  1522  	session.flows.Delete(ID)
  1523  }
  1524  
  1525  // reapFlows removes expired idle flows.
  1526  func (session *session) reapFlows() {
  1527  	session.flows.Range(func(key, value interface{}) bool {
  1528  		flowState := value.(*flowState)
  1529  		if flowState.expired(FLOW_IDLE_EXPIRY) {
  1530  			session.deleteFlow(key.(flowID), flowState)
  1531  		}
  1532  		return true
  1533  	})
  1534  }
  1535  
  1536  // deleteFlows deletes all flows.
  1537  func (session *session) deleteFlows() {
  1538  	session.flows.Range(func(key, value interface{}) bool {
  1539  		session.deleteFlow(key.(flowID), value.(*flowState))
  1540  		return true
  1541  	})
  1542  }
  1543  
  1544  type packetMetrics struct {
  1545  	upstreamRejectReasons   [packetRejectReasonCount]int64
  1546  	downstreamRejectReasons [packetRejectReasonCount]int64
  1547  	TCPIPv4                 relayedPacketMetrics
  1548  	TCPIPv6                 relayedPacketMetrics
  1549  	UDPIPv4                 relayedPacketMetrics
  1550  	UDPIPv6                 relayedPacketMetrics
  1551  }
  1552  
  1553  type relayedPacketMetrics struct {
  1554  	packetsUp            int64
  1555  	packetsDown          int64
  1556  	bytesUp              int64
  1557  	bytesDown            int64
  1558  	applicationBytesUp   int64
  1559  	applicationBytesDown int64
  1560  }
  1561  
  1562  func (metrics *packetMetrics) rejectedPacket(
  1563  	direction packetDirection,
  1564  	reason packetRejectReason) {
  1565  
  1566  	if direction == packetDirectionServerUpstream ||
  1567  		direction == packetDirectionClientUpstream {
  1568  
  1569  		atomic.AddInt64(&metrics.upstreamRejectReasons[reason], 1)
  1570  
  1571  	} else { // packetDirectionDownstream
  1572  
  1573  		atomic.AddInt64(&metrics.downstreamRejectReasons[reason], 1)
  1574  
  1575  	}
  1576  }
  1577  
  1578  func (metrics *packetMetrics) relayedPacket(
  1579  	direction packetDirection,
  1580  	version int,
  1581  	protocol internetProtocol,
  1582  	packetLength, applicationDataLength int) {
  1583  
  1584  	var packetsMetric, bytesMetric, applicationBytesMetric *int64
  1585  
  1586  	if direction == packetDirectionServerUpstream ||
  1587  		direction == packetDirectionClientUpstream {
  1588  
  1589  		if version == 4 {
  1590  
  1591  			if protocol == internetProtocolTCP {
  1592  				packetsMetric = &metrics.TCPIPv4.packetsUp
  1593  				bytesMetric = &metrics.TCPIPv4.bytesUp
  1594  				applicationBytesMetric = &metrics.TCPIPv4.applicationBytesUp
  1595  			} else { // UDP
  1596  				packetsMetric = &metrics.UDPIPv4.packetsUp
  1597  				bytesMetric = &metrics.UDPIPv4.bytesUp
  1598  				applicationBytesMetric = &metrics.UDPIPv4.applicationBytesUp
  1599  			}
  1600  
  1601  		} else { // IPv6
  1602  
  1603  			if protocol == internetProtocolTCP {
  1604  				packetsMetric = &metrics.TCPIPv6.packetsUp
  1605  				bytesMetric = &metrics.TCPIPv6.bytesUp
  1606  				applicationBytesMetric = &metrics.TCPIPv6.applicationBytesUp
  1607  			} else { // UDP
  1608  				packetsMetric = &metrics.UDPIPv6.packetsUp
  1609  				bytesMetric = &metrics.UDPIPv6.bytesUp
  1610  				applicationBytesMetric = &metrics.UDPIPv6.applicationBytesUp
  1611  			}
  1612  		}
  1613  
  1614  	} else { // packetDirectionDownstream
  1615  
  1616  		if version == 4 {
  1617  
  1618  			if protocol == internetProtocolTCP {
  1619  				packetsMetric = &metrics.TCPIPv4.packetsDown
  1620  				bytesMetric = &metrics.TCPIPv4.bytesDown
  1621  				applicationBytesMetric = &metrics.TCPIPv4.applicationBytesDown
  1622  			} else { // UDP
  1623  				packetsMetric = &metrics.UDPIPv4.packetsDown
  1624  				bytesMetric = &metrics.UDPIPv4.bytesDown
  1625  				applicationBytesMetric = &metrics.UDPIPv4.applicationBytesDown
  1626  			}
  1627  
  1628  		} else { // IPv6
  1629  
  1630  			if protocol == internetProtocolTCP {
  1631  				packetsMetric = &metrics.TCPIPv6.packetsDown
  1632  				bytesMetric = &metrics.TCPIPv6.bytesDown
  1633  				applicationBytesMetric = &metrics.TCPIPv6.applicationBytesDown
  1634  			} else { // UDP
  1635  				packetsMetric = &metrics.UDPIPv6.packetsDown
  1636  				bytesMetric = &metrics.UDPIPv6.bytesDown
  1637  				applicationBytesMetric = &metrics.UDPIPv6.applicationBytesDown
  1638  			}
  1639  		}
  1640  	}
  1641  
  1642  	atomic.AddInt64(packetsMetric, 1)
  1643  	atomic.AddInt64(bytesMetric, int64(packetLength))
  1644  	atomic.AddInt64(applicationBytesMetric, int64(applicationDataLength))
  1645  }
  1646  
  1647  const (
  1648  	packetMetricsRejected = 1
  1649  	packetMetricsRelayed  = 2
  1650  	packetMetricsAll      = packetMetricsRejected | packetMetricsRelayed
  1651  )
  1652  
  1653  func (metrics *packetMetrics) checkpoint(
  1654  	logger common.Logger, updater MetricsUpdater, logName string, whichMetrics int) {
  1655  
  1656  	// Report all metric counters in a single log message. Each
  1657  	// counter is reset to 0 when added to the log.
  1658  
  1659  	logFields := make(common.LogFields)
  1660  
  1661  	if whichMetrics&packetMetricsRejected != 0 {
  1662  
  1663  		for i := 0; i < packetRejectReasonCount; i++ {
  1664  			logFields["upstream_packet_rejected_"+packetRejectReasonDescription(packetRejectReason(i))] =
  1665  				atomic.SwapInt64(&metrics.upstreamRejectReasons[i], 0)
  1666  			logFields["downstream_packet_rejected_"+packetRejectReasonDescription(packetRejectReason(i))] =
  1667  				atomic.SwapInt64(&metrics.downstreamRejectReasons[i], 0)
  1668  		}
  1669  	}
  1670  
  1671  	if whichMetrics&packetMetricsRelayed != 0 {
  1672  
  1673  		var TCPApplicationBytesUp, TCPApplicationBytesDown,
  1674  			UDPApplicationBytesUp, UDPApplicationBytesDown int64
  1675  
  1676  		relayedMetrics := []struct {
  1677  			prefix           string
  1678  			metrics          *relayedPacketMetrics
  1679  			updaterBytesUp   *int64
  1680  			updaterBytesDown *int64
  1681  		}{
  1682  			{"tcp_ipv4_", &metrics.TCPIPv4, &TCPApplicationBytesUp, &TCPApplicationBytesDown},
  1683  			{"tcp_ipv6_", &metrics.TCPIPv6, &TCPApplicationBytesUp, &TCPApplicationBytesDown},
  1684  			{"udp_ipv4_", &metrics.UDPIPv4, &UDPApplicationBytesUp, &UDPApplicationBytesDown},
  1685  			{"udp_ipv6_", &metrics.UDPIPv6, &UDPApplicationBytesUp, &UDPApplicationBytesDown},
  1686  		}
  1687  
  1688  		for _, r := range relayedMetrics {
  1689  
  1690  			applicationBytesUp := atomic.SwapInt64(&r.metrics.applicationBytesUp, 0)
  1691  			applicationBytesDown := atomic.SwapInt64(&r.metrics.applicationBytesDown, 0)
  1692  
  1693  			*r.updaterBytesUp += applicationBytesUp
  1694  			*r.updaterBytesDown += applicationBytesDown
  1695  
  1696  			logFields[r.prefix+"packets_up"] = atomic.SwapInt64(&r.metrics.packetsUp, 0)
  1697  			logFields[r.prefix+"packets_down"] = atomic.SwapInt64(&r.metrics.packetsDown, 0)
  1698  			logFields[r.prefix+"bytes_up"] = atomic.SwapInt64(&r.metrics.bytesUp, 0)
  1699  			logFields[r.prefix+"bytes_down"] = atomic.SwapInt64(&r.metrics.bytesDown, 0)
  1700  			logFields[r.prefix+"application_bytes_up"] = applicationBytesUp
  1701  			logFields[r.prefix+"application_bytes_down"] = applicationBytesDown
  1702  		}
  1703  
  1704  		if updater != nil {
  1705  			updater(
  1706  				TCPApplicationBytesUp, TCPApplicationBytesDown,
  1707  				UDPApplicationBytesUp, UDPApplicationBytesDown)
  1708  		}
  1709  	}
  1710  
  1711  	logger.LogMetric(logName, logFields)
  1712  }
  1713  
  1714  // PacketQueue is a fixed-size, preallocated queue of packets.
  1715  // Enqueued packets are packed into a contiguous buffer with channel
  1716  // framing, allowing the entire queue to be written to a channel
  1717  // in a single call.
  1718  // Reuse of the queue buffers avoids GC churn. To avoid memory use
  1719  // spikes when many clients connect and may disconnect before relaying
  1720  // packets, the packet queue buffers start small and grow when required,
  1721  // up to the maximum size, and then remain static.
  1722  type PacketQueue struct {
  1723  	maxSize      int
  1724  	emptyBuffers chan []byte
  1725  	activeBuffer chan []byte
  1726  }
  1727  
  1728  // NewPacketQueue creates a new PacketQueue.
  1729  // The caller must ensure that maxSize exceeds the
  1730  // packet MTU, or packets will will never enqueue.
  1731  func NewPacketQueue(maxSize int) *PacketQueue {
  1732  
  1733  	// Two buffers of size up to maxSize are allocated, to
  1734  	// allow packets to continue to enqueue while one buffer
  1735  	// is borrowed by the DequeueFramedPackets caller.
  1736  	//
  1737  	// TODO: is there a way to implement this without
  1738  	// allocating up to 2x maxSize bytes? A circular queue
  1739  	// won't work because we want DequeueFramedPackets
  1740  	// to return a contiguous buffer. Perhaps a Bip
  1741  	// Buffer would work here:
  1742  	// https://www.codeproject.com/Articles/3479/The-Bip-Buffer-The-Circular-Buffer-with-a-Twist
  1743  
  1744  	queue := &PacketQueue{
  1745  		maxSize:      maxSize,
  1746  		emptyBuffers: make(chan []byte, 2),
  1747  		activeBuffer: make(chan []byte, 1),
  1748  	}
  1749  
  1750  	queue.emptyBuffers <- make([]byte, 0)
  1751  	queue.emptyBuffers <- make([]byte, 0)
  1752  
  1753  	return queue
  1754  }
  1755  
  1756  // Enqueue adds a packet to the queue.
  1757  // If the queue is full, the packet is dropped.
  1758  // Enqueue is _not_ safe for concurrent calls.
  1759  func (queue *PacketQueue) Enqueue(packet []byte) {
  1760  
  1761  	var buffer []byte
  1762  
  1763  	select {
  1764  	case buffer = <-queue.activeBuffer:
  1765  	default:
  1766  		buffer = <-queue.emptyBuffers
  1767  	}
  1768  
  1769  	packetSize := len(packet)
  1770  
  1771  	if queue.maxSize-len(buffer) >= channelHeaderSize+packetSize {
  1772  		// Assumes len(packet)/MTU <= 64K
  1773  		var channelHeader [channelHeaderSize]byte
  1774  		binary.BigEndian.PutUint16(channelHeader[:], uint16(packetSize))
  1775  
  1776  		// Once the buffer has reached maxSize capacity
  1777  		// and been replaced (buffer = buffer[0:0]), these
  1778  		// appends should no longer allocate new memory and
  1779  		// should just copy to preallocated memory.
  1780  
  1781  		buffer = append(buffer, channelHeader[:]...)
  1782  		buffer = append(buffer, packet...)
  1783  	}
  1784  	// Else, queue is full, so drop packet.
  1785  
  1786  	queue.activeBuffer <- buffer
  1787  }
  1788  
  1789  // DequeueFramedPackets waits until at least one packet is
  1790  // enqueued, and then returns a packet buffer containing one
  1791  // or more framed packets. The returned buffer remains part
  1792  // of the PacketQueue structure and the caller _must_ replace
  1793  // the buffer by calling Replace.
  1794  // DequeueFramedPackets unblocks and returns false if it receives
  1795  // runContext.Done().
  1796  // DequeueFramedPackets is _not_ safe for concurrent calls.
  1797  func (queue *PacketQueue) DequeueFramedPackets(
  1798  	runContext context.Context) ([]byte, bool) {
  1799  
  1800  	var buffer []byte
  1801  
  1802  	select {
  1803  	case buffer = <-queue.activeBuffer:
  1804  	case <-runContext.Done():
  1805  		return nil, false
  1806  	}
  1807  
  1808  	return buffer, true
  1809  }
  1810  
  1811  // Replace returns the buffer to the PacketQueue to be
  1812  // reused.
  1813  // The input must be a return value from DequeueFramedPackets.
  1814  func (queue *PacketQueue) Replace(buffer []byte) {
  1815  
  1816  	buffer = buffer[0:0]
  1817  
  1818  	// This won't block (as long as it is a DequeueFramedPackets return value).
  1819  	queue.emptyBuffers <- buffer
  1820  }
  1821  
  1822  // ClientConfig specifies the configuration of a packet tunnel client.
  1823  type ClientConfig struct {
  1824  
  1825  	// Logger is used for logging events and metrics.
  1826  	Logger common.Logger
  1827  
  1828  	// SudoNetworkConfigCommands specifies whether to use "sudo"
  1829  	// when executing network configuration commands. See description
  1830  	// for ServerConfig.SudoNetworkConfigCommands.
  1831  	SudoNetworkConfigCommands bool
  1832  
  1833  	// AllowNoIPv6NetworkConfiguration indicates that failures while
  1834  	// configuring tun interfaces and routing for IPv6 are to be
  1835  	// logged as warnings only. See description for
  1836  	// ServerConfig.AllowNoIPv6NetworkConfiguration.
  1837  	AllowNoIPv6NetworkConfiguration bool
  1838  
  1839  	// MTU is the packet MTU value to use; this value
  1840  	// should be obtained from the packet tunnel server.
  1841  	// When MTU is 0, a default value is used.
  1842  	MTU int
  1843  
  1844  	// UpstreamPacketQueueSize specifies the size of the upstream
  1845  	// packet queue.
  1846  	// When UpstreamPacketQueueSize is 0, a default value tuned for
  1847  	// Psiphon is used.
  1848  	UpstreamPacketQueueSize int
  1849  
  1850  	// Transport is an established transport channel that
  1851  	// will be used to relay packets to and from a packet
  1852  	// tunnel server.
  1853  	Transport io.ReadWriteCloser
  1854  
  1855  	// TunFileDescriptor specifies a file descriptor to use to
  1856  	// read and write packets to be relayed to the client. When
  1857  	// TunFileDescriptor is specified, the Client will use this
  1858  	// existing tun device and not create its own; in this case,
  1859  	// network address and routing configuration is not performed
  1860  	// by the Client. As the packet tunnel server performs
  1861  	// transparent source IP address and DNS rewriting, the tun
  1862  	// device may have any assigned IP address, but should be
  1863  	// configured with the given MTU; and DNS should be configured
  1864  	// to use the transparent DNS target resolver addresses.
  1865  	// Set TunFileDescriptor to <= 0 to ignore this parameter
  1866  	// and create and configure a tun device.
  1867  	TunFileDescriptor int
  1868  
  1869  	// IPv4AddressCIDR is the IPv4 address and netmask to
  1870  	// assign to a newly created tun device.
  1871  	IPv4AddressCIDR string
  1872  
  1873  	// IPv6AddressCIDR is the IPv6 address and prefix to
  1874  	// assign to a newly created tun device.
  1875  	IPv6AddressCIDR string
  1876  
  1877  	// RouteDestinations are hosts (IPs) or networks (CIDRs)
  1878  	// to be configured to be routed through a newly
  1879  	// created tun device.
  1880  	RouteDestinations []string
  1881  }
  1882  
  1883  // Client is a packet tunnel client. A packet tunnel client
  1884  // relays packets between a local tun device and a packet
  1885  // tunnel server via a transport channel.
  1886  type Client struct {
  1887  	config          *ClientConfig
  1888  	device          *Device
  1889  	channel         *Channel
  1890  	upstreamPackets *PacketQueue
  1891  	metrics         *packetMetrics
  1892  	runContext      context.Context
  1893  	stopRunning     context.CancelFunc
  1894  	workers         *sync.WaitGroup
  1895  }
  1896  
  1897  // NewClient initializes a new Client. Unless using the
  1898  // TunFileDescriptor configuration parameter, a new tun
  1899  // device is created for the client.
  1900  func NewClient(config *ClientConfig) (*Client, error) {
  1901  
  1902  	var device *Device
  1903  	var err error
  1904  
  1905  	if config.TunFileDescriptor > 0 {
  1906  		device, err = NewClientDeviceFromFD(config)
  1907  	} else {
  1908  		device, err = NewClientDevice(config)
  1909  	}
  1910  
  1911  	if err != nil {
  1912  		return nil, errors.Trace(err)
  1913  	}
  1914  
  1915  	upstreamPacketQueueSize := DEFAULT_UPSTREAM_PACKET_QUEUE_SIZE
  1916  	if config.UpstreamPacketQueueSize > 0 {
  1917  		upstreamPacketQueueSize = config.UpstreamPacketQueueSize
  1918  	}
  1919  
  1920  	runContext, stopRunning := context.WithCancel(context.Background())
  1921  
  1922  	return &Client{
  1923  		config:          config,
  1924  		device:          device,
  1925  		channel:         NewChannel(config.Transport, getMTU(config.MTU)),
  1926  		upstreamPackets: NewPacketQueue(upstreamPacketQueueSize),
  1927  		metrics:         new(packetMetrics),
  1928  		runContext:      runContext,
  1929  		stopRunning:     stopRunning,
  1930  		workers:         new(sync.WaitGroup),
  1931  	}, nil
  1932  }
  1933  
  1934  // Start starts a client and returns with it running.
  1935  func (client *Client) Start() {
  1936  
  1937  	client.config.Logger.WithTrace().Info("starting")
  1938  
  1939  	client.workers.Add(1)
  1940  	go func() {
  1941  		defer client.workers.Done()
  1942  		for {
  1943  			readPacket, err := client.device.ReadPacket()
  1944  
  1945  			select {
  1946  			case <-client.runContext.Done():
  1947  				// No error is logged as shutdown may have interrupted read.
  1948  				return
  1949  			default:
  1950  			}
  1951  
  1952  			if err != nil {
  1953  				client.config.Logger.WithTraceFields(
  1954  					common.LogFields{"error": err}).Info("read device packet failed")
  1955  				// May be temporary error condition, keep working.
  1956  				continue
  1957  			}
  1958  
  1959  			// processPacket will check for packets the server will reject
  1960  			// and drop those without sending.
  1961  
  1962  			// Limitation: packet metrics, including successful relay count,
  1963  			// are incremented _before_ the packet is written to the channel.
  1964  
  1965  			if !processPacket(
  1966  				client.metrics,
  1967  				nil,
  1968  				packetDirectionClientUpstream,
  1969  				readPacket) {
  1970  				continue
  1971  			}
  1972  
  1973  			// Instead of immediately writing to the channel, the
  1974  			// packet is enqueued, which has the effect of batching
  1975  			// up IP packets into a single channel packet (for Psiphon,
  1976  			// an SSH packet) to minimize overhead and, as benchmarked,
  1977  			// improve throughput.
  1978  			// Packet will be dropped if queue is full.
  1979  
  1980  			client.upstreamPackets.Enqueue(readPacket)
  1981  		}
  1982  	}()
  1983  
  1984  	client.workers.Add(1)
  1985  	go func() {
  1986  		defer client.workers.Done()
  1987  		for {
  1988  			packetBuffer, ok := client.upstreamPackets.DequeueFramedPackets(client.runContext)
  1989  			if !ok {
  1990  				// Dequeue aborted due to session.runContext.Done()
  1991  				return
  1992  			}
  1993  
  1994  			err := client.channel.WriteFramedPackets(packetBuffer)
  1995  
  1996  			client.upstreamPackets.Replace(packetBuffer)
  1997  
  1998  			if err != nil {
  1999  				client.config.Logger.WithTraceFields(
  2000  					common.LogFields{"error": err}).Info("write channel packets failed")
  2001  				// May be temporary error condition, such as reconnecting the tunnel;
  2002  				// keep working. The packets are most likely dropped.
  2003  				continue
  2004  			}
  2005  		}
  2006  	}()
  2007  
  2008  	client.workers.Add(1)
  2009  	go func() {
  2010  		defer client.workers.Done()
  2011  		for {
  2012  			readPacket, err := client.channel.ReadPacket()
  2013  
  2014  			select {
  2015  			case <-client.runContext.Done():
  2016  				// No error is logged as shutdown may have interrupted read.
  2017  				return
  2018  			default:
  2019  			}
  2020  
  2021  			if err != nil {
  2022  				client.config.Logger.WithTraceFields(
  2023  					common.LogFields{"error": err}).Info("read channel packet failed")
  2024  				// May be temporary error condition, such as reconnecting the tunnel;
  2025  				// keep working.
  2026  				continue
  2027  			}
  2028  
  2029  			if !processPacket(
  2030  				client.metrics,
  2031  				nil,
  2032  				packetDirectionClientDownstream,
  2033  				readPacket) {
  2034  				continue
  2035  			}
  2036  
  2037  			err = client.device.WritePacket(readPacket)
  2038  
  2039  			if err != nil {
  2040  				client.config.Logger.WithTraceFields(
  2041  					common.LogFields{"error": err}).Info("write device packet failed")
  2042  				// May be temporary error condition, keep working. The packet is
  2043  				// most likely dropped.
  2044  				continue
  2045  			}
  2046  		}
  2047  	}()
  2048  }
  2049  
  2050  // Stop halts a running client.
  2051  func (client *Client) Stop() {
  2052  
  2053  	client.config.Logger.WithTrace().Info("stopping")
  2054  
  2055  	client.stopRunning()
  2056  	client.device.Close()
  2057  	client.channel.Close()
  2058  
  2059  	client.workers.Wait()
  2060  
  2061  	client.metrics.checkpoint(
  2062  		client.config.Logger, nil, "packet_metrics", packetMetricsAll)
  2063  
  2064  	client.config.Logger.WithTrace().Info("stopped")
  2065  }
  2066  
  2067  /*
  2068     Packet offset constants in getPacketDestinationIPAddress and
  2069     processPacket are from the following RFC definitions.
  2070  
  2071  
  2072     IPv4 header: https://tools.ietf.org/html/rfc791
  2073  
  2074      0                   1                   2                   3
  2075      0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
  2076     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2077     |Version|  IHL  |Type of Service|          Total Length         |
  2078     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2079     |         Identification        |Flags|      Fragment Offset    |
  2080     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2081     |  Time to Live |    Protocol   |         Header Checksum       |
  2082     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2083     |                       Source Address                          |
  2084     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2085     |                    Destination Address                        |
  2086     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2087     |                    Options                    |    Padding    |
  2088     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2089  
  2090     IPv6 header: https://tools.ietf.org/html/rfc2460
  2091  
  2092      0                   1                   2                   3
  2093      0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
  2094     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2095     |Version| Traffic Class |           Flow Label                  |
  2096     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2097     |         Payload Length        |  Next Header  |   Hop Limit   |
  2098     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2099     |                                                               |
  2100     +                                                               +
  2101     |                                                               |
  2102     +                         Source Address                        +
  2103     |                                                               |
  2104     +                                                               +
  2105     |                                                               |
  2106     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2107     |                                                               |
  2108     +                                                               +
  2109     |                                                               |
  2110     +                      Destination Address                      +
  2111     |                                                               |
  2112     +                                                               +
  2113     |                                                               |
  2114     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2115  
  2116     TCP header: https://tools.ietf.org/html/rfc793
  2117  
  2118      0                   1                   2                   3
  2119      0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
  2120     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2121     |          Source Port          |       Destination Port        |
  2122     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2123     |                        Sequence Number                        |
  2124     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2125     |                    Acknowledgment Number                      |
  2126     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2127     |  Data |           |U|A|P|R|S|F|                               |
  2128     | Offset| Reserved  |R|C|S|S|Y|I|            Window             |
  2129     |       |           |G|K|H|T|N|N|                               |
  2130     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2131     |           Checksum            |         Urgent Pointer        |
  2132     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2133     |                    Options                    |    Padding    |
  2134     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2135     |                             data                              |
  2136     +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  2137  
  2138     UDP header: https://tools.ietf.org/html/rfc768
  2139  
  2140                    0      7 8     15 16    23 24    31
  2141                   +--------+--------+--------+--------+
  2142                   |     Source      |   Destination   |
  2143                   |      Port       |      Port       |
  2144                   +--------+--------+--------+--------+
  2145                   |                 |                 |
  2146                   |     Length      |    Checksum     |
  2147                   +--------+--------+--------+--------+
  2148                   |
  2149                   |          data octets ...
  2150                   +---------------- ...
  2151  */
  2152  
  2153  const (
  2154  	packetDirectionServerUpstream   = 0
  2155  	packetDirectionServerDownstream = 1
  2156  	packetDirectionClientUpstream   = 2
  2157  	packetDirectionClientDownstream = 3
  2158  
  2159  	internetProtocolTCP = 6
  2160  	internetProtocolUDP = 17
  2161  
  2162  	portNumberDNS = 53
  2163  
  2164  	packetRejectNoSession          = 0
  2165  	packetRejectDestinationAddress = 1
  2166  	packetRejectLength             = 2
  2167  	packetRejectVersion            = 3
  2168  	packetRejectOptions            = 4
  2169  	packetRejectProtocol           = 5
  2170  	packetRejectTCPProtocolLength  = 6
  2171  	packetRejectUDPProtocolLength  = 7
  2172  	packetRejectTCPPort            = 8
  2173  	packetRejectUDPPort            = 9
  2174  	packetRejectNoOriginalAddress  = 10
  2175  	packetRejectNoDNSResolvers     = 11
  2176  	packetRejectInvalidDNSMessage  = 12
  2177  	packetRejectDisallowedDomain   = 13
  2178  	packetRejectNoClient           = 14
  2179  	packetRejectReasonCount        = 15
  2180  	packetOk                       = 15
  2181  )
  2182  
  2183  type packetDirection int
  2184  type internetProtocol int
  2185  type packetRejectReason int
  2186  
  2187  func packetRejectReasonDescription(reason packetRejectReason) string {
  2188  
  2189  	// Description strings follow the metrics naming
  2190  	// convention: all lowercase; underscore seperators.
  2191  
  2192  	switch reason {
  2193  	case packetRejectNoSession:
  2194  		return "no_session"
  2195  	case packetRejectDestinationAddress:
  2196  		return "invalid_destination_address"
  2197  	case packetRejectLength:
  2198  		return "invalid_ip_packet_length"
  2199  	case packetRejectVersion:
  2200  		return "invalid_ip_header_version"
  2201  	case packetRejectOptions:
  2202  		return "invalid_ip_header_options"
  2203  	case packetRejectProtocol:
  2204  		return "invalid_ip_header_protocol"
  2205  	case packetRejectTCPProtocolLength:
  2206  		return "invalid_tcp_packet_length"
  2207  	case packetRejectUDPProtocolLength:
  2208  		return "invalid_tcp_packet_length"
  2209  	case packetRejectTCPPort:
  2210  		return "disallowed_tcp_destination_port"
  2211  	case packetRejectUDPPort:
  2212  		return "disallowed_udp_destination_port"
  2213  	case packetRejectNoOriginalAddress:
  2214  		return "no_original_address"
  2215  	case packetRejectNoDNSResolvers:
  2216  		return "no_dns_resolvers"
  2217  	case packetRejectInvalidDNSMessage:
  2218  		return "invalid_dns_message"
  2219  	case packetRejectDisallowedDomain:
  2220  		return "disallowed_domain"
  2221  	case packetRejectNoClient:
  2222  		return "no_client"
  2223  	}
  2224  
  2225  	return "unknown_reason"
  2226  }
  2227  
  2228  // Caller: the destination IP address return value is
  2229  // a slice of the packet input value and only valid while
  2230  // the packet buffer remains valid.
  2231  func getPacketDestinationIPAddress(
  2232  	metrics *packetMetrics,
  2233  	direction packetDirection,
  2234  	packet []byte) (net.IP, bool) {
  2235  
  2236  	// TODO: this function duplicates a subset of the packet
  2237  	// parsing code in processPacket. Refactor to reuse code;
  2238  	// also, both getPacketDestinationIPAddress and processPacket
  2239  	// are called for some packets; refactor to only parse once.
  2240  
  2241  	if len(packet) < 1 {
  2242  		metrics.rejectedPacket(direction, packetRejectLength)
  2243  		return nil, false
  2244  	}
  2245  
  2246  	version := packet[0] >> 4
  2247  
  2248  	if version != 4 && version != 6 {
  2249  		metrics.rejectedPacket(direction, packetRejectVersion)
  2250  		return nil, false
  2251  	}
  2252  
  2253  	if version == 4 {
  2254  		if len(packet) < 20 {
  2255  			metrics.rejectedPacket(direction, packetRejectLength)
  2256  			return nil, false
  2257  		}
  2258  
  2259  		return packet[16:20], true
  2260  
  2261  	} else { // IPv6
  2262  		if len(packet) < 40 {
  2263  			metrics.rejectedPacket(direction, packetRejectLength)
  2264  			return nil, false
  2265  		}
  2266  
  2267  		return packet[24:40], true
  2268  	}
  2269  }
  2270  
  2271  // processPacket parses IP packets, applies relaying rules,
  2272  // and rewrites packet elements as required. processPacket
  2273  // returns true if a packet parses correctly, is accepted
  2274  // by the relay rules, and is successfully rewritten.
  2275  //
  2276  // When a packet is rejected, processPacket returns false
  2277  // and updates a reason in the supplied metrics.
  2278  //
  2279  // Rejection may result in partially rewritten packets.
  2280  func processPacket(
  2281  	metrics *packetMetrics,
  2282  	session *session,
  2283  	direction packetDirection,
  2284  	packet []byte) bool {
  2285  
  2286  	// Parse and validate IP packet structure
  2287  
  2288  	// Must have an IP version field.
  2289  	if len(packet) < 1 {
  2290  		metrics.rejectedPacket(direction, packetRejectLength)
  2291  		return false
  2292  	}
  2293  
  2294  	version := packet[0] >> 4
  2295  
  2296  	// Must be IPv4 or IPv6.
  2297  	if version != 4 && version != 6 {
  2298  		metrics.rejectedPacket(direction, packetRejectVersion)
  2299  		return false
  2300  	}
  2301  
  2302  	var protocol internetProtocol
  2303  	var sourceIPAddress, destinationIPAddress net.IP
  2304  	var sourcePort, destinationPort uint16
  2305  	var IPChecksum, TCPChecksum, UDPChecksum []byte
  2306  	var applicationData []byte
  2307  
  2308  	if version == 4 {
  2309  
  2310  		// IHL must be 5: options are not supported; a fixed
  2311  		// 20 byte header is expected.
  2312  
  2313  		headerLength := packet[0] & 0x0F
  2314  
  2315  		if headerLength != 5 {
  2316  			metrics.rejectedPacket(direction, packetRejectOptions)
  2317  			return false
  2318  		}
  2319  
  2320  		if len(packet) < 20 {
  2321  			metrics.rejectedPacket(direction, packetRejectLength)
  2322  			return false
  2323  		}
  2324  
  2325  		// Protocol must be TCP or UDP.
  2326  
  2327  		protocol = internetProtocol(packet[9])
  2328  		dataOffset := 0
  2329  
  2330  		if protocol == internetProtocolTCP {
  2331  			if len(packet) < 33 {
  2332  				metrics.rejectedPacket(direction, packetRejectTCPProtocolLength)
  2333  				return false
  2334  			}
  2335  			dataOffset = 20 + 4*int(packet[32]>>4)
  2336  			if len(packet) < dataOffset {
  2337  				metrics.rejectedPacket(direction, packetRejectTCPProtocolLength)
  2338  				return false
  2339  			}
  2340  		} else if protocol == internetProtocolUDP {
  2341  			dataOffset = 28
  2342  			if len(packet) < dataOffset {
  2343  				metrics.rejectedPacket(direction, packetRejectUDPProtocolLength)
  2344  				return false
  2345  			}
  2346  		} else {
  2347  			metrics.rejectedPacket(direction, packetRejectProtocol)
  2348  			return false
  2349  		}
  2350  
  2351  		applicationData = packet[dataOffset:]
  2352  
  2353  		// Slices reference packet bytes to be rewritten.
  2354  
  2355  		sourceIPAddress = packet[12:16]
  2356  		destinationIPAddress = packet[16:20]
  2357  		IPChecksum = packet[10:12]
  2358  
  2359  		// Port numbers have the same offset in TCP and UDP.
  2360  
  2361  		sourcePort = binary.BigEndian.Uint16(packet[20:22])
  2362  		destinationPort = binary.BigEndian.Uint16(packet[22:24])
  2363  
  2364  		if protocol == internetProtocolTCP {
  2365  			TCPChecksum = packet[36:38]
  2366  		} else { // UDP
  2367  			UDPChecksum = packet[26:28]
  2368  		}
  2369  
  2370  	} else { // IPv6
  2371  
  2372  		if len(packet) < 40 {
  2373  			metrics.rejectedPacket(direction, packetRejectLength)
  2374  			return false
  2375  		}
  2376  
  2377  		// Next Header must be TCP or UDP.
  2378  
  2379  		nextHeader := packet[6]
  2380  
  2381  		protocol = internetProtocol(nextHeader)
  2382  		dataOffset := 0
  2383  
  2384  		if protocol == internetProtocolTCP {
  2385  			if len(packet) < 53 {
  2386  				metrics.rejectedPacket(direction, packetRejectTCPProtocolLength)
  2387  				return false
  2388  			}
  2389  			dataOffset = 40 + 4*int(packet[52]>>4)
  2390  			if len(packet) < dataOffset {
  2391  				metrics.rejectedPacket(direction, packetRejectTCPProtocolLength)
  2392  				return false
  2393  			}
  2394  		} else if protocol == internetProtocolUDP {
  2395  			dataOffset = 48
  2396  			if len(packet) < dataOffset {
  2397  				metrics.rejectedPacket(direction, packetRejectUDPProtocolLength)
  2398  				return false
  2399  			}
  2400  		} else {
  2401  			metrics.rejectedPacket(direction, packetRejectProtocol)
  2402  			return false
  2403  		}
  2404  
  2405  		applicationData = packet[dataOffset:]
  2406  
  2407  		// Slices reference packet bytes to be rewritten.
  2408  
  2409  		sourceIPAddress = packet[8:24]
  2410  		destinationIPAddress = packet[24:40]
  2411  
  2412  		// Port numbers have the same offset in TCP and UDP.
  2413  
  2414  		sourcePort = binary.BigEndian.Uint16(packet[40:42])
  2415  		destinationPort = binary.BigEndian.Uint16(packet[42:44])
  2416  
  2417  		if protocol == internetProtocolTCP {
  2418  			TCPChecksum = packet[56:58]
  2419  		} else { // UDP
  2420  			UDPChecksum = packet[46:48]
  2421  		}
  2422  	}
  2423  
  2424  	// Apply rules
  2425  	//
  2426  	// Most of this logic is only applied on the server, as only
  2427  	// the server knows the traffic rules configuration, and is
  2428  	// tracking flows.
  2429  
  2430  	isServer := (direction == packetDirectionServerUpstream ||
  2431  		direction == packetDirectionServerDownstream)
  2432  
  2433  	// Check if the packet qualifies for transparent DNS rewriting
  2434  	//
  2435  	// - Both TCP and UDP DNS packets may qualify
  2436  	// - Unless configured, transparent DNS flows are not tracked,
  2437  	//   as most DNS resolutions are very-short lived exchanges
  2438  	// - The traffic rules checks are bypassed, since transparent
  2439  	//   DNS is essential
  2440  
  2441  	doTransparentDNS := false
  2442  
  2443  	if isServer {
  2444  		if direction == packetDirectionServerUpstream {
  2445  
  2446  			// DNS packets destinated for the transparent DNS target addresses
  2447  			// will be rewritten to go to one of the server's resolvers.
  2448  
  2449  			if destinationPort == portNumberDNS {
  2450  				if version == 4 &&
  2451  					destinationIPAddress.Equal(transparentDNSResolverIPv4Address) {
  2452  
  2453  					numResolvers := len(session.DNSResolverIPv4Addresses)
  2454  					if numResolvers > 0 {
  2455  						doTransparentDNS = true
  2456  					} else {
  2457  						metrics.rejectedPacket(direction, packetRejectNoDNSResolvers)
  2458  						return false
  2459  					}
  2460  
  2461  				} else if version == 6 &&
  2462  					destinationIPAddress.Equal(transparentDNSResolverIPv6Address) {
  2463  
  2464  					numResolvers := len(session.DNSResolverIPv6Addresses)
  2465  					if numResolvers > 0 {
  2466  						doTransparentDNS = true
  2467  					} else {
  2468  						metrics.rejectedPacket(direction, packetRejectNoDNSResolvers)
  2469  						return false
  2470  					}
  2471  				}
  2472  
  2473  				// Limitation: checkAllowedDomainFunc is applied only to DNS queries in
  2474  				// UDP; currently DNS-over-TCP will bypass the domain block list check.
  2475  
  2476  				if doTransparentDNS && protocol == internetProtocolUDP {
  2477  
  2478  					domain, err := common.ParseDNSQuestion(applicationData)
  2479  					if err != nil {
  2480  						metrics.rejectedPacket(direction, packetRejectInvalidDNSMessage)
  2481  						return false
  2482  					}
  2483  					if domain != "" {
  2484  						checkAllowedDomainFunc := session.getCheckAllowedDomainFunc()
  2485  						if !checkAllowedDomainFunc(domain) {
  2486  							metrics.rejectedPacket(direction, packetRejectDisallowedDomain)
  2487  							return false
  2488  						}
  2489  					}
  2490  				}
  2491  			}
  2492  
  2493  		} else { // packetDirectionServerDownstream
  2494  
  2495  			// DNS packets with a source address of any of the server's
  2496  			// resolvers will be rewritten back to the transparent DNS target
  2497  			// address.
  2498  
  2499  			// Limitation: responses to client DNS packets _originally
  2500  			// destined_ for a resolver in GetDNSResolverIPv4Addresses will
  2501  			// be lost. This would happen if some process on the client
  2502  			// ignores the system set DNS values; and forces use of the same
  2503  			// resolvers as the server.
  2504  
  2505  			if sourcePort == portNumberDNS {
  2506  				if version == 4 {
  2507  					for _, IPAddress := range session.DNSResolverIPv4Addresses {
  2508  						if sourceIPAddress.Equal(IPAddress) {
  2509  							doTransparentDNS = true
  2510  							break
  2511  						}
  2512  					}
  2513  				} else if version == 6 {
  2514  					for _, IPAddress := range session.DNSResolverIPv6Addresses {
  2515  						if sourceIPAddress.Equal(IPAddress) {
  2516  							doTransparentDNS = true
  2517  							break
  2518  						}
  2519  					}
  2520  				}
  2521  			}
  2522  		}
  2523  	}
  2524  
  2525  	// Apply rewrites before determining flow ID to ensure that corresponding up-
  2526  	// and downstream flows yield the same flow ID.
  2527  
  2528  	var rewriteSourceIPAddress, rewriteDestinationIPAddress net.IP
  2529  
  2530  	if direction == packetDirectionServerUpstream {
  2531  
  2532  		// Store original source IP address to be replaced in
  2533  		// downstream rewriting.
  2534  
  2535  		if version == 4 {
  2536  			session.setOriginalIPv4AddressIfNotSet(sourceIPAddress)
  2537  			rewriteSourceIPAddress = session.assignedIPv4Address
  2538  		} else { // version == 6
  2539  			session.setOriginalIPv6AddressIfNotSet(sourceIPAddress)
  2540  			rewriteSourceIPAddress = session.assignedIPv6Address
  2541  		}
  2542  
  2543  		// Rewrite DNS packets destinated for the transparent DNS target addresses
  2544  		// to go to one of the server's resolvers. This random selection uses
  2545  		// math/rand to minimize overhead.
  2546  		//
  2547  		// Limitation: TCP packets are always assigned to the same resolver, as
  2548  		// currently there is no method for tracking the assigned resolver per TCP
  2549  		// flow.
  2550  
  2551  		if doTransparentDNS {
  2552  			if version == 4 {
  2553  
  2554  				index := session.TCPDNSResolverIPv4Index
  2555  				if protocol == internetProtocolUDP {
  2556  					index = rand.Intn(len(session.DNSResolverIPv4Addresses))
  2557  				}
  2558  				rewriteDestinationIPAddress = session.DNSResolverIPv4Addresses[index]
  2559  
  2560  			} else { // version == 6
  2561  
  2562  				index := session.TCPDNSResolverIPv6Index
  2563  				if protocol == internetProtocolUDP {
  2564  					index = rand.Intn(len(session.DNSResolverIPv6Addresses))
  2565  				}
  2566  				rewriteDestinationIPAddress = session.DNSResolverIPv6Addresses[index]
  2567  			}
  2568  		}
  2569  
  2570  	} else if direction == packetDirectionServerDownstream {
  2571  
  2572  		// Destination address will be original source address.
  2573  
  2574  		if version == 4 {
  2575  			rewriteDestinationIPAddress = session.getOriginalIPv4Address()
  2576  		} else { // version == 6
  2577  			rewriteDestinationIPAddress = session.getOriginalIPv6Address()
  2578  		}
  2579  
  2580  		if rewriteDestinationIPAddress == nil {
  2581  			metrics.rejectedPacket(direction, packetRejectNoOriginalAddress)
  2582  			return false
  2583  		}
  2584  
  2585  		// Rewrite source address of packets from servers' resolvers
  2586  		// to transparent DNS target address.
  2587  
  2588  		if doTransparentDNS {
  2589  
  2590  			if version == 4 {
  2591  				rewriteSourceIPAddress = transparentDNSResolverIPv4Address
  2592  			} else { // version == 6
  2593  				rewriteSourceIPAddress = transparentDNSResolverIPv6Address
  2594  			}
  2595  		}
  2596  	}
  2597  
  2598  	// Check if flow is tracked before checking traffic permission
  2599  
  2600  	doFlowTracking := isServer && (!doTransparentDNS || session.enableDNSFlowTracking)
  2601  
  2602  	// TODO: verify this struct is stack allocated
  2603  	var ID flowID
  2604  
  2605  	isTrackingFlow := false
  2606  
  2607  	if doFlowTracking {
  2608  
  2609  		if direction == packetDirectionServerUpstream {
  2610  
  2611  			// Reflect rewrites in the upstream case and don't reflect rewrites in the
  2612  			// following downstream case: all flow IDs are in the upstream space, with
  2613  			// the assigned private IP for the client and, in the case of DNS, the
  2614  			// actual resolver IP.
  2615  
  2616  			srcIP := sourceIPAddress
  2617  			if rewriteSourceIPAddress != nil {
  2618  				srcIP = rewriteSourceIPAddress
  2619  			}
  2620  
  2621  			destIP := destinationIPAddress
  2622  			if rewriteDestinationIPAddress != nil {
  2623  				destIP = rewriteDestinationIPAddress
  2624  			}
  2625  
  2626  			ID.set(srcIP, sourcePort, destIP, destinationPort, protocol)
  2627  
  2628  		} else if direction == packetDirectionServerDownstream {
  2629  
  2630  			ID.set(
  2631  				destinationIPAddress,
  2632  				destinationPort,
  2633  				sourceIPAddress,
  2634  				sourcePort,
  2635  				protocol)
  2636  		}
  2637  
  2638  		isTrackingFlow = session.isTrackingFlow(ID)
  2639  	}
  2640  
  2641  	// Check packet source/destination is permitted; except for:
  2642  	// - existing flows, which have already been checked
  2643  	// - transparent DNS, which is always allowed
  2644  
  2645  	if !doTransparentDNS && !isTrackingFlow {
  2646  
  2647  		// Enforce traffic rules (allowed TCP/UDP ports).
  2648  
  2649  		checkPort := 0
  2650  		if direction == packetDirectionServerUpstream ||
  2651  			direction == packetDirectionClientUpstream {
  2652  
  2653  			checkPort = int(destinationPort)
  2654  
  2655  		} else if direction == packetDirectionServerDownstream ||
  2656  			direction == packetDirectionClientDownstream {
  2657  
  2658  			checkPort = int(sourcePort)
  2659  		}
  2660  
  2661  		if protocol == internetProtocolTCP {
  2662  
  2663  			invalidPort := (checkPort == 0)
  2664  
  2665  			if !invalidPort && isServer {
  2666  				checkAllowedTCPPortFunc := session.getCheckAllowedTCPPortFunc()
  2667  				if checkAllowedTCPPortFunc == nil ||
  2668  					!checkAllowedTCPPortFunc(net.IP(ID.upstreamIPAddress[:]), checkPort) {
  2669  					invalidPort = true
  2670  				}
  2671  			}
  2672  
  2673  			if invalidPort {
  2674  				metrics.rejectedPacket(direction, packetRejectTCPPort)
  2675  				return false
  2676  			}
  2677  
  2678  		} else if protocol == internetProtocolUDP {
  2679  
  2680  			invalidPort := (checkPort == 0)
  2681  
  2682  			if !invalidPort && isServer {
  2683  				checkAllowedUDPPortFunc := session.getCheckAllowedUDPPortFunc()
  2684  				if checkAllowedUDPPortFunc == nil ||
  2685  					!checkAllowedUDPPortFunc(net.IP(ID.upstreamIPAddress[:]), checkPort) {
  2686  					invalidPort = true
  2687  				}
  2688  			}
  2689  
  2690  			if invalidPort {
  2691  				metrics.rejectedPacket(direction, packetRejectUDPPort)
  2692  				return false
  2693  			}
  2694  		}
  2695  
  2696  		// Enforce no localhost, multicast or broadcast packets; and no
  2697  		// client-to-client packets.
  2698  		//
  2699  		// TODO: a client-side check could check that destination IP
  2700  		// is strictly a tun device IP address.
  2701  
  2702  		if !destinationIPAddress.IsGlobalUnicast() ||
  2703  
  2704  			(direction == packetDirectionServerUpstream &&
  2705  				!session.allowBogons &&
  2706  				common.IsBogon(destinationIPAddress)) ||
  2707  
  2708  			// Client-to-client packets are disallowed even when other bogons are
  2709  			// allowed.
  2710  			(direction == packetDirectionServerUpstream &&
  2711  				((version == 4 &&
  2712  					!destinationIPAddress.Equal(transparentDNSResolverIPv4Address) &&
  2713  					privateSubnetIPv4.Contains(destinationIPAddress)) ||
  2714  					(version == 6 &&
  2715  						!destinationIPAddress.Equal(transparentDNSResolverIPv6Address) &&
  2716  						privateSubnetIPv6.Contains(destinationIPAddress)))) {
  2717  
  2718  			metrics.rejectedPacket(direction, packetRejectDestinationAddress)
  2719  			return false
  2720  		}
  2721  	}
  2722  
  2723  	// Apply packet rewrites. IP (v4 only) and TCP/UDP all have packet
  2724  	// checksums which are updated to relect the rewritten headers.
  2725  
  2726  	var checksumAccumulator int32
  2727  
  2728  	if rewriteSourceIPAddress != nil {
  2729  		checksumAccumulate(sourceIPAddress, false, &checksumAccumulator)
  2730  		copy(sourceIPAddress, rewriteSourceIPAddress)
  2731  		checksumAccumulate(sourceIPAddress, true, &checksumAccumulator)
  2732  	}
  2733  
  2734  	if rewriteDestinationIPAddress != nil {
  2735  		checksumAccumulate(destinationIPAddress, false, &checksumAccumulator)
  2736  		copy(destinationIPAddress, rewriteDestinationIPAddress)
  2737  		checksumAccumulate(destinationIPAddress, true, &checksumAccumulator)
  2738  	}
  2739  
  2740  	if rewriteSourceIPAddress != nil || rewriteDestinationIPAddress != nil {
  2741  
  2742  		// IPv6 doesn't have an IP header checksum.
  2743  		if version == 4 {
  2744  			checksumAdjust(IPChecksum, checksumAccumulator)
  2745  		}
  2746  
  2747  		if protocol == internetProtocolTCP {
  2748  			checksumAdjust(TCPChecksum, checksumAccumulator)
  2749  		} else { // UDP
  2750  			checksumAdjust(UDPChecksum, checksumAccumulator)
  2751  		}
  2752  	}
  2753  
  2754  	// Start/update flow tracking, only once past all possible packet rejects
  2755  
  2756  	if doFlowTracking {
  2757  		if !isTrackingFlow {
  2758  			session.startTrackingFlow(ID, direction, applicationData, doTransparentDNS)
  2759  		} else {
  2760  			session.updateFlow(ID, direction, applicationData)
  2761  		}
  2762  	}
  2763  
  2764  	metrics.relayedPacket(direction, int(version), protocol, len(packet), len(applicationData))
  2765  
  2766  	return true
  2767  }
  2768  
  2769  // Checksum code based on https://github.com/OpenVPN/openvpn:
  2770  /*
  2771  OpenVPN (TM) -- An Open Source VPN daemon
  2772  
  2773  Copyright (C) 2002-2017 OpenVPN Technologies, Inc. <sales@openvpn.net>
  2774  
  2775  OpenVPN license:
  2776  ----------------
  2777  
  2778  OpenVPN is distributed under the GPL license version 2 (see COPYRIGHT.GPL).
  2779  */
  2780  
  2781  func checksumAccumulate(data []byte, newData bool, accumulator *int32) {
  2782  
  2783  	// Based on ADD_CHECKSUM_32 and SUB_CHECKSUM_32 macros from OpenVPN:
  2784  	// https://github.com/OpenVPN/openvpn/blob/58716979640b5d8850b39820f91da616964398cc/src/openvpn/proto.h#L177
  2785  
  2786  	// Assumes length of data is factor of 4.
  2787  
  2788  	for i := 0; i < len(data); i += 4 {
  2789  		word := uint32(data[i+0])<<24 | uint32(data[i+1])<<16 | uint32(data[i+2])<<8 | uint32(data[i+3])
  2790  		if newData {
  2791  			*accumulator -= int32(word & 0xFFFF)
  2792  			*accumulator -= int32(word >> 16)
  2793  		} else {
  2794  			*accumulator += int32(word & 0xFFFF)
  2795  			*accumulator += int32(word >> 16)
  2796  		}
  2797  	}
  2798  }
  2799  
  2800  func checksumAdjust(checksumData []byte, accumulator int32) {
  2801  
  2802  	// Based on ADJUST_CHECKSUM macro from OpenVPN:
  2803  	// https://github.com/OpenVPN/openvpn/blob/58716979640b5d8850b39820f91da616964398cc/src/openvpn/proto.h#L177
  2804  
  2805  	// Assumes checksumData is 2 byte slice.
  2806  
  2807  	checksum := uint16(checksumData[0])<<8 | uint16(checksumData[1])
  2808  
  2809  	accumulator += int32(checksum)
  2810  	if accumulator < 0 {
  2811  		accumulator = -accumulator
  2812  		accumulator = (accumulator >> 16) + (accumulator & 0xFFFF)
  2813  		accumulator += accumulator >> 16
  2814  		checksum = uint16(^accumulator)
  2815  	} else {
  2816  		accumulator = (accumulator >> 16) + (accumulator & 0xFFFF)
  2817  		accumulator += accumulator >> 16
  2818  		checksum = uint16(accumulator)
  2819  	}
  2820  
  2821  	checksumData[0] = byte(checksum >> 8)
  2822  	checksumData[1] = byte(checksum & 0xFF)
  2823  }
  2824  
  2825  /*
  2826  
  2827  packet debugging snippet:
  2828  
  2829  	import (
  2830          "github.com/google/gopacket"
  2831          "github.com/google/gopacket/layers"
  2832  	)
  2833  
  2834  
  2835  	func tracePacket(where string, packet []byte) {
  2836  		var p gopacket.Packet
  2837  		if len(packet) > 0 && packet[0]>>4 == 4 {
  2838  			p = gopacket.NewPacket(packet, layers.LayerTypeIPv4, gopacket.Default)
  2839  		} else {
  2840  			p = gopacket.NewPacket(packet, layers.LayerTypeIPv6, gopacket.Default)
  2841  		}
  2842  		fmt.Printf("[%s packet]:\n%s\n\n", where, p)
  2843  	}
  2844  */
  2845  
  2846  // Device manages a tun device. It handles packet I/O using static,
  2847  // preallocated buffers to avoid GC churn.
  2848  type Device struct {
  2849  	name           string
  2850  	writeMutex     sync.Mutex
  2851  	deviceIO       io.ReadWriteCloser
  2852  	inboundBuffer  []byte
  2853  	outboundBuffer []byte
  2854  }
  2855  
  2856  // NewServerDevice creates and configures a new server tun device.
  2857  // Since the server uses fixed address spaces, only one server
  2858  // device may exist per host.
  2859  func NewServerDevice(config *ServerConfig) (*Device, error) {
  2860  
  2861  	file, deviceName, err := OpenTunDevice("")
  2862  	if err != nil {
  2863  		return nil, errors.Trace(err)
  2864  	}
  2865  
  2866  	err = configureServerInterface(config, deviceName)
  2867  	if err != nil {
  2868  		_ = file.Close()
  2869  		return nil, errors.Trace(err)
  2870  	}
  2871  
  2872  	return newDevice(
  2873  		deviceName,
  2874  		file,
  2875  		getMTU(config.MTU)), nil
  2876  }
  2877  
  2878  // NewClientDevice creates and configures a new client tun device.
  2879  // Multiple client tun devices may exist per host.
  2880  func NewClientDevice(config *ClientConfig) (*Device, error) {
  2881  
  2882  	file, deviceName, err := OpenTunDevice("")
  2883  	if err != nil {
  2884  		return nil, errors.Trace(err)
  2885  	}
  2886  
  2887  	err = configureClientInterface(
  2888  		config, deviceName)
  2889  	if err != nil {
  2890  		_ = file.Close()
  2891  		return nil, errors.Trace(err)
  2892  	}
  2893  
  2894  	return newDevice(
  2895  		deviceName,
  2896  		file,
  2897  		getMTU(config.MTU)), nil
  2898  }
  2899  
  2900  func newDevice(
  2901  	name string,
  2902  	deviceIO io.ReadWriteCloser,
  2903  	MTU int) *Device {
  2904  
  2905  	return &Device{
  2906  		name:           name,
  2907  		deviceIO:       deviceIO,
  2908  		inboundBuffer:  makeDeviceInboundBuffer(MTU),
  2909  		outboundBuffer: makeDeviceOutboundBuffer(MTU),
  2910  	}
  2911  }
  2912  
  2913  // NewClientDeviceFromFD wraps an existing tun device.
  2914  func NewClientDeviceFromFD(config *ClientConfig) (*Device, error) {
  2915  
  2916  	file, err := fileFromFD(config.TunFileDescriptor, "")
  2917  	if err != nil {
  2918  		return nil, errors.Trace(err)
  2919  	}
  2920  
  2921  	MTU := getMTU(config.MTU)
  2922  
  2923  	return &Device{
  2924  		name:           "",
  2925  		deviceIO:       file,
  2926  		inboundBuffer:  makeDeviceInboundBuffer(MTU),
  2927  		outboundBuffer: makeDeviceOutboundBuffer(MTU),
  2928  	}, nil
  2929  }
  2930  
  2931  // Name returns the interface name for a created tun device,
  2932  // or returns "" for a device created by NewClientDeviceFromFD.
  2933  // The interface name may be used for additional network and
  2934  // routing configuration.
  2935  func (device *Device) Name() string {
  2936  	return device.name
  2937  }
  2938  
  2939  // ReadPacket reads one full packet from the tun device. The
  2940  // return value is a slice of a static, reused buffer, so the
  2941  // value is only valid until the next ReadPacket call.
  2942  // Concurrent calls to ReadPacket are _not_ supported.
  2943  func (device *Device) ReadPacket() ([]byte, error) {
  2944  
  2945  	// readTunPacket performs the platform dependent
  2946  	// packet read operation.
  2947  	offset, size, err := device.readTunPacket()
  2948  	if err != nil {
  2949  		return nil, errors.Trace(err)
  2950  	}
  2951  
  2952  	return device.inboundBuffer[offset : offset+size], nil
  2953  }
  2954  
  2955  // WritePacket writes one full packet to the tun device.
  2956  // Concurrent calls to WritePacket are supported.
  2957  func (device *Device) WritePacket(packet []byte) error {
  2958  
  2959  	// This mutex ensures that only one concurrent goroutine
  2960  	// can use outboundBuffer when writing.
  2961  	device.writeMutex.Lock()
  2962  	defer device.writeMutex.Unlock()
  2963  
  2964  	// writeTunPacket performs the platform dependent
  2965  	// packet write operation.
  2966  	err := device.writeTunPacket(packet)
  2967  	if err != nil {
  2968  		return errors.Trace(err)
  2969  	}
  2970  
  2971  	return nil
  2972  }
  2973  
  2974  // Close interrupts any blocking Read/Write calls and
  2975  // tears down the tun device.
  2976  func (device *Device) Close() error {
  2977  	return device.deviceIO.Close()
  2978  }
  2979  
  2980  // Channel manages packet transport over a communications channel.
  2981  // Any io.ReadWriteCloser can provide transport. In psiphond, the
  2982  // io.ReadWriteCloser will be an SSH channel. Channel I/O frames
  2983  // packets with a length header and uses static, preallocated
  2984  // buffers to avoid GC churn.
  2985  type Channel struct {
  2986  	transport      io.ReadWriteCloser
  2987  	inboundBuffer  []byte
  2988  	outboundBuffer []byte
  2989  }
  2990  
  2991  // IP packets cannot be larger that 64K, so a 16-bit length
  2992  // header is sufficient.
  2993  const (
  2994  	channelHeaderSize = 2
  2995  )
  2996  
  2997  // NewChannel initializes a new Channel.
  2998  func NewChannel(transport io.ReadWriteCloser, MTU int) *Channel {
  2999  	return &Channel{
  3000  		transport:      transport,
  3001  		inboundBuffer:  make([]byte, channelHeaderSize+MTU),
  3002  		outboundBuffer: make([]byte, channelHeaderSize+MTU),
  3003  	}
  3004  }
  3005  
  3006  // ReadPacket reads one full packet from the channel. The
  3007  // return value is a slice of a static, reused buffer, so the
  3008  // value is only valid until the next ReadPacket call.
  3009  // Concurrent calls to ReadPacket are not supported.
  3010  func (channel *Channel) ReadPacket() ([]byte, error) {
  3011  
  3012  	header := channel.inboundBuffer[0:channelHeaderSize]
  3013  	_, err := io.ReadFull(channel.transport, header)
  3014  	if err != nil {
  3015  		return nil, errors.Trace(err)
  3016  	}
  3017  
  3018  	size := int(binary.BigEndian.Uint16(header))
  3019  	if size > len(channel.inboundBuffer[channelHeaderSize:]) {
  3020  		return nil, errors.Tracef("packet size exceeds MTU: %d", size)
  3021  	}
  3022  
  3023  	packet := channel.inboundBuffer[channelHeaderSize : channelHeaderSize+size]
  3024  	_, err = io.ReadFull(channel.transport, packet)
  3025  	if err != nil {
  3026  		return nil, errors.Trace(err)
  3027  	}
  3028  
  3029  	return packet, nil
  3030  }
  3031  
  3032  // WritePacket writes one full packet to the channel.
  3033  // Concurrent calls to WritePacket are not supported.
  3034  func (channel *Channel) WritePacket(packet []byte) error {
  3035  
  3036  	// Flow control assumed to be provided by the transport. In the case
  3037  	// of SSH, the channel window size will determine whether the packet
  3038  	// data is transmitted immediately or whether the transport.Write will
  3039  	// block. When the channel window is full and transport.Write blocks,
  3040  	// the sender's tun device will not be read (client case) or the send
  3041  	// queue will fill (server case) and packets will be dropped. In this
  3042  	// way, the channel window size will influence the TCP window size for
  3043  	// tunneled traffic.
  3044  
  3045  	// When the transport is an SSH channel, the overhead per packet message
  3046  	// includes:
  3047  	//
  3048  	// - SSH_MSG_CHANNEL_DATA: 5 bytes (https://tools.ietf.org/html/rfc4254#section-5.2)
  3049  	// - SSH packet: ~28 bytes (https://tools.ietf.org/html/rfc4253#section-5.3), with MAC
  3050  	// - TCP/IP transport for SSH: 40 bytes for IPv4
  3051  
  3052  	// Assumes MTU <= 64K and len(packet) <= MTU
  3053  
  3054  	size := len(packet)
  3055  	binary.BigEndian.PutUint16(channel.outboundBuffer, uint16(size))
  3056  	copy(channel.outboundBuffer[channelHeaderSize:], packet)
  3057  	_, err := channel.transport.Write(channel.outboundBuffer[0 : channelHeaderSize+size])
  3058  	if err != nil {
  3059  		return errors.Trace(err)
  3060  	}
  3061  
  3062  	return nil
  3063  }
  3064  
  3065  // WriteFramedPackets writes a buffer of pre-framed packets to
  3066  // the channel.
  3067  // Concurrent calls to WriteFramedPackets are not supported.
  3068  func (channel *Channel) WriteFramedPackets(packetBuffer []byte) error {
  3069  	_, err := channel.transport.Write(packetBuffer)
  3070  	if err != nil {
  3071  		return errors.Trace(err)
  3072  	}
  3073  	return nil
  3074  }
  3075  
  3076  // Close interrupts any blocking Read/Write calls and
  3077  // closes the channel transport.
  3078  func (channel *Channel) Close() error {
  3079  	return channel.transport.Close()
  3080  }