I am currently working on a steganalysis model using a capsule network. I used this implementation of CapsNet, which was based on Xifeng Guo's Keras code.
I replicated DRCaps, a CapsNet model used for COVID Diagnostics through X-ray Images to use due to following reasons:
- Same size and type of input images (256 × 256 grayscale images).
- The authors experimented with different hyperparameters configurations and their final DRCaps model achieved a high accuracy of 90.0%.
- I figured the model's good performance would also translate to binary classification in steganalysis.
For your reference, here are the hyperparameters of the DRCaps model.
I have the same parameters for Input, Convolutional, and PrimaryCaps layers. For the ClassCaps layer, I changed the number of classes to 2 (cover/stego) instead of 3. My implementation of the model and its summary are shown below.
class PrimaryCaps(Layer):"""A PrimaryCaps layer. It allows to move to capsule's domain, encapsulating scalars in vectors. Args: n_caps: The number of capsules in this layer. dims_caps: The dimension of the output vector of a capsule. kernel_size: Height and width of the 2D convolution window. strides: Strides of the convolution along the height and width. padding: Type of padding used in the convolution. activation: Activation function to use.""" def __init__( self, n_caps, dims_caps, kernel_size, strides, padding, activation=None, **kwargs, ): super(PrimaryCaps, self).__init__(**kwargs) self.n_caps = n_caps self.dims_caps = dims_caps self.kernel_size = kernel_size self.strides = strides self.padding = padding self.activation = activation def get_config(self): config = super().get_config().copy() config.update( {"n_caps": self.n_caps,"dims_caps": self.dims_caps,"kernel_size": self.kernel_size,"strides": self.strides,"padding": self.padding,"activation": self.activation, } ) return config def build(self, input_shape): assert ( len(input_shape) >= 4 ), "The input Tensor of a PrimaryCaps should have shape=[None, width, height, channels]" # Apply Convolution n_caps times self.conv2d = Conv2D( filters=self.n_caps * self.dims_caps, kernel_size=self.kernel_size, strides=self.strides, padding=self.padding, activation=self.activation, name="primarycaps_conv2d", ) # Reshape the convolutional layer output feature_dims = int((input_shape[1] - self.kernel_size + 1) / self.strides) self.reshape = Reshape( (feature_dims ** 2 * self.n_caps, self.dims_caps), name="primarycaps_reshape", ) # Squash the vectors output self.squash = Lambda(squash, name="primarycaps_squash") self.built = True def call(self, inputs): x = self.conv2d(inputs) x = self.reshape(x) return self.squash(x)class DenseCaps(Layer):"""A DenseCaps layer, where the dynamic routing algorithm is executed. Args: n_caps: The number of capsules in this layer. dims_caps: The dimension of the output vector of a capsule. r_iter: Number of routing iterations. kernel_initializer: Initializer that define the way to set the initial random weights. shared_weights: number of input capsules that must have the same weight.""" def __init__( self, n_caps, dims_caps, r_iter=3, kernel_initializer=initializers.RandomNormal(stddev=0.1), shared_weights=1, **kwargs, ): super(DenseCaps, self).__init__(**kwargs) self.n_caps = n_caps self.dims_caps = dims_caps self.r_iter = r_iter self.kernel_initializer = kernel_initializer self.shared_weights = shared_weights def get_config(self): config = super().get_config().copy() config.update( {"n_caps": self.n_caps,"dims_caps": self.dims_caps,"r_iter": self.r_iter,"kernel_initializer": self.kernel_initializer,"shared_weights": self.shared_weights, } ) return config def build(self, input_shape): assert ( len(input_shape) == 3 ), "The input Tensor of a DenseCaps should have shape=[None, input_n_caps, input_dims_caps]" self.input_n_caps = input_shape[1] self.input_dims_caps = input_shape[2] self.W = self.add_weight( name="W", shape=( 1, self.input_n_caps // self.shared_weights, self.n_caps, self.dims_caps, self.input_dims_caps, ), initializer=self.kernel_initializer, ) self.built = True def call(self, inputs): # Calculate predictions # Note: Matmul doesn't care about batch_size (it just uses the same self.W multiple times) inputs = tf.expand_dims(tf.expand_dims(inputs, -1), 2) W_tiled = tf.tile(self.W, [1, self.shared_weights, 1, 1, 1]) predictions = tf.matmul(W_tiled, inputs) # === DYNAMIC ROUTING === raw_weights = tf.zeros([1, self.input_n_caps, self.n_caps, 1, 1]) for i in range(self.r_iter): # Line 4, computes Eq.(3) routing_weights = tf.nn.softmax(raw_weights, axis=2) # Line 5 outputs = tf.reduce_sum( routing_weights * predictions, axis=1, keepdims=True ) # Line 6 outputs = squash(outputs, axis=-2) # Line 7 if i < self.r_iter - 1: outputs_tiled = tf.tile(outputs, [1, self.input_n_caps, 1, 1, 1]) raw_weights += tf.matmul(predictions, outputs_tiled, transpose_a=True) return tf.squeeze(outputs, axis=[1, -1]), routing_weightsdef mask(inputs):"""Mask a Tensor with shape (batch_size, n_capsules, dim_vector). It can be done either by selecting the capsule with max length or by an additional input mask. The first is usually the method for testing, the second is the one for the training. Args: inputs: Either a tensor to be masked (output of the class capsules) or a tensor with both the tensor and an additional input mask""" # Mask provided? if type(inputs) is tuple or type(inputs) is list: inputs, mask = inputs[0], inputs[1] else: # Calculate the mask by the max length of capsules. x = compute_vectors_length(inputs) # Generate one-hot encoded mask mask = one_hot(indices=argmax(x, 1), num_classes=x.get_shape().as_list()[-1]) # Mask the inputs masked = batch_flatten(inputs * expand_dims(mask, -1)) return maskeddef compute_vectors_length(vecs, axis=-1):"""Compute vectors' length. This is used to compute final prediction as probabilities. Args: vecs: A tensor with shape (batch_size, n_vectors, dim_vector) Returns: A new tensor with shape (batch_size, n_vectors)""" return tf.sqrt(tf.reduce_sum(tf.square(vecs), axis) + epsilon())def squash(vectors, axis=-1):"""The non-linear activation used in Capsule, computes Eq.(1) It drives the length of a large vector to near 1 and small vector to 0. Args: vectors: The vectors to be squashed, N-dim tensor. axis: The axis to squash. Returns: A tensor with the same shape as input vectors, but squashed in 'vec_len' dimension.""" s_squared_norm = tf.reduce_sum(tf.square(vectors), axis, keepdims=True) scale = s_squared_norm / (1 + s_squared_norm) / tf.sqrt(s_squared_norm + epsilon()) return scale * vectors
def DRCaps(input_shape, n_class, name="CapsuleNetwork"):"""Capsule Network model implementation, used for MNIST dataset training. The model has been adapted from the [official paper](https://arxiv.org/abs/1710.09829). Arguments: input_shape: 3-Dimensional data shape (width, height, channels). n_class: Number of classes.""" tf.keras.backend.clear_session() # --- Encoder --- # Input x = Input(shape=input_shape) # Layer 1 conv1 = Conv2D( filters=128, kernel_size=(3,3), strides=1, padding="valid", dilation_rate=(8,8), activation="relu", name="conv1", )(x) # Layer 2 conv2 = Conv2D( filters=64, kernel_size=(3,3), strides=2, padding="valid", dilation_rate=(1,1), activation="relu", name="conv2", )(conv1) # Layer 3 conv3 = Conv2D( filters=128, kernel_size=(3,3), strides=1, padding="valid", dilation_rate=(4,4), activation="relu", name="conv3", )(conv2) # Layer 4 conv4 = Conv2D( filters=64, kernel_size=(3,3), strides=2, padding="valid", dilation_rate=(1,1), activation="relu", name="conv4", )(conv3) # Layer 5 conv5 = Conv2D( filters=64, kernel_size=(3,3), strides=1, padding="valid", dilation_rate=(2,2), activation="relu", name="conv5", )(conv4) # Layer 6 conv6 = Conv2D( filters=64, kernel_size=(3,3), strides=2, padding="valid", dilation_rate=(1,1), activation="relu", name="conv6", )(conv5) # Layer 7 conv7 = Conv2D( filters=64, kernel_size=(3,3), strides=2, padding="valid", dilation_rate=(1,1), activation="relu", name="conv7", )(conv6) # Layer 8: PrimaryCaps Layer primary_caps = PrimaryCaps( n_caps=32, dims_caps=8, kernel_size=9, strides=2, padding="valid", activation="relu", name="primary_caps", )(conv7) # Layer 9: ClassCaps Layer: since routing it is computed only # between two consecutive capsule layers, it only happens here class_caps = DenseCaps(n_caps=n_class, dims_caps=16, name="digit_caps")( primary_caps )[0] # Layer 4: A convenience layer to calculate vectors' length vec_len = Lambda(compute_vectors_length, name="vec_len")(class_caps) # --- Decoder --- y = Input(shape=(n_class,)) # Layer 10: A convenience layer to compute the masked capsules' output masked = Lambda(mask, name="masked")( class_caps ) # Mask using the capsule with maximal length. For prediction masked_by_y = Lambda(mask, name="masked_by_y")( [class_caps, y] ) # The true label is used to mask the output of capsule layer. For training # Layer 11-14: Four Dense layer for the image reconstruction decoder = Sequential(name="decoder") decoder.add(Dense(64, activation="relu", input_dim=16 * n_class, name="dense_1")) decoder.add(Dense(128, activation="relu", name="dense_2")) decoder.add(Dense(128, activation="relu", name="dense_3")) decoder.add( Dense(tf.math.reduce_prod(input_shape), activation="sigmoid", name="dense_4") ) # Layer 15: Reshape the output as the image provided in input decoder.add(Reshape(target_shape=input_shape, name="img_reconstructed")) # Models for training and evaluation (prediction) train_model = Model( inputs=[x, y], outputs=[vec_len, decoder(masked_by_y)], name=name ) eval_model = Model(inputs=x, outputs=[vec_len, decoder(masked)], name=name) return train_model, eval_model
model_params = {"input_shape": X_train.shape[1:],"n_class": y_train.shape[1],"name": "model_DRCaps_S_UNIWARD_04bpp",}model, eval_model = DRCaps(**model_params)model.summary()
Model: "model_DRCaps_S_UNIWARD_04bpp"__________________________________________________________________________________________________Layer (type) Output Shape Param # Connected to ==================================================================================================input_1 (InputLayer) [(None, 256, 256, 1) 0 __________________________________________________________________________________________________conv1 (Conv2D) (None, 240, 240, 128 1280 input_1[0][0] __________________________________________________________________________________________________conv2 (Conv2D) (None, 119, 119, 64) 73792 conv1[0][0] __________________________________________________________________________________________________conv3 (Conv2D) (None, 111, 111, 128 73856 conv2[0][0] __________________________________________________________________________________________________conv4 (Conv2D) (None, 55, 55, 64) 73792 conv3[0][0] __________________________________________________________________________________________________conv5 (Conv2D) (None, 51, 51, 64) 36928 conv4[0][0] __________________________________________________________________________________________________conv6 (Conv2D) (None, 25, 25, 64) 36928 conv5[0][0] __________________________________________________________________________________________________conv7 (Conv2D) (None, 12, 12, 64) 36928 conv6[0][0] __________________________________________________________________________________________________primary_caps (PrimaryCaps) (None, 128, 8) 1327360 conv7[0][0] __________________________________________________________________________________________________digit_caps (DenseCaps) ((None, 2, 16), (Non 32768 primary_caps[0][0] __________________________________________________________________________________________________input_2 (InputLayer) [(None, 2)] 0 __________________________________________________________________________________________________masked_by_y (Lambda) (None, None) 0 digit_caps[0][0] input_2[0][0] __________________________________________________________________________________________________vec_len (Lambda) (None, 2) 0 digit_caps[0][0] __________________________________________________________________________________________________decoder (Sequential) (None, 256, 256, 1) 8481088 masked_by_y[0][0] ==================================================================================================Total params: 10,174,720Trainable params: 10,174,720Non-trainable params: 0__________________________________________________________________________________________________
After replicating the model, I proceeded with the training using the hyperparameters of the DRCaps model: λ recon = 32.768, a batch size equal to 32, and the Adam optimizer with the default parameters: a learning rate, lr = 0.001, and a lr decay = 0.9.
The training function is a revision of the training function found in this repository of GBRAS-Net, a steganalysis model based on CNN.
def train(model, X_train, y_train, X_valid, y_valid, X_test, y_test, batch_size=32, epochs=50, lr=0.001, lr_decay_mul=0.9, lam_recon=32.768):"""Train a given Capsule Network model. Args: model: The CapsuleNet model to train. data: The dataset that you want to train: ((x_train, y_train), (x_test, y_test)). epochs: Number of epochs for the training. batch_size: Size of the batch used for the training. lr: Initial learning rate value. lr_decay_mul: The value multiplied by lr at each epoch. Set a larger value for larger epochs. lam_recon: The coefficient for the loss of decoder (if present). Returns: The trained model.""" # Compile the model model.compile( optimizer=optimizers.Adam(lr=lr), loss=[margin_loss, "mse"], loss_weights=[1.0, lam_recon], metrics=["accuracy"], ) start_time = tm.time() log_dir=path_log_base+"/"+model.name+"_"+str(datetime.datetime.now().isoformat()[:19].replace("T", "_").replace(":","-")) tensorboard = tf.keras.callbacks.TensorBoard(log_dir, histogram_freq=1) filepath = log_dir+"/saved-model-{epoch:03d}-{val_vec_len_accuracy:.4f}-{val_decoder_accuracy:.4f}.hdf5" checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath, monitor=['val_vec_len_accuracy', 'val_decoder_accuracy'], save_best_only=False, mode='max') # Define a callback to reduce learning rate lrcb = callbacks.LearningRateScheduler( schedule=lambda epoch: lr * (lr_decay_mul ** epoch) ) # Define a callback to save the weights of the model weights_dir = os.path.join(log_dir, 'weights') os.makedirs(weights_dir) weights_filepath = weights_dir+"/saved-weights-{epoch:03d}.hdf5" weights_checkpoint = tf.keras.callbacks.ModelCheckpoint(weights_filepath, monitor=['val_vec_len_accuracy', 'val_decoder_accuracy'], save_best_only=False, save_weights_only=True, mode='max') model.reset_states() global lossTEST global vec_len_lossTEST global decoder_lossTEST global vec_len_accuracyTEST global decoder_accuracyTEST global lossTRAIN global vec_len_lossTRAIN global decoder_lossTRAIN global vec_len_accuracyTRAIN global decoder_accuracyTRAIN global lossVALID global vec_len_lossVALID global decoder_lossVALID global vec_len_accuracyVALID global decoder_accuracyVALID lossTEST, vec_len_lossTEST, decoder_lossTEST, vec_len_accuracyTEST, decoder_accuracyTEST = model.evaluate(x=(X_test, y_test), y=(y_test, X_test), verbose=None) lossTRAIN, vec_len_lossTRAIN, decoder_lossTRAIN, vec_len_accuracyTRAIN, decoder_accuracyTRAIN = model.evaluate(x=(X_train, y_train), y=(y_train, X_train),verbose=None) lossVALID, vec_len_lossVALID, decoder_lossVALID, vec_len_accuracyVALID, decoder_accuracyVALID = model.evaluate(x=(X_valid, y_valid), y=(y_valid, X_valid),verbose=None) global history global model_Name global log_Dir global weights_Dir model_Name = model.name log_Dir = log_dir weights_Dir = weights_dir print("Starting the training...") history=model.fit( x=(X_train, y_train), y=(y_train, X_train), epochs=epochs, callbacks=[tensorboard,checkpoint,lrcb,weights_checkpoint], batch_size=batch_size, validation_data=((X_valid, y_valid), (y_valid, X_valid)), verbose=2 ) metrics = model.evaluate(x=(X_test, y_test), y=(y_test, X_test), verbose=0) TIME = tm.time() - start_time print("Time "+model.name+" = %s [seconds]" % TIME) print("\n") print(log_dir) print(weights_dir) return {k:v for k,v in zip (model.metrics_names, metrics)}
train(model, X_train, y_train, X_valid, y_valid, X_test, y_test, batch_size=32, epochs=100, lr=0.001, lr_decay_mul=0.9, lam_recon=32.768)
Starting the training...Epoch 1/100250/250 - 1105s - loss: 434062.7500 - vec_len_loss: 0.3945 - decoder_loss: 13246.5352 - vec_len_accuracy: 0.4999 - decoder_accuracy: 0.0018 - val_loss: 297717.2188 - val_vec_len_loss: 0.3952 - val_decoder_loss: 9085.5957 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 0.0010Epoch 2/100250/250 - 1104s - loss: 433786.9062 - vec_len_loss: 0.2633 - decoder_loss: 13238.1162 - vec_len_accuracy: 0.5013 - decoder_accuracy: 0.0018 - val_loss: 297716.8750 - val_vec_len_loss: 0.2208 - val_decoder_loss: 9085.5918 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 9.0000e-04Epoch 3/100250/250 - 1103s - loss: 433786.7500 - vec_len_loss: 0.2182 - decoder_loss: 13238.1152 - vec_len_accuracy: 0.4946 - decoder_accuracy: 0.0018 - val_loss: 297716.8438 - val_vec_len_loss: 0.2150 - val_decoder_loss: 9085.5928 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 8.1000e-04Epoch 4/100250/250 - 1102s - loss: 433786.7500 - vec_len_loss: 0.2152 - decoder_loss: 13238.1143 - vec_len_accuracy: 0.5036 - decoder_accuracy: 0.0018 - val_loss: 297716.8750 - val_vec_len_loss: 0.2143 - val_decoder_loss: 9085.5918 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 7.2900e-04Epoch 5/100250/250 - 1100s - loss: 433786.5938 - vec_len_loss: 0.2149 - decoder_loss: 13238.1094 - vec_len_accuracy: 0.5056 - decoder_accuracy: 0.0018 - val_loss: 297716.8750 - val_vec_len_loss: 0.2138 - val_decoder_loss: 9085.5918 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 6.5610e-04Epoch 6/100250/250 - 1100s - loss: 433786.7188 - vec_len_loss: 0.2147 - decoder_loss: 13238.1123 - vec_len_accuracy: 0.4988 - decoder_accuracy: 0.0018 - val_loss: 297716.8750 - val_vec_len_loss: 0.2212 - val_decoder_loss: 9085.5918 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 5.9049e-04Epoch 7/100250/250 - 1100s - loss: 433786.5625 - vec_len_loss: 0.2146 - decoder_loss: 13238.1055 - vec_len_accuracy: 0.4915 - decoder_accuracy: 0.0018 - val_loss: 297716.8750 - val_vec_len_loss: 0.2146 - val_decoder_loss: 9085.5918 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 5.3144e-04Epoch 8/100250/250 - 1100s - loss: 433786.6562 - vec_len_loss: 0.2145 - decoder_loss: 13238.1143 - vec_len_accuracy: 0.5064 - decoder_accuracy: 0.0018 - val_loss: 297716.8438 - val_vec_len_loss: 0.2135 - val_decoder_loss: 9085.5918 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 4.7830e-04Epoch 9/100250/250 - 1101s - loss: 433786.5000 - vec_len_loss: 0.2141 - decoder_loss: 13238.1064 - vec_len_accuracy: 0.4991 - decoder_accuracy: 0.0018 - val_loss: 297716.8750 - val_vec_len_loss: 0.2143 - val_decoder_loss: 9085.5918 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 4.3047e-04Epoch 10/100250/250 - 1102s - loss: 433786.5625 - vec_len_loss: 0.2145 - decoder_loss: 13238.1064 - vec_len_accuracy: 0.4969 - decoder_accuracy: 0.0018 - val_loss: 297716.8438 - val_vec_len_loss: 0.2135 - val_decoder_loss: 9085.5918 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 3.8742e-04Epoch 11/100250/250 - 1104s - loss: 433786.5938 - vec_len_loss: 0.2142 - decoder_loss: 13238.1152 - vec_len_accuracy: 0.4906 - decoder_accuracy: 0.0018 - val_loss: 297716.8438 - val_vec_len_loss: 0.2135 - val_decoder_loss: 9085.5918 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 3.4868e-04Epoch 12/100250/250 - 1114s - loss: 433786.5625 - vec_len_loss: 0.2142 - decoder_loss: 13238.1123 - vec_len_accuracy: 0.4956 - decoder_accuracy: 0.0018 - val_loss: 297716.8750 - val_vec_len_loss: 0.2141 - val_decoder_loss: 9085.5918 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 3.1381e-04Epoch 13/100250/250 - 1115s - loss: 433786.5625 - vec_len_loss: 0.2138 - decoder_loss: 13238.1113 - vec_len_accuracy: 0.4996 - decoder_accuracy: 0.0018 - val_loss: 297716.8750 - val_vec_len_loss: 0.2146 - val_decoder_loss: 9085.5918 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 2.8243e-04Epoch 14/100250/250 - 1115s - loss: 433786.5625 - vec_len_loss: 0.2142 - decoder_loss: 13238.1113 - vec_len_accuracy: 0.4940 - decoder_accuracy: 0.0018 - val_loss: 297716.8750 - val_vec_len_loss: 0.2142 - val_decoder_loss: 9085.5918 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 2.5419e-04Epoch 15/100250/250 - 1114s - loss: 433786.6562 - vec_len_loss: 0.2140 - decoder_loss: 13238.1094 - vec_len_accuracy: 0.4855 - decoder_accuracy: 0.0018 - val_loss: 297716.8750 - val_vec_len_loss: 0.2133 - val_decoder_loss: 9085.5918 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 2.2877e-04Epoch 16/100250/250 - 1114s - loss: 433786.6250 - vec_len_loss: 0.2140 - decoder_loss: 13238.1113 - vec_len_accuracy: 0.4985 - decoder_accuracy: 0.0018 - val_loss: 297716.8750 - val_vec_len_loss: 0.2137 - val_decoder_loss: 9085.5918 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 2.0589e-04Epoch 17/100250/250 - 1115s - loss: 433786.4688 - vec_len_loss: 0.2138 - decoder_loss: 13238.1074 - vec_len_accuracy: 0.4939 - decoder_accuracy: 0.0018 - val_loss: 297716.8438 - val_vec_len_loss: 0.2135 - val_decoder_loss: 9085.5918 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 1.8530e-04Epoch 18/100250/250 - 1116s - loss: 433786.5312 - vec_len_loss: 0.2135 - decoder_loss: 13238.1113 - vec_len_accuracy: 0.5045 - decoder_accuracy: 0.0018 - val_loss: 297716.8750 - val_vec_len_loss: 0.2136 - val_decoder_loss: 9085.5918 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 1.6677e-04Epoch 19/100250/250 - 1118s - loss: 433786.6250 - vec_len_loss: 0.2140 - decoder_loss: 13238.1201 - vec_len_accuracy: 0.4976 - decoder_accuracy: 0.0018 - val_loss: 297716.8750 - val_vec_len_loss: 0.2137 - val_decoder_loss: 9085.5918 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 1.5009e-04Epoch 20/100250/250 - 1118s - loss: 433786.6875 - vec_len_loss: 0.2139 - decoder_loss: 13238.1113 - vec_len_accuracy: 0.4969 - decoder_accuracy: 0.0018 - val_loss: 297716.8750 - val_vec_len_loss: 0.2139 - val_decoder_loss: 9085.5918 - val_vec_len_accuracy: 0.5000 - val_decoder_accuracy: 0.0055 - lr: 1.3509e-04
For the DRCaps model, the authors observed that only 20 epochs are enough to obtain a good performance. However, mine's accuracy and loss did not change. The decoder loss also seems unusually high compared to the vector length loss, which makes me believe something is wrong with my implementation.
In addition, the images I used are the same as GBRAS-Net's. I also trained GBRAS-Net myself and I was able to reproduce their results.
Any input is greatly appreciated. Thank you so much.