Skip to content

Commit a7aee7d

Browse files
committed
refactor isKataNode function and add tests
1 parent 6ff0679 commit a7aee7d

File tree

4 files changed

+137
-46
lines changed

4 files changed

+137
-46
lines changed

pkg/azurefile/azurefile.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ type Driver struct {
288288
endpoint string
289289
resolver Resolver
290290
directVolume DirectVolume
291+
isKataNode bool
291292
}
292293

293294
// NewDriver Creates a NewCSIDriver object. Assumes vendor version is equal to driver version &
@@ -329,6 +330,7 @@ func NewDriver(options *DriverOptions) *Driver {
329330
driver.endpoint = options.Endpoint
330331
driver.resolver = new(NetResolver)
331332
driver.directVolume = new(directVolume)
333+
driver.isKataNode = false
332334

333335
var err error
334336
getter := func(_ context.Context, _ string) (interface{}, error) { return nil, nil }
@@ -452,12 +454,8 @@ func (d *Driver) Run(ctx context.Context) error {
452454
csi.RegisterControllerServer(server, d)
453455
csi.RegisterNodeServer(server, d)
454456
d.server = server
455-
val, val2, err := getNodeInfoFromLabels(ctx, d.NodeID, d.kubeClient)
456-
if err != nil {
457-
klog.Warningf("failed to get node info from labels: %v", err)
458-
}
457+
d.isKataNode = isKataNode(ctx, d.NodeID, d.kubeClient)
459458

460-
klog.V(2).Infof("Node info from labels: %s, %s", val, val2)
461459
listener, err := csicommon.ListenEndpoint(d.endpoint)
462460
if err != nil {
463461
klog.Fatalf("failed to listen endpoint: %v", err)
@@ -1302,10 +1300,12 @@ func getNodeInfoFromLabels(ctx context.Context, nodeID string, kubeClient client
13021300
}
13031301

13041302
func isKataNode(ctx context.Context, nodeID string, kubeClient clientset.Interface) bool {
1305-
val, val2, err := getNodeInfoFromLabels(ctx, nodeID, kubeClient)
1303+
kataVMIsolationLabel, kataRuntimeLabel, err := getNodeInfoFromLabels(ctx, nodeID, kubeClient)
1304+
13061305
if err != nil {
1307-
klog.Warningf("get node(%s) confidential label failed with %v", nodeID, err)
1306+
klog.Warningf("failed to get node info from labels: %v", err)
13081307
return false
13091308
}
1310-
return val == "true" || val2 == "true"
1309+
1310+
return strings.EqualFold(kataVMIsolationLabel, "true") || strings.EqualFold(kataRuntimeLabel, "true")
13111311
}

pkg/azurefile/azurefile_test.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,3 +1632,97 @@ func TestGetFileShareClientForSub(t *testing.T) {
16321632
assert.Equal(t, tc.expectedError, err)
16331633
}
16341634
}
1635+
1636+
func TestGetNodeInfoFromLabels(t *testing.T) {
1637+
ctx := context.TODO()
1638+
1639+
// Test case where kubeClient is nil
1640+
_, _, err := getNodeInfoFromLabels(ctx, "test-node", nil)
1641+
if err == nil || err.Error() != "kubeClient is nil" {
1642+
t.Fatalf("expected error 'kubeClient is nil', got %v", err)
1643+
}
1644+
1645+
// Create a fake clientset
1646+
clientset := fake.NewSimpleClientset()
1647+
1648+
// Test case where the node does not exist
1649+
_, _, err = getNodeInfoFromLabels(ctx, "nonexistent-node", clientset)
1650+
if err == nil {
1651+
t.Fatalf("expected an error, got nil")
1652+
}
1653+
1654+
// Test case where node exists but has no labels
1655+
node := &v1api.Node{
1656+
ObjectMeta: metav1.ObjectMeta{
1657+
Name: "test-node",
1658+
Labels: map[string]string{},
1659+
},
1660+
}
1661+
_, err = clientset.CoreV1().Nodes().Create(ctx, node, metav1.CreateOptions{})
1662+
if err != nil {
1663+
t.Fatalf("expected no error, got %v", err)
1664+
}
1665+
1666+
_, _, err = getNodeInfoFromLabels(ctx, "test-node", clientset)
1667+
if err == nil || err.Error() != "node(test-node) label is empty" {
1668+
t.Fatalf("expected error 'node(test-node) label is empty', got %v", err)
1669+
}
1670+
1671+
// Test case where node has kata labels
1672+
node.Labels = map[string]string{
1673+
"kubernetes.azure.com/kata-mshv-vm-isolation": "true",
1674+
"katacontainers.io/kata-runtime": "false",
1675+
}
1676+
_, err = clientset.CoreV1().Nodes().Update(ctx, node, metav1.UpdateOptions{})
1677+
if err != nil {
1678+
t.Fatalf("expected no error, got %v", err)
1679+
}
1680+
1681+
kataVMIsolation, kataRuntime, err := getNodeInfoFromLabels(ctx, "test-node", clientset)
1682+
if err != nil {
1683+
t.Fatalf("expected no error, got %v", err)
1684+
}
1685+
1686+
if kataVMIsolation != "true" || kataRuntime != "false" {
1687+
t.Fatalf("expected (true, false), got (%v, %v)", kataVMIsolation, kataRuntime)
1688+
}
1689+
}
1690+
1691+
func TestIsKataNode(t *testing.T) {
1692+
ctx := context.TODO()
1693+
clientset := fake.NewSimpleClientset()
1694+
1695+
// Test case where node does not exist
1696+
if isKataNode(ctx, "nonexistent-node", clientset) {
1697+
t.Fatalf("expected false, got true")
1698+
}
1699+
1700+
// Create node without kata labels
1701+
node := &v1api.Node{
1702+
ObjectMeta: metav1.ObjectMeta{
1703+
Name: "test-node",
1704+
Labels: map[string]string{
1705+
"some-other-label": "value",
1706+
},
1707+
},
1708+
}
1709+
_, err := clientset.CoreV1().Nodes().Create(ctx, node, metav1.CreateOptions{})
1710+
if err != nil {
1711+
t.Fatalf("expected no error, got %v", err)
1712+
}
1713+
1714+
if isKataNode(ctx, "test-node", clientset) {
1715+
t.Fatalf("expected false, got true")
1716+
}
1717+
1718+
// Update node with kata labels
1719+
node.Labels["kubernetes.azure.com/kata-mshv-vm-isolation"] = "true"
1720+
_, err = clientset.CoreV1().Nodes().Update(ctx, node, metav1.UpdateOptions{})
1721+
if err != nil {
1722+
t.Fatalf("expected no error, got %v", err)
1723+
}
1724+
1725+
if !isKataNode(ctx, "test-node", clientset) {
1726+
t.Fatalf("expected true, got false")
1727+
}
1728+
}

pkg/azurefile/nodeserver.go

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ import (
4343

4444
var getRuntimeClassForPodFunc = getRuntimeClassForPod
4545
var isConfidentialRuntimeClassFunc = isConfidentialRuntimeClass
46-
var isKataNodeFunc = isKataNode
4746

4847
// NodePublishVolume mount the volume from staging to target path
4948
func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) {
@@ -101,42 +100,40 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu
101100
}
102101
}
103102

104-
if d.enableKataCCMount {
105-
enableKataCCMount := isKataNodeFunc(ctx, d.NodeID, d.kubeClient)
106-
if enableKataCCMount && context[podNameField] != "" && context[podNamespaceField] != "" {
107-
runtimeClass, err := getRuntimeClassForPodFunc(ctx, d.kubeClient, context[podNameField], context[podNamespaceField])
103+
enableKataCCMount := d.isKataNode && d.enableKataCCMount
104+
if enableKataCCMount && context[podNameField] != "" && context[podNamespaceField] != "" {
105+
runtimeClass, err := getRuntimeClassForPodFunc(ctx, d.kubeClient, context[podNameField], context[podNamespaceField])
106+
if err != nil {
107+
return nil, status.Errorf(codes.Internal, "failed to get runtime class for pod %s/%s: %v", context[podNamespaceField], context[podNameField], err)
108+
}
109+
klog.V(2).Infof("NodePublishVolume: volume(%s) mount on %s with runtimeClass %s", volumeID, target, runtimeClass)
110+
isConfidentialRuntimeClass, err := isConfidentialRuntimeClassFunc(ctx, d.kubeClient, runtimeClass)
111+
if err != nil {
112+
return nil, status.Errorf(codes.Internal, "failed to check if runtime class %s is confidential: %v", runtimeClass, err)
113+
}
114+
if isConfidentialRuntimeClass {
115+
klog.V(2).Infof("NodePublishVolume for volume(%s) where runtimeClass is %s", volumeID, runtimeClass)
116+
source := req.GetStagingTargetPath()
117+
if len(source) == 0 {
118+
return nil, status.Error(codes.InvalidArgument, "Staging target not provided")
119+
}
120+
// Load the mount info from staging area
121+
mountInfo, err := d.directVolume.VolumeMountInfo(source)
108122
if err != nil {
109-
return nil, status.Errorf(codes.Internal, "failed to get runtime class for pod %s/%s: %v", context[podNamespaceField], context[podNameField], err)
123+
return nil, status.Errorf(codes.Internal, "failed to load mount info from %s: %v", source, err)
124+
}
125+
if mountInfo == nil {
126+
return nil, status.Errorf(codes.Internal, "mount info is nil for volume %s", volumeID)
110127
}
111-
klog.V(2).Infof("NodePublishVolume: volume(%s) mount on %s with runtimeClass %s", volumeID, target, runtimeClass)
112-
isConfidentialRuntimeClass, err := isConfidentialRuntimeClassFunc(ctx, d.kubeClient, runtimeClass)
128+
data, err := json.Marshal(mountInfo)
113129
if err != nil {
114-
return nil, status.Errorf(codes.Internal, "failed to check if runtime class %s is confidential: %v", runtimeClass, err)
130+
return nil, status.Errorf(codes.Internal, "failed to marshal mount info %s: %v", source, err)
115131
}
116-
if isConfidentialRuntimeClass {
117-
klog.V(2).Infof("NodePublishVolume for volume(%s) where runtimeClass is %s", volumeID, runtimeClass)
118-
source := req.GetStagingTargetPath()
119-
if len(source) == 0 {
120-
return nil, status.Error(codes.InvalidArgument, "Staging target not provided")
121-
}
122-
// Load the mount info from staging area
123-
mountInfo, err := d.directVolume.VolumeMountInfo(source)
124-
if err != nil {
125-
return nil, status.Errorf(codes.Internal, "failed to load mount info from %s: %v", source, err)
126-
}
127-
if mountInfo == nil {
128-
return nil, status.Errorf(codes.Internal, "mount info is nil for volume %s", volumeID)
129-
}
130-
data, err := json.Marshal(mountInfo)
131-
if err != nil {
132-
return nil, status.Errorf(codes.Internal, "failed to marshal mount info %s: %v", source, err)
133-
}
134-
if err = d.directVolume.Add(target, string(data)); err != nil {
135-
return nil, status.Errorf(codes.Internal, "failed to save mount info %s: %v", target, err)
136-
}
137-
klog.V(2).Infof("NodePublishVolume: direct volume mount %s at %s successfully", source, target)
138-
return &csi.NodePublishVolumeResponse{}, nil
132+
if err = d.directVolume.Add(target, string(data)); err != nil {
133+
return nil, status.Errorf(codes.Internal, "failed to save mount info %s: %v", target, err)
139134
}
135+
klog.V(2).Infof("NodePublishVolume: direct volume mount %s at %s successfully", source, target)
136+
return &csi.NodePublishVolumeResponse{}, nil
140137
}
141138
}
142139
}
@@ -419,9 +416,9 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe
419416
}
420417
klog.V(2).Infof("volume(%s) mount %s on %s succeeded", volumeID, source, cifsMountPath)
421418
}
422-
enableKataCCMount := isKataNodeFunc(ctx, d.NodeID, d.kubeClient)
419+
enableKataCCMount := d.isKataNode && d.enableKataCCMount
423420
// If runtime OS is not windows and protocol is not nfs, save mountInfo.json
424-
if d.enableKataCCMount && enableKataCCMount {
421+
if enableKataCCMount {
425422
if runtime.GOOS != "windows" && protocol != nfs {
426423
// Check if mountInfo.json is already present at the targetPath
427424
isMountInfoPresent, err := d.directVolume.VolumeMountInfo(cifsMountPath)

pkg/azurefile/nodeserver_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ func TestNodePublishVolume(t *testing.T) {
147147
mockDirectVolume := NewMockDirectVolume(ctrl)
148148
getRuntimeClassForPodFunc = mockGetRuntimeClassForPod
149149
isConfidentialRuntimeClassFunc = mockIsConfidentialRuntimeClass
150-
isKataNodeFunc = mockIsKataNodeAsFalse
150+
d.isKataNode = false
151151

152152
tests := []struct {
153153
desc string
@@ -274,7 +274,7 @@ func TestNodePublishVolume(t *testing.T) {
274274
VolumeContext: map[string]string{mountPermissionsField: "0755", podNameField: "testPod", podNamespaceField: "testNamespace"},
275275
},
276276
setup: func() {
277-
isKataNodeFunc = mockIsKataNodeAsTrue
277+
d.isKataNode = true
278278
d.directVolume = mockDirectVolume
279279
mockDirectVolume.EXPECT().VolumeMountInfo(sourceTest).Return(&volume.MountInfo{}, nil)
280280
mockDirectVolume.EXPECT().Add(targetTest, gomock.Any()).Return(nil)
@@ -470,7 +470,7 @@ func TestNodeStageVolume(t *testing.T) {
470470
defer ctrl.Finish()
471471
mockResolver := NewMockResolver(ctrl)
472472
mockDirectVolume := NewMockDirectVolume(ctrl)
473-
isKataNodeFunc = mockIsKataNodeAsFalse
473+
d.isKataNode = false
474474

475475
tests := []struct {
476476
desc string
@@ -760,7 +760,7 @@ func TestNodeStageVolume(t *testing.T) {
760760
d.resolver = mockResolver
761761
d.directVolume = mockDirectVolume
762762
if runtime.GOOS != "windows" {
763-
isKataNodeFunc = mockIsKataNodeAsTrue
763+
d.isKataNode = true
764764
mockIPAddr := &net.IPAddr{IP: net.ParseIP("192.168.1.1")}
765765
mockDirectVolume.EXPECT().VolumeMountInfo(sourceTest).Return(nil, nil)
766766
mockResolver.EXPECT().ResolveIPAddr("ip", "test_servername").Return(mockIPAddr, nil)

0 commit comments

Comments
 (0)