Skip to content

Properties

Basic¤

shape property ¤

shape: Tuple[sint, ...]

dtype property ¤

dtype: DType

device property ¤

device: Union[str, Tuple[str, ...]]

ndim property ¤

ndim: int

Returns the number of dimensions in the tensor.

t = Tensor([[1, 2], [3, 4]])
print(t.ndim)
2

numel ¤

numel() -> sint

Returns the total number of elements in the tensor.

t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(t.numel())
8
Source code in tinygrad/tensor.py
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
def numel(self) -> sint:
  """
  Returns the total number of elements in the tensor.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
  print(t.numel())
  ```
  """
  return prod(self.shape)

element_size ¤

element_size() -> int

Returns the size in bytes of an individual element in the tensor.

t = Tensor([5], dtype=dtypes.int16)
print(t.element_size())
2
Source code in tinygrad/tensor.py
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
def element_size(self) -> int:
  """
  Returns the size in bytes of an individual element in the tensor.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([5], dtype=dtypes.int16)
  print(t.element_size())
  ```
  """
  return self.dtype.itemsize

nbytes ¤

nbytes() -> int

Returns the total number of bytes of all elements in the tensor.

t = Tensor([8, 9], dtype=dtypes.float)
print(t.nbytes())
8
Source code in tinygrad/tensor.py
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
def nbytes(self) -> int:
  """
  Returns the total number of bytes of all elements in the tensor.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([8, 9], dtype=dtypes.float)
  print(t.nbytes())
  ```
  """
  return self.numel() * self.element_size()

is_floating_point ¤

is_floating_point() -> bool

Returns True if the tensor contains floating point types, i.e. is one of dtype.float64, dtype.float32, dtype.float16, dtype.bfloat16.

t = Tensor([8, 9], dtype=dtypes.float32)
print(t.is_floating_point())
True
Source code in tinygrad/tensor.py
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
def is_floating_point(self) -> bool:
  """
  Returns `True` if the tensor contains floating point types, i.e. is one of `dtype.float64`, `dtype.float32`,
  `dtype.float16`, `dtype.bfloat16`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([8, 9], dtype=dtypes.float32)
  print(t.is_floating_point())
  ```
  """
  return dtypes.is_float(self.dtype)

size ¤

size(
    dim: Optional[int] = None,
) -> Union[sint, Tuple[sint, ...]]

Return the size of the tensor. If dim is specified, return the length along dimension dim. Otherwise return the shape of the tensor.

t = Tensor([[4, 5, 6], [7, 8, 9]])
print(t.size())
(2, 3)
print(t.size(dim=1))
3

Source code in tinygrad/tensor.py
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
def size(self, dim:Optional[int]=None) -> Union[sint, Tuple[sint, ...]]:
  """
  Return the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([[4, 5, 6], [7, 8, 9]])
  print(t.size())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  print(t.size(dim=1))
  ```
  """
  return self.shape if dim is None else self.shape[dim]

Data Access¤

data ¤

data() -> memoryview

Returns the data of this tensor as a memoryview.

t = Tensor([1, 2, 3, 4])
print(np.frombuffer(t.data(), dtype=np.int32))
[1 2 3 4]
Source code in tinygrad/tensor.py
260
261
262
263
264
265
266
267
268
269
270
271
def data(self) -> memoryview:
  """
  Returns the data of this tensor as a memoryview.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3, 4])
  print(np.frombuffer(t.data(), dtype=np.int32))
  ```
  """
  assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
  assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
  return self._data().cast(self.dtype.fmt, self.shape)

item ¤

item() -> ConstType

Returns the value of this tensor as a standard Python number.

t = Tensor(42)
print(t.item())
42
Source code in tinygrad/tensor.py
273
274
275
276
277
278
279
280
281
282
283
284
def item(self) -> ConstType:
  """
  Returns the value of this tensor as a standard Python number.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor(42)
  print(t.item())
  ```
  """
  assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
  assert self.numel() == 1, "must have one element for item"
  return self._data().cast(self.dtype.fmt)[0]

tolist ¤

Returns the value of this tensor as a nested list.

t = Tensor([1, 2, 3, 4])
print(t.tolist())
[1, 2, 3, 4]
Source code in tinygrad/tensor.py
288
289
290
291
292
293
294
295
296
297
def tolist(self) -> Union[Sequence[ConstType], ConstType]:
  """
  Returns the value of this tensor as a nested list.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3, 4])
  print(t.tolist())
  ```
  """
  return self.data().tolist()

