0


spark之时间序列预测(商品销量预测)

项目地址见:https://github.com/jiangnanboy/spark_data_mining/tree/master/src/main/java/com/sy/dataalgorithms/advanced/time_series

一.概要

此项目将围绕一个时间序列预测任务展开。该任务是Kaggle上的一个比赛,M5 Forecasting - Accuarcy(https://www.kaggle.com/c/m5-forecasting-accuracy/notebooks )。M5的赛题目标是预测沃尔玛各种商品在未来28天的销量。本案例使用前1913天的数据作为训练数据,来预测1914天到1941天的销量。并且,我们只对最细粒度的30490条序列进行预测。 训练数据从kaggle中自行下载:

  • calendar.csv - Contains information about the dates on which the products are sold.
  • sales_train_validation.csv - Contains the historical daily unit sales data per product and store [d_1 - d_1913]
  • sample_submission.csv - The correct format for submissions. Reference the Evaluation tab for more info.
  • sell_prices.csv - Contains information about the price of the products sold per store and date.
  • sales_train_evaluation.csv - Includes sales [d_1 - d_1941] (labels used for the Public leaderboard)

以上数据下载后放入resources/advanced下,并在properties.properties中配置一下文件名和路径,以供程序读取和处理数据。

1.数据处理以及特征工程利用java spark进行提取,见TimeSeries.java。

2.模型的训练及预测利用python lightgbm进行操作,见time_series.ipynb,data.7z下是spark处理好的数据。

二.特征工程代码解读,完整见项目中代码注释

  1. /**
  2. * 分析和挖掘数据
  3. * @param session
  4. */
  5. public static void analysisData(SparkSession session) {
  6. // 一.数据集
  7. /* 1.这里是历史销量sales_train_validation数据
  8. +--------------------+-------------+---------+-------+--------+--------+---+---+---+---+---+---+---+---+-
  9. | id| item_id| dept_id| cat_id|store_id|state_id|d_1|d_2|d_3|d_4|d_5|d_6|d_7|d_8|d_9|d_10|...
  10. +--------------------+-------------+---------+-------+--------+--------+---+---+---+---+---+---+---+---+---+----+
  11. |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0|...
  12. |HOBBIES_1_002_CA_...|HOBBIES_1_002|HOBBIES_1|HOBBIES| CA_1| CA| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0|...
  13. |HOBBIES_1_003_CA_...|HOBBIES_1_003|HOBBIES_1|HOBBIES| CA_1| CA| 0| 0| 0| 0| 0| 0| 0| 0| 0| 0|...
  14. +--------------------+-------------+---------+-------+--------+--------+---+---+---+---+---+---+---+---+---+----+
  15. schema:
  16. |-- id: string (nullable = true)
  17. |-- item_id: string (nullable = true)
  18. |-- dept_id: string (nullable = true)
  19. |-- cat_id: string (nullable = true)
  20. |-- store_id: string (nullable = true)
  21. |-- state_id: string (nullable = true)
  22. |-- d_1: integer (nullable = true)
  23. |-- d_2: integer (nullable = true)
  24. |-- d_3: integer (nullable = true)
  25. |-- d_4: integer (nullable = true)
  26. |-- ......
  27. */
  28. String salesTrainValidationPath = TimeSeries.class.getClassLoader().getResource(PropertiesReader.get("advanced_timeseries_sales_train_validation_csv")).getPath().replaceFirst("/", "");
  29. Dataset<Row> salesTVDataset = session.read()
  30. .option("sep", ",")
  31. .option("header", true)
  32. .option("inferSchema", true)
  33. .csv(salesTrainValidationPath);
  34. /*首先,我们只留下salesTVDataset中的历史特征值,删去其他列。
  35. +---+---+---+---+---+---+---+---+---+----+
  36. |d_1|d_2|d_3|d_4|d_5|d_6|d_7|d_8|d_9|d_10|
  37. +---+---+---+---+---+---+---+---+---+----+
  38. | 0| 0| 0| 0| 0| 0| 0| 0| 0| 0|...
  39. | 0| 0| 0| 0| 0| 0| 0| 0| 0| 0|...
  40. | 0| 0| 0| 0| 0| 0| 0| 0| 0| 0|...
  41. +---+---+---+---+---+---+---+---+---+----+
  42. */
  43. Column[] columns = new Column[1913];
  44. int index = 0;
  45. for(String column : salesTVDataset.columns()) {
  46. if(column.contains("d_")) {
  47. columns[index] = functions.col(column);
  48. index++;
  49. }
  50. }
  51. Dataset<Row> xDataset = salesTVDataset.select(columns);
  52. /* 2.这里是日历信息calendar数据
  53. +----------+--------+--------+----+-----+----+---+------------+------------+------------+------------+-------+-------+-------+
  54. | date|wm_yr_wk| weekday|wday|month|year| d|event_name_1|event_type_1|event_name_2|event_type_2|snap_CA|snap_TX|snap_WI|
  55. +----------+--------+--------+----+-----+----+---+------------+------------+------------+------------+-------+-------+-------+
  56. |2011-01-29| 11101|Saturday| 1| 1|2011|d_1| null| null| null| null| 0| 0| 0|
  57. |2011-01-30| 11101| Sunday| 2| 1|2011|d_2| null| null| null| null| 0| 0| 0|
  58. |2011-01-31| 11101| Monday| 3| 1|2011|d_3| null| null| null| null| 0| 0| 0|
  59. +----------+--------+--------+----+-----+----+---+------------+------------+------------+------------+-------+-------+-------+
  60. schema:
  61. |-- date: string (nullable = true)
  62. |-- wm_yr_wk: integer (nullable = true)
  63. |-- weekday: string (nullable = true)
  64. |-- wday: integer (nullable = true)
  65. |-- month: integer (nullable = true)
  66. |-- year: integer (nullable = true)
  67. |-- d: string (nullable = true)
  68. |-- event_name_1: string (nullable = true)
  69. |-- event_type_1: string (nullable = true)
  70. |-- event_name_2: string (nullable = true)
  71. |-- event_type_2: string (nullable = true)
  72. |-- snap_CA: integer (nullable = true)
  73. |-- snap_TX: integer (nullable = true)
  74. |-- snap_WI: integer (nullable = true)
  75. */
  76. String calendarPath = TimeSeries.class.getClassLoader().getResource(PropertiesReader.get("advanced_timeseries_calendar_csv")).getPath().replaceFirst("/", "");
  77. Dataset<Row> calendarDataset = session.read()
  78. .option("sep", ",")
  79. .option("header", true)
  80. .option("inferSchema", true)
  81. .csv(calendarPath);
  82. /* 3.商品每周的价格信息sell_prices
  83. +--------+-------------+--------+----------+
  84. |store_id| item_id|wm_yr_wk|sell_price|
  85. +--------+-------------+--------+----------+
  86. | CA_1|HOBBIES_1_001| 11325| 9.58|
  87. | CA_1|HOBBIES_1_001| 11326| 9.58|
  88. | CA_1|HOBBIES_1_001| 11327| 8.26|
  89. +--------+-------------+--------+----------+
  90. schema:
  91. |-- store_id: string (nullable = true)
  92. |-- item_id: string (nullable = true)
  93. |-- wm_yr_wk: integer (nullable = true)
  94. |-- sell_price: double (nullable = true)
  95. */
  96. // String sellPricesPath = TimeSeries.class.getClassLoader().getResource(PropertiesReader.get("advanced_timeseries_sell_prices_csv")).getPath().replaceFirst("/", "");
  97. // Dataset<Row> sellPricesDataset = session.read()
  98. // .option("sep", ",")
  99. // .option("header", true)
  100. // .option("inferSchema", true)
  101. // .csv(sellPricesPath);
  102. // (1).测试集,我们只是计算了第1914天的数据的特征。这只些特征只能用来预测1914天的销量,也就是说,实际上是我们的测试数据。
  103. int targetDay = 1914;
  104. // 使用历史数据中最后的7天构造特征
  105. int localRange = 7;
  106. // 由于使用前1913天的数据预测第1914天,历史数据与预测目标的距离只有1天,因此predictDistance=1
  107. // 如果使用前1913天的数据预测第1915天,则历史数据与预测目标的距离有2天,因此predictDistance=2,以此类推
  108. int predictDistance = 1;
  109. Dataset<Row> testDataset = getTestDataset(salesTVDataset, calendarDataset, xDataset, targetDay, predictDistance);
  110. // (2).训练集,为了构造训练数据,我们对1914天之前的日期进行同样的特征计算操作,并附上它们的当天销量作为数据标签。
  111. int trainingDataDays = 7; // 为了简便,现只取7天的数据作训练集
  112. Dataset<Row> trainDataset = getTrainDataset(salesTVDataset, calendarDataset, xDataset, trainingDataDays, targetDay, predictDistance);
  113. String salesTrainEvaluationPath = TimeSeries.class.getClassLoader().getResource(PropertiesReader.get("advanced_timeseries__sales_train_evaluation_csv")).getPath().replaceFirst("/", "");
  114. Dataset<Row> labelDataset = session.read()
  115. .option("sep", ",")
  116. .option("header", true)
  117. .option("inferSchema", true)
  118. .csv(salesTrainEvaluationPath);
  119. // (3).测试集的label
  120. Dataset<Row> testLabelDataset = getTestDatasetLabel(labelDataset, targetDay);
  121. // (4).训练集的label
  122. Dataset<Row> trainLabelDataset = getTrainDatasetLabel(labelDataset, targetDay, trainingDataDays, predictDistance);
  123. // (5).保存为csv文件,供python lightgbm训练
  124. // 保存test dataset
  125. String testDatasetCsvPath = "E:\\idea_project\\spark_data_mining\\src\\main\\resources\\dataalgorithms\\advanced\\timeseries_data\\testdata.csv";
  126. saveCsv(testDataset, testDataset.columns(), testDatasetCsvPath);
  127. // 保存train dataset
  128. String trainDatasetCsvPath = "E:\\idea_project\\spark_data_mining\\src\\main\\resources\\dataalgorithms\\advanced\\timeseries_data\\traindata.csv";
  129. saveCsv(trainDataset, trainDataset.columns(), trainDatasetCsvPath);
  130. // 保存test label
  131. String testLabelCsvPath = "E:\\idea_project\\spark_data_mining\\src\\main\\resources\\dataalgorithms\\advanced\\timeseries_data\\testlabel.csv";
  132. saveCsv(testLabelDataset, testLabelDataset.columns(), testLabelCsvPath);
  133. // 保存train label
  134. String trainLabelCsvPath = "E:\\idea_project\\spark_data_mining\\src\\main\\resources\\dataalgorithms\\advanced\\timeseries_data\\trainlabel.csv";
  135. saveCsv(trainLabelDataset, trainLabelDataset.columns(), trainLabelCsvPath);
  136. }

三.模型训练


本文转载自: https://blog.csdn.net/qq_20182781/article/details/140679024
版权归原作者 石头木V2 所有, 如有侵权,请联系我们删除。

“spark之时间序列预测(商品销量预测)”的评论:

还没有评论