Skip to content

Movement

Movement (low level)¤

view ¤

view(*shape) -> Tensor

.view is an alias for .reshape.

Source code in tinygrad/tensor.py
910
911
912
def view(self, *shape) -> Tensor:
  """`.view` is an alias for `.reshape`."""
  return self.reshape(shape)

reshape ¤

reshape(shape, *args) -> Tensor

Returns a tensor with the same data as the original tensor but with a different shape. shape can be passed as a tuple or as separate arguments.

t = Tensor.arange(6)
print(t.reshape(2, 3).numpy())
[[0 1 2]
 [3 4 5]]
Source code in tinygrad/tensor.py
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
def reshape(self, shape, *args) -> Tensor:
  """
  Returns a tensor with the same data as the original tensor but with a different shape.
  `shape` can be passed as a tuple or as separate arguments.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(6)
  print(t.reshape(2, 3).numpy())
  ```
  """
  # resolve None and args
  new_shape = tuple([s if s is not None else self.shape[i] for i,s in enumerate(argfix(shape, *args))])
  # resolve -1
  if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
  if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
  return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self

expand ¤

expand(shape, *args) -> Tensor

Returns a tensor that is expanded to the shape that is specified. Expand can also increase the number of dimensions that a tensor has.

Passing a -1 or None to a dimension means that its size will not be changed.

t = Tensor([1, 2, 3])
print(t.expand(4, -1).numpy())
[[1 2 3]
 [1 2 3]
 [1 2 3]
 [1 2 3]]
Source code in tinygrad/tensor.py
931
932
933
934
935
936
937
938
939
940
941
942
943
def expand(self, shape, *args) -> Tensor:
  """
  Returns a tensor that is expanded to the shape that is specified.
  Expand can also increase the number of dimensions that a tensor has.

  Passing a `-1` or `None` to a dimension means that its size will not be changed.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3])
  print(t.expand(4, -1).numpy())
  ```
  """
  return self._broadcast_to(tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_pad_left(self.shape, argfix(shape, *args))))))

permute ¤

permute(order, *args) -> Tensor

Returns a tensor that is a permutation of the original tensor. The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified. order can be passed as a tuple or as separate arguments.

t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
[[0 1 2]
 [3 4 5]]
print(t.permute(1, 0).numpy())
[[0 3]
 [1 4]
 [2 5]]

Source code in tinygrad/tensor.py
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
def permute(self, order, *args) -> Tensor:
  """
  Returns a tensor that is a permutation of the original tensor.
  The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
  `order` can be passed as a tuple or as separate arguments.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(6).reshape(2, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.permute(1, 0).numpy())
  ```
  """
  order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
  if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
  return F.Permute.apply(self, order=order_arg)

flip ¤

flip(axis, *args) -> Tensor

Returns a tensor that reverses the order of the original tensor along given axis. axis can be passed as a tuple or as separate arguments.

t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
[[0 1 2]
 [3 4 5]]
print(t.flip(0).numpy())
[[3 4 5]
 [0 1 2]]
print(t.flip((0, 1)).numpy())
[[5 4 3]
 [2 1 0]]

Source code in tinygrad/tensor.py
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
def flip(self, axis, *args) -> Tensor:
  """
  Returns a tensor that reverses the order of the original tensor along given `axis`.
  `axis` can be passed as a tuple or as separate arguments.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(6).reshape(2, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.flip(0).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.flip((0, 1)).numpy())
  ```
  """
  axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
  if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
  return F.Flip.apply(self, axis=axis_arg)

shrink ¤

shrink(
    arg: Tuple[Optional[Tuple[sint, sint]], ...]
) -> Tensor

Returns a tensor that shrinks the each axis based on input arg. arg must have the same length as self.ndim. For each axis, it can be None, which means no shrink, or a tuple (start, end) that works the same as Python slice.

t = Tensor.arange(9).reshape(3, 3)
print(t.numpy())
[[0 1 2]
 [3 4 5]
 [6 7 8]]
print(t.shrink(((None, (1, 3)))).numpy())
[[1 2]
 [4 5]
 [7 8]]
print(t.shrink((((0, 2), (0, 2)))).numpy())
[[0 1]
 [3 4]]

Source code in tinygrad/tensor.py
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
  """
  Returns a tensor that shrinks the each axis based on input arg.
  `arg` must have the same length as `self.ndim`.
  For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(9).reshape(3, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.shrink(((None, (1, 3)))).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.shrink((((0, 2), (0, 2)))).numpy())
  ```
  """
  if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
  return F.Shrink.apply(self, arg=tuple(shrink_arg))

pad ¤

