DJL(Deep Java Library)是一个开源的深度学习框架,由AWS推出,DJL支持多种深度学习后端,包括但不限于:
MXNet:由Apache软件基金会支持的开源深度学习框架。
PyTorch:广泛使用的开源机器学习库,由Facebook的AI研究团队开发。
TensorFlow:由Google开发的另一个流行的开源机器学习框架。
DJL与Java生态系统紧密集成,可以与Spring Boot、Quarkus等Java框架协同工作。
maven
<!-- djl-->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.28.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.28.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<version>0.28.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>0.28.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>0.28.0</version>
</dependency>
<!-- /djl-->
Java DJL 架构图
┌──────────────────────────────┐
│ ModelZoo │
├──────────────────────────────┤
│ Model │
└───────────────┬──────────────┘
│
┌─────────▼─────────┐
│ Engine │
└───────┬─┬─────────┘
│ │
┌───────▼─▼─────────┐
│ NDManager │
└───────┬─┬─────────┘
│ │
┌─────────▼─▼───────────┐
│ Dataset
└─────────┬─────────────┘
│
┌─────────▼─────────────┐
│ Trainer / Predictor │
└───────────────────────┘
主要组件详细描述
1. ModelZoo 和 Model
- ModelZoo:提供多种预训练模型######
ModelZoo
的功能1. 模型发现与下载: -ModelZoo
提供了一种机制,可以从多种来源(例如模型提供商、在线仓库等)发现和下载预训练模型。- 例如,可以从 AWS S3、Hugging Face、TensorFlow Hub 等平台下载模型。2. 模型加载: -ModelZoo
提供了方便的方法来加载模型,用户可以根据需求加载不同类型的模型(例如图像分类模型、对象检测模型、自然语言处理模型等)。- 加载模型时,可以指定模型的名称、版本、以及模型的参数配置。3. 模型管理: -ModelZoo
帮助用户管理已下载和加载的模型,可以方便地查看、更新和删除模型。- 通过这种方式,可以有效地管理本地的模型资源,避免重复下载和浪费存储空间。#### 示例import ai.djl.Applicationimport ai.djl.Modelimport ai.djl.ModelExceptionimport ai.djl.modality.Classificationsimport ai.djl.modality.cv.Imageimport ai.djl.repository.zoo.Criteriaimport ai.djl.repository.zoo.ModelZooimport ai.djl.translate.TranslateExceptionobject ModelZooExample {@Throws(ModelException::class, TranslateException::class)@JvmStaticfunmain(args: Array<String>){// 定义模型的标准val criteria: Criteria<Image, Classifications>= Criteria.builder().optApplication(Application.CV.IMAGE_CLASSIFICATION)// 应用场景:图像分类.setTypes(Image::class.java, Classifications::class.java)// 输入输出类型.optFilter("backbone","resnet50")// 模型过滤条件.build()// 从 ModelZoo 加载模型val model: Model = ModelZoo.loadModel(criteria)// 使用模型进行推理// ...}}
####ModelZoo
的类与接口- **ModelZoo
**:核心类,提供模型的下载和加载功能。- **Criteria
**:定义模型加载的标准和过滤条件,用于指定所需模型的应用场景、输入输出类型等。- **ModelLoader
**:用于实际执行模型的下载和加载操作。 - Model:表示一个深度学习模型的接口,包含模型的加载、保存和运行等操作。- #### ai.djl.ModelZoo##### Key Methods:-
Model loadModel(Criteria<?, ?> criteria)
: Loads a model based on the provided criteria.-ModelInfo getModel(ModelId modelId)
: Retrieves information about a specific model using itsModelId
.-Set<ModelId> listModels(ZooModel<?, ?> model)
: Lists all models in the zoo that match the given model.####ai.djl.ModelInfo
InterfaceModelInfo
provides metadata about a model, including its name, description, and input/output information.##### Key Methods:-String getName()
: Returns the name of the model.-String getDescription()
: Provides a description of the model.-Shape getInputShape()
: Returns the shape of the input tensor.-Shape getOutputShape()
: Returns the shape of the output tensor.####ai.djl.ModelId
ClassModelId
uniquely identifies a model in the model zoo. It includes information about the model’s group, name, and version.##### Key Fields:-String getGroup()
: Gets the group name of the model.-String getName()
: Gets the name of the model.-String getVersion()
: Gets the version of the model.####ai.djl.Application
EnumApplication
enumerates different types of applications supported by the model zoo, such as IMAGE_CLASSIFICATION, OBJECT_DETECTION, etc.##### Key Values:-CV.IMAGE_CLASSIFICATION
-CV.OBJECT_DETECTION
-NLP.TEXT_CLASSIFICATION
####ai.djl.Criteria
ClassCriteria
is a builder for creating criteria objects used to filter and load models.##### Key Methods:-static Builder<?, ?> builder()
: Creates a new builder instance.-Criteria<I, O> optApplication(Application application)
: Sets the application type.-Criteria<I, O> optEngine(String engine)
: Specifies the engine to use (e.g., MXNet, PyTorch)###### exampleimport ai.djl.Modelimport ai.djl.ModelExceptionimport ai.djl.modality.Classificationsimport ai.djl.modality.cv.Imageimport ai.djl.modality.cv.ImageFactoryimport ai.djl.ndarray.NDListimport ai.djl.translate.TranslateExceptionimport ai.djl.translate.Translatorimport ai.djl.translate.TranslatorContextimport java.io.IOExceptionimport java.nio.file.Pathsobject DjlExample { @JvmStatic fun main(args: Array<String>) { // 模型路径 val modelDir = Paths.get("models") val modelName = "resnet18" try { Model.newInstance(modelName).use { model -> // 加载模型 model.load(modelDir) // 加载输入图像 val img = ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg")) // 获取预测器 val predictor = model.newPredictor(MyTranslator()) // 执行推理 val result = predictor.predict(img) println(result) } } catch (e: IOException) { e.printStackTrace() } catch (e: ModelException) { e.printStackTrace() } catch (e: TranslateException) { e.printStackTrace() } } // 自定义 Translator private class MyTranslator : Translator<Image?, Classifications?> { override fun processInput(ctx: TranslatorContext?, input: Image?): NDList { return NDList(input!!.toNDArray(ctx!!.ndManager)) } override fun processOutput(ctx: TranslatorContext, list: NDList): Classifications { val probabilitiesNDArray = list.singletonOrThrow().softmax(1) val labels: List<String> = List(100) { "name$it" } return Classifications(labels, probabilitiesNDArray) } }}
2. Dataset
- 常见的数据集类型:1. RandomAccessDataset: -
RandomAccessDataset
是一种基本的数据集接口,适用于数据可以随机访问的情况,如数组或列表。- 它支持批处理(batching)、数据切片(slicing)等操作,适合大多数监督学习任务。2. IterableDataset: -IterableDataset
适用于数据不能随机访问的情况,如流数据或实时生成的数据。- 它通过迭代器(iterator)提供数据,适用于需要动态生成或处理的数据源。3. RecordDataset: -RecordDataset
是基于记录文件(record file)的数据集格式,常用于大规模数据处理。- 它可以高效地加载和处理数据记录,适用于分布式训练和大数据集的处理。DJL 的数据集组件提供的功能包括:1. 数据加载和预处理: - 支持从多种数据源加载数据,如本地文件、远程服务器、数据库等。- 提供数据预处理功能,如归一化、数据增强、特征提取等。2. 批处理(Batching): - 支持将数据分成小批次进行处理,适用于大规模数据集的训练。- 提供灵活的批处理策略,可根据需要进行自定义。3. 数据变换(Transformations): - 提供多种数据变换功能,如图像变换、文本处理、数值处理等。- 支持链式调用,将多个变换操作组合在一起,形成数据处理管道。4. 数据加载器(DataLoader): -DataLoader
负责将数据集打包成批次,并在训练过程中按需提供数据。- 支持多线程数据加载,提高数据处理效率。 - Dataset:定义数据集的抽象类,用户可以继承该类来实现自定义的数据集。-
import ai.djl.Model;import ai.djl.ModelException;import ai.djl.inference.Predictor;import ai.djl.modality.Classifications;import ai.djl.modality.cv.Image;import ai.djl.modality.cv.ImageFactory;import ai.djl.repository.zoo.Criteria;import ai.djl.repository.zoo.ModelZoo;import ai.djl.translate.TranslateException;import java.io.IOException;import java.nio.file.Paths;public class DjlExample { public static void main(String[] args) throws IOException, ModelException, TranslateException { // 加载模型 Criteria<Image, Classifications> criteria = Criteria.builder() .optEngine("TensorFlow") // 选择引擎 .setTypes(Image.class, Classifications.class) .optModelPath(Paths.get("path/to/model")) .build(); try (Model model = ModelZoo.loadModel(criteria); Predictor<Image, Classifications> predictor = model.newPredictor()) { // 加载图像 Image img = ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg")); // 进行推理 Classifications result = predictor.predict(img); System.out.println(result); } }}
-import ai.djl.Application;import ai.djl.Model;import ai.djl.basicdataset.cv.classification.FashionMnist;import ai.djl.engine.Engine;import ai.djl.metric.Metrics;import ai.djl.ndarray.NDArray;import ai.djl.ndarray.NDManager;import ai.djl.training.DefaultTrainingConfig;import ai.djl.training.EasyTrain;import ai.djl.training.Trainer;import ai.djl.training.dataset.Batch;import ai.djl.training.dataset.Dataset;import ai.djl.training.listener.TrainingListener;import ai.djl.training.loss.Loss;import ai.djl.training.optimizer.Optimizer;import ai.djl.training.tracker.Tracker;import ai.djl.translate.TranslateException;import ai.djl.util.Pair;import java.io.IOException;public class DJLDatasetExample { public static void main(String[] args) throws IOException, TranslateException { NDManager manager = NDManager.newBaseManager(); FashionMnist fashionMnist = FashionMnist.builder() .optUsage(Dataset.Usage.TRAIN) .setSampling(32, true) // 32 is the batch size .optLimit(Long.MAX_VALUE) // Use this to limit the number of samples .build(); fashionMnist.prepare(); Model model = Model.newInstance("fashion-mnist-model"); TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .optOptimizer(Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build()) .addTrainingListeners(TrainingListener.Defaults.logging()); try (Trainer trainer = model.newTrainer(config)) { trainer.initialize(new long[]{1, 28, 28}); // Example shape for image data Metrics metrics = new Metrics(); trainer.setMetrics(metrics); for (Batch batch : trainer.iterateDataset(fashionMnist)) { EasyTrain.trainBatch(trainer, batch); trainer.step(); batch.close(); } trainer.notifyListeners(listener -> listener.onTrainingEnd(trainer)); } }}
3. Engine 和 NDManager
- Engine:DJL支持多个深度学习引擎,如MXNet、PyTorch、ONNX、TensorFlow,Engine接口提供统一的抽象,方便切换底层引擎。
- NDManager:管理NDArray,用于处理多维数组,封装了底层的数组操作。##### Using DJL Engine
importai.djl.Modelimportai.djl.ModelExceptionimportai.djl.ndarray.NDArrayimportai.djl.ndarray.NDListimportai.djl.ndarray.types.Shapeimportai.djl.translate.Batchifierimportai.djl.translate.TranslateExceptionimportai.djl.translate.Translatorimportai.djl.translate.TranslatorContextimportjava.io.IOExceptionimportjava.nio.file.Pathsobject DJLEngineExample{@Throws(ModelException::class,TranslateException::class,IOException::class)@JvmStatic fun main(args:Array<String>){// Initialize the model val model =Model.newInstance("model-name","ai.djl.pytorch")// Assuming "model-name" is valid and using PyTorch engine// Load a pre-trained model model.load(Paths.get("path/to/your/model"))// Ensure the path is correct// Define a translator for data preprocessing and postprocessing val translator:Translator<Array<Float>,Float>= object :Translator<Array<Float>,Float>{ override fun processInput(ctx:TranslatorContext, input:Array<Float>):NDList{ val manager = ctx.ndManager val array:NDArray= manager.create(input.toFloatArray()).reshape(Shape(1, input.size.toLong()))// Reshape might be necessaryreturnNDList(array)} override fun processOutput(ctx:TranslatorContext, list:NDList):Float{// Assuming the output is a single scalar valuereturn list[0].getFloat()// Use getFloat() to get the scalar value} override fun getBatchifier():Batchifier?{returnnull// Or implement batching if needed}} model.newPredictor(translator).use { predictor -> val input =arrayOf(1.0f,2.0f,3.0f)// Input should match the model's expected input shape val output = predictor.predict(input)println("Prediction: $output")}}}
##### Overview of NDManager###### Key Features of NDManager:1. Memory Management: Automates the process of memory allocation and deallocation for NDArrays.2. Resource Scope: NDArrays created by an NDManager are tied to the lifecycle of that manager. When the manager is closed, all associated NDArrays are also released.3. Hierarchical Structure: NDManagers can create child managers, which can further manage their own NDArrays. This is useful for managing resources in complex workflows.##### Using NDManagerimportai.djl.ndarray.NDManagerobject NDManagerExample{@JvmStatic fun main(args:Array<String>){NDManager.newBaseManager().use { manager -> val array = manager.create(floatArrayOf(1.0f,2.0f,3.0f))println("Array: $array")// Perform operations val result = array.add(2.0f)println("Result: $result")}// No need to explicitly free the memory, it's handled by the NDManager}}
4. Trainer 和 Predictor
Trainer 类提供训练模型的接口,包含优化器、损失函数和训练循环等功能。用于训练深度学习模型。它封装了训练过程中的一些常见操作,如前向传播、反向传播和参数更新。主要功能包括:- 模型的训练和验证- 管理优化器和损失函数- 提供易于使用的训练循环###### 代码演示以下是使用 DJL 的 Trainer 类训练一个简单神经网络的示例代码:
importai.djl.Modelimportai.djl.basicdataset.cv.classification.FashionMnistimportai.djl.basicmodelzoo.basic.Mlpimportai.djl.ndarray.types.Shapeimportai.djl.training.DefaultTrainingConfigimportai.djl.training.TrainingConfigimportai.djl.training.dataset.Datasetimportai.djl.training.dataset.RandomAccessDatasetimportai.djl.training.listener.LoggingTrainingListenerimportai.djl.training.listener.TrainingListenerimportai.djl.training.loss.Lossimportai.djl.training.optimizer.Optimizerimportai.djl.training.tracker.FixedPerVarTrackerimportai.djl.training.util.ProgressBarimportai.djl.translate.TranslateExceptionimportjava.io.IOExceptionimportjava.nio.file.Pathsobject DjlTrainerDemo{@Throws(IOException::class,TranslateException::class)@JvmStatic fun main(args:Array<String>){// Load dataset val trainDataset:RandomAccessDataset=FashionMnist.builder().optUsage(Dataset.Usage.TRAIN).setSampling(32,true).build() trainDataset.prepare(ProgressBar())// Define model val model =Model.newInstance("mlp") model.block =Mlp(28*28,10,intArrayOf(128,64))// Define training configuration val config:TrainingConfig=DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()).optOptimizer(Optimizer.sgd().setLearningRateTracker(FixedPerVarTracker.builder().setDefaultValue(0.01f).build()).build()).addTrainingListeners(LoggingTrainingListener()) model.newTrainer(config).use { trainer -> trainer.initialize(Shape(1,(28*28).toLong()))for(epoch in 0..9){for(batch in trainer.iterateDataset(trainDataset)){ trainer.step() batch.close()} trainer.notifyListeners { listener:TrainingListener-> listener.onEpoch(trainer)}} model.save(Paths.get("model"),"mlp")}}}
###### Predictor 类用于模型推理,接收输入数据并返回预测结果。用于对训练好的模型进行推理。它提供了一个简单的接口,用于将输入数据传递给模型并获取预测结果。主要功能包括:- 加载模型进行推理- 处理输入和输出数据的转换###### 代码演示importai.djl.Modelimportai.djl.modality.Classificationsimportai.djl.ndarray.NDArrayimportai.djl.ndarray.NDListimportai.djl.ndarray.NDManagerimportai.djl.ndarray.types.Shapeimportai.djl.translate.Batchifierimportai.djl.translate.TranslateExceptionimportai.djl.translate.Translatorimportai.djl.translate.TranslatorContextimportjava.io.IOExceptionimportjava.nio.file.Pathsobject DjlPredictorDemo{@Throws(IOException::class,TranslateException::class)@JvmStatic fun main(args:Array<String>){// Load model val model =Model.newInstance("mlp") model.load(Paths.get("model"),"mlp")// Define Translator val translator:Translator<NDArray,Classifications>= object :Translator<NDArray,Classifications>{ override fun processInput(ctx:TranslatorContext, input:NDArray):NDList{returnNDList(input.reshape(Shape(1,(28*28).toLong())))} override fun processOutput(ctx:TranslatorContext, list:NDList):Classifications{// Assuming the output NDArray is the first element in NDList val probabilities = list.singletonOrThrow()returnClassifications(listOf("Label1","Label2"), probabilities)// Example labels} override fun getBatchifier():Batchifier{returnBatchifier.STACK}} model.newPredictor(translator).use { predictor -> val manager =NDManager.newBaseManager() val array = manager.ones(Shape(1,(28*28).toLong())) val classifications = predictor.predict(array)println(classifications)}}}
版权归原作者 heeheeai 所有, 如有侵权,请联系我们删除。