Skip to content

Commit 9bf4f50

Browse files
authored
Merge pull request #10721 from yyforyongyu/fix-paymentdb
paymentsdb: restore sql payment store parity with kv
2 parents cb04748 + 414fcc6 commit 9bf4f50

File tree

8 files changed

+307
-233
lines changed

8 files changed

+307
-233
lines changed

payments/db/payment_test.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3008,6 +3008,80 @@ func TestFetchInFlightPaymentsMultipleAttempts(t *testing.T) {
30083008
require.Len(t, inFlightPayments[0].HTLCs, 2)
30093009
}
30103010

3011+
// TestFetchInFlightPaymentsIncludesRetryablePayments tests that payments with
3012+
// only failed HTLCs but no payment-level failure reason are still returned as
3013+
// in-flight. This matches the shared payment state machine used by the router.
3014+
func TestFetchInFlightPaymentsIncludesRetryablePayments(t *testing.T) {
3015+
t.Parallel()
3016+
3017+
ctx := t.Context()
3018+
3019+
paymentDB, _ := NewTestDB(t)
3020+
3021+
preimg := genPreimage(t)
3022+
rhash := sha256.Sum256(preimg[:])
3023+
info := genPaymentCreationInfo(t, rhash)
3024+
3025+
err := paymentDB.InitPayment(ctx, info.PaymentIdentifier, info)
3026+
require.NoError(t, err)
3027+
3028+
attempt := genAttemptWithHash(t, 0, genSessionKey(t), rhash)
3029+
_, err = paymentDB.RegisterAttempt(ctx, info.PaymentIdentifier, attempt)
3030+
require.NoError(t, err)
3031+
3032+
_, err = paymentDB.FailAttempt(
3033+
ctx, info.PaymentIdentifier, attempt.AttemptID,
3034+
&HTLCFailInfo{Reason: HTLCFailUnreadable},
3035+
)
3036+
require.NoError(t, err)
3037+
3038+
payment, err := paymentDB.FetchPayment(ctx, info.PaymentIdentifier)
3039+
require.NoError(t, err)
3040+
require.Equal(t, StatusInFlight, payment.Status)
3041+
3042+
inFlightPayments, err := paymentDB.FetchInFlightPayments(ctx)
3043+
require.NoError(t, err)
3044+
3045+
inFlightHashes := make(map[lntypes.Hash]struct{}, len(inFlightPayments))
3046+
for _, p := range inFlightPayments {
3047+
inFlightHashes[p.Info.PaymentIdentifier] = struct{}{}
3048+
}
3049+
3050+
require.Contains(t, inFlightHashes, info.PaymentIdentifier)
3051+
}
3052+
3053+
// TestFetchInFlightPaymentsIncludesInitiatedPayments tests that payments which
3054+
// have been initialized but have not yet registered an HTLC are still returned
3055+
// as non-terminal payments.
3056+
func TestFetchInFlightPaymentsIncludesInitiatedPayments(t *testing.T) {
3057+
t.Parallel()
3058+
3059+
ctx := t.Context()
3060+
3061+
paymentDB, _ := NewTestDB(t)
3062+
3063+
preimg := genPreimage(t)
3064+
rhash := sha256.Sum256(preimg[:])
3065+
info := genPaymentCreationInfo(t, rhash)
3066+
3067+
err := paymentDB.InitPayment(ctx, info.PaymentIdentifier, info)
3068+
require.NoError(t, err)
3069+
3070+
payment, err := paymentDB.FetchPayment(ctx, info.PaymentIdentifier)
3071+
require.NoError(t, err)
3072+
require.Equal(t, StatusInitiated, payment.Status)
3073+
3074+
inFlightPayments, err := paymentDB.FetchInFlightPayments(ctx)
3075+
require.NoError(t, err)
3076+
3077+
inFlightHashes := make(map[lntypes.Hash]struct{}, len(inFlightPayments))
3078+
for _, p := range inFlightPayments {
3079+
inFlightHashes[p.Info.PaymentIdentifier] = struct{}{}
3080+
}
3081+
3082+
require.Contains(t, inFlightHashes, info.PaymentIdentifier)
3083+
}
3084+
30113085
// TestRouteFirstHopData tests that Route.FirstHopAmount and
30123086
// Route.FirstHopWireCustomRecords are correctly stored and retrieved.
30133087
func TestRouteFirstHopData(t *testing.T) {
@@ -3118,6 +3192,41 @@ func TestRegisterAttemptWithAMP(t *testing.T) {
31183192
require.Equal(t, childIndex, finalHop.AMP.ChildIndex())
31193193
}
31203194

3195+
// TestRegisterAttemptPreservesAttemptHash tests that an attempt's own hash is
3196+
// preserved independently from the payment identifier. This is especially
3197+
// important for AMP payments where the payment identifier is the SetID and the
3198+
// individual HTLC attempts each use their own payment hash.
3199+
func TestRegisterAttemptPreservesAttemptHash(t *testing.T) {
3200+
t.Parallel()
3201+
3202+
ctx := t.Context()
3203+
3204+
paymentDB, _ := NewTestDB(t)
3205+
3206+
setID := lntypes.Hash{1, 2, 3, 4}
3207+
attemptHash := lntypes.Hash{5, 6, 7, 8}
3208+
info := genPaymentCreationInfo(t, setID)
3209+
3210+
err := paymentDB.InitPayment(ctx, info.PaymentIdentifier, info)
3211+
require.NoError(t, err)
3212+
3213+
attempt := genAttemptWithHash(t, 0, genSessionKey(t), attemptHash)
3214+
finalHopIdx := len(attempt.Route.Hops) - 1
3215+
attempt.Route.Hops[finalHopIdx].AMP = record.NewAMP(
3216+
[32]byte{9, 10, 11, 12}, setID, 99,
3217+
)
3218+
3219+
_, err = paymentDB.RegisterAttempt(ctx, info.PaymentIdentifier, attempt)
3220+
require.NoError(t, err)
3221+
3222+
payment, err := paymentDB.FetchPayment(ctx, info.PaymentIdentifier)
3223+
require.NoError(t, err)
3224+
require.Len(t, payment.HTLCs, 1)
3225+
require.NotNil(t, payment.HTLCs[0].Hash)
3226+
require.Equal(t, attemptHash, *payment.HTLCs[0].Hash)
3227+
require.NotEqual(t, info.PaymentIdentifier, *payment.HTLCs[0].Hash)
3228+
}
3229+
31213230
// TestRegisterAttemptWithBlindedRoute tests that blinded route data
31223231
// (EncryptedData, BlindingPoint, TotalAmtMsat) is correctly stored and
31233232
// retrieved.

payments/db/sql_store.go

Lines changed: 32 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"errors"
88
"fmt"
99
"math"
10-
"sort"
1110
"strconv"
1211
"time"
1312

@@ -50,12 +49,12 @@ type SQLQueries interface {
5049
FilterPaymentsDesc(ctx context.Context, query sqlc.FilterPaymentsDescParams) ([]sqlc.FilterPaymentsDescRow, error)
5150
FetchPayment(ctx context.Context, paymentIdentifier []byte) (sqlc.FetchPaymentRow, error)
5251
FetchPaymentsByIDs(ctx context.Context, paymentIDs []int64) ([]sqlc.FetchPaymentsByIDsRow, error)
52+
FetchNonTerminalPayments(ctx context.Context, arg sqlc.FetchNonTerminalPaymentsParams) ([]sqlc.FetchNonTerminalPaymentsRow, error)
5353

5454
CountPayments(ctx context.Context) (int64, error)
5555

5656
FetchHtlcAttemptsForPayments(ctx context.Context, paymentIDs []int64) ([]sqlc.FetchHtlcAttemptsForPaymentsRow, error)
5757
FetchHtlcAttemptResolutionsForPayments(ctx context.Context, paymentIDs []int64) ([]sqlc.FetchHtlcAttemptResolutionsForPaymentsRow, error)
58-
FetchAllInflightAttempts(ctx context.Context, arg sqlc.FetchAllInflightAttemptsParams) ([]sqlc.PaymentHtlcAttempt, error)
5958
FetchHopsForAttempts(ctx context.Context, htlcAttemptIndices []int64) ([]sqlc.FetchHopsForAttemptsRow, error)
6059

6160
FetchPaymentDuplicates(ctx context.Context, paymentID int64) ([]sqlc.PaymentDuplicate, error)
@@ -182,88 +181,6 @@ func fetchPaymentWithCompleteData(ctx context.Context,
182181
return buildPaymentFromBatchData(dbPayment, batchData, true)
183182
}
184183

185-
// paymentsCompleteData holds the full payment data when batch loading base
186-
// payment data and all the related data for a payment.
187-
type paymentsCompleteData struct {
188-
*paymentsBaseData
189-
*paymentsDetailsData
190-
}
191-
192-
// batchLoadPayments loads the full payment data for a batch of payment IDs.
193-
func batchLoadPayments(ctx context.Context, cfg *sqldb.QueryConfig,
194-
db SQLQueries, paymentIDs []int64) (*paymentsCompleteData, error) {
195-
196-
baseData, err := batchLoadpaymentsBaseData(ctx, cfg, db, paymentIDs)
197-
if err != nil {
198-
return nil, fmt.Errorf("failed to load payment base data: %w",
199-
err)
200-
}
201-
202-
batchData, err := batchLoadPaymentDetailsData(
203-
ctx, cfg, db, paymentIDs, true,
204-
)
205-
if err != nil {
206-
return nil, fmt.Errorf("failed to load payment batch data: %w",
207-
err)
208-
}
209-
210-
return &paymentsCompleteData{
211-
paymentsBaseData: baseData,
212-
paymentsDetailsData: batchData,
213-
}, nil
214-
}
215-
216-
// paymentsBaseData holds the base payment and intent data for a batch of
217-
// payments.
218-
type paymentsBaseData struct {
219-
// paymentsAndIntents maps payment ID to its payment and intent data.
220-
paymentsAndIntents map[int64]sqlc.PaymentAndIntent
221-
}
222-
223-
// batchLoadpaymentsBaseData loads the base payment and payment intent data for
224-
// a batch of payment IDs. This complements loadPaymentsBatchData which loads
225-
// related data (attempts, hops, custom records) but not the payment table
226-
// and payment intent table data.
227-
func batchLoadpaymentsBaseData(ctx context.Context,
228-
cfg *sqldb.QueryConfig, db SQLQueries,
229-
paymentIDs []int64) (*paymentsBaseData, error) {
230-
231-
baseData := &paymentsBaseData{
232-
paymentsAndIntents: make(map[int64]sqlc.PaymentAndIntent),
233-
}
234-
235-
if len(paymentIDs) == 0 {
236-
return baseData, nil
237-
}
238-
239-
err := sqldb.ExecuteBatchQuery(
240-
ctx, cfg, paymentIDs,
241-
func(id int64) int64 { return id },
242-
func(ctx context.Context, ids []int64) (
243-
[]sqlc.FetchPaymentsByIDsRow, error) {
244-
245-
records, err := db.FetchPaymentsByIDs(
246-
ctx, ids,
247-
)
248-
249-
return records, err
250-
},
251-
func(ctx context.Context,
252-
payment sqlc.FetchPaymentsByIDsRow) error {
253-
254-
baseData.paymentsAndIntents[payment.ID] = payment
255-
256-
return nil
257-
},
258-
)
259-
if err != nil {
260-
return nil, fmt.Errorf("failed to fetch payment base "+
261-
"data: %w", err)
262-
}
263-
264-
return baseData, nil
265-
}
266-
267184
// paymentsRelatedData holds all the batch-loaded data for multiple payments.
268185
// This does not include the base payment and intent data which is fetched
269186
// separately. It includes the additional data like attempts, hops, hop custom
@@ -1047,104 +964,64 @@ func (s *SQLStore) FetchInFlightPayments(ctx context.Context) ([]*MPPayment,
1047964
var mpPayments []*MPPayment
1048965

1049966
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
1050-
// Track which payment IDs we've already processed across all
1051-
// pages to avoid loading the same payment multiple times when
1052-
// multiple inflight attempts belong to the same payment.
1053-
processedPayments := make(map[int64]*MPPayment)
967+
extractCursor := func(
968+
row sqlc.FetchNonTerminalPaymentsRow) int64 {
1054969

1055-
extractCursor := func(row sqlc.PaymentHtlcAttempt) int64 {
1056-
return row.AttemptIndex
970+
return row.ID
1057971
}
1058972

1059-
// collectFunc extracts the payment ID from each attempt row.
1060-
collectFunc := func(row sqlc.PaymentHtlcAttempt) (
973+
collectFunc := func(row sqlc.FetchNonTerminalPaymentsRow) (
1061974
int64, error) {
1062975

1063-
return row.PaymentID, nil
976+
return row.ID, nil
1064977
}
1065978

1066-
// batchDataFunc loads payment data for a batch of payment IDs,
1067-
// but only for IDs we haven't processed yet.
1068979
batchDataFunc := func(ctx context.Context,
1069-
paymentIDs []int64) (*paymentsCompleteData, error) {
1070-
1071-
// Filter out already-processed payment IDs.
1072-
uniqueIDs := make([]int64, 0, len(paymentIDs))
1073-
for _, id := range paymentIDs {
1074-
_, processed := processedPayments[id]
1075-
if !processed {
1076-
uniqueIDs = append(uniqueIDs, id)
1077-
}
1078-
}
980+
paymentIDs []int64) (*paymentsDetailsData, error) {
1079981

1080-
// If uniqueIDs is empty, the batch load will return
1081-
// empty batch data.
1082-
return batchLoadPayments(
1083-
ctx, s.cfg.QueryCfg, db, uniqueIDs,
982+
return batchLoadPaymentDetailsData(
983+
ctx, s.cfg.QueryCfg, db, paymentIDs, true,
1084984
)
1085985
}
1086986

1087-
// processAttempt processes each attempt. We only build and
1088-
// store the payment once per unique payment ID.
1089-
processAttempt := func(ctx context.Context,
1090-
row sqlc.PaymentHtlcAttempt,
1091-
batchData *paymentsCompleteData) error {
1092-
1093-
// Skip if we've already processed this payment.
1094-
_, processed := processedPayments[row.PaymentID]
1095-
if processed {
1096-
return nil
1097-
}
1098-
1099-
dbPayment := batchData.paymentsAndIntents[row.PaymentID]
987+
processPayment := func(ctx context.Context,
988+
row sqlc.FetchNonTerminalPaymentsRow,
989+
batchData *paymentsDetailsData) error {
1100990

1101-
// Build the payment from batch data.
1102-
mpPayment, err := buildPaymentFromBatchData(
1103-
dbPayment, batchData.paymentsDetailsData, true,
991+
payment, err := buildPaymentFromBatchData(
992+
row, batchData, true,
1104993
)
1105994
if err != nil {
1106995
return fmt.Errorf("failed to build payment: %w",
1107996
err)
1108997
}
1109998

1110-
// Store in our processed map.
1111-
processedPayments[row.PaymentID] = mpPayment
999+
mpPayments = append(mpPayments, payment)
11121000

11131001
return nil
11141002
}
11151003

1116-
queryFunc := func(ctx context.Context, lastAttemptIndex int64,
1117-
limit int32) ([]sqlc.PaymentHtlcAttempt,
1004+
queryFunc := func(ctx context.Context, lastPaymentID int64,
1005+
limit int32) ([]sqlc.FetchNonTerminalPaymentsRow,
11181006
error) {
11191007

1120-
return db.FetchAllInflightAttempts(ctx,
1121-
sqlc.FetchAllInflightAttemptsParams{
1122-
AttemptIndex: lastAttemptIndex,
1123-
Limit: limit,
1008+
return db.FetchNonTerminalPayments(ctx,
1009+
sqlc.FetchNonTerminalPaymentsParams{
1010+
ID: lastPaymentID,
1011+
Limit: limit,
11241012
},
11251013
)
11261014
}
11271015

11281016
err := sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
1129-
ctx, s.cfg.QueryCfg, int64(-1), queryFunc,
1017+
ctx, s.cfg.QueryCfg, int64(0), queryFunc,
11301018
extractCursor, collectFunc, batchDataFunc,
1131-
processAttempt,
1019+
processPayment,
11321020
)
11331021
if err != nil {
11341022
return err
11351023
}
11361024

1137-
// Convert map to slice and sort by sequence number to
1138-
// produce a deterministic ordering.
1139-
mpPayments = make([]*MPPayment, 0, len(processedPayments))
1140-
for _, payment := range processedPayments {
1141-
mpPayments = append(mpPayments, payment)
1142-
}
1143-
sort.Slice(mpPayments, func(i, j int) bool {
1144-
return mpPayments[i].SequenceNum <
1145-
mpPayments[j].SequenceNum
1146-
})
1147-
11481025
return nil
11491026
}, func() {
11501027
mpPayments = nil
@@ -1597,13 +1474,21 @@ func (s *SQLStore) RegisterAttempt(ctx context.Context,
15971474
// Register the plain HTLC attempt next.
15981475
sessionKey := attempt.SessionKey()
15991476
sessionKeyBytes := sessionKey.Serialize()
1477+
attemptHash := paymentHash[:]
1478+
if attempt.Hash != nil {
1479+
attemptHash = attempt.Hash[:]
1480+
} else {
1481+
log.Errorf("RegisterAttempt: attempt %d has nil hash, "+
1482+
"falling back to payment identifier %x",
1483+
attempt.AttemptID, paymentHash)
1484+
}
16001485

16011486
_, err = db.InsertHtlcAttempt(ctx, sqlc.InsertHtlcAttemptParams{
16021487
PaymentID: dbPayment.Payment.ID,
16031488
AttemptIndex: int64(attempt.AttemptID),
16041489
SessionKey: sessionKeyBytes,
16051490
AttemptTime: attempt.AttemptTime,
1606-
PaymentHash: paymentHash[:],
1491+
PaymentHash: attemptHash,
16071492
FirstHopAmountMsat: int64(
16081493
attempt.Route.FirstHopAmount.Val.Int(),
16091494
),

0 commit comments

Comments
 (0)