capi/capi.go (579 lines of code) (raw):

/* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ package main // import "github.com/mozilla/CCADB-Tools/capi" import ( "bufio" "crypto/x509" "encoding/json" "encoding/pem" "fmt" "github.com/mozilla/CCADB-Tools/capi/lib/ccadb" "github.com/mozilla/CCADB-Tools/capi/lib/certificateUtils" "github.com/mozilla/CCADB-Tools/capi/lib/lint/certlint" "github.com/mozilla/CCADB-Tools/capi/lib/lint/x509lint" "github.com/mozilla/CCADB-Tools/capi/lib/model" "github.com/mozilla/CCADB-Tools/capi/lib/service" "github.com/natefinch/lumberjack" log "github.com/sirupsen/logrus" "github.com/throttled/throttled" "github.com/throttled/throttled/store/memstore" "golang.org/x/crypto/ssh/terminal" "io" "io/ioutil" "net" "net/http" "net/http/httputil" "net/url" "os" "path" "strconv" "strings" "sync" ) func main() { InitLogging() store, err := memstore.New(65536) if err != nil { log.Fatal(err) } // 100 per minute, with a burst of 6. quota := throttled.RateQuota{MaxRate: throttled.PerMin(500), MaxBurst: 24} rateLimiter, err := throttled.NewGCRARateLimiter(store, quota) if err != nil { log.Fatal(err) } httpRateLimiter := throttled.HTTPRateLimiter{ RateLimiter: rateLimiter, VaryBy: &throttled.VaryBy{Path: true}, } verifyLimiter := httpRateLimiter.RateLimit(http.HandlerFunc(verify)) verifyCCADBLimiter := httpRateLimiter.RateLimit(http.HandlerFunc(verifyFromCCADB)) verifyFromCertificateDetailsLimiter := httpRateLimiter.RateLimit(http.HandlerFunc(verifyFromCertificateDetails)) lintCCADBLimiter := httpRateLimiter.RateLimit(http.HandlerFunc(lintFromCCADB)) lintFromCertificateDetailsLimiter := httpRateLimiter.RateLimit(http.HandlerFunc(lintFromCertificateDetails)) lintFromSubjectLimiter := httpRateLimiter.RateLimit(http.HandlerFunc(lintFromSubject)) http.Handle("/", verifyLimiter) http.Handle("/fromreport", verifyCCADBLimiter) http.Handle("/fromCertificateDetails", verifyFromCertificateDetailsLimiter) http.Handle("/lintFromReport", lintCCADBLimiter) http.Handle("/lintFromCertificateDetails", lintFromCertificateDetailsLimiter) http.Handle("/lintFromSubject", lintFromSubjectLimiter) port := Port() addr := BindingAddress() log.WithFields(log.Fields{"Binding Address": addr, "Port": port}).Info("Starting server") if err := http.ListenAndServe(addr+":"+port, nil); err != nil { log.Panicln(err) } } // The flow for verify is that, the moment that the value for desired response code and response body is known, // that those variables be set and that the function return immediately. A deferred closure then reads these values // an provides a single point of responding back to the client. func verify(resp http.ResponseWriter, req *http.Request) { var response string var responseCode = http.StatusOK defer func() { if err := recover(); err != nil { responseCode = http.StatusBadGateway response = fmt.Sprintf("a fatal error has occured\n%s", err) } switch responseCode { case http.StatusBadGateway: log.Fatal(string(response)) case http.StatusBadRequest: log.Error(responseCode) } resp.WriteHeader(responseCode) _, err := fmt.Fprintln(resp, string(response)) if err != nil { // Oh my, perhaps the client hung up. log.WithField("response", string(response)). WithError(err). Fatal("failed to respond to the remote client") // This may or may not prove to be useful. // Leave it on debug because this can be incredibly noisy. dump, err := httputil.DumpRequest(req, false) switch err == nil { case true: log.WithField("wireRepresentation", dump).Debug() default: log.WithError(err).Fatal() } } }() dump, err := httputil.DumpRequest(req, false) if err != nil { responseCode = http.StatusBadGateway response = "a fatal internal error occurred, " + err.Error() return } log.WithField("Request", string(dump)).Info("Received request") log.Info(req.URL.RawQuery) query, err := url.ParseQuery(req.URL.RawQuery) log.Info(req.ParseForm()) if err != nil { responseCode = http.StatusBadRequest response = "malformed query string, " + err.Error() return } s, ok := query["subject"] if !ok { responseCode = http.StatusBadRequest response = "'subject' query parameter is required" return } if len(s) == 0 { responseCode = http.StatusBadRequest response = "'subject' query parameter may not be empty" return } subject := s[0] rawRoot, err := ioutil.ReadAll(req.Body) if err != nil { responseCode = http.StatusBadRequest response = "failed to read request body, " + err.Error() return } e, ok := query["expect"] interpretation := service.None log.Info(e) if ok { if len(e) == 0 { responseCode = http.StatusBadRequest response = "'expect' query parameter may not be empty" return } switch strings.ToLower(e[0]) { case "valid": interpretation = service.Valid case "expired": interpretation = service.Expired case "revoked": interpretation = service.Revoked } } log.Info("Expectation is " + strconv.Itoa(int(interpretation))) if err := req.Body.Close(); err != nil { responseCode = http.StatusBadGateway response = "failed to close the request body, " + err.Error() return } if len(rawRoot) == 0 { responseCode = http.StatusBadRequest response = "The PEM of the provided trust anchor cannot be empty." } rootPEM, err := certificateUtils.NormalizePEM(rawRoot) if err != nil { responseCode = http.StatusBadRequest response = "failed to format the provided PEM" return } log.Info(string(rootPEM)) block, _ := pem.Decode(rootPEM) if block == nil { responseCode = http.StatusBadRequest response = "failed to decode the provided PEM" return } root, err := x509.ParseCertificate(block.Bytes) if err != nil { responseCode = http.StatusBadRequest response = "Bad root PEM, " + err.Error() return } result := test(subject, root, interpretation) switch r, err := json.MarshalIndent(result, "", " "); err != nil { case true: responseCode = http.StatusBadGateway response = "a fatal error occurred when serializing the response, " + err.Error() case false: response = string(r) } } func verifyFromCCADB(resp http.ResponseWriter, _ *http.Request) { defer func() { if err := recover(); err != nil { log.Error(err) } }() report, err := ccadb.NewReport() if err != nil { resp.WriteHeader(500) resp.Write([]byte(err.Error())) return } ret := make(chan model.TestWebsiteResult, 30) work := make(chan ccadb.Record, len(report.Records)) for _, record := range report.Records { work <- record } close(work) wg := sync.WaitGroup{} for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() for record := range work { root := record.Root() ret <- test(record.TestWebsiteValid(), root, service.Valid) ret <- test(record.TestWebsiteExpired(), root, service.Expired) ret <- test(record.TestWebsiteRevoked(), root, service.Revoked) } }() } go func() { wg.Wait() close(ret) }() resp.Write([]byte{'['}) jsonResp := json.NewEncoder(resp) jsonResp.SetIndent("", " ") i := 0 for answer := range ret { i++ jsonResp.Encode(answer) if i < len(report.Records)*3 { resp.Write([]byte{','}) } if flusher, ok := resp.(http.Flusher); ok { flusher.Flush() } } resp.Write([]byte{']'}) } func streamJsonArray(w io.Writer, answers chan model.TestWebsiteResult, total int) { w.Write([]byte{'['}) jsonResp := json.NewEncoder(w) jsonResp.SetIndent("", " ") i := 0 for answer := range answers { i++ jsonResp.Encode(answer) if i < total { w.Write([]byte{','}) } if flusher, ok := w.(http.Flusher); ok { flusher.Flush() } } w.Write([]byte{']'}) } func verifyFromCertificateDetails(resp http.ResponseWriter, req *http.Request) { body, err := ioutil.ReadAll(req.Body) if err != nil { //@TODO } var records model.CCADBRecords err = json.Unmarshal(body, &records) if err != nil { //@TODO } answers := make(chan model.TestWebsiteResult, len(records.CertificateDetails)) work := make(chan model.CCADBRecord, len(records.CertificateDetails)) for _, record := range records.CertificateDetails { work <- record } close(work) wg := sync.WaitGroup{} for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() for record := range work { root := record.PEM answers <- test(record.TestWebsiteValid, root, service.Valid).SetRecordID(record.RecordID) answers <- test(record.TestWebsiteExpired, root, service.Expired).SetRecordID(record.RecordID) answers <- test(record.TestWebsiteRevoked, root, service.Revoked).SetRecordID(record.RecordID) } }() } go func() { wg.Wait() close(answers) }() streamJsonArray(resp, answers, len(records.CertificateDetails)*3) } func test(subject string, root *x509.Certificate, expectation service.Expectation) model.TestWebsiteResult { result := model.NewTestWebsiteResult(subject, expectation.String()) if subject == "" { return result } // Reach out to the test website on a plain GET and extract the certificate chain from the request. chain, err := certificateUtils.GatherCertificateChain(subject) if err != nil { // Leave this as a 200 as the remote CA test website not responding // is a perfectly valid piece of information to report. result.Error = err.Error() result.Opinion.Result = model.FAIL result.Opinion.Errors = append(result.Opinion.Errors, model.Concern{ Raw: err.Error(), Interpretation: "The subject test website failed to respond within 10 seconds.", Advise: "Please check that " + subject + " is up and responding in a reasonable time.", }) return result } // The test website may include a trust anchor. If it does, then swap it out with // the one our client wants to use, if not just tack our client's trust anchor onto the end. chain = certificateUtils.EmplaceRoot(chain, root) // And, finally, fill out chain verification information. result.Chain = service.VerifyChain(chain) service.InterpretResult(&result, expectation) return result } func lintFromCCADB(resp http.ResponseWriter, _ *http.Request) { defer func() { if err := recover(); err != nil { log.Error(err) } }() report, err := ccadb.NewReport() if err != nil { resp.WriteHeader(500) resp.Write([]byte(err.Error())) return } ret := make(chan model.ChainLintResult, 30) work := make(chan ccadb.Record, len(report.Records)) for _, record := range report.Records { work <- record } close(work) wg := sync.WaitGroup{} for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() for record := range work { ret <- lintSubject(record.TestWebsiteValid()) ret <- lintSubject(record.TestWebsiteExpired()) ret <- lintSubject(record.TestWebsiteRevoked()) } }() } go func() { wg.Wait() close(ret) }() resp.Write([]byte{'['}) jsonResp := json.NewEncoder(resp) jsonResp.SetIndent("", " ") i := 0 for answer := range ret { i++ jsonResp.Encode(answer) if i < len(report.Records)*3 { resp.Write([]byte{','}) } if flusher, ok := resp.(http.Flusher); ok { flusher.Flush() } } resp.Write([]byte{']'}) } func lintFromCertificateDetails(resp http.ResponseWriter, req *http.Request) { body, err := ioutil.ReadAll(req.Body) if err != nil { //@TODO } var records model.CCADBRecords err = json.Unmarshal(body, &records) if err != nil { //@TODO } answers := make(chan model.ChainLintResult, len(records.CertificateDetails)) work := make(chan model.CCADBRecord, len(records.CertificateDetails)) for _, record := range records.CertificateDetails { work <- record } close(work) wg := sync.WaitGroup{} for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() for record := range work { answers <- lintSubject(record.TestWebsiteValid) answers <- lintSubject(record.TestWebsiteExpired) answers <- lintSubject(record.TestWebsiteRevoked) } }() } go func() { wg.Wait() close(answers) }() total := len(records.CertificateDetails) * 3 w := bufio.NewWriter(resp) w.Write([]byte{'['}) jsonResp := json.NewEncoder(w) jsonResp.SetIndent("", " ") i := 0 for answer := range answers { i++ jsonResp.Encode(answer) if i < total { w.Write([]byte{','}) } } w.Write([]byte{']'}) w.Flush() } func lintFromSubject(w http.ResponseWriter, req *http.Request) { query, err := url.ParseQuery(req.URL.RawQuery) if err != nil { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("malformed query string, " + err.Error())) return } s, ok := query["subject"] if !ok { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("'subject' query parameter is required")) return } if len(s) == 0 { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("'subject' query parameter may not be empty")) return } result := lintSubject(s[0]) encoder := json.NewEncoder(w) encoder.SetIndent("", " ") encoder.Encode(result) w.WriteHeader(http.StatusOK) } func lintSubject(subject string) model.ChainLintResult { result := model.NewChainLintResult(subject) if subject == "" { return result } chain, err := certificateUtils.GatherCertificateChain(subject) if err != nil { result.Error = err.Error() result.Opinion.Result = model.FAIL result.Opinion.Errors = append(result.Opinion.Errors, model.Concern{ Raw: err.Error(), Interpretation: "The subject test website failed to respond within 10 seconds.", Advise: "Please check that " + subject + " is up and responding in a reasonable time.", }) return result } if len(chain) <= 1 { result.Error = fmt.Sprintf("certificate chain contains %d certificates", len(chain)) result.Opinion.Result = model.FAIL result.Opinion.Errors = append(result.Opinion.Errors, model.Concern{ Raw: result.Error, Interpretation: "The subject test website failed to provide a certificate chain with at least two certificates.", Advise: "Please check that " + subject + " is up and responding on an HTTPS endpoint and is not using a trust anchor as the sole certificate.", }) return result } chainWithoutRoot := chain[:len(chain)-1] clint, err := certlint.LintCerts(chainWithoutRoot) if err != nil { result.Error = err.Error() result.Opinion.Result = model.FAIL result.Opinion.Errors = append(result.Opinion.Errors, model.Concern{ Raw: err.Error(), Interpretation: "An internal error appears to have occurred while using certlint", Advise: "Please report this error.", }) return result } xlint, err := x509lint.LintChain(chainWithoutRoot) if err != nil { result.Error = err.Error() result.Opinion.Result = model.FAIL result.Opinion.Errors = append(result.Opinion.Errors, model.Concern{ Raw: err.Error(), Interpretation: "An internal error appears to have occurred while using x509lint", Advise: "Please report this error.", }) return result } lintResults := make([]model.CertificateLintResult, len(chainWithoutRoot)) for i := 0; i < len(lintResults); i++ { lintResults[i] = model.NewCertificateLintResult(chainWithoutRoot[i], xlint[i], clint[i]) } result.Finalize(lintResults[0], lintResults[1:]) return result } func Home() string { switch home := os.Getenv("CAPI_HOME"); home { case "": return "." default: return home } } func Port() string { return fmt.Sprintf("%d", parseIntFromEnvOrDie("PORT", 8080)) } func BindingAddress() string { switch addr := os.Getenv("ADDR"); addr { case "": return "0.0.0.0" default: _, _, err := net.ParseCIDR(addr) if err != nil { log.WithField("ADDR", addr). WithError(err). Error("failed to parse the provided ADDR to a valid CIDR") os.Exit(1) } return addr } } func LogFile() string { switch env := os.Getenv("LOG_DIR"); env { case "": return path.Join(Home(), "/logs/capi.log") default: return path.Join(env, "capi.log") } } func LogLevel() log.Level { switch lvl := os.Getenv("LOGLEVEL"); lvl { case "": return log.InfoLevel default: level, err := log.ParseLevel(lvl) if err != nil { // This is nipped straight from log.ParseLevel as // I don't see constants to refer to. If the version of logrus // included is ever bumped then this can migrate to being wrong. fmt.Printf("%s is not a valid logging level.\n", lvl) fmt.Println("Valid log levels are:") fmt.Println("> panic") fmt.Println("> fatal") fmt.Println("> error") fmt.Println("> warn OR warning") fmt.Println("> info") fmt.Println("> debug") fmt.Println("> trace") os.Exit(1) } return level } } func MaxLogSize() int { return parseIntFromEnvOrDie("MAXLOGSIZE", 12) } func MaxLogBackups() int { return parseIntFromEnvOrDie("MAXLOGBACKUPS", 12) } func MaxLogAge() int { return parseIntFromEnvOrDie("MAXLOGAGE", 31) } func Lumberjack() io.Writer { return &lumberjack.Logger{ Filename: LogFile(), MaxSize: MaxLogSize(), // megabytes MaxBackups: MaxLogBackups(), MaxAge: MaxLogAge(), //days Compress: true, } } func LogWriter() io.Writer { switch isTTY := terminal.IsTerminal(int(os.Stdout.Fd())); isTTY { case true: // People sitting in front of their screen probably want // a copy of the logs to stdout. return io.MultiWriter(os.Stdout, Lumberjack()) default: // Otherwise everything to just the file logger. return Lumberjack() } } func InitLogging() { log.SetLevel(LogLevel()) log.SetFormatter(&log.TextFormatter{}) log.SetOutput(LogWriter()) } func parseIntFromEnvOrDie(key string, defaultVal int) int { switch val := os.Getenv(key); val { case "": return defaultVal default: i, err := strconv.ParseUint(val, 10, 32) if err != nil { fmt.Printf("%s (%s) could not be parsed to an integer", val, key) os.Exit(1) } return int(i) } }