Main Content

Node Classification Using Graph Convolutional Network

This example shows how to classify nodes in a graph using Graph Convolutional Network (GCN).

The node classification task is one where an algorithm, in this example, a GCN [1], has to predict the labels of unlabelled nodes in a graph. In this example, a graph is represented by a molecule. Atoms in the molecule represent nodes in the graph and the chemical bonds between atoms represent edges in the graph. Node labels are the types of atom, for example, Carbon. As such, input to the GCN are molecules and the outputs are predictions of the type of atom of each unlabelled atom in the molecule.

To assign a categorical label to each node of a graph, the GCN models a function f(X,A) on a graph G=(V,E), where V denotes the set of nodes and E denotes the set of edges, such thatf(X,A) takes as input:

  • X: A feature matrix of dimension N×C, where N=|V| is the number of nodes in G and C is number of input channels/features per node.

  • A: An adjacency matrix of dimension N×N representing E and describing the structure of G.

and returns an output:

  • Z: An Embedding or feature matrix of dimension N×F, where F is number of output features per node. In other words, Z is the predictions of the network and F is the number of classes.

The model f(X,A) is based on spectral graph convolution, with weights/filter parameters shared over all locations in G. The model can be represented as a layer-wise propagation model, such that the output of layer l+1 is expressed as

Zl+1=σ(Dˆ-1/2AˆDˆ-1/2ZlWl),

where

  • σ is an activation function.

  • Zl is the activation matrix of layer l, with Z1=X.

  • Wl is the weight matrix of layer l.

  • Aˆ=A+IN is the adjacency matrix of graph G with added self-connections. IN is the identity matrix.

  • Dˆ is the degree matrix of Aˆ.

Expression Dˆ-1/2AˆDˆ-1/2 can be referred to as the normalized adjacency matrix of the graph.

The GCN model in this example is a variant of the standard GCN model described above. The variant uses residual connections between layers [1]. The residual connections enable the model to carry over information from previous layer’s input. Therefore, the output of layer l+1of the GCN model in this example is

Zl+1=σ(Dˆ-1/2AˆDˆ-1/2ZlWl)+Zl,

See [1] for more details about the GCN model.

This example uses the QM7 dataset [2] [3], which is a molecular dataset consisting of 7165 molecules composed of up to 23 atoms. That is, the molecule with the highest number of atoms has 23 atoms. Overall, the dataset consists of 5 unique atoms: Carbon, Hydrogen, Nitrogen, Oxygen, and Sulphur.

Download and Load QM7 data

Download the QM7 dataset from the following URL:

dataURL = 'http://quantum-machine.org/data/qm7.mat';
outputFolder = fullfile(tempdir,'qm7Data');
dataFile = fullfile(outputFolder,'qm7.mat');

if ~exist(dataFile, 'file')
    mkdir(outputFolder);
    fprintf('Downloading file ''%s'' ...\n', dataFile);
    websave(dataFile, dataURL);
end

Load QM7 data.

data = load(dataFile)
data = struct with fields:
    X: [7165×23×23 single]
    R: [7165×23×3 single]
    Z: [7165×23 single]
    T: [1×7165 single]
    P: [5×1433 int64]

The data consists of five different arrays. This example uses the arrays in fields X and Z of struct data. The array in X represents the Coulomb matrix [3] representation of each molecule, totalling 7165 molecules, and the array in Z represents the atomic charge/number of each atom in the molecules. The adjacency matrices of the graphs representing the molecules, and the feature matrices of the graphs, are extracted from the Coulomb matrices. The categorical array of labels is extracted from the array in Z.

Note that the data, for any molecule that does not have up to 23 atoms, contains padded zeros. For example, the data representing the atomic numbers of atoms in the molecule at index 1 is

data.Z(1,:)
ans = 1×23 single row vector

     6     1     1     1     1     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0     0

This shows that this molecule is composed of five atoms; one atom with atomic number 6 and four atoms with atomic number 1, and the data is padded with 18 zeros.

Extract and Preprocess Graph Data

