bitsandbytes package

bitsandbytes.functional module

CPU/GPU quantization functions and stateless optimizer functions.

bitsandbytes.functional.create_dynamic_map(signed=True, n=7)

Creates the dynamic quantiztion map.

The dynamic data type is made up of a dynamic exponent and fraction. As the exponent increase from 0 to -7 the number of bits available for the fraction shrinks.

This is a generalization of the dynamic type where a certain number of the bits and be reserved for the linear quantization region (the fraction). n determines the maximum number of exponent bits.

For more details see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]

bitsandbytes.functional.create_linear_map(signed=True)
bitsandbytes.functional.dequantize(A: torch.Tensor, quant_state: Tuple[torch.Tensor, torch.Tensor] = None, absmax: torch.Tensor = None, code: torch.Tensor = None, out: torch.Tensor = None) → torch.Tensor
bitsandbytes.functional.dequantize_blockwise(A: torch.Tensor, quant_state: Tuple[torch.Tensor, torch.Tensor] = None, absmax: torch.Tensor = None, code: torch.Tensor = None, out: torch.Tensor = None, blocksize: int = 4096) → torch.Tensor

Dequantizes blockwise quantized values.

Dequantizes the tensor A with maximum absolute values absmax in blocks of size 4096.

Parameters
  • A (torch.Tensor) – The input 8-bit tensor.

  • quant_state (tuple(torch.Tensor, torch.Tensor)) – Tuple of code and absmax values.

  • absmax (torch.Tensor) – The absmax values.

  • code (torch.Tensor) – The quantization map.

  • out (torch.Tensor) – Dequantized output tensor (default: float32)

Returns

Dequantized tensor (default: float32)

Return type

torch.Tensor

bitsandbytes.functional.dequantize_no_absmax(A: torch.Tensor, code: torch.Tensor, out: torch.Tensor = None) → torch.Tensor

Dequantizes the 8-bit tensor to 32-bit.

Dequantizes the 8-bit tensor A to the 32-bit tensor out via the quantization map code.

Parameters
  • A (torch.Tensor) – The 8-bit input tensor.

  • code (torch.Tensor) – The quantization map.

  • out (torch.Tensor) – The 32-bit output tensor.

Returns

32-bit output tensor.

Return type

torch.Tensor

bitsandbytes.functional.estimate_quantiles(A: torch.Tensor, out: torch.Tensor = None, offset: float = 0.001953125) → torch.Tensor

Estimates 256 equidistant quantiles on the input tensor eCDF.

Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles via the eCDF of the input tensor A. This is a fast but approximate algorithm and the extreme quantiles close to 0 and 1 have high variance / large estimation errors. These large errors can be avoided by using the offset variable which trims the distribution. The default offset value of 1/512 ensures minimum entropy encoding – it trims 1/512 = 0.2% from each side of the distrivution. An offset value of 0.01 to 0.02 usually has a much lower error but is not a minimum entropy encoding. Given an offset of 0.02 equidistance points in the range [0.02, 0.98] are used for the quantiles.

Parameters
  • A (torch.Tensor) – The input tensor. Any shape.

  • out (torch.Tensor) – Tensor with the 256 estimated quantiles.

  • offset (float) – The offset for the first and last quantile from 0 and 1. Default: 1/512

Returns

The 256 quantiles in float32 datatype.

Return type

torch.Tensor

bitsandbytes.functional.get_ptr(A: torch.Tensor) → ctypes.c_void_p

Get the ctypes pointer from a PyTorch Tensor.

Parameters

A (torch.tensor) – The PyTorch tensor.

Returns

Return type

ctypes.c_void_p

bitsandbytes.functional.histogram_scatter_add_2d(histogram: torch.Tensor, index1: torch.Tensor, index2: torch.Tensor, source: torch.Tensor)
bitsandbytes.functional.optimizer_update_32bit(optimizer_name: str, g: torch.Tensor, p: torch.Tensor, state1: torch.Tensor, beta1: float, eps: float, step: int, lr: float, state2: torch.Tensor = None, beta2: float = 0.0, weight_decay: float = 0.0, gnorm_scale: float = 1.0, unorm_vec: torch.Tensor = None, max_unorm: float = 0.0) → None

Performs an inplace optimizer update with one or two optimizer states.

Universal optimizer update for 32-bit state and 32/16-bit gradients/weights.

