Skip to content

[Mirror] MueLu: Kokkosify sparse approximate inverse and GetThresholded*#112

Merged
csiefer2 merged 3 commits intodevelopfrom
pr-mirror-15199
Apr 29, 2026
Merged

[Mirror] MueLu: Kokkosify sparse approximate inverse and GetThresholded*#112
csiefer2 merged 3 commits intodevelopfrom
pr-mirror-15199

Conversation

@csiefer2
Copy link
Copy Markdown
Owner

Automated mirror of upstream PR trilinos#15199 @trilinos/muelu

@github-actions
Copy link
Copy Markdown

CDash for AT1 results [Only accessible from Sandia networks]
CDash for AT2 results [Currently only accessible from Sandia networks]

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modernizes several MueLu components by replacing serial Teuchos-based implementations with Kokkos-parallel kernels. Key changes include the introduction of a Kokkos-based SPAI functor in InverseApproximationFactory and the parallelization of matrix and graph thresholding utilities in UtilitiesBase. Additionally, the expectedNNZperRow parameter has been removed from various constructors and function signatures to simplify the API. Feedback for this PR highlights a logic error in the diagOffset calculation within the new SPAI functor, suggests using binary search for performance optimization in local index lookups, and recommends using size_t for scratch memory size calculations to prevent potential integer overflows.

