agent/taskengine/signature/signature.go (228 lines of code) (raw):

package signature import ( "crypto" "crypto/rsa" "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "errors" "fmt" gmsm_x509 "github.com/tjfoc/gmsm/x509" "os" "path/filepath" "strconv" "strings" "sync" "github.com/aliyun/aliyun_assist_client/agent/taskengine/models" "github.com/aliyun/aliyun_assist_client/agent/util" "github.com/aliyun/aliyun_assist_client/common/pathutil" "github.com/aliyun/aliyun_assist_client/thirdparty/sirupsen/logrus" "github.com/tjfoc/gmsm/sm2" ) const ( certFileName = "task_sign_certs" algorithm_sm3withsm2 = "SM3WithSM2" algorithm_sha256withrsa = "SHA256WithRSA" ) type Cert struct { KeypairVersion int `json:"keypairVersion"` SignatureVersion int `json:"signatureVersion"` PublicKeyStr string `json:"publicKey"` Algorithm string `json:"algorithm"` SignatureFields []string `json:"signatureFields"` PublicKeyRsa *rsa.PublicKey `json:"-"` PublicKeySm2 *sm2.PublicKey `json:"-"` } type SignatureCertificateReq struct { KeypairVersion int `json:"keypairVersion"` SignatureVersion int `json:"signatureVersion"` } var ( ErrorUnknownSignatureFormat = errors.New("unknown signature format") ErrorCertNotFound = errors.New("cert not found") ErrorUnknownSignAlgorithm = errors.New("unknown sign algorithm") ErrorParsePublicKey = errors.New("parse public key failed") ErrorUnknownInstanceId = errors.New("unknown instanceId") certs_ sync.Map certsMp map[string]*Cert certsMpLock sync.Mutex loadCertsOnce sync.Once ) func (c *Cert) parsePublicKey() error { switch c.Algorithm { case algorithm_sha256withrsa: block, _ := pem.Decode([]byte(c.PublicKeyStr)) if block == nil { return ErrorParsePublicKey } pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { return err } c.PublicKeyRsa = pubInterface.(*rsa.PublicKey) case algorithm_sm3withsm2: block, _ := pem.Decode([]byte(c.PublicKeyStr)) if block == nil { return ErrorParsePublicKey } var err error c.PublicKeySm2, err = gmsm_x509.ParseSm2PublicKey(block.Bytes) if err != nil { return err } default: return ErrorUnknownSignAlgorithm } return nil } func VerifyTaskSign(logger logrus.FieldLogger, task models.RunTaskInfo) (bool, error) { loadCertsOnce.Do(func() { if err := loadCertMp(logger); err != nil { logger.WithError(err).Error("Load certs from local failed") } else { logger.Info("Load certs from local successd") } }) fields := strings.SplitN(task.Signature, "#", 3) if len(fields) != 3 { return false, ErrorUnknownSignatureFormat } if fields[0] == "" || fields[1] == "" || fields[2] == "" { return false, ErrorUnknownSignatureFormat } signVer, err := strconv.Atoi(fields[0]) if err != nil { return false, ErrorUnknownSignatureFormat } keypairVer, err := strconv.Atoi(fields[1]) if err != nil { return false, ErrorUnknownSignatureFormat } signature := fields[2] signatureByte, err := base64.StdEncoding.DecodeString(signature) if err != nil { return false, err } key := fmt.Sprintf("%d-%d", signVer, keypairVer) var c *Cert if value, ok := certs_.Load(key); ok { c, _ = value.(*Cert) } else { c, err = updateCertsMp(logger, key, signVer, keypairVer) if err != nil { return false, err } } var dataList []string for _, field := range c.SignatureFields { switch field { case "userId": dataList = append(dataList, task.UserId) case "instanceId": instanceId := util.GetInstanceId() if instanceId == "unknown" { return false, ErrorUnknownInstanceId } dataList = append(dataList, instanceId) case "commandContent": dataList = append(dataList, task.Content) } } data := strings.Join(dataList, "#") logger.Info("dataList: ", data) switch c.Algorithm { case algorithm_sha256withrsa: return verifyWithSha256withrsa(c.PublicKeyRsa, data, signatureByte) case algorithm_sm3withsm2: return verifyWithSm3withsm2(c.PublicKeySm2, data, signatureByte) default: return false, ErrorUnknownSignAlgorithm } } // load certificates from cache file func loadCertMp(logger logrus.FieldLogger) error { certsMpLock.Lock() defer certsMpLock.Unlock() certsMp = make(map[string]*Cert) crossVersionConfDir, err := pathutil.GetCrossVersionConfigPath() if err != nil { return err } certFile := filepath.Join(crossVersionConfDir, certFileName) content, err := os.ReadFile(certFile) if err != nil { return err } certs := map[string]*Cert{} if err = json.Unmarshal(content, &certs); err != nil { return err } for k, c := range certs { if err := c.parsePublicKey(); err != nil { logger.WithError(err).Errorf("Parse public key[%d-%d] failed", c.SignatureVersion, c.KeypairVersion) continue } certs_.Store(k, c) certsMp[k] = c } return nil } // pull new certificate from online and update certificates cache file func updateCertsMp(logger logrus.FieldLogger, key string, signVer int, keypairVer int) (*Cert, error) { certsMpLock.Lock() defer certsMpLock.Unlock() var c *Cert var ok bool var err error if c, ok = certsMp[key]; !ok { if key, c, err = pullCert(signVer, keypairVer); err != nil { return nil, err } certsMp[key] = c certs_.Store(key, c) if err := storeCertMp(); err != nil { logger.WithError(err).Error("Store cert to local failed") } } return c, nil } func pullCert(signVer, keypairVer int) (string, *Cert, error) { url := util.GetSignCertService() url += fmt.Sprintf("?signatureVersion=%d&keypairVersion=%d", signVer, keypairVer) err, resp := util.HttpGet(url) if err != nil { return "", nil, err } cert := &Cert{} if err = json.Unmarshal([]byte(resp), &cert); err != nil { return "", nil, err } if err := cert.parsePublicKey(); err != nil { return "", nil, err } key := fmt.Sprintf("%d-%d", cert.SignatureVersion, cert.KeypairVersion) return key, cert, nil } func storeCertMp() error { crossVersionConfDir, err := pathutil.GetCrossVersionConfigPath() if err != nil { return err } certFile := filepath.Join(crossVersionConfDir, certFileName) content, err := json.Marshal(certsMp) if err != nil { return err } return os.WriteFile(certFile, content, 0644) } func verifyWithSha256withrsa(pk *rsa.PublicKey, data string, signature []byte) (bool, error) { // Hash data before verifying signature. msgHash := crypto.SHA256.New() if _, err := msgHash.Write([]byte(data)); err != nil { return false, err } msgHashSum := msgHash.Sum(nil) err := rsa.VerifyPKCS1v15(pk, crypto.SHA256, msgHashSum, signature) return err == nil, err } func verifyWithSm3withsm2(pk *sm2.PublicKey, data string, signature []byte) (bool, error) { return pk.Verify([]byte(data), signature), nil }