github.com/clubpay/ronykit/kit@v0.14.4-0.20240515065620-d0dace45cbc7/desc/parser.go (about)

     1  package desc
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"reflect"
     7  	"sort"
     8  	"strings"
     9  
    10  	"github.com/clubpay/ronykit/kit"
    11  	"github.com/clubpay/ronykit/kit/utils"
    12  )
    13  
    14  type ParsedService struct {
    15  	// Origin is the original service descriptor untouched by the parser
    16  	Origin *Service
    17  	// Contracts is the list of parsed contracts. The relation between ParsedContract
    18  	// and Contract is not 1:1 because a Contract can have multiple RouteSelectors.
    19  	// Each RouteSelector will be parsed into a ParsedContract.
    20  	Contracts []ParsedContract
    21  
    22  	// internals
    23  	visited map[string]struct{}
    24  	parsed  map[string]*ParsedMessage
    25  }
    26  
    27  func (ps *ParsedService) Messages() []ParsedMessage {
    28  	var msgs []ParsedMessage //nolint:prealloc
    29  	for _, m := range ps.parsed {
    30  		msgs = append(msgs, *m)
    31  	}
    32  
    33  	sort.Slice(msgs, func(i, j int) bool {
    34  		return msgs[i].Name < msgs[j].Name
    35  	})
    36  
    37  	return msgs
    38  }
    39  
    40  func (ps *ParsedService) parseContract(c Contract) []ParsedContract {
    41  	var pcs []ParsedContract //nolint:prealloc
    42  	for idx, s := range c.RouteSelectors {
    43  		name := s.Name
    44  		if name == "" {
    45  			name = c.Name
    46  		}
    47  		pc := ParsedContract{
    48  			Index:     idx,
    49  			GroupName: c.Name,
    50  			Name:      name,
    51  			Encoding:  s.Selector.GetEncoding().Tag(),
    52  		}
    53  
    54  		switch r := s.Selector.(type) {
    55  		case kit.RESTRouteSelector:
    56  			pc.Type = REST
    57  			pc.Path = r.GetPath()
    58  			pc.Method = r.GetMethod()
    59  
    60  			for _, p := range strings.Split(pc.Path, "/") {
    61  				if strings.HasPrefix(p, ":") {
    62  					pc.PathParams = append(pc.PathParams, p[1:])
    63  				}
    64  			}
    65  		case kit.RPCRouteSelector:
    66  			pc.Type = RPC
    67  			pc.Predicate = r.GetPredicate()
    68  		}
    69  
    70  		pc.Request = ParsedRequest{
    71  			Headers: c.InputHeaders,
    72  			Message: ps.parseMessage(c.Input, s.Selector.GetEncoding()),
    73  		}
    74  
    75  		if c.Output != nil {
    76  			pc.Responses = append(
    77  				pc.Responses,
    78  				ParsedResponse{
    79  					Message: ps.parseMessage(c.Output, s.Selector.GetEncoding()),
    80  				},
    81  			)
    82  		}
    83  
    84  		for _, e := range c.PossibleErrors {
    85  			pc.Responses = append(
    86  				pc.Responses,
    87  				ParsedResponse{
    88  					Message: ps.parseMessage(e.Message, s.Selector.GetEncoding()),
    89  					ErrCode: e.Code,
    90  					ErrItem: e.Item,
    91  				},
    92  			)
    93  		}
    94  
    95  		pcs = append(pcs, pc)
    96  	}
    97  
    98  	return pcs
    99  }
   100  
   101  func (ps *ParsedService) parseMessage(m kit.Message, enc kit.Encoding) ParsedMessage {
   102  	mt := reflect.TypeOf(m)
   103  	if mt.Kind() == reflect.Ptr {
   104  		mt = mt.Elem()
   105  	}
   106  
   107  	pm := ParsedMessage{
   108  		Name: mt.Name(),
   109  	}
   110  
   111  	if mt.Kind() != reflect.Struct {
   112  		return pm
   113  	}
   114  
   115  	ps.visited[mt.Name()] = struct{}{}
   116  
   117  	tagName := enc.Tag()
   118  	if tagName == "" {
   119  		tagName = kit.JSON.Tag()
   120  	}
   121  
   122  	// if we are here, it means that mt is a struct
   123  	var fields []ParsedField
   124  	for i := 0; i < mt.NumField(); i++ {
   125  		f := mt.Field(i)
   126  		ptn := getParsedStructTag(f.Tag, tagName)
   127  		pp := ParsedField{
   128  			GoName: f.Name,
   129  			Name:   ptn.Name,
   130  			Tag:    ptn,
   131  		}
   132  
   133  		ft := f.Type
   134  		if ft.Kind() == reflect.Ptr {
   135  			pp.Optional = true
   136  			ft = ft.Elem()
   137  		}
   138  
   139  		pp.Embedded = f.Anonymous
   140  		pp.Kind = parseKind(ft)
   141  		pp.Type = ft.String()
   142  		switch pp.Kind {
   143  		case Map:
   144  			// we only support maps with string keys
   145  			if ft.Key().Kind() != reflect.String {
   146  				continue
   147  			}
   148  
   149  			fallthrough
   150  		case Array:
   151  			pe := &ParsedElement{}
   152  			pp.Element = pe
   153  
   154  			keepGoing := true
   155  			for keepGoing {
   156  				ft = ft.Elem()
   157  				pe.Kind = parseKind(ft)
   158  				switch pe.Kind {
   159  				case Map, Array:
   160  					pe.Element = &ParsedElement{}
   161  					pe = pe.Element
   162  				case Object:
   163  					if ft.Kind() == reflect.Ptr {
   164  						ft = ft.Elem()
   165  					}
   166  					pe.Message = utils.ValPtr(ps.parseMessage(reflect.New(ft).Interface(), enc))
   167  					keepGoing = false
   168  				default:
   169  					keepGoing = false
   170  				}
   171  			}
   172  		case Object:
   173  			if ps.isParsed(ft.Name()) {
   174  				pp.Message = ps.parsed[ft.Name()]
   175  			} else if ps.isVisited(ft.Name()) {
   176  				panic(fmt.Sprintf("infinite recursion detected: %s.%s", mt.Name(), ft.Name()))
   177  			} else {
   178  				pp.Message = utils.ValPtr(ps.parseMessage(reflect.New(ft).Interface(), enc))
   179  			}
   180  
   181  		case None:
   182  			continue
   183  		}
   184  
   185  		fields = append(fields, pp)
   186  	}
   187  
   188  	pm.Fields = fields
   189  	ps.parsed[mt.Name()] = &pm
   190  
   191  	return pm
   192  }
   193  
   194  func (ps *ParsedService) isParsed(name string) bool {
   195  	_, ok := ps.parsed[name]
   196  
   197  	return ok
   198  }
   199  
   200  func (ps *ParsedService) isVisited(name string) bool {
   201  	_, ok := ps.visited[name]
   202  
   203  	return ok
   204  }
   205  
   206  type ContractType string
   207  
   208  const (
   209  	REST ContractType = "REST"
   210  	RPC  ContractType = "RPC"
   211  )
   212  
   213  type ParsedContract struct {
   214  	Index     int
   215  	GroupName string
   216  	Name      string
   217  	Encoding  string
   218  
   219  	Type       ContractType
   220  	Path       string
   221  	PathParams []string
   222  	Method     string
   223  	Predicate  string
   224  
   225  	Request   ParsedRequest
   226  	Responses []ParsedResponse
   227  }
   228  
   229  func (pc ParsedContract) SuggestName() string {
   230  	if pc.Name != "" {
   231  		return pc.Name
   232  	}
   233  
   234  	switch pc.Type {
   235  	case REST:
   236  		parts := strings.Split(pc.Path, "/")
   237  		for i := len(parts) - 1; i >= 0; i-- {
   238  			if strings.HasPrefix(parts[i], ":") {
   239  				continue
   240  			}
   241  
   242  			return utils.ToCamel(parts[i])
   243  		}
   244  	case RPC:
   245  		return utils.ToCamel(pc.Predicate)
   246  	}
   247  
   248  	return fmt.Sprintf("%s%d", pc.GroupName, pc.Index)
   249  }
   250  
   251  func (pc ParsedContract) OKResponse() ParsedResponse {
   252  	for _, r := range pc.Responses {
   253  		if !r.IsError() {
   254  			return r
   255  		}
   256  	}
   257  
   258  	return ParsedResponse{}
   259  }
   260  
   261  func (pc ParsedContract) IsPathParam(name string) bool {
   262  	for _, p := range pc.PathParams {
   263  		if p == name {
   264  			return true
   265  		}
   266  	}
   267  
   268  	return false
   269  }
   270  
   271  type ParsedRequest struct {
   272  	Headers []Header
   273  	Message ParsedMessage
   274  }
   275  
   276  type ParsedResponse struct {
   277  	Message ParsedMessage
   278  	ErrCode int
   279  	ErrItem string
   280  }
   281  
   282  func (pr ParsedResponse) IsError() bool {
   283  	return pr.ErrCode != 0
   284  }
   285  
   286  type Kind string
   287  
   288  const (
   289  	None    Kind = ""
   290  	Bool    Kind = "boolean"
   291  	String  Kind = "string"
   292  	Integer Kind = "integer"
   293  	Float   Kind = "float"
   294  	Byte    Kind = "byte"
   295  	Object  Kind = "object"
   296  	Map     Kind = "map"
   297  	Array   Kind = "array"
   298  )
   299  
   300  type ParsedMessage struct {
   301  	Name   string
   302  	Kind   Kind
   303  	Fields []ParsedField
   304  }
   305  
   306  func (pm ParsedMessage) JSON() string {
   307  	m := map[string]any{}
   308  	for _, p := range pm.Fields {
   309  		switch p.Kind {
   310  		default:
   311  			m[p.Name] = p.Kind
   312  		case Object:
   313  			m[p.Name] = json.RawMessage(p.Message.JSON())
   314  		case Map:
   315  			var inner any
   316  			switch p.Element.Kind {
   317  			default:
   318  				inner = p.Element.Kind
   319  			case Object:
   320  				inner = json.RawMessage(p.Element.Message.JSON())
   321  			case Integer, Float, Byte:
   322  				inner = 0
   323  			case Array:
   324  				inner = []any{p.Element.Element.Kind}
   325  			case Map:
   326  				inner = map[string]any{
   327  					"keyName": p.Element.Element.Kind,
   328  				}
   329  			}
   330  			m[p.Name] = map[string]any{
   331  				"keyName": inner,
   332  			}
   333  		case Array:
   334  			var inner any
   335  			switch p.Element.Kind {
   336  			default:
   337  				inner = p.Element.Kind
   338  			case Object:
   339  				inner = json.RawMessage(p.Element.Message.JSON())
   340  			case Integer, Float, Byte:
   341  				inner = 0
   342  			case Array:
   343  				inner = []any{p.Element.Element.Kind}
   344  			case Map:
   345  				inner = map[string]any{
   346  					"keyName": p.Element.Element.Kind,
   347  				}
   348  			}
   349  			m[p.Name] = []any{inner}
   350  		case Integer, Float, Byte:
   351  			m[p.Name] = 0
   352  		}
   353  	}
   354  
   355  	d, _ := json.MarshalIndent(m, "", "  ")
   356  
   357  	return string(d)
   358  }
   359  
   360  func (pm ParsedMessage) String() string {
   361  	sb := strings.Builder{}
   362  	sb.WriteString(pm.Name)
   363  	sb.WriteString("[")
   364  	for idx, p := range pm.Fields {
   365  		if idx > 0 {
   366  			sb.WriteString(", ")
   367  		}
   368  		sb.WriteString(p.Name)
   369  		sb.WriteString(":")
   370  		sb.WriteString(string(p.Kind))
   371  
   372  		switch p.Kind {
   373  		case Map, Array:
   374  			sb.WriteString(":")
   375  			sb.WriteString(p.Element.String())
   376  		case Object:
   377  			sb.WriteString(":")
   378  		}
   379  	}
   380  	sb.WriteString("]")
   381  
   382  	return sb.String()
   383  }
   384  
   385  func (pm ParsedMessage) FieldByName(name string) *ParsedField {
   386  	for _, f := range pm.Fields {
   387  		if f.Name == name {
   388  			return &f
   389  		}
   390  	}
   391  
   392  	return nil
   393  }
   394  
   395  func (pm ParsedMessage) FieldByGoName(name string) *ParsedField {
   396  	for _, f := range pm.Fields {
   397  		if f.GoName == name {
   398  			return &f
   399  		}
   400  	}
   401  
   402  	return nil
   403  }
   404  
   405  type ParsedField struct {
   406  	GoName      string
   407  	Name        string
   408  	Tag         ParsedStructTag
   409  	SampleValue string
   410  	Optional    bool
   411  	Type        string
   412  	Kind        Kind
   413  	Embedded    bool
   414  
   415  	// Kind == Object
   416  	// Message is the parsed message if the kind is Object.
   417  	Message *ParsedMessage
   418  	// Kind == Array || Kind == Map
   419  	Element *ParsedElement
   420  }
   421  
   422  type ParsedElement struct {
   423  	Kind    Kind
   424  	Element *ParsedElement
   425  	Message *ParsedMessage
   426  }
   427  
   428  func (pf ParsedElement) String() string {
   429  	switch pf.Kind {
   430  	case Map:
   431  		return fmt.Sprintf("map[%s]", pf.Element.String())
   432  	case Array:
   433  		return fmt.Sprintf("array[%s]", pf.Element.String())
   434  	case Object:
   435  		return pf.Message.String()
   436  	default:
   437  		return string(pf.Kind)
   438  	}
   439  }
   440  
   441  // Parse extracts the Service descriptor from the input ServiceDesc
   442  // Refer to ParseService for more details.
   443  func Parse(desc ServiceDesc) ParsedService {
   444  	return ParseService(desc.Desc())
   445  }
   446  
   447  // ParseService extracts information from a Service descriptor using reflection.
   448  // It returns a ParsedService. The ParsedService is useful to generate custom
   449  // code based on the service descriptor.
   450  // In the contrib package this is used to generate the swagger spec and postman collections.
   451  func ParseService(svc *Service) ParsedService {
   452  	// reset the parsed map
   453  	// we need this map, to prevent infinite recursion
   454  
   455  	pd := ParsedService{
   456  		Origin:  svc,
   457  		parsed:  make(map[string]*ParsedMessage),
   458  		visited: make(map[string]struct{}),
   459  	}
   460  
   461  	for _, c := range svc.Contracts {
   462  		pd.Contracts = append(pd.Contracts, pd.parseContract(c)...)
   463  	}
   464  
   465  	return pd
   466  }
   467  
   468  func parseKind(t reflect.Type) Kind {
   469  	for t.Kind() == reflect.Ptr {
   470  		t = t.Elem()
   471  	}
   472  
   473  	switch t.Kind() {
   474  	default:
   475  	case reflect.Bool:
   476  		return Bool
   477  	case reflect.String:
   478  		return String
   479  	case reflect.Uint8, reflect.Int8:
   480  		return Byte
   481  	case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64,
   482  		reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   483  		return Integer
   484  	case reflect.Float32, reflect.Float64:
   485  		return Float
   486  	case reflect.Struct:
   487  		return Object
   488  	case reflect.Map:
   489  		return Map
   490  	case reflect.Slice, reflect.Array:
   491  		return Array
   492  	}
   493  
   494  	return None
   495  }
   496  
   497  const (
   498  	swagTagKey   = "swag"
   499  	swagSep      = ";"
   500  	swagIdentSep = ":"
   501  	swagValueSep = ","
   502  )
   503  
   504  type ParsedStructTag struct {
   505  	Name           string
   506  	Optional       bool
   507  	PossibleValues []string
   508  	Deprecated     bool
   509  }
   510  
   511  func getParsedStructTag(tag reflect.StructTag, name string) ParsedStructTag {
   512  	pst := ParsedStructTag{}
   513  	nameTag := tag.Get(name)
   514  	if nameTag == "" {
   515  		return pst
   516  	}
   517  
   518  	// This is a hack to remove omitempty from tags
   519  	fNameParts := strings.Split(nameTag, swagValueSep)
   520  	if len(fNameParts) > 0 {
   521  		pst.Name = strings.TrimSpace(fNameParts[0])
   522  	}
   523  
   524  	swagTag := tag.Get(swagTagKey)
   525  	parts := strings.Split(swagTag, swagSep)
   526  	for _, p := range parts {
   527  		x := strings.TrimSpace(strings.ToLower(p))
   528  		switch {
   529  		case x == "optional":
   530  			pst.Optional = true
   531  		case x == "deprecated":
   532  			pst.Deprecated = true
   533  		case strings.HasPrefix(x, "enum:"):
   534  			xx := strings.SplitN(p, swagIdentSep, 2)
   535  			if len(xx) == 2 {
   536  				xx = strings.Split(xx[1], swagValueSep)
   537  				for _, v := range xx {
   538  					pst.PossibleValues = append(pst.PossibleValues, strings.TrimSpace(v))
   539  				}
   540  			}
   541  		}
   542  	}
   543  
   544  	return pst
   545  }