https://github.com/lartpang/rotarypositionembedding
https://github.com/lartpang/rotarypositionembedding
position-embedding rope
Last synced: about 1 month ago
JSON representation
- Host: GitHub
- URL: https://github.com/lartpang/rotarypositionembedding
- Owner: lartpang
- License: mit
- Created: 2024-03-27T06:44:22.000Z (about 2 years ago)
- Default Branch: main
- Last Pushed: 2024-03-28T04:32:24.000Z (about 2 years ago)
- Last Synced: 2025-03-13T07:12:56.896Z (about 1 year ago)
- Topics: position-embedding, rope
- Language: Jupyter Notebook
- Homepage:
- Size: 59.6 KB
- Stars: 1
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# RotaryPositionEmbedding
## RoPE1D
```python
torch.manual_seed(1024)
torch.cuda.manual_seed(1024)
rope_1d = RoPE1D(dim=512).cuda()
x = torch.randn(3, 64 * 64, 512, dtype=torch.float32, device="cuda")
start = time.perf_counter()
for _ in range(100):
x1 = rope_1d.forward_v1(x)
torch.cuda.synchronize()
print(f"forward_v1 ({(time.perf_counter() - start) / 100}s): output mean: {x1.mean()}")
start = time.perf_counter()
for _ in range(100):
x2 = rope_1d.forward_v2(x)
torch.cuda.synchronize()
print(f"forward_v2 ({(time.perf_counter() - start) / 100}s): output mean: {x2.mean()}")
start = time.perf_counter()
for _ in range(100):
x3 = rope_1d.forward_v3(x)
torch.cuda.synchronize()
print(f"forward_v3 ({(time.perf_counter() - start) / 100}s): output mean: {x3.mean()}")
start = time.perf_counter()
for _ in range(100):
x4 = rope_1d.forward_v4(x)
torch.cuda.synchronize()
print(f"forward_v4 ({(time.perf_counter() - start) / 100}s): output mean: {x4.mean()}")
start = time.perf_counter()
for _ in range(100):
x_llama = rope_1d.forward_llama(x)
torch.cuda.synchronize()
print(f"forward_llama ({(time.perf_counter() - start) / 100}s): output mean: {x_llama.mean()}")
start = time.perf_counter()
for _ in range(100):
x_palm = rope_1d.forward_palm(x)
torch.cuda.synchronize()
print(f"forward_palm ({(time.perf_counter() - start) / 100}s): output mean: {x_palm.mean()}")
'''
forward_v1 (0.0007928510010242462s): output mean: 0.00037511205300688744
forward_v2 (0.0005496779992245137s): output mean: 0.00037511205300688744
forward_v3 (0.0005035390006378293s): output mean: 0.00037511205300688744
forward_v4 (0.0019484849995933472s): output mean: 0.0003751121403183788
forward_llama (0.0001582980016246438s): output mean: 0.00037511205300688744
forward_palm (0.004226168000604958s): output mean: 0.00037511205300688744
'''
```
## RoPE2D
```python
rope_2d = RoPE2D(dim=512).cuda()
x = torch.randn(3, 64, 64, 512, dtype=torch.float32, device="cuda")
start = time.perf_counter()
for _ in range(100):
x1 = rope_2d.forward_v1(x)
torch.cuda.synchronize()
print(f"forward_v1 ({(time.perf_counter() - start) / 100}s): output mean: {x1.mean()}")
start = time.perf_counter()
for _ in range(100):
x2 = rope_2d.forward_v2(x)
torch.cuda.synchronize()
print(f"forward_v2 ({(time.perf_counter() - start) / 200}s): output mean: {x2.mean()}")
"""
forward_v1 (0.00029631400015205144s): output mean: 0.00021492868836503476
forward_v2 (0.00013582399929873646s): output mean: 0.00021492868836503476
"""
```
## Reference
- https://blog.csdn.net/PennyYu123/article/details/131717323