Skip to content

Movement

Movement (low level)¤

view ¤

view(shape, *args) -> Self

.view is an alias for .reshape.

Source code in tinygrad/mixin/movement.py
178
179
180
def view(self, shape, *args) -> Self:
  """`.view` is an alias for `.reshape`."""
  return self.reshape(shape, *args)

reshape ¤

reshape(shape, *args) -> Self

Returns a tensor with the same data as the original tensor but with a different shape. shape can be passed as a tuple or as separate arguments.

t = Tensor.arange(6)
print(t.reshape(2, 3).numpy())
[[0 1 2]
 [3 4 5]]
Source code in tinygrad/mixin/movement.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def reshape(self, shape, *args) -> Self:
  """
  Returns a tensor with the same data as the original tensor but with a different shape.
  `shape` can be passed as a tuple or as separate arguments.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(6)
  print(t.reshape(2, 3).numpy())
  ```
  """
  # resolve None and args
  new_shape = tuple([s if s is not None else self.shape[i] for i, s in enumerate(argfix(shape, *args))])
  # resolve -1
  if (c := new_shape.count(-1)) > 1:
    raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
  if c:
    new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
  if prod(self.shape) != prod(new_shape):
    raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})")
  ret = self._mop(Ops.RESHAPE, arg=new_shape)
  return self if ret.shape == self.shape else ret

expand ¤

expand(shape, *args) -> Self

Returns a tensor that is expanded to the shape that is specified. Expand can also increase the number of dimensions that a tensor has.

Passing a -1 or None to a dimension means that its size will not be changed.

t = Tensor([1, 2, 3])
print(t.expand(4, -1).numpy())
[[1 2 3]
 [1 2 3]
 [1 2 3]
 [1 2 3]]
Source code in tinygrad/mixin/movement.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def expand(self, shape, *args) -> Self:
  """
  Returns a tensor that is expanded to the shape that is specified.
  Expand can also increase the number of dimensions that a tensor has.

  Passing a `-1` or `None` to a dimension means that its size will not be changed.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3])
  print(t.expand(4, -1).numpy())
  ```
  """
  new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_align_left(self.shape, argfix(shape, *args)))))
  return self._broadcast_to(new_shape)

permute ¤

permute(order, *args) -> Self

Returns a tensor that is a permutation of the original tensor. The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified. order can be passed as a tuple or as separate arguments.

t = Tensor.empty(2, 3, 5)
print(t.shape)
(2, 3, 5)
print(t.permute(2, 0, 1).shape)
(5, 2, 3)

Source code in tinygrad/mixin/movement.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def permute(self, order, *args) -> Self:
  """
  Returns a tensor that is a permutation of the original tensor.
  The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
  `order` can be passed as a tuple or as separate arguments.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.empty(2, 3, 5)
  print(t.shape)
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.permute(2, 0, 1).shape)
  ```
  """
  order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
  if sorted(order_arg) != list(range(self.ndim)):
    raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
  return self._mop(Ops.PERMUTE, arg=order_arg) if order_arg != tuple(range(self.ndim)) else self

flip ¤

flip(axis, *args) -> Self

Returns a tensor that reverses the order of the original tensor along given axis. axis can be passed as a tuple or as separate arguments.

t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
[[0 1 2]
 [3 4 5]]
print(t.flip(0).numpy())
[[3 4 5]
 [0 1 2]]
print(t.flip((0, 1)).numpy())
[[5 4 3]
 [2 1 0]]

Source code in tinygrad/mixin/movement.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def flip(self, axis, *args) -> Self:
  """
  Returns a tensor that reverses the order of the original tensor along given `axis`.
  `axis` can be passed as a tuple or as separate arguments.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(6).reshape(2, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.flip(0).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.flip((0, 1)).numpy())
  ```
  """
  axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
  assert all(not isinstance(x, bool) and x >= 0 and x < self.ndim for x in axis_arg), f"flip args must be axis ints {axis_arg}"
  if len(axis_arg) != len(dedup(axis_arg)):
    raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
  flip_arg = tuple([i in axis_arg for i in range(len(self.shape))])
  return self._mop(Ops.FLIP, arg=flip_arg) if any(flip_arg) else self

shrink ¤

shrink(arg: tuple[tuple[sint, sint] | None, ...]) -> Self

Returns a tensor that shrinks the each axis based on input arg. arg must have the same length as self.ndim. For each axis, it can be None, which means no shrink, or a tuple (start, end) that works the same as Python slice.

t = Tensor.arange(9).reshape(3, 3)
print(t.numpy())
[[0 1 2]
 [3 4 5]
 [6 7 8]]
print(t.shrink(((None, (1, 3)))).numpy())
[[1 2]
 [4 5]
 [7 8]]
print(t.shrink((((0, 2), (0, 2)))).numpy())
[[0 1]
 [3 4]]

Source code in tinygrad/mixin/movement.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def shrink(self, arg: tuple[tuple[sint, sint] | None, ...]) -> Self:
  """
  Returns a tensor that shrinks the each axis based on input arg.
  `arg` must have the same length as `self.ndim`.
  For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(9).reshape(3, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.shrink(((None, (1, 3)))).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.shrink((((0, 2), (0, 2)))).numpy())
  ```
  """
  if self.ndim != len(arg):
    raise ValueError(f"{self.ndim=} != {len(arg)=}")
  ret = self._mop(Ops.SHRINK, arg=[x if x is not None else (0, s) for x, s in zip(arg, self.shape)])
  return self if ret.shape == self.shape else ret

pad ¤

pad(
    padding: (
        Sequence[sint] | Sequence[tuple[sint, sint] | None]
    ),
    mode: str = "constant",
    value: float = 0.0,
) -> Tensor

Returns a tensor with padding applied based on the input padding.

padding supports two padding structures:

  1. Flat padding: (padding_left, padding_right, padding_top, padding_bottom, ...)

    • This structure matches PyTorch's pad.
    • padding length must be even.
  2. Group padding: (..., (padding_top, padding_bottom), (padding_left, padding_right))

    • This structure matches pad for JAX, NumPy, TensorFlow, and others.
    • For each axis, padding can be None, meaning no padding, or a tuple (start, end).
    • padding must have the same length as self.ndim.

Padding values can be negative, resulting in dimension shrinks that work similarly to Python negative slices. Padding modes is selected with mode which supports constant, reflect and replicate.

t = Tensor.arange(9).reshape(1, 1, 3, 3)
print(t.numpy())
[[[[0 1 2]
   [3 4 5]
   [6 7 8]]]]
print(t.pad((1, 2, 0, -1)).numpy())
[[[[0 0 1 2 0 0]
   [0 3 4 5 0 0]]]]
print(t.pad(((None, None, (0, -1), (1, 2)))).numpy())
[[[[0 0 1 2 0 0]
   [0 3 4 5 0 0]]]]
print(t.pad((1, 2, 0, -1), value=-float('inf')).numpy())
[[[[-inf   0.   1.   2. -inf -inf]
   [-inf   3.   4.   5. -inf -inf]]]]

