traffic_ops/traffic_ops_golang/user/user.go (920 lines of code) (raw):

package user /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ import ( "database/sql" "encoding/json" "errors" "fmt" "net/http" "strconv" "strings" "time" "github.com/apache/trafficcontrol/v8/lib/go-log" "github.com/apache/trafficcontrol/v8/lib/go-rfc" "github.com/apache/trafficcontrol/v8/lib/go-tc" "github.com/apache/trafficcontrol/v8/lib/go-tc/tovalidate" "github.com/apache/trafficcontrol/v8/lib/go-util" "github.com/apache/trafficcontrol/v8/traffic_ops/traffic_ops_golang/api" "github.com/apache/trafficcontrol/v8/traffic_ops/traffic_ops_golang/auth" "github.com/apache/trafficcontrol/v8/traffic_ops/traffic_ops_golang/dbhelpers" "github.com/apache/trafficcontrol/v8/traffic_ops/traffic_ops_golang/tenant" "github.com/apache/trafficcontrol/v8/traffic_ops/traffic_ops_golang/util/ims" validation "github.com/go-ozzo/ozzo-validation" "github.com/go-ozzo/ozzo-validation/is" "github.com/jmoiron/sqlx" ) type TOUser struct { api.APIInfoImpl `json:"-"` tc.User } func (user TOUser) GetKeyFieldsInfo() []api.KeyFieldInfo { return []api.KeyFieldInfo{{Field: "id", Func: api.GetIntKey}} } func (user TOUser) GetKeys() (map[string]interface{}, bool) { if user.ID == nil { return map[string]interface{}{"id": 0}, false } return map[string]interface{}{"id": *user.ID}, true } func (user TOUser) GetAuditName() string { if user.Username != nil { return *user.Username } if user.ID != nil { return strconv.Itoa(*user.ID) } return "unknown" } func (user TOUser) GetType() string { return "user" } func (user *TOUser) SetKeys(keys map[string]interface{}) { i, _ := keys["id"].(int) // non-panicking type assertion user.ID = &i } func (user *TOUser) SetLastUpdated(t tc.TimeNoMod) { user.LastUpdated = &t } func (user *TOUser) NewReadObj() interface{} { return &tc.User{} } func (user *TOUser) ParamColumns() map[string]dbhelpers.WhereColumnInfo { return map[string]dbhelpers.WhereColumnInfo{ "id": dbhelpers.WhereColumnInfo{Column: "u.id", Checker: api.IsInt}, "role": dbhelpers.WhereColumnInfo{Column: "r.name"}, "tenant": dbhelpers.WhereColumnInfo{Column: "t.name"}, "username": dbhelpers.WhereColumnInfo{Column: "u.username"}, } } func (user *TOUser) Validate() (error, error) { validateErrs := validation.Errors{ "email": validation.Validate(user.Email, validation.Required, is.Email), "fullName": validation.Validate(user.FullName, validation.Required), "role": validation.Validate(user.Role, validation.Required), "username": validation.Validate(user.Username, validation.Required), "tenantID": validation.Validate(user.TenantID, validation.Required), } // Password is not required for update if user.LocalPassword != nil { _, err := auth.IsGoodLoginPair(*user.Username, *user.LocalPassword) if err != nil { return err, nil } } return util.JoinErrs(tovalidate.ToErrors(validateErrs)), nil } func (user *TOUser) postValidate() error { validateErrs := validation.Errors{ "localPasswd": validation.Validate(user.LocalPassword, validation.Required), } return util.JoinErrs(tovalidate.ToErrors(validateErrs)) } func postValidateV40(user tc.UserV4) error { validateErrs := validation.Errors{ "localPasswd": validation.Validate(user.LocalPassword, validation.Required), } return util.JoinErrs(tovalidate.ToErrors(validateErrs)) } // Note: Not using GenericCreate because Scan also needs to scan tenant and rolename func (user *TOUser) Create() (error, error, int) { // PUT and POST validation differs slightly err := user.postValidate() if err != nil { return err, nil, http.StatusBadRequest } // make sure the user cannot create someone with a higher priv_level than themselves if usrErr, sysErr, code := user.privCheck(); code != http.StatusOK { return usrErr, sysErr, code } var caps []string if user.Role != nil { caps, err = dbhelpers.GetCapabilitiesFromRoleID(user.ReqInfo.Tx.Tx, *user.Role) } else if user.RoleName != nil { caps, err = dbhelpers.GetCapabilitiesFromRoleName(user.ReqInfo.Tx.Tx, *user.RoleName) } if err != nil { return nil, err, http.StatusInternalServerError } missing := user.ReqInfo.User.MissingPermissions(caps...) if len(missing) != 0 { return fmt.Errorf("cannot request more than assigned permissions, current user needs %s permissions", strings.Join(missing, ",")), nil, http.StatusForbidden } // Convert password to SCRYPT *user.LocalPassword, err = auth.DerivePassword(*user.LocalPassword) if err != nil { return err, nil, http.StatusBadRequest } resultRows, err := user.ReqInfo.Tx.NamedQuery(user.InsertQuery(), user) if err != nil { return api.ParseDBError(err) } defer resultRows.Close() var id int var lastUpdated tc.TimeNoMod var tenant string var rolename string rowsAffected := 0 for resultRows.Next() { rowsAffected++ if err = resultRows.Scan(&id, &lastUpdated, &tenant, &rolename); err != nil { return nil, fmt.Errorf("could not scan after insert: %s\n)", err), http.StatusInternalServerError } } if rowsAffected == 0 { return nil, fmt.Errorf("no user was inserted, nothing was returned"), http.StatusInternalServerError } else if rowsAffected > 1 { return nil, fmt.Errorf("too many rows affected from user insert"), http.StatusInternalServerError } user.ID = &id user.LastUpdated = &lastUpdated user.Tenant = &tenant user.RoleName = &rolename user.LocalPassword = nil return nil, nil, http.StatusOK } // This is not using GenericRead because of this tenancy check. Maybe we can add tenancy functionality to the generic case? func (this *TOUser) Read(h http.Header, useIMS bool) ([]interface{}, error, error, int, *time.Time) { var maxTime time.Time var runSecond bool var query string inf := this.APIInfo() api.DefaultSort(inf, "username") where, orderBy, pagination, queryValues, errs := dbhelpers.BuildWhereAndOrderByAndPagination(inf.Params, this.ParamColumns()) if len(errs) > 0 { return nil, util.JoinErrs(errs), nil, http.StatusBadRequest, nil } tenantIDs, err := tenant.GetUserTenantIDListTx(inf.Tx.Tx, inf.User.TenantID) if err != nil { return nil, nil, fmt.Errorf("getting tenant list for user: %w", err), http.StatusInternalServerError, nil } where, queryValues = dbhelpers.AddTenancyCheck(where, queryValues, "u.tenant_id", tenantIDs) if useIMS { runSecond, maxTime = ims.TryIfModifiedSinceQuery(this.APIInfo().Tx, h, queryValues, selectMaxLastUpdatedQuery(where)) if !runSecond { log.Debugln("IMS HIT") return []interface{}{}, nil, nil, http.StatusNotModified, &maxTime } log.Debugln("IMS MISS") } else { log.Debugln("Non IMS request") } groupBy := "\n" + `GROUP BY u.id, r.name, t.name` orderBy = groupBy + orderBy version := inf.Version if version == nil { return nil, nil, fmt.Errorf("TOUsers.Read called with invalid API version"), http.StatusInternalServerError, nil } if version.Major >= 4 { query = this.SelectQuery40() + where + orderBy + pagination } else { query = this.SelectQuery() + where + orderBy + pagination } rows, err := inf.Tx.NamedQuery(query, queryValues) if err != nil { return nil, nil, fmt.Errorf("querying users : %w", err), http.StatusInternalServerError, nil } defer rows.Close() type UserGet struct { RoleName *string `json:"rolename" db:"rolename"` tc.User } type UserGet40 struct { UserGet ChangeLogCount *int `json:"changeLogCount" db:"change_log_count"` LastAuthenticated *time.Time `json:"lastAuthenticated" db:"last_authenticated"` } user := &UserGet{} user40 := &UserGet40{} users := []interface{}{} for rows.Next() { if version.Major >= 4 { if err = rows.StructScan(user40); err != nil { return nil, nil, fmt.Errorf("parsing user rows: %w", err), http.StatusInternalServerError, nil } users = append(users, *user40) } else { if err = rows.StructScan(user); err != nil { return nil, nil, fmt.Errorf("parsing user rows: %w", err), http.StatusInternalServerError, nil } users = append(users, *user) } } return users, nil, nil, http.StatusOK, &maxTime } func selectMaxLastUpdatedQuery(where string) string { return `SELECT max(t) from ( SELECT max(u.last_updated) as t FROM tm_user u LEFT JOIN tenant t ON u.tenant_id = t.id LEFT JOIN role r ON u.role = r.id ` + where + ` UNION ALL select max(last_updated) as t from last_deleted l where l.table_name='tm_user') as res` } func (user *TOUser) privCheck() (error, error, int) { var requestedPrivLevel int var err error if user.Role == nil { requestedPrivLevel, _, err = dbhelpers.GetPrivLevelFromRole(user.ReqInfo.Tx.Tx, *user.RoleName) } else { requestedPrivLevel, _, err = dbhelpers.GetPrivLevelFromRoleID(user.ReqInfo.Tx.Tx, *user.Role) } if err != nil { return nil, err, http.StatusInternalServerError } if user.ReqInfo.User.PrivLevel < requestedPrivLevel { return fmt.Errorf("user cannot update a user with a role more privileged than themselves"), nil, http.StatusForbidden } return nil, nil, http.StatusOK } func (user *TOUser) Update(h http.Header) (error, error, int) { // make sure current user cannot update their own role to a new value if user.ReqInfo.User.ID == *user.ID && user.ReqInfo.User.Role != *user.Role { return fmt.Errorf("users cannot update their own role"), nil, http.StatusBadRequest } // make sure the user cannot update someone with a higher priv_level than themselves if usrErr, sysErr, code := user.privCheck(); code != http.StatusOK { return usrErr, sysErr, code } var caps []string var err error if user.Role != nil { caps, err = dbhelpers.GetCapabilitiesFromRoleID(user.ReqInfo.Tx.Tx, *user.Role) } else if user.RoleName != nil { caps, err = dbhelpers.GetCapabilitiesFromRoleName(user.ReqInfo.Tx.Tx, *user.RoleName) } if err != nil { return nil, err, http.StatusInternalServerError } missing := user.ReqInfo.User.MissingPermissions(caps...) if len(missing) != 0 { return fmt.Errorf("cannot request more than assigned permissions, current user needs %s permissions", strings.Join(missing, ",")), nil, http.StatusForbidden } if user.LocalPassword != nil { var err error *user.LocalPassword, err = auth.DerivePassword(*user.LocalPassword) if err != nil { return nil, err, http.StatusInternalServerError } } userErr, sysErr, errCode := api.CheckIfUnModified(h, user.ReqInfo.Tx, *user.ID, "tm_user") if userErr != nil || sysErr != nil { return userErr, sysErr, errCode } resultRows, err := user.ReqInfo.Tx.NamedQuery(user.UpdateQuery(), user) if err != nil { return api.ParseDBError(err) } defer resultRows.Close() var lastUpdated tc.TimeNoMod var tenant string var rolename string rowsAffected := 0 for resultRows.Next() { rowsAffected++ if err := resultRows.Scan(&lastUpdated, &tenant, &rolename); err != nil { return nil, fmt.Errorf("could not scan lastUpdated from insert: %s\n", err), http.StatusInternalServerError } } user.LastUpdated = &lastUpdated user.Tenant = &tenant user.RoleName = &rolename user.LocalPassword = nil if rowsAffected != 1 { if rowsAffected < 1 { return fmt.Errorf("no user found with this id"), nil, http.StatusNotFound } return nil, fmt.Errorf("this update affected too many rows: %d", rowsAffected), http.StatusInternalServerError } return nil, nil, http.StatusOK } func (u *TOUser) IsTenantAuthorized(user *auth.CurrentUser) (bool, error) { // Delete: only id is given // Create: only tenant id // Update: id and tenant id // id is associated with old tenant id // we need to also check new tenant id tx := u.ReqInfo.Tx.Tx if u.ID != nil { // old tenant id (only on update or delete) var tenantID int if err := tx.QueryRow(`SELECT tenant_id from tm_user WHERE id = $1`, *u.ID).Scan(&tenantID); err != nil { if err != sql.ErrNoRows { return false, err } // At this point, tenancy isn't technically 'true', but I can't return a resource not found error here. // Letting it continue will let it run into a 404 when it tries to update. return true, nil } //log.Debugf("%d with tenancy %d trying to access %d with tenancy %d", user.ID, user.TenantID, *u.ID, tenantID) authorized, err := tenant.IsResourceAuthorizedToUserTx(tenantID, user, tx) if err != nil { return false, err } if !authorized { return false, nil } } if u.TenantID != nil { // new tenant id (only on create or udpate) //log.Debugf("%d with tenancy %d trying to access %d", user.ID, user.TenantID, *u.TenantID) authorized, err := tenant.IsResourceAuthorizedToUserTx(*u.TenantID, user, tx) if err != nil { return false, err } if !authorized { return false, nil } } return true, nil } func (user *TOUser) SelectQuery() string { return ` SELECT u.id, u.username as username, u.public_ssh_key, u.role, r.name as rolename, u.company, u.email, u.full_name, u.new_user, u.address_line1, u.address_line2, u.city, u.state_or_province, u.phone_number, u.postal_code, u.country, u.registration_sent, u.tenant_id, t.name as tenant, u.last_updated FROM tm_user u LEFT JOIN tenant t ON u.tenant_id = t.id LEFT JOIN role r ON u.role = r.id` } func (user *TOUser) SelectQuery40() string { return ` SELECT u.id, u.username as username, u.public_ssh_key, u.role, r.name as rolename, u.company, u.email, u.full_name, u.new_user, u.address_line1, u.address_line2, u.city, u.state_or_province, u.phone_number, u.postal_code, u.country, u.registration_sent, u.tenant_id, t.name as tenant, u.last_updated, u.last_authenticated, (SELECT count(l.tm_user) FROM log as l WHERE l.tm_user = u.id) as change_log_count FROM tm_user u LEFT JOIN tenant t ON u.tenant_id = t.id LEFT JOIN role r ON u.role = r.id` } func (user *TOUser) UpdateQuery() string { return ` UPDATE tm_user u SET username=:username, public_ssh_key=:public_ssh_key, role=:role, company=:company, email=:email, full_name=:full_name, new_user=COALESCE(:new_user, FALSE), address_line1=:address_line1, address_line2=:address_line2, city=:city, state_or_province=:state_or_province, phone_number=:phone_number, postal_code=:postal_code, country=:country, tenant_id=:tenant_id, local_passwd=COALESCE(:local_passwd, local_passwd) WHERE id=:id RETURNING last_updated, (SELECT t.name FROM tenant t WHERE id = u.tenant_id), (SELECT r.name FROM role r WHERE id = u.role)` } func UpdateQueryV40() string { return ` UPDATE tm_user u SET username=:username, public_ssh_key=:public_ssh_key, role=(SELECT id FROM role WHERE role.name = :role), company=:company, email=:email, full_name=:full_name, new_user=COALESCE(:new_user, FALSE), address_line1=:address_line1, address_line2=:address_line2, city=:city, state_or_province=:state_or_province, phone_number=:phone_number, postal_code=:postal_code, country=:country, tenant_id=:tenant_id, local_passwd=COALESCE(:local_passwd, local_passwd), ucdn=:ucdn WHERE id=:id RETURNING last_updated, (SELECT t.name FROM tenant t WHERE id = u.tenant_id), (SELECT r.name FROM role r WHERE id = u.role)` } func InsertQueryV40() string { return ` INSERT INTO tm_user ( username, public_ssh_key, role, company, email, full_name, new_user, address_line1, address_line2, city, state_or_province, phone_number, postal_code, country, tenant_id, local_passwd, ucdn ) VALUES ( :username, :public_ssh_key, (SELECT id FROM role WHERE name = :role), :company, :email, :full_name, COALESCE(:new_user, FALSE), :address_line1, :address_line2, :city, :state_or_province, :phone_number, :postal_code, :country, :tenant_id, :local_passwd, :ucdn ) RETURNING id, last_updated, (SELECT t.name FROM tenant t WHERE id = tm_user.tenant_id), (SELECT r.name FROM role r WHERE id = tm_user.role)` } func (user *TOUser) DeleteQuery() string { return `DELETE FROM tm_user WHERE id = :id` } const readBaseQuery = ` SELECT u.id, u.username AS username, u.public_ssh_key, u.company, u.email, u.full_name, u.new_user, u.address_line1, u.address_line2, u.city, u.state_or_province, u.phone_number, u.postal_code, u.country, u.registration_sent, u.tenant_id, t.name AS tenant, u.last_updated, u.ucdn,` const readQuery = readBaseQuery + ` u.last_authenticated, (SELECT count(l.tm_user) FROM log as l WHERE l.tm_user = u.id) as change_log_count, r.name as role FROM tm_user u LEFT JOIN tenant t ON u.tenant_id = t.id LEFT JOIN role r ON u.role = r.id LEFT JOIN role_capability rc ON rc.role_id = r.id ` const legacyReadQuery = readBaseQuery + ` r.name AS rolename, u.role FROM tm_user u LEFT JOIN tenant t ON u.tenant_id = t.id LEFT JOIN role r ON u.role = r.id ` // this is necessary because tc.User doesn't read its RoleName field in sql // driver scans. type userGet struct { RoleName *string `json:"rolename" db:"rolename"` tc.User } type userGet40 struct { userGet ChangeLogCount *int `json:"changeLogCount" db:"change_log_count"` LastAuthenticated *time.Time `json:"lastAuthenticated" db:"last_authenticated"` } func read(rows *sqlx.Rows) ([]tc.UserV4, error) { if rows == nil { return nil, errors.New("cannot read from nil rows") } users := []tc.UserV4{} for rows.Next() { var user tc.UserV4 if err := rows.StructScan(&user); err != nil { return nil, fmt.Errorf("scanning UserV4 row: %w", err) } users = append(users, user) } return users, nil } func getMaxLastUpdated(where string, queryValues map[string]interface{}, tx *sqlx.Tx) (time.Time, error) { query := selectMaxLastUpdatedQuery(where) var t time.Time rows, err := tx.NamedQuery(query, queryValues) if err != nil { return t, fmt.Errorf("query for max user last updated time: %w", err) } defer rows.Close() for rows.Next() { if err = rows.Scan(&t); err != nil { return t, fmt.Errorf("scanning user max last updated time: %w", err) } } return t, nil } // Get is the handler for GET requests made to /users. func Get(w http.ResponseWriter, r *http.Request) { var query string inf, userErr, sysErr, errCode := api.NewInfo(r, nil, nil) tx := inf.Tx.Tx if userErr != nil || sysErr != nil { api.HandleErr(w, r, tx, errCode, userErr, sysErr) return } defer inf.Close() api.DefaultSort(inf, "username") params := map[string]dbhelpers.WhereColumnInfo{ "id": {Column: "u.id", Checker: api.IsInt}, "role": {Column: "r.name"}, "tenant": {Column: "t.name"}, "username": {Column: "u.username"}, } params["company"] = dbhelpers.WhereColumnInfo{Column: "u.company"} params["email"] = dbhelpers.WhereColumnInfo{Column: "u.email"} params["fullName"] = dbhelpers.WhereColumnInfo{Column: "u.full_name"} params["newUser"] = dbhelpers.WhereColumnInfo{Column: "u.new_user"} params["city"] = dbhelpers.WhereColumnInfo{Column: "u.city"} params["stateOrProvince"] = dbhelpers.WhereColumnInfo{Column: "u.state_or_province"} params["country"] = dbhelpers.WhereColumnInfo{Column: "u.country"} params["postalCode"] = dbhelpers.WhereColumnInfo{Column: "u.postal_code"} params["capability"] = dbhelpers.WhereColumnInfo{Column: "rc.cap_name"} where, orderBy, pagination, queryValues, errs := dbhelpers.BuildWhereAndOrderByAndPagination(inf.Params, params) if len(errs) != 0 { api.HandleErr(w, r, tx, http.StatusBadRequest, util.JoinErrs(errs), nil) return } tenantIDs, err := tenant.GetUserTenantIDListTx(inf.Tx.Tx, inf.User.TenantID) if err != nil { api.HandleErr(w, r, tx, http.StatusInternalServerError, nil, fmt.Errorf("getting tenant list for user: %w", err)) return } where, queryValues = dbhelpers.AddTenancyCheck(where, queryValues, "u.tenant_id", tenantIDs) if inf.Config.UseIMS { runSecond, maxTime := ims.TryIfModifiedSinceQuery(inf.Tx, r.Header, queryValues, selectMaxLastUpdatedQuery(where)) if !runSecond { log.Debugln("IMS HIT") w.Header().Add(rfc.LastModified, maxTime.Format(rfc.LastModifiedFormat)) w.WriteHeader(http.StatusNotModified) return } log.Debugln("IMS MISS") } else { log.Debugln("Non IMS request") } groupBy := "\n" + `GROUP BY u.id, r.name, t.name` orderBy = groupBy + orderBy query = readQuery + where + orderBy + pagination rows, err := inf.Tx.NamedQuery(query, queryValues) if err != nil { api.HandleErr(w, r, tx, http.StatusInternalServerError, nil, fmt.Errorf("querying Users: %w", err)) return } defer log.Close(rows, "reading in Users from the database") var response interface{} response, err = read(rows) if err != nil { api.HandleErr(w, r, tx, http.StatusInternalServerError, nil, err) return } if inf.UseIMS() { maxTime, err := getMaxLastUpdated(where, queryValues, inf.Tx) if err != nil { api.HandleErr(w, r, tx, http.StatusInternalServerError, nil, err) return } w.Header().Add(rfc.LastModified, maxTime.Format(rfc.LastModifiedFormat)) } api.WriteResp(w, r, response) } func validate(user TOUser) error { validateErrs := validation.Errors{ "email": validation.Validate(user.Email, validation.Required, is.Email), "fullName": validation.Validate(user.FullName, validation.Required), "role": validation.Validate(user.Role, validation.Required), "username": validation.Validate(user.Username, validation.Required), "tenantID": validation.Validate(user.TenantID, validation.Required), } // Password is not required for update if user.LocalPassword != nil { _, err := auth.IsGoodLoginPair(*user.Username, *user.LocalPassword) if err != nil { return err } } return util.JoinErrs(tovalidate.ToErrors(validateErrs)) } func validateUserV4(user tc.UserV4) error { validateErrs := validation.Errors{ "email": validation.Validate(user.Email, validation.Required, is.Email), "fullName": validation.Validate(user.FullName, validation.Required), "role": validation.Validate(user.Role, validation.Required), "username": validation.Validate(user.Username, validation.Required), "tenantID": validation.Validate(user.TenantID, validation.Required), } // Password is not required for update if user.LocalPassword != nil { _, err := auth.IsGoodLoginPair(user.Username, *user.LocalPassword) if err != nil { return err } } return util.JoinErrs(tovalidate.ToErrors(validateErrs)) } func Create(w http.ResponseWriter, r *http.Request) { var userV4 tc.UserV4 var err error inf, userErr, sysErr, errCode := api.NewInfo(r, nil, nil) if userErr != nil || sysErr != nil { api.HandleErr(w, r, inf.Tx.Tx, errCode, userErr, sysErr) return } defer inf.Close() tx := inf.Tx.Tx if err := json.NewDecoder(r.Body).Decode(&userV4); err != nil { api.HandleErr(w, r, tx, http.StatusBadRequest, err, nil) return } if err := validateUserV4(userV4); err != nil { api.HandleErr(w, r, tx, http.StatusBadRequest, err, nil) return } if err := postValidateV40(userV4); err != nil { api.HandleErr(w, r, tx, http.StatusBadRequest, err, nil) return } toUser := TOUser{ APIInfoImpl: api.APIInfoImpl{ReqInfo: inf}, } toUser.User = userV4.Downgrade() authorized, err := toUser.IsTenantAuthorized(inf.User) if err != nil { api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, errors.New("checking tenant authorized: "+err.Error())) return } if !authorized { api.HandleErr(w, r, inf.Tx.Tx, http.StatusForbidden, errors.New("not authorized on this tenant"), nil) return } // Convert password to SCRYPT *userV4.LocalPassword, err = auth.DerivePassword(*userV4.LocalPassword) if err != nil { api.HandleErr(w, r, tx, http.StatusBadRequest, err, nil) return } var resultRows *sqlx.Rows _, ok, err := dbhelpers.GetRoleIDFromName(tx, userV4.Role) if err != nil { api.HandleErr(w, r, tx, http.StatusInternalServerError, nil, fmt.Errorf("error fetching ID from role name: %w", err)) return } else if !ok { api.HandleErr(w, r, tx, http.StatusNotFound, errors.New("role not found"), nil) return } var caps []string caps, err = dbhelpers.GetCapabilitiesFromRoleName(tx, userV4.Role) if err != nil { api.HandleErr(w, r, tx, http.StatusInternalServerError, nil, err) return } missing := inf.User.MissingPermissions(caps...) if len(missing) != 0 { api.HandleErr(w, r, tx, http.StatusForbidden, fmt.Errorf("cannot request more than assigned permissions, current user needs %s permissions", strings.Join(missing, ",")), nil) return } resultRows, err = inf.Tx.NamedQuery(InsertQueryV40(), userV4) if err != nil { userErr, sysErr, statusCode := api.ParseDBError(err) api.HandleErr(w, r, tx, statusCode, userErr, sysErr) return } defer resultRows.Close() var id int var lastUpdated time.Time var tenant string var rolename string var changeLogMsg string rowsAffected := 0 for resultRows.Next() { rowsAffected++ if err = resultRows.Scan(&id, &lastUpdated, &tenant, &rolename); err != nil { api.HandleErr(w, r, tx, http.StatusInternalServerError, nil, fmt.Errorf("could not scan after insert: %w)", err)) return } } if rowsAffected == 0 { api.HandleErr(w, r, tx, http.StatusInternalServerError, nil, fmt.Errorf("no userV4 was inserted, nothing was returned")) return } else if rowsAffected > 1 { api.HandleErr(w, r, tx, http.StatusInternalServerError, nil, fmt.Errorf("too many rows affected from userV4 insert")) return } userV4.ID = &id userV4.LastUpdated = lastUpdated userV4.Tenant = &tenant userV4.Role = rolename userV4.LocalPassword = nil userResponse := tc.UserResponseV4{ Response: userV4, Alerts: tc.CreateAlerts(tc.SuccessLevel, "user was created."), } w.Header().Set(rfc.Location, fmt.Sprintf("/api/%s/users?id=%d", inf.Version, *userV4.ID)) api.WriteAlertsObj(w, r, http.StatusCreated, userResponse.Alerts, userResponse.Response) changeLogMsg = fmt.Sprintf("USER: %s, ID: %d, ACTION: Created User", userV4.Username, *userV4.ID) api.CreateChangeLogRawTx(api.ApiChange, changeLogMsg, inf.User, tx) return } func (user *TOUser) InsertQuery() string { return ` INSERT INTO tm_user ( username, public_ssh_key, role, company, email, full_name, new_user, address_line1, address_line2, city, state_or_province, phone_number, postal_code, country, tenant_id, local_passwd ) VALUES ( :username, :public_ssh_key, :role, :company, :email, :full_name, COALESCE(:new_user, FALSE), :address_line1, :address_line2, :city, :state_or_province, :phone_number, :postal_code, :country, :tenant_id, :local_passwd ) RETURNING id, last_updated, (SELECT t.name FROM tenant t WHERE id = tm_user.tenant_id), (SELECT r.name FROM role r WHERE id = tm_user.role)` } // Update is the handler for PUT requests made to /users. func Update(w http.ResponseWriter, r *http.Request) { var userV4 tc.UserV4 var roleID int inf, userErr, sysErr, errCode := api.NewInfo(r, nil, nil) tx := inf.Tx.Tx if userErr != nil || sysErr != nil { api.HandleErr(w, r, tx, errCode, userErr, sysErr) return } defer inf.Close() idParam, ok := inf.Params["id"] if !ok { api.HandleErr(w, r, tx, http.StatusBadRequest, errors.New("no ID supplied"), nil) return } id, err := strconv.Atoi(idParam) if err != nil { api.HandleErr(w, r, tx, http.StatusBadRequest, errors.New("couldn't convert id into an int"), nil) return } if err := json.NewDecoder(r.Body).Decode(&userV4); err != nil { api.HandleErr(w, r, tx, http.StatusBadRequest, err, nil) return } if err := validateUserV4(userV4); err != nil { api.HandleErr(w, r, tx, http.StatusBadRequest, err, nil) return } userV4.ID = &id roleID, ok, err = dbhelpers.GetRoleIDFromName(inf.Tx.Tx, userV4.Role) if err != nil { api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, err) return } else if !ok { api.HandleErr(w, r, tx, http.StatusNotFound, errors.New("no such role"), nil) return } // make sure current userV4 cannot update their own role to a new value if inf.User.ID == *userV4.ID && inf.User.Role != roleID { api.HandleErr(w, r, inf.Tx.Tx, http.StatusBadRequest, fmt.Errorf("users cannot update their own role"), nil) return } toUser := TOUser{ APIInfoImpl: api.APIInfoImpl{ReqInfo: inf}, } toUser.User = userV4.Downgrade() authorized, err := toUser.IsTenantAuthorized(inf.User) if err != nil { api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, errors.New("checking tenant authorized: "+err.Error())) return } if !authorized { api.HandleErr(w, r, inf.Tx.Tx, http.StatusForbidden, errors.New("not authorized on this tenant"), nil) return } // make sure the userV4 cannot create someone with a higher priv_level than themselves if userErr, sysErr, code := toUser.privCheck(); code != http.StatusOK { api.HandleErr(w, r, tx, code, userErr, sysErr) return } if userV4.LocalPassword != nil { // Convert password to SCRYPT *userV4.LocalPassword, err = auth.DerivePassword(*userV4.LocalPassword) if err != nil { api.HandleErr(w, r, tx, http.StatusBadRequest, err, nil) return } } var caps []string caps, err = dbhelpers.GetCapabilitiesFromRoleName(tx, userV4.Role) if err != nil { api.HandleErr(w, r, tx, http.StatusInternalServerError, nil, err) return } missing := inf.User.MissingPermissions(caps...) if len(missing) != 0 { api.HandleErr(w, r, tx, http.StatusForbidden, fmt.Errorf("cannot request more than assigned permissions, current user needs %s permissions", strings.Join(missing, ",")), nil) return } userErr, sysErr, errCode = api.CheckIfUnModified(r.Header, inf.Tx, id, "tm_user") if userErr != nil || sysErr != nil { api.HandleErr(w, r, tx, errCode, userErr, sysErr) return } var resultRows *sqlx.Rows resultRows, err = inf.Tx.NamedQuery(UpdateQueryV40(), userV4) if err != nil { api.ParseDBError(err) return } defer resultRows.Close() var lastUpdated time.Time var tenant string var rolename string var changeLogMsg string rowsAffected := 0 for resultRows.Next() { rowsAffected++ if err := resultRows.Scan(&lastUpdated, &tenant, &rolename); err != nil { api.HandleErr(w, r, tx, http.StatusInternalServerError, nil, fmt.Errorf("could not scan lastUpdated from insert: %s\n", err)) return } } if rowsAffected != 1 { if rowsAffected < 1 { api.HandleErr(w, r, tx, http.StatusNotFound, fmt.Errorf("no user found with this id"), nil) return } api.HandleErr(w, r, tx, http.StatusInternalServerError, nil, fmt.Errorf("this update affected too many rows: %d", rowsAffected)) return } userV4.LastUpdated = lastUpdated userV4.Tenant = &tenant userV4.Role = rolename userV4.LocalPassword = nil userResponse := tc.UserResponseV4{ Response: userV4, Alerts: tc.CreateAlerts(tc.SuccessLevel, "user was updated."), } api.WriteAlertsObj(w, r, http.StatusOK, userResponse.Alerts, userResponse.Response) changeLogMsg = fmt.Sprintf("USER: %s, ID: %d, ACTION: Updated User", userV4.Username, *userV4.ID) api.CreateChangeLogRawTx(api.ApiChange, changeLogMsg, inf.User, tx) }