API Reference#
Number formats#
- class mptorch.number.Number[source]#
Bases:
object
Base class for all supported number formats.
Users should always instantiate one of the derived classes.
- class mptorch.number.FloatType[source]#
Bases:
Number
Base class for all float-like number formats.
Similar to the
Number
class, users should not instantiate this class directly. It is useful as a means to determine if a number format is of float type.
- class mptorch.number.FixedPoint(wl, fl, clamp=True, symmetric=False)[source]#
Bases:
Number
Low-Precision Fixed Point Number Format. Defined similarly to Deep Learning with Limited Numerical Precision (https://arxiv.org/abs/1502.02551).
The representable range is \(\left[-2^{\text{wl}-\text{fl}-1}, 2^{\text{wl}-\text{fl}-1}-2^{-\text{fl}}\right]\) and the precision unit (smallest nonzero absolute value) is \(2^{-\text{fl}}\). Numbers outside of the representable range can be clamped (if clamp is true). We can also give up the smallest representable number to make the range symmetric, \(\left[-2^{\text{wl}-\text{fl}-1}+2^{-\text{fl}}, 2^{\textnormal{wl}-\text{fl}-1}-2^{-\text{fl}}\right]\) (if symmetric is true).
Define \(\lfloor x \rfloor\) to be the largest representable number (multiples of \(2^-\text{fl}\)) smaller than \(x\). For numbers within the representable range, we support two kinds of fixed point quantization: round to nearest (RN) and stochastic rounding (SR). They correspond to
\[\text{RN}(x) = \Biggl \lbrace { \lfloor x \rfloor, \text{ if } \lfloor x \rfloor \leq x \leq \lfloor x \rfloor + 2^{-\text{fl}-1} \atop \lfloor x \rfloor + 2^{-\text{fl}}, \text{ if } \lfloor x \rfloor + 2^{-\text{fl}-1} < x \leq \lfloor x \rfloor + 2^{-\text{fl}} }\]or
\[\textnormal{SR}(x) = \Biggl \lbrace { \lfloor x \rfloor, \text{ with probabilty } 1 - \frac{x - \lfloor x \rfloor}{2^{-\text{fl}}} \atop \lfloor x \rfloor + 2^{-\text{fl}}, \text{ with probabilty } \frac{x - \lfloor x \rfloor}{2^{-\text{fl}}} }\]- Parameters:
wl (
int
) – word length of each fixed point numberfl (
int
) – fractional length of each fixed point numberclamp (
bool
) – whether to clamp unrepresentable numberssymmetric (
bool
) – whether to make the representable range symmetric
- class mptorch.number.FloatingPoint(exp, man, subnormals=False, saturate=True)[source]#
Bases:
FloatType
Low-Precision Floating Point Format. Follows rules set out in the IEEE-754 standard, applying them in the context of custom precision formats.
We set the exponent bias to be \(2^{\text{exp}-1} - 1\). In terms of rounding mode (see available quantization functions), we offer support for round to nearest even and stochastic rounding.
- Parameters:
exp (
int
) – number of bits allocated for exponentman (
int
) – number of bits allocated for mantissa, referring to number of bits that are supposed to be stored on hardware (not counting the virtual bits)subnormals (
bool
) – allow the use of subnormal valuessaturate (
bool
) – clamp values instead of using infinities in case of overflow
- property is_fp32: bool#
Returns if the format is equivalent to the IEEE-754
binary32
format.
- property is_fp16: bool#
Returns if the format is equivalent to the IEEE-754
binary16
format.
- property is_bfloat16: bool#
Returns if the format is equivalent to the
bfloat16
format.
- class mptorch.number.BlockFloatingPoint(wl, dim=-1)[source]#
Bases:
Number
Low-Precision Block Floating Point Format.
BlockFloatingPoint shares an exponent across a block of numbers. The shared exponent is chosen from the largest magnitude in the block.
- Parameters:
wl (
int
) – word length of the tensordim (
int
) – block dimension to share exponent. (*, D, *) Tensor where D is at position dim will have D different exponents; use -1 if the entire tensor is treated as a single block (there is only 1 shared exponent).
- class mptorch.number.SuperNormalFloat(exp, man, binades, saturate=False)[source]#
Bases:
FloatType
Low-Precision SuperNormal Floating Point Format. Described in Range Extension with Supernormals for Mixed-Precision 8-bit DNN Training (https://www.arith2025.org/proceedings/215900a017.pdf)
We set the exponent bias to be \(2^{\text{exp}-1}\). For rounding mode, we apply round to nearest even.
- Parameters:
exp (
int
) – number of bits allocated for exponentman (
int
) – number of bits allocated for mantissa, referring to number of bits that are supposed to be stored on hardware (not counting the virtual bits)binades (
int
|tuple
[int
] |tuple
[int
,int
]) – number of binades transformed into log rangesaturate – clamp values instead of using infinities in case of overflow
- class mptorch.number.Binary8(P, signed=True, subnormals=True, overflow_policy='saturate_maxfloat2')[source]#
Bases:
FloatType
Low-Precision
binary8
Format the follows the IEEE-P3109 WG specification (P3109/Public). This is a format showcasing an evolving standard. Changes are likely in the future.binary8
is a format that takes a value P as an input to determines the number of mantissa and exponent bits.- Parameters:
P (
int
) – integer precision of thebinary8
formatsigned (
bool
) – boolean indicating whether the format is signed or unsignedsubnormals (
int
) – allow the use of subnormal valuesoverflow_policy (
Literal
['saturarte_infty'
,'saturate_maxfloat'
,'saturate_maxfloat2'
]) – string indicating the overflow policy, one of:saturate_maxfloat2
(no infinity and +1 normalized value),saturate_maxfloat
(clamp to maxfloat),saturate_infty
(use infinity)
Quantization#
- mptorch.quant.fixed_point_quantize(x, wl, fl, clamp=True, symmetric=False, rounding='stochastic')[source]#
Quantize a single precision floating-point tensor into a low-precision fixed-point tensor
- Parameters:
x (
Tensor
) – the single precision tensor to be quantizedwl (
int
) – word length of the fixed-point format being simulatedfl (
int
) – fractional length of the fixed-point format being simulatedclamp (
bool
) – clamp input numbers into representable range. If false, the quantization will only simulate the effect on precisionsymmetric (
bool
) – discard the minimum representable number to make the representable range symmetricrounding (
Literal
['nearest'
,'stochastic'
]) – rounding mode, “stochastic” or “nearest” (default: “stochastic”)
- Return type:
Tensor
- Returns:
a quantized fixed-point representation of the input tensor
- mptorch.quant.block_quantize(x, wl, dim=-1, rounding='stochastic')[source]#
Quantize a single precision floating-point tensor into a low-precision block floating-point representation
- Parameters:
x (
Tensor
) – the single precision tensor to be quantizedwl (
int
) – word length of the block floating-point format being simulateddim (
int
) – dimension over which to apply the block floating point representation (-1 applies it to the entire tensor)rounding (
Literal
['nearest'
,'stochastic'
]) – rounding mode, “stochastic” or “nearest”
- Return type:
Tensor
- Returns:
a quantized low-precision block floating-point representation of the input tensor
- mptorch.quant.float_quantize(x, exp, man, rounding='stochastic', subnormals=True, saturate=True, prng=0)[source]#
Quantize a single precision floating-point tensor into a IEEE-754-like low-precision floating-point tensor
- Parameters:
x (
Tensor
) – the single precision number to be quantizedexp (
int
) – number of bits allocated for exponentman (
int
) – number of bits allocated for mantissa, not counting the virtual bitrounding (
Literal
['nearest'
,'stochastic'
]) – rounding mode, “stochastic” or “nearest”subnormals (
bool
) – if subnormals are supported or notsaturate (
bool
) – saturate on overflow or use infinitiesprng (
int
) – number of random bits to use in case of stochastic rounding
- Return type:
Tensor
- Returns:
a quantized low-precision floating point representation of the input tensor
- mptorch.quant.binary8_quantize(x, P, rounding='nearest', overflow_policy='saturate_maxfloat', is_signed=True, subnormals=True, prng_bits=0)[source]#
Quantize a single precision floating-point tensor into a P3109-compatible one
- Parameters:
x (
Tensor
) – the single precision number(torch.Tensor) to be quantizedP (
int
) – number of bits allocated for precisionis_signed (
bool
) – if subnormals are supported or notrounding (
Literal
['nearest'
,'stochastic'
,'truncate'
]) – the quantization rounding modeoverflow_policy (
Literal
['saturate_infty'
,'saturate_maxfloat'
,'saturate_maxfloat2'
]) – overflow handling policysubnormals (
bool
) – saturate on overflow or use infinitiesprng_bits (
int
) – number of bits for the random generator
- Overflow Policies:
saturate_infty
: Finite input values of binary32, exceeding the maximum float value of the binary8 format, will saturate to the maximum float. Infinite inputs will still map to infinities in this mode.saturate_maxfloat
: Both finite and infinite input values of binary32, exceeding the maximum float value of the binary8 format, will saturate to the maximum float represented by 0x7e/0xfe. This number system has an encoding reserved for infinity (0x7f/0xff).saturate_maxfloat2
: Both finite and infinite input values of binary32, exceeding the maximum float value of the binary8 format, will saturate to the maximum float represented by 0x7f/0xff. This number system does not have an encoding reserved for infinity.
- Return type:
Tensor
- Returns:
a quantized low-precision floating point representation of the input tensor
- mptorch.quant.binaryK_quantize(x, K, P, rounding='nearest_even', saturation_mode='overflow_infinity', is_signed=True, prng_bits=0, bias=None)[source]#
- Return type:
Tensor
- mptorch.quant.superfp_quantize(x, exp, man, binades, rounding='nearest', saturate=False)[source]#
Quantize a single precision floating-point tensor into a low-precision supernormal floating-point one
- Parameters:
x (
Tensor
) – the single precision number to be quantizedexp (
int
) – number of bits allocated for exponentman (
int
) – number of bits allocated for mantissa, not counting the virtual bitbinades (
int
|tuple
[int
] |tuple
[int
,int
]) – number of binades that will be transformed into log rangerounding (
Literal
['stochastic'
,'nearest'
]) – rounding mode, “stochastic” or “nearest”saturate (
bool
) – saturate on overflow or use infinities
- Return type:
Tensor
- Returns:
a quantized low-precision supernormal floating-point representation of the input tensor
- mptorch.quant.quantizer(forward_number=None, backward_number=None, forward_rounding='stochastic', backward_rounding='stochastic', clamping_grad_zero=False, backward_hooks=[])[source]#
Creates a quantization function to support quantizing forward and backward process differently.
- Parameters:
forward_number (
Number
|None
) – the number format used for forward quantization. if is None, the quantization would be a identity mapping.backward_number (
Number
|None
) – the number format used for backward quantization. if is None, the quantization would be a identity mapping.forward_rounding (
Literal
['nearest'
,'stochastic'
]) – rounding mode, "stochastic" or "nearest" (default: "stochastic")backward_rounding (
Literal
['nearest'
,'stochastic'
]) – rounding mode, "stochastic" or "nearest" (default: "stochastic")clamping_grad_zero (
bool
) – zero out the gradient of numbers that are being clamped during forward propagation. currently requires forward_number to be a fixed point number.backward_hooks – iterable of functions that will be applied to gradients before backward quantization. For example, this can be used to support custom scaling.
- Returns:
A quantization function as specified
Quantization Formats#
- class mptorch.quant.quant_format.QAffineFormats(fwd_mac=None, bwd_mac=None, fwd_rnd=None, bwd_rnd=None, weight_quant=<function <lambda>>, bias_quant=<function <lambda>>, input_quant=<function <lambda>>, output_quant=<function <lambda>>, grad_quant=<function <lambda>>, compensated=None, use_scaling=False, weight_scaled_format=None, input_scaled_format=None, grad_scaled_format=None, scale_margin=0, rbits=0)[source]#
Bases:
object
Configuration class for number formats to use during compute (forward and/or backward pass) of affine layers (e.g. linear and convolutional). One can optionally specify quantizer objects for the signals in the layer (I/O activations, weights/bias terms and weight/error gradients) to facilitate quantization-aware-training (QAT) and post-training quantization (PTQ) workloads. Format parameters can also be specified for tensor scaling operations, in a similar way to what is described in: https://arxiv.org/pdf/2309.17224
- Parameters:
fwd_mac (
Number
|tuple
[Number
] |tuple
[Number
,Number
] |None
) – compute configuration (add and multiply) for forward MAC operationsbwd_mac (
Number
|tuple
[Number
] |tuple
[Number
,Number
] |None
) – compute configuration (add and multiply) for backward MAC operationsfwd_rnd (
str
|None
) – rounding mode for FWD computationsbwd_rnd (
str
|None
) – rounding mode for BWD computationsweight_quant (
Union
[Callable
[[Tensor
],Tensor
],tuple
[Number
,str
]]) – quantization function or format and rounding on the weight signal inputsbias_quant (
Union
[Callable
[[Tensor
],Tensor
],tuple
[Number
,str
]]) – quantization function or format and rounding on the bias signal inputsinput_quant (
Union
[Callable
[[Tensor
],Tensor
],tuple
[Number
,str
]]) – quantization function or format and rounding on the output signal from the layergrad_quant (
Union
[Callable
[[Tensor
],Tensor
],Tuple
[Number
,str
]]) – quantization function or format and rounding on the gradient signals in the BWD passuse_scaling (
bool
) – whether to use weight, input and grad scaling during forward/backward passweight_scaled_format (
FloatType
|None
) – number format to be used during weight tensor scaling (optional, matches weight_quant if format specified)input_scaled_format (
FloatType
|None
) – number format to be used during input tensor scaling (optional, matches input_quant if format specified)grad_scaled_format (
FloatType
|None
) – number format to be used during output tensor scaling (optional, matches grad_quant if format specified)scale_margin (
int
) – margin to reduce scaling bias when casting down to FP8 formatsrbits (
int
|tuple
[int
] |tuple
[int
,int
]) – number of bits used for random number generation when rounding is stochastic (for add and multiply)
- class mptorch.quant.quant_format.QSoftmaxFormats(fwd_off=None, fwd_exp=None, fwd_acc=None, fwd_lse=None, bwd_add=None, bwd_mul=None, fwd_rnd='nearest', bwd_rnd='nearest', input_quant=<function <lambda>>, output_quant=<function <lambda>>, grad_quant=<function <lambda>>)[source]#
Bases:
object
Configuration class for number formats to use during compute (forward and/or backward pass) of softmax layers. One can optionally specify quantizer objects for the signals in the layer (I/O activations and weight/error gradients) to facilitate quantization-aware-training (QAT) and post-training quantization (PTQ) workloads.
Two implementations of the forward softmax are provided: regular and LogSumExp-based (LSE).
Regular softmax is used when fwd_off, fwd_exp and fwd_acc are set, and is implemented as follows:
\[\textrm{softmax}(x)_i = \frac{\exp(x_i - \max x)}{ \sum_j \exp(x_j - \max x)}\]The LogSumExp implementation is used when fwd_off and fwd_exp are set, and is implemented as follows:
\[\textrm{softmax}(x)_i = \ln(\textrm{LSE}(x_1, ..., x_n))\]where \(\textrm{LSE}(x_1, ..., x_n)\) is computed iteratively with the relation:
\[\textrm{LSE}(x_1, ..., x_{j+1}) = \ln(\exp \textrm{LSE}(x_1, ..., x_{j}) + \exp x_{j+1})\]with the internal part of the log being computed at full precision.
- Parameters:
fwd_off (
Number
|None
) – compute configuration for forward subtractionfwd_exp (
Number
|None
) – compute configuration for forward exponential operationsfwd_acc (
Number
|None
) – compute configuration for forward add operationsfwd_lse (
Number
|None
) – compute configuration for forward LSE iterationbwd_add (
Number
|None
) – compute configuration for backward add operationsbwd_mul (
Number
|None
) – compute configuration for backward multiply operationsfwd_rnd (
str
|None
) – rounding mode for forward computationsbwd_rnd (
str
|None
) – rounding mode for backward computationsinput_quant (
Callable
[[Tensor
],Tensor
]) – quantization function on the input signaloutput_quant (
Callable
[[Tensor
],Tensor
]) – quantization function on the output signalgrad_quant (
Callable
[[Tensor
],Tensor
]) – quantization function on the gradients
- class mptorch.quant.quant_format.QLayerNormFormats(fwd_acc=None, fwd_mul=None, fwd_div=None, fwd_sqrt=None, bwd_acc=None, bwd_mul=None, bwd_div=None, fwd_rnd='nearest', bwd_rnd='nearest', input_quant=<function <lambda>>, output_quant=<function <lambda>>, grad_quant=<function <lambda>>, weight_quant=<function <lambda>>, bias_quant=<function <lambda>>)[source]#
Bases:
object
Configuration class for number formats to use during compute (forward and/or backward pass) of layer normalization. One can optionally specify quantizer objects for the signals in the layer (I/O activations, weights/bias terms and weight/error gradients) to facilitate quantization-aware-training (QAT) and post-training quantization (PTQ) workloads.
- Parameters:
fwd_acc (
Number
|None
) – compute configuration for forward add operationsfwd_mul (
Number
|None
) – compute configuration for forward multiply operationsfwd_div (
Number
|None
) – compute configuration for forward divide operationsfwd_sqrt (
Number
|None
) – compute configuration for forward square root operationsbwd_acc (
Number
|None
) – compute configuration for backward add operationsbwd_mul (
Number
|None
) – compute configuration for backward multiply operationsbwd_div (
Number
|None
) – compute configuration for backward divide operations,fwd_rnd (
str
|None
) – rounding mode for forward computationsbwd_rnd (
str
|None
) – rounding mode for backward computationsinput_quant (
Callable
[[Tensor
],Tensor
]) – quantization function on the input signaloutput_quant (
Callable
[[Tensor
],Tensor
]) – quantization function on the output signalgrad_quant (
Callable
[[Tensor
],Tensor
]) – quantization function on the gradientsweight_quant (
Callable
[[Tensor
],Tensor
]) – quantization function on the weights when applied to an inputbias_quant (
Callable
[[Tensor
],Tensor
]) – quantization function on the bias when applied to an input
- class mptorch.quant.quant_format.QGELUFormats(input_quant=<function <lambda>>, inter_quant=<function <lambda>>, output_quant=<function <lambda>>, grad_quant=<function <lambda>>)[source]#
Bases:
object
Configuration class for number formats to use during compute (forward and/or backward pass) of GELU activation.
- Parameters:
input_quant (
Callable
[[Tensor
],Tensor
]) – quantization function on the input signalinter_quant (
Callable
[[Tensor
],Tensor
]) – quantization function on intermediate computation, depends on wether tanh approximation is usedoutput_quant (
Callable
[[Tensor
],Tensor
]) – quantization function on the output signalgrad_quant (
Callable
[[Tensor
],Tensor
]) – quantization function on the gradients
Modules#
- class mptorch.quant.modules.Quantizer(forward_number=None, backward_number=None, forward_rounding='nearest', backward_rounding='nearest')[source]#
Bases:
Module
A quantization module that supports quantizing forward and backward process differently.
- Parameters:
forward_number (
Number
|None
) – the number format used for forward quantization. if is None, the quantization would be a identity mapping.backward_number (
Number
|None
) – the number format used for backward quantization. if is None, the quantization would be a identity mapping.forward_rounding (
Literal
['nearest'
,'stochastic'
]) – rounding mode, "stochastic" or "nearest" (default: "stochastic")backward_rounding (
Literal
['nearest'
,'stochastic'
]) – rounding mode, "stochastic" or "nearest" (default: "stochastic")
- forward(x)[source]#
The forward pass call on the input tensor. Applies the forward quantization on the input and registers the backward quantization format.
- Parameters:
x (
Tensor
) – the input tensor- Return type:
Tensor
- Returns:
The quantized version of the input, as specified by the FWD number format and associated rounding mode
- class mptorch.quant.modules.QLinear(in_features, out_features, formats, bias=True)[source]#
Bases:
Linear
Applies a linear transformation to the incoming data: \(y=xW^T + b\)
It is a subclass of
torch.nn.Linear
and allows one to specify if I/O signals should be quantized during inference & training (needed for instance in QAT and PTQ methods), but also the precision(s) to be used in internal GEMM computations (addition and multiplication, fused or not). This allows simulating the effect of custom precision during GEMM calls in the forward and backward pass and is helpful in studying the effect of low precision compute during inference and training (not just data quantization).- Parameters:
in_features (
int
) – size of each input sampleout_features (
int
) – size of each output sampleformats (
QAffineFormats
) – number formats used during compute (addition and multiplication) and quantization functions for signals during forward and back propagation (I/O activations, weights, biases, and neural gradients)bias (
bool
) – If set toFalse
, the layer will not learn an additive bias. Default:True
- Shape:
Input: \((*, H_\text{in})\) where \(*\) means any number of dimensions including none and \(H_\text{in} = \text{in_features}\).
Output: \((*, H_\text{out})\) where all but the last dimension are the same shape as the input and \(H_\text{out} = \text{out_features}\).
- weight#
the learnable weights of the module of shape \((\text{out_features}, \text{in_features})\). The values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\), where \(k = \frac{1}{\text{in_features}}\)
- Type:
Tensor
- bias#
the learnable bias of the module of shape \((\text{out_features})\). If
bias
isTrue
, the values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{1}{\text{in_features}}\)- Type:
Tensor
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(input)[source]#
Describes the computations that get performed at every call of the module. The use of quantized elementary operations (i.e., additions and multiplications) in the FWD and BWD passes are controlled through the
formats
argument to the module constructor.- Parameters:
input (
Tensor
) – the input tensor over which to perform the layer operations. Must adhere to the input shape requirements.- Return type:
Tensor
- Returns:
the result of the \(xW^T + b\) operation.
- quant_function(fwd_quant, bwd_quant)[source]#
Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.
- Parameters:
fwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the forward pass through the input signalbwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the backward pass through the output gradient signal
- class mptorch.quant.modules.QLazyLinear(out_features, formats, bias=True, device=None, dtype=None)[source]#
Bases:
LazyModuleMixin
,QLinear
A linear module where in_features is inferred.
In this module (an analogue to
torch.nn.LazyLinear
), thein_features
parameter of the quantized linear layer is inferred from the input’s last dimension (i.e.,input.shape[-1]
). The weight and bias layer parameters are oftorch.nn.UninitializedParameter
class. They are initialized after the first call toforward
and the module becomes a regulartorch.nn.Linear
module.- Parameters:
out_features (
int
) – size of each output sampleformats (
QAffineFormats
) – number formats used during compute (addition and multiplication) and quantization functions for signals during forward and back propagation (I/O activations, weights, biases, and neural gradients)bias (
bool
) – If set toFalse
, the layer will not learn an additive bias. Default:True
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- initialize_parameters(input)[source]#
Initialize parameters according to the input batch properties.
This adds an interface to isolate parameter initialization from the forward pass when doing parameter shape inference.
- Parameters:
input (
Tensor
) – the input tensor on which to perform the layer operations. It’s shape is used to determine the value of thein_features
member variable.
-
weight:
UninitializedParameter
# The learnable weights of the module of shape \((\text{out_features}, \text{in_features})\). The values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\), where \(k = \frac{1}{\text{in_features}}\)
- class mptorch.quant.modules.QConv1d(in_channels, out_channels, kernel_size, formats, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')[source]#
Bases:
Conv1d
Applies a 1D convolution over an input signal composed of several input planes.
It is a subclass of
torch.nn.Conv1d
and allows one to specify if I/O signals should be quantized during inference & training (needed for instance in QAT and PTQ methods), but also the precision(s) to be used in internal GEMM computations (addition and multiplication, fused or not). This allows simulating the effect of custom precision during GEMM calls in the forward and backward pass (which are performed using the im2col and col2im algorithms implemented in the unfoldNd library) and is helpful in studying the effect of low precision compute during inference and training (not just data quantization).In the simplest case, the output value of the layer with input size \((N, C_{\text{in}}, L)\) and output \((N, C_{\text{out}}, L_{\text{out}})\) can be precisely described as:
\[\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + \sum_{k = 0}^{C_\text{in} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)\]where \(\star\) is the valid cross-correlation operator, \(N\) is a batch size, \(C\) denotes a number of channels, \(L\) is a length of signal sequence.
stride
controls the stride for the cross-correlation, a single number or a one-element tuple.padding
controls the amount of padding applied to the input. It can be either a string valid or same or a tuple of ints giving the amount of implicit padding applied on both sides.dilation
controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this link has a nice visualization of whatdilation
does.groups
controls the connections between inputs and outputs.in_channels
andout_channels
must both be divisible bygroups
. For example,At groups=1, all inputs are convolved to all outputs.
At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels and producing half the output channels, and both subsequently concatenated.
At groups=
in_channels
, each input channel is convolved with its own set of filters (of size \(\frac{\text{out_channels}}{\text{in_channels}}\)).
Note
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, this operation is also known as a “depthwise convolution”.
In other words, for an input of size \((N, C_\text{in}, L_\text{in})\), a depthwise convolution with a depthwise multiplier K can be performed with the arguments \((C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})\).
- Shape:
Input: \((N, C_\text{in}, L_\text{in})\) or \((C_\text{in}, L_\text{in})\)
Output: \((N, C_\text{out}, L_\text{out})\) or \((C_\text{out}, L_\text{out})\), where
\[L_\text{out} = \left\lfloor\frac{L_\text{in} + 2 \times \text{padding} - \text{dilation} \times (\text{kernel_size} - 1) - 1}{\text{stride}} + 1\right\rfloor\]
- weight#
the learnable weights of the module of shape \((\text{out_channels}, \frac{\text{in_channels}}{\text{groups}}, \text{kernel_size})\). The values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{in} * \text{kernel_size}}\)
- Type:
Tensor
- bias#
the learnable bias of the module of shape (out_channels). If
bias
isTrue
, then the values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{in} * \text{kernel_size}}\)- Type:
Tensor
- Parameters:
in_channels (
int
) – Number of channels in the input imageout_channels (
int
) – Number of channels produced by the convolutionkernel_size (
int
|tuple
[int
]) – Size of the convolving kernelformats (
QAffineFormats
) – Number formats used during compute (addition and multiplication) and quantization functions for signals during forward and back propagation (I/O activations, weights, biases, and neural gradients)stride (
int
|tuple
[int
]) – Stride of the convolution. Default: 1padding (
str
|int
|tuple
[int
]) – Padding added to both sides of the input. Default: 0dilation (
int
|tuple
[int
]) – Spacing between kernel elements. Default: 1groups (
int
) – Number of blocked connections from input channels to output channels. Default: 1bias (
bool
) – IfTrue
, adds a learnable bias to the output. Default:True
padding_mode (
str
) –'zeros'
,'reflect'
,'replicate'
or'circular'
. Default:'zeros'
- forward(input)[source]#
Describes the computations that get performed at every call of the module. The use of quantized elementary operations (i.e., additions and multiplications) in the FWD and BWD passes are controlled through the
formats
argument to the module constructor.- Parameters:
input (
Tensor
) – the input tensor on which to perform the layer operations. Must adhere to the input shape requirements.- Return type:
Tensor
- Returns:
the result of the 1D cross-correlation operation.
- quant_function(fwd_quant, bwd_quant)[source]#
Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.
- Parameters:
fwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the forward pass through the input signalbwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the backward pass through the output gradient signal
- class mptorch.quant.modules.QConv2d(in_channels, out_channels, kernel_size, formats, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')[source]#
Bases:
Conv2d
Applies a 2D convolution over an input signal composed of several input planes.
It is a subclass of
torch.nn.Conv2d
and allows one to specify if I/O signals should be quantized during inference & training (needed for instance in QAT and PTQ methods), but also the precision(s) to be used in internal GEMM computations (addition and multiplication, fused or not). This allows simulating the effect of custom precision during GEMM calls in the forward and backward pass (which are performed using the im2col and col2im algorithms implemented in the unfoldNd library) and is helpful in studying the effect of low precision compute during inference and training (not just data quantization).In the simplest case, the output value of the layer with input size \((N, C_{\text{in}}, H, W)\) and output \((N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})\) can be precisely described as:
\[\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)\]where \(\star\) is the valid 2D cross-correlation operator, \(N\) is a batch size, \(C\) denotes a number of channels, \(H\) is a height of input planes in pixels, and \(W\) is width in pixels.
stride
controls the stride for the cross-correlation, a single number or a tuple.padding
controls the amount of padding applied to the input. It can be either a string {{‘valid’, ‘same’}} or an int / a tuple of ints giving the amount of implicit padding applied on both sides.dilation
controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this link has a nice visualization of whatdilation
does.groups
controls the connections between inputs and outputs.in_channels
andout_channels
must both be divisible bygroups
. For example,At groups=1, all inputs are convolved to all outputs.
At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels and producing half the output channels, and both subsequently concatenated.
At groups=
in_channels
, each input channel is convolved with its own set of filters (of size \(\frac{\text{out_channels}}{\text{in_channels}}\)).
The parameters
kernel_size
,stride
,padding
,dilation
can either be:a single
int
– in which case the same value is used for the height and width dimensiona
tuple
of two ints – in which case, the first int is used for the height dimension, and the second int for the width dimension
Note
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, this operation is also known as a “depthwise convolution”.
In other words, for an input of size \((N, C_\text{in}, L_\text{in})\), a depthwise convolution with a depthwise multiplier K can be performed with the arguments \((C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})\).
Note
padding='valid'
is the same as no padding.padding='same'
pads the input so the output has the shape as the input. However, this mode doesn’t support any stride values other than 1.- Shape:
Input: \((N, C_\text{in}, H_\text{in}, W_\text{in})\) or \((C_\text{in}, H_\text{in}, W_\text{in})\)
Output: \((N, C_\text{out}, H_\text{out}, W_\text{out})\) or \((C_\text{out}, H_\text{out}, W_\text{out})\), where
\[H_\text{out} = \left\lfloor\frac{H_\text{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times (\text{kernel_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor\]\[W_\text{out} = \left\lfloor\frac{W_\text{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times (\text{kernel_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor\]
- weight#
the learnable weights of the module of shape \((\text{out_channels}, \frac{\text{in_channels}}{\text{groups}},\) \(\text{kernel_size[0]}, \text{kernel_size[1]})\). The values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{in} * \prod_{i=0}^{1}\text{kernel_size}[i]}\)
- Type:
Tensor
- bias#
the learnable bias of the module of shape (out_channels). If
bias
isTrue
, then the values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{in} * \prod_{i=0}^{1}\text{kernel_size}[i]}\)- Type:
Tensor
- Parameters:
in_channels (
int
) – Number of channels in the input imageout_channels (
int
) – Number of channels produced by the convolutionkernel_size (
int
|tuple
[int
,int
]) – Size of the convolving kernelformats (
QAffineFormats
) – Number formats used during compute (addition and multiplication) and quantization functions for signals during forward and back propagation (I/O activations, weights, biases, and neural gradients)stride (
int
|tuple
[int
,int
]) – Stride of the convolution. Default: 1padding (
str
|int
|tuple
[int
,int
]) – Padding added to all four sides of the input. Default: 0dilation (
int
|tuple
[int
,int
]) – Spacing between kernel elements. Default: 1groups (
int
) – Number of blocked connections from input channels to output channels. Default: 1bias (
bool
) – IfTrue
, adds a learnable bias to the output. Default:True
padding_mode (
str
) –'zeros'
,'reflect'
,'replicate'
or'circular'
. Default:'zeros'
- forward(input)[source]#
Describes the computations that get performed at every call of the module. The use of quantized elementary operations (i.e., additions and multiplications) in the FWD and BWD passes are controlled through the
formats
argument to the module constructor.- Parameters:
input (
Tensor
) – the input tensor on which to perform the layer operations. Must adhere to the input shape requirements.- Return type:
Tensor
- Returns:
the result of the 2D cross-correlation operation.
- quant_function(fwd_quant, bwd_quant)[source]#
Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.
- Parameters:
fwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the forward pass through the input signalbwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the backward pass through the output gradient signal
- class mptorch.quant.modules.QConv3d(in_channels, out_channels, kernel_size, formats, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')[source]#
Bases:
Conv3d
Applies a 3D convolution over an input signal composed of several input planes.
It is a subclass of
torch.nn.Conv3d
and allows one to specify if I/O signals should be quantized during inference & training (needed for instance in QAT and PTQ methods), but also the precision(s) to be used in internal GEMM computations (addition and multiplication, fused or not). This allows simulating the effect of custom precision during GEMM calls in the forward and backward pass (which are performed using the im2col and col2im algorithms implemented in the unfoldNd library) and is helpful in studying the effect of low precision compute during inference and training (not just data quantization).In the simplest case, the output value of the layer with input size \((N, C_{\text{in}}, D, H, W)\) and output \((N, C_{\text{out}}, D_{\text{out}}, H_{\text{out}}, W_{\text{out}})\) can be precisely described as:
\[\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)\]where \(\star\) is the valid 3D cross-correlation operator, \(N\) is a batch size, \(C\) denotes a number of channels, \(D\) is a depth of input planes in pixels, \(H\) is a height of input planes in pixels, and \(W\) is width in pixels.
stride
controls the stride for the cross-correlation, a single number or a tuple.padding
controls the amount of padding applied to the input. It can be either a string {{‘valid’, ‘same’}} or an int / a tuple of ints giving the amount of implicit padding applied on both sides.dilation
controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but this link has a nice visualization of whatdilation
does.groups
controls the connections between inputs and outputs.in_channels
andout_channels
must both be divisible bygroups
. For example,At groups=1, all inputs are convolved to all outputs.
At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels and producing half the output channels, and both subsequently concatenated.
At groups=
in_channels
, each input channel is convolved with its own set of filters (of size \(\frac{\text{out_channels}}{\text{in_channels}}\)).
The parameters
kernel_size
,stride
,padding
,dilation
can either be:a single
int
– in which case the same value is used for the height and width dimensiona
tuple
of two ints – in which case, the first int is used for the height dimension, and the second int for the width dimension
Note
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, this operation is also known as a “depthwise convolution”.
In other words, for an input of size \((N, C_\text{in}, L_\text{in})\), a depthwise convolution with a depthwise multiplier K can be performed with the arguments \((C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})\).
Note
padding='valid'
is the same as no padding.padding='same'
pads the input so the output has the shape as the input. However, this mode doesn’t support any stride values other than 1.- Shape:
Input: \((N, C_\text{in}, D_\text{in}, H_\text{in}, W_\text{in})\) or \((C_\text{in}, D_\text{in}, H_\text{in}, W_\text{in})\)
Output: \((N, C_\text{out}, D_\text{out}, H_\text{out}, W_\text{out})\) or \((C_\text{out}, D_\text{out}, H_\text{out}, W_\text{out})\), where
\[D_\text{out} = \left\lfloor\frac{D_\text{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times (\text{kernel_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor\]\[H_\text{out} = \left\lfloor\frac{H_\text{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times (\text{kernel_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor\]\[W_\text{out} = \left\lfloor\frac{W_\text{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times (\text{kernel_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor\]
- weight#
the learnable weights of the module of shape \((\text{out_channels}, \frac{\text{in_channels}}{\text{groups}},\) \(\text{kernel_size[0]}, \text{kernel_size[1]}, \text{kernel_size[2]})\). The values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{in} * \prod_{i=0}^{2}\text{kernel_size}[i]}\)
- Type:
Tensor
- bias#
the learnable bias of the module of shape (out_channels). If
bias
isTrue
, then the values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{in} * \prod_{i=0}^{2}\text{kernel_size}[i]}\)- Type:
Tensor
- Parameters:
in_channels (
int
) – Number of channels in the input imageout_channels (
int
) – Number of channels produced by the convolutionkernel_size (
int
|tuple
[int
,int
,int
]) – Size of the convolving kernelformats (
QAffineFormats
) – Number formats used during compute (addition and multiplication) and quantization functions for signals during forward and back propagation (I/O activations, weights, biases, and neural gradients)stride (
int
|tuple
[int
,int
,int
]) – Stride of the convolution. Default: 1padding (
str
|int
|tuple
[int
,int
,int
]) – Padding added to all four sides of the input. Default: 0dilation (
int
|tuple
[int
,int
,int
]) – Spacing between kernel elements. Default: 1groups (
int
) – Number of blocked connections from input channels to output channels. Default: 1bias (
bool
) – IfTrue
, adds a learnable bias to the output. Default:True
padding_mode (
str
) –'zeros'
,'reflect'
,'replicate'
or'circular'
. Default:'zeros'
- forward(input)[source]#
Describes the computations that get performed at every call of the module. The use of quantized elementary operations (i.e., additions and multiplications) in the FWD and BWD passes are controlled through the
formats
argument to the module constructor.- Parameters:
input (
Tensor
) – the input tensor on which to perform the layer operations. Must adhere to the input shape requirements.- Return type:
Tensor
- Returns:
the result of the 3D cross-correlation operation.
- quant_function(fwd_quant, bwd_quant)[source]#
Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.
- Parameters:
fwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the forward pass through the input signalbwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the backward pass through the output gradient signal
- class mptorch.quant.modules.QConvTranspose1d(in_channels, out_channels, kernel_size, formats, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', device=None, dtype=None)[source]#
Bases:
ConvTranspose1d
Applies a 1D transposed convolution operator over an input image composed of several input planes.
This module can be seen as the gradient of Conv1d with respect to its input. It is also known as a fractionally-strided convolution or a deconvolution (although it is not an actual deconvolution operation as it does not compute a true inverse of convolution). For more information, see the visualizations here and the Deconvolutional Networks paper.
It is a subclass of
torch.nn.ConvTranspose1d
and allows one to specify if I/O signals should be quantized during inference & training (needed for instance in QAT and PTQ methods), but also the precision(s) to be used in internal GEMM computations (addition and multiplication, fused or not). This allows simulating the effect of custom precision during GEMM calls in the forward and backward pass (which are performed using the im2col and col2im algorithms implemented in the unfoldNd library) and is helpful in studying the effect of low precision compute during inference and training (not just data quantization).stride
controls the stride for the cross-correlation.padding
controls the amount of implicit zero padding on both sides fordilation * (kernel_size - 1) - padding
number of points. See note below for details.output_padding
controls the additional size added to one side of the output shape. See note below for details.dilation
controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but the link here has a nice visualization of whatdilation
does.in_channels
andout_channels
must both be divisible bygroups
. For example,At groups=1, all inputs are convolved to all outputs.
At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels and producing half the output channels, and both subsequently concatenated.
At groups=
in_channels
, each input channel is convolved with its own set of filters (of size \(\frac{\text{out_channels}}{\text{in_channels}}\)).
Note
The
padding
argument effectively addsdilation * (kernel_size - 1) - padding
amount of zero padding to both sizes of the input. This is set so that when atorch.nn.Conv1d
and atorch.nn.ConvTranspose1d
are initialized with same parameters, they are inverses of each other in regard to the input and output shapes. However, whenstride > 1
,torch.nn.Conv1d
maps multiple input shapes to the same output shape.output_padding
is provided to resolve this ambiguity by effectively increasing the calculated output shape on one side. Note thatoutput_padding
is only used to find output shape, but does not actually add zero-padding to output.- Shape:
Input: \((N, C_\text{in}, L_\text{in})\) or \((C_\text{in}, L_\text{in})\)
Output: \((N, C_\text{out}, L_\text{out})\) or \((C_\text{out}, L_\text{out})\), where
\[L_\text{out} = (L_\text{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{dilation} \times (\text{kernel_size} - 1) + \text{output_padding} + 1\]
- weight#
the learnable weights of the module of shape \((\text{in_channels}, \frac{\text{out_channels}}{\text{groups}},\) \(\text{kernel_size})\). The values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{out} * \text{kernel_size}}\)
- Type:
Tensor
- bias#
the learnable bias of the module of shape (out_channels). If
bias
isTrue
, then the values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{out} * \text{kernel_size}}\)- Type:
Tensor
- Parameters:
in_channels (
int
) – Number of channels in the input imageout_channels (
int
) – Number of channels produced by the convolutionkernel_size (
int
|tuple
[int
]) – Size of the convolving kernelformats (
QAffineFormats
) – Number formats used during compute (addition and multiplication) and quantization functions for signals during forward and back propagation (I/O activations, weights, biases, and neural gradients)stride (
int
|tuple
[int
]) – Stride of the convolution. Default: 1padding (
int
|tuple
[int
]) –dilation * (kernel_size - 1) - padding
zero-padding will be added to both sides of the input. Default: 0output_padding (
int
|tuple
[int
]) – Additional size added to one side of the output shape. Default: 0groups (
int
) – Number of blocked connections from input channels to output channels. Default: 1bias (
bool
) – IfTrue
, adds a learnable bias to the output. Default:True
dilation (
int
|tuple
[int
]) – Spacing between kernel elements. Default: 1
- forward(input, output_size=None)[source]#
Describes the computations that get performed at every call of the module. The use of quantized elementary operations (i.e., additions and multiplications) in the FWD and BWD passes are controlled through the
formats
argument to the module constructor.- Parameters:
input (
Tensor
) – the input tensor on which to perform the layer operations. Must adhere to the input shape requirements.output_size (
list
[int
] |None
) – specifies how the output should be padded
- Return type:
Tensor
- Returns:
the result of the 1D transposed convolution operation.
- quant_function(fwd_quant, bwd_quant)[source]#
Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.
- Parameters:
fwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the forward pass through the input signalbwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the backward pass through the output gradient signal
- class mptorch.quant.modules.QConvTranspose2d(in_channels, out_channels, kernel_size, formats, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', device=None, dtype=None)[source]#
Bases:
ConvTranspose2d
Applies a 2D transposed convolution operator over an input image composed of several input planes.
This module can be seen as the gradient of Conv2d with respect to its input. It is also known as a fractionally-strided convolution or a deconvolution (although it is not an actual deconvolution operation as it does not compute a true inverse of convolution). For more information, see the visualizations here and the Deconvolutional Networks paper.
It is a subclass of
torch.nn.ConvTranspose2d
and allows one to specify if I/O signals should be quantized during inference & training (needed for instance in QAT and PTQ methods), but also the precision(s) to be used in internal GEMM computations (addition and multiplication, fused or not). This allows simulating the effect of custom precision during GEMM calls in the forward and backward pass (which are performed using the im2col and col2im algorithms implemented in the unfoldNd library) and is helpful in studying the effect of low precision compute during inference and training (not just data quantization).stride
controls the stride for the cross-correlation.padding
controls the amount of implicit zero padding on both sides fordilation * (kernel_size - 1) - padding
number of points. See note below for details.output_padding
controls the additional size added to one side of the output shape. See note below for details.dilation
controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but the link here has a nice visualization of whatdilation
does.in_channels
andout_channels
must both be divisible bygroups
. For example,At groups=1, all inputs are convolved to all outputs.
At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels and producing half the output channels, and both subsequently concatenated.
At groups=
in_channels
, each input channel is convolved with its own set of filters (of size \(\frac{\text{out_channels}}{\text{in_channels}}\)).
The parameters
kernel_size
,stride
,padding
,output_padding
can either be:a single
int
– in which case the same value is used for the height and width dimensionsa
tuple
of two ints – in which case, the first int is used for the height dimension, and the second int for the width dimension
Note
The
padding
argument effectively addsdilation * (kernel_size - 1) - padding
amount of zero padding to both sizes of the input. This is set so that when atorch.nn.Conv2d
and atorch.nn.ConvTranspose2d
are initialized with same parameters, they are inverses of each other in regard to the input and output shapes. However, whenstride > 1
,torch.nn.Conv2d
maps multiple input shapes to the same output shape.output_padding
is provided to resolve this ambiguity by effectively increasing the calculated output shape on one side. Note thatoutput_padding
is only used to find output shape, but does not actually add zero-padding to output.- Shape:
Input: \((N, C_\text{in}, H_\text{in}, W_\text{in})\) or \((C_\text{in}, H_\text{in}, W_\text{in})\)
Output: \((N, C_\text{out}, H_\text{out}, W_\text{out})\) or \((C_\text{out}, H_\text{out}, W_\text{out})\), where
\[H_\text{out} = (H_\text{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0] \times (\text{kernel_size}[0] - 1) + \text{output_padding}[0] + 1\]\[W_\text{out} = (W_\text{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1] \times (\text{kernel_size}[1] - 1) + \text{output_padding}[1] + 1\]
- weight#
the learnable weights of the module of shape \((\text{in_channels}, \frac{\text{out_channels}}{\text{groups}},\) \(\text{kernel_size[0]}, \text{kernel_size[1]})\). The values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{out} * \prod_{i=0}^{1}\text{kernel_size}[i]}\)
- Type:
Tensor
- bias#
the learnable bias of the module of shape (out_channels) If
bias
isTrue
, then the values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{out} * \prod_{i=0}^{1}\text{kernel_size}[i]}\)- Type:
Tensor
- Parameters:
in_channels (
int
) – Number of channels in the input imageout_channels (
int
) – Number of channels produced by the convolutionkernel_size (
int
|tuple
[int
,int
]) – Size of the convolving kernelformats (
QAffineFormats
) – Number formats used during compute (addition and multiplication) and quantization functions for signals during forward and back propagation (I/O activations, weights, biases, and neural gradients)stride (
int
|tuple
[int
,int
]) – Stride of the convolution. Default: 1padding (
int
|tuple
[int
,int
]) –dilation * (kernel_size - 1) - padding
zero-padding will be added to both sides of the input. Default: 0output_padding (
int
|tuple
[int
,int
]) – Additional size added to one side of the output shape. Default: 0groups (
int
) – Number of blocked connections from input channels to output channels. Default: 1bias (
bool
) – IfTrue
, adds a learnable bias to the output. Default:True
dilation (
int
|tuple
[int
,int
]) – Spacing between kernel elements. Default: 1
- forward(input, output_size=None)[source]#
Describes the computations that get performed at every call of the module. The use of quantized elementary operations (i.e., additions and multiplications) in the FWD and BWD passes are controlled through the
formats
argument to the module constructor.- Parameters:
input (
Tensor
) – the input tensor on which to perform the layer operations. Must adhere to the input shape requirements.output_size (
list
[int
] |None
) – specifies how the output should be padded
- Return type:
Tensor
- Returns:
the result of the 2D transposed convolution operation.
- quant_function(fwd_quant, bwd_quant)[source]#
Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.
- Parameters:
fwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the forward pass through the input signalbwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the backward pass through the output gradient signal
- class mptorch.quant.modules.QConvTranspose3d(in_channels, out_channels, kernel_size, formats, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', device=None, dtype=None)[source]#
Bases:
ConvTranspose3d
Applies a 3D transposed convolution operator over an input image composed of several input planes. The transposed convolution operator multiplies each input value element-wise by a learnable kernel, and sums over the outputs from all input feature planes.
This module can be seen as the gradient of Conv2d with respect to its input. It is also known as a fractionally-strided convolution or a deconvolution (although it is not an actual deconvolution operation as it does not compute a true inverse of convolution). For more information, see the visualizations here and the Deconvolutional Networks paper.
It is a subclass of
torch.nn.ConvTranspose3d
and allows one to specify if I/O signals should be quantized during inference & training (needed for instance in QAT and PTQ methods), but also the precision(s) to be used in internal GEMM computations (addition and multiplication, fused or not). This allows simulating the effect of custom precision during GEMM calls in the forward and backward pass (which are performed using the im2col and col2im algorithms implemented in the unfoldNd library) and is helpful in studying the effect of low precision compute during inference and training (not just data quantization).stride
controls the stride for the cross-correlation.padding
controls the amount of implicit zero padding on both sides fordilation * (kernel_size - 1) - padding
number of points. See note below for details.output_padding
controls the additional size added to one side of the output shape. See note below for details.dilation
controls the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but the link here has a nice visualization of whatdilation
does.in_channels
andout_channels
must both be divisible bygroups
. For example,At groups=1, all inputs are convolved to all outputs.
At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels and producing half the output channels, and both subsequently concatenated.
At groups=
in_channels
, each input channel is convolved with its own set of filters (of size \(\frac{\text{out_channels}}{\text{in_channels}}\)).
The parameters
kernel_size
,stride
,padding
,output_padding
can either be:a single
int
– in which case the same value is used for the depth, height and width dimensionsa
tuple
of three ints – in which case, the first int is used for the depth dimension, the second int for the height dimension and the third int for the width dimension
Note
The
padding
argument effectively addsdilation * (kernel_size - 1) - padding
amount of zero padding to both sizes of the input. This is set so that when atorch.nn.Conv3d
and atorch.nn.ConvTranspose3d
are initialized with same parameters, they are inverses of each other in regard to the input and output shapes. However, whenstride > 1
,torch.nn.Conv3d
maps multiple input shapes to the same output shape.output_padding
is provided to resolve this ambiguity by effectively increasing the calculated output shape on one side. Note thatoutput_padding
is only used to find output shape, but does not actually add zero-padding to output.- Shape:
Input: \((N, C_\text{in}, D_\text{in}, H_\text{in}, W_\text{in})\) or \((C_\text{in}, D_\text{in}, H_\text{in}, W_\text{in})\)
Output: \((N, C_\text{out}, D_\text{out}, H_\text{out}, W_\text{out})\) or \((C_\text{out}, D_\text{out}, H_\text{out}, W_\text{out})\), where
\[D_\text{out} = (D_\text{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0] \times (\text{kernel_size}[0] - 1) + \text{output_padding}[0] + 1\]\[H_\text{out} = (H_\text{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1] \times (\text{kernel_size}[1] - 1) + \text{output_padding}[1] + 1\]\[W_\text{out} = (W_\text{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{dilation}[2] \times (\text{kernel_size}[2] - 1) + \text{output_padding}[2] + 1\]
- weight#
the learnable weights of the module of shape \((\text{in_channels}, \frac{\text{out_channels}}{\text{groups}},\) \(\text{kernel_size[0]}, \text{kernel_size[1]}, \text{kernel_size[2]})\). The values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{out} * \prod_{i=0}^{2}\text{kernel_size}[i]}\)
- Type:
Tensor
- bias#
the learnable bias of the module of shape (out_channels) If
bias
isTrue
, then the values of these weights are sampled from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{\text{groups}}{C_\text{out} * \prod_{i=0}^{2}\text{kernel_size}[i]}\)- Type:
Tensor
- Parameters:
in_channels (
int
) – Number of channels in the input imageout_channels (
int
) – Number of channels produced by the convolutionkernel_size (
int
|tuple
[int
,int
,int
]) – Size of the convolving kernelformats (
QAffineFormats
) – Number formats used during compute (addition and multiplication) and quantization functions for signals during forward and back propagation (I/O activations, weights, biases, and neural gradients)stride (
int
|tuple
[int
,int
,int
]) – Stride of the convolution. Default: 1padding (
int
|tuple
[int
,int
,int
]) –dilation * (kernel_size - 1) - padding
zero-padding will be added to both sides of the input. Default: 0output_padding (
int
|tuple
[int
,int
,int
]) – Additional size added to one side of the output shape. Default: 0groups (
int
) – Number of blocked connections from input channels to output channels. Default: 1bias (
bool
) – IfTrue
, adds a learnable bias to the output. Default:True
dilation (
int
|tuple
[int
,int
,int
]) – Spacing between kernel elements. Default: 1
- forward(input, output_size=None)[source]#
Describes the computations that get performed at every call of the module. The use of quantized elementary operations (i.e., additions and multiplications) in the FWD and BWD passes are controlled through the
formats
argument to the module constructor.- Parameters:
input (
Tensor
) – the input tensor on which to perform the layer operations. Must adhere to the input shape requirements.output_size (
list
[int
] |None
) – specifies how the output should be padded
- Return type:
Tensor
- Returns:
the result of the 3D transposed convolution operation.
- quant_function(fwd_quant, bwd_quant)[source]#
Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.
- Parameters:
fwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the forward pass through the input signalbwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the backward pass through the output gradient signal
- class mptorch.quant.modules.QAvgPool2d(kernel_size, fwd_quant, bwd_quant, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)[source]#
Bases:
AvgPool2d
Applies a 2D average pooling over an input signal composed of several input planes. Performs the addition operations in the FWD and BWD passes using quantized operators given as parameters.
In the simplest case, the output value of the layer with input size \((N, C, H, W)\), output \((N, C, H_\text{out}, W_\text{out})\) and
kernel_size
\((kH, kW)\) can be precisely described as:\[\text{out}(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \text{input}(N_i, C_j, \text{stride}[0] \times h + m, \text{stride}[1] \times w + n)\]If
padding
is non-zero, then the input is implicitly zero-padded on both sides forpadding
number of points.Note
When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding or the input. Sliding windows that would start in the right padded region are ignored.
The parameters
kernel_size
,stride
,padding
can either be:a single
int
– in which case the same value is used for the height and width dimensiona
tuple
of two ints – in which case, the first int is used for the height dimension, and the second int for the width dimension
- Shape:
Input: \((N, C, H_\text{in}, W_\text{in})\) or \((C, H_\text{in}, W_\text{in})\).
Output: \((N, C, H_\text{out}, W_\text{out})\) or \((C, H_\text{out}, W_\text{out})\), where
\[H_\text{out} = \left\lfloor\frac{H_\text{in} + 2 \times \text{padding}[0] - \text{kernel_size}[0]}{\text{stride}[0]} + 1\right\rfloor\]\[W_\text{out} = \left\lfloor\frac{W_\text{in} + 2 \times \text{padding}[1] - \text{kernel_size}[1]}{\text{stride}[1]} + 1\right\rfloor\]Per the note above, if
ceil_mode
is True and \((H_\text{out} - 1)\times \text{stride}[0]\geq H_\text{in} + \text{padding}[0]\), we skip the last window as it would start in the bottom padded region, resulting in \(H_\text{out}\) being reduced by one.The same applies for \(W_\text{out}\).
- Parameters:
kernel_size (
int
|tuple
[int
,int
]) – the size of the windowfwd_quant (
Callable
[[Tensor
],Tensor
]) – quantization function to use during FWD addition operationsbwd_quant (
Callable
[[Tensor
],Tensor
]) – quantization function to use during BWD addition operationsstride (
int
|tuple
[int
,int
] |None
) – the stride of the window. Default value iskernel_size
padding (
int
|tuple
[int
,int
]) – implicit zero padding to be added on both sidesceil_mode (
bool
) – when True, will use ceil instead of floor to compute the output shapecount_include_pad (
bool
) – when True, will include the zero-padding in the averaging calculationdivisor_override (
int
|None
) – if specified, it will be used as divisor, otherwise size of the pooling region will be used.
- class mptorch.quant.modules.QBatchNorm1d(num_features, fwd_quant, bwd_quant)[source]#
Bases:
QBatchNorm
Applies Batch Normalization over a 2D input.
Method described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .
\[y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]The mean and standard-deviation are calculated per-dimension over the mini-batches and \(\gamma\) and \(\beta\) are learnable parameter vectors of size C (where C is the number of features or channels of the input). By default, the elements of \(\gamma\) are set to 1 and the elements of \(\beta\) are set to 0. At train time in the forward pass, the variance is calculated via the biased estimator, equivalent to
torch.var(input, unbiased=False)
. However, the value stored in the moving average of the variance is calculated via the unbiased estimator, equivalent totorch.var(input, unbiased=True)
.Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default momentum of 0.1.
Note
This momentum is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is \(\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t\), where \(\hat{x}\) is the estimated statistic and \(x_t\) is the new observed value.
- Shape:
Input: \((N, C)\) or \((N, C, L)\), where \(N\) is the batch size, \(C\) is the number of features or channels, and \(L\) is the sequence length
Output: \((N, C)\) or \((N, C, L)\) (same shape as input)
Because the Batch Normalization is done over the C dimension, computing statistics on (N, L) slices, it’s common terminology to call this Temporal Batch Normalization.
- Parameters:
num_features (
int
) – number of features or channels \(C\) of the inputfwd_quant (
Callable
[[Tensor
],Tensor
]) – quantization function to use during FWD operationsbwd_quant (
Callable
[[Tensor
],Tensor
]) – quantization function to use during BWD operations
- class mptorch.quant.modules.QBatchNorm2d(num_features, fwd_quant, bwd_quant)[source]#
Bases:
QBatchNorm
Applies Batch Normalization over a 4D input.
4D is a mini-batch of 2D inputs with additional channel dimension. Method described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]The mean and standard-deviation are calculated per-dimension over the mini-batches and \(\gamma\) and \(\beta\) are learnable parameter vectors of size C (where C is the input size). By default, the elements of \(\gamma\) are set to 1 and the elements of \(\beta\) are set to 0. At train time in the forward pass, the standard-deviation is calculated via the biased estimator, equivalent to
torch.var(input, unbiased=False)
. However, the value stored in the moving average of the standard-deviation is calculated via the unbiased estimator, equivalent totorch.var(input, unbiased=True)
.Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default momentum of 0.1.
Note
This momentum different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is \(\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t\), where \(\hat{x}\) is the estimated statistic and \(x_t\) is the new observed value.
Because the Batch Normalization is done over the C dimension, computing statistics on (N, H, W) slices, it’s common terminology to call this Spatial Batch Normalization.
- Shape:
Input: \((N, C, H, W)\)
Output: \((N, C, H, W)\) (same shape as input)
- Parameters:
num_features (
int
) – \(C\) from an expected input of size \((N, C, H, W)\)fwd_quant (
Callable
[[Tensor
],Tensor
]) – quantization function to use during FWD operationsbwd_quant (
Callable
[[Tensor
],Tensor
]) – quantization function to use during BWD operations
- class mptorch.quant.modules.QLayerNorm(normalized_shape, formats, eps=1e-05, elementwise_affine=True, bias=True)[source]#
Bases:
LayerNorm
Applies Layer Normalization over a mini-batch of inputs.
This layer implements the operation as described in the paper Layer Normalization
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]The mean and standard-deviation are calculated over the last D dimensions, where D is the dimension of
normalized_shape
. For example, ifnormalized_shape
is(3, 5)
(a 2-dimensional shape), the mean and standard-deviation are computed over the last 2 dimensions of the input (i.e.input.mean((-2, -1))
). \(\gamma\) and \(\beta\) are learnable affine transform parameters ofnormalized_shape
ifelementwise_affine
isTrue
. The variance is calculated via the biased estimator, equivalent totorch.var(input, unbiased=False)
.Note
Unlike Batch Normalization and Instance Normalization, which applies scalar scale and bias for each entire channel/plane with the
affine
option, Layer Normalization applies per-element scale and bias withelementwise_affine
.This layer uses statistics computed from input data in both training and evaluation modes.
- weight#
the learnable weights of the module of shape \(\text{normalized_shape}\) when
elementwise_affine
is set toTrue
. The values are initialized to 1.- Type:
Tensor
- bias#
the learnable bias of the module of shape \(\text{normalized_shape}\) when
elementwise_affine
is set toTrue
. The values are initialized to 0.- Type:
Tensor
- Shape:
Input: \((N, *)\)
Output: \((N, *)\) (same shape as input)
- Parameters:
normalized_shape (int or list or torch.Size) –
input shape from an expected input of size
\[[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]\]If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size.
eps (
float
) – a value added to the denominator for numerical stability. Default: 1e-5elementwise_affine (
bool
) – a boolean value that when set toTrue
, this module has learnable per-element affine parameters initialized to ones (for weights) and zeros (for biases). Default:True
.bias (
bool
) – If set toFalse
, the layer will not learn an additive bias (only relevant ifelementwise_affine
isTrue
). Default:True
.
- forward(input)[source]#
Performs the layernorm operation on the input tensor, using quantized elementary operations (e.g. additions and multiplications) in the FWD and BWD passes, as specified through the
formats
argument to the module constructor.- Parameters:
input (
Tensor
) – the input tensor over which to perform the layernorm operations.- Return type:
Tensor
- Returns:
the output of the layernorm operation
- quant_function(fwd_quant, bwd_quant)[source]#
Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.
- Parameters:
fwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the forward pass through the input signalbwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the backward pass through the output gradient signal
- class mptorch.quant.modules.QSoftmax(dim, formats)[source]#
Bases:
Softmax
A quantized implementation of the Softmax activation function.
This class extends PyTorch’s
torch.nn.Softmax
and allows one to specify if I/O signals and internal computations should be quantized during inference & training. This allows simulating the effect of custom precision in the internal forward and backward pass computations and is helpful in studying the effect of low precision compute during inference and training (not just data quantization).- Parameters:
dim (
int
) – The dimension along which Softmax will be applied.formats (
QSoftmaxFormats
) – Number formats specification during compute and quantization functions for signals during forward and back propagation (I/O activations and gradients).
- forward(input)[source]#
Performs the softmax operation on the input tensor. The use of quantized elementary operations (i.e., additions and multiplications) in the FWD and BWD passes are controlled through the
formats
argument to the module constructor.- Parameters:
input – the input tensor over which to perform the softmax operations.
- Return type:
Tensor
- Returns:
the result of the softmax operation.
- quant_function(fwd_quant, bwd_quant)[source]#
Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.
- Parameters:
fwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the forward pass through the input signalbwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the backward pass through the output gradient signal
- class mptorch.quant.modules.QGELU(formats, approximate='none')[source]#
Bases:
GELU
Applies the Gaussian Error Linear Units function to the input \(x\):
\[\text{GELU}(x) = x * \Phi(x)\]where \(\Phi(x)\) is the Cumulative Distribution Function for Gaussian Distribution.
When the approximate argument is
'tanh'
, GELU is estimated with:\[\text{GELU}(x) = 0.5 * x * (1 + \tanh(\sqrt{2 / \pi} * (x + 0.044715 * x^3)))\]- Parameters:
x – the input tensor
formats (
QGELUFormats
) – configuration class for number formats and quantizers to use during forward and backward computations in GELUapproximate (
Literal
['none'
,'tanh'
]) – the GELU approximation algorithm to use:'none'
|'tanh'
. Default:'none'
- forward(input)[source]#
Applies the GELU function on the input tensor element-wise.
- Parameters:
input (
Tensor
) – the input tensor over which the GELU function is applied- Return type:
Tensor
- Returns:
the result of applying the GELU function
- quant_function(fwd_quant, bwd_quant)[source]#
Defines a straight-through estimator-like function (see Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (https://arxiv.org/abs/1308.3432)) that applies potentially different quantization functions in the forward and backward passes through the input and output gradient signals, respectively.
- Parameters:
fwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the forward pass through the input signalbwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply during the backward pass through the output gradient signal
Functional#
- mptorch.quant.functional.qlinear(input, weight, bias=None, formats=QAffineFormats(default_fwd, default_bwd, rbits_add=0, rbits_mul=0))[source]#
Applies a linear transformation to the incoming data: \(y=xW^T + b\)
The
formats
parameter allows one to specify if I/O signals should be quantized during inference & training (needed for instance in QAT and PTQ methods), but also the precision(s) to be used in internal GEMM computations (addition and multiplication, fused or not). This allows simulating the effect of custom precision during GEMM calls in the forward and backward pass and is helpful in studying the effect of low precision compute during inference and training (not just data quantization).This is the functional version of
QLinear
.- Parameters:
input (
Tensor
) – the input \(x\) to the linear layer of the form \((*, H_\text{in})\), where \(*\) means any number of dimensions including none and \(H_{in} = \text{in_features}\)weight (
Tensor
) – the weight tensor \(W\) of shape \((\text{out_features}, \text{in_features})\)bias (
Tensor
|None
) – optional bias term of shape \((\text{out_features})\)formats (
QAffineFormats
) – the configuration object for how quantization (if any!) should be handled on the matrix inputs and how the MAC and summation operations should be performed (e.g. using compensated algorithms or not)
- Return type:
Tensor
- Returns:
the result of the affine operation \(xW^T + b\)
- mptorch.quant.functional.qmatmul(input, other, formats=QAffineFormats(default_fwd, default_bwd, rbits_add=0, rbits_mul=0))[source]#
Simulates a mixed-precision computation pipeline for (batched) matrix multiplication of the tensors
input
andother
.The behavior depends on the dimensionality of the tensors as follows:
If both tensors are 1-dimensional, the dot product (scalar) is returned.
If both arguments are 2-dimensional, the matrix-matrix product is returned.
If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.
If the first argument is 2-dimensional and the second argument is 1-dimensional, the matrix-vector product is returned.
If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where 2 < N < 5), then a batched matrix multiply is returned. If the first argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after. If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable in the PyTorch sense). For example, if
input
is a \((j \times 1 \times n \times n)\) tensor andother
is a \((k \times n \times n)\) tensor,out
will be a \((j \times k \times n \times n)\) tensor.Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs are broadcastable, and not the matrix dimensions. For example, if
input
is a \((j \times 1 \times n \times m)\) tensor andother
is a \((k \times m \times p)\) tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the matrix dimensions) are different.out
will be a \((j \times k \times n \times p)\) tensor.
- Parameters:
input (
Tensor
) – the first tensor to be multipliedother (
Tensor
) – the second tensor to be multipliedformats (
QAffineFormats
) – the configuration object for how quantization (if any!) should be handled on the tensor inputs and how the MAC and summation operations should be performed (e.g. using compensated algorithms or not)
- Return type:
Tensor
- Returns:
the result of the (batched) matrix multiplication between
input
andother
- mptorch.quant.functional.qmm(input, mat2, formats=QAffineFormats(default_fwd, default_bwd, rbits_add=0, rbits_mul=0))[source]#
Simulates a mixed-precision computation pipeline for matrix multiplication of the matrices
input
andmat2
.If
input
is a \((n \times m)\) tensor,mat2
is a \((m \times p)\) tensor, the output tensor will be a \((n \times p)\) tensor.Note
This function does not broadcast. For broadcasting quantized matrix products, see
mptorch.quant.functional.qmatmul()
.- Parameters:
input (
Tensor
) – the first matrix to be multipliedmat2 (
Tensor
) – the second matrix to be multipliedformats (
QAffineFormats
) – the configuration object for how quantization (if any!) should be handled on the matrix inputs and how the MAC and summation operations should be performed (e.g. using compensated algorithms or not)
- Return type:
Tensor
- Returns:
the result of the matrix multiplication between
input
andmat2
- mptorch.quant.functional.qadd(x, y, fwd_quant, bwd_quant)[source]#
Adds
x
toy
. Usesfwd_quant
to quantize the result of the addition (e.g. can simulate the execution of the addition in low-precision, assuming the inputs are already in low precision). Thebwd_quant
function is used to quantize the gradients from the operator during the backward pass.For the forward computation:
\[\text{out} = \mathcal{Q}_\text{fwd}(\text{x} + \text{y})\]For the backward computation:
\[ \begin{align}\begin{aligned}\text{grad_x} = \mathcal{Q}_\text{bwd}(\text{grad_z} * \text{ones_like}(\text{x}))\\\text{grad_y} = \mathcal{Q}_\text{bwd}(\text{grad_z} * \text{ones_like}(\text{y}))\end{aligned}\end{align} \]- Parameters:
x (
Tensor
) – the input tensory (
Tensor
) – the other tensor to add tox
fwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply on the forward additionbwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply on the gradient computations in the backward pass
- Return type:
Tensor
- Returns:
the quantized result of the addition operation between x and y
- mptorch.quant.functional.qmul(x, y, fwd_quant, bwd_quant)[source]#
Multiplies
x
byy
. Usesfwd_quant
to quantize the result of the multiplication (e.g. can simulate the execution of the multiplication in low-precision, assuming the inputs are already in low precision). Thebwd_quant
function is used to quantize the gradients from the operator during the backward pass.For the forward computation:
\[\text{out} = \mathcal{Q}_\text{fwd}(\text{x} * \text{y})\]For the backward computation:
\[ \begin{align}\begin{aligned}\text{grad_x} = \mathcal{Q}_\text{bwd}(\text{grad_z} * \text{y})\\\text{grad_y} = \mathcal{Q}_\text{bwd}(\text{grad_z} * \text{x})\end{aligned}\end{align} \]- Parameters:
x (
Tensor
) – the input tensory (
Tensor
) – the other tensor to add tox
fwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply on the forward multiplicationbwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply on the gradient computations in the backward pass
- Return type:
Tensor
- Returns:
the quantized result of the multiplication operation between x and y
- mptorch.quant.functional.qsqrt(x, fwd_quant, bwd_quant)[source]#
Returns a new tensor with the square-root of the elements of
x
.\[\text{out}_{i} = \sqrt{\text{x}_{i}}\]- Parameters:
x (
Tensor
) – the input tensorfwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply on the forward square root operationbwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply on the gradient computations in the backward pass
- Return type:
Tensor
- Returns:
the quantized result of the square root operation on x
- mptorch.quant.functional.qdiv(x, y, fwd_quant, bwd_quant)[source]#
Divides
x
byy
. Usesfwd_quant
to quantize the result of the division (e.g. can simulate the execution of the division in low-precision, assuming the inputs are already in low precision). Thebwd_quant
function is used to quantize the gradients from the operator during the backward pass.For the forward computation:
\[\text{out}_i = \mathcal{Q}_\text{fwd}\left(\frac{\text{x}_i}{\text{y}_i}\right)\]For the backward computation:
\[ \begin{align}\begin{aligned}\text{grad_x} = \mathcal{Q}_\text{bwd}\left(\frac{\text{grad_z}}{\text{y}}\right)\\\text{grad_y} = \mathcal{Q}_\text{bwd}\left(\frac{\mathcal{Q}_\text{bwd}\left(-\text{grad_z} * \text{x}\right)}{\text{y}}\right)\end{aligned}\end{align} \]- Parameters:
x (
Tensor
) – the input tensory (
Tensor
) – the other tensor to add tox
fwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply on the forward divisionbwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply on the gradient computations in the backward pass
- Return type:
Tensor
- Returns:
the quantized result of the division operation between x and y
- mptorch.quant.functional.qpow(x, fwd_quant, bwd_quant, n=2.0)[source]#
Takes the power of each element in
x
withn
and returns a tensor with the result.n
can be either a singlefloat
number or a torch.Tensor with the same number of elements asx
.When
n
is a scalar value, the forward operation applied is:\[\text{out}_i = \mathcal{Q}_\text{fwd}\left(\text{x}_i^\text{n}\right)\]When
n
is a tensor, the operation applied is:\[\text{out}_i = \mathcal{Q}_\text{fwd}\left(\text{x}_i^{\text{n}_i}\right)\]When
n
is a tensor, the shapes ofx
andn
must be broadcastable.The backward operation is (applied element-wise):
\[\text{out} = \mathcal{Q}_\text{bwd}\left(\text{grad_out} *\mathcal{Q}_\text{bwd}\left(\mathcal{Q}_\text{bwd}\left(\text{x}^{\text{n}-1}\right) * \text{n}\right)\right)\]- Parameters:
x (
Tensor
) – the input tensorfwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply on the forward divisionbwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function to apply on the gradient computations in the backward passn (
Tensor
|float
) – the exponent value
- Return type:
Tensor
- Returns:
the quantized result of the power operation between x and n
- mptorch.quant.functional.qsum(x, dim=None, quant=<function <lambda>>, keepdim=False)[source]#
Returns the quantized sum of all elements in the
x
tensor. It can simulate low precision summation if the elements ofx
are low precision values. Thequant
function specifies the accumulator output format and precision as a quantization function.- Parameters:
x (
Tensor
) – the input tensordim (
int
|tuple
[int
,...
] |None
) – the dimension or dimensions to reduce. IfNone
, all dimensions are reducedquant (
Callable
[[Tensor
],Tensor
]) – the quantization function specifying how accumulation results should be storedkeepdim (
bool
) – whether the output tensor hasdim
retained or not
- Return type:
Tensor
- Returns:
the quantized result of the sum operation operation on x
- mptorch.quant.functional.qmean(x, fwd_quant, bwd_quant, dim=3, keepdim=False)[source]#
Returns the mean value of all the elements in the
x
tensor. Input must be a floating point tensor. It can simulate low precision summation if the elements ofx
are low precision values. Thefwd_quant
function specifies the accumulator (and final division) output format and precision as a quantization function in the forward pass. Thebwd_quant
function specifies how the arithmetic should be performed in the backward pass through this operator.- Parameters:
x (
Tensor
) – the input tensordim (
int
|tuple
[int
,...
] |None
) – the dimension or dimensions to reduce. IfNone
, all dimensions are reducedfwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function specifying how accumulation and division results should be stored in the forward passbwd_quant (
Callable
[[Tensor
],Tensor
]) – the quantization function specifying how operations should be performed in the backward passkeepdim (
bool
) – whether the output tensor hasdim
retained or not
- Return type:
Tensor
- Returns:
the quantized result of the mean operation operation on x
- mptorch.quant.functional.qlayernorm(x, normalized_shape, weight, bias, eps=1e-05, formats=QLayerNormFormats(default_fwd, default_bwd))[source]#
Implements the operation as described in the paper Layer Normalization (https://arxiv.org/abs/1607.06450), giving the user control over how the arithmetic is performed during the forward and backward passes through the
formats
parameter.\[y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}}\gamma + \beta\]The mean and standard deviation are computed over the last D dimensions, where D is the dimension of
normalized_shape
. For instance, ifnormalized_shape
is(3, 5)
(ad 2-D shape), the mean and standard deviation are computed over the last 2 dimension of the input (i.e.,x.mean((-2, -1))
). In a training context, \(\gamma\) and \(\beta\) are learnable affine transform paremeters ofnormalized_shape
.Note
Unlike Batch Normalization and Instance Normalization, which applies scalar scale and bias for each entire channel/plane with the
affine
option, Layer Normalization applies per-element scale and bias withelementwise_affine
.It uses statistics computed from input data in both training and evaluation modes.
- Parameters:
x (
Tensor
) – the input tensornormalized_shape (
int
|list
[int
] |Size
) –input shape from an expected input of size
\[[* \times \text{n_shape}[0] \times \text{n_shape}[1] \times \ldots \times \text{n_shape}[-1]]\]If a single integer is used, it is treated as a singleton list, and this function will normalize over the last dimension which is expected to be of that specific size
weight (
Tensor
) – the learnable weights \(\gamma\) of shape \(\text{n_shape}\)bias (
Tensor
) – the learnable bias \(\beta\) of shape \(\text{n_shape}\)eps (
float
) – a small value added to the denominator for numerical stability. Default: 1e-5formats (
QLayerNormFormats
) – configuration class for number formats and quantizers to use during forward and backward computations in layer normalization.
- Return type:
Tensor
- Returns:
the quantized result of the mean operation operation on
x
. Has the same shape asx
.
- mptorch.quant.functional.qsoftmax(x, dim, formats)[source]#
Applies the Softmax function to an n-dimensional input tensor
x
. Through theformats
parameter it allows one to specify if I/O signals and internal mathematical operations should be quantized during forward and backward compute chains.Rescales the elements in the input tensor so that they lie in the range \([0, 1]\) and sum to \(1\). It is defined as:
\[\text{Softmax}(x_i) = \frac{\exp(x_i)}{\sum_j\exp(x_j)}\]- Parameters:
x (
Tensor
) – the input tensordim (
int
|None
) – a dimension along which Softmax will be computed (so every slice alongdim
will sum to 1)formats (
QSoftmaxFormats
) – configuration class for number formats and quantizers to use during forward and backward computations in Softmax
- Return type:
Tensor
- Returns:
a
Tensor
of the same dimension and shape as the input with values in the range \([0, 1]\)
Example
# TODO
- mptorch.quant.functional.qgelu(x, formats, approximate='none')[source]#
Applies the Gaussian Error Linear Units function to the input \(x\):
\[\text{GELU}(x) = x * \Phi(x)\]where \(\Phi(x)\) is the Cumulative Distribution Function for Gaussian Distribution.
When the approximate argument is
'tanh'
, GELU is estimated with:\[\text{GELU}(x) = 0.5 * x * (1 + \tanh(\sqrt{2 / \pi} * (x + 0.044715 * x^3)))\]- Parameters:
x (
Tensor
) – the input tensorformats (
QGELUFormats
) – configuration class for number formats and quantizers to use during forward and backward computations in GELUapproximate (
Literal
['tanh'
,'none'
]) – the GELU approximation algorithm to use:'none'
|'tanh'
. Default:'none'
- Return type:
Tensor
- Returns:
a
Tensor
of the same dimension and shape as the input where the GELU function is applied element-wise
Example
import torch import mptorch.quant as qt from torch.testing import assert_close quant_func = lambda x: qt.float_quantize(x, man=7, exp=8, rounding="nearest") formats = qt.QGELUFormats( input_quant=quant_func, output_quant=quant_func, ) x = torch.randn(3, 2) y = torch.nn.functional.gelu(x) qy = qt.functional.qgelu(x, formats=formats) assert_close(y, qy, rtol=1e-2, atol=1e-5)
cuBLAS Acceleration#
- class mptorch.quant.cublas.cublas_acceleration(enabled, fast_mode=None)[source]#
Bases:
object
cuBLAS acceleration management.
This class allows enabling and disabling of automatic cuBLAS acceleration for compatible types. When enabled, all calls for float quantized (batched) GEMMs (float_mm and float_bmm) used internally linear and convolutional layers will use the cuBLAS GEMM functions if the floating point computation formats matches one of the formats combinations supported by cuBLAS, i.e: - nearest rounding even mode - fused-multiply-add enabled - subnormals enabled - saturation disabled - same multiplication/accumulator types matching one of the [supported combinations](https://docs.nvidia.com/cuda/cublas/#cublasgemmex)
This feature is disabled by default.
- Parameters:
enabled (
bool
) – whether to enable automatic cuBLAS acceleration.fast_mode (
str
|None
) – allow internal downcast to lower-precision for tensor cores. Currently supported are f16, bf16 and tf32
Example
cuBLAS acceleration can be enabled/disabled globally using the static enable method; or locally as a context manager:
mac_format = FloatingPoint( exp=5, man=10, subnormals=True, saturate=False # F16, supported by cublas ) layer_formats = QAffineFormats( fwd_mac=(mac_format,), bwd_mac=(mac_format,), fwd_rnd="nearest", bwd_rnd="nearest", ... ) layer = QLinear(in_features, out_features, formats=layer_formats) with cublas_acceleration(True): x = torch.tensor(...) y = layer.forward(x)
- enabled = False#
- fast_mode = None#