提交 bdca8b49 编写于 作者: W wizardforcel

引入线程池

上级 f2bf5bf3
......@@ -16,6 +16,7 @@ import threading
import time
import traceback
import types
from concurrent.futures import ThreadPoolExecutor
from BiliDriveEx import __version__
from BiliDriveEx.bilibili import Bilibili
from BiliDriveEx.encoder import Encoder
......@@ -24,6 +25,10 @@ from BiliDriveEx.util import *
encoder = Encoder()
api = Bilibili()
succ = True
nblocks = 0
lock = threading.Lock()
def fetch_meta(s):
url = api.meta2real(s)
if not url: return None
......@@ -52,49 +57,29 @@ def userinfo_handle(args):
info = api.get_user_info()
if info: log(info)
else: log("用户未登录")
def tr_upload(i, block, block_dict):
global succ
if not succ: return
enco_block = encoder.encode(block)
r = api.image_upload(enco_block)
if r['code'] == 0:
url = r['data']['image_url']
with lock:
block_dict.update({
'url': url,
'size': len(block),
'sha1': calc_sha1(block),
})
print(f'分块{i + 1}/{nblocks}上传完毕')
else:
print(f"分块{i + 1}/{nblocks}上传失败:{r.get('message')}")
succ = False
def upload_handle(args):
def core(index, block):
try:
block_sha1 = calc_sha1(block)
full_block = encoder.encode(block)
full_block_sha1 = calc_sha1(full_block)
url = api.exist(full_block_sha1)
if url:
log(f"分块{index + 1}/{block_num}上传完毕")
block_dict[index] = {
'url': url,
'size': len(block),
'sha1': block_sha1,
}
else:
# log(f"分块{index + 1}/{block_num}开始上传")
for _ in range(10):
if terminate_flag.is_set():
return
response = api.image_upload(full_block)
if response:
if response['code'] == 0:
url = response['data']['image_url']
log(f"分块{index + 1}/{block_num}上传完毕")
block_dict[index] = {
'url': url,
'size': len(block),
'sha1': block_sha1,
}
return
elif response['code'] == -4:
terminate_flag.set()
log(f"分块{index + 1}/{block_num}{_ + 1}次上传失败, 请重新登录")
return
log(f"分块{index + 1}/{block_num}{_ + 1}次上传失败")
else:
terminate_flag.set()
except:
terminate_flag.set()
traceback.print_exc()
finally:
done_flag.release()
global succ
global nblocks
start_time = time.time()
file_name = args.file
......@@ -118,133 +103,99 @@ def upload_handle(args):
return
log(f"线程数: {args.thread}")
done_flag = threading.Semaphore(0)
terminate_flag = threading.Event()
thread_pool = []
block_dict = {}
block_num = math.ceil(os.path.getsize(file_name) / (args.block_size * 1024 * 1024))
for index, block in enumerate(read_in_chunk(file_name, size=args.block_size * 1024 * 1024)):
if len(thread_pool) >= args.thread:
done_flag.acquire()
if not terminate_flag.is_set():
thread_pool.append(threading.Thread(target=core, args=(index, block)))
thread_pool[-1].start()
else:
log("已终止上传, 等待线程回收")
break
for thread in thread_pool:
thread.join()
if terminate_flag.is_set():
return
succ = True
nblocks = math.ceil(os.path.getsize(file_name) / (args.block_size * 1024 * 1024))
block_dicts = [{} for _ in range(nblocks)]
trpool = ThreadPoolExecutor(args.thread)
hdls = []
blocks = read_in_chunk(file_name, size=args.block_size * 1024 * 1024)
for i, block in enumerate(blocks):
hdl = trpool.submit(tr_upload, i, block, block_dicts[i])
hdls.append(hdl)
for h in hdls: h.result()
if not succ: return
sha1 = calc_sha1(read_in_chunk(file_name))
meta_dict = {
'time': int(time.time()),
'filename': os.path.basename(file_name),
'size': os.path.getsize(file_name),
'sha1': sha1,
'block': [block_dict[i] for i in range(len(block_dict))],
'block': block_dicts,
}
meta = json.dumps(meta_dict, ensure_ascii=False).encode("utf-8")
full_meta = encoder.encode(meta)
for _ in range(10):
response = api.image_upload(full_meta)
if response and response['code'] == 0:
url = response['data']['image_url']
log("元数据上传完毕")
log(f"{meta_dict['filename']} ({size_string(meta_dict['size'])}) 上传完毕, 用时{time.time() - start_time:.1f}秒, 平均速度{size_string(meta_dict['size'] / (time.time() - start_time))}/s")
log(f"META URL -> {api.real2meta(url)}")
write_history(first_4mb_sha1, meta_dict, url)
return url
log(f"元数据第{_ + 1}次上传失败")
r = api.image_upload(full_meta)
if r['code'] == 0:
url = r['data']['image_url']
log("元数据上传完毕")
log(f"{meta_dict['filename']} ({size_string(meta_dict['size'])}) 上传完毕, 用时{time.time() - start_time:.1f}秒, 平均速度{size_string(meta_dict['size'] / (time.time() - start_time))}/s")
log(f"META URL -> {api.real2meta(url)}")
write_history(first_4mb_sha1, meta_dict, url)
return url
else:
log(f"元数据上传失败:{r.get('message')}")
return
def download_handle(args):
def core(index, block_dict):
try:
# log(f"分块{index + 1}/{len(meta_dict['block'])}开始下载")
for _ in range(10):
if terminate_flag.is_set():
return
block = image_download(block_dict['url'])
if block:
block = encoder.decode(block)
if calc_sha1(block) == block_dict['sha1']:
file_lock.acquire()
f.seek(block_offset(index))
f.write(block)
file_lock.release()
log(f"分块{index + 1}/{len(meta_dict['block'])}下载完毕")
return
else:
log(f"分块{index + 1}/{len(meta_dict['block'])}校验未通过")
terminate_flag.set()
else:
log(f"分块{index + 1}/{len(meta_dict['block'])}{_ + 1}次下载失败")
terminate_flag.set()
except:
terminate_flag.set()
traceback.print_exc()
finally:
done_flag.release()
def block_offset(index):
return sum(meta_dict['block'][i]['size'] for i in range(index))
def tr_download(i, block_dict, f, offset):
global succ
def is_overwritable(file_name):
if args.force:
return True
else:
return (input("文件已存在, 是否覆盖? [y/N] ") in ["y", "Y"])
if not succ: return
url = block_dict['url']
block = image_download(url)
if not block:
log(f"分块{i + 1}/{nblocks}下载失败")
succ = False
return
block = encoder.decode(block)
if calc_sha1(block) == block_dict['sha1']:
with lock:
f.seek(offset)
f.write(block)
print(f"分块{i + 1}/{nblocks}下载完毕")
else:
print(f"分块{i + 1}/{nblocks}校验未通过")
succ = False
def download_handle(args):
global succ
global nblocks
start_time = time.time()
meta_dict = fetch_meta(args.meta)
if meta_dict:
file_name = args.file if args.file else meta_dict['filename']
log(f"下载: {os.path.basename(file_name)} ({size_string(meta_dict['size'])}), 共有{len(meta_dict['block'])}个分块, 上传于{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(meta_dict['time']))}")
else:
if not meta_dict:
log("元数据解析失败")
return
log(f"线程数: {args.thread}")
download_block_list = []
file_name = args.file if args.file else meta_dict['filename']
log(f"下载: {os.path.basename(file_name)} ({size_string(meta_dict['size'])}), 共有{len(meta_dict['block'])}个分块, 上传于{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(meta_dict['time']))}")
if os.path.exists(file_name):
if os.path.getsize(file_name) == meta_dict['size'] and calc_sha1(read_in_chunk(file_name)) == meta_dict['sha1']:
log("文件已存在, 且与服务器端内容一致")
return file_name
elif is_overwritable(file_name):
with open(file_name, "rb") as f:
for index, block_dict in enumerate(meta_dict['block']):
f.seek(block_offset(index))
if calc_sha1(f.read(block_dict['size'])) == block_dict['sha1']:
# log(f"分块{index + 1}/{len(meta_dict['block'])}校验通过")
pass
else:
# log(f"分块{index + 1}/{len(meta_dict['block'])}校验未通过")
download_block_list.append(index)
log(f"{len(download_block_list)}/{len(meta_dict['block'])}个分块待下载")
else:
return
else:
download_block_list = list(range(len(meta_dict['block'])))
done_flag = threading.Semaphore(0)
terminate_flag = threading.Event()
file_lock = threading.Lock()
thread_pool = []
with open(file_name, "r+b" if os.path.exists(file_name) else "wb") as f:
for index in download_block_list:
if len(thread_pool) >= args.thread:
done_flag.acquire()
if not terminate_flag.is_set():
thread_pool.append(threading.Thread(target=core, args=(index, meta_dict['block'][index])))
thread_pool[-1].start()
else:
log("已终止下载, 等待线程回收")
break
for thread in thread_pool:
thread.join()
if terminate_flag.is_set():
if not args.force and not ask_overwrite():
return
log(f"线程数: {args.thread}")
succ = True
nblocks = len(meta_dict['block'])
trpool = ThreadPoolExecutor(args.thread)
hdls = []
mode = "r+b" if os.path.exists(file_name) else "wb"
with open(file_name, mode) as f:
for i in range(nblocks):
offset = block_offset(meta_dict, i)
hdl = trpool.submit(tr_download, i, meta_dict['block'][i], f, offset)
hdls.append(hdl)
for h in hdls: h.result()
if not succ: return
f.truncate(sum(block['size'] for block in meta_dict['block']))
log(f"{os.path.basename(file_name)} ({size_string(meta_dict['size'])}) 下载完毕, 用时{time.time() - start_time:.1f}秒, 平均速度{size_string(meta_dict['size'] / (time.time() - start_time))}/s")
sha1 = calc_sha1(read_in_chunk(file_name))
if sha1 == meta_dict['sha1']:
......
......@@ -72,7 +72,7 @@ def read_in_chunk(fname, size=4 * 1024 * 1024, cnt=-1):
def log(message):
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {message}")
def request_retry(method, url, retry=5, **kwargs):
def request_retry(method, url, retry=10, **kwargs):
kwargs.setdefault('timeout', 10)
for i in range(retry):
try:
......@@ -80,8 +80,8 @@ def request_retry(method, url, retry=5, **kwargs):
except Exception as ex:
if i == retry - 1: raise ex
get_retry = lambda url, retry=5, **kwargs: request_retry('GET', url, retry, **kwargs)
post_retry = lambda url, retry=5, **kwargs: request_retry('POST', url, retry, **kwargs)
get_retry = lambda url, retry=10, **kwargs: request_retry('GET', url, retry, **kwargs)
post_retry = lambda url, retry=10, **kwargs: request_retry('POST', url, retry, **kwargs)
def print_meta(meta_dict):
print(f"文件名: {meta_dict['filename']}")
......@@ -91,3 +91,9 @@ def print_meta(meta_dict):
print(f"分块数: {len(meta_dict['block'])}")
for index, block_dict in enumerate(meta_dict['block']):
print(f"分块{index + 1} ({size_string(block_dict['size'])}) URL: {block_dict['url']}")
def block_offset(meta_dict, i):
return sum(meta_dict['block'][j]['size'] for j in range(i))
def ask_overwrite():
return (input(f"文件已存在, 是否覆盖? [y/N] ") in ["y", "Y"])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册