code.gitea.io/gitea@v1.22.3/build/codeformat/formatimports.go (about)

     1  // Copyright 2021 The Gitea Authors. All rights reserved.
     2  // SPDX-License-Identifier: MIT
     3  
     4  package codeformat
     5  
     6  import (
     7  	"bytes"
     8  	"errors"
     9  	"io"
    10  	"os"
    11  	"sort"
    12  	"strings"
    13  )
    14  
    15  var importPackageGroupOrders = map[string]int{
    16  	"":                     1, // internal
    17  	"code.gitea.io/gitea/": 2,
    18  }
    19  
    20  var errInvalidCommentBetweenImports = errors.New("comments between imported packages are invalid, please move comments to the end of the package line")
    21  
    22  var (
    23  	importBlockBegin = []byte("\nimport (\n")
    24  	importBlockEnd   = []byte("\n)")
    25  )
    26  
    27  type importLineParsed struct {
    28  	group   string
    29  	pkg     string
    30  	content string
    31  }
    32  
    33  func parseImportLine(line string) (*importLineParsed, error) {
    34  	il := &importLineParsed{content: line}
    35  	p1 := strings.IndexRune(line, '"')
    36  	if p1 == -1 {
    37  		return nil, errors.New("invalid import line: " + line)
    38  	}
    39  	p1++
    40  	p := strings.IndexRune(line[p1:], '"')
    41  	if p == -1 {
    42  		return nil, errors.New("invalid import line: " + line)
    43  	}
    44  	p2 := p1 + p
    45  	il.pkg = line[p1:p2]
    46  
    47  	pDot := strings.IndexRune(il.pkg, '.')
    48  	pSlash := strings.IndexRune(il.pkg, '/')
    49  	if pDot != -1 && pDot < pSlash {
    50  		il.group = "domain-package"
    51  	}
    52  	for groupName := range importPackageGroupOrders {
    53  		if groupName == "" {
    54  			continue // skip internal
    55  		}
    56  		if strings.HasPrefix(il.pkg, groupName) {
    57  			il.group = groupName
    58  		}
    59  	}
    60  	return il, nil
    61  }
    62  
    63  type (
    64  	importLineGroup    []*importLineParsed
    65  	importLineGroupMap map[string]importLineGroup
    66  )
    67  
    68  func formatGoImports(contentBytes []byte) ([]byte, error) {
    69  	p1 := bytes.Index(contentBytes, importBlockBegin)
    70  	if p1 == -1 {
    71  		return nil, nil
    72  	}
    73  	p1 += len(importBlockBegin)
    74  	p := bytes.Index(contentBytes[p1:], importBlockEnd)
    75  	if p == -1 {
    76  		return nil, nil
    77  	}
    78  	p2 := p1 + p
    79  
    80  	importGroups := importLineGroupMap{}
    81  	r := bytes.NewBuffer(contentBytes[p1:p2])
    82  	eof := false
    83  	for !eof {
    84  		line, err := r.ReadString('\n')
    85  		eof = err == io.EOF
    86  		if err != nil && !eof {
    87  			return nil, err
    88  		}
    89  		line = strings.TrimSpace(line)
    90  		if line != "" {
    91  			if strings.HasPrefix(line, "//") || strings.HasPrefix(line, "/*") {
    92  				return nil, errInvalidCommentBetweenImports
    93  			}
    94  			importLine, err := parseImportLine(line)
    95  			if err != nil {
    96  				return nil, err
    97  			}
    98  			importGroups[importLine.group] = append(importGroups[importLine.group], importLine)
    99  		}
   100  	}
   101  
   102  	var groupNames []string
   103  	for groupName, importLines := range importGroups {
   104  		groupNames = append(groupNames, groupName)
   105  		sort.Slice(importLines, func(i, j int) bool {
   106  			return strings.Compare(importLines[i].pkg, importLines[j].pkg) < 0
   107  		})
   108  	}
   109  
   110  	sort.Slice(groupNames, func(i, j int) bool {
   111  		n1 := groupNames[i]
   112  		n2 := groupNames[j]
   113  		o1 := importPackageGroupOrders[n1]
   114  		o2 := importPackageGroupOrders[n2]
   115  		if o1 != 0 && o2 != 0 {
   116  			return o1 < o2
   117  		}
   118  		if o1 == 0 && o2 == 0 {
   119  			return strings.Compare(n1, n2) < 0
   120  		}
   121  		return o1 != 0
   122  	})
   123  
   124  	formattedBlock := bytes.Buffer{}
   125  	for _, groupName := range groupNames {
   126  		hasNormalImports := false
   127  		hasDummyImports := false
   128  		// non-dummy import comes first
   129  		for _, importLine := range importGroups[groupName] {
   130  			if strings.HasPrefix(importLine.content, "_") {
   131  				hasDummyImports = true
   132  			} else {
   133  				formattedBlock.WriteString("\t" + importLine.content + "\n")
   134  				hasNormalImports = true
   135  			}
   136  		}
   137  		// dummy (_ "pkg") comes later
   138  		if hasDummyImports {
   139  			if hasNormalImports {
   140  				formattedBlock.WriteString("\n")
   141  			}
   142  			for _, importLine := range importGroups[groupName] {
   143  				if strings.HasPrefix(importLine.content, "_") {
   144  					formattedBlock.WriteString("\t" + importLine.content + "\n")
   145  				}
   146  			}
   147  		}
   148  		formattedBlock.WriteString("\n")
   149  	}
   150  	formattedBlockBytes := bytes.TrimRight(formattedBlock.Bytes(), "\n")
   151  
   152  	var formattedBytes []byte
   153  	formattedBytes = append(formattedBytes, contentBytes[:p1]...)
   154  	formattedBytes = append(formattedBytes, formattedBlockBytes...)
   155  	formattedBytes = append(formattedBytes, contentBytes[p2:]...)
   156  	return formattedBytes, nil
   157  }
   158  
   159  // FormatGoImports format the imports by our rules (see unit tests)
   160  func FormatGoImports(file string, doWriteFile bool) error {
   161  	f, err := os.Open(file)
   162  	if err != nil {
   163  		return err
   164  	}
   165  	var contentBytes []byte
   166  	{
   167  		defer f.Close()
   168  		contentBytes, err = io.ReadAll(f)
   169  		if err != nil {
   170  			return err
   171  		}
   172  	}
   173  	formattedBytes, err := formatGoImports(contentBytes)
   174  	if err != nil {
   175  		return err
   176  	}
   177  	if formattedBytes == nil {
   178  		return nil
   179  	}
   180  	if bytes.Equal(contentBytes, formattedBytes) {
   181  		return nil
   182  	}
   183  
   184  	if doWriteFile {
   185  		f, err = os.OpenFile(file, os.O_TRUNC|os.O_WRONLY, 0o644)
   186  		if err != nil {
   187  			return err
   188  		}
   189  		defer f.Close()
   190  		_, err = f.Write(formattedBytes)
   191  		return err
   192  	}
   193  
   194  	return err
   195  }