Source code in tinygrad/tensor.py
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
def pad(self, padding:Sequence[sint]|Sequence[tuple[sint, sint]|None], mode:str="constant", value:float=0.0) -> Tensor:
  """
  Returns a tensor with padding applied based on the input `padding`.

  `padding` supports two padding structures:

  1. Flat padding: `(padding_left, padding_right, padding_top, padding_bottom, ...)`
      - This structure matches PyTorch's pad.
      - `padding` length must be even.

  2. Group padding: `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
      - This structure matches pad for JAX, NumPy, TensorFlow, and others.
      - For each axis, padding can be `None`, meaning no padding, or a tuple `(start, end)`.
      - `padding` must have the same length as `self.ndim`.

  Padding values can be negative, resulting in dimension shrinks that work similarly to Python negative slices.
  Padding modes is selected with `mode` which supports `constant`, `reflect` and `replicate`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(9).reshape(1, 1, 3, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.pad((1, 2, 0, -1)).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.pad(((None, None, (0, -1), (1, 2)))).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.pad((1, 2, 0, -1), value=-float('inf')).numpy())
  ```
  """
  if mode not in {"constant", "reflect", "replicate", "circular"}: raise NotImplementedError(f"{mode=} is not supported")
  # flat padding
  if all(isinstance(p, (int,UOp)) for p in padding):
    if len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
    pX = _flat_to_grouped(tuple(cast(Sequence[sint], padding)) + (0,0)*(self.ndim - len(padding)//2))
  # group padding
  else: pX = tuple((0,0) if p is None else p for p in cast(Sequence[tuple[sint, sint]|None], padding))
  if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
  X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
  if mode == "constant":
    def _constant(x:Tensor,px,v) -> Tensor:
      return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v))
    return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
           _constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
  if mode == "circular":
    if any(pB>sh or pA>sh for (pB,pA),sh in zip(pX, X.shape)): raise ValueError('Padding value causes wrapping around more than once.')
    if any(pB<0 or pA<0 for pB,pA in pX): raise NotImplementedError("Negative pads with circular pads is not supported")
    orig_shape, X = X.shape, X.repeat(tuple(1 + bool(pB) + bool(pA) for pB,pA in pads))
    return X.shrink(tuple((0 if pB == 0 else osh-pB, xsh if pA == 0 else xsh-osh+pA) for (pB,pA),osh,xsh in zip(pads, orig_shape, X.shape)))
  for d,(pB,pA) in enumerate(pads):
    if mode == "reflect":
      if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
      slcB, slcA, = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
      xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
    if mode == "replicate":
      shrB, shrA, = tuple((0,1) if i==d else None for i in range(X.ndim)), tuple((X.shape[i]-1,X.shape[i]) if i==d else None for i in range(X.ndim))
      xB, xA = (X.shrink(shr).expand(tuple(p if i==d else None for i in range(X.ndim))) if p > 0 else None for shr, p in ((shrB, pB), (shrA, pA)))
    X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d)
  return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape)))

Movement (high level)¤

__getitem__ ¤

__getitem__(indices) -> Tensor

Retrieves a sub-tensor using indexing.

Supported Index Types: int | slice | Tensor | None | list | tuple | Ellipsis

Examples:

t = Tensor.arange(12).reshape(3, 4)
print(t.numpy())
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]]

  • Int Indexing: Select an element or sub-tensor using integers for each dimension.

    print(t[1, 2].numpy())
    
    6
    

  • Slice Indexing: Select a range of elements using slice notation (start:end:stride).

    print(t[0:2, ::2].numpy())
    
    [[0 2]
     [4 6]]
    

  • Tensor Indexing: Use another tensor as indices for advanced indexing. Using tuple or list here also works.

    print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
    
    [9 2 7]
    

  • None Indexing: Add a new dimension to the tensor.

    print(t[:, None].shape)
    
    (3, 1, 4)
    

Note

Out-of-bounds indexing results in a value of 0.

t = Tensor([1, 2, 3])
print(t[Tensor([4, 3, 2])].numpy())
[0 0 3]

Source code in tinygrad/tensor.py
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
def __getitem__(self, indices) -> Tensor:
  """
  Retrieves a sub-tensor using indexing.

  Supported Index Types: `int | slice | Tensor | None | list | tuple | Ellipsis`

  Examples:
  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(12).reshape(3, 4)
  print(t.numpy())
  ```

  - Int Indexing: Select an element or sub-tensor using integers for each dimension.
    ```python exec="true" source="above" session="tensor" result="python"
    print(t[1, 2].numpy())
    ```

  - Slice Indexing: Select a range of elements using slice notation (`start:end:stride`).
    ```python exec="true" source="above" session="tensor" result="python"
    print(t[0:2, ::2].numpy())
    ```

  - Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works.
    ```python exec="true" source="above" session="tensor" result="python"
    print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
    ```

  - `None` Indexing: Add a new dimension to the tensor.
    ```python exec="true" source="above" session="tensor" result="python"
    print(t[:, None].shape)
    ```

  NOTE: Out-of-bounds indexing results in a value of `0`.
  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3])
  print(t[Tensor([4, 3, 2])].numpy())
  ```
  """
  return self._getitem(indices)

gather ¤

gather(dim: int, index: Tensor) -> Tensor

Gathers values along an axis specified by dim.

t = Tensor([[1, 2], [3, 4]])
print(t.numpy())
[[1 2]
 [3 4]]
print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
[[1 1]
 [4 3]]

