Skip to content

UOp

UOp ¤

UOp(
    op: UOps,
    dtype: DType = dtypes.void,
    src: Tuple[UOp, ...] = tuple(),
    arg: Any = None,
)

Bases: MathTrait

Source code in tinygrad/ops.py
349
350
def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
  self.op, self.dtype, self.src, self.arg = op, dtype, src, arg

UOps ¤

Bases: FastEnum

SINK ¤

SINK = auto()

Holds UOps.STORE. SINK defines the AST for a Kernel.

  • dtype: dtypes.void
  • src: Tuple[UOp, ...], Only global STOREs are allowed.
  • arg: Optional[KernelInfo]

Note

ScheduleItem ASTs do not have the KernelInfo arg, Kernel inserts this to the SINK later.

EXT ¤

EXT = auto()

Holds a single MetaOp. EXT UOps do not need a Kernel.

  • dtype: Output DType
  • src: Tuple[]
  • arg: (MetaOps.CUSTOM | MetaOps.COPY | MetaOps.EMPTY | MetaOps.VIEW, LazyBuffer arg)

EXPAND ¤

EXPAND = auto()

CONTRACT ¤

CONTRACT = auto()

SHAPETRACKER ¤

SHAPETRACKER = auto()

Defines the ShapeTracker for a buffer UOp UOps.LOAD, UOps.STORE or UOps.VALID.

  • dtype: dtypes.void
  • src: Tuple[]
  • arg: ShapeTracker

SWIZZLE ¤

SWIZZLE = auto()

Swizzle inserts a movement op between a UOp and its children. Because movement ops (reshape, expand, shrink, permute, pad) are not allowed in an AST, the scheduler rewrites SWIZZLE by pushing its ShapeTracker through reduceops or elementwise ops to the edges of the graph.

This movement op can push up to the LOADs and/or down to the STOREs.

Example:

a = Tensor.empty(32, 32)
first_reduce = a.sum()
output = (a + first_reduce).sum()
first_reduce must broadcast to (32, 32) before ADD. We UOp this as:

UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
  UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=(
    UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
      UOp(UOps.LOAD, dtypes.int, arg=None, src=(
        x3:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
        UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
  UOp(UOps.LOAD, dtypes.int, arg=None, src=(
     x3,
    UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),))

The scheduler rewrites this by pushing the expand in SWIZZLE through the reduce, to the LOAD:

UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
-   UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=(
-     UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
-       UOp(UOps.LOAD, dtypes.int, arg=None, src=(
-         x3:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
-         UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
+   UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2, 3)), src=(
+     UOp(UOps.LOAD, dtypes.int, arg=None, src=(
+       x2:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
+       UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32, 32, 32), strides=(0, 0, 32, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),
  UOp(UOps.LOAD, dtypes.int, arg=None, src=(
-      x3,
-     UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),))
+      x2,
+     UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),))

Note

Pushing a SWIZZLE through a reduce changes the axis.

Note

Pushing a SWIZZLE changes the output shape of that UOp. We have to reshape every other adjacent node. eg. reshape of the second LOAD to (32, 32, 1, 1) above.

  • dtype: Output DType
  • src: Tuple[UOp], a single UOp to swizzle.
  • arg: ShapeTracker

DEFINE_GLOBAL ¤

DEFINE_GLOBAL = auto()

DEFINE_VAR ¤

DEFINE_VAR = auto()

DEFINE_LOCAL ¤

DEFINE_LOCAL = auto()

DEFINE_ACC ¤

DEFINE_ACC = auto()

VCONST ¤

VCONST = auto()

CONST ¤

CONST = auto()

Defines a single scalar constant value.

  • dtype: The scalar DType of the value.

  • src: Tuple[]

  • arg: The value.

VALID ¤

VALID = auto()

This is the first argument in a masked CONST.

  • dtype: dtypes.bool
  • src: Tuple[UOp]
    • UOps.SHAPETRACKER
  • arg: None

A masked CONST is defined as valid.where(value, 0).

SPECIAL ¤

SPECIAL = auto()

NOOP ¤

NOOP = auto()

GEP ¤

GEP = auto()

CAST ¤

CAST = auto()
  • dtype: The casted scalar DType
  • src: Tuple[UOp]
  • arg: None

BITCAST ¤

BITCAST = auto()
  • dtype: The bitcasted scalar DType
  • src: Tuple[UOp]
  • arg: None

VECTORIZE ¤

VECTORIZE = auto()
  • dtype: The upcasted vector DType
  • src: Tuple[UOp, ...]
  • arg: None

Note

Length of sources must match dtype.count

ALU ¤

ALU = auto()
  • dtype: Output DType
  • src: Tuple[UOp] | Tuple[UOp, UOp] | Tuple[UOp, UOp, UOp]
  • arg: UnaryOps | BinaryOps | TernaryOps

REDUCE ¤

REDUCE = auto()

REDUCE_AXIS ¤

REDUCE_AXIS = auto()
  • dtype: Output DType
  • src: Input to reduce Tuple[UOp]
  • arg: (BinaryOps.ADD | BinaryOps.MUL | BinaryOps.MAX, Tuple[int, ...])

WMMA ¤

WMMA = auto()

LOAD ¤

LOAD = auto()
  • dtype: Output DType
  • src:

The scheduler and Kernel create LOADs with a SHAPETRACKER uop in src.

  • Normal LOAD: Tuple[UOp, UOp]

    • Buffer UOp UOps.DEFINE_GLOBAL.
    • SHAPETRACKER UOp.
  • Local LOAD: Tuple[UOp, UOp, UOp]

    • Buffer UOp UOps.DEFINE_LOCAL.
    • SHAPETRACKER UOp.
    • Local UOps.STORE to the same local buffer. We will barrier this later.

The Lowerer replaces the SHAPETRACKER with an indexing uop and gates the LOAD if needed.

  • Normal LOAD: Tuple[UOp, UOp]
    • Buffer UOp UOps.DEFINE_GLOBAL.
    • Indexing UOp, can only return dtypes.int32.
  • Gated LOAD: Tuple[UOp, UOp, UOp, UOp]
    • Buffer UOp UOps.DEFINE_GLOBAL.
    • Indexing UOp, can only return dtypes.int32.
    • Gate UOp, can only return dtypes.bool.
    • Value if gate is False, can only be a UOps.CONST with arg 0, 0.0 or False.
  • Barriered LOAD: Tuple[UOp, UOp, UOp, UOp]
    • Buffer UOp UOps.DEFINE_LOCAL.
    • Indexing UOp, can only return dtypes.int32.
    • Gate UOp, can only return dtypes.bool.
    • Barrier UOp UOps.BARRIER.
  • arg: None

STORE ¤

STORE = auto()
  • dtype: dtypes.void
  • src:

Similar to LOAD, the scheduler and Kernel create STOREs with a SHAPETRACKER uop in src:

  • Buffer UOp UOps.DEFINE_GLOBAL or UOps.DEFINE_LOCAL.
  • SHAPETRACKER UOp.
  • Value to store.

The Lowerer replaces the SHAPETRACKER with an indexing uop and gates the STORE if needed.

  • Normal STORE: Tuple[UOp, UOp, UOp]
    • Buffer UOp UOps.DEFINE_GLOBAL or UOps.DEFINE_LOCAL.
    • Indexing UOp, can only return dtypes.int32.
    • Value to store.
  • Gated STORE: Tuple[UOp, UOp, UOp, UOp]
    • Buffer UOp UOps.DEFINE_GLOBAL or UOps.DEFINE_LOCAL.
    • Indexing UOp, can only return dtypes.int32.
    • Value to store.
    • Gate UOp, can only return dtypes.bool. We rewrite this to an IF block in the end.
  • arg: None

ASSIGN ¤

ASSIGN = auto()

BARRIER ¤

BARRIER = auto()

Inserts a warp sync between local stores and local loads.

  • dtype: dtypes.void
  • src: Tuple[UOp, ...], Only local STOREs are allowed.
  • arg: None

IF ¤

IF = auto()

Gates a single STORE to global memory. The IF block could also contain additional UOps the STORE depends on.

  • dtype: dtypes.void
  • src: Tuple[UOp, UOp]
    • Gate UOp, can only return dtypes.bool
    • The second UOp starts the gate block; All of its children are gated until the final STORE.
  • arg: None

For example, a local reduce must only run on one thread.

The STORE's IF gate:

UOp(UOps.IF, src=(
  UOp(UOps.ALU, dtypes.bool, (...), BinaryOps.CMPNE),
  UOp(UOps.BARRIER, dtypes.void, (...))))
The kernel:
barrier(CLK_LOCAL_MEM_FENCE);
if (lidx0!=1) {
  int acc1 = 0;
  for (int ridx1 = 0; ridx1 < 16; ridx1++) {
    int val1 = temp1[ridx1];
    acc1 = (acc1+val1);
  }
  data0[0] = acc1;
}

RANGE ¤

RANGE = auto()

ENDRANGE ¤

ENDRANGE = auto()

ENDIF ¤

ENDIF = auto()