github.com/mailru/activerecord@v1.12.2/internal/pkg/parser/partialstruct.go (about)

     1  package parser
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/parser"
     7  	"go/token"
     8  	"path/filepath"
     9  
    10  	"github.com/mailru/activerecord/internal/pkg/arerror"
    11  	"github.com/mailru/activerecord/internal/pkg/ds"
    12  )
    13  
    14  func parseStructFields(dst *ds.RecordPackage, gen *ast.GenDecl, name, pkgName string) ([]ds.PartialFieldDeclaration, error) {
    15  	for _, spec := range gen.Specs {
    16  		currType, ok := spec.(*ast.TypeSpec)
    17  		if !ok {
    18  			continue
    19  		}
    20  
    21  		switch curr := currType.Type.(type) {
    22  		case *ast.StructType:
    23  			if currType.Name.Name != name {
    24  				continue
    25  			}
    26  
    27  			if curr.Fields == nil {
    28  				return nil, &arerror.ErrParseTypeStructDecl{Name: currType.Name.Name, Err: arerror.ErrParseStructureEmpty}
    29  			}
    30  
    31  			partialFields := make([]ds.PartialFieldDeclaration, 0, len(curr.Fields.List))
    32  
    33  			for _, field := range curr.Fields.List {
    34  				if len(field.Names) == 0 {
    35  					continue
    36  				}
    37  
    38  				t, err := ParseFieldType(dst, name, pkgName, field.Type)
    39  				if err != nil {
    40  					return nil, &arerror.ErrParseTypeFieldStructDecl{Name: name, FieldType: field.Names[0].Name, Err: err}
    41  				}
    42  
    43  				field := ds.PartialFieldDeclaration{
    44  					Name: field.Names[0].Name,
    45  					Type: t,
    46  				}
    47  
    48  				partialFields = append(partialFields, field)
    49  			}
    50  
    51  			return partialFields, nil
    52  		}
    53  	}
    54  
    55  	return nil, nil
    56  }
    57  
    58  func ParsePartialStructFields(dst *ds.RecordPackage, name, pkgName, path string) ([]ds.PartialFieldDeclaration, error) {
    59  	relPath, err := filepath.Rel(dst.Namespace.ModuleName, path)
    60  	if err != nil {
    61  		return nil, fmt.Errorf("can't extract rel path of `%s` for module `%s`: %w", path, dst.Namespace.ModuleName, err)
    62  	}
    63  
    64  	pkgs, err := parser.ParseDir(token.NewFileSet(), relPath, nil, parser.DeclarationErrors)
    65  	if err != nil {
    66  		return nil, fmt.Errorf("error parse file `%s`: %w", path, err)
    67  	}
    68  
    69  	files := make(map[string]*ast.File)
    70  	for _, f := range pkgs[pkgName].Files {
    71  		for name, object := range f.Scope.Objects {
    72  			if object.Kind == ast.Typ {
    73  				files[name] = f
    74  			}
    75  		}
    76  	}
    77  
    78  	file, ok := files[name]
    79  	if !ok {
    80  		return nil, fmt.Errorf("can't find struct `%s` in package `%s`: %w", name, pkgName, err)
    81  	}
    82  
    83  	importPkg := ds.NewImportPackage()
    84  	for _, spec := range file.Imports {
    85  		if err = ParseImport(&importPkg, spec); err != nil {
    86  			return nil, fmt.Errorf("can't parse import from package file `%s`: %w", file.Name, err)
    87  		}
    88  	}
    89  
    90  	pkgDecl := ds.LinkedPackageDeclaration{
    91  		Types:  make(map[string]struct{}),
    92  		Import: importPkg,
    93  	}
    94  
    95  	for t := range files {
    96  		pkgDecl.Types[t] = struct{}{}
    97  	}
    98  
    99  	dst.LinkedStructsMap[pkgName] = pkgDecl
   100  
   101  	for _, decl := range file.Decls {
   102  		switch gen := decl.(type) {
   103  		case *ast.GenDecl:
   104  			if gen.Tok != token.TYPE {
   105  				continue
   106  			}
   107  			partialFields, genErr := parseStructFields(dst, gen, name, pkgName)
   108  			if genErr != nil {
   109  				return nil, &arerror.ErrParseGenDecl{Name: pkgName, Err: fmt.Errorf("error parse struct `%s` in package `%s`: %w", name, pkgName, genErr)}
   110  			}
   111  
   112  			if len(partialFields) == 0 {
   113  				continue
   114  			}
   115  
   116  			return partialFields, nil
   117  		}
   118  	}
   119  
   120  	return nil, nil
   121  }
   122  
   123  //nolint:gocognit
   124  func ParseFieldType(dst *ds.RecordPackage, name, pName string, t interface{}) (string, error) {
   125  	switch tv := t.(type) {
   126  	case *ast.Ident:
   127  		v := tv.String()
   128  
   129  		if ls, ok := dst.LinkedStructsMap[pName]; ok {
   130  			if _, ok := ls.Types[v]; ok {
   131  				return pName + "." + v, nil
   132  			}
   133  
   134  			// если импорта нет, то это простой тип
   135  			imp, err := ls.Import.FindImportByPkg(v)
   136  			if err != nil {
   137  				return v, nil
   138  			}
   139  
   140  			_, _ = dst.FindOrAddImport(imp.Path, imp.ImportName)
   141  		}
   142  
   143  		return v, nil
   144  	case *ast.ArrayType:
   145  		var err error
   146  
   147  		len := ""
   148  		if tv.Len != nil {
   149  			len, err = ParseFieldType(dst, name, "", tv.Len)
   150  			if err != nil {
   151  				return "", err
   152  			}
   153  		}
   154  
   155  		t, err := ParseFieldType(dst, name, pName, tv.Elt)
   156  		if err != nil {
   157  			return "", err
   158  		}
   159  
   160  		return "[" + len + "]" + t, nil
   161  	case *ast.InterfaceType:
   162  		return "interface{}", nil
   163  	case *ast.StarExpr:
   164  		t, err := ParseFieldType(dst, name, pName, tv.X)
   165  		if err != nil {
   166  			return "", err
   167  		}
   168  
   169  		return "*" + t, nil
   170  	case *ast.MapType:
   171  		k, err := ParseFieldType(dst, name, pName, tv.Key)
   172  		if err != nil {
   173  			return "", nil
   174  		}
   175  
   176  		v, err := ParseFieldType(dst, name, pName, tv.Value)
   177  		if err != nil {
   178  			return "", nil
   179  		}
   180  
   181  		return "map[" + k + "]" + v, nil
   182  	case *ast.SelectorExpr:
   183  		pkgName, err := ParseFieldType(dst, name, pName, tv.X)
   184  		if err != nil {
   185  			return "", err
   186  		}
   187  
   188  		imp, err := dst.FindImportByPkg(pkgName)
   189  		if err != nil {
   190  			return "", &arerror.ErrParseTypeStructDecl{Name: name, Err: err}
   191  		}
   192  
   193  		reqImportName := imp.ImportName
   194  		if reqImportName == "" {
   195  			reqImportName = pkgName
   196  		}
   197  
   198  		if _, ok := dst.ImportStructFieldsMap[reqImportName+"."+tv.Sel.Name]; !ok {
   199  			fieldDeclarations, err := ParsePartialStructFields(dst, tv.Sel.Name, pkgName, imp.Path)
   200  			if err != nil {
   201  				return "", &arerror.ErrParseTypeStructDecl{Name: name, Err: err}
   202  			}
   203  
   204  			dst.ImportStructFieldsMap[reqImportName+"."+tv.Sel.Name] = fieldDeclarations
   205  		}
   206  
   207  		return reqImportName + "." + tv.Sel.Name, nil
   208  	default:
   209  		return "", &arerror.ErrParseTypeStructDecl{Name: name, Err: arerror.ErrUnknown}
   210  	}
   211  }