Quantcast
Channel: Active questions tagged python - Stack Overflow
Viewing all articles
Browse latest Browse all 23131

TF Transformer model never overfits and just plateaus: Interpretation of this training curve and suggestions for improvement

$
0
0

This training curve is for a Transformer model that processes 2D (excluding batch) sequential signal and uses Adam optimizer, 32 batch size and for the learning rate: a custom LR Scheduler that replicates the warmup scheduler that is used at 'Attention is All You Need' paper. Training curve as below plateaus with eventual Training loss slightly lower than Validation loss, but training loss never starts back to climb, which I interpreted as the model never starts overfitting and just stops re-adjusting weights after around epoch 90.

Better interpretation and solutions to improve this model?

enter image description here

Below is my brief reproducible code:

x_train = np.random.normal(size=(32, 512, 512))batch_size = 32H, W = x_train.shaperows, cols = np.indices((H, W), sparse=True)padding_mask_init = np.zeros((H, W, W), dtype=np.bool_)padding_mask_init[rows, 1:, cols] = 1padding_mask = padding_mask_init[:batch_size]embed_dim = 512dense_dim = 2048num_heads = 2shape = (batch_size, embed_dim, 512) #(32, 512, 512)decoder_inputs = layers.Input(batch_input_shape=shape, dtype=tensorflow.float16)mha_1 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)mha_2 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)layernorm_1 = layers.LayerNormalization()Z = decoder_inputsZ = mha_1(query=Z, value=Z, key=Z, use_causal_mask=True, attention_mask=padding_mask)Z = layernorm_1(Z + decoder_inputs)Z = mha_2(query=Z, value=decoder_inputs, key=decoder_inputs, attention_mask=padding_mask)outputs = layers.TimeDistributed(keras.layers.Dense(embed_dim, activation="softmax"))(Z)model = keras.Model(decoder_inputs, outputs)model.compile(loss="mean_squared_error", optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule(embed_dim, 3000),beta_1=0.9,beta_2=0.98,epsilon=1.0e-9), metrics=["accuracy"])history = model.fit(dataset, epochs=200, validation_data=val_dataset)

Viewing all articles
Browse latest Browse all 23131

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>