PyTorch Tabular 是一个用于构建和训练深度学习模型以解决各种表格数据问题的库。这个库专为表格数据设计,通过提供灵活的、易于使用的API来简化模型的构建、训练和推理过程。PyTorch Tabular 基于 PyTorch,利用了 PyTorch 的动态计算图和强大的GPU加速能力。
主要特性
- 多种模型支持:- PyTorch Tabular 提供了多种现成的模型架构,如深度神经网络(DNN)、多层感知机(MLP)、条件变换器网络(CTN)等,以及对于时间序列和分类问题的特定解决方案。
- 模块化和可扩展:- 用户可以轻松自定义模型和数据管道。库设计了模块化,使得用户可以插入自定义的组件或改变处理流程,以适应复杂的数据科学任务。
- 高级优化功能:- 包括自动特征工程、超参数调优和模型蒸馏等功能,旨在提高模型的效率和性能。
- 简化的数据处理:- 内置多种数据预处理和增强技术,如缺失值处理、特征标准化和编码,简化了从原始数据到模型训练的过程。
PyTorch Tabular 适用于任何需要利用表格数据进行预测的场景,如金融风险评估、销售预测、客户细分等。其强大的功能和灵活性使其成为数据科学家和机器学习工程师在处理表格数据时的优选工具。
开始使用
要开始使用 PyTorch Tabular,可以通过 pip 安装:
pip install pytorch-tabular
安装后,创建一个模型配置,并使用DataFrame训练和测试模型。下面是一个简单的例子:
from pytorch_tabular import TabularModel
from pytorch_tabular.config import DataConfig, ModelConfig, OptimizerConfig, TrainerConfig
data_config = DataConfig(
target=['target_column'],
continuous_cols=['feature1', 'feature2'],
categorical_cols=['category1', 'category2']
)
model_config = ModelConfig(task="regression")
trainer_config = TrainerConfig(max_epochs=10)
tabular_model = TabularModel(data_config, model_config, trainer_config)
tabular_model.fit(train_df, validation_df)
predictions = tabular_model.predict(test_df)
这个例子展示了如何配置和训练一个基本的回归模型。通过更改
ModelConfig
中的参数,你可以轻松调整模型以适应不同的任务和数据集。
可用模型
PyTorch Tabular包含的模型还是很多的,我这里做一个简单的总结
- 带类别嵌入的前馈网络:这是一个简单的前馈网络,但为分类列添加了嵌入层。
- Neural Oblivious Decision Ensembles for Deep Learning on Tabular Data:这是在 2020 年 ICLR 提出的模型,根据作者的说法,它在许多数据集上击败了经过良好调整的梯度提升模型。
- TabNet:这是谷歌研究推出的另一种模型,它在决策过程的多个步骤中使用稀疏注意力来建模输出。
- Mixture Density Networks :这是一个回归模型,使用高斯组件近似目标函数,并提供开箱即用的概率预测。
- AutoInt:通过自注意力神经网络自动学习特征交互的模型,试图以自动化的方式学习特征之间的交互,创建更好的表示,然后在下游任务中使用这种表示。
- TabTransformer:是为表格数据定制的Transformer模型,为分类特征创建上下文表示。
- FT Transformer:表格数据的深度学习模型。
- Gated Additive Tree Ensemble :是一种新颖的高性能、参数和计算效率的深度学习架构,用于表格数据。GATE 使用灵感来自 GRU 的门控机制作为具有内置特征选择机制的特征表示学习单元。将其与一组可微分的、非线性的决策树结合,通过简单的自注意力重新加权,以预测所需的输出。
- Gated Adaptive Network for Deep Automated Learning of Features (GANDALF):是 GATE 的简化版本,比 GATE 更高效和性能更佳。GANDALF 使 GFLUs 成为主要学习单元,同时在过程中引入了一些加速。由于需要调整的超参数非常少,这使得该模型易于使用和调整。
- DANETs:用于表格数据分类和回归的深度抽象网络:是一种新颖且灵活的神经组件,称为抽象层(AbstLay),它学习明确分组相关输入特征并生成更高级别的语义抽象特征。使用 AbstLays 构建一个特殊的基本块,并通过堆叠此类块,构建用于表格数据分类和回归的一系列深度抽象网络(DANets)。
总结
如果你想尝试从 CatBoost 或 LightGBM 这样的传统框架适配模型,可以尝试使用 PyTorch Tabular,它不仅简化了表格处理的过程,还内建了很多深度学习模型。如果你正在寻找一种方法来提升你的表格数据处理和模型性能,PyTorch Tabular 提供了一个强大且灵活的平台,以支持从简单到复杂的各种机器学习需求。