API Reference#

Number formats#

class mptorch.number.Number[source]#

Bases: object

Base class for all supported number formats.

Users should always instantiate one of the derived classes.

class mptorch.number.FloatType[source]#

Bases: Number

Base class for all float-like number formats.

Similar to the Number class, users should not instantiate this class directly. It is useful as a means to determine if a number format is of float type.

class mptorch.number.FixedPoint(wl, fl, clamp=True, symmetric=False)[source]#

Bases: Number

Low-Precision Fixed Point Number Format. Defined similarly to Deep Learning with Limited Numerical Precision (https://arxiv.org/abs/1502.02551).

The representable range is \(\left[-2^{\text{wl}-\text{fl}-1}, 2^{\text{wl}-\text{fl}-1}-2^{-\text{fl}}\right]\) and the precision unit (smallest nonzero absolute value) is \(2^{-\text{fl}}\). Numbers outside of the representable range can be clamped (if clamp is true). We can also give up the smallest representable number to make the range symmetric, \(\left[-2^{\text{wl}-\text{fl}-1}+2^{-\text{fl}}, 2^{\textnormal{wl}-\text{fl}-1}-2^{-\text{fl}}\right]\) (if symmetric is true).

Define \(\lfloor x \rfloor\) to be the largest representable number (multiples of \(2^-\text{fl}\)) smaller than \(x\). For numbers within the representable range, we support two kinds of fixed point quantization: round to nearest (RN) and stochastic rounding (SR). They correspond to

\[\text{RN}(x) = \Biggl \lbrace { \lfloor x \rfloor, \text{ if } \lfloor x \rfloor \leq x \leq \lfloor x \rfloor + 2^{-\text{fl}-1} \atop \lfloor x \rfloor + 2^{-\text{fl}}, \text{ if } \lfloor x \rfloor + 2^{-\text{fl}-1} < x \leq \lfloor x \rfloor + 2^{-\text{fl}} }\]

or

\[\textnormal{SR}(x) = \Biggl \lbrace { \lfloor x \rfloor, \text{ with probabilty } 1 - \frac{x - \lfloor x \rfloor}{2^{-\text{fl}}} \atop \lfloor x \rfloor + 2^{-\text{fl}}, \text{ with probabilty } \frac{x - \lfloor x \rfloor}{2^{-\text{fl}}} }\]
Parameters:
  • wl (int) – word length of each fixed point number

  • fl (int) – fractional length of each fixed point number

  • clamp (bool) – whether to clamp unrepresentable numbers

  • symmetric (bool) – whether to make the representable range symmetric

class mptorch.number.FloatingPoint(exp, man, subnormals=False, saturate=True)[source]#

Bases: FloatType

Low-Precision Floating Point Format. Follows rules set out in the IEEE-754 standard, applying them in the context of custom precision formats.

We set the exponent bias to be \(2^{\text{exp}-1} - 1\). In terms of rounding mode (see available quantization functions), we offer support for round to nearest even and stochastic rounding.

Parameters:
  • exp (int) – number of bits allocated for exponent

  • man (int) – number of bits allocated for mantissa, referring to number of bits that are supposed to be stored on hardware (not counting the virtual bits)

  • subnormals (bool) – allow the use of subnormal values

  • saturate (bool) – clamp values instead of using infinities in case of overflow

property is_fp32: bool#

Returns if the format is equivalent to the IEEE-754 binary32 format.

property is_fp16: bool#

Returns if the format is equivalent to the IEEE-754 binary16 format.

property is_bfloat16: bool#

Returns if the format is equivalent to the bfloat16 format.

class mptorch.number.BlockFloatingPoint(wl, dim=-1)[source]#

Bases: Number

Low-Precision Block Floating Point Format.

BlockFloatingPoint shares an exponent across a block of numbers. The shared exponent is chosen from the largest magnitude in the block.

Parameters:
  • wl (int) – word length of the tensor

  • dim (int) – block dimension to share exponent. (*, D, *) Tensor where D is at position dim will have D different exponents; use -1 if the entire tensor is treated as a single block (there is only 1 shared exponent).

class mptorch.number.SuperNormalFloat(exp, man, binades, saturate=False)[source]#

Bases: FloatType

Low-Precision SuperNormal Floating Point Format. Described in Range Extension with Supernormals for Mixed-Precision 8-bit DNN Training (https://www.arith2025.org/proceedings/215900a017.pdf)

We set the exponent bias to be \(2^{\text{exp}-1}\). For rounding mode, we apply round to nearest even.

Parameters:
  • exp (int) – number of bits allocated for exponent

  • man (int) – number of bits allocated for mantissa, referring to number of bits that are supposed to be stored on hardware (not counting the virtual bits)

  • binades (int | tuple[int] | tuple[int, int]) – number of binades transformed into log range

  • saturate – clamp values instead of using infinities in case of overflow

class mptorch.number.Binary8(P, signed=True, subnormals=True, overflow_policy='saturate_maxfloat2')[source]#

Bases: FloatType

Low-Precision binary8 Format the follows the IEEE-P3109 WG specification (P3109/Public). This is a format showcasing an evolving standard. Changes are likely in the future.

binary8 is a format that takes a value P as an input to determines the number of mantissa and exponent bits.

Parameters:
  • P (int) – integer precision of the binary8 format

  • signed (bool) – boolean indicating whether the format is signed or unsigned

  • subnormals (int) – allow the use of subnormal values

  • overflow_policy (Literal['saturarte_infty', 'saturate_maxfloat', 'saturate_maxfloat2']) – string indicating the overflow policy, one of: saturate_maxfloat2 (no infinity and +1 normalized value), saturate_maxfloat (clamp to maxfloat), saturate_infty (use infinity)

Quantization#

mptorch.quant.fixed_point_quantize(x, wl, fl, clamp=True, symmetric=False, rounding='stochastic')[source]#

Quantize a single precision floating-point tensor into a low-precision fixed-point tensor

Parameters:
  • x (Tensor) – the single precision tensor to be quantized

  • wl (int) – word length of the fixed-point format being simulated

  • fl (int) – fractional length of the fixed-point format being simulated

  • clamp (bool) – clamp input numbers into representable range. If false, the quantization will only simulate the effect on precision

  • symmetric (bool) – discard the minimum representable number to make the representable range symmetric

  • rounding (Literal['nearest', 'stochastic']) – rounding mode, “stochastic” or “nearest” (default: “stochastic”)

Return type:

Tensor

Returns:

a quantized fixed-point representation of the input tensor

mptorch.quant.block_quantize(x, wl, dim=-1, rounding='stochastic')[source]#

Quantize a single precision floating-point tensor into a low-precision block floating-point representation

Parameters:
  • x (Tensor) – the single precision tensor to be quantized

  • wl (int) – word length of the block floating-point format being simulated

  • dim (int) – dimension over which to apply the block floating point representation (-1 applies it to the entire tensor)

  • rounding (Literal['nearest', 'stochastic']) – rounding mode, “stochastic” or “nearest”

Return type:

Tensor

Returns:

a quantized low-precision block floating-point representation of the input tensor

mptorch.quant.float_quantize(x, exp, man, rounding='stochastic', subnormals=True, saturate=True, prng=0)[source]#

Quantize a single precision floating-point tensor into a IEEE-754-like low-precision floating-point tensor

Parameters:
  • x (Tensor) – the single precision number to be quantized

  • exp (int) – number of bits allocated for exponent

  • man (int) – number of bits allocated for mantissa, not counting the virtual bit

  • rounding (Literal['nearest', 'stochastic']) – rounding mode, “stochastic” or “nearest”

  • subnormals (bool) – if subnormals are supported or not

  • saturate (bool) – saturate on overflow or use infinities

  • prng (int) – number of random bits to use in case of stochastic rounding

Return type:

Tensor

Returns:

a quantized low-precision floating point representation of the input tensor

mptorch.quant.binary8_quantize(x, P, rounding='nearest', overflow_policy='saturate_maxfloat', is_signed=True, subnormals=True, prng_bits=0)[source]#

Quantize a single precision floating-point tensor into a P3109-compatible one

Parameters:
  • x (Tensor) – the single precision number(torch.Tensor) to be quantized

  • P (int) – number of bits allocated for precision

  • is_signed (bool) – if subnormals are supported or not

  • rounding (Literal['nearest', 'stochastic', 'truncate']) – the quantization rounding mode

  • overflow_policy (Literal['saturate_infty', 'saturate_maxfloat', 'saturate_maxfloat2']) – overflow handling policy

  • subnormals (bool) – saturate on overflow or use infinities

  • prng_bits (int) – number of bits for the random generator

Overflow Policies:
  • saturate_infty: Finite input values of binary32, exceeding the maximum float value of the binary8 format, will saturate to the maximum float. Infinite inputs will still map to infinities in this mode.

  • saturate_maxfloat: Both finite and infinite input values of binary32, exceeding the maximum float value of the binary8 format, will saturate to the maximum float represented by 0x7e/0xfe. This number system has an encoding reserved for infinity (0x7f/0xff).

  • saturate_maxfloat2: Both finite and infinite input values of binary32, exceeding the maximum float value of the binary8 format, will saturate to the maximum float represented by 0x7f/0xff. This number system does not have an encoding reserved for infinity.

Return type:

Tensor

Returns:

a quantized low-precision floating point representation of the input tensor

mptorch.quant.binaryK_quantize(x, K, P, rounding='nearest_even', saturation_mode='overflow_infinity', is_signed=True, prng_bits=0, bias=None)[source]#
Return type:

Tensor

mptorch.quant.superfp_quantize(x, exp, man, binades, rounding='nearest', saturate=False)[source]#

Quantize a single precision floating-point tensor into a low-precision supernormal floating-point one

Parameters:
  • x (Tensor) – the single precision number to be quantized

  • exp (int) – number of bits allocated for exponent

  • man (int) – number of bits allocated for mantissa, not counting the virtual bit

  • binades (int | tuple[int] | tuple[int, int]) – number of binades that will be transformed into log range

  • rounding (Literal['stochastic', 'nearest']) – rounding mode, “stochastic” or “nearest”

  • saturate (bool) – saturate on overflow or use infinities

Return type:

Tensor

Returns:

a quantized low-precision supernormal floating-point representation of the input tensor

mptorch.quant.quantizer(forward_number=None, backward_number=None, forward_rounding='stochastic', backward_rounding='stochastic', clamping_grad_zero=False, backward_hooks=[])[source]#

Creates a quantization function to support quantizing forward and backward process differently.

Parameters:
  • forward_number (Number | None) – the number format used for forward quantization. if is None, the quantization would be a identity mapping.

  • backward_number (Number | None) – the number format used for backward quantization. if is None, the quantization would be a identity mapping.

  • forward_rounding (Literal['nearest', 'stochastic']) – rounding mode, "stochastic" or "nearest" (default: "stochastic")

  • backward_rounding (Literal['nearest', 'stochastic']) – rounding mode, "stochastic" or "nearest" (default: "stochastic")

  • clamping_grad_zero (bool) – zero out the gradient of numbers that are being clamped during forward propagation. currently requires forward_number to be a fixed point number.

  • backward_hooks – iterable of functions that will be applied to gradients before backward quantization. For example, this can be used to support custom scaling.

Returns:

A quantization function as specified

Quantization Formats#

class mptorch.quant.quant_format.QAffineFormats(fwd_mac=None, bwd_mac=None, fwd_rnd=None, bwd_rnd=None, weight_quant=<function <lambda>>, bias_quant=<function <lambda>>, input_quant=<function <lambda>>, output_quant=<function <lambda>>, grad_quant=<function <lambda>>, compensated=None, use_scaling=False, weight_scaled_format=None, input_scaled_format=None, grad_scaled_format=None, scale_margin=0, rbits=0)[source]#

Bases: object

Configuration class for number formats to use during compute (forward and/or backward pass) of affine layers (e.g. linear and convolutional). One can optionally specify quantizer objects for the signals in the layer (I/O activations, weights/bias terms and weight/error gradients) to facilitate quantization-aware-training (QAT) and post-training quantization (PTQ) workloads. Format parameters can also be specified for tensor scaling operations, in a similar way to what is described in: https://arxiv.org/pdf/2309.17224

Parameters:
  • fwd_mac (Number | tuple[Number] | tuple[Number, Number] | None) – compute configuration (add and multiply) for forward MAC operations

  • bwd_mac (Number | tuple[Number] | tuple[Number, Number] | None) – compute configuration (add and multiply) for backward MAC operations

  • fwd_rnd (str | None) – rounding mode for FWD computations

  • bwd_rnd (str | None) – rounding mode for BWD computations

  • weight_quant (Union[Callable[[Tensor], Tensor], tuple[Number, str]]) – quantization function or format and rounding on the weight signal inputs

  • bias_quant (Union[Callable[[Tensor], Tensor], tuple[Number, str]]) – quantization function or format and rounding on the bias signal inputs

  • input_quant (Union[Callable[[Tensor], Tensor], tuple[Number, str]]) – quantization function or format and rounding on the output signal from the layer

  • grad_quant (Union[Callable[[Tensor], Tensor], Tuple[Number, str]]) – quantization function or format and rounding on the gradient signals in the BWD pass

  • use_scaling (bool) – whether to use weight, input and grad scaling during forward/backward pass

  • weight_scaled_format (FloatType | None) – number format to be used during weight tensor scaling (optional, matches weight_quant if format specified)

  • input_scaled_format (FloatType | None) – number format to be used during input tensor scaling (optional, matches input_quant if format specified)

  • grad_scaled_format (FloatType | None) – number format to be used during output tensor scaling (optional, matches grad_quant if format specified)

  • scale_margin (int) – margin to reduce scaling bias when casting down to FP8 formats

  • rbits (int | tuple[int] | tuple[int, int]) – number of bits used for random number generation when rounding is stochastic (for add and multiply)

class mptorch.quant.quant_format.QSoftmaxFormats(fwd_off=None, fwd_exp=None, fwd_acc=None, fwd_lse=None, bwd_add=None, bwd_mul=None, fwd_rnd='nearest', bwd_rnd='nearest', input_quant=<function <lambda>>, output_quant=<function <lambda>>, grad_quant=<function <lambda>>)[source]#

Bases: object

Configuration class for number formats to use during compute (forward and/or backward pass) of softmax layers. One can optionally specify quantizer objects for the signals in the layer (I/O activations and weight/error gradients) to facilitate quantization-aware-training (QAT) and post-training quantization (PTQ) workloads.

Two implementations of the forward softmax are provided: regular and LogSumExp-based (LSE).

Regular softmax is used when fwd_off, fwd_exp and fwd_acc are set, and is implemented as follows:

\[\textrm{softmax}(x)_i = \frac{\exp(x_i - \max x)}{ \sum_j \exp(x_j - \max x)}\]

The LogSumExp implementation is used when fwd_off and fwd_exp are set, and is implemented as follows:

\[\textrm{softmax}(x)_i = \ln(\textrm{LSE}(x_1, ..., x_n))\]

where \(\textrm{LSE}(x_1, ..., x_n)\) is computed iteratively with the relation:

\[\textrm{LSE}(x_1, ..., x_{j+1}) = \ln(\exp \textrm{LSE}(x_1, ..., x_{j}) + \exp x_{j+1})\]

with the internal part of the log being computed at full precision.

Parameters:
  • fwd_off (Number | None) – compute configuration for forward subtraction

  • fwd_exp (Number | None) – compute configuration for forward exponential operations

  • fwd_acc (Number | None) – compute configuration for forward add operations

  • fwd_lse (Number | None) – compute configuration for forward LSE iteration

  • bwd_add (Number | None) – compute configuration for backward add operations

  • bwd_mul (Number | None) – compute configuration for backward multiply operations

  • fwd_rnd (str | None) – rounding mode for forward computations

  • bwd_rnd (str | None) – rounding mode for backward computations

  • input_quant (Callable[[Tensor], Tensor]) – quantization function on the input signal

  • output_quant (Callable[[Tensor], Tensor]) – quantization function on the output signal

  • grad_quant (Callable[[Tensor], Tensor]) – quantization function on the gradients

class mptorch.quant.quant_format.QLayerNormFormats(fwd_acc=None, fwd_mul=None, fwd_div=None, fwd_sqrt=None, bwd_acc=None, bwd_mul=None, bwd_div=None, fwd_rnd='nearest', bwd_rnd='nearest', input_quant=<function <lambda>>, output_quant=<function <lambda>>, grad_quant=<function <lambda>>, weight_quant=<function <lambda>>, bias_quant=<function <lambda>>)[source]#

Bases: object

Configuration class for number formats to use during compute (forward and/or backward pass) of layer normalization. One can optionally specify quantizer objects for the signals in the layer (I/O activations, weights/bias terms and weight/error gradients) to facilitate quantization-aware-training (QAT) and post-training quantization (PTQ) workloads.

Parameters:
  • fwd_acc (Number | None) – compute configuration for forward add operations

  • fwd_mul (Number | None) – compute configuration for forward multiply operations

  • fwd_div (Number | None) – compute configuration for forward divide operations

  • fwd_sqrt (Number | None) – compute configuration for forward square root operations

  • bwd_acc (Number | None) – compute configuration for backward add operations

  • bwd_mul (Number | None) – compute configuration for backward multiply operations

  • bwd_div (Number | None) – compute configuration for backward divide operations,

  • fwd_rnd (str | None) – rounding mode for forward computations

  • bwd_rnd (str | None) – rounding mode for backward computations

  • input_quant (Callable[[Tensor], Tensor]) – quantization function on the input signal

  • output_quant (Callable[[Tensor], Tensor]) – quantization function on the output signal

  • grad_quant (Callable[[Tensor], Tensor]) – quantization function on the gradients

  • weight_quant (Callable[[Tensor], Tensor]) – quantization function on the weights when applied to an input

  • bias_quant (Callable[[Tensor], Tensor]) – quantization function on the bias when applied to an input

class mptorch.quant.quant_format.QGELUFormats(input_quant=<function <lambda>>, inter_quant=<function <lambda>>, output_quant=<function <lambda>>, grad_quant=<function <lambda>>)[source]#

Bases: object

Configuration class for number formats to use during compute (forward and/or backward pass) of GELU activation.

Parameters:
  • input_quant (Callable[[Tensor], Tensor]) – quantization function on the input signal

  • inter_quant (Callable[[Tensor], Tensor]) – quantization function on intermediate computation, depends on wether tanh approximation is used

  • output_quant (Callable[[Tensor], Tensor]) – quantization function on the output signal

  • grad_quant (Callable[[Tensor], Tensor]) – quantization function on the gradients

Modules#

class mptorch.quant.modules.Quantizer(forward_number=None, backward_number=None, forward_rounding='nearest', backward_rounding='nearest')[source]#

Bases: Module

A quantization module that supports quantizing forward and backward process differently.

Parameters:
  • forward_number (Number | None) – the number format used for forward quantization. if is None, the quantization would be a identity mapping.

  • backward_number (Number | None) – the number format used for backward quantization. if is None, the quantization would be a identity mapping.

  • forward_rounding (Literal['nearest', 'stochastic']) – rounding mode, "stochastic" or "nearest" (default: "stochastic")

  • backward_rounding (Literal['nearest', 'stochastic']) – rounding mode, "stochastic" or "nearest" (default: "stochastic")

forward(x)[source]#

The forward pass call on the input tensor. Applies the forward quantization on the input and registers the backward quantization format.

Parameters:

x (Tensor) – the input tensor

Return type:

Tensor

Returns:

The quantized version of the input, as specified by the FWD number format and associated rounding mode

class mptorch.quant.modules.QLinear(in_features, out_features, formats, bias=True)[source]#

Bases: Linear

Applies a linear transformation to the incoming data: \(y=xW^T + b\)

It is a subclass of torch.nn.Linear and allows one to specify if I/O signals should be quantized during inference & training (needed for instance in QAT and PTQ methods), but also the precision(s) to be used in internal GEMM computations (addition and multiplication, fused or not). This allows simulating the effect of custom precision during GEMM calls in the forward and backward pass and is helpful in studying the effect of low precision compute during inference and training (not just data quantization).

Parameters:
  • in_features (int) – size of each input sample

  • out_features (int) – size of each output sample

  • formats (QAffineFormats) – number formats used during compute (addition and multiplication) and quantization functions for signals during forward and back propagation (I/O activations, weights, biases, and neural gradients)

  • bias (bool) – If set to False, the layer will not learn an additive bias. Default: True

Shape:
  • Input: \((*, H_\text{in})\) where \(*\) means any number of dimensions including none and \(H_\text{in} = \text{in_features}\).

  • Output: \((*, H_\text{out})\) where all but the last dimension are the same shape as the input and \(H_\text{out} = \text{out_features}\).

weight#

the learnable weights of the module of shape \((\text{out_features}, \text{in_features})\). The values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\), where \(k = \frac{1}{\text{in_features}}\)

Type:

Tensor

bias#

the learnable bias of the module of shape \((\text{out_features})\). If bias is True, the values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{1}{\text{in_features}}\)

