Quantcast
Channel: Active questions tagged python - Stack Overflow
Viewing all articles
Browse latest Browse all 16891

how to build a vae without reshape error?

$
0
0

I build a VAE following https://blog.tensorflow.org/2019/03/variational-autoencoders-with.html

PROBLEM: I thick the problem is in the decoder when I want to reshape the latent Distribution i face this ERROR:

Input to reshape is a tensor with 512 values, but the requested shape has 3200[[{{node model/sequential_12/module_wrapper_10/independent_bernoulli_7/IndependentBernoulli/Reshape}}]] [Op:__inference_train_function_7497]

My CODE:

`encoded_size = 10input_shape = (10,10,1)prior = tfd.Independent(tfd.Normal(loc=tf.zeros(encoded_size), scale=1),                        reinterpreted_batch_ndims=1)def get_encoder(input_shape):    layer = tfk.Sequential([        tfkl.Conv2D(kernel_size=(3, 3), filters=32, activation='swish', padding='VALID', input_shape=input_shape),        tfkl.MaxPooling2D(pool_size=(2, 2)),        tfkl.Conv2D(kernel_size=(3, 3), filters=64, activation='swish', padding='VALID', input_shape=input_shape),        tfkl.MaxPooling2D(pool_size=(2, 2)),        tfkl.Flatten(),        tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(encoded_size),               activation=None),        tfpl.MultivariateNormalTriL(        encoded_size,        activity_regularizer=tfpl.KLDivergenceRegularizer(prior)),])    return layerdef get_decoder(input_shape):    layer = tfk.Sequential([        tfkl.InputLayer(input_shape=[encoded_size]),        tfkl.Reshape([1, 1, encoded_size]),        tfkl.Conv2DTranspose(kernel_size=(2, 2), filters=64, strides=(2, 2), activation='swish', padding='VALID'),        tfkl.Conv2DTranspose(kernel_size=(2, 2), filters=32, strides=(2, 2), activation='swish', padding='VALID'),        tfkl.Flatten(),        tfpl.IndependentBernoulli(input_shape, tfd.Bernoulli.logits),])    return layerencoder = get_encoder((10,10,1))decoder = get_decoder((10,10,1))vae = tfk.Model(inputs=encoder.inputs,                outputs=decoder(encoder.outputs[0]))vae.summary()def nll(y_true, y_pred):    return -y_pred.log_prob(y_true)vae.compile(optimizer= tfk.optimizers.Adam(learning_rate=1e-3),            loss= nll)history = vae.fit(X_train,X_train,            epochs=15,            validation_data= (X_val,X_val))`

Viewing all articles
Browse latest Browse all 16891


<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>