Source code for habitat.analysis.predictor

import functools
import logging
import operator
import os

from habitat.analysis import SPECIAL_OPERATIONS
from habitat.analysis.operation import PredictedOperation
from habitat.analysis.run_time import RunTimePrediction, RunTimePurePrediction
from habitat.analysis.wave_scaling.metadata import MetadataManager
from habitat.analysis.wave_scaling.unified import unified_wave_scaling
from habitat.data import path_to_data
from habitat.utils import ms_to_ns, name_all_arguments

from habitat.analysis.mlp.mlp import RuntimePredictor

[docs]logger = logging.getLogger(__name__)
[docs]CONV2D_PARAMS = [ 'input', 'weight', 'bias', 'stride', 'padding', 'dilation', 'groups',
]
[docs]CONVTRANSPOSE2D_PARAMS = [ 'input', 'weight', 'bias', 'stride', 'padding', 'dilation', 'groups',
]
[docs]LINEAR_PARAMS = ['input', 'weight', 'bias']
[docs]BMM_PARAMS = ['input', 'mat2', 'out']
[docs]LSTM_PARAMS_NO_BATCH_SIZES = [ 'input', 'hx', 'flat_weights', 'bias', 'num_layers', 'dropout', 'training', 'bidirectional', 'batch_first',
]
[docs]LSTM_PARAMS = [ 'input', 'batch_sizes', 'hx', 'flat_weights', 'bias', 'num_layers', 'dropout', 'training', 'bidirectional',
]
[docs]MATMUL_PARAMS = ['input', 'other', 'out']
[docs]class Predictor: def __init__( self, kernel_metadata_file=None, wave_scaling_strategy=unified_wave_scaling ): self._kernel_metadata = MetadataManager( kernel_metadata_file if kernel_metadata_file is not None else path_to_data("kernels.sqlite") ) self._wave_scaling_strategy = wave_scaling_strategy # Load MLP predictor from saved models self.linear_pred = RuntimePredictor( "linear", 8, 1024, path_to_data("linear/model.pth"), ) self.lstm_pred = RuntimePredictor( "lstm", 8, 1024, path_to_data("lstm/model.pth"), ) self.conv2d_pred = RuntimePredictor( "conv2d", 8, 1024, path_to_data("conv2d/model.pth"), ) self.bmm_pred = RuntimePredictor( "bmm", 8, 1024, path_to_data("bmm/model.pth"), ) self.conv_transpose2d_pred = RuntimePredictor( "conv_transpose2d", 8, 1024, path_to_data("conv_transpose2d/model.pth"), )
[docs] def predict_operation(self, operation, dest_device): if operation.name not in SPECIAL_OPERATIONS: return PredictedOperation( operation, self._wave_scale(operation.forward, dest_device), (self._wave_scale(operation.backward, dest_device) if operation.backward is not None else None), dest_device, ) if operation.name == 'conv2d': return self._special_scale(operation, dest_device, self._conv2d_scale) elif operation.name == 'lstm': return self._special_scale(operation, dest_device, self._lstm_scale) elif operation.name == 'linear': return self._special_scale(operation, dest_device, self._linear_scale) elif operation.name == 'bmm': return self._special_scale(operation, dest_device, self._bmm_scale) elif operation.name == 'conv_transpose2d': return self._special_scale(operation, dest_device, self._conv_transpose2d_scale) logger.warn('Unhandled special operation: %s', operation.name) return PredictedOperation( operation, operation.forward, operation.backward, dest_device,
)
[docs] def _wave_scale(self, run_time, dest_device): run_time_ns = ms_to_ns(run_time.run_time_ms) total_ktime_ns = sum(map(lambda k: k.run_time_ns, run_time.kernels)) overhead_ns = run_time_ns - total_ktime_ns predicted_kernels = list(map( lambda kernel: self._wave_scaling_strategy( kernel, run_time.device, dest_device, self._kernel_metadata, ), run_time.kernels, )) return RunTimePrediction( overhead_ns=0 if overhead_ns < 0 else overhead_ns, predicted_kernels=predicted_kernels, device=dest_device,
)
[docs] def _special_scale(self, operation, dest_device, scaler): predicted_ms = scaler(operation, dest_device) if predicted_ms < 0: logger.warn( 'Operation %s predicted run time %.2f ms', operation.name, predicted_ms, ) predicted_ms = 0. return PredictedOperation( operation, RunTimePurePrediction(predicted_ms, dest_device), None, dest_device,
)
[docs] def _conv2d_scale(self, operation, dest_device): # 1. Merge arguments (give them all names) merged = name_all_arguments( CONV2D_PARAMS, operation.arguments.args, operation.arguments.kwargs, ) # 2. Construct arguments that the predictor expects arguments = dict( batch=merged['input'][0], image_size=merged['input'][2], in_channels=merged['input'][1], out_channels=merged['weight'][0], kernel_size=merged['weight'][2], stride=( merged['stride'][0] if isinstance(merged['stride'], tuple) else merged['stride'] ), padding=( merged['padding'][0] if isinstance(merged['padding'], tuple) else merged['padding'] ), bias=(1 if merged['bias'] is not None else 0), ) # 3. Call model to make prediction arguments = [arguments[x] for x in self.conv2d_pred.model.features] pred_dest = self.conv2d_pred.predict(arguments, dest_device.name) pred_orig = self.conv2d_pred.predict(arguments, operation.device.name) return operation.run_time_ms * pred_dest / pred_orig
[docs] def _conv_transpose2d_scale(self, operation, dest_device): # 1. Merge arguments (give them all names) merged = name_all_arguments( CONVTRANSPOSE2D_PARAMS, operation.arguments.args, operation.arguments.kwargs, ) # 2. Construct arguments that the predictor expects arguments = dict( batch=merged['input'][0], image_size=merged['input'][2], in_channels=merged['input'][1], out_channels=merged['weight'][0], kernel_size=merged['weight'][2], stride=( merged['stride'][0] if isinstance(merged['stride'], tuple) else merged['stride'] ), padding=( merged['padding'][0] if isinstance(merged['padding'], tuple) else merged['padding'] ), bias=(1 if merged['bias'] is not None else 0), ) # 3. Call model to make prediction arguments = [arguments[x] for x in self.conv_transpose2d_pred.model.features] pred_dest = self.conv_transpose2d_pred.predict(arguments, dest_device.name) pred_orig = self.conv_transpose2d_pred.predict(arguments, operation.device.name) return operation.run_time_ms * pred_dest / pred_orig
[docs] def _linear_scale(self, operation, dest_device): merged = name_all_arguments( LINEAR_PARAMS, operation.arguments.args, operation.arguments.kwargs, ) # The input to the linear function in PyTorch can contain an arbitrary # number of dimensions between the batch size and the input feature # dimensions. # # e.g., The input can have size (32, 50, 512), where 32 is the batch # size and 512 is the input feature dimension. # # This means that the effective batch size is the product of all the # dimensions before the input feature dimension (e.g., 32 * 50 = 1600). # We need to take this into account when making a prediction. effective_batch = functools.reduce( operator.mul, merged['input'][:-1], ) arguments = dict( batch=effective_batch, in_features=merged['weight'][1], out_features=merged['weight'][0], bias=(1 if merged['bias'] is not None else 0) ) arguments = [arguments[x] for x in self.linear_pred.model.features] pred_dest = self.linear_pred.predict(arguments, dest_device.name) pred_orig = self.linear_pred.predict(arguments, operation.device.name) return operation.run_time_ms * pred_dest / pred_orig
[docs] def _bmm_scale(self, operation, dest_device): merged = name_all_arguments( BMM_PARAMS, operation.arguments.args, operation.arguments.kwargs, ) arguments = dict( batch=merged['input'][0], left=merged['input'][1], middle=merged['input'][2], right=merged['mat2'][2], ) arguments = [arguments[x] for x in self.bmm_pred.model.features] pred_dest = self.bmm_pred.predict(arguments, dest_device.name) pred_orig = self.bmm_pred.predict(arguments, operation.device.name) return operation.run_time_ms * pred_dest / pred_orig
[docs] def _lstm_scale(self, operation, dest_device): # This is hacky, but unfortunately the only way to differentiate these # overloaded LSTM calls. has_batch_sizes = isinstance(operation.arguments.args[4], bool) if not has_batch_sizes: merged = name_all_arguments( LSTM_PARAMS_NO_BATCH_SIZES, operation.arguments.args, operation.arguments.kwargs, ) arguments = dict( bias=(1 if merged['bias'] is not None else 0), bidirectional=(1 if merged['bidirectional'] else 0), batch=merged['input'][1], # We require the batch to be in position 1 seq_len=merged['input'][0], input_size=merged['input'][2], hidden_size=merged['hx'][0][2], num_layers=merged['num_layers'], ) else: merged = name_all_arguments( LSTM_PARAMS, operation.arguments.args, operation.arguments.kwargs, ) max_batch_size = max(operation.arguments.special['batch_sizes']) arguments = dict( bias=(1 if merged['bias'] is not None else 0), bidirectional=(1 if merged['bidirectional'] else 0), batch=max_batch_size, seq_len=merged['input'][0] // max_batch_size, input_size=merged['input'][1], hidden_size=merged['hx'][0][2], num_layers=merged['num_layers'], ) arguments = [arguments[x] for x in self.lstm_pred.model.features] pred_dest = self.lstm_pred.predict(arguments, dest_device.name) pred_orig = self.lstm_pred.predict(arguments, operation.device.name) return operation.run_time_ms * pred_dest / pred_orig