How to Train a Text Classification Model with trax

How to Train a Text Classification Model with trax


What is the trax deep learning framework

trax is coming out of the Google Brain team and is the latest iteration after almost a decade of work on TensorFlow, machine translation, and Tensor2Tensor. Being a new-comer in a somewhat crowded space (keras, pytorch, thinc), it has been able to learn from those APIs. In particular:

  • it is very concise
  • it runs on TensorFlow backend
  • it uses Jax to speed up tensor-based computation (instead of numpy)

Text Classification on the AG News Dataset

The AG's news topic classification dataset is a dataset of 120,000 training samples of news articles that are of one of 4 classes.

It has been used as a text classification benchmark in numerous papers, like [1509.01626]. You can see the papers using that dataset on

Text Classification on AG News

On top of being a well-understood dataset, it has the advantage of being available on TensorFlow Datasets. Therefore we will use it in the rest of this post.

Prepare the dataset using and data.Serial

trax needs generators of data. Each element is a tuple (input, target) or (input, target, weight) (usually weight is =1 because all examples have the same importance).

train_stream = data.TFDS(
    keys=('description', 'label'),
eval_stream = data.TFDS(
    keys=('description', 'label'),
data_pipeline = data.Serial(
    data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
    data.FilterByLength(max_length=2048, length_keys=[0]),
    data.BucketByLength(boundaries=[  32, 128, 512, 2048],
                             batch_sizes=[512, 128,  32,    8, 1],
train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)

Build the Text Classification Model using tl.Serial

trax is really concise, you can use the library of layers available.

model = tl.Serial(
    tl.Embedding(vocab_size=8192, d_feature=50),

Model Training using training.Loop

For training, there is the concept of a "task" which wraps the data, the optimiser, the metrics etc...

# Training task.
train_task = training.TrainTask(

# Evaluaton task.
eval_task = training.EvalTask(
    metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],
    n_eval_batches=20  # For less variance in eval numbers.

# Training loop saves checkpoints to output_dir.
output_dir = os.path.expanduser('~/output-dir/')
!rm -rf {output_dir}
training_loop = training.Loop(model,

# Run 2000 steps (batches).

Look at the Predictions

inputs, targets, weights = next(eval_batches_stream)

example_input = inputs[0]
expected_class = targets[0]
example_input_str =, vocab_file='en_8k.subword')
print(f'example input_str: {example_input_str}')
sentiment_log_probs = model(example_input[None, :])  # Add batch dimension.
print(f'Model returned class probabilities: {np.exp(sentiment_log_probs)}')
print(f'Expected class: {expected_class}')


  1. google/trax: Trax — Deep Learning with Clear Code and Speed
  2. Trax Tutorials — Trax documentation
  3. ag_news_subset | TensorFlow Datasets
  4. [1509.01626] Character-level Convolutional Networks for Text Classification
  5. AG News Benchmark (Text Classification) | Papers With Code
  6. TensorFlow Datasets