Skip to content

Properties

Basic¤

shape property ¤

shape: tuple[sint, ...]

dtype property ¤

dtype: DType

device property ¤

device: Union[str, tuple[str, ...]]

ndim property ¤

ndim: int

Returns the number of dimensions in the tensor.

t = Tensor([[1, 2], [3, 4]])
print(t.ndim)
2

numel ¤

numel() -> sint

Returns the total number of elements in the tensor.

t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(t.numel())
8
Source code in tinygrad/tensor.py
3712
3713
3714
3715
3716
3717
3718
3719
3720
3721
def numel(self) -> sint:
  """
  Returns the total number of elements in the tensor.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
  print(t.numel())
  ```
  """
  return prod(self.shape)

element_size ¤

element_size() -> int

Returns the size in bytes of an individual element in the tensor.

t = Tensor([5], dtype=dtypes.int16)
print(t.element_size())
2
Source code in tinygrad/tensor.py
3723
3724
3725
3726
3727
3728
3729
3730
3731
3732
def element_size(self) -> int:
  """
  Returns the size in bytes of an individual element in the tensor.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([5], dtype=dtypes.int16)
  print(t.element_size())
  ```
  """
  return self.dtype.itemsize

nbytes ¤

nbytes() -> int

Returns the total number of bytes of all elements in the tensor.

t = Tensor([8, 9], dtype=dtypes.float)
print(t.nbytes())
8
Source code in tinygrad/tensor.py
3734
3735
3736
3737
3738
3739
3740
3741
3742
3743
def nbytes(self) -> int:
  """
  Returns the total number of bytes of all elements in the tensor.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([8, 9], dtype=dtypes.float)
  print(t.nbytes())
  ```
  """
  return self.numel() * self.element_size()

is_floating_point ¤

is_floating_point() -> bool

Returns True if the tensor contains floating point types, i.e. is one of dtype.float64, dtype.float32, dtype.float16, dtype.bfloat16.

t = Tensor([8, 9], dtype=dtypes.float32)
print(t.is_floating_point())
True
Source code in tinygrad/tensor.py
3745
3746
3747
3748
3749
3750
3751
3752
3753
3754
3755
def is_floating_point(self) -> bool:
  """
  Returns `True` if the tensor contains floating point types, i.e. is one of `dtype.float64`, `dtype.float32`,
  `dtype.float16`, `dtype.bfloat16`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([8, 9], dtype=dtypes.float32)
  print(t.is_floating_point())
  ```
  """
  return dtypes.is_float(self.dtype)

size ¤

size(
    dim: Optional[int] = None,
) -> Union[sint, tuple[sint, ...]]

Return the size of the tensor. If dim is specified, return the length along dimension dim. Otherwise return the shape of the tensor.

t = Tensor([[4, 5, 6], [7, 8, 9]])
print(t.size())
(2, 3)
print(t.size(dim=1))
3

Source code in tinygrad/tensor.py
3757
3758
3759
3760
3761
3762
3763
3764
3765
3766
3767
3768
3769
def size(self, dim:Optional[int]=None) -> Union[sint, tuple[sint, ...]]:
  """
  Return the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([[4, 5, 6], [7, 8, 9]])
  print(t.size())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.size(dim=1))
  ```
  """
  return self.shape if dim is None else self.shape[dim]

Data Access¤

data ¤

data() -> memoryview

Returns the data of this tensor as a memoryview.

t = Tensor([1, 2, 3, 4])
print(np.frombuffer(t.data(), dtype=np.int32))
[1 2 3 4]
Source code in tinygrad/tensor.py
314
315
316
317
318
319
320
321
322
323
324
325
326
def data(self) -> memoryview:
  """
  Returns the data of this tensor as a memoryview.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3, 4])
  print(np.frombuffer(t.data(), dtype=np.int32))
  ```
  """
  assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
  assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
  if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.base.fmt != "e"
  return cast(memoryview, self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape))

item ¤

item() -> ConstType

Returns the value of this tensor as a standard Python number.