Source code in tinygrad/tensor.py
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
def gather(self:Tensor, dim:int, index:Tensor) -> Tensor:
  """
  Gathers values along an axis specified by `dim`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([[1, 2], [3, 4]])
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
  ```
  """
  assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}"
  dim = self._resolve_dim(dim)
  assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
  index = index.to(self.device)
  x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
  return (index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).where(x, 0)).sum(-1, dtype=self.dtype)

cat ¤

cat(*args: Tensor, dim: int = 0) -> Tensor

Concatenates self with other Tensor in args along an axis specified by dim. All tensors must have the same shape except in the concatenating dimension.

t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
print(t0.cat(t1, t2, dim=0).numpy())
[[1 2]
 [3 4]
 [5 6]]
print(t0.cat(t1, t2, dim=1).numpy())
[[1 2 3 4 5 6]]

Source code in tinygrad/tensor.py
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
  """
  Concatenates self with other `Tensor` in `args` along an axis specified by `dim`.
  All tensors must have the same shape except in the concatenating dimension.

  ```python exec="true" source="above" session="tensor" result="python"
  t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
  print(t0.cat(t1, t2, dim=0).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t0.cat(t1, t2, dim=1).numpy())
  ```
  """
  dim = self._resolve_dim(dim)
  for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim)
  tensors = [self, *args]
  dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0))
  for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)])
  return functools.reduce(Tensor.add, tensors)

stack ¤

stack(*args: Tensor, dim: int = 0) -> Tensor

Concatenates self with other Tensor in args along a new dimension specified by dim.

t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
print(t0.stack(t1, t2, dim=0).numpy())
[[1 2]
 [3 4]
 [5 6]]
print(t0.stack(t1, t2, dim=1).numpy())
[[1 3 5]
 [2 4 6]]

Source code in tinygrad/tensor.py
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
  """
  Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`.

  ```python exec="true" source="above" session="tensor" result="python"
  t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
  print(t0.stack(t1, t2, dim=0).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t0.stack(t1, t2, dim=1).numpy())
  ```
  """
  # checks for shapes and number of dimensions delegated to cat
  return Tensor.cat(*[t.unsqueeze(dim) for t in argfix(self, *args)], dim=dim)

repeat ¤

repeat(repeats, *args) -> Self

Repeats tensor number of times along each dimension specified by repeats. repeats can be passed as a tuple or as separate arguments.

t = Tensor([1, 2, 3])
print(t.repeat(4, 2).numpy())
[[1 2 3 1 2 3]
 [1 2 3 1 2 3]
 [1 2 3 1 2 3]
 [1 2 3 1 2 3]]
print(t.repeat(4, 2, 1).shape)
(4, 2, 3)

Source code in tinygrad/mixin/movement.py
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
def repeat(self, repeats, *args) -> Self:
  """
  Repeats tensor number of times along each dimension specified by `repeats`.
  `repeats` can be passed as a tuple or as separate arguments.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3])
  print(t.repeat(4, 2).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.repeat(4, 2, 1).shape)
  ```
  """
  repeats = argfix(repeats, *args)
  base_shape = _align_left(self.shape, repeats)[0]
  unsqueezed_shape = flatten([[s] if r == 1 else [1, s] for r, s in zip(repeats, base_shape)])
  expanded_shape = flatten([[s] if r == 1 else [r, s] for r, s in zip(repeats, base_shape)])
  final_shape = [r * s for r, s in zip(repeats, base_shape)]
  return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape)

repeat_interleave ¤

repeat_interleave(
    repeats: int, dim: int | None = None
) -> Self

Repeats elements of a tensor.

t = Tensor([1, 2, 3])
print(t.repeat_interleave(2).numpy())
[1 1 2 2 3 3]
Source code in tinygrad/mixin/movement.py
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
def repeat_interleave(self, repeats: int, dim: int | None = None) -> Self:
  """
  Repeats elements of a tensor.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3])
  print(t.repeat_interleave(2).numpy())
  ```
  """
  x, dim = (self.flatten(), 0) if dim is None else (self, self._resolve_dim(dim))
  shp = x.shape
  x = x.reshape(*shp[: dim + 1], 1, *shp[dim + 1 :])
  x = x.expand(*shp[: dim + 1], repeats, *shp[dim + 1 :])
  x = x.reshape(*shp[:dim], shp[dim] * repeats, *shp[dim + 1 :])
  return x

