Learned Threshold Pruning

Kambiz Azarian, Yash Bhalgat, Jinwon Lee, Tijmen Blankevoort (or view on SciRate)

Main idea

We propose an end-to-end differentiable method for learning layerwise pruning thresholds which results in SOTA model compression ratios with AlexNet, ResNet and EfficientNet. Our method also generates a trail of checkpoints with different accuracy-efficiency operating points.


Most unstructured weight pruning methods are based on the assumption that smaller weights do not contribute as much to the model’s performance. These pruning methods iteratively prune the weights that are smaller than a certain threshold and retrain the network to regain the performance lost during pruning. A key challenge in unstructured pruning is to find an optimal setting for these pruning thresholds. Merely setting the same threshold for all layers may not be appropriate because the distribution and ranges of the weights in each layer can be very different. Also, different layers may have varying sensitivities to pruning, depending on their position in the network (initial layers versus final layers) or their type (depth-wise separable versus standard convolutional layers). The best setting of thresholds should consider these layer-wise characteristics. Many methods Zhang et al. (2018); Ye et al. (2019); Manessi et al. (2018) propose a way to search these layer-wise thresholds but become quite computationally expensive for networks with a large number of layers, such as ResNet50 or EfficientNet. In this paper, we propose Learned Threshold Pruning (LTP) to address these challenges.


Our learned-threshold pruning (LTP) method learns per-layer thresholds via gradient descent, unlike conventional methods where they are set as input. Making thresholds trainable also makes LTP computationally efficient, hence scalable to deeper networks. For example, it takes 30 epochs for LTP to prune ResNet50 on ImageNet by a factor of 9.1. This is in contrast to other methods that search for per-layer thresholds via a computationally intensive iterative pruning and fine-tuning process. Additionally, with a novel differentiable L0 regularization, LTP is able to operate effectively on architectures with batch-normalization. This is important since L1 and L2 penalties lose their regularizing effect in networks with batch-normalization. Finally, LTP generates a trail of progressively sparser networks from which the desired pruned network can be picked based on sparsity and performance requirements. These features allow LTP to achieve competitive compression rates on ImageNet networks such as AlexNet (26.4× compression with 79.1% Top-5 accuracy) and ResNet50 (9.1× compression with 92.0% Top-5 accuracy). We also show that LTP effectively prunes modern \textit{compact} architectures, such as EfficientNet, MobileNetV2 and MixNet.

tags: Model-Compression  Efficient-Deep-Learning  Model-Pruning