torch¶
The torch package contains data structures for multi-dimensional tensors and mathematical operations over these are defined. Additionally, it provides many utilities for efficient serializing of Tensors and arbitrary types, and other useful utilities.
It has a CUDA counterpart, that enables you to run your tensor computations on an NVIDIA GPU with compute capability >= 2.0.
Tensors¶
-
torch.is_tensor(obj)[source]¶ Returns True if obj is a pytorch tensor.
Parameters: obj (Object) – Object to test
-
torch.is_storage(obj)[source]¶ Returns True if obj is a pytorch storage object.
Parameters: obj (Object) – Object to test
-
torch.numel(input) → int¶ Returns the total number of elements in the
inputTensor.Parameters: input (Tensor) – the input Tensor Example:
>>> a = torch.randn(1,2,3,4,5) >>> torch.numel(a) 120 >>> a = torch.zeros(4,4) >>> torch.numel(a) 16
-
torch.set_printoptions(precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None)[source]¶ Set options for printing. Items shamelessly taken from Numpy
Parameters: - precision – Number of digits of precision for floating point output (default 8).
- threshold – Total number of array elements which trigger summarization rather than full repr (default 1000).
- edgeitems – Number of array items in summary at beginning and end of each dimension (default 3).
- linewidth – The number of characters per line for the purpose of inserting line breaks (default 80). Thresholded matricies will ignore this parameter.
- profile – Sane defaults for pretty printing. Can override with any of the above options. (default, short, full)
Creation Ops¶
-
torch.eye(n, m=None, out=None)¶ Returns a 2-D tensor with ones on the diagonal and zeros elsewhere.
Parameters: Returns: a 2-D tensor with ones on the diagonal and zeros elsewhere
Return type: Example:
>>> torch.eye(3) 1 0 0 0 1 0 0 0 1 [torch.FloatTensor of size 3x3]
-
torch.from_numpy(ndarray) → Tensor¶ Creates a
Tensorfrom anumpy.ndarray.The returned tensor and ndarray share the same memory. Modifications to the tensor will be reflected in the ndarray and vice versa. The returned tensor is not resizable.
Example:
>>> a = numpy.array([1, 2, 3]) >>> t = torch.from_numpy(a) >>> t torch.LongTensor([1, 2, 3]) >>> t[0] = -1 >>> a array([-1, 2, 3])
-
torch.linspace(start, end, steps=100, out=None) → Tensor¶ Returns a one-dimensional Tensor of
stepsequally spaced points betweenstartandendThe output tensor is 1D of size
stepsParameters: Example:
>>> torch.linspace(3, 10, steps=5) 3.0000 4.7500 6.5000 8.2500 10.0000 [torch.FloatTensor of size 5] >>> torch.linspace(-10, 10, steps=5) -10 -5 0 5 10 [torch.FloatTensor of size 5] >>> torch.linspace(start=-10, end=10, steps=5) -10 -5 0 5 10 [torch.FloatTensor of size 5]
-
torch.logspace(start, end, steps=100, out=None) → Tensor¶ Returns a one-dimensional Tensor of
stepspoints logarithmically spaced between \(10^{start}\) and \(10^{end}\)The output is a 1D tensor of size
stepsParameters: Example:
>>> torch.logspace(start=-10, end=10, steps=5) 1.0000e-10 1.0000e-05 1.0000e+00 1.0000e+05 1.0000e+10 [torch.FloatTensor of size 5] >>> torch.logspace(start=0.1, end=1.0, steps=5) 1.2589 2.1135 3.5481 5.9566 10.0000 [torch.FloatTensor of size 5]
-
torch.ones(*sizes, out=None) → Tensor¶ Returns a Tensor filled with the scalar value 1, with the shape defined by the varargs
sizes.Parameters: - sizes (int...) – a set of ints defining the shape of the output Tensor.
- out (Tensor, optional) – the result Tensor
Example:
>>> torch.ones(2, 3) 1 1 1 1 1 1 [torch.FloatTensor of size 2x3] >>> torch.ones(5) 1 1 1 1 1 [torch.FloatTensor of size 5]
-
torch.rand(*sizes, out=None) → Tensor¶ Returns a Tensor filled with random numbers from a uniform distribution on the interval \([0, 1)\)
The shape of the Tensor is defined by the varargs
sizes.Parameters: - sizes (int...) – a set of ints defining the shape of the output Tensor.
- out (Tensor, optional) – the result Tensor
Example:
>>> torch.rand(4) 0.9193 0.3347 0.3232 0.7715 [torch.FloatTensor of size 4] >>> torch.rand(2, 3) 0.5010 0.5140 0.0719 0.1435 0.5636 0.0538 [torch.FloatTensor of size 2x3]
-
torch.randn(*sizes, out=None) → Tensor¶ Returns a Tensor filled with random numbers from a normal distribution with zero mean and variance of one.
The shape of the Tensor is defined by the varargs
sizes.Parameters: - sizes (int...) – a set of ints defining the shape of the output Tensor.
- out (Tensor, optional) – the result Tensor
Example:
>>> torch.randn(4) -0.1145 0.0094 -1.1717 0.9846 [torch.FloatTensor of size 4] >>> torch.randn(2, 3) 1.4339 0.3351 -1.0999 1.5458 -0.9643 -0.3558 [torch.FloatTensor of size 2x3]
-
torch.randperm(n, out=None) → LongTensor¶ Returns a random permutation of integers from
0ton - 1.Parameters: n (int) – the upper bound (exclusive) Example:
>>> torch.randperm(4) 2 1 3 0 [torch.LongTensor of size 4]
-
torch.arange(start, end, step=1, out=None) → Tensor¶ Teturns a 1D Tensor of size \(floor((end - start) / step)\) with values from the interval
[start, end)taken with stepstepstarting from start.Parameters: Example:
>>> torch.arange(1, 4) 1 2 3 [torch.FloatTensor of size 3] >>> torch.arange(1, 2.5, 0.5) 1.0000 1.5000 2.0000 [torch.FloatTensor of size 3]
-
torch.range(start, end, step=1, out=None) → Tensor¶ returns a 1D Tensor of size \(floor((end - start) / step) + 1\) with values from
starttoendwith stepstep. Step is the gap between two values in the tensor. \(x_{i+1} = x_i + step\)Warning
This function is deprecated in favor of
torch.arange().Parameters: Example:
>>> torch.range(1, 4) 1 2 3 4 [torch.FloatTensor of size 4] >>> torch.range(1, 4, 0.5) 1.0000 1.5000 2.0000 2.5000 3.0000 3.5000 4.0000 [torch.FloatTensor of size 7]
-
torch.zeros(*sizes, out=None) → Tensor¶ Returns a Tensor filled with the scalar value 0, with the shape defined by the varargs
sizes.Parameters: - sizes (int...) – a set of ints defining the shape of the output Tensor.
- out (Tensor, optional) – the result Tensor
Example:
>>> torch.zeros(2, 3) 0 0 0 0 0 0 [torch.FloatTensor of size 2x3] >>> torch.zeros(5) 0 0 0 0 0 [torch.FloatTensor of size 5]
Indexing, Slicing, Joining, Mutating Ops¶
-
torch.cat(seq, dim=0) → Tensor¶ Concatenates the given sequence of
seqTensors in the given dimension.torch.cat()can be seen as an inverse operation fortorch.split()andtorch.chunk()cat()can be best understood via examples.Parameters: - seq (sequence of Tensors) – Can be any python sequence of Tensor of the same type.
- dim (int, optional) – The dimension over which the tensors are concatenated
Example:
>>> x = torch.randn(2, 3) >>> x 0.5983 -0.0341 2.4918 1.5981 -0.5265 -0.8735 [torch.FloatTensor of size 2x3] >>> torch.cat((x, x, x), 0) 0.5983 -0.0341 2.4918 1.5981 -0.5265 -0.8735 0.5983 -0.0341 2.4918 1.5981 -0.5265 -0.8735 0.5983 -0.0341 2.4918 1.5981 -0.5265 -0.8735 [torch.FloatTensor of size 6x3] >>> torch.cat((x, x, x), 1) 0.5983 -0.0341 2.4918 0.5983 -0.0341 2.4918 0.5983 -0.0341 2.4918 1.5981 -0.5265 -0.8735 1.5981 -0.5265 -0.8735 1.5981 -0.5265 -0.8735 [torch.FloatTensor of size 2x9]
-
torch.chunk(tensor, chunks, dim=0)[source]¶ Splits a tensor into a number of chunks along a given dimension.
Parameters:
-
torch.gather(input, dim, index, out=None) → Tensor¶ Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # dim=0 out[i][j][k] = input[i][index[i][j][k]][k] # dim=1 out[i][j][k] = input[i][j][index[i][j][k]] # dim=2
Parameters: Example:
>>> t = torch.Tensor([[1,2],[3,4]]) >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])) 1 1 4 3 [torch.FloatTensor of size 2x2]
-
torch.index_select(input, dim, index, out=None) → Tensor¶ Returns a new Tensor which indexes the
inputTensor along dimensiondimusing the entries inindexwhich is a LongTensor.The returned Tensor has the same number of dimensions as the original Tensor.
Note
The returned Tensor does not use the same storage as the original Tensor
Parameters: Example:
>>> x = torch.randn(3, 4) >>> x 1.2045 2.4084 0.4001 1.1372 0.5596 1.5677 0.6219 -0.7954 1.3635 -1.2313 -0.5414 -1.8478 [torch.FloatTensor of size 3x4] >>> indices = torch.LongTensor([0, 2]) >>> torch.index_select(x, 0, indices) 1.2045 2.4084 0.4001 1.1372 1.3635 -1.2313 -0.5414 -1.8478 [torch.FloatTensor of size 2x4] >>> torch.index_select(x, 1, indices) 1.2045 0.4001 0.5596 0.6219 1.3635 -0.5414 [torch.FloatTensor of size 3x2]
-
torch.masked_select(input, mask, out=None) → Tensor¶ Returns a new 1D Tensor which indexes the
inputTensor according to the binary maskmaskwhich is a ByteTensor.The
masktensor needs to have the same number of elements asinput, but it’s shape or dimensionality are irrelevant.Note
The returned Tensor does not use the same storage as the original Tensor
Parameters: Example:
>>> x = torch.randn(3, 4) >>> x 1.2045 2.4084 0.4001 1.1372 0.5596 1.5677 0.6219 -0.7954 1.3635 -1.2313 -0.5414 -1.8478 [torch.FloatTensor of size 3x4] >>> mask = x.ge(0.5) >>> mask 1 1 0 1 1 1 1 0 1 0 0 0 [torch.ByteTensor of size 3x4] >>> torch.masked_select(x, mask) 1.2045 2.4084 1.1372 0.5596 1.5677 0.6219 1.3635 [torch.FloatTensor of size 7]
-
torch.nonzero(input, out=None) → LongTensor¶ Returns a tensor containing the indices of all non-zero elements of
input. Each row in the result contains the indices of a non-zero element ininput.If
inputhas n dimensions, then the resulting indices Tensoroutis of size z x n, where z is the total number of non-zero elements in theinputTensor.Parameters: - input (Tensor) – the input Tensor
- out (LongTensor, optional) – The result Tensor containing indices
Example:
>>> torch.nonzero(torch.Tensor([1, 1, 1, 0, 1])) 0 1 2 4 [torch.LongTensor of size 4x1] >>> torch.nonzero(torch.Tensor([[0.6, 0.0, 0.0, 0.0], ... [0.0, 0.4, 0.0, 0.0], ... [0.0, 0.0, 1.2, 0.0], ... [0.0, 0.0, 0.0,-0.4]])) 0 0 1 1 2 2 3 3 [torch.LongTensor of size 4x2]
-
torch.split(tensor, split_size, dim=0)[source]¶ Splits the tensor into equally sized chunks (if possible).
Last chunk will be smaller if the tensor size along a given dimension is not divisible by
split_size.Parameters:
-
torch.squeeze(input, dim=None, out=None)¶ Returns a Tensor with all the dimensions of
inputof size 1 removed.If input is of shape: \((A x 1 x B x C x 1 x D)\) then the out Tensor will be of shape: \((A x B x C x D)\)
When
dimis given, a squeeze operation is done only in the given dimension. If input is of shape: \((A x 1 x B)\), squeeze(input, 0) leaves the Tensor unchanged, but squeeze(input, 1) will squeeze the tensor to the shape \((A x B)\).Note
The returned Tensor shares the storage with the input Tensor, so changing the contents of one will change the contents of the other.
Parameters: Example:
>>> x = torch.zeros(2,1,2,1,2) >>> x.size() (2L, 1L, 2L, 1L, 2L) >>> y = torch.squeeze(x) >>> y.size() (2L, 2L, 2L) >>> y = torch.squeeze(x, 0) >>> y.size() (2L, 1L, 2L, 1L, 2L) >>> y = torch.squeeze(x, 1) >>> y.size() (2L, 2L, 1L, 2L)
-
torch.stack(sequence, dim=0, out=None)[source]¶ Concatenates sequence of tensors along a new dimension.
All tensors need to be of the same size.
Parameters: - sequence (Sequence) – sequence of tensors to concatenate.
- dim (int) – dimension to insert. Has to be between 0 and the number of dimensions of concatenated tensors (inclusive).
-
torch.t(input, out=None) → Tensor¶ Expects
inputto be a matrix (2D Tensor) and transposes dimensions 0 and 1.Can be seen as a short-hand function for transpose(input, 0, 1)
Parameters: Example:
>>> x = torch.randn(2, 3) >>> x 0.4834 0.6907 1.3417 -0.1300 0.5295 0.2321 [torch.FloatTensor of size 2x3] >>> torch.t(x) 0.4834 -0.1300 0.6907 0.5295 1.3417 0.2321 [torch.FloatTensor of size 3x2]
-
torch.transpose(input, dim0, dim1, out=None) → Tensor¶ Returns a Tensor that is a transposed version of
input. The given dimensionsdim0anddim1are swapped.The resulting
outTensor shares it’s underlying storage with theinputTensor, so changing the content of one would change the content of the other.Parameters: Example:
>>> x = torch.randn(2, 3) >>> x 0.5983 -0.0341 2.4918 1.5981 -0.5265 -0.8735 [torch.FloatTensor of size 2x3] >>> torch.transpose(x, 0, 1) 0.5983 1.5981 -0.0341 -0.5265 2.4918 -0.8735 [torch.FloatTensor of size 3x2]
-
torch.unbind(tensor, dim=0)[source]¶ Removes a tensor dimension.
Returns a tuple of all slices along a given dimension, already without it.
Parameters:
-
torch.unsqueeze(input, dim, out=None)¶ Returns a new tensor with a dimension of size one inserted at the specified position.
The returned tensor shares the same underlying data with this tensor.
A negative dim value can be used and will correspond to \(dim + input.dim() + 1\)
Parameters: Example
>>> x = torch.Tensor([1, 2, 3, 4]) >>> torch.unsqueeze(x, 0) 1 2 3 4 [torch.FloatTensor of size 1x4] >>> torch.unsqueeze(x, 1) 1 2 3 4 [torch.FloatTensor of size 4x1]
Random sampling¶
-
torch.manual_seed(seed)[source]¶ Sets the seed for generating random numbers. And returns a torch._C.Generator object.
Parameters: seed (int or long) – The desired seed.
-
torch.initial_seed()[source]¶ Returns the initial seed for generating random numbers as a python long.
-
torch.set_rng_state(new_state)[source]¶ Sets the random number generator state.
Parameters: new_state (torch.ByteTensor) – The desired state
-
torch.default_generator= <torch._C.Generator object>¶
-
torch.bernoulli(input, out=None) → Tensor¶ Draws binary random numbers (0 or 1) from a bernoulli distribution.
The
inputTensor should be a tensor containing probabilities to be used for drawing the binary random number. Hence, all values ininputhave to be in the range: \(0 <= input_i <= 1\)The i-th element of the output tensor will draw a value 1 according to the i-th probability value given in
input.The returned
outTensor only has values 0 or 1 and is of the same shape asinputParameters: Example:
>>> a = torch.Tensor(3, 3).uniform_(0, 1) # generate a uniform random matrix with range [0, 1] >>> a 0.7544 0.8140 0.9842 0.5282 0.0595 0.6445 0.1925 0.9553 0.9732 [torch.FloatTensor of size 3x3] >>> torch.bernoulli(a) 1 1 1 0 0 1 0 1 1 [torch.FloatTensor of size 3x3] >>> a = torch.ones(3, 3) # probability of drawing "1" is 1 >>> torch.bernoulli(a) 1 1 1 1 1 1 1 1 1 [torch.FloatTensor of size 3x3] >>> a = torch.zeros(3, 3) # probability of drawing "1" is 0 >>> torch.bernoulli(a) 0 0 0 0 0 0 0 0 0 [torch.FloatTensor of size 3x3]
-
torch.multinomial(input, num_samples, replacement=False, out=None) → LongTensor¶ Returns a Tensor where each row contains
num_samplesindices sampled from the multinomial probability distribution located in the corresponding row of Tensorinput.Note
The rows of
inputdo not need to sum to one (in which case we use the values as weights), but must be non-negative and have a non-zero sum.Indices are ordered from left to right according to when each was sampled (first samples are placed in first column).
If
inputis a vector,outis a vector of size num_samples.If
inputis a matrix with m rows,outis an matrix of shape m × n.If replacement is True, samples are drawn with replacement.
If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.
This implies the constraint that
num_samplesmust be lower thaninputlength (or number of columns ofinputif it is a matrix).Parameters: Example:
>>> weights = torch.Tensor([0, 10, 3, 0]) # create a Tensor of weights >>> torch.multinomial(weights, 4) 1 2 0 0 [torch.LongTensor of size 4] >>> torch.multinomial(weights, 4, replacement=True) 1 2 1 2 [torch.LongTensor of size 4]
-
torch.normal()¶ -
torch.normal(means, std, out=None)
Returns a Tensor of random numbers drawn from separate normal distributions who’s mean and standard deviation are given.
The
meansis a Tensor with the mean of each output element’s normal distributionThe
stdis a Tensor with the standard deviation of each output element’s normal distributionThe shapes of
meansandstddon’t need to match. The total number of elements in each Tensor need to be the same.Note
When the shapes do not match, the shape of
meansis used as the shape for the returned output TensorParameters: Example:
torch.normal(means=torch.arange(1, 11), std=torch.arange(1, 0, -0.1)) 1.5104 1.6955 2.4895 4.9185 4.9895 6.9155 7.3683 8.1836 8.7164 9.8916 [torch.FloatTensor of size 10]
-
torch.normal(mean=0.0, std, out=None)
Similar to the function above, but the means are shared among all drawn elements.
Parameters: Example:
>>> torch.normal(mean=0.5, std=torch.arange(1, 6)) 0.5723 0.0871 -0.3783 -2.5689 10.7893 [torch.FloatTensor of size 5]
-
torch.normal(means, std=1.0, out=None)
Similar to the function above, but the standard-deviations are shared among all drawn elements.
Parameters: Example:
>>> torch.normal(means=torch.arange(1, 6)) 1.1681 2.8884 3.7718 2.5616 4.2500 [torch.FloatTensor of size 5]
-
Serialization¶
-
torch.save(obj, f, pickle_module=<module 'pickle' from '/home/jenkins/miniconda/lib/python3.5/pickle.py'>, pickle_protocol=2)[source]¶ Saves an object to a disk file.
See also: Recommended approach for saving a model
Parameters: - obj – saved object
- f – a file-like object (has to implement fileno that returns a file descriptor) or a string containing a file name
- pickle_module – module used for pickling metadata and objects
- pickle_protocol – can be specified to override the default protocol
-
torch.load(f, map_location=None, pickle_module=<module 'pickle' from '/home/jenkins/miniconda/lib/python3.5/pickle.py'>)[source]¶ Loads an object saved with
torch.save()from a file.torch.load can dynamically remap storages to be loaded on a different device using the map_location argument. If it’s a callable, it will be called with two arguments: storage and location tag. It’s expected to either return a storage that’s been moved to a different location, or None (and the location will be resolved using the default method). If this argument is a dict it’s expected to be a mapping from location tags used in a file, to location tags of the current system.
By default the location tags are ‘cpu’ for host tensors and ‘cuda:device_id’ (e.g. ‘cuda:2’) for cuda tensors. User extensions can register their own tagging and deserialization methods using register_package.
Parameters: - f – a file-like object (has to implement fileno that returns a file descriptor, and must implement seek), or a string containing a file name
- map_location – a function or a dict specifying how to remap storage locations
- pickle_module – module used for unpickling metadata and objects (has to match the pickle_module used to serialize file)
Example
>>> torch.load('tensors.pt') # Load all tensors onto the CPU >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage) # Map tensors from GPU 1 to GPU 0 >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
Parallelism¶
-
torch.get_num_threads() → int¶ Gets the number of OpenMP threads used for parallelizing CPU operations
-
torch.set_num_threads(int)¶ Sets the number of OpenMP threads used for parallelizing CPU operations
Math operations¶
Pointwise Ops¶
-
torch.abs(input, out=None) → Tensor¶ Computes the element-wise absolute value of the given
inputa tensor.Example:
>>> torch.abs(torch.FloatTensor([-1, -2, 3])) FloatTensor([1, 2, 3])
-
torch.acos(input, out=None) → Tensor¶ Returns a new Tensor with the arccosine of the elements of
input.Parameters: Example:
>>> a = torch.randn(4) >>> a -0.6366 0.2718 0.4469 1.3122 [torch.FloatTensor of size 4] >>> torch.acos(a) 2.2608 1.2956 1.1075 nan [torch.FloatTensor of size 4]
-
torch.add()¶ -
torch.add(input, value, out=None)
Adds the scalar
valueto each element of the inputinputand returns a new resulting tensor.\(out = tensor + value\)
If
inputis of type FloatTensor or DoubleTensor,valuemust be a real number, otherwise it should be an integerParameters: Example:
>>> a = torch.randn(4) >>> a 0.4050 -1.2227 1.8688 -0.4185 [torch.FloatTensor of size 4] >>> torch.add(a, 20) 20.4050 18.7773 21.8688 19.5815 [torch.FloatTensor of size 4]
-
torch.add(input, value=1, other, out=None)
Each element of the Tensor
otheris multiplied by the scalarvalueand added to each element of the Tensorinput. The resulting Tensor is returned.The shapes of
inputandotherdon’t need to match. The total number of elements in each Tensor need to be the same.Note
When the shapes do not match, the shape of
inputis used as the shape for the returned output Tensor\(out = input + (other * value)\)
If
otheris of type FloatTensor or DoubleTensor,valuemust be a real number, otherwise it should be an integerParameters: Example:
>>> import torch >>> a = torch.randn(4) >>> a -0.9310 2.0330 0.0852 -0.2941 [torch.FloatTensor of size 4] >>> b = torch.randn(2, 2) >>> b 1.0663 0.2544 -0.1513 0.0749 [torch.FloatTensor of size 2x2] >>> torch.add(a, 10, b) 9.7322 4.5770 -1.4279 0.4552 [torch.FloatTensor of size 4]
-
-
torch.addcdiv(tensor, value=1, tensor1, tensor2, out=None) → Tensor¶ Performs the element-wise division of
tensor1bytensor2, multiply the result by the scalarvalueand add it totensor.The number of elements must match, but sizes do not matter.
For inputs of type FloatTensor or DoubleTensor,
valuemust be a real number, otherwise an integerParameters: Example:
>>> t = torch.randn(2, 3) >>> t1 = torch.randn(1, 6) >>> t2 = torch.randn(6, 1) >>> torch.addcdiv(t, 0.1, t1, t2) 0.0122 -0.0188 -0.2354 0.7396 -1.5721 1.2878 [torch.FloatTensor of size 2x3]
-
torch.addcmul(tensor, value=1, tensor1, tensor2, out=None) → Tensor¶ Performs the element-wise multiplication of
tensor1bytensor2, multiply the result by the scalarvalueand add it totensor.The number of elements must match, but sizes do not matter.
For inputs of type FloatTensor or DoubleTensor,
valuemust be a real number, otherwise an integerParameters: Example:
>>> t = torch.randn(2, 3) >>> t1 = torch.randn(1, 6) >>> t2 = torch.randn(6, 1) >>> torch.addcmul(t, 0.1, t1, t2) 0.0122 -0.0188 -0.2354 0.7396 -1.5721 1.2878 [torch.FloatTensor of size 2x3]
-
torch.asin(input, out=None) → Tensor¶ Returns a new Tensor with the arcsine of the elements of
input.Parameters: Example:
>>> a = torch.randn(4) >>> a -0.6366 0.2718 0.4469 1.3122 [torch.FloatTensor of size 4] >>> torch.asin(a) -0.6900 0.2752 0.4633 nan [torch.FloatTensor of size 4]
-
torch.atan(input, out=None) → Tensor¶ Returns a new Tensor with the arctangent of the elements of
input.Parameters: Example:
>>> a = torch.randn(4) >>> a -0.6366 0.2718 0.4469 1.3122 [torch.FloatTensor of size 4] >>> torch.atan(a) -0.5669 0.2653 0.4203 0.9196 [torch.FloatTensor of size 4]
-
torch.atan2(input1, input2, out=None) → Tensor¶ Returns a new Tensor with the arctangent of the elements of
input1andinput2.Parameters: Example:
>>> a = torch.randn(4) >>> a -0.6366 0.2718 0.4469 1.3122 [torch.FloatTensor of size 4] >>> torch.atan2(a, torch.randn(4)) -2.4167 2.9755 0.9363 1.6613 [torch.FloatTensor of size 4]
-
torch.ceil(input, out=None) → Tensor¶ Returns a new Tensor with the ceil of the elements of
input, the smallest integer greater than or equal to each element.Parameters: Example:
>>> a = torch.randn(4) >>> a 1.3869 0.3912 -0.8634 -0.5468 [torch.FloatTensor of size 4] >>> torch.ceil(a) 2 1 -0 -0 [torch.FloatTensor of size 4]
-
torch.clamp(input, min, max, out=None) → Tensor¶ Clamp all elements in
inputinto the range [min, max] and return a resulting Tensor.| min, if x_i < min y_i = | x_i, if min <= x_i <= max | max, if x_i > max
If
inputis of type FloatTensor or DoubleTensor, argsminandmaxmust be real numbers, otherwise they should be integersParameters: Example:
>>> a = torch.randn(4) >>> a 1.3869 0.3912 -0.8634 -0.5468 [torch.FloatTensor of size 4] >>> torch.clamp(a, min=-0.5, max=0.5) 0.5000 0.3912 -0.5000 -0.5000 [torch.FloatTensor of size 4]
-
torch.clamp(input, *, min, out=None) → Tensor
Clamps all elements in
inputto be larger or equalmin.If
inputis of type FloatTensor or DoubleTensor,valueshould be a real number, otherwise it should be an integerParameters: Example:
>>> a = torch.randn(4) >>> a 1.3869 0.3912 -0.8634 -0.5468 [torch.FloatTensor of size 4] >>> torch.clamp(a, min=0.5) 1.3869 0.5000 0.5000 0.5000 [torch.FloatTensor of size 4]
-
torch.clamp(input, *, max, out=None) → Tensor
Clamps all elements in
inputto be smaller or equalmax.If
inputis of type FloatTensor or DoubleTensor,valueshould be a real number, otherwise it should be an integerParameters: Example:
>>> a = torch.randn(4) >>> a 1.3869 0.3912 -0.8634 -0.5468 [torch.FloatTensor of size 4] >>> torch.clamp(a, max=0.5) 0.5000 0.3912 -0.8634 -0.5468 [torch.FloatTensor of size 4]
-
-
torch.cos(input, out=None) → Tensor¶ Returns a new Tensor with the cosine of the elements of
input.Parameters: Example:
>>> a = torch.randn(4) >>> a -0.6366 0.2718 0.4469 1.3122 [torch.FloatTensor of size 4] >>> torch.cos(a) 0.8041 0.9633 0.9018 0.2557 [torch.FloatTensor of size 4]
-
torch.cosh(input, out=None) → Tensor¶ Returns a new Tensor with the hyperbolic cosine of the elements of
input.Parameters: Example:
>>> a = torch.randn(4) >>> a -0.6366 0.2718 0.4469 1.3122 [torch.FloatTensor of size 4] >>> torch.cosh(a) 1.2095 1.0372 1.1015 1.9917 [torch.FloatTensor of size 4]
-
torch.div()¶ -
torch.div(input, value, out=None)
Divides each element of the input
inputwith the scalarvalueand returns a new resulting tensor.\(out = tensor / value\)
If
inputis of type FloatTensor or DoubleTensor,valueshould be a real number, otherwise it should be an integerParameters: Example:
>>> a = torch.randn(5) >>> a -0.6147 -1.1237 -0.1604 -0.6853 0.1063 [torch.FloatTensor of size 5] >>> torch.div(a, 0.5) -1.2294 -2.2474 -0.3208 -1.3706 0.2126 [torch.FloatTensor of size 5]
-
torch.div(input, other, out=None)
Each element of the Tensor
inputis divided by each element of the Tensorother. The resulting Tensor is returned. The shapes ofinputandotherdon’t need to match. The total number of elements in each Tensor need to be the same.Note
When the shapes do not match, the shape of
inputis used as the shape for the returned output Tensor\(out_i = input_i / other_i\)
Parameters: Example:
>>> a = torch.randn(4,4) >>> a -0.1810 0.4017 0.2863 -0.1013 0.6183 2.0696 0.9012 -1.5933 0.5679 0.4743 -0.0117 -0.1266 -0.1213 0.9629 0.2682 1.5968 [torch.FloatTensor of size 4x4] >>> b = torch.randn(8, 2) >>> b 0.8774 0.7650 0.8866 1.4805 -0.6490 1.1172 1.4259 -0.8146 1.4633 -0.1228 0.4643 -0.6029 0.3492 1.5270 1.6103 -0.6291 [torch.FloatTensor of size 8x2] >>> torch.div(a, b) -0.2062 0.5251 0.3229 -0.0684 -0.9528 1.8525 0.6320 1.9559 0.3881 -3.8625 -0.0253 0.2099 -0.3473 0.6306 0.1666 -2.5381 [torch.FloatTensor of size 4x4]
-
-
torch.exp(tensor, out=None) → Tensor¶ Computes the exponential of each element.
Example:
>>> torch.exp(torch.Tensor([0, math.log(2)])) torch.FloatTensor([1, 2])
-
torch.floor(input, out=None) → Tensor¶ Returns a new Tensor with the floor of the elements of
input, the largest integer less than or equal to each element.Parameters: Example:
>>> a = torch.randn(4) >>> a 1.3869 0.3912 -0.8634 -0.5468 [torch.FloatTensor of size 4] >>> torch.floor(a) 1 0 -1 -1 [torch.FloatTensor of size 4]
-
torch.fmod(input, divisor, out=None) → Tensor¶ Computes the element-wise remainder of division.
The dividend and divisor may contain both for integer and floating point numbers. The remainder has the same sign as the dividend tensor.
Parameters: Example:
>>> torch.fmod(torch.Tensor([-3, -2, -1, 1, 2, 3]), 2) torch.FloatTensor([-1, -0, -1, 1, 0, 1]) >>> torch.fmod(torch.Tensor([1, 2, 3, 4, 5]), 1.5) torch.FloatTensor([1.0, 0.5, 0.0, 1.0, 0.5])
See also
torch.remainder(), which computes the element-wise remainder of division equivalently to Python’s % operator
-
torch.frac(tensor, out=None) → Tensor¶ Computes the fractional portion of each element in tensor.
Example:
>>> torch.frac(torch.Tensor([1, 2.5, -3.2]) torch.FloatTensor([0, 0.5, -0.2])
-
torch.lerp(start, end, weight, out=None)¶ Does a linear interpolation of two tensors
startandendbased on a scalarweight: and returns the resultingoutTensor.\(out_i = start_i + weight * (end_i - start_i)\)
Parameters: Example:
>>> start = torch.arange(1, 5) >>> end = torch.Tensor(4).fill_(10) >>> start 1 2 3 4 [torch.FloatTensor of size 4] >>> end 10 10 10 10 [torch.FloatTensor of size 4] >>> torch.lerp(start, end, 0.5) 5.5000 6.0000 6.5000 7.0000 [torch.FloatTensor of size 4]
-
torch.log(input, out=None) → Tensor¶ Returns a new Tensor with the natural logarithm of the elements of
input.Parameters: Example:
>>> a = torch.randn(5) >>> a -0.4183 0.3722 -0.3091 0.4149 0.5857 [torch.FloatTensor of size 5] >>> torch.log(a) nan -0.9883 nan -0.8797 -0.5349 [torch.FloatTensor of size 5]
-
torch.log1p(input, out=None) → Tensor¶ Returns a new Tensor with the natural logarithm of (1 +
input).\(y_i = log(x_i + 1)\)
Note
This function is more accurate than
torch.log()for small values ofinputParameters: Example:
>>> a = torch.randn(5) >>> a -0.4183 0.3722 -0.3091 0.4149 0.5857 [torch.FloatTensor of size 5] >>> torch.log1p(a) -0.5418 0.3164 -0.3697 0.3471 0.4611 [torch.FloatTensor of size 5]
-
torch.mul()¶ -
torch.mul(input, value, out=None)
Multiplies each element of the input
inputwith the scalarvalueand returns a new resulting tensor.\(out = tensor * value\)
If
inputis of type FloatTensor or DoubleTensor,valueshould be a real number, otherwise it should be an integerParameters: Example:
>>> a = torch.randn(3) >>> a -0.9374 -0.5254 -0.6069 [torch.FloatTensor of size 3] >>> torch.mul(a, 100) -93.7411 -52.5374 -60.6908 [torch.FloatTensor of size 3]
-
torch.mul(input, other, out=None)
Each element of the Tensor
inputis multiplied by each element of the Tensorother. The resulting Tensor is returned. The shapes ofinputandotherdon’t need to match. The total number of elements in each Tensor need to be the same.Note
When the shapes do not match, the shape of
inputis used as the shape for the returned output Tensor\(out_i = input_i * other_i\)
Parameters: Example:
>>> a = torch.randn(4,4) >>> a -0.7280 0.0598 -1.4327 -0.5825 -0.1427 -0.0690 0.0821 -0.3270 -0.9241 0.5110 0.4070 -1.1188 -0.8308 0.7426 -0.6240 -1.1582 [torch.FloatTensor of size 4x4] >>> b = torch.randn(2, 8) >>> b 0.0430 -1.0775 0.6015 1.1647 -0.6549 0.0308 -0.1670 1.0742 -1.2593 0.0292 -0.0849 0.4530 1.2404 -0.4659 -0.1840 0.5974 [torch.FloatTensor of size 2x8] >>> torch.mul(a, b) -0.0313 -0.0645 -0.8618 -0.6784 0.0934 -0.0021 -0.0137 -0.3513 1.1638 0.0149 -0.0346 -0.5068 -1.0304 -0.3460 0.1148 -0.6919 [torch.FloatTensor of size 4x4]
-
-
torch.neg(input, out=None) → Tensor¶ Returns a new Tensor with the negative of the elements of
input.\(out = -1 * input\)
Parameters: Example:
>>> a = torch.randn(5) >>> a -0.4430 1.1690 -0.8836 -0.4565 0.2968 [torch.FloatTensor of size 5] >>> torch.neg(a) 0.4430 -1.1690 0.8836 0.4565 -0.2968 [torch.FloatTensor of size 5]
-
torch.pow()¶ -
torch.pow(input, exponent, out=None)
Takes the power of each element in
inputwithexponentand returns a Tensor with the result.exponentcan be either a singlefloatnumber or aTensorwith the same number of elements asinput.When
exponentis a scalar value, the operation applied is:\(out_i = x_i ^ {exponent}\)
When
exponentis a Tensor, the operation applied is:\(out_i = x_i ^ {exponent_i}\)
Parameters: Example:
>>> a = torch.randn(4) >>> a -0.5274 -0.8232 -2.1128 1.7558 [torch.FloatTensor of size 4] >>> torch.pow(a, 2) 0.2781 0.6776 4.4640 3.0829 [torch.FloatTensor of size 4] >>> exp = torch.arange(1, 5) >>> a = torch.arange(1, 5) >>> a 1 2 3 4 [torch.FloatTensor of size 4] >>> exp 1 2 3 4 [torch.FloatTensor of size 4] >>> torch.pow(a, exp) 1 4 27 256 [torch.FloatTensor of size 4]
-
torch.pow(base, input, out=None)
baseis a scalarfloatvalue, andinputis a Tensor. The returned Tensoroutis of the same shape asinputThe operation applied is:
\(out_i = base ^ {input_i}\)
Parameters: Example:
>>> exp = torch.arange(1, 5) >>> base = 2 >>> torch.pow(base, exp) 2 4 8 16 [torch.FloatTensor of size 4]
-
-
torch.reciprocal(input, out=None) → Tensor¶ Returns a new Tensor with the reciprocal of the elements of
input, i.e. \(1.0 / x\)Parameters: Example:
>>> a = torch.randn(4) >>> a 1.3869 0.3912 -0.8634 -0.5468 [torch.FloatTensor of size 4] >>> torch.reciprocal(a) 0.7210 2.5565 -1.1583 -1.8289 [torch.FloatTensor of size 4]
-
torch.remainder(input, divisor, out=None) → Tensor¶ Computes the element-wise remainder of division.
The divisor and dividend may contain both for integer and floating point numbers. The remainder has the same sign as the divisor.
Parameters: Example:
>>> torch.remainder(torch.Tensor([-3, -2, -1, 1, 2, 3]), 2) torch.FloatTensor([1, 0, 1, 1, 0, 1]) >>> torch.remainder(torch.Tensor([1, 2, 3, 4, 5]), 1.5) torch.FloatTensor([1.0, 0.5, 0.0, 1.0, 0.5])
See also
torch.fmod(), which computes the element-wise remainder of division equivalently to the C library functionfmod()
-
torch.round(input, out=None) → Tensor¶ Returns a new Tensor with each of the elements of
inputrounded to the closest integer.Parameters: Example:
>>> a = torch.randn(4) >>> a 1.2290 1.3409 -0.5662 -0.0899 [torch.FloatTensor of size 4] >>> torch.round(a) 1 1 -1 -0 [torch.FloatTensor of size 4]
-
torch.rsqrt(input, out=None) → Tensor¶ Returns a new Tensor with the reciprocal of the square-root of each of the elements of
input.Parameters: Example:
>>> a = torch.randn(4) >>> a 1.2290 1.3409 -0.5662 -0.0899 [torch.FloatTensor of size 4] >>> torch.rsqrt(a) 0.9020 0.8636 nan nan [torch.FloatTensor of size 4]
-
torch.sigmoid(input, out=None) → Tensor¶ Returns a new Tensor with the sigmoid of the elements of
input.Parameters: Example:
>>> a = torch.randn(4) >>> a -0.4972 1.3512 0.1056 -0.2650 [torch.FloatTensor of size 4] >>> torch.sigmoid(a) 0.3782 0.7943 0.5264 0.4341 [torch.FloatTensor of size 4]
-
torch.sign(input, out=None) → Tensor¶ Returns a new Tensor with the sign of the elements of
input.Parameters: Example:
>>> a = torch.randn(4) >>> a -0.6366 0.2718 0.4469 1.3122 [torch.FloatTensor of size 4] >>> torch.sign(a) -1 1 1 1 [torch.FloatTensor of size 4]
-
torch.sin(input, out=None) → Tensor¶ Returns a new Tensor with the sine of the elements of
input.Parameters: Example:
>>> a = torch.randn(4) >>> a -0.6366 0.2718 0.4469 1.3122 [torch.FloatTensor of size 4] >>> torch.sin(a) -0.5944 0.2684 0.4322 0.9667 [torch.FloatTensor of size 4]
-
torch.sinh(input, out=None) → Tensor¶ Returns a new Tensor with the hyperbolic sine of the elements of
input.Parameters: Example:
>>> a = torch.randn(4) >>> a -0.6366 0.2718 0.4469 1.3122 [torch.FloatTensor of size 4] >>> torch.sinh(a) -0.6804 0.2751 0.4619 1.7225 [torch.FloatTensor of size 4]
-
torch.sqrt(input, out=None) → Tensor¶ Returns a new Tensor with the square-root of the elements of
input.Parameters: Example:
>>> a = torch.randn(4) >>> a 1.2290 1.3409 -0.5662 -0.0899 [torch.FloatTensor of size 4] >>> torch.sqrt(a) 1.1086 1.1580 nan nan [torch.FloatTensor of size 4]
-
torch.tan(input, out=None) → Tensor¶ Returns a new Tensor with the tangent of the elements of
input.Parameters: Example:
>>> a = torch.randn(4) >>> a -0.6366 0.2718 0.4469 1.3122 [torch.FloatTensor of size 4] >>> torch.tan(a) -0.7392 0.2786 0.4792 3.7801 [torch.FloatTensor of size 4]
-
torch.tanh(input, out=None) → Tensor¶ Returns a new Tensor with the hyperbolic tangent of the elements of
input.Parameters: Example:
>>> a = torch.randn(4) >>> a -0.6366 0.2718 0.4469 1.3122 [torch.FloatTensor of size 4] >>> torch.tanh(a) -0.5625 0.2653 0.4193 0.8648 [torch.FloatTensor of size 4]
-
torch.trunc(input, out=None) → Tensor¶ Returns a new Tensor with the truncated integer values of the elements of
input.Parameters: Example:
>>> a = torch.randn(4) >>> a -0.4972 1.3512 0.1056 -0.2650 [torch.FloatTensor of size 4] >>> torch.trunc(a) -0 1 0 -0 [torch.FloatTensor of size 4]
Reduction Ops¶
-
torch.cumprod(input, dim, out=None) → Tensor¶ Returns the cumulative product of elements of
inputin the dimensiondim.For example, if
inputis a vector of size N, the result will also be a vector of size N, with elements: \(y_i = x_1 * x_2 * x_3 * ... * x_i\)Parameters: Example:
>>> a = torch.randn(10) >>> a 1.1148 1.8423 1.4143 -0.4403 1.2859 -1.2514 -0.4748 1.1735 -1.6332 -0.4272 [torch.FloatTensor of size 10] >>> torch.cumprod(a, dim=0) 1.1148 2.0537 2.9045 -1.2788 -1.6444 2.0578 -0.9770 -1.1466 1.8726 -0.8000 [torch.FloatTensor of size 10] >>> a[5] = 0.0 >>> torch.cumprod(a, dim=0) 1.1148 2.0537 2.9045 -1.2788 -1.6444 -0.0000 0.0000 0.0000 -0.0000 0.0000 [torch.FloatTensor of size 10]
-
torch.cumsum(input, dim, out=None) → Tensor¶ Returns the cumulative sum of elements of
inputin the dimensiondim.For example, if
inputis a vector of size N, the result will also be a vector of size N, with elements: \(y_i = x_1 + x_2 + x_3 + ... + x_i\)Parameters: Example:
>>> a = torch.randn(10) >>> a -0.6039 -0.2214 -0.3705 -0.0169 1.3415 -0.1230 0.9719 0.6081 -0.1286 1.0947 [torch.FloatTensor of size 10] >>> torch.cumsum(a, dim=0) -0.6039 -0.8253 -1.1958 -1.2127 0.1288 0.0058 0.9777 1.5858 1.4572 2.5519 [torch.FloatTensor of size 10]
-
torch.dist(input, other, p=2, out=None) → Tensor¶ Returns the p-norm of (
input-other)Parameters: Example:
>>> x = torch.randn(4) >>> x 0.2505 -0.4571 -0.3733 0.7807 [torch.FloatTensor of size 4] >>> y = torch.randn(4) >>> y 0.7782 -0.5185 1.4106 -2.4063 [torch.FloatTensor of size 4] >>> torch.dist(x, y, 3.5) 3.302832063224223 >>> torch.dist(x, y, 3) 3.3677282206393286 >>> torch.dist(x, y, 0) inf >>> torch.dist(x, y, 1) 5.560028076171875
-
torch.mean()¶ -
torch.mean(input) → float
Returns the mean value of all elements in the
inputTensor.Parameters: input (Tensor) – the input Tensor Example:
>>> a = torch.randn(1, 3) >>> a -0.2946 -0.9143 2.1809 [torch.FloatTensor of size 1x3] >>> torch.mean(a) 0.32398951053619385
-
torch.mean(input, dim, out=None) → Tensor
Returns the mean value of each row of the
inputTensor in the given dimensiondim.The output Tensor is of the same size as
inputexcept in the dimensiondimwhere it is of size 1.Parameters: Example:
>>> a = torch.randn(4, 4) >>> a -1.2738 -0.3058 0.1230 -1.9615 0.8771 -0.5430 -0.9233 0.9879 1.4107 0.0317 -0.6823 0.2255 -1.3854 0.4953 -0.2160 0.2435 [torch.FloatTensor of size 4x4] >>> torch.mean(a, 1) -0.8545 0.0997 0.2464 -0.2157 [torch.FloatTensor of size 4x1]
-
-
torch.median(input, dim=-1, values=None, indices=None) -> (Tensor, LongTensor)¶ Returns the median value of each row of the
inputTensor in the given dimensiondim. Also returns the index location of the median value as a LongTensor.By default,
dimis the last dimension of theinputTensor.The output Tensors are of the same size as
inputexcept in the dimensiondimwhere it is of size 1.Note
This function is not defined for
torch.cuda.Tensoryet.Parameters: Example:
>>> a -0.6891 -0.6662 0.2697 0.7412 0.5254 -0.7402 0.5528 -0.2399 [torch.FloatTensor of size 4x2] >>> a = torch.randn(4, 5) >>> a 0.4056 -0.3372 1.0973 -2.4884 0.4334 2.1336 0.3841 0.1404 -0.1821 -0.7646 -0.2403 1.3975 -2.0068 0.1298 0.0212 -1.5371 -0.7257 -0.4871 -0.2359 -1.1724 [torch.FloatTensor of size 4x5] >>> torch.median(a, 1) ( 0.4056 0.1404 0.0212 -0.7257 [torch.FloatTensor of size 4x1] , 0 2 4 1 [torch.LongTensor of size 4x1] )
-
torch.mode(input, dim=-1, values=None, indices=None) -> (Tensor, LongTensor)¶ Returns the mode value of each row of the
inputTensor in the given dimensiondim. Also returns the index location of the mode value as a LongTensor.By default,
dimis the last dimension of theinputTensor.The output Tensors are of the same size as
inputexcept in the dimensiondimwhere it is of size 1.Note
This function is not defined for
torch.cuda.Tensoryet.Parameters: Example:
>>> a -0.6891 -0.6662 0.2697 0.7412 0.5254 -0.7402 0.5528 -0.2399 [torch.FloatTensor of size 4x2] >>> a = torch.randn(4, 5) >>> a 0.4056 -0.3372 1.0973 -2.4884 0.4334 2.1336 0.3841 0.1404 -0.1821 -0.7646 -0.2403 1.3975 -2.0068 0.1298 0.0212 -1.5371 -0.7257 -0.4871 -0.2359 -1.1724 [torch.FloatTensor of size 4x5] >>> torch.mode(a, 1) ( -2.4884 -0.7646 -2.0068 -1.5371 [torch.FloatTensor of size 4x1] , 3 4 2 0 [torch.LongTensor of size 4x1] )
-
torch.norm()¶ -
torch.norm(input, p=2) → float
Returns the p-norm of the
inputTensor.Parameters: Example:
>>> a = torch.randn(1, 3) >>> a -0.4376 -0.5328 0.9547 [torch.FloatTensor of size 1x3] >>> torch.norm(a, 3) 1.0338925067372466
-
torch.norm(input, p, dim, out=None) → Tensor
Returns the p-norm of each row of the
inputTensor in the given dimensiondim.The output Tensor is of the same size as
inputexcept in the dimensiondimwhere it is of size 1.Parameters: Example:
>>> a = torch.randn(4, 2) >>> a -0.6891 -0.6662 0.2697 0.7412 0.5254 -0.7402 0.5528 -0.2399 [torch.FloatTensor of size 4x2] >>> torch.norm(a, 2, 1) 0.9585 0.7888 0.9077 0.6026 [torch.FloatTensor of size 4x1] >>> torch.norm(a, 0, 1) 2 2 2 2 [torch.FloatTensor of size 4x1]
-
-
torch.prod()¶ -
torch.prod(input) → float
Returns the product of all elements in the
inputTensor.Parameters: input (Tensor) – the input Tensor Example:
>>> a = torch.randn(1, 3) >>> a 0.6170 0.3546 0.0253 [torch.FloatTensor of size 1x3] >>> torch.prod(a) 0.005537458061418483
-
torch.prod(input, dim, out=None) → Tensor
Returns the product of each row of the
inputTensor in the given dimensiondim.The output Tensor is of the same size as
inputexcept in the dimensiondimwhere it is of size 1.Parameters: Example:
>>> a = torch.randn(4, 2) >>> a 0.1598 -0.6884 -0.1831 -0.4412 -0.9925 -0.6244 -0.2416 -0.8080 [torch.FloatTensor of size 4x2] >>> torch.prod(a, 1) -0.1100 0.0808 0.6197 0.1952 [torch.FloatTensor of size 4x1]
-
-
torch.std()¶ -
torch.std(input) → float
Returns the standard-deviation of all elements in the
inputTensor.Parameters: input (Tensor) – the input Tensor Example:
>>> a = torch.randn(1, 3) >>> a -1.3063 1.4182 -0.3061 [torch.FloatTensor of size 1x3] >>> torch.std(a) 1.3782334731508061
-
torch.std(input, dim, out=None) → Tensor
Returns the standard-deviation of each row of the
inputTensor in the given dimensiondim.The output Tensor is of the same size as
inputexcept in the dimensiondimwhere it is of size 1.Parameters: Example:
>>> a = torch.randn(4, 4) >>> a 0.1889 -2.4856 0.0043 1.8169 -0.7701 -0.4682 -2.2410 0.4098 0.1919 -1.1856 -1.0361 0.9085 0.0173 1.0662 0.2143 -0.5576 [torch.FloatTensor of size 4x4] >>> torch.std(a, dim=1) 1.7756 1.1025 1.0045 0.6725 [torch.FloatTensor of size 4x1]
-
-
torch.sum()¶ -
torch.sum(input) → float
Returns the sum of all elements in the
inputTensor.Parameters: input (Tensor) – the input Tensor Example:
>>> a = torch.randn(1, 3) >>> a 0.6170 0.3546 0.0253 [torch.FloatTensor of size 1x3] >>> torch.sum(a) 0.9969287421554327
-
torch.sum(input, dim, out=None) → Tensor
Returns the sum of each row of the
inputTensor in the given dimensiondim.The output Tensor is of the same size as
inputexcept in the dimensiondimwhere it is of size 1.Parameters: Example:
>>> a = torch.randn(4, 4) >>> a -0.4640 0.0609 0.1122 0.4784 -1.3063 1.6443 0.4714 -0.7396 -1.3561 -0.1959 1.0609 -1.9855 2.6833 0.5746 -0.5709 -0.4430 [torch.FloatTensor of size 4x4] >>> torch.sum(a, 1) 0.1874 0.0698 -2.4767 2.2440 [torch.FloatTensor of size 4x1]
-
-
torch.var()¶ -
torch.var(input) → float
Returns the variance of all elements in the
inputTensor.Parameters: input (Tensor) – the input Tensor Example:
>>> a = torch.randn(1, 3) >>> a -1.3063 1.4182 -0.3061 [torch.FloatTensor of size 1x3] >>> torch.var(a) 1.899527506513334
-
torch.var(input, dim, out=None) → Tensor
Returns the variance of each row of the
inputTensor in the given dimensiondim.The output Tensor is of the same size as
inputexcept in the dimensiondimwhere it is of size 1.Parameters: Example:
>>> a = torch.randn(4, 4) >>> a -1.2738 -0.3058 0.1230 -1.9615 0.8771 -0.5430 -0.9233 0.9879 1.4107 0.0317 -0.6823 0.2255 -1.3854 0.4953 -0.2160 0.2435 [torch.FloatTensor of size 4x4] >>> torch.var(a, 1) 0.8859 0.9509 0.7548 0.6949 [torch.FloatTensor of size 4x1]
-
Comparison Ops¶
-
torch.eq(input, other, out=None) → Tensor¶ Computes element-wise equality
The second argument can be a number or a tensor of the same shape and type as the first argument.
Parameters: Returns: - a
torch.ByteTensorcontaining a 1 at each location where the tensors are equal and a 0 at every other location
Return type: Example:
>>> torch.eq(torch.Tensor([[1, 2], [3, 4]]), torch.Tensor([[1, 1], [4, 4]])) 1 0 0 1 [torch.ByteTensor of size 2x2]
- a
-
torch.equal(tensor1, tensor2) → bool¶ True if two tensors have the same size and elements, False otherwise.
Example:
>>> torch.equal(torch.Tensor([1, 2]), torch.Tensor([1, 2])) True
-
torch.ge(input, other, out=None) → Tensor¶ Computes tensor >= other element-wise.
The second argument can be a number or a tensor of the same shape and type as the first argument.
Parameters: Returns: a
torch.ByteTensorcontaining a 1 at each location where comparison is true.Return type: Example:
>>> torch.ge(torch.Tensor([[1, 2], [3, 4]]), torch.Tensor([[1, 1], [4, 4]])) 1 1 0 1 [torch.ByteTensor of size 2x2]
-
torch.gt(input, other, out=None) → Tensor¶ Computes tensor > other element-wise.
The second argument can be a number or a tensor of the same shape and type as the first argument.
Parameters: Returns: a
torch.ByteTensorcontaining a 1 at each location where comparison is true.Return type: Example:
>>> torch.gt(torch.Tensor([[1, 2], [3, 4]]), torch.Tensor([[1, 1], [4, 4]])) 0 1 0 0 [torch.ByteTensor of size 2x2]
-
torch.kthvalue(input, k, dim=None, out=None) -> (Tensor, LongTensor)¶ Returns the
k`th smallest element of the given :attr:`inputTensor along a given dimension.If
dimis not given, the last dimension of the input is chosen.A tuple of (values, indices) is returned, where the indices is the indices of the kth-smallest element in the original input Tensor in dimention dim.
Parameters: Example:
>>> x = torch.arange(1, 6) >>> x 1 2 3 4 5 [torch.FloatTensor of size 5] >>> torch.kthvalue(x, 4) ( 4 [torch.FloatTensor of size 1] , 3 [torch.LongTensor of size 1] )
-
torch.le(input, other, out=None) → Tensor¶ Computes tensor <= other element-wise.
The second argument can be a number or a tensor of the same shape and type as the first argument.
Parameters: Returns: a
torch.ByteTensorcontaining a 1 at each location where comparison is true.Return type: Example:
>>> torch.le(torch.Tensor([[1, 2], [3, 4]]), torch.Tensor([[1, 1], [4, 4]])) 1 0 1 1 [torch.ByteTensor of size 2x2]
-
torch.lt(input, other, out=None) → Tensor¶ Computes tensor < other element-wise.
The second argument can be a number or a tensor of the same shape and type as the first argument.
Parameters: Returns: a
torch.ByteTensorcontaining a 1 at each location where comparison is true.Return type: Example:
>>> torch.lt(torch.Tensor([[1, 2], [3, 4]]), torch.Tensor([[1, 1], [4, 4]])) 0 0 1 0 [torch.ByteTensor of size 2x2]
-
torch.max()¶ -
torch.max(input) → float
Returns the maximum value of all elements in the
inputTensor.Parameters: input (Tensor) – the input Tensor Example:
>>> a = torch.randn(1, 3) >>> a 0.4729 -0.2266 -0.2085 [torch.FloatTensor of size 1x3] >>> torch.max(a) 0.4729
-
torch.max(input, dim, max=None, max_indices=None) -> (Tensor, LongTensor)
Returns the maximum value of each row of the
inputTensor in the given dimensiondim. Also returns the index location of each maximum value found.The output Tensors are of the same size as
inputexcept in the dimensiondimwhere they are of size 1.Parameters: Example:
>> a = torch.randn(4, 4) >> a 0.0692 0.3142 1.2513 -0.5428 0.9288 0.8552 -0.2073 0.6409 1.0695 -0.0101 -2.4507 -1.2230 0.7426 -0.7666 0.4862 -0.6628 torch.FloatTensor of size 4x4] >>> torch.max(a, 1) ( 1.2513 0.9288 1.0695 0.7426 [torch.FloatTensor of size 4x1] , 2 0 0 0 [torch.LongTensor of size 4x1] )
-
torch.max(input, other, out=None) → Tensor
Each element of the Tensor
inputis compared with the corresponding element of the Tensorotherand an element-wise max is taken.The shapes of
inputandotherdon’t need to match. The total number of elements in each Tensor need to be the same.Note
When the shapes do not match, the shape of
inputis used as the shape for the returned output Tensor\(out_i = max(tensor_i, other_i)\)
Parameters: Example:
>>> a = torch.randn(4) >>> a 1.3869 0.3912 -0.8634 -0.5468 [torch.FloatTensor of size 4] >>> b = torch.randn(4) >>> b 1.0067 -0.8010 0.6258 0.3627 [torch.FloatTensor of size 4] >>> torch.max(a, b) 1.3869 0.3912 0.6258 0.3627 [torch.FloatTensor of size 4]
-
-
torch.min()¶ -
torch.min(input) → float
Returns the minimum value of all elements in the
inputTensor.Parameters: input (Tensor) – the input Tensor Example:
>>> a = torch.randn(1, 3) >>> a 0.4729 -0.2266 -0.2085 [torch.FloatTensor of size 1x3] >>> torch.min(a) -0.22663167119026184
-
torch.min(input, dim, min=None, min_indices=None) -> (Tensor, LongTensor)
Returns the minimum value of each row of the
inputTensor in the given dimensiondim. Also returns the index location of each minimum value found.The output Tensors are of the same size as
inputexcept in the dimensiondimwhere they are of size 1.Parameters: Example:
>> a = torch.randn(4, 4) >> a 0.0692 0.3142 1.2513 -0.5428 0.9288 0.8552 -0.2073 0.6409 1.0695 -0.0101 -2.4507 -1.2230 0.7426 -0.7666 0.4862 -0.6628 torch.FloatTensor of size 4x4] >> torch.min(a, 1) 0.5428 0.2073 2.4507 0.7666 torch.FloatTensor of size 4x1] 3 2 2 1 torch.LongTensor of size 4x1]
-
torch.min(input, other, out=None) → Tensor
Each element of the Tensor
inputis compared with the corresponding element of the Tensorotherand an element-wise min is taken. The resulting Tensor is returned.The shapes of
inputandotherdon’t need to match. The total number of elements in each Tensor need to be the same.Note
When the shapes do not match, the shape of
inputis used as the shape for the returned output Tensor\(out_i = min(tensor_i, other_i)\)
Parameters: Example:
>>> a = torch.randn(4) >>> a 1.3869 0.3912 -0.8634 -0.5468 [torch.FloatTensor of size 4] >>> b = torch.randn(4) >>> b 1.0067 -0.8010 0.6258 0.3627 [torch.FloatTensor of size 4] >>> torch.min(a, b) 1.0067 -0.8010 -0.8634 -0.5468 [torch.FloatTensor of size 4]
-
-
torch.ne(input, other, out=None) → Tensor¶ Computes tensor != other element-wise.
The second argument can be a number or a tensor of the same shape and type as the first argument.
Parameters: Returns: a
torch.ByteTensorcontaining a 1 at each location where comparison is true.Return type: Example:
>>> torch.ne(torch.Tensor([[1, 2], [3, 4]]), torch.Tensor([[1, 1], [4, 4]])) 0 1 1 0 [torch.ByteTensor of size 2x2]
-
torch.sort(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor)¶ Sorts the elements of the
inputTensor along a given dimension in ascending order by value.If
dimis not given, the last dimension of the input is chosen.If
descendingis True then the elements are sorted in descending order by value.A tuple of (sorted_tensor, sorted_indices) is returned, where the sorted_indices are the indices of the elements in the original input Tensor.
Parameters: Example:
>>> x = torch.randn(3, 4) >>> sorted, indices = torch.sort(x) >>> sorted -1.6747 0.0610 0.1190 1.4137 -1.4782 0.7159 1.0341 1.3678 -0.3324 -0.0782 0.3518 0.4763 [torch.FloatTensor of size 3x4] >>> indices 0 1 3 2 2 1 0 3 3 1 0 2 [torch.LongTensor of size 3x4] >>> sorted, indices = torch.sort(x, 0) >>> sorted -1.6747 -0.0782 -1.4782 -0.3324 0.3518 0.0610 0.4763 0.1190 1.0341 0.7159 1.4137 1.3678 [torch.FloatTensor of size 3x4] >>> indices 0 2 1 2 2 0 2 0 1 1 0 1 [torch.LongTensor of size 3x4]
-
torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)¶ Returns the
klargest elements of the giveninputTensor along a given dimension.If
dimis not given, the last dimension of the input is chosen.If
largestis False then the k smallest elements are returned.A tuple of (values, indices) is returned, where the indices are the indices of the elements in the original input Tensor.
The boolean option
sortedif True, will make sure that the returned k elements are themselves sortedParameters: - input (Tensor) – the input Tensor
- k (int) – the k in “top-k”
- dim (int, optional) – The dimension to sort along
- largest (bool, optional) – Controls whether to return largest or smallest elements
- sorted (bool, optional) – Controls whether to return the elements in sorted order
- out (tuple, optional) – The output tuple of (Tensor, LongTensor) can be optionally given to be used as output buffers
Example:
>>> x = torch.arange(1, 6) >>> x 1 2 3 4 5 [torch.FloatTensor of size 5] >>> torch.topk(x, 3) ( 5 4 3 [torch.FloatTensor of size 3] , 4 3 2 [torch.LongTensor of size 3] ) >>> torch.topk(x, 3, 0, largest=False) ( 1 2 3 [torch.FloatTensor of size 3] , 0 1 2 [torch.LongTensor of size 3] )
Other Operations¶
-
torch.cross(input, other, dim=-1, out=None) → Tensor¶ Returns the cross product of vectors in dimension
dimofinputandother.inputandothermust have the same size, and the size of theirdimdimension should be 3.If
dimis not given, it defaults to the first dimension found with the size 3.Parameters: Example:
>>> a = torch.randn(4, 3) >>> a -0.6652 -1.0116 -0.6857 0.2286 0.4446 -0.5272 0.0476 0.2321 1.9991 0.6199 1.1924 -0.9397 [torch.FloatTensor of size 4x3] >>> b = torch.randn(4, 3) >>> b -0.1042 -1.1156 0.1947 0.9947 0.1149 0.4701 -1.0108 0.8319 -0.0750 0.9045 -1.3754 1.0976 [torch.FloatTensor of size 4x3] >>> torch.cross(a, b, dim=1) -0.9619 0.2009 0.6367 0.2696 -0.6318 -0.4160 -1.6805 -2.0171 0.2741 0.0163 -1.5304 -1.9311 [torch.FloatTensor of size 4x3] >>> torch.cross(a, b) -0.9619 0.2009 0.6367 0.2696 -0.6318 -0.4160 -1.6805 -2.0171 0.2741 0.0163 -1.5304 -1.9311 [torch.FloatTensor of size 4x3]
-
torch.diag(input, diagonal=0, out=None) → Tensor¶ - If
inputis a vector (1D Tensor), then returns a 2D square Tensor with the elements ofinputas the diagonal. - If
inputis a matrix (2D Tensor), then returns a 1D Tensor with the diagonal elements ofinput.
The argument
diagonalcontrols which diagonal to consider.diagonal= 0, is the main diagonal.diagonal> 0, is above the main diagonal.diagonal< 0, is below the main diagonal.
Parameters: Example:
Get the square matrix where the input vector is the diagonal:
>>> a = torch.randn(3) >>> a 1.0480 -2.3405 -1.1138 [torch.FloatTensor of size 3] >>> torch.diag(a) 1.0480 0.0000 0.0000 0.0000 -2.3405 0.0000 0.0000 0.0000 -1.1138 [torch.FloatTensor of size 3x3] >>> torch.diag(a, 1) 0.0000 1.0480 0.0000 0.0000 0.0000 0.0000 -2.3405 0.0000 0.0000 0.0000 0.0000 -1.1138 0.0000 0.0000 0.0000 0.0000 [torch.FloatTensor of size 4x4]
Get the k-th diagonal of a given matrix:
>>> a = torch.randn(3, 3) >>> a -1.5328 -1.3210 -1.5204 0.8596 0.0471 -0.2239 -0.6617 0.0146 -1.0817 [torch.FloatTensor of size 3x3] >>> torch.diag(a, 0) -1.5328 0.0471 -1.0817 [torch.FloatTensor of size 3] >>> torch.diag(a, 1) -1.3210 -0.2239 [torch.FloatTensor of size 2]
- If
-
torch.histc(input, bins=100, min=0, max=0, out=None) → Tensor¶ Computes the histogram of a tensor.
The elements are sorted into equal width bins between min and max. If min and max are both zero, the minimum and maximum values of the data are used.
Parameters: Returns: the histogram
Return type: Example:
>>> torch.histc(torch.FloatTensor([1, 2, 1]), bins=4, min=0, max=3) FloatTensor([0, 2, 1, 0])
-
torch.renorm(input, p, dim, maxnorm, out=None) → Tensor¶ Returns a Tensor where each sub-tensor of
inputalong dimensiondimis normalized such that the p-norm of the sub-tensor is lower than the valuemaxnormNote
If the norm of a row is lower than maxnorm, the row is unchanged
Parameters: Example:
>>> x = torch.ones(3, 3) >>> x[1].fill_(2) >>> x[2].fill_(3) >>> x 1 1 1 2 2 2 3 3 3 [torch.FloatTensor of size 3x3] >>> torch.renorm(x, 1, 0, 5) 1.0000 1.0000 1.0000 1.6667 1.6667 1.6667 1.6667 1.6667 1.6667 [torch.FloatTensor of size 3x3]
-
torch.trace(input) → float¶ Returns the sum of the elements of the diagonal of the input 2D matrix.
Example:
>>> x = torch.arange(1, 10).view(3, 3) >>> x 1 2 3 4 5 6 7 8 9 [torch.FloatTensor of size 3x3] >>> torch.trace(x) 15.0
-
torch.tril(input, k=0, out=None) → Tensor¶ Returns the lower triangular part of the matrix (2D Tensor)
input, the other elements of the result Tensoroutare set to 0.The lower triangular part of the matrix is defined as the elements on and below the diagonal.
The argument
kcontrols which diagonal to consider.k= 0, is the main diagonal.k> 0, is above the main diagonal.k< 0, is below the main diagonal.
Parameters: Example:
>>> a = torch.randn(3,3) >>> a 1.3225 1.7304 1.4573 -0.3052 -0.3111 -0.1809 1.2469 0.0064 -1.6250 [torch.FloatTensor of size 3x3] >>> torch.tril(a) 1.3225 0.0000 0.0000 -0.3052 -0.3111 0.0000 1.2469 0.0064 -1.6250 [torch.FloatTensor of size 3x3] >>> torch.tril(a, k=1) 1.3225 1.7304 0.0000 -0.3052 -0.3111 -0.1809 1.2469 0.0064 -1.6250 [torch.FloatTensor of size 3x3] >>> torch.tril(a, k=-1) 0.0000 0.0000 0.0000 -0.3052 0.0000 0.0000 1.2469 0.0064 0.0000 [torch.FloatTensor of size 3x3]
-
torch.triu(input, k=0, out=None) → Tensor¶ Returns the upper triangular part of the matrix (2D Tensor)
input, the other elements of the result Tensoroutare set to 0.The upper triangular part of the matrix is defined as the elements on and above the diagonal.
The argument
kcontrols which diagonal to consider.k= 0, is the main diagonal.k> 0, is above the main diagonal.k< 0, is below the main diagonal.
Parameters: Example:
>>> a = torch.randn(3,3) >>> a 1.3225 1.7304 1.4573 -0.3052 -0.3111 -0.1809 1.2469 0.0064 -1.6250 [torch.FloatTensor of size 3x3] >>> torch.triu(a) 1.3225 1.7304 1.4573 0.0000 -0.3111 -0.1809 0.0000 0.0000 -1.6250 [torch.FloatTensor of size 3x3] >>> torch.triu(a, k=1) 0.0000 1.7304 1.4573 0.0000 0.0000 -0.1809 0.0000 0.0000 0.0000 [torch.FloatTensor of size 3x3] >>> torch.triu(a, k=-1) 1.3225 1.7304 1.4573 -0.3052 -0.3111 -0.1809 0.0000 0.0064 -1.6250 [torch.FloatTensor of size 3x3]
BLAS and LAPACK Operations¶
-
torch.addbmm(beta=1, mat, alpha=1, batch1, batch2, out=None) → Tensor¶ Performs a batch matrix-matrix product of matrices stored in
batch1andbatch2, with a reduced add step (all matrix multiplications get accumulated along the first dimension).matis added to the final result.batch1andbatch2must be 3D Tensors each containing the same number of matrices.If
batch1is a b x n x m Tensor,batch2is a b x m x p Tensor,outandmatwill be n x p Tensors.In other words, \(res = (beta * M) + (alpha * sum(batch1_i @ batch2_i, i = 0, b))\)
For inputs of type FloatTensor or DoubleTensor, args beta and alpha must be real numbers, otherwise they should be integers
Parameters: Example:
>>> M = torch.randn(3, 5) >>> batch1 = torch.randn(10, 3, 4) >>> batch2 = torch.randn(10, 4, 5) >>> torch.addbmm(M, batch1, batch2) -3.1162 11.0071 7.3102 0.1824 -7.6892 1.8265 6.0739 0.4589 -0.5641 -5.4283 -9.3387 -0.1794 -1.2318 -6.8841 -4.7239 [torch.FloatTensor of size 3x5]
-
torch.addmm(beta=1, mat, alpha=1, mat1, mat2, out=None) → Tensor¶ Performs a matrix multiplication of the matrices
mat1andmat2. The matrixmatis added to the final result.If
mat1is a n x m Tensor,mat2is a m x p Tensor,outandmatwill be n x p Tensors.alpha and beta are scaling factors on mat1 @ mat2 and mat respectively.
In other words, \(out = (beta * M) + (alpha * mat1 @ mat2)\)
For inputs of type FloatTensor or DoubleTensor, args
betaandalphamust be real numbers, otherwise they should be integersParameters: Example:
>>> M = torch.randn(2, 3) >>> mat1 = torch.randn(2, 3) >>> mat2 = torch.randn(3, 3) >>> torch.addmm(M, mat1, mat2) -0.4095 -1.9703 1.3561 5.7674 -4.9760 2.7378 [torch.FloatTensor of size 2x3]
-
torch.addmv(beta=1, tensor, alpha=1, mat, vec, out=None) → Tensor¶ Performs a matrix-vector product of the matrix
matand the vectorvec. The vectortensoris added to the final result.If
matis a n x m Tensor,vecis a 1D Tensor of size m,outandtensorwill be 1D of size n.alpha and beta are scaling factors on mat * vec and tensor respectively.
In other words:
\(out = (beta * tensor) + (alpha * (mat @ vec2))\)
For inputs of type FloatTensor or DoubleTensor, args
betaandalphamust be real numbers, otherwise they should be integersParameters: Example:
>>> M = torch.randn(2) >>> mat = torch.randn(2, 3) >>> vec = torch.randn(3) >>> torch.addmv(M, mat, vec) -2.0939 -2.2950 [torch.FloatTensor of size 2]
-
torch.addr(beta=1, mat, alpha=1, vec1, vec2, out=None) → Tensor¶ Performs the outer-product of vectors
vec1andvec2and adds it to the matrixmat.Optional values
betaandalphaare scalars that multiplymatand \((vec1 \otimes vec2)\) respectivelyIn other words, \(out = (beta * mat) + (alpha * vec1 \otimes vec2)\)
If
vec1is a vector of size n andvec2is a vector of size m, thenmatmust be a matrix of size n x mFor inputs of type FloatTensor or DoubleTensor, args
betaandalphamust be real numbers, otherwise they should be integersParameters: Example:
>>> vec1 = torch.arange(1, 4) >>> vec2 = torch.arange(1, 3) >>> M = torch.zeros(3, 2) >>> torch.addr(M, vec1, vec2) 1 2 2 4 3 6 [torch.FloatTensor of size 3x2]
-
torch.baddbmm(beta=1, mat, alpha=1, batch1, batch2, out=None) → Tensor¶ Performs a batch matrix-matrix product of matrices in
batch1andbatch2.matis added to the final result.batch1andbatch2must be 3D Tensors each containing the same number of matrices.If
batch1is a b x n x m Tensor,batch2is a b x m x p Tensor,outandmatwill be b x n x p Tensors.In other words, \(res_i = (beta * M_i) + (alpha * batch1_i \times batch2_i)\)
For inputs of type FloatTensor or DoubleTensor, args
betaandalphamust be real numbers, otherwise they should be integersParameters: Example:
>>> M = torch.randn(10, 3, 5) >>> batch1 = torch.randn(10, 3, 4) >>> batch2 = torch.randn(10, 4, 5) >>> torch.baddbmm(M, batch1, batch2).size() torch.Size([10, 3, 5])
-
torch.bmm(batch1, batch2, out=None) → Tensor¶ Performs a batch matrix-matrix product of matrices stored in
batch1andbatch2.batch1andbatch2must be 3D Tensors each containing the same number of matrices.If
batch1is a b x n x m Tensor,batch2is a b x m x p Tensor,outwill be a b x n x p Tensor.Parameters: Example:
>>> batch1 = torch.randn(10, 3, 4) >>> batch2 = torch.randn(10, 4, 5) >>> res = torch.bmm(batch1, batch2) >>> res.size() torch.Size([10, 3, 5])
-
torch.btrifact(A, info=None) → Tensor, IntTensor¶ Batch LU factorization.
Returns a tuple containing the LU factorization and pivots. The optional argument info provides information if the factorization succeeded for each minibatch example. The info values are from dgetrf and a non-zero value indicates an error occurred. The specific values are from cublas if cuda is being used, otherwise LAPACK.
Parameters: A (Tensor) – tensor to factor. Example:
>>> A = torch.randn(2, 3, 3) >>> A_LU = A.btrifact()
-
torch.btrisolve(b, LU_data, LU_pivots) → Tensor¶ Batch LU solve.
Returns the LU solve of the linear system Ax = b.
Parameters: Example:
>>> A = torch.randn(2, 3, 3) >>> b = torch.randn(2, 3) >>> A_LU = torch.btrifact(A) >>> x = b.btrisolve(*A_LU) >>> torch.norm(A.bmm(x.unsqueeze(2)) - b) 6.664001874625056e-08
-
torch.dot(tensor1, tensor2) → float¶ Computes the dot product (inner product) of two tensors. Both tensors are treated as 1-D vectors.
Example:
>>> torch.dot(torch.Tensor([2, 3]), torch.Tensor([2, 1])) 7.0
-
torch.eig(a, eigenvectors=False, out=None) -> (Tensor, Tensor)¶ Computes the eigenvalues and eigenvectors of a real square matrix.
Parameters: Returns: tuple containing
- e (Tensor): the right eigenvalues of
a - v (Tensor): the eigenvectors of
aifeigenvectors` is ``True; otherwise an empty tensor
Return type: - e (Tensor): the right eigenvalues of
-
torch.gels(B, A, out=None) → Tensor¶ Computes the solution to the least squares and least norm problems for a full rank \(m\) by \(n\) matrix \(A\).
If \(m >= n\),
gels()solves the least-squares problem:\[\begin{array}{ll} \mbox{minimize} & \|AX-B\|_F. \end{array}\]If \(m < n\),
gels()solves the least-norm problem:\[\begin{array}{ll} \mbox{minimize} & \|X\|_F & \mbox{subject to} & AX = B. \end{array}\]The first \(n\) rows of the returned matrix \(X\) contains the solution. The remaining rows contain residual information: the euclidean norm of each column starting at row \(n\) is the residual for the corresponding column.
Parameters: Returns: tuple containing:
- X (Tensor): the least squares solution
- qr (Tensor): the details of the QR factorization
Return type: Note
The returned matrices will always be tranposed, irrespective of the strides of the input matrices. That is, they will have stride (1, m) instead of (m, 1).
Example:
>>> A = torch.Tensor([[1, 1, 1], ... [2, 3, 4], ... [3, 5, 2], ... [4, 2, 5], ... [5, 4, 3]]) >>> B = torch.Tensor([[-10, -3], [ 12, 14], [ 14, 12], [ 16, 16], [ 18, 16]]) >>> X, _ = torch.gels(B, A) >>> X 2.0000 1.0000 1.0000 1.0000 1.0000 2.0000 [torch.FloatTensor of size 3x2]
-
torch.geqrf(input, out=None) -> (Tensor, Tensor)¶ This is a low-level function for calling LAPACK directly.
You’ll generally want to use
torch.qr()instead.Computes a QR decomposition of
input, but without constructing Q and R as explicit separate matrices.Rather, this directly calls the underlying LAPACK function ?geqrf which produces a sequence of ‘elementary reflectors’.
See LAPACK documentation for further details.
Parameters:
-
torch.ger(vec1, vec2, out=None) → Tensor¶ Outer product of
vec1andvec2. Ifvec1is a vector of size n andvec2is a vector of size m, thenoutmust be a matrix of size n x m.Parameters: Example:
>>> v1 = torch.arange(1, 5) >>> v2 = torch.arange(1, 4) >>> torch.ger(v1, v2) 1 2 3 2 4 6 3 6 9 4 8 12 [torch.FloatTensor of size 4x3]
-
torch.gesv(B, A, out=None) -> (Tensor, Tensor)¶ X, LU = torch.gesv(B, A) returns the solution to the system of linear equations represented by \(AX = B\)
LU contains L and U factors for LU factorization of A.
Ahas to be a square and non-singular matrix (2D Tensor).If A is an m x m matrix and B is m x k, the result LU is m x m and X is m x k .
Note
Irrespective of the original strides, the returned matrices X and LU will be transposed, i.e. with strides (1, m) instead of (m, 1).
Parameters: Example:
>>> A = torch.Tensor([[6.80, -2.11, 5.66, 5.97, 8.23], ... [-6.05, -3.30, 5.36, -4.44, 1.08], ... [-0.45, 2.58, -2.70, 0.27, 9.04], ... [8.32, 2.71, 4.35, -7.17, 2.14], ... [-9.67, -5.14, -7.26, 6.08, -6.87]]).t() >>> B = torch.Tensor([[4.02, 6.19, -8.22, -7.57, -3.03], ... [-1.56, 4.00, -8.67, 1.75, 2.86], ... [9.81, -4.09, -4.57, -8.61, 8.99]]).t() >>> X, LU = torch.gesv(B, A) >>> torch.dist(B, torch.mm(A, X)) 9.250057093890353e-06
-
torch.inverse(input, out=None) → Tensor¶ Takes the inverse of the square matrix
input.Note
Irrespective of the original strides, the returned matrix will be transposed, i.e. with strides (1, m) instead of (m, 1)
Parameters: Example:
>>> x = torch.rand(10, 10) >>> x 0.7800 0.2267 0.7855 0.9479 0.5914 0.7119 0.4437 0.9131 0.1289 0.1982 0.0045 0.0425 0.2229 0.4626 0.6210 0.0207 0.6338 0.7067 0.6381 0.8196 0.8350 0.7810 0.8526 0.9364 0.7504 0.2737 0.0694 0.5899 0.8516 0.3883 0.6280 0.6016 0.5357 0.2936 0.7827 0.2772 0.0744 0.2627 0.6326 0.9153 0.7897 0.0226 0.3102 0.0198 0.9415 0.9896 0.3528 0.9397 0.2074 0.6980 0.5235 0.6119 0.6522 0.3399 0.3205 0.5555 0.8454 0.3792 0.4927 0.6086 0.1048 0.0328 0.5734 0.6318 0.9802 0.4458 0.0979 0.3320 0.3701 0.0909 0.2616 0.3485 0.4370 0.5620 0.5291 0.8295 0.7693 0.1807 0.0650 0.8497 0.1655 0.2192 0.6913 0.0093 0.0178 0.3064 0.6715 0.5101 0.2561 0.3396 0.4370 0.4695 0.8333 0.1180 0.4266 0.4161 0.0699 0.4263 0.8865 0.2578 [torch.FloatTensor of size 10x10] >>> x = torch.rand(10, 10) >>> y = torch.inverse(x) >>> z = torch.mm(x, y) >>> z 1.0000 0.0000 0.0000 -0.0000 0.0000 0.0000 0.0000 0.0000 -0.0000 -0.0000 0.0000 1.0000 -0.0000 0.0000 0.0000 0.0000 -0.0000 -0.0000 -0.0000 -0.0000 0.0000 0.0000 1.0000 -0.0000 -0.0000 0.0000 0.0000 0.0000 -0.0000 -0.0000 0.0000 0.0000 0.0000 1.0000 0.0000 0.0000 0.0000 -0.0000 -0.0000 0.0000 0.0000 0.0000 -0.0000 -0.0000 1.0000 0.0000 0.0000 -0.0000 -0.0000 -0.0000 0.0000 0.0000 0.0000 -0.0000 0.0000 1.0000 -0.0000 -0.0000 -0.0000 -0.0000 0.0000 0.0000 0.0000 -0.0000 0.0000 0.0000 1.0000 0.0000 -0.0000 0.0000 0.0000 0.0000 -0.0000 -0.0000 0.0000 0.0000 -0.0000 1.0000 -0.0000 0.0000 -0.0000 0.0000 -0.0000 -0.0000 0.0000 0.0000 -0.0000 -0.0000 1.0000 -0.0000 -0.0000 0.0000 -0.0000 -0.0000 -0.0000 0.0000 -0.0000 -0.0000 0.0000 1.0000 [torch.FloatTensor of size 10x10] >>> torch.max(torch.abs(z - torch.eye(10))) # Max nonzero 5.096662789583206e-07
-
torch.mm(mat1, mat2, out=None) → Tensor¶ Performs a matrix multiplication of the matrices
mat1andmat2.If
mat1is a n x m Tensor,mat2is a m x p Tensor,outwill be a n x p Tensor.Parameters: Example:
>>> mat1 = torch.randn(2, 3) >>> mat2 = torch.randn(3, 3) >>> torch.mm(mat1, mat2) 0.0519 -0.3304 1.2232 4.3910 -5.1498 2.7571 [torch.FloatTensor of size 2x3]
-
torch.mv(mat, vec, out=None) → Tensor¶ Performs a matrix-vector product of the matrix
matand the vectorvec.If
matis a n x m Tensor,vecis a 1D Tensor of size m,outwill be 1D of size n.Parameters: Example:
>>> mat = torch.randn(2, 3) >>> vec = torch.randn(3) >>> torch.mv(mat, vec) -2.0939 -2.2950 [torch.FloatTensor of size 2]
-
torch.orgqr()¶
-
torch.ormqr()¶
-
torch.potrf()¶
-
torch.potri()¶
-
torch.potrs()¶
-
torch.pstrf()¶
-
torch.qr(input, out=None) -> (Tensor, Tensor)¶ Computes the QR decomposition of a matrix
input: returns matrices q and r such that \(x = q * r\), with q being an orthogonal matrix and r being an upper triangular matrix.This returns the thin (reduced) QR factorization.
Note
precision may be lost if the magnitudes of the elements of input are large
Note
while it should always give you a valid decomposition, it may not give you the same one across platforms - it will depend on your LAPACK implementation.
Note
Irrespective of the original strides, the returned matrix q will be transposed, i.e. with strides (1, m) instead of (m, 1).
Parameters: Example:
>>> a = torch.Tensor([[12, -51, 4], [6, 167, -68], [-4, 24, -41]]) >>> q, r = torch.qr(a) >>> q -0.8571 0.3943 0.3314 -0.4286 -0.9029 -0.0343 0.2857 -0.1714 0.9429 [torch.FloatTensor of size 3x3] >>> r -14.0000 -21.0000 14.0000 0.0000 -175.0000 70.0000 0.0000 0.0000 -35.0000 [torch.FloatTensor of size 3x3] >>> torch.mm(q, r).round() 12 -51 4 6 167 -68 -4 24 -41 [torch.FloatTensor of size 3x3] >>> torch.mm(q.t(), q).round() 1 -0 0 -0 1 0 0 0 1 [torch.FloatTensor of size 3x3]
-
torch.svd(input, some=True, out=None) -> (Tensor, Tensor, Tensor)¶ U, S, V = torch.svd(A) returns the singular value decomposition of a real matrix A of size (n x m) such that \(A = USV'*\).
U is of shape n x n
S is of shape n x m
V is of shape m x m.
somerepresents the number of singular values to be computed. If some=True, it computes some and some=False computes all.Note
Irrespective of the original strides, the returned matrix U will be transposed, i.e. with strides (1, n) instead of (n, 1).
Parameters: Example:
>>> a = torch.Tensor([[8.79, 6.11, -9.15, 9.57, -3.49, 9.84], ... [9.93, 6.91, -7.93, 1.64, 4.02, 0.15], ... [9.83, 5.04, 4.86, 8.83, 9.80, -8.99], ... [5.45, -0.27, 4.85, 0.74, 10.00, -6.02], ... [3.16, 7.98, 3.01, 5.80, 4.27, -5.31]]).t() >>> a 8.7900 9.9300 9.8300 5.4500 3.1600 6.1100 6.9100 5.0400 -0.2700 7.9800 -9.1500 -7.9300 4.8600 4.8500 3.0100 9.5700 1.6400 8.8300 0.7400 5.8000 -3.4900 4.0200 9.8000 10.0000 4.2700 9.8400 0.1500 -8.9900 -6.0200 -5.3100 [torch.FloatTensor of size 6x5] >>> u, s, v = torch.svd(a) >>> u -0.5911 0.2632 0.3554 0.3143 0.2299 -0.3976 0.2438 -0.2224 -0.7535 -0.3636 -0.0335 -0.6003 -0.4508 0.2334 -0.3055 -0.4297 0.2362 -0.6859 0.3319 0.1649 -0.4697 -0.3509 0.3874 0.1587 -0.5183 0.2934 0.5763 -0.0209 0.3791 -0.6526 [torch.FloatTensor of size 6x5] >>> s 27.4687 22.6432 8.5584 5.9857 2.0149 [torch.FloatTensor of size 5] >>> v -0.2514 0.8148 -0.2606 0.3967 -0.2180 -0.3968 0.3587 0.7008 -0.4507 0.1402 -0.6922 -0.2489 -0.2208 0.2513 0.5891 -0.3662 -0.3686 0.3859 0.4342 -0.6265 -0.4076 -0.0980 -0.4932 -0.6227 -0.4396 [torch.FloatTensor of size 5x5] >>> torch.dist(a, torch.mm(torch.mm(u, torch.diag(s)), v.t())) 8.934150226306685e-06
-
torch.symeig(input, eigenvectors=False, upper=True, out=None) -> (Tensor, Tensor)¶ e, V = torch.symeig(input) returns eigenvalues and eigenvectors of a symmetric real matrix
input.input and V are m x m matrices and e is a m dimensional vector.
This function calculates all eigenvalues (and vectors) of input such that input = V diag(e) V’
The boolean argument
eigenvectorsdefines computation of eigenvectors or eigenvalues only.If it is False, only eigenvalues are computed. If it is True, both eigenvalues and eigenvectors are computed.
Since the input matrix input is supposed to be symmetric, only the upper triangular portion is used by default.
If
upperis False, then lower triangular portion is used.Note: Irrespective of the original strides, the returned matrix V will be transposed, i.e. with strides (1, m) instead of (m, 1).
Parameters: Examples:
>>> a = torch.Tensor([[ 1.96, 0.00, 0.00, 0.00, 0.00], ... [-6.49, 3.80, 0.00, 0.00, 0.00], ... [-0.47, -6.39, 4.17, 0.00, 0.00], ... [-7.20, 1.50, -1.51, 5.70, 0.00], ... [-0.65, -6.34, 2.67, 1.80, -7.10]]).t() >>> e, v = torch.symeig(a, eigenvectors=True) >>> e -11.0656 -6.2287 0.8640 8.8655 16.0948 [torch.FloatTensor of size 5] >>> v -0.2981 -0.6075 0.4026 -0.3745 0.4896 -0.5078 -0.2880 -0.4066 -0.3572 -0.6053 -0.0816 -0.3843 -0.6600 0.5008 0.3991 -0.0036 -0.4467 0.4553 0.6204 -0.4564 -0.8041 0.4480 0.1725 0.3108 0.1622 [torch.FloatTensor of size 5x5]
-
torch.trtrs()¶