import torch
import math

"""
This file contains utility functions. 
Should not need to change these.
"""

# normalize a vector based on mean and standard deviation
def normalize(inputs, input_means, input_stds):
    return (inputs - input_means)/(input_stds+0.000001)

# batch iterator - used in training
def batch_iter(train_data, label_data, batch_size=8):
    train_size = train_data.shape[0] # 3, batch_size = 2
    rand_idx = torch.randperm(train_size) # [2,1,3]
    batch_num = math.ceil(train_size / batch_size) # 2
    # print(batch_num)
    for i in range(batch_num):
        indices = rand_idx[i*batch_size:((i+1)*batch_size)]
        # print(indices)
        inputs = train_data[indices]
        
        input_means = inputs.mean(dim=0)
        input_stds = inputs.std(dim=0)
        inputs = normalize(inputs, input_means, input_stds)
        
        labels = label_data[indices]
        yield input_means,input_stds, inputs, labels
       