model.predict for graph classification #369
Unanswered
Digital-Chemist
asked this question in
Q&A
Replies: 1 comment 2 replies
-
|
Hi, that example is using DisjointLoader and graph-level predictions, so the model expects a batch index input (typically called Change the call to model.predict([graphX, graphA, tf.zeros(x.shape[0])]) |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Not exactly sure how to use model.predict() for graph classification. I've trained a GeneralGNN model with some custom data using the framework in the example with about 130 graphs (https://github.com/danielegrattarola/spektral/blob/master/examples/graph_prediction/general_gnn.py). No issues with the training/testing. If I then have a single unlabeled Graph (no y label) created the exact same way that the training/testing set graphs were created except no y=label, should I not just be able to feed that into model.predict as
This gives an error:
global_pool.py", line 30, in call * I = inputs[1] IndexError: list index out of rangeIf I put the single Graph into the same loader (DisjointLoader) created in the same way that the training/testing set was created, I get an error:
Both x and a are definitely in the single dataset (print(pred_dataset[0].x) and print(pred_dataset[0].a) give the expected arrays and sizes.
What am I missing here? Any help would be greatly appreciated.
Beta Was this translation helpful? Give feedback.
All reactions