Comment on lines +155 to +172
local_ordinal_type diagOffset = 0;
{
// Sort
Kokkos::Experimental::sort_team(thread, Kokkos::subview(column_indices, Kokkos::make_pair(0, numColEntries)));
// Merge
if (numColEntries > 0)
++numUniqeColEntries;
local_ordinal_type pos = 0;
for (local_ordinal_type m = 1; m < numColEntries; ++m) {
if (column_indices(pos) != column_indices(m)) {
column_indices(pos + 1) = column_indices(m);
++pos;
++numUniqeColEntries;
if (column_indices(pos) == rlid)
diagOffset = pos;
}
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for finding diagOffset is incorrect. It is initialized to 0, which means if the diagonal entry rlid is not found in column_indices, it will incorrectly default to the first entry in the merged list. Furthermore, the first element column_indices(0) is never checked against rlid because the loop starts at m = 1. This can lead to incorrect results in the SPAI calculation if the diagonal is missing or not at the first position.

    local_ordinal_type numUniqeColEntries = 0;
    local_ordinal_type diagOffset         = -1;
    {
      // Sort
      Kokkos::Experimental::sort_team(thread, Kokkos::subview(column_indices, Kokkos::make_pair(0, numColEntries)));
      // Merge
      if (numColEntries > 0) {
        ++numUniqeColEntries;
        if (column_indices(0) == rlid) diagOffset = 0;
      }
      local_ordinal_type pos = 0;
      for (local_ordinal_type m = 1; m < numColEntries; ++m) {
        if (column_indices(pos) != column_indices(m)) {
          column_indices(pos + 1) = column_indices(m);
          ++pos;
          ++numUniqeColEntries;
          if (column_indices(pos) == rlid) diagOffset = pos;
        }
      }
    }

Comment on lines +187 to +192
for (local_ordinal_type m = 0; m < numUniqeColEntries; ++m) {
if (column_indices(m) == j) {
localA(m, ii) = v;
break;
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since column_indices is sorted, a binary search should be used instead of a sequential search to find the local index. This will significantly improve performance when numUniqeColEntries is large.

        local_ordinal_type low = 0, high = numUniqeColEntries - 1;
        while (low <= high) {
          local_ordinal_type mid = low + (high - low) / 2;
          if (column_indices(mid) == j) {
            localA(mid, ii) = v;
            break;
          }
          if (column_indices(mid) < j) low = mid + 1;
          else high = mid - 1;
        }

using shared_vector = typename decltype(spaiFunctor)::shared_vector;
using shared_lo_vector = typename decltype(spaiFunctor)::shared_lo_vector;

int size = shared_matrix::shmem_size(maxUniqueColEntries, maxRowEntriesAinv) + shared_matrix::shmem_size(maxUniqueColEntries, 1) + 2 * shared_vector::shmem_size(maxUniqueColEntries) + shared_lo_vector::shmem_size(maxUniqueColEntries);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable size should be of type size_t to avoid potential integer overflow when calculating the total scratch memory size, especially since shmem_size returns size_t and maxUniqueColEntries can be large.

    size_t size = shared_matrix::shmem_size(maxUniqueColEntries, maxRowEntriesAinv) + shared_matrix::shmem_size(maxUniqueColEntries, 1) + 2 * shared_vector::shmem_size(maxUniqueColEntries) + shared_lo_vector::shmem_size(maxUniqueColEntries);

cgcgcg added 3 commits April 29, 2026 09:08
Signed-off-by: Christian Glusa <caglusa@sandia.gov>
Signed-off-by: Christian Glusa <caglusa@sandia.gov>
Signed-off-by: Christian Glusa <caglusa@sandia.gov>
@cgcgcg
Copy link
Copy Markdown

cgcgcg commented Apr 29, 2026

/gemini review

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors MueLu's Sparse Approximate Inverse (SPAI) calculation and thresholding utilities to utilize Kokkos parallel kernels, replacing legacy serial loops. It also simplifies the ThresholdAFilterFactory by removing the expectedNNZperRow parameter and updates logging to include global constant computation. Critical feedback was provided regarding the parallel implementation of GetThresholdedMatrix and GetThresholdedGraph, where the parallel_scan indexing for rowptr is incorrect and the parallel_for logic for filling entries contains race conditions or improper offset handling.

Comment on lines +106 to +140
Kokkos::parallel_scan(
range_type(0, lclA.numRows()), KOKKOS_LAMBDA(const LocalOrdinal rlid, LocalOrdinal& my_nnz, const bool is_final) {
auto row = lclA.rowConst(rlid);
auto rclid = lclColmap.getLocalElement(lclRowmap.getGlobalElement(rlid));

for (LocalOrdinal offset = 0; offset < row.length; ++offset) {
auto clid = row.colidx(offset);
auto val = row.value(offset);
if ((rclid == clid) || implATS::magnitude(val) > threshold) {
++my_nnz;
if (is_final && (rlid + 1 < lclA.numRows())) {
rowptr(rlid + 2) = my_nnz;
}
}
}
},
nnz);

entries_type entries(Kokkos::ViewAllocateWithoutInitializing("MueLu::GetThresholdedGraph::indices"), nnz);
values_type values(Kokkos::ViewAllocateWithoutInitializing("MueLu::GetThresholdedGraph::values"), nnz);
Kokkos::parallel_for(
range_type(0, lclA.numRows()), KOKKOS_LAMBDA(const LocalOrdinal rlid) {
auto row = lclA.rowConst(rlid);
auto rclid = lclColmap.getLocalElement(lclRowmap.getGlobalElement(rlid));

for (LocalOrdinal offset = 0; offset < row.length; ++offset) {
auto clid = row.colidx(offset);
auto val = row.value(offset);
if ((rclid == clid) || implATS::magnitude(val) > threshold) {
entries(rowptr(rlid + 1)) = clid;
values(rowptr(rlid + 1)) = val;
++rowptr(rlid + 1);
}
}
});
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The Kokkos::parallel_scan and subsequent Kokkos::parallel_for logic for constructing the rowptr, entries, and values arrays contains critical errors.

  1. Incorrect rowptr population in parallel_scan: The line rowptr(rlid + 2) = my_nnz; within the is_final block is incorrectly indexing rowptr. For an exclusive prefix sum, rowptr(rlid) should store the accumulated count up to the previous row, and rowptr(rlid + 1) would store the accumulated count up to the current row. The +2 offset is likely an off-by-one error, leading to an improperly constructed rowptr array.
  2. Race condition in parallel_for: The parallel_for loop modifies the shared rowptr array using ++rowptr(rlid + 1);. This creates a race condition, as multiple threads will attempt to modify the same shared memory location concurrently, leading to incorrect results and potential data corruption. rowptr should be treated as read-only within the parallel_for, and each thread should use a local counter to track its current write position for its respective row.

These issues will result in an incorrectly constructed Xpetra::CrsMatrix and are critical for the correctness of the application. The suggested code below provides a standard and correct pattern for using parallel_scan to compute exclusive prefix sums for rowptr and then using parallel_for with local offsets to fill the entries and values arrays.

      rowmap_type rowptr("MueLu::GetThresholdedMatrix::rowptr", lclA.numRows() + 1);
      LocalOrdinal total_nnz = 0;
      Kokkos::parallel_scan(
          range_type(0, lclA.numRows()), KOKKOS_LAMBDA(const LocalOrdinal rlid, LocalOrdinal& update, const bool is_final) {
            LocalOrdinal current_row_nnz = 0;
            auto row   = lclA.rowConst(rlid);
            auto rclid = lclColmap.getLocalElement(lclRowmap.getGlobalElement(rlid));

            for (LocalOrdinal offset = 0; offset < row.length; ++offset) {
              auto clid = row.colidx(offset);
              auto val  = row.value(offset);
              if ((rclid == clid) || implATS::magnitude(val) > threshold) {
                current_row_nnz++;
              }
            }
            if (is_final) {
              rowptr(rlid) = update;
            }
            update += current_row_nnz;
          },
          total_nnz);
      rowptr(lclA.numRows()) = total_nnz;

      entries_type entries(Kokkos::ViewAllocateWithoutInitializing("MueLu::GetThresholdedGraph::indices"), total_nnz);
      values_type values(Kokkos::ViewAllocateWithoutInitializing("MueLu::GetThresholdedGraph::values"), total_nnz);
      Kokkos::parallel_for(
          range_type(0, lclA.numRows()), KOKKOS_LAMBDA(const LocalOrdinal rlid) {
            LocalOrdinal current_row_offset = rowptr(rlid);
            auto row   = lclA.rowConst(rlid);
            auto rclid = lclColmap.getLocalElement(lclRowmap.getGlobalElement(rlid));

            for (LocalOrdinal offset = 0; offset < row.length; ++offset) {
              auto clid = row.colidx(offset);
              auto val  = row.value(offset);
              if ((rclid == clid) || implATS::magnitude(val) > threshold) {
                entries(current_row_offset) = clid;
                values(current_row_offset)  = val;
                current_row_offset++;
              }
            }
          });

Comment on lines +144 to +177
rowmap_type rowptr("MueLu::GetThresholdedMatrix::rowptr", lclA.numRows() + 1);
LocalOrdinal nnz = 0;
Kokkos::parallel_scan(
range_type(0, lclA.numRows()), KOKKOS_LAMBDA(const LocalOrdinal rlid, LocalOrdinal& my_nnz, const bool is_final) {
auto row = lclA.rowConst(rlid);

for (LocalOrdinal offset = 0; offset < row.length; ++offset) {
auto val = row.value(offset);
if (implATS::magnitude(val) > threshold) {
++my_nnz;
if (is_final && (rlid + 1 < lclA.numRows())) {
rowptr(rlid + 2) = my_nnz;
}
}
}
},
nnz);

entries_type entries(Kokkos::ViewAllocateWithoutInitializing("MueLu::GetThresholdedGraph::indices"), nnz);
values_type values(Kokkos::ViewAllocateWithoutInitializing("MueLu::GetThresholdedGraph::values"), nnz);
Kokkos::parallel_for(
range_type(0, lclA.numRows()), KOKKOS_LAMBDA(const LocalOrdinal rlid) {
auto row = lclA.rowConst(rlid);

for (LocalOrdinal offset = 0; offset < row.length; ++offset) {
auto clid = row.colidx(offset);
auto val = row.value(offset);
if (implATS::magnitude(val) > threshold) {
entries(rowptr(rlid + 1)) = clid;
values(rowptr(rlid + 1)) = val;
++rowptr(rlid + 1);
}
}
});
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This else branch for GetThresholdedMatrix suffers from the same critical issues as the if (keepDiagonal) branch regarding the Kokkos::parallel_scan and Kokkos::parallel_for logic for populating rowptr, entries, and values. Please apply the same correction pattern as suggested for lines 106-140 to ensure correct matrix construction.

      rowmap_type rowptr("MueLu::GetThresholdedMatrix::rowptr", lclA.numRows() + 1);
      LocalOrdinal total_nnz = 0;
      Kokkos::parallel_scan(
          range_type(0, lclA.numRows()), KOKKOS_LAMBDA(const LocalOrdinal rlid, LocalOrdinal& update, const bool is_final) {
            LocalOrdinal current_row_nnz = 0;
            auto row = lclA.rowConst(rlid);

            for (LocalOrdinal offset = 0; offset < row.length; ++offset) {
              auto val = row.value(offset);
              if (implATS::magnitude(val) > threshold) {
                current_row_nnz++;
              }
            }
            if (is_final) {
              rowptr(rlid) = update;
            }
            update += current_row_nnz;
          },
          total_nnz);
      rowptr(lclA.numRows()) = total_nnz;

      entries_type entries(Kokkos::ViewAllocateWithoutInitializing("MueLu::GetThresholdedGraph::indices"), total_nnz);
      values_type values(Kokkos::ViewAllocateWithoutInitializing("MueLu::GetThresholdedGraph::values"), total_nnz);
      Kokkos::parallel_for(
          range_type(0, lclA.numRows()), KOKKOS_LAMBDA(const LocalOrdinal rlid) {
            LocalOrdinal current_row_offset = rowptr(rlid);
            auto row = lclA.rowConst(rlid);

            for (LocalOrdinal offset = 0; offset < row.length; ++offset) {
              auto clid = row.colidx(offset);
              auto val  = row.value(offset);
              if (implATS::magnitude(val) > threshold) {
                entries(current_row_offset) = clid;
                values(current_row_offset)  = val;
                current_row_offset++;
              }
            }
          });

Comment on lines +209 to +261
LocalOrdinal nnz = 0;
Kokkos::parallel_scan(
range_type(0, lclA.numRows()), KOKKOS_LAMBDA(const LocalOrdinal rlid, LocalOrdinal& my_nnz, const bool is_final) {
auto row = lclA.rowConst(rlid);
auto rclid = lclColmap.getLocalElement(lclRowmap.getGlobalElement(rlid));

typename implATS::magnitudeType d = magATS::one();
for (LocalOrdinal offset = 0; offset < row.length; ++offset) {
auto clid = row.colidx(offset);
if (rclid == clid) {
auto val = implATS::magnitude(row.value(offset));
if (val > implATS::epsilon())
d = val;
}
}

GlobalOrdinal globalRow = A->getRowMap()->getGlobalElement(row);
LocalOrdinal col = A->getColMap()->getLocalElement(globalRow);
for (LocalOrdinal offset = 0; offset < row.length; ++offset) {
auto clid = row.colidx(offset);
auto val = row.value(offset);
if ((rclid == clid) || implATS::magnitude(val) > d * threshold) {
++my_nnz;
if (is_final && (rlid + 1 < lclA.numRows())) {
rowptr(rlid + 2) = my_nnz;
}
}
}
},
nnz);

const Scalar Dk = STS::magnitude(D[col]) > 0.0 ? STS::magnitude(D[col]) : 1.0;
Array<GlobalOrdinal> indicesNew;
entries_type entries(Kokkos::ViewAllocateWithoutInitializing("MueLu::GetThresholdedGraph::indices"), nnz);
Kokkos::parallel_for(
range_type(0, lclA.numRows()), KOKKOS_LAMBDA(const LocalOrdinal rlid) {
auto row = lclA.rowConst(rlid);
auto rclid = lclColmap.getLocalElement(lclRowmap.getGlobalElement(rlid));

typename implATS::magnitudeType d = magATS::one();
for (LocalOrdinal offset = 0; offset < row.length; ++offset) {
auto clid = row.colidx(offset);
if (rclid == clid) {
auto val = implATS::magnitude(row.value(offset));
if (val > implATS::epsilon())
d = val;
}
}

for (size_t i = 0; i < size_t(indices.size()); i++)
// keep diagonal per default
if (col == indices[i] || STS::magnitude(STS::squareroot(Dk) * vals[i] * STS::squareroot(Dk)) > STS::magnitude(threshold))
indicesNew.append(A->getColMap()->getGlobalElement(indices[i]));
for (LocalOrdinal offset = 0; offset < row.length; ++offset) {
auto clid = row.colidx(offset);
auto val = row.value(offset);
if ((rclid == clid) || implATS::magnitude(val) > d * threshold) {
entries(rowptr(rlid + 1)) = clid;
++rowptr(rlid + 1);
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The Kokkos::parallel_scan and Kokkos::parallel_for logic within GetThresholdedGraph for populating rowptr and entries suffers from the same critical issues as identified in GetThresholdedMatrix (lines 106-140).

  1. Incorrect rowptr population in parallel_scan: The rowptr(rlid + 2) = my_nnz; indexing is incorrect for computing exclusive prefix sums.
  2. Race condition in parallel_for: Modifying rowptr (++rowptr(rlid + 1);) within the parallel_for loop is a race condition.

These errors will lead to an incorrectly constructed Xpetra::CrsGraph. Please apply the same correction pattern as suggested for GetThresholdedMatrix to ensure the graph is built correctly.

    rowmap_type rowptr("MueLu::GetThresholdedGraph::rowptr", lclA.numRows() + 1);

    LocalOrdinal total_nnz = 0;
    Kokkos::parallel_scan(
        range_type(0, lclA.numRows()), KOKKOS_LAMBDA(const LocalOrdinal rlid, LocalOrdinal& update, const bool is_final) {
          LocalOrdinal current_row_nnz = 0;
          auto row   = lclA.rowConst(rlid);
          auto rclid = lclColmap.getLocalElement(lclRowmap.getGlobalElement(rlid));

          typename implATS::magnitudeType d = magATS::one();
          for (LocalOrdinal offset = 0; offset < row.length; ++offset) {
            auto clid = row.colidx(offset);
            if (rclid == clid) {
              auto val = implATS::magnitude(row.value(offset));
              if (val > implATS::epsilon())
                d = val;
            }
          }

          for (LocalOrdinal offset = 0; offset < row.length; ++offset) {
            auto clid = row.colidx(offset);
            auto val  = row.value(offset);
            if ((rclid == clid) || implATS::magnitude(val) > d * threshold) {
              current_row_nnz++;
            }
          }
          if (is_final) {
            rowptr(rlid) = update;
          }
          update += current_row_nnz;
        },
        total_nnz);
    rowptr(lclA.numRows()) = total_nnz;

    entries_type entries(Kokkos::ViewAllocateWithoutInitializing("MueLu::GetThresholdedGraph::indices"), total_nnz);
    Kokkos::parallel_for(
        range_type(0, lclA.numRows()), KOKKOS_LAMBDA(const LocalOrdinal rlid) {
          LocalOrdinal current_row_offset = rowptr(rlid);
          auto row   = lclA.rowConst(rlid);
          auto rclid = lclColmap.getLocalElement(lclRowmap.getGlobalElement(rlid));

          typename implATS::magnitudeType d = magATS::one();
          for (LocalOrdinal offset = 0; offset < row.length; ++offset) {
            auto clid = row.colidx(offset);
            if (rclid == clid) {
              auto val = implATS::magnitude(row.value(offset));
              if (val > implATS::epsilon())
                d = val;
            }
          }

          for (LocalOrdinal offset = 0; offset < row.length; ++offset) {
            auto clid = row.colidx(offset);
            auto val  = row.value(offset);
            if ((rclid == clid) || implATS::magnitude(val) > d * threshold) {
              entries(current_row_offset) = clid;
              current_row_offset++;
            }
          }
        });

@csiefer2 csiefer2 merged commit a43c648 into develop Apr 29, 2026
5 of 16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants