Vengineerの戯言

人生は短いけど、長いです。人生を楽しみましょう!

RISC-VなGPGPUであるVortexを深堀する (その6)

はじめに

RISC-VなGPGPUであるVortexを深堀する (その6)として、main 関数をみていきます。]

pocl

Vortex で動くプログラムの main 関数は、pocl の中で下記のように生成されます

    char wrapper_cc[POCL_FILENAME_LENGTH];    
    char pfn_workgroup_string[WORKGROUP_STRING_LENGTH];
    std::stringstream ss;

    snprintf (pfn_workgroup_string, WORKGROUP_STRING_LENGTH,
              "_pocl_kernel_%s_workgroup", kernel->name);
    
    ss << "#include <vx_spawn.h>\n"
          "void " << pfn_workgroup_string << "(uint8_t* args, uint8_t*, uint32_t, uint32_t, uint32_t);\n"  
          "int main() {\n"
          "  const context_t* ctx = (const context_t*)" << KERNEL_ARG_BASE_ADDR << ";\n"
          "  void* args = (void*)" << (KERNEL_ARG_BASE_ADDR + ALIGNED_CTX_SIZE) << ";\n"
          "  vx_spawn_kernel(ctx, (void*)" << pfn_workgroup_string << ", args);\n"
          "  return 0;\n"
          "}";

この main 関数が pocl の中でコンパイルされるわけです。

main 関数の中では、vx_spawn_kernel 関数が呼ばれています。Vortex の runtime/src/vx_spawn.c で下記のように定義されています。

  • ctx : KERNEL_ARG_BASE_ADDR => vx_spawn_kernel関数で使われる引数があるアドレス
  • callback : (pocl_kernel%s_workgroup", kernel->name) => 呼び出す関数 (この関数が ctx にダウンロードされる)
  • args : KERNEL_ARG_BASE_ADDR + ALIGNED_CTX_SIZE => 呼び出す関数の引数があるアドレス

KERNEL_ARG_BASE_ADDR マクロは、ここで 下記のように定義されています。この KERNEL_ARG_BASE_ADDR マクロで設定している 0x7fff0000 は FPGA に接続している DRAM の中だと思います。

// location is local memory where to store kernel parameters
#define KERNEL_ARG_BASE_ADDR 0x7fff0000

ALIGNED_CTX_SIZE は、次のような値が設定されています。KERNEL_ARG_BASE_ADDR から ALIGNED_CTX_SIZE 分が Context に必要な情報が入っているんでしょうね。

static size_t ALIGNED_CTX_SIZE = 4 * ((sizeof(kernel_context_t) + 3) / 4);
void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg) {  
  // total number of WGs
  int X  = ctx->num_groups[0];
  int Y  = ctx->num_groups[1];
  int Z  = ctx->num_groups[2];
  int XY = X * Y;
  int Q  = XY * Z;
  
  // device specs
  int NC = vx_num_cores();
  int NW = vx_num_warps();
  int NT = vx_num_threads();

vx_num_cores, vx_num_warps, vx_num_threads は、vx_intrinsics.h の中で定義されています。

// Return the number of threads in a warp
inline int vx_num_threads() {
    int result;
    asm volatile ("csrr %0, %1" : "=r"(result) : "i"(CSR_NT));
    return result; 
}

// Return the number of warps in a core
inline int vx_num_warps() {
    int result;
    asm volatile ("csrr %0, %1" : "=r"(result) : "i"(CSR_NW));
    return result;   
}

// Return the number of cores in the processsor
inline int vx_num_cores() {
    int result;
    asm volatile ("csrr %0, %1" : "=r"(result) : "i"(CSR_NC));
    return result;   
}

vx_core_id は、vx_intrinsics.h の中で定義されています。

  // current core id
  int core_id = vx_core_id();  
  if (core_id >= NUM_CORES_MAX)
    return;
// Return processsor core id
inline int vx_core_id() {
    int result;
    asm volatile ("csrr %0, %1" : "=r"(result) : "i"(CSR_GCID));
    return result; 
}
  // calculate necessary active cores
  int WT = NW * NT;
  int nC = (Q > WT) ? (Q / WT) : 1;
  int nc = MIN(nC, NC);
  if (core_id >= nc)
    return; // terminate extra cores

  // number of workgroups per core
  int wgs_per_core = Q / nc;
  int wgs_per_core0 = wgs_per_core;  
  if (core_id == (NC-1)) {    
    int QC_r = Q - (nc * wgs_per_core0); 
    wgs_per_core0 += QC_r; // last core executes remaining WGs
  }

  // number of workgroups per warp
  int nW = wgs_per_core0 / NT;              // total warps per core
  int rT = wgs_per_core0 - (nW * NT);       // remaining threads
  int fW = (nW >= NW) ? (nW / NW) : 0;      // full warps iterations
  int rW = (fW != 0) ? (nW - fW * NW) : 0;  // reamining full warps
  if (0 == fW)
    fW = 1;

  // fast path handling
  char isXYpow2 = is_log2(XY);
  char isXpow2  = is_log2(X);
  char log2XY   = fast_log2(XY);
  char log2X    = fast_log2(X);

  //--
  wspawn_kernel_args_t wspawn_args = { 
    ctx, callback, arg, core_id * wgs_per_core, fW, rW, 0, isXYpow2, isXpow2, log2XY, log2X 
  };
  g_wspawn_args[core_id] = &wspawn_args;

  //--
    if (nW >= 1) { 
    int nw = MIN(nW, NW);    
    wspawn_args.NW = nw;
      vx_wspawn(nw, spawn_kernel_all_cb);
    spawn_kernel_all_cb();
    }  

  //--    
  if (rT != 0) {
    wspawn_args.offset = wgs_per_core0 - rT;
    int tmask = (1 << rT) - 1;
    spawn_kernel_rem_cb(tmask);
  }
}