Type:

Tensor

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(input)[source]#

Describes the computations that get performed at every call of the module. The use of quantized elementary operations (i.e., additions and multiplications) in the FWD and BWD passes are controlled through the formats argument to the module constructor.

Parameters:

input (Tensor) – the input tensor over which to perform the layer operations. Must adhere to the input shape requirements.

Return type:

Tensor

Returns:

the result of the \(xW^T + b\) operation.

quant_function(fwd_quant, bwd_quant)[source]#

Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.

Parameters:
  • fwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the forward pass through the input signal

  • bwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the backward pass through the output gradient signal

quant_parameters()[source]#

Quantizes the module parameters weight and bias using the quantization functions specified in formats.weight_quant and formats.bias_quant, respectively.

reset_quant_function()[source]#

Sets a straight-through estimator-like function to all the quantized signals in the module (these can be the weight, bias, input, and/or output signals), depending if quantizers are specified in the associated QAffineFormats formats parameter.

class mptorch.quant.modules.QLazyLinear(out_features, formats, bias=True, device=None, dtype=None)[source]#

Bases: LazyModuleMixin, QLinear

A linear module where in_features is inferred.

In this module (an analogue to torch.nn.LazyLinear), the in_features parameter of the quantized linear layer is inferred from the input’s last dimension (i.e., input.shape[-1]). The weight and bias layer parameters are of torch.nn.UninitializedParameter class. They are initialized after the first call to forward and the module becomes a regular torch.nn.Linear module.

