Visit our Shiny App for interactive machine-learning.

This section only documents the corresponding script used for the Shiny app.



library(shiny)
library(DT)
library(rdrop2)
library(e1071)
library(MASS)
library(glmnet)
library(randomForest)
library(caret)
library(readxl)
library(writexl)
library(tidyverse)
library(ComplexHeatmap) # from bionconductor
library(circlize)

library(BiocManager) # ensure smooth use of ComplexHeatmap
options(repos = BiocManager::repositories())

# Import data from Dropbox ----
drop_auth()
token = drop_auth()
saveRDS(token, file = "token.rds")

df.content = drop_read_csv(file = "Dropbox AA Content Data.csv") 

# select needed columns and tidy up
df.content = df.content %>% select(Name, Category, Site) %>%
    mutate(strata.group = paste(Category, Site, sep = "_")) %>%
    cbind(df.content %>% select(23:ncol(df.content))) %>%
    as_tibble()

# categories
unique.categories = df.content$Category %>% unique()



# Define functions ----
# Statified sampling
# strata group
unique.categories.site = df.content$strata.group %>% unique()

# Define function doing stratified sampling based on category-site combination 
func.stratifiedSampling = function(trainingRatio = 0.7){
    
    df.train = sample_frac(df.content, size = trainingRatio, replace = F)
    df.test = df.content %>% anti_join(df.train, by = "Name")
    
    list.learn = list(df.train, df.test)
    return(list.learn)
}



## Normalization  ----
# input list: 1st: training set; 2nd, test set; i.e., the output of func.stratifiedSampling
func.normalize.trainTest = function(list) { 
    
    mat.train = list[[1]] %>% dplyr::select(-c(Name, Category, Site, strata.group, `Data.File`)) %>% as.matrix()
    mat.test = list[[2]] %>% dplyr::select(-c(Name, Category, Site, strata.group, `Data.File`)) %>% as.matrix()
    
    # mean vector computed from training set (as single column matrix)
    meanVector.train = apply(mat.train, 2, mean) %>% as.matrix()
    
    # diagonol matrix, with standard deviation inverse
    mat.inverse.sd.diaganol = apply(mat.train, 2, sd)  %>% diag() %>% solve()
    # Reserve column names
    colnames(mat.inverse.sd.diaganol) = colnames(mat.train) 
    
    # ones vector, as single column matrix, length = # observation units of TRAINING set
    vector.ones.train = rep(1, nrow(mat.train)) %>% as.matrix()
    
    # Compute normalized training dataset 
    mat.train.scaled = (mat.train - vector.ones.train %*% t(meanVector.train)) %*% mat.inverse.sd.diaganol
    
    # Normalize test dataset using training-set mean vector and standard deviation diagonal matrix
    # ones vector, as single column matrix, length = # observation units of TESTING set 
    vector.ones.test = rep(1, nrow(mat.test)) %>% as.matrix()
    mat.test.scaled = (mat.test - vector.ones.test %*% t(meanVector.train)) %*% mat.inverse.sd.diaganol
    
    # Complete two matrices with category labels, and convert to tibble
    df.train.scaled = cbind(data.frame(Category = list[[1]]$Category), mat.train.scaled) %>% as_tibble()
    df.test.scaled  = cbind(data.frame(Category = list[[2]]$Category), mat.test.scaled) %>% as_tibble()
    return(list(df.train.scaled, df.test.scaled))
}

# Chaining prior two functions
# 1) stratified sampling into training and test set
# 2) normalize training set, and normalize test set based on training mean vector and standard deviation 
func.strata.Norm.trainTest = function(trainingRatio = 0.7, scaleData = T){
    if(scaleData == T){
        func.stratifiedSampling(trainingRatio = trainingRatio) %>%
            func.normalize.trainTest() %>% return()
    } else {
        func.stratifiedSampling(trainingRatio = trainingRatio) %>% return()
    }
}



# Prediction table tidy up ----
# Confusion matrix
# Define functions: tidy up confusion tables (contingency table and stats) 

# 1) For contigency table
func.tidy.cf.contigencyTable = function(inputTable, ModelName){
    inputTable %>% as.data.frame() %>% 
        spread(key = Reference, value = Freq) %>% 
        mutate(Model = ModelName) %>%
        return()
}

# 2) Summary statistics for each category
func.tidy.cf.statsTable = function(inputTable, ModelName){
    cbind(Category = rownames(inputTable), inputTable %>% as_tibble()) %>%
        mutate(Category = str_remove(Category, pattern = "Class: ")) %>%
        mutate(Model = ModelName) %>%
        return()
}

# 3) Summarize most votes form all models
func.mostVotes = function(vector){
    x = vector %>% table() %>% sort() %>% rev() 
    names(x)[1] %>% return()
}

# 4) Overal summary statistics for each model
func.tidy.cf.statsOveral = function(vector, modelName){
    d = data.frame(Accuracy = vector[1], AccuracyLower = vector[3], 
                   AccuracyUpper = vector[4], Model = modelName) %>% as_tibble()
    return(d)
}


# arrange model order
unique.models = factor(c("LDA", "QDA", "EN", 
                         "RF", "SVM", "NB", "Most voted"), ordered = T)

unique.models2 = factor(c("Linear discriminant analysis", "Quadratic discriminant analysis", "Elastic net", 
                          "Random forest", "Support vector machine", "Naive Bayes", "Most voted"), ordered = T)


