Los viejos [R]ockeros. model.matrix

R
python
causal inference
2021
Author

jlcr

Published

September 10, 2021

Nota: He cambiado la parte final para que hiciera lo mismo que el código de python, gracias a mi tocayo José Luis Hidalgo

El otro día por linkedin, mi jefe compartió el siguiente artículo recomendable por otro lado. El repo con el código y datos está aquí.

En el artículo hacen referencia a que una forma de ver el CATE (Conditional Average Treatmen Effect) cuando hay variables categóricas puede ser construirse los términos de interacción de alto orden entre las variables categóricas y calcular la diferencia entre la media de la variable de interés antes del tratamiento y después del tratamiento, en cada una de las variables de interacción consideradas.

Para eso el autor del artículo hace lo siguiente

import pandas as pd
import datetime as dt
import numpy as np
import itertools
import time
from copy import copy
df = pd.read_csv("https://raw.githubusercontent.com/kausa-ai/blog/master/how_causal_inference_lifts_augmented_analytics_beyond_flatland/dataset/ecommerce_sample.csv")

kpi_axis = 'order_value'
time_axis = 'time'

df[time_axis] = pd.to_datetime(df[time_axis],format = '%d/%m/%Y')
df.head()
        time       system  ... customer_age  customer_country
0 2019-09-08       win-pc  ...        21-24            france
1 2019-09-08  android-mob  ...        21-24            poland
2 2019-09-08  android-mob  ...        18-21            france
3 2019-09-08      ios-mob  ...        30-35           germany
4 2019-09-08   android-tv  ...        18-21            poland

[5 rows x 9 columns]
df.columns
Index(['time', 'system', 'product_category', 'order_value', 'household_income',
       'first_order_made', 'gender', 'customer_age', 'customer_country'],
      dtype='object')

Y crea una función para crear las interacciones de orden n. 

def binarize(df,cols,kpi_axis,time_axis,order):
    cols = cols.drop([kpi_axis,time_axis]) 
    features = []
    for k in range(0,order):
        features.append(cols)
    fs = []
    for f in itertools.product(*features):
      #  list(set(f)).sort()
        f = np.unique(f)
        fs.append(tuple(f))
    fs = tuple(set(i for i in fs))
    for f in fs:
        states =[]
        for d in f:
            states.append(tuple(set(df[d].astype('category'))))
        for state in itertools.product(*states):
            z = 1
            name = str()
            for d in range(0,len(f)):
                z = z*df[f[d]]==state[d]
                name +=  f[d] + " == " +str(state[d])
                if d<len(f)-1:
                   name += " AND "
            df[name] = z
    for d in cols:
        df = df.drop([d],axis = 1)
    return df

Y crea las variables, al medir cuánta tarda vemos que es en torno al minuto.

start_time = time.time()
df = binarize(df,df.columns,kpi_axis,time_axis,3)
elapsed_time = time.time() - start_time
print(elapsed_time/60)
1.0172961990038554
df.head(20)
         time  order_value  ...  customer_age == 46+  customer_age == 36-45
0  2019-09-08        52.03  ...                False                  False
1  2019-09-08        30.21  ...                False                  False
2  2019-09-08        55.15  ...                False                  False
3  2019-09-08        50.00  ...                False                  False
4  2019-09-08        71.80  ...                False                  False
5  2019-09-08        60.31  ...                False                  False
6  2019-09-08        51.94  ...                False                  False
7  2019-09-08       144.58  ...                False                  False
8  2019-09-08        47.79  ...                False                  False
9  2019-09-08        36.27  ...                False                  False
10 2019-09-08        57.49  ...                False                   True
11 2019-09-08        65.43  ...                False                  False
12 2019-09-08        34.47  ...                False                   True
13 2019-09-08        35.83  ...                False                  False
14 2019-09-08       122.96  ...                False                   True
15 2019-09-08       108.20  ...                False                  False
16 2019-09-08        57.94  ...                 True                  False
17 2019-09-08        33.60  ...                 True                  False
18 2019-09-08        41.48  ...                False                  False
19 2019-09-08        45.48  ...                False                  False

[20 rows x 2719 columns]

Y aquí es dónde vienen los viejos [R]ockeros. Cada vez que oigo hablar de interacciones pienso en R y en nuestras queridas fórmulas. En R podemos hacer lo mismo tirando de nuestro viejo amigo model.matrix

# puedo pasar de python a R con 
# df <-  py$df 
# o leer el csv igual 

library(tidyverse)
library(lubridate)
library(collapse) # for fast calculation
df <- readr::read_csv("https://raw.githubusercontent.com/kausa-ai/blog/master/how_causal_inference_lifts_augmented_analytics_beyond_flatland/dataset/ecommerce_sample.csv")

Convertimos las variables que nos interesan a tipo factor

df <- df %>% 
  mutate(time = time %>% as.character %>% dmy) %>%  
  mutate(corte_fecha = if_else(time <= '2019-09-11', "antes", "despues" )) %>% 
  mutate_if(is.character,as.factor) 


