Source code for habitat.profiling.autograd

import torch

from habitat.profiling.backward import get_grad_fn, flatten_operation_output


[docs]class AutogradEngine: """ Emulates the backward pass for a given model output, for timing purposes. """ def __init__(self, grad_fn_ordering, input_map, initial_inputs): self._grad_fn_ordering = grad_fn_ordering self._input_holder = { fn: [None] * size for fn, size in input_map.items() } self._input_holder[self._grad_fn_ordering[0]] = initial_inputs @classmethod
[docs] def new_from(cls, operation_output, exclude_accumulate_grad=True): # Traverse the autograd graph, build input map for each grad_fn and # create a topological ordering _, initial_grad_fn = get_grad_fn(operation_output) if initial_grad_fn is None: raise ValueError('No grad_fn available on the operation output.') ordering = [] input_map = {} initial_inputs = [ tensor.detach() for tensor in flatten_operation_output(operation_output) ] input_map[initial_grad_fn] = len(initial_inputs) stack = [(initial_grad_fn, 0)] visited = {initial_grad_fn} # Build a topological ordering while len(stack) > 0: grad_fn, visit_count = stack.pop() if visit_count != 0: ordering.append(grad_fn) continue stack.append((grad_fn, 1)) for next_fn, input_idx in grad_fn.next_functions: if next_fn is None: continue if (exclude_accumulate_grad and next_fn.name() == 'torch::autograd::AccumulateGrad'): continue # Keep track of the inputs to each grad_fn if next_fn not in input_map: input_map[next_fn] = 1 input_map[next_fn] = max(input_map[next_fn], input_idx + 1) # Determine whether to visit this grad_fn if next_fn in visited: continue visited.add(next_fn) stack.append((next_fn, 0)) ordering.reverse() return cls(ordering, input_map, initial_inputs)
[docs] def run_backward(self): for grad_fn in self._grad_fn_ordering: # 1. Run the backward function outputs = grad_fn(*(self._input_holder[grad_fn])) # 2. Store its outputs for the next backward function(s) if isinstance(outputs, torch.Tensor): outputs = [outputs] for (output, (next_fn, input_idx)) in zip( outputs, grad_fn.next_functions): if next_fn is None or next_fn not in self._input_holder: continue # NOTE: If implementing to actually calculate the gradient, we # need to sum gradients that "flow" into the same grad function # input. self._input_holder[next_fn][input_idx] = output