Questions tagged [trax]

For questions related to the Trax deep learning framework

Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained by the Google Brain team.

29 questions
6
votes
2 answers

How to install trax, jax, jaxlib on M1 Mac on macOS 12?

trax New to trax, I'm trying to run it locally (macOS 12.1, Apple Silicon ARM M1 processor, 8GB RAM, Anaconda), but I'm running into some issues. In an environment with python 3.8.5, I installed trax running pip3 install trax==1.3.9 inside an…
arturomp
  • 28,790
  • 10
  • 43
  • 72
5
votes
1 answer

What is the difference between JAX, Trax, and TensorRT, in simple terms?

I have been using TensorRT and TensorFlow-TRT to accelerate the inference of my DL algorithms. Then I have heard of: JAX https://github.com/google/jax Trax https://github.com/google/trax Both seem to accelerate DL. But I am having a hard time to…
Aizzaac
  • 3,146
  • 8
  • 29
  • 61
3
votes
1 answer

trax tl.Relu and tl.ShiftRight layers are nested inside Serial Combinator

I am trying to build an attention model but Relu and ShiftRight layer by default nested inside the Serial Combinator. This further gives me errors in training. layer_block = tl.Serial( tl.Relu(), tl.LayerNorm(), ) x = np.array([[-2, -1, 0,…
1
vote
1 answer

Is there a model.summary() in Trax?

I'm working with Trax, a framework built by the Google Brain team to work with deep learning models as an alternative to TensorFlow. As a TensorFlow developer, I'm pretty used to the model.summary() method (documented here) to display a full model…
Emiliano Viotti
  • 1,619
  • 2
  • 16
  • 30
1
vote
0 answers

Can I write a Neural Network with Trax in a functional way?

I am trying to learn Trax. I have previous exprience with Tensorflow, and I prefer writing neural networks with functional api(https://www.tensorflow.org/guide/keras/functional). I was wondering if it is possible to do the same with Trax, the…
1
vote
1 answer

No module named 'trax'

Guys, Help we with this, I installed trax with pip install trax, but it showing error, after restarting the kernel also Getting this error.
Aravind R
  • 835
  • 1
  • 7
  • 14
1
vote
1 answer

Import trax takes too long to load

I was stumped the first time I loaded this library. In my local computer it tooks me at least 40s to load trax on a local Jupyter Notebook and more than 1 minute to load it on a shared Colab environment. import trax I'm not sure if it's an issue…
Emiliano Viotti
  • 1,619
  • 2
  • 16
  • 30
1
vote
0 answers

How to switch to predict mode in Trax after training a model?

I'm familiarizing myself with the Trax library for building deep learning models and one question that I can't find an answer to is how to switch from "train" mode to "eval" mode after model training is complete. Consider this example with a…
djvaroli
  • 1,223
  • 1
  • 11
  • 28
1
vote
0 answers

How to use multiple heads option in selfAttention class?

I am playing around with Self-attention model from trax library. when I set n_heads=1, everything works fine. But when I set n_heads=2, my code breaks. I use only input activations and one SelfAttention layer. Here is a minimal code: import…
Kenenbek Arzymatov
  • 8,439
  • 19
  • 58
  • 109
1
vote
0 answers

Why does Trax automatically create a Serial layer over a sublayer?

I implemented a Serial Layer in Trax (deep learning library by Google). Why does an additional Serial layer is created in spite of already declaring one? Below is the code. model = tl.Serial( tl.Dense(n_units=512), tl.Relu() …
1
vote
1 answer

TensorBoard with Trax

Anyone managed to log the loss with TensorBoard? I am using the trax ml library. I am getting this error TypeError: 'SummaryWriter' object is not callable. I am using the SummaryWriter from jaxboard and then adding it to callbacks within…
Exa
  • 466
  • 3
  • 16
1
vote
1 answer

creating a custom TFDS dataset

I would like to create a custom tensorflow dataset for summarization task. I have a set of reports with three gold summaries for every report. All the data is in (.txt) format. I would like to create a TFDS where the key is the report and the value…
Nadhem
  • 23
  • 4
1
vote
1 answer

Multivariate regression using trax

How do I set up a multi-variate regression problem using Trax? I get AssertionError: Invalid shape (16, 2); expected (16,). from the code below, coming from the L2Loss object. The following is my attempt to adapt the sentiment analysis example into…
Dave
  • 7,555
  • 8
  • 46
  • 88
1
vote
1 answer

Understanding introductory example on transformers in Trax

My goal is to understand the introductory example on transformers in Trax, which can be found at https://trax-ml.readthedocs.io/en/latest/notebooks/trax_intro.html: import trax # Create a Transformer model. # Pre-trained model config in…
Sebastian Thomas
  • 481
  • 3
  • 14
0
votes
0 answers

Getting an error "_kwargs = spec_.kwargs.copy()" saying "AttributeError: 'NoneType' object has no attribute 'copy'" when running tensor2tensor

I am currently trying to run this google collab notebook. It is a notebook that creates a transformer model and takes input from piano performances and creates newly generated music. I am getting this error when I try importing the tensor2tensor…
1
2