#StackBounty: #python #machine-learning #neural-network #pytorch #training-data training by batches leads to more over-fitting

Bounty: 50

I’m training a sequence to sequence (seq2seq) model and I have different values to train on for the input_sequence_length. for the values 10 and 15 i get acceptable results but when i try to train with 20 i get memory errors so i switched the training to train by batches but the model over-fit and the validation loss explodes, and even with accumulated gradient i get the same behavior,
so I’m looking for hints and leads to more accurate ways to do the update (given the memory restriction).

here is my training function( only with batch section) :

    if batch_size is not None:
        k=len(list(np.arange(0,(X_train_tensor_1.size()[0]//batch_size-1), batch_size )))
        for epoch in range(num_epochs):
            optimizer.zero_grad()
            epoch_loss=0
            for i in list(np.arange(0,(X_train_tensor_1.size()[0]//batch_size-1), batch_size )): # by using equidistant batch till the last one it becomes much faster than using the X.size()[0] directly
                sequence = X_train_tensor[i:i+batch_size,:,:].reshape(-1, sequence_length, input_size).to(device)
                labels = y_train_tensor[i:i+batch_size,:,:].reshape(-1, sequence_length, output_size).to(device)
                # Forward pass
                outputs = model(sequence)
                loss = criterion(outputs, labels)
                epoch_loss+=loss.item()
                # Backward and optimize
                loss.backward() 

            optimizer.step()    
            epoch_loss=epoch_loss/k
            model.eval
            validation_loss,_= evaluate(model,X_test_hard_tensor_1,y_test_hard_tensor_1)
            model.train()
            training_loss_log.append(epoch_loss)
            print ('Epoch [{}/{}], Train MSELoss: {}, Validation : {} {}'.format(epoch+1, num_epochs,epoch_loss,validation_loss))


Get this bounty!!!

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.