pad(
    padding: Union[
        Sequence[sint],
        Sequence[Optional[Tuple[sint, sint]]],
    ],
    mode: str = "constant",
    value: float = 0.0,
) -> Tensor

Returns a tensor with padding applied based on the input padding. padding supports two padding structures:

  1. Flat padding: (padding_left, padding_right, padding_top, padding_bottom, ...)
  2. This structure matches PyTorch's pad.
  3. padding length must be even.

  4. Group padding: (..., (padding_top, padding_bottom), (padding_left, padding_right))

  5. This structure matches pad for jax, numpy, tensorflow and others.
  6. For each axis, padding can be None, meaning no padding, or a tuple (start, end).
  7. padding must have the same length as self.ndim.

Padding values can be negative, resulting in dimension shrinks that work similarly to Python negative slices. Padding modes is selected with mode which supports constant, reflect and replicate.

t = Tensor.arange(9).reshape(1, 1, 3, 3)
print(t.numpy())
[[[[0 1 2]
   [3 4 5]
   [6 7 8]]]]
print(t.pad((1, 2, 0, -1)).numpy())
[[[[0 0 1 2 0 0]
   [0 3 4 5 0 0]]]]
print(t.pad(((None, None, (0, -1), (1, 2)))).numpy())
[[[[0 0 1 2 0 0]
   [0 3 4 5 0 0]]]]
print(t.pad((1, 2, 0, -1), value=-float('inf')).numpy())
[[[[-inf   0.   1.   2. -inf -inf]
   [-inf   3.   4.   5. -inf -inf]]]]

Source code in tinygrad/tensor.py
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[Tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
  """
  Returns a tensor with padding applied based on the input `padding`.
  `padding` supports two padding structures:

  1. Flat padding: (padding_left, padding_right, padding_top, padding_bottom, ...)
     - This structure matches PyTorch's pad.
     - `padding` length must be even.

  2. Group padding: (..., (padding_top, padding_bottom), (padding_left, padding_right))
     - This structure matches pad for jax, numpy, tensorflow and others.
     - For each axis, padding can be `None`, meaning no padding, or a tuple `(start, end)`.
     - `padding` must have the same length as `self.ndim`.

  Padding values can be negative, resulting in dimension shrinks that work similarly to Python negative slices.
  Padding modes is selected with `mode` which supports `constant`, `reflect` and `replicate`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(9).reshape(1, 1, 3, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.pad((1, 2, 0, -1)).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.pad(((None, None, (0, -1), (1, 2)))).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.pad((1, 2, 0, -1), value=-float('inf')).numpy())
  ```
  """
  if mode not in {"constant", "reflect", "replicate"}: raise NotImplementedError(f"{mode=} is not supported")
  if (flat:=all(isinstance(p, (int,UOp)) for p in padding)) and len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
  # turn flat padding into group padding
  pX = ((0,0),)*(self.ndim - len(padding)//2) + tuple(zip(padding[-2::-2], padding[::-2])) if flat else padding
  if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
  X, pX = self, cast(Tuple[Tuple[sint, sint]], tuple((0,0) if p is None else p for p in pX))
  pads = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
  if mode == "constant":
    def _constant(x,px,v): return F.Pad.apply(x, arg=px) if v == 0 else F.Pad.apply(x, arg=px) + F.Pad.apply(Tensor.ones_like(x), arg=px).where(0,v)
    return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
           _constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
  for d,(pB,pA) in enumerate(pads):
    if mode == "reflect":
      if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
      slcB, slcA, = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
      xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
    if mode == "replicate":
      shrB, shrA, = tuple((0,1) if i==d else None for i in range(X.ndim)), tuple((X.shape[i]-1,X.shape[i]) if i==d else None for i in range(X.ndim))
      xB, xA = (X.shrink(shr).expand(tuple(p if i==d else None for i in range(X.ndim))) if p > 0 else None for shr, p in ((shrB, pB), (shrA, pA)))
    X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d)
  return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape)))

Movement (high level)¤

gather ¤

gather(dim: int, index: Tensor) -> Tensor

Gathers values along an axis specified by dim.

t = Tensor([[1, 2], [3, 4]])
print(t.numpy())
[[1 2]
 [3 4]]
print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
[[1 1]
 [4 3]]

