C++与Lua交互实例 – 矩阵的加减乘除(版本二)
TIPS:关于使用矩阵的加减乘除测试C++与Lua的交互以及下面没讲述到的知识点可以阅读第一版:
https://blog.csdn.net/qq135595696/article/details/128960951
同时下面两个方式矩阵的数据都来源于C++端,只是第一种是在C++端进行结果比较展示,第二种方式(userdata)是在lua端进行结果比较展示。
下面C++端引入第三方开源库测试lua端矩阵的运算是否正确,参考链接如下:
http://eigen.tuxfamily.org/index.php?title=3.4
CppToLua1
CppToLua.cpp
#include<iostream>#include<vector>#include<assert.h>#include<Dense>#include"lua.hpp"using std::cout;using std::endl;using std::cin;staticint gs_Top =0;#defineSTACK_NUM(L)\gs_Top =lua_gettop(L);\std::cout<<"stack top:"<< gs_Top <<std::endl\// 矩阵运算enumclassMATRIX_OPERATE{
ADD,
SUB,
MUL,
DIV,
NONE
};#defineLUA_SCRIPT_PATH"matrix2.0.lua"static std::vector<std::vector<double>> gs_mat1;static std::vector<std::vector<double>> gs_mat2;staticboolOutPrint(const std::vector<std::vector<double>>& data){for(int32_t i =0; i < data.size(); i++){for(int32_t j =0; j < data[0].size(); j++)
std::cout <<" "<< data[i][j];
std::cout <<'\n';}
std::cout <<"......\n";returntrue;}staticboolInit(lua_State* L){assert(NULL!= L);
gs_mat1.clear();
gs_mat2.clear();if(luaL_dofile(L, LUA_SCRIPT_PATH)){printf("%s\n",lua_tostring(L,-1));returnfalse;}returntrue;}staticboolCreateLuaArr(lua_State* L,const std::vector<std::vector<double>>& data){assert(NULL!= L);//STACK_NUM(L);lua_newtable(L);for(int32_t i =0; i < data.size(); i++){lua_newtable(L);for(int32_t j =0; j < data[0].size(); j++){lua_pushnumber(L, data[i][j]);lua_rawseti(L,-2, j +1);}lua_rawseti(L,-2, i +1);}//STACK_NUM(L);returntrue;}staticboolGetLuaArr(lua_State* L, std::vector<std::vector<double>>& outData){assert(NULL!= L);
outData.clear();bool result =false;int32_t row =0;int32_t col =0;if(LUA_TTABLE !=lua_type(L,-1)){goto Exit;}if(LUA_TTABLE !=lua_getfield(L,-1,"tbData")){goto Exit;}lua_getfield(L,-2,"nRow");
row =lua_tointeger(L,-1);lua_pop(L,1);lua_getfield(L,-2,"nColumn");
col =lua_tointeger(L,-1);lua_pop(L,1);for(int32_t i =0; i < row; i++){lua_rawgeti(L,-1, i +1);
std::vector<double> data;for(int32_t j =0; j < col; j++){lua_rawgeti(L,-1, j +1);
data.push_back(lua_tonumber(L,-1));lua_pop(L,1);}
outData.push_back(data);lua_pop(L,1);}//维持lua堆栈平衡lua_pop(L,1);
result =true;
Exit:returntrue;}staticboolMatrixOperate(lua_State* L,
std::vector<std::vector<double>>& outData, MATRIX_OPERATE type){
outData.clear();constchar* funcName =NULL;bool result =false;switch(type){case MATRIX_OPERATE::ADD:
funcName ="MatrixAdd";break;case MATRIX_OPERATE::SUB:
funcName ="MatrixSub";break;case MATRIX_OPERATE::MUL:
funcName ="MatrixMul";break;case MATRIX_OPERATE::DIV:
funcName ="MatrixDiv";break;case MATRIX_OPERATE::NONE:break;default:break;}lua_getglobal(L, funcName);luaL_checktype(L,-1, LUA_TFUNCTION);//添加形参CreateLuaArr(L, gs_mat1);CreateLuaArr(L, gs_mat2);//调用函数if(lua_pcall(L,2,1,0)){printf("error[%s]\n",lua_tostring(L,-1));goto Exit;}GetLuaArr(L, outData);
result =true;
Exit:return result;}staticboolAPIMatrixOperate(const std::vector<std::vector<double>>& data1,const std::vector<std::vector<double>>& data2, MATRIX_OPERATE type, Eigen::MatrixXd& outResMat){
Eigen::MatrixXd mat1(data1.size(), data1[0].size());
Eigen::MatrixXd mat2(data2.size(), data2[0].size());for(int i =0; i < data1.size(); i++){for(int j =0; j < data1[0].size(); j++){mat1(i, j)= data1[i][j];}}for(int i =0; i < data2.size(); i++){for(int j =0; j < data2[0].size(); j++){mat2(i, j)= data2[i][j];}}switch(type){case MATRIX_OPERATE::ADD:
outResMat = mat1 + mat2;break;case MATRIX_OPERATE::SUB:
outResMat = mat1 - mat2;break;case MATRIX_OPERATE::MUL:
outResMat = mat1 * mat2;break;case MATRIX_OPERATE::DIV:
outResMat = mat1 *(mat2.inverse());break;case MATRIX_OPERATE::NONE:break;default:break;}returntrue;}staticboolRun(lua_State* L){assert(NULL!= L);
std::vector<std::vector<double>> addData;
std::vector<std::vector<double>> subData;
std::vector<std::vector<double>> mulData;
std::vector<std::vector<double>> divData;
Eigen::MatrixXd addApiData;
Eigen::MatrixXd subApiData;
Eigen::MatrixXd mulApiData;
Eigen::MatrixXd divApiData;// 运算
gs_mat1 ={{1,2,3},{4,5,6}};
gs_mat2 ={{2,3,4},{5,6,7}};MatrixOperate(L, addData, MATRIX_OPERATE::ADD);APIMatrixOperate(gs_mat1, gs_mat2, MATRIX_OPERATE::ADD, addApiData);
gs_mat1 = addData;
gs_mat2 ={{1,1,1},{1,1,1}};MatrixOperate(L, subData, MATRIX_OPERATE::SUB);APIMatrixOperate(gs_mat1, gs_mat2, MATRIX_OPERATE::SUB, subApiData);
gs_mat1 ={{1,2,3},{4,5,6}};
gs_mat2 ={{7,8},{9,10},{11,12}};MatrixOperate(L, mulData, MATRIX_OPERATE::MUL);APIMatrixOperate(gs_mat1, gs_mat2, MATRIX_OPERATE::MUL, mulApiData);
gs_mat1 ={{41,2,3},{424,5,6},{742,8,11}};
gs_mat2 ={{1,2,1},{1,1,2},{2,1,1}};MatrixOperate(L, divData, MATRIX_OPERATE::DIV);APIMatrixOperate(gs_mat1, gs_mat2, MATRIX_OPERATE::DIV, divApiData);// 输出
cout <<"================加法:================"<< endl;OutPrint(addData);
cout <<"正确答案:\n"<< addApiData << endl;
cout <<"================减法:================"<< endl;OutPrint(subData);
cout <<"正确答案:\n"<< subApiData << endl;
cout <<"================乘法:================"<< endl;OutPrint(mulData);
cout <<"正确答案:\n"<< mulApiData << endl;
cout <<"================除法:================"<< endl;OutPrint(divData);
cout <<"正确答案:\n"<< divApiData << endl;returntrue;}staticboolUnInit(){returntrue;}intmain020811(){
lua_State* L =luaL_newstate();luaL_openlibs(L);if(Init(L)){Run(L);}UnInit();lua_close(L);return0;}
matrix2.0.lua
local _class ={}functionclass(super)local tbClassType ={}
tbClassType.Ctor =false
tbClassType.super = super
tbClassType.New =function(...)local tbObj ={}dolocal funcCreate
funcCreate =function(tbClass,...)if tbClass.super thenfuncCreate(tbClass.super,...)endif tbClass.Ctor then
tbClass.Ctor(tbObj,...)endendfuncCreate(tbClassType,...)end-- 防止调用Ctor初始化时,在Ctor内部设置了元表的情况发生ifgetmetatable(tbObj)thengetmetatable(tbObj).__index = _class[tbClassType]elsesetmetatable(tbObj,{ __index = _class[tbClassType]})endreturn tbObj
endlocal vtbl ={}
_class[tbClassType]= vtbl
setmetatable(tbClassType,{ __newindex =function(tb,k,v)
vtbl[k]= v
end})if super thensetmetatable(vtbl,{ __index =function(tb,k)local varRet = _class[super][k]
vtbl[k]= varRet
return varRet
end})endreturn tbClassType
end
Matrix =class()function Matrix:Ctor(data)
self.tbData = data
self.nRow =#data
if self.nRow >0then
self.nColumn =(#data[1])else
self.nColumn =0end-- print("row:",self.nRow," col:",self.nColumn)setmetatable(self,{
__add =function(tbSource, tbDest)assert(tbSource,"tbSource not exist")assert(tbDest,"tbDest not exist")local tbRes = Matrix.New({})-- print(tbSource,tbDest)-- print("tbSource:",tbSource.nRow,tbSource.nColumn)-- tbSource:Print()-- print("tbDest:",tbDest.nRow,tbDest.nColumn)-- tbDest:Print()if tbSource.nRow ~= tbDest.nRow
or tbSource.nColumn ~= tbDest.nColumn thenprint("row or column not equal...")return tbRes
elsefor rowKey,rowValue inipairs(tbSource.tbData)dofor colKey,colValue inipairs(tbSource.tbData[rowKey])doif tbRes.tbData[rowKey]==nilthen
tbRes.tbData[rowKey]={}endif tbRes.tbData[rowKey][colKey]==nilthen
tbRes.tbData[rowKey][colKey]=0end
tbRes.tbData[rowKey][colKey]=
tbSource.tbData[rowKey][colKey]+ tbDest.tbData[rowKey][colKey]endend
tbRes.nRow = tbSource.nRow
tbRes.nColumn = tbSource.nColumn
return tbRes
endend,
__sub =function(tbSource, tbDest)assert(tbSource,"tbSource not exist")assert(tbDest,"tbDest not exist")local tbRes = Matrix.New({})if tbSource.nRow ~= tbDest.nRow
or tbSource.nColumn ~= tbDest.nColumn thenprint("row or column not equal...")return tbRes
elsefor rowKey,rowValue inipairs(tbSource.tbData)dofor colKey,colValue inipairs(tbSource.tbData[rowKey])doif tbRes.tbData[rowKey]==nilthen
tbRes.tbData[rowKey]={}endif tbRes.tbData[rowKey][colKey]==nilthen
tbRes.tbData[rowKey][colKey]=0end
tbRes.tbData[rowKey][colKey]=
tbSource.tbData[rowKey][colKey]- tbDest.tbData[rowKey][colKey]endend
tbRes.nRow = tbSource.nRow
tbRes.nColumn = tbSource.nColumn
return tbRes
endend,
__mul =function(tbSource, tbDest)return self:_MartixMul(tbSource, tbDest)end,
__div =function(tbSource, tbDest)assert(tbSource,"tbSource not exist")assert(tbDest,"tbDest not exist")local nDet = self:_GetDetValue(tbDest)if nDet ==0thenprint("matrix no inverse matrix...")returnnilend-- print("det ",nDet)local tbInverseDest = self:_MatrixNumMul(self:_GetCompanyMatrix(tbDest),1/ nDet)-- self:_GetCompanyMatrix(tbDest):Print()-- print(nDet)-- tbInverseDest:Print()return self:_MartixMul(tbSource, tbInverseDest)end})endfunction Matrix:Print()for rowKey,rowValue inipairs(self.tbData)dofor colKey,colValue inipairs(self.tbData[rowKey])do
io.write(self.tbData[rowKey][colKey],',')endprint('')endend-- 加function Matrix:Add(matrix)return self + matrix
end-- 减function Matrix:Sub(matrix)return self - matrix
end-- 乘function Matrix:Mul(matrix)return self * matrix
end-- 除function Matrix:Div(matrix)return self / matrix
end-- 切割,切去第rowIndex以及第colIndex列function Matrix:_CutoffMatrix(tbMatrix, rowIndex, colIndex)assert(tbMatrix,"tbMatrix not exist")assert(rowIndex >=1,"rowIndex < 1")assert(colIndex >=1,"colIndex < 1")local tbRes = Matrix.New({})
tbRes.nRow = tbMatrix.nRow -1
tbRes.nColumn = tbMatrix.nColumn -1for i =1, tbMatrix.nRow -1dofor j =1, tbMatrix.nColumn -1doif tbRes.tbData[i]==nilthen
tbRes.tbData[i]={}endlocal nRowDir =0local nColDir =0if i >= rowIndex then
nRowDir =1endif j >= colIndex then
nColDir =1end
tbRes.tbData[i][j]= tbMatrix.tbData[i + nRowDir][j + nColDir]endendreturn tbRes
end-- 获取矩阵的行列式对应的值function Matrix:_GetDetValue(tbMatrix)assert(tbMatrix,"tbMatrix not exist")-- 当矩阵为一阶矩阵时,直接返回A中唯一的元素if tbMatrix.nRow ==1thenreturn tbMatrix.tbData[1][1]endlocal nAns =0for i =1, tbMatrix.nColumn dolocal nFlag =-1if i %2~=0then
nFlag =1end
nAns =
nAns + tbMatrix.tbData[1][i]*
self:_GetDetValue(self:_CutoffMatrix(tbMatrix,1, i))* nFlag
-- print("_GetDetValue nflag:",nFlag)endreturn nAns
end-- 获取矩阵的伴随矩阵function Matrix:_GetCompanyMatrix(tbMatrix)assert(tbMatrix,"tbMatrix not exist")local tbRes = Matrix.New({})-- 伴随矩阵与原矩阵存在转置关系
tbRes.nRow = tbMatrix.nColumn
tbRes.nColumn = tbMatrix.nRow
for i =1, tbMatrix.nRow dofor j =1, tbMatrix.nColumn dolocal nFlag =1if((i + j)%2)~=0then
nFlag =-1endif tbRes.tbData[j]==nilthen
tbRes.tbData[j]={}end-- print(Matrix:_GetDetValue(Matrix:_CutoffMatrix(tbMatrix, i, j)))-- Matrix:_CutoffMatrix(tbMatrix, i, j):Print()-- print("---11----")
tbRes.tbData[j][i]=
self:_GetDetValue(self:_CutoffMatrix(tbMatrix, i, j))* nFlag
endendreturn tbRes
end-- 矩阵数乘function Matrix:_MatrixNumMul(tbMatrix, num)for i =1, tbMatrix.nRow dofor j =1, tbMatrix.nColumn do
tbMatrix.tbData[i][j]= tbMatrix.tbData[i][j]* num
endendreturn tbMatrix
end-- 矩阵相乘function Matrix:_MartixMul(tbSource, tbDest)assert(tbSource,"tbSource not exist")assert(tbDest,"tbDest not exist")if tbSource.nColumn ~= tbDest.nRow thenprint("column not equal row...")return tbSource
elselocal tbRes = Matrix.New({})for i =1, tbSource.nRow dofor j =1, tbDest.nColumn doif tbRes.tbData[i]==nilthen
tbRes.tbData[i]={}endif tbRes.tbData[i][j]==nilthen
tbRes.tbData[i][j]=0endfor k =1, tbSource.nColumn do
tbRes.tbData[i][j]=
tbRes.tbData[i][j]+(tbSource.tbData[i][k]* tbDest.tbData[k][j])endendend
tbRes.nRow = tbSource.nRow
tbRes.nColumn = tbDest.nColumn
return tbRes
endend-- addfunctionMatrixAdd(data1, data2)assert(data1,"data1 not exist")assert(data2,"data2 not exist")local matrix1 = Matrix.New(data1)local matrix2 = Matrix.New(data2)return matrix1 + matrix2
end-- subfunctionMatrixSub(data1, data2)assert(data1,"data1 not exist")assert(data2,"data2 not exist")local matrix1 = Matrix.New(data1)local matrix2 = Matrix.New(data2)return matrix1 - matrix2
end-- mulfunctionMatrixMul(data1, data2)assert(data1,"data1 not exist")assert(data2,"data2 not exist")local matrix1 = Matrix.New(data1)local matrix2 = Matrix.New(data2)return matrix1 * matrix2
end-- divfunctionMatrixDiv(data1, data2)assert(data1,"data1 not exist")assert(data2,"data2 not exist")local matrix1 = Matrix.New(data1)local matrix2 = Matrix.New(data2)return matrix1 / matrix2
end
输出结果
CppToLua2
CppToLua.cpp
#include<iostream>#include<Dense>#include<vector>#include"lua.hpp"using std::cout;using std::endl;using std::cin;#defineCPP_MATRIX"CPP_MATRIX"#defineLUA_SCRIPT_PATH"matrix2.0-lua.lua"staticint gs_Top =0;#defineSTACK_NUM(L)\gs_Top =lua_gettop(L);\std::cout<<"stack top:"<< gs_Top <<std::endl\// 矩阵运算enumclassMATRIX_OPERATE{
ADD,
SUB,
MUL,
DIV,
NONE
};static std::vector<std::vector<double>> gs_mat1;static std::vector<std::vector<double>> gs_mat2;extern"C"{staticintCreateMatrix(lua_State* L){
Eigen::MatrixXd** pp =(Eigen::MatrixXd**)lua_newuserdata(L,sizeof(Eigen::MatrixXd*));*pp =newEigen::MatrixXd();luaL_setmetatable(L, CPP_MATRIX);return1;}staticintInitMatrix(lua_State* L){assert(NULL!= L);int32_t row =0;int32_t col =0;
row =luaL_len(L,-1);lua_rawgeti(L,-1,1);
col =luaL_len(L,-1);lua_pop(L,1);
Eigen::MatrixXd** pp =(Eigen::MatrixXd**)luaL_checkudata(L,1, CPP_MATRIX);(*pp)->resize(row, col);for(int32_t i =0; i < row; i++){lua_rawgeti(L,-1, i +1);for(int32_t j =0; j < col; j++){lua_rawgeti(L,-1, j +1);(**pp)(i, j)=lua_tonumber(L,-1);lua_pop(L,1);}lua_pop(L,1);}lua_pop(L,2);return0;}staticintUnInitMatrix(lua_State* L){
Eigen::MatrixXd** pp =(Eigen::MatrixXd**)luaL_checkudata(L,1, CPP_MATRIX);
std::cout <<"auto gc"<< std::endl;if(*pp){delete*pp;}return0;}staticintAddMatrix(lua_State* L){//STACK_NUM(L);
Eigen::MatrixXd** pp1 =(Eigen::MatrixXd**)luaL_checkudata(L,1, CPP_MATRIX);
Eigen::MatrixXd** pp2 =(Eigen::MatrixXd**)luaL_checkudata(L,2, CPP_MATRIX);
Eigen::MatrixXd** pp =(Eigen::MatrixXd**)lua_newuserdata(L,sizeof(Eigen::MatrixXd*));*pp =newEigen::MatrixXd();//该部分内存由C++分配**pp =(**pp1)+(**pp2);luaL_setmetatable(L, CPP_MATRIX);//STACK_NUM(L);return1;}staticintSubMatrix(lua_State* L){//STACK_NUM(L);
Eigen::MatrixXd** pp1 =(Eigen::MatrixXd**)luaL_checkudata(L,1, CPP_MATRIX);
Eigen::MatrixXd** pp2 =(Eigen::MatrixXd**)luaL_checkudata(L,2, CPP_MATRIX);
Eigen::MatrixXd** pp =(Eigen::MatrixXd**)lua_newuserdata(L,sizeof(Eigen::MatrixXd*));*pp =newEigen::MatrixXd();//该部分内存由C++分配**pp =(**pp1)-(**pp2);luaL_setmetatable(L, CPP_MATRIX);//STACK_NUM(L);return1;}staticintMulMatrix(lua_State* L){//STACK_NUM(L);
Eigen::MatrixXd** pp1 =(Eigen::MatrixXd**)luaL_checkudata(L,1, CPP_MATRIX);
Eigen::MatrixXd** pp2 =(Eigen::MatrixXd**)luaL_checkudata(L,2, CPP_MATRIX);
Eigen::MatrixXd** pp =(Eigen::MatrixXd**)lua_newuserdata(L,sizeof(Eigen::MatrixXd*));*pp =newEigen::MatrixXd();//该部分内存由C++分配**pp =(**pp1)*(**pp2);luaL_setmetatable(L, CPP_MATRIX);//STACK_NUM(L);return1;}staticintDivMatrix(lua_State* L){//STACK_NUM(L);
Eigen::MatrixXd** pp1 =(Eigen::MatrixXd**)luaL_checkudata(L,1, CPP_MATRIX);
Eigen::MatrixXd** pp2 =(Eigen::MatrixXd**)luaL_checkudata(L,2, CPP_MATRIX);
Eigen::MatrixXd** pp =(Eigen::MatrixXd**)lua_newuserdata(L,sizeof(Eigen::MatrixXd*));*pp =newEigen::MatrixXd();//该部分内存由C++分配**pp =(**pp1)*((**pp2).inverse());luaL_setmetatable(L, CPP_MATRIX);//STACK_NUM(L);return1;}staticintPrintMatrix(lua_State* L){
Eigen::MatrixXd** pp =(Eigen::MatrixXd**)luaL_checkudata(L,1, CPP_MATRIX);
std::cout <<"正确答案:\n"<<**pp << std::endl;return0;}}staticconst luaL_Reg MatrixFuncs[]={{"InitMatrix", InitMatrix },{"__gc", UnInitMatrix},{"__add", AddMatrix },{"__sub", SubMatrix },{"__mul", MulMatrix },{"__div", DivMatrix },{"PrintMatrix",PrintMatrix },{NULL,NULL}};extern"C"{staticboolCreateMatrixMetaTable(lua_State* L){luaL_newmetatable(L, CPP_MATRIX);lua_pushvalue(L,-1);lua_setfield(L,-2,"__index");luaL_setfuncs(L, MatrixFuncs,0);//STACK_NUM(L);lua_pop(L,1);returntrue;}}boolCreateLuaArr(lua_State* L,const std::vector<std::vector<double>>& data){assert(NULL!= L);//STACK_NUM(L);lua_newtable(L);for(int32_t i =0; i < data.size(); i++){lua_newtable(L);for(int32_t j =0; j < data[0].size(); j++){lua_pushnumber(L, data[i][j]);lua_rawseti(L,-2, j +1);}lua_rawseti(L,-2, i +1);}//STACK_NUM(L);returntrue;}boolMatrixOperate(lua_State* L, MATRIX_OPERATE type){constchar* funcName =NULL;bool result =false;switch(type){case MATRIX_OPERATE::ADD:
funcName ="MatrixAdd";break;case MATRIX_OPERATE::SUB:
funcName ="MatrixSub";break;case MATRIX_OPERATE::MUL:
funcName ="MatrixMul";break;case MATRIX_OPERATE::DIV:
funcName ="MatrixDiv";break;case MATRIX_OPERATE::NONE:break;default:break;}lua_getglobal(L, funcName);luaL_checktype(L,-1, LUA_TFUNCTION);//添加形参CreateLuaArr(L, gs_mat1);CreateLuaArr(L, gs_mat2);//调用函数if(lua_pcall(L,2,0,0)){printf("error[%s]\n",lua_tostring(L,-1));goto Exit;}
result =true;
Exit:return result;}boolInit(lua_State *L){//构造一张全局元表,名为CPP_MATRIXCreateMatrixMetaTable(L);//注册第三方API构造对象方法lua_pushcfunction(L, CreateMatrix);lua_setglobal(L,"CreateMatrix");if(luaL_dofile(L, LUA_SCRIPT_PATH)){printf("%s\n",lua_tostring(L,-1));}returntrue;}boolRun(lua_State* L){assert(NULL!= L);// 运算
gs_mat1 ={{1,2,3},{4,5,6}};
gs_mat2 ={{2,3,4},{5,6,7}};MatrixOperate(L, MATRIX_OPERATE::ADD);
gs_mat1 ={{1,2,3},{4,5,6}};
gs_mat2 ={{1,1,1},{1,1,1}};MatrixOperate(L, MATRIX_OPERATE::SUB);
gs_mat1 ={{1,2,3},{4,5,6}};
gs_mat2 ={{7,8},{9,10},{11,12}};MatrixOperate(L, MATRIX_OPERATE::MUL);
gs_mat1 ={{41,2,3},{424,5,6},{742,8,11}};
gs_mat2 ={{1,2,1},{1,1,2},{2,1,1}};MatrixOperate(L, MATRIX_OPERATE::DIV);returntrue;}boolUnInit(){returntrue;}intmain(){
lua_State* L =luaL_newstate();luaL_openlibs(L);if(Init(L)){Run(L);}UnInit();lua_close(L);return0;}
matrix2.0-lua.lua
local _class ={}functionclass(super)local tbClassType ={}
tbClassType.Ctor =false
tbClassType.super = super
tbClassType.New =function(...)local tbObj ={}dolocal funcCreate
funcCreate =function(tbClass,...)if tbClass.super thenfuncCreate(tbClass.super,...)endif tbClass.Ctor then
tbClass.Ctor(tbObj,...)endendfuncCreate(tbClassType,...)end-- 防止调用Ctor初始化时,在Ctor内部设置了元表的情况发生ifgetmetatable(tbObj)thengetmetatable(tbObj).__index = _class[tbClassType]elsesetmetatable(tbObj,{ __index = _class[tbClassType]})endreturn tbObj
endlocal vtbl ={}
_class[tbClassType]= vtbl
setmetatable(tbClassType,{ __newindex =function(tb,k,v)
vtbl[k]= v
end})if super thensetmetatable(vtbl,{ __index =function(tb,k)local varRet = _class[super][k]
vtbl[k]= varRet
return varRet
end})endreturn tbClassType
end
Matrix =class()function Matrix:Ctor(data)
self.tbData = data
self.nRow =#data
if self.nRow >0then
self.nColumn =(#data[1])else
self.nColumn =0end-- print("row:",self.nRow," col:",self.nColumn)setmetatable(self,{
__add =function(tbSource, tbDest)assert(tbSource,"tbSource not exist")assert(tbDest,"tbDest not exist")local tbRes = Matrix.New({})-- print(tbSource,tbDest)-- print("tbSource:",tbSource.nRow,tbSource.nColumn)-- tbSource:Print()-- print("tbDest:",tbDest.nRow,tbDest.nColumn)-- tbDest:Print()if tbSource.nRow ~= tbDest.nRow
or tbSource.nColumn ~= tbDest.nColumn thenprint("row or column not equal...")return tbRes
elsefor rowKey,rowValue inipairs(tbSource.tbData)dofor colKey,colValue inipairs(tbSource.tbData[rowKey])doif tbRes.tbData[rowKey]==nilthen
tbRes.tbData[rowKey]={}endif tbRes.tbData[rowKey][colKey]==nilthen
tbRes.tbData[rowKey][colKey]=0end
tbRes.tbData[rowKey][colKey]=
tbSource.tbData[rowKey][colKey]+ tbDest.tbData[rowKey][colKey]endend
tbRes.nRow = tbSource.nRow
tbRes.nColumn = tbSource.nColumn
return tbRes
endend,
__sub =function(tbSource, tbDest)assert(tbSource,"tbSource not exist")assert(tbDest,"tbDest not exist")local tbRes = Matrix.New({})if tbSource.nRow ~= tbDest.nRow
or tbSource.nColumn ~= tbDest.nColumn thenprint("row or column not equal...")return tbRes
elsefor rowKey,rowValue inipairs(tbSource.tbData)dofor colKey,colValue inipairs(tbSource.tbData[rowKey])doif tbRes.tbData[rowKey]==nilthen
tbRes.tbData[rowKey]={}endif tbRes.tbData[rowKey][colKey]==nilthen
tbRes.tbData[rowKey][colKey]=0end
tbRes.tbData[rowKey][colKey]=
tbSource.tbData[rowKey][colKey]- tbDest.tbData[rowKey][colKey]endend
tbRes.nRow = tbSource.nRow
tbRes.nColumn = tbSource.nColumn
return tbRes
endend,
__mul =function(tbSource, tbDest)return self:_MartixMul(tbSource, tbDest)end,
__div =function(tbSource, tbDest)assert(tbSource,"tbSource not exist")assert(tbDest,"tbDest not exist")local nDet = self:_GetDetValue(tbDest)if nDet ==0thenprint("matrix no inverse matrix...")returnnilend-- print("det ",nDet)local tbInverseDest = self:_MatrixNumMul(self:_GetCompanyMatrix(tbDest),1/ nDet)-- self:_GetCompanyMatrix(tbDest):Print()-- print(nDet)-- tbInverseDest:Print()return self:_MartixMul(tbSource, tbInverseDest)end})endfunction Matrix:Print()for rowKey,rowValue inipairs(self.tbData)dofor colKey,colValue inipairs(self.tbData[rowKey])do
io.write(self.tbData[rowKey][colKey],',')endprint('')endend-- 加function Matrix:Add(matrix)return self + matrix
end-- 减function Matrix:Sub(matrix)return self - matrix
end-- 乘function Matrix:Mul(matrix)return self * matrix
end-- 除function Matrix:Div(matrix)return self / matrix
end-- 切割,切去第rowIndex以及第colIndex列function Matrix:_CutoffMatrix(tbMatrix, rowIndex, colIndex)assert(tbMatrix,"tbMatrix not exist")assert(rowIndex >=1,"rowIndex < 1")assert(colIndex >=1,"colIndex < 1")local tbRes = Matrix.New({})
tbRes.nRow = tbMatrix.nRow -1
tbRes.nColumn = tbMatrix.nColumn -1for i =1, tbMatrix.nRow -1dofor j =1, tbMatrix.nColumn -1doif tbRes.tbData[i]==nilthen
tbRes.tbData[i]={}endlocal nRowDir =0local nColDir =0if i >= rowIndex then
nRowDir =1endif j >= colIndex then
nColDir =1end
tbRes.tbData[i][j]= tbMatrix.tbData[i + nRowDir][j + nColDir]endendreturn tbRes
end-- 获取矩阵的行列式对应的值function Matrix:_GetDetValue(tbMatrix)assert(tbMatrix,"tbMatrix not exist")-- 当矩阵为一阶矩阵时,直接返回A中唯一的元素if tbMatrix.nRow ==1thenreturn tbMatrix.tbData[1][1]endlocal nAns =0for i =1, tbMatrix.nColumn dolocal nFlag =-1if i %2~=0then
nFlag =1end
nAns =
nAns + tbMatrix.tbData[1][i]*
self:_GetDetValue(self:_CutoffMatrix(tbMatrix,1, i))* nFlag
-- print("_GetDetValue nflag:",nFlag)endreturn nAns
end-- 获取矩阵的伴随矩阵function Matrix:_GetCompanyMatrix(tbMatrix)assert(tbMatrix,"tbMatrix not exist")local tbRes = Matrix.New({})-- 伴随矩阵与原矩阵存在转置关系
tbRes.nRow = tbMatrix.nColumn
tbRes.nColumn = tbMatrix.nRow
for i =1, tbMatrix.nRow dofor j =1, tbMatrix.nColumn dolocal nFlag =1if((i + j)%2)~=0then
nFlag =-1endif tbRes.tbData[j]==nilthen
tbRes.tbData[j]={}end-- print(Matrix:_GetDetValue(Matrix:_CutoffMatrix(tbMatrix, i, j)))-- Matrix:_CutoffMatrix(tbMatrix, i, j):Print()-- print("---11----")
tbRes.tbData[j][i]=
self:_GetDetValue(self:_CutoffMatrix(tbMatrix, i, j))* nFlag
endendreturn tbRes
end-- 矩阵数乘function Matrix:_MatrixNumMul(tbMatrix, num)for i =1, tbMatrix.nRow dofor j =1, tbMatrix.nColumn do
tbMatrix.tbData[i][j]= tbMatrix.tbData[i][j]* num
endendreturn tbMatrix
end-- 矩阵相乘function Matrix:_MartixMul(tbSource, tbDest)assert(tbSource,"tbSource not exist")assert(tbDest,"tbDest not exist")if tbSource.nColumn ~= tbDest.nRow thenprint("column not equal row...")return tbSource
elselocal tbRes = Matrix.New({})for i =1, tbSource.nRow dofor j =1, tbDest.nColumn doif tbRes.tbData[i]==nilthen
tbRes.tbData[i]={}endif tbRes.tbData[i][j]==nilthen
tbRes.tbData[i][j]=0endfor k =1, tbSource.nColumn do
tbRes.tbData[i][j]=
tbRes.tbData[i][j]+(tbSource.tbData[i][k]* tbDest.tbData[k][j])endendend
tbRes.nRow = tbSource.nRow
tbRes.nColumn = tbDest.nColumn
return tbRes
endend-- addfunctionMatrixAdd(data1, data2)assert(data1,"data1 not exist")assert(data2,"data2 not exist")local matrix1 = Matrix.New(data1)local matrix2 = Matrix.New(data2)local matrix3 = matrix1 + matrix2
print("===========加法===========")
matrix3:Print()local cppMatrix1 =CreateMatrix()
cppMatrix1:InitMatrix(data1)local cppMatrix2 =CreateMatrix()
cppMatrix2:InitMatrix(data2)local cppMatrix3 = cppMatrix1 + cppMatrix2
cppMatrix3:PrintMatrix()end-- subfunctionMatrixSub(data1, data2)assert(data1,"data1 not exist")assert(data2,"data2 not exist")local matrix1 = Matrix.New(data1)local matrix2 = Matrix.New(data2)local matrix3 = matrix1 - matrix2
print("===========减法===========")
matrix3:Print()local cppMatrix1 =CreateMatrix()
cppMatrix1:InitMatrix(data1)local cppMatrix2 =CreateMatrix()
cppMatrix2:InitMatrix(data2)local cppMatrix3 = cppMatrix1 - cppMatrix2
cppMatrix3:PrintMatrix()end-- mulfunctionMatrixMul(data1, data2)assert(data1,"data1 not exist")assert(data2,"data2 not exist")local matrix1 = Matrix.New(data1)local matrix2 = Matrix.New(data2)local matrix3 = matrix1 * matrix2
print("===========乘法===========")
matrix3:Print()local cppMatrix1 =CreateMatrix()
cppMatrix1:InitMatrix(data1)local cppMatrix2 =CreateMatrix()
cppMatrix2:InitMatrix(data2)local cppMatrix3 = cppMatrix1 * cppMatrix2
cppMatrix3:PrintMatrix()end-- divfunctionMatrixDiv(data1, data2)assert(data1,"data1 not exist")assert(data2,"data2 not exist")local matrix1 = Matrix.New(data1)local matrix2 = Matrix.New(data2)local matrix3 = matrix1 / matrix2
print("===========除法===========")
matrix3:Print()local cppMatrix1 =CreateMatrix()
cppMatrix1:InitMatrix(data1)local cppMatrix2 =CreateMatrix()
cppMatrix2:InitMatrix(data2)local cppMatrix3 = cppMatrix1 / cppMatrix2
cppMatrix3:PrintMatrix()end
输出结果
版权归原作者 ufgnix0802 所有, 如有侵权,请联系我们删除。