Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
alist-org
GitHub Repository: alist-org/alist
Path: blob/main/internal/db/user.go
1560 views
1
package db
2
3
import (
4
"encoding/base64"
5
"fmt"
6
"github.com/alist-org/alist/v3/internal/model"
7
"github.com/alist-org/alist/v3/pkg/utils"
8
"github.com/go-webauthn/webauthn/webauthn"
9
"github.com/pkg/errors"
10
"gorm.io/gorm"
11
"path"
12
"slices"
13
"strings"
14
)
15
16
func GetUserByRole(role int) (*model.User, error) {
17
var users []model.User
18
if err := db.Find(&users).Error; err != nil {
19
return nil, err
20
}
21
for i := range users {
22
if users[i].Role.Contains(role) {
23
return &users[i], nil
24
}
25
}
26
return nil, gorm.ErrRecordNotFound
27
}
28
29
func GetUsersByRole(roleID int) ([]model.User, error) {
30
var users []model.User
31
if err := db.Find(&users).Error; err != nil {
32
return nil, err
33
}
34
var result []model.User
35
for _, u := range users {
36
if slices.Contains(u.Role, roleID) {
37
result = append(result, u)
38
}
39
}
40
return result, nil
41
}
42
43
func GetUserByName(username string) (*model.User, error) {
44
user := model.User{Username: username}
45
if err := db.Where(user).First(&user).Error; err != nil {
46
return nil, errors.Wrapf(err, "failed find user")
47
}
48
return &user, nil
49
}
50
51
func GetUserBySSOID(ssoID string) (*model.User, error) {
52
user := model.User{SsoID: ssoID}
53
if err := db.Where(user).First(&user).Error; err != nil {
54
return nil, errors.Wrapf(err, "The single sign on platform is not bound to any users")
55
}
56
return &user, nil
57
}
58
59
func GetUserById(id uint) (*model.User, error) {
60
var u model.User
61
if err := db.First(&u, id).Error; err != nil {
62
return nil, errors.Wrapf(err, "failed get old user")
63
}
64
return &u, nil
65
}
66
67
func CreateUser(u *model.User) error {
68
return errors.WithStack(db.Create(u).Error)
69
}
70
71
func UpdateUser(u *model.User) error {
72
return errors.WithStack(db.Save(u).Error)
73
}
74
75
func GetUsers(pageIndex, pageSize int) (users []model.User, count int64, err error) {
76
userDB := db.Model(&model.User{})
77
if err := userDB.Count(&count).Error; err != nil {
78
return nil, 0, errors.Wrapf(err, "failed get users count")
79
}
80
if err := userDB.Order(columnName("id")).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&users).Error; err != nil {
81
return nil, 0, errors.Wrapf(err, "failed get find users")
82
}
83
return users, count, nil
84
}
85
86
func DeleteUserById(id uint) error {
87
return errors.WithStack(db.Delete(&model.User{}, id).Error)
88
}
89
90
func UpdateAuthn(userID uint, authn string) error {
91
return db.Model(&model.User{ID: userID}).Update("authn", authn).Error
92
}
93
94
func RegisterAuthn(u *model.User, credential *webauthn.Credential) error {
95
if u == nil {
96
return errors.New("user is nil")
97
}
98
exists := u.WebAuthnCredentials()
99
if credential != nil {
100
exists = append(exists, *credential)
101
}
102
res, err := utils.Json.Marshal(exists)
103
if err != nil {
104
return err
105
}
106
return UpdateAuthn(u.ID, string(res))
107
}
108
109
func RemoveAuthn(u *model.User, id string) error {
110
exists := u.WebAuthnCredentials()
111
for i := 0; i < len(exists); i++ {
112
idEncoded := base64.StdEncoding.EncodeToString(exists[i].ID)
113
if idEncoded == id {
114
exists[len(exists)-1], exists[i] = exists[i], exists[len(exists)-1]
115
exists = exists[:len(exists)-1]
116
break
117
}
118
}
119
120
res, err := utils.Json.Marshal(exists)
121
if err != nil {
122
return err
123
}
124
return UpdateAuthn(u.ID, string(res))
125
}
126
127
func UpdateUserBasePathPrefix(oldPath, newPath string, usersOpt ...[]model.User) ([]string, error) {
128
var users []model.User
129
var modifiedUsernames []string
130
131
oldPathClean := path.Clean(oldPath)
132
133
if len(usersOpt) > 0 {
134
users = usersOpt[0]
135
} else {
136
if err := db.Find(&users).Error; err != nil {
137
return nil, errors.WithMessage(err, "failed to load users")
138
}
139
}
140
141
for _, user := range users {
142
basePath := path.Clean(user.BasePath)
143
updated := false
144
145
if basePath == oldPathClean {
146
user.BasePath = path.Clean(newPath)
147
updated = true
148
} else if strings.HasPrefix(basePath, oldPathClean+"/") {
149
user.BasePath = path.Clean(newPath + basePath[len(oldPathClean):])
150
updated = true
151
}
152
153
if updated {
154
if err := UpdateUser(&user); err != nil {
155
return nil, errors.WithMessagef(err, "failed to update user ID %d", user.ID)
156
}
157
modifiedUsernames = append(modifiedUsernames, user.Username)
158
}
159
}
160
161
return modifiedUsernames, nil
162
}
163
164
func CountUsersByRoleAndEnabledExclude(roleID uint, excludeUserID uint) (int64, error) {
165
var count int64
166
jsonValue := fmt.Sprintf("[%d]", roleID)
167
err := db.Model(&model.User{}).
168
Where("disabled = ? AND id != ?", false, excludeUserID).
169
Where("JSON_CONTAINS(role, ?)", jsonValue).
170
Count(&count).Error
171
return count, err
172
}
173
174