Source code in tinygrad/tensor.py
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
def gather(self:Tensor, dim:int, index:Tensor) -> Tensor:
  """
  Gathers values along an axis specified by `dim`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([[1, 2], [3, 4]])
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
  ```
  """
  assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
  dim = self._resolve_dim(dim)
  assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
  index = index.to(self.device)
  x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
  return ((index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * x).sum(-1, acc_dtype=self.dtype)

cat ¤

cat(*args: Tensor, dim: int = 0) -> Tensor

Concatenates self with other Tensor in args along an axis specified by dim. All tensors must have the same shape except in the concatenating dimension.

t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
print(t0.cat(t1, t2, dim=0).numpy())
[[1 2]
 [3 4]
 [5 6]]
print(t0.cat(t1, t2, dim=1).numpy())
[[1 2 3 4 5 6]]

Source code in tinygrad/tensor.py
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
  """
  Concatenates self with other `Tensor` in `args` along an axis specified by `dim`.
  All tensors must have the same shape except in the concatenating dimension.

  ```python exec="true" source="above" session="tensor" result="python"
  t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
  print(t0.cat(t1, t2, dim=0).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t0.cat(t1, t2, dim=1).numpy())
  ```
  """
  dim = self._resolve_dim(dim)
  for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim)
  tensors = [self, *args]
  dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0))
  for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)])
  return functools.reduce(Tensor.add, tensors)

stack ¤

stack(*args: Tensor, dim: int = 0) -> Tensor

Concatenates self with other Tensor in args along a new dimension specified by dim.

t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
print(t0.stack(t1, t2, dim=0).numpy())
[[1 2]
 [3 4]
 [5 6]]
print(t0.stack(t1, t2, dim=1).numpy())
[[1 3 5]
 [2 4 6]]

Source code in tinygrad/tensor.py
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
  """
  Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`.

  ```python exec="true" source="above" session="tensor" result="python"
  t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
  print(t0.stack(t1, t2, dim=0).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t0.stack(t1, t2, dim=1).numpy())
  ```
  """
  # checks for shapes and number of dimensions delegated to cat
  return Tensor.cat(*[t.unsqueeze(dim) for t in [self, *args]], dim=dim)

repeat ¤

repeat(repeats, *args) -> Tensor

Repeats tensor number of times along each dimension specified by repeats. repeats can be passed as a tuple or as separate arguments.

t = Tensor([1, 2, 3])
print(t.repeat(4, 2).numpy())
[[1 2 3 1 2 3]
 [1 2 3 1 2 3]
 [1 2 3 1 2 3]
 [1 2 3 1 2 3]]
print(t.repeat(4, 2, 1).shape)
(4, 2, 3)

Source code in tinygrad/tensor.py
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
def repeat(self, repeats, *args) -> Tensor:
  """
  Repeats tensor number of times along each dimension specified by `repeats`.
  `repeats` can be passed as a tuple or as separate arguments.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3])
  print(t.repeat(4, 2).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.repeat(4, 2, 1).shape)
  ```
  """
  repeats = argfix(repeats, *args)
  base_shape = _pad_left(self.shape, repeats)[0]
  unsqueezed_shape = flatten([[1, s] for s in base_shape])
  expanded_shape = flatten([[r, s] for r,s in zip(repeats, base_shape)])
  final_shape = [r*s for r,s in zip(repeats, base_shape)]
  return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)

repeat_interleave ¤

repeat_interleave(
    repeats: int, dim: Optional[int] = None
) -> Tensor

Repeat elements of a tensor.

t = Tensor([1, 2, 3])
print(t.repeat_interleave(2).numpy())
[1 1 2 2 3 3]
Source code in tinygrad/tensor.py
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
def repeat_interleave(self, repeats:int, dim:Optional[int]=None) -> Tensor:
  """
  Repeat elements of a tensor.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3])
  print(t.repeat_interleave(2).numpy())
  ```
  """
  x, dim = (self.flatten(), 0) if dim is None else (self, self._resolve_dim(dim))
  shp = x.shape
  return x.reshape(*shp[:dim+1], 1, *shp[dim+1:]).expand(*shp[:dim+1], repeats, *shp[dim+1:]).reshape(*shp[:dim], shp[dim]*repeats, *shp[dim+1:])

split ¤

split(
    sizes: Union[int, List[int]], dim: int = 0
) -> Tuple[Tensor, ...]

Splits the tensor into chunks along the dimension specified by dim. If sizes is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller. If sizes is a list, it splits into len(sizes) chunks with size in dim according to size.

t = Tensor.arange(10).reshape(5, 2)
print(t.numpy())
[[0 1]
 [2 3]
 [4 5]
 [6 7]
 [8 9]]
split = t.split(2)
print("\n".join([repr(x.numpy()) for x in split]))
array([[0, 1],
       [2, 3]], dtype=int32)
array([[4, 5],
       [6, 7]], dtype=int32)
