{"id":13439509,"url":"https://github.com/DefTruth/CUDA-Learn-Notes","last_synced_at":"2025-03-20T08:31:28.243Z","repository":{"id":103961991,"uuid":"579300230","full_name":"DefTruth/CUDA-Learn-Notes","owner":"DefTruth","description":"📚200+ Tensor/CUDA Cores Kernels, ⚡️flash-attn-mma, ⚡️hgemm with WMMA, MMA and CuTe (98%~100% TFLOPS of cuBLAS/FA2 🎉🎉).","archived":false,"fork":false,"pushed_at":"2025-03-04T04:14:08.000Z","size":231950,"stargazers_count":2844,"open_issues_count":4,"forks_count":293,"subscribers_count":22,"default_branch":"main","last_synced_at":"2025-03-14T10:41:50.840Z","etag":null,"topics":["cuda","cuda-kernels","cuda-programming","cuda-toolkit","cudnn","cutlass","flash-attention","flash-mla","gemm","gemv","hgemm"],"latest_commit_sha":null,"homepage":"","language":"Cuda","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"gpl-3.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/DefTruth.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-12-17T08:19:52.000Z","updated_at":"2025-03-14T10:21:11.000Z","dependencies_parsed_at":null,"dependency_job_id":"32340c72-0188-407e-a91e-4c255e1bac87","html_url":"https://github.com/DefTruth/CUDA-Learn-Notes","commit_stats":{"total_commits":431,"total_committers":6,"mean_commits":71.83333333333333,"dds":0.02784222737819031,"last_synced_commit":"f4d8d91671dd9b949bd705733ef48f3fc4367023"},"previous_names":["deftruth/learn-optimize-cuda-simd","deftruth/learn-cuda-optimize","deftruth/cuda-learn-notes","deftruth/cuda-learn-note"],"tags_count":51,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/DefTruth%2FCUDA-Learn-Notes","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/DefTruth%2FCUDA-Learn-Notes/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/DefTruth%2FCUDA-Learn-Notes/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/DefTruth%2FCUDA-Learn-Notes/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/DefTruth","download_url":"https://codeload.github.com/DefTruth/CUDA-Learn-Notes/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":244577768,"owners_count":20475360,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["cuda","cuda-kernels","cuda-programming","cuda-toolkit","cudnn","cutlass","flash-attention","flash-mla","gemm","gemv","hgemm"],"created_at":"2024-07-31T03:01:14.502Z","updated_at":"2025-03-20T08:31:23.234Z","avatar_url":"https://github.com/DefTruth.png","language":"Cuda","readme":"![cuda-learn-note](https://github.com/DefTruth/CUDA-Learn-Note/assets/31974251/882271fe-ab60-4b0e-9440-2e0fa3c0fb6f)   \n\n\u003cdiv align='center'\u003e\n  \u003cimg src=https://img.shields.io/badge/Language-CUDA-brightgreen.svg \u003e\n  \u003cimg src=https://img.shields.io/github/watchers/DefTruth/cuda-learn-note?color=9cc \u003e\n  \u003cimg src=https://img.shields.io/github/forks/DefTruth/cuda-learn-note.svg?style=social \u003e\n  \u003cimg src=https://img.shields.io/github/stars/DefTruth/cuda-learn-note.svg?style=social \u003e\n  \u003cimg src=https://img.shields.io/badge/Release-v0.6-brightgreen.svg \u003e\n  \u003cimg src=https://img.shields.io/badge/License-GPLv3.0-turquoise.svg \u003e\n \u003c/div\u003e   \n\n📒**CUDA-Learn-Notes**: 🎉CUDA/C++ 笔记 / 大模型手撕CUDA / 技术博客，更新随缘: flash_attn、sgemm、sgemv、warp reduce、block reduce、dot、elementwise、softmax、layernorm、rmsnorm、histogram、relu、sigmoid etc. 更多资料，请关注本人知乎技术博客: [DefTruth on ZhiHu](https://www.zhihu.com/people/qyjdef/posts)\n\n\n\u003e 想要我的财宝吗？想要的话可以全部给你，去找吧！我把所有财宝都放在那里！—— **哥尔·D·罗杰**\n\n## News 👇👇\nMost of my time now is focused on **LLM/VLM** Inference. Please check 📖[Awesome-LLM-Inference](https://github.com/DefTruth/Awesome-LLM-Inference)  ![](https://img.shields.io/github/stars/DefTruth/Awesome-LLM-Inference.svg?style=social), 📖[Awesome-SD-Inference](https://github.com/DefTruth/Awesome-SD-Inference)  ![](https://img.shields.io/github/stars/DefTruth/Awesome-SD-Inference.svg?style=social) and 📖[CUDA-Learn-Notes](https://github.com/DefTruth/CUDA-Learn-Notes)  ![](https://img.shields.io/github/stars/DefTruth/CUDA-Learn-Notes.svg?style=social) for more details.\n\n\n## 0x00 📖 博客目录\n\n\u003c!---\n### 📒 图解LLM推理优化\n\n- vLLM Automatic Prefix Caching: Prefix + Generated KV Caching\n\u003cimg width=\"1106\" alt=\"image\" src=\"https://github.com/user-attachments/assets/476c46c8-2427-4e6d-8808-ab168f2be101\"\u003e\n\n- vLLM Automatic Prefix Caching: Hash Prefix Tree (Part-1)\n\n\u003cimg width=\"1019\" alt=\"image\" src=\"https://github.com/user-attachments/assets/ee499627-529e-439e-9fb1-523d01c153b8\"\u003e\n\n- vLLM Automatic Prefix Caching: Hash Prefix Tree (Part-2)\n  \n\u003cimg width=\"983\" alt=\"image\" src=\"https://github.com/user-attachments/assets/9483d982-4314-4d49-a9a3-1b1dc93dbe10\"\u003e\n\n- vLLM Automatic Prefix Caching: CachedBlockAllocator   \n\u003cimg width=\"1046\" alt=\"image\" src=\"https://github.com/user-attachments/assets/60b7e82d-2668-4103-82d2-3a0ac289c3a3\"\u003e\n\n- vLLM Automatic Prefix Caching: Prefix Prefill Triton Kernel Tiling\n\n![prefix prefill](https://github.com/DefTruth/CUDA-Learn-Notes/assets/31974251/8e1fe25a-0697-408a-849a-6f0ea47012b0)\n\n\n- FlashAttenion V1/V2/V3: FlashAttenion V2 Kernel Tiling  \n\u003cimg width=\"1438\" alt=\"image\" src=\"https://github.com/user-attachments/assets/0c5e5125-586f-43fa-8e8b-e2c61c1afbbe\"\u003e\n\n- TensorRT-LLM: TensorRT-LLM部署调优-指北\n\n![image](https://github.com/user-attachments/assets/0c69d866-2a44-475f-8732-92e74d0133cc)\n--\u003e \n\n\u003cimg width=\"1438\" alt=\"image\" src=\"https://github.com/user-attachments/assets/0c5e5125-586f-43fa-8e8b-e2c61c1afbbe\"\u003e\n\n### 📒 大模型/多模态/SD 推理优化\n\n- [[InternLM/VL系列][万字]📒InternLM2/InternLM2.5/InternViT/InternVL1.5/InternVL2笔记: 核心点解析](https://zhuanlan.zhihu.com/p/702481058)\n- [[TensorRT-LLM][5w字]🔥TensorRT-LLM部署调优-指北](https://zhuanlan.zhihu.com/p/699333691)\n- [[KV Cache优化]🔥MQA/GQA/YOCO/CLA/MLKV笔记: 层内和层间KV Cache共享](https://zhuanlan.zhihu.com/p/697311739)\n- [[Prefill优化]🔥图解vLLM Prefix Prefill Triton Kernel](https://zhuanlan.zhihu.com/p/695799736)\n- [[Prefill优化][万字]🔥原理\u0026图解vLLM Automatic Prefix Cache(RadixAttention): 首Token时延优化](https://zhuanlan.zhihu.com/p/693556044)\n- [[Attention优化][2w字]🔥原理\u0026图解: 从Online-Softmax到FlashAttention V1/V2/V3](https://zhuanlan.zhihu.com/p/668888063)\n- [[Decoding优化]🔥原理\u0026图解FlashDecoding/FlashDecoding++](https://zhuanlan.zhihu.com/p/696075602)\n- [[LLaVA系列]📒CLIP/LLaVA/LLaVA1.5/VILA笔记: 核心点解析](https://zhuanlan.zhihu.com/p/683137074)\n- [[Attention优化][万字]🔥TensorRT 9.2 MHA/Myelin Optimize vs FlashAttention-2 profile](https://zhuanlan.zhihu.com/p/678873216)\n- [[CUDA 12 PTX汇编]📒PRMT指令详解-通用模式](https://zhuanlan.zhihu.com/p/660630414)\n- [[CUDA 12 PTX汇编]📒LOP3指令详解](https://zhuanlan.zhihu.com/p/659741469)\n- [[LLM推理优化][3w字]🔥高频面试题汇总-大模型手撕CUDA](https://zhuanlan.zhihu.com/p/678903537)\n- [[LLM推理优化]🔥WINT8/4-(00): 通俗易懂讲解-快速反量化算法](https://zhuanlan.zhihu.com/p/657072856)\n- [[LLM推理优化]🔥WINT8/4-(01): PRMT指令详解及FasterTransformer源码解析](https://zhuanlan.zhihu.com/p/657070837)\n- [[LLM推理优化]🔥WINT8/4-(02): 快速反量化之INT8转BF16](https://zhuanlan.zhihu.com/p/657073159)\n- [[LLM推理优化]🔥WINT8/4-(03): LOP3指令详解及INT4转FP16/BF16分析](https://zhuanlan.zhihu.com/p/657073857)\n- [[LLM推理优化]🔥100+篇: 大模型推理各方向新发展整理](https://zhuanlan.zhihu.com/p/693680304)\n- [[LLM推理优化]🔥30+篇: LLM推理论文集-500页PDF💡](https://zhuanlan.zhihu.com/p/669777159)\n- [[LLM推理优化]🔥FlashDecoding++: 比FlashDecoding还要快！](https://zhuanlan.zhihu.com/p/665022589)\n- [[LLM推理优化]🔥速递：TensorRT-LLM开源，TensorRT 9.1 也来了🤓](https://zhuanlan.zhihu.com/p/662361469)\n- [[LLM推理优化]🔥20+篇: LLM推理论文集-300页PDF💡](https://zhuanlan.zhihu.com/p/658091768)\n- [[LLM推理优化]🔥PagedAttention论文新鲜出炉](https://zhuanlan.zhihu.com/p/617015570)\n\n### 📒 CV移动端/服务端 推理部署\n\n- [[推理部署]⚡️🔥覆盖云边端全场景，FastDeploy三行代码搞定150+ CV、NLP、Speech模型部署](https://zhuanlan.zhihu.com/p/581326442)\n- [[推理部署]💡如何在lite.ai.toolkit(3.5k+🔥stars)中增加您的模型？](https://zhuanlan.zhihu.com/p/523876625)\n- [[推理部署]🤓凑个热闹之 美团 YOLOv6 ORT/MNN/TNN/NCNN C++推理部署](https://zhuanlan.zhihu.com/p/533643238)\n- [[推理部署]🌔ONNX推理加速技术文档-杂记](https://zhuanlan.zhihu.com/p/524023964)\n- [[推理部署]👉Mac源码编译TensorFlow C++指北](https://zhuanlan.zhihu.com/p/524013615)\n- [[推理部署]👿1Mb!头部姿态估计: 来讲讲FSANet，一个小而美的模型(含ONNXRuntime/MNN C++实现)](https://zhuanlan.zhihu.com/p/447364201)\n- [[推理部署]🤓opencv+ffmpeg编译打包全解指南](https://zhuanlan.zhihu.com/p/472115312)\n- [[推理部署]🔧填坑: RobustVideoMatting(5k+🔥star)视频抠图静态ONNX模型转换](https://zhuanlan.zhihu.com/p/459088407)\n- [[推理部署]🔥190Kb!SSRNet年龄检测详细解读（含C++工程）](https://zhuanlan.zhihu.com/p/462762797)\n- [[推理部署]🔥MGMatting(CVPR2021)人像抠图C++应用记录](https://zhuanlan.zhihu.com/p/464732042)\n- [[推理部署]🍅🍅超准确人脸检测(带关键点)YOLO5Face C++工程详细记录](https://zhuanlan.zhihu.com/p/461878005)\n- [[推理部署]👋解决: ONNXRuntime(Python) GPU 部署配置记录](https://zhuanlan.zhihu.com/p/457484536)\n- [[推理部署]🍅记录SCRFD(CVPR2021)人脸检测C++工程化(含docker镜像)](https://zhuanlan.zhihu.com/p/455165568)\n- [[推理部署]👋野路子：记录一个解决onnx转ncnn时op不支持的trick](https://zhuanlan.zhihu.com/p/451446147)\n- [[推理部署]🔥升级版NanoDet-Plus MNN/TNN/NCNN/ONNXRuntime C++工程记录](https://zhuanlan.zhihu.com/p/450586647)\n- [[推理部署]📒超有用NCNN参考资料整理](https://zhuanlan.zhihu.com/p/449765328)\n- [[推理部署]📒超有用MNN参考资料整理](https://zhuanlan.zhihu.com/p/449761992)\n- [[推理部署]📒超有用TNN参考资料整理](https://zhuanlan.zhihu.com/p/449769615)\n- [[推理部署]📒超有用ONNX参考资料整理](https://zhuanlan.zhihu.com/p/449773663)\n- [[推理部署]📒超有用ONNX模型结构参考资料整理](https://zhuanlan.zhihu.com/p/449775926)\n- [[推理部署]📒超有用OpenCV-DNN参考资料整理](https://zhuanlan.zhihu.com/p/449778377)\n- [[推理部署]📒超有用Tensorflow C++工程化知识点](https://zhuanlan.zhihu.com/p/449788027)\n- [[推理部署]📒深度学习模型转换资料整理](https://zhuanlan.zhihu.com/p/449759361)\n- [[推理部署]🔥🔥超轻量级NanoDet MNN/TNN/NCNN/ONNXRuntime C++工程记录](https://zhuanlan.zhihu.com/p/443419387)\n- [[推理部署]🔥详细记录MGMatting(CVPR2021)🔥MNN、TNN和ONNXRuntime C++移植（长文警告!）](https://zhuanlan.zhihu.com/p/442949027)\n- [[推理部署]🔥YOLOX NCNN/MNN/TNN/ONNXRuntime C++工程简记](https://zhuanlan.zhihu.com/p/447364122)\n- [[推理部署]🔥手动修改YoloX的tnnproto记录-TNN C++](https://zhuanlan.zhihu.com/p/425668734)\n- [[推理部署]🔥🔥🔥 全网最详细 ONNXRuntime C++/Java/Python 资料！](https://zhuanlan.zhihu.com/p/414317269)\n- [[推理部署]🔥RobustVideoMatting🔥2021 ONNXRuntime C++工程化记录-实现篇](https://zhuanlan.zhihu.com/p/413280488)\n- [[推理部署]🔥RobustVideoMatting🔥2021最新视频抠图来了! C++ 工程化记录-应用篇](https://zhuanlan.zhihu.com/p/412491918)\n- [[推理部署]💡ONNXRuntime C++ CMake 工程分析及编译](https://zhuanlan.zhihu.com/p/411887386)\n- [[推理部署]🤓如何使用ONNXRuntime C++ API处理NCHW和NHWC输入？](https://zhuanlan.zhihu.com/p/524230808)\n- [[推理部署]💡tnn-convert搭建简记-YOLOP转TNN](https://zhuanlan.zhihu.com/p/431418709)\n- [[推理部署]💡YOLOP ONNXRuntime C++工程化记录](https://zhuanlan.zhihu.com/p/411651933)\n\n### 📒 C/C++/算法/技术随笔  \n\n- [[C++][CMake]👋超有用CMake参考资料整理](https://zhuanlan.zhihu.com/p/449779892)\n- [[C++][3W字]💡静态链接和静态库实践指北-原理篇](https://zhuanlan.zhihu.com/p/595527528)\n- [[C++]🤓Mac下C++内存检查指北(Valgrind VS Asan)](https://zhuanlan.zhihu.com/p/508470880)\n- [[技术随笔]🔥torchlm: 人脸关键点检测库](https://zhuanlan.zhihu.com/p/467211561)\n- [[技术随笔]📒200页PDF笔记: 《统计学习方法-李航: 笔记-从原理到实现-基于R》](https://zhuanlan.zhihu.com/p/684885595)\n- [[技术随笔]💡如何优雅地git clone和git submodule？](https://zhuanlan.zhihu.com/p/639136221)\n- [[技术随笔]📒人脸重建3D参考资料整理](https://zhuanlan.zhihu.com/p/524034741)\n- [[技术随笔]📒BlendShapes参考资料整理](https://zhuanlan.zhihu.com/p/524036145)\n- [[技术随笔]🛠🛠从源码安装Pytorch3D详细记录及学习资料](https://zhuanlan.zhihu.com/p/512347464)\n- [[技术随笔]🍅🍅200页:《统计学习方法：李航》笔记 -从原理到实现](https://zhuanlan.zhihu.com/p/461520847)\n\n\n## 0x01 📖 Kernel目录\n\u003cdiv id=\"kernellist\"\u003e\u003c/div\u003e  \n\n- [x] 📖 [sgemm_naive_f32_kernel](#sgemm)\n- [x] 📖 [sgemm_block_tile_k_tile_vec4_f32_kernel](#sgemm)\n- [x] 📖 [sgemv_k32_f32_kernel](#sgemv)\n- [x] 📖 [sgemv_k128_f32_kernel](#sgemv)\n- [x] 📖 [sgemv_k16_f32_kernel](#sgemv)\n- [x] 📖 [warp_reduce_sum/max_f32_kernel](#warpreduce)\n- [x] 📖 [block_reduce_sum/max_f32_kernel](#warpreduce)\n- [x] 📖 [block_all_reduce_f32_kernel](#blockallreduce)\n- [x] 📖 [block_all_reduce_vec4_f32_kernel](#blockallreduce)\n- [x] 📖 [dot_product_f32_kernel](#dot)\n- [x] 📖 [dot_product_vec4_f32_kernel](#dot)\n- [x] 📖 [elementwise_f32_kernel](#elementwise)\n- [x] 📖 [elementwise_vec4_f32_kernel](#elementwise)\n- [x] 📖 [histogram_i32_kernel](#histogram)\n- [x] 📖 [histogram_vec4_i32_kernel](#histogram)\n- [x] 📖 [softmax_f32_kernel (grid level memory fence)](#softmax)\n- [x] 📖 [softmax_vec4_f32_kernel (grid level memory fence)](#softmax)\n- [ ] 📖 [safe_softmax_f32_kernel (per token)](#softmax)\n- [x] 📖 [sigmoid_f32_kernel](#sigmoid)\n- [x] 📖 [sigmoid_vec4_f32_kernel](#sigmoid)\n- [ ] 📖 [safe_sigmoid_f32_kernel](#sigmoid)\n- [x] 📖 [relu_f32_kernel](#relu)\n- [x] 📖 [relu_vec4_f32_kernel](#relu)\n- [x] 📖 [layer_norm_f32_kernel (per token)](#layernorm)\n- [x] 📖 [layer_norm_vec4_f32_kernel (per token)](#layernorm)\n- [ ] 📖 [layer_norm_vec4_f16_kernel (per token)](#layernorm)\n- [x] 📖 [rms_norm_f32_kernel (per token)](#rmsnorm)\n- [x] 📖 [rms_norm_vec4_f32_kernel (per token)](#rmsnorm)\n- [ ] 📖 [rms_norm_vec4_f16_kernel (per token)](#rmsnorm)\n- [x] 📖 [flash_attn_1_fwd_f32_kernel](./flash_attn_1_fwd_f32.cu)\n- [ ] 📖 flash_attn_2_fwd_f32_kernel\n- [ ] 📖 flash_attn_2_fwd_f16_kernel\n- [ ] 📖 flash_attn_2_fwd_b16_kernel\n- [ ] 📖 flash_attn_2_fwd_f8_kernel\n- [ ] 📖 flash_attn_2_split_kv_f16_kernel\n- [ ] 📖 flash_attn_2_split_kv_b16_kernel\n- [ ] 📖 flash_attn_2_split_kv_f8_kernel\n- [ ] 📖 online_softmax_f32_kernel\n- [ ] 📖 online_softmax_f16_kernel\n- [ ] 📖 online_softmax_b16_kernel\n- [ ] 📖 hgemm_f16_kernel\n- [ ] 📖 sgemm_dbuf_f32_kernel\n\n## 0x02 sgemm naive, sgemm + block-tile + k-tile + vec4  ([©️back👆🏻](#kernellist))  \n\u003cdiv id=\"sgemm\"\u003e\u003c/div\u003e  \n\n```c++\n#include \u003cstdio.h\u003e\n#include \u003cstdlib.h\u003e\n#include \u003cfloat.h\u003e\n#include \u003cvector\u003e\n#include \u003calgorithm\u003e\n#include \u003ccuda_runtime.h\u003e\n\n#define WARP_SIZE 32\n#define INT4(value) (reinterpret_cast\u003cint4*\u003e(\u0026(value))[0])\n#define FLOAT4(value) (reinterpret_cast\u003cfloat4*\u003e(\u0026(value))[0])\n\n// SGEMM: Block Tile + K Tile, with smem\n// Block Tile (BM, BN) + K Tile (BK=32)\n// grid((N + BN - 1) / BN, (M + BM - 1) / BM), block(BN, BM)\n// a: MxK, b: KxN, c: MxN, compute: c = a * b, all row major  \n__global__ void sgemm(float* a, float* b, float* c, int M, int N, int K) {\n  // [1] Block Tile: 32x32的block处理c上一块32x32的元素计算\n  // [2]     K Tile: 使用共享内存，并将K分块为BK大小的块\n  constexpr int BM = 32;\n  constexpr int BN = 32;\n  constexpr int BK = 32;\n  __shared__ float s_a[BM][BK], s_b[BK][BN]; \n\n  int bx = blockIdx.x;\n  int by = blockIdx.y;\n  int tx = threadIdx.x;\n  int ty = threadIdx.y;\n  int tid = threadIdx.y * blockDim.x + tx; // tid within the block\n  // load values to shared memory, 32x32 threads working together \n  // to fetch data along the row direction of a and b both for s_a \n  // and s_b 32x32x4x2=8KB, we use 32x32 threads within block to \n  // load 32x32 elements from global memory to shared memory, namely, \n  // each thread will load 1 element.\n  int load_smem_a_m = tid / 32; // 0~31, tid / 32, tid / BM, threadIdx.y\n  int load_smem_a_k = tid % 32; // 0~31, tid % 32, tid % BK, threadIdx.x\n  int load_smem_b_k = tid / 32; // 0~31, tid / 32, tid / BK, threadIdx.y\n  int load_smem_b_n = tid % 32; // 0~31, tid % 32, tid % BN, threadIdx.x\n  int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c\n  int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c\n  // if (load_gmem_a_m \u003e= M || load_gmem_b_n \u003e= N) return;\n  \n  float sum = 0.f;\n  for (int bk = 0; bk \u003c (K + BK - 1) / BK; ++bk) {\n    int load_gmem_a_k = bk * BK + load_smem_a_k;\n    int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;\n    s_a[load_smem_a_m][load_smem_a_k] = a[load_gmem_a_addr];\n    int load_gmem_b_k = bk * BK + load_smem_b_k;\n    int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;\n    s_b[load_smem_b_k][load_smem_b_n] = b[load_gmem_b_addr];\n    __syncthreads();\n    #pragma unroll\n    for (int k = 0; k \u003c BK; ++k) {\n      int comp_smem_a_m = load_smem_a_m;\n      int comp_smem_b_n = load_smem_b_n;\n      sum += s_a[comp_smem_a_m][k] * s_b[k][comp_smem_b_n];\n    }\n    __syncthreads();\n  }\n  int store_gmem_c_m = load_gmem_a_m;\n  int store_gmem_c_n = load_gmem_b_n;\n  int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n;\n  c[store_gmem_c_addr] = sum;\n}\n\n// SGEMM: Block Tile + Thread Tile + K Tile + Vec4, with smem\n// BK:TILE_K=8 BM=BN=128\n// TM=TN=8 增加计算密度 BM/TM=16 BN/TN=16\n// dim3 blockDim(BN/TN, BM/TM);\n// dim3 gridDim((N + BN - 1) / BN, (M + BM - 1) / BM)\n__global__ void sgemm_thread_tile_vec4(\n  float* a, float* b, float* c, int M, int N, int K) {\n  // [1]  Block Tile: 一个16x16的block处理C上大小为128X128的一个目标块\n  // [2] Thread Tile: 每个thread负责计算TM*TN(8*8)个元素，增加计算密度\n  // [3]      K Tile: 将K分块，每块BK大小，迭代(K+BK-1/BK)次，\n  //                  每次计算TM*TN个元素各自的部分乘累加\n  // [4]   Vectorize: 减少load和store指令，使用float4\n  constexpr int BM = 128;\n  constexpr int BN = 128;\n  constexpr int BK = 8; \n  constexpr int TM = 8;\n  constexpr int TN = 8;\n\n  int bx = blockIdx.x;\n  int by = blockIdx.y;\n  int tx = threadIdx.x;\n  int ty = threadIdx.y;\n  int tid = threadIdx.y * blockDim.x + tx; // tid within the block\n  __shared__ float s_a[BM][BK], s_b[BK][BN]; // 2*128*8*4=8KB\n  \n  // 0. 先计算shared memory中的索引\n  // tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=8 按行读取 A行主序\n  // 对于s_a每行8个数据，每个线程读取4个，需要2个线程；总共128行，需要128x2刚好256线程\n  int load_smem_a_m = tid / 2; // tid/2 (128/8)*(128/8)=256 threads per block, tid/2-\u003e[0,128), BM=128 0~127\n  int load_smem_a_k = (tid % 2 == 0) ? 0 : 4;  // (tid%2 == 0) ? 0 : 4, col of s_a 0,4\n  // tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=8 BN=128 按行读取 B行主序\n  // 对于s_b每行128个数据，每个线程读4个数据，需要32个线程；总共8行，需要32x8=256个线程\n  int load_smem_b_k = tid / 32; // tid/32, row of s_b 256/32=8 行 0~7\n  int load_smem_b_n = (tid % 32) * 4;  // (tid % 32) * 4, col of s_b 0,4,...,124\n  // 1. 再计算全局内存中的索引\n  // 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块\n  int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c\n  int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c\n  \n  float r_c[TM][TN] = {0.0}; // 8x8\n  // 2. 先对K进行分块，每块BK大小\n  for (int bk = 0; bk \u003c (K + BK - 1) / BK; ++bk) {\n    // 加载数据到共享内存smem s_a BM*BK 128*8 vectorize float4\n    int load_gmem_a_k = bk * BK + load_smem_a_k; // global col of a\n    int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;\n    FLOAT4(s_a[load_smem_a_m][load_smem_a_k]) = FLOAT4(a[load_gmem_a_addr]);\n    // 加载数据到共享内存smem s_b BK*BN 8*128 vectorize float4\n    int load_gmem_b_k = bk * BK + load_smem_b_k; // global row of b\n    int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; \n    FLOAT4(s_b[load_smem_b_k][load_smem_b_n]) = FLOAT4(b[load_gmem_b_addr]); \n    __syncthreads();\n    #pragma unroll\n    for (int k = 0; k \u003c BK; k++) {\n      // 3. 每个线程负责计算BM*BN(12x128)中的TM*TN(8x8)个元素\n      #pragma unroll\n      for (int m = 0; m \u003c TM; m++) {\n        #pragma unroll\n        for (int n = 0; n \u003c TN; n++) {\n          // k from 0~7，0 ~ BK, ty and tx range from 0 to 15, 16x8=128\n          int comp_smem_a_m = ty * TM + m;  // 128*8 128/TM(8)=16 M方向 16线程\n          int comp_smem_b_n = tx * TN + n;  // 8*128 128/TN(8)=16 N方向 16线程\n          r_c[m][n] += s_a[comp_smem_a_m][k] * s_b[k][comp_smem_b_n];\n        }\n      }\n    }\n    __syncthreads();\n  }\n\n  #pragma unroll\n  for (int m = 0; m \u003c TM; ++m) {\n    int store_gmem_c_m = by * BM + ty * TM + m;\n    #pragma unroll\n    for (int n = 0; n \u003c TN; n += 4) {\n      int store_gmem_c_n = bx * BN + tx * TN + n;\n      int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n;\n      FLOAT4(c[store_gmem_c_addr]) = FLOAT4(r_c[m][n]);\n    }\n  }\n}\n```\n这里gemm的实现比较简单，只使用了CUDA Cores，并且只实现Block Tile + K Tile以及Block Tile + K Tile+Thread Tile+向量化的版本。主要在于如何加载gmem中的数据到smem，也就是把全局内存中的数据索引mapping到共享内存中的。核心思维：把一个block中的线程id按照线性来理解，然后把这个线性的id和全局内存索引以及共享内存索引进行匹配。比如Block Tile + K Tile的实现，block内一共32x32个Threads，需要加载到smem的数据也是32x32，那么，最简单的做法，只需要每个线程加载一个互不重复数据即可。NOTE，本文的gemm kernel修改自：[紫气东来：CUDA（三）：通用矩阵乘法：从入门到熟练](https://zhuanlan.zhihu.com/p/657632577)\n\n\n## 0x03 warp/block reduce sum/max  ([©️back👆🏻](#kernellist))\n\u003cdiv id=\"warpreduce\"\u003e\u003c/div\u003e  \n\n```C++\n// Warp Reduce Sum\ntemplate\u003cconst int kWarpSize = WARP_SIZE\u003e\n__device__ __forceinline__ float warp_reduce_sum(float val) {\n  #pragma unroll\n  for (int mask = kWarpSize \u003e\u003e 1; mask \u003e= 1; mask \u003e\u003e= 1) {\n    val += __shfl_xor_sync(0xffffffff, val, mask);\n  }\n  return val;\n}\n\n// Warp Reduce Max\ntemplate\u003cconst int kWarpSize = WARP_SIZE\u003e\n__device__ __forceinline__ float warp_reduce_max(float val) {\n  #pragma unroll\n  for (int mask = kWarpSize \u003e\u003e 1; mask \u003e= 1; mask \u003e\u003e= 1) {\n    val = fmaxf(val, __shfl_xor_sync(0xffffffff, val, mask));\n  }\n  return val;\n}\n\n// Block reduce sum/max/min device helper for Layer/RMS Norm/Softmax etc.\n// grid 1D block 1D, grid(N/128), block(128)\ntemplate\u003cconst int NUM_THREADS=128\u003e\n__device__ __forceinline__ float block_reduce_sum(float val) {\n  // always \u003c= 32 warps per block (limited by 1024 threads per block)\n  constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;\n  int warp = threadIdx.x / WARP_SIZE;\n  int lane = threadIdx.x % WARP_SIZE;\n  static __shared__ float shared[NUM_WARPS];\n  \n  val = warp_reduce_sum\u003cWARP_SIZE\u003e(val);\n  if (lane == 0) shared[warp] = val;\n  __syncthreads();\n  val = (lane \u003c NUM_WARPS) ? shared[lane] : 0.0f;\n  val = warp_reduce_sum\u003cNUM_WARPS\u003e(val);\n  return val;\n}\n\ntemplate\u003cconst int NUM_THREADS=128\u003e\n__device__ __forceinline__ float block_reduce_max(float val) {\n  // always \u003c= 32 warps per block (limited by 1024 threads per block)\n  constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;\n  int warp = threadIdx.x / WARP_SIZE;\n  int lane = threadIdx.x % WARP_SIZE;\n  static __shared__ float shared[NUM_WARPS];\n  \n  val = warp_reduce_max\u003cWARP_SIZE\u003e(val);\n  if (lane == 0) shared[warp] = val;\n  __syncthreads();\n  val = (lane \u003c NUM_WARPS) ? shared[lane] : -FLT_MAX;\n  val = warp_reduce_max\u003cNUM_WARPS\u003e(val);\n  return val;\n}\n```\nwarp reduce几乎已经成为大部分reduce kernel的标准写法了，比如vLLM中，就是这种经典的写法。所以，先搞懂warp reduce（也就是搞懂各种warp functions的用法），再去写其他kernel，思路就会容易很多。需要注意的是，warp函数处理的是寄存器上的数据，也就是说，此时，没必要先加载数据到smem，再进行reduce，直接加载到寄存器即可（以前犯过这个小错误...）。Warp Functions建议参考：[jhang：CUDA编程入门之Warp-Level Primitives](https://zhuanlan.zhihu.com/p/572820783)\n\n## 0x04 block all reduce + vec4  ([©️back👆🏻](#kernellist))\n\u003cdiv id=\"blockallreduce\"\u003e\u003c/div\u003e  \n\n```c++\n// Block All Reduce Sum\n// grid(N/128), block(128)\n// a: Nx1, y=sum(a)\ntemplate\u003cconst int NUM_THREADS = 128\u003e\n__global__ void block_all_reduce_sum(float* a, float* y, int N) {\n  int tid = threadIdx.x;\n  int idx = blockIdx.x * NUM_THREADS + tid;\n  constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;\n  __shared__ float reduce_smem[NUM_WARPS];\n  // keep the data in register is enougth for warp operaion.\n  float sum = (idx \u003c N) ? a[idx] : 0.0f;\n  int warp = tid / WARP_SIZE;\n  int lane = tid % WARP_SIZE;\n  // perform warp sync reduce.\n  sum = warp_reduce_sum\u003cWARP_SIZE\u003e(sum);\n  // warp leaders store the data to shared memory.\n  if (lane == 0) reduce_smem[warp] = sum;\n  __syncthreads(); // make sure the data is in shared memory.\n  // the first warp compute the final sum.\n  sum = (lane \u003c NUM_WARPS) ? reduce_smem[lane] : 0.0f;\n  if (warp == 0) sum = warp_reduce_sum\u003cNUM_WARPS\u003e(sum);\n  if (tid == 0) atomicAdd(y, sum);\n}\n\n// Block All Reduce Sum + float4\n// grid(N/128), block(128/4)\n// a: Nx1, y=sum(a)\ntemplate\u003cconst int NUM_THREADS = 128/4\u003e\n__global__ void block_all_reduce_sum_vec4(float* a, float* y, int N) {\n  int tid = threadIdx.x;\n  int idx = (blockIdx.x * NUM_THREADS + tid) * 4;\n  constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;\n  __shared__ float reduce_smem[NUM_WARPS];\n\n  float4 reg_a = FLOAT4(a[idx]);\n  // keep the data in register is enougth for warp operaion.\n  float sum = (idx \u003c N) ? (reg_a.x + reg_a.y + reg_a.z + reg_a.w) : 0.0f;\n  int warp = tid / WARP_SIZE;\n  int lane = tid % WARP_SIZE;\n  // perform warp sync reduce.\n  sum = warp_reduce_sum\u003cWARP_SIZE\u003e(sum);\n  // warp leaders store the data to shared memory.\n  if (lane == 0) reduce_smem[warp] = sum;\n  __syncthreads(); // make sure the data is in shared memory.\n  // the first warp compute the final sum.\n  sum = (lane \u003c NUM_WARPS) ? reduce_smem[lane] : 0.0f;\n  if (warp == 0) sum = warp_reduce_sum\u003cNUM_WARPS\u003e(sum);\n  if (tid == 0) atomicAdd(y, sum);\n}\n```\nblock all reduce是在warp reduce的基础上进行的，reduce_smem这部分的共享内存申请无法避免，这是用来同步每个warp之间得到局部结果。注意，最后，还需要atomicAdd做一个block级别的原子操作，以得到全局的和。float4向量化优化访存，可以减缓WarpScheduler发送指令的压力。\n\n## 0x05 sgemv k32/k128/k16 kernel   ([©️back👆🏻](#kernellist))\n\u003cdiv id=\"sgemv\"\u003e\u003c/div\u003e  \n\n```C++\n// SGEMV: Warp SGEMV K32\n// 假设K为32的倍数，每个warp负责一行\n// grid(M/4), block(32,4) blockDim.x=32=K, blockDim.y=4\n// a: MxK, x: Kx1, y: Mx1, compute: y = a * x\n__global__ void sgemv_k32(float* a, float* x, float* y, int M, int K) {\n  int tx = threadIdx.x;         // 0~31\n  int ty = threadIdx.y;         // 0~4\n  int bx = blockIdx.x;          // 0~M/4\n  int lane = tx % WARP_SIZE;    // 0~31\n  int m = bx * blockDim.y + ty; // (0~M/4) * 4 + (0~3)\n  if (m \u003c M) {\n    float sum = 0.0f;\n    int NUM_WARPS = (K + WARP_SIZE - 1) / WARP_SIZE;\n    #pragma unroll\n    for (int w = 0; w \u003c NUM_WARPS; ++w) {\n      // 若NUM_WARPS\u003e=2，先将当前行的数据累加到第一个warp中\n      int k = w * WARP_SIZE + lane;\n      sum += a[m * K + k] * x[k];\n    }\n    sum = warp_reduce_sum\u003cWARP_SIZE\u003e(sum);\n    if (lane == 0) y[m] = sum;\n  }\n}\n\n// SGEMV: Warp SGEMV K128 + Vec4\n// 假设K为128的倍数 float4\n// grid(M/4), block(32,4) blockDim.x=32=K, blockDim.y=4\n// a: MxK, x: Kx1, y: Mx1, compute: y = a * x\n__global__ void sgemv_k128(float* a, float* x, float* y, int M, int K) {\n  // 每个线程负责4个元素，一个warp覆盖128个元素\n  int tx = threadIdx.x;         // 0~31\n  int ty = threadIdx.y;         // 0~3\n  int bx = blockIdx.x;          // 0~M/4\n  int lane = tx % WARP_SIZE;    // 0~31\n  int m = blockDim.y * bx + ty; // (0~M/4) * 4 + (0~3)\n  \n  if (m \u003c M) {\n    float sum = 0.0f;\n    // process 4*WARP_SIZE elements per warp.\n    int NUM_WARPS = (((K + WARP_SIZE - 1) / WARP_SIZE) + 4 - 1) / 4;\n    #pragma unroll\n    for (int w = 0; w \u003c NUM_WARPS; ++w) {\n      int k = (w * WARP_SIZE + lane) * 4;\n      float4 reg_x = FLOAT4(x[k]);\n      float4 reg_a = FLOAT4(a[m * K + k]);\n      sum += (reg_a.x * reg_x.x + reg_a.y * reg_x.y \n            + reg_a.z * reg_x.z + reg_a.w * reg_x.w);\n    }\n    sum = warp_reduce_sum\u003cWARP_SIZE\u003e(sum);\n    if(lane == 0) y[m] = sum;\n  }\n}\n\n// SGEMV: Warp SGEMV K16\n// 假设K为16 \u003c 32,每个warp负责2行，每行有16个元素\n// NUM_THREADS=128, NUM_WARPS=NUM_THREADS/WARP_SIZE;\n// NUM_ROWS=NUM_WARPS * ROW_PER_WARP, grid(M/NUM_ROWS), block(32,NUM_WARPS)\n// a: MxK, x: Kx1, y: Mx1, compute: y = a * x\ntemplate\u003cconst int ROW_PER_WARP = 2\u003e \n__global__ void sgemv_k16(float* A, float* x, float* y, int M, int K) {\n  constexpr int K_WARP_SIZE = (WARP_SIZE + ROW_PER_WARP - 1) / ROW_PER_WARP;\n  int tx = threadIdx.x;       // 0~31\n  int ty = threadIdx.y;       // 0~NUM_WARPS\n  int bx = blockIdx.x;        // 0~M/NUM_ROWS (NUM_ROWS=NUM_WARPS * ROW_PER_WARP)\n  int lane = tx % WARP_SIZE;  // 0~31\n  int k = lane % K_WARP_SIZE; // 0~15\n  // gloabl row of a: MxK and y:Mx1, blockDim.y=NUM_WARPS\n  int m = (blockDim.y * bx + ty) * ROW_PER_WARP + lane / K_WARP_SIZE;\n  if (m \u003c M) {\n    float sum = A[m * K + k] * x[k];\n    sum = warp_reduce_sum\u003cK_WARP_SIZE\u003e(sum);\n    // 注意是k == 0，而不是lane == 0\n    if(k == 0) y[m] = sum; \n  }\n}\n```\n估计有些大佬倒立都能写sgemv的各种优化版了，核心思路其实也是基于warp reduce，考虑K的不同情况进行优化。本文的sgemv kernel修改自：[有了琦琦的棍子：深入浅出GPU优化系列：gemv优化](https://zhuanlan.zhihu.com/p/494144694)\n\n## 0x06 dot product, dot product + vec4  ([©️back👆🏻](#kernellist))\n\u003cdiv id=\"dot\"\u003e\u003c/div\u003e  \n\n```c++\n// Dot Product\n// grid(N/128), block(128)\n// a: Nx1, b: Nx1, y=sum(elementwise_mul(a,b))\ntemplate\u003cconst int NUM_THREADS = 128\u003e\n__global__ void dot(float* a, float* b, float* y, int N) {\n  int tid = threadIdx.x;\n  int idx = blockIdx.x * NUM_THREADS + tid;\n  constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;\n  __shared__ float reduce_smem[NUM_WARPS];\n\n  // keep the data in register is enougth for warp operaion.\n  float prod = (idx \u003c N) ? a[idx] * b[idx] : 0.0f;\n  int warp = tid / WARP_SIZE;\n  int lane = tid % WARP_SIZE;\n  // perform warp sync reduce.\n  prod = warp_reduce_sum\u003cWARP_SIZE\u003e(prod);\n  // warp leaders store the data to shared memory.\n  if (lane == 0) reduce_smem[warp] = prod;\n  __syncthreads(); // make sure the data is in shared memory.\n  // the first warp compute the final sum.\n  prod = (lane \u003c NUM_WARPS) ? reduce_smem[lane] : 0.0f;\n  if (warp == 0) prod = warp_reduce_sum\u003cNUM_WARPS\u003e(prod);\n  if (tid == 0) atomicAdd(y, prod);\n}\n\n// Dot Product + Vec4\n// grid(N/128), block(128/4)\n// a: Nx1, b: Nx1, y=sum(elementwise_mul(a,b))\ntemplate\u003cconst int NUM_THREADS = 128/4\u003e\n__global__ void dot_vec4(float* a, float* b, float* y, int N) {\n  int tid = threadIdx.x;\n  int idx = (blockIdx.x * NUM_THREADS + tid) * 4;\n  constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;\n  __shared__ float reduce_smem[NUM_WARPS];\n\n  float4 reg_a = FLOAT4(a[idx]);\n  float4 reg_b = FLOAT4(b[idx]);\n  float prod = (idx \u003c N) ? (reg_a.x * reg_b.x + reg_a.y * reg_b.y \n                          + reg_a.z * reg_b.z + reg_a.w * reg_b.w) : 0.0f;\n  int warp = tid / WARP_SIZE;\n  int lane = tid % WARP_SIZE;\n  // perform warp sync reduce.\n  prod = warp_reduce_sum\u003cWARP_SIZE\u003e(prod);\n  // warp leaders store the data to shared memory.\n  if (lane == 0) reduce_smem[warp] = prod;\n  __syncthreads(); // make sure the data is in shared memory.\n  // the first warp compute the final sum.\n  prod = (lane \u003c NUM_WARPS) ? reduce_smem[lane] : 0.0f;\n  if (warp == 0) prod = warp_reduce_sum\u003cNUM_WARPS\u003e(prod);\n  if (tid == 0) atomicAdd(y, prod);\n}\n```\ndot product kernel的核心就是block reduce，不多说了。\n\n## 0x07 elementwise, elementwise + vec4  ([©️back👆🏻](#kernellist))\n\u003cdiv id=\"elementwise\"\u003e\u003c/div\u003e  \n\n```c++\n// ElementWise Add  \n// grid(N/128), block(128)\n// a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b)\n__global__ void elementwise_add(float* a, float* b, float* c, int N) {\n  int idx = blockIdx.x * blockDim.x + threadIdx.x;\n  if (idx \u003c N) c[idx] = a[idx] + b[idx];\n}\n\n// ElementWise Add + Vec4\n// grid(N/128), block(128/4)\n// a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b)\n__global__ void elementwise_add_vec4(float* a, float* b, float* c, int N) {\n  int idx = 4 * (blockIdx.x * blockDim.x + threadIdx.x);\n  if (idx \u003c N) {\n    float4 reg_a = FLOAT4(a[idx]);\n    float4 reg_b = FLOAT4(b[idx]);\n    float4 reg_c;\n    reg_c.x = reg_a.x + reg_b.x;\n    reg_c.y = reg_a.y + reg_b.y;\n    reg_c.z = reg_a.z + reg_b.z;\n    reg_c.w = reg_a.w + reg_b.w;\n    FLOAT4(c[idx]) = reg_c;\n  }\n}\n```\nelementwise可以考虑加点向量化进行访存优化。\n\n## 0x08 histogram, histogram + vec4  \n\u003cdiv id=\"histogram\"\u003e\u003c/div\u003e  \n\n```c++\n// Histogram\n// grid(N/128), block(128)\n// a: Nx1, y: count histogram\n__global__ void histogram(int* a, int* y, int N) {\n  int idx = blockIdx.x * blockDim.x + threadIdx.x;\n  if (idx \u003c N) atomicAdd(\u0026(y[a[idx]]), 1);\n}\n\n// Histogram + Vec4\n// grid(N/128), block(128/4)\n// a: Nx1, y: count histogram\n__global__ void histogram_vec4(int* a, int* y, int N) {\n  int idx = 4 * (blockIdx.x * blockDim.x + threadIdx.x);\n  if (idx \u003c N) {\n    int4 reg_a = INT4(a[idx]);\n    atomicAdd(\u0026(y[reg_a.x]), 1);\n    atomicAdd(\u0026(y[reg_a.y]), 1);\n    atomicAdd(\u0026(y[reg_a.z]), 1);\n    atomicAdd(\u0026(y[reg_a.w]), 1);\n  }\n}\n```\n统计频数直方图，很简单，两行代码搞定。\n\n## 0x09 softmax, softmax + vec4 (grid level memory fence)   ([©️back👆🏻](#kernellist))\n\u003cdiv id=\"softmax\"\u003e\u003c/div\u003e  \n\n```c++\n// Softmax x: N, y: N\n// grid(N/128), block(K=128)\ntemplate\u003cconst int NUM_THREADS = 128\u003e\n__global__ void softmax(float* x, float* y, float* total, int N) {\n  const int tid = threadIdx.x;\n  const int idx = blockIdx.x * blockDim.x + tid; \n  constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;\n  __shared__ float reduce_smem[NUM_WARPS];\n  \n  float sum = (idx \u003c N) ? expf(x[idx]) : 0.0f;\n  int warp = tid / WARP_SIZE;\n  int lane = tid % WARP_SIZE;\n  sum = warp_reduce_sum\u003cWARP_SIZE\u003e(sum);\n  if (lane == 0) reduce_smem[warp] = sum;\n  __syncthreads();\n  // compute the final sum in each warp\n  sum = (lane \u003c NUM_WARPS) ? reduce_smem[lane] : 0.0f;\n  sum = warp_reduce_sum\u003cNUM_WARPS\u003e(sum); // sum(e^x_0,...,e^x_n-1)\n  // get the total sum of all blocks.\n  if (tid == 0) atomicAdd(total, sum);\n  __threadfence(); // grid level memory fence 注意这里需要网格级别的内存同步\n  // e^x_i/sum(e^x_0,...,e^x_n-1) \n  if (idx \u003c N) y[idx] = block_smem[tid] / (*total); \n}\n\n// Softmax x: N, y: N\n// grid(N/128), block(K=128)\ntemplate\u003cconst int NUM_THREADS = 128\u003e\n__global__ void softmax_v2(float* x, float* y, float* total, int N) {\n  const int tid = threadIdx.x;\n  const int idx = blockIdx.x * blockDim.x + tid; \n  \n  float exp_val = (idx \u003c N) ? expf(x[idx]) : 0.0f;\n  float sum = block_reduce_sum\u003cNUM_THREADS\u003e(exp_val);\n  // get the total sum of all blocks.\n  if (tid == 0) atomicAdd(total, sum);\n  __threadfence(); // grid level memory fence  注意这里需要网格级别的内存同步\n  // e^x_i/sum(e^x_0,...,e^x_n-1) \n  if (idx \u003c N) y[idx] = exp_val / (*total); \n}\n\n// Softmax Vec4 x: N, y: N\n// grid(N/128), block(128/4)\ntemplate\u003cconst int NUM_THREADS = 128/4\u003e\n__global__ void softmax_v2_vec4(float* x, float* y, float* total, int N) {\n  const int tid = threadIdx.x;\n  const int idx = (blockIdx.x * blockDim.x + tid) * 4; \n  \n  float4 reg_x = FLOAT4(x[idx]);\n  float4 reg_exp;\n  reg_exp.x = (idx \u003c N) ? expf(reg_x.x) : 0.0f;\n  reg_exp.y = (idx \u003c N) ? expf(reg_x.y) : 0.0f;\n  reg_exp.z = (idx \u003c N) ? expf(reg_x.z) : 0.0f;\n  reg_exp.w = (idx \u003c N) ? expf(reg_x.w) : 0.0f;\n  float exp_val = (reg_exp.x + reg_exp.y + reg_exp.z + reg_exp.w);\n  float sum = block_reduce_sum\u003cNUM_THREADS\u003e(exp_val);\n  // get the total sum of all blocks.\n  if (tid == 0) atomicAdd(total, sum);\n  __threadfence(); // grid level memory fence  注意这里需要网格级别的内存同步\n  // e^x_i/sum(e^x_0,...,e^x_n-1) \n  if (idx \u003c N) {\n    float4 reg_y;\n    reg_y.x = reg_exp.x / (*total);\n    reg_y.y = reg_exp.y / (*total);\n    reg_y.z = reg_exp.z / (*total);\n    reg_y.w = reg_exp.w / (*total);\n    FLOAT4(y[idx]) = reg_y; \n  }\n}\n```\nsoftmax稍微要注意的就是内存同步的问题，这里，你需要做一个网格级别的同步，而不能仅仅是block级别，否则拿不到全局的exp sum作为分母项。因此使用 __threadfence 这个网格及内存同步操作。不过效率我还没测过，实在要高效的话，可能得整成FA2那样的 1-pass + online softmax的实现。不过，如果是面试的话，就不要太为难自己了...，但是FA1/FA2的论文很经典，强烈建议多读几遍。\n\n## 0x0a sigmoid, sigmoid + vec4   ([©️back👆🏻](#kernellist))\n\u003cdiv id=\"sigmoid\"\u003e\u003c/div\u003e  \n\n```c++\n// Sigmoid x: N, y: N y=1/(1+exp(-x))\n// grid(N/128), block(K=128) \n__global__ void sigmoid(float* x, float* y, int N) {\n  int idx = blockIdx.x * blockDim.x + threadIdx.x;\n  if (idx \u003c N) y[idx] = 1.0f / (1.0f + expf(-x[idx]));\n}\n\n// Sigmoid x: N, y: N y=1/(1+exp(-x)) Vec4\n// grid(N/128), block(128/4)\n__global__ void sigmoid_vec4(float* x, float* y, int N) {\n  int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;\n  if (idx \u003c N) {\n    float4 reg_x = FLOAT4(x[idx]);\n    float4 reg_y;\n    reg_y.x = 1.0f / (1.0f + expf(-reg_x.x));\n    reg_y.y = 1.0f / (1.0f + expf(-reg_x.y));\n    reg_y.z = 1.0f / (1.0f + expf(-reg_x.z));\n    reg_y.w = 1.0f / (1.0f + expf(-reg_x.w));\n    FLOAT4(y[idx]) = reg_y;\n  }\n}\n```\n\n## 0x0b relu, relu + vec4   ([©️back👆🏻](#kernellist))\n\u003cdiv id=\"relu\"\u003e\u003c/div\u003e  \n\n```c++\n// Relu x: N, y: N y=max(0,x)\n// grid(N/128), block(K=128) \n__global__ void relu(float* x, float* y, int N) {\n  int idx = blockIdx.x * blockDim.x + threadIdx.x;\n  if (idx \u003c N) y[idx] = fmaxf(0.0f, x[idx]);\n}\n\n// Relu x: N, y: N y=max(0,x) Vec4\n// grid(N/128/4), block(128/4) \n__global__ void relu_vec4(float* x, float* y, int N) {\n  int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;\n  if (idx \u003c N) {\n    float4 reg_x = FLOAT4(x[idx]);\n    float4 reg_y;\n    reg_y.x = fmaxf(0.0f, reg_x.x);\n    reg_y.y = fmaxf(0.0f, reg_x.y);\n    reg_y.z = fmaxf(0.0f, reg_x.z);\n    reg_y.w = fmaxf(0.0f, reg_x.w);\n    FLOAT4(y[idx]) = reg_y;\n  }\n}\n```\n\n## 0x0c layer_norm, layer_norm + vec4   ([©️back👆🏻](#kernellist))\n\u003cdiv id=\"layernorm\"\u003e\u003c/div\u003e  \n\n```c++\n// Layer Norm: x: NxK(K=128\u003c1024), y': NxK, y'=x-mean(x)/std(x) each row\n// mean(x) = sum(x)/K, 1/std(x) = rsqrtf( sum( (x-mean(x))^2 )/K ) each row\n// grid(N*K/K), block(K\u003c1024) N=batch_size*seq_len, K=hidden_size\n// y=y'*g + b (g: scale, b: bias)\ntemplate\u003cconst int NUM_THREADS=128\u003e\n__global__ void layer_norm(float* x, float* y, float g, float b, int N, int K) {\n  int tid = threadIdx.x; // 0..K-1\n  int bid = blockIdx.x; // 0..N-1\n  int idx = bid * blockDim.x + threadIdx.x;\n  const float epsilon = 1e-5f;\n\n  __shared__ float s_mean; // shared within block\n  __shared__ float s_variance; // shared within block\n  float value = (idx \u003c N * K) ? x[idx] : 0.0f; // load once only\n  float sum = block_reduce_sum\u003cNUM_THREADS\u003e(value);\n  if (tid == 0) s_mean = sum / (float) K;\n  // wait for s_mean in shared memory to be ready for all threads\n  __syncthreads();\n  float variance = (value - s_mean) * (value - s_mean);\n  variance = block_reduce_sum\u003cNUM_THREADS\u003e(variance);\n  if (tid == 0) s_variance = rsqrtf(variance / (float) K + epsilon);\n  // wait for s_variance in shared memory to be ready for all threads\n  __syncthreads();\n  if (idx \u003c N * K) y[idx] = ((value - s_mean) * s_variance) * g + b;\n}\n\n// Layer Norm Vec4: x: NxK(K=128\u003c1024), y': NxK, y'=x-mean(x)/std(x) each row\n// mean(x) = sum(x)/K, 1/std(x) = rsqrtf( sum( (x-mean(x))^2 )/K ) each row\n// grid(N*K/K), block(K/4\u003c1024) N=batch_size*seq_len, K=hidden_size\n// y=y'*g + b (g: scale, b: bias)\ntemplate\u003cconst int NUM_THREADS=128/4\u003e\n__global__ void layer_norm_vec4(float* x, float* y, float g, float b, int N, int K) {\n  int tid = threadIdx.x; // 0..K-1\n  int bid = blockIdx.x; // 0..N-1\n  int idx = (bid * blockDim.x + threadIdx.x) * 4;\n  const float epsilon = 1e-5f;\n\n  __shared__ float s_mean; // shared within block\n  __shared__ float s_variance; // shared within block\n  float4 reg_x = FLOAT4(x[idx])\n  float value = (idx \u003c N * K) ? (reg_x.x + reg_x.y \n                               + reg_x.z + reg_x.w) : 0.0f;\n  float sum = block_reduce_sum\u003cNUM_THREADS\u003e(value);\n  if (tid == 0) s_mean = sum / (float) K;\n  // wait for s_mean in shared memory to be ready for all threads\n  __syncthreads();\n  float4 reg_x_hat;\n  reg_x_hat.x = reg_x.x - s_mean;\n  reg_x_hat.y = reg_x.y - s_mean;\n  reg_x_hat.z = reg_x.z - s_mean;\n  reg_x_hat.w = reg_x.w - s_mean;\n  float variance = reg_x_hat.x * reg_x_hat.x + reg_x_hat.y * reg_x_hat.y \n                 + reg_x_hat.z * reg_x_hat.z + reg_x_hat.w * reg_x_hat.w;\n  variance = block_reduce_sum\u003cNUM_THREADS\u003e(variance);\n  if (tid == 0) s_variance = rsqrtf(variance / (float) K + epsilon);\n  // wait for s_variance in shared memory to be ready for all threads\n  __syncthreads();\n  float4 reg_y;\n  reg_y.x = reg_x_hat.x * s_variance * g + b;\n  reg_y.y = reg_x_hat.y * s_variance * g + b;\n  reg_y.z = reg_x_hat.z * s_variance * g + b;\n  reg_y.w = reg_x_hat.w * s_variance * g + b;\n  if (idx \u003c N * K) FLOAT4(y[idx]) = reg_y;\n}\n```\nlayer norm实现的核心同样也是block reduce和warp reduce，然后再整点向量化...\n\n## 0x0d rms_norm, rms_norm + vec4   ([©️back👆🏻](#kernellist))\n\u003cdiv id=\"rmsnorm\"\u003e\u003c/div\u003e  \n\n```c++\n// RMS Norm: x: NxK(K=128\u003c1024), y': NxK, y'=x/rms(x) each row\n// 1/rms(x) = rsqrtf( sum(x^2)/K ) each row\n// grid(N*K/K), block(K\u003c1024) N=batch_size*seq_len, K=hidden_size\n// y=y'*g (g: scale)\ntemplate\u003cconst int NUM_THREADS=128\u003e\n__global__ void rms_norm(float* x, float* y, float g, int N, int K) {\n  int tid = threadIdx.x; // 0..K-1\n  int bid = blockIdx.x; // 0..N-1\n  int idx = bid * blockDim.x + threadIdx.x;\n  const float epsilon = 1e-5f;\n\n  __shared__ float s_variance; // shared within block\n  float value = (idx \u003c N * K) ? x[idx] : 0.0f; // load once only\n  float variance = value * value;\n  variance = block_reduce_sum\u003cNUM_THREADS\u003e(variance);\n  if (tid == 0) s_variance = rsqrtf(variance / (float) K + epsilon);\n  // wait for s_variance in shared memory to be ready for all threads\n  __syncthreads(); \n  if (idx \u003c N * K) y[idx] = (value * s_variance) * g;\n}\n\n// RMS Norm Vec4: x: NxK(K=128\u003c1024), y': NxK, y'=x/rms(x) each row\n// 1/rms(x) = rsqrtf( sum(x^2)/K ) each row\n// grid(N*K/K), block(K/4\u003c1024) N=batch_size*seq_len, K=hidden_size\n// y=y'*g (g: scale)\ntemplate\u003cconst int NUM_THREADS=128/4\u003e\n__global__ void rms_norm_vec4(float* x, float* y, float g, int N, int K) {\n  int tid = threadIdx.x; // 0..K-1\n  int bid = blockIdx.x; // 0..N-1\n  int idx = (bid * blockDim.x + threadIdx.x) * 4;\n  const float epsilon = 1e-5f;\n\n  __shared__ float s_variance; // shared within block\n  float4 reg_x = FLOAT4(x[idx]);\n  float variance = (idx \u003c N * K) ? (reg_x.x * reg_x.x + reg_x.y * reg_x.y \n                                  + reg_x.z * reg_x.z + reg_x.w * reg_x.w) : 0.0f;\n  variance = block_reduce_sum\u003cNUM_THREADS\u003e(variance);\n  if (tid == 0) s_variance = rsqrtf(variance / (float) K + epsilon);\n  // wait for s_variance in shared memory to be ready for all threads\n  __syncthreads(); \n  float4 reg_y;\n  reg_y.x = reg_x.x * s_variance * g;\n  reg_y.y = reg_x.y * s_variance * g;\n  reg_y.z = reg_x.z * s_variance * g;\n  reg_y.w = reg_x.w * s_variance * g;\n  if (idx \u003c N * K) FLOAT4(y[idx]) = reg_y;\n}\n```\nrms norm实现的核心同样也是block reduce和warp reduce...，然后再加点float4向量化什么的。\n\n## 0x0e NMS  ([©️back👆🏻](#kernellist))\n\u003cdiv id=\"NMS\"\u003e\u003c/div\u003e  \n\n```c++\nstruct Box {\n  float x1, y1, x2, y2, score;\n  float area() const {return (std::abs(x2 - x1 + 1)) * std::abs(y2 - y1 + 1); }\n  float iou_of(const Box\u0026 other) const{\n    float inner_x1 = x1 \u003e other.x1 ? x1 : other.x1;\n    float inner_y1 = y1 \u003e other.y1 ? y1 : other.y1;\n    float inner_x2 = x2 \u003c other.x2 ? x2 : other.x2;\n    float inner_y2 = y2 \u003c other.y2 ? y2 : other.y2;\n    float inner_h = inner_y2 - inner_y1 + 1.0f;\n    float inner_w = inner_x2 - inner_x1 + 1.0f;\n    float inner_area = inner_h * inner_w;\n    return (inner_area / (area() + tbox.area() - inner_area));\n  }\n}\nvoid hard_nms(std::vector\u003cBox\u003e \u0026input, std::vector\u003cBox\u003e \u0026output, float iou_threshold){\n  if (input.empty()) return;\n  std::sort(input.begin(), input.end(),[](Box\u0026 a, Box\u0026 b) { return a.score \u003e b.score; });\n  int box_num = input.size();\n  std::vector\u003cint\u003e merged(box_num, 0);\n  for (int i = 0; i \u003c box_num; ++i) {\n    if (merged[i]) continue;\n    merged[i] = 1;\n    for (int j = i + 1; j \u003c box_num; ++j) {\n      if (merged[j]) continue;\n      float iou = input[i].iou_of(input[j]);\n      if (iou \u003e iou_threshold) merged[j] = 1;\n    }\n    output.push_back(input[i]);\n  }\n}\n```\nCV相关的经常会要手撕NMS，也记录下。\n\n## 0x0f 总结  ([©️back👆🏻](#kernellist))\n可以发现，大部分kernel的基本写法都是依赖warp reduce和block reduce的，基本上只要熟练应用warp functions各种场景的写法，应该问题不大；softmax需要考虑网格级同步的问题，或者online softmax以及FlashAttention；sgemm的优化是个很大的课题，不是案例中写的这么简单，但是入门的话，基本就是tiling的思想以及如何做索引之间的mapping；sgemv的优化则主要考虑K不同的值（因为M为1了），比如K=16,64,128等情况下，如何按照warp来处理；relu、sigmoid等都是elementwise的操作，很好实现，可以再考虑加点向量化优化访存；layer norm和rms norm在数学上其实也是挺清晰简单的，落实到cuda kernel时，只要按照逐个token来处理，headdim没有超过1024的情况下（一个block最多可以放1024个threads），可以放到一个block处理，这样并行化就很好写。当然，核心还是warp reduce和block reduce；NMS是乱入的，没有CUDA版本，别问了...\n\n## ©️License\nGNU General Public License v3.0\n\n## References  \n- [flash-attention-minimal](https://github.com/tspeterkim/flash-attention-minimal): Flash Attention in ~100 lines of CUDA (forward pass only)\n\n## 🎉Contribute\n🌟如果觉得有用，不妨给个🌟👆🏻Star支持一下吧~\n\n\u003cdiv align='center'\u003e\n\u003ca href=\"https://star-history.com/#DefTruth/Awesome-LLM-Inference\u0026Date\"\u003e\n  \u003cpicture align='center'\u003e\n    \u003csource media=\"(prefers-color-scheme: dark)\" srcset=\"https://api.star-history.com/svg?repos=DefTruth/cuda-learn-note\u0026type=Date\u0026theme=dark\" /\u003e\n    \u003csource media=\"(prefers-color-scheme: light)\" srcset=\"https://api.star-history.com/svg?repos=DefTruth/cuda-learn-note\u0026type=Date\" /\u003e\n    \u003cimg width=450 height=300 alt=\"Star History Chart\" src=\"https://api.star-history.com/svg?repos=DefTruth/cuda-learn-note\u0026type=Date\" /\u003e\n  \u003c/picture\u003e\n\u003c/a\u003e  \n\u003c/div\u003e\n\n","funding_links":[],"categories":["Cuda","Learning Resources 📚","Learning Resources"],"sub_categories":["Blogs 🖋️"],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FDefTruth%2FCUDA-Learn-Notes","html_url":"https://awesome.ecosyste.ms/projects/github.com%2FDefTruth%2FCUDA-Learn-Notes","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FDefTruth%2FCUDA-Learn-Notes/lists"}