Parameters
  • optimizer_name (str) – The name of the optimizer: {adam}.

  • g (torch.Tensor) – Gradient tensor.

  • p (torch.Tensor) – Parameter tensor.

  • state1 (torch.Tensor) – Optimizer state 1.

  • beta1 (float) – Optimizer beta1.

  • eps (float) – Optimizer epsilon.

  • weight_decay (float) – Weight decay.

  • step (int) – Current optimizer step.

  • lr (float) – The learning rate.

  • state2 (torch.Tensor) – Optimizer state 2.

  • beta2 (float) – Optimizer beta2.

  • gnorm_scale (float) – The factor to rescale the gradient to the max clip value.

bitsandbytes.functional.optimizer_update_8bit(optimizer_name: str, g: torch.Tensor, p: torch.Tensor, state1: torch.Tensor, state2: torch.Tensor, beta1: float, beta2: float, eps: float, step: int, lr: float, qmap1: torch.Tensor, qmap2: torch.Tensor, max1: torch.Tensor, max2: torch.Tensor, new_max1: torch.Tensor, new_max2: torch.Tensor, weight_decay: float = 0.0, gnorm_scale: float = 1.0, unorm_vec: torch.Tensor = None, max_unorm: float = 0.0) → None

Performs an inplace Adam update.

Universal Adam update for 32/8-bit state and 32/16-bit gradients/weights. Uses AdamW formulation if weight decay > 0.0.

Parameters
  • optimizer_name (str) – The name of the optimizer. Choices {adam, momentum}

  • g (torch.Tensor) – Gradient tensor.

  • p (torch.Tensor) – Parameter tensor.

  • state1 (torch.Tensor) – Adam state 1.

  • state2 (torch.Tensor) – Adam state 2.

  • beta1 (float) – Adam beta1.

  • beta2 (float) – Adam beta2.

  • eps (float) – Adam epsilon.

  • weight_decay (float) – Weight decay.

  • step (int) – Current optimizer step.

  • lr (float) – The learning rate.

  • qmap1 (torch.Tensor) – Quantization map for first Adam state.

  • qmap2 (torch.Tensor) – Quantization map for second Adam state.

  • max1 (torch.Tensor) – Max value for first Adam state update.

  • max2 (torch.Tensor) – Max value for second Adam state update.

  • new_max1 (torch.Tensor) – Max value for the next Adam update of the first state.

  • new_max2 (torch.Tensor) – Max value for the next Adam update of the second state.

  • gnorm_scale (float) – The factor to rescale the gradient to the max clip value.

bitsandbytes.functional.optimizer_update_8bit_blockwise(optimizer_name: str, g: torch.Tensor, p: torch.Tensor, state1: torch.Tensor, state2: torch.Tensor, beta1: float, beta2: float, eps: float, step: int, lr: float, qmap1: torch.Tensor, qmap2: torch.Tensor, absmax1: torch.Tensor, absmax2: torch.Tensor, weight_decay: float = 0.0, gnorm_scale: float = 1.0) → None
bitsandbytes.functional.percentile_clipping(grad: torch.Tensor, gnorm_vec: torch.Tensor, step: int, percentile: int = 5)

Applies percentile clipping

grad: torch.Tensor

The gradient tensor.

gnorm_vec: torch.Tensor

Vector of gradient norms. 100 elements expected.

step: int

The current optimiation steps (number of past gradient norms).

bitsandbytes.functional.quantize(A: torch.Tensor, code: torch.Tensor = None, out: torch.Tensor = None) → torch.Tensor
bitsandbytes.functional.quantize_blockwise(A: torch.Tensor, code: torch.Tensor = None, absmax: torch.Tensor = None, rand=None, out: torch.Tensor = None) → torch.Tensor

Quantize tensor A in blocks of size 4096 values.

Quantizes tensor A by dividing it into blocks of 4096 values. Then the absolute maximum value within these blocks is calculated for the non-linear quantization.

Parameters
  • A (torch.Tensor) – The input tensor.

  • code (torch.Tensor) – The quantization map.

  • absmax (torch.Tensor) – The absmax values.

  • rand (torch.Tensor) – The tensor for stochastic rounding.

  • out (torch.Tensor) – The output tensor (8-bit).

Returns

  • torch.Tensor – The 8-bit tensor.

  • tuple(torch.Tensor, torch.Tensor) – The quantization state to undo the quantization.

bitsandbytes.functional.quantize_no_absmax(A: torch.Tensor, code: torch.Tensor, out: torch.Tensor = None) → torch.Tensor

Quantizes input tensor to 8-bit.

Quantizes the 32-bit input tensor A to the 8-bit output tensor out using the quantization map code.

Parameters
  • A (torch.Tensor) – The input tensor.

  • code (torch.Tensor) – The quantization map.

  • out (torch.Tensor, optional) – The output tensor. Needs to be of type byte.

Returns

Quantized 8-bit tensor.

Return type

torch.Tensor