Skip to content
109 changes: 109 additions & 0 deletions payments/db/payment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3008,6 +3008,80 @@ func TestFetchInFlightPaymentsMultipleAttempts(t *testing.T) {
require.Len(t, inFlightPayments[0].HTLCs, 2)
}

// TestFetchInFlightPaymentsIncludesRetryablePayments tests that payments with
// only failed HTLCs but no payment-level failure reason are still returned as
// in-flight. This matches the shared payment state machine used by the router.
func TestFetchInFlightPaymentsIncludesRetryablePayments(t *testing.T) {
t.Parallel()

ctx := t.Context()

paymentDB, _ := NewTestDB(t)

preimg := genPreimage(t)
rhash := sha256.Sum256(preimg[:])
info := genPaymentCreationInfo(t, rhash)

err := paymentDB.InitPayment(ctx, info.PaymentIdentifier, info)
require.NoError(t, err)

attempt := genAttemptWithHash(t, 0, genSessionKey(t), rhash)
_, err = paymentDB.RegisterAttempt(ctx, info.PaymentIdentifier, attempt)
require.NoError(t, err)

_, err = paymentDB.FailAttempt(
ctx, info.PaymentIdentifier, attempt.AttemptID,
&HTLCFailInfo{Reason: HTLCFailUnreadable},
)
require.NoError(t, err)

payment, err := paymentDB.FetchPayment(ctx, info.PaymentIdentifier)
require.NoError(t, err)
require.Equal(t, StatusInFlight, payment.Status)

inFlightPayments, err := paymentDB.FetchInFlightPayments(ctx)
require.NoError(t, err)

inFlightHashes := make(map[lntypes.Hash]struct{}, len(inFlightPayments))
for _, p := range inFlightPayments {
inFlightHashes[p.Info.PaymentIdentifier] = struct{}{}
}

require.Contains(t, inFlightHashes, info.PaymentIdentifier)
}

// TestFetchInFlightPaymentsIncludesInitiatedPayments tests that payments which
// have been initialized but have not yet registered an HTLC are still returned
// as non-terminal payments.
func TestFetchInFlightPaymentsIncludesInitiatedPayments(t *testing.T) {
t.Parallel()

ctx := t.Context()

paymentDB, _ := NewTestDB(t)

preimg := genPreimage(t)
rhash := sha256.Sum256(preimg[:])
info := genPaymentCreationInfo(t, rhash)

err := paymentDB.InitPayment(ctx, info.PaymentIdentifier, info)
require.NoError(t, err)

payment, err := paymentDB.FetchPayment(ctx, info.PaymentIdentifier)
require.NoError(t, err)
require.Equal(t, StatusInitiated, payment.Status)

inFlightPayments, err := paymentDB.FetchInFlightPayments(ctx)
require.NoError(t, err)

inFlightHashes := make(map[lntypes.Hash]struct{}, len(inFlightPayments))
for _, p := range inFlightPayments {
inFlightHashes[p.Info.PaymentIdentifier] = struct{}{}
}

require.Contains(t, inFlightHashes, info.PaymentIdentifier)
}