vx_wspawn 関数は、vx_intrinsics.h の中で下記のように定義されています。vx_wspawn 関数の中では、前回説明した wspawn a0, a1 を呼び出しています。func_ptr は、spawn_kernel_all_cb です。

// Spawn warps
inline void vx_wspawn(unsigned num_warps, vx_wspawn_pfn func_ptr) {
    asm volatile (".insn s 0x6b, 1, %1, 0(%0)" :: "r"(num_warps), "r"(func_ptr));
}

spawn_kernel_all_cb 関数は、下記のように定義されています。

static void spawn_kernel_all_cb() {  
  // activate all threads
  vx_tmc(-1);

  // call stub routine
  spawn_kernel_all_stub();

  // set warp0 to single-threaded and stop other warps
  int wid = vx_warp_id();
  vx_tmc(0 == wid);
}

spawn_kernel_all_stub 関数は、次のように定義されています。

static void __attribute__ ((noinline)) spawn_kernel_all_stub() {
  int core_id = vx_core_id();
  int wid     = vx_warp_id();
  int tid     = vx_thread_id(); 
  int NT      = vx_num_threads();
  
  wspawn_kernel_args_t* p_wspawn_args = (wspawn_kernel_args_t*)g_wspawn_args[core_id];

  int wK = (p_wspawn_args->N * wid) + MIN(p_wspawn_args->R, wid);
  int tK = p_wspawn_args->N + (wid < p_wspawn_args->R);
  int offset = p_wspawn_args->offset + (wK * NT) + (tid * tK);

  int X = p_wspawn_args->ctx->num_groups[0];
  int Y = p_wspawn_args->ctx->num_groups[1];
  int XY = X * Y;

  for (int wg_id = offset, N = wg_id + tK; wg_id < N; ++wg_id) {    
    int k = p_wspawn_args->isXYpow2 ? (wg_id >> p_wspawn_args->log2XY) : (wg_id / XY);
    int wg_2d = wg_id - k * XY;
    int j = p_wspawn_args->isXpow2 ? (wg_2d >> p_wspawn_args->log2X) : (wg_2d / X);
    int i = wg_2d - j * X;

    int gid0 = p_wspawn_args->ctx->global_offset[0] + i;
    int gid1 = p_wspawn_args->ctx->global_offset[1] + j;
    int gid2 = p_wspawn_args->ctx->global_offset[2] + k;

    (p_wspawn_args->callback)(p_wspawn_args->arg, p_wspawn_args->ctx, gid0, gid1, gid2);
  }

  // wait for all warps to complete
  vx_barrier(0, p_wspawn_args->NW);
}

callback は、vx_spawn_kerl 関数の callback 、つまり、OpenCLカーネル関数、が呼び出され、すべての Warp が終了するまで、vx_barrier 関数にて待ちます。

    (p_wspawn_args->callback)(p_wspawn_args->arg, p_wspawn_args->ctx, gid0, gid1, gid2);
  }

  // wait for all warps to complete
  vx_barrier(0, p_wspawn_args->NW);

vx_barrier 関数は vx_intrinsics.h の中で次のように定義されています。

