Efficient Batched Inference in Conditional Neural Networks
Abstract
Conditional neural networks (NNs) are networks in which the computations performed vary based on the input. Many NNs of interest (such as autoregressive transformers for sequence generation tasks) are inherently conditional since they process variable-length inputs or produce variable-length outputs. In addition, popular NN optimization techniques, such as early exit, result in the computational footprint varying across inputs. Computational irregularity across inputs presents a challenge to batching, a technique widely used to improve hardware utilization and throughput during NN inference. To address this challenge, we propose BatchCond, an optimized batching framework for Conditional NNs that consists of two key steps: 1) computational similarity-driven batching (SimBatch) and 2) adaptive batch reorganization (ABR). SimBatch utilizes a lightweight DNN predictor to create batches of inputs that are more likely to share similar computational patterns, thereby reducing computational irregularity. Further, ABR addresses residual irregularity by dynamically splitting batches into computationally similar sub-batches in a hardware-aware manner. Our experiments demonstrate that BatchCond improves the overall throughput of batched inference by up to 6.6× (mean of 2.5× ) across a suite of diverse Conditional NNs, including early-exit networks, dynamic slimmable networks, and autoregressive transformers. Code is available at https://github.com/surya00060/BatchCond.