To extract graph data, get the Coulomb matrices and atomic numbers. Permute the data representing the Coulomb matrices and change the datatype to double. Sort the data representing the atomic charges so that it matches the data representing the Coulomb matrices.

coulombData = double(permute(data.X, [2 3 1]));
atomicNumber = sort(data.Z,2,'descend'); 

Reformat the Coulomb matrix representation of the molecules to binary adjacency matrices using the coloumb2Adjacency function attached to this example as a supporting file.

adjacencyData = coloumb2Adjacency(coulombData, atomicNumber);

Note that the coloumb2Adjacency function does not remove padded zeros from the data. They are left intentionally to make it easier to split the data into separate molecules for training, validation and inference. Therefore, ignoring the padded zeros, the adjacency matrix of the graph representing the molecule at index 1, which consists of 5 atoms, is

adjacencyData(1:5,1:5,1)
ans = 5×5

     0     1     1     1     1
     1     0     0     0     0
     1     0     0     0     0
     1     0     0     0     0
     1     0     0     0     0

Before preprocessing the data, use the splitData function, provided at the end of the example, to randomly select and split the data into training, validation and test data. The function uses the ratio 80:10:10 to split the data.

The adjacencyDataSplit output of the splitData function is the adjacencyData input data split into three different arrays. Likewise, the coulombDataSplit and atomicNumberSplit outputs are the coulombData and atomicNumber input data split into three different arrays respectively.

[adjacencyDataSplit, coulombDataSplit, atomicNumberSplit] = splitData(adjacencyData, coulombData, atomicNumber);

Use the preprocessData function, provided at the end of the example, to process the adjacencyDataSplit, coulombDataSplit, and atomicNumberSplit and return the adjacency matrix adjacency, feature matrix features, and categorical array labels.

The preprocessData function builds a sparse block-diagonal matrix of the adjacency matrices of different graph instances, such that, each block in the matrix corresponds to the adjacency matrix of one graph instance. This preprocessing is required because GCN accepts a single adjacency matrix as input, whereas this example deals with multiple graph instances. The function takes the non-zero diagonal elements of the Coulomb matrices and assigns them as features. Therefore, the number of input features per node in the example is 1.

[adjacency, features, labels] = cellfun(@preprocessData, adjacencyDataSplit, coulombDataSplit, atomicNumberSplit, 'UniformOutput', false);

View the adjacency matrices of the training, validation, and test data.

adjacency
adjacency=1×3 cell array
    {88722×88722 double}    {10942×10942 double}    {10986×10986 double}

This shows that there are 88722 nodes in the training data, 10942 nodes in the validation data, and 10986 nodes in the test data.

Normalize the feature array using the normalizeFeatures function provided at the end of the example.

features = normalizeFeatures(features);

Get the training and the validation data.

featureTrain = features{1};
adjacencyTrain = adjacency{1};
targetTrain = labels{1};

featureValidation = features{2};
adjacencyValidation = adjacency{2};
targetValidation = labels{2};

Visualize Data and Data Statistics

Sample and specify indices of molecules to visualize.

For each specified index

  • Remove padded zeros from the data representing unprocessed atomic numbers atomicNumber and unprocessed adjacency matrix adjacencyData of the sampled molecule. The unprocessed data are used here for easy sampling.

  • Convert the adjacency matrix to graph using the graph function.

  • Convert the atomic numbers to symbols.

  • Plot the graph using the atomic symbols as node labels.

idx = [1 5 300 1159];
for j = 1:numel(idx)
    % Remove padded zeros from the data
    atomicNum = nonzeros(atomicNumber(idx(j),:));
    numOfNodes = numel(atomicNum);
    adj = adjacencyData(1:numOfNodes,1:numOfNodes,idx(j));
    
    % Convert adjacency matrix to graph
    compound = graph(adj);
    
    % Convert atomic numbers to symbols
    symbols = cell(numOfNodes, 1);
    for i = 1:numOfNodes
        if atomicNum(i) == 1
            symbols{i} = 'H';
        elseif atomicNum(i) == 6
            symbols{i} = 'C';
        elseif atomicNum(i) == 7
            symbols{i} = 'N';
        elseif atomicNum(i) == 8
            symbols{i} = 'O';
        else
            symbols{i} = 'S';
        end
    end
    
    % Plot graph
    subplot(2,2,j)
    plot(compound, 'NodeLabel', symbols, 'LineWidth', 0.75, ...
    'Layout', 'force')
    title("Molecule " + idx(j))
