It's a trade-off between training time, accuracy, available ram, data size ...etc.
The complexity of the model vs the size of the data is known as the bias/variance trade-off, and you can read more about it here. But it is normally good practice to start off simple, and then add complexity as long as the model keeps improving.
If the input dimension has a high number of classes (like words in text) consider using embedding (or other methods) to reduce the dimension.
Assuming you have a sufficient large data set, enough time to train, and can choose your model infinitely complex, only restricted by ram, then following is a ram analysis.
The model ram size is dependent on the network architecture, dtype and batch size.
To test the required memory size to train your network, you can use the function from ZFTurbo's answer here, with a copy below.
def get_model_memory_usage(batch_size, model):
import numpy as np
from tensorflow.keras import backend as K
shapes_mem_count = 0
for l in model.layers:
single_layer_mem = 1
for s in l.output_shape:
if s is None:
continue
single_layer_mem *= s
shapes_mem_count += single_layer_mem
trainable_count = np.sum([K.count_params(p) for p in set(model.trainable_weights)])
non_trainable_count = np.sum([K.count_params(p) for p in set(model.non_trainable_weights)])
number_size = 4.0
if K.floatx() == 'float16':
number_size = 2.0
if K.floatx() == 'float64':
number_size = 8.0
total_memory = number_size*(batch_size*shapes_mem_count + trainable_count + non_trainable_count)
gbytes = np.round(total_memory / (1024.0 ** 3), 3)
return gbytes
Now we can easily test different setup, and it's corresponding ram size required. Below is an example 2 hidden layers, dtype is float32, batch size is 256 and where we set the first hidden layer size to 1k, 10k, 100k and 1000k, which results in the corresponding 0.06, 0.052, 0.517 and 5.16 GB ram required for training.
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
K.set_floatx('float32')
n_input_nodes = [1e3, 1e4, 1e5, 1e6]
for n in n_input_nodes:
inputs = Input(shape=(1000,))
x = Dense(n, activation='relu')(inputs)
x = Dense(256, activation='relu')(x)
outputs = Dense(1)(x)
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss='mse')
print(get_model_memory_usage(128, model))