Alien-XGBoost

 view release on metacpan or  search on metacpan

xgboost/cub/cub/agent/agent_segment_fixup.cuh  view on Meta::CPAN

                pairs[ITEM].value = reduction_op(pairs[ITEM - 1].value, pairs[ITEM].value);
        }

        // Flush last item if valid
        ValueT* d_scatter = d_aggregates_out + pairs[ITEMS_PER_THREAD - 1].key;
        if ((!IS_LAST_TILE) || (pairs[ITEMS_PER_THREAD - 1].key >= 0))
            atomicAdd(d_scatter, pairs[ITEMS_PER_THREAD - 1].value);
    }


    /**
     * Process input tile.  Specialized for reduce-by-key fixup
     */
    template <bool IS_LAST_TILE>
    __device__ __forceinline__ void ConsumeTile(
        OffsetT             num_remaining,      ///< Number of global input items remaining (including this tile)
        int                 tile_idx,           ///< Tile index
        OffsetT             tile_offset,        ///< Tile offset
        ScanTileStateT&     tile_state,         ///< Global tile state descriptor
        Int2Type<false>     use_atomic_fixup)   ///< Marker whether to use atomicAdd (instead of reduce-by-key)
    {
        KeyValuePairT   pairs[ITEMS_PER_THREAD];
        KeyValuePairT   scatter_pairs[ITEMS_PER_THREAD];

        // Load pairs
        KeyValuePairT oob_pair;
        oob_pair.key = -1;

        if (IS_LAST_TILE)
            BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs, num_remaining, oob_pair);
        else
            BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs);

        CTA_SYNC();

        KeyValuePairT tile_aggregate;
        if (tile_idx == 0)
        {
            // Exclusive scan of values and segment_flags
            BlockScanT(temp_storage.scan).ExclusiveScan(pairs, scatter_pairs, scan_op, tile_aggregate);

            // Update tile status if this is not the last tile
            if (threadIdx.x == 0)
            {
                // Set first segment id to not trigger a flush (invalid from exclusive scan)
                scatter_pairs[0].key = pairs[0].key;

                if (!IS_LAST_TILE)
                    tile_state.SetInclusive(0, tile_aggregate);

            }
        }
        else
        {
            // Exclusive scan of values and segment_flags
            TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, scan_op, tile_idx);
            BlockScanT(temp_storage.scan).ExclusiveScan(pairs, scatter_pairs, scan_op, prefix_op);
            tile_aggregate = prefix_op.GetBlockAggregate();
        }

        // Scatter updated values
        #pragma unroll
        for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
        {
            if (scatter_pairs[ITEM].key != pairs[ITEM].key)
            {
                // Update the value at the key location
                ValueT value    = d_fixup_in[scatter_pairs[ITEM].key];
                value           = reduction_op(value, scatter_pairs[ITEM].value);

                d_aggregates_out[scatter_pairs[ITEM].key] = value;
            }
        }

        // Finalize the last item
        if (IS_LAST_TILE)
        {
            // Last thread will output final count and last item, if necessary
            if (threadIdx.x == BLOCK_THREADS - 1)
            {
                // If the last tile is a whole tile, the inclusive prefix contains accumulated value reduction for the last segment
                if (num_remaining == TILE_ITEMS)
                {
                    // Update the value at the key location
                    OffsetT last_key = pairs[ITEMS_PER_THREAD - 1].key;
                    d_aggregates_out[last_key] = reduction_op(tile_aggregate.value, d_fixup_in[last_key]);
                }
            }
        }
    }


    /**
     * Scan tiles of items as part of a dynamic chained scan
     */
    __device__ __forceinline__ void ConsumeRange(
        int                 num_items,          ///< Total number of input items
        int                 num_tiles,          ///< Total number of input tiles
        ScanTileStateT&     tile_state)         ///< Global tile state descriptor
    {
        // Blocks are launched in increasing order, so just assign one tile per block
        int     tile_idx        = (blockIdx.x * gridDim.y) + blockIdx.y;    // Current tile index
        OffsetT tile_offset     = tile_idx * TILE_ITEMS;                    // Global offset for the current tile
        OffsetT num_remaining   = num_items - tile_offset;                  // Remaining items (including this tile)

        if (num_remaining > TILE_ITEMS)
        {
            // Not the last tile (full)
            ConsumeTile<false>(num_remaining, tile_idx, tile_offset, tile_state, Int2Type<USE_ATOMIC_FIXUP>());
        }
        else if (num_remaining > 0)
        {
            // The last tile (possibly partially-full)
            ConsumeTile<true>(num_remaining, tile_idx, tile_offset, tile_state, Int2Type<USE_ATOMIC_FIXUP>());
        }
    }

};


}               // CUB namespace



( run in 0.941 second using v1.01-cache-2.11-cpan-39bf76dae61 )