I am working on a regression problem using TensorFlow, where I have encountered a challenge with my loss function. My data points are structured as triples $(Y_i, G_i, X_i)$, where
$Y_i \in \mathbb{R}$ represents an outcome; $G_i \in {0,1}$ is a binary group identifier; $X_i \in \mathbb{R}^d$ is a feature vector.
The goal is to train a neural network that predicts $\hat{Y}$ given $X$, using a custom loss function that is the absolute difference in Mean Squared Errors (MSE) between the two groups.
Formally, the loss function for a prediction algorithm $f \colon \mathbb{R}^d \to \mathbb{R}$ is defined as$$e_k(f) := \frac{\sum_{i \colon G_i=k} (Y_i - f(X_i))^2)}{\sum_i 1{G_i=k}}.$$The loss is then $|e_0(f) - e_1(f)|$. Alternatively, minimizing $(e_0(f) - e_1(f))^2$ is equivalent.
I am struggling to implement this loss function in TensorFlow, as it is not a straightforward sum of individual data point losses. The challenge lies in the loss depending on multiple data points across both groups.
My main questions are
- How can I structure this loss function in TensorFlow, considering its dependence on group-wise calculations?
- Are there any specific TensorFlow functions or techniques that would simplify the implementation of such a group-based loss function?Any guidance or suggestions on how to proceed would be greatly appreciated.
What I tried
import numpy as npimport tensorflow as tffrom sklearn.datasets import make_regressionfrom sklearn.model_selection import train_test_splitdef custom_loss(group): def loss(y_true, y_pred): # reshape y_pred = tf.reshape(y_pred, [-1]) y_true = tf.reshape(y_true, [-1]) # Create a mask for each batch mask_b = tf.equal(group, 1) mask_r = tf.equal(group, 0) y_pred_b = tf.boolean_mask(y_pred, mask_b) y_pred_r = tf.boolean_mask(y_pred, mask_r) y_true_b = tf.boolean_mask(y_true, mask_b) y_true_r = tf.boolean_mask(y_true, mask_r) # Ensure same data type y_pred_b = tf.cast(y_pred_b, y_true.dtype) y_pred_r = tf.cast(y_pred_r, y_true.dtype) mse_b = tf.reduce_mean(tf.square(y_true_b - y_pred_b)) mse_r = tf.reduce_mean(tf.square(y_true_r - y_pred_r)) return tf.abs(mse_b - mse_r) return loss# Since the loss depends on the group average, batch_size should be sufficiently large (?)def train_early_stopping(model, custom_loss, X_train, y_train, g_train, X_val, y_val, g_val, n_epoch=500, patience=10, batch_size=1000): # Initialize variables for early stopping best_val_loss = float('inf') wait = 0 best_epoch = 0 for epoch in range(n_epoch): if epoch == n_epoch-1: print('Not converged.') break loss_epoch_list = [] for step in range(len(X_train) // batch_size): with tf.GradientTape() as tape: start = step * batch_size end = start + batch_size if end > X_train.shape[0]: break else: X_batch = X_train[start:end] y_batch = y_train[start:end] g_batch = g_train[start:end] y_pred = model(X_batch, training=True) loss_value = custom_loss(g_batch)(y_batch, y_pred) loss_epoch_list.append(loss_value.numpy()) grads = tape.gradient(loss_value, model.trainable_variables) model.optimizer.apply_gradients(zip(grads, model.trainable_variables)) # Calculate validation loss for EarlyStopping val_loss = custom_loss(g_val)(y_val, model.predict(X_val)) print(f"Epoch {epoch+1}: Train Loss: {np.mean(loss_epoch_list)}, Validation Loss: {val_loss}") # Early stopping check if val_loss < best_val_loss: best_val_loss = val_loss best_weights = model.get_weights() wait = 0 best_epoch = epoch else: wait += 1 if wait >= patience: print(f"Early Stopping triggered at epoch {best_epoch + 1}, Validation Loss: {best_val_loss}") model.set_weights(best_weights) break# Create a synthetic datasetX, y = make_regression(n_samples=20000, n_features=10, noise=0.2, random_state=42)group = np.random.choice([0, 1], size=y.shape) # 1 for 'b', 0 for 'r'X_train_full, X_test, y_train_full, y_test, g_train_full, g_test = train_test_split(X, y, group, test_size=0.5, random_state=42)X_train, X_val, y_train, y_val, g_train, g_val = train_test_split(X_train_full, y_train_full, g_train_full, test_size=0.2, random_state=42)# mainnum_unit = 64model_fair = tf.keras.Sequential([ tf.keras.layers.Dense(num_unit, activation='relu', input_shape=(X.shape[1],)), tf.keras.layers.Dense(num_unit, activation='relu'), tf.keras.layers.Dense(1)])model_fair.compile(optimizer='adam')batch_size = X_train.shape[0]//5train_early_stopping(model_fair, custom_loss, X_train, y_train, g_train, X_val, y_val, g_val, patience=10, batch_size=batch_size)
The code executes without generating any error. However, the training, validation, and test set losses differ significantly, indicating potential issues with the training process.