Skip to content

nn (Neural Networks)

Neural Network classes¤

BatchNorm ¤

BatchNorm(
    sz: int,
    eps=1e-05,
    affine=True,
    track_running_stats=True,
    momentum=0.1,
)

Applies Batch Normalization over a 2D or 3D input.

See: Tensor.batchnorm

norm = nn.BatchNorm(3)
t = Tensor.rand(2, 3, 4, 4)
print(t.mean().item(), t.std().item())
0.5023592710494995 0.2932378053665161
t = norm(t)
print(t.mean().item(), t.std().item())
0.502356767654419 0.29323628544807434

Source code in tinygrad/nn/__init__.py
33
34
35
36
37
38
39
40
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
  self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum

  self.weight: Optional[Tensor] = Tensor.ones(sz) if affine else None
  self.bias: Optional[Tensor] = Tensor.zeros(sz) if affine else None

  self.num_batches_tracked = Tensor.zeros(1, dtype='long', requires_grad=False)
  if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)

Conv1d ¤

Conv1d(
    in_channels: int,
    out_channels: int,
    kernel_size: int,
    stride=1,
    padding: Union[int, str] = 0,
    dilation=1,
    groups=1,
    bias=True,
) -> Conv2d

Applies a 1D convolution over an input signal composed of several input planes.

See: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d

conv = nn.Conv1d(1, 1, 3)
t = Tensor.rand(1, 1, 4)
print(t.numpy())
[[[0.7451 0.3881 0.6753 0.7302]]]
t = conv(t)
print(t.numpy())
[[[0.8188 0.9575]]]

Source code in tinygrad/nn/__init__.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def Conv1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding:Union[int, str]=0, dilation=1, groups=1, bias=True) -> Conv2d:
  """
  Applies a 1D convolution over an input signal composed of several input planes.

  See: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d

  ```python exec="true" source="above" session="tensor" result="python"
  conv = nn.Conv1d(1, 1, 3)
  t = Tensor.rand(1, 1, 4)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  t = conv(t)
  print(t.numpy())
  ```
  """
  return Conv2d(in_channels, out_channels, (kernel_size,), stride, padding, dilation, groups, bias)

Conv2d ¤

Conv2d(
    in_channels: int,
    out_channels: int,
    kernel_size: Union[int, Tuple[int, ...]],
    stride=1,
    padding: Union[int, str] = 0,
    dilation=1,
    groups=1,
    bias=True,
)

Applies a 2D convolution over an input signal composed of several input planes.

See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d

conv = nn.Conv2d(1, 1, 3)
t = Tensor.rand(1, 1, 4, 4)
print(t.numpy())
[[[[0.847  0.5599 0.742  0.7218]
   [0.7082 0.7137 0.9589 0.6087]
   [0.7811 0.6999 0.6388 0.8132]
   [0.6954 0.2978 0.5806 0.5769]]]]
t = conv(t)
print(t.numpy())
[[[[-0.6182 -0.6352]
   [-0.4443 -0.5687]]]]