split ¤

split(
    sizes: int | Sequence[int], dim: int = 0
) -> tuple[Tensor, ...]

Splits the tensor into chunks along the dimension specified by dim. If sizes is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller. If sizes is a list, it splits into len(sizes) chunks with size in dim according to size.

t = Tensor.arange(10).reshape(5, 2)
print(t.numpy())
[[0 1]
 [2 3]
 [4 5]
 [6 7]
 [8 9]]
split = t.split(2)
print("\n".join([repr(x.numpy()) for x in split]))
array([[0, 1],
       [2, 3]], dtype=int32)
array([[4, 5],
       [6, 7]], dtype=int32)
array([[8, 9]], dtype=int32)
split = t.split([1, 4])
print("\n".join([repr(x.numpy()) for x in split]))
array([[0, 1]], dtype=int32)
array([[2, 3],
       [4, 5],
       [6, 7],
       [8, 9]], dtype=int32)

Source code in tinygrad/tensor.py
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
def split(self, sizes:int|Sequence[int], dim:int=0) -> tuple[Tensor, ...]:
  """
  Splits the tensor into chunks along the dimension specified by `dim`.
  If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
  If `sizes` is a list, it splits into `len(sizes)` chunks with size in `dim` according to `size`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(10).reshape(5, 2)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  split = t.split(2)
  print("\\n".join([repr(x.numpy()) for x in split]))
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  split = t.split([1, 4])
  print("\\n".join([repr(x.numpy()) for x in split]))
  ```
  """
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
  dim = self._resolve_dim(dim)
  if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
  assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
  return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])

chunk ¤

chunk(chunks: int, dim: int = 0) -> list[Tensor]

Splits the tensor into chunks number of chunks along the dimension dim. If the tensor size along dim is not divisible by chunks, all returned chunks will be the same size except the last one. The function may return fewer than the specified number of chunks.

chunked = Tensor.arange(11).chunk(6)
print("\n".join([repr(x.numpy()) for x in chunked]))
array([0, 1], dtype=int32)
array([2, 3], dtype=int32)
array([4, 5], dtype=int32)
array([6, 7], dtype=int32)
array([8, 9], dtype=int32)
array([10], dtype=int32)
chunked = Tensor.arange(12).chunk(6)
print("\n".join([repr(x.numpy()) for x in chunked]))
array([0, 1], dtype=int32)
array([2, 3], dtype=int32)
array([4, 5], dtype=int32)
array([6, 7], dtype=int32)
array([8, 9], dtype=int32)
array([10, 11], dtype=int32)
chunked = Tensor.arange(13).chunk(6)
print("\n".join([repr(x.numpy()) for x in chunked]))
array([0, 1, 2], dtype=int32)
array([3, 4, 5], dtype=int32)
array([6, 7, 8], dtype=int32)
array([ 9, 10, 11], dtype=int32)
array([12], dtype=int32)

Source code in tinygrad/tensor.py
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
def chunk(self, chunks:int, dim:int=0) -> list[Tensor]:
  """
  Splits the tensor into `chunks` number of chunks along the dimension `dim`.
  If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one.
  The function may return fewer than the specified number of chunks.

  ```python exec="true" source="above" session="tensor" result="python"
  chunked = Tensor.arange(11).chunk(6)
  print("\\n".join([repr(x.numpy()) for x in chunked]))
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  chunked = Tensor.arange(12).chunk(6)
  print("\\n".join([repr(x.numpy()) for x in chunked]))
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  chunked = Tensor.arange(13).chunk(6)
  print("\\n".join([repr(x.numpy()) for x in chunked]))
  ```
  """
  assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
  assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
  dim = self._resolve_dim(dim)
  return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))

unfold ¤

unfold(dim: int, size: sint, step: int) -> Tensor

Unfolds the tensor along dimension dim into overlapping windows.

Each window has length size and begins every step elements of self. Returns the input tensor with dimension dim replaced by dims (n_windows, size) where n_windows = (self.shape[dim] - size) // step + 1.

