Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/platform-mihomo-service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func main() {
managementUC,
)
sharedSvc := service.NewMihomoCredentialService(ticketVerifier, managementUC)
genericSvc := service.NewGenericPlatformService(ticketVerifier, managementUC)
genericSvc := service.NewGenericPlatformService(ticketVerifier, bindUC, statusUC, managementUC)

grpcSrv := server.NewGRPCServer(&bc, mihomoSvc, sharedSvc, genericSvc)
app := kratos.New(
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module platform-mihomo-service
go 1.24.0

require (
github.com/PaiGramTeam/proto-contracts v0.0.0-20260416033414-4016470bc78d
github.com/PaiGramTeam/proto-contracts v0.0.0-20260419070719-80366e162dcf
github.com/glebarez/sqlite v1.11.0
github.com/go-kratos/kratos/v2 v2.9.2
github.com/go-sql-driver/mysql v1.8.1
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/PaiGramTeam/proto-contracts v0.0.0-20260416033414-4016470bc78d h1:sRlQ5OlTFoft20wsKyi/7pVTTflEPqPcvDhOuoPMGpE=
github.com/PaiGramTeam/proto-contracts v0.0.0-20260416033414-4016470bc78d/go.mod h1:i5FW0tHMyzubj1qI/7e04arFyrDv3YXkK1Pu5UxejBI=
github.com/PaiGramTeam/proto-contracts v0.0.0-20260419070719-80366e162dcf h1:Wa77pOdHqp/wtil+DecjMvAOZ0rlLczTu4/Wqi17Cjw=
github.com/PaiGramTeam/proto-contracts v0.0.0-20260419070719-80366e162dcf/go.mod h1:i5FW0tHMyzubj1qI/7e04arFyrDv3YXkK1Pu5UxejBI=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
Expand Down
9 changes: 5 additions & 4 deletions internal/biz/ticket.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ type ServiceTicketClaims struct {
ActorType string
ActorID string
OwnerUserID uint64
// BindingID is the first-class control-plane binding identity carried by
// service tickets and used for authorization and resource lookup.
BindingID uint64
Platform string
PlatformAccountID string
Expand All @@ -16,8 +18,7 @@ type ServiceTicketClaims struct {
PlatformServiceKey string
PlatformAccountRefID uint64

// PlatformAccountRefID is a read-only legacy alias for BindingID so
// downstream callers can migrate incrementally. New tickets should use
// BindingID; if a token still carries platform_account_ref_id, verifier
// code requires it to match binding_id.
// PlatformAccountRefID is a read-only legacy alias for BindingID. New
// tickets should use binding_id; if platform_account_ref_id is present,
// verifier code requires it to match BindingID exactly.
}
14 changes: 10 additions & 4 deletions internal/data/credential_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ func NewCredentialRepo(db *gorm.DB) *CredentialRepo {
return &CredentialRepo{db: db}
}

func (r *CredentialRepo) WithinTransaction(ctx context.Context, fn func(context.Context) error) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
return fn(withTx(ctx, tx))
})
}

func (r *CredentialRepo) Save(ctx context.Context, credential *biz.Credential) error {
record := model.CredentialRecord{
BindingID: credential.BindingID,
Expand All @@ -34,15 +40,15 @@ func (r *CredentialRepo) Save(ctx context.Context, credential *biz.Credential) e
ExpiresAt: credential.ExpiresAt,
}

return r.db.WithContext(ctx).Clauses(clause.OnConflict{
return dbFromContext(ctx, r.db).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "binding_id"}},
UpdateAll: true,
}).Create(&record).Error
}