Source code in tinygrad/nn/__init__.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
def __init__(self, in_channels:int, out_channels:int, kernel_size:Union[int, Tuple[int, ...]], stride=1, padding:Union[int, str]=0,
              dilation=1, groups=1, bias=True):
  self.kernel_size: Tuple[int, ...] = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
  if isinstance(padding, str):
    if padding.lower() != 'same': raise ValueError(f"Invalid padding string {padding!r}, only 'same' is supported")
    if stride != 1: raise ValueError("padding='same' is not supported for strided convolutions")
    self.padding: Union[int, List[int]] = [p for d,k in zip(make_pair(dilation,len(self.kernel_size)), self.kernel_size[::-1]) for p in (d*(k-1)//2, d*(k-1) - d*(k-1)//2)] #noqa:E501
  else: self.padding = padding
  self.stride, self.dilation, self.groups = stride, dilation, groups
  scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
  self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
  self.bias: Optional[Tensor] = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None

ConvTranspose1d ¤

ConvTranspose1d(
    in_channels: int,
    out_channels: int,
    kernel_size: int,
    stride=1,
    padding=0,
    output_padding=0,
    dilation=1,
    groups=1,
    bias=True,
) -> ConvTranspose2d

Applies a 1D transposed convolution operator over an input signal composed of several input planes.

See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d

conv = nn.ConvTranspose1d(1, 1, 3)
t = Tensor.rand(1, 1, 4)
print(t.numpy())
[[[0.1249 0.6017 0.9536 0.0144]]]
t = conv(t)
print(t.numpy())
[[[0.1806 0.3138 0.4687 0.4093 0.5338 0.1521]]]

Source code in tinygrad/nn/__init__.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def ConvTranspose1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding=0, output_padding=0, dilation=1,
                      groups=1, bias=True) -> ConvTranspose2d:
  """
  Applies a 1D transposed convolution operator over an input signal composed of several input planes.

  See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d

  ```python exec="true" source="above" session="tensor" result="python"
  conv = nn.ConvTranspose1d(1, 1, 3)
  t = Tensor.rand(1, 1, 4)
  print(t.numpy())
  ```
  ```python exec="true" source="above" session="tensor" result="python"
  t = conv(t)
  print(t.numpy())
  ```
  """
  return ConvTranspose2d(in_channels, out_channels, (kernel_size,), stride, padding, output_padding, dilation, groups, bias)

ConvTranspose2d ¤

ConvTranspose2d(
    in_channels: int,
    out_channels: int,
    kernel_size: Union[int, Tuple[int, ...]],
    stride=1,
    padding=0,
    output_padding=0,
    dilation=1,
    groups=1,
    bias=True,
)

Bases: Conv2d

Applies a 2D transposed convolution operator over an input image.

See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d

conv = nn.ConvTranspose2d(1, 1, 3)
t = Tensor.rand(1, 1, 4, 4)
print(t.numpy())
[[[[4.0614e-02 5.9807e-02 8.4977e-02 4.8934e-01]
   [5.4422e-01 5.7134e-01 6.1290e-01 4.9531e-01]
   [2.8514e-01 9.2525e-01 4.2239e-01 7.4913e-01]
   [3.8660e-04 5.5375e-01 8.3814e-02 9.3957e-01]]]]
t = conv(t)
print(t.numpy())
[[[[ 0.0538  0.0577  0.0482  0.0934  0.0387 -0.1123]
   [ 0.1107  0.1229 -0.0643 -0.1557 -0.2484 -0.1945]
   [ 0.004  -0.0248 -0.2613 -0.4306 -0.3285 -0.2904]
   [ 0.0295 -0.124  -0.2784 -0.4334 -0.2438 -0.3948]
   [ 0.0611 -0.0307 -0.1828 -0.2436 -0.2618 -0.123 ]
   [ 0.0487  0.0729 -0.0181  0.0662 -0.0728  0.0267]]]]

Source code in tinygrad/nn/__init__.py
148
149
150
151
152
153
def __init__(self, in_channels:int, out_channels:int, kernel_size:Union[int, Tuple[int, ...]], stride=1, padding=0, output_padding=0,
              dilation=1, groups=1, bias=True):
  super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
  scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
  self.weight = Tensor.uniform(in_channels, out_channels//groups, *self.kernel_size, low=-scale, high=scale)
  self.output_padding = output_padding

Linear ¤

Linear(in_features: int, out_features: int, bias=True)

Applies a linear transformation to the incoming data.

See: https://pytorch.org/docs/stable/generated/torch.nn.Linear

lin = nn.Linear(3, 4)
t = Tensor.rand(2, 3)
print(t.numpy())
[[0.4849 0.3759 0.165 ]
 [0.3368 0.6523 0.2948]]
t = lin(t)
print(t.numpy())
[[ 0.1466 -0.7808  0.5711 -0.2258]
 [ 0.2718 -0.8786  0.66   -0.1921]]

Source code in tinygrad/nn/__init__.py
175
176
177
178
def __init__(self, in_features:int, out_features:int, bias=True):
  bound = 1 / math.sqrt(in_features)
  self.weight = Tensor.uniform(out_features, in_features, low=-bound, high=bound)
  self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None

GroupNorm ¤

GroupNorm(
    num_groups: int,
    num_channels: int,
    eps=1e-05,
    affine=True,
)

Applies Group Normalization over a mini-batch of inputs.

norm = nn.GroupNorm(2, 12)
t = Tensor.rand(2, 12, 4, 4) * 2 + 1
print(t.mean().item(), t.std().item())
1.980262041091919 0.5960425734519958
t = norm(t)
print(t.mean().item(), t.std().item())
-2.588363088307233e-07 1.0012905597686768

Source code in tinygrad/nn/__init__.py
200
201
202
203
def __init__(self, num_groups:int, num_channels:int, eps=1e-5, affine=True):
  self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
  self.weight: Optional[Tensor] = Tensor.ones(num_channels) if affine else None
  self.bias: Optional[Tensor] = Tensor.zeros(num_channels) if affine else None

InstanceNorm ¤

InstanceNorm(num_features: int, eps=1e-05, affine=True)

Applies Instance Normalization over a mini-batch of inputs.

norm = nn.InstanceNorm(3)
t = Tensor.rand(2, 3, 4, 4) * 2 + 1
print(t.mean().item(), t.std().item())
1.9248237609863281 0.5711491107940674
t = norm(t)
print(t.mean().item(), t.std().item())
-3.397539316551956e-08 1.005232572555542

Source code in tinygrad/nn/__init__.py
231
232
233
234
def __init__(self, num_features:int, eps=1e-5, affine=True):
  self.num_features, self.eps = num_features, eps
  self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None
  self.bias: Optional[Tensor] = Tensor.zeros(num_features) if affine else None

LayerNorm ¤

LayerNorm(
    normalized_shape: Union[int, Tuple[int, ...]],
    eps=1e-05,
    elementwise_affine=True,
)

Applies Layer Normalization over a mini-batch of inputs.

norm = nn.LayerNorm(3)
t = Tensor.rand(2, 5, 3) * 2 + 1
print(t.mean().item(), t.std().item())
1.9588143825531006 0.6108710765838623
t = norm(t)
print(t.mean().item(), t.std().item())
-4.89320619578848e-08 1.0169694423675537

Source code in tinygrad/nn/__init__.py
258
259
260
261
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps=1e-5, elementwise_affine=True):
  self.normalized_shape: Tuple[int, ...] = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
  self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
  self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None)

