github.com/psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/portlist.go (about)

     1  /*
     2   * Copyright (c) 2021, Psiphon Inc.
     3   * All rights reserved.
     4   *
     5   * This program is free software: you can redistribute it and/or modify
     6   * it under the terms of the GNU General Public License as published by
     7   * the Free Software Foundation, either version 3 of the License, or
     8   * (at your option) any later version.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package common
    21  
    22  import (
    23  	"bytes"
    24  	"encoding/json"
    25  	"strconv"
    26  
    27  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
    28  )
    29  
    30  // PortList provides a lookup for a configured list of IP ports and port
    31  // ranges. PortList is intended for use with JSON config files and is
    32  // initialized via UnmarshalJSON.
    33  //
    34  // A JSON port list field should look like:
    35  //
    36  // "FieldName": [1, 2, 3, [10, 20], [30, 40]]
    37  //
    38  // where the ports in the list are 1, 2, 3, 10-20, 30-40. UnmarshalJSON
    39  // validates that each port is in the range 1-65535 and that ranges have two
    40  // elements in increasing order. PortList is designed to be backwards
    41  // compatible with existing JSON config files where port list fields were
    42  // defined as `[]int`.
    43  type PortList struct {
    44  	portRanges [][2]int
    45  	lookup     map[int]bool
    46  }
    47  
    48  const lookupThreshold = 10
    49  
    50  // OptimizeLookups converts the internal port list representation to use a
    51  // map, which increases the performance of lookups for longer lists with an
    52  // increased memory footprint tradeoff. OptimizeLookups is not safe to use
    53  // concurrently with Lookup and should be called immediately after
    54  // UnmarshalJSON and before performing lookups.
    55  func (p *PortList) OptimizeLookups() {
    56  	if p == nil {
    57  		return
    58  	}
    59  	// TODO: does the threshold take long ranges into account?
    60  	if len(p.portRanges) > lookupThreshold {
    61  		p.lookup = make(map[int]bool)
    62  		for _, portRange := range p.portRanges {
    63  			for i := portRange[0]; i <= portRange[1]; i++ {
    64  				p.lookup[i] = true
    65  			}
    66  		}
    67  	}
    68  }
    69  
    70  // IsEmpty returns true for a nil PortList or a PortList with no entries.
    71  func (p *PortList) IsEmpty() bool {
    72  	if p == nil {
    73  		return true
    74  	}
    75  	return len(p.portRanges) == 0
    76  }
    77  
    78  // Lookup returns true if the specified port is in the port list and false
    79  // otherwise. Lookups on a nil PortList are allowed and return false.
    80  func (p *PortList) Lookup(port int) bool {
    81  	if p == nil {
    82  		return false
    83  	}
    84  	if p.lookup != nil {
    85  		return p.lookup[port]
    86  	}
    87  	for _, portRange := range p.portRanges {
    88  		if port >= portRange[0] && port <= portRange[1] {
    89  			return true
    90  		}
    91  	}
    92  	return false
    93  }
    94  
    95  // UnmarshalJSON implements the json.Unmarshaler interface.
    96  func (p *PortList) UnmarshalJSON(b []byte) error {
    97  
    98  	p.portRanges = nil
    99  	p.lookup = nil
   100  
   101  	if bytes.Equal(b, []byte("null")) {
   102  		return nil
   103  	}
   104  
   105  	decoder := json.NewDecoder(bytes.NewReader(b))
   106  	decoder.UseNumber()
   107  
   108  	var array []interface{}
   109  
   110  	err := decoder.Decode(&array)
   111  	if err != nil {
   112  		return errors.Trace(err)
   113  	}
   114  
   115  	p.portRanges = make([][2]int, len(array))
   116  
   117  	for i, portRange := range array {
   118  
   119  		var startPort, endPort int64
   120  
   121  		if portNumber, ok := portRange.(json.Number); ok {
   122  
   123  			port, err := portNumber.Int64()
   124  			if err != nil {
   125  				return errors.Trace(err)
   126  			}
   127  
   128  			startPort = port
   129  			endPort = port
   130  
   131  		} else if array, ok := portRange.([]interface{}); ok {
   132  
   133  			if len(array) != 2 {
   134  				return errors.TraceNew("invalid range size")
   135  			}
   136  
   137  			portNumber, ok := array[0].(json.Number)
   138  			if !ok {
   139  				return errors.TraceNew("invalid type")
   140  			}
   141  			port, err := portNumber.Int64()
   142  			if err != nil {
   143  				return errors.Trace(err)
   144  			}
   145  			startPort = port
   146  
   147  			portNumber, ok = array[1].(json.Number)
   148  			if !ok {
   149  				return errors.TraceNew("invalid type")
   150  			}
   151  			port, err = portNumber.Int64()
   152  			if err != nil {
   153  				return errors.Trace(err)
   154  			}
   155  			endPort = port
   156  
   157  		} else {
   158  
   159  			return errors.TraceNew("invalid type")
   160  		}
   161  
   162  		if startPort < 1 || startPort > 65535 {
   163  			return errors.TraceNew("invalid range start")
   164  		}
   165  
   166  		if endPort < 1 || endPort > 65535 || endPort < startPort {
   167  			return errors.TraceNew("invalid range end")
   168  		}
   169  
   170  		p.portRanges[i] = [2]int{int(startPort), int(endPort)}
   171  	}
   172  
   173  	return nil
   174  }
   175  
   176  // MarshalJSON implements the json.Marshaler interface.
   177  func (p *PortList) MarshalJSON() ([]byte, error) {
   178  	var json bytes.Buffer
   179  	json.WriteString("[")
   180  	for i, portRange := range p.portRanges {
   181  		if i > 0 {
   182  			json.WriteString(",")
   183  		}
   184  		if portRange[0] == portRange[1] {
   185  			json.WriteString(strconv.Itoa(portRange[0]))
   186  		} else {
   187  			json.WriteString("[")
   188  			json.WriteString(strconv.Itoa(portRange[0]))
   189  			json.WriteString(",")
   190  			json.WriteString(strconv.Itoa(portRange[1]))
   191  			json.WriteString("]")
   192  		}
   193  	}
   194  	json.WriteString("]")
   195  	return json.Bytes(), nil
   196  }