apt/method.go (217 lines of code) (raw):

// Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package apt import ( "bufio" "context" "crypto/md5" "errors" "fmt" "io" "net/http" "net/http/httputil" "os" "strconv" "strings" "golang.org/x/oauth2" "golang.org/x/oauth2/google" ) const ( cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform" ) // NewAptMethod returns an AptMethod. func NewAptMethod(input *bufio.Reader, output io.Writer) *Method { return &Method{ config: &aptMethodConfig{}, writer: NewAptMessageWriter(output), reader: NewAptMessageReader(input), dl: downloaderImpl{}, } } // httpClient exists to enable mocking of http.Client. type httpClient interface { Do(req *http.Request) (*http.Response, error) } // downloader exists to enable mocking of AptMethod.download. type downloader interface { download(io.ReadCloser, string) (string, error) } type downloaderImpl struct{} // Method represents the method handler. type Method struct { reader *MessageReader writer *MessageWriter config *aptMethodConfig client httpClient dl downloader } type aptMethodConfig struct { serviceAccountJSON, serviceAccountEmail string debug bool } // Run runs the method. func (m *Method) Run(ctx context.Context) error { m.writer.SendCapabilities() for { select { case <-ctx.Done(): return nil default: } msg, err := m.reader.ReadMessage(ctx) if errors.Is(err, errEmptyMessage) { continue } else if errors.Is(err, io.EOF) { return nil } else if err != nil { return err } switch msg.code { case 600: m.handleAcquire(ctx, msg) case 601: m.handleConfigure(msg) default: // TODO(hopkiw): now write a test for this. m.writer.Fail(fmt.Sprintf("Unsupported message code %d received from apt", msg.code)) } } } func (m *Method) initClient(ctx context.Context) error { if m.client != nil { return nil } var ts oauth2.TokenSource switch { case m.config.serviceAccountJSON != "": json, err := os.ReadFile(m.config.serviceAccountJSON) if err != nil { return fmt.Errorf("failed to read service account JSON file: %v", err) } creds, err := google.CredentialsFromJSON(ctx, json, cloudPlatformScope) if err != nil { return fmt.Errorf("failed to obtain creds from service account JSON: %v", err) } ts = creds.TokenSource case m.config.serviceAccountEmail != "": ts = google.ComputeTokenSource(m.config.serviceAccountEmail) default: creds, err := google.FindDefaultCredentials(ctx, cloudPlatformScope) if err != nil { return fmt.Errorf("failed to obtain default creds: %v", err) } ts = creds.TokenSource } if ts == nil { return errors.New("failed to obtain creds") } m.client = oauth2.NewClient(ctx, ts) return nil } // download performs the actual downloading to target file and returns // an MD5 hash of the downloaded file. func (r downloaderImpl) download(body io.ReadCloser, filename string) (string, error) { defer body.Close() data, err := io.ReadAll(body) if err != nil { return "", err } file, err := os.Create(filename) if err != nil { return "", err } defer file.Close() _, err = file.Write(data) return fmt.Sprintf("%x", md5.Sum(data)), err } func (m *Method) handleAcquire(ctx context.Context, msg *Message) error { uri := msg.Get("URI") if uri == "" { err := errors.New("no URI provided in Acquire message") m.writer.Fail(err.Error()) return err } filename := msg.Get("Filename") if filename == "" { err := errors.New("no filename provided in Acquire message") m.writer.FailURI(uri, err.Error()) return err } ifModifiedSince := msg.Get("Last-Modified") if err := m.initClient(ctx); err != nil { m.writer.FailURI(uri, err.Error()) return err } realuri := strings.Replace(uri, "ar+https", "https", 1) req, err := http.NewRequest("GET", realuri, nil) if err != nil { return err } if ifModifiedSince != "" { // TODO(hopkiw): validate this string is in RFC1123Z format. req.Header.Add("If-Modified-Since", ifModifiedSince) } if m.config.debug { if reqDump, dumpErr := httputil.DumpRequest(req, true); dumpErr == nil { m.writer.Log(string(reqDump)) } } resp, err := m.client.Do(req) if m.config.debug && resp != nil { if respDump, dumpErr := httputil.DumpResponse(resp, false); dumpErr == nil { m.writer.Log(string(respDump)) } } if err != nil { m.writer.FailURI(uri, err.Error()) return err } size := resp.Header.Get("Content-Length") lastModified := resp.Header.Get("Last-Modified") switch resp.StatusCode { case 200: // It's weird to send URI Start after we've already contacted // the server, but we need to know the size. m.writer.URIStart(uri, size, lastModified) md5Hash, err := m.dl.download(resp.Body, filename) if err != nil { m.writer.FailURI(uri, err.Error()) return err } m.writer.URIDone(uri, size, lastModified, md5Hash, filename, false) case 304: // Unchanged since Last-Modified. Respond with "IMS-Hit: true" to // indicate the existing file is valid. m.writer.URIDone(uri, size, lastModified, "", filename, true) default: // All other codes including 404, 403, etc. err := fmt.Errorf("error downloading: code %v", resp.StatusCode) m.writer.FailURI(uri, err.Error()) return err } return nil } // Ported from apt's `StringToBool` function // https://salsa.debian.org/apt-team/apt/-/blob/a0a76c2e20c1ddefd76a4a539a9350b96d66006e/apt-pkg/contrib/strutl.cc#L824 func stringToBool(s string) bool { if i, err := strconv.Atoi(s); err == nil { if i == 1 { return true } return false } sl := strings.ToLower(s) trueStrs := []string{"yes", "true", "with", "on", "enable"} for _, trueStr := range trueStrs { if sl == trueStr { return true } } return false } func (m *Method) handleConfigure(msg *Message) { configs, ok := msg.fields["Config-Item"] if !ok { // Nothing to set. return } for _, configItem := range configs { parts := strings.SplitN(configItem, "=", 2) if len(parts) != 2 { m.writer.Log(fmt.Sprintf("malformed config item: %v", configItem)) return } switch parts[0] { case "Acquire::gar::Service-Account-JSON": m.config.serviceAccountJSON = strings.TrimSpace(parts[1]) case "Acquire::gar::Service-Account-Email": m.config.serviceAccountEmail = strings.TrimSpace(parts[1]) case "Debug::Acquire::gar": m.config.debug = stringToBool(strings.TrimSpace(parts[1])) } } // Enforce the precedence of these two options. if m.config.serviceAccountJSON != "" { m.config.serviceAccountEmail = "" } }