Parameters:
  • out_features (int) – size of each output sample

  • formats (QAffineFormats) – number formats used during compute (addition and multiplication) and quantization functions for signals during forward and back propagation (I/O activations, weights, biases, and neural gradients)

  • bias (bool) – If set to False, the layer will not learn an additive bias. Default: True

Initialize internal Module state, shared by both nn.Module and ScriptModule.

cls_to_become#

alias of QLinear

initialize_parameters(input)[source]#

Initialize parameters according to the input batch properties.

This adds an interface to isolate parameter initialization from the forward pass when doing parameter shape inference.

Parameters:

input (Tensor) – the input tensor on which to perform the layer operations. It’s shape is used to determine the value of the in_features member variable.

reset_parameters()[source]#

Resets parameter values in case parameters have been initialized

weight: UninitializedParameter#

The learnable weights of the module of shape \((\text{out_features}, \text{in_features})\). The values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\), where \(k = \frac{1}{\text{in_features}}\)

bias: UninitializedParameter#

The learnable bias of the module of shape \((\text{out_features})\). If bias is True, the values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{1}{\text{in_features}}\)

class mptorch.quant.modules.QConv1d(in_channels, out_channels, kernel_size, formats, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')[source]#

Bases: Conv1d

Applies a 1D convolution over an input signal composed of several input planes.

It is a subclass of torch.nn.Conv1d and allows one to specify if I/O signals should be quantized during inference & training (needed for instance in QAT and PTQ methods), but also the precision(s) to be used in internal GEMM computations (addition and multiplication, fused or not). This allows simulating the effect of custom precision during GEMM calls in the forward and backward pass (which are performed using the im2col and col2im algorithms implemented in the unfoldNd library) and is helpful in studying the effect of low precision compute during inference and training (not just data quantization).

In the simplest case, the output value of the layer with input size \((N, C_{\text{in}}, L)\) and output \((N, C_{\text{out}}, L_{\text{out}})\) can be precisely described as:

\[\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + \sum_{k = 0}^{C_\text{in} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)\]

where \(\star\) is the valid cross-correlation operator, \(N\) is a batch size, \(C\) denotes a number of channels, \(L\) is a length of signal sequence.

  • stride controls the stride for the cross-correlation, a single number or a one-element tuple.

  • padding controls the amount of padding applied to the input. It can be either a string valid or same or a tuple of ints giving the amount of implicit padding applied on both sides.

  • dilation controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this link has a nice visualization of what dilation does.

  • groups controls the connections between inputs and outputs.

  • in_channels and out_channels must both be divisible by groups. For example,

    • At groups=1, all inputs are convolved to all outputs.

    • At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels and producing half the output channels, and both subsequently concatenated.

    • At groups= in_channels, each input channel is convolved with its own set of filters (of size \(\frac{\text{out_channels}}{\text{in_channels}}\)).

Note

When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, this operation is also known as a “depthwise convolution”.

In other words, for an input of size \((N, C_\text{in}, L_\text{in})\), a depthwise convolution with a depthwise multiplier K can be performed with the arguments \((C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})\).

Shape:
  • Input: \((N, C_\text{in}, L_\text{in})\) or \((C_\text{in}, L_\text{in})\)

  • Output: \((N, C_\text{out}, L_\text{out})\) or \((C_\text{out}, L_\text{out})\), where

    \[L_\text{out} = \left\lfloor\frac{L_\text{in} + 2 \times \text{padding} - \text{dilation} \times (\text{kernel_size} - 1) - 1}{\text{stride}} + 1\right\rfloor\]
weight#

the learnable weights of the module of shape \((\text{out_channels}, \frac{\text{in_channels}}{\text{groups}}, \text{kernel_size})\). The values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{in} * \text{kernel_size}}\)

Type:

Tensor

bias#

the learnable bias of the module of shape (out_channels). If bias is True, then the values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{in} * \text{kernel_size}}\)

Type:

Tensor

Parameters:
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels produced by the convolution

  • kernel_size (int | tuple[int]) – Size of the convolving kernel

  • formats (QAffineFormats) – Number formats used during compute (addition and multiplication) and quantization functions for signals during forward and back propagation (I/O activations, weights, biases, and neural gradients)

  • stride (int | tuple[int]) – Stride of the convolution. Default: 1

  • padding (str | int | tuple[int]) – Padding added to both sides of the input. Default: 0

  • dilation (int | tuple[int]) – Spacing between kernel elements. Default: 1

  • groups (int) – Number of blocked connections from input channels to output channels. Default: 1

  • bias (bool) – If True, adds a learnable bias to the output. Default: True

  • padding_mode (str) – 'zeros', 'reflect', 'replicate' or 'circular'. Default: 'zeros'

forward(input)[source]#

Describes the computations that get performed at every call of the module. The use of quantized elementary operations (i.e., additions and multiplications) in the FWD and BWD passes are controlled through the formats argument to the module constructor.

Parameters:

input (Tensor) – the input tensor on which to perform the layer operations. Must adhere to the input shape requirements.

Return type:

Tensor

Returns:

the result of the 1D cross-correlation operation.

quant_function(fwd_quant, bwd_quant)[source]#

Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.

Parameters:
  • fwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the forward pass through the input signal

  • bwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the backward pass through the output gradient signal

quant_parameters()[source]#

Quantizes the module parameters weight and bias using the quantization functions specified in formats.weight_quant and formats.bias_quant, respectively.

reset_quant_function()[source]#

Sets a straight-through estimator-like function to all the quantized signals in the module (these can be the weight, bias, input, and/or output signals), depending if quantizers are specified in the associated QAffineFormats formats parameter.

class mptorch.quant.modules.QConv2d(in_channels, out_channels, kernel_size, formats, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')[source]#

Bases: Conv2d

Applies a 2D convolution over an input signal composed of several input planes.

It is a subclass of torch.nn.Conv2d and allows one to specify if I/O signals should be quantized during inference & training (needed for instance in QAT and PTQ methods), but also the precision(s) to be used in internal GEMM computations (addition and multiplication, fused or not). This allows simulating the effect of custom precision during GEMM calls in the forward and backward pass (which are performed using the im2col and col2im algorithms implemented in the unfoldNd library) and is helpful in studying the effect of low precision compute during inference and training (not just data quantization).

In the simplest case, the output value of the layer with input size \((N, C_{\text{in}}, H, W)\) and output \((N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})\) can be precisely described as:

\[\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)\]

where \(\star\) is the valid 2D cross-correlation operator, \(N\) is a batch size, \(C\) denotes a number of channels, \(H\) is a height of input planes in pixels, and \(W\) is width in pixels.

  • stride controls the stride for the cross-correlation, a single number or a tuple.

  • padding controls the amount of padding applied to the input. It can be either a string {{‘valid’, ‘same’}} or an int / a tuple of ints giving the amount of implicit padding applied on both sides.

  • dilation controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this link has a nice visualization of what dilation does.

  • groups controls the connections between inputs and outputs.

  • in_channels and out_channels must both be divisible by groups. For example,

    • At groups=1, all inputs are convolved to all outputs.

    • At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels and producing half the output channels, and both subsequently concatenated.

    • At groups= in_channels, each input channel is convolved with its own set of filters (of size \(\frac{\text{out_channels}}{\text{in_channels}}\)).

The parameters kernel_size, stride, padding, dilation can either be:

  • a single int – in which case the same value is used for the height and width dimension

  • a tuple of two ints – in which case, the first int is used for the height dimension, and the second int for the width dimension

Note

When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, this operation is also known as a “depthwise convolution”.

In other words, for an input of size \((N, C_\text{in}, L_\text{in})\), a depthwise convolution with a depthwise multiplier K can be performed with the arguments \((C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})\).

Note

padding='valid' is the same as no padding. padding='same' pads the input so the output has the shape as the input. However, this mode doesn’t support any stride values other than 1.

