Source code for habitat.profiling.backward
import torch
[docs]class BackwardHelper:
def __init__(self, backward_runnable, ag_dict):
self.run_backward = backward_runnable
self._ag_dict = ag_dict
@classmethod
[docs] def new_from(cls, operation_outputs):
retval, initial_grad_fn = get_grad_fn(operation_outputs)
if initial_grad_fn is None:
raise ValueError('No grad_fn available on the operation output.')
grads = torch.ones_like(retval)
def backward_runnable():
torch.autograd.backward(retval, grads, retain_graph=True)
size_dict = get_accumulate_grad_inputs(
initial_grad_fn,
backward_runnable,
)
ag_dict = {
grad_fn: torch.randn(size, device=torch.device('cuda'))
for grad_fn, size in size_dict.items()
}
return cls(backward_runnable, ag_dict)
[docs] def run_accumulate_grad(self):
for grad_fn, grad in self._ag_dict.items():
grad_fn(grad)
[docs]def backward_available(operation_output):
return get_grad_fn(operation_output)[1] is not None
[docs]def flatten_operation_output(operation_output):
if isinstance(operation_output, torch.Tensor):
return [operation_output]
elif (not isinstance(operation_output, tuple) and
not isinstance(operation_output, list)):
return []
flattened = []
for value in operation_output:
flattened.extend(flatten_operation_output(value))
return flattened
[docs]def get_grad_fn(retval):
if isinstance(retval, torch.Tensor) and retval.grad_fn is not None:
return retval, retval.grad_fn
elif isinstance(retval, tuple) or isinstance(retval, list):
for inner_value in retval:
inner_retval, grad_fn = get_grad_fn(inner_value)
if grad_fn is not None:
return inner_retval, grad_fn
return None, None