internal/files/files.go (142 lines of code) (raw):
package files
import (
"fmt"
"io/ioutil"
"net/url"
"path/filepath"
"strings"
"os"
"github.com/Azure/run-command-handler-linux/internal/handlersettings"
"github.com/Azure/run-command-handler-linux/pkg/download"
"github.com/Azure/run-command-handler-linux/pkg/preprocess"
"github.com/Azure/run-command-handler-linux/pkg/urlutil"
"github.com/go-kit/kit/log"
"github.com/pkg/errors"
)
var UseMockSASDownloadFailure bool = false
func DownloadAndProcessArtifact(ctx *log.Context, downloadDir string, artifact *handlersettings.UnifiedArtifact) (string, error) {
fileName := artifact.FileName
if fileName == "" {
fileName = fmt.Sprintf("%s%d", "Artifact", artifact.ArtifactId)
}
targetFilePath, err := downloadAndProcessURL(ctx, artifact.ArtifactUri, downloadDir, fileName, artifact.ArtifactSasToken, artifact.ArtifactManagedIdentity)
return targetFilePath, err
}
func DownloadAndProcessScript(ctx *log.Context, url, downloadDir string, cfg *handlersettings.HandlerSettings) (string, error) {
fileName, err := UrlToFileName(url)
if err != nil {
return "", err
}
scriptSAS := cfg.ScriptSAS()
sourceManagedIdentity := cfg.SourceManagedIdentity
targetFilePath, err := downloadAndProcessURL(ctx, url, downloadDir, fileName, scriptSAS, sourceManagedIdentity)
return targetFilePath, err
}
// downloadAndProcessURL downloads using the specified downloader and saves it to the
// specified existing directory, which must be the path to the saved file. Then
// it post-processes file based on heuristics.
func downloadAndProcessURL(ctx *log.Context, url, downloadDir string, fileName string, scriptSAS string, sourceManagedIdentity *handlersettings.RunCommandManagedIdentity) (string, error) {
var err error
if !urlutil.IsValidUrl(url) {
return "", fmt.Errorf(url + " is not a valid url") // url does not contain SAS to se can log it
}
targetFilePath := filepath.Join(downloadDir, fileName)
var scriptSASDownloadErr error = nil
var downloadedFilePath string = ""
if scriptSAS != "" {
if UseMockSASDownloadFailure {
scriptSASDownloadErr = errors.New("Downloading script using SAS token failed.")
} else {
downloadedFilePath, scriptSASDownloadErr = download.GetSASBlob(url, scriptSAS, downloadDir)
}
// Download was successful using SAS. So use downloadedFilePath
if scriptSASDownloadErr == nil && downloadedFilePath != "" {
targetFilePath = downloadedFilePath
}
}
//If there was an error downloading using SAS URI or SAS was not provided, download using managedIdentity or publicly.
if scriptSASDownloadErr != nil || scriptSAS == "" {
downloaders, getDownloadersError := getDownloaders(url, sourceManagedIdentity, download.ProdMsiDownloader{})
if getDownloadersError == nil {
const mode = 0500 // we assume users download scripts to execute
_, err = download.SaveTo(ctx, downloaders, targetFilePath, mode)
} else {
return "", getDownloadersError
}
}
if err != nil {
return "", err
}
err = PostProcessFile(targetFilePath)
if err != nil {
return "", errors.Wrapf(err, "failed to post-process '%s'", fileName)
}
return targetFilePath, nil
}
// getDownloaders returns one or two downloaders (two if it is an Azure storage blob):
// 1. Downloader for script using public URI.
// 2. Downloader for script using managed identity.
func getDownloaders(fileURL string, managedIdentity *handlersettings.RunCommandManagedIdentity, msiDownloader download.MsiDownloader) ([]download.Downloader, error) {
if fileURL == "" {
return nil, fmt.Errorf("fileURL is empty")
}
if download.IsAzureStorageBlobUri(fileURL) {
// if managed identity was specified in the configuration, try to use it to download the files
var msiProvider download.MsiProvider
switch {
case managedIdentity == nil || (managedIdentity.ClientId == "" && managedIdentity.ObjectId == ""):
// get msi Provider for blob url implicitly (uses system managed identity)
msiProvider = msiDownloader.GetMsiProvider(fileURL)
case managedIdentity.ClientId != "" && managedIdentity.ObjectId == "":
// uses user-managed identity
msiProvider = msiDownloader.GetMsiProviderByClientId(fileURL, managedIdentity.ClientId)
case managedIdentity.ClientId == "" && managedIdentity.ObjectId != "":
// uses user-managed identity
msiProvider = msiDownloader.GetMsiProviderByObjectId(fileURL, managedIdentity.ObjectId)
default:
return nil, fmt.Errorf("use either ClientId or ObjectId for managed identity. Not both")
}
_, msiError := msiProvider()
if msiError == nil {
return []download.Downloader{
//Try downloading with MSI token first, if that fails attempt public download
download.NewBlobWithMsiDownload(fileURL, msiProvider),
download.NewURLDownload(fileURL), // Try downloading the Azure storage blob as public URI
}, nil
} else {
return []download.Downloader{
// Try downloading the Azure storage blob as public URI
download.NewURLDownload(fileURL),
}, nil
}
} else {
// Public URI - do not use MSI downloader if the uri is not azure storage blob
return []download.Downloader{download.NewURLDownload(fileURL)}, nil
}
}
// UrlToFileName parses given URL and returns the section after the last slash
// character of the path segment to be used as a file name. If a value is not
// found, an error is returned.
func UrlToFileName(fileURL string) (string, error) {
u, err := url.Parse(fileURL)
if err != nil {
return "", errors.Wrapf(err, "unable to parse URL: %q", fileURL)
}
s := strings.Split(u.Path, "/")
if len(s) > 0 {
fn := s[len(s)-1]
if fn != "" {
return fn, nil
}
}
return "", fmt.Errorf("cannot extract file name from URL: %q", fileURL)
}
// postProcessFile determines if path is a script file based on heuristics
// and makes in-place changes to the file with some post-processing such as BOM
// and DOS-line endings fixes to make the script POSIX-friendly.
func PostProcessFile(path string) error {
ok, err := preprocess.IsTextFile(path)
if err != nil {
return errors.Wrapf(err, "error determining if script is a text file")
}
if !ok {
return nil
}
b, err := ioutil.ReadFile(path) // read the file into memory for processing
if err != nil {
return errors.Wrapf(err, "error reading file")
}
b = preprocess.RemoveBOM(b)
b = preprocess.Dos2Unix(b)
err = ioutil.WriteFile(path, b, 0)
return errors.Wrap(os.Rename(path, path), "error writing file")
}
func SaveScriptFile(filePath string, content string) error {
const mode = 0500 // scripts should have execute permissions
file, err := os.OpenFile(filePath, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, mode)
if err != nil {
return errors.Wrap(err, "failed to open file for writing: "+filePath)
}
_, err = file.WriteString(content)
file.Close()
return errors.Wrap(err, "failed to write to the file: "+filePath)
}