numpy ¤

numpy() -> ndarray

Returns the value of this tensor as a numpy.ndarray.

t = Tensor([1, 2, 3, 4])
print(repr(t.numpy()))
array([1, 2, 3, 4], dtype=int32)
Source code in tinygrad/tensor.py
299
300
301
302
303
304
305
306
307
308
309
310
311
def numpy(self) -> np.ndarray:
  """
  Returns the value of this tensor as a `numpy.ndarray`.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1, 2, 3, 4])
  print(repr(t.numpy()))
  ```
  """
  if self.dtype == dtypes.bfloat16: return self.float().numpy()
  assert _to_np_dtype(self.dtype) is not None, f"no np dtype for {self.dtype}"
  assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
  return np.frombuffer(self._data(), dtype=_to_np_dtype(self.dtype)).reshape(self.shape)

tinygrad ops¤

schedule_with_vars ¤

schedule_with_vars(
    *lst: Tensor,
) -> Tuple[List[ScheduleItem], Dict[Variable, int]]

Creates the schedule needed to realize these Tensor(s), with Variables.

Note

A Tensor can only be scheduled once.

Source code in tinygrad/tensor.py
194
195
196
197
198
199
200
201
202
203
204
def schedule_with_vars(self, *lst:Tensor) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
  """
  Creates the schedule needed to realize these Tensor(s), with Variables.

  NOTE: A Tensor can only be scheduled once.
  """
  if getenv("FUZZ_SCHEDULE"):
    from test.external.fuzz_schedule import fuzz_schedule
    fuzz_schedule(flatten([x.lazydata.lbs for x in (self,)+lst]))
  schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]))
  return memory_planner(schedule), var_vals

schedule ¤

schedule(*lst: Tensor) -> List[ScheduleItem]

Creates the schedule needed to realize these Tensor(s).

Source code in tinygrad/tensor.py
206
207
208
209
210
def schedule(self, *lst:Tensor) -> List[ScheduleItem]:
  """Creates the schedule needed to realize these Tensor(s)."""
  schedule, var_vals = self.schedule_with_vars(*lst)
  assert len(var_vals) == 0
  return schedule

realize ¤

realize(*lst: Tensor, do_update_stats=True) -> Tensor

Triggers the computation needed to create these Tensor(s).

Source code in tinygrad/tensor.py
212
213
214
215
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
  """Triggers the computation needed to create these Tensor(s)."""
  run_schedule(*self.schedule_with_vars(*lst), do_update_stats=do_update_stats)
  return self

replace ¤

replace(x: Tensor) -> Tensor

Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.

Source code in tinygrad/tensor.py
217
218
219
220
221
222
223
224
225
def replace(self, x:Tensor) -> Tensor:
  """
  Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
  """
  # used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
  assert not x.requires_grad and getattr(self, '_ctx', None) is None
  assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
  self.lazydata = x.lazydata
  return self

assign ¤

assign(x) -> Tensor
Source code in tinygrad/tensor.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def assign(self, x) -> Tensor:
  # TODO: this is a hack for writing to DISK. remove with working assign
  if isinstance(self.device, str) and self.device.startswith("DISK"):
    if x.__class__ is not Tensor: x = Tensor(x, device="NPY", dtype=self.dtype)
    self.contiguous().realize().lazydata.base.realized.copyin(x.numpy().data)
    return self
  if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
  if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
  if self.lazydata is x.lazydata: return self  # a self assign is a NOOP
  # NOTE: we allow cross device assign
  assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
  assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
  assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
  assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
  assert not x.requires_grad  # self requires_grad is okay?
  if not self.lazydata.is_realized(): return self.replace(x)
  self.lazydata = self.lazydata.assign(x.lazydata)
  return self

detach ¤

detach() -> Tensor

Returns a new tensor with the same data as this tensor, but detached from the autograd graph.

Source code in tinygrad/tensor.py
246
247
248
249
250
def detach(self) -> Tensor:
  """
  Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
  """
  return Tensor(self.lazydata, device=self.device, requires_grad=False)

to ¤

to(device: Optional[Union[str, Tuple[str, ...]]]) -> Tensor

Moves the tensor to the given device.

Source code in tinygrad/tensor.py
313
314
315
316
317
318
319
320
321
322
323
def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
  """
  Moves the tensor to the given device.
  """
  device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
  if device == self.device: return self
  if not isinstance(device, str): return self.shard(device)
  ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
  if self.grad is not None: ret.grad = self.grad.to(device)
  if hasattr(self, '_ctx'): ret._ctx = self._ctx
  return ret

