3D image classification of lung CT scan using propeller

3D image classification from lung CT scan

Author: KHB1698

Date: September 2021

Absrtact: this example tutorial demonstrates how to classify 3D images of lung CT scans

1. Brief introduction

This example will show the steps required to build a 3D convolutional neural network (CNN) to predict the presence of viral pneumonia in computed tomography (CT) scans. 2D CNN is usually used to process RGB images (3 channels). 3D CNN is only a 3D equivalent: it needs to input 3D graphics or 2D frame sequences (for example, slices in CT scanning). 3D CNN is a powerful model for learning and representation of volume data.

2. Environment setting

This tutorial is based on paste 2.1. If your environment is not this version, please refer to it first Official website Install paste 2.1.

import os
import zipfile
import numpy as np
import paddle
from paddle.nn import functional as F
paddle.__version__
'2.1.2'

3. Dataset

In this example, we use a subset of MOSMEDDATA: MosMedData: Chest CT Scans with COVID-19 Related Findings (chest CT scan of covid-19 related findings). The data set consists of lung CT scans with covid-19 related findings and lung CT without related findings.

We will use the relevant radiological findings of CT scans as labels to build classifiers to predict the presence of viral pneumonia. Therefore, the task is a binary classification problem.

The data set of the project can be downloaded directly with wget command, but due to the long download time, I uploaded the data to the project data set

# The wget download command is as follows
!wget https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-0.zip
!wget https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-23.zip
# Make a directory to store data.
os.makedirs("MosMedData")
# Unzip the data in the newly created directory.
with zipfile.ZipFile("data/data106304/CT-0.zip", "r") as z_fp:
    z_fp.extractall("./MosMedData/")

with zipfile.ZipFile("data/data106304/CT-23.zip", "r") as z_fp:
    z_fp.extractall("./MosMedData/")

3.1 loading data and preprocessing

These files are provided in NIFTI format with a. nii extension. To read the scan, we use the nibabel package. You can install the nibabel package through PIP install. CT scans store the original voxel intensity in Hounsfield units (HU). In this dataset, they range from - 1024 to more than 2000. More than 400 bones have different radioactive intensities, so this is used as a higher limit. A threshold between - 1000 and 400 is usually used to standardize CT scans.

To process data, we do the following:

  • We first rotate the volume 90 degrees, so the direction is fixed.
  • We scale the HU value between 0 and 1.
  • We adjust the size of width, height and depth.

Here, we define several auxiliary functions to process data. These features will be used when building training and validation datasets.

# nii.gz format is a common compression format for medical images. nibabel can be used to read and save in python.
!pip install nibabel
import nibabel as nib

from scipy import ndimage

def read_nifti_file(filepath):
    """Reading and loading data"""
    # Read file
    scan = nib.load(filepath)
    # Get raw data
    scan = scan.get_fdata()
    return scan


def normalize(volume):
    """data normalization """
    min = -1000
    max = 400
    volume[volume < min] = min
    volume[volume > max] = max
    volume = (volume - min) / (max - min)
    volume = volume.astype("float32")
    return volume


def resize_volume(img):
    """Span z Axis resizing"""
    # Set the desired depth
    desired_depth = 64
    desired_width = 128
    desired_height = 128
    # Get current depth
    current_depth = img.shape[-1]
    current_width = img.shape[0]
    current_height = img.shape[1]
    # Calculate depth factor
    depth = current_depth / desired_depth
    width = current_width / desired_width
    height = current_height / desired_height
    depth_factor = 1 / depth
    width_factor = 1 / width
    height_factor = 1 / height
    # rotate
    img = ndimage.rotate(img, 90, reshape=False)
    # Cross z axis adjustment, great Xia
    img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
    return img


def process_scan(path):
    """Read and resize data"""
    # Read scan file
    volume = read_nifti_file(path)
    # normalization
    volume = normalize(volume)
    # Adjust width, height, and depth
    volume = resize_volume(volume)
    return volume

Let's read the path of CT scan from the class directory.

# The folder "CT-0" consists of CT scans with normal lung tissue without CT signs of viral pneumonia.
normal_scan_paths = [
    os.path.join(os.getcwd(), "MosMedData/CT-0", x)
    for x in os.listdir("MosMedData/CT-0")
]
# The folder "CT-23" includes several ground glass opaque CT scans with pulmonary parenchyma involvement.
abnormal_scan_paths = [
    os.path.join(os.getcwd(), "MosMedData/CT-23", x)
    for x in os.listdir("MosMedData/CT-23")
]

print("CT scans with normal lung tissue: " + str(len(normal_scan_paths)))
print("CT scans with abnormal lung tissue: " + str(len(abnormal_scan_paths)))
CT scans with normal lung tissue: 100
CT scans with abnormal lung tissue: 100