end

Get all the labels and the classes.

labelsAll = cat(1,labels{:});
classes = categories(labelsAll)
classes = 5×1 cell
    {'Hydrogen'}
    {'Carbon'  }
    {'Nitrogen'}
    {'Oxygen'  }
    {'Sulphur' }

Visualize frequency of each label category using a histogram.

figure
histogram(labelsAll)
xlabel('Category')
ylabel('Frequency')
title('Label Counts')

Define Model Function

Create the function model, provided at the end of the example, that takes the feature data dlX, the adjacency matrix A, and the model parameters parameters as input and returns predictions for the label.

Initialize Model Parameters

Set the number of input features per node. This is the column length of the feature matrix.

numInputFeatures = size(featureTrain,2)
numInputFeatures = 1

Set the number of feature maps for the hidden layers.

numHiddenFeatureMaps = 32;

Set the number of output features as the number of categories.

numOutputFeatures = numel(classes)
numOutputFeatures = 5

Create a struct parameters containing the model weights. Initialize the weights using the initializeGlorot function attached to this example as a supporting file.

sz = [numInputFeatures numHiddenFeatureMaps];
numOut = numHiddenFeatureMaps;
numIn = numInputFeatures;
parameters.W1 = initializeGlorot(sz,numOut,numIn,'double');

sz = [numHiddenFeatureMaps numHiddenFeatureMaps];
numOut = numHiddenFeatureMaps;
numIn = numHiddenFeatureMaps;
parameters.W2 = initializeGlorot(sz,numOut,numIn,'double');

sz = [numHiddenFeatureMaps numOutputFeatures];
numOut = numOutputFeatures;
numIn = numHiddenFeatureMaps;
parameters.W3 = initializeGlorot(sz,numOut,numIn,'double');
parameters
parameters = struct with fields:
    W1: [1×32 dlarray]
    W2: [32×32 dlarray]
    W3: [32×5 dlarray]

Define Model Gradients Function

Create the function modelGradients, provided at the end of the example, that takes the feature data dlX, the adjacency matrix adjacencyTrain, the one-hot encoded targets T of the labels, and the model parameters parameters as input and returns the gradients of the loss with respect to the parameters, the corresponding loss, and the network predictions.

Specify Training Options

Train for 1500 epochs and set the learn rate for Adam solver to 0.01.

numEpochs = 1500;
learnRate = 0.01;

Validate the network after every 300 epochs.

validationFrequency = 300;

Visualize the training progress in a plot.

plots = "training-progress";

To train on a GPU if one is available, specify the execution environment "auto". Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Support by Release (Parallel Computing Toolbox) (Parallel Computing Toolbox).

executionEnvironment = "auto";

Train Model

Train the model using a custom training loop. The training uses full-batch gradient descent.

Initialize the training progress plot.

if plots == "training-progress"
    figure
    
    % Accuracy.
    subplot(2,1,1)
    lineAccuracyTrain = animatedline('Color',[0 0.447 0.741]);
    lineAccuracyValidation = animatedline( ...
        'LineStyle','--', ...
        'Marker','o', ...
        'MarkerFaceColor','black');
    ylim([0 1])
    xlabel("Epoch")
    ylabel("Accuracy")
    grid on
    
    % Loss.
    subplot(2,1,2)
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    lineLossValidation = animatedline( ...
        'LineStyle','--', ...
        'Marker','o', ...
        'MarkerFaceColor','black');
    ylim([0 inf])
    xlabel("Epoch")
    ylabel("Loss")
    grid on
end

Initialize parameters for Adam.

trailingAvg = [];
trailingAvgSq = [];

