testutils/fakekms/fakekms.go (205 lines of code) (raw):

// Copyright 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package fakekms supports integration testing of kms-plugin by faking CloudKMS. package fakekms import ( "encoding/json" "errors" "fmt" "io/ioutil" "net" "net/http" "net/http/httptest" "sync" "time" "github.com/golang/glog" "github.com/google/go-cmp/cmp" "github.com/phayes/freeport" "google.golang.org/api/cloudkms/v1" ) // Server fakes CloudKMS. type Server struct { srv *httptest.Server mux sync.Mutex encryptRequestLog []*cloudkms.EncryptRequest decryptRequestLog []*cloudkms.DecryptRequest iamTestRequestLog []*cloudkms.TestIamPermissionsRequest port int } // Client returns *http.Client for the fake. func (f *Server) Client() *http.Client { return f.srv.Client() } // URL returns URL on which the fake is expecting requests. func (f *Server) URL() string { return f.srv.URL } // Close closes the underlying httptest.Server. func (f *Server) Close() { f.srv.Close() } // EncryptRequestsEqual validates that the supplied EncryptRequests are equal to all // EncryptRequests processed by the server. func (f *Server) EncryptRequestsEqual(r []*cloudkms.EncryptRequest) error { f.mux.Lock() defer f.mux.Unlock() if diff := cmp.Diff(f.encryptRequestLog, r); diff != "" { return fmt.Errorf("EncryptRequests differs from expected:(-want +got)\n%s", diff) } return nil } // DecryptRequestsEqual validates that the supplied DecryptRequests are equal to the all // DecryptRequests processed by the server. func (f *Server) DecryptRequestsEqual(r []*cloudkms.DecryptRequest) error { f.mux.Lock() defer f.mux.Unlock() if diff := cmp.Diff(f.decryptRequestLog, r); diff != "" { return fmt.Errorf("DecryptRequests differ from expected: (-want +got)\n%s", diff) } return nil } // TestIAMRequestsEqual validates that the supplied TestIamPermissionsRequests are equal to the all // TestIamPermissionsRequests processed by the server. func (f *Server) TestIAMRequestsEqual(r []*cloudkms.TestIamPermissionsRequest) error { f.mux.Lock() defer f.mux.Unlock() if diff := cmp.Diff(f.iamTestRequestLog[len(f.iamTestRequestLog)-1], r); diff != "" { return fmt.Errorf("Last TestIAMPermissionsRequest differs from expected: (-want +got)\n%s", diff) } return nil } func (f *Server) recordEncryptRequest(r *cloudkms.EncryptRequest) { f.mux.Lock() defer f.mux.Unlock() f.encryptRequestLog = append(f.encryptRequestLog, r) } func (f *Server) recordDecryptRequest(r *cloudkms.DecryptRequest) { f.mux.Lock() defer f.mux.Unlock() f.decryptRequestLog = append(f.decryptRequestLog, r) } func (f *Server) recordTestIAMRequest(r *cloudkms.TestIamPermissionsRequest) { f.mux.Lock() defer f.mux.Unlock() f.iamTestRequestLog = append(f.iamTestRequestLog, r) } // NewWithPipethrough creates and returns *Server that simply passed through the requests, by // replacing ciphertext to cleartext and vice versa. // Callers are also responsible for calling Close after completing tests. // keyName simulates CloudKMS' keyName and is taken into account when calculating expected URL endpoints. func NewWithPipethrough(keyName string, port int) (*Server, error) { handle := func(req json.Marshaler) (json.Marshaler, int, error) { glog.Infof("Processing request: %#v", req) switch r := req.(type) { case *cloudkms.EncryptRequest: return &cloudkms.EncryptResponse{ Name: keyName, Ciphertext: r.Plaintext, }, http.StatusOK, nil case *cloudkms.DecryptRequest: return &cloudkms.DecryptResponse{ Plaintext: r.Ciphertext, }, http.StatusOK, nil case *cloudkms.TestIamPermissionsRequest: return &cloudkms.TestIamPermissionsResponse{ Permissions: []string{ "cloudkms.cryptoKeyVersions.useToEncrypt", "cloudkms.cryptoKeyVersions.useToDecrypt", }, }, http.StatusOK, nil default: return nil, http.StatusInternalServerError, fmt.Errorf("was not expecting request type:%T", r) } } return newWithCallback(keyName, port, 0, handle) } // NewWithResponses creates and returns *Server. // It is the responsibility of the caller to supply the expected number of Responses. // When the provided Responses are exhausted an error will be returned. // Callers are also responsible for calling Close after completing tests. // keyName simulates CloudKMS' keyName and is taken into account when calculating expected URL endpoints. // delay allows the caller to simulate delayed responses from KMS. func NewWithResponses(keyName string, port int, delay time.Duration, responses ...json.Marshaler) (*Server, error) { handle := func(req json.Marshaler) (json.Marshaler, int, error) { if len(responses) == 0 { return nil, http.StatusServiceUnavailable, errors.New("list of responses is empty") } status := http.StatusInternalServerError switch req.(type) { case *cloudkms.EncryptRequest: e, ok := responses[0].(*cloudkms.EncryptResponse) if !ok { return nil, status, errors.New("request for encrypt does not have a corresponding response of cloudkms.EncryptResponse") } status = e.HTTPStatusCode case *cloudkms.DecryptRequest: d, ok := responses[0].(*cloudkms.DecryptResponse) if !ok { return nil, status, errors.New("request for decrypt does not have a corresponding response of cloudkms.DecryptResponse") } status = d.HTTPStatusCode case *cloudkms.TestIamPermissionsRequest: t, ok := responses[0].(*cloudkms.TestIamPermissionsResponse) if !ok { return nil, status, errors.New("request for testIamPermissions does not have a corresponding response of cloudkms.TestIAMPermissionResponse") } status = t.HTTPStatusCode } r := responses[0] responses = responses[1:] return r, status, nil } return newWithCallback(keyName, port, delay, handle) } func newWithCallback(keyName string, port int, delay time.Duration, handle func(req json.Marshaler) (json.Marshaler, int, error)) (*Server, error) { var err error if port == 0 { port, err = freeport.GetFreePort() if err != nil { return nil, fmt.Errorf("failed to allocate port for fake kms, error: %v", err) } } s := &Server{port: port} s.srv = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(delay) body, err := ioutil.ReadAll(r.Body) if err != nil { http.Error(w, fmt.Sprintf("can't read the body of the request, error: %v", err), http.StatusBadRequest) return } var ( response json.Marshaler status = http.StatusInternalServerError ) switch r.URL.EscapedPath() { case fmt.Sprintf("/v1/%s:encrypt", keyName): e := &cloudkms.EncryptRequest{} if err := json.Unmarshal(body, e); err != nil { http.Error(w, err.Error(), status) return } s.recordEncryptRequest(e) response, status, err = handle(e) if err != nil { http.Error(w, err.Error(), status) return } case fmt.Sprintf("/v1/%s:decrypt", keyName): d := &cloudkms.DecryptRequest{} if err := json.Unmarshal(body, d); err != nil { http.Error(w, err.Error(), status) return } s.recordDecryptRequest(d) response, status, err = handle(d) if err != nil { http.Error(w, err.Error(), status) return } case fmt.Sprintf("/v1/%s:testIamPermissions", keyName): t := &cloudkms.TestIamPermissionsRequest{} if err := json.Unmarshal(body, t); err != nil { http.Error(w, err.Error(), status) return } s.recordTestIAMRequest(t) response, status, err = handle(t) if err != nil { http.Error(w, err.Error(), status) return } default: http.Error(w, fmt.Sprintf("Was not expecting call to %q", r.URL.EscapedPath()), status) return } w.WriteHeader(status) if err := json.NewEncoder(w).Encode(response); err != nil { http.Error(w, fmt.Sprintf("failed to marshal response, error %v", err), http.StatusInternalServerError) return } })) l, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port)) if err != nil { return nil, fmt.Errorf("failed to listen on port %d, error: %v", port, err) } s.srv.Listener = l s.srv.Start() return s, nil }