PyTorch tutorial

Data collection

In this PyTorch tutorial, we use GTZAN dataset which consists of 10 exclusive genre classes. Please run the following script in your local path.

!wget http://opihi.cs.uvic.ca/sound/genres.tar.gz
!tar -zxvf genres.tar.gz
!wget https://raw.githubusercontent.com/coreyker/dnn-mgr/master/gtzan/train_filtered.txt
!wget https://raw.githubusercontent.com/coreyker/dnn-mgr/master/gtzan/valid_filtered.txt
!wget https://raw.githubusercontent.com/coreyker/dnn-mgr/master/gtzan/test_filtered.txt

Data loader

import os
import random
import torch
import numpy as np
import soundfile as sf
from torch.utils import data
from torchaudio_augmentations import (
    RandomResizedCrop,
    RandomApply,
    PolarityInversion,
    Noise,
    Gain,
    HighLowPass,
    Delay,
    PitchShift,
    Reverb,
    Compose,
)


GTZAN_GENRES = ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']


class GTZANDataset(data.Dataset):
    def __init__(self, data_path, split, num_samples, num_chunks, is_augmentation):
        self.data_path =  data_path if data_path else ''
        self.split = split
        self.num_samples = num_samples
        self.num_chunks = num_chunks
        self.is_augmentation = is_augmentation
        self.genres = GTZAN_GENRES
        self._get_song_list()
        if is_augmentation:
            self._get_augmentations()

    def _get_song_list(self):
        list_filename = os.path.join(self.data_path, '%s_filtered.txt' % self.split)
        with open(list_filename) as f:
            lines = f.readlines()
        self.song_list = [line.strip() for line in lines]

    def _get_augmentations(self):
        transforms = [
            RandomResizedCrop(n_samples=self.num_samples),
            RandomApply([PolarityInversion()], p=0.8),
            RandomApply([Noise(min_snr=0.3, max_snr=0.5)], p=0.3),
            RandomApply([Gain()], p=0.2),
            RandomApply([HighLowPass(sample_rate=22050)], p=0.8),
            RandomApply([Delay(sample_rate=22050)], p=0.5),
            RandomApply([PitchShift(n_samples=self.num_samples, sample_rate=22050)], p=0.4),
            RandomApply([Reverb(sample_rate=22050)], p=0.3),
        ]
        self.augmentation = Compose(transforms=transforms)

    def _adjust_audio_length(self, wav):
        if self.split == 'train':
            random_index = random.randint(0, len(wav) - self.num_samples - 1)
            wav = wav[random_index : random_index + self.num_samples]
        else:
            hop = (len(wav) - self.num_samples) // self.num_chunks
            wav = np.array([wav[i * hop : i * hop + self.num_samples] for i in range(self.num_chunks)])
        return wav

    def __getitem__(self, index):
        line = self.song_list[index]

        # get genre
        genre_name = line.split('/')[0]
        genre_index = self.genres.index(genre_name)

        # get audio
        audio_filename = os.path.join(self.data_path, 'genres', line)
        wav, fs = sf.read(audio_filename)

        # adjust audio length
        wav = self._adjust_audio_length(wav).astype('float32')

        # data augmentation
        if self.is_augmentation:
            wav = self.augmentation(torch.from_numpy(wav).unsqueeze(0)).squeeze(0).numpy()

        return wav, genre_index

    def __len__(self):
        return len(self.song_list)

