server/httpproxy.go (230 lines of code) (raw):
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"strings"
cwgrpc "github.com/GoogleCloudPlatform/stet/proto/confidential_wrap_go_proto"
cwpb "github.com/GoogleCloudPlatform/stet/proto/confidential_wrap_go_proto"
ssgrpc "github.com/GoogleCloudPlatform/stet/proto/secure_session_go_proto"
sspb "github.com/GoogleCloudPlatform/stet/proto/secure_session_go_proto"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)
const (
beginSessionEndpoint = "/session/beginsession"
handshakeEndpoint = "/session/handshake"
negotiateAttestationEndpoint = "/session/negotiateattestation"
finalizeEndpoint = "/session/finalize"
endSessionEndpoint = "/session/endsession"
confidentialWrapEndpoint = ":confidentialwrap"
confidentialUnwrapEndpoint = ":confidentialunwrap"
)
// ekmToken is a struct that implements credentials.PerRPCCredentials to
// store a bearer token for authenticating requests to the EKM.
type ekmToken struct {
token string
}
func (t ekmToken) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
return map[string]string{
TokenMetadataKey: fmt.Sprintf("%s%s", TokenPrefix, t.token),
}, nil
}
func (ekmToken) RequireTransportSecurity() bool {
return false
}
// SecureSessionHTTPService is an HTTP-to-gRPC proxy for SecureSessionService, to be used for local testing only.
type SecureSessionHTTPService struct {
sessionClient ssgrpc.ConfidentialEkmSessionEstablishmentServiceClient
wrapClient cwgrpc.ConfidentialWrapUnwrapServiceClient
}
// NewSecureSessionHTTPService creates and returns an instance of SecureSessionHTTPService.
// The Caller should Close using SecureSessionHTTPService.Close() when finished.
func NewSecureSessionHTTPService(address, authToken string) (*SecureSessionHTTPService, error) {
srv := &SecureSessionHTTPService{}
if err := srv.connectToGRPCServer(address, authToken); err != nil {
return nil, fmt.Errorf("error initializing test server: %w", err)
}
return srv, nil
}
// NewSecureSessionHTTPServiceWithFakeClients creates and returns an instance of SecureSessionHTTPService
// with the provided fake clients.
// The Caller should Close using SecureSessionHTTPService.Close() when finished.
func NewSecureSessionHTTPServiceWithFakeClients(address, authToken string, sessionClient ssgrpc.ConfidentialEkmSessionEstablishmentServiceClient, wrapClient cwgrpc.ConfidentialWrapUnwrapServiceClient) (*SecureSessionHTTPService, error) {
if (sessionClient == nil) != (wrapClient == nil) {
return nil, fmt.Errorf("only one fake client provided, must specify both or neither")
}
srv := &SecureSessionHTTPService{
sessionClient: sessionClient,
wrapClient: wrapClient,
}
return srv, nil
}
func processHTTPRequest(ctx context.Context, httpReq *http.Request, protoReq proto.Message) (context.Context, error) {
defer httpReq.Body.Close()
reqBody, err := ioutil.ReadAll(httpReq.Body)
if err != nil {
return ctx, fmt.Errorf("unable to read HTTP request body: %w", err)
}
if err = protojson.Unmarshal(reqBody, protoReq); err != nil {
return ctx, fmt.Errorf("unable to unmarshal HTTP request body: %w", err)
}
return metadata.AppendToOutgoingContext(ctx, TokenMetadataKey, httpReq.Header.Get(TokenMetadataKey)), nil
}
func (s *SecureSessionHTTPService) handleBeginSession(ctx context.Context, w http.ResponseWriter, r *http.Request) {
req := &sspb.BeginSessionRequest{}
reqCtx, err := processHTTPRequest(ctx, r, req)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
}
resp, err := s.sessionClient.BeginSession(reqCtx, req)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
}
marshaled, err := protojson.Marshal(resp)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
}
w.Write(marshaled)
}
func (s *SecureSessionHTTPService) handleHandshake(ctx context.Context, w http.ResponseWriter, r *http.Request) {
req := &sspb.HandshakeRequest{}
reqCtx, err := processHTTPRequest(ctx, r, req)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
}
resp, err := s.sessionClient.Handshake(reqCtx, req)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
}
marshaled, err := protojson.Marshal(resp)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
}
w.Write(marshaled)
}
func (s *SecureSessionHTTPService) handleNegotiateAttestation(ctx context.Context, w http.ResponseWriter, r *http.Request) {
req := &sspb.NegotiateAttestationRequest{}
reqCtx, err := processHTTPRequest(ctx, r, req)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
}
resp, err := s.sessionClient.NegotiateAttestation(reqCtx, req)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
}
marshaled, err := protojson.Marshal(resp)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
}
w.Write(marshaled)
}
func (s *SecureSessionHTTPService) handleFinalize(ctx context.Context, w http.ResponseWriter, r *http.Request) {
req := &sspb.FinalizeRequest{}
reqCtx, err := processHTTPRequest(ctx, r, req)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
}
resp, err := s.sessionClient.Finalize(reqCtx, req)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
}
marshaled, err := protojson.Marshal(resp)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
}
w.Write(marshaled)
}
func (s *SecureSessionHTTPService) handleEndSession(ctx context.Context, w http.ResponseWriter, r *http.Request) {
req := &sspb.EndSessionRequest{}
reqCtx, err := processHTTPRequest(ctx, r, req)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
}
resp, err := s.sessionClient.EndSession(reqCtx, req)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
}
marshaled, err := protojson.Marshal(resp)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
}
w.Write(marshaled)
}
func (s *SecureSessionHTTPService) handleConfidentialWrap(ctx context.Context, w http.ResponseWriter, r *http.Request) {
req := &cwpb.ConfidentialWrapRequest{}
reqCtx, err := processHTTPRequest(ctx, r, req)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
}
resp, err := s.wrapClient.ConfidentialWrap(reqCtx, req)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
}
marshaled, err := protojson.Marshal(resp)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
}
w.Write(marshaled)
}
func (s *SecureSessionHTTPService) handleConfidentialUnwrap(ctx context.Context, w http.ResponseWriter, r *http.Request) {
req := &cwpb.ConfidentialUnwrapRequest{}
reqCtx, err := processHTTPRequest(ctx, r, req)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
}
resp, err := s.wrapClient.ConfidentialUnwrap(reqCtx, req)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
}
marshaled, err := protojson.Marshal(resp)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
}
w.Write(marshaled)
}
// Handler acts as a HandlerFunc for HTTP servers.
func (s *SecureSessionHTTPService) Handler(w http.ResponseWriter, r *http.Request) {
endpoint := r.URL.String()
ctx := r.Context()
if strings.HasSuffix(endpoint, beginSessionEndpoint) {
s.handleBeginSession(ctx, w, r)
} else if strings.HasSuffix(endpoint, handshakeEndpoint) {
s.handleHandshake(ctx, w, r)
} else if strings.HasSuffix(endpoint, negotiateAttestationEndpoint) {
s.handleNegotiateAttestation(ctx, w, r)
} else if strings.HasSuffix(endpoint, finalizeEndpoint) {
s.handleFinalize(ctx, w, r)
} else if strings.HasSuffix(endpoint, endSessionEndpoint) {
s.handleEndSession(ctx, w, r)
} else if strings.HasSuffix(endpoint, confidentialWrapEndpoint) {
s.handleConfidentialWrap(ctx, w, r)
} else if strings.HasSuffix(endpoint, confidentialUnwrapEndpoint) {
s.handleConfidentialUnwrap(ctx, w, r)
} else {
// If no match found, respond with error.
w.WriteHeader(http.StatusBadRequest)
}
}
// Initializes gRPC clients and connects to services at given address. Creates and returns httptest server.
func (s *SecureSessionHTTPService) connectToGRPCServer(address, authToken string) error {
grpcOpts := []grpc.DialOption{grpc.WithInsecure()}
// Add bearer token to requests if present.
if authToken != "" {
grpcOpts = append(grpcOpts, grpc.WithPerRPCCredentials(ekmToken{token: authToken}))
}
conn, err := grpc.Dial(address, grpcOpts...)
if err != nil {
return fmt.Errorf("error creating gRPC client connection: %w", err)
}
s.sessionClient = ssgrpc.NewConfidentialEkmSessionEstablishmentServiceClient(conn)
s.wrapClient = cwgrpc.NewConfidentialWrapUnwrapServiceClient(conn)
return nil
}