github.com/amazechain/amc@v0.1.3/internal/vm/absint_cfg_proof_check.go (about)

     1  package vm
     2  
     3  import (
     4  	"errors"
     5  	"log"
     6  	"reflect"
     7  
     8  	"github.com/holiman/uint256"
     9  )
    10  
    11  type CfgOpSem struct {
    12  	isPush   bool
    13  	isDup    bool
    14  	isSwap   bool
    15  	numBytes int
    16  	opNum    int
    17  	numPush  int
    18  	numPop   int
    19  }
    20  
    21  type CfgAbsSem map[OpCode]*CfgOpSem
    22  
    23  func NewCfgAbsSem() *CfgAbsSem {
    24  	jt := newIstanbulInstructionSet()
    25  
    26  	sem := CfgAbsSem{}
    27  
    28  	for opcode, op := range jt {
    29  		if op == nil {
    30  			continue
    31  		}
    32  		opsem := CfgOpSem{}
    33  		opsem.isPush = op.isPush
    34  		opsem.isDup = op.isDup
    35  		opsem.isSwap = op.isSwap
    36  		opsem.opNum = op.opNum
    37  		opsem.numPush = op.numPush
    38  		opsem.numPop = op.numPop
    39  
    40  		if opsem.isPush {
    41  			opsem.numBytes = op.opNum + 1
    42  		} else {
    43  			opsem.numBytes = 1
    44  
    45  		}
    46  		sem[OpCode(opcode)] = &opsem
    47  	}
    48  
    49  	return &sem
    50  }
    51  
    52  func getPushValue(code []byte, pc int, opsem0 *CfgOpSem) uint256.Int {
    53  	pushByteSize := opsem0.opNum
    54  	startMin := pc + 1
    55  	if startMin >= len(code) {
    56  		startMin = len(code)
    57  	}
    58  	endMin := startMin + pushByteSize
    59  	if startMin+pushByteSize >= len(code) {
    60  		endMin = len(code)
    61  	}
    62  	integer := new(uint256.Int)
    63  	integer.SetBytes(code[startMin:endMin])
    64  	return *integer
    65  }
    66  
    67  func isJumpDest(code []byte, value *uint256.Int) bool {
    68  	if !value.IsUint64() {
    69  		return false
    70  	}
    71  
    72  	pc := value.Uint64()
    73  	if pc >= uint64(len(code)) {
    74  		return false
    75  	}
    76  
    77  	return OpCode(code[pc]) == JUMPDEST
    78  }
    79  
    80  func resolveCheck(sem *CfgAbsSem, code []byte, st0 *astate, pc0 int) (map[int]bool, map[int]bool, error) {
    81  	opcode := OpCode(code[pc0])
    82  	opsem := (*sem)[opcode]
    83  	succs := make(map[int]bool)
    84  	jumps := make(map[int]bool)
    85  
    86  	if opsem == nil {
    87  		return succs, jumps, nil
    88  	}
    89  
    90  	codeLen := len(code)
    91  
    92  	for _, stack := range st0.stackset {
    93  		if opcode == JUMP || opcode == JUMPI {
    94  			if stack.hasIndices(0) {
    95  				jumpDest := stack.values[0]
    96  				if jumpDest.kind == InvalidValue {
    97  					//program terminates, don't add edges
    98  				} else if jumpDest.kind == TopValue {
    99  					empty := make(map[int]bool)
   100  					return empty, empty, errors.New("unresolvable jumps found")
   101  				} else if jumpDest.kind == ConcreteValue {
   102  					if isJumpDest(code, jumpDest.value) {
   103  						pc1 := int(jumpDest.value.Uint64())
   104  						succs[pc1] = true
   105  						jumps[pc1] = true
   106  					}
   107  				}
   108  			}
   109  		}
   110  	}
   111  
   112  	//fall-thru edge
   113  	if opcode != JUMP {
   114  		if pc0 < codeLen-opsem.numBytes {
   115  			succs[pc0+opsem.numBytes] = true
   116  		}
   117  	}
   118  
   119  	return succs, jumps, nil
   120  }
   121  
   122  func postCheck(sem *CfgAbsSem, code []byte, st0 *astate, pc0 int, pc1 int, isJump bool) *astate {
   123  	st1 := emptyState()
   124  	op0 := OpCode(code[pc0])
   125  	opsem0 := (*sem)[op0]
   126  
   127  	for _, stack0 := range st0.stackset {
   128  		if isJump {
   129  			if !stack0.hasIndices(0) {
   130  				continue
   131  			}
   132  
   133  			elm0 := stack0.values[0]
   134  			if elm0.kind == ConcreteValue && elm0.value.IsUint64() && int(elm0.value.Uint64()) != pc1 {
   135  				continue
   136  			}
   137  		}
   138  
   139  		stack1 := stack0.Copy()
   140  
   141  		if opsem0.isPush {
   142  			pushValue := getPushValue(code, pc0, opsem0)
   143  			if isJumpDest(code, &pushValue) || isFF(&pushValue) {
   144  				stack1.Push(AbsValueConcrete(pushValue))
   145  			} else {
   146  				stack1.Push(AbsValueInvalid())
   147  			}
   148  		} else if opsem0.isDup {
   149  			if !stack0.hasIndices(opsem0.opNum - 1) {
   150  				continue
   151  			}
   152  
   153  			value := stack1.values[opsem0.opNum-1]
   154  			stack1.Push(value)
   155  		} else if opsem0.isSwap {
   156  			opNum := opsem0.opNum
   157  
   158  			if !stack0.hasIndices(0, opNum) {
   159  				continue
   160  			}
   161  
   162  			a := stack1.values[0]
   163  			b := stack1.values[opNum]
   164  			stack1.values[0] = b
   165  			stack1.values[opNum] = a
   166  
   167  		} else if op0 == AND {
   168  			if !stack0.hasIndices(0, 1) {
   169  				continue
   170  			}
   171  
   172  			a := stack1.Pop(pc0)
   173  			b := stack1.Pop(pc0)
   174  
   175  			if a.kind == ConcreteValue && b.kind == ConcreteValue {
   176  				v := uint256.NewInt(0)
   177  				v.And(a.value, b.value)
   178  				stack1.Push(AbsValueConcrete(*v))
   179  			} else {
   180  				stack1.Push(AbsValueTop(pc0))
   181  			}
   182  		} else if op0 == PC {
   183  			v := uint256.NewInt(0)
   184  			v.SetUint64(uint64(pc0))
   185  			stack1.Push(AbsValueConcrete(*v))
   186  		} else {
   187  			if !stack0.hasIndices(opsem0.numPop - 1) {
   188  				continue
   189  			}
   190  
   191  			for i := 0; i < opsem0.numPop; i++ {
   192  				stack1.Pop(pc0)
   193  			}
   194  
   195  			for i := 0; i < opsem0.numPush; i++ {
   196  				stack1.Push(AbsValueTop(pc0))
   197  			}
   198  		}
   199  
   200  		stack1.updateHash()
   201  		st1.Add(stack1)
   202  	}
   203  
   204  	return st1
   205  }
   206  
   207  func CheckCfg(code []byte, proof *CfgProof) bool {
   208  	sem := NewCfgAbsSem()
   209  
   210  	if !proof.isValid() {
   211  		return false
   212  	}
   213  
   214  	preLub := make(map[int][]*astate)
   215  	for _, block := range proof.Blocks {
   216  		st := intoAState(block.Entry.Stacks)
   217  		pc0 := block.Entry.Pc
   218  		blockSuccs := intMap(block.Succs)
   219  		for pc0 <= block.Exit.Pc {
   220  
   221  			if pc0 == block.Exit.Pc {
   222  				if !Eq(st, intoAState(block.Exit.Stacks)) {
   223  					return false
   224  				}
   225  			}
   226  
   227  			succs, isJump, err := resolveCheck(sem, code, st, pc0)
   228  			if err != nil {
   229  				return false
   230  			}
   231  
   232  			if pc0 == block.Exit.Pc {
   233  				if !reflect.DeepEqual(succs, blockSuccs) {
   234  					return false
   235  				}
   236  				for succEntryPc := range succs {
   237  					succEntrySt := postCheck(sem, code, st, pc0, succEntryPc, isJump[succEntryPc])
   238  					preLub[succEntryPc] = append(preLub[succEntryPc], succEntrySt)
   239  				}
   240  				break
   241  			} else {
   242  				if len(succs) != 1 {
   243  					return false
   244  				}
   245  
   246  				pc1 := one(succs)
   247  				if pc0 >= pc1 || pc1 > block.Exit.Pc {
   248  					return false
   249  				}
   250  
   251  				st = postCheck(sem, code, st, pc0, pc1, false)
   252  				pc0 = pc1
   253  			}
   254  		}
   255  	}
   256  
   257  	for _, block := range proof.Blocks {
   258  
   259  		var inferredEntry *astate
   260  		if block.Entry.Pc == 0 {
   261  			inferredEntry = botState()
   262  		} else {
   263  			lub := emptyState()
   264  			for _, preSt := range preLub[block.Entry.Pc] {
   265  				lub = Lub(lub, preSt)
   266  			}
   267  			inferredEntry = lub
   268  		}
   269  
   270  		if !Eq(inferredEntry, intoAState(block.Entry.Stacks)) {
   271  			return false
   272  		}
   273  	}
   274  
   275  	return true
   276  }
   277  
   278  func intMap(succs []int) map[int]bool {
   279  	res := make(map[int]bool)
   280  	for _, succ := range succs {
   281  		res[succ] = true
   282  	}
   283  	return res
   284  }
   285  
   286  func one(m map[int]bool) int {
   287  	for k := range m {
   288  		return k
   289  	}
   290  	log.Fatal("must have exactly one element")
   291  	return -1
   292  }