diff --git a/cmd/platform-mihomo-service/main.go b/cmd/platform-mihomo-service/main.go index 2668a42..e8aa5a6 100644 --- a/cmd/platform-mihomo-service/main.go +++ b/cmd/platform-mihomo-service/main.go @@ -74,7 +74,7 @@ func main() { managementUC, ) sharedSvc := service.NewMihomoCredentialService(ticketVerifier, managementUC) - genericSvc := service.NewGenericPlatformService(ticketVerifier, managementUC) + genericSvc := service.NewGenericPlatformService(ticketVerifier, bindUC, statusUC, managementUC) grpcSrv := server.NewGRPCServer(&bc, mihomoSvc, sharedSvc, genericSvc) app := kratos.New( diff --git a/go.mod b/go.mod index d8902e3..33d2b73 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module platform-mihomo-service go 1.24.0 require ( - github.com/PaiGramTeam/proto-contracts v0.0.0-20260416033414-4016470bc78d + github.com/PaiGramTeam/proto-contracts v0.0.0-20260419070719-80366e162dcf github.com/glebarez/sqlite v1.11.0 github.com/go-kratos/kratos/v2 v2.9.2 github.com/go-sql-driver/mysql v1.8.1 diff --git a/go.sum b/go.sum index 448e5b1..b849063 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/PaiGramTeam/proto-contracts v0.0.0-20260416033414-4016470bc78d h1:sRlQ5OlTFoft20wsKyi/7pVTTflEPqPcvDhOuoPMGpE= github.com/PaiGramTeam/proto-contracts v0.0.0-20260416033414-4016470bc78d/go.mod h1:i5FW0tHMyzubj1qI/7e04arFyrDv3YXkK1Pu5UxejBI= +github.com/PaiGramTeam/proto-contracts v0.0.0-20260419070719-80366e162dcf h1:Wa77pOdHqp/wtil+DecjMvAOZ0rlLczTu4/Wqi17Cjw= +github.com/PaiGramTeam/proto-contracts v0.0.0-20260419070719-80366e162dcf/go.mod h1:i5FW0tHMyzubj1qI/7e04arFyrDv3YXkK1Pu5UxejBI= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= diff --git a/internal/biz/ticket.go b/internal/biz/ticket.go index 41d1ea1..ad7fd76 100644 --- a/internal/biz/ticket.go +++ b/internal/biz/ticket.go @@ -4,6 +4,8 @@ type ServiceTicketClaims struct { ActorType string ActorID string OwnerUserID uint64 + // BindingID is the first-class control-plane binding identity carried by + // service tickets and used for authorization and resource lookup. BindingID uint64 Platform string PlatformAccountID string @@ -16,8 +18,7 @@ type ServiceTicketClaims struct { PlatformServiceKey string PlatformAccountRefID uint64 - // PlatformAccountRefID is a read-only legacy alias for BindingID so - // downstream callers can migrate incrementally. New tickets should use - // BindingID; if a token still carries platform_account_ref_id, verifier - // code requires it to match binding_id. + // PlatformAccountRefID is a read-only legacy alias for BindingID. New + // tickets should use binding_id; if platform_account_ref_id is present, + // verifier code requires it to match BindingID exactly. } diff --git a/internal/data/credential_repo.go b/internal/data/credential_repo.go index 10479ff..d117ac0 100644 --- a/internal/data/credential_repo.go +++ b/internal/data/credential_repo.go @@ -19,6 +19,12 @@ func NewCredentialRepo(db *gorm.DB) *CredentialRepo { return &CredentialRepo{db: db} } +func (r *CredentialRepo) WithinTransaction(ctx context.Context, fn func(context.Context) error) error { + return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return fn(withTx(ctx, tx)) + }) +} + func (r *CredentialRepo) Save(ctx context.Context, credential *biz.Credential) error { record := model.CredentialRecord{ BindingID: credential.BindingID, @@ -34,7 +40,7 @@ func (r *CredentialRepo) Save(ctx context.Context, credential *biz.Credential) e ExpiresAt: credential.ExpiresAt, } - return r.db.WithContext(ctx).Clauses(clause.OnConflict{ + return dbFromContext(ctx, r.db).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "binding_id"}}, UpdateAll: true, }).Create(&record).Error @@ -42,7 +48,7 @@ func (r *CredentialRepo) Save(ctx context.Context, credential *biz.Credential) e func (r *CredentialRepo) GetByBindingID(ctx context.Context, bindingID uint64) (*biz.Credential, error) { var record model.CredentialRecord - if err := r.db.WithContext(ctx).Where("binding_id = ?", bindingID).Order("id asc").Take(&record).Error; err != nil { + if err := dbFromContext(ctx, r.db).Where("binding_id = ?", bindingID).Order("id asc").Take(&record).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } @@ -54,7 +60,7 @@ func (r *CredentialRepo) GetByBindingID(ctx context.Context, bindingID uint64) ( func (r *CredentialRepo) GetByPlatformAccountID(ctx context.Context, platformAccountID string) (*biz.Credential, error) { var record model.CredentialRecord - if err := r.db.WithContext(ctx).Where("platform_account_id = ?", platformAccountID).Take(&record).Error; err != nil { + if err := dbFromContext(ctx, r.db).Where("platform_account_id = ?", platformAccountID).Take(&record).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } @@ -65,7 +71,7 @@ func (r *CredentialRepo) GetByPlatformAccountID(ctx context.Context, platformAcc } func (r *CredentialRepo) DeleteByPlatformAccountID(ctx context.Context, platformAccountID string) error { - return r.db.WithContext(ctx).Where("platform_account_id = ?", platformAccountID).Delete(&model.CredentialRecord{}).Error + return dbFromContext(ctx, r.db).Where("platform_account_id = ?", platformAccountID).Delete(&model.CredentialRecord{}).Error } func credentialFromRecord(record model.CredentialRecord) *biz.Credential { diff --git a/internal/data/device_repo.go b/internal/data/device_repo.go index 838f684..f889c7f 100644 --- a/internal/data/device_repo.go +++ b/internal/data/device_repo.go @@ -28,7 +28,7 @@ func (r *DeviceRepo) Save(ctx context.Context, device *biz.Device) error { LastSeenAt: device.LastSeenAt, } - return r.db.WithContext(ctx).Clauses(clause.OnConflict{ + return dbFromContext(ctx, r.db).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "platform_account_id"}, {Name: "device_id"}}, DoUpdates: clause.AssignmentColumns([]string{"device_fp", "device_name", "is_valid", "last_seen_at", "updated_at"}), }).Create(&record).Error @@ -36,7 +36,7 @@ func (r *DeviceRepo) Save(ctx context.Context, device *biz.Device) error { func (r *DeviceRepo) ListByPlatformAccountID(ctx context.Context, platformAccountID string) ([]*biz.Device, error) { var records []model.DeviceRecord - if err := r.db.WithContext(ctx).Where("platform_account_id = ?", platformAccountID).Order("id asc").Find(&records).Error; err != nil { + if err := dbFromContext(ctx, r.db).Where("platform_account_id = ?", platformAccountID).Order("id asc").Find(&records).Error; err != nil { return nil, err } @@ -57,5 +57,5 @@ func (r *DeviceRepo) ListByPlatformAccountID(ctx context.Context, platformAccoun } func (r *DeviceRepo) DeleteByPlatformAccountID(ctx context.Context, platformAccountID string) error { - return r.db.WithContext(ctx).Where("platform_account_id = ?", platformAccountID).Delete(&model.DeviceRecord{}).Error + return dbFromContext(ctx, r.db).Where("platform_account_id = ?", platformAccountID).Delete(&model.DeviceRecord{}).Error } diff --git a/internal/data/profile_repo.go b/internal/data/profile_repo.go index b4831ac..232dd87 100644 --- a/internal/data/profile_repo.go +++ b/internal/data/profile_repo.go @@ -31,7 +31,7 @@ func (r *ProfileRepo) Save(ctx context.Context, profile *biz.Profile) error { DiscoveredAt: profile.DiscoveredAt, } - return r.db.WithContext(ctx).Clauses(clause.OnConflict{ + return dbFromContext(ctx, r.db).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "platform_account_id"}, {Name: "player_id"}, {Name: "region"}}, DoUpdates: clause.AssignmentColumns([]string{"game_biz", "nickname", "level", "is_default", "updated_at"}), }).Create(&record).Error @@ -39,7 +39,7 @@ func (r *ProfileRepo) Save(ctx context.Context, profile *biz.Profile) error { func (r *ProfileRepo) ListByBindingID(ctx context.Context, bindingID uint64) ([]*biz.Profile, error) { var records []model.AccountProfile - if err := r.db.WithContext(ctx).Where("binding_id = ?", bindingID).Order("id asc").Find(&records).Error; err != nil { + if err := dbFromContext(ctx, r.db).Where("binding_id = ?", bindingID).Order("id asc").Find(&records).Error; err != nil { return nil, err } @@ -48,7 +48,7 @@ func (r *ProfileRepo) ListByBindingID(ctx context.Context, bindingID uint64) ([] func (r *ProfileRepo) ListByPlatformAccountID(ctx context.Context, platformAccountID string) ([]*biz.Profile, error) { var records []model.AccountProfile - if err := r.db.WithContext(ctx).Where("platform_account_id = ?", platformAccountID).Order("id asc").Find(&records).Error; err != nil { + if err := dbFromContext(ctx, r.db).Where("platform_account_id = ?", platformAccountID).Order("id asc").Find(&records).Error; err != nil { return nil, err } @@ -56,12 +56,12 @@ func (r *ProfileRepo) ListByPlatformAccountID(ctx context.Context, platformAccou } func (r *ProfileRepo) DeleteByPlatformAccountID(ctx context.Context, platformAccountID string) error { - return r.db.WithContext(ctx).Where("platform_account_id = ?", platformAccountID).Delete(&model.AccountProfile{}).Error + return dbFromContext(ctx, r.db).Where("platform_account_id = ?", platformAccountID).Delete(&model.AccountProfile{}).Error } func (r *ProfileRepo) DeleteMissingByPlatformAccountID(ctx context.Context, platformAccountID string, keep []biz.ProfileIdentity) error { var records []model.AccountProfile - if err := r.db.WithContext(ctx).Where("platform_account_id = ?", platformAccountID).Find(&records).Error; err != nil { + if err := dbFromContext(ctx, r.db).Where("platform_account_id = ?", platformAccountID).Find(&records).Error; err != nil { return err } keepSet := make(map[string]struct{}, len(keep)) @@ -73,7 +73,7 @@ func (r *ProfileRepo) DeleteMissingByPlatformAccountID(ctx context.Context, plat if _, ok := keepSet[key]; ok { continue } - if err := r.db.WithContext(ctx).Delete(&model.AccountProfile{}, record.ID).Error; err != nil { + if err := dbFromContext(ctx, r.db).Delete(&model.AccountProfile{}, record.ID).Error; err != nil { return err } } diff --git a/internal/data/ticket_verifier.go b/internal/data/ticket_verifier.go index 01a5ad4..01588ef 100644 --- a/internal/data/ticket_verifier.go +++ b/internal/data/ticket_verifier.go @@ -69,8 +69,10 @@ func (v *TicketVerifier) Verify(raw string, expectedAudience string) (*biz.Servi if claims.Platform == "" { return nil, fmt.Errorf("service ticket missing platform") } - if claims.ActorType == "consumer" && claims.Consumer == "" { - return nil, fmt.Errorf("service ticket missing consumer") + if claims.ActorType == "consumer" { + if claims.Consumer == "" { + return nil, fmt.Errorf("service ticket missing consumer") + } } userID := claims.OwnerUserID if claims.UserID != 0 { diff --git a/internal/data/ticket_verifier_test.go b/internal/data/ticket_verifier_test.go index cb79cea..4db4f49 100644 --- a/internal/data/ticket_verifier_test.go +++ b/internal/data/ticket_verifier_test.go @@ -111,6 +111,22 @@ func TestVerifyRejectsConsumerActorWithoutConsumerClaim(t *testing.T) { require.ErrorContains(t, err, "consumer") } +func TestVerifyAllowsNonConsumerActorTypesWithoutConsumerClaim(t *testing.T) { + verifier := NewTicketVerifier(testTicketIssuer, testTicketSigningKey) + raw := issueTestTicket(t, map[string]any{ + "actor_type": "robot", + "actor_id": "user-1", + "owner_user_id": float64(1), + "binding_id": float64(101), + "platform": "mihomo", + "scopes": []string{"mihomo.profile.read"}, + }) + + claims, err := verifier.Verify(raw, testTicketAudience) + require.NoError(t, err) + require.Equal(t, "robot", claims.ActorType) +} + func TestVerifyRejectsMismatchedLegacyPlatformAccountRefID(t *testing.T) { verifier := NewTicketVerifier(testTicketIssuer, testTicketSigningKey) raw := issueTestTicket(t, map[string]any{ diff --git a/internal/data/tx_context.go b/internal/data/tx_context.go new file mode 100644 index 0000000..3aa9c59 --- /dev/null +++ b/internal/data/tx_context.go @@ -0,0 +1,20 @@ +package data + +import ( + "context" + + "gorm.io/gorm" +) + +type txContextKey struct{} + +func withTx(ctx context.Context, tx *gorm.DB) context.Context { + return context.WithValue(ctx, txContextKey{}, tx) +} + +func dbFromContext(ctx context.Context, fallback *gorm.DB) *gorm.DB { + if tx, ok := ctx.Value(txContextKey{}).(*gorm.DB); ok && tx != nil { + return tx.WithContext(ctx) + } + return fallback.WithContext(ctx) +} diff --git a/internal/service/mihomo_account.go b/internal/service/mihomo_account.go index 3cbd8c6..84c5ef8 100644 --- a/internal/service/mihomo_account.go +++ b/internal/service/mihomo_account.go @@ -119,7 +119,11 @@ func (s *MihomoAccountService) RefreshCredential(ctx context.Context, req *v1.Re if err != nil { return nil, err } - if _, err := scopedGuardForPlatformAccount(claims, req.GetPlatformAccountId(), usecase.ActionStatusRead); err != nil { + guard, err := scopedGuardForPlatformAccount(claims, req.GetPlatformAccountId(), usecase.ActionCredentialRefresh) + if err != nil { + return nil, mapUsecaseError(err) + } + if err := guard.RequireBindingWide(); err != nil { return nil, mapUsecaseError(err) } @@ -212,11 +216,12 @@ func (s *MihomoAccountService) GetCredentialSummary(ctx context.Context, req *v1 return nil, err } - if _, err := scopedGuardForPlatformAccount(claims, req.GetPlatformAccountId(), usecase.ActionCredentialRead); err != nil { + guard, err := scopedGuard(claims, usecase.ActionCredentialRead) + if err != nil { return nil, mapUsecaseError(err) } - output, err := s.managementUC.GetCredentialSummary(ctx, req.GetPlatformAccountId()) + output, err := s.managementUC.GetCredentialSummaryWithScope(ctx, guard, req.GetPlatformAccountId()) if err != nil { return nil, mapUsecaseError(err) } @@ -233,7 +238,8 @@ func (s *MihomoAccountService) UpdateCredential(ctx context.Context, req *v1.Upd if err != nil { return nil, err } - if _, err := scopedGuardForPlatformAccount(claims, req.GetPlatformAccountId(), usecase.ActionCredentialUpdate); err != nil { + guard, err := scopedGuard(claims, usecase.ActionCredentialUpdate) + if err != nil { return nil, mapUsecaseError(err) } if req.GetDevice() == nil { @@ -255,7 +261,7 @@ func (s *MihomoAccountService) UpdateCredential(ctx context.Context, req *v1.Upd input.DeviceName = device.GetDeviceName() } - output, err := s.managementUC.UpdateCredential(ctx, input) + output, err := s.managementUC.UpdateCredentialWithScope(ctx, guard, input) if err != nil { return nil, mapUsecaseError(err) } @@ -272,11 +278,12 @@ func (s *MihomoAccountService) DeleteCredential(ctx context.Context, req *v1.Del if err != nil { return nil, err } - if _, err := scopedGuardForPlatformAccount(claims, req.GetPlatformAccountId(), usecase.ActionCredentialDelete); err != nil { + guard, err := scopedGuard(claims, usecase.ActionCredentialDelete) + if err != nil { return nil, mapUsecaseError(err) } - if err := s.managementUC.DeleteCredential(ctx, req.GetPlatformAccountId()); err != nil { + if err := s.managementUC.DeleteCredentialWithScope(ctx, guard, req.GetPlatformAccountId()); err != nil { return nil, mapUsecaseError(err) } diff --git a/internal/service/mihomo_account_test.go b/internal/service/mihomo_account_test.go index 9383909..9d8683a 100644 --- a/internal/service/mihomo_account_test.go +++ b/internal/service/mihomo_account_test.go @@ -67,7 +67,7 @@ func TestBindCredentialReturnsDiscoveredProfiles(t *testing.T) { require.Empty(t, validateResp.ErrorCode) refreshResp, err := svc.RefreshCredential(context.Background(), &v1.RefreshCredentialRequest{ - ServiceTicket: signedServiceTicketForAccount(t, bindResp.PlatformAccountId, "mihomo.status.read"), + ServiceTicket: signedServiceTicketForAccount(t, bindResp.PlatformAccountId, "mihomo.credential.refresh"), PlatformAccountId: bindResp.PlatformAccountId, }) require.NoError(t, err) @@ -167,6 +167,18 @@ func TestGetCredentialStatusRejectsOutOfScopePlatformAccountID(t *testing.T) { require.Equal(t, codes.PermissionDenied, status.Code(err)) } +func TestRefreshCredentialRejectsProfileScopedTicket(t *testing.T) { + svc := newMihomoAccountServiceForTest(t) + bindResp := bindCredentialForServiceTest(t, svc) + + _, err := svc.RefreshCredential(context.Background(), &v1.RefreshCredentialRequest{ + ServiceTicket: signedServiceTicketForProfile(t, bindResp.PlatformAccountId, 999, "mihomo.credential.refresh"), + PlatformAccountId: bindResp.PlatformAccountId, + }) + require.Error(t, err) + require.Equal(t, codes.PermissionDenied, status.Code(err)) +} + func TestConfirmPrimaryProfileRejectsUnknownPlayerID(t *testing.T) { svc := newMihomoAccountServiceForTest(t) bindResp, err := svc.BindCredential(context.Background(), &v1.BindCredentialRequest{ @@ -213,6 +225,18 @@ func TestDeleteCredentialRejectsMissingScope(t *testing.T) { require.Equal(t, codes.PermissionDenied, status.Code(err)) } +func TestDeleteCredentialRejectsProfileScopedTicket(t *testing.T) { + svc := newMihomoAccountServiceForTest(t) + bindResp := bindCredentialForServiceTest(t, svc) + + _, err := svc.DeleteCredential(context.Background(), &v1.DeleteCredentialRequest{ + ServiceTicket: signedServiceTicketForProfile(t, bindResp.PlatformAccountId, 999, "mihomo.credential.delete"), + PlatformAccountId: bindResp.PlatformAccountId, + }) + require.Error(t, err) + require.Equal(t, codes.PermissionDenied, status.Code(err)) +} + func TestBindCredentialRejectsMissingScope(t *testing.T) { svc := newMihomoAccountServiceForTest(t) @@ -237,6 +261,18 @@ func TestGetCredentialSummaryRejectsMissingScope(t *testing.T) { require.Equal(t, codes.PermissionDenied, status.Code(err)) } +func TestGetCredentialSummaryRejectsProfileScopedTicket(t *testing.T) { + svc := newMihomoAccountServiceForTest(t) + bindResp := bindCredentialForServiceTest(t, svc) + + _, err := svc.GetCredentialSummary(context.Background(), &v1.GetCredentialSummaryRequest{ + ServiceTicket: signedServiceTicketForProfile(t, bindResp.PlatformAccountId, 999, "mihomo.credential.read_meta"), + PlatformAccountId: bindResp.PlatformAccountId, + }) + require.Error(t, err) + require.Equal(t, codes.PermissionDenied, status.Code(err)) +} + func TestUpdateCredentialRejectsMissingScope(t *testing.T) { svc := newMihomoAccountServiceForTest(t) bindResp := bindCredentialForServiceTest(t, svc) @@ -251,6 +287,20 @@ func TestUpdateCredentialRejectsMissingScope(t *testing.T) { require.Equal(t, codes.PermissionDenied, status.Code(err)) } +func TestUpdateCredentialRejectsProfileScopedTicket(t *testing.T) { + svc := newMihomoAccountServiceForTest(t) + bindResp := bindCredentialForServiceTest(t, svc) + + _, err := svc.UpdateCredential(context.Background(), &v1.UpdateCredentialRequest{ + ServiceTicket: signedServiceTicketForProfile(t, bindResp.PlatformAccountId, 999, "mihomo.credential.update"), + PlatformAccountId: bindResp.PlatformAccountId, + CookieBundleJson: `{"account_id":"10001","cookie_token":"updated"}`, + Device: &v1.DeviceInfo{DeviceId: "device-2", DeviceFp: "fp-2", DeviceName: "iPad"}, + }) + require.Error(t, err) + require.Equal(t, codes.PermissionDenied, status.Code(err)) +} + func TestListProfilesRejectsMissingScope(t *testing.T) { svc := newMihomoAccountServiceForTest(t) bindResp := bindCredentialForServiceTest(t, svc) diff --git a/internal/service/mihomo_credential_service.go b/internal/service/mihomo_credential_service.go index f4bce0c..0610296 100644 --- a/internal/service/mihomo_credential_service.go +++ b/internal/service/mihomo_credential_service.go @@ -32,11 +32,12 @@ func (s *MihomoCredentialService) GetCredentialSummary(ctx context.Context, req if err != nil { return nil, status.Error(codes.Unauthenticated, "invalid service ticket") } - if _, err := scopedGuardForPlatformAccount(claims, req.GetPlatformAccountId(), usecase.ActionCredentialRead); err != nil { + guard, err := scopedGuardForPlatformAccount(claims, req.GetPlatformAccountId(), usecase.ActionCredentialRead) + if err != nil { return nil, mapUsecaseError(err) } - output, err := s.managementUC.GetCredentialSummary(ctx, req.GetPlatformAccountId()) + output, err := s.managementUC.GetCredentialSummaryWithScope(ctx, guard, req.GetPlatformAccountId()) if err != nil { return nil, mapUsecaseError(err) } diff --git a/internal/service/mihomo_credential_service_test.go b/internal/service/mihomo_credential_service_test.go index d3cf296..e754438 100644 --- a/internal/service/mihomo_credential_service_test.go +++ b/internal/service/mihomo_credential_service_test.go @@ -70,6 +70,25 @@ func TestMihomoCredentialServiceRejectsMissingSummaryScope(t *testing.T) { require.Error(t, err) } +func TestMihomoCredentialServiceRejectsProfileScopedSummaryTicket(t *testing.T) { + harness := newMihomoCredentialServiceForTest(t) + + bindResp, err := harness.bindUC.BindCredential(context.Background(), usecase.BindCredentialInput{ + BindingID: 101, + CookieBundleJSON: `{"account_id":"10001","cookie_token":"abc"}`, + DeviceID: "12345678-1234-1234-1234-123456789abc", + DeviceFP: "abcdefghijklmn", + DeviceName: "iPhone", + }) + require.NoError(t, err) + + _, err = harness.service.GetCredentialSummary(context.Background(), &mihomov1.GetCredentialSummaryRequest{ + ServiceTicket: signedServiceTicketForProfile(t, bindResp.PlatformAccountID, 999, "mihomo.credential.read_meta"), + PlatformAccountId: bindResp.PlatformAccountID, + }) + require.Error(t, err) +} + func signedMihomoSummaryTicket(t *testing.T, platformAccountID string, scopes ...string) string { t.Helper() return signedServiceTicketForAccount(t, platformAccountID, scopes...) diff --git a/internal/service/platform_service_adapter.go b/internal/service/platform_service_adapter.go index 2d85f15..3e2f525 100644 --- a/internal/service/platform_service_adapter.go +++ b/internal/service/platform_service_adapter.go @@ -2,6 +2,7 @@ package service import ( "context" + "encoding/json" platformv1 "github.com/PaiGramTeam/proto-contracts/platform/v1" "google.golang.org/grpc/codes" @@ -17,11 +18,13 @@ type GenericPlatformService struct { platformv1.UnimplementedPlatformServiceServer ticketVerifier *data.TicketVerifier + bindUC *usecase.BindUsecase + statusUC *usecase.StatusUsecase managementUC *usecase.ManagementUsecase } -func NewGenericPlatformService(ticketVerifier *data.TicketVerifier, managementUC *usecase.ManagementUsecase) *GenericPlatformService { - return &GenericPlatformService{ticketVerifier: ticketVerifier, managementUC: managementUC} +func NewGenericPlatformService(ticketVerifier *data.TicketVerifier, bindUC *usecase.BindUsecase, statusUC *usecase.StatusUsecase, managementUC *usecase.ManagementUsecase) *GenericPlatformService { + return &GenericPlatformService{ticketVerifier: ticketVerifier, bindUC: bindUC, statusUC: statusUC, managementUC: managementUC} } func (s *GenericPlatformService) DescribePlatform(context.Context, *platformv1.DescribePlatformRequest) (*platformv1.DescribePlatformResponse, error) { @@ -43,7 +46,7 @@ func (s *GenericPlatformService) DescribePlatform(context.Context, *platformv1.D PlatformKey: "mihomo", DisplayName: "Mihomo", ServiceAudience: serviceTicketAudience, - SupportedActions: []string{"summary"}, + SupportedActions: []string{"summary", "put_credential", "refresh_credential", "delete_credential"}, CredentialSchema: credentialSchema, Version: "v1", }, nil @@ -58,11 +61,12 @@ func (s *GenericPlatformService) GetCredentialSummary(ctx context.Context, req * if err != nil { return nil, status.Error(codes.Unauthenticated, "invalid service ticket") } - if _, err := scopedGuardForPlatformAccount(claims, req.GetPlatformAccountId(), usecase.ActionCredentialRead); err != nil { + guard, err := scopedGuardForPlatformAccount(claims, req.GetPlatformAccountId(), usecase.ActionCredentialRead) + if err != nil { return nil, mapUsecaseError(err) } - output, err := s.managementUC.GetCredentialSummary(ctx, req.GetPlatformAccountId()) + output, err := s.managementUC.GetCredentialSummaryWithScope(ctx, guard, req.GetPlatformAccountId()) if err != nil { return nil, mapUsecaseError(err) } @@ -70,6 +74,151 @@ func (s *GenericPlatformService) GetCredentialSummary(ctx context.Context, req * return toGenericCredentialSummary(output), nil } +func (s *GenericPlatformService) PutCredential(ctx context.Context, req *platformv1.PutCredentialRequest) (*platformv1.PutCredentialResponse, error) { + if req == nil { + return nil, status.Error(codes.InvalidArgument, "request is required") + } + + claims, err := s.ticketVerifier.Verify(req.GetServiceTicket(), serviceTicketAudience) + if err != nil { + return nil, status.Error(codes.Unauthenticated, "invalid service ticket") + } + guard, err := scopedGuard(claims) + if err != nil { + return nil, mapUsecaseError(err) + } + if err := guard.RequireBindingWide(); err != nil { + return nil, mapUsecaseError(err) + } + + payload, err := decodeGenericCredentialPayload(req.GetCredentialPayloadJson()) + if err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } + bindInput := usecase.BindCredentialInput{ + BindingID: claims.BindingID, + CookieBundleJSON: payload.CookieBundle, + DeviceID: payload.DeviceID, + DeviceFP: payload.DeviceFP, + DeviceName: payload.DeviceName, + RegionHint: payload.RegionHint, + } + + platformAccountID := req.GetPlatformAccountId() + if platformAccountID != "" { + if err := guard.RequireAction(usecase.ActionCredentialUpdate); err != nil { + return nil, mapUsecaseError(err) + } + if err := guard.RequirePlatformAccountID(platformAccountID); err != nil { + return nil, mapUsecaseError(err) + } + summary, err := s.managementUC.UpdateCredentialWithScope(ctx, guard, usecase.UpdateCredentialInput{ + PlatformAccountID: platformAccountID, + BindCredentialInput: bindInput, + }) + if err != nil { + return nil, mapUsecaseError(err) + } + return &platformv1.PutCredentialResponse{Summary: toGenericCredentialSummary(summary)}, nil + } + if err := guard.RequireAction(usecase.ActionCredentialBind); err != nil { + return nil, mapUsecaseError(err) + } + + bound, err := s.bindUC.BindCredential(ctx, bindInput) + if err != nil { + return nil, mapUsecaseError(err) + } + summary, err := s.managementUC.GetCredentialSummaryWithScope(ctx, guard, bound.PlatformAccountID) + if err != nil { + return nil, mapUsecaseError(err) + } + return &platformv1.PutCredentialResponse{Summary: toGenericCredentialSummary(summary)}, nil +} + +func (s *GenericPlatformService) RefreshCredential(ctx context.Context, req *platformv1.RefreshCredentialRequest) (*platformv1.RefreshCredentialResponse, error) { + if req == nil { + return nil, status.Error(codes.InvalidArgument, "request is required") + } + + claims, err := s.ticketVerifier.Verify(req.GetServiceTicket(), serviceTicketAudience) + if err != nil { + return nil, status.Error(codes.Unauthenticated, "invalid service ticket") + } + guard, err := scopedGuard(claims, usecase.ActionCredentialRefresh) + if err != nil { + return nil, mapUsecaseError(err) + } + platformAccountID := req.GetPlatformAccountId() + if platformAccountID == "" { + platformAccountID = claims.PlatformAccountID + } + if platformAccountID == "" { + return nil, status.Error(codes.InvalidArgument, "platform_account_id is required") + } + if err := guard.RequirePlatformAccountID(platformAccountID); err != nil { + return nil, mapUsecaseError(err) + } + if err := guard.RequireBindingWide(); err != nil { + return nil, mapUsecaseError(err) + } + + output, err := s.statusUC.RefreshCredential(ctx, platformAccountID) + if err != nil { + return nil, mapUsecaseError(err) + } + return &platformv1.RefreshCredentialResponse{Status: toGenericCredentialStatus(output.Status), RefreshedAt: toTimestamp(output.RefreshedAt)}, nil +} + +func (s *GenericPlatformService) DeleteCredential(ctx context.Context, req *platformv1.DeleteCredentialRequest) (*platformv1.DeleteCredentialResponse, error) { + if req == nil { + return nil, status.Error(codes.InvalidArgument, "request is required") + } + + claims, err := s.ticketVerifier.Verify(req.GetServiceTicket(), serviceTicketAudience) + if err != nil { + return nil, status.Error(codes.Unauthenticated, "invalid service ticket") + } + guard, err := scopedGuard(claims, usecase.ActionCredentialDelete) + if err != nil { + return nil, mapUsecaseError(err) + } + platformAccountID := req.GetPlatformAccountId() + if platformAccountID == "" { + platformAccountID = claims.PlatformAccountID + } + if platformAccountID == "" { + return nil, status.Error(codes.InvalidArgument, "platform_account_id is required") + } + if err := guard.RequirePlatformAccountID(platformAccountID); err != nil { + return nil, mapUsecaseError(err) + } + if err := guard.RequireBindingWide(); err != nil { + return nil, mapUsecaseError(err) + } + if err := s.managementUC.DeleteCredentialWithScope(ctx, guard, platformAccountID); err != nil { + return nil, mapUsecaseError(err) + } + return &platformv1.DeleteCredentialResponse{Success: true}, nil +} + +type genericCredentialPayload struct { + CookieBundle string `json:"cookie_bundle"` + DeviceID string `json:"device_id"` + DeviceFP string `json:"device_fp"` + DeviceName string `json:"device_name"` + RegionHint string `json:"region_hint"` +} + +func decodeGenericCredentialPayload(raw string) (*genericCredentialPayload, error) { + var payload genericCredentialPayload + if err := json.Unmarshal([]byte(raw), &payload); err != nil { + return nil, err + } + return &payload, nil +} + + func toGenericCredentialSummary(output *usecase.CredentialSummaryOutput) *platformv1.GetCredentialSummaryResponse { profiles := make([]*platformv1.ProfileSummary, 0, len(output.Profiles)) for _, profile := range output.Profiles { diff --git a/internal/service/platform_service_adapter_test.go b/internal/service/platform_service_adapter_test.go index 7ef6060..ab1659e 100644 --- a/internal/service/platform_service_adapter_test.go +++ b/internal/service/platform_service_adapter_test.go @@ -36,6 +36,8 @@ func TestGenericPlatformServiceGetCredentialSummary(t *testing.T) { ) adapter := NewGenericPlatformService( data.NewTicketVerifier(serviceTestIssuer, serviceTestSigningKey), + bindUC, + usecase.NewStatusUsecase(credentialRepo, client, serviceTestSigningKey), managementUC, ) @@ -78,6 +80,8 @@ func TestGenericPlatformServiceRejectsMissingSummaryScope(t *testing.T) { ) adapter := NewGenericPlatformService( data.NewTicketVerifier(serviceTestIssuer, serviceTestSigningKey), + bindUC, + usecase.NewStatusUsecase(credentialRepo, client, serviceTestSigningKey), managementUC, ) @@ -97,10 +101,55 @@ func TestGenericPlatformServiceRejectsMissingSummaryScope(t *testing.T) { require.Error(t, err) } +func TestGenericPlatformServiceRejectsProfileScopedSummaryTicket(t *testing.T) { + credentialRepo := newMemoryCredentialRepo() + deviceRepo := newMemoryDeviceRepo() + profileRepo := newMemoryProfileRepo() + artifactRepo := newMemoryArtifactRepo() + client := platformmihomo.StubClient{} + + bindUC := usecase.NewBindUsecase(credentialRepo, deviceRepo, profileRepo, client, serviceTestSigningKey) + profileUC := usecase.NewProfileUsecase(profileRepo) + managementUC := usecase.NewManagementUsecase( + credentialRepo, + deviceRepo, + profileRepo, + artifactRepo, + newMemoryManagementRepo(credentialRepo, deviceRepo, profileRepo, artifactRepo), + bindUC, + profileUC, + ) + adapter := NewGenericPlatformService( + data.NewTicketVerifier(serviceTestIssuer, serviceTestSigningKey), + bindUC, + usecase.NewStatusUsecase(credentialRepo, client, serviceTestSigningKey), + managementUC, + ) + + bindResp, err := bindUC.BindCredential(context.Background(), usecase.BindCredentialInput{ + BindingID: 101, + CookieBundleJSON: `{"account_id":"10001","cookie_token":"abc"}`, + DeviceID: "12345678-1234-1234-1234-123456789abc", + DeviceFP: "abcdefghijklmn", + DeviceName: "iPhone", + }) + require.NoError(t, err) + + _, err = adapter.GetCredentialSummary(context.Background(), &platformv1.GetCredentialSummaryRequest{ + ServiceTicket: signedServiceTicketForProfile(t, bindResp.PlatformAccountID, 999, "mihomo.credential.read_meta"), + PlatformAccountId: bindResp.PlatformAccountID, + }) + require.Error(t, err) +} + func TestGenericPlatformServiceDescribePlatform(t *testing.T) { + bindUC := usecase.NewBindUsecase(newMemoryCredentialRepo(), newMemoryDeviceRepo(), newMemoryProfileRepo(), platformmihomo.StubClient{}, serviceTestSigningKey) + statusUC := usecase.NewStatusUsecase(newMemoryCredentialRepo(), platformmihomo.StubClient{}, serviceTestSigningKey) adapter := NewGenericPlatformService( data.NewTicketVerifier(serviceTestIssuer, serviceTestSigningKey), - usecase.NewManagementUsecase(newMemoryCredentialRepo(), newMemoryDeviceRepo(), newMemoryProfileRepo(), newMemoryArtifactRepo(), newMemoryManagementRepo(newMemoryCredentialRepo(), newMemoryDeviceRepo(), newMemoryProfileRepo(), newMemoryArtifactRepo()), usecase.NewBindUsecase(newMemoryCredentialRepo(), newMemoryDeviceRepo(), newMemoryProfileRepo(), platformmihomo.StubClient{}, serviceTestSigningKey), usecase.NewProfileUsecase(newMemoryProfileRepo())), + bindUC, + statusUC, + usecase.NewManagementUsecase(newMemoryCredentialRepo(), newMemoryDeviceRepo(), newMemoryProfileRepo(), newMemoryArtifactRepo(), newMemoryManagementRepo(newMemoryCredentialRepo(), newMemoryDeviceRepo(), newMemoryProfileRepo(), newMemoryArtifactRepo()), bindUC, usecase.NewProfileUsecase(newMemoryProfileRepo())), ) resp, err := adapter.DescribePlatform(context.Background(), &platformv1.DescribePlatformRequest{}) @@ -108,15 +157,19 @@ func TestGenericPlatformServiceDescribePlatform(t *testing.T) { require.Equal(t, "mihomo", resp.PlatformKey) require.Equal(t, "Mihomo", resp.DisplayName) require.Equal(t, serviceTicketAudience, resp.ServiceAudience) - require.Equal(t, []string{"summary"}, resp.SupportedActions) + require.Equal(t, []string{"summary", "put_credential", "refresh_credential", "delete_credential"}, resp.SupportedActions) require.NotNil(t, resp.CredentialSchema) require.NotEmpty(t, resp.CredentialSchema.Fields) } func TestGenericPlatformServiceRegisteredOnGRPCServer(t *testing.T) { + bindUC := usecase.NewBindUsecase(newMemoryCredentialRepo(), newMemoryDeviceRepo(), newMemoryProfileRepo(), platformmihomo.StubClient{}, serviceTestSigningKey) + statusUC := usecase.NewStatusUsecase(newMemoryCredentialRepo(), platformmihomo.StubClient{}, serviceTestSigningKey) adapter := NewGenericPlatformService( data.NewTicketVerifier(serviceTestIssuer, serviceTestSigningKey), - usecase.NewManagementUsecase(newMemoryCredentialRepo(), newMemoryDeviceRepo(), newMemoryProfileRepo(), newMemoryArtifactRepo(), newMemoryManagementRepo(newMemoryCredentialRepo(), newMemoryDeviceRepo(), newMemoryProfileRepo(), newMemoryArtifactRepo()), usecase.NewBindUsecase(newMemoryCredentialRepo(), newMemoryDeviceRepo(), newMemoryProfileRepo(), platformmihomo.StubClient{}, serviceTestSigningKey), usecase.NewProfileUsecase(newMemoryProfileRepo())), + bindUC, + statusUC, + usecase.NewManagementUsecase(newMemoryCredentialRepo(), newMemoryDeviceRepo(), newMemoryProfileRepo(), newMemoryArtifactRepo(), newMemoryManagementRepo(newMemoryCredentialRepo(), newMemoryDeviceRepo(), newMemoryProfileRepo(), newMemoryArtifactRepo()), bindUC, usecase.NewProfileUsecase(newMemoryProfileRepo())), ) listener := bufconn.Listen(1024 * 1024) @@ -135,3 +188,153 @@ func TestGenericPlatformServiceRegisteredOnGRPCServer(t *testing.T) { require.NoError(t, err) require.Equal(t, "mihomo", resp.PlatformKey) } + +func TestGenericPlatformServicePutCredentialBindsWhenPlatformAccountIDUnknown(t *testing.T) { + credentialRepo := newMemoryCredentialRepo() + deviceRepo := newMemoryDeviceRepo() + profileRepo := newMemoryProfileRepo() + artifactRepo := newMemoryArtifactRepo() + client := platformmihomo.StubClient{} + + bindUC := usecase.NewBindUsecase(credentialRepo, deviceRepo, profileRepo, client, serviceTestSigningKey) + profileUC := usecase.NewProfileUsecase(profileRepo) + managementUC := usecase.NewManagementUsecase( + credentialRepo, + deviceRepo, + profileRepo, + artifactRepo, + newMemoryManagementRepo(credentialRepo, deviceRepo, profileRepo, artifactRepo), + bindUC, + profileUC, + ) + adapter := NewGenericPlatformService( + data.NewTicketVerifier(serviceTestIssuer, serviceTestSigningKey), + bindUC, + usecase.NewStatusUsecase(credentialRepo, client, serviceTestSigningKey), + managementUC, + ) + + resp, err := adapter.PutCredential(context.Background(), &platformv1.PutCredentialRequest{ + ServiceTicket: signedMihomoSummaryTicket(t, "", "mihomo.credential.bind"), + CredentialPayloadJson: `{"cookie_bundle":"{\"account_id\":\"10001\",\"cookie_token\":\"abc\"}","device_id":"12345678-1234-1234-1234-123456789abc","device_fp":"abcdefghijklmn","device_name":"iPhone","region_hint":"cn_gf01"}`, + }) + require.NoError(t, err) + require.NotNil(t, resp.GetSummary()) + require.Equal(t, "binding_101_10001", resp.GetSummary().GetPlatformAccountId()) +} + +func TestGenericPlatformServicePutCredentialRejectsCreateWithUpdateOnlyScope(t *testing.T) { + credentialRepo := newMemoryCredentialRepo() + deviceRepo := newMemoryDeviceRepo() + profileRepo := newMemoryProfileRepo() + artifactRepo := newMemoryArtifactRepo() + client := platformmihomo.StubClient{} + + bindUC := usecase.NewBindUsecase(credentialRepo, deviceRepo, profileRepo, client, serviceTestSigningKey) + profileUC := usecase.NewProfileUsecase(profileRepo) + managementUC := usecase.NewManagementUsecase( + credentialRepo, + deviceRepo, + profileRepo, + artifactRepo, + newMemoryManagementRepo(credentialRepo, deviceRepo, profileRepo, artifactRepo), + bindUC, + profileUC, + ) + adapter := NewGenericPlatformService( + data.NewTicketVerifier(serviceTestIssuer, serviceTestSigningKey), + bindUC, + usecase.NewStatusUsecase(credentialRepo, client, serviceTestSigningKey), + managementUC, + ) + + _, err := adapter.PutCredential(context.Background(), &platformv1.PutCredentialRequest{ + ServiceTicket: signedMihomoSummaryTicket(t, "", "mihomo.credential.update"), + CredentialPayloadJson: `{"cookie_bundle":"{\"account_id\":\"10001\",\"cookie_token\":\"abc\"}","device_id":"12345678-1234-1234-1234-123456789abc","device_fp":"abcdefghijklmn","device_name":"iPhone","region_hint":"cn_gf01"}`, + }) + require.Error(t, err) +} + +func TestGenericPlatformServicePutCredentialRejectsUpdateWithBindOnlyScope(t *testing.T) { + credentialRepo := newMemoryCredentialRepo() + deviceRepo := newMemoryDeviceRepo() + profileRepo := newMemoryProfileRepo() + artifactRepo := newMemoryArtifactRepo() + client := platformmihomo.StubClient{} + + bindUC := usecase.NewBindUsecase(credentialRepo, deviceRepo, profileRepo, client, serviceTestSigningKey) + profileUC := usecase.NewProfileUsecase(profileRepo) + managementUC := usecase.NewManagementUsecase( + credentialRepo, + deviceRepo, + profileRepo, + artifactRepo, + newMemoryManagementRepo(credentialRepo, deviceRepo, profileRepo, artifactRepo), + bindUC, + profileUC, + ) + adapter := NewGenericPlatformService( + data.NewTicketVerifier(serviceTestIssuer, serviceTestSigningKey), + bindUC, + usecase.NewStatusUsecase(credentialRepo, client, serviceTestSigningKey), + managementUC, + ) + + bindResp, err := bindUC.BindCredential(context.Background(), usecase.BindCredentialInput{ + BindingID: 101, + CookieBundleJSON: `{"account_id":"10001","cookie_token":"abc"}`, + DeviceID: "12345678-1234-1234-1234-123456789abc", + DeviceFP: "abcdefghijklmn", + DeviceName: "iPhone", + }) + require.NoError(t, err) + + _, err = adapter.PutCredential(context.Background(), &platformv1.PutCredentialRequest{ + ServiceTicket: signedMihomoSummaryTicket(t, bindResp.PlatformAccountID, "mihomo.credential.bind"), + PlatformAccountId: bindResp.PlatformAccountID, + CredentialPayloadJson: `{"cookie_bundle":"{\"account_id\":\"10001\",\"cookie_token\":\"updated\"}","device_id":"device-2","device_fp":"fp-2","device_name":"iPad","region_hint":"cn_gf01"}`, + }) + require.Error(t, err) +} + +func TestGenericPlatformServiceDeleteCredentialUsesDeleteScope(t *testing.T) { + credentialRepo := newMemoryCredentialRepo() + deviceRepo := newMemoryDeviceRepo() + profileRepo := newMemoryProfileRepo() + artifactRepo := newMemoryArtifactRepo() + client := platformmihomo.StubClient{} + + bindUC := usecase.NewBindUsecase(credentialRepo, deviceRepo, profileRepo, client, serviceTestSigningKey) + profileUC := usecase.NewProfileUsecase(profileRepo) + managementUC := usecase.NewManagementUsecase( + credentialRepo, + deviceRepo, + profileRepo, + artifactRepo, + newMemoryManagementRepo(credentialRepo, deviceRepo, profileRepo, artifactRepo), + bindUC, + profileUC, + ) + adapter := NewGenericPlatformService( + data.NewTicketVerifier(serviceTestIssuer, serviceTestSigningKey), + bindUC, + usecase.NewStatusUsecase(credentialRepo, client, serviceTestSigningKey), + managementUC, + ) + + bindResp, err := bindUC.BindCredential(context.Background(), usecase.BindCredentialInput{ + BindingID: 101, + CookieBundleJSON: `{"account_id":"10001","cookie_token":"abc"}`, + DeviceID: "12345678-1234-1234-1234-123456789abc", + DeviceFP: "abcdefghijklmn", + DeviceName: "iPhone", + }) + require.NoError(t, err) + + resp, err := adapter.DeleteCredential(context.Background(), &platformv1.DeleteCredentialRequest{ + ServiceTicket: signedMihomoSummaryTicket(t, bindResp.PlatformAccountID, "mihomo.credential.delete"), + PlatformAccountId: bindResp.PlatformAccountID, + }) + require.NoError(t, err) + require.True(t, resp.GetSuccess()) +} diff --git a/internal/usecase/bind_usecase.go b/internal/usecase/bind_usecase.go index 8270510..858eaac 100644 --- a/internal/usecase/bind_usecase.go +++ b/internal/usecase/bind_usecase.go @@ -27,6 +27,14 @@ type BindCredentialOutput struct { Status v1.CredentialStatus } +type bindCredentialPreparation struct { + platformAccountID string + accountID string + region string + discoveredProfiles []platformmihomo.DiscoveredProfile + encryptedBlob string +} + type BindUsecase struct { credentialRepo biz.CredentialRepository deviceRepo biz.DeviceRepository @@ -35,6 +43,10 @@ type BindUsecase struct { encryptionKey []byte } +type bindTransactioner interface { + WithinTransaction(ctx context.Context, fn func(context.Context) error) error +} + func NewBindUsecase( credentialRepo biz.CredentialRepository, deviceRepo biz.DeviceRepository, @@ -52,6 +64,34 @@ func NewBindUsecase( } func (uc *BindUsecase) BindCredential(ctx context.Context, input BindCredentialInput) (*BindCredentialOutput, error) { + prepared, err := uc.prepareBindCredential(ctx, input) + if err != nil { + return nil, err + } + + var output *BindCredentialOutput + err = uc.runInTransaction(ctx, func(txCtx context.Context) error { + result, err := uc.bindPreparedCredential(txCtx, input, prepared) + if err != nil { + return err + } + output = result + return nil + }) + if err != nil { + return nil, err + } + return output, nil +} + +func (uc *BindUsecase) runInTransaction(ctx context.Context, fn func(context.Context) error) error { + if txRepo, ok := uc.credentialRepo.(bindTransactioner); ok { + return txRepo.WithinTransaction(ctx, fn) + } + return fn(ctx) +} + +func (uc *BindUsecase) prepareBindCredential(ctx context.Context, input BindCredentialInput) (*bindCredentialPreparation, error) { if input.BindingID == 0 { return nil, errors.New("binding id is required") } @@ -67,31 +107,58 @@ func (uc *BindUsecase) BindCredential(ctx context.Context, input BindCredentialI return nil, err } + return &bindCredentialPreparation{ + platformAccountID: platformAccountID, + accountID: accountID, + region: region, + discoveredProfiles: discoveredProfiles, + encryptedBlob: encryptedBlob, + }, nil +} + +func (uc *BindUsecase) bindPreparedCredential(ctx context.Context, input BindCredentialInput, prepared *bindCredentialPreparation) (*BindCredentialOutput, error) { + if prepared == nil { + return nil, errors.New("bind preparation is required") + } + + existingCredential, err := uc.credentialRepo.GetByBindingID(ctx, input.BindingID) + if err != nil { + return nil, err + } + now := time.Now().UTC() if err := uc.credentialRepo.Save(ctx, &biz.Credential{ BindingID: input.BindingID, - PlatformAccountID: platformAccountID, + PlatformAccountID: prepared.platformAccountID, Platform: "mihomo", - AccountID: accountID, - Region: region, - CredentialBlob: encryptedBlob, + AccountID: prepared.accountID, + Region: prepared.region, + CredentialBlob: prepared.encryptedBlob, CredentialVersion: "v1", Status: "active", LastValidatedAt: &now, }); err != nil { return nil, err } + previousPlatformAccountID := "" + if existingCredential != nil && existingCredential.PlatformAccountID != "" && existingCredential.PlatformAccountID != prepared.platformAccountID { + previousPlatformAccountID = existingCredential.PlatformAccountID + } rollback := true defer func() { if rollback { - _ = uc.profileRepo.DeleteByPlatformAccountID(ctx, platformAccountID) - _ = uc.deviceRepo.DeleteByPlatformAccountID(ctx, platformAccountID) - _ = uc.credentialRepo.DeleteByPlatformAccountID(ctx, platformAccountID) + _ = uc.profileRepo.DeleteByPlatformAccountID(ctx, prepared.platformAccountID) + _ = uc.deviceRepo.DeleteByPlatformAccountID(ctx, prepared.platformAccountID) + if previousPlatformAccountID != "" { + _ = uc.credentialRepo.Save(ctx, existingCredential) + return + } + _ = uc.credentialRepo.DeleteByPlatformAccountID(ctx, prepared.platformAccountID) } }() device := &biz.Device{ - PlatformAccountID: platformAccountID, + PlatformAccountID: prepared.platformAccountID, DeviceID: input.DeviceID, DeviceFP: input.DeviceFP, IsValid: true, @@ -105,11 +172,11 @@ func (uc *BindUsecase) BindCredential(ctx context.Context, input BindCredentialI return nil, err } - outputProfiles := make([]v1.ProfileSummary, 0, len(discoveredProfiles)) - for index, discoveredProfile := range discoveredProfiles { + outputProfiles := make([]v1.ProfileSummary, 0, len(prepared.discoveredProfiles)) + for index, discoveredProfile := range prepared.discoveredProfiles { profile := &biz.Profile{ BindingID: input.BindingID, - PlatformAccountID: platformAccountID, + PlatformAccountID: prepared.platformAccountID, GameBiz: discoveredProfile.GameBiz, Region: discoveredProfile.Region, PlayerID: discoveredProfile.PlayerID, @@ -125,11 +192,23 @@ func (uc *BindUsecase) BindCredential(ctx context.Context, input BindCredentialI outputProfiles = append(outputProfiles, *toProfileSummary(profile)) } + if previousPlatformAccountID != "" { + if err := uc.profileRepo.DeleteByPlatformAccountID(ctx, previousPlatformAccountID); err != nil { + return nil, err + } + if err := uc.deviceRepo.DeleteByPlatformAccountID(ctx, previousPlatformAccountID); err != nil { + return nil, err + } + if err := uc.credentialRepo.DeleteByPlatformAccountID(ctx, previousPlatformAccountID); err != nil { + return nil, err + } + } + rollback = false return &BindCredentialOutput{ BindingID: input.BindingID, - PlatformAccountID: platformAccountID, + PlatformAccountID: prepared.platformAccountID, Profiles: outputProfiles, Status: v1.CredentialStatus_CREDENTIAL_STATUS_ACTIVE, }, nil diff --git a/internal/usecase/bind_usecase_test.go b/internal/usecase/bind_usecase_test.go index f39f528..0765f6e 100644 --- a/internal/usecase/bind_usecase_test.go +++ b/internal/usecase/bind_usecase_test.go @@ -137,6 +137,262 @@ func TestBindCredentialRollsBackWhenProfileSaveFails(t *testing.T) { require.Empty(t, uc.profileRepo.byPlatformAccountID) } +func TestBindCredentialRebindRemovesOldPlatformScopedRows(t *testing.T) { + client := &sequentialMihomoClient{ + results: []mihomoValidateResult{ + { + accountID: "10001", + region: "cn_gf01", + profiles: []platformmihomo.DiscoveredProfile{{ + GameBiz: "hk4e_cn", + Region: "cn_gf01", + PlayerID: "1008611", + Nickname: "Traveler", + Level: 60, + }}, + }, + { + accountID: "20002", + region: "cn_gf01", + profiles: []platformmihomo.DiscoveredProfile{{ + GameBiz: "hk4e_cn", + Region: "cn_gf01", + PlayerID: "2008622", + Nickname: "Rebound", + Level: 55, + }}, + }, + }, + } + uc := newBindUsecaseForTestWithClient(client) + + first, err := uc.BindCredential(context.Background(), BindCredentialInput{ + BindingID: 42, + CookieBundleJSON: `{"account_id":"10001","cookie_token":"abc"}`, + DeviceID: "device-old", + DeviceFP: "fp-old", + }) + require.NoError(t, err) + + second, err := uc.BindCredential(context.Background(), BindCredentialInput{ + BindingID: 42, + CookieBundleJSON: `{"account_id":"20002","cookie_token":"def"}`, + DeviceID: "device-new", + DeviceFP: "fp-new", + }) + require.NoError(t, err) + require.Equal(t, "binding_42_20002", second.PlatformAccountID) + + _, ok := uc.credentialRepo.byPlatformAccountID[first.PlatformAccountID] + require.False(t, ok) + require.NotContains(t, uc.deviceRepo.byPlatformAccountID, first.PlatformAccountID) + require.NotContains(t, uc.profileRepo.byPlatformAccountID, first.PlatformAccountID) + + credential, err := uc.credentialRepo.GetByBindingID(context.Background(), 42) + require.NoError(t, err) + require.NotNil(t, credential) + require.Equal(t, second.PlatformAccountID, credential.PlatformAccountID) + + profiles, err := uc.profileRepo.ListByBindingID(context.Background(), 42) + require.NoError(t, err) + require.Len(t, profiles, 1) + require.Equal(t, second.PlatformAccountID, profiles[0].PlatformAccountID) + require.Equal(t, "2008622", profiles[0].PlayerID) + + devices, err := uc.deviceRepo.ListByPlatformAccountID(context.Background(), second.PlatformAccountID) + require.NoError(t, err) + require.Len(t, devices, 1) + require.Equal(t, "device-new", devices[0].DeviceID) +} + +func TestBindCredentialRebindRollbackRestoresPreviousBindingState(t *testing.T) { + client := &sequentialMihomoClient{ + results: []mihomoValidateResult{ + { + accountID: "10001", + region: "cn_gf01", + profiles: []platformmihomo.DiscoveredProfile{{ + GameBiz: "hk4e_cn", + Region: "cn_gf01", + PlayerID: "1008611", + Nickname: "Traveler", + Level: 60, + }}, + }, + { + accountID: "20002", + region: "cn_gf01", + profiles: []platformmihomo.DiscoveredProfile{{ + GameBiz: "hk4e_cn", + Region: "cn_gf01", + PlayerID: "2008622", + Nickname: "Rebound", + Level: 55, + }}, + }, + }, + } + uc := newBindUsecaseForTestWithClient(client) + + first, err := uc.BindCredential(context.Background(), BindCredentialInput{ + BindingID: 42, + CookieBundleJSON: `{"account_id":"10001","cookie_token":"abc"}`, + DeviceID: "device-old", + DeviceFP: "fp-old", + }) + require.NoError(t, err) + + uc.profileRepo.failSave = true + _, err = uc.BindCredential(context.Background(), BindCredentialInput{ + BindingID: 42, + CookieBundleJSON: `{"account_id":"20002","cookie_token":"def"}`, + DeviceID: "device-new", + DeviceFP: "fp-new", + }) + require.Error(t, err) + + credential, err := uc.credentialRepo.GetByBindingID(context.Background(), 42) + require.NoError(t, err) + require.NotNil(t, credential) + require.Equal(t, first.PlatformAccountID, credential.PlatformAccountID) + + devices, err := uc.deviceRepo.ListByPlatformAccountID(context.Background(), first.PlatformAccountID) + require.NoError(t, err) + require.Len(t, devices, 1) + require.Equal(t, "device-old", devices[0].DeviceID) + require.NotContains(t, uc.deviceRepo.byPlatformAccountID, "binding_42_20002") + + profiles, err := uc.profileRepo.ListByBindingID(context.Background(), 42) + require.NoError(t, err) + require.Len(t, profiles, 1) + require.Equal(t, first.PlatformAccountID, profiles[0].PlatformAccountID) + require.Equal(t, "1008611", profiles[0].PlayerID) + require.NotContains(t, uc.profileRepo.byPlatformAccountID, "binding_42_20002") +} + +func TestBindCredentialRebindRollbackRestoresPreviousRowsWhenCleanupFails(t *testing.T) { + client := &sequentialMihomoClient{ + results: []mihomoValidateResult{ + { + accountID: "10001", + region: "cn_gf01", + profiles: []platformmihomo.DiscoveredProfile{{ + GameBiz: "hk4e_cn", + Region: "cn_gf01", + PlayerID: "1008611", + Nickname: "Traveler", + Level: 60, + }}, + }, + { + accountID: "20002", + region: "cn_gf01", + profiles: []platformmihomo.DiscoveredProfile{{ + GameBiz: "hk4e_cn", + Region: "cn_gf01", + PlayerID: "2008622", + Nickname: "Rebound", + Level: 55, + }}, + }, + }, + } + uc := newBindUsecaseForTestWithClient(client) + + first, err := uc.BindCredential(context.Background(), BindCredentialInput{ + BindingID: 42, + CookieBundleJSON: `{"account_id":"10001","cookie_token":"abc"}`, + DeviceID: "device-old", + DeviceFP: "fp-old", + }) + require.NoError(t, err) + + uc.deviceRepo.failDeleteByPlatformAccountID[first.PlatformAccountID] = errors.New("cleanup failed") + _, err = uc.BindCredential(context.Background(), BindCredentialInput{ + BindingID: 42, + CookieBundleJSON: `{"account_id":"20002","cookie_token":"def"}`, + DeviceID: "device-new", + DeviceFP: "fp-new", + }) + require.ErrorContains(t, err, "cleanup failed") + + credential, err := uc.credentialRepo.GetByBindingID(context.Background(), 42) + require.NoError(t, err) + require.NotNil(t, credential) + require.Equal(t, first.PlatformAccountID, credential.PlatformAccountID) + + devices, err := uc.deviceRepo.ListByPlatformAccountID(context.Background(), first.PlatformAccountID) + require.NoError(t, err) + require.Len(t, devices, 1) + require.Equal(t, "device-old", devices[0].DeviceID) + require.NotContains(t, uc.deviceRepo.byPlatformAccountID, "binding_42_20002") + + profiles, err := uc.profileRepo.ListByBindingID(context.Background(), 42) + require.NoError(t, err) + require.Len(t, profiles, 1) + require.Equal(t, first.PlatformAccountID, profiles[0].PlatformAccountID) + require.Equal(t, "1008611", profiles[0].PlayerID) + require.NotContains(t, uc.profileRepo.byPlatformAccountID, "binding_42_20002") +} + +func TestBindCredentialRebindUsesLatestBindingSnapshotInsideTransaction(t *testing.T) { + uc := newBindUsecaseForTest() + + first, err := uc.BindCredential(context.Background(), BindCredentialInput{ + BindingID: 42, + CookieBundleJSON: `{"account_id":"10001","cookie_token":"abc"}`, + DeviceID: "device-old", + DeviceFP: "fp-old", + }) + require.NoError(t, err) + + mutatedAccountID := FormatPlatformAccountID(42, "30003") + uc.client = &mutatingMihomoClient{ + repo: uc.credentialRepo, + deviceRepo: uc.deviceRepo, + profileRepo: uc.profileRepo, + bindingID: 42, + newAccountID: mutatedAccountID, + discoveredID: "20002", + discoveredUID: "2008622", + discoveredName: "Fresh", + } + + second, err := uc.BindCredential(context.Background(), BindCredentialInput{ + BindingID: 42, + CookieBundleJSON: `{"account_id":"20002","cookie_token":"def"}`, + DeviceID: "device-new", + DeviceFP: "fp-new", + }) + require.NoError(t, err) + require.Equal(t, "binding_42_20002", second.PlatformAccountID) + + _, ok := uc.credentialRepo.byPlatformAccountID[mutatedAccountID] + require.False(t, ok) + require.NotContains(t, uc.deviceRepo.byPlatformAccountID, mutatedAccountID) + require.NotContains(t, uc.profileRepo.byPlatformAccountID, mutatedAccountID) + + credential, err := uc.credentialRepo.GetByBindingID(context.Background(), 42) + require.NoError(t, err) + require.NotNil(t, credential) + require.Equal(t, second.PlatformAccountID, credential.PlatformAccountID) + require.NotEqual(t, first.PlatformAccountID, credential.PlatformAccountID) +} + +func TestBindCredentialValidatesBeforeTransaction(t *testing.T) { + client := &transactionObservingClient{} + uc := newBindUsecaseForTestWithClient(client) + + _, err := uc.BindCredential(context.Background(), BindCredentialInput{ + BindingID: 101, + CookieBundleJSON: `{"account_id":"10001","cookie_token":"abc"}`, + DeviceID: "12345678-1234-1234-1234-123456789abc", + DeviceFP: "abcdefghijklmn", + }) + require.NoError(t, err) + require.False(t, client.calledDuringTransaction) +} + var testEncryptionKey = []byte("0123456789abcdef0123456789abcdef") type bindUsecaseTestHarness struct { @@ -148,11 +404,20 @@ type bindUsecaseTestHarness struct { } func newBindUsecaseForTest() *bindUsecaseTestHarness { + return newBindUsecaseForTestWithClient(platformmihomo.StubClient{}) +} + +func newBindUsecaseForTestWithClient(client platformmihomo.Client) *bindUsecaseTestHarness { credentialRepo := newMemoryCredentialRepo() deviceRepo := newMemoryDeviceRepo() profileRepo := newMemoryProfileRepo() + credentialRepo.deviceRepo = deviceRepo + credentialRepo.profileRepo = profileRepo + if observingClient, ok := client.(*transactionObservingClient); ok { + observingClient.repo = credentialRepo + } - bindUsecase := NewBindUsecase(credentialRepo, deviceRepo, profileRepo, platformmihomo.StubClient{}, testEncryptionKey) + bindUsecase := NewBindUsecase(credentialRepo, deviceRepo, profileRepo, client, testEncryptionKey) return &bindUsecaseTestHarness{ BindUsecase: bindUsecase, @@ -163,9 +428,97 @@ func newBindUsecaseForTest() *bindUsecaseTestHarness { } } +type mihomoValidateResult struct { + accountID string + region string + profiles []platformmihomo.DiscoveredProfile + err error +} + +type sequentialMihomoClient struct { + results []mihomoValidateResult + index int +} + +func (c *sequentialMihomoClient) ValidateAndDiscover(_ context.Context, _ string, _ string) (string, string, []platformmihomo.DiscoveredProfile, error) { + result := c.results[c.index] + c.index++ + return result.accountID, result.region, result.profiles, result.err +} + +func (c *sequentialMihomoClient) IssueAuthKey(_ context.Context, _ string, _ string) (string, int64, error) { + return "stub-authkey", 300, nil +} + +type mutatingMihomoClient struct { + repo *memoryCredentialRepo + deviceRepo *memoryDeviceRepo + profileRepo *memoryProfileRepo + bindingID uint64 + newAccountID string + discoveredID string + discoveredUID string + discoveredName string + mutated bool +} + +func (c *mutatingMihomoClient) ValidateAndDiscover(_ context.Context, _ string, _ string) (string, string, []platformmihomo.DiscoveredProfile, error) { + if !c.mutated { + credential := &biz.Credential{ + BindingID: c.bindingID, + PlatformAccountID: c.newAccountID, + Platform: "mihomo", + AccountID: "30003", + Region: "cn_gf01", + CredentialBlob: "blob-30003", + CredentialVersion: "v1", + Status: "active", + } + _ = c.repo.Save(context.Background(), credential) + _ = c.deviceRepo.Save(context.Background(), &biz.Device{PlatformAccountID: c.newAccountID, DeviceID: "device-mutated", DeviceFP: "fp-mutated", IsValid: true}) + _ = c.profileRepo.Save(context.Background(), &biz.Profile{BindingID: c.bindingID, PlatformAccountID: c.newAccountID, GameBiz: "hk4e_cn", Region: "cn_gf01", PlayerID: "3008611", Nickname: "Mutated", Level: 60, IsDefault: true}) + c.mutated = true + } + + return c.discoveredID, "cn_gf01", []platformmihomo.DiscoveredProfile{{ + GameBiz: "hk4e_cn", + Region: "cn_gf01", + PlayerID: c.discoveredUID, + Nickname: c.discoveredName, + Level: 55, + }}, nil +} + +func (c *mutatingMihomoClient) IssueAuthKey(_ context.Context, _ string, _ string) (string, int64, error) { + return "stub-authkey", 300, nil +} + +type transactionObservingClient struct { + calledDuringTransaction bool + repo *memoryCredentialRepo +} + +func (c *transactionObservingClient) ValidateAndDiscover(_ context.Context, _ string, _ string) (string, string, []platformmihomo.DiscoveredProfile, error) { + c.calledDuringTransaction = c.repo.inTransaction + return "10001", "cn_gf01", []platformmihomo.DiscoveredProfile{{ + GameBiz: "hk4e_cn", + Region: "cn_gf01", + PlayerID: "1008611", + Nickname: "Traveler", + Level: 60, + }}, nil +} + +func (c *transactionObservingClient) IssueAuthKey(_ context.Context, _ string, _ string) (string, int64, error) { + return "stub-authkey", 300, nil +} + type memoryCredentialRepo struct { byPlatformAccountID map[string]*biz.Credential byBindingID map[uint64]*biz.Credential + deviceRepo *memoryDeviceRepo + profileRepo *memoryProfileRepo + inTransaction bool } func newMemoryCredentialRepo() *memoryCredentialRepo { @@ -202,18 +555,45 @@ func (r *memoryCredentialRepo) GetByBindingID(_ context.Context, bindingID uint6 func (r *memoryCredentialRepo) DeleteByPlatformAccountID(_ context.Context, platformAccountID string) error { if credential := r.byPlatformAccountID[platformAccountID]; credential != nil { - delete(r.byBindingID, credential.BindingID) + if current := r.byBindingID[credential.BindingID]; current != nil && current.PlatformAccountID == platformAccountID { + delete(r.byBindingID, credential.BindingID) + } } delete(r.byPlatformAccountID, platformAccountID) return nil } +func (r *memoryCredentialRepo) WithinTransaction(ctx context.Context, fn func(context.Context) error) error { + credentialByPlatform := cloneCredentialMapByPlatformAccountID(r.byPlatformAccountID) + credentialByBinding := cloneCredentialMapByBindingID(r.byBindingID) + deviceByPlatform := cloneDeviceMapByPlatformAccountID(r.deviceRepo.byPlatformAccountID) + profileByPlatform := cloneProfileMapByPlatformAccountID(r.profileRepo.byPlatformAccountID) + profileByBinding := cloneProfileMapByBindingID(r.profileRepo.byBindingID) + r.inTransaction = true + defer func() { + r.inTransaction = false + }() + if err := fn(ctx); err != nil { + r.byPlatformAccountID = credentialByPlatform + r.byBindingID = credentialByBinding + r.deviceRepo.byPlatformAccountID = deviceByPlatform + r.profileRepo.byPlatformAccountID = profileByPlatform + r.profileRepo.byBindingID = profileByBinding + return err + } + return nil +} + type memoryDeviceRepo struct { - byPlatformAccountID map[string][]*biz.Device + byPlatformAccountID map[string][]*biz.Device + failDeleteByPlatformAccountID map[string]error } func newMemoryDeviceRepo() *memoryDeviceRepo { - return &memoryDeviceRepo{byPlatformAccountID: make(map[string][]*biz.Device)} + return &memoryDeviceRepo{ + byPlatformAccountID: make(map[string][]*biz.Device), + failDeleteByPlatformAccountID: make(map[string]error), + } } func (r *memoryDeviceRepo) Save(_ context.Context, device *biz.Device) error { @@ -241,6 +621,9 @@ func (r *memoryDeviceRepo) ListByPlatformAccountID(_ context.Context, platformAc } func (r *memoryDeviceRepo) DeleteByPlatformAccountID(_ context.Context, platformAccountID string) error { + if err := r.failDeleteByPlatformAccountID[platformAccountID]; err != nil { + return err + } delete(r.byPlatformAccountID, platformAccountID) return nil } @@ -306,7 +689,19 @@ func (r *memoryProfileRepo) ListByBindingID(_ context.Context, bindingID uint64) func (r *memoryProfileRepo) DeleteByPlatformAccountID(_ context.Context, platformAccountID string) error { if profiles := r.byPlatformAccountID[platformAccountID]; len(profiles) > 0 { - delete(r.byBindingID, profiles[0].BindingID) + bindingID := profiles[0].BindingID + current := r.byBindingID[bindingID] + filtered := make([]*biz.Profile, 0, len(current)) + for _, profile := range current { + if profile.PlatformAccountID != platformAccountID { + filtered = append(filtered, profile) + } + } + if len(filtered) == 0 { + delete(r.byBindingID, bindingID) + } else { + r.byBindingID[bindingID] = filtered + } } delete(r.byPlatformAccountID, platformAccountID) return nil @@ -338,3 +733,60 @@ func (r *memoryProfileRepo) DeleteMissingByPlatformAccountID(_ context.Context, var _ biz.CredentialRepository = (*memoryCredentialRepo)(nil) var _ biz.DeviceRepository = (*memoryDeviceRepo)(nil) var _ biz.ProfileRepository = (*memoryProfileRepo)(nil) + +func cloneCredentialMapByPlatformAccountID(source map[string]*biz.Credential) map[string]*biz.Credential { + cloned := make(map[string]*biz.Credential, len(source)) + for key, credential := range source { + clone := *credential + cloned[key] = &clone + } + return cloned +} + +func cloneCredentialMapByBindingID(source map[uint64]*biz.Credential) map[uint64]*biz.Credential { + cloned := make(map[uint64]*biz.Credential, len(source)) + for key, credential := range source { + clone := *credential + cloned[key] = &clone + } + return cloned +} + +func cloneDeviceMapByPlatformAccountID(source map[string][]*biz.Device) map[string][]*biz.Device { + cloned := make(map[string][]*biz.Device, len(source)) + for key, devices := range source { + copied := make([]*biz.Device, 0, len(devices)) + for _, device := range devices { + clone := *device + copied = append(copied, &clone) + } + cloned[key] = copied + } + return cloned +} + +func cloneProfileMapByPlatformAccountID(source map[string][]*biz.Profile) map[string][]*biz.Profile { + cloned := make(map[string][]*biz.Profile, len(source)) + for key, profiles := range source { + copied := make([]*biz.Profile, 0, len(profiles)) + for _, profile := range profiles { + clone := *profile + copied = append(copied, &clone) + } + cloned[key] = copied + } + return cloned +} + +func cloneProfileMapByBindingID(source map[uint64][]*biz.Profile) map[uint64][]*biz.Profile { + cloned := make(map[uint64][]*biz.Profile, len(source)) + for key, profiles := range source { + copied := make([]*biz.Profile, 0, len(profiles)) + for _, profile := range profiles { + clone := *profile + copied = append(copied, &clone) + } + cloned[key] = copied + } + return cloned +} diff --git a/internal/usecase/management_usecase.go b/internal/usecase/management_usecase.go index 60b8e3b..426e23b 100644 --- a/internal/usecase/management_usecase.go +++ b/internal/usecase/management_usecase.go @@ -3,7 +3,6 @@ package usecase import ( "context" "errors" - "fmt" "time" v1 "platform-mihomo-service/api/mihomo/v1" @@ -80,6 +79,16 @@ func (uc *ManagementUsecase) GetCredentialSummary(ctx context.Context, platformA }, nil } +func (uc *ManagementUsecase) GetCredentialSummaryWithScope(ctx context.Context, guard ScopeGuard, platformAccountID string) (*CredentialSummaryOutput, error) { + if err := guard.RequirePlatformAccountID(platformAccountID); err != nil { + return nil, err + } + if err := guard.RequireBindingWide(); err != nil { + return nil, err + } + return uc.GetCredentialSummary(ctx, platformAccountID) +} + type UpdateCredentialInput struct { PlatformAccountID string BindCredentialInput @@ -94,27 +103,55 @@ func (uc *ManagementUsecase) UpdateCredential(ctx context.Context, input UpdateC return nil, ErrCredentialNotFound } - output, err := uc.bindUC.BindCredential(ctx, input.BindCredentialInput) + prepared, err := uc.bindUC.prepareBindCredential(ctx, input.BindCredentialInput) if err != nil { return nil, err } - if output.PlatformAccountID != input.PlatformAccountID { - if cleanupErr := uc.management.DeleteCredentialGraph(ctx, output.PlatformAccountID); cleanupErr != nil { - return nil, fmt.Errorf("%w: cleanup failed: %v", ErrPlatformAccountMismatch, cleanupErr) - } + if prepared.platformAccountID != input.PlatformAccountID { return nil, ErrPlatformAccountMismatch } - if err := uc.pruneStaleProfiles(ctx, input.PlatformAccountID, output.Profiles); err != nil { + + err = uc.bindUC.runInTransaction(ctx, func(txCtx context.Context) error { + result, err := uc.bindUC.bindPreparedCredential(txCtx, input.BindCredentialInput, prepared) + if err != nil { + return err + } + if err := uc.pruneStaleProfiles(txCtx, input.PlatformAccountID, result.Profiles); err != nil { + return err + } + return nil + }) + if err != nil { return nil, err } return uc.GetCredentialSummary(ctx, input.PlatformAccountID) } +func (uc *ManagementUsecase) UpdateCredentialWithScope(ctx context.Context, guard ScopeGuard, input UpdateCredentialInput) (*CredentialSummaryOutput, error) { + if err := guard.RequirePlatformAccountID(input.PlatformAccountID); err != nil { + return nil, err + } + if err := guard.RequireBindingWide(); err != nil { + return nil, err + } + return uc.UpdateCredential(ctx, input) +} + func (uc *ManagementUsecase) DeleteCredential(ctx context.Context, platformAccountID string) error { return uc.management.DeleteCredentialGraph(ctx, platformAccountID) } +func (uc *ManagementUsecase) DeleteCredentialWithScope(ctx context.Context, guard ScopeGuard, platformAccountID string) error { + if err := guard.RequirePlatformAccountID(platformAccountID); err != nil { + return err + } + if err := guard.RequireBindingWide(); err != nil { + return err + } + return uc.DeleteCredential(ctx, platformAccountID) +} + func (uc *ManagementUsecase) pruneStaleProfiles(ctx context.Context, platformAccountID string, profiles []v1.ProfileSummary) error { keep := make([]biz.ProfileIdentity, 0, len(profiles)) for i := range profiles { diff --git a/internal/usecase/management_usecase_test.go b/internal/usecase/management_usecase_test.go index 1ea80c2..67e0f17 100644 --- a/internal/usecase/management_usecase_test.go +++ b/internal/usecase/management_usecase_test.go @@ -53,8 +53,14 @@ func TestUpdateCredentialRejectsPlatformAccountMismatch(t *testing.T) { uc := newManagementUsecaseForTest(t) platformAccountID := bindCredentialForManagementTest(t, uc) uc.bindUC = NewBindUsecase(uc.credentialRepo, uc.deviceRepo, uc.profileRepo, mismatchClient{}, testEncryptionKey) + originalCredential, err := uc.credentialRepo.GetByPlatformAccountID(context.Background(), platformAccountID) + require.NoError(t, err) + originalDevices, err := uc.deviceRepo.ListByPlatformAccountID(context.Background(), platformAccountID) + require.NoError(t, err) + originalProfiles, err := uc.profileRepo.ListByPlatformAccountID(context.Background(), platformAccountID) + require.NoError(t, err) - _, err := uc.UpdateCredential(context.Background(), UpdateCredentialInput{ + _, err = uc.UpdateCredential(context.Background(), UpdateCredentialInput{ PlatformAccountID: platformAccountID, BindCredentialInput: BindCredentialInput{ BindingID: 101, @@ -65,6 +71,20 @@ func TestUpdateCredentialRejectsPlatformAccountMismatch(t *testing.T) { }, }) require.ErrorIs(t, err, ErrPlatformAccountMismatch) + require.Zero(t, uc.managementRepo.deleteCalls) + currentCredential, err := uc.credentialRepo.GetByPlatformAccountID(context.Background(), platformAccountID) + require.NoError(t, err) + require.NotNil(t, currentCredential) + require.Equal(t, originalCredential.PlatformAccountID, currentCredential.PlatformAccountID) + require.Equal(t, originalCredential.AccountID, currentCredential.AccountID) + currentDevices, err := uc.deviceRepo.ListByPlatformAccountID(context.Background(), platformAccountID) + require.NoError(t, err) + require.Len(t, currentDevices, len(originalDevices)) + require.Equal(t, originalDevices[0].DeviceID, currentDevices[0].DeviceID) + currentProfiles, err := uc.profileRepo.ListByPlatformAccountID(context.Background(), platformAccountID) + require.NoError(t, err) + require.Len(t, currentProfiles, len(originalProfiles)) + require.Equal(t, originalProfiles[0].PlayerID, currentProfiles[0].PlayerID) require.Nil(t, uc.credentialRepo.byPlatformAccountID["binding_101_20002"]) } @@ -110,6 +130,56 @@ func TestDeleteCredentialRemovesCredentialArtifactsAndRelations(t *testing.T) { require.ErrorIs(t, err, ErrCredentialNotFound) } +func TestGetCredentialSummaryWithScopeRejectsForeignBinding(t *testing.T) { + uc := newManagementUsecaseForTest(t) + platformAccountID := bindCredentialForManagementTest(t, uc) + + _, err := uc.GetCredentialSummaryWithScope(context.Background(), ScopeGuard{BindingID: 999}, platformAccountID) + require.ErrorIs(t, err, ErrBindingScopeDenied) +} + +func TestDeleteCredentialWithScopeRejectsForeignBinding(t *testing.T) { + uc := newManagementUsecaseForTest(t) + platformAccountID := bindCredentialForManagementTest(t, uc) + + err := uc.DeleteCredentialWithScope(context.Background(), ScopeGuard{BindingID: 999}, platformAccountID) + require.ErrorIs(t, err, ErrBindingScopeDenied) +} + +func TestGetCredentialSummaryWithScopeRejectsProfileScopedTicket(t *testing.T) { + uc := newManagementUsecaseForTest(t) + platformAccountID := bindCredentialForManagementTest(t, uc) + + _, err := uc.GetCredentialSummaryWithScope(context.Background(), ScopeGuard{BindingID: 101, ProfileID: 1001}, platformAccountID) + require.ErrorIs(t, err, ErrProfileScopeDenied) +} + +func TestUpdateCredentialWithScopeRejectsProfileScopedTicket(t *testing.T) { + uc := newManagementUsecaseForTest(t) + platformAccountID := bindCredentialForManagementTest(t, uc) + + _, err := uc.UpdateCredentialWithScope(context.Background(), ScopeGuard{BindingID: 101, ProfileID: 1001}, UpdateCredentialInput{ + PlatformAccountID: platformAccountID, + BindCredentialInput: BindCredentialInput{ + BindingID: 101, + CookieBundleJSON: `{"account_id":"10001","cookie_token":"updated"}`, + DeviceID: "device-2", + DeviceFP: "fp-2", + DeviceName: "iPad", + RegionHint: "cn_gf01", + }, + }) + require.ErrorIs(t, err, ErrProfileScopeDenied) +} + +func TestDeleteCredentialWithScopeRejectsProfileScopedTicket(t *testing.T) { + uc := newManagementUsecaseForTest(t) + platformAccountID := bindCredentialForManagementTest(t, uc) + + err := uc.DeleteCredentialWithScope(context.Background(), ScopeGuard{BindingID: 101, ProfileID: 1001}, platformAccountID) + require.ErrorIs(t, err, ErrProfileScopeDenied) +} + type managementUsecaseTestHarness struct { *ManagementUsecase credentialRepo *memoryCredentialRepo @@ -177,9 +247,11 @@ type memoryManagementRepo struct { deviceRepo *memoryDeviceRepo profileRepo *memoryProfileRepo artifactRepo *memoryArtifactRepo + deleteCalls int } func (r *memoryManagementRepo) DeleteCredentialGraph(_ context.Context, platformAccountID string) error { + r.deleteCalls++ requireDeleteArtifacts(r.artifactRepo, platformAccountID) delete(r.credentialRepo.byPlatformAccountID, platformAccountID) delete(r.deviceRepo.byPlatformAccountID, platformAccountID) diff --git a/internal/usecase/scope_guard.go b/internal/usecase/scope_guard.go index 3237a12..c46417c 100644 --- a/internal/usecase/scope_guard.go +++ b/internal/usecase/scope_guard.go @@ -15,6 +15,7 @@ const ( ActionDeviceUpdate = "mihomo.device.update" ActionCredentialRead = "mihomo.credential.read_meta" ActionCredentialUpdate = "mihomo.credential.update" + ActionCredentialRefresh = "mihomo.credential.refresh" ActionCredentialDelete = "mihomo.credential.delete" ) @@ -39,6 +40,13 @@ func (g ScopeGuard) RequirePlatformAccountID(platformAccountID string) error { return nil } +func (g ScopeGuard) RequireBindingWide() error { + if g.ProfileID != 0 { + return ErrProfileScopeDenied + } + return nil +} + func (g ScopeGuard) RequireProfile(bindingID, profileID uint64) error { if g.BindingID == 0 || g.BindingID != bindingID { return ErrBindingScopeDenied