github.com/grafviktor/keep-my-secret@v0.9.10-0.20230908165355-19f35cce90e5/internal/model/secret.go (about) 1 package model 2 3 import ( 4 "fmt" 5 "log" 6 "reflect" 7 8 "github.com/samber/lo" 9 10 "github.com/grafviktor/keep-my-secret/internal/api/utils" 11 ) 12 13 // var shouldNotEncrypt = []string{"ID", "Type", "Title"} 14 var shouldNotEncrypt = []string{"ID", "Encryptor"} 15 16 // Encryptor is used for setting encrypting method for Secret model. This interface is used mainly for mocking 17 type Encryptor interface { 18 Encrypt(secret *Secret, key, salt string) error 19 Decrypt(secret *Secret, key, salt string) error 20 } 21 22 // Secret is a model of secret object which the application receives from the client 23 type Secret struct { 24 ID int64 `json:"id"` 25 Type string `json:"type"` 26 Title string `json:"title"` 27 Login string `json:"login"` 28 Password string `json:"password"` 29 Note string `json:"note"` 30 File []byte `json:"-"` 31 FileName string `json:"file_name"` 32 CardholderName string `json:"cardholder_name"` 33 CardNumber string `json:"card_number"` 34 Expiration string `json:"expiration"` 35 SecurityCode string `json:"security_code"` 36 Encryptor Encryptor `json:"-"` 37 } 38 39 // SetEncryptor should be used for setting concrete encryptor implementation. Currently used in unit tests 40 func (s *Secret) SetEncryptor(encryptor Encryptor) { 41 s.Encryptor = encryptor 42 } 43 44 const ( 45 typeString = "string" 46 typeBinary = "[]uint8" 47 ) 48 49 // Encrypt - encrypts object using key and salt 50 func (s *Secret) Encrypt(key, salt string) error { 51 if s.Encryptor != nil { 52 return s.Encryptor.Encrypt(s, key, salt) 53 } 54 55 v := reflect.Indirect(reflect.ValueOf(s)) 56 typeOfP := v.Type() 57 58 for i := 0; i < v.NumField(); i++ { 59 field := v.Field(i) 60 fieldName := typeOfP.Field(i).Name 61 62 // skip fields that should not be decrypted 63 if _, ok := lo.Find(shouldNotEncrypt, func(i string) bool { 64 return i == fieldName 65 }); ok { 66 continue 67 } 68 69 fieldType := field.Type().String() 70 var toEncrypt []byte 71 72 switch { 73 case fieldType == typeString: 74 fieldValue, _ := field.Interface().(string) 75 toEncrypt = []byte(salt + fieldValue) 76 case fieldType == typeBinary: 77 fieldValue, _ := field.Interface().([]byte) 78 toEncrypt = append([]byte(salt), fieldValue...) 79 default: 80 log.Printf("secret decrypt: field %s is not a string", fieldName) 81 continue 82 } 83 84 encrypted, err := utils.Encrypt(toEncrypt, key) 85 if err != nil { 86 return fmt.Errorf("secret.Encrypt: %s", err.Error()) 87 } 88 89 if fieldType == typeString { 90 v.Field(i).SetString(string(encrypted)) 91 } else { 92 v.Field(i).SetBytes(encrypted) 93 } 94 } 95 96 return nil 97 } 98 99 // Decrypt - decrypts object using key and salt 100 func (s *Secret) Decrypt(key, salt string) error { 101 if s.Encryptor != nil { 102 return s.Encryptor.Decrypt(s, key, salt) 103 } 104 105 v := reflect.Indirect(reflect.ValueOf(s)) 106 typeOfP := v.Type() 107 108 for i := 0; i < v.NumField(); i++ { 109 field := v.Field(i) 110 fieldName := typeOfP.Field(i).Name 111 112 // skip fields that should not be decrypted 113 if _, ok := lo.Find(shouldNotEncrypt, func(i string) bool { 114 return i == fieldName 115 }); ok { 116 continue 117 } 118 119 fieldType := field.Type().String() 120 var toDecrypt []byte 121 122 switch { 123 case fieldType == typeString: 124 fieldValue, _ := field.Interface().(string) 125 toDecrypt = []byte(fieldValue) 126 case fieldType == typeBinary: 127 fieldValue, _ := field.Interface().([]byte) 128 toDecrypt = fieldValue 129 default: 130 log.Printf("secret.Decrypt: field %s is not a string", fieldName) 131 continue 132 } 133 134 decrypted, err := utils.Decrypt(toDecrypt, key) 135 if err != nil { 136 return fmt.Errorf("secret encrypt: %s", err.Error()) 137 } 138 139 if fieldType == typeString { 140 decryptedStr := "" 141 if len(decrypted) > len(salt) { 142 decryptedStr = string(decrypted[len(salt):]) 143 } 144 145 v.Field(i).SetString(decryptedStr) 146 } else { 147 decryptedBytes := []byte{} 148 if len(decrypted) > len(salt) { 149 decryptedBytes = decrypted[len(salt):] 150 } 151 152 v.Field(i).SetBytes(decryptedBytes) 153 } 154 } 155 156 return nil 157 }