t = Tensor(42)
print(t.item())
42
Source code in tinygrad/tensor.py
328
329
330
331
332
333
334
335
336
337
338
def item(self) -> ConstType:
  """
  Returns the value of this tensor as a standard Python number.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor(42)
  print(t.item())
  ```
  """
  assert self.numel() == 1, "must have one element for item"
  return self.data()[(0,) * len(self.shape)]

tolist ¤

Returns the value of this tensor as a nested list.

t = Tensor([1, 2, 3, 4])
print(t.tolist())
[1, 2, 3, 4]
Source code in tinygrad/tensor.py
342
343
344
345
346
347
348
349
350
351
def tolist(self) -> Union[Sequence[ConstType], ConstType]:
  """
  Returns the value of this tensor as a nested list.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3, 4])
  print(t.tolist())
  ```
  """
  return self.data().tolist()

numpy ¤

numpy() -> 'np.ndarray'

Returns the value of this tensor as a numpy.ndarray.

t = Tensor([1, 2, 3, 4])
print(repr(t.numpy()))
array([1, 2, 3, 4], dtype=int32)
Source code in tinygrad/tensor.py
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def numpy(self) -> 'np.ndarray':  # type: ignore [name-defined] # noqa: F821
  """
  Returns the value of this tensor as a `numpy.ndarray`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3, 4])
  print(repr(t.numpy()))
  ```
  """
  import numpy as np
  if self.dtype.base == dtypes.bfloat16: return self.float().numpy()
  assert _to_np_dtype(self.dtype.base) is not None, f"no np dtype for {self.dtype.base}"
  assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
  return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype.base)).reshape(self.shape)

tinygrad ops¤

schedule_with_vars ¤

schedule_with_vars(
    *lst: Tensor,
) -> tuple[list[ScheduleItem], dict[Variable, int]]

Creates the schedule needed to realize these Tensor(s), with Variables.

Note

A Tensor can only be scheduled once.

Source code in tinygrad/tensor.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
  """
  Creates the schedule needed to realize these Tensor(s), with Variables.

  NOTE: A Tensor can only be scheduled once.
  """
  schedule, var_vals, becomes_map = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]))

  # get all children of keys in becomes_map
  all_uops: set[UOp] = set()
  search_uops = list(becomes_map)
  while len(search_uops):
    x = search_uops.pop(0)
    if x in all_uops: continue
    all_uops.add(x)
    search_uops.extend([u for c in x.children if (u:=c()) is not None])

  # link the found UOps back to Tensors. exit early if there's no Tensors to realize
  # NOTE: this uses all_tensors, but it's fast
  fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and any(x in all_uops for x in t.lazydata.lbs)]
  if len(fixed_tensors) == 0: return [], {}

  # potentially rewrite all the discovered Tensors
  sink = UOp.sink(*[UOp.sink(*t.lazydata.lbs) if isinstance(t.lazydata, MultiLazyBuffer) else t.lazydata for t in fixed_tensors])
  new_sink = sink.substitute(becomes_map)

  # set the relevant lazydata to the realized UOps
  for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
    if s is ns: continue
    if isinstance(t.lazydata, MultiLazyBuffer): t.lazydata.lbs = list(ns.src)
    else: t.lazydata = ns

  return memory_planner(schedule), var_vals

schedule ¤

schedule(*lst: Tensor) -> list[ScheduleItem]

Creates the schedule needed to realize these Tensor(s).

Source code in tinygrad/tensor.py
260
261
262
263
264
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
  """Creates the schedule needed to realize these Tensor(s)."""
  schedule, var_vals = self.schedule_with_vars(*lst)
  assert len(var_vals) == 0
  return schedule

realize ¤

realize(*lst: Tensor, do_update_stats=True) -> Tensor

Triggers the computation needed to create these Tensor(s).

Source code in tinygrad/tensor.py
266
267
268
269
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
  """Triggers the computation needed to create these Tensor(s)."""
  run_schedule(*self.schedule_with_vars(*lst), do_update_stats=do_update_stats)
  return self

replace ¤

replace(x: Tensor) -> Tensor

Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.

Source code in tinygrad/tensor.py
271
272
273
274
275
276
277
278
279
def replace(self, x:Tensor) -> Tensor:
  """
  Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
  """
  # used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
  assert getattr(self, '_ctx', None) is None
  assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
  self.lazydata = x.lazydata
  return self

