Skip to content

Movement

Movement (low level)¤

view ¤

view(*shape) -> Tensor

.view is an alias for .reshape.

Source code in tinygrad/tensor.py
807
808
809
def view(self, *shape) -> Tensor:
  """`.view` is an alias for `.reshape`."""
  return self.reshape(shape)

reshape ¤

reshape(shape, *args) -> Tensor

Returns a tensor with the same data as the original tensor but with a different shape. shape can be passed as a tuple or as separate arguments.

t = Tensor.arange(6)
print(t.reshape(2, 3).numpy())
[[0 1 2]
 [3 4 5]]
Source code in tinygrad/tensor.py
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
def reshape(self, shape, *args) -> Tensor:
  """
  Returns a tensor with the same data as the original tensor but with a different shape.
  `shape` can be passed as a tuple or as separate arguments.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(6)
  print(t.reshape(2, 3).numpy())
  ```
  """
  # resolve None and args
  new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))])
  # resolve -1
  if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
  if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
  return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self

expand ¤

expand(shape, *args) -> Tensor

Returns a tensor that is expanded to the shape that is specified. Expand can also increase the number of dimensions that a tensor has.

Passing a -1 or None to a dimension means that its size will not be changed.

t = Tensor([1, 2, 3])
print(t.expand(4, -1).numpy())
[[1 2 3]
 [1 2 3]
 [1 2 3]
 [1 2 3]]
Source code in tinygrad/tensor.py
828
829
830
831
832
833
834
835
836
837
838
839
840
def expand(self, shape, *args) -> Tensor:
  """
  Returns a tensor that is expanded to the shape that is specified.
  Expand can also increase the number of dimensions that a tensor has.

  Passing a `-1` or `None` to a dimension means that its size will not be changed.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3])
  print(t.expand(4, -1).numpy())
  ```
  """
  return self._broadcast_to(tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_pad_left(self.shape, argfix(shape, *args))))))

permute ¤

permute(order, *args) -> Tensor

Returns a tensor that is a permutation of the original tensor. The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified. order can be passed as a tuple or as separate arguments.

t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
[[0 1 2]
 [3 4 5]]
print(t.permute(1, 0).numpy())
[[0 3]
 [1 4]
 [2 5]]

Source code in tinygrad/tensor.py
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
def permute(self, order, *args) -> Tensor:
  """
  Returns a tensor that is a permutation of the original tensor.
  The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
  `order` can be passed as a tuple or as separate arguments.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(6).reshape(2, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.permute(1, 0).numpy())
  ```
  """
  order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
  if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
  return F.Permute.apply(self, order=order_arg)

flip ¤

flip(axis, *args) -> Tensor

Returns a tensor that reverses the order of the original tensor along given axis. axis can be passed as a tuple or as separate arguments.

t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
[[0 1 2]
 [3 4 5]]
print(t.flip(0).numpy())
[[3 4 5]
 [0 1 2]]
print(t.flip((0, 1)).numpy())
[[5 4 3]
 [2 1 0]]

Source code in tinygrad/tensor.py
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
def flip(self, axis, *args) -> Tensor:
  """
  Returns a tensor that reverses the order of the original tensor along given `axis`.
  `axis` can be passed as a tuple or as separate arguments.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(6).reshape(2, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.flip(0).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.flip((0, 1)).numpy())
  ```
  """
  axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
  if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at least once, getting {axis_arg}")
  return F.Flip.apply(self, axis=axis_arg)

shrink ¤

shrink(
    arg: Tuple[Optional[Tuple[sint, sint]], ...]
) -> Tensor

Returns a tensor that shrinks the each axis based on input arg. arg must have the same length as self.ndim. For each axis, it can be None, which means no shrink, or a tuple (start, end) that works the same as Python slice.

t = Tensor.arange(9).reshape(3, 3)
print(t.numpy())
[[0 1 2]
 [3 4 5]
 [6 7 8]]
