jpeg-ls算法代码
时间: 2023-08-26 13:05:01 浏览: 142
以下是一个简单的实现JPEG-LS编码和解码的代码示例(基于C++):
```cpp
#include <iostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <cmath>
using namespace std;
#define MAXN 1024
typedef unsigned char uchar;
typedef unsigned short ushort;
typedef unsigned int uint;
typedef unsigned long ulong;
// 矩阵操作
template <typename T>
class Matrix {
public:
Matrix(int r, int c) {
rows = r;
cols = c;
data = new T[rows * cols];
}
Matrix(int r, int c, T val) {
rows = r;
cols = c;
data = new T[rows * cols];
for (int i = 0; i < rows * cols; i++) {
data[i] = val;
}
}
~Matrix() {
delete[] data;
}
void set(int r, int c, T val) {
data[r * cols + c] = val;
}
T get(int r, int c) const {
return data[r * cols + c];
}
int getRows() const {
return rows;
}
int getCols() const {
return cols;
}
private:
T *data;
int rows, cols;
};
// 二进制转换为十进制
uint binary2decimal(const uchar *data, int start, int len) {
uint res = 0;
for (int i = start; i < start + len; i++) {
res = res * 2 + data[i];
}
return res;
}
// 十进制转换为二进制
void decimal2binary(uint val, uchar *data, int start, int len) {
for (int i = start + len - 1; i >= start; i--) {
data[i] = val % 2;
val /= 2;
}
}
// 读取二进制文件
vector<uchar> readBinaryFile(const char *filename) {
vector<uchar> data;
ifstream fin(filename, ios::binary);
if (!fin) {
cerr << "Error: cannot open file " << filename << endl;
return data;
}
fin.seekg(0, ios::end);
int fileSize = fin.tellg();
fin.seekg(0, ios::beg);
data.resize(fileSize);
fin.read((char *)&data[0], fileSize);
fin.close();
return data;
}
// 写入二进制文件
void writeBinaryFile(const char *filename, const vector<uchar> &data) {
ofstream fout(filename, ios::binary);
if (!fout) {
cerr << "Error: cannot open file " << filename << endl;
return;
}
fout.write((char *)&data[0], data.size());
fout.close();
}
// 将图像分成若干块
void splitImage(const Matrix<ushort> &img, Matrix<ushort> *blocks, int blockSize) {
int rows = img.getRows();
int cols = img.getCols();
int blockRows = rows / blockSize;
int blockCols = cols / blockSize;
for (int i = 0; i < blockRows; i++) {
for (int j = 0; j < blockCols; j++) {
Matrix<ushort> &blk = blocks[i * blockCols + j];
blk = Matrix<ushort>(blockSize, blockSize);
for (int ii = 0; ii < blockSize; ii++) {
for (int jj = 0; jj < blockSize; jj++) {
blk.set(ii, jj, img.get(i * blockSize + ii, j * blockSize + jj));
}
}
}
}
}
// 将若干块合并成图像
void mergeImage(const Matrix<ushort> *blocks, int blockRows, int blockCols, int blockSize, Matrix<ushort> &img) {
int rows = blockRows * blockSize;
int cols = blockCols * blockSize;
img = Matrix<ushort>(rows, cols);
for (int i = 0; i < blockRows; i++) {
for (int j = 0; j < blockCols; j++) {
const Matrix<ushort> &blk = blocks[i * blockCols + j];
for (int ii = 0; ii < blockSize; ii++) {
for (int jj = 0; jj < blockSize; jj++) {
img.set(i * blockSize + ii, j * blockSize + jj, blk.get(ii, jj));
}
}
}
}
}
// 计算预测误差
void calcPredError(const Matrix<ushort> &blk, int predMode, Matrix<short> &err) {
int rows = blk.getRows();
int cols = blk.getCols();
err = Matrix<short>(rows, cols);
switch (predMode) {
case 0: {
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
if (i == 0 && j == 0) {
err.set(i, j, blk.get(i, j));
} else if (i == 0) {
err.set(i, j, blk.get(i, j) - blk.get(i, j - 1));
} else if (j == 0) {
err.set(i, j, blk.get(i, j) - blk.get(i - 1, j));
} else {
err.set(i, j, blk.get(i, j) - (blk.get(i - 1, j) + blk.get(i, j - 1) - blk.get(i - 1, j - 1)));
}
}
}
break;
}
case 1: {
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
if (i == 0 && j == 0) {
err.set(i, j, blk.get(i, j));
} else if (i == 0) {
err.set(i, j, blk.get(i, j) - (blk.get(i, j - 1) + blk.get(i + 1, j - 1) + 1) / 2);
} else if (j == 0) {
err.set(i, j, blk.get(i, j) - (blk.get(i - 1, j) + blk.get(i - 1, j + 1) + 1) / 2);
} else if (j == cols - 1) {
err.set(i, j, blk.get(i, j) - (blk.get(i, j - 1) + blk.get(i - 1, j - 1) + 1) / 2);
} else {
err.set(i, j, blk.get(i, j) - (blk.get(i - 1, j) + blk.get(i, j - 1) + blk.get(i + 1, j - 1) + blk.get(i - 1, j + 1) + 2) / 4);
}
}
}
break;
}
case 2: {
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
if (i == 0 && j == 0) {
err.set(i, j, blk.get(i, j));
} else if (i == 0) {
err.set(i, j, blk.get(i, j) - blk.get(i, j - 1));
} else if (j == 0) {
err.set(i, j, blk.get(i, j) - blk.get(i - 1, j));
} else {
int a = blk.get(i - 1, j);
int b = blk.get(i, j - 1);
int c = blk.get(i - 1, j - 1);
int x = a + b - c;
int pa = abs(x - a);
int pb = abs(x - b);
int pc = abs(x - c);
if (pa <= pb && pa <= pc) {
err.set(i, j, blk.get(i, j) - a);
} else if (pb <= pc) {
err.set(i, j, blk.get(i, j) - b);
} else {
err.set(i, j, blk.get(i, j) - c);
}
}
}
}
break;
}
case 3: {
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
if (i == 0 && j == 0) {
err.set(i, j, blk.get(i, j));
} else if (i == 0) {
err.set(i, j, blk.get(i, j) - blk.get(i, j - 1));
} else if (j == 0) {
err.set(i, j, blk.get(i, j) - blk.get(i - 1, j));
} else {
int a = blk.get(i - 1, j);
int b = blk.get(i, j - 1);
int c = blk.get(i - 1, j - 1);
int x = a + b - c;
int pa = abs(x - a);
int pb = abs(x - b);
int pc = abs(x - c);
if (pa <= pb && pa <= pc) {
err.set(i, j, blk.get(i, j) - a);
} else if (pb <= pc) {
err.set(i, j, blk.get(i, j) - b);
} else {
err.set(i, j, blk.get(i, j) - c);
}
err.set(i, j, -err.get(i, j));
}
}
}
break;
}
default: {
cerr << "Error: invalid prediction mode" << endl;
break;
}
}
}
// 计算预测误差直方图
void calcPredErrorHistogram(const Matrix<short> &err, int maxErr, int &thresh, int &errCount) {
const int HIST_SIZE = 65536;
int hist[HIST_SIZE];
memset(hist, 0, sizeof(hist));
int rows = err.getRows();
int cols = err.getCols();
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
hist[err.get(i, j) + maxErr]++;
}
}
int total = rows * cols;
int targetCount = total * 3 / 4;
int count = 0;
for (int i = HIST_SIZE - 1; i >= 0; i--) {
count += hist[i];
if (count >= targetCount) {
thresh = i - maxErr;
errCount = count;
break;
}
}
}
// 将预测误差编码为差分编码
void encodePredError(const Matrix<short> &err, int thresh, uchar *data, int &len) {
int rows = err.getRows();
int cols = err.getCols();
len = 0;
short prevVal = 0;
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
short val = err.get(i, j);
if (val >= thresh) {
val++;
}
if (val <= -thresh) {
val--;
}
val -= prevVal;
prevVal += val;
if (val < 0) {
val += (1 << 16);
}
decimal2binary(val, data, len, 16);
len += 16;
}
}
}
// 将差分编码解码为预测误差
void decodePredError(const uchar *data, int len, int thresh, Matrix<short> &err) {
int rows = err.getRows();
int cols = err.getCols();
err = Matrix<short>(rows, cols);
short prevVal = 0;
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
uint val = binary2decimal(data, i * cols * 16 + j * 16, 16);
if (val >= (1 << 15)) {
val -= (1 << 16);
}
if (val < 0) {
val -= 1;
}
val += prevVal;
prevVal = val;
if (val > thresh) {
val--;
}
if (val < -thresh) {
val++;
}
err.set(i, j, val);
}
}
}
// 计算RLE编码
void calcRLE(const Matrix<short> &err, uchar *data, int &len) {
int rows = err.getRows();
int cols = err.getCols();
len = 0;
for (int i = 0; i < rows; i++) {
int j = 0;
while (j < cols) {
int k = j + 1;
while (k < cols && err.get(i, k) == err.get(i, j)) {
k++;
}
int cnt = k - j;
if (cnt >= 2) {
if (cnt <= 9) {
data[len++] = 0x90 + cnt - 2;
} else if (cnt <= 270) {
data[len++] = 0xFF;
data[len++] = cnt - 10;
} else {
cerr << "Error: RLE count too large" << endl;
exit(1);
}
decimal2binary(err.get(i, j), data, len, 16);
len += 16;
j = k;
} else {
decimal2binary(err.get(i, j), data, len, 16);
len += 16;
j++;
}
}
}
}
// 解码RLE编码
void decodeRLE(const uchar *data, int len, Matrix<short> &err) {
int rows = err.getRows();
int cols = err.getCols();
err = Matrix<short>(rows, cols);
int pos = 0;
for (int i = 0; i < rows; i++) {
int j = 0;
while (j < cols) {
uchar b1 = data[pos++];
if (b1 == 0xFF) {
uchar b2 = data[pos++];
int cnt = b2 + 10;
if (cnt > cols - j) {
cerr << "Error: RLE count too large" << endl;
exit(1);
}
short val = binary2decimal(data, pos * 8, 16);
pos += 2;
for (int k = 0; k < cnt; k++) {
err.set(i, j + k, val);
}
j += cnt;
} else if (b1 >= 0x90) {
int cnt = b1 - 0x8E;
if (cnt > cols - j) {
cerr << "Error: RLE count too large" << endl;
exit(1);
}
short val = binary2decimal(data, pos * 8, 16);
pos += 2;
for (int k = 0; k < cnt; k++) {
err.set(i, j + k, val);
}
j += cnt;
} else {
short val = binary2decimal(data, pos * 8, 16);
pos += 2;
err.set(i, j, val);
j++;
}
}
}
}
// 将预测误差编码为压缩后的数据(包括预测模式、预测误差直方图阈值、预测误差差分编码和RLE编码)
void encode(const Matrix<ushort> &blk, uchar *data, int &len) {
int rows = blk.getRows();
int cols = blk.getCols();
int predMode = 0;
Matrix<short> err(rows, cols);
calcPredError(blk, predMode, err);
int maxErr = (1 << 15) - 1;
int thresh = 0;
int errCount = 0;
calcPredErrorHistogram(err, maxErr, thresh, errCount);
encodePredError(err, thresh, data, len);
uchar *rleData = new uchar[rows * cols * 2 * 5];
int rleLen;
calcRLE(err, rleData, rleLen);
if (rleLen < len) {
data[0] = (data[0] & 0xFC) | predMode;
decimal2binary(thresh, data, 2, 15);
len = rleLen;
memcpy(data + 17, rleData, len);
len += 17;
} else {
len = 2 + 16 * rows * cols;
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
decimal2binary(err.get(i, j), data, 2 + (i * cols + j) * 16, 16);
}
}
}
delete[] rleData;
}
// 将压缩后的数据解码为预测误差
void decode(const uchar *data, int len, Matrix<short> &err) {
int rows = err.getRows();
int cols = err.getCols();
int mode = data[0] & 0x03;
int pos = 2;
if (mode == 0) {
for (int i = 0; i < rows; i++) {
阅读全文