LayerNorm2d ¤

LayerNorm2d(
    normalized_shape: Union[int, Tuple[int, ...]],
    eps=1e-05,
    elementwise_affine=True,
)

Bases: LayerNorm

Applies Layer Normalization over a mini-batch of 2D inputs.

See: LayerNorm

norm = nn.LayerNorm2d(3)
t = Tensor.rand(2, 3, 4, 4) * 2 + 1
print(t.mean().item(), t.std().item())
2.0245089530944824 0.5549166798591614
t = norm(t)
print(t.mean().item(), t.std().item())
-2.539678689572611e-07 1.005126714706421

Source code in tinygrad/nn/__init__.py
258
259
260
261
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps=1e-5, elementwise_affine=True):
  self.normalized_shape: Tuple[int, ...] = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
  self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
  self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None)

RMSNorm ¤

RMSNorm(dim: int, eps=1e-06)

Applies Root Mean Square Normalization to input.

norm = nn.RMSNorm(4)
t = Tensor.arange(12, dtype=dtypes.float).reshape(3, 4)
print(t.numpy())
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]]
print(norm(t).numpy())
[[0.     0.5345 1.069  1.6036]
 [0.7127 0.8909 1.069  1.2472]
 [0.8363 0.9409 1.0454 1.15  ]]

Source code in tinygrad/nn/__init__.py
303
def __init__(self, dim:int, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)

Embedding ¤

Embedding(vocab_size: int, embed_size: int)

A simple lookup table that stores embeddings of a fixed dictionary and size.

See: https://pytorch.org/docs/stable/generated/torch.nn.Embedding

emb = nn.Embedding(10, 3)
print(emb(Tensor([1, 2, 3, 1])).numpy())
[[-0.511  -0.1045  0.5687]
 [ 0.2601  0.6305 -0.1365]
 [ 0.6639  0.5806  0.1463]
 [-0.511  -0.1045  0.5687]]
Source code in tinygrad/nn/__init__.py
320
321
def __init__(self, vocab_size:int, embed_size:int):
  self.vocab_sz, self.embed_sz, self.weight = vocab_size, embed_size, Tensor.glorot_uniform(vocab_size, embed_size)

LSTMCell ¤

LSTMCell(
    input_size: int, hidden_size: int, bias: bool = True
)

A long short-term memory (LSTM) cell.

