A general purpose runner for TF-GNN.
class ContextLabelFn
: Reads out a tfgnn.Field
from the GraphTensor
class DatasetProvider
: Helper class that
provides a standard way to create an ABC using inheritance.
class DotProductLinkPrediction
Implements edge score as dot product of features of endpoint nodes.
class FitOrSkipPadding
: Calculates fit or skip
for GraphTensor
class GraphBinaryClassification
Graph binary (or multi-label) classification from pooled node states.
class GraphMeanAbsoluteError
: Regression
from pooled node states with mean absolute error.
class GraphMeanAbsolutePercentageError
Regression from pooled node states with mean absolute percentage error.
class GraphMeanSquaredError
: Regression
from pooled node states with mean squared error.
class GraphMeanSquaredLogScaledError
Regression from pooled node states with mean squared log scaled error.
class GraphMeanSquaredLogarithmicError
Regression from pooled node states with mean squared logarithmic error.
class GraphMulticlassClassification
Graph multiclass classification from pooled node states.
class GraphTensorPadding
: Collects
padding helpers.
class GraphTensorProcessorFn
: A class
for GraphTensor
class HadamardProductLinkPrediction
Implements edge score as hadamard product of features of endpoint nodes.
class IntegratedGradientsExporter
Exports a Keras model with an additional integrated gradients signature.
class KerasModelExporter
: Exports a Keras
model (with Keras API) via tf.keras.models.save_model
class KerasTrainer
: Trains using the
training loop.
class KerasTrainerCheckpointOptions
Provides Keras Checkpointing related configuration options.
class KerasTrainerOptions
: Provides Keras
training related options.
class ModelExporter
: Saves a Keras model.
class NodeBinaryClassification
: Node
binary (or multi-label) classification via structured readout.
class NodeMeanAbsoluteError
: Node
regression with mean absolute error via structured readout.
class NodeMeanAbsolutePercentageError
Node regression with mean absolute percentage error via structured readout.
class NodeMeanSquaredError
: Node
regression with mean squared error via structured readout.
class NodeMeanSquaredLogScaledError
Node regression with mean squared log scaled error via structured readout.
class NodeMeanSquaredLogarithmicError
Node regression with mean squared log error via structured readout.
class NodeMulticlassClassification
Node multiclass classification via structured readout.
class ParameterServerStrategy
: A
convenience wrapper.
class PassthruDatasetProvider
: Builds a
from a pass thru dataset.
class PassthruSampleDatasetsProvider
Builds a sampled tf.data.Dataset
from multiple pass thru datasets.
class RootNodeBinaryClassification
Root node binary (or multi-label) classification.
class RootNodeLabelFn
: Reads out a
from the GraphTensor
root (i.e. first) node.
class RootNodeMeanAbsoluteError
: Root
node regression with mean absolute error.
class RootNodeMeanAbsolutePercentageError
Root node regression with mean absolute percentage error.
class RootNodeMeanSquaredError
: Root
node regression with mean squared error.
class RootNodeMeanSquaredLogScaledError
Root node regression with mean squared log scaled error.
class RootNodeMeanSquaredLogarithmicError
Root node regression with mean squared logarithmic error.
class RootNodeMulticlassClassification
Root node multiclass classification.
class RunResult
: Holds the return values of
class SampleTFRecordDatasetsProvider
Builds a sampling tf.data.Dataset
from multiple filenames.
class SimpleDatasetProvider
: Builds a
from a list of files.
class SimpleSampleDatasetsProvider
Builds a sampling tf.data.Dataset
from multiple filenames.
class SubmoduleExporter
: Exports a Keras
class TFDataServiceConfig
: Provides tf.data
service related configuration options.
class TFRecordDatasetProvider
: Builds a
from a list of files.
class TPUStrategy
: A TPUStrategy
class Task
: Defines a learning objective for a GNN.
class TightPadding
: Calculates tight
for GraphTensor
class Trainer
: A class for training and validation of a
Keras model.
: Exports a Keras model without
traces s.t. it is loadable without TF-GNN.
: Create,
given some dirname
, an incrementing model directory.
: Integrated
: Returns a
node_set_name: 1
for every node set in gtspec
: Runs training (and validation) of a model on
task(s) with the given data.