internal/mock/alloydbadmin.go (140 lines of code) (raw):

// Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package mock import ( "crypto/x509" "encoding/json" "encoding/pem" "fmt" "io" "net/http" "net/http/httptest" "sync" "cloud.google.com/go/alloydb/apiv1alpha/alloydbpb" "google.golang.org/protobuf/encoding/protojson" ) // Request represents a HTTP request for a test Server to mock responses for. // // Use NewRequest to initialize new Requests. type Request struct { sync.Mutex reqMethod string reqPath string reqCt int handle func(resp http.ResponseWriter, req *http.Request) } // matches returns true if a given http.Request should be handled by this Request. func (r *Request) matches(hR *http.Request) bool { r.Lock() defer r.Unlock() if r.reqMethod != "" && r.reqMethod != hR.Method { return false } if r.reqPath != "" && r.reqPath != hR.URL.Path { return false } if r.reqCt <= 0 { return false } r.reqCt-- return true } // InstanceGetSuccess returns a Request that responds to the `instance.get` // AlloyDB Admin API endpoint. func InstanceGetSuccess(i FakeAlloyDBInstance, ct int) *Request { p := fmt.Sprintf("/v1alpha/projects/%s/locations/%s/clusters/%s/instances/%s/connectionInfo", i.project, i.region, i.cluster, i.name) res := map[string]string{} for ipType, addr := range i.ipAddrs { if ipType == "PRIVATE" { res["ipAddress"] = addr continue } if ipType == "PUBLIC" { res["publicIpAddress"] = addr } if ipType == "PSC" { res["psc_dns_name"] = addr } } res["instanceUid"] = i.uid jsonString, err := json.Marshal(res) if err != nil { panic(err) } return &Request{ reqMethod: http.MethodGet, reqPath: p, reqCt: ct, handle: func(resp http.ResponseWriter, _ *http.Request) { resp.WriteHeader(http.StatusOK) resp.Write(jsonString) }, } } // CreateEphemeralSuccess returns a Request that responds to the // `generateClientCertificate` AlloyDB Admin API endpoint. func CreateEphemeralSuccess(i FakeAlloyDBInstance, ct int) *Request { return &Request{ reqMethod: http.MethodPost, reqPath: fmt.Sprintf( "/v1alpha/projects/%s/locations/%s/clusters/%s:generateClientCertificate", i.project, i.region, i.cluster), reqCt: ct, handle: func(resp http.ResponseWriter, req *http.Request) { // Read the body from the request. b, err := io.ReadAll(req.Body) defer req.Body.Close() if err != nil { http.Error(resp, fmt.Errorf("unable to read body: %w", err).Error(), http.StatusBadRequest) return } var rreq alloydbpb.GenerateClientCertificateRequest err = protojson.Unmarshal(b, &rreq) if err != nil { http.Error(resp, fmt.Errorf("invalid or unexpected json: %w", err).Error(), http.StatusBadRequest) return } bl, _ := pem.Decode([]byte(rreq.PublicKey)) if bl == nil { http.Error(resp, fmt.Errorf("unable to decode CSR: %w", err).Error(), http.StatusBadRequest) return } pub, err := x509.ParsePKCS1PublicKey(bl.Bytes) if err != nil { http.Error(resp, fmt.Errorf("unable to decode CSR: %w", err).Error(), http.StatusBadRequest) return } chain, err := i.GeneratePEMCertificateChain(pub) if err != nil { http.Error( resp, fmt.Errorf("unable to create certificate: %w", err).Error(), http.StatusBadRequest, ) return } rresp := alloydbpb.GenerateClientCertificateResponse{ CaCert: chain[len(chain)-1], // last entry is CA PemCertificateChain: chain, } if err := json.NewEncoder(resp).Encode(&rresp); err != nil { http.Error(resp, fmt.Errorf("unable to encode response: %w", err).Error(), http.StatusBadRequest) return } }, } } // HTTPClient returns an *http.Client, URL, and cleanup function. The http.Client is // configured to connect to test SSL Server at the returned URL. This server will // respond to HTTP requests defined, or return a 5xx server error for unexpected ones. // The cleanup function will close the server, and return an error if any expected calls // weren't received. func HTTPClient(requests ...*Request) (*http.Client, string, func() error) { // Create a TLS Server that responses to the requests defined s := httptest.NewTLSServer(http.HandlerFunc( func(resp http.ResponseWriter, req *http.Request) { for _, r := range requests { if r.matches(req) { r.handle(resp, req) return } } // Unexpected requests should throw an error resp.WriteHeader(http.StatusNotImplemented) // TODO: follow error format better? resp.Write([]byte(fmt.Sprintf("unexpected request sent to mock client: %v", req))) }, )) // cleanup stops the test server and checks for uncalled requests cleanup := func() error { s.Close() for i, e := range requests { if e.reqCt > 0 { return fmt.Errorf("%d calls left for specified call in pos %d: %v", e.reqCt, i, e) } } return nil } return s.Client(), s.URL, cleanup }