features <- setdiff(colnames(df),c("time","order_value", "corte_fecha"))
glimpse(df)
Rows: 100,000
Columns: 10
$ time             <date> 2019-09-08, 2019-09-08, 2019-09-08, 2019-09-08, 2019…
$ system           <fct> win-pc, android-mob, android-mob, ios-mob, android-tv…
$ product_category <fct> household, electronics, electronics, electronics, ele…
$ order_value      <dbl> 52.03, 30.21, 55.15, 50.00, 71.80, 60.31, 51.94, 144.…
$ household_income <fct> medium, low, low, low, low, medium, medium, medium, l…
$ first_order_made <fct> no, no, no, yes, no, no, no, no, no, yes, no, no, yes…
$ gender           <fct> male, female, female, n.a., n.a., n.a., n.a., n.a., m…
$ customer_age     <fct> 21-24, 21-24, 18-21, 30-35, 18-21, 30-35, 30-35, 30-3…
$ customer_country <fct> france, poland, france, germany, poland, france, fran…
$ corte_fecha      <fct> antes, antes, antes, antes, antes, antes, antes, ante…

Y al utilizar model matrix R hace por defecto codificación parcial de las variables (One Hot quitando la que sobra para los modernos), así que para tener lo mismo hay que tocar un argumento de model matrix. el truco es definir para cada variable el contrasts = FALSE. Por ejemplo

Por defecto el contrasts para una variable categórica elimina la categoría redundante.

contrasts(df$product_category) 
                    electronics household sports and outdoors
books                         0         0                   0
electronics                   1         0                   0
household                     0         1                   0
sports and outdoors           0         0                   1

Pero podemos decir que no, y así nos construirá tantas variables dicotómicas como categorías tenga nuestra variable.

contrasts(df$product_category, contrasts = FALSE)
                    books electronics household sports and outdoors
books                   1           0         0                   0
electronics             0           1         0                   0
household               0           0         1                   0
sports and outdoors     0           0         0                   1

Ya podemos crear nuestra función binarize

Para crear interacciones de orden n en R basta con definir la fórmula ~ 0 + ( var1 + var2 + var3)^n

binarize <- function(df, columns, order = 3) {
  
  # creo formula  uniendo por + las variables y luego la interación del orden deseado
  features_unidas <- paste(features, collapse = " + ")

  formula_orden <- as.formula(paste0("~ 0 + (  ", features_unidas, ")^ ", order))
  
  # con model.matrix me creo el dataframe con los términos de interacción 
  df_variables <- as_tibble(model.matrix(
    formula_orden,
    df,
    # aqui está la clave 
    contrasts.arg = lapply(df[, features], contrasts, contrasts = FALSE)
  ))

  df_final <- bind_cols(
    df %>%
      select(time, order_value, corte_fecha),
    df_variables
  )


  df_final <- df_final %>%
    select(-time) %>%
    select(corte_fecha, order_value, everything())
}

Y podemos medir cuanto tarda nuestra función sobre el mismo conjunto de datos. Y vemos, que en crear las variables tarda unos pocos segundos.

tictoc::tic()
df_final <- binarize(df, features, 3)
tictoc::toc()
2.482 sec elapsed
head(df_final, 10)
# A tibble: 10 × 2,719
   corte_fecha order_v…¹ syste…² syste…³ syste…⁴ syste…⁵ syste…⁶ produ…⁷ produ…⁸
   <fct>           <dbl>   <dbl>   <dbl>   <dbl>   <dbl>   <dbl>   <dbl>   <dbl>
 1 antes            52.0       0       0       0       0       1       0       0
 2 antes            30.2       1       0       0       0       0       0       1
 3 antes            55.2       1       0       0       0       0       0       1
 4 antes            50         0       0       1       0       0       0       1
 5 antes            71.8       0       1       0       0       0       0       1
 6 antes            60.3       0       0       0       1       0       0       0
 7 antes            51.9       0       0       1       0       0       0       0
 8 antes           145.        1       0       0       0       0       0       1
 9 antes            47.8       0       1       0       0       0       0       0
10 antes            36.3       0       0       0       0       1       1       0
# … with 2,710 more variables: product_categoryhousehold <dbl>,
#   `product_categorysports and outdoors` <dbl>, household_incomehigh <dbl>,
#   household_incomelow <dbl>, household_incomemedium <dbl>,
#   first_order_madeno <dbl>, first_order_madeyes <dbl>, genderfemale <dbl>,
#   gendermale <dbl>, gendern.a. <dbl>, `customer_age18-21` <dbl>,
#   `customer_age21-24` <dbl>, `customer_age25-30` <dbl>,
#   `customer_age30-35` <dbl>, `customer_age36-45` <dbl>, …

Y ya estaría .

CATE

La parte interesante del artículo es la de calcular el CATE como la diferencia de medias de la variable order_value en cada uno de los segmentos antes de una determinada fecha y después.

En el artículo lo hacen así


start_time = time.time()