assign ¤

assign(x) -> Tensor
Source code in tinygrad/tensor.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def assign(self, x) -> Tensor:
  # TODO: this is a hack for writing to DISK. remove with working assign
  if isinstance(self.device, str) and self.device.startswith("DISK"):
    if x.__class__ is not Tensor: x = Tensor(x, device="CLANG", dtype=self.dtype)
    self.contiguous().realize().lazydata.base.realized.copyin(x._data())
    return self
  if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
  if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
  if self.lazydata is x.lazydata: return self  # a self assign is a NOOP
  # NOTE: we allow cross device assign
  assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
  assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
  assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
  assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
  assert not x.requires_grad  # self requires_grad is okay?
  if not self.lazydata.is_realized: return self.replace(x)
  self.lazydata = self.lazydata.assign(x.lazydata)
  return self

detach ¤

detach() -> Tensor

Returns a new tensor with the same data as this tensor, but detached from the autograd graph.

Source code in tinygrad/tensor.py
300
301
302
303
304
def detach(self) -> Tensor:
  """
  Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
  """
  return Tensor(self.lazydata.detach(), device=self.device, requires_grad=False)

to ¤

to(device: Optional[Union[str, tuple[str, ...]]]) -> Tensor

Moves the tensor to the given device.

Source code in tinygrad/tensor.py
377
378
379
380
381
382
383
384
385
386
387
def to(self, device:Optional[Union[str, tuple[str, ...]]]) -> Tensor:
  """
  Moves the tensor to the given device.
  """
  device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
  if device == self.device: return self
  if not isinstance(device, str): return self.shard(device)
  ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
  if self.grad is not None: ret.grad = self.grad.to(device)
  if hasattr(self, '_ctx'): ret._ctx = self._ctx
  return ret

to_ ¤

to_(device: Optional[Union[str, tuple[str, ...]]])

Moves the tensor to the given device in place.

Source code in tinygrad/tensor.py
389
390
391
392
393
394
395
def to_(self, device:Optional[Union[str, tuple[str, ...]]]):
  """
  Moves the tensor to the given device in place.
  """
  real = self.to(device)
  if self.grad is not None and real.grad is not None: self.grad.replace(real.grad)
  return self.replace(real)

shard ¤

shard(
    devices: tuple[str, ...], axis: Optional[int] = None
) -> Tensor

Shards the tensor across the given devices. Optionally specify which axis to shard on.

t = Tensor.empty(2, 4)
print(t.shard((t.device, t.device), axis=1).lazydata)
<MLB self.axis=1 self.real=[True, True] 
CLANG ShapeTracker(views=(View(shape=(2, 2), strides=(2, 1), offset=0, mask=None, contiguous=True),))
CLANG ShapeTracker(views=(View(shape=(2, 2), strides=(2, 1), offset=0, mask=None, contiguous=True),))>
Source code in tinygrad/tensor.py
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> Tensor:
  """
  Shards the tensor across the given devices. Optionally specify which axis to shard on.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.empty(2, 4)
  print(t.shard((t.device, t.device), axis=1).lazydata)
  ```
  """
  assert isinstance(self.lazydata, UOp), "can't shard a MultiLazyBuffer"
  devices = tuple(Device.canonicalize(x) for x in devices)
  if axis is None: lbs = [self.lazydata] * len(devices)
  else:
    axis = self._resolve_dim(axis)
    if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}")
    sz = self.shape[axis] // len(devices)
    sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))]
    lbs = [cast(UOp, t.lazydata) for t in self.split(sizes, axis)]
  sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)]
  # NOTE: this contiguous is making it impossible for the scheduler to do late const folding
  mlb = MultiLazyBuffer([lb.contiguous() for lb in sharded_lbs], axis)
  return Tensor(mlb, device=devices, requires_grad=self.requires_grad)

shard_ ¤

shard_(
    devices: tuple[str, ...], axis: Optional[int] = None
)

Shards the tensor across the given devices in place.

Source code in tinygrad/tensor.py
420
421
422
423
424
def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None):
  """
  Shards the tensor across the given devices in place.
  """
  return self.replace(self.shard(devices, axis))