def get_dataloader(data_path=None, 
                   split='train', 
                   num_samples=22050 * 29, 
                   num_chunks=1, 
                   batch_size=16, 
                   num_workers=0, 
                   is_augmentation=False):
    is_shuffle = True if (split == 'train') else False
    batch_size = batch_size if (split == 'train') else (batch_size // num_chunks)
    data_loader = data.DataLoader(dataset=GTZANDataset(data_path, 
                                                       split, 
                                                       num_samples, 
                                                       num_chunks, 
                                                       is_augmentation),
                                  batch_size=batch_size,
                                  shuffle=is_shuffle,
                                  drop_last=False,
                                  num_workers=num_workers)
    return data_loader

Let’s check returned data shapes.

train_loader = get_dataloader(split='train', is_augmentation=True)
iter_train_loader = iter(train_loader)
train_wav, train_genre = next(iter_train_loader)

valid_loader = get_dataloader(split='valid')
test_loader = get_dataloader(split='test')
iter_test_loader = iter(test_loader)
test_wav, test_genre = next(iter_test_loader)
print('training data shape: %s' % str(train_wav.shape))
print('validation/test data shape: %s' % str(test_wav.shape))
print(train_genre)
training data shape: torch.Size([16, 639450])
validation/test data shape: torch.Size([16, 1, 639450])
tensor([9, 3, 4, 2, 2, 5, 2, 5, 7, 1, 1, 7, 8, 7, 4, 0])

Note

  • A data loader returns a tensor of audio and their genre indice at each iteration.

  • Random chunks of audio are cropped from the entire sequence during the training. But in validation / test phase, an entire sequence is split into multiple chunks and the chunks are stacked. The stacked chunks are later input to a trained model and the output predictions are aggregated to make song-level predictions.

Model

We are going to build a simple 2D CNN model with Mel spectrogram inputs. First, we design a convolution module that consists of 3x3 convolution, batch normalization, ReLU non-linearity, and 2x2 max pooling. This module is going to be used for each layer of the 2D CNN.

from torch import nn


class Conv_2d(nn.Module):
    def __init__(self, input_channels, output_channels, shape=3, pooling=2, dropout=0.1):
        super(Conv_2d, self).__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, shape, padding=shape//2)
        self.bn = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(pooling)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, wav):
        out = self.conv(wav)
        out = self.bn(out)
        out = self.relu(out)
        out = self.maxpool(out)
        out = self.dropout(out)
        return out

Stack the convolution layers. In a PyTorch module, layers are declared in __init__ and they are built up in forward function.

import torchaudio


class CNN(nn.Module):
    def __init__(self, num_channels=16, 
                       sample_rate=22050, 
                       n_fft=1024, 
                       f_min=0.0, 
                       f_max=11025.0, 
                       num_mels=128, 
                       num_classes=10):
        super(CNN, self).__init__()

        # mel spectrogram
        self.melspec = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, 
                                                            n_fft=n_fft, 
                                                            f_min=f_min, 
                                                            f_max=f_max, 
                                                            n_mels=num_mels)
        self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
        self.input_bn = nn.BatchNorm2d(1)

        # convolutional layers
        self.layer1 = Conv_2d(1, num_channels, pooling=(2, 3))
        self.layer2 = Conv_2d(num_channels, num_channels, pooling=(3, 4))
        self.layer3 = Conv_2d(num_channels, num_channels * 2, pooling=(2, 5))
        self.layer4 = Conv_2d(num_channels * 2, num_channels * 2, pooling=(3, 3))
        self.layer5 = Conv_2d(num_channels * 2, num_channels * 4, pooling=(3, 4))

        # dense layers
        self.dense1 = nn.Linear(num_channels * 4, num_channels * 4)
        self.dense_bn = nn.BatchNorm1d(num_channels * 4)
        self.dense2 = nn.Linear(num_channels * 4, num_classes)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self, wav):
        # input Preprocessing
        out = self.melspec(wav)
        out = self.amplitude_to_db(out)

        # input batch normalization
        out = out.unsqueeze(1)
        out = self.input_bn(out)

        # convolutional layers
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        
        # reshape. (batch_size, num_channels, 1, 1) -> (batch_size, num_channels)
        out = out.reshape(len(out), -1)

        # dense layers
        out = self.dense1(out)
        out = self.dense_bn(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.dense2(out)

        return out

Note

In this example, we performed preprocessing on-the-fly using torchaudio. This process can be done offline outside of the network using other libraries such as librosa and essentia.

Tip

  • There is no activation function at the last layer since nn.CrossEntropyLoss already includes softmax in it.

  • If you want to perform multi-label binary classification, include out = nn.Sigmoid()(out) at the last layer and use nn.BCELoss().

Training

Iterate training. One epoch is defined as visiting all training items once. This definition can be modified in def __len__ in data loader.

from sklearn.metrics import accuracy_score, confusion_matrix


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
cnn = CNN().to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)
valid_losses = []
num_epochs = 30