unfolded = Tensor.arange(8).unfold(0,2,2)
print("\n".join([repr(x.numpy()) for x in unfolded]))
array([0, 1], dtype=int32)
array([2, 3], dtype=int32)
array([4, 5], dtype=int32)
array([6, 7], dtype=int32)
unfolded = Tensor.arange(27).reshape(3,3,3).unfold(-1,2,3)
print("\n".join([repr(x.numpy()) for x in unfolded]))
array([[[0, 1]],

       [[3, 4]],

       [[6, 7]]], dtype=int32)
array([[[ 9, 10]],

       [[12, 13]],

       [[15, 16]]], dtype=int32)
array([[[18, 19]],

       [[21, 22]],

       [[24, 25]]], dtype=int32)

Source code in tinygrad/tensor.py
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
def unfold(self, dim:int, size:sint, step:int) -> Tensor:
  """
  Unfolds the tensor along dimension `dim` into overlapping windows.

  Each window has length `size` and begins every `step` elements of `self`.
  Returns the input tensor with dimension `dim` replaced by dims `(n_windows, size)`
  where `n_windows = (self.shape[dim] - size) // step + 1`.

  ```python exec="true" source="above" session="tensor" result="python"
  unfolded = Tensor.arange(8).unfold(0,2,2)
  print("\\n".join([repr(x.numpy()) for x in unfolded]))
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  unfolded = Tensor.arange(27).reshape(3,3,3).unfold(-1,2,3)
  print("\\n".join([repr(x.numpy()) for x in unfolded]))
  ```
  """
  if size < 0: raise RuntimeError(f'size must be >= 0 but got {size=}')
  if step <= 0: raise RuntimeError(f'step must be > 0 but got {step=}')
  if size > self.shape[dim]: raise RuntimeError(f'maximum size for tensor at dimension {dim} is {self.shape[dim]} but size is {size}')
  dim = self._resolve_dim(dim)
  perm_to_last = tuple(i for i in range(self.ndim) if i != dim) + (dim,)
  return self.permute(perm_to_last)._pool((size,), step).permute(argsort(perm_to_last) + (self.ndim,))

meshgrid ¤

meshgrid(
    *args: Tensor, indexing: Literal["ij", "xy"] = "ij"
) -> tuple[Tensor, ...]

Generates coordinate matrices from coordinate vectors. Input tensors can be scalars or 1D tensors.

indexing determines how the output grids are aligned. ij indexing follows matrix-style indexing and xy indexing follows Cartesian-style indexing.

x, y = Tensor([1, 2, 3]), Tensor([4, 5, 6])
grid_x, grid_y = x.meshgrid(y)
print(grid_x.numpy())
print(grid_y.numpy())
[[1 1 1]
 [2 2 2]
 [3 3 3]]
[[4 5 6]
 [4 5 6]
 [4 5 6]]
grid_x, grid_y = x.meshgrid(y, indexing="xy")
print(grid_x.numpy())
print(grid_y.numpy())
[[1 2 3]
 [1 2 3]
 [1 2 3]]
[[4 4 4]
 [5 5 5]
 [6 6 6]]

Source code in tinygrad/tensor.py
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
def meshgrid(self:Tensor, *args:Tensor, indexing:Literal["ij", "xy"]="ij") -> tuple[Tensor, ...]:
  """
  Generates coordinate matrices from coordinate vectors.
  Input tensors can be scalars or 1D tensors.

  `indexing` determines how the output grids are aligned.
  `ij` indexing follows matrix-style indexing and `xy` indexing follows Cartesian-style indexing.

  ```python exec="true" source="above" session="tensor" result="python"
  x, y = Tensor([1, 2, 3]), Tensor([4, 5, 6])
  grid_x, grid_y = x.meshgrid(y)
  print(grid_x.numpy())
  print(grid_y.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  grid_x, grid_y = x.meshgrid(y, indexing="xy")
  print(grid_x.numpy())
  print(grid_y.numpy())
  ```
  """
  if indexing not in ("ij", "xy"): raise RuntimeError(f'indexing must be in ("ij", "xy"), got {indexing}')
  if len(tensors:=(self, *args)) == 1: return tensors
  basis = tuple(range(len(tensors))) if indexing == "ij" else (1, 0) + tuple(range(2, len(tensors)))
  tensors = tuple(t.reshape((-1,) + (1,)*(len(args) - i)) for i,t in zip(basis, tensors))
  output_shape = _broadcast_shape(*(t.shape for t in tensors))
  return tuple(t._broadcast_to(output_shape) for t in tensors)

