/*******************************************************************************
* Copyright 2023 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

#pragma once

#define ESIMD_UNROLL _Pragma("unroll")
#define MAX_NUM_VEC_ESB 27

#include "../Helpers.hpp"
#include "../EsimdHelpers.hpp"
#include "esb_unrolls.hpp"

// Uncomment to enable unrolled kernels
// If not using unrolled kernels disable large grf mode in makefile
#define USE_MV_UNROLL_KERNELS
#define USE_TRMV_L_UNROLL_KERNELS
#define USE_TRMV_R_UNROLL_KERNELS

// for naming kernels in cgh.parallel_for
// SpGEMV
template <int BLOCK_SIZE, bool withDot>
class esb3_mv_esimd_kernel;

// SpTRMV
template <int BLOCK_SIZE>
class esb4_lbmv_update_esimd_kernel;

template <int BLOCK_SIZE>
class esb4_ubmv_esimd_kernel;

using LSCAtomicOp = sycl::ext::intel::esimd::native::lsc::atomic_op;

//
// gemv (optionally w/ fused x*y dot product) using LSC esimd api with ESB format
//
// withDot == false:     y = A * x
// withDot == true:      y = A * x, xAx = x * A * x
//
// assuming 0 based indexing and
// FPTYPE = double, INTTYPE = local_int_t
//
template <int BLOCK_SIZE, bool withDot>
sycl::event sparse_esb3_mv_esimd(sycl::queue &queue,
                                 const local_int_t nrows,
                                 const local_int_t nBlocks,
                                 const local_int_t *blockptr,
                                 const local_int_t *colind,
                                 const double *values,
                                 const double *x,
                                 double *y,
                                 double *xAx,
                                 const local_int_t *reorder,
                                 const bool applyReorder,
                                 const std::vector<sycl::event> &dependencies)
{
    // Each thread does MV on a block of BLOCK_SIZE rows.

    //printf("calling sparse_esb3_mv_esimd: nrows = %d, nBlocks = %d\n", nrows, nBlocks); fflush(0);

    sycl::event evt;

    if constexpr (withDot) {
        evt = queue.memset(xAx, 0, sizeof(double), dependencies);
    }

    auto last = queue.submit([&](sycl::handler &cgh) {
        cgh.depends_on(dependencies);
        cgh.depends_on(evt);

#ifndef USE_MV_UNROLL_KERNELS

#undef ADD_DIAG
#undef UPDATE_Y
#include "./mv_esb3_esimd_kernel.hxx"

        cgh.parallel_for<class esb3_mv_esimd_kernel<BLOCK_SIZE>>(
            sycl::range<1>(nBlocks), mv_esb3_esimd_kernel);
#else
        const local_int_t nWG = 2; // Number of hw threads per WG

        auto mv_esb3_esimd_kernel = [=](sycl::nd_item<1> item) SYCL_ESIMD_KERNEL
            {
                local_int_t block = item.get_global_id(0);
                if (block >= nBlocks) return;

                if (applyReorder) block = reorder[block];

                local_int_t start_row, st_vec, en_vec;

                esimd::simd<double, BLOCK_SIZE> y_vec(0);

                start_row  = block * BLOCK_SIZE;
                st_vec = blockptr[block];
                en_vec = blockptr[block + 1];

                mv_unroll_dispatch<BLOCK_SIZE>(st_vec, en_vec, y_vec, values, colind, x);

                esimd_lsc_block_store<double, local_int_t, BLOCK_SIZE, uc, uc>(y, start_row, y_vec);

                if constexpr (withDot) {
                    auto x_vec = esimd_lsc_block_load<double, local_int_t, BLOCK_SIZE, ca, ca>(x, start_row);
                    y_vec = y_vec * x_vec;
                    auto res = esimd::reduce<double>(y_vec, std::plus<>());
                    sycl::ext::intel::esimd::atomic_update<LSCAtomicOp::fadd, double, 1>(xAx, 0, res);
                }

            };

        cgh.parallel_for<class esb3_mv_esimd_kernel<BLOCK_SIZE, withDot>>(
            sycl::nd_range<1>(ceil_div(nBlocks, nWG) * nWG, nWG), mv_esb3_esimd_kernel);
#endif
    });
    return last;
}

template <int BLOCK_SIZE>
sycl::event sparse_esb4_lbmv_update_esimd(sycl::queue &queue,
                                           const local_int_t nrows,
                                           const local_int_t nBlocks,
                                           const local_int_t *blockptr_st, // for lower_st
                                           const local_int_t *blockptr_en, // for lower_en
                                           const local_int_t *nonloc_st,
                                           const local_int_t *nonloc_en,
                                           const local_int_t *colind,
                                           const double *values,
                                           const double *x,
                                           double *y,
                                           const local_int_t *reorder,
                                           const bool applyReorder,
                                           const std::vector<sycl::event> &dependencies)
{
    // Each thread does (L+B) MV and y update on a block of BLOCK_SIZE rows.

    //printf("calling sparse_esb4_lbmv_update_esimd: nrows = %d, nBlocks = %d\n", nrows, nBlocks); fflush(0);

    auto last = queue.submit([&](sycl::handler &cgh) {
        cgh.depends_on(dependencies);

#ifndef USE_TRMV_L_UNROLL_KERNELS
        const local_int_t nWG = 16; // Number of hw threads per WG
        const local_int_t unroll = 1; // Number of hw threads per WG

#define TRMV_LB
#include "./mv_esb4_esimd_kernel.hxx"
#undef TRMV_LB
#else
        const local_int_t nWG = 2; // Number of hw threads per WG
        const local_int_t unroll = 1; // Number of hw threads per WG
        auto mv_esb4_esimd_kernel = [=](sycl::nd_item<1> item) SYCL_ESIMD_KERNEL
            {
                local_int_t block = item.get_global_id(0);
                if (block >= nBlocks) return;

                if (applyReorder) block = reorder[block];

                local_int_t start_row, st_vec, en_vec;
                esimd::simd<double, BLOCK_SIZE> y_vec(0);

                // offset for masking out elements in U (colind <= start_row + {0,..,BLOCK_SIZE})
                esimd::simd<local_int_t, BLOCK_SIZE> offset(0,1);

                start_row  = block * BLOCK_SIZE;
                st_vec = blockptr_st[block];
                en_vec = blockptr_en[block];
                offset += start_row;

                y_vec = esimd_lsc_block_load<double, local_int_t, BLOCK_SIZE, ca, uc>(y, start_row);


                // No B handling needed
                trmv_lbmv_unroll_dispatch<BLOCK_SIZE, maskL>(
                    st_vec, en_vec, y_vec, values, colind, x, nrows, offset);
#ifndef HPCG_NO_MPI
                // Second pass for B
                st_vec = nonloc_st[block];
                en_vec = nonloc_en[block];
                trmv_lbmv_unroll_dispatch<BLOCK_SIZE, maskB>(
                    st_vec, en_vec, y_vec, values, colind, x, nrows, offset);
#endif
                esimd_lsc_block_store<double, local_int_t, BLOCK_SIZE, uc, uc>(y, start_row, y_vec);
            };
#endif
        cgh.parallel_for<class esb4_lbmv_update_esimd_kernel<BLOCK_SIZE>>(
            sycl::nd_range<1>(ceil_div(nBlocks / unroll, nWG) * nWG, nWG), mv_esb4_esimd_kernel);

    });
    return last;
}

template <int BLOCK_SIZE>
sycl::event sparse_esb4_ubmv_esimd(sycl::queue &queue,
                                   const local_int_t nrows,
                                   const local_int_t nBlocks,
                                   const local_int_t *blockptr_st,
                                   const local_int_t *blockptr_en,
                                   const local_int_t *colind,
                                   const double *values,
                                   const double *x,
                                   const double *r,
                                   double *y,
                                   double *y1,
                                   const local_int_t *reorder,
                                   const bool applyReorder,
                                   const std::vector<sycl::event> &dependencies)
{
    // Each thread does y = r - (U+B) MV and y1 = (U) MV on a block of BLOCK_SIZE rows.

    //printf("calling sparse_esb4_ubmv_esimd: nrows = %d, nBlocks = %d\n", nrows, nBlocks); fflush(0);

    auto last = queue.submit([&](sycl::handler &cgh) {
        cgh.depends_on(dependencies);

#ifndef USE_TRMV_R_UNROLL_KERNELS
        const local_int_t nWG = 16; // Number of hw threads per WG
        const local_int_t unroll = 2; // Number of hw threads per WG
#define TRMV_UB
#include "./mv_esb4_esimd_kernel.hxx"
#undef TRMV_UB
#else
        const local_int_t nWG = 4; // Number of hw threads per WG
        const local_int_t unroll = 1; // Number of hw threads per WG

        auto mv_esb4_esimd_kernel = [=](sycl::nd_item<1> item) SYCL_ESIMD_KERNEL
            {
                local_int_t block = item.get_global_id(0);
                if (block >= nBlocks) return;

                if (applyReorder) block = reorder[block];

                local_int_t start_row, st_vec, en_vec;

                esimd::simd<double, BLOCK_SIZE> y_vec(0), z_vec(0);

                // offset for masking out elements in the L + D (colind <= start_row + {0,..,BLOCK_SIZE})
                esimd::simd<local_int_t, BLOCK_SIZE> offset(0,1);

                start_row  = block * BLOCK_SIZE;
                st_vec = blockptr_st[block];
                en_vec = blockptr_en[block];
                offset += start_row;

                auto r_vec = esimd_lsc_block_load<double, local_int_t, BLOCK_SIZE, uc, uc>(r, start_row);

                trmv_ubmv_unroll_dispatch<BLOCK_SIZE>(
                    st_vec, en_vec, y_vec, z_vec, values, colind, x, nrows, blockptr_st, block, offset);
                //trmv_ubmv_unroll_generic<BLOCK_SIZE>(st_vec, en_vec, y_vec, z_vec, values, colind, x, nrows, offset);

                esimd_lsc_block_store<double, local_int_t, BLOCK_SIZE, uc, uc>(y1, start_row, z_vec);
                y_vec = r_vec - y_vec; // Computes r - (U + B)x
                esimd_lsc_block_store<double, local_int_t, BLOCK_SIZE, uc, uc>(y, start_row, y_vec);
            };
#endif
            cgh.parallel_for<class esb4_ubmv_esimd_kernel<BLOCK_SIZE>>(
                sycl::nd_range<1>(ceil_div(nBlocks / unroll, nWG) * nWG, nWG), mv_esb4_esimd_kernel);
    });
    return last;
}
