diff --git a/cryptor/crypto.go b/cryptor/crypto.go index f0ca370..11527f2 100644 --- a/cryptor/crypto.go +++ b/cryptor/crypto.go @@ -8,16 +8,20 @@ package cryptor import ( "bytes" + "crypto" "crypto/aes" "crypto/cipher" "crypto/des" "crypto/rand" "crypto/rsa" "crypto/sha256" + "crypto/sha512" "crypto/x509" "encoding/pem" + "errors" "io" "os" + "strings" ) // AesEcbEncrypt encrypt data with key use AES ECB algorithm @@ -591,6 +595,7 @@ func RsaEncrypt(data []byte, pubKeyFileName string) []byte { if err != nil { panic(err) } + return cipherText } @@ -624,6 +629,7 @@ func RsaDecrypt(data []byte, privateKeyFileName string) []byte { if err != nil { panic(err) } + return plainText } @@ -655,3 +661,150 @@ func RsaDecryptOAEP(ciphertext []byte, label []byte, key rsa.PrivateKey) ([]byte return decryptedBytes, nil } + +// RsaSign signs the data with RSA. +// Play: todo +func RsaSign(hash crypto.Hash, data []byte, privateKeyFileName string) ([]byte, error) { + privateKey, err := loadRasPrivateKey(privateKeyFileName) + if err != nil { + return nil, err + } + + hashed, err := hashData(hash, data) + if err != nil { + return nil, err + } + + return rsa.SignPKCS1v15(rand.Reader, privateKey, hash, hashed) +} + +// RsaVerifySign verifies the signature of the data with RSA. +// Play: todo +func RsaVerifySign(hash crypto.Hash, data, signature []byte, pubKeyFileName string) error { + publicKey, err := loadRsaPublicKey(pubKeyFileName) + if err != nil { + return err + } + + hashed, err := hashData(hash, data) + if err != nil { + return err + } + + return rsa.VerifyPKCS1v15(publicKey, hash, hashed, signature) +} + +// loadRsaPrivateKey loads and parses a PEM encoded private key file. +func loadRsaPublicKey(filename string) (*rsa.PublicKey, error) { + pubKeyData, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + + block, _ := pem.Decode(pubKeyData) + if block == nil { + return nil, errors.New("failed to decode PEM block containing the public key") + } + + var pubKey *rsa.PublicKey + blockType := strings.ToUpper(block.Type) + + if blockType == "RSA PUBLIC KEY" { + pubKey, err = x509.ParsePKCS1PublicKey(block.Bytes) + if err != nil { + // todo: here should be a bug, should return nil, err + key, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, err + } + + var ok bool + pubKey, ok = key.(*rsa.PublicKey) + if !ok { + return nil, errors.New("failed to parse RSA private key") + } + } + } else if blockType == "PUBLIC KEY" { + key, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, err + } + + var ok bool + pubKey, ok = key.(*rsa.PublicKey) + if !ok { + return nil, errors.New("failed to parse RSA private key") + } + + } else { + return nil, errors.New("unsupported key type") + } + + return pubKey, nil +} + +// loadRsaPrivateKey loads and parses a PEM encoded private key file. +func loadRasPrivateKey(filename string) (*rsa.PrivateKey, error) { + priKeyData, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + + block, _ := pem.Decode(priKeyData) + if block == nil { + return nil, errors.New("failed to decode PEM block containing the private key") + } + + var privateKey *rsa.PrivateKey + blockType := strings.ToUpper(block.Type) + + // PKCS#1 format + if blockType == "RSA PRIVATE KEY" { + privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + } else if blockType == "PRIVATE KEY" { // PKCS#8 format + priKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + var ok bool + privateKey, ok = priKey.(*rsa.PrivateKey) + if !ok { + return nil, errors.New("failed to parse RSA private key") + } + } else { + return nil, errors.New("unsupported key type") + } + + return privateKey, nil +} + +// hashData returns the hash value of the data, using the specified hash function +func hashData(hash crypto.Hash, data []byte) ([]byte, error) { + if !hash.Available() { + return nil, errors.New("unsupported hash algorithm") + } + + var hashed []byte + + switch hash { + case crypto.SHA224: + h := sha256.Sum224(data) + hashed = h[:] + case crypto.SHA256: + h := sha256.Sum256(data) + hashed = h[:] + case crypto.SHA384: + h := sha512.Sum384(data) + hashed = h[:] + case crypto.SHA512: + h := sha512.Sum512(data) + hashed = h[:] + default: + return nil, errors.New("unsupported hash algorithm") + } + + return hashed, nil +} diff --git a/cryptor/crypto_example_test.go b/cryptor/crypto_example_test.go index 30405b0..1240d24 100644 --- a/cryptor/crypto_example_test.go +++ b/cryptor/crypto_example_test.go @@ -1,6 +1,7 @@ package cryptor import ( + "crypto" "fmt" ) @@ -536,3 +537,49 @@ func ExampleRsaEncryptOAEP() { // Output: // hello world } + +func ExampleRsaSign() { + data := []byte("This is a test data for RSA signing") + hash := crypto.SHA256 + + privateKey := "./rsa_private.pem" + publicKey := "./rsa_public.pem" + + signature, err := RsaSign(hash, data, privateKey) + if err != nil { + return + } + + err = RsaVerifySign(hash, data, signature, publicKey) + if err != nil { + return + } + + fmt.Println("ok") + + // Output: + // ok +} + +func ExampleRsaVerifySign() { + data := []byte("This is a test data for RSA signing") + hash := crypto.SHA256 + + privateKey := "./rsa_private.pem" + publicKey := "./rsa_public.pem" + + signature, err := RsaSign(hash, data, privateKey) + if err != nil { + return + } + + err = RsaVerifySign(hash, data, signature, publicKey) + if err != nil { + return + } + + fmt.Println("ok") + + // Output: + // ok +} diff --git a/cryptor/crypto_test.go b/cryptor/crypto_test.go index 86f4218..2da1a8a 100644 --- a/cryptor/crypto_test.go +++ b/cryptor/crypto_test.go @@ -1,6 +1,7 @@ package cryptor import ( + "crypto" "testing" "github.com/duke-git/lancet/v2/internal" @@ -170,7 +171,6 @@ func TestRsaEncryptOAEP(t *testing.T) { } func TestAesGcmEncrypt(t *testing.T) { - t.Parallel() data := "hello world" @@ -182,3 +182,53 @@ func TestAesGcmEncrypt(t *testing.T) { assert := internal.NewAssert(t, "TestAesGcmEncrypt") assert.Equal(data, string(decrypted)) } + +func TestRsaSignAndVerify(t *testing.T) { + t.Parallel() + + data := []byte("This is a test data for RSA signing") + hash := crypto.SHA256 + + t.Run("RSA Sign and Verify", func(t *testing.T) { + privateKey := "./rsa_private.pem" + publicKey := "./rsa_public.pem" + + signature, err := RsaSign(hash, data, privateKey) + if err != nil { + t.Fatalf("RsaSign failed: %v", err) + } + + err = RsaVerifySign(hash, data, signature, publicKey) + if err != nil { + t.Fatalf("RsaVerifySign failed: %v", err) + } + }) + + t.Run("RSA Sign and Verify Invalid Signature", func(t *testing.T) { + publicKey := "./rsa_public.pem" + + invalidSig := []byte("InvalidSignature") + + err := RsaVerifySign(hash, data, invalidSig, publicKey) + if err == nil { + t.Fatalf("RsaVerifySign failed: %v", err) + } + }) + + t.Run("RSA Sign and Verify With Different Hash", func(t *testing.T) { + publicKey := "./rsa_public.pem" + privateKey := "./rsa_private.pem" + hashSign := crypto.SHA256 + hashVerify := crypto.SHA512 + + signature, err := RsaSign(hashSign, data, privateKey) + if err != nil { + t.Fatalf("RsaSign failed: %v", err) + } + + err = RsaVerifySign(hashVerify, data, signature, publicKey) + if err == nil { + t.Fatalf("RsaVerifySign failed: %v", err) + } + }) +}