Skip to content
Snippets Groups Projects
Commit baea7549 authored by Carlos Muniz's avatar Carlos Muniz
Browse files

STAN: Fix problems with TranAD model

parent 2477b0b4
No related branches found
No related tags found
No related merge requests found
......@@ -134,6 +134,28 @@ chmod +x ./scripts/install_mdi.sh
cmake -D CMAKE_CXX_FLAGS=-I$EIGEN_HOME ..
```
- **`RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!`**
- Solution 1: Edit `is-it-worth-it-benchmark/lib/tranad/main.py`, line 308:
```python
# Add these two lines of code
device = next(model.parameters()).device
window, elem = window.to(device), elem.to(device)
```
- Solution 2: Edit `/home/cmcuza/btw2025/is-it-worth-it-benchmark/lib/tranad/src/dlutils.py`, line 253:
```python
x = x + self.pe.to(x.device)[pos : pos + x.size(0), :] # move self.pe to x.device
```
- Solution 3: Edit `is-it-worth-it-benchmark/lib/tranad/src/dlutils.py`, line 271:
```python
src = src.to(next(self.parameters()).device) # move src to attn layer device
```
- **`TypeError: forward() got an unexpected keyword argument 'is_causal'`**
- Solution: Edit `is-it-worth-it-benchmark/lib/tranad/src/dlutils.py`, line 270:
```python
def forward(self, src, src_mask=None, src_key_padding_mask=None, **kwargs): # add **kwargs
```
## License
This repository is licensed under the [Apache 2.0](LICENSE) license.
name: btw2025
channels:
- pytorch
- nvidia
- conda-forge
- defaults
dependencies:
......@@ -60,6 +62,13 @@ dependencies:
- contourpy=1.3.0=py39h74842e3_2
- cpython=3.9.21=py39hd8ed1ab_1
- cryptography=44.0.1=py39h7170ec2_0
- cuda-cudart=12.1.105=0
- cuda-cupti=12.1.105=0
- cuda-libraries=12.1.0=0
- cuda-nvrtc=12.1.105=0
- cuda-nvtx=12.1.105=0
- cuda-opencl=12.4.127=0
- cuda-runtime=12.1.0=0
- cuda-version=11.8=h70ddcb2_3
- cudatoolkit=11.8.0=h4ba93d1_13
- cudnn=8.9.7.29=hbc23b4c_3
......@@ -74,6 +83,7 @@ dependencies:
- entrypoints=0.4=pyhd8ed1ab_1
- exceptiongroup=1.2.2=pyhd8ed1ab_1
- executing=2.1.0=pyhd8ed1ab_1
- ffmpeg=4.3=hf484d3e_0
- filelock=3.17.0=pyhd8ed1ab_0
- flask=3.1.0=pyhff2d567_0
- flatbuffers=24.3.25=h59595ed_0
......@@ -90,6 +100,7 @@ dependencies:
- glog=0.7.1=hbabe93e_0
- gmp=6.3.0=hac33072_2
- gmpy2=2.1.5=py39h7196dd7_3
- gnutls=3.6.15=he1e5248_0
- google-auth=2.38.0=pyhd8ed1ab_0
- google-pasta=0.2.0=pyhd8ed1ab_2
- graphene=3.4.3=pyhd8ed1ab_1
......@@ -139,6 +150,7 @@ dependencies:
- keyutils=1.6.1=h166bdaf_0
- kiwisolver=1.4.7=py39h74842e3_0
- krb5=1.21.3=h659f571_0
- lame=3.100=h7b6447c_0
- lcms2=2.16=hb7c19ff_0
- ld_impl_linux-64=2.40=h12ee557_0
- lerc=4.0.0=h27087fc_0
......@@ -154,7 +166,13 @@ dependencies:
- libbrotlienc=1.1.0=hb9d3cd8_2
- libcblas=3.9.0=28_he106b2a_openblas
- libcrc32c=1.1.2=h9c3ff4c_0
- libcublas=12.1.0.26=0
- libcufft=11.0.2.4=0
- libcufile=1.9.1.3=0
- libcurand=10.3.5.147=0
- libcurl=8.8.0=hca28451_1
- libcusolver=11.4.4.55=0
- libcusparse=12.0.2.55=0
- libdeflate=1.20=hd590300_0
- libedit=3.1.20191231=he28a2e2_2
- libev=4.33=hd590300_2
......@@ -170,12 +188,17 @@ dependencies:
- libgoogle-cloud-storage=2.24.0=h3d9a0c8_0
- libgrpc=1.62.2=h15f2491_0
- libhwloc=2.11.2=default_h0d58e46_1001
- libiconv=1.16=h5eee18b_3
- libidn2=2.3.4=h5eee18b_0
- libjpeg-turbo=3.0.0=hd590300_1
- liblapack=3.9.0=28_h7ac8fdf_openblas
- liblapacke=3.9.0=28_he2f377e_openblas
- libmagma=2.7.2=h09b5827_2
- libmagma_sparse=2.7.2=h09b5827_3
- libnghttp2=1.58.0=h47da74e_1
- libnpp=12.0.2.50=0
- libnvjitlink=12.1.105=0
- libnvjpeg=12.1.1.14=0
- libopenblas=0.3.28=pthreads_h94d23a6_1
- libparquet=16.1.0=h6a7eafb_6_cpu
- libpng=1.6.43=h2797004_0
......@@ -186,9 +209,11 @@ dependencies:
- libssh2=1.11.0=h0841786_0
- libstdcxx=14.2.0=hc0a3c3a_1
- libstdcxx-ng=14.2.0=h4852527_1
- libtasn1=4.19.0=h5eee18b_0
- libthrift=0.19.0=hb90f79a_1
- libtiff=4.6.0=h1dd3fc0_3
- libtorch=2.3.1=cuda118_h7aef8b2_300
- libunistring=0.9.10=h27cfd23_0
- liburing=2.7=h434a139_0
- libutf8proc=2.8.0=hf23e847_1
- libuv=1.50.0=hb9d3cd8_0
......@@ -225,10 +250,12 @@ dependencies:
- nccl=2.25.1.1=h03a54cd_0
- ncurses=6.4=h6a678d5_0
- nest-asyncio=1.6.0=pyhd8ed1ab_1
- nettle=3.7.3=hbbd107a_1
- networkx=3.2.1=pyhd8ed1ab_0
- notebook-shim=0.2.4=pyhd8ed1ab_1
- numpy=1.26.4=py39h474f0d3_0
- omegaconf=2.3.0=pyhd8ed1ab_0
- openh264=2.1.1=h4ff587b_0
- openjpeg=2.5.2=h488ebb8_0
- openssl=3.4.1=h7b32b05_0
- opentelemetry-api=1.16.0=pyhd8ed1ab_0
......@@ -279,6 +306,9 @@ dependencies:
- python-json-logger=2.0.7=pyhd8ed1ab_0
- python-tzdata=2025.1=pyhd8ed1ab_0
- python_abi=3.9=2_cp39
- pytorch=2.3.1=cuda118_py39hd3e083d_300
- pytorch-cuda=12.1=ha16c6d3_6
- pytorch-mutex=1.0=cuda
- pytz=2024.1=pyhd8ed1ab_0
- pyu2f=0.1.5=pyhd8ed1ab_1
- pywin32-on-windows=0.1.0=pyh1179c8e_3
......@@ -327,7 +357,9 @@ dependencies:
- tinycss2=1.4.0=pyhd8ed1ab_0
- tk=8.6.14=h39e8969_0
- tomli=2.2.1=pyhd8ed1ab_1
- torchaudio=2.3.1=py39_cu121
- torchdata=0.7.1=py39hc552c7e_6
- torchvision=0.18.1=py39_cu121
- tornado=6.4.2=py39h8cd3c5a_0
- tqdm=4.67.1=pyhd8ed1ab_1
- traitlets=5.14.3=pyhd8ed1ab_1
......@@ -362,6 +394,7 @@ dependencies:
- cmaes==0.11.1
- cmd2==2.5.11
- hydra-optuna-sweeper==1.2.0
- merlin==1.0.2
- nvidia-cublas-cu12==12.4.5.8
- nvidia-cuda-cupti-cu12==12.4.127
- nvidia-cuda-nvrtc-cu12==12.4.127
......@@ -382,8 +415,5 @@ dependencies:
- rrcf==0.4.4
- stevedore==5.4.0
- sympy==1.13.1
- torch==2.6.0
- torchaudio==2.6.0
- torchvision==0.21.0
- triton==3.2.0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment