iml icon indicating copy to clipboard operation
iml copied to clipboard

ggplot functions are not working with ale and pdp plots for all features

Open bappa10085 opened this issue 2 years ago • 0 comments

I am trying to use ggplot functions with the plot function of iml package. But the additional aurgments are applied to only the last plot. Here is an example

#Load required packahes
library("iml")
library("randomForest")

# Train a random forest on the Boston dataset:
data("Boston", package = "MASS")
set.seed(42)
rf <- randomForest(medv ~ ., data = Boston, ntree = 50)

#create a Predictor object, that holds the model and the data
X <- Boston[which(names(Boston) != "medv")]
predictor <- Predictor$new(rf, data = X, y = Boston$medv)

# Compute the accumulated local effects for all features
ale <- FeatureEffects$new(predictor, method = "ale")

plot(ale) + theme_bw() +
  theme(panel.grid.major = element_blank(),
        panel.grid.minor = element_blank())+
  theme(text=element_text(family = "serif", size=15), 
        axis.title.x = element_text(size = 18),
        axis.title.y = element_text(size = 18),
        axis.text.x = element_text(colour="black",face="bold"), 
        axis.text.y = element_text(colour="black",face="bold"))

image

Similarly for pdp plot also

# Compute the partial dependence plot for all features
pdp <- FeatureEffects$new(predictor, method = "pdp")
pdp$plot() + theme_bw() +
  theme(panel.grid.major = element_blank(),
        panel.grid.minor = element_blank())+
  theme(text=element_text(family = "serif", size=15), 
        axis.title.x = element_text(size = 18),
        axis.title.y = element_text(size = 18),
        axis.text.x = element_text(colour="black",face="bold"), 
        axis.text.y = element_text(colour="black",face="bold"))

image

Now I have done some tweaking after getting the results inside the ale or pdp object and afterthat all the ggplot functions are working like

#PDP plotting
df2 <- pdp$results %>% 
  imap_dfr(., ~ data.frame(name = .y, Borders = .x$.borders, Value = .x$.value))

df2 %>% 
  ggplot(aes(x = Borders, y = Value)) +
  geom_line(aes(group = name)) +
  facet_wrap(name~., scales = "free") + theme_bw() +
  theme(panel.grid.major = element_blank(),
        panel.grid.minor = element_blank())+
  theme(text=element_text(family = "serif", size=15), 
        axis.title.x = element_text(size = 18),
        axis.title.y = element_text(size = 18),
        axis.text.x = element_text(colour="black",face="bold"), 
        axis.text.y = element_text(colour="black",face="bold"))

image

#ALE plotting
df3 <- ale$results %>% 
  imap_dfr(., ~ data.frame(name = .y, Borders = .x$.borders, value = .x$.value))

df3 %>% 
  ggplot(aes(x = Borders, y = value)) +
  geom_line(aes(group = name)) +
  facet_wrap(name~., scales = "free") + theme_bw() +
  theme(panel.grid.major = element_blank(),
        panel.grid.minor = element_blank()) +
  theme(text=element_text(family = "serif", size=15), 
        axis.title.x = element_text(size = 18),
        axis.title.y = element_text(size = 18),
        axis.text.x = element_text(colour="black",face="bold"), 
        axis.text.y = element_text(colour="black",face="bold"))

image

Can you can changes in the source code so that ggplot functions like facet_wrap, theme works for all the panels?

bappa10085 avatar May 14 '23 06:05 bappa10085