0


java(kotlin) ai框架djl

DJL(Deep Java Library)是一个开源的深度学习框架,由AWS推出,DJL支持多种深度学习后端,包括但不限于:

MXNet:由Apache软件基金会支持的开源深度学习框架。
PyTorch:广泛使用的开源机器学习库,由Facebook的AI研究团队开发。
TensorFlow:由Google开发的另一个流行的开源机器学习框架。
DJL与Java生态系统紧密集成,可以与Spring Boot、Quarkus等Java框架协同工作。

maven

 <!--        djl-->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.28.0</version>
        </dependency>
  <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>0.28.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-model-zoo</artifactId>
            <version>0.28.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
            <version>0.28.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
            <version>0.28.0</version>
        </dependency>
        <!--        /djl-->

Java DJL 架构图

┌──────────────────────────────┐
│          ModelZoo            │
├──────────────────────────────┤
│            Model             │
└───────────────┬──────────────┘
                │
      ┌─────────▼─────────┐
      │       Engine      │
      └───────┬─┬─────────┘
              │ │
      ┌───────▼─▼─────────┐
      │     NDManager     │
      └───────┬─┬─────────┘
              │ │
    ┌─────────▼─▼───────────┐
    │    Dataset 
    └─────────┬─────────────┘
              │
    ┌─────────▼─────────────┐
    │  Trainer / Predictor  │
    └───────────────────────┘

主要组件详细描述

1. ModelZoo 和 Model
  • ModelZoo:提供多种预训练模型###### ModelZoo 的功能1. 模型发现与下载: - ModelZoo 提供了一种机制,可以从多种来源(例如模型提供商、在线仓库等)发现和下载预训练模型。- 例如,可以从 AWS S3、Hugging Face、TensorFlow Hub 等平台下载模型。2. 模型加载: - ModelZoo 提供了方便的方法来加载模型,用户可以根据需求加载不同类型的模型(例如图像分类模型、对象检测模型、自然语言处理模型等)。- 加载模型时,可以指定模型的名称、版本、以及模型的参数配置。3. 模型管理: - ModelZoo 帮助用户管理已下载和加载的模型,可以方便地查看、更新和删除模型。- 通过这种方式,可以有效地管理本地的模型资源,避免重复下载和浪费存储空间。#### 示例import ai.djl.Applicationimport ai.djl.Modelimport ai.djl.ModelExceptionimport ai.djl.modality.Classificationsimport ai.djl.modality.cv.Imageimport ai.djl.repository.zoo.Criteriaimport ai.djl.repository.zoo.ModelZooimport ai.djl.translate.TranslateExceptionobject ModelZooExample {@Throws(ModelException::class, TranslateException::class)@JvmStaticfunmain(args: Array<String>){// 定义模型的标准val criteria: Criteria<Image, Classifications>= Criteria.builder().optApplication(Application.CV.IMAGE_CLASSIFICATION)// 应用场景:图像分类.setTypes(Image::class.java, Classifications::class.java)// 输入输出类型.optFilter("backbone","resnet50")// 模型过滤条件.build()// 从 ModelZoo 加载模型val model: Model = ModelZoo.loadModel(criteria)// 使用模型进行推理// ...}}#### ModelZoo 的类与接口- **ModelZoo**:核心类,提供模型的下载和加载功能。- **Criteria**:定义模型加载的标准和过滤条件,用于指定所需模型的应用场景、输入输出类型等。- **ModelLoader**:用于实际执行模型的下载和加载操作。
  • Model:表示一个深度学习模型的接口,包含模型的加载、保存和运行等操作。- #### ai.djl.ModelZoo##### Key Methods:- Model loadModel(Criteria<?, ?> criteria): Loads a model based on the provided criteria.- ModelInfo getModel(ModelId modelId): Retrieves information about a specific model using its ModelId.- Set<ModelId> listModels(ZooModel<?, ?> model): Lists all models in the zoo that match the given model.#### ai.djl.ModelInfo InterfaceModelInfo provides metadata about a model, including its name, description, and input/output information.##### Key Methods:- String getName(): Returns the name of the model.- String getDescription(): Provides a description of the model.- Shape getInputShape(): Returns the shape of the input tensor.- Shape getOutputShape(): Returns the shape of the output tensor.#### ai.djl.ModelId ClassModelId uniquely identifies a model in the model zoo. It includes information about the model’s group, name, and version.##### Key Fields:- String getGroup(): Gets the group name of the model.- String getName(): Gets the name of the model.- String getVersion(): Gets the version of the model.#### ai.djl.Application EnumApplication enumerates different types of applications supported by the model zoo, such as IMAGE_CLASSIFICATION, OBJECT_DETECTION, etc.##### Key Values:- CV.IMAGE_CLASSIFICATION- CV.OBJECT_DETECTION- NLP.TEXT_CLASSIFICATION#### ai.djl.Criteria ClassCriteria is a builder for creating criteria objects used to filter and load models.##### Key Methods:- static Builder<?, ?> builder(): Creates a new builder instance.- Criteria<I, O> optApplication(Application application): Sets the application type.- Criteria<I, O> optEngine(String engine): Specifies the engine to use (e.g., MXNet, PyTorch)###### exampleimport ai.djl.Modelimport ai.djl.ModelExceptionimport ai.djl.modality.Classificationsimport ai.djl.modality.cv.Imageimport ai.djl.modality.cv.ImageFactoryimport ai.djl.ndarray.NDListimport ai.djl.translate.TranslateExceptionimport ai.djl.translate.Translatorimport ai.djl.translate.TranslatorContextimport java.io.IOExceptionimport java.nio.file.Pathsobject DjlExample { @JvmStatic fun main(args: Array<String>) { // 模型路径 val modelDir = Paths.get("models") val modelName = "resnet18" try { Model.newInstance(modelName).use { model -> // 加载模型 model.load(modelDir) // 加载输入图像 val img = ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg")) // 获取预测器 val predictor = model.newPredictor(MyTranslator()) // 执行推理 val result = predictor.predict(img) println(result) } } catch (e: IOException) { e.printStackTrace() } catch (e: ModelException) { e.printStackTrace() } catch (e: TranslateException) { e.printStackTrace() } } // 自定义 Translator private class MyTranslator : Translator<Image?, Classifications?> { override fun processInput(ctx: TranslatorContext?, input: Image?): NDList { return NDList(input!!.toNDArray(ctx!!.ndManager)) } override fun processOutput(ctx: TranslatorContext, list: NDList): Classifications { val probabilitiesNDArray = list.singletonOrThrow().softmax(1) val labels: List<String> = List(100) { "name$it" } return Classifications(labels, probabilitiesNDArray) } }}