Parameters:

  • input_size (int) –

    The number of expected features in the input x

  • hidden_size (int) –

    The number of features in the hidden state h

  • bias (bool, default: True ) –

    If False, then the layer does not use bias weights b_ih and b_hh

Source code in tinygrad/nn/__init__.py
339
340
341
342
343
def __init__(self, input_size:int, hidden_size:int, bias:bool=True):
  stdv = 1.0 / math.sqrt(hidden_size)
  self.weight_ih = Tensor.uniform(hidden_size*4, input_size, low=-stdv, high=stdv)
  self.weight_hh = Tensor.uniform(hidden_size*4, hidden_size, low=-stdv, high=stdv)
  self.bias_ih, self.bias_hh = (Tensor.zeros(hidden_size*4), Tensor.zeros(hidden_size*4)) if bias else (None, None)

Optimizers¤

SGD ¤

SGD(
    params: List[Tensor],
    lr=0.001,
    momentum=0.0,
    weight_decay=0.0,
    nesterov=False,
    classic=False,
)

Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.

classic is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.

Source code in tinygrad/nn/optim.py
57
58
59
60
61
62
63
64
65
def SGD(params: List[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False):
  """
  Stochastic Gradient Descent (SGD) optimizer with optional momentum and weight decay.

  `classic` is a boolean flag that determines whether to use the popular momentum update rule or the classic momentum update rule.

  - Described: https://paperswithcode.com/method/sgd
  """
  return LARS(params, lr, momentum, weight_decay, nesterov, classic, tcoef=0.0)

LARS ¤

LARS(
    params: List[Tensor],
    lr=0.001,
    momentum=0.9,
    weight_decay=0.0001,
    nesterov=False,
    classic=True,
    tcoef=0.001,
)

Bases: Optimizer

Layer-wise Adaptive Rate Scaling (LARS) optimizer with optional momentum and weight decay.

Source code in tinygrad/nn/optim.py
74
75
76
77
def __init__(self, params:List[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-4, nesterov=False, classic=True, tcoef=0.001):
  super().__init__(params, lr)
  self.momentum, self.wd, self.nesterov, self.classic, self.tcoef = momentum, weight_decay, nesterov, classic, tcoef
  self.b = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []

AdamW ¤

AdamW(
    params: List[Tensor],
    lr=0.001,
    b1=0.9,
    b2=0.999,
    eps=1e-08,
    weight_decay=0.01,
)

AdamW optimizer with optional weight decay.

Source code in tinygrad/nn/optim.py
102
103
104
105
106
107
108
109
def AdamW(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.01):
  """
  AdamW optimizer with optional weight decay.

  - Described: https://paperswithcode.com/method/adamw
  - Paper: https://arxiv.org/abs/1711.05101v3
  """
  return LAMB(params, lr, b1, b2, eps, weight_decay, adam=True)

Adam ¤

Adam(
    params: List[Tensor],
    lr=0.001,
    b1=0.9,
    b2=0.999,
    eps=1e-08,
)

Adam optimizer.

Source code in tinygrad/nn/optim.py
110
111
112
113
114
115
116
117
def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
  """
  Adam optimizer.

  - Described: https://paperswithcode.com/method/adam
  - Paper: https://arxiv.org/abs/1412.6980
  """
  return LAMB(params, lr, b1, b2, eps, 0.0, adam=True)

LAMB ¤

LAMB(
    params: List[Tensor],
    lr=0.001,
    b1=0.9,
    b2=0.999,
    eps=1e-06,
    weight_decay=0.0,
    adam=False,
)

Bases: Optimizer

LAMB optimizer with optional weight decay.

Source code in tinygrad/nn/optim.py
126
127
128
129
130
131
def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False):
  super().__init__(params, lr)
  self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
  self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device, requires_grad=False).contiguous() for _ in [b1, b2])
  self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
  self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]

Load/Save¤

safe_load ¤

safe_load(fn: Union[Tensor, str]) -> Dict[str, Tensor]

Loads a .safetensor file from disk, returning the state_dict.

state_dict = nn.state.safe_load("test.safetensor")
Source code in tinygrad/nn/state.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
  """
  Loads a .safetensor file from disk, returning the state_dict.

  ```python
  state_dict = nn.state.safe_load("test.safetensor")
  ```
  """
  t, json_len, metadata = safe_load_metadata(fn)
  ret = {}
  for k,v in metadata.items():
    if k == "__metadata__": continue
    dtype = safe_dtypes[v['dtype']]
    sz = (v['data_offsets'][1]-v['data_offsets'][0])
    ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].bitcast(dtype).reshape(v['shape'])
  return ret

