はじめに
RISC-VなGPGPUであるVortexを深堀する (その6)として、main 関数をみていきます。]
pocl
Vortex で動くプログラムの main 関数は、pocl の中で下記のように生成されます。
char wrapper_cc[POCL_FILENAME_LENGTH]; char pfn_workgroup_string[WORKGROUP_STRING_LENGTH]; std::stringstream ss; snprintf (pfn_workgroup_string, WORKGROUP_STRING_LENGTH, "_pocl_kernel_%s_workgroup", kernel->name); ss << "#include <vx_spawn.h>\n" "void " << pfn_workgroup_string << "(uint8_t* args, uint8_t*, uint32_t, uint32_t, uint32_t);\n" "int main() {\n" " const context_t* ctx = (const context_t*)" << KERNEL_ARG_BASE_ADDR << ";\n" " void* args = (void*)" << (KERNEL_ARG_BASE_ADDR + ALIGNED_CTX_SIZE) << ";\n" " vx_spawn_kernel(ctx, (void*)" << pfn_workgroup_string << ", args);\n" " return 0;\n" "}";
この main 関数が pocl の中でコンパイルされるわけです。
main 関数の中では、vx_spawn_kernel 関数が呼ばれています。Vortex の runtime/src/vx_spawn.c で下記のように定義されています。
- ctx : KERNEL_ARG_BASE_ADDR => vx_spawn_kernel関数で使われる引数があるアドレス
- callback : (pocl_kernel%s_workgroup", kernel->name) => 呼び出す関数 (この関数が ctx にダウンロードされる)
- args : KERNEL_ARG_BASE_ADDR + ALIGNED_CTX_SIZE => 呼び出す関数の引数があるアドレス
KERNEL_ARG_BASE_ADDR マクロは、ここで 下記のように定義されています。この KERNEL_ARG_BASE_ADDR マクロで設定している 0x7fff0000 は FPGA に接続している DRAM の中だと思います。
// location is local memory where to store kernel parameters #define KERNEL_ARG_BASE_ADDR 0x7fff0000
ALIGNED_CTX_SIZE は、次のような値が設定されています。KERNEL_ARG_BASE_ADDR から ALIGNED_CTX_SIZE 分が Context に必要な情報が入っているんでしょうね。
static size_t ALIGNED_CTX_SIZE = 4 * ((sizeof(kernel_context_t) + 3) / 4);
void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg) { // total number of WGs int X = ctx->num_groups[0]; int Y = ctx->num_groups[1]; int Z = ctx->num_groups[2]; int XY = X * Y; int Q = XY * Z; // device specs int NC = vx_num_cores(); int NW = vx_num_warps(); int NT = vx_num_threads();
vx_num_cores, vx_num_warps, vx_num_threads は、vx_intrinsics.h の中で定義されています。
// Return the number of threads in a warp inline int vx_num_threads() { int result; asm volatile ("csrr %0, %1" : "=r"(result) : "i"(CSR_NT)); return result; } // Return the number of warps in a core inline int vx_num_warps() { int result; asm volatile ("csrr %0, %1" : "=r"(result) : "i"(CSR_NW)); return result; } // Return the number of cores in the processsor inline int vx_num_cores() { int result; asm volatile ("csrr %0, %1" : "=r"(result) : "i"(CSR_NC)); return result; }
vx_core_id は、vx_intrinsics.h の中で定義されています。
// current core id int core_id = vx_core_id(); if (core_id >= NUM_CORES_MAX) return;
// Return processsor core id inline int vx_core_id() { int result; asm volatile ("csrr %0, %1" : "=r"(result) : "i"(CSR_GCID)); return result; }
// calculate necessary active cores int WT = NW * NT; int nC = (Q > WT) ? (Q / WT) : 1; int nc = MIN(nC, NC); if (core_id >= nc) return; // terminate extra cores // number of workgroups per core int wgs_per_core = Q / nc; int wgs_per_core0 = wgs_per_core; if (core_id == (NC-1)) { int QC_r = Q - (nc * wgs_per_core0); wgs_per_core0 += QC_r; // last core executes remaining WGs } // number of workgroups per warp int nW = wgs_per_core0 / NT; // total warps per core int rT = wgs_per_core0 - (nW * NT); // remaining threads int fW = (nW >= NW) ? (nW / NW) : 0; // full warps iterations int rW = (fW != 0) ? (nW - fW * NW) : 0; // reamining full warps if (0 == fW) fW = 1; // fast path handling char isXYpow2 = is_log2(XY); char isXpow2 = is_log2(X); char log2XY = fast_log2(XY); char log2X = fast_log2(X); //-- wspawn_kernel_args_t wspawn_args = { ctx, callback, arg, core_id * wgs_per_core, fW, rW, 0, isXYpow2, isXpow2, log2XY, log2X }; g_wspawn_args[core_id] = &wspawn_args; //-- if (nW >= 1) { int nw = MIN(nW, NW); wspawn_args.NW = nw; vx_wspawn(nw, spawn_kernel_all_cb); spawn_kernel_all_cb(); } //-- if (rT != 0) { wspawn_args.offset = wgs_per_core0 - rT; int tmask = (1 << rT) - 1; spawn_kernel_rem_cb(tmask); } }
vx_wspawn 関数は、vx_intrinsics.h の中で下記のように定義されています。vx_wspawn 関数の中では、前回説明した wspawn a0, a1 を呼び出しています。func_ptr は、spawn_kernel_all_cb です。
// Spawn warps inline void vx_wspawn(unsigned num_warps, vx_wspawn_pfn func_ptr) { asm volatile (".insn s 0x6b, 1, %1, 0(%0)" :: "r"(num_warps), "r"(func_ptr)); }
spawn_kernel_all_cb 関数は、下記のように定義されています。
static void spawn_kernel_all_cb() { // activate all threads vx_tmc(-1); // call stub routine spawn_kernel_all_stub(); // set warp0 to single-threaded and stop other warps int wid = vx_warp_id(); vx_tmc(0 == wid); }
spawn_kernel_all_stub 関数は、次のように定義されています。
static void __attribute__ ((noinline)) spawn_kernel_all_stub() { int core_id = vx_core_id(); int wid = vx_warp_id(); int tid = vx_thread_id(); int NT = vx_num_threads(); wspawn_kernel_args_t* p_wspawn_args = (wspawn_kernel_args_t*)g_wspawn_args[core_id]; int wK = (p_wspawn_args->N * wid) + MIN(p_wspawn_args->R, wid); int tK = p_wspawn_args->N + (wid < p_wspawn_args->R); int offset = p_wspawn_args->offset + (wK * NT) + (tid * tK); int X = p_wspawn_args->ctx->num_groups[0]; int Y = p_wspawn_args->ctx->num_groups[1]; int XY = X * Y; for (int wg_id = offset, N = wg_id + tK; wg_id < N; ++wg_id) { int k = p_wspawn_args->isXYpow2 ? (wg_id >> p_wspawn_args->log2XY) : (wg_id / XY); int wg_2d = wg_id - k * XY; int j = p_wspawn_args->isXpow2 ? (wg_2d >> p_wspawn_args->log2X) : (wg_2d / X); int i = wg_2d - j * X; int gid0 = p_wspawn_args->ctx->global_offset[0] + i; int gid1 = p_wspawn_args->ctx->global_offset[1] + j; int gid2 = p_wspawn_args->ctx->global_offset[2] + k; (p_wspawn_args->callback)(p_wspawn_args->arg, p_wspawn_args->ctx, gid0, gid1, gid2); } // wait for all warps to complete vx_barrier(0, p_wspawn_args->NW); }
callback は、vx_spawn_kerl 関数の callback 、つまり、OpenCLのカーネル関数、が呼び出され、すべての Warp が終了するまで、vx_barrier 関数にて待ちます。
(p_wspawn_args->callback)(p_wspawn_args->arg, p_wspawn_args->ctx, gid0, gid1, gid2); } // wait for all warps to complete vx_barrier(0, p_wspawn_args->NW);
vx_barrier 関数は vx_intrinsics.h の中で次のように定義されています。
inline void vx_barrier(unsigned barried_id, unsigned num_warps) { asm volatile (".insn s 0x6b, 4, %1, 0(%0)" :: "r"(barried_id), "r"(num_warps)); }
.insn s 0x6b, 4 は、barrier 命令であり、VX_gpu_unit.sv の中では、下記のようなコードになっています。
// barrier assign barrier.valid = is_bar; assign barrier.id = rs1_data[`NB_BITS-1:0]; assign barrier.size_m1 = (`NW_BITS)'(rs2_data - 1);
この barrier も下記のように、Vx_warp_sched に伝わります。
// pack warp ctl result assign warp_ctl_data = {tmc, wspawn, split, barrier};
VX_warp_sched の中では下記のようコードになっています。
if (warp_ctl_if.valid && warp_ctl_if.barrier.valid) begin stalled_warps[warp_ctl_if.wid] <= 0; if (reached_barrier_limit) begin barrier_masks[warp_ctl_if.barrier.id] <= 0; end else begin barrier_masks[warp_ctl_if.barrier.id][warp_ctl_if.wid] <= 1; end end
// calculate active barrier status `IGNORE_UNUSED_BEGIN wire [`NW_BITS:0] active_barrier_count; `IGNORE_UNUSED_END wire [`NUM_WARPS-1:0] barrier_mask = barrier_masks[warp_ctl_if.barrier.id]; `POP_COUNT(active_barrier_count, barrier_mask); assign reached_barrier_limit = (active_barrier_count[`NW_BITS-1:0] == warp_ctl_if.barrier.size_m1); reg [`NUM_WARPS-1:0] barrier_stalls; always @(*) begin barrier_stalls = barrier_masks[0]; for (integer i = 1; i < `NUM_BARRIERS; ++i) begin barrier_stalls |= barrier_masks[i]; end end
下記の部分で次の Warp が実行できるかどうかを決めています。
// schedule the next ready warp wire [`NUM_WARPS-1:0] ready_warps = active_warps & ~(stalled_warps | barrier_stalls);
端数は、spawn_task_rem_cb 関数の中で実行されます。
static void spawn_tasks_rem_cb(int thread_mask) { // activate threads vx_tmc(thread_mask); // call stub routine spawn_tasks_rem_stub(); // back to single-threaded vx_tmc(1); }
spawn_tasks_rem_cb 関数の中で、spawn_kernel_rem_stub 関数を呼び出しています。こちらでは、vx_barrier 関数で待つ必要はありません。
static void __attribute__ ((noinline)) spawn_kernel_rem_stub() { int core_id = vx_core_id(); int tid = vx_thread_gid(); wspawn_kernel_args_t* p_wspawn_args = (wspawn_kernel_args_t*)g_wspawn_args[core_id]; int wg_id = p_wspawn_args->offset + tid; int X = p_wspawn_args->ctx->num_groups[0]; int Y = p_wspawn_args->ctx->num_groups[1]; int XY = X * Y; int k = p_wspawn_args->isXYpow2 ? (wg_id >> p_wspawn_args->log2XY) : (wg_id / XY); int wg_2d = wg_id - k * XY; int j = p_wspawn_args->isXpow2 ? (wg_2d >> p_wspawn_args->log2X) : (wg_2d / X); int i = wg_2d - j * X; int gid0 = p_wspawn_args->ctx->global_offset[0] + i; int gid1 = p_wspawn_args->ctx->global_offset[1] + j; int gid2 = p_wspawn_args->ctx->global_offset[2] + k; (p_wspawn_args->callback)(p_wspawn_args->arg, p_wspawn_args->ctx, gid0, gid1, gid2); }
おわりに
6回に分けて、RISC-VなGPGPUである Vortex を深堀してみました。RTL だけ調べてもわからないことが、runtime や pocl の中を覗くことで、色々なことを知ることができました。
SIMTな(RISC-Vベースの)GPGPUって、こんな感じで実装できるんだー、ということがわかってよかったです。
1つの Clusterに Vortex を 4個ぐらいを ZynqMP SoC に入らないでしょうかね。
入れば、ZynqMP SoC + Vortex という面白い教材ができますよ。。。そうすれば、SIMTなRISC-V のGPGPU を含めたシステムを学べますよ。