Skip to content

Commit 12fef6d

Browse files
committed
Add parameter validation to hipGraph* API functions
Adds null-pointer checks to hipGraphCreate, hipGraphAddEmptyNode, hipGraphAddMemAllocNode, hipGraphAddMemFreeNode, hipGraphExecDestroy, hipGraphNodeFindInClone, hipGraphNodeGetDependencies, hipGraphNodeGetDependentNodes. Fixes bitwise AND bug in hipGraphAddMemcpyNode dependency check. Fixes #555
1 parent ae61cce commit 12fef6d

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

src/CHIPBindings.cc

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ hipError_t hipGraphAddMemAllocNode(hipGraphNode_t *pGraphNode, hipGraph_t graph,
200200
CHIP_TRY
201201
LOCK(ApiMtx);
202202
CHIPInitialize();
203+
if (!pGraphNode)
204+
RETURN(hipErrorInvalidValue);
205+
if (!graph)
206+
RETURN(hipErrorInvalidValue);
203207
UNIMPLEMENTED(hipErrorNotSupported);
204208
CHIP_CATCH
205209
}
@@ -210,6 +214,10 @@ hipError_t hipGraphAddMemFreeNode(hipGraphNode_t *pGraphNode, hipGraph_t graph,
210214
CHIP_TRY
211215
LOCK(ApiMtx);
212216
CHIPInitialize();
217+
if (!pGraphNode)
218+
RETURN(hipErrorInvalidValue);
219+
if (!graph)
220+
RETURN(hipErrorInvalidValue);
213221
UNIMPLEMENTED(hipErrorNotSupported);
214222
CHIP_CATCH
215223
}
@@ -1105,6 +1113,8 @@ hipError_t hipGraphCreate(hipGraph_t *pGraph, unsigned int flags) {
11051113
CHIP_TRY
11061114
LOCK(ApiMtx);
11071115
CHIPInitialize();
1116+
if (!pGraph)
1117+
RETURN(hipErrorInvalidValue);
11081118
CHIPGraph *Graph = new CHIPGraph();
11091119
*pGraph = Graph;
11101120
RETURN(hipSuccess);
@@ -1225,6 +1235,10 @@ hipError_t hipGraphNodeGetDependencies(hipGraphNode_t node,
12251235
CHIP_TRY
12261236
LOCK(ApiMtx);
12271237
CHIPInitialize();
1238+
if (!node)
1239+
RETURN(hipErrorInvalidValue);
1240+
if (!pNumDependencies)
1241+
RETURN(hipErrorInvalidValue);
12281242
auto Deps = NODE(node)->getDependencies();
12291243
*pNumDependencies = Deps.size();
12301244
if (!pDependencies)
@@ -1242,6 +1256,10 @@ hipError_t hipGraphNodeGetDependentNodes(hipGraphNode_t node,
12421256
CHIP_TRY
12431257
LOCK(ApiMtx);
12441258
CHIPInitialize();
1259+
if (!node)
1260+
RETURN(hipErrorInvalidValue);
1261+
if (!pNumDependentNodes)
1262+
RETURN(hipErrorInvalidValue);
12451263
auto Deps = NODE(node)->getDependants();
12461264
*pNumDependentNodes = Deps.size();
12471265
if (!pDependentNodes)
@@ -1343,6 +1361,12 @@ hipError_t hipGraphNodeFindInClone(hipGraphNode_t *pNode,
13431361
CHIP_TRY
13441362
LOCK(ApiMtx);
13451363
CHIPInitialize();
1364+
if (!pNode)
1365+
RETURN(hipErrorInvalidValue);
1366+
if (!originalNode)
1367+
RETURN(hipErrorInvalidValue);
1368+
if (!clonedGraph)
1369+
RETURN(hipErrorInvalidValue);
13461370
auto Node = GRAPH(clonedGraph)->getClonedNodeFromOriginal(NODE(originalNode));
13471371
*pNode = Node;
13481372
RETURN(hipSuccess);
@@ -1406,6 +1430,8 @@ hipError_t hipGraphExecDestroy(hipGraphExec_t graphExec) {
14061430
CHIP_TRY
14071431
LOCK(ApiMtx);
14081432
CHIPInitialize();
1433+
if (!graphExec)
1434+
RETURN(hipErrorInvalidValue);
14091435
delete graphExec;
14101436
RETURN(hipSuccess);
14111437
CHIP_CATCH
@@ -1640,7 +1666,7 @@ hipError_t hipGraphAddMemcpyNode(hipGraphNode_t *pGraphNode, hipGraph_t graph,
16401666
// NULLCHECK(graph, pGraphNode, pCopyParams);
16411667
if (!graph || !pGraphNode || !pCopyParams)
16421668
RETURN(hipErrorInvalidValue);
1643-
if (!pDependencies & numDependencies > 0)
1669+
if (!pDependencies && numDependencies > 0)
16441670
CHIPERR_LOG_AND_THROW(
16451671
"numDependencies is not 0 while pDependencies is null",
16461672
hipErrorInvalidValue);
@@ -2204,6 +2230,10 @@ hipError_t hipGraphAddEmptyNode(hipGraphNode_t *pGraphNode, hipGraph_t graph,
22042230
CHIP_TRY
22052231
LOCK(ApiMtx);
22062232
CHIPInitialize();
2233+
if (!pGraphNode)
2234+
RETURN(hipErrorInvalidValue);
2235+
if (!graph)
2236+
RETURN(hipErrorInvalidValue);
22072237
CHIPGraphNodeEmpty *Node = new CHIPGraphNodeEmpty();
22082238
Node->addDependencies(DECONST_NODES(pDependencies), numDependencies);
22092239
*pGraphNode = Node;

0 commit comments

Comments
 (0)