// TestRouteFirstHopData tests that Route.FirstHopAmount and
// Route.FirstHopWireCustomRecords are correctly stored and retrieved.
func TestRouteFirstHopData(t *testing.T) {
Expand Down Expand Up @@ -3118,6 +3192,41 @@ func TestRegisterAttemptWithAMP(t *testing.T) {
require.Equal(t, childIndex, finalHop.AMP.ChildIndex())
}

// TestRegisterAttemptPreservesAttemptHash tests that an attempt's own hash is
// preserved independently from the payment identifier. This is especially
// important for AMP payments where the payment identifier is the SetID and the
// individual HTLC attempts each use their own payment hash.
func TestRegisterAttemptPreservesAttemptHash(t *testing.T) {
t.Parallel()

ctx := t.Context()

paymentDB, _ := NewTestDB(t)

setID := lntypes.Hash{1, 2, 3, 4}
attemptHash := lntypes.Hash{5, 6, 7, 8}
info := genPaymentCreationInfo(t, setID)

err := paymentDB.InitPayment(ctx, info.PaymentIdentifier, info)
require.NoError(t, err)

attempt := genAttemptWithHash(t, 0, genSessionKey(t), attemptHash)
finalHopIdx := len(attempt.Route.Hops) - 1
attempt.Route.Hops[finalHopIdx].AMP = record.NewAMP(
[32]byte{9, 10, 11, 12}, setID, 99,
)

_, err = paymentDB.RegisterAttempt(ctx, info.PaymentIdentifier, attempt)
require.NoError(t, err)

payment, err := paymentDB.FetchPayment(ctx, info.PaymentIdentifier)
require.NoError(t, err)
require.Len(t, payment.HTLCs, 1)
require.NotNil(t, payment.HTLCs[0].Hash)
require.Equal(t, attemptHash, *payment.HTLCs[0].Hash)
require.NotEqual(t, info.PaymentIdentifier, *payment.HTLCs[0].Hash)
}

// TestRegisterAttemptWithBlindedRoute tests that blinded route data
// (EncryptedData, BlindingPoint, TotalAmtMsat) is correctly stored and
// retrieved.
Expand Down
92 changes: 28 additions & 64 deletions payments/db/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"math"
"sort"
"strconv"
"time"

Expand Down Expand Up @@ -50,6 +49,7 @@ type SQLQueries interface {
FilterPaymentsDesc(ctx context.Context, query sqlc.FilterPaymentsDescParams) ([]sqlc.FilterPaymentsDescRow, error)
FetchPayment(ctx context.Context, paymentIdentifier []byte) (sqlc.FetchPaymentRow, error)
FetchPaymentsByIDs(ctx context.Context, paymentIDs []int64) ([]sqlc.FetchPaymentsByIDsRow, error)
FetchNonTerminalPayments(ctx context.Context, arg sqlc.FetchNonTerminalPaymentsParams) ([]sqlc.FetchNonTerminalPaymentsRow, error)

CountPayments(ctx context.Context) (int64, error)

Expand Down Expand Up @@ -1047,104 +1047,64 @@ func (s *SQLStore) FetchInFlightPayments(ctx context.Context) ([]*MPPayment,
var mpPayments []*MPPayment

err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
// Track which payment IDs we've already processed across all
// pages to avoid loading the same payment multiple times when
// multiple inflight attempts belong to the same payment.
processedPayments := make(map[int64]*MPPayment)
extractCursor := func(
row sqlc.FetchNonTerminalPaymentsRow) int64 {

extractCursor := func(row sqlc.PaymentHtlcAttempt) int64 {
return row.AttemptIndex
return row.ID
}

// collectFunc extracts the payment ID from each attempt row.
collectFunc := func(row sqlc.PaymentHtlcAttempt) (
collectFunc := func(row sqlc.FetchNonTerminalPaymentsRow) (
int64, error) {

return row.PaymentID, nil
return row.ID, nil
}

// batchDataFunc loads payment data for a batch of payment IDs,
// but only for IDs we haven't processed yet.
batchDataFunc := func(ctx context.Context,
paymentIDs []int64) (*paymentsCompleteData, error) {

// Filter out already-processed payment IDs.
uniqueIDs := make([]int64, 0, len(paymentIDs))
for _, id := range paymentIDs {
_, processed := processedPayments[id]
if !processed {
uniqueIDs = append(uniqueIDs, id)
}
}
paymentIDs []int64) (*paymentsDetailsData, error) {

// If uniqueIDs is empty, the batch load will return
// empty batch data.
return batchLoadPayments(
ctx, s.cfg.QueryCfg, db, uniqueIDs,
return batchLoadPaymentDetailsData(
ctx, s.cfg.QueryCfg, db, paymentIDs, true,
)
}

// processAttempt processes each attempt. We only build and
// store the payment once per unique payment ID.
processAttempt := func(ctx context.Context,
row sqlc.PaymentHtlcAttempt,
batchData *paymentsCompleteData) error {

// Skip if we've already processed this payment.
_, processed := processedPayments[row.PaymentID]
if processed {
return nil
}

dbPayment := batchData.paymentsAndIntents[row.PaymentID]
processPayment := func(ctx context.Context,
row sqlc.FetchNonTerminalPaymentsRow,
batchData *paymentsDetailsData) error {

// Build the payment from batch data.
mpPayment, err := buildPaymentFromBatchData(
dbPayment, batchData.paymentsDetailsData, true,
payment, err := buildPaymentFromBatchData(
row, batchData, true,
)
if err != nil {
return fmt.Errorf("failed to build payment: %w",
err)
}

// Store in our processed map.
processedPayments[row.PaymentID] = mpPayment
mpPayments = append(mpPayments, payment)

return nil
}

queryFunc := func(ctx context.Context, lastAttemptIndex int64,
limit int32) ([]sqlc.PaymentHtlcAttempt,
queryFunc := func(ctx context.Context, lastPaymentID int64,
limit int32) ([]sqlc.FetchNonTerminalPaymentsRow,
error) {

return db.FetchAllInflightAttempts(ctx,
sqlc.FetchAllInflightAttemptsParams{
AttemptIndex: lastAttemptIndex,
Limit: limit,
return db.FetchNonTerminalPayments(ctx,
sqlc.FetchNonTerminalPaymentsParams{
ID: lastPaymentID,
Limit: limit,
},
)
}

err := sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
ctx, s.cfg.QueryCfg, int64(0), queryFunc,
extractCursor, collectFunc, batchDataFunc,
processAttempt,
processPayment,
)
if err != nil {
return err
}

// Convert map to slice and sort by sequence number to
// produce a deterministic ordering.
mpPayments = make([]*MPPayment, 0, len(processedPayments))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We no longer need this sorting step?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea we no longer need the old Go-side sorting step. The current path appends one payment per selector row, and the selector query already returns rows ordered by p.id ASC

for _, payment := range processedPayments {
mpPayments = append(mpPayments, payment)
}
sort.Slice(mpPayments, func(i, j int) bool {
return mpPayments[i].SequenceNum <
mpPayments[j].SequenceNum
})

return nil
}, func() {
mpPayments = nil
Expand Down Expand Up @@ -1597,13 +1557,17 @@ func (s *SQLStore) RegisterAttempt(ctx context.Context,
// Register the plain HTLC attempt next.
sessionKey := attempt.SessionKey()
sessionKeyBytes := sessionKey.Serialize()
attemptHash := paymentHash[:]
if attempt.Hash != nil {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we log an error in case the hash is nil normally this should never happen or ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I added 412db8ae7 to log the nil-hash fallback path so we get a diagnostic if a newly registered attempt ever reaches it unexpectedly. The legacy fallback behavior itself stays the same.

attemptHash = attempt.Hash[:]
}

_, err = db.InsertHtlcAttempt(ctx, sqlc.InsertHtlcAttemptParams{
PaymentID: dbPayment.Payment.ID,
AttemptIndex: int64(attempt.AttemptID),
SessionKey: sessionKeyBytes,
AttemptTime: attempt.AttemptTime,
PaymentHash: paymentHash[:],
PaymentHash: attemptHash,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth noting that while the stored hash was wrong for AMP payments, this did not cause observable payment failures. Settlement is driven by the preimage arriving on the wire, and
failure decryption relies solely on the session key and payment path — neither path consults the stored attempt hash. The bug was a data correctness issue rather than a live payment
breakage, which also explains why it went undetected.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I left the runtime behavior unchanged and treated this as a data-correctness fix rather than a live payment breakage. I only followed up with 412db8ae7 to add a diagnostic on the nil-hash fallback path.

FirstHopAmountMsat: int64(
attempt.Route.FirstHopAmount.Val.Int(),
),
Expand Down
29 changes: 29 additions & 0 deletions sqldb/sqlc/db_custom.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,32 @@ func (r FetchPaymentsByIDsRow) GetPaymentIntent() PaymentIntent {
IntentPayload: r.IntentPayload,
}
}

// GetPayment returns the Payment associated with this interface.
//
// NOTE: This method is part of the PaymentAndIntent interface.
func (r FetchNonTerminalPaymentsRow) GetPayment() Payment {
return Payment{
ID: r.ID,
AmountMsat: r.AmountMsat,
CreatedAt: r.CreatedAt,
PaymentIdentifier: r.PaymentIdentifier,
FailReason: r.FailReason,
}
}

// GetPaymentIntent returns the PaymentIntent associated with this payment.
// If the payment has no intent (IntentType is NULL), this returns a zero-value
// PaymentIntent.
//
// NOTE: This method is part of the PaymentAndIntent interface.
func (r FetchNonTerminalPaymentsRow) GetPaymentIntent() PaymentIntent {
if !r.IntentType.Valid {
return PaymentIntent{}
}

return PaymentIntent{
IntentType: r.IntentType.Int16,
IntentPayload: r.IntentPayload,
}
}
Loading