squeeze ¤

squeeze(dim: int | None = None) -> Self

Returns a tensor with specified dimensions of input of size 1 removed. If dim is not specified, all dimensions with size 1 are removed.

t = Tensor.zeros(2, 1, 2, 1, 2)
print(t.squeeze().shape)
(2, 2, 2)
print(t.squeeze(0).shape)
(2, 1, 2, 1, 2)
print(t.squeeze(1).shape)
(2, 2, 1, 2)

Source code in tinygrad/mixin/movement.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def squeeze(self, dim: int | None = None) -> Self:
  """
  Returns a tensor with specified dimensions of input of size 1 removed.
  If `dim` is not specified, all dimensions with size 1 are removed.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.zeros(2, 1, 2, 1, 2)
  print(t.squeeze().shape)
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.squeeze(0).shape)
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.squeeze(1).shape)
  ```
  """
  if dim is None:
    return self.reshape(tuple(dim for dim in self.shape if dim != 1))
  dim = self._resolve_dim(dim)
  return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim + 1 :])

unsqueeze ¤

unsqueeze(dim: int) -> Self

Returns a tensor with a new dimension of size 1 inserted at the specified dim.

t = Tensor([1, 2, 3, 4])
print(t.unsqueeze(0).numpy())
[[1 2 3 4]]
print(t.unsqueeze(1).numpy())
[[1]
 [2]
 [3]
 [4]]

Source code in tinygrad/mixin/movement.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def unsqueeze(self, dim: int) -> Self:
  """
  Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3, 4])
  print(t.unsqueeze(0).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.unsqueeze(1).numpy())
  ```
  """
  dim = self._resolve_dim(dim, extra=True)
  return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])

T property ¤

T: Self

.T is an alias for .transpose().

transpose ¤

transpose(dim0=1, dim1=0) -> Self

Returns a tensor that is a transposed version of the original tensor. The given dimensions dim0 and dim1 are swapped.

t = Tensor.arange(6).reshape(2, 3)
print(t.numpy())
[[0 1 2]
 [3 4 5]]
print(t.transpose(0, 1).numpy())
[[0 3]
 [1 4]
 [2 5]]

Source code in tinygrad/mixin/movement.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def transpose(self, dim0=1, dim1=0) -> Self:
  """
  Returns a tensor that is a transposed version of the original tensor.
  The given dimensions `dim0` and `dim1` are swapped.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(6).reshape(2, 3)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.transpose(0, 1).numpy())
  ```
  """
  order = list(range(self.ndim))
  order[dim0], order[dim1] = order[dim1], order[dim0]
  return self.permute(order)

flatten ¤

flatten(start_dim=0, end_dim=-1) -> Self

Flattens the tensor by reshaping it into a one-dimensional tensor. If start_dim or end_dim are passed, only dimensions starting with start_dim and ending with end_dim are flattened.

t = Tensor.arange(8).reshape(2, 2, 2)
print(t.flatten().numpy())
[0 1 2 3 4 5 6 7]
print(t.flatten(start_dim=1).numpy())
[[0 1 2 3]
 [4 5 6 7]]

Source code in tinygrad/mixin/movement.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def flatten(self, start_dim=0, end_dim=-1) -> Self:
  """
  Flattens the tensor by reshaping it into a one-dimensional tensor.
  If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(8).reshape(2, 2, 2)
  print(t.flatten().numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.flatten(start_dim=1).numpy())
  ```
  """
  start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
  return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim : end_dim + 1]),) + self.shape[end_dim + 1 :])

unflatten ¤

unflatten(dim: int, sizes: tuple[int, ...]) -> Self

Unflattens dimension dim of the tensor into multiple dimensions specified by sizes. Tensor.flatten() is the inverse of this function.

print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
(3, 2, 2, 1)
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
(3, 2, 2, 1)
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
(5, 2, 2, 3, 1, 1, 3)