safe_save ¤

safe_save(
    tensors: Dict[str, Tensor],
    fn: str,
    metadata: Optional[Dict[str, Any]] = None,
)

Saves a state_dict to disk in a .safetensor file with optional metadata.

t = Tensor([1, 2, 3])
nn.state.safe_save({'t':t}, "test.safetensor")
Source code in tinygrad/nn/state.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
  """
  Saves a state_dict to disk in a .safetensor file with optional metadata.

  ```python
  t = Tensor([1, 2, 3])
  nn.state.safe_save({'t':t}, "test.safetensor")
  ```
  """
  headers, offset = {}, 0
  if metadata: headers['__metadata__'] = metadata
  for k,v in tensors.items():
    headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
    offset += v.nbytes()
  j = json.dumps(headers, separators=(',', ':'))
  j += "\x20"*((8-len(j)%8)%8)
  pathlib.Path(fn).unlink(missing_ok=True)
  t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
  t[0:8].bitcast(dtypes.int64).assign([len(j)])
  t[8:8+len(j)].assign(list(j.encode('utf-8')))
  for k,v in safe_load(t).items(): v.assign(tensors[k])

get_state_dict ¤

get_state_dict(
    obj, prefix: str = "", tensor_type=Tensor
) -> Dict[str, Tensor]

Returns a state_dict of the object, with optional prefix.

class Net:
  def __init__(self):
    self.l1 = nn.Linear(4, 5)
    self.l2 = nn.Linear(5, 6)

net = Net()
print(nn.state.get_state_dict(net).keys())
dict_keys(['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'])
Source code in tinygrad/nn/state.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
  """
  Returns a state_dict of the object, with optional prefix.

  ```python exec="true" source="above" session="tensor" result="python"
  class Net:
    def __init__(self):
      self.l1 = nn.Linear(4, 5)
      self.l2 = nn.Linear(5, 6)

  net = Net()
  print(nn.state.get_state_dict(net).keys())
  ```
  """
  if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
  if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type)  # namedtuple
  if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
  if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
  state_dict = {}
  if isinstance(obj, (list, tuple)):
    for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
  elif isinstance(obj, dict):
    for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
  return state_dict

get_parameters ¤

get_parameters(obj) -> List[Tensor]
class Net:
  def __init__(self):
    self.l1 = nn.Linear(4, 5)
    self.l2 = nn.Linear(5, 6)

net = Net()
print(len(nn.state.get_parameters(net)))
4
Source code in tinygrad/nn/state.py
87
88
89
90
91
92
93
94
95
96
97
98
99
def get_parameters(obj) -> List[Tensor]:
  """
  ```python exec="true" source="above" session="tensor" result="python"
  class Net:
    def __init__(self):
      self.l1 = nn.Linear(4, 5)
      self.l2 = nn.Linear(5, 6)

  net = Net()
  print(len(nn.state.get_parameters(net)))
  ```
  """
  return list(get_state_dict(obj).values())

load_state_dict ¤

load_state_dict(
    model,
    state_dict: Dict[str, Tensor],
    strict=True,
    verbose=True,
    consume=False,
) -> None

Loads a state_dict into a model.

class Net:
  def __init__(self):
    self.l1 = nn.Linear(4, 5)
    self.l2 = nn.Linear(5, 6)

net = Net()
state_dict = nn.state.get_state_dict(net)
nn.state.load_state_dict(net, state_dict)
Source code in tinygrad/nn/state.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=True, consume=False) -> None:
  """
  Loads a state_dict into a model.

  ```python
  class Net:
    def __init__(self):
      self.l1 = nn.Linear(4, 5)
      self.l2 = nn.Linear(5, 6)

  net = Net()
  state_dict = nn.state.get_state_dict(net)
  nn.state.load_state_dict(net, state_dict)
  ```
  """
  start_mem_used = GlobalCounters.mem_used
  with Timing("loaded weights in ", lambda et_ns: f", {(GlobalCounters.mem_used-start_mem_used)/1e9:.2f} GB loaded at {(GlobalCounters.mem_used-start_mem_used)/et_ns:.2f} GB/s"):  # noqa: E501
    model_state_dict = get_state_dict(model)
    if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
      print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
    for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)):
      t.desc = f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}: "
      if k not in state_dict and not strict:
        if DEBUG >= 1: print(f"WARNING: not loading {k}")
        continue
      if isinstance((mlb:=v.lazydata), MultiLazyBuffer):
        if isinstance(state_dict[k].lazydata, MultiLazyBuffer): v.replace(state_dict[k]).realize()
        else: v.replace(state_dict[k].shard(mlb.device, mlb.axis)).realize()
      else: v.replace(state_dict[k].to(v.device)).realize()
      if consume: del state_dict[k]

