github.com/tetratelabs/wazero@v1.7.3-0.20240513003603-48f702e154b5/internal/engine/wazevo/ssa/pass_test.go (about)

     1  package ssa
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/tetratelabs/wazero/internal/testing/require"
     7  )
     8  
     9  func TestBuilder_passes(t *testing.T) {
    10  	for _, tc := range []struct {
    11  		name string
    12  		// pass is the optimization pass to run.
    13  		pass,
    14  		// postPass is run after the pass is executed, and can be used to test a pass that depends on another pass.
    15  		postPass func(b *builder)
    16  		// setup creates the SSA function in the given *builder.
    17  		// TODO: when we have the text SSA IR parser, we can eliminate this `setup`,
    18  		// 	we could directly decode the *builder from the `before` string. I am still
    19  		//  constantly changing the format, so let's keep setup for now.
    20  		// `verifier` is executed after executing pass, and can be used to
    21  		// do the additional verification of the state of SSA function in addition to `after` text result.
    22  		setup func(*builder) (verifier func(t *testing.T))
    23  		// before is the expected SSA function after `setup` is executed.
    24  		before,
    25  		// after is the expected output after optimization pass.
    26  		after string
    27  	}{
    28  		{
    29  			name: "dead block",
    30  			pass: passDeadBlockEliminationOpt,
    31  			setup: func(b *builder) func(*testing.T) {
    32  				entry := b.AllocateBasicBlock()
    33  				value := entry.AddParam(b, TypeI32)
    34  
    35  				middle1, middle2 := b.AllocateBasicBlock(), b.AllocateBasicBlock()
    36  				end := b.AllocateBasicBlock()
    37  
    38  				b.SetCurrentBlock(entry)
    39  				{
    40  					brz := b.AllocateInstruction()
    41  					brz.AsBrz(value, ValuesNil, middle1)
    42  					b.InsertInstruction(brz)
    43  
    44  					jmp := b.AllocateInstruction()
    45  					jmp.AsJump(ValuesNil, middle2)
    46  					b.InsertInstruction(jmp)
    47  				}
    48  
    49  				b.SetCurrentBlock(middle1)
    50  				{
    51  					jmp := b.AllocateInstruction()
    52  					jmp.AsJump(ValuesNil, end)
    53  					b.InsertInstruction(jmp)
    54  				}
    55  
    56  				b.SetCurrentBlock(middle2)
    57  				{
    58  					jmp := b.AllocateInstruction()
    59  					jmp.AsJump(ValuesNil, end)
    60  					b.InsertInstruction(jmp)
    61  				}
    62  
    63  				{
    64  					unreachable := b.AllocateBasicBlock()
    65  					b.SetCurrentBlock(unreachable)
    66  					jmp := b.AllocateInstruction()
    67  					jmp.AsJump(ValuesNil, end)
    68  					b.InsertInstruction(jmp)
    69  				}
    70  
    71  				b.SetCurrentBlock(end)
    72  				{
    73  					jmp := b.AllocateInstruction()
    74  					jmp.AsJump(ValuesNil, middle1)
    75  					b.InsertInstruction(jmp)
    76  				}
    77  
    78  				b.Seal(entry)
    79  				b.Seal(middle1)
    80  				b.Seal(middle2)
    81  				b.Seal(end)
    82  				return nil
    83  			},
    84  			before: `
    85  blk0: (v0:i32)
    86  	Brz v0, blk1
    87  	Jump blk2
    88  
    89  blk1: () <-- (blk0,blk3)
    90  	Jump blk3
    91  
    92  blk2: () <-- (blk0)
    93  	Jump blk3
    94  
    95  blk3: () <-- (blk1,blk2,blk4)
    96  	Jump blk1
    97  
    98  blk4: ()
    99  	Jump blk3
   100  `,
   101  			after: `
   102  blk0: (v0:i32)
   103  	Brz v0, blk1
   104  	Jump blk2
   105  
   106  blk1: () <-- (blk0,blk3)
   107  	Jump blk3
   108  
   109  blk2: () <-- (blk0)
   110  	Jump blk3
   111  
   112  blk3: () <-- (blk1,blk2)
   113  	Jump blk1
   114  `,
   115  		},
   116  		{
   117  			name: "redundant phis",
   118  			pass: passRedundantPhiEliminationOpt,
   119  			setup: func(b *builder) func(*testing.T) {
   120  				entry, loopHeader, end := b.AllocateBasicBlock(), b.AllocateBasicBlock(), b.AllocateBasicBlock()
   121  
   122  				loopHeader.AddParam(b, TypeI32)
   123  				var1 := b.DeclareVariable(TypeI32)
   124  
   125  				b.SetCurrentBlock(entry)
   126  				{
   127  					constInst := b.AllocateInstruction()
   128  					constInst.AsIconst32(0xff)
   129  					b.InsertInstruction(constInst)
   130  					iConst := constInst.Return()
   131  					b.DefineVariable(var1, iConst, entry)
   132  
   133  					jmp := b.AllocateInstruction()
   134  					args := b.varLengthPool.Allocate(1)
   135  					args = args.Append(&b.varLengthPool, iConst)
   136  					jmp.AsJump(args, loopHeader)
   137  					b.InsertInstruction(jmp)
   138  				}
   139  				b.Seal(entry)
   140  
   141  				b.SetCurrentBlock(loopHeader)
   142  				{
   143  					// At this point, loop is not sealed, so PHI will be added to this header. However, the only
   144  					// input to the PHI is iConst above, so there must be an alias to iConst from the PHI value.
   145  					value := b.MustFindValue(var1)
   146  
   147  					tmpInst := b.AllocateInstruction()
   148  					tmpInst.AsIconst32(0xff)
   149  					b.InsertInstruction(tmpInst)
   150  					tmp := tmpInst.Return()
   151  
   152  					args := b.varLengthPool.Allocate(0)
   153  					args = args.Append(&b.varLengthPool, tmp)
   154  					brz := b.AllocateInstruction()
   155  					brz.AsBrz(value, args, loopHeader) // Loop to itself.
   156  					b.InsertInstruction(brz)
   157  
   158  					jmp := b.AllocateInstruction()
   159  					jmp.AsJump(ValuesNil, end)
   160  					b.InsertInstruction(jmp)
   161  				}
   162  				b.Seal(loopHeader)
   163  
   164  				b.SetCurrentBlock(end)
   165  				{
   166  					ret := b.AllocateInstruction()
   167  					ret.AsReturn(ValuesNil)
   168  					b.InsertInstruction(ret)
   169  				}
   170  				return nil
   171  			},
   172  			before: `
   173  blk0: ()
   174  	v1:i32 = Iconst_32 0xff
   175  	Jump blk1, v1, v1
   176  
   177  blk1: (v0:i32,v2:i32) <-- (blk0,blk1)
   178  	v3:i32 = Iconst_32 0xff
   179  	Brz v2, blk1, v3, v2
   180  	Jump blk2
   181  
   182  blk2: () <-- (blk1)
   183  	Return
   184  `,
   185  			after: `
   186  blk0: ()
   187  	v1:i32 = Iconst_32 0xff
   188  	Jump blk1, v1
   189  
   190  blk1: (v0:i32) <-- (blk0,blk1)
   191  	v3:i32 = Iconst_32 0xff
   192  	Brz v1, blk1, v3
   193  	Jump blk2
   194  
   195  blk2: () <-- (blk1)
   196  	Return
   197  `,
   198  		},
   199  		{
   200  			name: "dead code",
   201  			pass: passDeadCodeEliminationOpt,
   202  			setup: func(b *builder) func(*testing.T) {
   203  				entry, end := b.AllocateBasicBlock(), b.AllocateBasicBlock()
   204  
   205  				b.SetCurrentBlock(entry)
   206  				iconstRefThriceInst := b.AllocateInstruction()
   207  				iconstRefThriceInst.AsIconst32(3)
   208  				b.InsertInstruction(iconstRefThriceInst)
   209  				refThriceVal := iconstRefThriceInst.Return()
   210  
   211  				// This has side effect.
   212  				store := b.AllocateInstruction()
   213  				store.AsStore(OpcodeStore, refThriceVal, refThriceVal, 0)
   214  				b.InsertInstruction(store)
   215  
   216  				iconstDeadInst := b.AllocateInstruction()
   217  				iconstDeadInst.AsIconst32(0)
   218  				b.InsertInstruction(iconstDeadInst)
   219  
   220  				iconstRefOnceInst := b.AllocateInstruction()
   221  				iconstRefOnceInst.AsIconst32(1)
   222  				b.InsertInstruction(iconstRefOnceInst)
   223  				refOnceVal := iconstRefOnceInst.Return()
   224  
   225  				jmp := b.AllocateInstruction()
   226  				jmp.AsJump(ValuesNil, end)
   227  				b.InsertInstruction(jmp)
   228  
   229  				b.SetCurrentBlock(end)
   230  				aliasedRefOnceVal := b.allocateValue(refOnceVal.Type())
   231  				b.alias(aliasedRefOnceVal, refOnceVal)
   232  
   233  				add := b.AllocateInstruction()
   234  				add.AsIadd(aliasedRefOnceVal, refThriceVal)
   235  				b.InsertInstruction(add)
   236  
   237  				addRes := add.Return()
   238  
   239  				ret := b.AllocateInstruction()
   240  				args := b.varLengthPool.Allocate(1)
   241  				args = args.Append(&b.varLengthPool, addRes)
   242  				ret.AsReturn(args)
   243  				b.InsertInstruction(ret)
   244  				return func(t *testing.T) {
   245  					// Group IDs.
   246  					const gid0, gid1, gid2 InstructionGroupID = 0, 1, 2
   247  					require.Equal(t, gid0, iconstRefThriceInst.gid)
   248  					require.Equal(t, gid0, store.gid)
   249  					require.Equal(t, gid1, iconstDeadInst.gid)
   250  					require.Equal(t, gid1, iconstRefOnceInst.gid)
   251  					require.Equal(t, gid1, jmp.gid)
   252  					// Different blocks have different gids.
   253  					require.Equal(t, gid2, add.gid)
   254  					require.Equal(t, gid2, ret.gid)
   255  
   256  					// Dead or Alive...
   257  					require.False(t, iconstDeadInst.live)
   258  					require.True(t, iconstRefOnceInst.live)
   259  					require.True(t, iconstRefThriceInst.live)
   260  					require.True(t, add.live)
   261  					require.True(t, jmp.live)
   262  					require.True(t, ret.live)
   263  
   264  					require.Equal(t, 1, b.valueRefCounts[refOnceVal.ID()])
   265  					require.Equal(t, 1, b.valueRefCounts[addRes.ID()])
   266  					require.Equal(t, 3, b.valueRefCounts[refThriceVal.ID()])
   267  				}
   268  			},
   269  			before: `
   270  blk0: ()
   271  	v0:i32 = Iconst_32 0x3
   272  	Store v0, v0, 0x0
   273  	v1:i32 = Iconst_32 0x0
   274  	v2:i32 = Iconst_32 0x1
   275  	Jump blk1
   276  
   277  blk1: () <-- (blk0)
   278  	v4:i32 = Iadd v3, v0
   279  	Return v4
   280  `,
   281  			after: `
   282  blk0: ()
   283  	v0:i32 = Iconst_32 0x3
   284  	Store v0, v0, 0x0
   285  	v2:i32 = Iconst_32 0x1
   286  	Jump blk1
   287  
   288  blk1: () <-- (blk0)
   289  	v4:i32 = Iadd v2, v0
   290  	Return v4
   291  `,
   292  		},
   293  		{
   294  			name:     "nop elimination",
   295  			pass:     passNopInstElimination,
   296  			postPass: passDeadCodeEliminationOpt,
   297  			setup: func(b *builder) (verifier func(t *testing.T)) {
   298  				entry := b.AllocateBasicBlock()
   299  				b.SetCurrentBlock(entry)
   300  
   301  				i32Param := entry.AddParam(b, TypeI32)
   302  				i64Param := entry.AddParam(b, TypeI64)
   303  
   304  				// 32-bit shift.
   305  				moduleZeroI32 := b.AllocateInstruction().AsIconst32(32 * 245).Insert(b).Return()
   306  				nopIshl := b.AllocateInstruction().AsIshl(i32Param, moduleZeroI32).Insert(b).Return()
   307  
   308  				// 64-bit shift.
   309  				moduleZeroI64 := b.AllocateInstruction().AsIconst64(64 * 245).Insert(b).Return()
   310  				nopUshr := b.AllocateInstruction().AsUshr(i64Param, moduleZeroI64).Insert(b).Return()
   311  
   312  				// Non zero shift amount should not be eliminated.
   313  				nonZeroI32 := b.AllocateInstruction().AsIconst32(32*245 + 1).Insert(b).Return()
   314  				nonZeroIshl := b.AllocateInstruction().AsIshl(i32Param, nonZeroI32).Insert(b).Return()
   315  
   316  				nonZeroI64 := b.AllocateInstruction().AsIconst64(64*245 + 1).Insert(b).Return()
   317  				nonZeroSshr := b.AllocateInstruction().AsSshr(i64Param, nonZeroI64).Insert(b).Return()
   318  
   319  				ret := b.AllocateInstruction()
   320  				args := b.varLengthPool.Allocate(4)
   321  				args = args.Append(&b.varLengthPool, nopIshl)
   322  				args = args.Append(&b.varLengthPool, nopUshr)
   323  				args = args.Append(&b.varLengthPool, nonZeroIshl)
   324  				args = args.Append(&b.varLengthPool, nonZeroSshr)
   325  				ret.AsReturn(args)
   326  				b.InsertInstruction(ret)
   327  				return nil
   328  			},
   329  			before: `
   330  blk0: (v0:i32, v1:i64)
   331  	v2:i32 = Iconst_32 0x1ea0
   332  	v3:i32 = Ishl v0, v2
   333  	v4:i64 = Iconst_64 0x3d40
   334  	v5:i64 = Ushr v1, v4
   335  	v6:i32 = Iconst_32 0x1ea1
   336  	v7:i32 = Ishl v0, v6
   337  	v8:i64 = Iconst_64 0x3d41
   338  	v9:i64 = Sshr v1, v8
   339  	Return v3, v5, v7, v9
   340  `,
   341  			after: `
   342  blk0: (v0:i32, v1:i64)
   343  	v6:i32 = Iconst_32 0x1ea1
   344  	v7:i32 = Ishl v0, v6
   345  	v8:i64 = Iconst_64 0x3d41
   346  	v9:i64 = Sshr v1, v8
   347  	Return v0, v1, v7, v9
   348  `,
   349  		},
   350  	} {
   351  		tc := tc
   352  		t.Run(tc.name, func(t *testing.T) {
   353  			b := NewBuilder().(*builder)
   354  			verifier := tc.setup(b)
   355  			require.Equal(t, tc.before, b.Format())
   356  			tc.pass(b)
   357  			if verifier != nil {
   358  				verifier(t)
   359  			}
   360  			if tc.postPass != nil {
   361  				tc.postPass(b)
   362  			}
   363  			require.Equal(t, tc.after, b.Format())
   364  		})
   365  	}
   366  }