
     1  // Copyright 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     4  package state
     6  import (
     7  	"fmt"
     8  	"regexp"
     9  	"strings"
    11  	""
    12  	""
    13  	statetxn ""
    14  	""
    15  	""
    16  	""
    18  	""
    19  )
    21  // A regular expression for parsing ports document id into corresponding machine
    22  // and subnet ids.
    23  var portsIDRe = regexp.MustCompile(fmt.Sprintf("m#(?P<machine>%s)#(?P<subnet>.*)$", names.MachineSnippet))
    25  type portIDPart int
    27  const (
    28  	_ portIDPart = iota
    29  	machineIDPart
    30  	subnetIDPart
    31  )
    33  // PortRange represents a single range of ports opened
    34  // by one unit.
    35  type PortRange struct {
    36  	UnitName string
    37  	FromPort int
    38  	ToPort   int
    39  	Protocol string
    40  }
    42  // NewPortRange create a new port range and validate it.
    43  func NewPortRange(unitName string, fromPort, toPort int, protocol string) (PortRange, error) {
    44  	p := PortRange{
    45  		UnitName: unitName,
    46  		FromPort: fromPort,
    47  		ToPort:   toPort,
    48  		Protocol: strings.ToLower(protocol),
    49  	}
    50  	if err := p.Validate(); err != nil {
    51  		return PortRange{}, err
    52  	}
    53  	return p, nil
    54  }
    56  // PortRangeFromNetworkPortRange constructs a state.PortRange from the
    57  // given unitName and network.PortRange.
    58  func PortRangeFromNetworkPortRange(unitName string, portRange network.PortRange) (PortRange, error) {
    59  	return NewPortRange(unitName, portRange.FromPort, portRange.ToPort, portRange.Protocol)
    60  }
    62  // Validate checks if the port range is valid.
    63  func (p PortRange) Validate() error {
    64  	proto := strings.ToLower(p.Protocol)
    65  	if proto != "tcp" && proto != "udp" {
    66  		return errors.Errorf("invalid protocol %q", proto)
    67  	}
    68  	if !names.IsValidUnit(p.UnitName) {
    69  		return errors.Errorf("invalid unit %q", p.UnitName)
    70  	}
    71  	if p.FromPort > p.ToPort {
    72  		return errors.Errorf("invalid port range %d-%d", p.FromPort, p.ToPort)
    73  	}
    74  	if p.FromPort <= 0 || p.FromPort > 65535 ||
    75  		p.ToPort <= 0 || p.ToPort > 65535 {
    76  		return errors.Errorf("port range bounds must be between 1 and 65535, got %d-%d", p.FromPort, p.ToPort)
    77  	}
    78  	return nil
    79  }
    81  // Length returns the number of ports in the range.
    82  // If the range is not valid, it returns 0.
    83  func (a PortRange) Length() int {
    84  	if err := a.Validate(); err != nil {
    85  		// Invalid range (from > to or something equally bad)
    86  		return 0
    87  	}
    88  	return (a.ToPort - a.FromPort) + 1
    89  }
    91  // Sanitize returns a copy of the port range, which is guaranteed to
    92  // have FromPort >= ToPort and both FromPort and ToPort fit into the
    93  // valid range from 1 to 65535, inclusive.
    94  func (a PortRange) SanitizeBounds() PortRange {
    95  	b := a
    96  	if b.FromPort > b.ToPort {
    97  		b.FromPort, b.ToPort = b.ToPort, b.FromPort
    98  	}
    99  	for _, bound := range []*int{&b.FromPort, &b.ToPort} {
   100  		switch {
   101  		case *bound <= 0:
   102  			*bound = 1
   103  		case *bound > 65535:
   104  			*bound = 65535
   105  		}
   106  	}
   107  	return b
   108  }
   110  // CheckConflicts determines if the two port ranges conflict.
   111  func (prA PortRange) CheckConflicts(prB PortRange) error {
   112  	if err := prA.Validate(); err != nil {
   113  		return err
   114  	}
   115  	if err := prB.Validate(); err != nil {
   116  		return err
   117  	}
   119  	// An exact port range match (including the associated unit name) is not
   120  	// considered a conflict due to the fact that many charms issue commands
   121  	// to open the same port multiple times.
   122  	if prA == prB {
   123  		return nil
   124  	}
   125  	if prA.Protocol != prB.Protocol {
   126  		return nil
   127  	}
   128  	if prA.ToPort >= prB.FromPort && prB.ToPort >= prA.FromPort {
   129  		return errors.Errorf("port ranges %v and %v conflict", prA, prB)
   130  	}
   131  	return nil
   132  }
   134  // Strings returns the port range as a string.
   135  func (p PortRange) String() string {
   136  	return fmt.Sprintf("%d-%d/%s (%q)", p.FromPort, p.ToPort, strings.ToLower(p.Protocol), p.UnitName)
   137  }
   139  // portsDoc represents the state of ports opened on machines for networks
   140  type portsDoc struct {
   141  	DocID     string      `bson:"_id"`
   142  	ModelUUID string      `bson:"model-uuid"`
   143  	MachineID string      `bson:"machine-id"`
   144  	SubnetID  string      `bson:"subnet-id"`
   145  	Ports     []PortRange `bson:"ports"`
   146  	TxnRevno  int64       `bson:"txn-revno"`
   147  }
   149  // Ports represents the state of ports on a machine.
   150  type Ports struct {
   151  	st  *State
   152  	doc portsDoc
   153  	// areNew is true for documents not in state yet.
   154  	areNew bool
   155  }
   157  // String returns p as a user-readable string.
   158  func (p *Ports) String() string {
   159  	return fmt.Sprintf("ports for machine %q, subnet %q", p.doc.MachineID, p.doc.SubnetID)
   160  }
   162  // globalKey returns the id of the ports document.
   163  func (p *Ports) globalKey() string {
   164  	return portsGlobalKey(p.doc.MachineID, p.doc.SubnetID)
   165  }
   167  // portsGlobalKey returns the global database key for the opened ports
   168  // document for the given machine and subnet.
   169  func portsGlobalKey(machineID, subnetID string) string {
   170  	return fmt.Sprintf("m#%s#%s", machineID, subnetID)
   171  }
   173  // extractPortsIDParts parses the given ports global key and extracts
   174  // its parts.
   175  func extractPortsIDParts(globalKey string) ([]string, error) {
   176  	if parts := portsIDRe.FindStringSubmatch(globalKey); len(parts) == 3 {
   177  		return parts, nil
   178  	}
   179  	return nil, errors.NotValidf("ports document key %q", globalKey)
   180  }
   182  // SubnetID returns the subnet ID associated with this ports document.
   183  func (p *Ports) SubnetID() string {
   184  	return p.doc.SubnetID
   185  }
   187  // OpenPorts adds the specified port range to the list of ports
   188  // maintained by this document.
   189  func (p *Ports) OpenPorts(portRange PortRange) (err error) {
   190  	defer errors.DeferredAnnotatef(&err, "cannot open ports %s", portRange)
   192  	if err = portRange.Validate(); err != nil {
   193  		return errors.Trace(err)
   194  	}
   195  	ports := Ports{st:, doc: p.doc, areNew: p.areNew}
   197  	buildTxn := func(attempt int) ([]txn.Op, error) {
   198  		if attempt > 0 {
   199  			if err := checkModelActive(; err != nil {
   200  				return nil, errors.Trace(err)
   201  			}
   202  			if err := p.verifySubnetAliveWhenSet(); err != nil {
   203  				return nil, errors.Trace(err)
   204  			}
   205  			if err = ports.Refresh(); errors.IsNotFound(err) {
   206  				// No longer exists, we'll create it.
   207  				if !ports.areNew {
   208  					ports.areNew = true
   209  				}
   210  			} else if err != nil {
   211  				return nil, errors.Trace(err)
   212  			} else if ports.areNew {
   213  				// Already created, we'll update it.
   214  				ports.areNew = false
   215  			}
   216  		}
   218  		// Check for conflicts with existing ports.
   219  		for _, existingPorts := range p.doc.Ports {
   220  			if err := existingPorts.CheckConflicts(portRange); err != nil {
   221  				return nil, errors.Trace(err)
   222  			} else if existingPorts == portRange {
   223  				// Trying to open the same range for the same unit is
   224  				// ignored, as we don't need to change the document
   225  				// and hence its txn-revno and trigger unnecessary
   226  				// watcher notifications.
   227  				return nil, statetxn.ErrNoOperations
   228  			}
   229  		}
   231  		ops := []txn.Op{
   232  			assertModelActiveOp(,
   233  		}
   234  		if ports.areNew {
   235  			// Create a new document.
   236  			assert := txn.DocMissing
   237  			ops = append(ops, addPortsDocOps(, &ports.doc, assert, portRange)...)
   238  		} else {
   239  			// Update an existing document.
   240  			assert := bson.D{{"txn-revno", ports.doc.TxnRevno}}
   241  			ops = append(ops, updatePortsDocOps(, ports.doc, assert, portRange)...)
   242  		}
   243  		return ops, nil
   244  	}
   245  	// Run the transaction using the state transaction runner.
   246  	if err =; err != nil {
   247  		return errors.Trace(err)
   248  	}
   249  	// Mark object as created.
   250  	p.areNew = false
   251  	p.doc.Ports = append(p.doc.Ports, portRange)
   252  	return nil
   253  }
   255  func (p *Ports) verifySubnetAliveWhenSet() error {
   256  	if p.doc.SubnetID == "" {
   257  		return nil
   258  	}
   260  	subnet, err :=
   261  	if err != nil {
   262  		return errors.Trace(err)
   263  	} else if subnet.Life() != Alive {
   264  		return errors.Errorf("subnet %q not alive", subnet.CIDR())
   265  	}
   266  	return nil
   267  }
   269  // ClosePorts removes the specified port range from the list of ports
   270  // maintained by this document.
   271  func (p *Ports) ClosePorts(portRange PortRange) (err error) {
   272  	defer errors.DeferredAnnotatef(&err, "cannot close ports %s", portRange)
   274  	if err = portRange.Validate(); err != nil {
   275  		return errors.Trace(err)
   276  	}
   277  	var newPorts []PortRange
   278  	ports := Ports{st:, doc: p.doc, areNew: p.areNew}
   280  	buildTxn := func(attempt int) ([]txn.Op, error) {
   281  		if attempt > 0 {
   282  			if err := p.verifySubnetAliveWhenSet(); err != nil {
   283  				return nil, errors.Trace(err)
   284  			}
   285  			if err = ports.Refresh(); errors.IsNotFound(err) {
   286  				// No longer exists, nothing to do.
   287  				return nil, statetxn.ErrNoOperations
   288  			} else if err != nil {
   289  				return nil, errors.Trace(err)
   290  			}
   291  		}
   292  		newPorts = newPorts[0:0]
   294  		found := false
   295  		for _, existingPortsDef := range ports.doc.Ports {
   296  			if existingPortsDef == portRange {
   297  				found = true
   298  				continue
   299  			}
   300  			err = existingPortsDef.CheckConflicts(portRange)
   301  			if existingPortsDef.UnitName == portRange.UnitName && err != nil {
   302  				return nil, errors.Trace(err)
   303  			}
   304  			newPorts = append(newPorts, existingPortsDef)
   305  		}
   306  		if !found {
   307  			return nil, statetxn.ErrNoOperations
   308  		}
   309  		if len(newPorts) == 0 {
   310  			// All ports closed, so remove the ports doc instead.
   311  			return p.removeOps(), nil
   312  		} else {
   313  			assert := bson.D{{"txn-revno", ports.doc.TxnRevno}}
   314  			return setPortsDocOps(, ports.doc, assert, newPorts...), nil
   315  		}
   316  	}
   317  	if err =; err != nil {
   318  		return errors.Trace(err)
   319  	}
   320  	p.doc.Ports = newPorts
   321  	return nil
   322  }
   324  // PortsForUnit returns the ports associated with specified unitName that are
   325  // maintained on this document (i.e. are open on this unit's assigned machine).
   326  func (p *Ports) PortsForUnit(unitName string) []PortRange {
   327  	ports := []PortRange{}
   328  	for _, port := range p.doc.Ports {
   329  		if port.UnitName == unitName {
   330  			ports = append(ports, port)
   331  		}
   332  	}
   333  	return ports
   334  }
   336  // Refresh refreshes the port document from state.
   337  func (p *Ports) Refresh() error {
   338  	openedPorts, closer :=
   339  	defer closer()
   341  	err := openedPorts.FindId(p.doc.DocID).One(&p.doc)
   342  	if err == mgo.ErrNotFound {
   343  		return errors.NotFoundf(p.String())
   344  	} else if err != nil {
   345  		return errors.Annotatef(err, "cannot refresh %s", p)
   346  	}
   347  	return nil
   348  }
   350  // AllPortRanges returns a map with network.PortRange as keys and unit
   351  // names as values.
   352  func (p *Ports) AllPortRanges() map[network.PortRange]string {
   353  	result := make(map[network.PortRange]string)
   354  	for _, portRange := range p.doc.Ports {
   355  		rawRange := network.PortRange{
   356  			FromPort: portRange.FromPort,
   357  			ToPort:   portRange.ToPort,
   358  			Protocol: portRange.Protocol,
   359  		}
   360  		result[rawRange] = portRange.UnitName
   361  	}
   362  	return result
   363  }
   365  // Remove removes the ports document from state.
   366  func (p *Ports) Remove() error {
   367  	ports := &Ports{st:, doc: p.doc}
   368  	buildTxn := func(attempt int) ([]txn.Op, error) {
   369  		if attempt > 0 {
   370  			err := ports.Refresh()
   371  			if errors.IsNotFound(err) {
   372  				return nil, statetxn.ErrNoOperations
   373  			} else if err != nil {
   374  				return nil, errors.Trace(err)
   375  			}
   376  		}
   377  		return ports.removeOps(), nil
   378  	}
   379  	return
   380  }
   382  // OpenedPorts returns this machine ports document for the given subnetID.
   383  func (m *Machine) OpenedPorts(subnetID string) (*Ports, error) {
   384  	ports, err := getPorts(, m.Id(), subnetID)
   385  	if err != nil && !errors.IsNotFound(err) {
   386  		return nil, errors.Trace(err)
   387  	}
   388  	return ports, nil
   389  }
   391  // AllPorts returns all opened ports for this machine (on all
   392  // networks).
   393  func (m *Machine) AllPorts() ([]*Ports, error) {
   394  	openedPorts, closer :=
   395  	defer closer()
   397  	docs := []portsDoc{}
   398  	err := openedPorts.Find(bson.D{{"machine-id", m.Id()}}).All(&docs)
   399  	if err != nil {
   400  		return nil, errors.Trace(err)
   401  	}
   402  	results := make([]*Ports, len(docs))
   403  	for i, doc := range docs {
   404  		results[i] = &Ports{st:, doc: doc}
   405  	}
   406  	return results, nil
   407  }
   409  // addPortsDocOps returns the ops for adding a number of port ranges
   410  // to a new ports document. portsAssert allows specifying an assert
   411  // statement for on the openedPorts collection op.
   412  var addPortsDocOps = addPortsDocOpsFunc
   414  func addPortsDocOpsFunc(st *State, pDoc *portsDoc, portsAssert interface{}, ports ...PortRange) []txn.Op {
   415  	pDoc.Ports = ports
   417  	ops := assertMachineNotDeadAndSubnetNotDeadWhenSetOps(st, pDoc)
   418  	return append(ops, txn.Op{
   419  		C:      openedPortsC,
   420  		Id:     pDoc.DocID,
   421  		Assert: portsAssert,
   422  		Insert: pDoc,
   423  	})
   424  }
   426  func assertMachineNotDeadAndSubnetNotDeadWhenSetOps(st *State, pDoc *portsDoc) []txn.Op {
   427  	ops := []txn.Op{{
   428  		C:      machinesC,
   429  		Id:     st.docID(pDoc.MachineID),
   430  		Assert: notDeadDoc,
   431  	}}
   433  	if pDoc.SubnetID != "" {
   434  		ops = append(ops, txn.Op{
   435  			C:      subnetsC,
   436  			Id:     st.docID(pDoc.SubnetID),
   437  			Assert: notDeadDoc,
   438  		})
   439  	}
   440  	return ops
   441  }
   443  // updatePortsDocOps returns the ops for adding a port range to an
   444  // existing ports document. portsAssert allows specifying an assert
   445  // statement on the openedPorts collection op.
   446  var updatePortsDocOps = updatePortsDocOpsFunc
   448  func updatePortsDocOpsFunc(st *State, pDoc portsDoc, portsAssert interface{}, portRange PortRange) []txn.Op {
   449  	ops := assertMachineNotDeadAndSubnetNotDeadWhenSetOps(st, &pDoc)
   450  	return append(ops, []txn.Op{{
   451  		C:      unitsC,
   452  		Id:     st.docID(portRange.UnitName),
   453  		Assert: notDeadDoc,
   454  	}, {
   455  		C:      openedPortsC,
   456  		Id:     pDoc.DocID,
   457  		Assert: portsAssert,
   458  		Update: bson.D{{"$addToSet", bson.D{{"ports", portRange}}}},
   459  	}}...)
   460  }
   462  // setPortsDocOps returns the ops for setting given port ranges to an
   463  // existing ports document. portsAssert allows specifying an assert
   464  // statement on the openedPorts collection op.
   465  var setPortsDocOps = setPortsDocOpsFunc
   467  func setPortsDocOpsFunc(st *State, pDoc portsDoc, portsAssert interface{}, ports ...PortRange) []txn.Op {
   468  	ops := assertMachineNotDeadAndSubnetNotDeadWhenSetOps(st, &pDoc)
   469  	return append(ops, txn.Op{
   470  		C:      openedPortsC,
   471  		Id:     pDoc.DocID,
   472  		Assert: portsAssert,
   473  		Update: bson.D{{"$set", bson.D{{"ports", ports}}}},
   474  	})
   475  }
   477  // removeOps returns the ops for removing the ports document from
   478  // state.
   479  func (p *Ports) removeOps() []txn.Op {
   480  	return []txn.Op{{
   481  		C:      openedPortsC,
   482  		Id:     p.doc.DocID,
   483  		Remove: true,
   484  	}}
   485  }
   487  // removePortsForUnitOps returns the ops needed to remove all opened
   488  // ports for the given unit on its assigned machine.
   489  func removePortsForUnitOps(st *State, unit *Unit) ([]txn.Op, error) {
   490  	machineId, err := unit.AssignedMachineId()
   491  	if err != nil {
   492  		// No assigned machine, so there won't be any ports.
   493  		return nil, nil
   494  	}
   495  	machine, err := st.Machine(machineId)
   496  	if errors.IsNotFound(err) {
   497  		// Machine is removed, so there won't be a ports doc for it.
   498  		return nil, nil
   499  	} else if err != nil {
   500  		return nil, errors.Trace(err)
   501  	}
   502  	allPorts, err := machine.AllPorts()
   503  	if err != nil {
   504  		return nil, errors.Trace(err)
   505  	}
   506  	var ops []txn.Op
   507  	for _, ports := range allPorts {
   508  		allRanges := ports.AllPortRanges()
   509  		var keepPorts []PortRange
   510  		for portRange, unitName := range allRanges {
   511  			if unitName != unit.Name() {
   512  				unitRange := PortRange{
   513  					UnitName: unitName,
   514  					FromPort: portRange.FromPort,
   515  					ToPort:   portRange.ToPort,
   516  					Protocol: portRange.Protocol,
   517  				}
   518  				keepPorts = append(keepPorts, unitRange)
   519  			}
   520  		}
   521  		if len(keepPorts) > 0 {
   522  			assert := bson.D{{"txn-revno", ports.doc.TxnRevno}}
   523  			ops = append(ops, setPortsDocOps(st, ports.doc, assert, keepPorts...)...)
   524  		} else {
   525  			// No other ports left, remove the doc.
   526  			ops = append(ops, ports.removeOps()...)
   527  		}
   528  	}
   529  	return ops, nil
   530  }
   532  // getPorts returns the ports document for the specified machine and subnet.
   533  func getPorts(st *State, machineID, subnetID string) (*Ports, error) {
   534  	openedPorts, closer := st.getCollection(openedPortsC)
   535  	defer closer()
   537  	var doc portsDoc
   538  	key := portsGlobalKey(machineID, subnetID)
   539  	err := openedPorts.FindId(key).One(&doc)
   540  	if err != nil {
   541  		doc.MachineID = machineID
   542  		doc.SubnetID = subnetID
   543  		p := Ports{st, doc, false}
   544  		if err == mgo.ErrNotFound {
   545  			return nil, errors.NotFoundf(p.String())
   546  		}
   547  		return nil, errors.Annotatef(err, "cannot get %s", p.String())
   548  	}
   550  	return &Ports{st, doc, false}, nil
   551  }
   553  // getOrCreatePorts attempts to retrieve a ports document and returns a newly
   554  // created one if it does not exist.
   555  func getOrCreatePorts(st *State, machineID, subnetID string) (*Ports, error) {
   556  	ports, err := getPorts(st, machineID, subnetID)
   557  	if errors.IsNotFound(err) {
   558  		key := portsGlobalKey(machineID, subnetID)
   559  		doc := portsDoc{
   560  			DocID:     st.docID(key),
   561  			MachineID: machineID,
   562  			SubnetID:  subnetID,
   563  			ModelUUID: st.ModelUUID(),
   564  		}
   565  		ports = &Ports{st, doc, true}
   566  	} else if err != nil {
   567  		return nil, errors.Trace(err)
   568  	}
   569  	return ports, nil
   570  }