agent/cryptdata/cryptdata.go (252 lines of code) (raw):

package cryptdata import ( "bytes" "crypto" "crypto/aes" "crypto/cipher" "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/x509" "encoding/pem" "errors" "fmt" "sort" "sync" "time" "github.com/aliyun/aliyun_assist_client/agent/log" "github.com/aliyun/aliyun_assist_client/agent/util" ) var ( keyPairs_ sync.Map ErrKeyIdNotExist = errors.New("Key id not exist") ErrKeyIdDuplicated = errors.New("Key id is duplicated") ErrParamNotExist = errors.New("Secret param not exist") // The ciphertext length cannot be less than the AES key length ErrCipherTextTooShort = errors.New("ciphertext too short") ) const ( ERR_OTHER_CODE = 1 // Error for 'agent not support' is 110 ERR_KEYID_NOTEXIST_CODE = 111 ERR_KEYID_DUPLICATED_CODE = 112 ERR_PARAM_NOTEXIST_CODE = 113 // length limit of plaintext to encrypt is 190byte, see https://crypto.stackexchange.com/questions/42097/what-is-the-maximum-size-of-the-plaintext-message-for-rsa-oaep LIMIT_PLAINTEXT_LEN = 190 ) func GenRsaKey(specifiedId string, timeout int) (*KeyInfo, error) { if specifiedId != "" { if k, _ := loadKey(specifiedId); k != nil { return nil, ErrKeyIdDuplicated } } privateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, err } publicKey := privateKey.PublicKey var X509PublicKey []byte X509PublicKey, err = x509.MarshalPKIXPublicKey(&publicKey) if err != nil { return nil, err } publicBlock := pem.Block{Type: "PUBLIC KEY", Bytes: X509PublicKey} buf := bytes.NewBufferString("") if err = pem.Encode(buf, &publicBlock); err != nil { return nil, err } var keyId, publicKeyStr string publicKeyStr = buf.String() timestamp := time.Now().Unix() if specifiedId != "" { keyId = specifiedId } else { keyId = util.ComputeStrMd5(fmt.Sprint(timestamp, publicKeyStr)) } keyPair := &rsaKeyPair{ Id: keyId, CreatedTimestamp: timestamp, ExpiredTimestamp: timestamp + int64(timeout), PrivateKey: privateKey, PublicKey: publicKeyStr, } if err = storeKey(keyId, keyPair); err != nil { return nil, err } keyInfo := &KeyInfo{ Id: keyId, CreatedTimestamp: timestamp, ExpiredTimestamp: timestamp + int64(timeout), PublicKey: publicKeyStr, } return keyInfo, nil } func RemoveRsaKey(keyId string) error { if k, _ := loadKey(keyId); k == nil { return ErrKeyIdNotExist } deleteKey(keyId) return nil } func EncryptWithRsa(keyId, rawData string) ([]byte, error) { if privateKey, err := loadKey(keyId); err != nil { return nil, err } else { if encrypted, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, &privateKey.PrivateKey.PublicKey, []byte(rawData), nil); err != nil { return nil, err } else { return encrypted, nil } } } func DecryptWithRsa(keyId string, encrypted []byte) ([]byte, error) { if privateKey, err := loadKey(keyId); err != nil { return nil, err } else { if decrypted, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, privateKey.PrivateKey, encrypted, nil); err != nil { return nil, err } else { return decrypted, nil } } } func decryptWithAes(encrypted []byte, aesKey []byte) ([]byte, error) { var err error block, err := aes.NewCipher(aesKey) if err != nil { return nil, err } blockSize := block.BlockSize() if len(encrypted) < blockSize { return nil, ErrCipherTextTooShort } iv := encrypted[:blockSize] encrypted = encrypted[blockSize:] mode := cipher.NewCBCDecrypter(block, iv) mode.CryptBlocks(encrypted, encrypted) encrypted = pkcs7UnPadding(encrypted) return encrypted, nil } func CheckKey(keyId string) (*KeyInfo, error) { if privateKey, err := loadKey(keyId); err != nil { return nil, err } else { keyInfo := &KeyInfo{ Id: privateKey.Id, CreatedTimestamp: privateKey.CreatedTimestamp, ExpiredTimestamp: privateKey.ExpiredTimestamp, PublicKey: privateKey.PublicKey, } return keyInfo, nil } } func CheckKeyList() (keyList KeyInfos) { ks := getKeys() now := time.Now().Unix() for _, k := range ks { if k.ExpiredTimestamp <= now { continue } keyList = append(keyList, KeyInfo{ Id: k.Id, CreatedTimestamp: k.CreatedTimestamp, ExpiredTimestamp: k.ExpiredTimestamp, PublicKey: k.PublicKey, }) } sort.Sort(keyList) return } func SignData(keyId, data string) ([]byte, error) { privateKey, err := loadKey(keyId) if err != nil { return nil, err } // Hash data before signing. msgHash := crypto.SHA256.New() if _, err := msgHash.Write([]byte(data)); err != nil { return nil, err } msgHashSum := msgHash.Sum(nil) signature, err := rsa.SignPSS(rand.Reader, privateKey.PrivateKey, crypto.SHA256, msgHashSum, nil) if err != nil { return nil, err } return signature, nil } func VerifySignature(keyId, data string, signature []byte) (bool, error) { if privateKey, err := loadKey(keyId); err != nil { return false, err } else { // Hash data before verifying signature. msgHash := crypto.SHA256.New() if _, err := msgHash.Write([]byte(data)); err != nil { return false, err } msgHashSum := msgHash.Sum(nil) err = rsa.VerifyPSS(&privateKey.PrivateKey.PublicKey, crypto.SHA256, msgHashSum, signature, nil) if errors.Is(err, rsa.ErrVerification) { return false, nil } else if err != nil { return false, err } return true, nil } } func clearExpiredKey() { ks := getKeys() now := time.Now().Unix() for _, k := range ks { if k.ExpiredTimestamp <= now { log.GetLogger().Infof("KeyPair[%s] has expired for %d second, so delete it", k.Id, now-k.ExpiredTimestamp) keyPairs_.Delete(k.Id) } } } func loadKey(keyId string) (*rsaKeyPair, error) { if value, ok := keyPairs_.Load(keyId); !ok { return nil, ErrKeyIdNotExist } else { privateKey, ok := value.(*rsaKeyPair) if !ok { return nil, errors.New("Type convert failed") } now := time.Now().Unix() if privateKey.ExpiredTimestamp < now { log.GetLogger().Infof("KeyPair[%s] has expired for %d second, so delete it", privateKey.Id, now-privateKey.ExpiredTimestamp) keyPairs_.Delete(keyId) return nil, ErrKeyIdNotExist } return privateKey, nil } } func getKeys() []*rsaKeyPair { keys := []*rsaKeyPair{} keyPairs_.Range(func(k, v interface{}) bool { if privateKey, ok := v.(*rsaKeyPair); ok { keys = append(keys, privateKey) } return true }) return keys } func storeKey(keyId string, keyPair *rsaKeyPair) error { if _, ok := keyPairs_.LoadOrStore(keyId, keyPair); ok { return ErrKeyIdDuplicated } return nil } func deleteKey(keyId string) { keyPairs_.Delete(keyId) log.GetLogger().Infof("KeyPair[%s] is actively deleted", keyId) } func ErrToCode(err error) int { if errors.Is(err, ErrKeyIdDuplicated) { return ERR_KEYID_DUPLICATED_CODE } else if errors.Is(err, ErrKeyIdNotExist) { return ERR_KEYID_NOTEXIST_CODE } else if errors.Is(err, ErrParamNotExist) { return ERR_PARAM_NOTEXIST_CODE } return ERR_OTHER_CODE } func pkcs7UnPadding(origData []byte) []byte { length := len(origData) unpadding := int(origData[length-1]) return origData[:(length - unpadding)] }