|
|
本帖最后由 fish313331 于 2025-11-21 20:15 编辑
先按照這篇完整教學安裝A卡版本的lada
https://javpcn.com/forum.php?mod ... 1327&extra=page%3D1
這次發現問題出在AMD 放出的Rocm pytorch目前使用運算加速器時有bug,跑不出顯卡該有的實力
真正正確的解法是等AMD更新
然而這邊提出一個暫時折衷的解法,原理是關閉有問題的加速庫
並且將原本模型進行的fp32運算降低精度為fp16計算,這麼做理論上會降低計算結果的精度
但是速度會快許多
修改的地方有兩個
1.lada/cli/main.py
開頭import torch下面加上
- torch.backends.cudnn.enabled = False
- print(f"enabled cudnn :{torch.backends.cudnn.enabled}")
复制代码
如果是2025/11/15後安裝lada的,新版中預設使用FP16,所以不需要修改第二步了
再次重申,第一步修改原因是AMD的AI加速底層軟體還沒幫9000系開發完,暫時先關閉
以後弄好了兩步驟都不需要,但目前還得等
2. lada/basicvrpp/inference.py 修改 def inference(model, video: list, device, max_frames=-1)函式:
def inference(model, video: list, device, max_frames=-1):
input_frame_count = len(video)
input_frame_shape = video[0].shape
if device and type(device) == str:
device = torch.device(device)
with torch.no_grad():
result = []
input = torch.stack(img2tensor(video, bgr2rgb=False, float32=True), dim=0)
input = torch.unsqueeze(input, dim=0) # TCHW -> BTCHW
input = input.to(device)
#將模型和輸入的精度降低為FP16
model = model.half()
input = input.half()
if max_frames > 0:
for i in range(0, input.shape[1], max_frames):
output = model(inputs=input[:, i:i + max_frames])
result.append(output)
result = torch.cat(result, dim=1)
else:
result = model(inputs=input)
result = torch.squeeze(result, dim=0) # BTCHW -> TCHW
result = list(torch.unbind(result, 0))
#將結果格式擴充回FP32
result = [r.float() for r in result]
output = tensor2img(result, rgb2bgr=False, out_type=np.uint8, min_max=(0, 1))
output_frame_count = len(output)
output_frame_shape = output[0].shape
assert input_frame_count == output_frame_count and input_frame_shape == output_frame_shape
return output
|
|