Abstract
Medical imaging research is often limited by data scarcity and availability. Governance, privacy concerns and the cost of acquisition all restrict access to medical imaging data, which, compounded by the data-hungry nature of deep learning algorithms, limits progress in the field of healthcare AI. Generative models have recently been used to synthesize photorealistic natural images, presenting a potential solution to the data scarcity problem. But are current generative models synthesizing morphologically correct samples? In this work we present a three-dimensional generative model of the human brain that is trained at the necessary scale to generate diverse, realistic-looking, high-resolution and morphologically preserving samples and conditioned on patient characteristics (for example, age and pathology). We show that the synthetic samples generated by the model preserve biological and disease phenotypes and are realistic enough to permit use downstream in well-established image analysis tools. While the proposed model has broad future applicability, such as anomaly detection and learning under limited data, its generative capabilities can be used to directly mitigate data scarcity, limited data availability and algorithmic fairness.
Similar content being viewed by others
Main
In computer vision, the rapid progress of deep learning methods was underpinned by huge datasets such as ImageNet1, TextOCR2 and COCO3, which contain around 1.2 million, 1 million and 170,000 samples respectively. Current volumetric medical imaging datasets pale in comparison: UK Biobank (UKB)4,5, one of the largest datasets available, has approximately 40,000 samples, the Alzheimer’s Disease Neuroimaging Initiative (ADNI)6 has roughly 3,000 images and Medical Segmentation Decathlon7 roughly 3,000 images across 10 organs. This limitation, combined with the 3D nature of medical images, results in datasets that do not cover all anatomical, pathological and signal variations. At present, state-of-the-art medical imaging algorithms are developed on highly curated datasets, leading to potential biases regarding demographics and acquisition parameters that may adversely affect particular populations. Owing to the restrictive nature and volume of medical imaging data, deep learning models are often limited in scale, further hindering the deployment of successful research in clinical environments. To mitigate the need for specialized and costly acquisition equipment, as well as the restrictions entailed by regulations and complex maintenance8, generative modelling can be seen as a viable solution. Not only might it allow the open sourcing of a large corpus of data, but it also enables the prevalence of confounding variables such as pathologies and ethnicities to be balanced.
Variational autoencoders (VAEs)9 are the classic baselines for generative modelling in computer vision, but they are known to suffer from blurry reconstructions. Generative adversarial networks (GANs) are the current state of the art in generative modelling of the brain10,11,12. The generator tries to learn the underlying distribution of images by generating realistic-looking images, aiming to fool the discriminator into labelling them as real images13. The main drawbacks of current approaches are the commonly known pitfalls of GANs, including unstable training regimes, mode collapses (always generating the same images), failure to converge14 and the lack of mechanisms for a fine-grained conditioned generation. The α-Wasserstein GAN (α-WGAN)11 and cycle consistent embedding GAN (CCE-GAN)10 models rely on an encoder–decoder architecture in which images are encoded into a latent space and then decoded; additional synthetic images can then be decoded from random noise. This architecture aims to bring the real image latent representation close to random noise, such that after training randomly sampled noise will be decoded into meaningful images. On the other hand, HA-GAN12 relies on having small-resolution and slice-based high-resolution generators and discriminators. The generators and discriminators partially share weights, thus enabling high-resolution sampling during inference. α-WGAN and CCE-GAN lack any form of conditioning while HA-GAN12 includes class-based conditioning. These models present either basic or non-existing conditioning, but none have quantified how morphologically persistent the synthetic samples are—a paramount trait if we are to use such methods.
The most popular application of generative models in medical imaging is data augmentation, mainly for classification and segmentation tasks15, with the aim of increasing the diversity of the phenotype by grounding the generative process onto a ground truth segmentation or inpainting segmentations of different pathologies onto healthy subjects from the training dataset16. However, none of those methods truly expand the healthy phenotype manifold or add pathological phenotypes. A fully unsupervised 2D generative model of both human brain parcellations and corresponding images has recently been reported17.
A subject’s phenotype is determined by several covariates, such as demographic characteristics and the presence or absence of pathologies. When analysed together, multiple subjects with the same covariates of interest will determine a population-level morphological statistic. In line with the synthetic data desiderata from ref. 18, any usable synthetic dataset should share most, if not all, the statistical properties of the real one. Towards that end, a synthetic model must have a controllable sampling mechanism. While current state-of-the-art generative models of natural images19,20 have demonstrated astonishing control over the generation of synthetic samples, their ability to preserve the morphology of the generated objects remains to be validated. Synthetic images are often validated using qualitative human evaluation and relatively simple classifier-based metrics such as the Frechet inception distance (FID)21, which are insufficient when applied to medical data. In ref. 22 image moments were used to asses pixel distribution alignment and a reader study was used to assess the synthetic sample realism of medical data; however, no quantitative morphological analysis was performed. Without validating the preservation of patient and disease morphology, one would not be able to trust downstream data use and subsequent analysis.
In natural image synthesis, the combination of a vector-quantized VAE (VQ-VAE) and transformer has recently23,24 been shown to generate high-resolution realistic images. They work by projecting an image into a quantized space where the transformer then learns the conditional probability of tokens in an autoregressive fashion. Afterwards, the trained transformer can synthesize new latent samples and pass them to the VQ-VAE to be decoded. In our previous work25,26 we have shown that the pipeline depicted in Fig. 1 can be used in clinical settings for anomaly detection, which represents a much more locally restricted task of token inpainting. In this work, we show that such models can be scaled up to generate realistic-looking morphologically preserving synthetic samples. To assess their realism, we use classic metrics such as the FID21 and maximum mean discrepancy (MMD)27 with image diversity being assessed through the multi-scale structural similarity index measure (MS-SSIM)28 and four-gradient-structural similarity index measure (4-G-SSIM)29,30 as used in ref. 11. To study the morphology of the generated data, we compare the tissue segmentations and subcortical volumes generated by SynthSeg31 between the synthetic and real samples, assess the distribution alignment between cortical thicknesses of synthetic and real samples as measured by FastSurfer32 and look at the focal difference between real and synthetic subpopulations using voxel-based morphology (VBM)33. We will make both the trained models and code available to the research community, together with two sampled synthetic datasets: a UK Biobank (UKB)4 based dataset of 100,000 healthy participants and 1,000 pathological and cognitively normal participants based on Alzheimer’s Disease Neuroimaging Initiative (ADNI)6.
The VQ-VAE and transformer two-stage training and inference pipeline is shown. During the VQ-VAE training (blue arrow) a Codebook representation is learned in order to minimise the reconstruction loss between the Input image and Output image. For stability a consistency loss is applied to the Codebook elements in regards to the Encoded image. For the Transformer training the autoregressive conditional generation is learned via a cross-entropy loss applied to the Rasterized code. For inference the Transformer is conditioned on the variables of interest and generates on token at a time, once the whole sequence is generated it is reshaped into a Tokenized encoded image and fed through the Codebook to the Decoder in order to obtain the Output image.
Results
We used two datasets, the UKB dataset formed of 39,679 neurologically healthy participants and the ADNI dataset, which encompasses 765 unique participants. Further details on the datasets, preprocessing and augmentation can be found in Supplementary Section A.
The methods we are comparing against are the state-of-the-art medical imaging generative models Hierarchical Amortized GAN (HA-GAN)12, CCE-GAN10 and Least Squares GAN (LS-GAN)34. In the ‘Quantitative image fidelity evaluation’ section we present a quantitative evaluation of their realism and in the ‘Morphological evaluation’ section we assess the morphological correctness of the synthetic samples. Finally, we provide a detailed ablation study of our pipeline, showcasing how each design’s choice and model’s scale influenced the results. This can be found in Supplementary Section B.
All baselines underwent a 25 experiment grid search to optimize them on our datasets and were trained on one NVIDIA A100 DGX Superpod with the maximum batch size possible for 20,000 iterations. In line with ref. 12, CCE-GAN was modified to work at an image size of 1283 and its results were upsampled. The same augmentation pipeline was applied to all models, while the intensity thresholding and normalization transformations were configured per baselines’ official implementations. The T1w VQ-VAE models were trained on one NVIDIA A100 DGX Superpod, while the FLAIR and \({T}_{2}^{\,* }\) VQ-VAE models that were trained for the ‘Quantitative image fidelity evaluation’ section generalizability were trained on a single NVIDIA V100 32 GB card. The small transformers were trained on four NVIDIA A100 DGX Superpods while the big transformers were trained on eight NVIDIA A100 DGX Superpods.
Quantitative image fidelity evaluation
To evaluate the sample realism, we trained the models on all of the available data and, where possible, we sampled in a controlled manner. The synthetic datasets had the same number of samples as their real counterparts to guarantee a fair comparison. As shown in Table 1, the proposed model outperformed all of the baselines by a wide margin, ranging up to two orders of magnitude in the case of FID and MMD. The diversity measured by the MS-SSIM and 4-G-SSIM was roughly the same. This should be considered together with the poor image sampling quality displayed in Extended Data Fig. 1 and Fig. 2 for other baselines. Altogether, our model generated sharper images that better adhere to the dataset’s underlying distribution. The HA-GAN sampling quality came closest to the proposed models, but it had apparent artefacts within the white matter, which did not align with the real morphology as showcased in the ‘Morphological evaluation’ section. While LS-GAN provided better FID than CCE-GAN, it lacked distribution alignment, as shown by the MMD. This discrepancy can be attributed to CCE-GAN working on 1283 downsampled space, which was then upsampled for quantitative metric calculations. Furthermore, manual inspection revealed that LS-GAN shows signs of mode collapse due to its reduced diversity in sampling compared with CCE-GAN.
Random synthetic samples from the LS-GAN, CCE-GAN and HA-GAN baseline trained on ADNI together with our proposed model trained on ADNI and the a real participant from the datasets. All three planes of visualization (axial, coronal and sagittal) are presented. Additionally an axial zoomed in visualization of the cerebellum is showcased due to the higher number of cortical folds entailing more high-frequency details. All the visualization planes are from the same synthetic samples and real participant.
To gauge how well the pipeline generalized, we trained a big VQ-VAE and a small transformer on the \({T}_{2}^{\,* }\) and FLAIR images from the UKB dataset. Similarly, we sampled the same number of synthetic samples as their real counterparts to guarantee a fair comparison. As shown in Table 2 and Extended Data Fig. 2, the pipeline showed generalizability potential when looking at the FLAIR results but underperformed with \({T}_{2}^{\,* }\) images. The main culprit of the difference in performance was the difference in the scaling of the VQ-VAE between the T1 weighted and other models. This manifested as an underutilized latent representation as measured by the average perplexity of the VQ-VAE’s codebook elements between the \({T}_{2}^{\,* }\) and FLAIR models when compared with the T1 weighted one. Quantitatively, this resulted in 4-G-SSIM values of 0.392 ± 0.019 and 0.513 ± 0.024 for \({T}_{2}^{\,* }\) and FLAIR, respectively. A manual inspection showed that this resulted, from a qualitative point of view, in a lower cortical fold diversity for the FLAIR model. In the case of the \({T}_{2}^{\,* }\) model, as observed in Extended Data Fig. 2, the VQ-VAE was unable to model the super-resolved anisotropic images. The \({T}_{2}^{\,*}{{{\rm{s}}}}\) images were originally acquired at a voxel resolution of 0.76 mm × 0.76 mm × 1 mm and super-resolved during preprocessing to 1 mm × 1 mm × 1 mm. Overall, with additional improvements to the codebook update strategy, this pipeline should yield good generalizability.
Morphological evaluation
VBM
To assess whether the focal differences between population subgroups were preserved in the synthetic data, we used VBM33 as implemented in the SPM software package35 (https://www.fil.ion.ucl.ac.uk/spm/). VBM identifies the morphological differences in modulated tissue compartments between group-aligned selected groups through a generalized linear model and associated statistical tests across all voxels. For a detailed description of how VBM works, please see ref. 36. In line with ref. 37, all t-statistics maps were corrected to minimize the spurious effects of low-variance areas. In this experiment, two datasets were created. First, for the healthy dataset, we aimed to assess the morphological differences between a small ventricle population and a big ventricle population, defined on the basis of the ventricular cerebrospinal fluid (CSF) segmentation. This experiment was chosen as the expected pattern of differences should be trivial and localized in the ventricular region. Specifically, the small ventricle population was formed of 160 random participants from the first quintile, while 160 subjects formed the big ventricle from the last (fifth) quintile. Second, for the pathological dataset, we evaluated the differences between 145 cognitively normal participants and 185 with Alzheimer’s disease, defined by their clinical diagnoses; the grey matter segmentation was chosen to visualize the difference between populations.
Contrary to the proposed method, which was conditioned on the actual ventricular volume size, HA-GAN was conditioned on a discretized class of ventricular sizes as determined by a 5-quantile of the healthy dataset and on the cognitively normal/Alzheimer’s disease binary label for the pathological dataset. This is due to the limitation of HA-GAN, which requires class-based conditioning. While our model and HA-GAN were sampled in a controlled manner, LS-GAN and CCE-GAN do not allow for conditioning, so a separate model was explicitly trained for each population. Note that due to the conditional sampling capabilities of our model, its VBM analysis factors out age and sex. The template scripts used for VBM can be found in Supplementary Section D.
Extended Data Fig. 3 and Fig. 3 show that the proposed model is in significantly better agreement with the real data than competing methods. As shown in Extended Data Fig. 3, for the healthy dataset, our model captured the differences in the ventricles, subgenual area and partially in the left and right insulae. However, HA-GAN could not sample in a controlled manner and could not model the ventricular size, and also overemphasized the CSF between the meninges and the brain. LS-GAN and CCE-GAN were unable to properly model the brain’s structure, resulting in uninformative VBM maps. Our model, however, as shown in Fig. 3, was in near-perfect agreement with the real data for the pathological dataset as it captured the overall morphological differences in the putamen, hippocampal area, temporal gyrus, inferior occipital gyrus and fusiform gyrus. Even though the other baselines seem to capture the morphological differences, they were more in line with the general atrophy of the brain due to ageing, lacking the focal increase in t-statistics showcased in the real data.
Maps for all models and real data are shown. The displayed t-statistics range is [0, 8] (colour scale) and is based on the VBM of the real data. The t-statistics were corrected following the procedure used in ref. 37.
Subcortical volume analysis
To further validate the morphology of the data, we compared regional volumes between real and synthetic data as estimated by SynthSeg31 (https://github.com/BBillot/SynthSeg). All images were segmented into grey matter, white matter and CSF. Table 3 shows that the proposed model better adhered to the original tissues’ volumetric distribution than other baselines. We also ran a two-sided Mann–Whitney U-test38 between the models’ segmentation distributions and the real ones and Glass’s Δ effect size. We evaluated the distribution shift through the Wasserstein distance and Kullback–Leibler divergence. The results in Table 3 and Extended Data Table 1 show that our model’s synthetic samples better adhered to the real distribution. This can be seen by looking at the Wasserstein distance and Kullback–Leibler divergence specifically, for which our model and HA-GAN had the best values across all datasets. The UKB results show that our model outperformed the HA-GAN model, only underperforming on the Wasserstein distance of the CSF tissue. On the ADNI dataset, HA-GAN showed superior performance that could be attributed to the lack of diversity in our model, as seen in the MS-SSIM and 4-G-SSIM values in Table 1.
As downstream tasks do not rely solely on tissue statistics, we also analysed the individual subcortical volumes of the SynthSeg parcellations (Extended Data Table 2). We ranked the models by calculating the average rank across all subcortical regions as measured by the Wasserstein distance between the synthetic and real populations. As can be seen in Extended Data Table 2, our method had the best rank across UKB and ADNI datasets. Those results, combined with the VBM results in Extended Data Fig. 3 and Fig. 3, show that our model had better global and local morphological preservation.
Given that downstream tasks can silently fail by giving results that do not adhere to the expected behaviour, we quantified the failure rate of the SynthSeg pipeline. This was done by measuring the proportions of the synthetic samples with at least one region with an absolute Z score greater than 5, meaning that they fall outside the 99.99th percentile. For any region, a failure rate above 0.01% can be attributed to the model’s performance. The results for the SynthSeg pipeline can be found in Extended Data Table 2; our model outperformed all of the baselines as it has closest failure rate to real samples. A failure rate lower than real data can be explained by the model not covering the extreme modes of the real datasets that fall outside the 99.99th percentile. An extreme case of this can be seen for the CCE-GAN model, which achieved a 0% failure rate. On the other extreme is LS-GAN that, from the quantitative results in Table 1, achieved a 100% failure rate.
Cortical thickness analysis
Cortical thickness estimation is one of the most sensitive biomarkers that can be used to assess a wide range of brain conditions, ranging from ageing to neurological disorders39. We used FastSurfer32 (https://github.com/Deep-MI/FastSurfer) on 100 randomly selected cognitively normal participants from the ADNI dataset to estimate the mean cortical thickness. FastSurfer is a convolutional neural network-based surface-based thickness analysis tool that parcellates and quantifies the cortical thickness of 32 regions per hemisphere. We used the same methodology as in the ‘Subcortical volume analysis’ section for the FastSurfer cortical parcellations. As can be seen in Extended Data Table 3, our model had the best performance and the closest failure rate to the real data. Having the best rank in both Extended Data Table 2 and Extended Data Table 3 could be attributed to better modelling of the variability between participants and structural coherence. Those results further reinforce the morphological preservation capabilities of our model. Interestingly, LS-GAN had a better failure rate than HA-GAN despite not performing morphologically or quantitatively as well.
Train on synthetic, test on real analysis
We trained a simple fully convolutional network (SFCN)40 and used it to predict the age of the synthetic samples and the SynthSeg segmentations to calculate the ventricular size of the synthetic samples. Afterwards, we calculated the Pearson correlation coefficient between the input conditioning of the given samples and the regressed age and ventricular size. The correlations and results are illustrated in Extended Data Fig. 4. Looking at the correlations between the conditional age versus the predicted age and the conditional ventricular size versus the SynthSeg calculated ventricular size we found a statistically significant (P < 0.001, 0.0005 Bonferroni corrected), but still relatively weak, correlation of 0.37 for age and a moderate correlation of 0.47 for ventricular size. We also found that smaller transformer models had poorer correlations during internal testing, suggesting that model capacity might limit performance.
We quantified the clinical viability and usefulness of the proposed model following the train on synthetic, test on real paradigm outlined in ref. 41. The SFCNs40 were cross-validated to regress age for the UKB dataset. Each SFCN’s parameters were first optimized and evaluated on the synthetic samples. Afterwards, a new model with the same parameters was trained and evaluated on the real dataset. Each synthetically trained model was evaluated on the full real data to assess its generalizability capability without any fine-tuning. While the SFCN model trained on the real data and tested on real data achieved a mean absolute error (presented as mean ± s.d.) for regressing the age of 3.05 ± 2.47 yr, the model trained on synthetic and tested on synthetic data achieved a mean absolute error of 4.05 ± 3.74 yr. When the synthetic-trained SFCN was tested on real data, it achieved a mean absolute error of 5.34 ± 3.73 yr. The proposed model shows promise for application within the training on synthetic, testing on real paradigm41. The SFCN model trained on synthetic healthy UKB data generated by our model had results comparable to those of the model trained on the real dataset. The drop in performance could be attributed to the degree of compression needed for the transformer and the correlations that the conditionings have, as shown in Extended Data Fig. 4. This puts a lower-bound limit on the performance that a synthetically trained model can have with the real data.
Discussion
In this work we developed a deep generative model capable of creating morphologically preserving realistic 3D images of the brain. The quantitative analysis shows that the proposed model achieved state-of-the-art performance in image synthesis. At the same time, the voxel-based morphology, together with subcortical and cortical analyses, demonstrated the superior morphological preservation of our model. Our training on synthetic, test on real experiments suggests that synthetic data could one day replace real data for AI model training in privacy-sensitive fields such as healthcare. The main limitations of this work were the application of quantitative and morphological analysis to a single organ, the brain, and to a single diagnosis tool, the MRI scans. We hope that this work paves the way towards a more principled evaluation of synthetic data in the field of healthcare, and more specifically brain imaging, by incorporating morphological analyses alongside the classical quantitative ones.
Methods
First, we review how the VQ-VAE and transformer pipeline works. Following that we introduce our VQ-VAE and the transformer. Lastly, we detail the implementation of the model (see Fig. 3 for the models’ architecture).
Background
VQ-VAEs42,43 can successfully synthesize high-resolution natural images23,44. VQ-VAEs are composed of an encoder Enc that takes as input an image \(X\in {{\mathbb{R}}}^{H\times W\times D}\) and projects it to a smaller latent representation \(Z\in {{\mathbb{R}}}^{h\times w\times d\times {n}_{z}}\) where H, W, D and h, w, d are the height, width and depth of the input image and latent representation respectively with nz being the dimensionality of the latent embedding’s vector. Afterwards, Z is passed through the quantization block Quant where an element-wise quantization is done. Each spatial code \({Z}_{ijk}\in {{\mathbb{R}}}^{{n}_{z}}\) is replaced by its nearest codebook element \({{e}}_{m}\in {{\mathbb{R}}}^{{n}_{z}},m\in 1,\ldots ,M\) where the vocabulary’s size is denoted M, thus obtaining Zq, the quantized representation. The codebook’s elements are learned in an online fashion through exponential moving average (EMA) as part of the VQ-VAE training procedure. Given Zq, the decoder Dec tries to reconstruct the observations \(\hat{X}\in {{\mathbb{R}}}^{H\times W\times D}\). This is outlined by the blue flow path in Fig. 3.
In the second stage, a generative model Gen is trained on the latent discrete representation. The representation is obtained by replacing the codebook elements of Zq with their respective indices, thus obtaining Ziq. In refs. 42,43, a PixelSNAIL45 autoregressive model was originally employed. Later, ref. 23 showed improved performance by replacing PixelSNAIL with a transformer46. As transformers work on sequences, Ziq is flattened in a row-major fashion, thereby obtaining Siq. The transformer is then used to model Siq by minimizing the conditional distribution \(p({{S}_{{\mathrm{i}}q}\!}_{j+1})=p({{S}_{{\mathrm{i}}q}\!}_{j}| {{S}_{{\mathrm{i}}q}}_{ < j},c)\) where p is probability, \({{S}_{{\mathrm{i}}q}\!}_{j}\) is the jth element of Siq and c are conditioning variables. The green flow path in Fig. 3 shows this, together with an on-the-fly augmentation by passing augmented images through the encoder and quantization before flattening, such that Siq benefits from the augmentation.
The transformer46 is composed of multiple layers, each equipped with an (self-)attention mechanism. The attention mechanism can capture the interaction between the elements of Siq elements regardless of their relative position to each other. This is achieved by projecting the intermediary representation into three vectors: query, key and value; and can be written as follows:
where T stands for transpose, Q is the query tensor, K is the key tensor, V is the value tensor and dK is the dimension of K.
The projection happens once per head, where each head will learn to attend to different concepts of the sequence. During training, the model aims to predict each of \({{S}_{{\mathrm{i}}q}\!}_{j}\) based on \({{S}_{{\mathrm{i}}q}}_{ < j}\) and c. This happens due to the autoregressive training method, in which the attention mechanism is masked such that it only attends to the elements before it. The conditioning is applied at every other attention block where the inputs to the key and value layers are replaced by a vector formed from the concatenation of the arbitrary conditionings. This approach is known as cross-attention and allows the network to attend to the conditionings offered, thus enabling controlled sampling during inference46,47.
At inference, the model predicts the probability distribution of a single token at a time \(p({{S}_{{\mathrm{i}}q}\!}_{j}| {{S}_{{\mathrm{i}}q}}_{ < j},c)\). On the basis of the estimated probability, it randomly picks one of the tokens. This procedure is applied for each token location sequentially until the full sequence is sampled. Afterwards, Siq is reshaped back to a 3D tensor into Ziq and passed through the quantizer to obtain Zq, which is then decoded into the image space. This step of the pipeline is represented by the purple flow line in Fig. 3.
Descriptive quantization for transformer use
For the transformer to synthesize meaningful samples, the VQ-VAE’s latent representation needs to meet the following two requirements: it should be small enough to meet the memory constraints of the transformer architecture and be sufficiently descriptive such that the reconstruction is both structurally coherent and realistic.
Towards the first requirement, our VQ-VAE projects the input image from a tensor \(X\in {{\mathbb{R}}}^{160\times 224\times 160}\) to \({Z}_{{\mathrm{i}}q}\in {{\mathbb{N}}}_{0}^{10\times 14\times 10}\). This compresses the image spatially by a factor of 4,096, or 14,564 if we take into account the data types conversion from floating point 32 to integer 8. This results in a sequence Siq of length 1,400, which is sufficiently small to be modelled by a transformer. Zq was originally learned by gradient descent as follows:
where sg is the stop-gradient operation. As per refs. 42,43, the second component in equation (2) is replaced by equation (3), where \({n}_{\mathrm{i}}^{(t)}\) is the number of vectors in Z that will be quantized at exponential moving average update timestep t to the codebook element \({{Z}_{q}}_{\mathrm{i}}\). The hyperparameters γ and β control the decay of the EMA from equation (3) and the commitment of the encoder output to a certain quantized element, respectively. This changes the learning procedure of Zq from a gradient-descent-based one to an online EMA procedure.
To meet the second requirement (that is synthetic samples being structurally coherent and realistic), significant change to the VQ-VAE loss was necessary to appropriately cater to and stabilize the training on 3D medical data. We started with the classical mean squared error (MSE) \({{{{\mathcal{L}}}}}_{pix}=MSE(X,{\hat{X}}\,)\) that works purely on the pixel domain. We took inspiration from ref. 48 and used the MSE on the amplitude of the fast Fourier transformation of X and \(\hat{X}\), which can be written as \({{{{\mathcal{L}}}}}_{freq}=MSE(| FFT(X\,)| ,| FFT(\hat{X}\,)| )\) to improve the sharpness of the reconstructions. Following that, we added a perceptual loss based on the LPIPS49 package using AlexNet50 as the feature extractor. Note that the LPIPS pre-trained network is 2D, so owing to the 3D nature of medical images this loss was applied on a slice-wise basis. More specifically, we applied it to 50% of randomly selected slices of each axis, resulting in:
where ax, sag and cor representing the axial, sagittal and coronal planes, respectively.
With LPIPS loss we aimed to increase the perceived quality, as well as increasing the convergence, similarly to the reasoning presented in ref. 51. Finally, we used an adversarial loss as the images showed intensity patterns that were not realistic:
This is based on the Patch-GAN52 model as per ref. 23 paired together with the LS-GAN loss34, and provided more stable and reproducible behaviour, which can be written as:
A quantitative assessment of the effect of and need for these metrics is presented in the ablation study in Supplementary Section B. Overall, the proposed VQ-VAE loss function is:
Autoregressive modelling of the brain
Clinically usable, fully unsupervised generative modelling of medical images should preserve morphology. To ensure that our synthetic samples do so, we enhanced a baseline transformer implementation from refs. 42,43 with a series of techniques that have been tailored to facilitate controlled sampling sequences that also have whole-sequence contextual information as well as 3D inductive bias. Furthermore, one of the techniques53 has been generalized from 2D to 3D. In the rest of this section we detail each of the techniques and the reasons we chose it.
As the transformer takes a 1D sequence as input Siq, it loses the intrinsic spatial inductive bias of Ziq. To reintroduce it, we implemented a 3D generalization to the relative positional bias53 approach. It works by adding a bias term to the QK vector pre-softmax in the attention mechanism. The bias term is based on the product of the directional relative positional embedding of each voxel. As we were working on a full latent representation we did not quantize the distance with a piecewise function as done by ref. 53 because our Ziq had 46 times fewer spatial positions than an ImageNet sample. This is due to the high compression rate that was achieved with the VQ-VAE.
Thus, the transformer’s hidden state should contain contextual information about the whole sequence. For that, we introduced an enhanced recurrence, similarly to the procedure in ref. 54. It is a simple mechanism that passes the outputs of the next layer from the previous sampling step to the current layer in the current sampling step. Furthermore, it enhances future lower-level representation with higher-level representations.
Owing to the size of the sequence and the scaling requirements, we employed root mean squared error normalization55. This was shown to be the best-performing normalization variant for transformers in ref. 56 and was used in large language models such as those in refs. 57,58. We noticed that this approach was of paramount importance for the models to converge in a reasonable amount of time.
In line with ref. 59 and to enhance our dataset size, we augmented the samples fed into the transformer with the same augmentation protocol used during the VQ-VAE training. During training, we also conditioned our model on the random parameters sampled for the augmentations such that the model could estimate and sample during inference from the original distribution.
Finally, inspired by the success of the attention-gating mechanism in ref. 60, we gated our attention block outputs with the transformer’s input. This gave the model more control over each update and, in the case of cross-attention, increased the correlation between the conditioning and input.
Implementation details
VQ-VAE
To showcase the fine-tuning capabilities, the value of the pre-trained models for the wider research community and in line with ref. 61, the UKB VQ-VAE was trained for 350 epochs for the ablation study. Following that, it was further trained for 100 epochs until convergence and then it was fine-tuned for another 100 epochs on the ADNI dataset. Adam62 was used as an optimizer with a learning rate of 0.000165 for the Enc and Dec networks and 0.00005 for Dis. The batch size was set to 8. The learning rates were selected via a grid search of 16 experiments. The loss components scaling factor αpix, αfreq, αpcp and αdis were set to 1, 1, 0.001 and 1, respectively. γ and β were set to 0.25 and 0.5, respectively. The loss weights were empirically set during development. The VQ-VAE had four downsamplings with convolutions that had 128 feature maps until the last level, where they had 256. The latent representation had dimensions of 10 × 14 × 10 with a feature size of 32. Each model was trained using distributed data parallelism on an NVIDIADGX SuperPod equipped with eight A100 cards, each with 80 GB.
Transformer
The UKB transformer and ADNI transformer were trained for 500 epochs. Adam62 was used as an optimizer with a learning rate of 0.0005 and a batch size of 3. The input sequence length was the full 1,400 tokens. The transformer had 24 layers, a latent representation of 512 and 16 heads for each attention layer. Each model was trained using distributed data parallelism on four NVIDIA A100 DGX SuperPods, each equipped with eight A100 with 80 GB each. For the UKB dataset, they were conditioned on sex (UKB Data-Field 31-0.0), age (UKB Data-Field 21003-2.0), ventricular volume (UKB Data-Field 25004-2.0) and brain size normalized to head size (UKB Data-Field 25009-2.0), whereas for the ADNI dataset they were conditioned on sex, age and pathology as defined by the ARM field.
A depiction of the architecture can be seen in Fig. 3, further architectural descriptions can be found in Supplementary Section C and detailed ablation studies are elaborated on in Supplementary Section B.
Data availability
The UKB T1-weighted brain images used in this study are available via the UKB data access process (http://www.ukbiobank.ac.uk/register-apply/). Detailed information about the brain images from UKB is available at https://www.ukbiobank.ac.uk/enable-your-research/about-our-data/imaging-dataand https://biobank.ndph.ox.ac.uk/showcase/label.cgi?id=100. Given the ongoing nature of UKB study, the number of brain image samples currently available in UKB may differ slightly from those described in this Article. The ADNI T1-weighted brain images used in this study are available via the ADNI database access process (https://adni.loni.usc.edu/data-samples/access-data/). Detailed information about the brain images from ADNI is available at https://adni.loni.usc.edu/data-samples/data-types/ and https://adni.loni.usc.edu/methods/mri-tool/mri-analysis/.
Code availability
The VBM SPM MATLAB script templates are available in Supplementary Section D. The code is open source and available via GitHub at https://github.com/AmigoLab/BrainSynth and via Zenodo at https://doi.org/10.5281/zenodo.11583061 (ref. 63).
References
Deng, J. et al. ImageNet: A large-scale hierarchical image database. In 2009 IEEE Conference on Computer Vision and Pattern Recognition 248–255 (IEEE, 2009).
Singh, A. et al. TextOCR: towards large-scale end-to-end reasoning for arbitrary-shaped scene text. In Proc. of the IEEE/CVF Conference on Computer Vision and Pattern Recognition 8802– 8812 (CVPR, 2021).
Lin, T.-Y. et al. Microsoft COCO: common objects in context. in European Conference on Computer Vision (eds Fleet, D., Pajdla, T., Schiele, B. & Tuytelaars, T.) 740–755 (Springer, 2014).
Sudlow, C. et al. UK Biobank: an open access resource for identifying the causes of a wide range of complex diseases of middle and old age. PLoS Med. 12, e1001779 (2015).
Miller, K. L. et al. Multimodal population brain imaging in the UK Biobank prospective epidemiological study. Nat. Neurosci. 19, 1523–1536 (2016).
Jack, C. et al. The Alzheimer's Disease Neuroimaging Initiative (ADNI): MRI methods. J. Magn. Reson. Imaging 27, 685–691 (2008).
Antonelli, M. et al. The medical segmentation decathlon. Nat. Commun. 13, 4128 (2022).
Rieke, N. et al. The future of digital health with federated learning. NPJ Digit. Med. 3, 119 (2020).
Kingma, D. P. & Welling, M. Auto-encoding variational Bayes. In 2nd International Conference on Learning Representations (eds Bengio, Y. & LeCun, Y.) (ICLR, 2014).
Xing, S., Sinha, H. & Hwang, S. J. Cycle consistent embedding of 3D brains with auto-encoding generative adversarial networks. In Medical Imaging with Deep Learning (2021).
Kwon, G., Han, C. & Kim, D.-s. Generation of 3D brain MRI using autoencoding generative adversarial networks. In Medical Image Computing and Computer Assisted Intervention – MICCAI 2019 (eds Shen, D. et al.) 118–126 (Springer, 2019).
Sun, L. et al. Hierarchical amortized GAN for 3D high resolution medical image synthesis. IEEE J. Biomed. Health Inform. 26, 3966–3975 (2022).
Goodfellow, I. et al. Generative adversarial nets. In Advances in Neural Information Processing Systems (eds Ghahramani, Z., Welling, M., Cortes, C., Lawrence, N. & Weinberger, K.) 27 (Curran Associates, Inc., 2014).
Chu, C., Minami, K. & Fukumizu, K. Smoothness and stability in GANs. In 8th International Conference on Learning Representations (ICLR, 2020).
Chlap, P. et al. A review of medical image data augmentation techniques for deep learning applications. J. Med. Imag. Radiat. Oncol. 65, 545–563 (2021).
Shin, H.-C. et al. Medical image synthesis for data augmentation and anonymization using generative adversarial networks. In Simulation and Synthesis in Medical Imaging (eds Gooya, A., Goksel, O., Oguz, I. & Burgos, N.) 1–11 (Springer, 2018).
Fernandez, V. et al. Can segmentation models be trained with fully synthetically generated data? In Simulation and Synthesis in Medical Imaging (eds Zhao, C., Svoboda, D., Wolterink, J. M. & Escobar, M.) 79–90 (Springer, 2022).
Jordon, J. et al. Synthetic data–what, why and how? Preprint at https://arxiv.org/abs/2205.03257 (2022).
Saharia, C. et al. Photorealistic text-to-image diffusion models with deep language understanding. In Proc. of the 36th International Conference on Neural Information Processing Systems (Curran Associates Inc., 2024).
Ramesh, A., Dhariwal, P., Nichol, A., Chu, C. & Chen, M. Hierarchical text-conditional image generation with clip latents. Preprint at https://arxiv.org/abs/2204.06125 (2022).
Heusel, M., Ramsauer, H., Unterthiner, T., Nessler, B. & Hochreiter, S. GANs trained by a two time-scale update rule converge to a local nash equilibrium. In Proc. of the 31st International Conference on Neural Information Processing Systems 6629–6640 (Curran Associates Inc., 2017).
Korkinof, D. et al. Perceived realism of high-resolution generative adversarial network–derived synthetic mammograms. Radiol. Artif. Intell. 3, e190181 (2021).
Esser, P., Rombach, R. & Ommer, B. Taming transformers for high-resolution image synthesis. In Proc. of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) 12873–12883 (IEEE, 2021).
Yu, J. et al. Vector-quantized image modeling with improved VQGAN. In The Tenth International Conference on Learning Representations (ICLR, 2022).
Pinaya, W. H. L. et al. Unsupervised brain anomaly detection and segmentation with transformers. In Proc. of the Fourth Conference on Medical Imaging with Deep Learning (eds Heinrich, M. et al.) 596–617 (PMLR, 2021).
Graham, M. S. et al. Transformer-based out-of-distribution detection for clinically safe segmentation. In Proc. of The 5th International Conference on Medical Imaging with Deep Learning (eds Konukoglu, E. et al.) 457–476 (PMLR, 2022).
Gretton, A., Borgwardt, K. M., Rasch, M. J., Scholkopf, B. & Smola, A. A kernel two-sample test. J. Mach. Learn. Res. 13, 723–773 (2012).
Wang, Z., Simoncelli, E. & Bovik, A. Multiscale structural similarity for image quality assessment. In The Thrity-Seventh Asilomar Conference on Signals, Systems Computers 1398–1402 (IEEE, 2003).
Li, C. & Bovik, A. C. Content-partitioned structural similarity index for image quality assessment. Signal Process. Image Commun. 25, 517–526 (2010).
Chen, G.-h., Yang, C.-l. & Xie, S.-l. Gradient-based structural similarity for image quality assessment. In 2006 International Conference on Image Processing 2929–2932 (IEEE, 2006).
Billot, B., Magdamo, C., Arnold, S. E., Das, S. & Iglesias, J. E. Robust segmentation of brain MRI in the wild with hierarchical CNNs and no retraining. In Medical Image Computing and Computer Assisted Intervention – MICCAI 2022 (eds Wang, L., Dou, Q., Fletcher, P. T., Speidel, S. & Li, S.) 538–548 (Springer, 2022).
Henschel, L. et al. FastSurfer - A fast and accurate deep learning based neuroimaging pipeline. NeuroImage 219, 117012 (2020).
Ashburner, J. & Friston, K. J. Voxel-Based Morphometry—the methods. NeuroImage 11, 805–821 (2000).
Mao, X. et al. Least squares generative adversarial networks. In 2017 IEEE International Conference on Computer Vision (ICCV) 2813–2821 (ICCV, 2017).
Friston, K. J. et al. Statistical parametric mapping: the analysis of functional brain images (Academic Press, 2006).
Whitwell, J. L. Voxel-based morphometry: an automated technique for assessing structural changes in the brain. J. Neurosci. 29, 9661–9664 (2009).
Ridgway, G. R., Litvak, V., Flandin, G., Friston, K. J. & Penny, W. D. The problem of low variance voxels in statistical parametric mapping; a new hat avoids a ‘haircut’. NeuroImage 59, 2131–2141 (2012).
Mann, H. B. & Whitney, D. R. On a test of whether one of two random variables is stochastically larger than the other. Ann. Math. Stat. 18, 50–60 (1947).
Velázquez, J., Mateos, J., Pasaye, E. H., Barrios, F. A. & Marquez-Flores, J. A. Cortical thickness estimation: a comparison of FreeSurfer and three voxel-based methods in a test–retest analysis and a clinical application. Brain Topogr. 34, 430–441 (2021).
Peng, H., Gong, W., Beckmann, C. F., Vedaldi, A. & Smith, S. M. Accurate brain age prediction with lightweight deep neural networks. Med. Image Anal. 68, 101871 (2021).
Esteban, C., Hyland, S. L. & R¨atsch, G. Real-valued (medical) time series generation with recurrent conditional gans. Preprint at https://arxiv.org/abs/1706.02633 (2017).
Van den Oord, A., Vinyals, O. & Kavukcuoglu, K. Neural discrete representation learning. In Proc. of the 31st International Conference on Neural Information Processing Systems 6309– 6318 (Curran Associates Inc., 2017).
Razavi, A., van den Oord, A. & Vinyals, O. in Proceedings of the 33rd International Conference on Neural Information Processing Systems (Curran Associates Inc., 2019).
Hu, M., Wang, Y., Cham, T.-J., Yang, J. & Suganthan, P. Global context with discrete diffusion in vector quantised modelling for image generation. In 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) 11492–11501 (IEEE, 2022).
Chen, X., Mishra, N., Rohaninejad, M. & Abbeel, P. PixelSNAIL: an improved autoregressive generative model. In Proc. of the 35th International Conference on Machine Learning (eds Dy, J. & Krause, A.) 864–872 (PMLR, 2018)
Vaswani, A. et al. Attention is All you Need in Advances in Neural Information Processing Systems (eds Guyon, I. et al.) 30 (Curran Associates, Inc., 2017).
Lin, H., Cheng, X., Wu, X. & Shen, D. CAT: cross attention in vision transformer. In 2022 IEEE International Conference on Multimedia and Expo (ICME) 1–6 (2022).
Dhariwal, P. et al. Jukebox: A generative model for music. Preprint at https://arxiv.org/abs/2005.00341 (2020).
Zhang, R., Isola, P., Efros, A. A., Shechtman, E. & Wang, O. The unreasonable effectiveness of deep features as a perceptual metric. In 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition 586–595 (2018).
Krizhevsky, A., Sutskever, I. & Hinton, G. E. ImageNet classification with deep convolutional neural networks. In Proc. of the 25th International Conference on Neural Information Processing Systems - Volume 1 1097–1105 (Curran Associates Inc., 2012).
Johnson, J., Alahi, A. & Fei-Fei, L. Perceptual losses for real-time style transfer and super-resolution. In Computer Vision – ECCV 2016 (eds Leibe, B., Matas, J., Sebe, N. & Welling, M.) 694–711 (Springer, 2016).
Isola, P., Zhu, J.-Y., Zhou, T. & Efros, A. A. Image-to-image translation with conditional adversarial networks. In 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 5967–5976 (IEEE, 2017).
Wu, K., Peng, H., Chen, M., Fu, J. & Chao, H. Rethinking and improving relative position encoding for vision transformer. In 2021 IEEE/CVF International Conference on Computer Vision (ICCV) 10013–10021 (IEEE, 2021).
Ding, S. et al. ERNIE-Doc: A retrospective long-document modeling transformer. Preprint at https://arxiv.org/abs/2012.15688 (2020).
Zhang, B. & Sennrich, R. Root mean square layer mormalization. In Advances in Neural Information Processing Systems (eds Wallach, H. et al.) 32 (Curran Associates, Inc., 2019).
Narang, S. et al. Do transformer modifications transfer across implementations and applications? in Proc. of the 2021 Conference on Empirical Methods in Natural Language Processing (eds Moens, M.-F., Huang, X., Specia, L. & Yih, S. W.-t.) 5758–5773 (Association for Computational Linguistics, 2021).
Borgeaud, S. et al. Improving language models by retrieving from trillions of tokens. in Proc. of the 39th International Conference on Machine Learning (eds Chaudhuri, K. et al.) 2206– 2240 (PMLR, 2022).
Rae, J. W. et al. Scaling language models: methods, analysis & insights from training gopher. Preprint at https://arxiv.org/abs/2112.11446 (2021).
Jun, H. et al. Distribution augmentation for generative modeling. in Proc. of the 37th International Conference on Machine Learning (eds Daumé, H. & Singh, A.) 5006–5019 (PMLR, 2020).
Jumper, J. et al. Highly accurate protein structure prediction with AlphaFold. Nature 596, 583–589 (2021).
Tudosiu, P.-D. et al. Neuromorphologicaly-preserving volumetric data encoding using VQ-VAE. Preprint at https://arxiv.org/abs/2002.05692 (2020).
Kingma, D. P. & Ba, J. Adam: a method for stochastic optimization. in 3rd International Conference on Learning Representations (eds Bengio, Y. & LeCun, Y.) (ICLR, 2015).
Tudosiu, P.-D. AmigoLab/BrainSynth: Nature Machine Intelligence Release version release. June 2024. Zenodo https://doi.org/10.5281/zenodo.11583061 (2024).
Acknowledgements
P.-D.T. is supported by the EPSRC Research Council, part of the EPSRC DTP (grant number EP/R513064/1). W.H.L.P., M.S.G., P.B., M.J.C., S.O., R.J.G. and P.N. are supported by Wellcome (grant number WT213038/Z/18/Z). V.F. and A.P. are supported by Wellcome/the EPSRC Centre for Medical Engineering (grant number WT203148/Z/16/Z), the Wellcome Flagship Programme (grant number WT213038/Z/18/Z), the London AI Centre for Value-based Healthcare and GE Healthcare. P.B. is also supported by Wellcome EPSRC CME (grant number WT203148/Z/16/Z). J.D. is supported by the Intramural Research Program of the NIMH (grant numbers ZIC-MH002960 and ZIC-MH002968). P.F.D.C. is supported by the European Union’s HORIZON 2020 Research and Innovation Programme under the Marie Sklodowska-Curie grant agreement number 814302. P.N. is also supported by the UCLH NIHR Biomedical Research Centre. The models in this work were trained on NVIDIA Cambridge-1, the UK’s largest supercomputer, aimed at accelerating digital biology. We also thank D. Yang from NVIDIA for his unrelenting support and guidance in using NVIDIA Cambridge-1.
Author information
Authors and Affiliations
Contributions
P.-D.T., A.P. and M.G. implemented the model. P.-D.T. implemented, ran and analysed the VBM experiments. P.F.D.C., J.D. and W.H.L.P. implemented and ran the correlation and age prediction results. W.H.L.P. implemented and ran the SynthSeg experiments and P.-D.T. analysed the results. P.-D.T. implemented, ran and analysed the FastSurfer experiments. P.-D.T. implemented, ran and analysed the quantitative image fidelity evaluation. P.-D.T. implemented, ran and analysed the ablation studies. P.-D.T., W.H.L.P., P.N., M.J.C. and R.J.G. designed the experiments. P.-D.T., W.H.L.P., P.B., A.P. and M.J.C. wrote the manuscript and all other authors provided their feedback. P.T. and V.F. created the figures. M.J.C. and S.O. facilitated the computing capabilities.
Corresponding author
Ethics declarations
Competing interests
The authors declare no competing interests.
Peer review
Peer review information
Nature Machine Intelligence thanks Mingming Gong, Han Peng and the other, anonymous, reviewer(s) for their contribution to the peer review of this work.
Additional information
Publisher’s note Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
Extended data
Extended Data Fig. 1 UKB samples from each baseline model, the proposed model, and the real data.
Random synthetic sample from each model trained on UKB as well as the real dataset.
Extended Data Fig. 2 UKB samples from the proposed model trained on T2* and FLAIR data.
Random synthetic sample from our model trained on UKB’s T2* and FLAIR as well as corresponding real samples.
Extended Data Fig. 3 VBM t-statistic maps of cerebrospinal fluid for UKB healthy dataset from the baseline models, proposed model and real data.
VBM t-statistic maps of cerebrospinal fluid for UKB healthy dataset. The displayed t-statistics range is [0, 2] and is based on the VBM of the real data. The t-statistics were corrected as per36.
Extended Data Fig. 4 Visualization of the predicted sampled conditioning vs the input conditions.
Correlation between the measured and the conditioning variable, when applied to Age and Ventricular size. The Pearson Correlation of Ventricular Size and Age is 0.47 and 0.33 respectively, while the p-Value is 6.71E-21 and 3.51E-102.
Supplementary information
Supplementary Information
Supplementary Fig. 1, Tables 1–9 and Listings 1–3.
Rights and permissions
Open Access This article is licensed under a Creative Commons Attribution 4.0 International License, which permits use, sharing, adaptation, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons licence, and indicate if changes were made. The images or other third party material in this article are included in the article’s Creative Commons licence, unless indicated otherwise in a credit line to the material. If material is not included in the article’s Creative Commons licence and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this licence, visit http://creativecommons.org/licenses/by/4.0/.
About this article
Cite this article
Tudosiu, PD., Pinaya, W.H.L., Ferreira Da Costa, P. et al. Realistic morphology-preserving generative modelling of the brain. Nat Mach Intell 6, 811–819 (2024). https://doi.org/10.1038/s42256-024-00864-0
Received:
Accepted:
Published:
Issue Date:
DOI: https://doi.org/10.1038/s42256-024-00864-0