array([[8, 9]], dtype=int32)
split = t.split([1, 4])
print("\n".join([repr(x.numpy()) for x in split]))
array([[0, 1]], dtype=int32)
array([[2, 3],
       [4, 5],
       [6, 7],
       [8, 9]], dtype=int32)

Source code in tinygrad/tensor.py
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
  """
  Splits the tensor into chunks along the dimension specified by `dim`.
  If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
  If `sizes` is a list, it splits into `len(sizes)` chunks with size in `dim` according to `size`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(10).reshape(5, 2)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  split = t.split(2)
  print("\\n".join([repr(x.numpy()) for x in split]))
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  split = t.split([1, 4])
  print("\\n".join([repr(x.numpy()) for x in split]))
  ```
  """
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
  dim = self._resolve_dim(dim)
  if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
  assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
  return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])

chunk ¤

chunk(chunks: int, dim: int = 0) -> List[Tensor]

Splits the tensor into chunks number of chunks along the dimension dim. If the tensor size along dim is not divisible by chunks, all returned chunks will be the same size except the last one. The function may return fewer than the specified number of chunks.

chunked = Tensor.arange(11).chunk(6)
print("\n".join([repr(x.numpy()) for x in chunked]))
array([0, 1], dtype=int32)
array([2, 3], dtype=int32)
array([4, 5], dtype=int32)
array([6, 7], dtype=int32)
array([8, 9], dtype=int32)
array([10], dtype=int32)
chunked = Tensor.arange(12).chunk(6)
print("\n".join([repr(x.numpy()) for x in chunked]))
array([0, 1], dtype=int32)
array([2, 3], dtype=int32)
array([4, 5], dtype=int32)
array([6, 7], dtype=int32)
array([8, 9], dtype=int32)
array([10, 11], dtype=int32)
chunked = Tensor.arange(13).chunk(6)
print("\n".join([repr(x.numpy()) for x in chunked]))
array([0, 1, 2], dtype=int32)
array([3, 4, 5], dtype=int32)
array([6, 7, 8], dtype=int32)
array([ 9, 10, 11], dtype=int32)
array([12], dtype=int32)

Source code in tinygrad/tensor.py
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
def chunk(self, chunks:int, dim:int=0) -> List[Tensor]:
  """
  Splits the tensor into `chunks` number of chunks along the dimension `dim`.
  If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one.
  The function may return fewer than the specified number of chunks.

  ```python exec="true" source="above" session="tensor" result="python"
  chunked = Tensor.arange(11).chunk(6)
  print("\\n".join([repr(x.numpy()) for x in chunked]))
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  chunked = Tensor.arange(12).chunk(6)
  print("\\n".join([repr(x.numpy()) for x in chunked]))
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  chunked = Tensor.arange(13).chunk(6)
  print("\\n".join([repr(x.numpy()) for x in chunked]))
  ```
  """
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
  assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
  dim = self._resolve_dim(dim)
  return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))

meshgrid ¤

meshgrid(
    *args: Tensor,
    indexing: Union[Literal["ij"], Literal["xy"]] = "ij"
) -> Tuple[Tensor, ...]

Generates coordinate matrices from coordinate vectors. Input tensors can be scalars or 1D tensors.

indexing determines how the output grids are aligned. ij indexing follows matrix-style indexing and xy indexing follows Cartesian-style indexing.

x, y = Tensor([1, 2, 3]), Tensor([4, 5, 6])
grid_x, grid_y = x.meshgrid(y)
print(grid_x.numpy())
print(grid_y.numpy())
[[1 1 1]
 [2 2 2]
 [3 3 3]]
[[4 5 6]
 [4 5 6]
 [4 5 6]]
grid_x, grid_y = x.meshgrid(y, indexing="xy")
print(grid_x.numpy())
print(grid_y.numpy())
[[1 2 3]
 [1 2 3]
 [1 2 3]]
[[4 4 4]
 [5 5 5]
 [6 6 6]]

Source code in tinygrad/tensor.py
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
def meshgrid(self:Tensor, *args:Tensor, indexing:Union[Literal["ij"], Literal["xy"]]="ij") -> Tuple[Tensor, ...]:
  """
  Generates coordinate matrices from coordinate vectors.
  Input tensors can be scalars or 1D tensors.

  `indexing` determines how the output grids are aligned.
  `ij` indexing follows matrix-style indexing and `xy` indexing follows Cartesian-style indexing.

  ```python exec="true" source="above" session="tensor" result="python"
  x, y = Tensor([1, 2, 3]), Tensor([4, 5, 6])
  grid_x, grid_y = x.meshgrid(y)
  print(grid_x.numpy())
  print(grid_y.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  grid_x, grid_y = x.meshgrid(y, indexing="xy")
  print(grid_x.numpy())
  print(grid_y.numpy())
  ```
  """
  if indexing not in ("ij", "xy"): raise RuntimeError(f'indexing must be in ("ij", "xy"), got {indexing}')
  if len(tensors:=(self, *args)) == 1: return tensors
  basis = tuple(range(len(tensors))) if indexing == "ij" else (1, 0) + tuple(range(2, len(tensors)))
  tensors = tuple(t.reshape((-1,) + (1,)*(len(args) - i)) for i,t in zip(basis, tensors))
  output_shape = _broadcast_shape(*(t.shape for t in tensors))
  return tuple(t._broadcast_to(output_shape) for t in tensors)