print(t.shrink(((None, (1, 3)))).numpy())
[[1 2]
 [4 5]
 [7 8]]
print(t.shrink((((0, 2), (0, 2)))).numpy())
[[0 1]
 [3 4]]

Source code in tinygrad/tensor.py
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
  """
  Returns a tensor that shrinks the each axis based on input arg.
  `arg` must have the same length as `self.ndim`.
  For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(9).reshape(3, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.shrink(((None, (1, 3)))).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.shrink((((0, 2), (0, 2)))).numpy())
  ```
  """
  if all(x is None or x == (0,s) for x,s in zip(arg, self.shape)): return self
  return F.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape)))

pad ¤

pad(
    arg: Tuple[Optional[Tuple[sint, sint]], ...],
    value: float = 0.0,
) -> Tensor

Returns a tensor that pads the each axis based on input arg. arg must have the same length as self.ndim. For each axis, it can be None, which means no pad, or a tuple (pad_before, pad_after). If value is specified, the tensor is padded with value instead of 0.0.

t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
[[0 1 2]
 [3 4 5]]
print(t.pad(((None, (1, 2)))).numpy())
[[0 0 1 2 0 0]
 [0 3 4 5 0 0]]
print(t.pad(((None, (1, 2))), -2).numpy())
[[-2  0  1  2 -2 -2]
 [-2  3  4  5 -2 -2]]

Source code in tinygrad/tensor.py
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
def pad(self, arg:Tuple[Optional[Tuple[sint, sint]], ...], value:float=0.0) -> Tensor:
  """
  Returns a tensor that pads the each axis based on input arg.
  `arg` must have the same length as `self.ndim`.
  For each axis, it can be `None`, which means no pad, or a tuple `(pad_before, pad_after)`.
  If `value` is specified, the tensor is padded with `value` instead of `0.0`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(6).reshape(2, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.pad(((None, (1, 2)))).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.pad(((None, (1, 2))), -2).numpy())
  ```
  """
  if all(x is None or x == (0,0) for x in arg): return self
  ret = F.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg)))
  return ret if 0 == value else ret + F.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value)

Movement (high level)¤

gather ¤

gather(dim: int, index: Tensor) -> Tensor

Gathers values along an axis specified by dim.

t = Tensor([[1, 2], [3, 4]])
print(t.numpy())
[[1 2]
 [3 4]]
print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
[[1 1]
 [4 3]]

Source code in tinygrad/tensor.py
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
def gather(self:Tensor, dim:int, index:Tensor) -> Tensor:
  """
  Gathers values along an axis specified by `dim`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([[1, 2], [3, 4]])
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
  ```
  """
  assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
  dim = self._resolve_dim(dim)
  assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
  index = index.to(self.device)
  x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
  return ((index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * x).sum(-1, acc_dtype=self.dtype)

cat ¤

cat(*args: Tensor, dim: int = 0) -> Tensor

Concatenates self with other Tensor in args along an axis specified by dim. All tensors must have the same shape except in the concatenating dimension.

t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
print(t0.cat(t1, t2, dim=0).numpy())
[[1 2]
 [3 4]
 [5 6]]
print(t0.cat(t1, t2, dim=1).numpy())
[[1 2 3 4 5 6]]

Source code in tinygrad/tensor.py
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
  """
  Concatenates self with other `Tensor` in `args` along an axis specified by `dim`.
  All tensors must have the same shape except in the concatenating dimension.

  ```python exec="true" source="above" session="tensor" result="python"
  t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
  print(t0.cat(t1, t2, dim=0).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t0.cat(t1, t2, dim=1).numpy())
  ```
  """
  dim = self._resolve_dim(dim)
  assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
  catargs = [self, *args]
  cat_dims = [s.shape[dim] for s in catargs]
  cat_dim_cumsum = [0, *itertools.accumulate(cat_dims)]
  slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs]
  for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d)
  return functools.reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])

stack ¤

stack(*args: Tensor, dim: int = 0) -> Tensor

