router/core/authorizer.go (163 lines of code) (raw):
package core
import (
"context"
"encoding/json"
"io"
"slices"
"sync"
nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1"
"github.com/wundergraph/cosmo/router/pkg/authentication"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
)
type CosmoAuthorizerOptions struct {
FieldConfigurations []*nodev1.FieldConfiguration
RejectOperationIfUnauthorized bool
}
func NewCosmoAuthorizer(opts *CosmoAuthorizerOptions) *CosmoAuthorizer {
return &CosmoAuthorizer{
fieldConfigurations: opts.FieldConfigurations,
rejectUnauthorized: opts.RejectOperationIfUnauthorized,
}
}
type CosmoAuthorizer struct {
fieldConfigurations []*nodev1.FieldConfiguration
rejectUnauthorized bool
}
func (a *CosmoAuthorizer) HasResponseExtensionData(ctx *resolve.Context) bool {
extension := a.getAuthorizationExtension(ctx)
return extension != nil && len(extension.MissingScopes) > 0
}
func (a *CosmoAuthorizer) RenderResponseExtension(ctx *resolve.Context, out io.Writer) error {
extension := a.getAuthorizationExtension(ctx)
if extension == nil {
return nil
}
data, err := json.Marshal(extension)
if err != nil {
return err
}
_, err = out.Write(data)
return err
}
func (a *CosmoAuthorizer) getAuth(ctx context.Context) (isAuthenticated bool, scopes []string) {
auth := authentication.FromContext(ctx)
if auth == nil {
return false, nil
}
return true, auth.Scopes()
}
func (a *CosmoAuthorizer) handleRejectUnauthorized(result *resolve.AuthorizationDeny) (*resolve.AuthorizationDeny, error) {
if result == nil {
return nil, nil
}
if a.rejectUnauthorized {
return nil, ErrUnauthorized
}
return result, nil
}
func (a *CosmoAuthorizer) AuthorizePreFetch(ctx *resolve.Context, dataSourceID string, input json.RawMessage, coordinate resolve.GraphCoordinate) (result *resolve.AuthorizationDeny, err error) {
isAuthenticated, actual := a.getAuth(ctx.Context())
required := a.requiredScopesForField(coordinate)
return a.handleRejectUnauthorized(a.validateScopes(ctx, coordinate, required, isAuthenticated, actual))
}
func (a *CosmoAuthorizer) AuthorizeObjectField(ctx *resolve.Context, dataSourceID string, object json.RawMessage, coordinate resolve.GraphCoordinate) (result *resolve.AuthorizationDeny, err error) {
isAuthenticated, actual := a.getAuth(ctx.Context())
required := a.requiredScopesForField(coordinate)
return a.handleRejectUnauthorized(a.validateScopes(ctx, coordinate, required, isAuthenticated, actual))
}
func (a *CosmoAuthorizer) validateScopes(ctx *resolve.Context, coordinate resolve.GraphCoordinate, requiredOrScopes []*nodev1.Scopes, isAuthenticated bool, actual []string) (result *resolve.AuthorizationDeny) {
if !isAuthenticated {
return &resolve.AuthorizationDeny{
Reason: "not authenticated",
}
}
if len(requiredOrScopes) == 0 {
return nil
}
WithNext:
for _, requiredOrScope := range requiredOrScopes {
for i := range requiredOrScope.RequiredAndScopes {
if !slices.Contains(actual, requiredOrScope.RequiredAndScopes[i]) {
continue WithNext
}
}
return nil
}
a.addMissingScopes(ctx, coordinate, requiredOrScopes, actual)
return &resolve.AuthorizationDeny{
Reason: "missing required scopes",
}
}
func (a *CosmoAuthorizer) addMissingScopes(ctx *resolve.Context, coordinate resolve.GraphCoordinate, requiredOrScopes []*nodev1.Scopes, actual []string) {
extensionCtx := ctx.Context().Value(authorizationExtensionKey{})
if extensionCtx == nil {
return
}
extension := extensionCtx.(*authorizationExtensionCtx)
extension.mux.Lock()
if extension.extension.ActualScopes == nil {
if len(actual) == 0 {
extension.extension.ActualScopes = make([]string, 0)
} else {
extension.extension.ActualScopes = actual
}
}
newMissingScopesError := a.missingScopesError(coordinate, requiredOrScopes)
if !slices.ContainsFunc(extension.extension.MissingScopes, func(existingMissingScopesError MissingScopesError) bool {
return existingMissingScopesError.Coordinate.TypeName == newMissingScopesError.Coordinate.TypeName &&
existingMissingScopesError.Coordinate.FieldName == newMissingScopesError.Coordinate.FieldName
}) {
extension.extension.MissingScopes = append(extension.extension.MissingScopes, newMissingScopesError)
}
extension.mux.Unlock()
}
func (a *CosmoAuthorizer) getAuthorizationExtension(ctx *resolve.Context) *AuthorizationExtension {
extensionCtx := ctx.Context().Value(authorizationExtensionKey{})
if extensionCtx == nil {
return nil
}
extension := extensionCtx.(*authorizationExtensionCtx)
return &extension.extension
}
type authorizationExtensionCtx struct {
extension AuthorizationExtension
mux sync.Mutex
}
type authorizationExtensionKey struct{}
func WithAuthorizationExtension(ctx *resolve.Context) *resolve.Context {
withAuthorization := context.WithValue(ctx.Context(), authorizationExtensionKey{}, &authorizationExtensionCtx{})
return ctx.WithContext(withAuthorization)
}
type AuthorizationExtension struct {
MissingScopes []MissingScopesError `json:"missingScopes,omitempty"`
ActualScopes []string `json:"actualScopes"`
}
type MissingScopesError struct {
Coordinate resolve.GraphCoordinate `json:"coordinate"`
RequiredOrScopes [][]string `json:"required"`
}
type RequiredAndScopes struct {
RequiredAndScopes []string `json:"and"`
}
func (a *CosmoAuthorizer) missingScopesError(coordinate resolve.GraphCoordinate, requiredOrScopes []*nodev1.Scopes) MissingScopesError {
out := MissingScopesError{
Coordinate: coordinate,
RequiredOrScopes: a.requiredAndScopes(requiredOrScopes),
}
return out
}
func (a *CosmoAuthorizer) requiredAndScopes(requiredOrScopes []*nodev1.Scopes) [][]string {
var result [][]string
for i := range requiredOrScopes {
result = append(result, requiredOrScopes[i].RequiredAndScopes)
}
return result
}
func (a *CosmoAuthorizer) requiredScopesForField(coordinate resolve.GraphCoordinate) []*nodev1.Scopes {
for i := range a.fieldConfigurations {
if a.fieldConfigurations[i].TypeName == coordinate.TypeName && a.fieldConfigurations[i].FieldName == coordinate.FieldName {
return a.fieldConfigurations[i].AuthorizationConfiguration.RequiredOrScopes
}
}
return nil
}