Convert training and validation feature data to dlarray.

dlX = dlarray(featureTrain);
dlXValidation = dlarray(featureValidation);

For GPU training, convert data to gpuArray objects.

if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlX = gpuArray(dlX);
end

Encode training and validation label data using onehotencode.

T = onehotencode(targetTrain, 2, 'ClassNames', classes);
TValidation = onehotencode(targetValidation, 2, 'ClassNames', classes);

Train the model.

For each epoch

  • Evaluate the model gradients and loss using dlfeval and the modelGradients function.

  • Update the network parameters using adamupdate.

  • Compute the training accuracy score using the accuracy function provided at the end of the example. The function takes the network predictions, the target containing the labels, and the categories classes as inputs and returns the accuracy score.

  • If required, validate the network by making predictions using the model function and computing the validation loss and the validation accuracy score using crossentropy and the accuracy function.

  • Update the training plot.

start = tic;
% Loop over epochs.
for epoch = 1:numEpochs
    
    % Evaluate the model gradients and loss using dlfeval and the
    % modelGradients function.
    [gradients, loss, dlYPred] = dlfeval(@modelGradients, dlX, adjacencyTrain, T, parameters);
    
    % Update the network parameters using the Adam optimizer.
    [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
        trailingAvg,trailingAvgSq,epoch,learnRate);
    
    % Display the training progress.
    if plots == "training-progress"
        subplot(2,1,1)
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        title("Epoch: " + epoch + ", Elapsed: " + string(D))

        % Loss.
        addpoints(lineLossTrain,epoch,double(gather(extractdata(loss))))

        % Accuracy score.
        score = accuracy(dlYPred, targetTrain, classes);
        addpoints(lineAccuracyTrain,epoch,double(gather(score)))

        drawnow

        % Display validation metrics.
        if epoch == 1 || mod(epoch,validationFrequency) == 0
            % Loss.
            dlYPredValidation = model(dlXValidation, adjacencyValidation, parameters);
            lossValidation = crossentropy(dlYPredValidation, TValidation, 'DataFormat', 'BC');
            addpoints(lineLossValidation,epoch,double(gather(extractdata(lossValidation))))

            % Accuracy score.
            scoreValidation = accuracy(dlYPredValidation, targetValidation, classes);
            addpoints(lineAccuracyValidation,epoch,double(gather(scoreValidation)))

            drawnow
        end
    end
end

Test Model

Test the model using the test data.

featureTest = features{3};
adjacencyTest = adjacency{3};
targetTest = labels{3};

Convert the test feature data to dlarray.

dlXTest = dlarray(featureTest);

Make predictions on the data.

dlYPredTest = model(dlXTest, adjacencyTest, parameters);

Calculate the accuracy score using the accuracy function. The accuracy function also returns a decoded network predictions predTest as class labels. The network predictions are decoded using onehotdecode.

[scoreTest, predTest] = accuracy(dlYPredTest, targetTest, classes);

View the accuracy score.

scoreTest
scoreTest = 0.9053

Visualize Predictions

To visualize the accuracy score for each category, compute the class-wise accuracy scores and visualize them using a histogram.

numOfSamples = numel(targetTest);
classTarget = zeros(numOfSamples, numOutputFeatures);
classPred = zeros(numOfSamples, numOutputFeatures);
for i = 1:numOutputFeatures
    classTarget(:,i) = targetTest==categorical(classes(i));
    classPred(:,i) = predTest==categorical(classes(i));
end

% Compute class-wise accuracy score
classAccuracy = sum(classPred == classTarget)./numOfSamples;

% Visualize class-wise accuracy score
figure
[~,idx] = sort(classAccuracy,'descend');
histogram('Categories',classes(idx), ...
    'BinCounts',classAccuracy(idx), ...
    'Barwidth',0.8)
xlabel("Category")
ylabel("Accuracy")
title("Class Accuracy Score")

The class-wise accuracy scores show how the model makes correct predictions using both the true positives and the true negatives. A true positive is an outcome where the model correctly predicts a class as present in an observation. A true negative is an outcome where the model correctly predicts a class as absent in an observation.