for epoch in range(num_epochs):
    losses = []

    # Train
    cnn.train()
    for (wav, genre_index) in train_loader:
        wav = wav.to(device)
        genre_index = genre_index.to(device)

        # Forward
        out = cnn(wav)
        loss = loss_function(out, genre_index)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    print('Epoch: [%d/%d], Train loss: %.4f' % (epoch+1, num_epochs, np.mean(losses)))

    # Validation
    cnn.eval()
    y_true = []
    y_pred = []
    losses = []
    for wav, genre_index in valid_loader:
        wav = wav.to(device)
        genre_index = genre_index.to(device)

        # reshape and aggregate chunk-level predictions
        b, c, t = wav.size()
        logits = cnn(wav.view(-1, t))
        logits = logits.view(b, c, -1).mean(dim=1)
        loss = loss_function(logits, genre_index)
        losses.append(loss.item())
        _, pred = torch.max(logits.data, 1)

        # append labels and predictions
        y_true.extend(genre_index.tolist())
        y_pred.extend(pred.tolist())
    accuracy = accuracy_score(y_true, y_pred)
    valid_loss = np.mean(losses)
    print('Epoch: [%d/%d], Valid loss: %.4f, Valid accuracy: %.4f' % (epoch+1, num_epochs, valid_loss, accuracy))

    # Save model
    valid_losses.append(valid_loss.item())
    if np.argmin(valid_losses) == epoch:
        print('Saving the best model at %d epochs!' % epoch)
        torch.save(cnn.state_dict(), 'best_model.ckpt')