3.2 dividing training and validation data sets

Read scans from the directory and assign labels. Scan the image to have a shape of 128x128x64. Reclassify the original HU value to a range of 0 to 1. Finally, the data set is divided into training and verification subsets.

# Read and process scanned files. Each scan will adjust the span height, width and depth and rescale.
abnormal_scans = np.array([process_scan(path) for path in abnormal_scan_paths])
normal_scans = np.array([process_scan(path) for path in normal_scan_paths])

# 1 for CT scans with viral pneumonia and 0 for normal.
abnormal_labels = np.array([1 for _ in range(len(abnormal_scans))])
normal_labels = np.array([0 for _ in range(len(normal_scans))])

# Split data in ratios 70-30 for training and validation.
x_train = np.concatenate((abnormal_scans[:70], normal_scans[:70]), axis=0)
y_train = np.concatenate((abnormal_labels[:70], normal_labels[:70]), axis=0)
x_val = np.concatenate((abnormal_scans[70:], normal_scans[70:]), axis=0)
y_val = np.concatenate((abnormal_labels[70:], normal_labels[70:]), axis=0)
print(
    "Number of samples in train and validation are %d and %d."
    % (x_train.shape[0], x_val.shape[0])
)
Number of samples in train and validation are 140 and 60.

3.3 dataset definition

Use the paddy.io.dataset custom dataset class of the high-level API of the propeller framework. For details, please refer to the official website document custom dataset.

Custom dataset to override__ init__ , And realize__ getitem__ And__ len__.

# Build the lung Dataset according to the usage specification of Dataset

from paddle.io import Dataset

class CTDataset(Dataset):
    # Lung scan dataset
    """
    Step 1: inherit paddle.io.Dataset class
    """
    def __init__(self, x, y, transform=None):
        """
        Step 2: implement the constructor and define the data set size
        Args:
            x: image
            y: Folder path for picture storage
            transform (callable, optional): Data processing method applied to image
        """
        self.x = x
        self.y = y
        self.transform = transform # Get transform method

    def __getitem__(self, idx):
        """
        Step 3: Implement__getitem__Methods, defining and specifying index How to obtain data and return a single piece of data (training data)/Test data, corresponding label)
        """
        img = self.x[idx]
        label = self.y[idx]
        # If the transform method is defined, use the transform method
        if self.transform:
            img,label = self.transform([img,label])
        # Because we have processed the data set above and generated the numpy form, there is no need to process it
        return img, label

    def __len__(self):
        """
        Step 4: Implement__len__Method to return the total number of data sets
        """
        return len(self.y) # Returns the size of the dataset, that is, the number of pictures


3.4 data visualization

Instantiate the dataset and display some images.

import matplotlib.pyplot as plt
train_dataset = CTDataset(x_train,y_train)

images, labels = train_dataset[11]
image = images
print("Dimension of the CT scan is:", image.shape)
plt.imshow(np.squeeze(image[:, :, 30]), cmap="gray")
print("label is:",labels)
Dimension of the CT scan is: (128, 128, 64)
label is: 1

Since CT scans have many slices, let's take a look at the collection of slices.

def plot_slices(num_rows, num_columns, width, height, data):
    """Plot a montage of 20 CT slices"""
    data = np.rot90(np.array(data))
    data = np.transpose(data)
    data = np.reshape(data, (num_rows, num_columns, width, height))
    rows_data, columns_data = data.shape[0], data.shape[1]
    heights = [slc[0].shape[0] for slc in data]
    widths = [slc.shape[1] for slc in data[0]]
    fig_width = 12.0
    fig_height = fig_width * sum(heights) / sum(widths)
    f, axarr = plt.subplots(
        rows_data,
        columns_data,
        figsize=(fig_width, fig_height),
        gridspec_kw={"height_ratios": heights},
    )
    for i in range(rows_data):
        for j in range(columns_data):
            axarr[i, j].imshow(data[i][j], cmap="gray")
            axarr[i, j].axis("off")
    plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
    plt.show()


# Visualize montage of slices.
# 4 rows and 10 columns for 100 slices of the CT scan.
plot_slices(4, 10, 128, 128, image[:, :, :40])

3.5 Transforms

CT scans also enhance data by performing random angular rotation during training. Since the data is stored in the shape HWD (height, width, depth) of Rank-3, we first change the image format to DHW, and then add the size of size 1 at axis 0 to enable 3D convolution of the data. Therefore, the new shape (1, depth, height, width). There are also different types of preprocessing and enhancement techniques. This example performs some simple operations through custom transform.

