func newWithCallback()

in testutils/fakekms/fakekms.go [191:274]


func newWithCallback(keyName string, port int, delay time.Duration, handle func(req json.Marshaler) (json.Marshaler, int, error)) (*Server, error) {
	var err error
	if port == 0 {
		port, err = freeport.GetFreePort()
		if err != nil {
			return nil, fmt.Errorf("failed to allocate port for fake kms, error: %v", err)
		}
	}

	s := &Server{port: port}
	s.srv = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		time.Sleep(delay)

		body, err := ioutil.ReadAll(r.Body)
		if err != nil {
			http.Error(w, fmt.Sprintf("can't read the body of the request, error: %v", err), http.StatusBadRequest)
			return
		}

		var (
			response json.Marshaler
			status   = http.StatusInternalServerError
		)
		switch r.URL.EscapedPath() {
		case fmt.Sprintf("/v1/%s:encrypt", keyName):
			e := &cloudkms.EncryptRequest{}
			if err := json.Unmarshal(body, e); err != nil {
				http.Error(w, err.Error(), status)
				return
			}
			s.recordEncryptRequest(e)

			response, status, err = handle(e)
			if err != nil {
				http.Error(w, err.Error(), status)
				return
			}
		case fmt.Sprintf("/v1/%s:decrypt", keyName):
			d := &cloudkms.DecryptRequest{}
			if err := json.Unmarshal(body, d); err != nil {
				http.Error(w, err.Error(), status)
				return
			}
			s.recordDecryptRequest(d)

			response, status, err = handle(d)
			if err != nil {
				http.Error(w, err.Error(), status)
				return
			}
		case fmt.Sprintf("/v1/%s:testIamPermissions", keyName):
			t := &cloudkms.TestIamPermissionsRequest{}
			if err := json.Unmarshal(body, t); err != nil {
				http.Error(w, err.Error(), status)
				return
			}
			s.recordTestIAMRequest(t)

			response, status, err = handle(t)
			if err != nil {
				http.Error(w, err.Error(), status)
				return
			}
		default:
			http.Error(w, fmt.Sprintf("Was not expecting call to %q", r.URL.EscapedPath()), status)
			return
		}

		w.WriteHeader(status)
		if err := json.NewEncoder(w).Encode(response); err != nil {
			http.Error(w, fmt.Sprintf("failed to marshal response, error %v", err), http.StatusInternalServerError)
			return
		}
	}))

	l, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
	if err != nil {
		return nil, fmt.Errorf("failed to listen on port %d, error: %v", port, err)
	}
	s.srv.Listener = l
	s.srv.Start()

	return s, nil
}