squeeze ¤

squeeze(dim: Optional[int] = None) -> Tensor

Returns a tensor with specified dimensions of input of size 1 removed. If dim is not specified, all dimensions with size 1 are removed.

t = Tensor.zeros(2, 1, 2, 1, 2)
print(t.squeeze().shape)
(2, 2, 2)
print(t.squeeze(0).shape)
(2, 1, 2, 1, 2)
print(t.squeeze(1).shape)
(2, 2, 1, 2)

Source code in tinygrad/tensor.py
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
def squeeze(self, dim:Optional[int]=None) -> Tensor:
  """
  Returns a tensor with specified dimensions of input of size 1 removed.
  If `dim` is not specified, all dimensions with size 1 are removed.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.zeros(2, 1, 2, 1, 2)
  print(t.squeeze().shape)
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.squeeze(0).shape)
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.squeeze(1).shape)
  ```
  """
  if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
  dim = self._resolve_dim(dim)
  return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])

unsqueeze ¤

unsqueeze(dim: int) -> Tensor

Returns a tensor with a new dimension of size 1 inserted at the specified dim.

t = Tensor([1, 2, 3, 4])
print(t.unsqueeze(0).numpy())
[[1 2 3 4]]
print(t.unsqueeze(1).numpy())
[[1]
 [2]
 [3]
 [4]]

Source code in tinygrad/tensor.py
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
def unsqueeze(self, dim:int) -> Tensor:
  """
  Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3, 4])
  print(t.unsqueeze(0).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.unsqueeze(1).numpy())
  ```
  """
  dim = self._resolve_dim(dim, extra=True)
  return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])

T property ¤

T: Tensor

.T is an alias for .transpose().

transpose ¤

transpose(dim0=1, dim1=0) -> Tensor

Returns a tensor that is a transposed version of the original tensor. The given dimensions dim0 and dim1 are swapped.

t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
[[0 1 2]
 [3 4 5]]
print(t.transpose(0, 1).numpy())
[[0 3]
 [1 4]
 [2 5]]

Source code in tinygrad/tensor.py
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
def transpose(self, dim0=1, dim1=0) -> Tensor:
  """
  Returns a tensor that is a transposed version of the original tensor.
  The given dimensions `dim0` and `dim1` are swapped.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(6).reshape(2, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.transpose(0, 1).numpy())
  ```
  """
  order = list(range(self.ndim))
  order[dim0], order[dim1] = order[dim1], order[dim0]
  return self.permute(order)

flatten ¤

flatten(start_dim=0, end_dim=-1)

Flattens the tensor by reshaping it into a one-dimensional tensor. If start_dim or end_dim are passed, only dimensions starting with start_dim and ending with end_dim are flattened.

t = Tensor.arange(8).reshape(2, 2, 2)
print(t.flatten().numpy())
[0 1 2 3 4 5 6 7]
print(t.flatten(start_dim=1).numpy())
[[0 1 2 3]
 [4 5 6 7]]

Source code in tinygrad/tensor.py
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
def flatten(self, start_dim=0, end_dim=-1):
  """
  Flattens the tensor by reshaping it into a one-dimensional tensor.
  If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(8).reshape(2, 2, 2)
  print(t.flatten().numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.flatten(start_dim=1).numpy())
  ```
  """
  start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
  return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])

unflatten ¤

unflatten(dim: int, sizes: Tuple[int, ...])

Unflattens dimension dim of the tensor into multiple dimensions specified by sizes. Tensor.flatten() is the inverse of this function.

print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
(3, 2, 2, 1)
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
(3, 2, 2, 1)
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
(5, 2, 2, 3, 1, 1, 3)

Source code in tinygrad/tensor.py
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
def unflatten(self, dim:int, sizes:Tuple[int,...]):
  """
  Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.

  ```python exec="true" source="above" session="tensor" result="python"
  print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
  ```
  """
  dim = self._resolve_dim(dim)
  return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])