# Standardized custom transform method
# The current transform method of the propeller can only process image data, but not lable data, so we need to define transform
class TransformAPI(object):
    """
    Step 1: inherit object class
    """
    def __call__(self, data):

        """
        Step 2: in __call__ Data processing method defined in
        """
        
        processed_data = data
        return  processed_data
import paddle
import random
from scipy import ndimage
import paddle.vision.transforms.functional as F

# Rotate the image several degrees
class Rotate(object):

    def __call__(self, data):
        
        image = data[0]
        key_pts = data[1]
        # Define some rotation angles
        angles = [-20, -10, -5, 5, 10, 20]
        # Random selection angle
        angle = random.choice(angles)
        # Rotating volume
        image = ndimage.rotate(image, angle, reshape=False)
        image[image < 0] = 0
        image[image > 1] = 1        
        return image, key_pts

# Change the image format from HWD to CDHW
class ToCDHW(object):
    
    def __call__(self, data):
        
        image = data[0]
        key_pts = data[1]
        image = paddle.transpose(paddle.to_tensor(image),perm=[2,0,1])
        image = np.expand_dims(image,axis=0)
        return image, key_pts

3.6 data definition

from paddle.vision.transforms import Compose

# create the transformed dataset
train_dataset = CTDataset(x_train,y_train,transform=Compose([Rotate(),ToCDHW()]))
valid_dataset = CTDataset(x_train,y_train,transform=Compose([ToCDHW()]))

4. Model networking

To make the model easier to understand, we build it into blocks. The architecture of 3D CNN used in this example.

import paddle

class Model3D(paddle.nn.Layer):
    def __init__(self):
        super(Model3D,self).__init__()
        self.layerAll = paddle.nn.Sequential(
            paddle.nn.Conv3D(1,64,(3,3,3)),
            paddle.nn.ReLU(),
            paddle.nn.MaxPool3D(kernel_size=2),
            paddle.nn.BatchNorm3D(64),

            paddle.nn.Conv3D(64,64,(3,3,3)),
            paddle.nn.ReLU(),
            paddle.nn.MaxPool3D(kernel_size=2),
            paddle.nn.BatchNorm3D(64),

            paddle.nn.Conv3D(64,128,(3,3,3)),
            paddle.nn.ReLU(),
            paddle.nn.MaxPool3D(kernel_size=2),
            paddle.nn.BatchNorm3D(128),

            paddle.nn.Conv3D(128,256,(3,3,3)),
            paddle.nn.ReLU(),
            paddle.nn.MaxPool3D(kernel_size=2),
            paddle.nn.BatchNorm3D(256),
            
            paddle.nn.AdaptiveAvgPool3D(output_size=1),
            paddle.nn.Flatten(),
            paddle.nn.Linear(256,512),
            paddle.nn.Dropout(p=0.3),

            paddle.nn.Linear(512,1),
            paddle.nn.Sigmoid()


        )

    def forward(self, inputs):
        x = self.layerAll(inputs)
        return x

model = paddle.Model(Model3D())
model.summary((-1,1,64,128,128))
-----------------------------------------------------------------------------------
   Layer (type)           Input Shape           Output Shape          Param #    
===================================================================================
     Conv3D-29      [[1, 1, 64, 128, 128]]  [1, 64, 62, 126, 126]      1,792     
      ReLU-29       [[1, 64, 62, 126, 126]] [1, 64, 62, 126, 126]        0       
   MaxPool3D-29     [[1, 64, 62, 126, 126]]  [1, 64, 31, 63, 63]         0       
  BatchNorm3D-29     [[1, 64, 31, 63, 63]]   [1, 64, 31, 63, 63]        256      
     Conv3D-30       [[1, 64, 31, 63, 63]]   [1, 64, 29, 61, 61]      110,656    
      ReLU-30        [[1, 64, 29, 61, 61]]   [1, 64, 29, 61, 61]         0       
   MaxPool3D-30      [[1, 64, 29, 61, 61]]   [1, 64, 14, 30, 30]         0       
  BatchNorm3D-30     [[1, 64, 14, 30, 30]]   [1, 64, 14, 30, 30]        256      
     Conv3D-31       [[1, 64, 14, 30, 30]]  [1, 128, 12, 28, 28]      221,312    
      ReLU-31       [[1, 128, 12, 28, 28]]  [1, 128, 12, 28, 28]         0       
   MaxPool3D-31     [[1, 128, 12, 28, 28]]   [1, 128, 6, 14, 14]         0       
  BatchNorm3D-31     [[1, 128, 6, 14, 14]]   [1, 128, 6, 14, 14]        512      
     Conv3D-32       [[1, 128, 6, 14, 14]]   [1, 256, 4, 12, 12]      884,992    
      ReLU-32        [[1, 256, 4, 12, 12]]   [1, 256, 4, 12, 12]         0       
   MaxPool3D-32      [[1, 256, 4, 12, 12]]    [1, 256, 2, 6, 6]          0       
  BatchNorm3D-32      [[1, 256, 2, 6, 6]]     [1, 256, 2, 6, 6]        1,024     
