auth/token.go (105 lines of code) (raw):
package gitalyauth
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"strconv"
"strings"
"time"
grpcmwauth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
var (
//nolint:gochecknoglobals
// This infrastructure is required for testing purposes and there is no
// proper place to put it instead. While we could move it into the
// config, we certainly don't want to make it configurable for now, so
// it'd be a bad fit there.
tokenValidityDuration = 30 * time.Second
errUnauthenticated = status.Errorf(codes.Unauthenticated, "authentication required")
authErrors = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "gitaly_authentication_errors_total",
Help: "Counts of Gitaly request authentication errors",
},
[]string{"version", "error"},
)
)
const tokenVersionV2 = "v2"
func newPermissionDeniedError(reason string) error {
return status.Errorf(codes.PermissionDenied, "permission denied: %s", reason)
}
// TokenValidityDuration returns the duration for which any token will be
// valid. This is currently only used by our testing infrastructure.
func TokenValidityDuration() time.Duration {
return tokenValidityDuration
}
// SetTokenValidityDuration changes the duration for which any token will be
// valid. It only applies to newly created tokens.
func SetTokenValidityDuration(d time.Duration) {
tokenValidityDuration = d
}
// AuthInfo contains the authentication information coming from a request
type AuthInfo struct {
Version string
SignedMessage []byte
Message string
}
// CheckToken checks the 'authentication' header of incoming gRPC
// metadata in ctx. It returns nil if and only if the token matches
// secret.
func CheckToken(ctx context.Context, secret string, targetTime time.Time) error {
if len(secret) == 0 {
return status.Errorf(codes.Unauthenticated, "secret must not be empty")
}
authInfo, err := ExtractAuthInfo(ctx)
if err != nil {
return errUnauthenticated
}
if authInfo.Version != tokenVersionV2 {
return newPermissionDeniedError("invalid token version")
}
return v2HmacInfoValid(authInfo.Message, authInfo.SignedMessage, []byte(secret), targetTime, tokenValidityDuration)
}
// ExtractAuthInfo returns an `AuthInfo` with the data extracted from `ctx`
func ExtractAuthInfo(ctx context.Context) (*AuthInfo, error) {
token, err := grpcmwauth.AuthFromMD(ctx, "bearer")
if err != nil {
return nil, err
}
split := strings.SplitN(token, ".", 3)
if len(split) != 3 {
return nil, fmt.Errorf("invalid token format")
}
version, sig, msg := split[0], split[1], split[2]
decodedSig, err := hex.DecodeString(sig)
if err != nil {
return nil, err
}
return &AuthInfo{Version: version, SignedMessage: decodedSig, Message: msg}, nil
}
func countV2Error(message string) { authErrors.WithLabelValues(tokenVersionV2, message).Inc() }
func v2HmacInfoValid(message string, signedMessage, secret []byte, targetTime time.Time, tokenValidity time.Duration) error {
expectedHMAC := hmacSign(secret, message)
if !hmac.Equal(signedMessage, expectedHMAC) {
const reason = "wrong hmac signature"
countV2Error(reason)
return newPermissionDeniedError(reason)
}
timestamp, err := strconv.ParseInt(message, 10, 64)
if err != nil {
const reason = "cannot parse timestamp"
countV2Error(reason)
return newPermissionDeniedError(fmt.Sprintf("%s: %s", reason, err))
}
issuedAt := time.Unix(timestamp, 0)
lowerBound := targetTime.Add(-tokenValidity)
upperBound := targetTime.Add(tokenValidity)
if issuedAt.Before(lowerBound) {
const reason = "token has expired"
countV2Error(reason)
return newPermissionDeniedError(reason)
}
if issuedAt.After(upperBound) {
const reason = "token's validity window is in future"
countV2Error(reason)
return newPermissionDeniedError(reason)
}
return nil
}
func hmacSign(secret []byte, message string) []byte {
mac := hmac.New(sha256.New, secret)
// hash.Hash never returns an error.
_, _ = mac.Write([]byte(message))
return mac.Sum(nil)
}