2. Dataset
  • 常见的数据集类型:1. RandomAccessDataset: - RandomAccessDataset 是一种基本的数据集接口,适用于数据可以随机访问的情况,如数组或列表。- 它支持批处理(batching)、数据切片(slicing)等操作,适合大多数监督学习任务。2. IterableDataset: - IterableDataset 适用于数据不能随机访问的情况,如流数据或实时生成的数据。- 它通过迭代器(iterator)提供数据,适用于需要动态生成或处理的数据源。3. RecordDataset: - RecordDataset 是基于记录文件(record file)的数据集格式,常用于大规模数据处理。- 它可以高效地加载和处理数据记录,适用于分布式训练和大数据集的处理。DJL 的数据集组件提供的功能包括:1. 数据加载和预处理: - 支持从多种数据源加载数据,如本地文件、远程服务器、数据库等。- 提供数据预处理功能,如归一化、数据增强、特征提取等。2. 批处理(Batching): - 支持将数据分成小批次进行处理,适用于大规模数据集的训练。- 提供灵活的批处理策略,可根据需要进行自定义。3. 数据变换(Transformations): - 提供多种数据变换功能,如图像变换、文本处理、数值处理等。- 支持链式调用,将多个变换操作组合在一起,形成数据处理管道。4. 数据加载器(DataLoader): - DataLoader 负责将数据集打包成批次,并在训练过程中按需提供数据。- 支持多线程数据加载,提高数据处理效率。
  • Dataset:定义数据集的抽象类,用户可以继承该类来实现自定义的数据集。- import ai.djl.Model;import ai.djl.ModelException;import ai.djl.inference.Predictor;import ai.djl.modality.Classifications;import ai.djl.modality.cv.Image;import ai.djl.modality.cv.ImageFactory;import ai.djl.repository.zoo.Criteria;import ai.djl.repository.zoo.ModelZoo;import ai.djl.translate.TranslateException;import java.io.IOException;import java.nio.file.Paths;public class DjlExample { public static void main(String[] args) throws IOException, ModelException, TranslateException { // 加载模型 Criteria<Image, Classifications> criteria = Criteria.builder() .optEngine("TensorFlow") // 选择引擎 .setTypes(Image.class, Classifications.class) .optModelPath(Paths.get("path/to/model")) .build(); try (Model model = ModelZoo.loadModel(criteria); Predictor<Image, Classifications> predictor = model.newPredictor()) { // 加载图像 Image img = ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg")); // 进行推理 Classifications result = predictor.predict(img); System.out.println(result); } }}- import ai.djl.Application;import ai.djl.Model;import ai.djl.basicdataset.cv.classification.FashionMnist;import ai.djl.engine.Engine;import ai.djl.metric.Metrics;import ai.djl.ndarray.NDArray;import ai.djl.ndarray.NDManager;import ai.djl.training.DefaultTrainingConfig;import ai.djl.training.EasyTrain;import ai.djl.training.Trainer;import ai.djl.training.dataset.Batch;import ai.djl.training.dataset.Dataset;import ai.djl.training.listener.TrainingListener;import ai.djl.training.loss.Loss;import ai.djl.training.optimizer.Optimizer;import ai.djl.training.tracker.Tracker;import ai.djl.translate.TranslateException;import ai.djl.util.Pair;import java.io.IOException;public class DJLDatasetExample { public static void main(String[] args) throws IOException, TranslateException { NDManager manager = NDManager.newBaseManager(); FashionMnist fashionMnist = FashionMnist.builder() .optUsage(Dataset.Usage.TRAIN) .setSampling(32, true) // 32 is the batch size .optLimit(Long.MAX_VALUE) // Use this to limit the number of samples .build(); fashionMnist.prepare(); Model model = Model.newInstance("fashion-mnist-model"); TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .optOptimizer(Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build()) .addTrainingListeners(TrainingListener.Defaults.logging()); try (Trainer trainer = model.newTrainer(config)) { trainer.initialize(new long[]{1, 28, 28}); // Example shape for image data Metrics metrics = new Metrics(); trainer.setMetrics(metrics); for (Batch batch : trainer.iterateDataset(fashionMnist)) { EasyTrain.trainBatch(trainer, batch); trainer.step(); batch.close(); } trainer.notifyListeners(listener -> listener.onTrainingEnd(trainer)); } }}

3. Engine 和 NDManager
  • Engine:DJL支持多个深度学习引擎,如MXNet、PyTorch、ONNX、TensorFlow,Engine接口提供统一的抽象,方便切换底层引擎。
  • NDManager:管理NDArray,用于处理多维数组,封装了底层的数组操作。##### Using DJL Engineimportai.djl.Modelimportai.djl.ModelExceptionimportai.djl.ndarray.NDArrayimportai.djl.ndarray.NDListimportai.djl.ndarray.types.Shapeimportai.djl.translate.Batchifierimportai.djl.translate.TranslateExceptionimportai.djl.translate.Translatorimportai.djl.translate.TranslatorContextimportjava.io.IOExceptionimportjava.nio.file.Pathsobject DJLEngineExample{@Throws(ModelException::class,TranslateException::class,IOException::class)@JvmStatic fun main(args:Array<String>){// Initialize the model val model =Model.newInstance("model-name","ai.djl.pytorch")// Assuming "model-name" is valid and using PyTorch engine// Load a pre-trained model model.load(Paths.get("path/to/your/model"))// Ensure the path is correct// Define a translator for data preprocessing and postprocessing val translator:Translator<Array<Float>,Float>= object :Translator<Array<Float>,Float>{ override fun processInput(ctx:TranslatorContext, input:Array<Float>):NDList{ val manager = ctx.ndManager val array:NDArray= manager.create(input.toFloatArray()).reshape(Shape(1, input.size.toLong()))// Reshape might be necessaryreturnNDList(array)} override fun processOutput(ctx:TranslatorContext, list:NDList):Float{// Assuming the output is a single scalar valuereturn list[0].getFloat()// Use getFloat() to get the scalar value} override fun getBatchifier():Batchifier?{returnnull// Or implement batching if needed}} model.newPredictor(translator).use { predictor -> val input =arrayOf(1.0f,2.0f,3.0f)// Input should match the model's expected input shape val output = predictor.predict(input)println("Prediction: $output")}}}##### Overview of NDManager###### Key Features of NDManager:1. Memory Management: Automates the process of memory allocation and deallocation for NDArrays.2. Resource Scope: NDArrays created by an NDManager are tied to the lifecycle of that manager. When the manager is closed, all associated NDArrays are also released.3. Hierarchical Structure: NDManagers can create child managers, which can further manage their own NDArrays. This is useful for managing resources in complex workflows.##### Using NDManagerimportai.djl.ndarray.NDManagerobject NDManagerExample{@JvmStatic fun main(args:Array<String>){NDManager.newBaseManager().use { manager -> val array = manager.create(floatArrayOf(1.0f,2.0f,3.0f))println("Array: $array")// Perform operations val result = array.add(2.0f)println("Result: $result")}// No need to explicitly free the memory, it's handled by the NDManager}}
4. Trainer 和 Predictor
  • Trainer 类提供训练模型的接口,包含优化器、损失函数和训练循环等功能。用于训练深度学习模型。它封装了训练过程中的一些常见操作,如前向传播、反向传播和参数更新。主要功能包括:- 模型的训练和验证- 管理优化器和损失函数- 提供易于使用的训练循环###### 代码演示以下是使用 DJL 的 Trainer 类训练一个简单神经网络的示例代码:importai.djl.Modelimportai.djl.basicdataset.cv.classification.FashionMnistimportai.djl.basicmodelzoo.basic.Mlpimportai.djl.ndarray.types.Shapeimportai.djl.training.DefaultTrainingConfigimportai.djl.training.TrainingConfigimportai.djl.training.dataset.Datasetimportai.djl.training.dataset.RandomAccessDatasetimportai.djl.training.listener.LoggingTrainingListenerimportai.djl.training.listener.TrainingListenerimportai.djl.training.loss.Lossimportai.djl.training.optimizer.Optimizerimportai.djl.training.tracker.FixedPerVarTrackerimportai.djl.training.util.ProgressBarimportai.djl.translate.TranslateExceptionimportjava.io.IOExceptionimportjava.nio.file.Pathsobject DjlTrainerDemo{@Throws(IOException::class,TranslateException::class)@JvmStatic fun main(args:Array<String>){// Load dataset val trainDataset:RandomAccessDataset=FashionMnist.builder().optUsage(Dataset.Usage.TRAIN).setSampling(32,true).build() trainDataset.prepare(ProgressBar())// Define model val model =Model.newInstance("mlp") model.block =Mlp(28*28,10,intArrayOf(128,64))// Define training configuration val config:TrainingConfig=DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()).optOptimizer(Optimizer.sgd().setLearningRateTracker(FixedPerVarTracker.builder().setDefaultValue(0.01f).build()).build()).addTrainingListeners(LoggingTrainingListener()) model.newTrainer(config).use { trainer -> trainer.initialize(Shape(1,(28*28).toLong()))for(epoch in 0..9){for(batch in trainer.iterateDataset(trainDataset)){ trainer.step() batch.close()} trainer.notifyListeners { listener:TrainingListener-> listener.onEpoch(trainer)}} model.save(Paths.get("model"),"mlp")}}}###### Predictor 类用于模型推理,接收输入数据并返回预测结果。用于对训练好的模型进行推理。它提供了一个简单的接口,用于将输入数据传递给模型并获取预测结果。主要功能包括:- 加载模型进行推理- 处理输入和输出数据的转换###### 代码演示importai.djl.Modelimportai.djl.modality.Classificationsimportai.djl.ndarray.NDArrayimportai.djl.ndarray.NDListimportai.djl.ndarray.NDManagerimportai.djl.ndarray.types.Shapeimportai.djl.translate.Batchifierimportai.djl.translate.TranslateExceptionimportai.djl.translate.Translatorimportai.djl.translate.TranslatorContextimportjava.io.IOExceptionimportjava.nio.file.Pathsobject DjlPredictorDemo{@Throws(IOException::class,TranslateException::class)@JvmStatic fun main(args:Array<String>){// Load model val model =Model.newInstance("mlp") model.load(Paths.get("model"),"mlp")// Define Translator val translator:Translator<NDArray,Classifications>= object :Translator<NDArray,Classifications>{ override fun processInput(ctx:TranslatorContext, input:NDArray):NDList{returnNDList(input.reshape(Shape(1,(28*28).toLong())))} override fun processOutput(ctx:TranslatorContext, list:NDList):Classifications{// Assuming the output NDArray is the first element in NDList val probabilities = list.singletonOrThrow()returnClassifications(listOf("Label1","Label2"), probabilities)// Example labels} override fun getBatchifier():Batchifier{returnBatchifier.STACK}} model.newPredictor(translator).use { predictor -> val manager =NDManager.newBaseManager() val array = manager.ones(Shape(1,(28*28).toLong())) val classifications = predictor.predict(array)println(classifications)}}}

本文转载自: https://blog.csdn.net/heeheeai/article/details/139640598
版权归原作者 heeheeai 所有, 如有侵权,请联系我们删除。

“java(kotlin) ai框架djl”的评论:

还没有评论