Exploring latent regularisation bias in disentangled representation learning
Date
2022
Authors
Journal Title
Journal ISSN
Volume Title
Publisher
Abstract
The field of representation learning involves learning data representations (or features) that capture the structure and relationships implicit in data rather than engineering them. Latent (hidden) variables are high-level data representations inferred indirectly from data via statistical models. Representation learning assumes data is generated via a process characterised by higher-order generative factors. When learning disentangled latent representations, the aim is individual latent variables sensitive to changes in only one generative factor. Disentangled representations are interpretable and provide insight into what a model has learnt.
The most popular disentangled representation learning model is Higgins et al. [2017]’s β VAE. It is an extension of Kingma and Welling [2013] and Rezende et al. [2014]’s VAE, which uses variational inference (VI), framing inference as optimisation of the Evidence Lower BOund (ELBO) objective. VAEs are also Bayesian models, with a prior belief of the latent structure before data is observed, as captured by the prior latent distribution. The framework fits a latent posterior distribution to data, regularising the resultant latent distribution by keeping it “close” to the prior.
The closeness between the latent prior and posterior is enforced by penalising the ELBO using distance metric called the Kullback–Leibler (KL) divergence between these distributions. The latent posterior distribution is thus encouraged not to deviate too much from the latent prior distribution. Notably, the KL divergence is asymmetrical, so swapping the argument distributions around results in a different distance metric. Furthermore, the chosen direction of the KL divergence results in different behaviour when the divergence is minimised. This, in turn, encourages different latent posterior solutions.
The reverse KL divergence is typically used in VI and encourages “mode-seeking” behaviour in the latent space, favouring under-dispersed solutions relative to the latent prior. Conversely, the forward KL divergence results in “mean-matching” behaviour in the latent space, favouring over-dispersed solutions relative to the latent prior. The β VAE includes a β hyperparameter factor to the reverse KL divergence in the ELBO, resulting in hyperparameter constrained reverse KL latent regularisation. It thus exhibits “mode-seeking” behaviour, favouring under-dispersed solutions relative to the latent prior.
Our study assesses the impact of the KL divergence direction on the resultant latent posterior. Since the goal of disentangled representation learning is latents that capture generative factors in an isotropic manner, we assess the impact multimodal generative spaces have on the resultant posterior solution when the KL divergence direction is varied. To facilitate this investigation, we extend Higgins et al. [2017]’s β VAE to include an additional hyperparameter constrained forward KL latent regulariser, deriving our model the βγ VAE. Furthermore, we construct a collection of supervised datasets, each with a different number of generative space modes called mSprites. Finally, impacts in the study are assessed using information-theoretic disentangling metrics.
When using the reverse KL for latent regularisation, we find that multimodal generative spaces distort the overall information content captured by the learnt representation. This is related to the mode-seeking behaviour of the reverse KL, as evidenced by reduced fit from increased generative space modality. This distortion may be remedied by introducing an additional constrained forward KL for latent regularisation, as done in our βγ VAE. The impact of multimodal generative spaces on disentangled representation learning, however, is less clear. Our study provides evidence that multimodal generative spaces negatively distort axis alignment between latent and generative dimensions. However, it is not clear that this necessarily hinders disentangled representation learning. Finally, we observe that while our βγ VAE can improve some metrics in disentangled representation learning, eliminating the impact of multimodal generative spaces, it is not consistent when all disentangling metrics are considered, and thus results are less robust. Our βγ VAE is thus suitable for representation learning with inconsistent evidence that it may also be useful in disentangled representation learning.
Our findings are summarised as follows (the related research question addressed is given in brackets): ·
6.2 Multimodal generative spaces distort global latent fit of β VAE, particularly at lower values of β (3.2.1) ·
6.3 Multimodal generative factors do not consistently distort disentangled representation learning, but improve axis alignment. (3.2.2) ·
6.4 βγ VAE has consistently better global latent fit than β VAE and eliminates negative impacts of multimodal generative space (3.2.3) ·
6.5 βγ VAE does not consistently learn better disentangled representations than β VAE, nor does it consistently eliminate impact of multimodal generative spaces (3.2.4)
Our key contributions are mSprites, a series of supervised datasets designed to investigate the impact of multimodal generative spaces on disentangled representation learning, the βγ VAE, a model with proven merit for representation learning that uses an additional constrained forward KL for latent regularisation, and empirical evidence validating the findings mentioned above.
Description
A dissertation submitted in fulfilment of the requirements for the degree of Master of Science in Computer Science to the Faculty of Science, University of the Witwatersrand, Johannesburg, 2022
Keywords
Representation learning, Learning data representations, Latent (hidden) variables