When we perform a groupby
across multiple columns, we often want to change how our data is stored. For instance, recall the example where we are running a chain of stores and have data about the number of sales at different locations on different days:
Location | Date | Day of Week | Total Sales |
---|---|---|---|
West Village | February 1 | W | 400 |
West Village | February 2 | Th | 450 |
Chelsea | February 1 | W | 375 |
Chelsea | February 2 | Th | 390 |
groupby
across two different columns (Location
and Day of Week
). This gave us results that looked like this:Location | Day of Week | Total Sales |
---|---|---|
Chelsea | M | 300 |
Chelsea | Tu | 310 |
Chelsea | W | 320 |
Chelsea | Th | 290 |
… | ||
West Village | Th | 400 |
West Village | F | 390 |
West Village | Sa | 250 |
… | ||
Location | M | Tu | W | Th | F | Sa | Su |
---|---|---|---|---|---|---|---|
Chelsea | 400 | 390 | 250 | 275 | 300 | 150 | 175 |
West Village | 300 | 310 | 350 | 400 | 390 | 250 | 200 |
… |
Reorganizing a table in this way is called pivoting. The new table is called a pivot table.
In Pandas, the command for pivot is:
df.pivot(columns='ColumnToPivot', index='ColumnToBeRows', values='ColumnToBeValues')
For our specific example, we would write the command like this:
# First use the groupby statement: unpivoted = df.groupby(['Location', 'Day of Week'])['Total Sales'].mean().reset_index() # Now pivot the table pivoted = unpivoted.pivot( columns='Day of Week', index='Location', values='Total Sales')
Just like with groupby
, the output of a pivot command is a new DataFrame, but the indexing tends to be “weird”, so we usually follow up with .reset_index()
.
For more on the pivot function, see the pandas documentation.
Instructions
In the previous example, you created a DataFrame with the total number of shoes of each shoe_type
/shoe_color
combination purchased for ShoeFly.com.
The purchasing manager complains that this DataFrame is confusing.
Make it easier for her to compare purchases of different shoe colors of the same shoe type by creating a pivot table. Save your results to the variable shoe_counts_pivot
.
Your table should look like this:
shoe_type | black | brown | navy | red | white |
---|---|---|---|---|---|
ballet flats | … | … | … | … | … |
sandals | … | … | … | … | … |
stilettos | … | … | … | … | … |
wedges | … | … | … | … | … |
Remember to use reset_index()
at the end of your code!
Display shoe_counts_pivot
using print
.