github.com/streamingfast/substreams@v1.6.2/info/proto.go (about)

     1  package info
     2  
     3  import (
     4  	"fmt"
     5  	"regexp"
     6  	"strings"
     7  
     8  	"github.com/jhump/protoreflect/desc"
     9  	"github.com/jhump/protoreflect/desc/protoprint"
    10  	"google.golang.org/protobuf/types/descriptorpb"
    11  )
    12  
    13  type ProtoPackageParser struct {
    14  	allFiles            []*descriptorpb.FileDescriptorProto
    15  	nestedMessagesAdded map[string]bool
    16  
    17  	fileCodeMap    map[string]string
    18  	filePacakgeMap map[string]string
    19  }
    20  
    21  func NewProtoPackageParser(files []*descriptorpb.FileDescriptorProto) (*ProtoPackageParser, error) {
    22  	p := &ProtoPackageParser{
    23  		allFiles:            files,
    24  		nestedMessagesAdded: make(map[string]bool),
    25  	}
    26  
    27  	desc, err := desc.CreateFileDescriptors(p.allFiles)
    28  	if err != nil {
    29  		return nil, err
    30  	}
    31  
    32  	printer := &protoprint.Printer{
    33  		Compact: true,
    34  	}
    35  	fileCodeMap := make(map[string]string)
    36  	filePackageMap := make(map[string]string)
    37  	for fd, d := range desc {
    38  		r, err := printer.PrintProtoToString(d)
    39  		if err != nil {
    40  			return nil, err
    41  		}
    42  		fileCodeMap[fd] = r
    43  		filePackageMap[fd] = d.GetPackage()
    44  	}
    45  	p.fileCodeMap = fileCodeMap
    46  	p.filePacakgeMap = filePackageMap
    47  
    48  	return p, nil
    49  }
    50  
    51  func (p *ProtoPackageParser) Parse() (map[string][]*ProtoMessageInfo, error) {
    52  	result := map[string][]*ProtoMessageInfo{}
    53  
    54  	for _, file := range p.allFiles {
    55  		result[file.GetPackage()] = append(result[file.GetPackage()], p.extractMessages(file, "", file.MessageType)...)
    56  
    57  		for _, enum := range file.GetEnumType() {
    58  			doc := getDocumentationForSymbol(file.GetSourceCodeInfo(), enum.GetName())
    59  			protoCode, err := extractEnumBlock(p.fileCodeMap[file.GetName()], enum.GetName())
    60  			if err != nil {
    61  				return nil, fmt.Errorf("extract message block: %w", err)
    62  			}
    63  			result[file.GetPackage()] = append(result[file.GetPackage()], &ProtoMessageInfo{
    64  				Name:          enum.GetName(),
    65  				Package:       file.GetPackage(),
    66  				Type:          "enum",
    67  				File:          file.GetName(),
    68  				Proto:         protoCode,
    69  				Documentation: doc,
    70  			})
    71  		}
    72  
    73  	}
    74  
    75  	return result, nil
    76  }
    77  
    78  func (p *ProtoPackageParser) extractMessages(file *descriptorpb.FileDescriptorProto, prefix string, messages []*descriptorpb.DescriptorProto) []*ProtoMessageInfo {
    79  	var results []*ProtoMessageInfo
    80  
    81  	for _, msg := range messages {
    82  		doc := getDocumentationForSymbol(file.GetSourceCodeInfo(), msg.GetName())
    83  		protoCode, err := extractMessageBlock(p.fileCodeMap[file.GetName()], msg.GetName())
    84  		if err != nil {
    85  			return nil
    86  		}
    87  
    88  		name := prefix + msg.GetName()
    89  		result := &ProtoMessageInfo{
    90  			Name:          name,
    91  			Package:       file.GetPackage(),
    92  			Type:          "Message",
    93  			File:          file.GetName(),
    94  			Proto:         protoCode,
    95  			Documentation: doc,
    96  		}
    97  		if len(msg.GetNestedType()) > 0 {
    98  			result.NestedMessages = append(result.NestedMessages, p.extractMessages(file, name+".", msg.GetNestedType())...)
    99  		}
   100  		if len(msg.GetEnumType()) > 0 {
   101  			result.NestedMessages = append(result.NestedMessages, p.extractEnums(file, name+".", msg.GetEnumType())...)
   102  		}
   103  		results = append(results, result)
   104  	}
   105  
   106  	return results
   107  }
   108  
   109  func (p *ProtoPackageParser) extractEnums(file *descriptorpb.FileDescriptorProto, prefix string, enums []*descriptorpb.EnumDescriptorProto) []*ProtoMessageInfo {
   110  	var results []*ProtoMessageInfo
   111  
   112  	for _, enum := range enums {
   113  		doc := getDocumentationForSymbol(file.GetSourceCodeInfo(), enum.GetName())
   114  		protoCode, err := extractEnumBlock(p.fileCodeMap[file.GetName()], enum.GetName())
   115  		if err != nil {
   116  			return nil
   117  		}
   118  
   119  		name := prefix + enum.GetName()
   120  		results = append(results, &ProtoMessageInfo{
   121  			Name:          name,
   122  			Package:       file.GetPackage(),
   123  			Type:          "Enum",
   124  			File:          file.GetName(),
   125  			Proto:         protoCode,
   126  			Documentation: doc,
   127  		})
   128  	}
   129  
   130  	return results
   131  }
   132  
   133  func (p *ProtoPackageParser) GetPackagesList() []string {
   134  	packages := make(map[string]bool)
   135  	for _, file := range p.allFiles {
   136  		packages[file.GetPackage()] = true
   137  	}
   138  
   139  	var result []string
   140  	for pkg := range packages {
   141  		result = append(result, pkg)
   142  	}
   143  
   144  	return result
   145  }
   146  
   147  func (p *ProtoPackageParser) GetFilesSourceCode() map[string][]*SourceCodeInfo {
   148  	result := make(map[string][]*SourceCodeInfo)
   149  	for filename, pkg := range p.filePacakgeMap {
   150  		source := p.fileCodeMap[filename]
   151  		result[pkg] = append(result[pkg], &SourceCodeInfo{
   152  			Filename: filename,
   153  			Source:   source,
   154  		})
   155  	}
   156  
   157  	return result
   158  }
   159  
   160  // getDocumentationForSymbol extracts the leading comments associated with a named symbol (message/enum)
   161  func getDocumentationForSymbol(sourceInfo *descriptorpb.SourceCodeInfo, symbolName string) string {
   162  	for _, location := range sourceInfo.GetLocation() {
   163  		if strings.HasPrefix(strings.TrimSpace(location.GetLeadingComments()), symbolName) {
   164  			return strings.TrimSpace(location.GetLeadingComments())
   165  		}
   166  	}
   167  	return ""
   168  }
   169  
   170  func extractMessageBlock(protoContent, messageName string) (string, error) {
   171  	pattern := fmt.Sprintf(`(?s)message\s+%s\s+\{.*?\}`, messageName)
   172  	re := regexp.MustCompile(pattern)
   173  
   174  	matches := re.FindStringSubmatch(protoContent)
   175  	if matches == nil {
   176  		return "", fmt.Errorf("no message block found for message %q", messageName)
   177  	}
   178  
   179  	return matches[0], nil
   180  }
   181  
   182  func extractEnumBlock(protoContent, messageName string) (string, error) {
   183  	pattern := fmt.Sprintf(`(?s)enum\s+%s\s+\{.*?\}`, messageName)
   184  	re := regexp.MustCompile(pattern)
   185  
   186  	matches := re.FindStringSubmatch(protoContent)
   187  	if matches == nil {
   188  		return "", fmt.Errorf("no message block found for enum %q", messageName)
   189  	}
   190  
   191  	return matches[0], nil
   192  }