to_ ¤

to_(device: Optional[Union[str, Tuple[str, ...]]])

Moves the tensor to the given device in place.

Source code in tinygrad/tensor.py
325
326
327
328
329
330
331
332
def to_(self, device:Optional[Union[str, Tuple[str, ...]]]):
  """
  Moves the tensor to the given device in place.
  """
  real = self.to(device)
  # TODO: is this assign?
  if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
  self.lazydata = real.lazydata

shard ¤

shard(
    devices: Tuple[str, ...],
    axis: Optional[int] = None,
    splits: Optional[Tuple[int, ...]] = None,
) -> Tensor

Shards the tensor across the given devices. Optionally specify which axis to shard on, and how to split it across devices.

t = Tensor.empty(2, 3)
print(t.shard((t.device, t.device), axis=1, splits=(2, 1)).lazydata)
<MLB self.axis=1 self.real=[True, True] 
CLANG ShapeTracker(views=(View(shape=(2, 2), strides=(2, 1), offset=0, mask=None, contiguous=True),))
CLANG ShapeTracker(views=(View(shape=(2, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),))>
Source code in tinygrad/tensor.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None) -> Tensor:
  """
  Shards the tensor across the given devices. Optionally specify which axis to shard on, and how to split it across devices.

  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor.empty(2, 3)
  print(t.shard((t.device, t.device), axis=1, splits=(2, 1)).lazydata)
  ```

  """
  assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
  canonical_devices, bounds = tuple(Device.canonicalize(x) for x in devices), None
  if axis is not None:
    if axis < 0: axis += len(self.shape)
    if splits is None:
      sz = round_up(self.shape[axis], len(devices)) // len(devices)
      splits = tuple([max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))])
    assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape"
    boundaries = tuple(itertools.accumulate(splits))
    bounds = tuple(zip((0,) + boundaries, boundaries))
  return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis, bounds),
                device=canonical_devices, requires_grad=self.requires_grad)

shard_ ¤

shard_(
    devices: Tuple[str, ...],
    axis: Optional[int] = None,
    splits: Optional[Tuple[int, ...]] = None,
)

Shards the tensor across the given devices in place.

Source code in tinygrad/tensor.py
357
358
359
360
361
362
def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None):
  """
  Shards the tensor across the given devices in place.
  """
  self.lazydata = self.shard(devices, axis, splits).lazydata
  return self

contiguous ¤

contiguous()

Returns a contiguous tensor.

Source code in tinygrad/tensor.py
2132
2133
2134
2135
2136
def contiguous(self):
  """
  Returns a contiguous tensor.
  """
  return F.Contiguous.apply(self)

contiguous_backward ¤

contiguous_backward()

Inserts a contiguous operation in the backward pass.

Source code in tinygrad/tensor.py
2137
2138
2139
2140
2141
def contiguous_backward(self):
  """
  Inserts a contiguous operation in the backward pass.
  """
  return F.ContiguousBackward.apply(self)

backward ¤

backward(
    gradient: Optional[Tensor] = None,
    retain_graph: bool = False,
) -> Tensor

Propagates the gradient of a tensor backwards through the computation graph. If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0. If 'retain_graph' is false, the graph used to compute the grads will be freed. Otherwise, it will be kept. Keeping it can increase memory usage.

t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
t.sum().backward()
print(t.grad.numpy())
[1. 1. 1. 1.]

Source code in tinygrad/tensor.py
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor:
  """
  Propagates the gradient of a tensor backwards through the computation graph.
  If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
  If 'retain_graph' is false, the graph used to compute the grads will be freed. Otherwise, it will be kept. Keeping it can increase memory usage.
  ```python exec="true" source="above" session="tensor" result="python"
  t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
  t.sum().backward()
  print(t.grad.numpy())
  ```
  """
  toposorted = self._deepwalk()
  if gradient is None:
    assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
    # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
    # this is "implicit gradient creation"
    gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)

  assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
  self.grad = gradient
  for t0 in reversed(toposorted):
    if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
    token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None)
    grads = t0._ctx.backward(t0.grad.lazydata)
    _METADATA.reset(token)
    grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
      for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
    for t, g in zip(t0._ctx.parents, grads):
      if g is not None and t.requires_grad:
        assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
        t.grad = g if t.grad is None else (t.grad + g)
    if not retain_graph: del t0._ctx
  return self