crlVerification/main.go (252 lines of code) (raw):
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
package main // import "github.com/mozilla/CCADB-Tools/crlVerification"
import (
"crypto/x509/pkix"
"encoding/json"
"fmt"
"github.com/mozilla/CCADB-Tools/crlVerification/utils"
"github.com/pkg/errors"
"io/ioutil"
"log"
"math/big"
"net"
"net/http"
"os"
"strconv"
"time"
)
type Input struct {
Crls []string
Serial *big.Int
Date time.Time
Reason utils.RevocationReason
errs []error
}
func NewInput() Input {
return Input{
Crls: make([]string, 0),
Serial: nil,
Date: time.Time{},
Reason: utils.NOT_GIVEN,
errs: make([]error, 0),
}
}
func (i *Input) UnmarshalJSON(data []byte) error {
intermediate := make(map[string]interface{})
err := json.Unmarshal(data, &intermediate)
if err != nil {
i.errs = append(i.errs, err)
return nil
}
if c, ok := intermediate["crl"]; ok {
switch t := c.(type) {
case []interface{}:
for _, crlInterface := range t {
switch crl := crlInterface.(type) {
case string:
i.Crls = append(i.Crls, crl)
default:
i.errs = append(i.errs, errors.New(fmt.Sprintf(`unexpected type for "crl", got %T from value "%v"`, crl, crl)))
}
}
case string:
// The old interface was built for
// only one CRL, so let's honor that just in case
i.Crls = append(i.Crls, t)
case nil:
// The old interface allowed for a null entry
// so let's leave that in place by leaving the array empty.
default:
i.errs = append(i.errs, errors.New(fmt.Sprintf(`unexpected type for "crl", got %T from value "%v"`, t, t)))
}
}
if s, ok := intermediate["serial"]; ok {
switch t := s.(type) {
case string:
serial, err := utils.BigIntFromHexString(t)
if err != nil {
i.errs = append(i.errs, err)
} else {
i.Serial = serial
}
default:
i.errs = append(i.errs, errors.New(fmt.Sprintf(`unexpected type for "serial", got %T from value "%v"`, t, t)))
}
} else {
i.errs = append(i.errs, errors.New(`"serial" is a required field`))
}
if d, ok := intermediate["revocationDate"]; ok {
switch t := d.(type) {
case string:
date, err := utils.TimeFromString(t)
if err != nil {
i.errs = append(i.errs, err)
} else {
i.Date = date
}
default:
i.errs = append(i.errs, errors.New(fmt.Sprintf(`unexpected type for "revocationData", got %T from value "%v"`, t, t)))
}
} else {
i.errs = append(i.errs, errors.New(`"revocationDate" is a required field`))
}
if r, ok := intermediate["revocationReason"]; ok {
switch t := r.(type) {
case string:
reason, err := utils.FromString(&t)
if err != nil {
i.errs = append(i.errs, err)
} else {
i.Reason = reason
}
default:
i.errs = append(i.errs, errors.New(fmt.Sprintf(`unexpected type for "revocationReason", got %T from value "%v"`, t, t)))
}
} else {
i.Reason = utils.NOT_GIVEN
}
return nil
}
type Result string
const (
PASS Result = "PASS"
FAIL Result = "FAIL"
)
type Return struct {
Result Result
Errors []error
}
func (r Return) MarshalJSON() ([]byte, error) {
result := r.Result
errs := make([]string, len(r.Errors))
for i := 0; i < len(errs); i++ {
errs[i] = r.Errors[i].Error()
}
return json.Marshal(map[string]interface{}{
"Result": result,
"Errors": errs,
})
}
func NewReturn() Return {
return Return{
Result: FAIL,
Errors: make([]error, 0),
}
}
func Validate(i Input, crlURL string) Return {
crl, err := utils.CRLFromURL(crlURL)
if err != nil {
ret := NewReturn()
ret.Errors = append(ret.Errors, err)
return ret
}
return validate(i, crl)
}
func validate(i Input, crl *pkix.CertificateList) Return {
ret := NewReturn()
cert, err := utils.FindSerial(crl, i.Serial)
if err != nil {
ret.Errors = append(ret.Errors, err)
return ret
}
if err = utils.ValidateRevocationDate(cert, i.Date); err != nil {
ret.Errors = append(ret.Errors, err)
}
if err = utils.ValidateRevocationReason(cert, i.Reason); err != nil {
ret.Errors = append(ret.Errors, err)
}
if len(ret.Errors) == 0 {
ret.Result = PASS
}
return ret
}
func endpoint(resp http.ResponseWriter, req *http.Request) {
ret := NewReturn()
code := 200
defer func() {
resp.Header().Set("Content-Type", "application/json")
resp.WriteHeader(code)
encoder := json.NewEncoder(resp)
encoder.SetIndent("", " ")
if err := encoder.Encode(&ret); err != nil {
fmt.Println(err)
}
}()
body, err := ioutil.ReadAll(req.Body)
if err != nil {
code = 500
resp.WriteHeader(500)
ret = Return{
Result: FAIL,
Errors: []error{err},
}
return
}
_ = req.Body.Close()
i := NewInput()
err = json.Unmarshal(body, &i)
if err != nil {
code = 500
ret = Return{
Result: FAIL,
Errors: []error{err},
}
return
}
if len(i.errs) != 0 {
code = 400
ret = Return{
Result: FAIL,
Errors: i.errs,
}
return
}
if len(i.Crls) == 0 {
ret = NewReturn()
ret.Errors = append(ret.Errors, utils.CRLNotGiven{})
return
}
code = 200
allErrors := make([]error, 0)
for _, crl := range i.Crls {
result := Validate(i, crl)
if len(result.Errors) == 0 {
ret = result
return
}
allErrors = append(allErrors, result.Errors...)
}
ret = NewReturn()
ret.Errors = allErrors
}
func main() {
http.HandleFunc("/", endpoint)
port := Port()
addr := BindingAddress()
if err := http.ListenAndServe(addr+":"+port, nil); err != nil {
log.Panicln(err)
}
}
func Port() string {
return fmt.Sprintf("%d", parseIntFromEnvOrDie("PORT", 8080))
}
func BindingAddress() string {
switch addr := os.Getenv("ADDR"); addr {
case "":
return "0.0.0.0"
default:
_, _, err := net.ParseCIDR(addr)
if err != nil {
panic("failed to parse the provided ADDR to a valid CIDR")
}
return addr
}
}
func parseIntFromEnvOrDie(key string, defaultVal int) int {
switch val := os.Getenv(key); val {
case "":
return defaultVal
default:
i, err := strconv.ParseUint(val, 10, 32)
if err != nil {
fmt.Printf("%s (%s) could not be parsed to an integer", val, key)
os.Exit(1)
}
return int(i)
}
}