df_before = df[df[time_axis] <= '2019-09-11']
df_after  = df[df[time_axis] > '2019-09-11']
features = copy(df.drop([time_axis,kpi_axis], axis=1).columns)

K = 10 
subgroups=[]
score=[]
for k in range(0,K):
    CATE = []
    y_before = df_before[kpi_axis]
    y_after= df_after[kpi_axis]
    
    #compute CATEs for all subgroups
    for d in features:
        g = df_before[d] == True
        m_before = np.mean(y_before[g])
        g = df_after[d] == True
        m_after = np.mean(y_after[g])
        CATE.append(m_after-m_before)
    
    #find subgroup with biggest CATE
    index = np.argsort(-abs(np.array(CATE)))
    subgroups.append(features[index[0]])
    score.append(abs( CATE [index[0]]))
    
    #remove found subgroups from dataset
    df_before = df_before[df_before[features[index[0]]] == False]
    df_after = df_after[df_after[features[index[0]]] == False] 
    features = features.drop(features[index[0]])
    

df_nuevo = pd.DataFrame(np.array([score,subgroups]).T, columns=['CATE','features'])

elapsed_time = time.time() - start_time

print(elapsed_time)
39.18984651565552
df_nuevo
                 CATE                                           features
0   289.4008630608073  customer_age == 46+ AND first_order_made == ye...
1   8.979524530417706  customer_age == 30-35 AND customer_country == ...
2   8.690151515151513  customer_age == 36-45 AND customer_country == ...
3   8.567118700265269  customer_age == 36-45 AND customer_country == ...
4   7.811875000000015  customer_age == 30-35 AND customer_country == ...
5   7.510393162393143  customer_age == 36-45 AND customer_country == ...
6    8.40514915254235  customer_age == 36-45 AND customer_country == ...
7   7.597928321678324  customer_age == 36-45 AND customer_country == ...
8  7.4170337760987906  customer_age == 46+ AND customer_country == ca...
9  7.2043861024033475  customer_age == 21-24 AND customer_country == ...

Y tarda su ratillo, pero no está mal

En R lo podemos hacer utilizando nuestro viejo amigo el R base para poner las condiciones

CalcularCate_old <-  function(f, df){
  
  filtro_antes   = df[[f]] == 1 & df$corte_fecha == "antes"
  filtro_despues = df[[f]] == 1 & df$corte_fecha != "antes"
  
  media_antes   = mean(df$order_value[filtro_antes])
  media_despues = mean(df$order_value[filtro_despues])
  
  cate = media_despues - media_antes
  
  return(cate)
  
  
}

# usando fmean de collapse

CalcularCate <-  function(f, df){
  
  filtro_antes   = df[[f]] == 1 & df$corte_fecha == "antes"
  filtro_despues = df[[f]] == 1 & df$corte_fecha != "antes"
  
  media_antes   = fmean(df$order_value[filtro_antes])
  media_despues = fmean(df$order_value[filtro_despues])
  
  cate = media_despues - media_antes
  
  return(cate)
  
  
}
tictoc::tic()
K = 10
cate = c()
tmp <-  df_final

for ( k in 1:K) {
  
  features <- colnames(tmp)[3:ncol(tmp)]
  res <-  unlist(lapply(features, function(x) CalcularCate(x, df = tmp)))
  names(res) <- features
  ordenado <-  sort(abs(res), decreasing = TRUE)[1]
  f <-  names(ordenado)
  cate <- c(cate, ordenado)
  tmp <-  tmp[tmp[[f]]== 0, c("corte_fecha", "order_value", setdiff(features, f))]
}

 
 
tictoc::toc()
40.084 sec elapsed
cate
              systemandroid-tv:first_order_madeyes:customer_age46+ 
                                                        289.400863 
                systemios-mob:customer_age30-35:customer_countryuk 
                                                          8.979525 
     household_incomelow:customer_age36-45:customer_countrygermany 
                                                          8.690152 
            systemwin-pc:customer_age36-45:customer_countrygermany 
                                                          8.567119 
     household_incomehigh:customer_age30-35:customer_countrycanada 
                                                          7.811875 
   product_categorybooks:customer_age36-45:customer_countrygermany 
                                                          7.510393 
        systemandroid-tv:customer_age36-45:customer_countrygermany 
                                                          8.405149 
            genderfemale:customer_age36-45:customer_countrygermany 
                                                          7.597928 
product_categoryelectronics:customer_age46+:customer_countrycanada 
                                                          7.417034 
                 systemios-pc:customer_age21-24:customer_countryuk 
                                                          7.204386 

Y bueno, parece que en este caso, los viejos [R]ockeros no lo hacen mal del todo, sobre todo la parte de model.matrix es muy rápida, y usando collapse para calcular la media es aún más rápido

En resumen, model.matrix de rbase es muy rápido, y usar fmean de collapse en vez del mean de R-base mejor, con lo que con esta implementación en R es mucho más rápida que la vista en python (que seguramente se puede mejorar hasta igualar)