Source code for habitat.profiling.kernel

import logging
import habitat.habitat_cuda as hc

from habitat.analysis import SPECIAL_OPERATIONS
from habitat.analysis.metrics import resolve_metrics
from habitat.analysis.kernels import MeasuredKernel

[docs]logger = logging.getLogger(__name__)
[docs]class KernelProfiler: def __init__(self, device, metrics=None, metrics_threshold_ms=0): self._device = device self._metrics = resolve_metrics(metrics, self._device) self._metrics_threshold_ns = metrics_threshold_ms * 1000000
[docs] def measure_kernels(self, runnable, func_name=None): """ Uses CUPTI to measure the kernels launched by runnable. Returns: A list of MeasuredKernels """ if func_name is None: fname = ( runnable.__name__ if hasattr(runnable, "__name__") else "Unnamed" ) else: fname = func_name return list(map( lambda ks: MeasuredKernel(ks[0], ks[1], self._device), self._measure_kernels_raw(runnable, fname)
))
[docs] def _measure_kernels_raw(self, runnable, func_name): """ Uses CUPTI to measure the kernels launched by runnable. Returns: A list of tuples, where - tuple[0] is the raw kernel measurement that should be used for the kernel's run time - tuple[1] is a list of the raw kernel measurements that contain the metrics requested """ time_kernels = hc.profile(runnable) if (len(self._metrics) == 0 or func_name in SKIP_METRICS or func_name in SPECIAL_OPERATIONS or self._under_threshold(time_kernels)): return list(map(lambda tk: (tk, []), time_kernels)) try: metric_kernels = [ hc.profile(runnable, metric) for metric in self._metrics ] # Make sure the same number of kernels are recorded for each metric assert all(map( lambda ks: len(ks) == len(metric_kernels[0]), metric_kernels, )) # metric_kernels is originally (# metrics x # kernels in op) # we need to transpose it to become (# kernels in op x # metrics) # so that we can join kernels with their metrics. transposed = map(list, zip(*metric_kernels)) # We return a list of (time kernel, [metric kernels]) return list(zip(time_kernels, transposed)) except RuntimeError as ex: logger.warn( 'Metrics error "%s" for function "%s".', str(ex), func_name, ) return list(map(lambda tk: (tk, []), time_kernels))
[docs] def _under_threshold(self, kernels): # If under threshold, don't measure metrics return ( sum(map(lambda k: k.run_time_ns, kernels)) <= self._metrics_threshold_ns
)
[docs]SKIP_METRICS = { "detach_",
}