Alien-XGBoost

 view release on metacpan or  search on metacpan

xgboost/cub/cub/agent/agent_spmv_row_based.cuh  view on Meta::CPAN

        ValueT&             row_start,
        OffsetT&            tile_nonzero_idx,
        OffsetT             tile_nonzero_idx_end,
        OffsetT             row_nonzero_idx,
        OffsetT             row_nonzero_idx_end)
    {
        ValueT NAN_TOKEN;
        InitNan(NAN_TOKEN);


        //
        // Gather a strip of nonzeros into shared memory
        //

        #pragma unroll
        for (int ITEM = 0; ITEM < NNZ_PER_THREAD; ++ITEM)
        {

            ValueT nonzero = 0.0;

            OffsetT                 local_nonzero_idx   = (ITEM * BLOCK_THREADS) + threadIdx.x;
            OffsetT                 nonzero_idx         = tile_nonzero_idx + local_nonzero_idx;

            bool in_range = nonzero_idx < tile_nonzero_idx_end;

            OffsetT nonzero_idx2 = (in_range) ?
                nonzero_idx :
                tile_nonzero_idx_end - 1;

            OffsetT column_idx          = wd_column_indices[nonzero_idx2];
            ValueT  value               = wd_values[nonzero_idx2];
            ValueT  vector_value        = wd_vector_x[column_idx];
            nonzero                     = value * vector_value;

            if (!in_range)
                nonzero = 0.0;

            temp_storage.nonzeros[local_nonzero_idx] = nonzero;
        }

        CTA_SYNC();

        //
        // Swap in NANs at local row start offsets
        //

        OffsetT local_row_nonzero_idx = row_nonzero_idx - tile_nonzero_idx;
        if ((local_row_nonzero_idx >= 0) && (local_row_nonzero_idx < TILE_ITEMS))
        {
            // Thread's row starts in this strip
            row_start = temp_storage.nonzeros[local_row_nonzero_idx];
            temp_storage.nonzeros[local_row_nonzero_idx] = NAN_TOKEN;
        }

        CTA_SYNC();

        //
        // Segmented scan
        //

        // Read strip of nonzeros into thread-blocked order, setup segment flags
        KeyValuePairT scan_items[NNZ_PER_THREAD];
        for (int ITEM = 0; ITEM < NNZ_PER_THREAD; ++ITEM)
        {
            int     local_nonzero_idx   = (threadIdx.x * NNZ_PER_THREAD) + ITEM;
            ValueT  value               = temp_storage.nonzeros[local_nonzero_idx];
            bool    is_nan              = (value != value);

            scan_items[ITEM].value  = (is_nan) ? 0.0 : value;
            scan_items[ITEM].key    = is_nan;
        }

        KeyValuePairT       tile_aggregate;
        KeyValuePairT       scan_items_out[NNZ_PER_THREAD];

        BlockScanT(temp_storage.scan).ExclusiveScan(scan_items, scan_items_out, scan_op, tile_aggregate, prefix_op);

        // Save the inclusive sum for the last row
        if (threadIdx.x == 0)
        {
            temp_storage.nonzeros[TILE_ITEMS] = prefix_op.running_total.value;
        }

        // Store segment totals
        for (int ITEM = 0; ITEM < NNZ_PER_THREAD; ++ITEM)
        {
            int local_nonzero_idx = (threadIdx.x * NNZ_PER_THREAD) + ITEM;

            if (scan_items[ITEM].key)
                temp_storage.nonzeros[local_nonzero_idx] = scan_items_out[ITEM].value;
        }

        CTA_SYNC();

        //
        // Update row totals
        //

        OffsetT local_row_nonzero_idx_end = row_nonzero_idx_end - tile_nonzero_idx;
        if ((local_row_nonzero_idx_end >= 0) && (local_row_nonzero_idx_end < TILE_ITEMS))
        {
            // Thread's row ends in this strip
            row_total = temp_storage.nonzeros[local_row_nonzero_idx_end];
        }

        tile_nonzero_idx += NNZ_PER_THREAD * BLOCK_THREADS;
    }



    /**
     * Consume input tile
     */
    __device__ __forceinline__ void ConsumeTile(
        int     tile_idx,
        int     rows_per_tile)
    {
        //
        // Read in tile of row ranges
        //



( run in 2.418 seconds using v1.01-cache-2.11-cpan-39bf76dae61 )