gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/tools/go_generics/imports.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package main
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"go/ast"
    21  	"go/format"
    22  	"go/parser"
    23  	"go/token"
    24  	"sort"
    25  	"strconv"
    26  
    27  	"gvisor.dev/gvisor/tools/go_generics/globals"
    28  )
    29  
    30  type importedPackage struct {
    31  	newName string
    32  	path    string
    33  }
    34  
    35  // updateImportIdent modifies the given import identifier with the new name
    36  // stored in the used map. If the identifier doesn't exist in the used map yet,
    37  // a new name is generated and inserted into the map.
    38  func updateImportIdent(orig string, imports mapValue, id *ast.Ident, used map[string]*importedPackage) error {
    39  	importName := id.Name
    40  
    41  	// If the name is already in the table, just use the new name.
    42  	m := used[importName]
    43  	if m != nil {
    44  		id.Name = m.newName
    45  		return nil
    46  	}
    47  
    48  	// Create a new entry in the used map.
    49  	path := imports[importName]
    50  	if path == "" {
    51  		return fmt.Errorf("unknown path to package '%s', used in '%s'", importName, orig)
    52  	}
    53  
    54  	m = &importedPackage{
    55  		newName: fmt.Sprintf("__generics_imported%d", len(used)),
    56  		path:    strconv.Quote(path),
    57  	}
    58  	used[importName] = m
    59  
    60  	id.Name = m.newName
    61  
    62  	return nil
    63  }
    64  
    65  // convertExpression creates a new string that is a copy of the input one with
    66  // all imports references renamed to the names in the "used" map. If the
    67  // referenced import isn't in "used" yet, a new one is created based on the path
    68  // in "imports" and stored in "used". For example, if string s is
    69  // "math.MaxUint32-math.MaxUint16+10", it would be converted to
    70  // "x.MaxUint32-x.MathUint16+10", where x is a generated name.
    71  func convertExpression(s string, imports mapValue, used map[string]*importedPackage) (string, error) {
    72  	// Parse the expression in the input string.
    73  	expr, err := parser.ParseExpr(s)
    74  	if err != nil {
    75  		return "", fmt.Errorf("unable to parse \"%s\": %v", s, err)
    76  	}
    77  
    78  	// Go through the AST and update references.
    79  	var retErr error
    80  	ast.Inspect(expr, func(n ast.Node) bool {
    81  		switch x := n.(type) {
    82  		case *ast.SelectorExpr:
    83  			if id := globals.GetIdent(x.X); id != nil {
    84  				if err := updateImportIdent(s, imports, id, used); err != nil {
    85  					retErr = err
    86  				}
    87  				return false
    88  			}
    89  		}
    90  		return true
    91  	})
    92  	if retErr != nil {
    93  		return "", retErr
    94  	}
    95  
    96  	// Convert the modified AST back to a string.
    97  	fset := token.NewFileSet()
    98  	var buf bytes.Buffer
    99  	if err := format.Node(&buf, fset, expr); err != nil {
   100  		return "", err
   101  	}
   102  
   103  	return string(buf.Bytes()), nil
   104  }
   105  
   106  // updateImports replaces all maps in the input slice with copies where the
   107  // mapped values have had all references to imported packages renamed to
   108  // generated names. It also returns an import declaration for all the renamed
   109  // import packages.
   110  //
   111  // For example, if the input maps contains A=math.B and C=math.D, the updated
   112  // maps will instead contain A=__generics_imported0.B and
   113  // C=__generics_imported0.C, and the 'import __generics_imported0 "math"' would
   114  // be returned as the import declaration.
   115  func updateImports(maps []mapValue, imports mapValue) (ast.Decl, error) {
   116  	importsUsed := make(map[string]*importedPackage)
   117  
   118  	// Update all maps.
   119  	for i, m := range maps {
   120  		newMap := make(mapValue)
   121  		for n, e := range m {
   122  			updated, err := convertExpression(e, imports, importsUsed)
   123  			if err != nil {
   124  				return nil, err
   125  			}
   126  
   127  			newMap[n] = updated
   128  		}
   129  		maps[i] = newMap
   130  	}
   131  
   132  	// Nothing else to do if no imports are used in the expressions.
   133  	if len(importsUsed) == 0 {
   134  		return nil, nil
   135  	}
   136  	var names []string
   137  	for n := range importsUsed {
   138  		names = append(names, n)
   139  	}
   140  	// Sort the new imports for deterministic build outputs.
   141  	sort.Strings(names)
   142  
   143  	// Create spec array for each new import.
   144  	specs := make([]ast.Spec, 0, len(importsUsed))
   145  	for _, n := range names {
   146  		i := importsUsed[n]
   147  		specs = append(specs, &ast.ImportSpec{
   148  			Name: &ast.Ident{Name: i.newName},
   149  			Path: &ast.BasicLit{Value: i.path},
   150  		})
   151  	}
   152  
   153  	return &ast.GenDecl{
   154  		Tok:    token.IMPORT,
   155  		Specs:  specs,
   156  		Lparen: token.NoPos + 1,
   157  	}, nil
   158  }