github.com/cloudwego/dynamicgo@v0.2.6-0.20240519101509-707f41b6b834/proto/idl.go (about)

     1  package proto
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"math"
     7  
     8  	"github.com/cloudwego/dynamicgo/meta"
     9  	"github.com/jhump/protoreflect/desc"
    10  	"github.com/jhump/protoreflect/desc/protoparse"
    11  )
    12  
    13  type compilingInstance struct {
    14  	desc        *TypeDescriptor
    15  	opts        Options
    16  	parseTarget ParseTarget
    17  }
    18  
    19  type compilingCache map[string]*compilingInstance
    20  
    21  const (
    22  	Request ParseTarget = iota
    23  	Response
    24  	Exception
    25  )
    26  
    27  // ParseTarget indicates the target to parse
    28  type ParseTarget uint8
    29  
    30  // Options is options for parsing thrift IDL.
    31  type Options struct {
    32  	// ParseServiceMode indicates how to parse service.
    33  	ParseServiceMode meta.ParseServiceMode
    34  
    35  	MapFieldWay meta.MapFieldWay // not implemented.
    36  
    37  	ParseFieldRandomRate float64 // not implemented.
    38  
    39  	ParseEnumAsInt64 bool // not implemented.
    40  
    41  	SetOptionalBitmap bool // not implemented.
    42  
    43  	UseDefaultValue bool // not implemented.
    44  
    45  	ParseFunctionMode meta.ParseFunctionMode // not implemented.
    46  
    47  	EnableProtoBase bool // not implemented.
    48  }
    49  
    50  // NewDefaultOptions creates a default Options.
    51  func NewDefaultOptions() Options {
    52  	return Options{}
    53  }
    54  
    55  // NewDescritorFromPath behaviors like NewDescritorFromPath, besides it uses DefaultOptions.
    56  func NewDescritorFromPath(ctx context.Context, path string, importDirs ...string) (*ServiceDescriptor, error) {
    57  	return NewDefaultOptions().NewDescriptorFromPath(ctx, path, importDirs...)
    58  }
    59  
    60  // NewDescritorFromContent creates a ServiceDescriptor from a proto path and its imports, which uses the given options.
    61  // The importDirs is used to find the include files.
    62  func (opts Options) NewDescriptorFromPath(ctx context.Context, path string, importDirs ...string) (*ServiceDescriptor, error) {
    63  	var pbParser protoparse.Parser
    64  	ImportPaths := []string{""} // default import "" when path is absolute path, no need to join with importDirs
    65  	// append importDirs to ImportPaths
    66  	ImportPaths = append(ImportPaths, importDirs...)
    67  	pbParser.ImportPaths = ImportPaths
    68  	fds, err := pbParser.ParseFiles(path)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	fd := fds[0]
    73  	svc, err := parse(ctx, fd, opts.ParseServiceMode, opts)
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  	return svc, nil
    78  }
    79  
    80  // NewDescritorFromContent behaviors like NewDescritorFromPath, besides it uses DefaultOptions.
    81  func NewDescritorFromContent(ctx context.Context, path, content string, includes map[string]string, importDirs ...string) (*ServiceDescriptor, error) {
    82  	return NewDefaultOptions().NewDesccriptorFromContent(ctx, path, content, includes, importDirs...)
    83  }
    84  
    85  func (opts Options) NewDesccriptorFromContent(ctx context.Context, path, content string, includes map[string]string, importDirs ...string) (*ServiceDescriptor, error) {
    86  
    87  	var pbParser protoparse.Parser
    88  	// add main proto to includes
    89  	includes[path] = content
    90  
    91  	ImportPaths := []string{""} // default import "" when path is absolute path, no need to join with importDirs
    92  	// append importDirs to ImportPaths
    93  	ImportPaths = append(ImportPaths, importDirs...)
    94  
    95  	pbParser.ImportPaths = ImportPaths
    96  	pbParser.Accessor = protoparse.FileContentsFromMap(includes)
    97  	fds, err := pbParser.ParseFiles(path)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  
   102  	fd := fds[0]
   103  	sdsc, err := parse(ctx, fd, opts.ParseServiceMode, opts)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  
   108  	return sdsc, nil
   109  }
   110  
   111  // Parse descriptor from fileDescriptor
   112  func parse(ctx context.Context, fileDesc *desc.FileDescriptor, mode meta.ParseServiceMode, opts Options, methods ...string) (*ServiceDescriptor, error) {
   113  	svcs := fileDesc.GetServices()
   114  	if len(svcs) == 0 {
   115  		return nil, errors.New("empty service from idls")
   116  	}
   117  
   118  	sDsc := &ServiceDescriptor{
   119  		methods: map[string]*MethodDescriptor{},
   120  	}
   121  
   122  	structsCache := compilingCache{}
   123  	// support one service
   124  	switch mode {
   125  	case meta.LastServiceOnly:
   126  		svcs = svcs[len(svcs)-1:]
   127  		sDsc.serviceName = svcs[len(svcs)-1].GetName()
   128  	case meta.FirstServiceOnly:
   129  		svcs = svcs[:1]
   130  		sDsc.serviceName = svcs[0].GetName()
   131  	case meta.CombineServices:
   132  		sDsc.serviceName = "CombinedService"
   133  	}
   134  
   135  	for _, svc := range svcs {
   136  		for _, mtd := range svc.GetMethods() {
   137  			var req *TypeDescriptor
   138  			var resp *TypeDescriptor
   139  			var err error
   140  
   141  			req, err = parseMessage(ctx, mtd.GetInputType(), structsCache, 0, opts, Request)
   142  			if err != nil {
   143  				return nil, err
   144  			}
   145  
   146  			resp, err = parseMessage(ctx, mtd.GetOutputType(), structsCache, 0, opts, Response)
   147  			if err != nil {
   148  				return nil, err
   149  			}
   150  
   151  			sDsc.methods[mtd.GetName()] = &MethodDescriptor{
   152  				name:   mtd.GetName(),
   153  				input:  req,
   154  				output: resp,
   155  			}
   156  
   157  		}
   158  	}
   159  	return sDsc, nil
   160  }
   161  
   162  func parseMessage(ctx context.Context, msgDesc *desc.MessageDescriptor, cache compilingCache, recursionDepth int, opts Options, parseTarget ParseTarget) (*TypeDescriptor, error) {
   163  	if tycache, ok := cache[msgDesc.GetName()]; ok && tycache.parseTarget == parseTarget {
   164  		return tycache.desc, nil
   165  	}
   166  
   167  	var ty *TypeDescriptor
   168  	var err error
   169  	fields := msgDesc.GetFields()
   170  	md := &MessageDescriptor{
   171  		baseId:    FieldNumber(math.MaxInt32),
   172  		ids:       FieldNumberMap{},
   173  		names:     FieldNameMap{},
   174  	}
   175  
   176  	ty = &TypeDescriptor{
   177  		typ:  MESSAGE,
   178  		name: msgDesc.GetName(),
   179  		msg:  md,
   180  	}
   181  
   182  	cache[ty.name] = &compilingInstance{
   183  		desc:        ty,
   184  		opts:        opts,
   185  		parseTarget: parseTarget,
   186  	}
   187  
   188  	for _, field := range fields {
   189  		descpbType := field.GetType()
   190  		id := field.GetNumber()
   191  		name := field.GetName()
   192  		jsonName := field.GetJSONName()
   193  		fieldDesc := &FieldDescriptor{
   194  			id:       FieldNumber(id),
   195  			name:     name,
   196  			jsonName: jsonName,
   197  		}
   198  
   199  		// MAP TypeDescriptor
   200  		if field.IsMap() {
   201  			kt := builtinTypes[field.GetMapKeyType().GetType()]
   202  			vt := builtinTypes[field.GetMapValueType().GetType()]
   203  			if vt.Type() == MESSAGE {
   204  				vt, err = parseMessage(ctx, field.GetMapValueType().GetMessageType(), cache, recursionDepth+1, opts, parseTarget)
   205  				if err != nil {
   206  					return nil, err
   207  				}
   208  			}
   209  
   210  			mapt, err := parseMessage(ctx, field.GetMessageType(), cache, recursionDepth+1, opts, parseTarget)
   211  			if err != nil {
   212  				return nil, err
   213  			}
   214  
   215  			fieldDesc.typ = &TypeDescriptor{
   216  				typ:    MAP,
   217  				name:   name,
   218  				key:    kt,
   219  				elem:   vt,
   220  				baseId: FieldNumber(id),
   221  				msg:    mapt.msg,
   222  			}
   223  			fieldDesc.kind = MessageKind
   224  		} else {
   225  			// basic type or message TypeDescriptor
   226  			t := builtinTypes[descpbType]
   227  			if t.Type() == MESSAGE {
   228  				t, err = parseMessage(ctx, field.GetMessageType(), cache, recursionDepth+1, opts, parseTarget)
   229  				if err != nil {
   230  					return nil, err
   231  				}
   232  			}
   233  			fieldDesc.kind = t.typ.TypeToKind()
   234  
   235  			// LIST TypeDescriptor
   236  			if field.IsRepeated() {
   237  				t = &TypeDescriptor{
   238  					typ:    LIST,
   239  					name:   name,
   240  					elem:   t,
   241  					baseId: FieldNumber(id),
   242  					msg:    t.msg,
   243  				}
   244  			}
   245  			fieldDesc.typ = t
   246  		}
   247  
   248  		// add fieldDescriptor to MessageDescriptor
   249  		// md.ids[FieldNumber(id)] = fieldDesc
   250  		md.ids.Set(FieldNumber(id), fieldDesc)
   251  		md.names.Set(name, fieldDesc)
   252  		md.names.Set(jsonName, fieldDesc)
   253  	}
   254  	md.names.Build()
   255  
   256  	return ty, nil
   257  }