Source code in tinygrad/mixin/movement.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def unflatten(self, dim: int, sizes: tuple[int, ...]) -> Self:
  """
  Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.

  ```python exec="true" source="above" session="tensor" result="python"
  print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
  ```
  """
  dim = self._resolve_dim(dim)
  return self.reshape(self.shape[:dim] + sizes + self.shape[dim + 1 :])

diag ¤

diag() -> Tensor

Returns a 2-D square tensor with the elements of input as the main diagonal.

print(Tensor([1, 2, 3]).diag().numpy())
[[1 0 0]
 [0 2 0]
 [0 0 3]]
Source code in tinygrad/tensor.py
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
def diag(self) -> Tensor:
  """
  Returns a 2-D square tensor with the elements of input as the main diagonal.

  ```python exec="true" source="above" session="tensor" result="python"
  print(Tensor([1, 2, 3]).diag().numpy())
  ```
  """
  if self.ndim != 1: raise ValueError(f"expect input to be 1-D, getting {self.ndim}-D")
  return self.unsqueeze(-1).pad((None,(0,n:=self.shape[0]))).flatten().shrink(((0,n*n),)).reshape(n,n)

roll ¤

roll(
    shifts: int | tuple[int, ...],
    dims: int | tuple[int, ...] | None = None,
) -> Tensor

Rolls the tensor along specified dimension(s). The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.

t = Tensor.arange(4)
print(t.roll(shifts=1, dims=0).numpy())
[3 0 1 2]
print(t.roll(shifts=-1, dims=0).numpy())
[1 2 3 0]

Source code in tinygrad/tensor.py
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
def roll(self, shifts:int|tuple[int, ...], dims:int|tuple[int, ...]|None=None) -> Tensor:
  """
  Rolls the tensor along specified dimension(s).
  The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.arange(4)
  print(t.roll(shifts=1, dims=0).numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.roll(shifts=-1, dims=0).numpy())
  ```
  """
  if dims is None: return self.flatten().roll(shifts, 0).reshape(self.shape)
  dims, shifts, slices = tuple(self._resolve_dim(d) for d in make_tuple(dims, 1)), make_tuple(shifts, 1), [slice(None)] * self.ndim
  if len(dims) != len(shifts): raise RuntimeError(f"{len(dims)=} != {len(shifts)=}")
  for dim, shift in zip(dims, shifts): slices[dim] = slice(delta:=self.shape[dim]-shift%self.shape[dim], delta+self.shape[dim])
  return self.repeat(*tuple(2 if i in dims else 1 for i in range(self.ndim)))[slices]

rearrange ¤

rearrange(formula: str, **sizes) -> Self

Rearranges input according to formula

See: https://einops.rocks/api/rearrange/

x = Tensor([[1, 2], [3, 4]])
print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy())
[1 2 3 4]
Source code in tinygrad/mixin/movement.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
def rearrange(self, formula: str, **sizes) -> Self:
  """
  Rearranges input according to formula

  See: https://einops.rocks/api/rearrange/

  ```python exec="true" source="above" session="tensor" result="python"
  x = Tensor([[1, 2], [3, 4]])
  print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy())
  ```
  """

  def parse_formula(formula: str):
    tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", "  ").replace(" 1 ", " ( ) ").split()
    lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
    pairs = list(zip(lparens, rparens))
    assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
    return [name for name in tokens if name not in ("(", ")")], [(s - 2 * i, e - 1 - 2 * i) for i, (s, e) in enumerate(pairs)]

  assert formula.count("->") == 1, 'need exactly one "->" in formula'

  (lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))

  for name in sizes:
    assert name in lhs, f"axis {name} is not used in transform"
  assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
  for name in flatten((lhs, rhs)):
    assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
  assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
  assert lhs.count("...") <= 1, f"too many ellipses in {formula}"

  # resolve ellipsis
  if "..." in lhs:
    ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
  lhs, rhs = map(lambda l: l[: (i := l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1 :] if "..." in l else l, (lhs, rhs))
  unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
  flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]

  # apply movement ops in order unflatten -> permute -> flatten/unsqueeze
  t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
  for i, name in enumerate(lhs):
    assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
  t = t.permute([lhs.index(name) for name in rhs])
  return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0] < dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)