# Entire dataset summary stats (used for tabs #2 & #3 single & batch sample(s) prediction) -----
df.summaryALL = (df.content %>% select(-c(Name, Site, Category, strata.group, Data.File))) %>%
    gather(key = compounds, value = content) %>%
    group_by(compounds) %>%
    summarise(content.mean = mean(content),
              content.sd = sd(content)) 


# set up comounds as ordered factor
unique.compounds = (df.content %>% colnames())[-c(1:5)]
unique.compounds = unique.compounds %>% factor(levels = unique.compounds, ordered = T)


# In compound order as the input matrix
# This data provides the mean and standard deviation vector used for scale() function; normalization of the feature input of the unknown sample(s)
df.summaryALL = df.summaryALL %>% 
    mutate(compounds = factor(compounds, level = unique.compounds, ordered = T)) %>%
    arrange(compounds)


# AIV Category color
unique.categories = df.content$Category %>% unique()

color.category = c("Black", "Steelblue", "Firebrick" , "Darkgreen") 
names(color.category) = unique.categories


# the computation progress bar
progress = c("Dataset setup", "LDA", "QDA", "Elastic Net", "Random Forest", "SVM", "Naive Bayes", "Wrap up", "Output")




# Define user interface ----
ui <- fluidPage(
    
    # Application title
    titlePanel(strong("African Vegetables Classification Predictor")),
    
    # Sidebar with a slider input for number of bins 
    sidebarLayout(
        sidebarPanel(
            
            # tab 1 ----
            conditionalPanel( 
                'input.purpose === "Machine learning simulator"',
                
                fluidRow(column(12, actionButton(inputId = "Shuffle", 
                                                 label = "  Train models and predict on test set", 
                                                 icon("paper-plane"), 
                                                 style="color: white; background-color: steelblue; border-color: black"))),
                
                # see icons at https://fontawesome.com/icons?from=io
                
                br(), br(),
                
                fluidRow(column(12, sliderInput(inputId = "SplitRatio",
                                                label = "Percent of data used for model training",
                                                min = .5, max = .9, value = .7))),
                
                fluidRow(column(12, checkboxInput(inputId = "SeedBirthday", 
                                                  label = "Set seed: generate reproducible split of train-test dataset", 
                                                  value = F))), 
                
                br(),br(),br(),
                strong("Model names are abbreviated as acronyms:"), br(), br(),
                span(strong("LDA"), ", Linear Discriminant Analysis"), br(),br(),
                span(strong("QDA"), ", Quadratic Discriminant Analysis"), br(),br(),
                span(strong("EN"), ", Elastic Net (alpha = 0.5)"), br(),br(),
                span(strong("RF"), ", Random Forest"), br(),br(),
                span(strong("SVM"), ", Support Vector Machine"), br(),br(),
                span(strong("NB"), ", Naive Bayes"), br(), br()
            ),
            
            # tab 2 ----
            conditionalPanel(
                'input.purpose === "Unknown prediction: single sample"',
                
                fluidRow(column(12, actionButton(inputId = "Predict2", 
                                                 label = " Predict!",
                                                 icon("cocktail"), 
                                                 style="color: white; background-color: steelblue; border-color: black"))),
                
                br(),
                
                fluidRow(column(4, sliderInput("alanine", "alanine", min = 0, max = 1000, value = 529)),
                         column(4, sliderInput("arginine", "arginine", min = 0, max = 1000, value = 176)),
                         column(4, sliderInput("asparagine", "asparagine", min = 0, max = 5000, value = 2477))),
                
                fluidRow(column(4, sliderInput("aspartic.acid", "aspartic acid", min = 0, max = 1000, value = 446)),
                         column(4, sliderInput("cysteine", "cysteine", min = 0, max = 100, value = 2)),
                         column(4, sliderInput("glutamic.acid", "glutamic acid", min = 0, max = 1000, value = 88))),
                
                fluidRow(column(4, sliderInput("glutamine", "glutamine", min = 0, max = 1000, value = 465)),
                         column(4, sliderInput("glycine", "glycine", min = 0, max = 1000, value = 72)),
                         column(4, sliderInput("histidine", "histidine", min = 0, max = 1000, value = 136))),
                
                fluidRow(column(4, sliderInput("hydroxyproline", "hydroxyproline", min = 0, max = 100, value = 5)),
                         column(4, sliderInput("isoleucine", "isoleucine", min = 0, max = 1000, value = 175)),
                         column(4, sliderInput("leucine", "leucine", min = 0, max = 1000, value = 136))),
                
                fluidRow(column(4, sliderInput("lysine", "lysine", min = 0, max = 1000, value = 92)),
                         column(4, sliderInput("methionine", "methionine", min = 0, max = 1000, value = 1.7)),
                         column(4, sliderInput("phenylalanine", "phenylalanine", min = 0, max = 1000, value = 312))),
                
                fluidRow(column(4, sliderInput("proline", "proline", min = 0, max = 5000, value = 1754)),
                         column(4, sliderInput("serine", "serine", min = 0, max = 1000, value = 168)),
                         column(4, sliderInput("threonine", "threonine", min = 0, max = 1000, value = 73))), # ----
                
                fluidRow(column(4, sliderInput("tryptophan", "tryptophan", min = 0, max = 1000, value = 113)),
                         column(4, sliderInput("tyrosine", "tyrosine", min = 0, max = 1000, value = 228)),
                         column(4, sliderInput("valine", "valine", min = 0, max = 1000, value = 281)))
            ),
            
            # tab 3 ------
            conditionalPanel(
                'input.purpose === "Unknown prediction: batch sample"',
                fluidRow(
                    column(
                        11,
                        downloadButton("downloadTemplate", "Download table template for batch prediction",
                                       style="color: white; background-color: steelblue; border-color: black"))
                ), 
                
                br(),
                
                fluidRow(
                    column(
                        11,
                        fileInput(inputId = "userfile", label = NULL, buttonLabel = "Browse...Excel input",
                                  width = '700px', placeholder = "Upload free amino acid content dataset"))
                ),
                
                fluidRow(
                    column(
                        11, 
                        checkboxInput(inputId = "ArrangeCategory", 
                                      label = "Arrange heatmap predicted category by most voted results", 
                                      value = F))
                ), 
                
                fluidRow(
                    column(
                        11,
                        downloadButton("downloadPredicted", "Download predicted result table",
                                       style="color: white; background-color: orange; border-color: black"))
                ),
                
                br(),br(),br(),
                
                strong("Model names are abbreviated as acronyms:"), br(), br(),
                span(strong("LDA"), ", Linear Discriminant Analysis"), br(),br(),
                span(strong("QDA"), ", Quadratic Discriminant Analysis"), br(),br(),
                span(strong("EN"), ", Elastic Net (alpha = 0.5)"), br(),br(),
                span(strong("RF"), ", Random Forest"), br(),br(),
                span(strong("SVM"), ", Support Vector Machine"), br(),br(),
                span(strong("NB"), ", Naive Bayes"), br(), br()
                
            )
            
        ),
        
        
        # Show a plot of the generated distribution
        mainPanel( 
            tabsetPanel(
                id = 'purpose',
                
                tabPanel("Machine learning simulator", # ------
                         br(),
                         fluidRow(
                             column(
                                 11,
                                 p("The African indigenous vegetables (AIVs) free amino acid dataset (544 rows, 21 columns) contains the content of 21 free amino acids in 544 samples of AIVs, the latter comprising four basic categories, i.e., nighthshades, amaranth, spider plant, and mustard. Each category was sourced from different continents or countries, harvest years and seasons, etc., including one or more subspecies and cultivars.", strong("This app makes prediction of the AIV category based on the content of free amino acids.",
                                                                                                                                                                                                                                                                                                                                                                                                                                                                  style = "color: black")), br(),
                                 
                                 strong("Click", code("Train models and predict on test set"), 
                                        "button to train models using specified percent (default 70%) of the AIV amino acid dataset as the training set, and make predictions on the remaining percent of the dataset as the test set. The prediction result is summarized as the confusion matrix, showing counts of correct and incorrect predictions.", 
                                        style = "color: steelblue"),
                                 br(), br(),
                                 p("Following clicks of the button shuffles the dataset to create new training-testing sets, update models parameters, and refresh prediction results. After updating the training set percentage and seed setting, click the button again to initiate a new round of train and prediction.", 
                                   style = "color:black"),
                                 br(), br()
                             )
                         ),
                         
                         strong("A. The following confusion matrix presents the summarized prediction results vs. the actual.",
                                style = "color:#006666"),
                         
                         fluidRow(
                             column(8, plotOutput("plt.newShuffle"))
                         ),
                         br(),br(),
                         
                         strong("B. The following heatmap presents the predicted result of each model vs. the actual on a sample-wise basis.",
                                style = "color:#006666"),
                         
                         fluidRow(
                             column(11, plotOutput("plt.newShuffle.heatmap"))
                         ),
                         
                         br(),br(),
                         strong("C. The following sprout plot presents key summary statistics of each model.",
                                style = "color:#006666"),
                         p("The overal accuracy is in bold, with 95% confidence interval displayed in the following line."),
                         
                         fluidRow(
                             column(11, plotOutput("plt.summaryStats"))
                         )
                ),
                
                tabPanel("Unknown prediction: single sample", # ------
                         br(),
                         
                         fluidRow(
                             column(
                                 11,
                                 p("This section makes prediction of AIV category of an unknown sample based on its free amino acid content (in unit of mg/100g dry weight), using models trained from the entire AIVs amino acid dataset (544 rows X 21 columns). The current default setting is the average level of free amino acids in African nightshade leaves."), br(),
                                 
                                 strong("Use sliderbars to update the amino acid content of a single unknown sample."), 
                                 br(),
                                 strong("Click", code("Predict!"), "to predict classfication based on updated feature space.")
                             )
                         ),
                         br(),
                         plotOutput("plt.singleSamplePredict")),
                
                tabPanel("Unknown prediction: batch sample", # ------
                         br(),
                         fluidRow(
                             column(
                                 11, 
                                 p("This section makes prediction of multiple samples using models trained from the entire AIV amino acid dataset."),
                                 strong("Simply upload the Excel sheet to make a prediction.")
                             )
                         ),
                         br(),
                         
                         strong("A: The following table represents prediction result for each sample.", 
                                style = "color:#006666"), br(),
                         dataTableOutput("batchPrediction.output1"),
                         br(),
                         
                         strong("B: The following heatmap visualizes the above prediction table as easy overview.", 
                                style = "color:#006666"),
                         plotOutput("batchPrediction.output2")
                ),
                
                tabPanel("About",
                         br(),
                         fluidRow(
                             column(8,
                                    strong(span("For R script of this interactive application and more associated data analysis of our publication on African vegetables amino acids,
                                please visit our", a("R markdown code documents", href = "https://yuanbofaith.github.io/AfricanVegetables_AminoAcids/index.html", ".")),
                                           style = "color: #666666")
                             )
                         )
                         
                )
            )
        )
    )
)





# Define server----
server <- function(input, output) {
    
    # tab 1: Simulator
    
    ph  <- eventReactive(input$Shuffle, {
        
        if (input$SeedBirthday == T) {set.seed(19911110)}
        
        withProgress(message = "Modelling...", value = 0, {
            
            for(i in 1:length(progress)) {
                
                # update progress of step i
                incProgress(1/length(progress), detail = progress[i])
                # Pause for 1 sec to simulate a long computation.
                Sys.sleep(1)
                
                ## 70/30 train/test split ----
                if (i == 1){
                    
                    df.learn = func.strata.Norm.trainTest(trainingRatio = input$SplitRatio, scaleData = T) 
                    df.train = df.learn[[1]]
                    df.test = df.learn[[2]]
                    
                    colnames(df.train) = make.names(colnames(df.train))
                    colnames(df.test) = make.names(colnames(df.test))
                } 
                
                ## Linear discriminant analysis (LDA)----
                if (i == 2) {
                    mdl.trained.LDA = lda(Category ~., data = df.train)
                    mdl.fitted.LDA = predict(mdl.trained.LDA, newdata = df.test, prior = rep(1/4, 4))
                    fitted.LDA = mdl.fitted.LDA$class # here we overwrite the prior fitted.LDA object
                    cf.LDA = confusionMatrix(data = fitted.LDA, reference = df.test$Category, mode = "everything")
                    
                    # confusion table and summary stats
                    cf.counts.LDA = cf.LDA$table %>% func.tidy.cf.contigencyTable(ModelName = "LDA")
                    cf.stats.LDA = cf.LDA$byClass %>% func.tidy.cf.statsTable(ModelName = "LDA") 
                }
                
                ## Quadratic discriminant analysis ----
                if (i == 3) {
                    mdl.QDA = qda(Category ~., data = df.train)
                    mdl.fitted.QDA = predict(mdl.QDA, newdata = df.test, prior = rep(1/4, 4))
                    fitted.QDA = mdl.fitted.QDA$class
                    cf.QDA = confusionMatrix(data = fitted.QDA, reference = df.test$Category, mode = "everything")
                    
                    # confusion table and summary stats
                    cf.counts.QDA = cf.QDA$table %>% func.tidy.cf.contigencyTable(ModelName = "QDA") 
                    cf.stats.QDA = cf.QDA$byClass %>% func.tidy.cf.statsTable(ModelName = "QDA")
                    
                }
                
                ## Regularized logistic regression -----
                if (i == 4) {
                    
                    # Define function for performing regularized logistic with different alpha values
                    func.regularizedLogistic = function(
                        input.alpha, # control ridge, lasso, or between
                        ModelName    # model type as extra column note in the confusion table output
                    ){        # train with 10-fold cross validation
                        cv.mdl.logistic = cv.glmnet(x = df.train[, -1] %>% as.matrix(), y = df.train$Category, 
                                                    family = "multinomial", alpha = input.alpha, nfolds = 6)
                        # predict with test set
                        fitted.logistic =
                            predict(cv.mdl.logistic, newx = df.test[, -1] %>% as.matrix(), 
                                    s = cv.mdl.logistic$lambda.1se, type = "class") %>% c() %>% 
                            factor(levels = sort(unique.categories), ordered = T) # Note: important to sort unique.categories!
                        
                        # wrap up prediction results
                        cf.logistic = confusionMatrix(data = fitted.logistic, reference = df.test$Category)
                        cf.counts.logistic = cf.logistic$table %>% func.tidy.cf.contigencyTable(ModelName = ModelName)
                        cf.stats.logistic = cf.logistic$byClass %>% func.tidy.cf.statsTable(ModelName = ModelName)
                        
                        return(list(cf.counts.logistic, cf.stats.logistic, fitted.logistic, cf.logistic))
                        
                    }
                    
                    list.logistic = func.regularizedLogistic(input.alpha = 0.5, ModelName = paste("EN") )
                    cf.counts.ElasticNet =list.logistic[[1]]
                    cf.stats.ElasticNet = list.logistic[[2]]
                    fitted.ElasticNet = list.logistic[[3]]
                    cf.EN = list.logistic[[4]]
                    
                }
                
                ## Random forest ----
                if (i == 5) {
                    mdl.randomForest = randomForest(Category ~., data = df.train, ntree = 500, mtry = 5) 
                    fitted.randomForest = predict(mdl.randomForest, newdata = df.test)
                    
                    # set up confusion table
                    cf.randomForest = confusionMatrix(
                        data = fitted.randomForest, reference = df.test$Category, mode = "everything")
                    # confusion matrix and summary stats
                    cf.counts.randomForest = cf.randomForest$table %>%
                        func.tidy.cf.contigencyTable(ModelName = "RF")
                    
                    cf.stats.randomForest = cf.randomForest$byClass %>% 
                        func.tidy.cf.statsTable(ModelName = "RF")
                    
                }
                
                ## Support vector machine ----
                if (i == 6) {
                    mdl.svm = svm(x = df.train[, -1], y = df.train$Category)
                    # predict
                    fitted.svm = predict(mdl.svm, newdata = df.test[, -1])
                    
                    # confusion matrix
                    cf.svm = confusionMatrix(
                        data = fitted.svm, reference = df.test$Category, mode = "everything")
                    
                    # confusion matrix and summary stats
                    cf.counts.svm = cf.svm$table %>%
                        func.tidy.cf.contigencyTable(ModelName = "SVM")
                    
                    cf.stats.svm = cf.svm$byClass %>% 
                        func.tidy.cf.statsTable(ModelName = "SVM")
                    
                }
                
                
                ## Naive Bayes (benchmark) ----
                if(i == 7) {
                    mdl.Bayes = naiveBayes(x = df.train[, -1], y = df.train$Category)
                    
                    fitted.Bayes = predict(mdl.Bayes, newdata = df.test, type = "class")
                    
                    # confusion matrix
                    cf.Bayes = confusionMatrix(
                        data = fitted.Bayes, reference = df.test$Category, mode = "everything")
                    
                    # confusion matrix and summary stats
                    cf.counts.Bayes = cf.Bayes$table %>%
                        func.tidy.cf.contigencyTable(ModelName = "NB")
                    cf.stats.Bayes = cf.Bayes$byClass %>% 
                        func.tidy.cf.statsTable(ModelName = "NB")
                    
                }
                
                
                ## Wrap up results
                if (i == 8) {
                    # Test results ----
                    # Summary of all machine learning techniques ----
                    
                    # Sample wise prediction of all models and most voted 
                    df.actual.vs.fit = data.frame(
                        "Actual" = df.test$Category, 
                        "LDA" = fitted.LDA, 
                        "QDA" = fitted.QDA, 
                        "EN" = fitted.ElasticNet, 
                        # "CART" = fitted.CART,
                        "RF" =  fitted.randomForest, 
                        "SVM" =  fitted.svm, 
                        "NB" = fitted.Bayes)
                    
                    df.actual.vs.fit = df.actual.vs.fit %>% as_tibble() %>%
                        mutate(Actual = factor(Actual, ordered = F))
                    
                    df.actual.vs.fit = df.actual.vs.fit %>% 
                        mutate(most.voted =  apply(df.actual.vs.fit %>% select(-Actual),
                                                   MARGIN = 1, func.mostVotes))
                    
                    
                    # Most voted confusion matrix
                    cf.mostVoted = confusionMatrix(data = df.actual.vs.fit$most.voted %>% factor(), 
                                                   reference = df.actual.vs.fit$Actual, mode = "everything")
                    
                    cf.counts.MostVoted = cf.mostVoted$table %>%
                        func.tidy.cf.contigencyTable(ModelName = "Most voted")
                    
                    cf.stats.MostVoted = cf.mostVoted$byClass %>%
                        func.tidy.cf.statsTable(ModelName = "Most voted")
                    
                    
                    
                    
                    
                    # Summary of all machine learning techniques
                    df.confusionMatrix.all = cf.counts.LDA %>% rbind(cf.counts.QDA) %>%
                        rbind(cf.counts.ElasticNet) %>% # rbind(cf.counts.CART) %>%
                        rbind(cf.counts.randomForest) %>% rbind(cf.counts.svm) %>%
                        rbind(cf.counts.Bayes) %>% rbind(cf.counts.MostVoted) %>% as_tibble()
                    
                    
                    # tidy up the confusion matrix combined
                    df.confusionMatrix.all.tidy = df.confusionMatrix.all %>%
                        # tidy up
                        gather(-c(Prediction, Model), key = reference, value = counts) %>%
                        
                        # convert AIVs category into ordered factor
                        mutate(reference = factor(reference, levels = unique.categories, ordered = T),
                               Prediction = factor(Prediction, levels = unique.categories %>% rev(), ordered = T)) %>%
                        
                        # change model order in the dataset
                        mutate(Model = factor(Model, levels = unique.models, ordered = T)) %>%
                        arrange(Model, reference, Prediction) %>%
                        mutate(Diaganol = Prediction == reference)
                    
                    
                    
                    
                    ## Plot confusion matrix
                    # Assign color to correct / incorrect prediction
                    diag = df.confusionMatrix.all.tidy %>% 
                        filter(Diaganol == F)
                    
                    df.confusionMatrix.all.tidy = diag %>% 
                        mutate(color = ifelse(counts == 0, "Grey", "Firebrick")) %>%
                        rbind(df.confusionMatrix.all.tidy %>% filter(Diaganol == T) %>%
                                  mutate(color = "steelblue"))
                  
                    
                    # Set up summary statistics
                    df.stats = cf.stats.LDA %>% rbind(cf.stats.QDA) %>%
                        rbind(cf.stats.ElasticNet) %>% rbind(cf.stats.randomForest) %>%
                        rbind(cf.stats.svm) %>% rbind(cf.stats.Bayes) %>% rbind(cf.stats.MostVoted) %>%
                        select(Category, Model, Precision, Recall, F1) %>% 
                        gather(-c(1:2), key = metrics, value = values) %>%
                        mutate(Model = factor(Model, levels = unique.models, ordered = T))
                    
                    
                    df.stats.overal = func.tidy.cf.statsOveral(cf.LDA$overall, modelName = "LDA") %>%
                        rbind(func.tidy.cf.statsOveral(cf.QDA$overall, modelName = "QDA")) %>%
                        rbind(func.tidy.cf.statsOveral(cf.EN$overall, modelName = "EN")) %>%
                        rbind(func.tidy.cf.statsOveral(cf.randomForest$overall, modelName = "RF")) %>%
                        rbind(func.tidy.cf.statsOveral(cf.svm$overall, modelName = "SVM")) %>%
                        rbind(func.tidy.cf.statsOveral(cf.Bayes$overall, modelName = "NB")) %>%
                        rbind(func.tidy.cf.statsOveral(cf.mostVoted$overall, modelName = "Most voted")) %>%
                        mutate(Model = factor(Model, levels = unique.models, ordered = T))
                    
                }
                
                ## Visualization ----
                if (i == 9) {
                    # confusion matrix -----
                    p =   df.confusionMatrix.all.tidy %>% # --------
                    ggplot(aes(x = reference, y = Prediction)) + 
                        facet_wrap(~Model, nrow = 2) +
                        
                        # off diaganol incorrect prediction
                        geom_label(data = df.confusionMatrix.all.tidy %>% filter(color == "Firebrick"),
                                   aes(label = counts), 
                                   fill = "firebrick", alpha = .3, size = 6) +
                        
                        # diaganol correct prediction
                        geom_label(data = df.confusionMatrix.all.tidy %>% filter(color == "steelblue"),
                                   aes(label = counts),
                                   fill = "Steelblue", alpha = .3, size = 6) +
                        
                        # zero counts
                        geom_label(data = df.confusionMatrix.all.tidy %>% filter(color == "Grey"),
                                   aes(label = counts), 
                                   size = 6, color = "grey")  +
                        theme_bw() +
                        
                        theme(axis.text.x = element_text(angle = 90, vjust = .8, hjust = .8, color = "black", size = 12),
                              axis.text.y = element_text(color = "black", size = 12),
                              axis.title = element_text(size = 14, colour = "black"),
                              strip.background = element_blank(),
                              strip.text = element_text(face = "bold", size = 14),
                              panel.border = element_rect(color = "black", size = 1),
                              title = element_text(face = "bold")) + 
                        labs(x = "\nReference", y = "Prediction\n") 
                    
                    
                    # sample-wise comparison heatmap -----
                    df.actual.vs.fit = data.frame(
                        "Actual" = df.test$Category, 
                        "LDA" = fitted.LDA, 
                        "QDA" = fitted.QDA, 
                        "EN" = fitted.ElasticNet, 
                        # "CART" = fitted.CART,
                        "RF" =  fitted.randomForest, 
                        "SVM" =  fitted.svm, 
                        "NB" = fitted.Bayes)
                    
                    df.actual.vs.fit = df.actual.vs.fit %>% as_tibble() %>%
                        mutate(Actual = factor(Actual, ordered = F))
                    df.actual.vs.fit
                    
                    df.actual.vs.fit = df.actual.vs.fit %>% 
                        mutate(most.voted =  apply(df.actual.vs.fit %>% select(-Actual),
                                                   MARGIN = 1, func.mostVotes))
                    
                    
                    # Heatmap of sample-wise predicted result
                    plt.heatmap.machineLearning = 
                        df.actual.vs.fit %>% arrange(Actual) %>%
                        as.matrix() %>% t() %>%
                        Heatmap(col = color.category,
                                heatmap_legend_param = list(
                                    title = "", title_position = "leftcenter",
                                    nrow = 1,
                                    labels_gp = gpar(fontsize = 15)), 
                                rect_gp = gpar(col = "white", lwd = 0.1))
                    
                    h = draw(plt.heatmap.machineLearning, heatmap_legend_side = "bottom")
                    
                    
                    
                    # Summary statistics
                    plt.summaryStats = 
                        df.stats %>% ggplot(aes(x = Category, y = values, color = metrics)) +
                        geom_segment(aes( xend = Category, y = 0.5, yend = values), 
                                     position = position_dodge(0.5)) +
                        geom_point(size = 4, position = position_dodge(.3), alpha = .9) +
                        facet_wrap(~Model, nrow = 1) +
                        theme_bw() +
                        theme(legend.position = "bottom",
                              legend.title = element_text(size = 14),
                              legend.text = element_text(size = 14),
                              
                              strip.text = element_text(face = "bold", size = 14),
                              strip.background = element_blank(),
                              
                              axis.text.x = element_text(angle = 90, hjust = 1, colour = "black", size = 12),
                              axis.text.y = element_text(colour = "black", size = 12),
                              axis.title = element_text(size = 13)) +
                        coord_cartesian(ylim = c(0.65, 1)) +
                        scale_color_brewer(palette = "Accent") +
                        
                        # overal stats
                        geom_text(data = df.stats.overal, 
                                  aes(x = 2.5, y = 0.7, label = round(Accuracy, 3) * 100), 
                                  color = "black", fontface = "bold", size = 6) +
                        
                        geom_text(data = df.stats.overal, 
                                  aes(x = 2.5, y = 0.66, 
                                      label = paste(round(AccuracyLower, 3) * 100, " ~ ",
                                                    round(AccuracyUpper, 3) * 100)), 
                                  color = "black", size = 5) 
                } 
            }
        })
        return(list(p, h, plt.summaryStats))
    })
    
    output$plt.newShuffle = renderPlot({  ph()[[1]] })
    output$plt.newShuffle.heatmap = renderPlot({  ph()[[2]] })
    output$plt.summaryStats = renderPlot({ ph()[3] })
    
    
    
    
    
    # tab 2: Single prediction of real unknown sample
    plt.singleSample.predict = eventReactive(input$Predict2, {
        
        withProgress(message = "Modelling...", value = 0, {
            
            for(i in 1:length(progress)) {
                
                # update progress of step i
                incProgress(1/length(progress), detail = progress[i])
                # Pause for 1 sec to simulate a long computation.
                Sys.sleep(1)
                
                # Use all oroginal dataset for training models
                if (i == 1){
                    
                    # model training
                    df.train = func.strata.Norm.trainTest(trainingRatio = 1, scaleData = T)[[1]] # split ratio = 1
                    
                    # collect to-predict feature space
                    df.singleInput = data.frame(
                        
                        leucine = input$leucine,             isoleucine = input$isoleucine,         tryptophan = input$tryptophan,
                        phenylalanine = input$phenylalanine, valine = input$valine,                 methionine = input$methionine, 
                        
                        tyrosine = input$tyrosine,           proline = input$proline,               alanine = input$alanine, 
                        cysteine = input$cysteine,           glycine = input$glycine,               glutamic.acid = input$glutamic.acid, 
                        
                        threonine = input$threonine,         hydroxyproline = input$hydroxyproline, glutamine = input$glutamine,
                        aspartic.acid = input$aspartic.acid, serine = input$serine,                 asparagine = input$asparagine, 
                        
                        arginine = input$arginine,           histidine = input$histidine,           lysine = input$lysine
                        
                    ) %>% as_tibble()
                    
                    # Normalize unknown sample based on mean and standard deviation vector computed from entire original AIV dataset
                    # Meanwhile overwriting the input feature space
                    df.singleInput = df.singleInput %>% 
                        scale(center = df.summaryALL$content.mean, scale = df.summaryALL$content.sd) %>%
                        as_tibble()
                } 
                
                ## Linear discriminant analysis (LDA)----
                if (i == 2) {
                    mdl.trained.LDA = lda(Category ~., data = df.train)
                    mdl.fitted.LDA = predict(mdl.trained.LDA, newdata = df.singleInput, prior = rep(1/4, 4))
                    fitted.LDA = mdl.fitted.LDA$class 
                }
                ## Quadratic discriminant analysis  
                if (i == 3) {
                    mdl.QDA = qda(Category ~., data = df.train)
                    mdl.fitted.QDA = predict(mdl.QDA, newdata = df.singleInput, prior = rep(1/4, 4))
                    fitted.QDA = mdl.fitted.QDA$class
                }
                ## Regularized logistic regression 
                if (i == 4) {
                    cv.mdl.ElasticNet = cv.glmnet(x = df.train[, -1] %>% as.matrix(), y = df.train$Category, 
                                                  family = "multinomial", alpha = .5, nfolds = 6)
                    
                    fitted.ElasticNet = predict(cv.mdl.ElasticNet, newx = df.singleInput %>% as.matrix(), 
                                                s = cv.mdl.ElasticNet$lambda.1se, type = "class")
                }
                ## Random forest 
                if (i == 5) {
                    mdl.randomForest = randomForest(Category ~., data = df.train, ntree = 500, mtry = 5) 
                    fitted.randomForest = predict(mdl.randomForest, newdata = df.singleInput)
                }
                ## Support vector machine
                if (i == 6) {
                    mdl.svm = svm(x = df.train[, -1], y = df.train$Category)
                    # predict
                    fitted.svm = predict(mdl.svm, newdata = df.singleInput)
                }
                ## Naive Bayes (benchmark) ----
                if(i == 7) {
                    mdl.Bayes = naiveBayes(x = df.train[, -1], y = df.train$Category)
                    fitted.Bayes = predict(mdl.Bayes, newdata = df.singleInput, type = "class")
                }
                
                ## Wrap up results
                if (i == 8) {
                    # As column names, no space allowed in header names. Thus use underscore later replace with space
                    t = data.frame("Linear_discriminant_analysis" = fitted.LDA, 
                                   "Quadratic_discriminant_analysis" = fitted.QDA, 
                                   "Elastic_net" = fitted.ElasticNet, 
                                   "Random_forest" = fitted.randomForest, 
                                   "Support_vector_machine" = fitted.svm, 
                                   "Naive_Bayes" = fitted.Bayes) %>% 
                        as_tibble() %>% gather(key = model, value = result) %>%
                        mutate(model = str_replace_all(model, pattern = "_", replacement = " "),
                               model = factor(model, levels = unique.models2, ordered = T))
                    
                }
                
                if (i == 9) {
                    p2 = t %>% gather(key = model, value = result) %>%
                        ggplot(aes(x = 1, y = 1, fill = result)) + 
                        geom_label(aes(label = result),
                                   size = 7, alpha = 1, color = "white", fontface = "bold",
                                   label.padding = unit(1, "lines"), # padding size
                                   label.r = unit(1, "lines")) + # rounding corder
                        facet_wrap(~model) +
                        scale_fill_manual(values = color.category) +
                        theme_minimal() +  
                        theme(plot.background =  element_blank(), 
                              axis.title = element_blank(), axis.text = element_blank(),
                              strip.text = element_text(face = "bold", size = 16, color = "#666666"),
                              legend.position = "NA") +
                        coord_flip() +  coord_fixed(.5)
                    
                    
                }
            }
        })
        return(p2)
    })
    
    output$plt.singleSamplePredict = renderPlot({  plt.singleSample.predict() })
    
    
    
    
    
    
    # tab 3: Batch prediction of real unknown samples
    
    
    tab3.output = eventReactive(input$userfile, {
        
        # Read user-input excel file
        infile <- input$userfile
        df.batchPredict <- read_excel(infile$datapath)
        
        df.batchPredict.featureSpace.scaled = df.batchPredict[, -1] %>% 
            scale(center = df.summaryALL$content.mean, scale = df.summaryALL$content.sd) %>%
            as_tibble()
        
        
        withProgress(message = "Modelling...", value = 0, {
            
            for(i in 1:length(progress)) {
                
                # update progress of step i
                incProgress(1/length(progress), detail = progress[i])
                # Pause for 1 sec to simulate a long computation.
                Sys.sleep(1)
                
                # Use all oroginal dataset for training models
                if (i == 1){
                    
                    # model training
                    df.train = func.strata.Norm.trainTest(trainingRatio = 1, scaleData = T)[[1]] # split ratio = 1
                    
                    
                    # Normalize unknown batch based on mean and standard deviation vector computed from entire original AIV dataset
                    df.batchPredict.featureSpace.scaled = df.batchPredict[, -1] %>% 
                        scale(center = df.summaryALL$content.mean, scale = df.summaryALL$content.sd) %>%
                        as_tibble()
                } 
                
                ## Linear discriminant analysis (LDA)----
                if (i == 2) {
                    mdl.trained.LDA = lda(Category ~., data = df.train)
                    mdl.fitted.LDA = predict(mdl.trained.LDA, newdata = df.batchPredict.featureSpace.scaled, prior = rep(1/4, 4))
                    fitted.LDA = mdl.fitted.LDA$class 
                }
                ## Quadratic discriminant analysis  
                if (i == 3) {
                    mdl.QDA = qda(Category ~., data = df.train)
                    mdl.fitted.QDA = predict(mdl.QDA, newdata = df.batchPredict.featureSpace.scaled, prior = rep(1/4, 4))
                    fitted.QDA = mdl.fitted.QDA$class
                }
                ## Regularized logistic regression 
                if (i == 4) {
                    cv.mdl.ElasticNet = cv.glmnet(x = df.train[, -1] %>% as.matrix(), y = df.train$Category, 
                                                  family = "multinomial", alpha = .5, nfolds = 6)
                    
                    fitted.ElasticNet = predict(cv.mdl.ElasticNet, newx = df.batchPredict.featureSpace.scaled %>% as.matrix(), 
                                                s = cv.mdl.ElasticNet$lambda.1se, type = "class")
                }
                ## Random forest 
                if (i == 5) {
                    mdl.randomForest = randomForest(Category ~., data = df.train, ntree = 500, mtry = 5) 
                    fitted.randomForest = predict(mdl.randomForest, newdata = df.batchPredict.featureSpace.scaled)
                }
                ## Support vector machine
                if (i == 6) {
                    mdl.svm = svm(x = df.train[, -1], y = df.train$Category)
                    # predict
                    fitted.svm = predict(mdl.svm, newdata = df.batchPredict.featureSpace.scaled)
                }
                ## Naive Bayes (benchmark) ----
                if(i == 7) {
                    mdl.Bayes = naiveBayes(x = df.train[, -1], y = df.train$Category)
                    fitted.Bayes = predict(mdl.Bayes, newdata = df.batchPredict.featureSpace.scaled, type = "class")
                }
                
                ## Wrap up results
                if (i == 8) {
                    # Summary table -----
                    d.fitted.all = data.frame("LDA" = fitted.LDA, 
                                              "QDA" = fitted.QDA, 
                                              "EN" = fitted.ElasticNet %>% c(), 
                                              "RF" = fitted.randomForest, 
                                              "SVM" = fitted.svm, 
                                              "NB" = fitted.Bayes) 
                    # summarize most voted
                    d.fitted.all = d.fitted.all %>% 
                        mutate(most.voted = apply(d.fitted.all, MARGIN = 1, func.mostVotes)) 
                    
                    # combine sample name with prediction result as output table
                    d.fitted.all.named.1 = cbind(Sample = df.batchPredict$Sample, d.fitted.all) 
                    
                }
            }
            
        })
        return(list(d.fitted.all.named.1) )
    }) 
    
    output$batchPrediction.output1 = renderDataTable(tab3.output()[[1]] )
    
    
    
    
    # Heatmap of sample-wise predicted result -----
    # rownames(d.fitted.all) = df.batchPredict$Sample
    
    tab3.heatmap = reactive({
        d.heat = tab3.output()[[1]]
        
        if (input$ArrangeCategory == TRUE) { # arrange by category order
            
            d.heat = d.heat %>% 
                mutate(most.voted = factor(most.voted, levels = unique.categories, ordered = T)) %>%
                arrange(most.voted) 
            SampleNames = d.heat$Sample
            d.heat = d.heat %>% select(-Sample) %>% as.matrix()
            rownames(d.heat) = SampleNames
        } else{ # keep original order
            SampleNames = d.heat$Sample
            d.heat = d.heat %>% select(-Sample) %>% as.matrix()
            rownames(d.heat) = SampleNames
        }
        
        plt.heatmap.batchPrediction = 
            d.heat %>% t() %>%
            Heatmap(col = color.category,
                    heatmap_legend_param = list(
                        title = "", title_position = "leftcenter",
                        nrow = 1,
                        labels_gp = gpar(fontsize = 15)), 
                    rect_gp = gpar(col = "white", lwd = 0.1))
        
        kk = draw(plt.heatmap.batchPrediction, heatmap_legend_side = "bottom")
        return(kk)
    })
    
    
    output$batchPrediction.output2 = renderPlot(tab3.heatmap())
    
    
    
    # download template
    
    df.template = drop_read_csv(file = "Template for batch prediction.csv")
    
    output$downloadTemplate = downloadHandler(
        filename = "Template for batch prediction.xlsx",
        content = function(file) {
            write_xlsx(df.template, path = file)  
        } 
    )
    
    # download predicted table
    output$downloadPredicted = downloadHandler(
        filename = "Predicted results.xlsx",
        content = function(file) {
            write_xlsx(tab3.output()[[1]], path = file )
        }
    )
    
}

shinyApp(ui, server)