Skip to content

Commit 9fed09f

Browse files
author
Hope Woods
committed
Include safe NVTX through the rest of basis.py
1 parent 167ec9a commit 9fed09f

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

  • rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model

rf_diffusion/RF2-allatom/rf2aa/SE3Transformer/se3_transformer/model/basis.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ def get_basis(relative_pos: Tensor,
178178
compute_gradients: bool = False,
179179
use_pad_trick: bool = False,
180180
amp: bool = False) -> Dict[str, Tensor]:
181-
with nvtx_range('spherical harmonics'):
181+
with safe_nvtx_range('spherical harmonics'):
182182
spherical_harmonics = get_spherical_harmonics(relative_pos, max_degree)
183-
with nvtx_range('CB coefficients'):
183+
with safe_nvtx_range('CB coefficients'):
184184
clebsch_gordon = get_all_clebsch_gordon(max_degree, relative_pos.device)
185185

186186
with torch.autograd.set_grad_enabled(compute_gradients):
187-
with nvtx_range('bases'):
187+
with safe_nvtx_range('bases'):
188188
basis = get_basis_script(max_degree=max_degree,
189189
use_pad_trick=use_pad_trick,
190190
spherical_harmonics=spherical_harmonics,

0 commit comments

Comments
 (0)