Notes on Intern Work
Published:
https://www.coursera.org/learn/build-a-computer
llvm::outs()«“retIsF16”«retIsF16«“\n”;
ttir -> ttgir
layout 转换的函数
MLIR_ENABLE_DUMP = 1
TODO list for this week:
- 看 triton python tutorials 怎么用的
- 看一下 gluon 到底是干什么的,里面有没有 linear 或者 blocked layouts;如果有 blocked layouts,看看能否转换为 linear
- emitIndices 函数是与 layout 转换相关的,研究一下里面的细节(https://github.com/triton-lang/triton/blob/main/lib/Conversion/TritonGPUToLLVM/Utility.cpp#L310)
- 学会使用 llvm::outs() 打印调试信息
- 学会使用 MLIR_ENABLE_DUMP = 1 来看下 python code 是如何一步步转换为 IR 的,里面是怎么用 layout 的(在哪一个层次,可能是 ttir -> ttgir 层次?)
Decorator:
add_kernel = triton.jit(add_kernel) # of type JITFunction[type(add_kernel)]
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
# equivalent to triton.jit(add_kernel)[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
JITFunction has implemented its __getitem__
method, such that triton.jit(add_kernel)[grid]
is equivalent to triton.jit(add_kernel).__getitem__(grid)
.
class KernelInterface(Generic[T]):
run: T
def __getitem__(self, grid) -> T:
"""
A JIT function is launched with: fn[grid](*args, **kwargs).
Hence JITFunction.__getitem__ returns a callable proxy that
memorizes the grid.
"""
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
class JITFunction(KernelInterface[T]): ...
Therefore, add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
is equivalent to triton.jit(add_kernel).__getitem__(grid)(x, y, output, n_elements, BLOCK_SIZE=1024)
, which is triton.jit(add_kernel).run(grid=grid, warmup=False, *(x, y, output, n_elements, BLOCK_SIZE=1024))
. The core computation is done in kernel.run
, which is too complicated that I actually haven’t fully figured out how it functions yet.
typing
is so freaking strange. It makes a python script look like a C++ template code.
Positional and keyword arguments in Python:
*args
collects all positional arguments into a tuple, while**kwargs
collects all keyword arguments into a dictionary.- In a function definition, there are five types of parameters:
- Positional-only parameters (e.g.,
def func(a, b, /)
): Arguments listed before a forward slash (/
). - Positional-or-keyword parameters (e.g.,
def func(a, b)
): The “normal” arguments we use every day. They appear after any positional-only arguments but before*args
. - Keyword-only parameters (e.g.,
def func(*, a, b)
): Any argument that appears after*args
or a bare*
. - Variable positional parameters (e.g.,
def func(*args)
). - Variable keyword parameters (e.g.,
def func(**kwargs)
).
- Positional-only parameters (e.g.,
- In a general function call:
- positional arguments are passed first, followed by keyword arguments. This is non-negotiable.
- However, when collecting variable positional arguments (
*args
), they can be passed after keyword arguments, where python will cleverly unpack and place them in the correct place. They cannot, nevertheless, be passed after variable keyword arguments (**kwargs
).
def my_func(pos_only=None, /, std_arg=None, *args, kw_only=None, **kwargs):
print(f"pos_only: {pos_only}, std_arg: {std_arg}, args: {args}, kw_only: {kw_only}, kwargs: {kwargs}")
In the example above, my_func(kw_only=10, *(30, 40, 50))
would work, but my_func(**{'kw_only': 10}, *(30, ))
would raise a SyntaxError
.
By the way, I just found out python checks the order of arguments in a function call during compilation step, even before runtime.
- The compilation step checks
SyntaxError
,IndentationError
, andTabError
.SyntaxError
:- Invalid structure: Using a keyword in the wrong place, such as
for = 10 # 'for' is a keyword, not a variable name
- Malformed expressions: Unbalanced parentheses, brackets, or quotes, such as
my_list = [1, 2, (3, 4] # Mismatched () and []
- Argument order violations: The rules we’ve discussed above.
- Invalid assignment: Trying to assign a value to something that can’t hold one, such as
"hello" = 12 # Can't assign to a string literal
- Invalid structure: Using a keyword in the wrong place, such as
IndentationError
:- Unexpected indentation: An indent that doesn’t follow a colon (
:
). - Unmatched indentation: Code that is expected to be indented but isn’t.
- Unexpected indentation: An indent that doesn’t follow a colon (
TabError
: If tabs and spaces are used interchangeably within the same block, Python can’t reliably determine the indentation level, resulting in a TabError.def example_function(): print("Indented with tabs") print("This line has a mixture of tabs and spaces")
In modern code editors like VSCode, this error is automatically fixed by code editors silently converting all tabs to spaces.
- The runtime step checks the rest of errors, such as
NameError
,TypeError
,ValueError
, etc.
The conventional try-except
block can only capture those errors that occur during the runtime step. In VSCode, these two types of errors are underlined in different colors: errors occurred during compilation time (can’t be captured by try-except
) are colored red, while errors occurred during runtime (can be captured by try-except
) are colored yellow.
July 1
我们这一个月工作的心得:
- 看了很多东西,也写了一些文档,但是并没有真正上手做什么。
- 感觉目标不够明确。在看 code 的过程中经常会纠结于一些细节,但是后来发现这些细节并不是我们需要关注的重点,导致浪费了很多时间。
- 像 triton 这样一个很大的项目,阅读的过程中通常会觉得有些 overwhelming,有很多地方细节实在是太多,不知道究竟要了解到什么程度,经常迷失在细节里面
- 希望能有更明确的一个可以执行的任务,让我上手写写代码
_cute_ops_gen.py
今天开会的时候跟我说要我重新做回 cuteDsl 的工作。大致我理解的意思是:
- CUTLASS 有两套代码,一套是 C++ 的,另一套是 Python 的
- Python 的代码执行过程是,Python 代码先被翻译为 CuTe IR,然后是 optimized CuTe IR,再到 LLVM IR,最后被编译为机器码
- CUTLASS 是开源的,但是上面所述的转换过程中,Python 到 CuTe 基本已知,CuTe 到 llvm 和 kernel binary 代码并没有开源,只有 Python 写的接口(API,不知道这样叫对不对)和生成的 IR (被叫做 dump)是开源的
- 既然他们想在 Intel 的 GPU 上弄这一套,具体弄到什么程度我也不知道,但是他们想要把这套转换过程弄出来
- 整个工程很繁杂(一个 MLIR 的代码接近几十万行),他们说一个月内弄完不太现实,所以还是让我以看文档为主,对照着 C++ 的代码搞清楚各个函数的预期表现是什么,向下转换的逻辑是什么样的,可能让我在一些最简单的 operator 操作上写一些单元测试
说实话我有些不太开心,感觉自己有点被耍,前面让我看的 linear layout 和 gluon 之类现在都不提了,上一个月花的时间好像直接被浪费掉了。关键是那几个人说的也不清楚,感觉他们自己也没想好要干什么,东一榔头西一棒的。现在又让我回去看 cuteDsl 的文档,感觉有点像是把我当成了一个文档阅读器,搞半天两个月全部都是在看些可能以后永远也不会用到的文档,也不说要看到什么程度,就只是说看看看。(这段是 Copilot 自动生成的,但我觉得说得太好了)
但说回来,这么大的一个工程项目,确实想想都繁杂,可能最简单的任务上手 ramp up 都得花好久才行,更别提我这种从来没接触过 C++、从来没写过 cuda、完全不懂 LLVM 和 IR、从来没真正做过一个项目的人了。虽然我很想要有点代码产出,但是也许现在得换个心态。一个很重要的能学习的点就是观察和阅读庞大项目的代码究竟是如何组织的,为什么要有这么多文件夹(好多都还是同名的),如何高效率地在并不清楚全部细节的情况下阅读并理解代码、不陷入细节漩涡里,这可能比直接上手写代码更重要。我们还可以在读代码的过程中获得工程代码如何书写的一手资料。这样看的话,也不能说上个月就完全一无所获,至少看了不少 Python 代码后知道了 __init__.py
到底干什么的,学会了各种奇妙的 Python 语法糖,对 typing
库里的类型注解有了更深的理解,还有些杂七杂八的 abc
、@builtin
、@aggregate
、@triton.jit
等等的装饰器的用法。如果从这个角度看,可学的可就太多了,不管是 C++ 里各种奇妙的模板元编程,还有 CMake、pyproject.toml 等等配置文件,大项目的文档组织架构、注释风格,Python 代码与 low-level IR 究竟在哪里转换的,什么是 GPUDialect,等等等等。
我也是刚刚才意识到,之前做 UROP 时其实都是当调包侠,这是我第一次需要这么关心底层的东西,所以不知道该怎么把握细节信息量。学吧……
今天实在无心工作。明天 sync 下他们究竟想要我干什么,再 get down。
July 2
大致知道了这个星期的工作。刚拿到了 vadd_ir.mlir
这个文件,里面是最简单的(带有 layout 的)加法 IR。这个 IR 有六万多行,不过模块(或者说,kernel)分得很清楚,每个 kernel 大概三百多行,都是这个加法的一个实现,只不过从上往下是依次 lower 到越来越底层(比如,从 cute
dialect 给 lower 到 llvm
)。我们的目的就是搞清楚这个向下 lower 的过程是究竟如何将 cute.xxx
这些函数进行转换的。
每个 kernel 里的 300 多行代码也不是每个都要看,主要聚焦 arith.addf
的参数是从哪里来的,这些参数在之前又经历了些什么。整个加法可能是 1024 * 1024 的大加法,但是每个 arith.addf
只负责 vector<16xf32>
的加法,这表明在之前肯定经历过切分。我们以最上层 kernel 为例,arith.addf
出现在这里:
%43 = arith.addf %38, %41 : vector<16xf32>
向上 trace %38
和 %41
:
%38 = cute.memref.load_vec %33, row_major : !memref_rmem_f32_
%41 = cute.memref.load_vec %34, row_major : !memref_rmem_f32_
这里就出现了 cute
dialect 的函数 cute.memref.load_vec
。向上继续 trace %33
和 %34
:
%33 = cute.make_fragment_like(%src_partitioned) <{elem_type = f32}> : !memref_gmem_f32_2 to !memref_rmem_f32_
%34 = cute.make_fragment_like(%src_partitioned_76) <{elem_type = f32}> : !memref_gmem_f32_2 to !memref_rmem_f32_
依此类推,我们可以一步步向上 trace。这个过程中就会涉及到大量 cute
dialect 的函数,我们要做的就是搞清楚每个过程中涉及到的 cute
函数接收什么参数、究竟是什么行为。有些函数可能很直白,比如 cute.get_layout
或者 cute.get_shape
,但是别的函数可能就没有这么直观。当函数意思不够直观的时候,我们就得借助 cute
的文档、对应的 C++ 代码,或者向下 lower 的 kernel IR 来理解。这就是我这周的任务。
July 3
今天继续看昨天看的,但除了看 mlir dump 文件外,还转过头来看看原 py 代码里的 kernel 究竟是怎么写的,好理解下这个 kernel 到底在干什么。CUTLASS C++ 里也有一些函数的用法,比如 cute.tiled.copy.partition_S
。atom
是什么意思也许也可以在相同地方找到。除此以外,CuTe 的文档里也有一些函数的用法。另外,除了保留与 arith.addf
相关的 kernel op 外,存储部分的代码也要保留下来。建议我弄个 github repo 写点向下转换的伪代码,可以用 Python 也可以用 C++。