github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/sm3/gen_sm3block_ni.go (about)

     1  // Not used yet!!!
     2  // go run gen_sm3block_ni.go
     3  
     4  //go:build ignore
     5  // +build ignore
     6  
     7  package main
     8  
     9  import (
    10  	"bytes"
    11  	"fmt"
    12  	"log"
    13  	"math/bits"
    14  	"os"
    15  )
    16  
    17  //SM3PARTW1 <Vd>.4S, <Vn>.4S, <Vm>.4S
    18  func sm3partw1(Vd, Vn, Vm byte) uint32 {
    19  	inst := uint32(0xce60c000) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | (uint32(Vm&0x1f) << 16)
    20  	return bits.ReverseBytes32(inst)
    21  }
    22  
    23  //SM3PARTW2 <Vd>.4S, <Vn>.4S, <Vm>.4S
    24  func sm3partw2(Vd, Vn, Vm byte) uint32 {
    25  	inst := uint32(0xce60c400) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | (uint32(Vm&0x1f) << 16)
    26  	return bits.ReverseBytes32(inst)
    27  }
    28  
    29  //SM3SS1 <Vd>.4S, <Vn>.4S, <Vm>.4S, <Va>.4S
    30  func sm3ss1(Vd, Vn, Vm, Va byte) uint32 {
    31  	inst := uint32(0xce400000) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | uint32(Va&0x1f)<<10 | uint32(Vm&0x1f)<<16
    32  	return bits.ReverseBytes32(inst)
    33  }
    34  
    35  //SM3TT1A <Vd>.4S, <Vn>.4S, <Vm>.S[<imm2>]
    36  func sm3tt1a(Vd, Vn, Vm, imm2 byte) uint32 {
    37  	inst := uint32(0xce408000) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | uint32(imm2&0x3)<<12 | uint32(Vm&0x1f)<<16
    38  	return bits.ReverseBytes32(inst)
    39  }
    40  
    41  //SM3TT1B <Vd>.4S, <Vn>.4S, <Vm>.S[<imm2>]
    42  func sm3tt1b(Vd, Vn, Vm, imm2 byte) uint32 {
    43  	inst := uint32(0xce408400) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | uint32(imm2&0x3)<<12 | uint32(Vm&0x1f)<<16
    44  	return bits.ReverseBytes32(inst)
    45  }
    46  
    47  //SM3TT2A <Vd>.4S, <Vn>.4S, <Vm>.S[<imm2>]
    48  func sm3tt2a(Vd, Vn, Vm, imm2 byte) uint32 {
    49  	inst := uint32(0xce408800) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | uint32(imm2&0x3)<<12 | uint32(Vm&0x1f)<<16
    50  	return bits.ReverseBytes32(inst)
    51  }
    52  
    53  //SM3TT2B <Vd>.4S, <Vn>.4S, <Vm>.S[<imm2>]
    54  func sm3tt2b(Vd, Vn, Vm, imm2 byte) uint32 {
    55  	inst := uint32(0xce408c00) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | uint32(imm2&0x3)<<12 | uint32(Vm&0x1f)<<16
    56  	return bits.ReverseBytes32(inst)
    57  }
    58  
    59  // Used v5 as temp register
    60  func roundA(buf *bytes.Buffer, i, t, st1, st2, w, wt byte) {
    61  	fmt.Fprintf(buf, "\tWORD $0x%08x           //SM3SS1 V%d.4S, V%d.4S, V%d.4S, V%d.4S\n", sm3ss1(5, st1, t, st2), 5, st1, t, st2)
    62  	fmt.Fprintf(buf, "\tVSHL $1, V%d.S4, V%d.S4\n", t, t)
    63  	fmt.Fprintf(buf, "\tWORD $0x%08x           //SM3TT1A V%dd.4S, V%d.4S, V%d.S, %d\n", sm3tt1a(st1, 5, wt, i), st1, 5, wt, i)
    64  	fmt.Fprintf(buf, "\tWORD $0x%08x           //SM3TT2A V%dd.4S, V%d.4S, V%d.S, %d\n", sm3tt2a(st2, 5, w, i), st2, 5, w, i)
    65  }
    66  
    67  // Used v5 as temp register
    68  func roundB(buf *bytes.Buffer, i, t, st1, st2, w, wt byte) {
    69  	fmt.Fprintf(buf, "\tWORD $0x%08x           //SM3SS1 V%d.4S, V%d.4S, V%d.4S, V%d.4S\n", sm3ss1(5, st1, t, st2), 5, st1, t, st2)
    70  	fmt.Fprintf(buf, "\tVSHL $1, V%d.S4, V%d.S4\n", t, t)
    71  	fmt.Fprintf(buf, "\tWORD $0x%08x           //SM3TT1B V%dd.4S, V%d.4S, V%d.S, %d\n", sm3tt1b(st1, 5, wt, i), st1, 5, wt, i)
    72  	fmt.Fprintf(buf, "\tWORD $0x%08x           //SM3TT2B V%dd.4S, V%d.4S, V%d.S, %d\n", sm3tt2b(st2, 5, w, i), st2, 5, w, i)
    73  }
    74  
    75  // Compress 4 words and generate 4 words, use v6, v7, v10 as temp registers
    76  // s4, used to store next 4 words
    77  // s0, W(4i) W(4i+1) W(4i+2) W(4i+3)
    78  // s1, W(4i+4) W(4i+5) W(4i+6) W(4i+7)
    79  // s2, W(4i+8) W(4i+9) W(4i+10) W(4i+11)
    80  // s3, W(4i+12) W(4i+13) W(4i+14) W(4i+15)
    81  // t, t constant
    82  // st1, st2, sm3 state
    83  func qroundA(buf *bytes.Buffer, t, st1, st2, s0, s1, s2, s3, s4 byte) {
    84  	fmt.Fprintf(buf, "\t// Extension\n")
    85  	fmt.Fprintf(buf, "\tVEXT $3, V%d.B16, V%d.B16, V%d.B16\n", s2, s1, s4)
    86  	fmt.Fprintf(buf, "\tVEXT $3, V%d.B16, V%d.B16, V%d.B16\n", s1, s0, 6)
    87  	fmt.Fprintf(buf, "\tVEXT $2, V%d.B16, V%d.B16, V%d.B16\n", s3, s2, 7)
    88  	fmt.Fprintf(buf, "\tWORD $0x%08x          //SM3PARTW1 V%d.4S, V%d.4S, V%d.4S\n", sm3partw1(s4, s0, s3), s4, s0, s3)
    89  	fmt.Fprintf(buf, "\tWORD $0x%08x          //SM3PARTW2 V%d.4S, V%d.4S, V%d.4S\n", sm3partw2(s4, 7, 6), s4, 7, 6)
    90  	fmt.Fprintf(buf, "\tVEOR V%d.B16, V%d.B16, V10.B16\n", s1, s0)
    91  	fmt.Fprintf(buf, "\t// Compression\n")
    92  	roundA(buf, 0, t, st1, st2, s0, 10)
    93  	roundA(buf, 1, t, st1, st2, s0, 10)
    94  	roundA(buf, 2, t, st1, st2, s0, 10)
    95  	roundA(buf, 3, t, st1, st2, s0, 10)
    96  	fmt.Fprintf(buf, "\n")
    97  }
    98  
    99  // Used v6, v7, v10 as temp registers
   100  func qroundB(buf *bytes.Buffer, t, st1, st2, s0, s1, s2, s3, s4 byte) {
   101  	if s4 != 0xff {
   102  		fmt.Fprintf(buf, "\t// Extension\n")
   103  		fmt.Fprintf(buf, "\tVEXT $3, V%d.B16, V%d.B16, V%d.B16\n", s2, s1, s4)
   104  		fmt.Fprintf(buf, "\tVEXT $3, V%d.B16, V%d.B16, V%d.B16\n", s1, s0, 6)
   105  		fmt.Fprintf(buf, "\tVEXT $2, V%d.B16, V%d.B16, V%d.B16\n", s3, s2, 7)
   106  		fmt.Fprintf(buf, "\tWORD $0x%08x          //SM3PARTW1 V%d.4S, V%d.4S, V%d.4S\n", sm3partw1(s4, s0, s3), s4, s0, s3)
   107  		fmt.Fprintf(buf, "\tWORD $0x%08x          //SM3PARTW2 V%d.4S, V%d.4S, V%d.4S\n", sm3partw2(s4, 7, 6), s4, 7, 6)
   108  	}
   109  	fmt.Fprintf(buf, "\tVEOR V%d.B16, V%d.B16, V10.B16\n", s1, s0)
   110  	fmt.Fprintf(buf, "\t// Compression\n")
   111  	roundB(buf, 0, t, st1, st2, s0, 10)
   112  	roundB(buf, 1, t, st1, st2, s0, 10)
   113  	roundB(buf, 2, t, st1, st2, s0, 10)
   114  	roundB(buf, 3, t, st1, st2, s0, 10)
   115  	fmt.Fprintf(buf, "\n")
   116  }
   117  
   118  func main() {
   119  	buf := new(bytes.Buffer)
   120  	fmt.Fprint(buf, `
   121  // Generated by gen_sm3block_ni.go. DO NOT EDIT.
   122  
   123  #include "textflag.h"
   124  
   125  // func blockSM3NI(h []uint32, p []byte, t []uint32)
   126  TEXT ·blockSM3NI(SB), 0, $0
   127  	MOVD	h_base+0(FP), R0                           // Hash value first address
   128  	MOVD	p_base+24(FP), R1                          // message first address
   129  	MOVD	p_len+32(FP), R3                           // message length
   130  	MOVD	t_base+48(FP), R2                          // t constants first address
   131  
   132  	VLD1 (R0), [V8.S4, V9.S4]                          // load h(a,b,c,d,e,f,g,h)
   133  	LDPW	(0*8)(R2), (R5, R6)                        // load t constants
   134      
   135  blockloop:
   136  	VLD1.P	64(R1), [V0.B16, V1.B16, V2.B16, V3.B16]    // load 64bytes message
   137  	VMOV	V8.B16, V15.B16                             // backup: V8 h(dcba)
   138  	VMOV	V9.B16, V16.B16                             // backup: V9 h(hgfe)
   139  	VREV32	V0.B16, V0.B16                              // prepare for using message in Byte format
   140  	VREV32	V1.B16, V1.B16
   141  	VREV32	V2.B16, V2.B16
   142  	VREV32	V3.B16, V3.B16    
   143  	// first 16 rounds
   144  	VMOV R5, V11.S[3]
   145  `[1:])
   146  	qroundA(buf, 11, 8, 9, 0, 1, 2, 3, 4)
   147  	qroundA(buf, 11, 8, 9, 1, 2, 3, 4, 0)
   148  	qroundA(buf, 11, 8, 9, 2, 3, 4, 0, 1)
   149  	qroundA(buf, 11, 8, 9, 3, 4, 0, 1, 2)
   150  
   151  	fmt.Fprintf(buf, "\t// second 48 rounds\n")
   152  	fmt.Fprintf(buf, "\tVMOV R6, V11.S[3]\n")
   153  	qroundB(buf, 11, 8, 9, 4, 0, 1, 2, 3)
   154  	qroundB(buf, 11, 8, 9, 0, 1, 2, 3, 4)
   155  	qroundB(buf, 11, 8, 9, 1, 2, 3, 4, 0)
   156  	qroundB(buf, 11, 8, 9, 2, 3, 4, 0, 1)
   157  	qroundB(buf, 11, 8, 9, 3, 4, 0, 1, 2)
   158  	qroundB(buf, 11, 8, 9, 4, 0, 1, 2, 3)
   159  	qroundB(buf, 11, 8, 9, 0, 1, 2, 3, 4)
   160  	qroundB(buf, 11, 8, 9, 1, 2, 3, 4, 0)
   161  	qroundB(buf, 11, 8, 9, 2, 3, 4, 0, 1)
   162  	qroundB(buf, 11, 8, 9, 3, 4, 0xff, 0xff, 0xff)
   163  	qroundB(buf, 11, 8, 9, 4, 0, 0xff, 0xff, 0xff)
   164  	qroundB(buf, 11, 8, 9, 0, 1, 0xff, 0xff, 0xff)
   165  
   166  	fmt.Fprint(buf, `
   167  	SUB	$64, R3, R3                                  // message length - 64bytes, then compare with 64bytes
   168  	VEOR	V8.B16, V15.B16, V8.B16
   169  	VEOR	V9.B16, V16.B16, V9.B16
   170  	CBNZ	R3, blockloop
   171  
   172  sm3ret:
   173  	VST1	[V8.S4, V9.S4], (R0)                       // store hash value H	
   174  	RET
   175  `[1:])
   176  	src := buf.Bytes()
   177  	// fmt.Println(string(src))
   178  
   179  	err := os.WriteFile("sm3blockni_arm64.s", src, 0644)
   180  	if err != nil {
   181  		log.Fatal(err)
   182  	}
   183  }