github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/pkg/codesearch/codesearch.go (about)

     1  // Copyright 2025 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package codesearch
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"os"
    10  	"path/filepath"
    11  	"strings"
    12  
    13  	"github.com/google/syzkaller/pkg/osutil"
    14  )
    15  
    16  type Index struct {
    17  	db      *Database
    18  	srcDirs []string
    19  }
    20  
    21  type Command struct {
    22  	Name  string
    23  	NArgs int
    24  	Func  func(*Index, []string) (string, error)
    25  }
    26  
    27  // Commands are used to run unit tests and for the syz-codesearch tool.
    28  var Commands = []Command{
    29  	{"file-index", 1, func(index *Index, args []string) (string, error) {
    30  		ok, entities, err := index.FileIndex(args[0])
    31  		if err != nil || !ok {
    32  			return notFound, err
    33  		}
    34  		b := new(strings.Builder)
    35  		fmt.Fprintf(b, "file %v defines the following entities:\n\n", args[0])
    36  		for _, ent := range entities {
    37  			fmt.Fprintf(b, "%v %v\n", ent.Kind, ent.Name)
    38  		}
    39  		return b.String(), nil
    40  	}},
    41  	{"def-comment", 2, func(index *Index, args []string) (string, error) {
    42  		info, err := index.DefinitionComment(args[0], args[1])
    43  		if err != nil || info == nil {
    44  			return notFound, err
    45  		}
    46  		if info.Body == "" {
    47  			return fmt.Sprintf("%v %v is defined in %v and is not commented\n",
    48  				info.Kind, args[1], info.File), nil
    49  		}
    50  		return fmt.Sprintf("%v %v is defined in %v and commented as:\n\n%v",
    51  			info.Kind, args[1], info.File, info.Body), nil
    52  	}},
    53  	{"def-source", 3, func(index *Index, args []string) (string, error) {
    54  		info, err := index.DefinitionSource(args[0], args[1], args[2] == "yes")
    55  		if err != nil || info == nil {
    56  			return notFound, err
    57  		}
    58  		return fmt.Sprintf("%v %v is defined in %v:\n\n%v", info.Kind, args[1], info.File, info.Body), nil
    59  	}},
    60  }
    61  
    62  const notFound = "not found\n"
    63  
    64  func NewIndex(databaseFile string, srcDirs []string) (*Index, error) {
    65  	db, err := osutil.ReadJSON[*Database](databaseFile)
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  	return &Index{
    70  		db:      db,
    71  		srcDirs: srcDirs,
    72  	}, nil
    73  }
    74  
    75  func (index *Index) Command(cmd string, args []string) (string, error) {
    76  	for _, meta := range Commands {
    77  		if cmd == meta.Name {
    78  			if len(args) != meta.NArgs {
    79  				return "", fmt.Errorf("codesearch command %v requires %v args, but %v provided",
    80  					cmd, meta.NArgs, len(args))
    81  			}
    82  			return meta.Func(index, args)
    83  		}
    84  	}
    85  	return "", fmt.Errorf("unknown codesearch command %v", cmd)
    86  }
    87  
    88  type Entity struct {
    89  	Kind string
    90  	Name string
    91  }
    92  
    93  func (index *Index) FileIndex(file string) (bool, []Entity, error) {
    94  	var entities []Entity
    95  	for _, def := range index.db.Definitions {
    96  		if def.Body.File == file {
    97  			entities = append(entities, Entity{
    98  				Kind: def.Kind,
    99  				Name: def.Name,
   100  			})
   101  		}
   102  	}
   103  	return len(entities) != 0, entities, nil
   104  }
   105  
   106  type EntityInfo struct {
   107  	File string
   108  	Kind string
   109  	Body string
   110  }
   111  
   112  func (index *Index) DefinitionComment(contextFile, name string) (*EntityInfo, error) {
   113  	return index.definitionSource(contextFile, name, true, false)
   114  }
   115  
   116  func (index *Index) DefinitionSource(contextFile, name string, includeLines bool) (*EntityInfo, error) {
   117  	return index.definitionSource(contextFile, name, false, includeLines)
   118  }
   119  
   120  func (index *Index) definitionSource(contextFile, name string, comment, includeLines bool) (*EntityInfo, error) {
   121  	def := index.findDefinition(contextFile, name)
   122  	if def == nil {
   123  		return nil, nil
   124  	}
   125  	lineRange := def.Body
   126  	if comment {
   127  		lineRange = def.Comment
   128  	}
   129  	src, err := index.formatSource(lineRange, includeLines)
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  	return &EntityInfo{
   134  		File: def.Body.File,
   135  		Kind: def.Kind,
   136  		Body: src,
   137  	}, nil
   138  }
   139  
   140  func (index *Index) findDefinition(contextFile, name string) *Definition {
   141  	var weakMatch *Definition
   142  	for _, def := range index.db.Definitions {
   143  		if def.Name == name {
   144  			if def.Body.File == contextFile {
   145  				return def
   146  			}
   147  			if !def.IsStatic {
   148  				weakMatch = def
   149  			}
   150  		}
   151  	}
   152  	return weakMatch
   153  }
   154  
   155  func (index *Index) formatSource(lines LineRange, includeLines bool) (string, error) {
   156  	if lines.File == "" {
   157  		return "", nil
   158  	}
   159  	for _, dir := range index.srcDirs {
   160  		file := filepath.Join(dir, lines.File)
   161  		if !osutil.IsExist(file) {
   162  			continue
   163  		}
   164  		return formatSourceFile(file, lines.StartLine, lines.EndLine, includeLines)
   165  	}
   166  	return "", fmt.Errorf("codesearch: can't find %q file in any of %v", lines.File, index.srcDirs)
   167  }
   168  
   169  func formatSourceFile(file string, start, end int, includeLines bool) (string, error) {
   170  	data, err := os.ReadFile(file)
   171  	if err != nil {
   172  		return "", err
   173  	}
   174  	lines := bytes.Split(data, []byte{'\n'})
   175  	start--
   176  	end--
   177  	if start < 0 || end < start || end > len(lines) {
   178  		return "", fmt.Errorf("codesearch: bad line range [%v-%v] for file %v with %v lines",
   179  			start, end, file, len(lines))
   180  	}
   181  	b := new(strings.Builder)
   182  	for line := start; line <= end; line++ {
   183  		if includeLines {
   184  			fmt.Fprintf(b, "%4v:\t%s\n", line, lines[line])
   185  		} else {
   186  			fmt.Fprintf(b, "%s\n", lines[line])
   187  		}
   188  	}
   189  	return b.String(), nil
   190  }