Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
alist-org
GitHub Repository: alist-org/alist
Path: blob/main/drivers/baidu_netdisk/driver.go
2325 views
1
package baidu_netdisk
2
3
import (
4
"context"
5
"crypto/md5"
6
"encoding/hex"
7
"errors"
8
"fmt"
9
"io"
10
"net/url"
11
"os"
12
stdpath "path"
13
"strconv"
14
"strings"
15
"sync"
16
"time"
17
18
"github.com/alist-org/alist/v3/drivers/base"
19
"github.com/alist-org/alist/v3/internal/conf"
20
"github.com/alist-org/alist/v3/internal/driver"
21
"github.com/alist-org/alist/v3/internal/errs"
22
"github.com/alist-org/alist/v3/internal/model"
23
"github.com/alist-org/alist/v3/pkg/errgroup"
24
"github.com/alist-org/alist/v3/pkg/singleflight"
25
"github.com/alist-org/alist/v3/pkg/utils"
26
"github.com/avast/retry-go"
27
"github.com/go-resty/resty/v2"
28
log "github.com/sirupsen/logrus"
29
)
30
31
type BaiduNetdisk struct {
32
model.Storage
33
Addition
34
35
uploadThread int
36
vipType int // 会员类型,0普通用户(4G/4M)、1普通会员(10G/16M)、2超级会员(20G/32M)
37
38
upClient *resty.Client // 上传文件使用的http客户端
39
uploadUrlG singleflight.Group[string]
40
uploadUrlMu sync.RWMutex
41
uploadUrl string // 上传域名
42
uploadUrlUpdateTime time.Time // 上传域名上次更新时间
43
}
44
45
var ErrUploadIDExpired = errors.New("uploadid expired")
46
47
func (d *BaiduNetdisk) Config() driver.Config {
48
return config
49
}
50
51
func (d *BaiduNetdisk) GetAddition() driver.Additional {
52
return &d.Addition
53
}
54
55
func (d *BaiduNetdisk) Init(ctx context.Context) error {
56
d.upClient = base.NewRestyClient().
57
SetTimeout(UPLOAD_TIMEOUT).
58
SetRetryCount(UPLOAD_RETRY_COUNT).
59
SetRetryWaitTime(UPLOAD_RETRY_WAIT_TIME).
60
SetRetryMaxWaitTime(UPLOAD_RETRY_MAX_WAIT_TIME)
61
d.uploadThread, _ = strconv.Atoi(d.UploadThread)
62
if d.uploadThread < 1 {
63
d.uploadThread, d.UploadThread = 1, "1"
64
} else if d.uploadThread > 32 {
65
d.uploadThread, d.UploadThread = 32, "32"
66
}
67
68
if _, err := url.Parse(d.UploadAPI); d.UploadAPI == "" || err != nil {
69
d.UploadAPI = UPLOAD_FALLBACK_API
70
}
71
72
res, err := d.get("/xpan/nas", map[string]string{
73
"method": "uinfo",
74
}, nil)
75
log.Debugf("[baidu_netdisk] get uinfo: %s", string(res))
76
if err != nil {
77
return err
78
}
79
d.vipType = utils.Json.Get(res, "vip_type").ToInt()
80
return nil
81
}
82
83
func (d *BaiduNetdisk) Drop(ctx context.Context) error {
84
return nil
85
}
86
87
func (d *BaiduNetdisk) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) {
88
files, err := d.getFiles(dir.GetPath())
89
if err != nil {
90
return nil, err
91
}
92
return utils.SliceConvert(files, func(src File) (model.Obj, error) {
93
return fileToObj(src), nil
94
})
95
}
96
97
func (d *BaiduNetdisk) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) {
98
if d.DownloadAPI == "crack" {
99
return d.linkCrack(file, args)
100
} else if d.DownloadAPI == "crack_video" {
101
return d.linkCrackVideo(file, args)
102
}
103
return d.linkOfficial(file, args)
104
}
105
106
func (d *BaiduNetdisk) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) {
107
var newDir File
108
_, err := d.create(stdpath.Join(parentDir.GetPath(), dirName), 0, 1, "", "", &newDir, 0, 0)
109
if err != nil {
110
return nil, err
111
}
112
return fileToObj(newDir), nil
113
}
114
115
func (d *BaiduNetdisk) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) {
116
data := []base.Json{
117
{
118
"path": srcObj.GetPath(),
119
"dest": dstDir.GetPath(),
120
"newname": srcObj.GetName(),
121
},
122
}
123
_, err := d.manage("move", data)
124
if err != nil {
125
return nil, err
126
}
127
if srcObj, ok := srcObj.(*model.ObjThumb); ok {
128
srcObj.SetPath(stdpath.Join(dstDir.GetPath(), srcObj.GetName()))
129
srcObj.Modified = time.Now()
130
return srcObj, nil
131
}
132
return nil, nil
133
}
134
135
func (d *BaiduNetdisk) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) {
136
data := []base.Json{
137
{
138
"path": srcObj.GetPath(),
139
"newname": newName,
140
},
141
}
142
_, err := d.manage("rename", data)
143
if err != nil {
144
return nil, err
145
}
146
147
if srcObj, ok := srcObj.(*model.ObjThumb); ok {
148
srcObj.SetPath(stdpath.Join(stdpath.Dir(srcObj.GetPath()), newName))
149
srcObj.Name = newName
150
srcObj.Modified = time.Now()
151
return srcObj, nil
152
}
153
return nil, nil
154
}
155
156
func (d *BaiduNetdisk) Copy(ctx context.Context, srcObj, dstDir model.Obj) error {
157
data := []base.Json{
158
{
159
"path": srcObj.GetPath(),
160
"dest": dstDir.GetPath(),
161
"newname": srcObj.GetName(),
162
},
163
}
164
_, err := d.manage("copy", data)
165
return err
166
}
167
168
func (d *BaiduNetdisk) Remove(ctx context.Context, obj model.Obj) error {
169
data := []string{obj.GetPath()}
170
_, err := d.manage("delete", data)
171
return err
172
}
173
174
func (d *BaiduNetdisk) PutRapid(ctx context.Context, dstDir model.Obj, stream model.FileStreamer) (model.Obj, error) {
175
contentMd5 := stream.GetHash().GetHash(utils.MD5)
176
if len(contentMd5) < utils.MD5.Width {
177
return nil, errors.New("invalid hash")
178
}
179
180
streamSize := stream.GetSize()
181
path := stdpath.Join(dstDir.GetPath(), stream.GetName())
182
mtime := stream.ModTime().Unix()
183
ctime := stream.CreateTime().Unix()
184
blockList, _ := utils.Json.MarshalToString([]string{contentMd5})
185
186
var newFile File
187
_, err := d.create(path, streamSize, 0, "", blockList, &newFile, mtime, ctime)
188
if err != nil {
189
return nil, err
190
}
191
// 修复时间,具体原因见 Put 方法注释的 **注意**
192
newFile.Ctime = stream.CreateTime().Unix()
193
newFile.Mtime = stream.ModTime().Unix()
194
return fileToObj(newFile), nil
195
}
196
197
// Put
198
//
199
// **注意**: 截至 2024/04/20 百度云盘 api 接口返回的时间永远是当前时间,而不是文件时间。
200
// 而实际上云盘存储的时间是文件时间,所以此处需要覆盖时间,保证缓存与云盘的数据一致
201
func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) {
202
// 百度网盘不允许上传空文件
203
if stream.GetSize() < 1 {
204
return nil, ErrBaiduEmptyFilesNotAllowed
205
}
206
207
// rapid upload
208
if newObj, err := d.PutRapid(ctx, dstDir, stream); err == nil {
209
return newObj, nil
210
}
211
212
var (
213
cache = stream.GetFile()
214
tmpF *os.File
215
err error
216
)
217
if _, ok := cache.(io.ReaderAt); !ok {
218
tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*")
219
if err != nil {
220
return nil, err
221
}
222
defer func() {
223
_ = tmpF.Close()
224
_ = os.Remove(tmpF.Name())
225
}()
226
cache = tmpF
227
}
228
229
streamSize := stream.GetSize()
230
sliceSize := d.getSliceSize(streamSize)
231
count := int(streamSize / sliceSize)
232
lastBlockSize := streamSize % sliceSize
233
if lastBlockSize > 0 {
234
count++
235
} else {
236
lastBlockSize = sliceSize
237
}
238
239
//cal md5 for first 256k data
240
const SliceSize int64 = 256 * utils.KB
241
// cal md5
242
blockList := make([]string, 0, count)
243
byteSize := sliceSize
244
fileMd5H := md5.New()
245
sliceMd5H := md5.New()
246
sliceMd5H2 := md5.New()
247
slicemd5H2Write := utils.LimitWriter(sliceMd5H2, SliceSize)
248
writers := []io.Writer{fileMd5H, sliceMd5H, slicemd5H2Write}
249
if tmpF != nil {
250
writers = append(writers, tmpF)
251
}
252
written := int64(0)
253
254
for i := 1; i <= count; i++ {
255
if utils.IsCanceled(ctx) {
256
return nil, ctx.Err()
257
}
258
if i == count {
259
byteSize = lastBlockSize
260
}
261
n, err := utils.CopyWithBufferN(io.MultiWriter(writers...), stream, byteSize)
262
written += n
263
if err != nil && err != io.EOF {
264
return nil, err
265
}
266
blockList = append(blockList, hex.EncodeToString(sliceMd5H.Sum(nil)))
267
sliceMd5H.Reset()
268
}
269
if tmpF != nil {
270
if written != streamSize {
271
return nil, errs.NewErr(err, "CreateTempFile failed, size mismatch: %d != %d ", written, streamSize)
272
}
273
_, err = tmpF.Seek(0, io.SeekStart)
274
if err != nil {
275
return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ")
276
}
277
}
278
contentMd5 := hex.EncodeToString(fileMd5H.Sum(nil))
279
sliceMd5 := hex.EncodeToString(sliceMd5H2.Sum(nil))
280
blockListStr, _ := utils.Json.MarshalToString(blockList)
281
path := stdpath.Join(dstDir.GetPath(), stream.GetName())
282
mtime := stream.ModTime().Unix()
283
ctime := stream.CreateTime().Unix()
284
285
// step.1 尝试读取已保存进度
286
precreateResp, ok := base.GetUploadProgress[*PrecreateResp](d, d.AccessToken, contentMd5)
287
if !ok {
288
// 没有进度,走预上传
289
precreateResp, err = d.precreate(ctx, path, streamSize, blockListStr, contentMd5, sliceMd5, ctime, mtime)
290
if err != nil {
291
return nil, err
292
}
293
if precreateResp.ReturnType == 2 {
294
//rapid upload, since got md5 match from baidu server
295
// 修复时间,具体原因见 Put 方法注释的 **注意**
296
return fileToObj(precreateResp.File), nil
297
}
298
}
299
300
// step.2 上传分片
301
uploadLoop:
302
for attempt := 0; attempt < 2; attempt++ {
303
// 获取上传域名
304
uploadUrl := d.getUploadUrl(path, precreateResp.Uploadid)
305
// 并发上传
306
threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread,
307
retry.Attempts(1),
308
retry.Delay(time.Second),
309
retry.DelayType(retry.BackOffDelay))
310
311
cacheReaderAt, okReaderAt := cache.(io.ReaderAt)
312
if !okReaderAt {
313
return nil, fmt.Errorf("cache object must implement io.ReaderAt interface for upload operations")
314
}
315
316
totalParts := len(precreateResp.BlockList)
317
for i, partseq := range precreateResp.BlockList {
318
if utils.IsCanceled(upCtx) || partseq < 0 {
319
continue
320
}
321
322
i, partseq := i, partseq
323
offset, size := int64(partseq)*sliceSize, sliceSize
324
if partseq+1 == count {
325
size = lastBlockSize
326
}
327
threadG.Go(func(ctx context.Context) error {
328
params := map[string]string{
329
"method": "upload",
330
"access_token": d.AccessToken,
331
"type": "tmpfile",
332
"path": path,
333
"uploadid": precreateResp.Uploadid,
334
"partseq": strconv.Itoa(partseq),
335
}
336
section := io.NewSectionReader(cacheReaderAt, offset, size)
337
err := d.uploadSlice(ctx, uploadUrl, params, stream.GetName(), driver.NewLimitedUploadStream(ctx, section))
338
if err != nil {
339
return err
340
}
341
precreateResp.BlockList[i] = -1
342
// 当前goroutine还没退出,+1才是真正成功的数量
343
success := threadG.Success() + 1
344
progress := float64(success) * 100 / float64(totalParts)
345
up(progress)
346
return nil
347
})
348
}
349
350
err = threadG.Wait()
351
if err == nil {
352
break uploadLoop
353
}
354
355
// 保存进度(所有错误都会保存)
356
precreateResp.BlockList = utils.SliceFilter(precreateResp.BlockList, func(s int) bool { return s >= 0 })
357
base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5)
358
359
if errors.Is(err, context.Canceled) {
360
return nil, err
361
}
362
if errors.Is(err, ErrUploadIDExpired) {
363
log.Warn("[baidu_netdisk] uploadid expired, will restart from scratch")
364
// 重新 precreate(所有分片都要重传)
365
newPre, err2 := d.precreate(ctx, path, streamSize, blockListStr, "", "", ctime, mtime)
366
if err2 != nil {
367
return nil, err2
368
}
369
if newPre.ReturnType == 2 {
370
return fileToObj(newPre.File), nil
371
}
372
precreateResp = newPre
373
// 覆盖掉旧的进度
374
base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5)
375
continue uploadLoop
376
}
377
return nil, err
378
}
379
380
// step.3 创建文件
381
var newFile File
382
_, err = d.create(path, streamSize, 0, precreateResp.Uploadid, blockListStr, &newFile, mtime, ctime)
383
if err != nil {
384
return nil, err
385
}
386
// 修复时间,具体原因见 Put 方法注释的 **注意**
387
newFile.Ctime = ctime
388
newFile.Mtime = mtime
389
// 上传成功清理进度
390
base.SaveUploadProgress(d, nil, d.AccessToken, contentMd5)
391
return fileToObj(newFile), nil
392
}
393
394
// precreate 执行预上传操作,支持首次上传和 uploadid 过期重试
395
func (d *BaiduNetdisk) precreate(ctx context.Context, path string, streamSize int64, blockListStr, contentMd5, sliceMd5 string, ctime, mtime int64) (*PrecreateResp, error) {
396
params := map[string]string{"method": "precreate"}
397
form := map[string]string{
398
"path": path,
399
"size": strconv.FormatInt(streamSize, 10),
400
"isdir": "0",
401
"autoinit": "1",
402
"rtype": "3",
403
"block_list": blockListStr,
404
}
405
406
// 只有在首次上传时才包含 content-md5 和 slice-md5
407
if contentMd5 != "" && sliceMd5 != "" {
408
form["content-md5"] = contentMd5
409
form["slice-md5"] = sliceMd5
410
}
411
412
joinTime(form, ctime, mtime)
413
414
var precreateResp PrecreateResp
415
_, err := d.postForm("/xpan/file", params, form, &precreateResp)
416
if err != nil {
417
return nil, err
418
}
419
420
// 修复时间,具体原因见 Put 方法注释的 **注意**
421
if precreateResp.ReturnType == 2 {
422
precreateResp.File.Ctime = ctime
423
precreateResp.File.Mtime = mtime
424
}
425
426
return &precreateResp, nil
427
}
428
429
func (d *BaiduNetdisk) uploadSlice(ctx context.Context, uploadUrl string, params map[string]string, fileName string, file io.Reader) error {
430
res, err := d.upClient.R().
431
SetContext(ctx).
432
SetQueryParams(params).
433
SetFileReader("file", fileName, file).
434
Post(uploadUrl + "/rest/2.0/pcs/superfile2")
435
if err != nil {
436
return err
437
}
438
log.Debugln(res.RawResponse.Status + res.String())
439
errCode := utils.Json.Get(res.Body(), "error_code").ToInt()
440
errNo := utils.Json.Get(res.Body(), "errno").ToInt()
441
respStr := res.String()
442
lower := strings.ToLower(respStr)
443
if strings.Contains(lower, "uploadid") &&
444
(strings.Contains(lower, "invalid") || strings.Contains(lower, "expired") || strings.Contains(lower, "not found")) {
445
return ErrUploadIDExpired
446
}
447
448
if errCode != 0 || errNo != 0 {
449
return errs.NewErr(errs.StreamIncomplete, "error uploading to baidu, response=%s", res.String())
450
}
451
return nil
452
}
453
454
var _ driver.Driver = (*BaiduNetdisk)(nil)
455
456