Concatenates self with other Tensor in args along a new dimension specified by dim.

t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
print(t0.stack(t1, t2, dim=0).numpy())
[[1 2]
 [3 4]
 [5 6]]
print(t0.stack(t1, t2, dim=1).numpy())
[[1 3 5]
 [2 4 6]]

Source code in tinygrad/tensor.py
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
  """
  Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`.

  ```python exec="true" source="above" session="tensor" result="python"
  t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
  print(t0.stack(t1, t2, dim=0).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t0.stack(t1, t2, dim=1).numpy())
  ```
  """
  # checks for shapes and number of dimensions delegated to cat
  return self.unsqueeze(dim).cat(*[t.unsqueeze(dim) for t in args], dim=dim)

repeat ¤

repeat(repeats, *args) -> Tensor

Repeats tensor number of times along each dimension specified by repeats. repeats can be passed as a tuple or as separate arguments.

t = Tensor([1, 2, 3])
print(t.repeat(4, 2).numpy())
[[1 2 3 1 2 3]
 [1 2 3 1 2 3]
 [1 2 3 1 2 3]
 [1 2 3 1 2 3]]
print(t.repeat(4, 2, 1).shape)
(4, 2, 3)

Source code in tinygrad/tensor.py
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
def repeat(self, repeats, *args) -> Tensor:
  """
  Repeats tensor number of times along each dimension specified by `repeats`.
  `repeats` can be passed as a tuple or as separate arguments.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3])
  print(t.repeat(4, 2).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.repeat(4, 2, 1).shape)
  ```
  """
  repeats = argfix(repeats, *args)
  base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
  new_shape = [x for b in base_shape for x in [1, b]]
  expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
  final_shape = [r*s for r,s in zip(repeats, base_shape)]
  return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)

repeat_interleave ¤

repeat_interleave(
    repeats: int, dim: Optional[int] = None
) -> Tensor

Repeat elements of a tensor.

t = Tensor([1, 2, 3])
print(t.repeat_interleave(2).numpy())
[1 1 2 2 3 3]
Source code in tinygrad/tensor.py
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
def repeat_interleave(self, repeats:int, dim:Optional[int]=None) -> Tensor:
  """
  Repeat elements of a tensor.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3])
  print(t.repeat_interleave(2).numpy())
  ```
  """
  x, dim = (self.flatten(), 0) if dim is None else (self, dim)
  shp = x.shape
  return x.reshape(*shp[:dim+1], 1, *shp[dim+1:]).expand(*shp[:dim+1], repeats, *shp[dim+1:]).reshape(*shp[:dim], shp[dim]*repeats, *shp[dim+1:])

split ¤

split(
    sizes: Union[int, List[int]], dim: int = 0
) -> Tuple[Tensor, ...]

Splits the tensor into chunks along the dimension specified by dim. If sizes is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller. If sizes is a list, it splits into len(sizes) chunks with size in dim according to size.

t = Tensor.arange(10).reshape(5, 2)
print(t.numpy())
[[0 1]
 [2 3]
 [4 5]
 [6 7]
 [8 9]]
split = t.split(2)
print("\n".join([repr(x.numpy()) for x in split]))
array([[0, 1],
       [2, 3]], dtype=int32)
array([[4, 5],
       [6, 7]], dtype=int32)
array([[8, 9]], dtype=int32)
split = t.split([1, 4])
print("\n".join([repr(x.numpy()) for x in split]))
array([[0, 1]], dtype=int32)
array([[2, 3],
       [4, 5],
       [6, 7],
       [8, 9]], dtype=int32)

Source code in tinygrad/tensor.py
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
  """
  Splits the tensor into chunks along the dimension specified by `dim`.
  If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
  If `sizes` is a list, it splits into `len(sizes)` chunks with size in `dim` according to `size`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(10).reshape(5, 2)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  split = t.split(2)
  print("\\n".join([repr(x.numpy()) for x in split]))
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  split = t.split([1, 4])
  print("\\n".join([repr(x.numpy()) for x in split]))
  ```
  """
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
  dim = self._resolve_dim(dim)
  if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
  assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
  return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])

