github.com/containerd/nerdctl/v2@v2.0.0-beta.5.0.20240520001846-b5758f54fa28/pkg/idutil/netwalker/netwalker.go (about)

     1  /*
     2     Copyright The containerd Authors.
     3  
     4     Licensed under the Apache License, Version 2.0 (the "License");
     5     you may not use this file except in compliance with the License.
     6     You may obtain a copy of the License at
     7  
     8         http://www.apache.org/licenses/LICENSE-2.0
     9  
    10     Unless required by applicable law or agreed to in writing, software
    11     distributed under the License is distributed on an "AS IS" BASIS,
    12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13     See the License for the specific language governing permissions and
    14     limitations under the License.
    15  */
    16  
    17  package netwalker
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"regexp"
    23  	"strings"
    24  
    25  	"github.com/containerd/nerdctl/v2/pkg/netutil"
    26  )
    27  
    28  type Found struct {
    29  	Network    *netutil.NetworkConfig
    30  	Req        string // The raw request string. name, short ID, or long ID.
    31  	MatchIndex int    // Begins with 0, up to MatchCount - 1.
    32  	MatchCount int    // 1 on exact match. > 1 on ambiguous match. Never be <= 0.
    33  }
    34  
    35  type OnFound func(ctx context.Context, found Found) error
    36  
    37  type NetworkWalker struct {
    38  	Client  *netutil.CNIEnv
    39  	OnFound OnFound
    40  }
    41  
    42  // Walk walks networks and calls w.OnFound .
    43  // Req is name, short ID, or long ID.
    44  // Returns the number of the found entries.
    45  func (w *NetworkWalker) Walk(ctx context.Context, req string) (int, error) {
    46  	longIDExp, err := regexp.Compile(fmt.Sprintf("^sha256:%s.*", regexp.QuoteMeta(req)))
    47  	if err != nil {
    48  		return 0, err
    49  	}
    50  
    51  	shortIDExp, err := regexp.Compile(fmt.Sprintf("^%s", regexp.QuoteMeta(req)))
    52  	if err != nil {
    53  		return 0, err
    54  	}
    55  
    56  	idFilterF := func(n *netutil.NetworkConfig) bool {
    57  		if n.NerdctlID == nil {
    58  			// External network
    59  			return n.Name == req
    60  		}
    61  		return n.Name == req || longIDExp.Match([]byte(*n.NerdctlID)) || shortIDExp.Match([]byte(*n.NerdctlID))
    62  	}
    63  	networks, err := w.Client.FilterNetworks(idFilterF)
    64  	if err != nil {
    65  		return 0, err
    66  	}
    67  
    68  	matchCount := len(networks)
    69  
    70  	for i, network := range networks {
    71  		f := Found{
    72  			Network:    network,
    73  			Req:        req,
    74  			MatchIndex: i,
    75  			MatchCount: matchCount,
    76  		}
    77  		if e := w.OnFound(ctx, f); e != nil {
    78  			return -1, e
    79  		}
    80  	}
    81  	return matchCount, nil
    82  }
    83  
    84  // WalkAll calls `Walk` for each req in `reqs`.
    85  //
    86  // It can be used when the matchCount is not important (e.g., only care if there
    87  // is any error or if matchCount == 0 (not found error) when walking all reqs).
    88  // If `forceAll`, it calls `Walk` on every req
    89  // and return all errors joined by `\n`. If not `forceAll`, it returns the first error
    90  // encountered while calling `Walk`.
    91  // `allowSeudoNetwork` allows seudo network (host, none) to be passed to `Walk`, otherwise
    92  // an error is recorded for it.
    93  func (w *NetworkWalker) WalkAll(ctx context.Context, reqs []string, forceAll, allowSeudoNetwork bool) error {
    94  	var errs []string
    95  	for _, req := range reqs {
    96  		if !allowSeudoNetwork && (req == "host" || req == "none") {
    97  			err := fmt.Errorf("pseudo network not allowed: %s", req)
    98  			if !forceAll {
    99  				return err
   100  			}
   101  			errs = append(errs, err.Error())
   102  		} else {
   103  			n, err := w.Walk(ctx, req)
   104  			if err == nil && n == 0 {
   105  				err = fmt.Errorf("no such network: %s", req)
   106  			}
   107  			if err != nil {
   108  				if !forceAll {
   109  					return err
   110  				}
   111  				errs = append(errs, err.Error())
   112  			}
   113  		}
   114  	}
   115  	if len(errs) > 0 {
   116  		return fmt.Errorf("%d errors:\n%s", len(errs), strings.Join(errs, "\n"))
   117  	}
   118  	return nil
   119  }