func (r *CredentialRepo) GetByBindingID(ctx context.Context, bindingID uint64) (*biz.Credential, error) {
var record model.CredentialRecord
if err := r.db.WithContext(ctx).Where("binding_id = ?", bindingID).Order("id asc").Take(&record).Error; err != nil {
if err := dbFromContext(ctx, r.db).Where("binding_id = ?", bindingID).Order("id asc").Take(&record).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
Expand All @@ -54,7 +60,7 @@ func (r *CredentialRepo) GetByBindingID(ctx context.Context, bindingID uint64) (

func (r *CredentialRepo) GetByPlatformAccountID(ctx context.Context, platformAccountID string) (*biz.Credential, error) {
var record model.CredentialRecord
if err := r.db.WithContext(ctx).Where("platform_account_id = ?", platformAccountID).Take(&record).Error; err != nil {
if err := dbFromContext(ctx, r.db).Where("platform_account_id = ?", platformAccountID).Take(&record).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
Expand All @@ -65,7 +71,7 @@ func (r *CredentialRepo) GetByPlatformAccountID(ctx context.Context, platformAcc
}

func (r *CredentialRepo) DeleteByPlatformAccountID(ctx context.Context, platformAccountID string) error {
return r.db.WithContext(ctx).Where("platform_account_id = ?", platformAccountID).Delete(&model.CredentialRecord{}).Error
return dbFromContext(ctx, r.db).Where("platform_account_id = ?", platformAccountID).Delete(&model.CredentialRecord{}).Error
}

func credentialFromRecord(record model.CredentialRecord) *biz.Credential {
Expand Down
6 changes: 3 additions & 3 deletions internal/data/device_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ func (r *DeviceRepo) Save(ctx context.Context, device *biz.Device) error {
LastSeenAt: device.LastSeenAt,
}

return r.db.WithContext(ctx).Clauses(clause.OnConflict{
return dbFromContext(ctx, r.db).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "platform_account_id"}, {Name: "device_id"}},
DoUpdates: clause.AssignmentColumns([]string{"device_fp", "device_name", "is_valid", "last_seen_at", "updated_at"}),
}).Create(&record).Error
}

func (r *DeviceRepo) ListByPlatformAccountID(ctx context.Context, platformAccountID string) ([]*biz.Device, error) {
var records []model.DeviceRecord
if err := r.db.WithContext(ctx).Where("platform_account_id = ?", platformAccountID).Order("id asc").Find(&records).Error; err != nil {
if err := dbFromContext(ctx, r.db).Where("platform_account_id = ?", platformAccountID).Order("id asc").Find(&records).Error; err != nil {
return nil, err
}

Expand All @@ -57,5 +57,5 @@ func (r *DeviceRepo) ListByPlatformAccountID(ctx context.Context, platformAccoun
}

func (r *DeviceRepo) DeleteByPlatformAccountID(ctx context.Context, platformAccountID string) error {
return r.db.WithContext(ctx).Where("platform_account_id = ?", platformAccountID).Delete(&model.DeviceRecord{}).Error
return dbFromContext(ctx, r.db).Where("platform_account_id = ?", platformAccountID).Delete(&model.DeviceRecord{}).Error
}
12 changes: 6 additions & 6 deletions internal/data/profile_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ func (r *ProfileRepo) Save(ctx context.Context, profile *biz.Profile) error {
DiscoveredAt: profile.DiscoveredAt,
}

return r.db.WithContext(ctx).Clauses(clause.OnConflict{
return dbFromContext(ctx, r.db).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "platform_account_id"}, {Name: "player_id"}, {Name: "region"}},
DoUpdates: clause.AssignmentColumns([]string{"game_biz", "nickname", "level", "is_default", "updated_at"}),
}).Create(&record).Error
}

func (r *ProfileRepo) ListByBindingID(ctx context.Context, bindingID uint64) ([]*biz.Profile, error) {
var records []model.AccountProfile
if err := r.db.WithContext(ctx).Where("binding_id = ?", bindingID).Order("id asc").Find(&records).Error; err != nil {
if err := dbFromContext(ctx, r.db).Where("binding_id = ?", bindingID).Order("id asc").Find(&records).Error; err != nil {
return nil, err
}

Expand All @@ -48,20 +48,20 @@ func (r *ProfileRepo) ListByBindingID(ctx context.Context, bindingID uint64) ([]

func (r *ProfileRepo) ListByPlatformAccountID(ctx context.Context, platformAccountID string) ([]*biz.Profile, error) {
var records []model.AccountProfile
if err := r.db.WithContext(ctx).Where("platform_account_id = ?", platformAccountID).Order("id asc").Find(&records).Error; err != nil {
if err := dbFromContext(ctx, r.db).Where("platform_account_id = ?", platformAccountID).Order("id asc").Find(&records).Error; err != nil {
return nil, err
}

return profilesFromRecords(records), nil
}

func (r *ProfileRepo) DeleteByPlatformAccountID(ctx context.Context, platformAccountID string) error {
return r.db.WithContext(ctx).Where("platform_account_id = ?", platformAccountID).Delete(&model.AccountProfile{}).Error
return dbFromContext(ctx, r.db).Where("platform_account_id = ?", platformAccountID).Delete(&model.AccountProfile{}).Error
}

func (r *ProfileRepo) DeleteMissingByPlatformAccountID(ctx context.Context, platformAccountID string, keep []biz.ProfileIdentity) error {
var records []model.AccountProfile
if err := r.db.WithContext(ctx).Where("platform_account_id = ?", platformAccountID).Find(&records).Error; err != nil {
if err := dbFromContext(ctx, r.db).Where("platform_account_id = ?", platformAccountID).Find(&records).Error; err != nil {
return err
}
keepSet := make(map[string]struct{}, len(keep))
Expand All @@ -73,7 +73,7 @@ func (r *ProfileRepo) DeleteMissingByPlatformAccountID(ctx context.Context, plat
if _, ok := keepSet[key]; ok {
continue
}
if err := r.db.WithContext(ctx).Delete(&model.AccountProfile{}, record.ID).Error; err != nil {
if err := dbFromContext(ctx, r.db).Delete(&model.AccountProfile{}, record.ID).Error; err != nil {
return err
}
}
Expand Down
6 changes: 4 additions & 2 deletions internal/data/ticket_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ func (v *TicketVerifier) Verify(raw string, expectedAudience string) (*biz.Servi
if claims.Platform == "" {
return nil, fmt.Errorf("service ticket missing platform")
}
if claims.ActorType == "consumer" && claims.Consumer == "" {
return nil, fmt.Errorf("service ticket missing consumer")
if claims.ActorType == "consumer" {
if claims.Consumer == "" {
return nil, fmt.Errorf("service ticket missing consumer")
}
}
userID := claims.OwnerUserID
if claims.UserID != 0 {
Expand Down
16 changes: 16 additions & 0 deletions internal/data/ticket_verifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,22 @@ func TestVerifyRejectsConsumerActorWithoutConsumerClaim(t *testing.T) {
require.ErrorContains(t, err, "consumer")
}

func TestVerifyAllowsNonConsumerActorTypesWithoutConsumerClaim(t *testing.T) {
verifier := NewTicketVerifier(testTicketIssuer, testTicketSigningKey)
raw := issueTestTicket(t, map[string]any{
"actor_type": "robot",
"actor_id": "user-1",
"owner_user_id": float64(1),
"binding_id": float64(101),
"platform": "mihomo",
"scopes": []string{"mihomo.profile.read"},
})

claims, err := verifier.Verify(raw, testTicketAudience)
require.NoError(t, err)
require.Equal(t, "robot", claims.ActorType)
}

func TestVerifyRejectsMismatchedLegacyPlatformAccountRefID(t *testing.T) {
verifier := NewTicketVerifier(testTicketIssuer, testTicketSigningKey)
raw := issueTestTicket(t, map[string]any{
Expand Down
20 changes: 20 additions & 0 deletions internal/data/tx_context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package data

import (
"context"

"gorm.io/gorm"
)

type txContextKey struct{}

func withTx(ctx context.Context, tx *gorm.DB) context.Context {
return context.WithValue(ctx, txContextKey{}, tx)
}

func dbFromContext(ctx context.Context, fallback *gorm.DB) *gorm.DB {
if tx, ok := ctx.Value(txContextKey{}).(*gorm.DB); ok && tx != nil {
return tx.WithContext(ctx)
}
return fallback.WithContext(ctx)
}
21 changes: 14 additions & 7 deletions internal/service/mihomo_account.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,11 @@ func (s *MihomoAccountService) RefreshCredential(ctx context.Context, req *v1.Re
if err != nil {
return nil, err
}
if _, err := scopedGuardForPlatformAccount(claims, req.GetPlatformAccountId(), usecase.ActionStatusRead); err != nil {
guard, err := scopedGuardForPlatformAccount(claims, req.GetPlatformAccountId(), usecase.ActionCredentialRefresh)
if err != nil {
return nil, mapUsecaseError(err)
}
if err := guard.RequireBindingWide(); err != nil {
return nil, mapUsecaseError(err)
}

Expand Down Expand Up @@ -212,11 +216,12 @@ func (s *MihomoAccountService) GetCredentialSummary(ctx context.Context, req *v1
return nil, err
}

if _, err := scopedGuardForPlatformAccount(claims, req.GetPlatformAccountId(), usecase.ActionCredentialRead); err != nil {
guard, err := scopedGuard(claims, usecase.ActionCredentialRead)
if err != nil {
return nil, mapUsecaseError(err)
}

output, err := s.managementUC.GetCredentialSummary(ctx, req.GetPlatformAccountId())
output, err := s.managementUC.GetCredentialSummaryWithScope(ctx, guard, req.GetPlatformAccountId())
if err != nil {
return nil, mapUsecaseError(err)
}
Expand All @@ -233,7 +238,8 @@ func (s *MihomoAccountService) UpdateCredential(ctx context.Context, req *v1.Upd
if err != nil {
return nil, err
}
if _, err := scopedGuardForPlatformAccount(claims, req.GetPlatformAccountId(), usecase.ActionCredentialUpdate); err != nil {
guard, err := scopedGuard(claims, usecase.ActionCredentialUpdate)
if err != nil {
return nil, mapUsecaseError(err)
}
if req.GetDevice() == nil {
Expand All @@ -255,7 +261,7 @@ func (s *MihomoAccountService) UpdateCredential(ctx context.Context, req *v1.Upd
input.DeviceName = device.GetDeviceName()
}

output, err := s.managementUC.UpdateCredential(ctx, input)
output, err := s.managementUC.UpdateCredentialWithScope(ctx, guard, input)
if err != nil {
return nil, mapUsecaseError(err)
}
Expand All @@ -272,11 +278,12 @@ func (s *MihomoAccountService) DeleteCredential(ctx context.Context, req *v1.Del
if err != nil {
return nil, err
}
if _, err := scopedGuardForPlatformAccount(claims, req.GetPlatformAccountId(), usecase.ActionCredentialDelete); err != nil {
guard, err := scopedGuard(claims, usecase.ActionCredentialDelete)
if err != nil {
return nil, mapUsecaseError(err)
}

if err := s.managementUC.DeleteCredential(ctx, req.GetPlatformAccountId()); err != nil {
if err := s.managementUC.DeleteCredentialWithScope(ctx, guard, req.GetPlatformAccountId()); err != nil {
return nil, mapUsecaseError(err)
}

Expand Down
52 changes: 51 additions & 1 deletion internal/service/mihomo_account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestBindCredentialReturnsDiscoveredProfiles(t *testing.T) {
require.Empty(t, validateResp.ErrorCode)

refreshResp, err := svc.RefreshCredential(context.Background(), &v1.RefreshCredentialRequest{
ServiceTicket: signedServiceTicketForAccount(t, bindResp.PlatformAccountId, "mihomo.status.read"),
ServiceTicket: signedServiceTicketForAccount(t, bindResp.PlatformAccountId, "mihomo.credential.refresh"),
PlatformAccountId: bindResp.PlatformAccountId,
})
require.NoError(t, err)
Expand Down Expand Up @@ -167,6 +167,18 @@ func TestGetCredentialStatusRejectsOutOfScopePlatformAccountID(t *testing.T) {
require.Equal(t, codes.PermissionDenied, status.Code(err))
}

func TestRefreshCredentialRejectsProfileScopedTicket(t *testing.T) {
svc := newMihomoAccountServiceForTest(t)
bindResp := bindCredentialForServiceTest(t, svc)

_, err := svc.RefreshCredential(context.Background(), &v1.RefreshCredentialRequest{
ServiceTicket: signedServiceTicketForProfile(t, bindResp.PlatformAccountId, 999, "mihomo.credential.refresh"),
PlatformAccountId: bindResp.PlatformAccountId,
})
require.Error(t, err)
require.Equal(t, codes.PermissionDenied, status.Code(err))
}

func TestConfirmPrimaryProfileRejectsUnknownPlayerID(t *testing.T) {
svc := newMihomoAccountServiceForTest(t)
bindResp, err := svc.BindCredential(context.Background(), &v1.BindCredentialRequest{
Expand Down Expand Up @@ -213,6 +225,18 @@ func TestDeleteCredentialRejectsMissingScope(t *testing.T) {
require.Equal(t, codes.PermissionDenied, status.Code(err))
}

func TestDeleteCredentialRejectsProfileScopedTicket(t *testing.T) {
svc := newMihomoAccountServiceForTest(t)
bindResp := bindCredentialForServiceTest(t, svc)

_, err := svc.DeleteCredential(context.Background(), &v1.DeleteCredentialRequest{
ServiceTicket: signedServiceTicketForProfile(t, bindResp.PlatformAccountId, 999, "mihomo.credential.delete"),
PlatformAccountId: bindResp.PlatformAccountId,
})
require.Error(t, err)
require.Equal(t, codes.PermissionDenied, status.Code(err))
}

func TestBindCredentialRejectsMissingScope(t *testing.T) {
svc := newMihomoAccountServiceForTest(t)

Expand All @@ -237,6 +261,18 @@ func TestGetCredentialSummaryRejectsMissingScope(t *testing.T) {
require.Equal(t, codes.PermissionDenied, status.Code(err))
}

func TestGetCredentialSummaryRejectsProfileScopedTicket(t *testing.T) {
svc := newMihomoAccountServiceForTest(t)
bindResp := bindCredentialForServiceTest(t, svc)

_, err := svc.GetCredentialSummary(context.Background(), &v1.GetCredentialSummaryRequest{
ServiceTicket: signedServiceTicketForProfile(t, bindResp.PlatformAccountId, 999, "mihomo.credential.read_meta"),
PlatformAccountId: bindResp.PlatformAccountId,
})
require.Error(t, err)
require.Equal(t, codes.PermissionDenied, status.Code(err))
}

func TestUpdateCredentialRejectsMissingScope(t *testing.T) {
svc := newMihomoAccountServiceForTest(t)
bindResp := bindCredentialForServiceTest(t, svc)
Expand All @@ -251,6 +287,20 @@ func TestUpdateCredentialRejectsMissingScope(t *testing.T) {
require.Equal(t, codes.PermissionDenied, status.Code(err))
}

func TestUpdateCredentialRejectsProfileScopedTicket(t *testing.T) {
svc := newMihomoAccountServiceForTest(t)
bindResp := bindCredentialForServiceTest(t, svc)

_, err := svc.UpdateCredential(context.Background(), &v1.UpdateCredentialRequest{
ServiceTicket: signedServiceTicketForProfile(t, bindResp.PlatformAccountId, 999, "mihomo.credential.update"),
PlatformAccountId: bindResp.PlatformAccountId,
CookieBundleJson: `{"account_id":"10001","cookie_token":"updated"}`,
Device: &v1.DeviceInfo{DeviceId: "device-2", DeviceFp: "fp-2", DeviceName: "iPad"},
})
require.Error(t, err)
require.Equal(t, codes.PermissionDenied, status.Code(err))
}

func TestListProfilesRejectsMissingScope(t *testing.T) {
svc := newMihomoAccountServiceForTest(t)
bindResp := bindCredentialForServiceTest(t, svc)
Expand Down
5 changes: 3 additions & 2 deletions internal/service/mihomo_credential_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ func (s *MihomoCredentialService) GetCredentialSummary(ctx context.Context, req
if err != nil {
return nil, status.Error(codes.Unauthenticated, "invalid service ticket")
}
if _, err := scopedGuardForPlatformAccount(claims, req.GetPlatformAccountId(), usecase.ActionCredentialRead); err != nil {
guard, err := scopedGuardForPlatformAccount(claims, req.GetPlatformAccountId(), usecase.ActionCredentialRead)
if err != nil {
return nil, mapUsecaseError(err)
}

output, err := s.managementUC.GetCredentialSummary(ctx, req.GetPlatformAccountId())
output, err := s.managementUC.GetCredentialSummaryWithScope(ctx, guard, req.GetPlatformAccountId())
if err != nil {
return nil, mapUsecaseError(err)
}
Expand Down
Loading
Loading