To visualize how the model makes incorrect predictions and evaluate the model based on class-wise precision and class-wise recall, calculate the confusion matrix using confusionmat and visualize the results using confusionchart.

Class-wise precision is the ratio of true positives to total positive predictions for a class. The total positive predictions include the true positives and false positives. A false positive is an outcome where the model incorrectly predicts a class as present in an observation.

Class-wise recall, also known as true positive rates, is the ratio of true positives to total positive observations for a class. The total positive observation includes the true positives and false negatives. A false negative is an outcome where the model incorrectly predicts a class as absent in an observation.

[confusionMatrix, order] = confusionmat(targetTest, predTest);
figure
cm = confusionchart(confusionMatrix, classes, ...
    'ColumnSummary','column-normalized', ...
    'RowSummary','row-normalized', ...
    'Title', 'GCN QM7 Confusion Chart');

The class-wise precision are the scores in the first row of the 'column summary' of the chart and the class-wise recall are the scores in the first column of the 'row summary' of the chart.

Split Data Function

The splitData function takes the adjacencyData, coloumbData, and atomicNumber data and randomly splits them into training, validation and test data in ratio 80:10:10. The function returns the corresponding split data adjacencyDataSplit, coulombDataSplit, atomicNumberSplit as cell arrays.

function [adjacencyDataSplit, coulombDataSplit, atomicNumberSplit] = splitData(adjacencyData, coulombData, atomicNumber)

adjacencyDataSplit = cell(1,3);
coulombDataSplit = cell(1,3);
atomicNumberSplit = cell(1,3);

numMolecules = size(adjacencyData, 3);

% Set initial random state for example reproducibility.
rng(0);

% Get training data
idx = randperm(size(adjacencyData, 3), floor(0.8*numMolecules));
adjacencyDataSplit{1} = adjacencyData(:,:,idx);
coulombDataSplit{1} = coulombData(:,:,idx);
atomicNumberSplit{1} = atomicNumber(idx,:);
adjacencyData(:,:,idx) = [];
coulombData(:,:,idx) = [];
atomicNumber(idx,:) = [];

% Get validation data
idx = randperm(size(adjacencyData, 3), floor(0.1*numMolecules));
adjacencyDataSplit{2} = adjacencyData(:,:,idx);
coulombDataSplit{2} = coulombData(:,:,idx);
atomicNumberSplit{2} = atomicNumber(idx,:);
adjacencyData(:,:,idx) = [];
coulombData(:,:,idx) = [];
atomicNumber(idx,:) = [];

% Get test data
adjacencyDataSplit{3} = adjacencyData;
coulombDataSplit{3} = coulombData;
atomicNumberSplit{3} = atomicNumber;

end

Preprocess Data Function

The preprocessData function preprocesses the input data as follows:

For each graph/molecule

  • Remove padded zeros from atomicNumber.

  • Concatenate the atomic number data with the atomic number data of other graph instances. It is necessary to concatenate the data since the example deals with multiple graph instances.

  • Remove padded zeros from adjacencyData.

  • Build a sparse block-diagonal matrix of the adjacency matrices of different graph instances. Each block in the matrix corresponds to the adjacency matrix of one graph instance. This step is also necessary because there are multiple graph instances in the example.

  • Extract feature array from coulombData. The feature array is the non-zero diagonal elements of the Coulomb matrix in coulombData.

  • Concatenate the feature array with feature arrays of other graph instances.

The function then converts the atomic number data to categorical arrays.

function [adjacency, features, labels] = preprocessData(adjacencyData, coulombData, atomicNumber)

