jax.experimental.pallas.program_id#
- jax.experimental.pallas.program_id(axis)[source]#
Returns the kernel execution position along the given axis of the grid.
For example, with a 2D grid in the kernel execution corresponding to the grid coordinates (1, 2), program_id(axis=0) returns 1 and program_id(axis=1) returns 2.
The returned value is an array of shape () and dtype int32.
- Parameters:
axis (int) – the axis of the grid along which to count the program.
- Return type:
jax_typing.Array