chunk ¤

chunk(chunks: int, dim: int = 0) -> List[Tensor]

Splits the tensor into chunks number of chunks along the dimension dim. If the tensor size along dim is not divisible by chunks, all returned chunks will be the same size except the last one. The function may return fewer than the specified number of chunks.

chunked = Tensor.arange(11).chunk(6)
print("\n".join([repr(x.numpy()) for x in chunked]))
array([0, 1], dtype=int32)
array([2, 3], dtype=int32)
array([4, 5], dtype=int32)
array([6, 7], dtype=int32)
array([8, 9], dtype=int32)
array([10], dtype=int32)
chunked = Tensor.arange(12).chunk(6)
print("\n".join([repr(x.numpy()) for x in chunked]))
array([0, 1], dtype=int32)
array([2, 3], dtype=int32)
array([4, 5], dtype=int32)
array([6, 7], dtype=int32)
array([8, 9], dtype=int32)
array([10, 11], dtype=int32)
chunked = Tensor.arange(13).chunk(6)
print("\n".join([repr(x.numpy()) for x in chunked]))
array([0, 1, 2], dtype=int32)
array([3, 4, 5], dtype=int32)
array([6, 7, 8], dtype=int32)
array([ 9, 10, 11], dtype=int32)
array([12], dtype=int32)

Source code in tinygrad/tensor.py
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
def chunk(self, chunks:int, dim:int=0) -> List[Tensor]:
  """
  Splits the tensor into `chunks` number of chunks along the dimension `dim`.
  If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one.
  The function may return fewer than the specified number of chunks.

  ```python exec="true" source="above" session="tensor" result="python"
  chunked = Tensor.arange(11).chunk(6)
  print("\\n".join([repr(x.numpy()) for x in chunked]))
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  chunked = Tensor.arange(12).chunk(6)
  print("\\n".join([repr(x.numpy()) for x in chunked]))
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  chunked = Tensor.arange(13).chunk(6)
  print("\\n".join([repr(x.numpy()) for x in chunked]))
  ```
  """
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
  assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
  dim = self._resolve_dim(dim)
  return list(self.split(math.ceil(self.shape[dim]/chunks) if self.shape[dim] else [0]*chunks, dim=dim))

squeeze ¤

squeeze(dim: Optional[int] = None) -> Tensor

Returns a tensor with specified dimensions of input of size 1 removed. If dim is not specified, all dimensions with size 1 are removed.

t = Tensor.zeros(2, 1, 2, 1, 2)
print(t.squeeze().shape)
(2, 2, 2)
print(t.squeeze(0).shape)
(2, 1, 2, 1, 2)
print(t.squeeze(1).shape)
(2, 2, 1, 2)

Source code in tinygrad/tensor.py
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
def squeeze(self, dim:Optional[int]=None) -> Tensor:
  """
  Returns a tensor with specified dimensions of input of size 1 removed.
  If `dim` is not specified, all dimensions with size 1 are removed.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.zeros(2, 1, 2, 1, 2)
  print(t.squeeze().shape)
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.squeeze(0).shape)
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.squeeze(1).shape)
  ```
  """
  if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
  dim = self._resolve_dim(dim)
  return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])

unsqueeze ¤

unsqueeze(dim: int) -> Tensor

Returns a tensor with a new dimension of size 1 inserted at the specified dim.

t = Tensor([1, 2, 3, 4])
print(t.unsqueeze(0).numpy())
[[1 2 3 4]]
print(t.unsqueeze(1).numpy())
[[1]
 [2]
 [3]
 [4]]

Source code in tinygrad/tensor.py
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
def unsqueeze(self, dim:int) -> Tensor:
  """
  Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3, 4])
  print(t.unsqueeze(0).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.unsqueeze(1).numpy())
  ```
  """
  dim = self._resolve_dim(dim, outer=True)
  return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])

