ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • MNIST 데이터 분석하기
    R/Keras 2018. 8. 8. 11:25

    이것을 하기 위해서는 


    1. Anaconda를 설치
    2. 관련 설정을 해야 한다.


    그런 다음에 다음의 코드를 실행하면 된다.

    # Keras with R
    install.packages("devtools")
    install.packages("yaml")
    install.packages("tensorflow")
    install.packages("keras")
    install.packages("reticulate")
    library(devtools)
    library(yaml)
    library(tensorflow)
    tensorflow::install_tensorflow(version = "gpu")
    library(keras)
    library(reticulate)



    # Import Data
    mnist  <- keras::dataset_mnist()
    x_train <- mnist$train$x
    y_train <- mnist$train$y
    x_test  <- mnist$test$x
    y_test  <- mnist$test$y



    # reshape
    x_train <- reticulate::array_reshape(x_train, c(nrow(x_train), 784))
    x_test  <- reticulate::array_reshape(x_test, c(nrow(x_test), 784))

    x_train <- x_train / 255
    x_test  <- x_test / 255

    y_train <- keras::to_categorical(y_train, num_classes = 10)
    y_test  <- keras::to_categorical(y_test, num_classes = 10)



    # Defining the Model
    model <- keras::keras_model_sequential() 

    model %>% 
    keras::layer_dense(units = 256, activation = 'relu', input_shape = c(784)) %>% 
    keras::layer_dropout(rate = 0.4) %>% 
    keras::layer_dense(units = 128, activation = 'relu') %>%
    keras::layer_dropout(rate = 0.3) %>%
    keras::layer_dense(units = 10, activation = 'softmax')

    summary(model)


    # Compling the Model
    model %>% 

         keras::compile(loss       = 'categorical_crossentropy',
                           optimizer = optimizer_rmsprop(),
                           metrics   = c('accuracy'))



    # Training the Model
    history <- model %>% 
         keras::fit(x_train,
                    y_train,
                    epochs         = 30,
                    batch_size     = 128,
                    validation_split = 0.2)

    plot(history)

    자동 대체 텍스트를 사용할 수 없습니다.


    # Evaluating the Model
    model %>% 
         keras::evaluate(x_test, y_test)

    자동 대체 텍스트를 사용할 수 없습니다.


    # Prediction
    model %>% 
         keras::predict_classes(x_test)



    [출처] https://keras.rstudio.com/


Designed by Tistory.