github.com/emmansun/gmsm@v0.29.1/sm3/gen_sm3block_ni.go (about)

     1  // Not used yet!!!
     2  // go run gen_sm3block_ni.go
     3  
     4  //go:build ignore
     5  
     6  package main
     7  
     8  import (
     9  	"bytes"
    10  	"fmt"
    11  	"log"
    12  	"os"
    13  )
    14  
    15  //SM3PARTW1 <Vd>.4S, <Vn>.4S, <Vm>.4S
    16  func sm3partw1(Vd, Vn, Vm byte) uint32 {
    17  	inst := uint32(0xce60c000) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | (uint32(Vm&0x1f) << 16)
    18  	// return bits.ReverseBytes32(inst)
    19  	return inst
    20  }
    21  
    22  //SM3PARTW2 <Vd>.4S, <Vn>.4S, <Vm>.4S
    23  func sm3partw2(Vd, Vn, Vm byte) uint32 {
    24  	inst := uint32(0xce60c400) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | (uint32(Vm&0x1f) << 16)
    25  	// return bits.ReverseBytes32(inst)
    26  	return inst
    27  }
    28  
    29  //SM3SS1 <Vd>.4S, <Vn>.4S, <Vm>.4S, <Va>.4S
    30  func sm3ss1(Vd, Vn, Vm, Va byte) uint32 {
    31  	inst := uint32(0xce400000) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | uint32(Va&0x1f)<<10 | uint32(Vm&0x1f)<<16
    32  	// return bits.ReverseBytes32(inst)
    33  	return inst
    34  }
    35  
    36  //SM3TT1A <Vd>.4S, <Vn>.4S, <Vm>.S[<imm2>]
    37  func sm3tt1a(Vd, Vn, Vm, imm2 byte) uint32 {
    38  	inst := uint32(0xce408000) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | uint32(imm2&0x3)<<12 | uint32(Vm&0x1f)<<16
    39  	// return bits.ReverseBytes32(inst)
    40  	return inst
    41  }
    42  
    43  //SM3TT1B <Vd>.4S, <Vn>.4S, <Vm>.S[<imm2>]
    44  func sm3tt1b(Vd, Vn, Vm, imm2 byte) uint32 {
    45  	inst := uint32(0xce408400) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | uint32(imm2&0x3)<<12 | uint32(Vm&0x1f)<<16
    46  	// return bits.ReverseBytes32(inst)
    47  	return inst
    48  }
    49  
    50  //SM3TT2A <Vd>.4S, <Vn>.4S, <Vm>.S[<imm2>]
    51  func sm3tt2a(Vd, Vn, Vm, imm2 byte) uint32 {
    52  	inst := uint32(0xce408800) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | uint32(imm2&0x3)<<12 | uint32(Vm&0x1f)<<16
    53  	// return bits.ReverseBytes32(inst)
    54  	return inst
    55  }
    56  
    57  //SM3TT2B <Vd>.4S, <Vn>.4S, <Vm>.S[<imm2>]
    58  func sm3tt2b(Vd, Vn, Vm, imm2 byte) uint32 {
    59  	inst := uint32(0xce408c00) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | uint32(imm2&0x3)<<12 | uint32(Vm&0x1f)<<16
    60  	// return bits.ReverseBytes32(inst)
    61  	return inst
    62  }
    63  
    64  // Used v5 as temp register
    65  func roundA(buf *bytes.Buffer, i, t0, t1, st1, st2, w, wt byte) {
    66  	fmt.Fprintf(buf, "\tWORD $0x%08x           //SM3SS1 V%d.4S, V%d.4S, V%d.4S, V%d.4S\n", sm3ss1(5, st1, t0, st2), 5, st1, t0, st2)
    67  	fmt.Fprintf(buf, "\tVSHL $1, V%d.S4, V%d.S4\n", t0, t1)
    68  	fmt.Fprintf(buf, "\tVSRI $31, V%d.S4, V%d.S4\n", t0, t1)
    69  	fmt.Fprintf(buf, "\tWORD $0x%08x           //SM3TT1A V%dd.4S, V%d.4S, V%d.S, %d\n", sm3tt1a(st1, 5, wt, i), st1, 5, wt, i)
    70  	fmt.Fprintf(buf, "\tWORD $0x%08x           //SM3TT2A V%dd.4S, V%d.4S, V%d.S, %d\n", sm3tt2a(st2, 5, w, i), st2, 5, w, i)
    71  }
    72  
    73  // Used v5 as temp register
    74  func roundB(buf *bytes.Buffer, i, t0, t1, st1, st2, w, wt byte) {
    75  	fmt.Fprintf(buf, "\tWORD $0x%08x           //SM3SS1 V%d.4S, V%d.4S, V%d.4S, V%d.4S\n", sm3ss1(5, st1, t0, st2), 5, st1, t0, st2)
    76  	fmt.Fprintf(buf, "\tVSHL $1, V%d.S4, V%d.S4\n", t0, t1)
    77  	fmt.Fprintf(buf, "\tVSRI $31, V%d.S4, V%d.S4\n", t0, t1)
    78  	fmt.Fprintf(buf, "\tWORD $0x%08x           //SM3TT1B V%dd.4S, V%d.4S, V%d.S, %d\n", sm3tt1b(st1, 5, wt, i), st1, 5, wt, i)
    79  	fmt.Fprintf(buf, "\tWORD $0x%08x           //SM3TT2B V%dd.4S, V%d.4S, V%d.S, %d\n", sm3tt2b(st2, 5, w, i), st2, 5, w, i)
    80  }
    81  
    82  // Compress 4 words and generate 4 words, use v6, v7, v10 as temp registers
    83  // s4, used to store next 4 words
    84  // s0, W(4i) W(4i+1) W(4i+2) W(4i+3)
    85  // s1, W(4i+4) W(4i+5) W(4i+6) W(4i+7)
    86  // s2, W(4i+8) W(4i+9) W(4i+10) W(4i+11)
    87  // s3, W(4i+12) W(4i+13) W(4i+14) W(4i+15)
    88  // t, t constant
    89  // st1, st2, sm3 state
    90  func qroundA(buf *bytes.Buffer, t0, t1, st1, st2, s0, s1, s2, s3, s4 byte) {
    91  	fmt.Fprintf(buf, "\t// Extension\n")
    92  	fmt.Fprintf(buf, "\tVEXT $12, V%d.B16, V%d.B16, V%d.B16\n", s2, s1, s4)
    93  	fmt.Fprintf(buf, "\tVEXT $12, V%d.B16, V%d.B16, V%d.B16\n", s1, s0, 6)
    94  	fmt.Fprintf(buf, "\tVEXT $8, V%d.B16, V%d.B16, V%d.B16\n", s3, s2, 7)
    95  	fmt.Fprintf(buf, "\tWORD $0x%08x          //SM3PARTW1 V%d.4S, V%d.4S, V%d.4S\n", sm3partw1(s4, s0, s3), s4, s0, s3)
    96  	fmt.Fprintf(buf, "\tWORD $0x%08x          //SM3PARTW2 V%d.4S, V%d.4S, V%d.4S\n", sm3partw2(s4, 7, 6), s4, 7, 6)
    97  	fmt.Fprintf(buf, "\tVEOR V%d.B16, V%d.B16, V10.B16\n", s1, s0)
    98  	fmt.Fprintf(buf, "\t// Compression\n")
    99  	roundA(buf, 0, t0, t1, st1, st2, s0, 10)
   100  	roundA(buf, 1, t1, t0, st1, st2, s0, 10)
   101  	roundA(buf, 2, t0, t1, st1, st2, s0, 10)
   102  	roundA(buf, 3, t1, t0, st1, st2, s0, 10)
   103  	fmt.Fprintf(buf, "\n")
   104  }
   105  
   106  // Used v6, v7, v10 as temp registers
   107  func qroundB(buf *bytes.Buffer, t0, t1, st1, st2, s0, s1, s2, s3, s4 byte) {
   108  	if s4 != 0xff {
   109  		fmt.Fprintf(buf, "\t// Extension\n")
   110  		fmt.Fprintf(buf, "\tVEXT $12, V%d.B16, V%d.B16, V%d.B16\n", s2, s1, s4)
   111  		fmt.Fprintf(buf, "\tVEXT $12, V%d.B16, V%d.B16, V%d.B16\n", s1, s0, 6)
   112  		fmt.Fprintf(buf, "\tVEXT $8, V%d.B16, V%d.B16, V%d.B16\n", s3, s2, 7)
   113  		fmt.Fprintf(buf, "\tWORD $0x%08x          //SM3PARTW1 V%d.4S, V%d.4S, V%d.4S\n", sm3partw1(s4, s0, s3), s4, s0, s3)
   114  		fmt.Fprintf(buf, "\tWORD $0x%08x          //SM3PARTW2 V%d.4S, V%d.4S, V%d.4S\n", sm3partw2(s4, 7, 6), s4, 7, 6)
   115  	}
   116  	fmt.Fprintf(buf, "\tVEOR V%d.B16, V%d.B16, V10.B16\n", s1, s0)
   117  	fmt.Fprintf(buf, "\t// Compression\n")
   118  	roundB(buf, 0, t0, t1, st1, st2, s0, 10)
   119  	roundB(buf, 1, t1, t0, st1, st2, s0, 10)
   120  	roundB(buf, 2, t0, t1, st1, st2, s0, 10)
   121  	roundB(buf, 3, t1, t0, st1, st2, s0, 10)
   122  	fmt.Fprintf(buf, "\n")
   123  }
   124  
   125  func main() {
   126  	buf := new(bytes.Buffer)
   127  	fmt.Fprint(buf, `
   128  // Generated by gen_sm3block_ni.go. DO NOT EDIT.
   129  //go:build !purego
   130  
   131  #include "textflag.h"
   132  
   133  // func blockSM3NI(h []uint32, p []byte, t []uint32)
   134  TEXT ·blockSM3NI(SB), 0, $0
   135  	MOVD	h_base+0(FP), R0                           // Hash value first address
   136  	MOVD	p_base+24(FP), R1                          // message first address
   137  	MOVD	p_len+32(FP), R3                           // message length
   138  	MOVD	t_base+48(FP), R2                          // t constants first address
   139  
   140  	VLD1 (R0), [V8.S4, V9.S4]                          // load h(a,b,c,d,e,f,g,h)
   141  	VREV64  V8.S4, V8.S4
   142  	VEXT $8, V8.B16, V8.B16, V8.B16
   143  	VREV64  V9.S4, V9.S4
   144  	VEXT $8, V9.B16, V9.B16, V9.B16
   145  	LDPW	(0*8)(R2), (R5, R6)                        // load t constants
   146      
   147  blockloop:
   148  	VLD1.P	64(R1), [V0.B16, V1.B16, V2.B16, V3.B16]    // load 64bytes message
   149  	VMOV	V8.B16, V15.B16                             // backup: V8 h(dcba)
   150  	VMOV	V9.B16, V16.B16                             // backup: V9 h(hgfe)
   151  	VREV32	V0.B16, V0.B16                              // prepare for using message in Byte format
   152  	VREV32	V1.B16, V1.B16
   153  	VREV32	V2.B16, V2.B16
   154  	VREV32	V3.B16, V3.B16    
   155  	// first 16 rounds
   156  	VMOV R5, V11.S[3]
   157  `[1:])
   158  	qroundA(buf, 11, 12, 8, 9, 0, 1, 2, 3, 4)
   159  	qroundA(buf, 11, 12, 8, 9, 1, 2, 3, 4, 0)
   160  	qroundA(buf, 11, 12, 8, 9, 2, 3, 4, 0, 1)
   161  	qroundA(buf, 11, 12, 8, 9, 3, 4, 0, 1, 2)
   162  
   163  	fmt.Fprintf(buf, "\t// second 48 rounds\n")
   164  	fmt.Fprintf(buf, "\tVMOV R6, V11.S[3]\n")
   165  	qroundB(buf, 11, 12, 8, 9, 4, 0, 1, 2, 3)
   166  	qroundB(buf, 11, 12, 8, 9, 0, 1, 2, 3, 4)
   167  	qroundB(buf, 11, 12, 8, 9, 1, 2, 3, 4, 0)
   168  	qroundB(buf, 11, 12, 8, 9, 2, 3, 4, 0, 1)
   169  	qroundB(buf, 11, 12, 8, 9, 3, 4, 0, 1, 2)
   170  	qroundB(buf, 11, 12, 8, 9, 4, 0, 1, 2, 3)
   171  	qroundB(buf, 11, 12, 8, 9, 0, 1, 2, 3, 4)
   172  	qroundB(buf, 11, 12, 8, 9, 1, 2, 3, 4, 0)
   173  	qroundB(buf, 11, 12, 8, 9, 2, 3, 4, 0, 1)
   174  	qroundB(buf, 11, 12, 8, 9, 3, 4, 0xff, 0xff, 0xff)
   175  	qroundB(buf, 11, 12, 8, 9, 4, 0, 0xff, 0xff, 0xff)
   176  	qroundB(buf, 11, 12, 8, 9, 0, 1, 0xff, 0xff, 0xff)
   177  
   178  	fmt.Fprint(buf, `
   179  	SUB	$64, R3, R3                                  // message length - 64bytes, then compare with 64bytes
   180  	VEOR	V8.B16, V15.B16, V8.B16
   181  	VEOR	V9.B16, V16.B16, V9.B16
   182  	CBNZ	R3, blockloop
   183  
   184  sm3ret:
   185  	VREV64  V8.S4, V8.S4
   186  	VEXT $8, V8.B16, V8.B16, V8.B16
   187  	VREV64  V9.S4, V9.S4
   188  	VEXT $8, V9.B16, V9.B16, V9.B16
   189  	VST1	[V8.S4, V9.S4], (R0)                       // store hash value H	
   190  	RET
   191  
   192  `[1:])
   193  	src := buf.Bytes()
   194  	// fmt.Println(string(src))
   195  
   196  	err := os.WriteFile("sm3blockni_arm64.s", src, 0644)
   197  	if err != nil {
   198  		log.Fatal(err)
   199  	}
   200  }