pad2d ¤

pad2d(padding: Sequence[int], value: float = 0.0) -> Tensor

Returns a tensor that pads the last two axes specified by padding (padding_left, padding_right, padding_top, padding_bottom). If value is specified, the tensor is padded with value instead of 0.0.

t = Tensor.arange(9).reshape(1, 1, 3, 3)
print(t.numpy())
[[[[0 1 2]
   [3 4 5]
   [6 7 8]]]]
print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy())
[[[[-inf -inf -inf -inf -inf]
   [-inf -inf -inf -inf -inf]
   [-inf   0.   1.   2. -inf]
   [-inf   3.   4.   5. -inf]
   [-inf   6.   7.   8. -inf]]]]

Source code in tinygrad/tensor.py
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
def pad2d(self, padding:Sequence[int], value:float=0.0) -> Tensor:
  """
  Returns a tensor that pads the last two axes specified by `padding` (padding_left, padding_right, padding_top, padding_bottom).
  If `value` is specified, the tensor is padded with `value` instead of `0.0`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(9).reshape(1, 1, 3, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy())
  ```
  """
  pads = tuple((max(p0, 0), max(p1, 0)) for p0, p1 in zip(padding[::2], padding[1::2]))[::-1]
  padded = self.pad((None,) * (self.ndim - len(padding) // 2) + tuple(pads), value=value)
  shrink = tuple((-min(p0, 0), min(p1 + s, s)) for p0, p1, s in zip(padding[::2], padding[1::2], padded.shape[::-1]))[::-1]
  return padded.shrink((None,) * (self.ndim - len(padding) // 2) + shrink)

T property ¤

T: Tensor

.T is an alias for .transpose().

transpose ¤

transpose(dim0=1, dim1=0) -> Tensor

Returns a tensor that is a transposed version of the original tensor. The given dimensions dim0 and dim1 are swapped.

t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
[[0 1 2]
 [3 4 5]]
print(t.transpose(0, 1).numpy())
[[0 3]
 [1 4]
 [2 5]]

Source code in tinygrad/tensor.py
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
def transpose(self, dim0=1, dim1=0) -> Tensor:
  """
  Returns a tensor that is a transposed version of the original tensor.
  The given dimensions `dim0` and `dim1` are swapped.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(6).reshape(2, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.transpose(0, 1).numpy())
  ```
  """
  order = list(range(self.ndim))
  order[dim0], order[dim1] = order[dim1], order[dim0]
  return self.permute(order)

flatten ¤

flatten(start_dim=0, end_dim=-1)

Flattens the tensor by reshaping it into a one-dimensional tensor. If start_dim or end_dim are passed, only dimensions starting with start_dim and ending with end_dim are flattened.

t = Tensor.arange(8).reshape(2, 2, 2)
print(t.flatten().numpy())
[0 1 2 3 4 5 6 7]
print(t.flatten(start_dim=1).numpy())
[[0 1 2 3]
 [4 5 6 7]]

Source code in tinygrad/tensor.py
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
def flatten(self, start_dim=0, end_dim=-1):
  """
  Flattens the tensor by reshaping it into a one-dimensional tensor.
  If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(8).reshape(2, 2, 2)
  print(t.flatten().numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.flatten(start_dim=1).numpy())
  ```
  """
  start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
  return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])

unflatten ¤

unflatten(dim: int, sizes: Tuple[int, ...])

Unflattens dimension dim of the tensor into multiple dimensions specified by sizes. Tensor.flatten() is the inverse of this function.

print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
(3, 2, 2, 1)
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
(3, 2, 2, 1)
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
(5, 2, 2, 3, 1, 1, 3)

Source code in tinygrad/tensor.py
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
def unflatten(self, dim:int, sizes:Tuple[int,...]):
  """
  Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.

  ```python exec="true" source="above" session="tensor" result="python"
  print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
  ```
  """
  dim = self._resolve_dim(dim)
  return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])