Shape:
  • Input: \((N, C_\text{in}, H_\text{in}, W_\text{in})\) or \((C_\text{in}, H_\text{in}, W_\text{in})\)

  • Output: \((N, C_\text{out}, H_\text{out}, W_\text{out})\) or \((C_\text{out}, H_\text{out}, W_\text{out})\), where

    \[H_\text{out} = \left\lfloor\frac{H_\text{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times (\text{kernel_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor\]
    \[W_\text{out} = \left\lfloor\frac{W_\text{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times (\text{kernel_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor\]
weight#

the learnable weights of the module of shape \((\text{out_channels}, \frac{\text{in_channels}}{\text{groups}},\) \(\text{kernel_size[0]}, \text{kernel_size[1]})\). The values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{in} * \prod_{i=0}^{1}\text{kernel_size}[i]}\)

Type:

Tensor

bias#

the learnable bias of the module of shape (out_channels). If bias is True, then the values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{in} * \prod_{i=0}^{1}\text{kernel_size}[i]}\)

Type:

Tensor

Parameters:
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels produced by the convolution

  • kernel_size (int | tuple[int, int]) – Size of the convolving kernel

  • formats (QAffineFormats) – Number formats used during compute (addition and multiplication) and quantization functions for signals during forward and back propagation (I/O activations, weights, biases, and neural gradients)

  • stride (int | tuple[int, int]) – Stride of the convolution. Default: 1

  • padding (str | int | tuple[int, int]) – Padding added to all four sides of the input. Default: 0

  • dilation (int | tuple[int, int]) – Spacing between kernel elements. Default: 1

  • groups (int) – Number of blocked connections from input channels to output channels. Default: 1

  • bias (bool) – If True, adds a learnable bias to the output. Default: True

  • padding_mode (str) – 'zeros', 'reflect', 'replicate' or 'circular'. Default: 'zeros'

forward(input)[source]#

Describes the computations that get performed at every call of the module. The use of quantized elementary operations (i.e., additions and multiplications) in the FWD and BWD passes are controlled through the formats argument to the module constructor.

Parameters:

input (Tensor) – the input tensor on which to perform the layer operations. Must adhere to the input shape requirements.

Return type:

Tensor

Returns:

the result of the 2D cross-correlation operation.

quant_function(fwd_quant, bwd_quant)[source]#

Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.

Parameters:
  • fwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the forward pass through the input signal

  • bwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the backward pass through the output gradient signal

quant_parameters()[source]#

Quantizes the module parameters weight and bias using the quantization functions specified in formats.weight_quant and formats.bias_quant, respectively.

reset_quant_function()[source]#

Sets a straight-through estimator-like function to all the quantized signals in the module (these can be the weight, bias, input, and/or output signals), depending if quantizers are specified in the associated QAffineFormats formats parameter.

class mptorch.quant.modules.QConv3d(in_channels, out_channels, kernel_size, formats, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')[source]#

Bases: Conv3d

Applies a 3D convolution over an input signal composed of several input planes.

It is a subclass of torch.nn.Conv3d and allows one to specify if I/O signals should be quantized during inference & training (needed for instance in QAT and PTQ methods), but also the precision(s) to be used in internal GEMM computations (addition and multiplication, fused or not). This allows simulating the effect of custom precision during GEMM calls in the forward and backward pass (which are performed using the im2col and col2im algorithms implemented in the unfoldNd library) and is helpful in studying the effect of low precision compute during inference and training (not just data quantization).

In the simplest case, the output value of the layer with input size \((N, C_{\text{in}}, D, H, W)\) and output \((N, C_{\text{out}}, D_{\text{out}}, H_{\text{out}}, W_{\text{out}})\) can be precisely described as:

\[\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)\]

where \(\star\) is the valid 3D cross-correlation operator, \(N\) is a batch size, \(C\) denotes a number of channels, \(D\) is a depth of input planes in pixels, \(H\) is a height of input planes in pixels, and \(W\) is width in pixels.

  • stride controls the stride for the cross-correlation, a single number or a tuple.

  • padding controls the amount of padding applied to the input. It can be either a string {{‘valid’, ‘same’}} or an int / a tuple of ints giving the amount of implicit padding applied on both sides.

  • dilation controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this link has a nice visualization of what dilation does.

  • groups controls the connections between inputs and outputs.

  • in_channels and out_channels must both be divisible by groups. For example,

    • At groups=1, all inputs are convolved to all outputs.

    • At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels and producing half the output channels, and both subsequently concatenated.

    • At groups= in_channels, each input channel is convolved with its own set of filters (of size \(\frac{\text{out_channels}}{\text{in_channels}}\)).

The parameters kernel_size, stride, padding, dilation can either be:

  • a single int – in which case the same value is used for the height and width dimension

  • a tuple of two ints – in which case, the first int is used for the height dimension, and the second int for the width dimension

Note

When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, this operation is also known as a “depthwise convolution”.

In other words, for an input of size \((N, C_\text{in}, L_\text{in})\), a depthwise convolution with a depthwise multiplier K can be performed with the arguments \((C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})\).

Note

padding='valid' is the same as no padding. padding='same' pads the input so the output has the shape as the input. However, this mode doesn’t support any stride values other than 1.

Shape:
  • Input: \((N, C_\text{in}, D_\text{in}, H_\text{in}, W_\text{in})\) or \((C_\text{in}, D_\text{in}, H_\text{in}, W_\text{in})\)

  • Output: \((N, C_\text{out}, D_\text{out}, H_\text{out}, W_\text{out})\) or \((C_\text{out}, D_\text{out}, H_\text{out}, W_\text{out})\), where

    \[D_\text{out} = \left\lfloor\frac{D_\text{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times (\text{kernel_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor\]
    \[H_\text{out} = \left\lfloor\frac{H_\text{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times (\text{kernel_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor\]
    \[W_\text{out} = \left\lfloor\frac{W_\text{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times (\text{kernel_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor\]
weight#

the learnable weights of the module of shape \((\text{out_channels}, \frac{\text{in_channels}}{\text{groups}},\) \(\text{kernel_size[0]}, \text{kernel_size[1]}, \text{kernel_size[2]})\). The values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{in} * \prod_{i=0}^{2}\text{kernel_size}[i]}\)

Type:

Tensor

bias#

the learnable bias of the module of shape (out_channels). If bias is True, then the values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{in} * \prod_{i=0}^{2}\text{kernel_size}[i]}\)

Type:

Tensor

Parameters:
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels produced by the convolution

  • kernel_size (int | tuple[int, int, int]) – Size of the convolving kernel

  • formats (QAffineFormats) – Number formats used during compute (addition and multiplication) and quantization functions for signals during forward and back propagation (I/O activations, weights, biases, and neural gradients)

  • stride (int | tuple[int, int, int]) – Stride of the convolution. Default: 1

  • padding (str | int | tuple[int, int, int]) – Padding added to all four sides of the input. Default: 0

  • dilation (int | tuple[int, int, int]) – Spacing between kernel elements. Default: 1

  • groups (int) – Number of blocked connections from input channels to output channels. Default: 1

  • bias (bool) – If True, adds a learnable bias to the output. Default: True

  • padding_mode (str) – 'zeros', 'reflect', 'replicate' or 'circular'. Default: 'zeros'

forward(input)[source]#

Describes the computations that get performed at every call of the module. The use of quantized elementary operations (i.e., additions and multiplications) in the FWD and BWD passes are controlled through the formats argument to the module constructor.

Parameters:

input (Tensor) – the input tensor on which to perform the layer operations. Must adhere to the input shape requirements.

Return type:

Tensor

Returns:

the result of the 3D cross-correlation operation.

quant_function(fwd_quant, bwd_quant)[source]#

Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.

Parameters:
  • fwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the forward pass through the input signal

  • bwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the backward pass through the output gradient signal

quant_parameters()[source]#

Quantizes the module parameters weight and bias using the quantization functions specified in formats.weight_quant and formats.bias_quant, respectively.

reset_quant_function()[source]#

Sets a straight-through estimator-like function to all the quantized signals in the module (these can be the weight, bias, input, and/or output signals), depending if quantizers are specified in the associated QAffineFormats formats parameter.

class mptorch.quant.modules.QConvTranspose1d(in_channels, out_channels, kernel_size, formats, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', device=None, dtype=None)[source]#

Bases: ConvTranspose1d

Applies a 1D transposed convolution operator over an input image composed of several input planes.

This module can be seen as the gradient of Conv1d with respect to its input. It is also known as a fractionally-strided convolution or a deconvolution (although it is not an actual deconvolution operation as it does not compute a true inverse of convolution). For more information, see the visualizations here and the Deconvolutional Networks paper.

It is a subclass of torch.nn.ConvTranspose1d and allows one to specify if I/O signals should be quantized during inference & training (needed for instance in QAT and PTQ methods), but also the precision(s) to be used in internal GEMM computations (addition and multiplication, fused or not). This allows simulating the effect of custom precision during GEMM calls in the forward and backward pass (which are performed using the im2col and col2im algorithms implemented in the unfoldNd library) and is helpful in studying the effect of low precision compute during inference and training (not just data quantization).

  • stride controls the stride for the cross-correlation.

  • padding controls the amount of implicit zero padding on both sides for dilation * (kernel_size - 1) - padding number of points. See note below for details.

  • output_padding controls the additional size added to one side of the output shape. See note below for details.

  • dilation controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but the link here has a nice visualization of what dilation does.

  • in_channels and out_channels must both be divisible by groups. For example,

    • At groups=1, all inputs are convolved to all outputs.

    • At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels and producing half the output channels, and both subsequently concatenated.

    • At groups= in_channels, each input channel is convolved with its own set of filters (of size \(\frac{\text{out_channels}}{\text{in_channels}}\)).

Note

The padding argument effectively adds dilation * (kernel_size - 1) - padding amount of zero padding to both sizes of the input. This is set so that when a torch.nn.Conv1d and a torch.nn.ConvTranspose1d are initialized with same parameters, they are inverses of each other in regard to the input and output shapes. However, when stride > 1, torch.nn.Conv1d maps multiple input shapes to the same output shape. output_padding is provided to resolve this ambiguity by effectively increasing the calculated output shape on one side. Note that output_padding is only used to find output shape, but does not actually add zero-padding to output.

Shape:
  • Input: \((N, C_\text{in}, L_\text{in})\) or \((C_\text{in}, L_\text{in})\)

  • Output: \((N, C_\text{out}, L_\text{out})\) or \((C_\text{out}, L_\text{out})\), where

    \[L_\text{out} = (L_\text{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{dilation} \times (\text{kernel_size} - 1) + \text{output_padding} + 1\]
weight#

the learnable weights of the module of shape \((\text{in_channels}, \frac{\text{out_channels}}{\text{groups}},\) \(\text{kernel_size})\). The values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{out} * \text{kernel_size}}\)

Type:

Tensor

bias#

the learnable bias of the module of shape (out_channels). If bias is True, then the values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{out} * \text{kernel_size}}\)

Type:

Tensor

Parameters:
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels produced by the convolution

  • kernel_size (int | tuple[int]) – Size of the convolving kernel

  • formats (QAffineFormats) – Number formats used during compute (addition and multiplication) and quantization functions for signals during forward and back propagation (I/O activations, weights, biases, and neural gradients)

  • stride (int | tuple[int]) – Stride of the convolution. Default: 1

  • padding (int | tuple[int]) – dilation * (kernel_size - 1) - padding zero-padding will be added to both sides of the input. Default: 0

  • output_padding (int | tuple[int]) – Additional size added to one side of the output shape. Default: 0

  • groups (int) – Number of blocked connections from input channels to output channels. Default: 1

  • bias (bool) – If True, adds a learnable bias to the output. Default: True

  • dilation (int | tuple[int]) – Spacing between kernel elements. Default: 1

forward(input, output_size=None)[source]#

Describes the computations that get performed at every call of the module. The use of quantized elementary operations (i.e., additions and multiplications) in the FWD and BWD passes are controlled through the formats argument to the module constructor.

Parameters:
  • input (Tensor) – the input tensor on which to perform the layer operations. Must adhere to the input shape requirements.

  • output_size (list[int] | None) – specifies how the output should be padded

Return type:

Tensor

Returns:

the result of the 1D transposed convolution operation.

quant_function(fwd_quant, bwd_quant)[source]#

Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.

Parameters:
  • fwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the forward pass through the input signal

  • bwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the backward pass through the output gradient signal

quant_parameters()[source]#

Quantizes the module parameters weight and bias using the quantization functions specified in formats.weight_quant and formats.bias_quant, respectively.

reset_quant_function()[source]#

Sets a straight-through estimator-like function to all the quantized signals in the module (these can be the weight, bias, input, and/or output signals), depending if quantizers are specified in the associated QAffineFormats formats parameter.

class mptorch.quant.modules.QConvTranspose2d(in_channels, out_channels, kernel_size, formats, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', device=None, dtype=None)[source]#

Bases: ConvTranspose2d

Applies a 2D transposed convolution operator over an input image composed of several input planes.

This module can be seen as the gradient of Conv2d with respect to its input. It is also known as a fractionally-strided convolution or a deconvolution (although it is not an actual deconvolution operation as it does not compute a true inverse of convolution). For more information, see the visualizations here and the Deconvolutional Networks paper.

It is a subclass of torch.nn.ConvTranspose2d and allows one to specify if I/O signals should be quantized during inference & training (needed for instance in QAT and PTQ methods), but also the precision(s) to be used in internal GEMM computations (addition and multiplication, fused or not). This allows simulating the effect of custom precision during GEMM calls in the forward and backward pass (which are performed using the im2col and col2im algorithms implemented in the unfoldNd library) and is helpful in studying the effect of low precision compute during inference and training (not just data quantization).

  • stride controls the stride for the cross-correlation.

  • padding controls the amount of implicit zero padding on both sides for dilation * (kernel_size - 1) - padding number of points. See note below for details.

  • output_padding controls the additional size added to one side of the output shape. See note below for details.

  • dilation controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but the link here has a nice visualization of what dilation does.

  • in_channels and out_channels must both be divisible by groups. For example,

    • At groups=1, all inputs are convolved to all outputs.

    • At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels and producing half the output channels, and both subsequently concatenated.

    • At groups= in_channels, each input channel is convolved with its own set of filters (of size \(\frac{\text{out_channels}}{\text{in_channels}}\)).

The parameters kernel_size, stride, padding, output_padding can either be:

  • a single int – in which case the same value is used for the height and width dimensions

  • a tuple of two ints – in which case, the first int is used for the height dimension, and the second int for the width dimension

Note

The padding argument effectively adds dilation * (kernel_size - 1) - padding amount of zero padding to both sizes of the input. This is set so that when a torch.nn.Conv2d and a torch.nn.ConvTranspose2d are initialized with same parameters, they are inverses of each other in regard to the input and output shapes. However, when stride > 1, torch.nn.Conv2d maps multiple input shapes to the same output shape. output_padding is provided to resolve this ambiguity by effectively increasing the calculated output shape on one side. Note that output_padding is only used to find output shape, but does not actually add zero-padding to output.

Shape:
  • Input: \((N, C_\text{in}, H_\text{in}, W_\text{in})\) or \((C_\text{in}, H_\text{in}, W_\text{in})\)

  • Output: \((N, C_\text{out}, H_\text{out}, W_\text{out})\) or \((C_\text{out}, H_\text{out}, W_\text{out})\), where

\[H_\text{out} = (H_\text{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0] \times (\text{kernel_size}[0] - 1) + \text{output_padding}[0] + 1\]
\[W_\text{out} = (W_\text{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1] \times (\text{kernel_size}[1] - 1) + \text{output_padding}[1] + 1\]
weight#

the learnable weights of the module of shape \((\text{in_channels}, \frac{\text{out_channels}}{\text{groups}},\) \(\text{kernel_size[0]}, \text{kernel_size[1]})\). The values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{out} * \prod_{i=0}^{1}\text{kernel_size}[i]}\)

Type:

Tensor

bias#

the learnable bias of the module of shape (out_channels) If bias is True, then the values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{out} * \prod_{i=0}^{1}\text{kernel_size}[i]}\)

Type:

Tensor

Parameters:
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels produced by the convolution

  • kernel_size (int | tuple[int, int]) – Size of the convolving kernel

  • formats (QAffineFormats) – Number formats used during compute (addition and multiplication) and quantization functions for signals during forward and back propagation (I/O activations, weights, biases, and neural gradients)

  • stride (int | tuple[int, int]) – Stride of the convolution. Default: 1

  • padding (int | tuple[int, int]) – dilation * (kernel_size - 1) - padding zero-padding will be added to both sides of the input. Default: 0

  • output_padding (int | tuple[int, int]) – Additional size added to one side of the output shape. Default: 0

  • groups (int) – Number of blocked connections from input channels to output channels. Default: 1

  • bias (bool) – If True, adds a learnable bias to the output. Default: True

  • dilation (int | tuple[int, int]) – Spacing between kernel elements. Default: 1

forward(input, output_size=None)[source]#

Describes the computations that get performed at every call of the module. The use of quantized elementary operations (i.e., additions and multiplications) in the FWD and BWD passes are controlled through the formats argument to the module constructor.

Parameters:
  • input (Tensor) – the input tensor on which to perform the layer operations. Must adhere to the input shape requirements.

  • output_size (list[int] | None) – specifies how the output should be padded

Return type:

Tensor

Returns:

the result of the 2D transposed convolution operation.

quant_function(fwd_quant, bwd_quant)[source]#

Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.

Parameters:
  • fwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the forward pass through the input signal

  • bwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the backward pass through the output gradient signal

quant_parameters()[source]#

Quantizes the module parameters weight and bias using the quantization functions specified in formats.weight_quant and formats.bias_quant, respectively.

reset_quant_function()[source]#

Sets a straight-through estimator-like function to all the quantized signals in the module (these can be the weight, bias, input, and/or output signals), depending if quantizers are specified in the associated QAffineFormats formats parameter.

class mptorch.quant.modules.QConvTranspose3d(in_channels, out_channels, kernel_size, formats, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', device=None, dtype=None)[source]#

Bases: ConvTranspose3d

Applies a 3D transposed convolution operator over an input image composed of several input planes. The transposed convolution operator multiplies each input value element-wise by a learnable kernel, and sums over the outputs from all input feature planes.

This module can be seen as the gradient of Conv2d with respect to its input. It is also known as a fractionally-strided convolution or a deconvolution (although it is not an actual deconvolution operation as it does not compute a true inverse of convolution). For more information, see the visualizations here and the Deconvolutional Networks paper.

It is a subclass of torch.nn.ConvTranspose3d and allows one to specify if I/O signals should be quantized during inference & training (needed for instance in QAT and PTQ methods), but also the precision(s) to be used in internal GEMM computations (addition and multiplication, fused or not). This allows simulating the effect of custom precision during GEMM calls in the forward and backward pass (which are performed using the im2col and col2im algorithms implemented in the unfoldNd library) and is helpful in studying the effect of low precision compute during inference and training (not just data quantization).

  • stride controls the stride for the cross-correlation.

  • padding controls the amount of implicit zero padding on both sides for dilation * (kernel_size - 1) - padding number of points. See note below for details.

  • output_padding controls the additional size added to one side of the output shape. See note below for details.

  • dilation controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but the link here has a nice visualization of what dilation does.

  • in_channels and out_channels must both be divisible by groups. For example,

    • At groups=1, all inputs are convolved to all outputs.

    • At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels and producing half the output channels, and both subsequently concatenated.

    • At groups= in_channels, each input channel is convolved with its own set of filters (of size \(\frac{\text{out_channels}}{\text{in_channels}}\)).

The parameters kernel_size, stride, padding, output_padding can either be:

  • a single int – in which case the same value is used for the depth, height and width dimensions

  • a tuple of three ints – in which case, the first int is used for the depth dimension, the second int for the height dimension and the third int for the width dimension

Note

The padding argument effectively adds dilation * (kernel_size - 1) - padding amount of zero padding to both sizes of the input. This is set so that when a torch.nn.Conv3d and a torch.nn.ConvTranspose3d are initialized with same parameters, they are inverses of each other in regard to the input and output shapes. However, when stride > 1, torch.nn.Conv3d maps multiple input shapes to the same output shape. output_padding is provided to resolve this ambiguity by effectively increasing the calculated output shape on one side. Note that output_padding is only used to find output shape, but does not actually add zero-padding to output.

Shape:
  • Input: \((N, C_\text{in}, D_\text{in}, H_\text{in}, W_\text{in})\) or \((C_\text{in}, D_\text{in}, H_\text{in}, W_\text{in})\)

  • Output: \((N, C_\text{out}, D_\text{out}, H_\text{out}, W_\text{out})\) or \((C_\text{out}, D_\text{out}, H_\text{out}, W_\text{out})\), where

\[D_\text{out} = (D_\text{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0] \times (\text{kernel_size}[0] - 1) + \text{output_padding}[0] + 1\]
\[H_\text{out} = (H_\text{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1] \times (\text{kernel_size}[1] - 1) + \text{output_padding}[1] + 1\]
\[W_\text{out} = (W_\text{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{dilation}[2] \times (\text{kernel_size}[2] - 1) + \text{output_padding}[2] + 1\]
weight#

the learnable weights of the module of shape \((\text{in_channels}, \frac{\text{out_channels}}{\text{groups}},\) \(\text{kernel_size[0]}, \text{kernel_size[1]}, \text{kernel_size[2]})\). The values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{out} * \prod_{i=0}^{2}\text{kernel_size}[i]}\)

Type:

Tensor

bias#

the learnable bias of the module of shape (out_channels) If bias is True, then the values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{out} * \prod_{i=0}^{2}\text{kernel_size}[i]}\)

Type:

Tensor

Parameters:
  • in_channels (int) – Number of channels in the input image

  • out_channels (int) – Number of channels produced by the convolution

  • kernel_size (int | tuple[int, int, int]) – Size of the convolving kernel

  • formats (QAffineFormats) – Number formats used during compute (addition and multiplication) and quantization functions for signals during forward and back propagation (I/O activations, weights, biases, and neural gradients)

  • stride (int | tuple[int, int, int]) – Stride of the convolution. Default: 1

  • padding (int | tuple[int, int, int]) – dilation * (kernel_size - 1) - padding zero-padding will be added to both sides of the input. Default: 0

  • output_padding (int | tuple[int, int, int]) – Additional size added to one side of the output shape. Default: 0

  • groups (int) – Number of blocked connections from input channels to output channels. Default: 1

  • bias (bool) – If True, adds a learnable bias to the output. Default: True

  • dilation (int | tuple[int, int, int]) – Spacing between kernel elements. Default: 1

forward(input, output_size=None)[source]#

Describes the computations that get performed at every call of the module. The use of quantized elementary operations (i.e., additions and multiplications) in the FWD and BWD passes are controlled through the formats argument to the module constructor.

Parameters:
  • input (Tensor) – the input tensor on which to perform the layer operations. Must adhere to the input shape requirements.

  • output_size (list[int] | None) – specifies how the output should be padded

Return type:

Tensor

Returns:

the result of the 3D transposed convolution operation.

quant_function(fwd_quant, bwd_quant)[source]#

Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.

Parameters:
  • fwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the forward pass through the input signal

  • bwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the backward pass through the output gradient signal

quant_parameters()[source]#

Quantizes the module parameters weight and bias using the quantization functions specified in formats.weight_quant and formats.bias_quant, respectively.

reset_quant_function()[source]#

Sets a straight-through estimator-like function to all the quantized signals in the module (these can be the weight, bias, input, and/or output signals), depending if quantizers are specified in the associated QAffineFormats formats parameter.

class mptorch.quant.modules.QAvgPool2d(kernel_size, fwd_quant, bwd_quant, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)[source]#

Bases: AvgPool2d

Applies a 2D average pooling over an input signal composed of several input planes. Performs the addition operations in the FWD and BWD passes using quantized operators given as parameters.

In the simplest case, the output value of the layer with input size \((N, C, H, W)\), output \((N, C, H_\text{out}, W_\text{out})\) and kernel_size \((kH, kW)\) can be precisely described as:

\[\text{out}(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \text{input}(N_i, C_j, \text{stride}[0] \times h + m, \text{stride}[1] \times w + n)\]

If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points.

Note

When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding or the input. Sliding windows that would start in the right padded region are ignored.

The parameters kernel_size, stride, padding can either be:

  • a single int – in which case the same value is used for the height and width dimension

  • a tuple of two ints – in which case, the first int is used for the height dimension, and the second int for the width dimension

Shape:
  • Input: \((N, C, H_\text{in}, W_\text{in})\) or \((C, H_\text{in}, W_\text{in})\).

  • Output: \((N, C, H_\text{out}, W_\text{out})\) or \((C, H_\text{out}, W_\text{out})\), where

    \[H_\text{out} = \left\lfloor\frac{H_\text{in} + 2 \times \text{padding}[0] - \text{kernel_size}[0]}{\text{stride}[0]} + 1\right\rfloor\]
    \[W_\text{out} = \left\lfloor\frac{W_\text{in} + 2 \times \text{padding}[1] - \text{kernel_size}[1]}{\text{stride}[1]} + 1\right\rfloor\]

    Per the note above, if ceil_mode is True and \((H_\text{out} - 1)\times \text{stride}[0]\geq H_\text{in} + \text{padding}[0]\), we skip the last window as it would start in the bottom padded region, resulting in \(H_\text{out}\) being reduced by one.

    The same applies for \(W_\text{out}\).

Parameters:
  • kernel_size (int | tuple[int, int]) – the size of the window

  • fwd_quant (Callable[[Tensor], Tensor]) – quantization function to use during FWD addition operations

  • bwd_quant (Callable[[Tensor], Tensor]) – quantization function to use during BWD addition operations

  • stride (int | tuple[int, int] | None) – the stride of the window. Default value is kernel_size

  • padding (int | tuple[int, int]) – implicit zero padding to be added on both sides

  • ceil_mode (bool) – when True, will use ceil instead of floor to compute the output shape

  • count_include_pad (bool) – when True, will include the zero-padding in the averaging calculation

  • divisor_override (int | None) – if specified, it will be used as divisor, otherwise size of the pooling region will be used.

forward(input)[source]#

Performs the pooling operation over the input tensor.

Parameters:

input (Tensor) – the input tensor over which to perform the pooling operation. Must adhere to the input shape requirements.

Return type:

Tensor

Returns:

the result of the pooling operation.

class mptorch.quant.modules.QBatchNorm1d(num_features, fwd_quant, bwd_quant)[source]#

Bases: QBatchNorm

Applies Batch Normalization over a 2D input.

Method described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .

\[y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]

The mean and standard-deviation are calculated per-dimension over the mini-batches and \(\gamma\) and \(\beta\) are learnable parameter vectors of size C (where C is the number of features or channels of the input). By default, the elements of \(\gamma\) are set to 1 and the elements of \(\beta\) are set to 0. At train time in the forward pass, the variance is calculated via the biased estimator, equivalent to torch.var(input, unbiased=False). However, the value stored in the moving average of the variance is calculated via the unbiased estimator, equivalent to torch.var(input, unbiased=True).

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default momentum of 0.1.

Note

This momentum is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is \(\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t\), where \(\hat{x}\) is the estimated statistic and \(x_t\) is the new observed value.

Shape:
  • Input: \((N, C)\) or \((N, C, L)\), where \(N\) is the batch size, \(C\) is the number of features or channels, and \(L\) is the sequence length

  • Output: \((N, C)\) or \((N, C, L)\) (same shape as input)

Because the Batch Normalization is done over the C dimension, computing statistics on (N, L) slices, it’s common terminology to call this Temporal Batch Normalization.

Parameters:
  • num_features (int) – number of features or channels \(C\) of the input

  • fwd_quant (Callable[[Tensor], Tensor]) – quantization function to use during FWD operations

  • bwd_quant (Callable[[Tensor], Tensor]) – quantization function to use during BWD operations

class mptorch.quant.modules.QBatchNorm2d(num_features, fwd_quant, bwd_quant)[source]#

Bases: QBatchNorm

Applies Batch Normalization over a 4D input.

4D is a mini-batch of 2D inputs with additional channel dimension. Method described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .

\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]

The mean and standard-deviation are calculated per-dimension over the mini-batches and \(\gamma\) and \(\beta\) are learnable parameter vectors of size C (where C is the input size). By default, the elements of \(\gamma\) are set to 1 and the elements of \(\beta\) are set to 0. At train time in the forward pass, the standard-deviation is calculated via the biased estimator, equivalent to torch.var(input, unbiased=False). However, the value stored in the moving average of the standard-deviation is calculated via the unbiased estimator, equivalent to torch.var(input, unbiased=True).

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default momentum of 0.1.

Note

This momentum different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is \(\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t\), where \(\hat{x}\) is the estimated statistic and \(x_t\) is the new observed value.

Because the Batch Normalization is done over the C dimension, computing statistics on (N, H, W) slices, it’s common terminology to call this Spatial Batch Normalization.

Shape:
  • Input: \((N, C, H, W)\)

  • Output: \((N, C, H, W)\) (same shape as input)

Parameters:
  • num_features (int) – \(C\) from an expected input of size \((N, C, H, W)\)

  • fwd_quant (Callable[[Tensor], Tensor]) – quantization function to use during FWD operations

  • bwd_quant (Callable[[Tensor], Tensor]) – quantization function to use during BWD operations

class mptorch.quant.modules.QLayerNorm(normalized_shape, formats, eps=1e-05, elementwise_affine=True, bias=True)[source]#

Bases: LayerNorm

Applies Layer Normalization over a mini-batch of inputs.

This layer implements the operation as described in the paper Layer Normalization

\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]

The mean and standard-deviation are calculated over the last D dimensions, where D is the dimension of normalized_shape. For example, if normalized_shape is (3, 5) (a 2-dimensional shape), the mean and standard-deviation are computed over the last 2 dimensions of the input (i.e. input.mean((-2, -1))). \(\gamma\) and \(\beta\) are learnable affine transform parameters of normalized_shape if elementwise_affine is True. The variance is calculated via the biased estimator, equivalent to torch.var(input, unbiased=False).

Note

Unlike Batch Normalization and Instance Normalization, which applies scalar scale and bias for each entire channel/plane with the affine option, Layer Normalization applies per-element scale and bias with elementwise_affine.

This layer uses statistics computed from input data in both training and evaluation modes.

weight#

the learnable weights of the module of shape \(\text{normalized_shape}\) when elementwise_affine is set to True. The values are initialized to 1.

Type:

Tensor

bias#

the learnable bias of the module of shape \(\text{normalized_shape}\) when elementwise_affine is set to True. The values are initialized to 0.

Type:

Tensor

Shape:
  • Input: \((N, *)\)

  • Output: \((N, *)\) (same shape as input)

Parameters:
  • normalized_shape (int or list or torch.Size) –

    input shape from an expected input of size

    \[[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]\]

    If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size.

  • eps (float) – a value added to the denominator for numerical stability. Default: 1e-5

  • elementwise_affine (bool) – a boolean value that when set to True, this module has learnable per-element affine parameters initialized to ones (for weights) and zeros (for biases). Default: True.

  • bias (bool) – If set to False, the layer will not learn an additive bias (only relevant if elementwise_affine is True). Default: True.

forward(input)[source]#

Performs the layernorm operation on the input tensor, using quantized elementary operations (e.g. additions and multiplications) in the FWD and BWD passes, as specified through the formats argument to the module constructor.

Parameters:

input (Tensor) – the input tensor over which to perform the layernorm operations.

Return type:

Tensor

Returns:

the output of the layernorm operation

quant_function(fwd_quant, bwd_quant)[source]#

Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.

Parameters:
  • fwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the forward pass through the input signal

  • bwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the backward pass through the output gradient signal

quant_parameters()[source]#

Quantizes the module parameters weight and bias using the quantization functions specified in formats.weight_quant and formats.bias_quant, respectively.

reset_quant_function()[source]#

Sets a straight-through estimator-like function to all the quantized signals in the module (these can be the weight, bias, input, and/or output signals), depending if quantizers are specified in the associated QAffineFormats formats parameter.

class mptorch.quant.modules.QSoftmax(dim, formats)[source]#

Bases: Softmax

A quantized implementation of the Softmax activation function.

This class extends PyTorch’s torch.nn.Softmax and allows one to specify if I/O signals and internal computations should be quantized during inference & training. This allows simulating the effect of custom precision in the internal forward and backward pass computations and is helpful in studying the effect of low precision compute during inference and training (not just data quantization).

Parameters:
  • dim (int) – The dimension along which Softmax will be applied.

  • formats (QSoftmaxFormats) – Number formats specification during compute and quantization functions for signals during forward and back propagation (I/O activations and gradients).

forward(input)[source]#

Performs the softmax operation on the input tensor. The use of quantized elementary operations (i.e., additions and multiplications) in the FWD and BWD passes are controlled through the formats argument to the module constructor.

Parameters:

input – the input tensor over which to perform the softmax operations.

Return type:

Tensor

Returns:

the result of the softmax operation.

quant_function(fwd_quant, bwd_quant)[source]#

Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.

Parameters:
  • fwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the forward pass through the input signal

  • bwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the backward pass through the output gradient signal

reset_quant_function()[source]#

Sets a straight-through estimator-like function to all the quantized I/O signals in the module, depending if quantizers are specified in the associated QSoftmaxFormats formats parameter.

class mptorch.quant.modules.QGELU(formats, approximate='none')[source]#

Bases: GELU

Applies the Gaussian Error Linear Units function to the input \(x\):

\[\text{GELU}(x) = x * \Phi(x)\]

where \(\Phi(x)\) is the Cumulative Distribution Function for Gaussian Distribution.

When the approximate argument is 'tanh', GELU is estimated with:

\[\text{GELU}(x) = 0.5 * x * (1 + \tanh(\sqrt{2 / \pi} * (x + 0.044715 * x^3)))\]
Parameters:
  • x – the input tensor

  • formats (QGELUFormats) – configuration class for number formats and quantizers to use during forward and backward computations in GELU

  • approximate (Literal['none', 'tanh']) – the GELU approximation algorithm to use: 'none' | 'tanh'. Default: 'none'

forward(input)[source]#

Applies the GELU function on the input tensor element-wise.

Parameters:

input (Tensor) – the input tensor over which the GELU function is applied

Return type:

Tensor

Returns:

the result of applying the GELU function

quant_function(fwd_quant, bwd_quant)[source]#

Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.

Parameters:
  • fwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the forward pass through the input signal

  • bwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply during the backward pass through the output gradient signal

reset_quant_function()[source]#

Sets a straight-through estimator-like function to all the quantized I/O signals in the module, depending if quantizers are specified in the associated QGELUFormats formats parameter.

Functional#

mptorch.quant.functional.qlinear(input, weight, bias=None, formats=QAffineFormats(default_fwd, default_bwd, rbits_add=0, rbits_mul=0))[source]#

Applies a linear transformation to the incoming data: \(y=xW^T + b\)

The formats parameter allows one to specify if I/O signals should be quantized during inference & training (needed for instance in QAT and PTQ methods), but also the precision(s) to be used in internal GEMM computations (addition and multiplication, fused or not). This allows simulating the effect of custom precision during GEMM calls in the forward and backward pass and is helpful in studying the effect of low precision compute during inference and training (not just data quantization).

This is the functional version of QLinear.

Parameters:
  • input (Tensor) – the input \(x\) to the linear layer of the form \((*, H_\text{in})\), where \(*\) means any number of dimensions including none and \(H_{in} = \text{in_features}\)

  • weight (Tensor) – the weight tensor \(W\) of shape \((\text{out_features}, \text{in_features})\)

  • bias (Tensor | None) – optional bias term of shape \((\text{out_features})\)

  • formats (QAffineFormats) – the configuration object for how quantization (if any!) should be handled on the matrix inputs and how the MAC and summation operations should be performed (e.g. using compensated algorithms or not)

Return type:

Tensor

Returns:

the result of the affine operation \(xW^T + b\)

mptorch.quant.functional.qmatmul(input, other, formats=QAffineFormats(default_fwd, default_bwd, rbits_add=0, rbits_mul=0))[source]#

Simulates a mixed-precision computation pipeline for (batched) matrix multiplication of the tensors input and other.

The behavior depends on the dimensionality of the tensors as follows:

  • If both tensors are 1-dimensional, the dot product (scalar) is returned.

  • If both arguments are 2-dimensional, the matrix-matrix product is returned.

  • If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.

  • If the first argument is 2-dimensional and the second argument is 1-dimensional, the matrix-vector product is returned.

  • If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where 2 < N < 5), then a batched matrix multiply is returned. If the first argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after. If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable in the PyTorch sense). For example, if input is a \((j \times 1 \times n \times n)\) tensor and other is a \((k \times n \times n)\) tensor, out will be a \((j \times k \times n \times n)\) tensor.

    Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs are broadcastable, and not the matrix dimensions. For example, if input is a \((j \times 1 \times n \times m)\) tensor and other is a \((k \times m \times p)\) tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the matrix dimensions) are different. out will be a \((j \times k \times n \times p)\) tensor.

Parameters:
  • input (Tensor) – the first tensor to be multiplied

  • other (Tensor) – the second tensor to be multiplied

  • formats (QAffineFormats) – the configuration object for how quantization (if any!) should be handled on the tensor inputs and how the MAC and summation operations should be performed (e.g. using compensated algorithms or not)

Return type:

Tensor

Returns:

the result of the (batched) matrix multiplication between input and other

mptorch.quant.functional.qmm(input, mat2, formats=QAffineFormats(default_fwd, default_bwd, rbits_add=0, rbits_mul=0))[source]#

Simulates a mixed-precision computation pipeline for matrix multiplication of the matrices input and mat2.

If input is a \((n \times m)\) tensor, mat2 is a \((m \times p)\) tensor, the output tensor will be a \((n \times p)\) tensor.

Note

This function does not broadcast. For broadcasting quantized matrix products, see mptorch.quant.functional.qmatmul().

Parameters:
  • input (Tensor) – the first matrix to be multiplied

  • mat2 (Tensor) – the second matrix to be multiplied

  • formats (QAffineFormats) – the configuration object for how quantization (if any!) should be handled on the matrix inputs and how the MAC and summation operations should be performed (e.g. using compensated algorithms or not)

Return type:

Tensor

Returns:

the result of the matrix multiplication between input and mat2

mptorch.quant.functional.qadd(x, y, fwd_quant, bwd_quant)[source]#

Adds x to y. Uses fwd_quant to quantize the result of the addition (e.g. can simulate the execution of the addition in low-precision, assuming the inputs are already in low precision). The bwd_quant function is used to quantize the gradients from the operator during the backward pass.

For the forward computation:

\[\text{out} = \mathcal{Q}_\text{fwd}(\text{x} + \text{y})\]

For the backward computation:

\[ \begin{align}\begin{aligned}\text{grad_x} = \mathcal{Q}_\text{bwd}(\text{grad_z} * \text{ones_like}(\text{x}))\\\text{grad_y} = \mathcal{Q}_\text{bwd}(\text{grad_z} * \text{ones_like}(\text{y}))\end{aligned}\end{align} \]
Parameters:
  • x (Tensor) – the input tensor

  • y (Tensor) – the other tensor to add to x

  • fwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply on the forward addition

  • bwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply on the gradient computations in the backward pass

Return type:

Tensor

Returns:

the quantized result of the addition operation between x and y

mptorch.quant.functional.qmul(x, y, fwd_quant, bwd_quant)[source]#

Multiplies x by y. Uses fwd_quant to quantize the result of the multiplication (e.g. can simulate the execution of the multiplication in low-precision, assuming the inputs are already in low precision). The bwd_quant function is used to quantize the gradients from the operator during the backward pass.

For the forward computation:

\[\text{out} = \mathcal{Q}_\text{fwd}(\text{x} * \text{y})\]

For the backward computation:

\[ \begin{align}\begin{aligned}\text{grad_x} = \mathcal{Q}_\text{bwd}(\text{grad_z} * \text{y})\\\text{grad_y} = \mathcal{Q}_\text{bwd}(\text{grad_z} * \text{x})\end{aligned}\end{align} \]
Parameters:
  • x (Tensor) – the input tensor

  • y (Tensor) – the other tensor to add to x

  • fwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply on the forward multiplication

  • bwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply on the gradient computations in the backward pass

Return type:

Tensor

Returns:

the quantized result of the multiplication operation between x and y

mptorch.quant.functional.qsqrt(x, fwd_quant, bwd_quant)[source]#

Returns a new tensor with the square-root of the elements of x.

\[\text{out}_{i} = \sqrt{\text{x}_{i}}\]
Parameters:
  • x (Tensor) – the input tensor

  • fwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply on the forward square root operation

  • bwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply on the gradient computations in the backward pass

Return type:

Tensor

Returns:

the quantized result of the square root operation on x

mptorch.quant.functional.qdiv(x, y, fwd_quant, bwd_quant)[source]#

Divides x by y. Uses fwd_quant to quantize the result of the division (e.g. can simulate the execution of the division in low-precision, assuming the inputs are already in low precision). The bwd_quant function is used to quantize the gradients from the operator during the backward pass.

For the forward computation:

\[\text{out}_i = \mathcal{Q}_\text{fwd}\left(\frac{\text{x}_i}{\text{y}_i}\right)\]

For the backward computation:

\[ \begin{align}\begin{aligned}\text{grad_x} = \mathcal{Q}_\text{bwd}\left(\frac{\text{grad_z}}{\text{y}}\right)\\\text{grad_y} = \mathcal{Q}_\text{bwd}\left(\frac{\mathcal{Q}_\text{bwd}\left(-\text{grad_z} * \text{x}\right)}{\text{y}}\right)\end{aligned}\end{align} \]
Parameters:
  • x (Tensor) – the input tensor

  • y (Tensor) – the other tensor to add to x

  • fwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply on the forward division

  • bwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply on the gradient computations in the backward pass

Return type:

Tensor

Returns:

the quantized result of the division operation between x and y

mptorch.quant.functional.qpow(x, fwd_quant, bwd_quant, n=2.0)[source]#

Takes the power of each element in x with n and returns a tensor with the result. n can be either a single float number or a torch.Tensor with the same number of elements as x.

When n is a scalar value, the forward operation applied is:

\[\text{out}_i = \mathcal{Q}_\text{fwd}\left(\text{x}_i^\text{n}\right)\]

When n is a tensor, the operation applied is:

\[\text{out}_i = \mathcal{Q}_\text{fwd}\left(\text{x}_i^{\text{n}_i}\right)\]

When n is a tensor, the shapes of x and n must be broadcastable.

The backward operation is (applied element-wise):

\[\text{out} = \mathcal{Q}_\text{bwd}\left(\text{grad_out} *\mathcal{Q}_\text{bwd}\left(\mathcal{Q}_\text{bwd}\left(\text{x}^{\text{n}-1}\right) * \text{n}\right)\right)\]
Parameters:
  • x (Tensor) – the input tensor

  • fwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply on the forward division

  • bwd_quant (Callable[[Tensor], Tensor]) – the quantization function to apply on the gradient computations in the backward pass

  • n (Tensor | float) – the exponent value

Return type:

Tensor

Returns:

the quantized result of the power operation between x and n

mptorch.quant.functional.qsum(x, dim=None, quant=<function <lambda>>, keepdim=False)[source]#

Returns the quantized sum of all elements in the x tensor. It can simulate low precision summation if the elements of x are low precision values. The quant function specifies the accumulator output format and precision as a quantization function.

Parameters:
  • x (Tensor) – the input tensor

  • dim (int | tuple[int, ...] | None) – the dimension or dimensions to reduce. If None, all dimensions are reduced

  • quant (Callable[[Tensor], Tensor]) – the quantization function specifying how accumulation results should be stored

  • keepdim (bool) – whether the output tensor has dim retained or not

Return type:

Tensor

Returns:

the quantized result of the sum operation operation on x

mptorch.quant.functional.qmean(x, fwd_quant, bwd_quant, dim=3, keepdim=False)[source]#

Returns the mean value of all the elements in the x tensor. Input must be a floating point tensor. It can simulate low precision summation if the elements of x are low precision values. The fwd_quant function specifies the accumulator (and final division) output format and precision as a quantization function in the forward pass. The bwd_quant function specifies how the arithmetic should be performed in the backward pass through this operator.

Parameters:
  • x (Tensor) – the input tensor

  • dim (int | tuple[int, ...] | None) – the dimension or dimensions to reduce. If None, all dimensions are reduced

  • fwd_quant (Callable[[Tensor], Tensor]) – the quantization function specifying how accumulation and division results should be stored in the forward pass

  • bwd_quant (Callable[[Tensor], Tensor]) – the quantization function specifying how operations should be performed in the backward pass

  • keepdim (bool) – whether the output tensor has dim retained or not

Return type:

Tensor

Returns:

the quantized result of the mean operation operation on x

mptorch.quant.functional.qlayernorm(x, normalized_shape, weight, bias, eps=1e-05, formats=QLayerNormFormats(default_fwd, default_bwd))[source]#

Implements the operation as described in the paper Layer Normalization (https://arxiv.org/abs/1607.06450), giving the user control over how the arithmetic is performed during the forward and backward passes through the formats parameter.

\[y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}}\gamma + \beta\]

The mean and standard deviation are computed over the last D dimensions, where D is the dimension of normalized_shape. For instance, if normalized_shape is (3, 5) (ad 2-D shape), the mean and standard deviation are computed over the last 2 dimension of the input (i.e., x.mean((-2, -1))). In a training context, \(\gamma\) and \(\beta\) are learnable affine transform paremeters of normalized_shape.

Note

Unlike Batch Normalization and Instance Normalization, which applies scalar scale and bias for each entire channel/plane with the affine option, Layer Normalization applies per-element scale and bias with elementwise_affine.

It uses statistics computed from input data in both training and evaluation modes.

Parameters:
  • x (Tensor) – the input tensor

  • normalized_shape (int | list[int] | Size) –

    input shape from an expected input of size

    \[[* \times \text{n_shape}[0] \times \text{n_shape}[1] \times \ldots \times \text{n_shape}[-1]]\]

    If a single integer is used, it is treated as a singleton list, and this function will normalize over the last dimension which is expected to be of that specific size

  • weight (Tensor) – the learnable weights \(\gamma\) of shape \(\text{n_shape}\)

  • bias (Tensor) – the learnable bias \(\beta\) of shape \(\text{n_shape}\)

  • eps (float) – a small value added to the denominator for numerical stability. Default: 1e-5

  • formats (QLayerNormFormats) – configuration class for number formats and quantizers to use during forward and backward computations in layer normalization.

Return type:

Tensor

Returns:

the quantized result of the mean operation operation on x. Has the same shape as x.

mptorch.quant.functional.qsoftmax(x, dim, formats)[source]#

Applies the Softmax function to an n-dimensional input tensor x. Through the formats parameter it allows one to specify if I/O signals and internal mathematical operations should be quantized during forward and backward compute chains.

Rescales the elements in the input tensor so that they lie in the range \([0, 1]\) and sum to \(1\). It is defined as:

\[\text{Softmax}(x_i) = \frac{\exp(x_i)}{\sum_j\exp(x_j)}\]
Parameters:
  • x (Tensor) – the input tensor

  • dim (int | None) – a dimension along which Softmax will be computed (so every slice along dim will sum to 1)

  • formats (QSoftmaxFormats) – configuration class for number formats and quantizers to use during forward and backward computations in Softmax

Return type:

Tensor

Returns:

a Tensor of the same dimension and shape as the input with values in the range \([0, 1]\)

Example

# TODO
mptorch.quant.functional.qgelu(x, formats, approximate='none')[source]#

Applies the Gaussian Error Linear Units function to the input \(x\):

\[\text{GELU}(x) = x * \Phi(x)\]

where \(\Phi(x)\) is the Cumulative Distribution Function for Gaussian Distribution.

When the approximate argument is 'tanh', GELU is estimated with:

\[\text{GELU}(x) = 0.5 * x * (1 + \tanh(\sqrt{2 / \pi} * (x + 0.044715 * x^3)))\]
Parameters:
  • x (Tensor) – the input tensor

  • formats (QGELUFormats) – configuration class for number formats and quantizers to use during forward and backward computations in GELU

  • approximate (Literal['tanh', 'none']) – the GELU approximation algorithm to use: 'none' | 'tanh'. Default: 'none'

Return type:

Tensor

Returns:

a Tensor of the same dimension and shape as the input where the GELU function is applied element-wise

Example

import torch
import mptorch.quant as qt
from torch.testing import assert_close

quant_func = lambda x: qt.float_quantize(x, man=7, exp=8, rounding="nearest")
formats = qt.QGELUFormats(
    input_quant=quant_func,
    output_quant=quant_func,
)

x = torch.randn(3, 2)
y = torch.nn.functional.gelu(x)
qy = qt.functional.qgelu(x, formats=formats)
assert_close(y, qy, rtol=1e-2, atol=1e-5)

cuBLAS Acceleration#

class mptorch.quant.cublas.cublas_acceleration(enabled, fast_mode=None)[source]#

Bases: object

cuBLAS acceleration management.

This class allows enabling and disabling of automatic cuBLAS acceleration for compatible types. When enabled, all calls for float quantized (batched) GEMMs (float_mm and float_bmm) used internally linear and convolutional layers will use the cuBLAS GEMM functions if the floating point computation formats matches one of the formats combinations supported by cuBLAS, i.e: - nearest rounding even mode - fused-multiply-add enabled - subnormals enabled - saturation disabled - same multiplication/accumulator types matching one of the [supported combinations](https://docs.nvidia.com/cuda/cublas/#cublasgemmex)

This feature is disabled by default.

Parameters:
  • enabled (bool) – whether to enable automatic cuBLAS acceleration.

  • fast_mode (str | None) – allow internal downcast to lower-precision for tensor cores. Currently supported are f16, bf16 and tf32

Example

cuBLAS acceleration can be enabled/disabled globally using the static enable method; or locally as a context manager:

mac_format = FloatingPoint(
    exp=5, man=10, subnormals=True, saturate=False # F16, supported by cublas
)
layer_formats = QAffineFormats(
    fwd_mac=(mac_format,),
    bwd_mac=(mac_format,),
    fwd_rnd="nearest",
    bwd_rnd="nearest",
    ...
)
layer = QLinear(in_features, out_features, formats=layer_formats)
with cublas_acceleration(True):
    x = torch.tensor(...)
    y = layer.forward(x)
enabled = False#
fast_mode = None#
classmethod enable(status, fast_mode=None)[source]#

Globally enables or disables cuBLAS acceleration.

Parameters:
  • status (bool) – whether to enable or disable cuBLAS acceleration

  • fast_mode (str | None) – use down-conversion to f16, bf16 or tf32 for faster GEMM when possible