adjacency = sparse([]);
features = [];
labels = [];
for i = 1:size(adjacencyData, 3)
    % Remove padded zeros from atomicNumber
    tmpLabels = nonzeros(atomicNumber(i,:));
    labels = [labels; tmpLabels];
    
    % Get the indices of the un-padded data
    validIdx = 1:numel(tmpLabels);
    
    % Use the indices for un-padded data to remove padded zeros
    % from the adjacency data
    tmpAdjacency = adjacencyData(validIdx, validIdx, i);
    
    % Build the adjacency matrix into a block diagonal matrix
    adjacency = blkdiag(adjacency, tmpAdjacency);
    
    % Remove padded zeros from coulombData and extract the
    % feature array
    tmpFeatures = diag(coulombData(validIdx, validIdx, i));
    features = [features; tmpFeatures];
end

% Convert labels to categorical array
atomicNumbers = unique(labels);
atomNames = ["Hydrogen","Carbon","Nitrogen","Oxygen","Sulphur"];
labels = categorical(labels, atomicNumbers, atomNames);

end

Normalize Features Function

The normalizeFeatures function standardizes the input training, validation, and test feature data features using the mean and variance of the training data.

function features = normalizeFeatures(features)

% Get the mean and variance from the training data
meanFeatures = mean(features{1});
varFeatures = var(features{1}, 1);

% Standardize training, validation and test data
for i = 1:3
    features{i} = (features{i} - meanFeatures)./sqrt(varFeatures);
end

end

Model Function

The model function takes the feature matrix dlX, the adjacency matrix A, and the model parameters parameters and returns the network predictions. In a preprocessing step, the model function calculates the normalized adjacency matrix described earlier using the normalizeAdjacency function provided.

function dlY = model(dlX, A, parameters)

% Normalize adjacency matrix
L = normalizeAdjacency(A);

Z1 = dlX;

Z2 = L * Z1 * parameters.W1;
Z2 = relu(Z2) + Z1;

Z3 = L * Z2 * parameters.W2;
Z3 = relu(Z3) + Z2;

Z4 = L * Z3 * parameters.W3;
dlY = softmax(Z4, 'DataFormat', 'BC');

end

Normalize Adjacency Function

The normalizeAdjacency function calculates and returns the normalized adjacency matrix normAdjacency of the input adjacency matrix adjacency.

function normAdjacency = normalizeAdjacency(adjacency)

% Add self connections to adjacency matrix
adjacency = adjacency + speye(size(adjacency));

% Compute degree of nodes
degree = sum(adjacency, 2);

% Compute inverse square root of degree
degreeInvSqrt = sparse(sqrt(1./degree));

% Normalize adjacency matrix
normAdjacency = diag(degreeInvSqrt) * adjacency * diag(degreeInvSqrt);

end

Model Gradients Function

The modelGradients function takes the feature matrix dlX, the adjacency matrix adjacencyTrain, the one-hot encoded target data T, and the model parameters parameters, and returns the gradients of the loss with respect to the model parameters, the corresponding loss, and the network predictions.

function [gradients, loss, dlYPred] = modelGradients(dlX, adjacencyTrain, T, parameters)

dlYPred = model(dlX, adjacencyTrain, parameters);

loss = crossentropy(dlYPred, T, 'DataFormat', 'BC');

gradients = dlgradient(loss, parameters);

end

Accuracy Function

The accuracy function decodes the network predictions YPred and calculates accuracy using the decoded predictions and the target data target. The function returns the computed accuracy score and the decoded predictions prediction.

function [score, prediction] = accuracy(YPred, target, classes)

% Decode probability vectors into class labels
prediction = onehotdecode(YPred, classes, 2);
score = sum(prediction == target)/numel(target);

end

References

  1. T. N. Kipf and M. Welling. Semi-supervised classification with graph convolutional networks. In ICLR, 2016.

  2. L. C. Blum, J. -L. Reymond, 970 Million Druglike Small Molecules for Virtual Screening in the Chemical Universe Database GDB-13, J. Am. Chem. Soc., 131:8732, 2009.

  3. M. Rupp, A. Tkatchenko, K.-R. Müller, O. A. von Lilienfeld: Fast and Accurate Modeling of Molecular Atomization Energies with Machine Learning, Physical Review Letters, 108(5):058301, 2012.

Copyright 2021, The MathWorks, Inc.

See Also

| | |

Related Topics