torch_load ¤

torch_load(fn: str) -> Dict[str, Tensor]

Loads a torch .pth file from disk.

state_dict = nn.state.torch_load("test.pth")
Source code in tinygrad/nn/state.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def torch_load(fn:str) -> Dict[str, Tensor]:
  """
  Loads a torch .pth file from disk.

  ```python
  state_dict = nn.state.torch_load("test.pth")
  ```
  """
  t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")

  offsets: Dict[Union[str, int], int] = {}
  lens: Dict[Union[str, int], int] = {}
  def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
    #print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
    lens[storage[2]] = storage[4] * storage[1].itemsize
    if storage[2] not in offsets: return None
    byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
    ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1])

    # 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
    shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
    permute_indexes = [len(shape_strides)-1-y for y in argsort([x[1] for x in shape_strides])]
    if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
      intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
      assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
      if DEBUG >= 3: print(f"WARNING: this torch load is slow. CLANG to permute {intermediate_shape} with {permute_indexes}")
      assert storage[1] != dtypes.bfloat16, "can't CLANG permute BF16"
      # TODO: find a nice way to support all shapetracker on disktensors
      ret = ret.to(None).reshape(intermediate_shape).permute(permute_indexes)

    return ret.reshape(size)

  class Parameter:
    def __setstate__(self, state): self.tensor = state[0]

  deserialized_objects: Dict[str, Any] = {}
  intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16,
               "IntStorage": dtypes.int32, "BoolStorage": dtypes.bool,
               "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
  whitelist = {"torch", "collections", "numpy", "_codecs"}  # NOTE: this is not for security, only speed
  class Dummy: pass
  class TorchPickle(pickle.Unpickler):
    def find_class(self, module, name):
      module_root = module.split(".")[0]
      if module_root not in whitelist:
        if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}")
        return Dummy
      return intercept[name] if module_root == "torch" else super().find_class(module, name)
    def persistent_load(self, pid): return deserialized_objects.get(pid, pid)

  if zipfile.is_zipfile(fn):
    myzip = zipfile.ZipFile(fn, 'r')
    base_name = myzip.namelist()[0].split('/', 1)[0]
    for n in myzip.namelist():
      if n.startswith(f'{base_name}/data/'):
        with myzip.open(n) as myfile:
          offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
    with myzip.open(f'{base_name}/data.pkl') as myfile:
      return TorchPickle(myfile).load()
  elif tarfile.is_tarfile(fn):
    with tarfile.open(fn, "r") as tar:
      storages_offset = tar.getmember('storages').offset_data
      f = unwrap(tar.extractfile('storages'))
      for i in range(TorchPickle(f).load()):  # num_storages
        (key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('<q', f.read(8))[0]
        offsets[key] = storages_offset + f.tell()
        f.seek(sz*storage_type.itemsize, 1)
      f = unwrap(tar.extractfile('tensors'))
      for _ in range(TorchPickle(f).load()):  # num_tensors
        (key, storage_id, _), ndim, _ = TorchPickle(f).load(), struct.unpack('<i', f.read(4))[0], f.read(4)
        size, stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack(f'<{ndim}q', f.read(8 * ndim))
        storage_offset = struct.unpack('<q', f.read(8))[0]
        deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
      return {k:v.tensor if isinstance(v, Parameter) else v for k,v in TorchPickle(unwrap(tar.extractfile('pickle'))).load().items()}
  else:
    with open(fn, "rb") as f:
      pkl = TorchPickle(f)
      _, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), f.tell(), pkl.load(), pkl.load(), f.tell()
      for i in ids:
        offsets[i] = base_offset + 8
        base_offset += 8 + lens[i]
      f.seek(rwd)
      return TorchPickle(f).load()