tooling/secret-sync/main.go (212 lines of code) (raw):
// Copyright 2025 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"fmt"
"io"
"log"
"os"
"strings"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys"
"github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets"
"github.com/Azure/ARO-HCP/tooling/templatize/pkg/azauth"
)
// conservative
var chunkSizeBytes = 400
var chunkDelemiter = "\n"
const (
dryRunEnvKey = "DRY_RUN"
inputFileKeyEnv = "INPUT_FILE"
outputFileEnvKey = "OUTPUT_FILE"
publicKetFileEnvKey = "PUBLIC_KEY_FILE"
encryptionKeyEnvKey = "ENCRYPTION_KEY"
secretToSetEnvKey = "SECRET_TO_SET"
vaultNameEnvKey = "KEYVAULT"
)
func readAndChunkData(inputReader io.Reader) ([][]byte, error) {
returnBytes := make([][]byte, 0)
reader := bufio.NewReader(inputReader)
for {
data := make([]byte, chunkSizeBytes)
n, err := reader.Read(data)
if err == io.EOF {
break
} else if err != nil {
return nil, fmt.Errorf("problems reading from input: %v", err)
}
returnBytes = append(returnBytes, data[:n])
}
return returnBytes, nil
}
func persistEncryptedChunks(encryptedChunks [][]byte) error {
outputFile, err := os.Create(os.Getenv(outputFileEnvKey))
if err != nil {
return fmt.Errorf("error creating output file %v", err)
}
defer outputFile.Close()
for _, c := range encryptedChunks {
encodedChunk := make([]byte, base64.StdEncoding.EncodedLen(len(c)))
base64.StdEncoding.Encode(encodedChunk, c)
_, err := outputFile.Write(encodedChunk)
if err != nil {
return fmt.Errorf("error writing encoded chunk %v", err)
}
_, err = outputFile.Write([]byte(chunkDelemiter))
if err != nil {
return fmt.Errorf("error writing delimiter %v", err)
}
}
return nil
}
func encryptData(secretMessage []byte) ([]byte, error) {
pubPEMData, err := os.ReadFile(os.Getenv("PUBLIC_KEY_FILE"))
if err != nil {
return nil, fmt.Errorf("error while reading public key file %s: %v", os.Getenv("PUBLIC_KEY_FILE"), err)
}
block, _ := pem.Decode(pubPEMData)
if block == nil || block.Type != "PUBLIC KEY" {
return nil, fmt.Errorf("failed to decode PEM block containing public key")
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("error while parsing public key %v", err)
}
label := []byte{}
rng := rand.Reader
return rsa.EncryptOAEP(sha256.New(), rng, pub.(*rsa.PublicKey), secretMessage, label)
}
func decryptData(client *azkeys.Client, encryptedMessage []byte) ([]byte, error) {
d, err := client.Decrypt(
context.Background(),
os.Getenv(encryptionKeyEnvKey),
"",
azkeys.KeyOperationParameters{
Algorithm: to.Ptr(azkeys.EncryptionAlgorithmRSAOAEP256),
Value: encryptedMessage,
},
&azkeys.DecryptOptions{},
)
if err != nil {
return nil, fmt.Errorf("error decoding secret %v", err)
}
return d.Result, nil
}
func persistSecret(client *azsecrets.Client, secret []byte) error {
secretToSet := os.Getenv(secretToSetEnvKey)
currentSecret, err := client.GetSecret(
context.Background(),
secretToSet,
"",
&azsecrets.GetSecretOptions{})
if err != nil && !strings.Contains(err.Error(), "SecretNotFound") {
return fmt.Errorf("error getting secret %v", err)
}
if currentSecret.Value == nil || *currentSecret.Value != string(secret) {
fmt.Println("Secret needs update")
if os.Getenv(dryRunEnvKey) != "true" {
_, err := client.SetSecret(
context.Background(),
secretToSet,
azsecrets.SetSecretParameters{
Value: to.Ptr(string(secret)),
},
nil,
)
if err != nil {
return fmt.Errorf("error setting secret %v", err)
}
} else {
fmt.Println("Skipped due to dry run")
}
} else {
fmt.Println("Secret up to date")
}
return nil
}
func readEncryptedChunks() ([][]byte, error) {
chunkedData, err := os.ReadFile(os.Getenv(inputFileKeyEnv))
if err != nil {
return nil, fmt.Errorf("error reading input file %v", err)
}
return bytes.Split(chunkedData, []byte(chunkDelemiter)), nil
}
func main() {
if len(os.Args) != 2 {
log.Fatal("Need to provide mode parameter encrypt/decrypt")
}
mode := os.Args[1]
switch mode {
case "encrypt":
{
encryptedChunks := make([][]byte, 0)
plainChunks, err := readAndChunkData(os.Stdin)
if err != nil {
log.Fatal(err)
}
for _, c := range plainChunks {
encryptedChunk, err := encryptData(c)
if err != nil {
log.Fatal(err)
}
encryptedChunks = append(encryptedChunks, encryptedChunk)
}
fmt.Printf("Encrypted data, persisting to: %s\n", os.Getenv(outputFileEnvKey))
if os.Getenv(dryRunEnvKey) == "true" {
fmt.Println("... skiped due to dry run")
} else {
if err := persistEncryptedChunks(encryptedChunks); err != nil {
log.Fatal(err)
}
}
os.Exit(0)
}
case "decrypt":
{
chain, err := azauth.GetAzureTokenCredentials()
if err != nil {
log.Fatal(fmt.Errorf("error getting credentials %v", err))
}
keyClient, err := azkeys.NewClient(fmt.Sprintf("https://%s.vault.azure.net", os.Getenv(vaultNameEnvKey)), chain, nil)
if err != nil {
log.Fatal(fmt.Errorf("error getting azkeys client %v", err))
}
decryptedChunks := make([][]byte, 0)
encryptedChunks, err := readEncryptedChunks()
if err != nil {
log.Fatal(err)
}
for _, c := range encryptedChunks {
if len(c) > 0 {
dst := make([]byte, base64.StdEncoding.DecodedLen(len(c)))
if _, err = base64.StdEncoding.Decode(dst, c); err != nil {
log.Fatal(err)
}
decryptedChunk, err := decryptData(keyClient, dst)
if err != nil {
log.Fatal(err)
}
decryptedChunks = append(decryptedChunks, decryptedChunk)
}
}
secretsClient, err := azsecrets.NewClient(fmt.Sprintf("https://%s.vault.azure.net", os.Getenv(vaultNameEnvKey)), chain, nil)
if err != nil {
log.Fatal(fmt.Errorf("error getting azsecrets client %v", err))
}
joinedMessage := bytes.Join(decryptedChunks, []byte{})
fmt.Printf("Data decrypted, persisting to: %s\n", os.Getenv(secretToSetEnvKey))
if err := persistSecret(secretsClient, joinedMessage); err != nil {
log.Fatal(err)
}
os.Exit(0)
}
default:
log.Fatalf("Invalid mode %s", mode)
}
}