github.com/xaionaro-go/rand@v0.0.0-20191005105903-aba1befc54a5/internal/autogen/parse_methods.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/parser"
     7  	"go/token"
     8  	"strconv"
     9  	"strings"
    10  )
    11  
    12  type Method struct {
    13  	Name           string
    14  	InitCode       string
    15  	GetValueCode   string
    16  	ResultVariable string
    17  	FinishCode     string
    18  	ResultSize     uint
    19  	AdditionalInfo string
    20  
    21  	genInfo           *genInfo
    22  	initCodeLines     []string
    23  	getValueCodeLines []string
    24  	finishCodeLines   []string
    25  	variableExists    map[string]struct{}
    26  }
    27  type Methods []*Method
    28  
    29  func ParseMethods() (methods Methods, err error) {
    30  	for _, filePath := range []string{`uint64.go`, `uint32.go`} {
    31  		var newMethods Methods
    32  		newMethods, err = parseMethodsFromFile(filePath)
    33  		if err != nil {
    34  			return
    35  		}
    36  		methods = append(methods, newMethods...)
    37  	}
    38  
    39  	return
    40  }
    41  
    42  func parseMethodsFromFile(path string) (methods Methods, err error) {
    43  	fileSet := token.NewFileSet()
    44  	parsedFile, err := parser.ParseFile(fileSet, path, nil, parser.ParseComments)
    45  	if err != nil {
    46  		return
    47  	}
    48  	var genInfo *genInfo
    49  	for _, decl := range parsedFile.Decls {
    50  		switch decl := decl.(type) {
    51  		case *ast.GenDecl:
    52  			genInfo = parseGenInfo(decl)
    53  		case *ast.FuncDecl:
    54  			method := parseMethod(decl, genInfo)
    55  			if method == nil {
    56  				continue
    57  			}
    58  			methods = append(methods, method)
    59  		default:
    60  			panic(fmt.Sprintf("%T:%v", decl, decl))
    61  		}
    62  	}
    63  
    64  	return
    65  }
    66  
    67  type genInfo struct {
    68  	constMap map[string]ast.Expr
    69  	funcMap  map[string]*ast.FuncDecl
    70  }
    71  
    72  func parseGenInfo(genDecl *ast.GenDecl) (result *genInfo) {
    73  	result = &genInfo{
    74  		constMap: map[string]ast.Expr{},
    75  		funcMap:  map[string]*ast.FuncDecl{},
    76  	}
    77  	for _, spec := range genDecl.Specs {
    78  		switch spec := spec.(type) {
    79  		case *ast.ValueSpec:
    80  			if len(spec.Names) != 1 {
    81  				panic(spec.Names)
    82  			}
    83  			if len(spec.Values) != 1 {
    84  				panic(spec.Values)
    85  			}
    86  			result.constMap[spec.Names[0].String()] = spec.Values[0]
    87  		default:
    88  			panic(fmt.Sprintf("%T:%v\n", spec, spec))
    89  		}
    90  	}
    91  
    92  	return
    93  }
    94  
    95  func parseMethod(funcDecl *ast.FuncDecl, genInfo *genInfo) (method *Method) {
    96  	if funcDecl.Recv == nil { // it's not a method, just a function (in the global scope)
    97  		genInfo.funcMap[funcDecl.Name.String()] = funcDecl
    98  		return
    99  	}
   100  
   101  	recvPtrType, ok := funcDecl.Recv.List[0].Type.(*ast.StarExpr)
   102  	if !ok {
   103  		return
   104  	}
   105  	recvType, ok := recvPtrType.X.(*ast.Ident)
   106  	if !ok {
   107  		return
   108  	}
   109  	if recvType.String() != `PRNG` {
   110  		return
   111  	}
   112  
   113  	method = &Method{
   114  		Name:           funcDecl.Name.String(),
   115  		genInfo:        genInfo,
   116  		variableExists: map[string]struct{}{},
   117  	}
   118  
   119  	resultTypeName := funcDecl.Type.Results.List[0].Type.(*ast.Ident).String()
   120  	switch resultTypeName {
   121  	case `uint32`:
   122  		method.ResultSize = 4
   123  	case `uint64`:
   124  		method.ResultSize = 8
   125  	}
   126  
   127  	if funcDecl.Doc != nil {
   128  		var docs []string
   129  		for _, doc := range funcDecl.Doc.List {
   130  			docs = append(docs, strings.Trim(doc.Text, "/ "))
   131  		}
   132  		method.AdditionalInfo = strings.Join(docs, "\n// ")
   133  	}
   134  
   135  	if len(funcDecl.Type.Results.List[0].Names) == 1 {
   136  		method.ResultVariable = funcDecl.Type.Results.List[0].Names[0].String()
   137  		method.initCodeLines = append(method.initCodeLines, `var `+method.ResultVariable+` `+resultTypeName)
   138  	}
   139  
   140  	method.addBodyStmt(funcDecl.Body)
   141  
   142  	method.compile()
   143  	return
   144  }
   145  
   146  func (m *Method) addBodyStmt(body *ast.BlockStmt) {
   147  	for _, decl := range body.List {
   148  		switch decl := decl.(type) {
   149  		case *ast.AssignStmt:
   150  			m.addAssignStmt(decl)
   151  		case *ast.ReturnStmt:
   152  			m.addReturnStmt(decl)
   153  		case *ast.IfStmt:
   154  			m.addIfStmt(decl)
   155  		default:
   156  			panic(fmt.Sprintf("%T", decl))
   157  		}
   158  	}
   159  }
   160  
   161  func (m *Method) compile() {
   162  	m.InitCode = strings.Join(m.initCodeLines, "\n\t\t")
   163  	m.GetValueCode = strings.Join(m.getValueCodeLines, "\n\t\t\t")
   164  	m.FinishCode = strings.Join(m.finishCodeLines, "\n\t\t")
   165  }
   166  
   167  func (m *Method) addReturnStmt(stmt *ast.ReturnStmt) {
   168  	if len(stmt.Results) == 0 {
   169  		return
   170  	}
   171  	if len(stmt.Results) != 1 {
   172  		panic(stmt.Results)
   173  	}
   174  	m.ResultVariable = m.getCodeString(stmt.Results[0])
   175  }
   176  
   177  func (m *Method) addAssignStmt(stmt *ast.AssignStmt) {
   178  	if len(stmt.Lhs) != 1 {
   179  		panic(len(stmt.Lhs))
   180  	}
   181  	if len(stmt.Rhs) != 1 {
   182  		panic(len(stmt.Rhs))
   183  	}
   184  	l := m.getCodeString(stmt.Lhs[0])
   185  	r := m.getCodeString(stmt.Rhs[0])
   186  	m.getValueCodeLines = append(m.getValueCodeLines, l+` `+stmt.Tok.String()+` `+r)
   187  }
   188  
   189  func (m *Method) addIfStmt(stmt *ast.IfStmt) {
   190  	m.getValueCodeLines = append(m.getValueCodeLines, `if ` + m.getCodeString(stmt.Cond)+` {`)
   191  	s := m.getValueCodeLines
   192  	m.addBodyStmt(stmt.Body)
   193  	for idx := range m.getValueCodeLines[len(s):] {
   194  		m.getValueCodeLines[len(s)+idx] = "\t" + m.getValueCodeLines[len(s)+idx]
   195  	}
   196  	m.getValueCodeLines = append(m.getValueCodeLines, `}`)
   197  }
   198  
   199  func (m *Method) getCodeString(expr ast.Expr) string {
   200  	switch expr := expr.(type) {
   201  	case *ast.Ident:
   202  		v := expr.String()
   203  		if expr.Obj != nil && expr.Obj.Kind.String() == `var` {
   204  			m.considerVariable(v)
   205  		}
   206  		return v
   207  	case *ast.SelectorExpr:
   208  		if expr.X.(*ast.Ident).String() != `prng` {
   209  			panic(expr.X)
   210  		}
   211  		v := expr.Sel.String() + "Temp"
   212  		m.considerVariable(v)
   213  		return v
   214  	case *ast.IndexExpr:
   215  		variable := m.getCodeString(expr.X) + strings.ToTitle(m.getCodeString(expr.Index))
   216  		m.considerVariable(variable)
   217  		return variable
   218  	case *ast.CallExpr:
   219  		var args []string
   220  		for _, arg := range expr.Args {
   221  			args = append(args, m.getCodeString(arg))
   222  		}
   223  		return m.getCodeString(expr.Fun) + `(` + strings.Join(args, `, `) + `)`
   224  	case *ast.BasicLit:
   225  		return expr.Value
   226  	case *ast.BinaryExpr:
   227  		y := m.getCodeString(expr.Y)
   228  		return m.getCodeString(expr.X) + ` ` + expr.Op.String() + ` ` + y
   229  	default:
   230  		panic(fmt.Sprintf("%T", expr))
   231  	}
   232  }
   233  
   234  func (m *Method) considerVariable(variable string) {
   235  	if _, ok := m.variableExists[variable]; ok {
   236  		return
   237  	}
   238  	m.variableExists[variable] = struct{}{}
   239  	/*if m.genInfo.constMap[variable] != nil {
   240  		return
   241  	}*/
   242  	if m.genInfo.funcMap[variable] != nil {
   243  		return
   244  	}
   245  	if strings.HasPrefix(variable, `state64Temp`) {
   246  		idxString := variable[len(`state64Temp`):]
   247  		if len(idxString) == 0 {
   248  			return
   249  		}
   250  		idx, err := strconv.ParseUint(idxString, 10, 64)
   251  		if err != nil {
   252  			panic(variable)
   253  		}
   254  		m.initCodeLines = append(m.initCodeLines, fmt.Sprintf(`%v := prng.state64[%v]`, variable, idx))
   255  		m.finishCodeLines = append(m.finishCodeLines, fmt.Sprintf(`prng.state64[%v] = %v`, idx, variable))
   256  		return
   257  	}
   258  	if strings.HasPrefix(variable, `state32Temp`) {
   259  		idxString := variable[len(`state32Temp`):]
   260  		if len(idxString) == 0 {
   261  			return
   262  		}
   263  		idx, err := strconv.ParseUint(idxString, 10, 64)
   264  		if err != nil {
   265  			panic(err)
   266  		}
   267  		m.initCodeLines = append(m.initCodeLines, fmt.Sprintf(`%v := prng.state32[%v]`, variable, idx))
   268  		m.finishCodeLines = append(m.finishCodeLines, fmt.Sprintf(`prng.state32[%v] = %v`, idx, variable))
   269  		return
   270  	}
   271  	if variable == `pcgStateTemp` {
   272  		m.initCodeLines = append(m.initCodeLines, fmt.Sprintf(`%v := prng.pcgState`, variable))
   273  		m.finishCodeLines = append(m.finishCodeLines, fmt.Sprintf(`prng.pcgState = %v`, variable))
   274  	}
   275  }