What I want to talk about here is the solution that worked fine before, but can't be used after changing to 50 series graphics card. If you haven't installed it before, you need to check the previous installation method first.
Prerequisites:
1. It worked normally before, but it cannot run SageAttention normally after changing to a 50 series card. If you have not installed it before, you don’t need to read it.
2. It can produce images normally.
This means you need to install pytorch 2.7 cu128 first and be able to generate images, see here:
https://github.com/comfyanonymous/ComfyUI/discussions/6643
3.I am using the Chinese comfyui package(即秋葉整合包), For all the commands below, you need to modify your python path to your own comfyui python_embeded path.
Steps:
1. Delete the previous triton and SageAttention environments in E:\ComfyUI-aki-v1.6\python\Lib\site-packages.
2. Install triton 3.3
E:\ComfyUI-aki-v1.6\python\python.exe -m pip install -U --pre triton-windows
3.Create a new script file named test_triton.py
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
def add(x: torch.Tensor, y: torch.Tensor):
output = torch.empty_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
return output
a = torch.rand(3, device="cuda")
b = a + a
b_compiled = add(a, a)
print(b_compiled - b)
print("If you see tensor([0., 0., 0.], device='cuda:0'), then it works")
Run in the directory where test_triton.py is saved:
E:\ComfyUI-aki-v1.6\python\python.exe .\test_triton.py
The following output indicates a successful installation:
tensor([0., 0., 0.], device='cuda:0')
If you see tensor([0., 0., 0.], device='cuda:0'), then it works
4.Clone the SageAttention repository
Go to E:\ComfyUI-aki-v1.6 directory (any directory will do, I just give an example here)
Open cmd:
git clone https://github.com/thu-ml/SageAttention.git
5.Go to the SageAttention directory,edit setup.py
Replace all with the following:
"""
Copyright (c) 2024 by SageAttention team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import subprocess
from packaging.version import parse, Version
from typing import List, Set
import warnings
from setuptools import setup, find_packages
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
HAS_SM80 = False
HAS_SM86 = False
HAS_SM89 = False
HAS_SM90 = False
# Supported NVIDIA GPU architectures.
# SUPPORTED_ARCHS = {"8.0", "8.6", "8.9", "9.0"}
# Compiler flags.
CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]
NVCC_FLAGS = [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--use_fast_math",
"--threads=8",
"-Xptxas=-v",
"-diag-suppress=174", # suppress the specific warning
]
ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
# if CUDA_HOME is None:
# raise RuntimeError(
# "Cannot find CUDA_HOME. CUDA must be available to build the package.")
# def get_nvcc_cuda_version(cuda_dir: str) -> Version:
# """Get the CUDA version from nvcc.
# Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
# """
# nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
# universal_newlines=True)
# output = nvcc_output.split()
# release_idx = output.index("release") + 1
# nvcc_cuda_version = parse(output[release_idx].split(",")[0])
# return nvcc_cuda_version
# def get_torch_arch_list() -> Set[str]:
# # TORCH_CUDA_ARCH_LIST can have one or more architectures,
# # e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
# # compiler to additionally include PTX code that can be runtime-compiled
# # and executed on the 8.6 or newer architectures. While the PTX code will
# # not give the best performance on the newer architectures, it provides
# # forward compatibility.
# env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
# if env_arch_list is None:
# return set()
# # List are separated by ; or space.
# torch_arch_list = set(env_arch_list.replace(" ", ";").split(";"))
# if not torch_arch_list:
# return set()
# # Filter out the invalid architectures and print a warning.
# valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS})
# arch_list = torch_arch_list.intersection(valid_archs)
# # If none of the specified architectures are valid, raise an error.
# if not arch_list:
# raise RuntimeError(
# "None of the CUDA architectures in TORCH_CUDA_ARCH_LIST
env "
# f"variable ({env_arch_list}) is supported. "
# f"Supported CUDA architectures are: {valid_archs}.")
# invalid_arch_list = torch_arch_list - valid_archs
# if invalid_arch_list:
# warnings.warn(
# f"Unsupported CUDA architectures ({invalid_arch_list}) are "
# "excluded from the TORCH_CUDA_ARCH_LIST
env variable "
# f"({env_arch_list}). Supported CUDA architectures are: "
# f"{valid_archs}.")
# return arch_list
# # First, check the TORCH_CUDA_ARCH_LIST environment variable.
# compute_capabilities = get_torch_arch_list()
# if not compute_capabilities:
# # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
# # GPUs on the current machine.
# device_count = torch.cuda.device_count()
# for i in range(device_count):
# major, minor = torch.cuda.get_device_capability(i)
# if major < 8:
# raise RuntimeError(
# "GPUs with compute capability below 8.0 are not supported.")
# compute_capabilities.add(f"{major}.{minor}")
# nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
# if not compute_capabilities:
# raise RuntimeError("No GPUs found. Please specify the target GPU architectures or build on a machine with GPUs.")
# # Validate the NVCC CUDA version.
# if nvcc_cuda_version < Version("12.0"):
# raise RuntimeError("CUDA 12.0 or higher is required to build the package.")
# if nvcc_cuda_version < Version("12.4") and any(cc.startswith("8.9") for cc in compute_capabilities):
# raise RuntimeError(
# "CUDA 12.4 or higher is required for compute capability 8.9.")
# if nvcc_cuda_version < Version("12.3") and any(cc.startswith("9.0") for cc in compute_capabilities):
# if any(cc.startswith("9.0") for cc in compute_capabilities):
# raise RuntimeError(
# "CUDA 12.3 or higher is required for compute capability 9.0.")
# Add target compute capabilities to NVCC flags.
# for capability in compute_capabilities:
# num = capability[0] + capability[2]
# if num == "80":
# HAS_SM80 = True
# elif num == "86":
# HAS_SM86 = True
# elif num == "89":
# HAS_SM89 = True
# elif num == "90":
# HAS_SM90 = True
# num = num + "a" # convert sm90 to sm9a
# NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
# if capability.endswith("+PTX"):
# NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
NVCC_FLAGS += ["-gencode", f"arch=compute_120,code=sm_120"]
ext_modules = []
qattn_extension = CUDAExtension(
name="sageattention._qattn_sm80",
sources=[
"csrc/qattn/pybind_sm80.cpp",
"csrc/qattn/qk_int_sv_f16_cuda_sm80.cu",
],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(qattn_extension)
qattn_extension = CUDAExtension(
name="sageattention._qattn_sm89",
sources=[
"csrc/qattn/pybind_sm89.cpp",
"csrc/qattn/qk_int_sv_f8_cuda_sm89.cu",
],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(qattn_extension)
# Fused kernels.
fused_extension = CUDAExtension(
name="sageattention._fused",
sources=["csrc/fused/pybind.cpp", "csrc/fused/fused.cu"],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(fused_extension)
setup(
name='sageattention',
version='2.1.0',
author='SageAttention team',
license='Apache 2.0 License',
description='Accurate and efficient plug-and-play low-bit attention.',
long_description=open('README.md', encoding='utf-8').read(),
long_description_content_type='text/markdown',
url='https://github.com/thu-ml/SageAttention',
packages=find_packages(),
python_requires='>=3.9',
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
)
6.Compile SageAttention
Open cmd in the SageAttention directory and run:
E:\ComfyUI-aki-v1.6\python\python.exe -m pip install -e .
7.Modify the vararg_kernel.py file
Search for the file vararg_kernel.py in comfyui,Usually in E:\ComfyUI-aki-v1.6\python\Lib\site-packages\xformers\triton,
Edit this file ,change
jitted_fn.src = new_src
to
jitted_fn._unsafe_update_src(new_src)
jitted_fn.hash = None
8.Run Comfyui and test
9.References
https://github.com/thu-ml/SageAttention
https://github.com/thu-ml/SageAttention/issues/107
https://github.com/thu-ml/SageAttention/issues/122