contiguous ¤

contiguous()

Returns a contiguous tensor.

Source code in tinygrad/tensor.py
2539
2540
2541
2542
2543
def contiguous(self):
  """
  Returns a contiguous tensor.
  """
  return F.Contiguous.apply(self)

contiguous_backward ¤

contiguous_backward()

Inserts a contiguous operation in the backward pass.

Source code in tinygrad/tensor.py
2544
2545
2546
2547
2548
def contiguous_backward(self):
  """
  Inserts a contiguous operation in the backward pass.
  """
  return F.ContiguousBackward.apply(self)

Gradient¤

gradient ¤

gradient(
    *targets: Tensor, gradient: Optional[Tensor] = None
) -> list[Tensor]

Compute the gradient of the targets with respect to self.

x = Tensor.eye(3)
y = Tensor([[2.0,0,-2.0]])
z = y.matmul(x).sum()
dx, dy = z.gradient(x, y)

print(dx.tolist())  # dz/dx
print(dy.tolist())  # dz/dy
[[2.0, 2.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -2.0, -2.0]]
[[1.0, 1.0, 1.0]]
Source code in tinygrad/tensor.py
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None) -> list[Tensor]:
  """
  Compute the gradient of the targets with respect to self.

  ```python exec="true" source="above" session="tensor" result="python"
  x = Tensor.eye(3)
  y = Tensor([[2.0,0,-2.0]])
  z = y.matmul(x).sum()
  dx, dy = z.gradient(x, y)

  print(dx.tolist())  # dz/dx
  print(dy.tolist())  # dz/dy
  ```
  """
  assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
  if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
  rets = []
  for i,(uop,grad) in enumerate(zip(self.lazydata.lbs, gradient.lazydata.lbs)):
    target_uops = [x.lazydata.lbs[i] for x in targets]
    grads = compute_gradient(uop, grad, set(target_uops))
    ret = []
    for x in target_uops:
      if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{uop}")
      ret.append(y)
    rets.append(ret)
  # create returned Tensors
  if isinstance(self.lazydata, UOp): return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
  return [Tensor(MultiLazyBuffer(list(u), cast(MultiLazyBuffer, t.lazydata).axis, cast(MultiLazyBuffer, t.lazydata).real),
                 device=t.device) for t,u in zip(targets, zip(*rets))]

backward ¤

backward(
    gradient: Optional[Tensor] = None,
    retain_graph: bool = False,
) -> Tensor

Propagates the gradient of a tensor backwards through the computation graph. If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0. If 'retain_graph' is false, the graph used to compute the grads will be freed. Otherwise, it will be kept. Keeping it can increase memory usage.

t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
t.sum().backward()
print(t.grad.numpy())
[1. 1. 1. 1.]

Source code in tinygrad/tensor.py
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor:
  """
  Propagates the gradient of a tensor backwards through the computation graph.
  If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
  If 'retain_graph' is false, the graph used to compute the grads will be freed. Otherwise, it will be kept. Keeping it can increase memory usage.
  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
  t.sum().backward()
  print(t.grad.numpy())
  ```
  """
  toposorted = self._deepwalk()
  if gradient is None:
    assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
    # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
    # this is "implicit gradient creation"
    gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)

  toposort_uop = self.lazydata.toposort
  assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
  self.grad = gradient
  for t0 in reversed(toposorted):
    if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
    ctx = cast(Function, t0._ctx)
    token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := ctx.metadata) is not None else None)
    grads = ctx.backward(t0.grad.lazydata)
    _METADATA.reset(token)
    grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
      for g in ([grads] if len(ctx.parents) == 1 else grads)]
    for t, g in zip(ctx.parents, grads):
      if g is not None and t.requires_grad:
        assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
        assert t.lazydata in toposort_uop or (isinstance(t.lazydata, MultiLazyBuffer) and any(x in toposort_uop for x in t.lazydata.lbs)), \
          f"grad uop must have a path from self\ngrad uop: {t.lazydata}"
        t.grad = g if t.grad is None else (t.grad + g)
    if not retain_graph: del t0._ctx
  return self