Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deploy/csi-azurefile-driver.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@ spec:
- Persistent
- Ephemeral
fsGroupPolicy: ReadWriteOnceWithFSType
requiresRepublish: true
tokenRequests:
- audience: api://AzureADTokenExchange
expirationSeconds: 3600
88 changes: 77 additions & 11 deletions pkg/azurefile/azurefile.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -88,6 +91,7 @@ const (
defaultAzureFileQuota = 100
minimumAccountQuota = 100 // GB

DefaultTokenAudience = "api://AzureADTokenExchange/.default"
// key of snapshot name in metadata
snapshotNameKey = "initiator"

Expand Down Expand Up @@ -168,6 +172,7 @@ const (
runtimeClassHandlerField = "runtimeclasshandler"
defaultRuntimeClassHandler = "kata-cc"
mountWithManagedIdentityField = "mountwithmanagedidentity"
mountWithWITokenField = "mountwithworkloadidentitytoken"

accountNotProvisioned = "StorageAccountIsNotProvisioned"
// this is a workaround fix for 429 throttling issue, will update cloud provider for better fix later
Expand Down Expand Up @@ -229,6 +234,8 @@ var (
azcopyCloneVolumeOptions = []string{"--recursive", "--check-length=false", "--log-level=ERROR"}
// azcopySnapshotRestoreOptions used in smb snapshot restore and set --check-length to true because snapshot data is changeless
azcopySnapshotRestoreOptions = []string{"--recursive", "--check-length=true", "--log-level=ERROR"}

defaultAzureOAuthTokenDir = "/var/lib/kubelet/plugins/" + DefaultDriverName
)

// Driver implements all interfaces of CSI drivers
Expand Down Expand Up @@ -804,8 +811,8 @@ func IsCorruptedDir(dir string) bool {
}

// GetAccountInfo get account info
// return <rgName, accountName, accountKey, fileShareName, diskName, subsID, err>
func (d *Driver) GetAccountInfo(ctx context.Context, volumeID string, secrets, reqContext map[string]string) (string, string, string, string, string, string, error) {
// return <rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, err>
func (d *Driver) GetAccountInfo(ctx context.Context, volumeID string, secrets, reqContext map[string]string) (string, string, string, string, string, string, string, string, error) {
rgName, accountName, fileShareName, diskName, secretNamespace, subsID, err := GetFileShareInfo(volumeID)
if err != nil {
// ignore volumeID parsing error
Expand All @@ -815,8 +822,8 @@ func (d *Driver) GetAccountInfo(ctx context.Context, volumeID string, secrets, r

var protocol, accountKey, secretName, pvcNamespace string
// getAccountKeyFromSecret indicates whether get account key only from k8s secret
var getAccountKeyFromSecret, getLatestAccountKey, mountWithManagedIdentity bool
var clientID, tenantID, serviceAccountToken string
var getAccountKeyFromSecret, getLatestAccountKey, mountWithManagedIdentity, mountWithWIToken bool
var clientID, tenantID, tokenFilePath, serviceAccountToken string

for k, v := range reqContext {
switch strings.ToLower(k) {
Expand Down Expand Up @@ -844,13 +851,17 @@ func (d *Driver) GetAccountInfo(ctx context.Context, volumeID string, secrets, r
pvcNamespace = v
case getLatestAccountKeyField:
if getLatestAccountKey, err = strconv.ParseBool(v); err != nil {
return rgName, accountName, accountKey, fileShareName, diskName, subsID, fmt.Errorf("invalid %s: %s in volume context", getLatestAccountKeyField, v)
return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, fmt.Errorf("invalid %s: %s in volume context", getLatestAccountKeyField, v)
}
case clientIDField:
clientID = v
case mountWithManagedIdentityField:
if mountWithManagedIdentity, err = strconv.ParseBool(v); err != nil {
return rgName, accountName, accountKey, fileShareName, diskName, subsID, fmt.Errorf("invalid %s: %s in volume context", mountWithManagedIdentityField, v)
return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, fmt.Errorf("invalid %s: %s in volume context", mountWithManagedIdentityField, v)
}
case mountWithWITokenField:
if mountWithWIToken, err = strconv.ParseBool(v); err != nil {
return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, fmt.Errorf("invalid %s: %s in volume context", mountWithWITokenField, v)
}
case tenantIDField:
tenantID = v
Expand All @@ -870,7 +881,7 @@ func (d *Driver) GetAccountInfo(ctx context.Context, volumeID string, secrets, r
}
if protocol == nfs && fileShareName != "" {
// nfs protocol does not need account key, return directly
return rgName, accountName, accountKey, fileShareName, diskName, subsID, err
return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, err
}

if secretNamespace == "" {
Expand All @@ -883,21 +894,51 @@ func (d *Driver) GetAccountInfo(ctx context.Context, volumeID string, secrets, r

if mountWithManagedIdentity {
klog.V(2).Infof("mountWithManagedIdentity is true, use managed identity auth")
return rgName, accountName, accountKey, fileShareName, diskName, subsID, nil
return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, nil
}

if mountWithWIToken {
if clientID == "" {
clientID = d.cloud.Config.AzureAuthConfig.UserAssignedIdentityID
if clientID == "" {
return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, fmt.Errorf("clientID is empty for workload identity auth")
}
}
klog.V(2).Infof("mountWithWorkloadIdentityToken is specified, use workload identity auth for mount, clientID: %s, tenantID: %s", clientID, tenantID)
token, err := parseServiceAccountToken(serviceAccountToken)
if err != nil {
return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, fmt.Errorf("failed to parse service account token: %v", err)
}
tokenFileName := clientID + "-" + accountName
if !isValidTokenFileName(tokenFileName) {
return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, fmt.Errorf("invalid token file name(%s) generated for clientID(%s) and accountName(%s)", tokenFileName, clientID, accountName)
}
tokenFilePath = filepath.Join(defaultAzureOAuthTokenDir, tokenFileName)
// check whether token value is the same as the one in the token file
existingToken, readErr := os.ReadFile(tokenFilePath)
if readErr == nil && string(existingToken) == token {
klog.V(4).Infof("the token file(%s) already exists and the token value is the same, no need to rewrite the token file", tokenFilePath)
return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, "", nil
}
// write token to a file
if err := os.WriteFile(tokenFilePath, []byte(token), 0600); err != nil {
return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, fmt.Errorf("failed to write azure oAuth token file(%s): %v", tokenFilePath, err)
}
return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, err
}

if clientID != "" {
klog.V(2).Infof("clientID(%s) is specified, use service account token to get account key", clientID)
accountKey, err := d.cloud.GetStorageAccesskeyFromServiceAccountToken(ctx, subsID, accountName, rgName, clientID, tenantID, serviceAccountToken)
return rgName, accountName, accountKey, fileShareName, diskName, subsID, err
return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, "", err
}

if len(secrets) == 0 {
// if request context does not contain secrets, get secrets in the following order:
// 1. get account key from cache first
cache, errCache := d.accountCacheMap.Get(ctx, accountName, azcache.CacheReadTypeDefault)
if errCache != nil {
return rgName, accountName, accountKey, fileShareName, diskName, subsID, errCache
return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, errCache
}
if cache != nil {
accountKey = cache.(string)
Expand Down Expand Up @@ -939,7 +980,7 @@ func (d *Driver) GetAccountInfo(ctx context.Context, volumeID string, secrets, r
if err == nil && accountKey != "" {
d.accountCacheMap.Set(accountName, accountKey)
}
return rgName, accountName, accountKey, fileShareName, diskName, subsID, err
return rgName, accountName, accountKey, fileShareName, diskName, subsID, tenantID, tokenFilePath, err
}

func isSupportedProtocol(protocol string) bool {
Expand Down Expand Up @@ -1489,3 +1530,28 @@ func (d *Driver) createFolderIfNotExists(ctx context.Context, accountName, accou
klog.V(2).Infof("Successfully ensured folder path %s exists in share %s", folderName, fileShareName)
return nil
}

// serviceAccountToken represents the service account token sent from NodePublishVolume Request.
// ref: https://kubernetes-csi.github.io/docs/token-requests.html
type serviceAccountToken struct {
APIAzureADTokenExchange struct {
Token string `json:"token"`
ExpirationTimestamp time.Time `json:"expirationTimestamp"`
} `json:"api://AzureADTokenExchange"`
}

// parseServiceAccountToken parses the bound service account token from the token passed from NodePublishVolume Request.
// ref: https://kubernetes-csi.github.io/docs/token-requests.html
func parseServiceAccountToken(tokenStr string) (string, error) {
if len(tokenStr) == 0 {
return "", fmt.Errorf("service account token is empty")
}
token := serviceAccountToken{}
if err := json.Unmarshal([]byte(tokenStr), &token); err != nil {
return "", fmt.Errorf("failed to unmarshal service account tokens, error: %w", err)
}
if token.APIAzureADTokenExchange.Token == "" {
return "", fmt.Errorf("token for audience %s not found", DefaultTokenAudience)
}
return token.APIAzureADTokenExchange.Token, nil
}
101 changes: 100 additions & 1 deletion pkg/azurefile/azurefile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,20 @@ func TestGetAccountInfo(t *testing.T) {
expectFileShareName: "test_sharename",
expectDiskName: "test_diskname",
},
{
volumeID: "invalid_mountWithWITokenField_value##",
rgName: "vol_2",
secrets: emptySecret,
reqContext: map[string]string{
shareNameField: "test_sharename",
mountWithWITokenField: "invalid",
},
expectErr: true,
err: fmt.Errorf("invalid %s: %s in volume context", mountWithWITokenField, "invalid"),
expectAccountName: "",
expectFileShareName: "test_sharename",
expectDiskName: "test_diskname",
},
}

for _, test := range tests {
Expand All @@ -847,7 +861,7 @@ func TestGetAccountInfo(t *testing.T) {
d.kubeClient = clientSet
d.cloud.Environment = &azclient.Environment{StorageEndpointSuffix: "abc"}
mockStorageAccountsClient.EXPECT().ListKeys(gomock.Any(), gomock.Any(), test.rgName).Return(key, nil).AnyTimes()
rgName, accountName, _, fileShareName, diskName, _, err := d.GetAccountInfo(context.Background(), test.volumeID, test.secrets, test.reqContext)
rgName, accountName, _, fileShareName, diskName, _, _, _, err := d.GetAccountInfo(context.Background(), test.volumeID, test.secrets, test.reqContext)
if test.expectErr && err == nil {
t.Errorf("Unexpected non-error")
continue
Expand Down Expand Up @@ -2052,3 +2066,88 @@ func TestGetInfoFromSnapshotID(t *testing.T) {
})
}
}

func TestParseServiceAccountToken(t *testing.T) {
tests := []struct {
name string
tokenStr string
expectedToken string
expectedError string
}{
{
name: "Empty token string",
tokenStr: "",
expectedToken: "",
expectedError: "service account token is empty",
},
{
name: "Invalid JSON",
tokenStr: "invalid-json",
expectedToken: "",
expectedError: "failed to unmarshal service account tokens",
},
{
name: "Valid token with audience",
tokenStr: `{"api://AzureADTokenExchange":{"token":"test-token-value","expirationTimestamp":"2025-01-01T00:00:00Z"}}`,
expectedToken: "test-token-value",
expectedError: "",
},
{
name: "Token with empty token value",
tokenStr: `{"api://AzureADTokenExchange":{"token":"","expirationTimestamp":"2025-01-01T00:00:00Z"}}`,
expectedToken: "",
expectedError: "token for audience api://AzureADTokenExchange/.default not found",
},
{
name: "Token with missing api://AzureADTokenExchange field",
tokenStr: `{"someOtherField":{"token":"test-token","expirationTimestamp":"2025-01-01T00:00:00Z"}}`,
expectedToken: "",
expectedError: "token for audience api://AzureADTokenExchange/.default not found",
},
{
name: "Token with partial JSON structure",
tokenStr: `{"api://AzureADTokenExchange":{}}`,
expectedToken: "",
expectedError: "token for audience api://AzureADTokenExchange/.default not found",
},
{
name: "Malformed JSON with extra characters",
tokenStr: `{"api://AzureADTokenExchange":{"token":"test-token"}}extra`,
expectedToken: "",
expectedError: "failed to unmarshal service account tokens",
},
{
name: "Token with special characters",
tokenStr: `{"api://AzureADTokenExchange":{"token":"eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3QifQ.eyJzdWIiOiIxMjM0NTY3ODkwIn0.test","expirationTimestamp":"2025-01-01T00:00:00Z"}}`,
expectedToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3QifQ.eyJzdWIiOiIxMjM0NTY3ODkwIn0.test",
expectedError: "",
},
{
name: "Token with unicode characters",
tokenStr: `{"api://AzureADTokenExchange":{"token":"test-token-.","expirationTimestamp":"2025-01-01T00:00:00Z"}}`,
expectedToken: "test-token-.",
expectedError: "",
},
{
name: "Token with whitespace in value",
tokenStr: `{"api://AzureADTokenExchange":{"token":" test-token ","expirationTimestamp":"2025-01-01T00:00:00Z"}}`,
expectedToken: " test-token ",
expectedError: "",
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
token, err := parseServiceAccountToken(test.tokenStr)

if test.expectedError != "" {
assert.Error(t, err)
assert.Contains(t, err.Error(), test.expectedError)
assert.Equal(t, "", token)
} else {
assert.NoError(t, err)
assert.Equal(t, test.expectedToken, token)
}
})
}
}
Loading
Loading