Epoch: [1/30], Train loss: 2.4078
Epoch: [1/30], Valid loss: 2.3558, Valid accuracy: 0.1117
Saving the best model at 0 epochs!
Epoch: [2/30], Train loss: 2.3422
Epoch: [2/30], Valid loss: 2.2748, Valid accuracy: 0.1218
Saving the best model at 1 epochs!
Epoch: [3/30], Train loss: 2.2830
Epoch: [3/30], Valid loss: 2.2013, Valid accuracy: 0.1929
Saving the best model at 2 epochs!
Epoch: [4/30], Train loss: 2.2026
Epoch: [4/30], Valid loss: 2.0716, Valid accuracy: 0.2487
Saving the best model at 3 epochs!
Epoch: [5/30], Train loss: 2.1279
Epoch: [5/30], Valid loss: 1.9948, Valid accuracy: 0.2640
Saving the best model at 4 epochs!
Epoch: [6/30], Train loss: 2.1007
Epoch: [6/30], Valid loss: 1.9407, Valid accuracy: 0.3249
Saving the best model at 5 epochs!
Epoch: [7/30], Train loss: 2.0670
Epoch: [7/30], Valid loss: 1.9217, Valid accuracy: 0.3096
Saving the best model at 6 epochs!
Epoch: [8/30], Train loss: 2.0387
Epoch: [8/30], Valid loss: 1.9618, Valid accuracy: 0.2893
Epoch: [9/30], Train loss: 2.0034
Epoch: [9/30], Valid loss: 1.7882, Valid accuracy: 0.3604
Saving the best model at 8 epochs!
Epoch: [10/30], Train loss: 1.9669
Epoch: [10/30], Valid loss: 1.7608, Valid accuracy: 0.3807
Saving the best model at 9 epochs!
Epoch: [11/30], Train loss: 1.9212
Epoch: [11/30], Valid loss: 1.7428, Valid accuracy: 0.3604
Saving the best model at 10 epochs!
Epoch: [12/30], Train loss: 1.9497
Epoch: [12/30], Valid loss: 1.7381, Valid accuracy: 0.3401
Saving the best model at 11 epochs!
Epoch: [13/30], Train loss: 1.8578
Epoch: [13/30], Valid loss: 1.7946, Valid accuracy: 0.3350
Epoch: [14/30], Train loss: 1.8934
Epoch: [14/30], Valid loss: 1.6822, Valid accuracy: 0.3959
Saving the best model at 13 epochs!
Epoch: [15/30], Train loss: 1.8459
Epoch: [15/30], Valid loss: 1.6475, Valid accuracy: 0.4416
Saving the best model at 14 epochs!
Epoch: [16/30], Train loss: 1.8433
Epoch: [16/30], Valid loss: 1.6429, Valid accuracy: 0.3503
Saving the best model at 15 epochs!
Epoch: [17/30], Train loss: 1.8358
Epoch: [17/30], Valid loss: 2.0232, Valid accuracy: 0.3046
Epoch: [18/30], Train loss: 1.8106
Epoch: [18/30], Valid loss: 1.6712, Valid accuracy: 0.3655
Epoch: [19/30], Train loss: 1.7393
Epoch: [19/30], Valid loss: 2.2497, Valid accuracy: 0.2741
Epoch: [20/30], Train loss: 1.7158
Epoch: [20/30], Valid loss: 1.5637, Valid accuracy: 0.4162
Saving the best model at 19 epochs!
Epoch: [21/30], Train loss: 1.7603
Epoch: [21/30], Valid loss: 1.4845, Valid accuracy: 0.5178
Saving the best model at 20 epochs!
Epoch: [22/30], Train loss: 1.7305
Epoch: [22/30], Valid loss: 1.6282, Valid accuracy: 0.3503
Epoch: [23/30], Train loss: 1.7213
Epoch: [23/30], Valid loss: 1.4270, Valid accuracy: 0.5381
Saving the best model at 22 epochs!
Epoch: [24/30], Train loss: 1.7064
Epoch: [24/30], Valid loss: 1.6344, Valid accuracy: 0.3655
Epoch: [25/30], Train loss: 1.6306
Epoch: [25/30], Valid loss: 1.3873, Valid accuracy: 0.5330
Saving the best model at 24 epochs!
Epoch: [26/30], Train loss: 1.7458
Epoch: [26/30], Valid loss: 1.4194, Valid accuracy: 0.5076
Epoch: [27/30], Train loss: 1.6578
Epoch: [27/30], Valid loss: 1.7264, Valid accuracy: 0.3604
Epoch: [28/30], Train loss: 1.6247
Epoch: [28/30], Valid loss: 1.4872, Valid accuracy: 0.5076
Epoch: [29/30], Train loss: 1.6642
Epoch: [29/30], Valid loss: 1.3975, Valid accuracy: 0.4772
Epoch: [30/30], Train loss: 1.6681
Epoch: [30/30], Valid loss: 1.6023, Valid accuracy: 0.4213

Evaluation

Collect the trained model’s predictions for the test set. Chunk-level predictions are aggregated to make song-level predictions.

# Load the best model
S = torch.load('best_model.ckpt')
cnn.load_state_dict(S)
print('loaded!')

# Run evaluation
cnn.eval()
y_true = []
y_pred = []

with torch.no_grad():
    for wav, genre_index in test_loader:
        wav = wav.to(device)
        genre_index = genre_index.to(device)

        # reshape and aggregate chunk-level predictions
        b, c, t = wav.size()
        logits = cnn(wav.view(-1, t))
        logits = logits.view(b, c, -1).mean(dim=1)
        _, pred = torch.max(logits.data, 1)

        # append labels and predictions
        y_true.extend(genre_index.tolist())
        y_pred.extend(pred.tolist())
loaded!

Finally, we can assess the performance and visualize a confusion matrix for better understanding.

import seaborn as sns
from sklearn.metrics import confusion_matrix


accuracy = accuracy_score(y_true, y_pred)
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, xticklabels=GTZAN_GENRES, yticklabels=GTZAN_GENRES, cmap='YlGnBu')
print('Accuracy: %.4f' % accuracy)
Accuracy: 0.5414
../_images/tutorial_17_1.png

Tip

In this tutorial, we did not use any high-level library for more understandable implementation. We highly recommend checking the following libraries for simplified implementation: