Quantcast
Viewing all articles
Browse latest Browse all 14126

Implementing a Custom Loss Function in TensorFlow for Regression Problem with Group-Specific MSEs

I am working on a regression problem using TensorFlow, where I have encountered a challenge with my loss function. My data points are structured as triples $(Y_i, G_i, X_i)$, where
$Y_i \in \mathbb{R}$ represents an outcome; $G_i \in {0,1}$ is a binary group identifier; $X_i \in \mathbb{R}^d$ is a feature vector.

The goal is to train a neural network that predicts $\hat{Y}$ given $X$, using a custom loss function that is the absolute difference in Mean Squared Errors (MSE) between the two groups.

Formally, the loss function for a prediction algorithm $f \colon \mathbb{R}^d \to \mathbb{R}$ is defined as$$e_k(f) := \frac{\sum_{i \colon G_i=k} (Y_i - f(X_i))^2)}{\sum_i 1{G_i=k}}.$$The loss is then $|e_0(f) - e_1(f)|$. Alternatively, minimizing $(e_0(f) - e_1(f))^2$ is equivalent.

I am struggling to implement this loss function in TensorFlow, as it is not a straightforward sum of individual data point losses. The challenge lies in the loss depending on multiple data points across both groups.

My main questions are

  • How can I structure this loss function in TensorFlow, considering its dependence on group-wise calculations?
  • Are there any specific TensorFlow functions or techniques that would simplify the implementation of such a group-based loss function?Any guidance or suggestions on how to proceed would be greatly appreciated.

What I tried

import numpy as npimport tensorflow as tffrom sklearn.datasets import make_regressionfrom sklearn.model_selection import train_test_splitdef custom_loss(group):    def loss(y_true, y_pred):        # reshape        y_pred = tf.reshape(y_pred, [-1])        y_true = tf.reshape(y_true, [-1])        # Create a mask for each batch        mask_b = tf.equal(group, 1)        mask_r = tf.equal(group, 0)        y_pred_b = tf.boolean_mask(y_pred, mask_b)        y_pred_r = tf.boolean_mask(y_pred, mask_r)        y_true_b = tf.boolean_mask(y_true, mask_b)        y_true_r = tf.boolean_mask(y_true, mask_r)        # Ensure same data type        y_pred_b = tf.cast(y_pred_b, y_true.dtype)        y_pred_r = tf.cast(y_pred_r, y_true.dtype)        mse_b = tf.reduce_mean(tf.square(y_true_b - y_pred_b))        mse_r = tf.reduce_mean(tf.square(y_true_r - y_pred_r))        return tf.abs(mse_b - mse_r)    return loss# Since the loss depends on the group average, batch_size should be sufficiently large (?)def train_early_stopping(model, custom_loss,                         X_train, y_train, g_train, X_val, y_val, g_val,                         n_epoch=500, patience=10, batch_size=1000):    # Initialize variables for early stopping    best_val_loss = float('inf')    wait = 0    best_epoch = 0    for epoch in range(n_epoch):        if epoch == n_epoch-1:            print('Not converged.')            break        loss_epoch_list = []        for step in range(len(X_train) // batch_size):            with tf.GradientTape() as tape:                start = step * batch_size                end = start + batch_size                if end > X_train.shape[0]:                    break                else:                    X_batch = X_train[start:end]                    y_batch = y_train[start:end]                    g_batch = g_train[start:end]                    y_pred = model(X_batch, training=True)                    loss_value = custom_loss(g_batch)(y_batch, y_pred)                loss_epoch_list.append(loss_value.numpy())                grads = tape.gradient(loss_value, model.trainable_variables)                model.optimizer.apply_gradients(zip(grads, model.trainable_variables))        # Calculate validation loss for EarlyStopping        val_loss = custom_loss(g_val)(y_val, model.predict(X_val))        print(f"Epoch {epoch+1}: Train Loss: {np.mean(loss_epoch_list)}, Validation Loss: {val_loss}")        # Early stopping check        if val_loss < best_val_loss:            best_val_loss = val_loss            best_weights = model.get_weights()            wait = 0            best_epoch = epoch        else:            wait += 1            if wait >= patience:                print(f"Early Stopping triggered at epoch {best_epoch + 1}, Validation Loss: {best_val_loss}")                model.set_weights(best_weights)                 break# Create a synthetic datasetX, y = make_regression(n_samples=20000, n_features=10, noise=0.2, random_state=42)group = np.random.choice([0, 1], size=y.shape)  # 1 for 'b', 0 for 'r'X_train_full, X_test, y_train_full, y_test, g_train_full, g_test = train_test_split(X, y, group, test_size=0.5, random_state=42)X_train, X_val, y_train, y_val, g_train, g_val = train_test_split(X_train_full, y_train_full, g_train_full, test_size=0.2, random_state=42)# mainnum_unit = 64model_fair = tf.keras.Sequential([    tf.keras.layers.Dense(num_unit, activation='relu', input_shape=(X.shape[1],)),    tf.keras.layers.Dense(num_unit, activation='relu'),    tf.keras.layers.Dense(1)])model_fair.compile(optimizer='adam')batch_size = X_train.shape[0]//5train_early_stopping(model_fair, custom_loss, X_train, y_train, g_train, X_val, y_val, g_val,                     patience=10, batch_size=batch_size)

The code executes without generating any error. However, the training, validation, and test set losses differ significantly, indicating potential issues with the training process.


Viewing all articles
Browse latest Browse all 14126

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>