internal/repo/user/user_repo.go (312 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package user import ( "context" "strings" "time" "github.com/apache/answer/internal/base/data" "github.com/apache/answer/internal/base/reason" "github.com/apache/answer/internal/entity" "github.com/apache/answer/internal/schema" usercommon "github.com/apache/answer/internal/service/user_common" "github.com/apache/answer/pkg/converter" "github.com/apache/answer/plugin" "github.com/segmentfault/pacman/errors" "github.com/segmentfault/pacman/log" "xorm.io/xorm" ) // userRepo user repository type userRepo struct { data *data.Data } // NewUserRepo new repository func NewUserRepo(data *data.Data) usercommon.UserRepo { return &userRepo{ data: data, } } // AddUser add user func (ur *userRepo) AddUser(ctx context.Context, user *entity.User) (err error) { _, err = ur.data.DB.Transaction(func(session *xorm.Session) (interface{}, error) { session = session.Context(ctx) userInfo := &entity.User{} exist, err := session.Where("username = ?", user.Username).Get(userInfo) if err != nil { return nil, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } if exist { return nil, errors.InternalServer(reason.UsernameDuplicate) } _, err = session.Insert(user) if err != nil { return nil, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } return nil, nil }) return } // IncreaseAnswerCount increase answer count func (ur *userRepo) IncreaseAnswerCount(ctx context.Context, userID string, amount int) (err error) { user := &entity.User{} _, err = ur.data.DB.Context(ctx).Where("id = ?", userID).Incr("answer_count", amount).Update(user) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } return nil } // IncreaseQuestionCount increase question count func (ur *userRepo) IncreaseQuestionCount(ctx context.Context, userID string, amount int) (err error) { user := &entity.User{} _, err = ur.data.DB.Context(ctx).Where("id = ?", userID).Incr("question_count", amount).Update(user) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } return nil } func (ur *userRepo) UpdateQuestionCount(ctx context.Context, userID string, count int64) (err error) { user := &entity.User{} user.QuestionCount = int(count) _, err = ur.data.DB.Context(ctx).Where("id = ?", userID).Cols("question_count").Update(user) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } return nil } func (ur *userRepo) UpdateAnswerCount(ctx context.Context, userID string, count int) (err error) { user := &entity.User{} user.AnswerCount = count _, err = ur.data.DB.Context(ctx).Where("id = ?", userID).Cols("answer_count").Update(user) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } return nil } // UpdateLastLoginDate update last login date func (ur *userRepo) UpdateLastLoginDate(ctx context.Context, userID string) (err error) { user := &entity.User{LastLoginDate: time.Now()} _, err = ur.data.DB.Context(ctx).Where("id = ?", userID).Cols("last_login_date").Update(user) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } return nil } // UpdateEmailStatus update email status func (ur *userRepo) UpdateEmailStatus(ctx context.Context, userID string, emailStatus int) error { cond := &entity.User{MailStatus: emailStatus} _, err := ur.data.DB.Context(ctx).Where("id = ?", userID).Cols("mail_status").Update(cond) if err != nil { return err } return nil } // UpdateNoticeStatus update notice status func (ur *userRepo) UpdateNoticeStatus(ctx context.Context, userID string, noticeStatus int) error { cond := &entity.User{NoticeStatus: noticeStatus} _, err := ur.data.DB.Context(ctx).Where("id = ?", userID).Cols("notice_status").Update(cond) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } return nil } func (ur *userRepo) UpdatePass(ctx context.Context, userID, pass string) error { _, err := ur.data.DB.Context(ctx).Where("id = ?", userID).Cols("pass").Update(&entity.User{Pass: pass}) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } return nil } func (ur *userRepo) UpdateEmail(ctx context.Context, userID, email string) (err error) { _, err = ur.data.DB.Context(ctx).Where("id = ?", userID).Update(&entity.User{EMail: email}) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } return } func (ur *userRepo) UpdateUserInterface(ctx context.Context, userID, language, colorSchema string) (err error) { session := ur.data.DB.Context(ctx).Where("id = ?", userID) _, err = session.Cols("language", "color_scheme").Update(&entity.User{Language: language, ColorScheme: colorSchema}) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } return } // UpdateInfo update user info func (ur *userRepo) UpdateInfo(ctx context.Context, userInfo *entity.User) (err error) { _, err = ur.data.DB.Context(ctx).Where("id = ?", userInfo.ID). Cols("username", "display_name", "avatar", "bio", "bio_html", "website", "location").Update(userInfo) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } return } // UpdateUserProfile update user profile func (ur *userRepo) UpdateUserProfile(ctx context.Context, userInfo *entity.User) (err error) { _, err = ur.data.DB.Context(ctx).Where("id = ?", userInfo.ID). Cols("username", "e_mail", "mail_status", "display_name").Update(userInfo) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } return } // GetByUserID get user info by user id func (ur *userRepo) GetByUserID(ctx context.Context, userID string) (userInfo *entity.User, exist bool, err error) { userInfo = &entity.User{} exist, err = ur.data.DB.Context(ctx).Where("id = ?", userID).Get(userInfo) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() return } err = tryToDecorateUserInfoFromUserCenter(ctx, ur.data, userInfo) if err != nil { return nil, false, err } return } func (ur *userRepo) BatchGetByID(ctx context.Context, ids []string) ([]*entity.User, error) { list := make([]*entity.User, 0) err := ur.data.DB.Context(ctx).In("id", ids).Find(&list) if err != nil { return nil, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } tryToDecorateUserListFromUserCenter(ctx, ur.data, list) return list, nil } // GetByUsername get user by username func (ur *userRepo) GetByUsername(ctx context.Context, username string) (userInfo *entity.User, exist bool, err error) { userInfo = &entity.User{} exist, err = ur.data.DB.Context(ctx).Where("username = ?", username).Get(userInfo) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() return } err = tryToDecorateUserInfoFromUserCenter(ctx, ur.data, userInfo) if err != nil { return nil, false, err } return } func (ur *userRepo) GetByUsernames(ctx context.Context, usernames []string) ([]*entity.User, error) { list := make([]*entity.User, 0) err := ur.data.DB.Context(ctx).Where("status =?", entity.UserStatusAvailable).In("username", usernames).Find(&list) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() return list, err } tryToDecorateUserListFromUserCenter(ctx, ur.data, list) return list, nil } // GetByEmail get user by email func (ur *userRepo) GetByEmail(ctx context.Context, email string) (userInfo *entity.User, exist bool, err error) { userInfo = &entity.User{} exist, err = ur.data.DB.Context(ctx).Where("e_mail = ?", email). Where("status != ?", entity.UserStatusDeleted).Get(userInfo) if err != nil { err = errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } return } func (ur *userRepo) GetUserCount(ctx context.Context) (count int64, err error) { session := ur.data.DB.Context(ctx) session.Where("status = ? OR status = ?", entity.UserStatusAvailable, entity.UserStatusSuspended) count, err = session.Count(&entity.User{}) if err != nil { return count, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } return count, nil } func (ur *userRepo) SearchUserListByName(ctx context.Context, name string, limit int, onlyStaff bool) (userList []*entity.User, err error) { userList = make([]*entity.User, 0) session := ur.data.DB.Context(ctx) if onlyStaff { session.Join("INNER", "user_role_rel", "`user`.id = `user_role_rel`.user_id AND `user_role_rel`.role_id > 1") } session.Where("status = ?", entity.UserStatusAvailable) session.Where("username LIKE ? OR display_name LIKE ?", strings.ToLower(name)+"%", name+"%") session.OrderBy("username ASC, `user`.id DESC") session.Limit(limit) err = session.Find(&userList) if err != nil { return nil, errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } tryToDecorateUserListFromUserCenter(ctx, ur.data, userList) return } func tryToDecorateUserInfoFromUserCenter(ctx context.Context, data *data.Data, original *entity.User) (err error) { if original == nil { return nil } uc, ok := plugin.GetUserCenter() if !ok { return nil } userInfo := &entity.UserExternalLogin{} session := data.DB.Context(ctx).Where("user_id = ?", original.ID) session.Where("provider = ?", uc.Info().SlugName) exist, err := session.Get(userInfo) if err != nil { return errors.InternalServer(reason.DatabaseError).WithError(err).WithStack() } if !exist { return nil } userCenterBasicUserInfo, err := uc.UserInfo(userInfo.ExternalID) if err != nil { log.Error(err) return errors.BadRequest(reason.UserNotFound).WithError(err).WithStack() } decorateByUserCenterUser(original, userCenterBasicUserInfo) return nil } func tryToDecorateUserListFromUserCenter(ctx context.Context, data *data.Data, original []*entity.User) { uc, ok := plugin.GetUserCenter() if !ok { return } ids := make([]string, 0) originalUserIDMapping := make(map[string]*entity.User, 0) for _, user := range original { originalUserIDMapping[user.ID] = user ids = append(ids, user.ID) } userExternalLoginList := make([]*entity.UserExternalLogin, 0) session := data.DB.Context(ctx).Where("provider = ?", uc.Info().SlugName) session.In("user_id", ids) err := session.Find(&userExternalLoginList) if err != nil { log.Error(err) return } userExternalIDs := make([]string, 0) originalExternalIDMapping := make(map[string]*entity.User, 0) for _, u := range userExternalLoginList { originalExternalIDMapping[u.ExternalID] = originalUserIDMapping[u.UserID] userExternalIDs = append(userExternalIDs, u.ExternalID) } if len(userExternalIDs) == 0 { return } ucUsers, err := uc.UserList(userExternalIDs) if err != nil { log.Errorf("get user list from user center failed: %v, %v", err, userExternalIDs) return } for _, ucUser := range ucUsers { decorateByUserCenterUser(originalExternalIDMapping[ucUser.ExternalID], ucUser) } } func decorateByUserCenterUser(original *entity.User, ucUser *plugin.UserCenterBasicUserInfo) { if original == nil || ucUser == nil { return } // In general, usernames should be guaranteed unique by the User Center plugin, so there are no inconsistencies. if original.Username != ucUser.Username { log.Warnf("user %s username is inconsistent with user center", original.ID) } if len(ucUser.DisplayName) > 0 { original.DisplayName = ucUser.DisplayName } if len(ucUser.Email) > 0 { original.EMail = ucUser.Email } if len(ucUser.Avatar) > 0 { original.Avatar = schema.CustomAvatar(ucUser.Avatar).ToJsonString() } if len(ucUser.Mobile) > 0 { original.Mobile = ucUser.Mobile } if len(ucUser.Bio) > 0 { original.BioHTML = converter.Markdown2HTML(ucUser.Bio) + original.BioHTML } // If plugin enable rank agent, use rank from user center. if plugin.RankAgentEnabled() { original.Rank = ucUser.Rank } if ucUser.Status != plugin.UserStatusAvailable { original.Status = int(ucUser.Status) } }