inline void vx_barrier(unsigned barried_id, unsigned num_warps) {
    asm volatile (".insn s 0x6b, 4, %1, 0(%0)" :: "r"(barried_id), "r"(num_warps));
}

.insn s 0x6b, 4 は、barrier 命令であり、VX_gpu_unit.sv の中では、下記のようなコードになっています。

    // barrier
    
    assign barrier.valid   = is_bar;
    assign barrier.id      = rs1_data[`NB_BITS-1:0];
    assign barrier.size_m1 = (`NW_BITS)'(rs2_data - 1);       

この barrier も下記のように、Vx_warp_sched に伝わります。

    // pack warp ctl result
    assign warp_ctl_data = {tmc, wspawn, split, barrier};

VX_warp_sched の中では下記のようコードになっています。

            if (warp_ctl_if.valid && warp_ctl_if.barrier.valid) begin
                stalled_warps[warp_ctl_if.wid] <= 0;
                if (reached_barrier_limit) begin
                    barrier_masks[warp_ctl_if.barrier.id] <= 0;
                end else begin
                    barrier_masks[warp_ctl_if.barrier.id][warp_ctl_if.wid] <= 1;
                end
            end 

下記の部分でbarrier_stalls を計算し、

    // calculate active barrier status

`IGNORE_UNUSED_BEGIN
    wire [`NW_BITS:0] active_barrier_count;
`IGNORE_UNUSED_END
    wire [`NUM_WARPS-1:0] barrier_mask = barrier_masks[warp_ctl_if.barrier.id];
    `POP_COUNT(active_barrier_count, barrier_mask);

    assign reached_barrier_limit = (active_barrier_count[`NW_BITS-1:0] == warp_ctl_if.barrier.size_m1);

    reg [`NUM_WARPS-1:0] barrier_stalls;
    always @(*) begin
        barrier_stalls = barrier_masks[0];
        for (integer i = 1; i < `NUM_BARRIERS; ++i) begin
            barrier_stalls |= barrier_masks[i];
        end
    end

下記の部分で次の Warp が実行できるかどうかを決めています。

    // schedule the next ready warp

    wire [`NUM_WARPS-1:0] ready_warps = active_warps & ~(stalled_warps | barrier_stalls);

端数は、spawn_task_rem_cb 関数の中で実行されます。

static void spawn_tasks_rem_cb(int thread_mask) {  
  // activate threads  
  vx_tmc(thread_mask);

  // call stub routine
  spawn_tasks_rem_stub();

  // back to single-threaded
  vx_tmc(1);
}

spawn_tasks_rem_cb 関数の中で、spawn_kernel_rem_stub 関数を呼び出しています。こちらでは、vx_barrier 関数で待つ必要はありません。

static void __attribute__ ((noinline)) spawn_kernel_rem_stub() {
  int core_id = vx_core_id(); 
  int tid = vx_thread_gid();

  wspawn_kernel_args_t* p_wspawn_args = (wspawn_kernel_args_t*)g_wspawn_args[core_id];

  int wg_id = p_wspawn_args->offset + tid;

  int X = p_wspawn_args->ctx->num_groups[0];
  int Y = p_wspawn_args->ctx->num_groups[1];
  int XY = X * Y;
  
  int k = p_wspawn_args->isXYpow2 ? (wg_id >> p_wspawn_args->log2XY) : (wg_id / XY);
  int wg_2d = wg_id - k * XY;
  int j = p_wspawn_args->isXpow2 ? (wg_2d >> p_wspawn_args->log2X) : (wg_2d / X);
  int i = wg_2d - j * X;

  int gid0 = p_wspawn_args->ctx->global_offset[0] + i;
  int gid1 = p_wspawn_args->ctx->global_offset[1] + j;
  int gid2 = p_wspawn_args->ctx->global_offset[2] + k;

  (p_wspawn_args->callback)(p_wspawn_args->arg, p_wspawn_args->ctx, gid0, gid1, gid2);
}

おわりに

6回に分けて、RISC-VなGPGPUである Vortex を深堀してみました。RTL だけ調べてもわからないことが、runtime や pocl の中を覗くことで、色々なことを知ることができました。

SIMTな(RISC-Vベースの)GPGPUって、こんな感じで実装できるんだー、ということがわかってよかったです。

1つの Clusterに Vortex を 4個ぐらいを ZynqMP SoC に入らないでしょうかね。

入れば、ZynqMP SoC + Vortex という面白い教材ができますよ。。。そうすれば、SIMTなRISC-V のGPGPU を含めたシステムを学べますよ。