lib/scim/scim_group.go (448 lines of code) (raw):
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package scim
import (
"fmt"
"io"
"net/http"
"net/mail"
"regexp"
"sort"
"strconv"
"strings"
"google.golang.org/grpc/codes" /* copybara-comment */
"google.golang.org/grpc/status" /* copybara-comment */
"github.com/golang/protobuf/jsonpb" /* copybara-comment */
"github.com/golang/protobuf/proto" /* copybara-comment */
"github.com/GoogleCloudPlatform/healthcare-federated-access-services/lib/errutil" /* copybara-comment: errutil */
"github.com/GoogleCloudPlatform/healthcare-federated-access-services/lib/handlerfactory" /* copybara-comment: handlerfactory */
"github.com/GoogleCloudPlatform/healthcare-federated-access-services/lib/httputils" /* copybara-comment: httputils */
"github.com/GoogleCloudPlatform/healthcare-federated-access-services/lib/storage" /* copybara-comment: storage */
"github.com/GoogleCloudPlatform/healthcare-federated-access-services/lib/strutil" /* copybara-comment: strutil */
spb "github.com/GoogleCloudPlatform/healthcare-federated-access-services/proto/scim/v2" /* copybara-comment: go_proto */
)
var (
scimGroupFilterMap = map[string]func(p proto.Message) string{
"displayname": func(p proto.Message) string {
return groupProto(p).DisplayName
},
"id": func(p proto.Message) string {
return groupProto(p).Id
},
"$ref": func(p proto.Message) string {
return groupRef(groupProto(p))
},
}
scimMemberFilterMap = map[string]func(p proto.Message) string{
"member.display": func(p proto.Message) string {
return memberProto(p).Display
},
"member.issuer": func(p proto.Message) string {
return memberProto(p).ExtensionIssuer
},
"member.subject": func(p proto.Message) string {
return memberProto(p).ExtensionSubject
},
"member.type": func(p proto.Message) string {
return memberProto(p).Type
},
"member.value": func(p proto.Message) string {
return memberProto(p).Value
},
"$ref": func(p proto.Message) string {
return memberRef(memberProto(p))
},
}
scimGroupsFilterMap = map[string]func(p proto.Message) string{
"displayname": func(p proto.Message) string {
return groupProto(p).DisplayName
},
"id": func(p proto.Message) string {
return groupProto(p).Id
},
"externalid": func(p proto.Message) string {
return groupProto(p).ExternalId
},
}
memberPathRE = regexp.MustCompile(`^members\[\$ref eq "(.*)"\]$`)
)
////////////////////////////////////////////////////////////
// GroupFactory creates handlers for group requests.
func GroupFactory(store storage.Store, groupPath string) *handlerfactory.Options {
return &handlerfactory.Options{
TypeName: "group",
PathPrefix: groupPath,
HasNamedIdentifiers: true,
Service: func() handlerfactory.Service {
return NewGroupHandler(store)
},
}
}
// GroupHandler handles SCIM group requests.
type GroupHandler struct {
item *spb.Group
save *spb.Group
input *spb.Group
patch *spb.Patch
scim *Scim
store storage.Store
tx storage.Tx
}
// NewGroupHandler handles one SCIM group request.
func NewGroupHandler(store storage.Store) *GroupHandler {
return &GroupHandler{
store: store,
scim: New(store),
item: &spb.Group{},
}
}
// Setup sets up the handler.
func (h *GroupHandler) Setup(r *http.Request, tx storage.Tx) (int, error) {
r.ParseForm()
switch r.Method {
case http.MethodPost:
fallthrough
case http.MethodPut:
h.input = &spb.Group{}
if err := jsonpb.Unmarshal(r.Body, h.input); err != nil && err != io.EOF {
return http.StatusBadRequest, err
}
case http.MethodPatch:
h.patch = &spb.Patch{}
if err := jsonpb.Unmarshal(r.Body, h.patch); err != nil && err != io.EOF {
return http.StatusBadRequest, err
}
}
h.tx = tx
return http.StatusOK, nil
}
// LookupItem looks up the item in the storage layer.
func (h *GroupHandler) LookupItem(r *http.Request, name string, vars map[string]string) bool {
group, err := h.scim.LoadGroup(name, getRealm(r), h.tx)
if err != nil || group == nil {
return false
}
h.item = group
return true
}
// NormalizeInput sets up basic structure of request input objects if absent.
func (h *GroupHandler) NormalizeInput(r *http.Request, name string, vars map[string]string) error {
switch r.Method {
case http.MethodPatch:
if len(h.patch.Schemas) != 1 || h.patch.Schemas[0] != scimPatchSchema {
return fmt.Errorf("PATCH requires schemas set to only be %q", scimPatchSchema)
}
case http.MethodPost:
fallthrough
case http.MethodPut:
if len(h.input.Schemas) != 1 || h.input.Schemas[0] != scimGroupSchema {
return fmt.Errorf("%s requires schemas set to only be %q", strings.ToUpper(r.Method), scimGroupSchema)
}
}
if h.input == nil {
return nil
}
switch {
case h.input.Id == "":
h.input.Id = name
case h.input.Id != name:
return errutil.NewError(codes.InvalidArgument, errutil.ErrorPath("scim", "groups", name), fmt.Sprintf("value must not be empty"))
}
for i, member := range h.input.Members {
if member == nil {
return errutil.NewIndexError(codes.InvalidArgument, errutil.ErrorPath("scim", "groups", name), i, fmt.Sprintf("member must not be empty"))
}
if err := h.normalizeMember(member, name, i); err != nil {
return err
}
}
return nil
}
// Get is a GET request.
func (h *GroupHandler) Get(r *http.Request, name string) (proto.Message, error) {
filters, err := storage.BuildFilters(httputils.QueryParam(r, "filter"), scimMemberFilterMap)
if err != nil {
return nil, err
}
// "startIndex" is a 1-based starting location, to be converted to an offset for the query.
start := httputils.QueryParamInt(r, "startIndex")
if start == 0 {
start = 1
}
offset := start - 1
// "count" is the number of results desired on this request's page.
max := httputils.QueryParamInt(r, "count")
if len(httputils.QueryParam(r, "count")) == 0 {
max = storage.DefaultPageSize
}
results, err := h.store.MultiReadTx(storage.GroupMemberDatatype, getRealm(r), name, storage.MatchAllIDs, filters, offset, max, &spb.Member{}, h.tx)
if err != nil {
return nil, err
}
members := make(map[string]*spb.Member)
keys := []string{}
for _, entry := range results.Entries {
if member, ok := entry.Item.(*spb.Member); ok {
member.Ref = member.Value
members[member.Value] = member
keys = append(keys, member.Value)
}
}
sort.Strings(keys)
for _, key := range keys {
h.item.Members = append(h.item.Members, members[key])
}
return h.item, nil
}
// Post is a POST request.
func (h *GroupHandler) Post(r *http.Request, name string) (proto.Message, error) {
h.save = h.input
for i, member := range h.save.Members {
if err := h.normalizeMember(member, name, i); err != nil {
return nil, err
}
if err := h.store.WriteTx(storage.GroupMemberDatatype, getRealm(r), name, member.Value, storage.LatestRev, member, nil, h.tx); err != nil {
return nil, fmt.Errorf("writing group member %q: %v", member.Value, err)
}
}
return nil, nil
}
// Put is a PUT request.
func (h *GroupHandler) Put(r *http.Request, name string) (proto.Message, error) {
// Clean up existing membership
if _, err := h.Remove(r, name); err != nil {
return nil, err
}
return h.Post(r, name)
}
// Patch is a PATCH request.
func (h *GroupHandler) Patch(r *http.Request, name string) (proto.Message, error) {
h.save = &spb.Group{}
proto.Merge(h.save, h.item)
memberCounter := 0
for i, patch := range h.patch.Operations {
path := patch.Path
if memberPathRE.MatchString(path) {
path = "member"
}
src := ""
var dst *string
switch path {
case "displayName":
src = patchSource(patch.Value)
dst = &h.save.DisplayName
if patch.Op == "remove" || len(src) == 0 {
return nil, errutil.NewIndexError(codes.InvalidArgument, errutil.ErrorPath("scim", "groups", name, path), i, fmt.Sprintf("value must not be empty"))
}
case "members":
if patch.Op != "add" {
return nil, errutil.NewIndexError(codes.InvalidArgument, errutil.ErrorPath("scim", "groups", name, path), i, fmt.Sprintf("op %q is not valid", patch.Op))
}
member, err := h.patchMember(patch.Object, name, memberCounter)
if err != nil {
return nil, err
}
memberCounter++
if err := h.store.WriteTx(storage.GroupMemberDatatype, getRealm(r), name, member.Value, storage.LatestRev, member, nil, h.tx); err != nil {
return nil, errutil.NewIndexError(codes.InvalidArgument, errutil.ErrorPath("scim", "groups", name, path), i, err.Error())
}
case "member":
if patch.Op != "remove" {
return nil, errutil.NewIndexError(codes.InvalidArgument, errutil.ErrorPath("scim", "groups", name, path), i, fmt.Sprintf("op %q is not valid", patch.Op))
}
match := memberPathRE.FindStringSubmatch(patch.Path)
if len(match) < 2 {
return nil, errutil.NewIndexError(codes.InvalidArgument, errutil.ErrorPath("scim", "groups", name, path), i, fmt.Sprintf("invalid member path %q", patch.Path))
}
memberName := match[1]
if err := h.store.DeleteTx(storage.GroupMemberDatatype, getRealm(r), name, memberName, storage.LatestRev, h.tx); err != nil {
if storage.ErrNotFound(err) {
return nil, errutil.NewIndexError(codes.InvalidArgument, errutil.ErrorPath("scim", "groups", name, path), i, fmt.Sprintf("%q is not a member of the group", memberName))
}
return nil, errutil.NewIndexError(codes.InvalidArgument, errutil.ErrorPath("scim", "groups", name, path), i, err.Error())
}
default:
return nil, errutil.NewIndexError(codes.InvalidArgument, errutil.ErrorPath("scim", "groups", name, path), i, fmt.Sprintf("invalid path %q", patch.Path))
}
if dst == nil {
continue
}
if patch.Op != "remove" && len(src) == 0 {
return nil, errutil.NewIndexError(codes.InvalidArgument, errutil.ErrorPath("scim", "groups", name, path), i, fmt.Sprintf("cannot set an empty value"))
}
switch patch.Op {
case "add":
fallthrough
case "replace":
*dst = src
case "remove":
*dst = ""
default:
return nil, errutil.NewIndexError(codes.InvalidArgument, errutil.ErrorPath("scim", "groups", name, path), i, fmt.Sprintf("invalid op %q", patch.Op))
}
}
// Output the new result: Get() will return contents from h.item with the latest edits from h.save.
// Needs a deep copy since h.save as the item saved will not include members once Save() is called
// but the item returned to the client will include members.
h.item = proto.Clone(h.save).(*spb.Group)
return h.Get(r, name)
}
// Remove is a DELETE request.
func (h *GroupHandler) Remove(r *http.Request, name string) (proto.Message, error) {
if err := h.store.MultiDeleteTx(storage.GroupMemberDatatype, getRealm(r), name, h.tx); err != nil {
return nil, err
}
return nil, h.store.DeleteTx(storage.GroupDatatype, getRealm(r), name, storage.DefaultID, storage.LatestRev, h.tx)
}
// CheckIntegrity checks that any modifications make sense before applying them.
func (h *GroupHandler) CheckIntegrity(*http.Request) *status.Status {
return nil
}
// Save will save any modifications done for the request.
func (h *GroupHandler) Save(r *http.Request, tx storage.Tx, name string, vars map[string]string, desc, typeName string) error {
if h.save == nil {
return nil
}
h.save.Members = nil // members are stored separately.
return h.store.WriteTx(storage.GroupDatatype, getRealm(r), name, storage.DefaultID, storage.LatestRev, h.save, nil, h.tx)
}
func (h *GroupHandler) patchMember(object map[string]string, name string, idx int) (*spb.Member, error) {
if object == nil {
return nil, fmt.Errorf("member not provided")
}
typ := object["type"]
if typ == "" {
typ = "User"
}
member := &spb.Member{
Type: typ,
Value: object["value"],
ExtensionIssuer: object["issuer"],
ExtensionSubject: object["subject"],
}
if err := h.normalizeMember(member, name, idx); err != nil {
return nil, err
}
return member, nil
}
func (h *GroupHandler) normalizeMember(member *spb.Member, name string, idx int) error {
switch member.Type {
case "User":
case "":
member.Type = "User"
default:
return errutil.NewIndexError(codes.InvalidArgument, errutil.ErrorPath("scim", "groups", name, "members", strconv.Itoa(idx), "type"), idx, "invalid member type")
}
email, err := mail.ParseAddress(member.Value)
if err != nil {
return errutil.NewIndexError(codes.InvalidArgument, errutil.ErrorPath("scim", "groups", name, "members", strconv.Itoa(idx), "value"), idx, fmt.Sprintf("%q must be an email address", member.Value))
}
member.Value = email.Address
if member.Display == "" && email.Name != "" {
member.Display = strings.TrimSpace(email.Name)
}
if member.Display != "" && strings.Contains(member.Display, "@") {
// Do not accept email addresses as the display name.
// Reject when a different email address, or remove display field when it repeats the value field.
if member.Display != member.Value {
return errutil.NewIndexError(codes.InvalidArgument, errutil.ErrorPath("scim", "groups", name, "members", strconv.Itoa(idx), "display"), idx, "display name as an email address not allowed")
}
member.Display = ""
}
if member.ExtensionIssuer != "" && !strutil.IsURL(member.ExtensionIssuer) {
return errutil.NewIndexError(codes.InvalidArgument, errutil.ErrorPath("scim", "groups", name, "members", strconv.Itoa(idx), "issuer"), idx, fmt.Sprintf("invalid member issuer %q", member.ExtensionIssuer))
}
if member.ExtensionIssuer != "" && len(member.ExtensionIssuer) > 256 {
return errutil.NewIndexError(codes.InvalidArgument, errutil.ErrorPath("scim", "groups", name, "members", strconv.Itoa(idx), "issuer"), idx, fmt.Sprintf("member issuer %q exceeds maximum length", member.ExtensionIssuer))
}
if member.ExtensionSubject != "" && len(member.ExtensionSubject) > 60 {
return errutil.NewIndexError(codes.InvalidArgument, errutil.ErrorPath("scim", "groups", name, "members", strconv.Itoa(idx), "subject"), idx, fmt.Sprintf("member subject %q exceeds maximum length", member.ExtensionSubject))
}
return nil
}
////////////////////////////////////////////////////////////
// GroupsFactory creates handlers for group requests.
func GroupsFactory(store storage.Store, path string) *handlerfactory.Options {
return &handlerfactory.Options{
TypeName: "groups",
PathPrefix: path,
HasNamedIdentifiers: false,
Service: func() handlerfactory.Service {
return NewGroupsHandler(store)
},
}
}
// GroupsHandler handles SCIM group requests.
type GroupsHandler struct {
scim *Scim
store storage.Store
tx storage.Tx
}
// NewGroupsHandler handles the SCIM groups request.
func NewGroupsHandler(store storage.Store) *GroupsHandler {
return &GroupsHandler{
store: store,
scim: New(store),
}
}
// Setup sets up the handler.
func (h *GroupsHandler) Setup(r *http.Request, tx storage.Tx) (int, error) {
r.ParseForm()
h.tx = tx
return http.StatusOK, nil
}
// LookupItem returns true if the named object is found.
func (h *GroupsHandler) LookupItem(r *http.Request, name string, vars map[string]string) bool {
return true
}
// NormalizeInput sets up basic structure of request input objects if absent.
func (h *GroupsHandler) NormalizeInput(r *http.Request, name string, vars map[string]string) error {
return nil
}
// Get is a GET request.
func (h *GroupsHandler) Get(r *http.Request, name string) (proto.Message, error) {
filters, err := storage.BuildFilters(httputils.QueryParam(r, "filter"), scimGroupsFilterMap)
if err != nil {
return nil, err
}
// "startIndex" is a 1-based starting location, to be converted to an offset for the query.
start := httputils.QueryParamInt(r, "startIndex")
if start == 0 {
start = 1
}
offset := start - 1
// "count" is the number of results desired on this request's page.
max := httputils.QueryParamInt(r, "count")
if len(httputils.QueryParam(r, "count")) == 0 {
max = storage.DefaultPageSize
}
results, err := h.store.MultiReadTx(storage.GroupDatatype, getRealm(r), storage.MatchAllGroups, storage.MatchAllIDs, filters, offset, max, &spb.Group{}, h.tx)
if err != nil {
return nil, err
}
groups := make(map[string]*spb.Group)
names := []string{}
for _, entry := range results.Entries {
if group, ok := entry.Item.(*spb.Group); ok {
groups[group.Id] = group
names = append(names, group.Id)
}
}
sort.Strings(names)
var list []*spb.Group
for _, name := range names {
list = append(list, groups[name])
}
resp := &spb.ListGroupsResponse{
Schemas: []string{scimListSchema},
TotalResults: uint32(offset + results.MatchCount),
ItemsPerPage: uint32(len(list)),
StartIndex: uint32(start),
Resources: list,
}
return resp, nil
}
// Post is a POST request.
func (h *GroupsHandler) Post(r *http.Request, name string) (proto.Message, error) {
return nil, fmt.Errorf("POST not allowed")
}
// Put is a PUT request.
func (h *GroupsHandler) Put(r *http.Request, name string) (proto.Message, error) {
return nil, fmt.Errorf("PUT not allowed")
}
// Patch is a PATCH request.
func (h *GroupsHandler) Patch(r *http.Request, name string) (proto.Message, error) {
return nil, fmt.Errorf("PATCH not allowed")
}
// Remove is a DELETE request.
func (h *GroupsHandler) Remove(r *http.Request, name string) (proto.Message, error) {
return nil, fmt.Errorf("DELETE not allowed")
}
// CheckIntegrity checks that any modifications make sense before applying them.
func (h *GroupsHandler) CheckIntegrity(*http.Request) *status.Status {
return nil
}
// Save will save any modifications done for the request.
func (h *GroupsHandler) Save(r *http.Request, tx storage.Tx, name string, vars map[string]string, desc, typeName string) error {
return nil
}
////////////////////////////////////////////////////////////
func memberProto(p proto.Message) *spb.Member {
member, ok := p.(*spb.Member)
if !ok {
return &spb.Member{}
}
return member
}
func groupProto(p proto.Message) *spb.Group {
group, ok := p.(*spb.Group)
if !ok {
return &spb.Group{}
}
return group
}
func groupRef(group *spb.Group) string {
return "group/" + group.Id
}
func memberRef(member *spb.Member) string {
return "member/" + member.Value
}