github.com/ethereum-optimism/optimism@v1.7.2/packages/contracts-bedrock/src/cannon/MIPS.sol (about)

     1  // SPDX-License-Identifier: MIT
     2  pragma solidity 0.8.15;
     3  
     4  import { ISemver } from "src/universal/ISemver.sol";
     5  import { IPreimageOracle } from "./interfaces/IPreimageOracle.sol";
     6  import { PreimageKeyLib } from "./PreimageKeyLib.sol";
     7  
     8  /// @title MIPS
     9  /// @notice The MIPS contract emulates a single MIPS instruction.
    10  ///         Note that delay slots are isolated instructions:
    11  ///         the nextPC in the state pre-schedules where the VM jumps next.
    12  ///         The Step input is a packed VM state, with binary-merkle-tree
    13  ///         witness data for memory reads/writes.
    14  ///         The Step outputs a keccak256 hash of the packed VM State,
    15  ///         and logs the resulting state for offchain usage.
    16  /// @dev https://inst.eecs.berkeley.edu/~cs61c/resources/MIPS_Green_Sheet.pdf
    17  /// @dev https://www.cs.cmu.edu/afs/cs/academic/class/15740-f97/public/doc/mips-isa.pdf
    18  ///      (page A-177)
    19  /// @dev https://uweb.engr.arizona.edu/~ece369/Resources/spim/MIPSReference.pdf
    20  /// @dev https://en.wikibooks.org/wiki/MIPS_Assembly/Instruction_Formats
    21  /// @dev https://github.com/golang/go/blob/master/src/syscall/zerrors_linux_mips.go
    22  ///      MIPS linux kernel errors used by Go runtime
    23  contract MIPS is ISemver {
    24      /// @notice Stores the VM state.
    25      ///         Total state size: 32 + 32 + 6 * 4 + 1 + 1 + 8 + 32 * 4 = 226 bytes
    26      ///         If nextPC != pc + 4, then the VM is executing a branch/jump delay slot.
    27      struct State {
    28          bytes32 memRoot;
    29          bytes32 preimageKey;
    30          uint32 preimageOffset;
    31          uint32 pc;
    32          uint32 nextPC;
    33          uint32 lo;
    34          uint32 hi;
    35          uint32 heap;
    36          uint8 exitCode;
    37          bool exited;
    38          uint64 step;
    39          uint32[32] registers;
    40      }
    41  
    42      /// @notice Start of the data segment.
    43      uint32 public constant BRK_START = 0x40000000;
    44  
    45      /// @notice The semantic version of the MIPS contract.
    46      /// @custom:semver 0.1.0
    47      string public constant version = "0.1.0";
    48  
    49      uint32 internal constant FD_STDIN = 0;
    50      uint32 internal constant FD_STDOUT = 1;
    51      uint32 internal constant FD_STDERR = 2;
    52      uint32 internal constant FD_HINT_READ = 3;
    53      uint32 internal constant FD_HINT_WRITE = 4;
    54      uint32 internal constant FD_PREIMAGE_READ = 5;
    55      uint32 internal constant FD_PREIMAGE_WRITE = 6;
    56  
    57      uint32 internal constant EBADF = 0x9;
    58      uint32 internal constant EINVAL = 0x16;
    59  
    60      /// @notice The preimage oracle contract.
    61      IPreimageOracle internal immutable ORACLE;
    62  
    63      /// @param _oracle The address of the preimage oracle contract.
    64      constructor(IPreimageOracle _oracle) {
    65          ORACLE = _oracle;
    66      }
    67  
    68      /// @notice Getter for the pre-image oracle contract.
    69      /// @return oracle_ The IPreimageOracle contract.
    70      function oracle() external view returns (IPreimageOracle oracle_) {
    71          oracle_ = ORACLE;
    72      }
    73  
    74      /// @notice Extends the value leftwards with its most significant bit (sign extension).
    75      function SE(uint32 _dat, uint32 _idx) internal pure returns (uint32 out_) {
    76          unchecked {
    77              bool isSigned = (_dat >> (_idx - 1)) != 0;
    78              uint256 signed = ((1 << (32 - _idx)) - 1) << _idx;
    79              uint256 mask = (1 << _idx) - 1;
    80              return uint32(_dat & mask | (isSigned ? signed : 0));
    81          }
    82      }
    83  
    84      /// @notice Computes the hash of the MIPS state.
    85      /// @return out_ The hashed MIPS state.
    86      function outputState() internal returns (bytes32 out_) {
    87          assembly {
    88              // copies 'size' bytes, right-aligned in word at 'from', to 'to', incl. trailing data
    89              function copyMem(from, to, size) -> fromOut, toOut {
    90                  mstore(to, mload(add(from, sub(32, size))))
    91                  fromOut := add(from, 32)
    92                  toOut := add(to, size)
    93              }
    94  
    95              // From points to the MIPS State
    96              let from := 0x80
    97  
    98              // Copy to the free memory pointer
    99              let start := mload(0x40)
   100              let to := start
   101  
   102              // Copy state to free memory
   103              from, to := copyMem(from, to, 32) // memRoot
   104              from, to := copyMem(from, to, 32) // preimageKey
   105              from, to := copyMem(from, to, 4) // preimageOffset
   106              from, to := copyMem(from, to, 4) // pc
   107              from, to := copyMem(from, to, 4) // nextPC
   108              from, to := copyMem(from, to, 4) // lo
   109              from, to := copyMem(from, to, 4) // hi
   110              from, to := copyMem(from, to, 4) // heap
   111              let exitCode := mload(from)
   112              from, to := copyMem(from, to, 1) // exitCode
   113              let exited := mload(from)
   114              from, to := copyMem(from, to, 1) // exited
   115              from, to := copyMem(from, to, 8) // step
   116              from := add(from, 32) // offset to registers
   117  
   118              // Copy registers
   119              for { let i := 0 } lt(i, 32) { i := add(i, 1) } { from, to := copyMem(from, to, 4) }
   120  
   121              // Clean up end of memory
   122              mstore(to, 0)
   123  
   124              // Log the resulting MIPS state, for debugging
   125              log0(start, sub(to, start))
   126  
   127              // Determine the VM status
   128              let status := 0
   129              switch exited
   130              case 1 {
   131                  switch exitCode
   132                  // VMStatusValid
   133                  case 0 { status := 0 }
   134                  // VMStatusInvalid
   135                  case 1 { status := 1 }
   136                  // VMStatusPanic
   137                  default { status := 2 }
   138              }
   139              // VMStatusUnfinished
   140              default { status := 3 }
   141  
   142              // Compute the hash of the resulting MIPS state and set the status byte
   143              out_ := keccak256(start, sub(to, start))
   144              out_ := or(and(not(shl(248, 0xFF)), out_), shl(248, status))
   145          }
   146      }
   147  
   148      /// @notice Handles a syscall.
   149      /// @param _localContext The local key context for the preimage oracle.
   150      /// @return out_ The hashed MIPS state.
   151      function handleSyscall(bytes32 _localContext) internal returns (bytes32 out_) {
   152          unchecked {
   153              // Load state from memory
   154              State memory state;
   155              assembly {
   156                  state := 0x80
   157              }
   158  
   159              // Load the syscall number from the registers
   160              uint32 syscall_no = state.registers[2];
   161              uint32 v0 = 0;
   162              uint32 v1 = 0;
   163  
   164              // Load the syscall arguments from the registers
   165              uint32 a0 = state.registers[4];
   166              uint32 a1 = state.registers[5];
   167              uint32 a2 = state.registers[6];
   168  
   169              // mmap: Allocates a page from the heap.
   170              if (syscall_no == 4090) {
   171                  uint32 sz = a1;
   172                  if (sz & 4095 != 0) {
   173                      // adjust size to align with page size
   174                      sz += 4096 - (sz & 4095);
   175                  }
   176                  if (a0 == 0) {
   177                      v0 = state.heap;
   178                      state.heap += sz;
   179                  } else {
   180                      v0 = a0;
   181                  }
   182              }
   183              // brk: Returns a fixed address for the program break at 0x40000000
   184              else if (syscall_no == 4045) {
   185                  v0 = BRK_START;
   186              }
   187              // clone (not supported) returns 1
   188              else if (syscall_no == 4120) {
   189                  v0 = 1;
   190              }
   191              // exit group: Sets the Exited and ExitCode states to true and argument 0.
   192              else if (syscall_no == 4246) {
   193                  state.exited = true;
   194                  state.exitCode = uint8(a0);
   195                  return outputState();
   196              }
   197              // read: Like Linux read syscall. Splits unaligned reads into aligned reads.
   198              else if (syscall_no == 4003) {
   199                  // args: a0 = fd, a1 = addr, a2 = count
   200                  // returns: v0 = read, v1 = err code
   201                  if (a0 == FD_STDIN) {
   202                      // Leave v0 and v1 zero: read nothing, no error
   203                  }
   204                  // pre-image oracle read
   205                  else if (a0 == FD_PREIMAGE_READ) {
   206                      // verify proof 1 is correct, and get the existing memory.
   207                      uint32 mem = readMem(a1 & 0xFFffFFfc, 1); // mask the addr to align it to 4 bytes
   208                      bytes32 preimageKey = state.preimageKey;
   209                      // If the preimage key is a local key, localize it in the context of the caller.
   210                      if (uint8(preimageKey[0]) == 1) {
   211                          preimageKey = PreimageKeyLib.localize(preimageKey, _localContext);
   212                      }
   213                      (bytes32 dat, uint256 datLen) = ORACLE.readPreimage(preimageKey, state.preimageOffset);
   214  
   215                      // Transform data for writing to memory
   216                      // We use assembly for more precise ops, and no var count limit
   217                      assembly {
   218                          let alignment := and(a1, 3) // the read might not start at an aligned address
   219                          let space := sub(4, alignment) // remaining space in memory word
   220                          if lt(space, datLen) { datLen := space } // if less space than data, shorten data
   221                          if lt(a2, datLen) { datLen := a2 } // if requested to read less, read less
   222                          dat := shr(sub(256, mul(datLen, 8)), dat) // right-align data
   223                          dat := shl(mul(sub(sub(4, datLen), alignment), 8), dat) // position data to insert into memory
   224                              // word
   225                          let mask := sub(shl(mul(sub(4, alignment), 8), 1), 1) // mask all bytes after start
   226                          let suffixMask := sub(shl(mul(sub(sub(4, alignment), datLen), 8), 1), 1) // mask of all bytes
   227                              // starting from end, maybe none
   228                          mask := and(mask, not(suffixMask)) // reduce mask to just cover the data we insert
   229                          mem := or(and(mem, not(mask)), dat) // clear masked part of original memory, and insert data
   230                      }
   231  
   232                      // Write memory back
   233                      writeMem(a1 & 0xFFffFFfc, 1, mem);
   234                      state.preimageOffset += uint32(datLen);
   235                      v0 = uint32(datLen);
   236                  }
   237                  // hint response
   238                  else if (a0 == FD_HINT_READ) {
   239                      // Don't read into memory, just say we read it all
   240                      // The result is ignored anyway
   241                      v0 = a2;
   242                  } else {
   243                      v0 = 0xFFffFFff;
   244                      v1 = EBADF;
   245                  }
   246              }
   247              // write: like Linux write syscall. Splits unaligned writes into aligned writes.
   248              else if (syscall_no == 4004) {
   249                  // args: a0 = fd, a1 = addr, a2 = count
   250                  // returns: v0 = written, v1 = err code
   251                  if (a0 == FD_STDOUT || a0 == FD_STDERR || a0 == FD_HINT_WRITE) {
   252                      v0 = a2; // tell program we have written everything
   253                  }
   254                  // pre-image oracle
   255                  else if (a0 == FD_PREIMAGE_WRITE) {
   256                      uint32 mem = readMem(a1 & 0xFFffFFfc, 1); // mask the addr to align it to 4 bytes
   257                      bytes32 key = state.preimageKey;
   258  
   259                      // Construct pre-image key from memory
   260                      // We use assembly for more precise ops, and no var count limit
   261                      assembly {
   262                          let alignment := and(a1, 3) // the read might not start at an aligned address
   263                          let space := sub(4, alignment) // remaining space in memory word
   264                          if lt(space, a2) { a2 := space } // if less space than data, shorten data
   265                          key := shl(mul(a2, 8), key) // shift key, make space for new info
   266                          let mask := sub(shl(mul(a2, 8), 1), 1) // mask for extracting value from memory
   267                          mem := and(shr(mul(sub(space, a2), 8), mem), mask) // align value to right, mask it
   268                          key := or(key, mem) // insert into key
   269                      }
   270  
   271                      // Write pre-image key to oracle
   272                      state.preimageKey = key;
   273                      state.preimageOffset = 0; // reset offset, to read new pre-image data from the start
   274                      v0 = a2;
   275                  } else {
   276                      v0 = 0xFFffFFff;
   277                      v1 = EBADF;
   278                  }
   279              }
   280              // fcntl: Like linux fcntl syscall, but only supports minimal file-descriptor control commands,
   281              // to retrieve the file-descriptor R/W flags.
   282              else if (syscall_no == 4055) {
   283                  // fcntl
   284                  // args: a0 = fd, a1 = cmd
   285                  if (a1 == 3) {
   286                      // F_GETFL: get file descriptor flags
   287                      if (a0 == FD_STDIN || a0 == FD_PREIMAGE_READ || a0 == FD_HINT_READ) {
   288                          v0 = 0; // O_RDONLY
   289                      } else if (a0 == FD_STDOUT || a0 == FD_STDERR || a0 == FD_PREIMAGE_WRITE || a0 == FD_HINT_WRITE) {
   290                          v0 = 1; // O_WRONLY
   291                      } else {
   292                          v0 = 0xFFffFFff;
   293                          v1 = EBADF;
   294                      }
   295                  } else {
   296                      v0 = 0xFFffFFff;
   297                      v1 = EINVAL; // cmd not recognized by this kernel
   298                  }
   299              }
   300  
   301              // Write the results back to the state registers
   302              state.registers[2] = v0;
   303              state.registers[7] = v1;
   304  
   305              // Update the PC and nextPC
   306              state.pc = state.nextPC;
   307              state.nextPC = state.nextPC + 4;
   308  
   309              out_ = outputState();
   310          }
   311      }
   312  
   313      /// @notice Handles a branch instruction, updating the MIPS state PC where needed.
   314      /// @param _opcode The opcode of the branch instruction.
   315      /// @param _insn The instruction to be executed.
   316      /// @param _rtReg The register to be used for the branch.
   317      /// @param _rs The register to be compared with the branch register.
   318      /// @return out_ The hashed MIPS state.
   319      function handleBranch(uint32 _opcode, uint32 _insn, uint32 _rtReg, uint32 _rs) internal returns (bytes32 out_) {
   320          unchecked {
   321              // Load state from memory
   322              State memory state;
   323              assembly {
   324                  state := 0x80
   325              }
   326  
   327              bool shouldBranch = false;
   328  
   329              if (state.nextPC != state.pc + 4) {
   330                  revert("branch in delay slot");
   331              }
   332  
   333              // beq/bne: Branch on equal / not equal
   334              if (_opcode == 4 || _opcode == 5) {
   335                  uint32 rt = state.registers[_rtReg];
   336                  shouldBranch = (_rs == rt && _opcode == 4) || (_rs != rt && _opcode == 5);
   337              }
   338              // blez: Branches if instruction is less than or equal to zero
   339              else if (_opcode == 6) {
   340                  shouldBranch = int32(_rs) <= 0;
   341              }
   342              // bgtz: Branches if instruction is greater than zero
   343              else if (_opcode == 7) {
   344                  shouldBranch = int32(_rs) > 0;
   345              }
   346              // bltz/bgez: Branch on less than zero / greater than or equal to zero
   347              else if (_opcode == 1) {
   348                  // regimm
   349                  uint32 rtv = ((_insn >> 16) & 0x1F);
   350                  if (rtv == 0) {
   351                      shouldBranch = int32(_rs) < 0;
   352                  }
   353                  if (rtv == 1) {
   354                      shouldBranch = int32(_rs) >= 0;
   355                  }
   356              }
   357  
   358              // Update the state's previous PC
   359              uint32 prevPC = state.pc;
   360  
   361              // Execute the delay slot first
   362              state.pc = state.nextPC;
   363  
   364              // If we should branch, update the PC to the branch target
   365              // Otherwise, proceed to the next instruction
   366              if (shouldBranch) {
   367                  state.nextPC = prevPC + 4 + (SE(_insn & 0xFFFF, 16) << 2);
   368              } else {
   369                  state.nextPC = state.nextPC + 4;
   370              }
   371  
   372              // Return the hash of the resulting state
   373              out_ = outputState();
   374          }
   375      }
   376  
   377      /// @notice Handles HI and LO register instructions.
   378      /// @param _func The function code of the instruction.
   379      /// @param _rs The value of the RS register.
   380      /// @param _rt The value of the RT register.
   381      /// @param _storeReg The register to store the result in.
   382      /// @return out_ The hashed MIPS state.
   383      function handleHiLo(uint32 _func, uint32 _rs, uint32 _rt, uint32 _storeReg) internal returns (bytes32 out_) {
   384          unchecked {
   385              // Load state from memory
   386              State memory state;
   387              assembly {
   388                  state := 0x80
   389              }
   390  
   391              uint32 val;
   392  
   393              // mfhi: Move the contents of the HI register into the destination
   394              if (_func == 0x10) {
   395                  val = state.hi;
   396              }
   397              // mthi: Move the contents of the source into the HI register
   398              else if (_func == 0x11) {
   399                  state.hi = _rs;
   400              }
   401              // mflo: Move the contents of the LO register into the destination
   402              else if (_func == 0x12) {
   403                  val = state.lo;
   404              }
   405              // mtlo: Move the contents of the source into the LO register
   406              else if (_func == 0x13) {
   407                  state.lo = _rs;
   408              }
   409              // mult: Multiplies `rs` by `rt` and stores the result in HI and LO registers
   410              else if (_func == 0x18) {
   411                  uint64 acc = uint64(int64(int32(_rs)) * int64(int32(_rt)));
   412                  state.hi = uint32(acc >> 32);
   413                  state.lo = uint32(acc);
   414              }
   415              // multu: Unsigned multiplies `rs` by `rt` and stores the result in HI and LO registers
   416              else if (_func == 0x19) {
   417                  uint64 acc = uint64(uint64(_rs) * uint64(_rt));
   418                  state.hi = uint32(acc >> 32);
   419                  state.lo = uint32(acc);
   420              }
   421              // div: Divides `rs` by `rt`.
   422              // Stores the quotient in LO
   423              // And the remainder in HI
   424              else if (_func == 0x1a) {
   425                  state.hi = uint32(int32(_rs) % int32(_rt));
   426                  state.lo = uint32(int32(_rs) / int32(_rt));
   427              }
   428              // divu: Unsigned divides `rs` by `rt`.
   429              // Stores the quotient in LO
   430              // And the remainder in HI
   431              else if (_func == 0x1b) {
   432                  state.hi = _rs % _rt;
   433                  state.lo = _rs / _rt;
   434              }
   435  
   436              // Store the result in the destination register, if applicable
   437              if (_storeReg != 0) {
   438                  state.registers[_storeReg] = val;
   439              }
   440  
   441              // Update the PC
   442              state.pc = state.nextPC;
   443              state.nextPC = state.nextPC + 4;
   444  
   445              // Return the hash of the resulting state
   446              out_ = outputState();
   447          }
   448      }
   449  
   450      /// @notice Handles a jump instruction, updating the MIPS state PC where needed.
   451      /// @param _linkReg The register to store the link to the instruction after the delay slot instruction.
   452      /// @param _dest The destination to jump to.
   453      /// @return out_ The hashed MIPS state.
   454      function handleJump(uint32 _linkReg, uint32 _dest) internal returns (bytes32 out_) {
   455          unchecked {
   456              // Load state from memory.
   457              State memory state;
   458              assembly {
   459                  state := 0x80
   460              }
   461  
   462              if (state.nextPC != state.pc + 4) {
   463                  revert("jump in delay slot");
   464              }
   465  
   466              // Update the next PC to the jump destination.
   467              uint32 prevPC = state.pc;
   468              state.pc = state.nextPC;
   469              state.nextPC = _dest;
   470  
   471              // Update the link-register to the instruction after the delay slot instruction.
   472              if (_linkReg != 0) {
   473                  state.registers[_linkReg] = prevPC + 8;
   474              }
   475  
   476              // Return the hash of the resulting state.
   477              out_ = outputState();
   478          }
   479      }
   480  
   481      /// @notice Handles a storing a value into a register.
   482      /// @param _storeReg The register to store the value into.
   483      /// @param _val The value to store.
   484      /// @param _conditional Whether or not the store is conditional.
   485      /// @return out_ The hashed MIPS state.
   486      function handleRd(uint32 _storeReg, uint32 _val, bool _conditional) internal returns (bytes32 out_) {
   487          unchecked {
   488              // Load state from memory.
   489              State memory state;
   490              assembly {
   491                  state := 0x80
   492              }
   493  
   494              // The destination register must be valid.
   495              require(_storeReg < 32, "valid register");
   496  
   497              // Never write to reg 0, and it can be conditional (movz, movn).
   498              if (_storeReg != 0 && _conditional) {
   499                  state.registers[_storeReg] = _val;
   500              }
   501  
   502              // Update the PC.
   503              state.pc = state.nextPC;
   504              state.nextPC = state.nextPC + 4;
   505  
   506              // Return the hash of the resulting state.
   507              out_ = outputState();
   508          }
   509      }
   510  
   511      /// @notice Computes the offset of the proof in the calldata.
   512      /// @param _proofIndex The index of the proof in the calldata.
   513      /// @return offset_ The offset of the proof in the calldata.
   514      function proofOffset(uint8 _proofIndex) internal pure returns (uint256 offset_) {
   515          unchecked {
   516              // A proof of 32 bit memory, with 32-byte leaf values, is (32-5)=27 bytes32 entries.
   517              // And the leaf value itself needs to be encoded as well. And proof.offset == 420
   518              offset_ = 420 + (uint256(_proofIndex) * (28 * 32));
   519              uint256 s = 0;
   520              assembly {
   521                  s := calldatasize()
   522              }
   523              require(s >= (offset_ + 28 * 32), "check that there is enough calldata");
   524              return offset_;
   525          }
   526      }
   527  
   528      /// @notice Reads a 32-bit value from memory.
   529      /// @param _addr The address to read from.
   530      /// @param _proofIndex The index of the proof in the calldata.
   531      /// @return out_ The hashed MIPS state.
   532      function readMem(uint32 _addr, uint8 _proofIndex) internal pure returns (uint32 out_) {
   533          unchecked {
   534              // Compute the offset of the proof in the calldata.
   535              uint256 offset = proofOffset(_proofIndex);
   536  
   537              assembly {
   538                  // Validate the address alignement.
   539                  if and(_addr, 3) { revert(0, 0) }
   540  
   541                  // Load the leaf value.
   542                  let leaf := calldataload(offset)
   543                  offset := add(offset, 32)
   544  
   545                  // Convenience function to hash two nodes together in scratch space.
   546                  function hashPair(a, b) -> h {
   547                      mstore(0, a)
   548                      mstore(32, b)
   549                      h := keccak256(0, 64)
   550                  }
   551  
   552                  // Start with the leaf node.
   553                  // Work back up by combining with siblings, to reconstruct the root.
   554                  let path := shr(5, _addr)
   555                  let node := leaf
   556                  for { let i := 0 } lt(i, 27) { i := add(i, 1) } {
   557                      let sibling := calldataload(offset)
   558                      offset := add(offset, 32)
   559                      switch and(shr(i, path), 1)
   560                      case 0 { node := hashPair(node, sibling) }
   561                      case 1 { node := hashPair(sibling, node) }
   562                  }
   563  
   564                  // Load the memory root from the first field of state.
   565                  let memRoot := mload(0x80)
   566  
   567                  // Verify the root matches.
   568                  if iszero(eq(node, memRoot)) {
   569                      mstore(0, 0x0badf00d)
   570                      revert(0, 32)
   571                  }
   572  
   573                  // Bits to shift = (32 - 4 - (addr % 32)) * 8
   574                  let shamt := shl(3, sub(sub(32, 4), and(_addr, 31)))
   575                  out_ := and(shr(shamt, leaf), 0xFFffFFff)
   576              }
   577          }
   578      }
   579  
   580      /// @notice Writes a 32-bit value to memory.
   581      ///         This function first overwrites the part of the leaf.
   582      ///         Then it recomputes the memory merkle root.
   583      /// @param _addr The address to write to.
   584      /// @param _proofIndex The index of the proof in the calldata.
   585      /// @param _val The value to write.
   586      function writeMem(uint32 _addr, uint8 _proofIndex, uint32 _val) internal pure {
   587          unchecked {
   588              // Compute the offset of the proof in the calldata.
   589              uint256 offset = proofOffset(_proofIndex);
   590  
   591              assembly {
   592                  // Validate the address alignement.
   593                  if and(_addr, 3) { revert(0, 0) }
   594  
   595                  // Load the leaf value.
   596                  let leaf := calldataload(offset)
   597                  let shamt := shl(3, sub(sub(32, 4), and(_addr, 31)))
   598  
   599                  // Mask out 4 bytes, and OR in the value
   600                  leaf := or(and(leaf, not(shl(shamt, 0xFFffFFff))), shl(shamt, _val))
   601                  offset := add(offset, 32)
   602  
   603                  // Convenience function to hash two nodes together in scratch space.
   604                  function hashPair(a, b) -> h {
   605                      mstore(0, a)
   606                      mstore(32, b)
   607                      h := keccak256(0, 64)
   608                  }
   609  
   610                  // Start with the leaf node.
   611                  // Work back up by combining with siblings, to reconstruct the root.
   612                  let path := shr(5, _addr)
   613                  let node := leaf
   614                  for { let i := 0 } lt(i, 27) { i := add(i, 1) } {
   615                      let sibling := calldataload(offset)
   616                      offset := add(offset, 32)
   617                      switch and(shr(i, path), 1)
   618                      case 0 { node := hashPair(node, sibling) }
   619                      case 1 { node := hashPair(sibling, node) }
   620                  }
   621  
   622                  // Store the new memory root in the first field of state.
   623                  mstore(0x80, node)
   624              }
   625          }
   626      }
   627  
   628      /// @notice Executes a single step of the vm.
   629      ///         Will revert if any required input state is missing.
   630      /// @param _stateData The encoded state witness data.
   631      /// @param _proof The encoded proof data for leaves within the MIPS VM's memory.
   632      /// @param _localContext The local key context for the preimage oracle. Optional, can be set as a constant
   633      ///                      if the caller only requires one set of local keys.
   634      function step(bytes calldata _stateData, bytes calldata _proof, bytes32 _localContext) public returns (bytes32) {
   635          unchecked {
   636              State memory state;
   637  
   638              // Packed calldata is ~6 times smaller than state size
   639              assembly {
   640                  if iszero(eq(state, 0x80)) {
   641                      // expected state mem offset check
   642                      revert(0, 0)
   643                  }
   644                  if iszero(eq(mload(0x40), shl(5, 48))) {
   645                      // expected memory check
   646                      revert(0, 0)
   647                  }
   648                  if iszero(eq(_stateData.offset, 132)) {
   649                      // 32*4+4=132 expected state data offset
   650                      revert(0, 0)
   651                  }
   652                  if iszero(eq(_proof.offset, 420)) {
   653                      // 132+32+256=420 expected proof offset
   654                      revert(0, 0)
   655                  }
   656  
   657                  function putField(callOffset, memOffset, size) -> callOffsetOut, memOffsetOut {
   658                      // calldata is packed, thus starting left-aligned, shift-right to pad and right-align
   659                      let w := shr(shl(3, sub(32, size)), calldataload(callOffset))
   660                      mstore(memOffset, w)
   661                      callOffsetOut := add(callOffset, size)
   662                      memOffsetOut := add(memOffset, 32)
   663                  }
   664  
   665                  // Unpack state from calldata into memory
   666                  let c := _stateData.offset // calldata offset
   667                  let m := 0x80 // mem offset
   668                  c, m := putField(c, m, 32) // memRoot
   669                  c, m := putField(c, m, 32) // preimageKey
   670                  c, m := putField(c, m, 4) // preimageOffset
   671                  c, m := putField(c, m, 4) // pc
   672                  c, m := putField(c, m, 4) // nextPC
   673                  c, m := putField(c, m, 4) // lo
   674                  c, m := putField(c, m, 4) // hi
   675                  c, m := putField(c, m, 4) // heap
   676                  c, m := putField(c, m, 1) // exitCode
   677                  c, m := putField(c, m, 1) // exited
   678                  c, m := putField(c, m, 8) // step
   679  
   680                  // Unpack register calldata into memory
   681                  mstore(m, add(m, 32)) // offset to registers
   682                  m := add(m, 32)
   683                  for { let i := 0 } lt(i, 32) { i := add(i, 1) } { c, m := putField(c, m, 4) }
   684              }
   685  
   686              // Don't change state once exited
   687              if (state.exited) {
   688                  return outputState();
   689              }
   690  
   691              state.step += 1;
   692  
   693              // instruction fetch
   694              uint32 insn = readMem(state.pc, 0);
   695              uint32 opcode = insn >> 26; // 6-bits
   696  
   697              // j-type j/jal
   698              if (opcode == 2 || opcode == 3) {
   699                  // Take top 4 bits of the next PC (its 256 MB region), and concatenate with the 26-bit offset
   700                  uint32 target = (state.nextPC & 0xF0000000) | (insn & 0x03FFFFFF) << 2;
   701                  return handleJump(opcode == 2 ? 0 : 31, target);
   702              }
   703  
   704              // register fetch
   705              uint32 rs; // source register 1 value
   706              uint32 rt; // source register 2 / temp value
   707              uint32 rtReg = (insn >> 16) & 0x1F;
   708  
   709              // R-type or I-type (stores rt)
   710              rs = state.registers[(insn >> 21) & 0x1F];
   711              uint32 rdReg = rtReg;
   712  
   713              if (opcode == 0 || opcode == 0x1c) {
   714                  // R-type (stores rd)
   715                  rt = state.registers[rtReg];
   716                  rdReg = (insn >> 11) & 0x1F;
   717              } else if (opcode < 0x20) {
   718                  // rt is SignExtImm
   719                  // don't sign extend for andi, ori, xori
   720                  if (opcode == 0xC || opcode == 0xD || opcode == 0xe) {
   721                      // ZeroExtImm
   722                      rt = insn & 0xFFFF;
   723                  } else {
   724                      // SignExtImm
   725                      rt = SE(insn & 0xFFFF, 16);
   726                  }
   727              } else if (opcode >= 0x28 || opcode == 0x22 || opcode == 0x26) {
   728                  // store rt value with store
   729                  rt = state.registers[rtReg];
   730  
   731                  // store actual rt with lwl and lwr
   732                  rdReg = rtReg;
   733              }
   734  
   735              if ((opcode >= 4 && opcode < 8) || opcode == 1) {
   736                  return handleBranch(opcode, insn, rtReg, rs);
   737              }
   738  
   739              uint32 storeAddr = 0xFF_FF_FF_FF;
   740              // memory fetch (all I-type)
   741              // we do the load for stores also
   742              uint32 mem;
   743              if (opcode >= 0x20) {
   744                  // M[R[rs]+SignExtImm]
   745                  rs += SE(insn & 0xFFFF, 16);
   746                  uint32 addr = rs & 0xFFFFFFFC;
   747                  mem = readMem(addr, 1);
   748                  if (opcode >= 0x28 && opcode != 0x30) {
   749                      // store
   750                      storeAddr = addr;
   751                      // store opcodes don't write back to a register
   752                      rdReg = 0;
   753                  }
   754              }
   755  
   756              // ALU
   757              uint32 val = execute(insn, rs, rt, mem) & 0xffFFffFF; // swr outputs more than 4 bytes without the mask
   758  
   759              uint32 func = insn & 0x3f; // 6-bits
   760              if (opcode == 0 && func >= 8 && func < 0x1c) {
   761                  if (func == 8 || func == 9) {
   762                      // jr/jalr
   763                      return handleJump(func == 8 ? 0 : rdReg, rs);
   764                  }
   765  
   766                  if (func == 0xa) {
   767                      // movz
   768                      return handleRd(rdReg, rs, rt == 0);
   769                  }
   770                  if (func == 0xb) {
   771                      // movn
   772                      return handleRd(rdReg, rs, rt != 0);
   773                  }
   774  
   775                  // syscall (can read and write)
   776                  if (func == 0xC) {
   777                      return handleSyscall(_localContext);
   778                  }
   779  
   780                  // lo and hi registers
   781                  // can write back
   782                  if (func >= 0x10 && func < 0x1c) {
   783                      return handleHiLo(func, rs, rt, rdReg);
   784                  }
   785              }
   786  
   787              // stupid sc, write a 1 to rt
   788              if (opcode == 0x38 && rtReg != 0) {
   789                  state.registers[rtReg] = 1;
   790              }
   791  
   792              // write memory
   793              if (storeAddr != 0xFF_FF_FF_FF) {
   794                  writeMem(storeAddr, 1, val);
   795              }
   796  
   797              // write back the value to destination register
   798              return handleRd(rdReg, val, true);
   799          }
   800      }
   801  
   802      /// @notice Execute an instruction.
   803      function execute(uint32 insn, uint32 rs, uint32 rt, uint32 mem) internal pure returns (uint32 out) {
   804          unchecked {
   805              uint32 opcode = insn >> 26; // 6-bits
   806  
   807              if (opcode == 0 || (opcode >= 8 && opcode < 0xF)) {
   808                  uint32 func = insn & 0x3f; // 6-bits
   809                  assembly {
   810                      // transform ArithLogI to SPECIAL
   811                      switch opcode
   812                      // addi
   813                      case 0x8 { func := 0x20 }
   814                      // addiu
   815                      case 0x9 { func := 0x21 }
   816                      // stli
   817                      case 0xA { func := 0x2A }
   818                      // sltiu
   819                      case 0xB { func := 0x2B }
   820                      // andi
   821                      case 0xC { func := 0x24 }
   822                      // ori
   823                      case 0xD { func := 0x25 }
   824                      // xori
   825                      case 0xE { func := 0x26 }
   826                  }
   827  
   828                  // sll
   829                  if (func == 0x00) {
   830                      return rt << ((insn >> 6) & 0x1F);
   831                  }
   832                  // srl
   833                  else if (func == 0x02) {
   834                      return rt >> ((insn >> 6) & 0x1F);
   835                  }
   836                  // sra
   837                  else if (func == 0x03) {
   838                      uint32 shamt = (insn >> 6) & 0x1F;
   839                      return SE(rt >> shamt, 32 - shamt);
   840                  }
   841                  // sllv
   842                  else if (func == 0x04) {
   843                      return rt << (rs & 0x1F);
   844                  }
   845                  // srlv
   846                  else if (func == 0x6) {
   847                      return rt >> (rs & 0x1F);
   848                  }
   849                  // srav
   850                  else if (func == 0x07) {
   851                      return SE(rt >> rs, 32 - rs);
   852                  }
   853                  // functs in range [0x8, 0x1b] are handled specially by other functions
   854                  // Explicitly enumerate each funct in range to reduce code diff against Go Vm
   855                  // jr
   856                  else if (func == 0x08) {
   857                      return rs;
   858                  }
   859                  // jalr
   860                  else if (func == 0x09) {
   861                      return rs;
   862                  }
   863                  // movz
   864                  else if (func == 0x0a) {
   865                      return rs;
   866                  }
   867                  // movn
   868                  else if (func == 0x0b) {
   869                      return rs;
   870                  }
   871                  // syscall
   872                  else if (func == 0x0c) {
   873                      return rs;
   874                  }
   875                  // 0x0d - break not supported
   876                  // sync
   877                  else if (func == 0x0f) {
   878                      return rs;
   879                  }
   880                  // mfhi
   881                  else if (func == 0x10) {
   882                      return rs;
   883                  }
   884                  // mthi
   885                  else if (func == 0x11) {
   886                      return rs;
   887                  }
   888                  // mflo
   889                  else if (func == 0x12) {
   890                      return rs;
   891                  }
   892                  // mtlo
   893                  else if (func == 0x13) {
   894                      return rs;
   895                  }
   896                  // mult
   897                  else if (func == 0x18) {
   898                      return rs;
   899                  }
   900                  // multu
   901                  else if (func == 0x19) {
   902                      return rs;
   903                  }
   904                  // div
   905                  else if (func == 0x1a) {
   906                      return rs;
   907                  }
   908                  // divu
   909                  else if (func == 0x1b) {
   910                      return rs;
   911                  }
   912                  // The rest includes transformed R-type arith imm instructions
   913                  // add
   914                  else if (func == 0x20) {
   915                      return (rs + rt);
   916                  }
   917                  // addu
   918                  else if (func == 0x21) {
   919                      return (rs + rt);
   920                  }
   921                  // sub
   922                  else if (func == 0x22) {
   923                      return (rs - rt);
   924                  }
   925                  // subu
   926                  else if (func == 0x23) {
   927                      return (rs - rt);
   928                  }
   929                  // and
   930                  else if (func == 0x24) {
   931                      return (rs & rt);
   932                  }
   933                  // or
   934                  else if (func == 0x25) {
   935                      return (rs | rt);
   936                  }
   937                  // xor
   938                  else if (func == 0x26) {
   939                      return (rs ^ rt);
   940                  }
   941                  // nor
   942                  else if (func == 0x27) {
   943                      return ~(rs | rt);
   944                  }
   945                  // slti
   946                  else if (func == 0x2a) {
   947                      return int32(rs) < int32(rt) ? 1 : 0;
   948                  }
   949                  // sltiu
   950                  else if (func == 0x2b) {
   951                      return rs < rt ? 1 : 0;
   952                  } else {
   953                      revert("invalid instruction");
   954                  }
   955              } else {
   956                  // SPECIAL2
   957                  if (opcode == 0x1C) {
   958                      uint32 func = insn & 0x3f; // 6-bits
   959                      // mul
   960                      if (func == 0x2) {
   961                          return uint32(int32(rs) * int32(rt));
   962                      }
   963                      // clz, clo
   964                      else if (func == 0x20 || func == 0x21) {
   965                          if (func == 0x20) {
   966                              rs = ~rs;
   967                          }
   968                          uint32 i = 0;
   969                          while (rs & 0x80000000 != 0) {
   970                              i++;
   971                              rs <<= 1;
   972                          }
   973                          return i;
   974                      }
   975                  }
   976                  // lui
   977                  else if (opcode == 0x0F) {
   978                      return rt << 16;
   979                  }
   980                  // lb
   981                  else if (opcode == 0x20) {
   982                      return SE((mem >> (24 - (rs & 3) * 8)) & 0xFF, 8);
   983                  }
   984                  // lh
   985                  else if (opcode == 0x21) {
   986                      return SE((mem >> (16 - (rs & 2) * 8)) & 0xFFFF, 16);
   987                  }
   988                  // lwl
   989                  else if (opcode == 0x22) {
   990                      uint32 val = mem << ((rs & 3) * 8);
   991                      uint32 mask = uint32(0xFFFFFFFF) << ((rs & 3) * 8);
   992                      return (rt & ~mask) | val;
   993                  }
   994                  // lw
   995                  else if (opcode == 0x23) {
   996                      return mem;
   997                  }
   998                  // lbu
   999                  else if (opcode == 0x24) {
  1000                      return (mem >> (24 - (rs & 3) * 8)) & 0xFF;
  1001                  }
  1002                  //  lhu
  1003                  else if (opcode == 0x25) {
  1004                      return (mem >> (16 - (rs & 2) * 8)) & 0xFFFF;
  1005                  }
  1006                  //  lwr
  1007                  else if (opcode == 0x26) {
  1008                      uint32 val = mem >> (24 - (rs & 3) * 8);
  1009                      uint32 mask = uint32(0xFFFFFFFF) >> (24 - (rs & 3) * 8);
  1010                      return (rt & ~mask) | val;
  1011                  }
  1012                  //  sb
  1013                  else if (opcode == 0x28) {
  1014                      uint32 val = (rt & 0xFF) << (24 - (rs & 3) * 8);
  1015                      uint32 mask = 0xFFFFFFFF ^ uint32(0xFF << (24 - (rs & 3) * 8));
  1016                      return (mem & mask) | val;
  1017                  }
  1018                  //  sh
  1019                  else if (opcode == 0x29) {
  1020                      uint32 val = (rt & 0xFFFF) << (16 - (rs & 2) * 8);
  1021                      uint32 mask = 0xFFFFFFFF ^ uint32(0xFFFF << (16 - (rs & 2) * 8));
  1022                      return (mem & mask) | val;
  1023                  }
  1024                  //  swl
  1025                  else if (opcode == 0x2a) {
  1026                      uint32 val = rt >> ((rs & 3) * 8);
  1027                      uint32 mask = uint32(0xFFFFFFFF) >> ((rs & 3) * 8);
  1028                      return (mem & ~mask) | val;
  1029                  }
  1030                  //  sw
  1031                  else if (opcode == 0x2b) {
  1032                      return rt;
  1033                  }
  1034                  //  swr
  1035                  else if (opcode == 0x2e) {
  1036                      uint32 val = rt << (24 - (rs & 3) * 8);
  1037                      uint32 mask = uint32(0xFFFFFFFF) << (24 - (rs & 3) * 8);
  1038                      return (mem & ~mask) | val;
  1039                  }
  1040                  // ll
  1041                  else if (opcode == 0x30) {
  1042                      return mem;
  1043                  }
  1044                  // sc
  1045                  else if (opcode == 0x38) {
  1046                      return rt;
  1047                  } else {
  1048                      revert("invalid instruction");
  1049                  }
  1050              }
  1051              revert("invalid instruction");
  1052          }
  1053      }
  1054  }