AdaptiveAvgPool3D-8   [[1, 256, 2, 6, 6]]     [1, 256, 1, 1, 1]          0       
    Flatten-49        [[1, 256, 1, 1, 1]]         [1, 256]               0       
     Linear-15            [[1, 256]]              [1, 512]            131,584    
     Dropout-8            [[1, 512]]              [1, 512]               0       
     Linear-16            [[1, 512]]               [1, 1]               513      
     Sigmoid-8             [[1, 1]]                [1, 1]                0       
===================================================================================
Total params: 1,352,897
Trainable params: 1,350,849
Non-trainable params: 2,048
-----------------------------------------------------------------------------------
Input size (MB): 4.00
Forward/backward pass size (MB): 1222.30
Params size (MB): 5.16
Estimated Total Size (MB): 1231.46
-----------------------------------------------------------------------------------






{'total_params': 1352897, 'trainable_params': 1350849}

5. Model training

The model network structure and data set are used for model training. We need to talk about some knowledge points in practice.

import paddle.nn.functional as F

epoch_num = 100
batch_size = 2
learning_rate = 0.0001

val_acc_history = []
val_loss_history = []

def train(model):
    print('start training ... ')
    # turn into training mode
    model.train()

    #The interface provides a strategy for the learning rate to decay exponentially.
    scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate= learning_rate, gamma=0.96, verbose=True)
    opt = paddle.optimizer.Adam(learning_rate=scheduler,
                                parameters=model.parameters())

    train_loader = paddle.io.DataLoader(train_dataset,
                                        shuffle=True,
                                        batch_size=batch_size)

    valid_loader = paddle.io.DataLoader(valid_dataset, batch_size=batch_size)
    
    for epoch in range(epoch_num):
        for batch_id, data in enumerate(train_loader()):
            x_data = data[0]
            y_data = paddle.to_tensor(data[1],dtype="float32")
            y_data = paddle.unsqueeze(y_data, 1)

            logits = model(x_data)
            bce_loss = paddle.nn.BCELoss()
            loss = bce_loss(logits, y_data)
            
            if batch_id % 10 == 0:
                print("epoch: {}/{}, batch_id: {}, loss is: {}".format(epoch,epoch_num, batch_id, loss.numpy()))
            loss.backward()
            opt.step()
            opt.clear_grad()

        # evaluate model after one epoch
        model.eval()
        accuracies = []
        losses = []
        for batch_id, data in enumerate(valid_loader()):
            x_data = data[0]
            y_data = paddle.to_tensor(data[1],dtype="float32")
            y_data = paddle.unsqueeze(y_data, 1)

            logits = model(x_data)
            bce_loss = paddle.nn.BCELoss()
            loss = bce_loss(logits, y_data)
            acc = 1-loss
            accuracies.append(acc.numpy())
            losses.append(loss.numpy())

        avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)
        print("[validation] accuracy/loss: {}/{}".format(avg_acc, avg_loss))
        val_acc_history.append(avg_acc)
        val_loss_history.append(avg_loss)
        model.train()

model = Model3D()
train(model)

Operation results of the last 10 cycles:

[validation] accuracy/loss: 0.9439725875854492/0.056027382612228394
[validation] accuracy/loss: 0.9338274002075195/0.06617263704538345
[validation] accuracy/loss: 0.9401693344116211/0.059830646961927414
[validation] accuracy/loss: 0.9441938400268555/0.05580609291791916
[validation] accuracy/loss: 0.9094834327697754/0.09051664918661118
[validation] accuracy/loss: 0.9318971633911133/0.06810279935598373
[validation] accuracy/loss: 0.9222384691238403/0.07776157557964325
[validation] accuracy/loss: 0.9538374543190002/0.04616249352693558
[validation] accuracy/loss: 0.9099018573760986/0.09009815007448196
[validation] accuracy/loss: 0.9140024185180664/0.08599752932786942

It is worth noting that the number of samples is very small (only 200) and no random seeds are specified. You can also here A complete dataset containing more than 1000 CT scans was found. With complete data sets, the accuracy becomes higher.

# Model saving
paddle.save(model.state_dict(), "net_3d.pdparams")

6. Model evaluation

The model accuracy of the validation set drawn here can reach 0.95613086

import matplotlib.pyplot as plt
plt.plot(val_acc_history, label = 'validation accuracy')

plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0, 1])
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7fd52c4f1fd0>

Tags: Python Deep Learning paddlepaddle

Posted on Wed, 01 Sep 2021 21:35:10 -0400 by milsaw