-1

For my Stats class, we are using R to compute all of our statistics, and we are working with numeric data that also has a categorical factor. The way we currently are plotting fitted lines is with lm() and then looking at the summary to grab the coefficients manually, create a mesh, and then use the lines() function. I am wanting a way to do this easier. I have seen the predict() function, but not how to use this along with categories.

For example, the data set found here has 2 numerical variables, and one categorical. I want to be able plot the line of best fit for men and women in this set without having to extract each coefficient individually, as below in my current code.

bank<-read.table("http://www.uwyo.edu/crawford/datasets/bank.txt",header=TRUE)

fit <-lm(salary~years*gender,data=bank)
summary(fit)

yearhat<-seq(0,max(bank$salary),length=1000)
salaryfemalehat=fit$coefficients[1]+fit$coefficients[2]*yearhat
salarymalehat=(fit$coefficients[1]+fit$coefficients[3])+(fit$coefficients[2]+fit$coefficients[4])*yearhat
Jaap
  • 81,064
  • 34
  • 182
  • 193
Brandon Myers
  • 33
  • 1
  • 7
  • When asking for help, you should include a simple [reproducible example](https://stackoverflow.com/questions/5963269/how-to-make-a-great-r-reproducible-example) with sample input and desired output that can be used to test and verify possible solutions. Show the code you are currently using. – MrFlick Feb 26 '18 at 17:38

2 Answers2

1

Using what you have, you can get the same predicted values with

yearhat<-seq(0,max(bank$salary),length=1000)
salaryfemalehat <- predict(fit, data.frame(years=yearhat, gender="Female"))
salarymalehat <- predict(fit, data.frame(years=yearhat, gender="Male"))
MrFlick
  • 195,160
  • 17
  • 277
  • 295
0

To supplement MrFlick, in case of more levels we can try:

dat <- mtcars 
dat$cyl <- as.factor(dat$cyl)
fit <- lm(mpg ~ disp*cyl, data = dat)

plot(dat$disp, dat$mpg)
with(dat,
  for(i in levels(cyl)){
      lines(disp, predict(fit, newdata = data.frame(disp = disp, cyl = i))
            , col = which(levels(cyl) == i))
  }
)

enter image description here

J.R.
  • 3,838
  • 1
  • 21
  • 25