feat: 重构 batchUpdate 方法返回 BatchUpdateResult 对象

- 将 batchUpdate 方法的返回类型从 List<int[]> 改为 BatchUpdateResult
- 移除 exceptions 参数,改用 BatchUpdateResult 对象来封装执行结果
- 添加 BatchUpdateResult 类来统一管理批量更新的状态、统计信息和异常处理
- 优化批处理逻辑,提供更详细的执行状态和错误处理机制
- 更新相关接口和实现类以适配新的方法签名
- 新增测试用例,验证新的批量更新结果处理方式
This commit is contained in:
2026-05-23 00:31:14 +08:00
parent 6e0230888d
commit 96d414252f
8 changed files with 507 additions and 38 deletions

View File

@@ -0,0 +1,46 @@
/*
* Copyright 2026 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package xyz.zhouxy.jdbc;
/**
* 批量更新错误信息
*
* @author ZhouXY108 <luquanlion@outlook.com>
*/
public class BatchUpdateErrorInfo {
private final int batchIndex;
private final Throwable cause;
private final Class<? extends Throwable> errorType;
public BatchUpdateErrorInfo(int batchIndex, Throwable cause) {
this.batchIndex = batchIndex;
this.cause = cause;
this.errorType = cause.getClass();
}
public int getBatchIndex() {
return batchIndex;
}
public Throwable getCause() {
return cause;
}
public Class<? extends Throwable> getErrorType() {
return errorType;
}
}

View File

@@ -0,0 +1,191 @@
/*
* Copyright 2026 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package xyz.zhouxy.jdbc;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
public class BatchUpdateResult {
private final int total;
private final int batchCount;
private final int batchSize;
private BatchUpdateStatus status = BatchUpdateStatus.SUCCESS;
private Map<Integer, int[]> allUpdateCounts;
private Map<Integer, BatchUpdateErrorInfo> allErrorsInfo;
private int successBatchCount;
private int completeBatchCount;
public BatchUpdateResult(int total, int batchCount, int batchSize) {
this.total = total;
this.batchCount = batchCount;
this.batchSize = batchSize;
this.allUpdateCounts = new HashMap<>(batchCount);
this.allErrorsInfo = new HashMap<>(batchCount);
}
/**
* 记录成功批次
*/
public void recordSuccessBatch(int batchIndex, int[] updateCounts) {
this.completeBatchCount++;
this.allUpdateCounts.put(batchIndex, updateCounts);
this.successBatchCount++;
}
/**
* 记录失败批次
*/
public void recordErrorBatch(int batchIndex, int[] updateCounts, Throwable cause) {
this.completeBatchCount++;
this.allUpdateCounts.put(batchIndex, updateCounts);
this.allErrorsInfo.put(batchIndex, new BatchUpdateErrorInfo(batchIndex, cause));
if (this.status == BatchUpdateStatus.SUCCESS) {
this.status = BatchUpdateStatus.COMPLETED_WITH_ERRORS;
}
}
/**
* 中断
*/
public void interrupt() {
this.status = BatchUpdateStatus.INTERRUPTED;
}
/**
* 获取批次更新结果
*/
public int[] getUpdateCounts(int batchIndex) {
return this.allUpdateCounts.get(batchIndex);
}
/**
* 获取错误批次号
*/
public int[] getErrorBatchIndexes() {
return this.allErrorsInfo.keySet().stream().mapToInt(Integer::intValue).toArray();
}
/**
* 获取错误批次信息
*
* @param batchIndex 批次号
* @return 批次错误信息
*/
public BatchUpdateErrorInfo getBatchUpdateErrorInfo(int batchIndex) {
return this.allErrorsInfo.get(batchIndex);
}
/**
* 获取所有错误批次信息
*
* @return 批次错误信息
*/
public Map<Integer, BatchUpdateErrorInfo> getAllErrorsInfo() {
return Collections.unmodifiableMap(allErrorsInfo);
}
/**
* 获取总数据量
*
* @return 总数据量
*/
public int getTotal() {
return total;
}
/**
* 获取批次数量
*
* @return 批次数量
*/
public int getBatchCount() {
return batchCount;
}
/**
* 获取批次大小
*
* @return 批次大小
*/
public int getBatchSize() {
return batchSize;
}
/**
* 获取批次更新状态
*
* @return 批次更新状态
*/
public BatchUpdateStatus getStatus() {
return status;
}
/**
* 获取完成批次数量
*
* @return 完成批次数量
*/
public int getCompleteBatchCount() {
return completeBatchCount;
}
/**
* 获取成功批次数量
*
* @return 成功批次数量
*/
public int getSuccessBatchCount() {
return successBatchCount;
}
/**
* 获取错误批次数量
*
* @return 错误批次数量
*/
public int getErrorBatchCount() {
return allErrorsInfo.size();
}
/**
* 获取剩余批次数量
*
* @return 剩余批次数量
*/
public int getRemainingBatchCount() {
return batchCount - successBatchCount - getErrorBatchCount();
}
@Override
public String toString() {
return "BatchUpdateResult ["
+ "total()=" + total
+ ", batchCount()=" + batchCount
+ ", batchSize()=" + batchSize
+ ", status()=" + status
+ ", completeBatchCount()=" + completeBatchCount
+ ", successBatchCount()=" + successBatchCount
+ ", errorBatchCount()=" + getErrorBatchCount()
+ ", remainingBatchCount()=" + getRemainingBatchCount()
+ "]";
}
}

View File

@@ -0,0 +1,71 @@
/*
* Copyright 2026 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package xyz.zhouxy.jdbc;
import xyz.zhouxy.plusone.commons.base.IWithIntCode;
/**
* 批量更新状态
*
* @author ZhouXY108 <luquanlion@outlook.com>
*/
public enum BatchUpdateStatus implements IWithIntCode {
/**
* 成功
*/
SUCCESS(0, "成功"),
/**
* 部分成功
*/
COMPLETED_WITH_ERRORS(-1, "部分成功"),
/**
* 中断
*/
INTERRUPTED(-2, "中断"),
;
private final int code;
private final String description;
BatchUpdateStatus(int code, String description) {
this.code = code;
this.description = description;
}
@Override
public int getCode() {
return code;
}
/**
* @return the description
*/
public String getDescription() {
return description;
}
@Override
public String toString() {
return "BatchUpdateStatus ["
+ "name=" + name()
+ ", code=" + code
+ ", description=" + description
+ "]";
}
}

View File

@@ -210,29 +210,28 @@ class JdbcOperationSupport {
* @param params 参数列表
* @param batchSize 每次批量更新的数据量
* @param exceptions 空列表,用于记录异常信息
* @param quietly 静默跑批
* {@code quietly} 为 {@code true},发生异常不中断操作,将异常存入 {@code exceptions} 中
* {@code quietly} 为 {@code false},发生异常即中断操作,并将异常抛出
* @param quietly 静默分批更新
* 如果 {@code quietly} 为 {@code true}分批更新过程中发生异常不中断操作;
* 如果 {@code quietly} 为 {@code false}分批更新过程中发生异常即中断操作,并返回结果
*/
static List<int[]> batchUpdate(Connection conn,
static BatchUpdateResult batchUpdate(Connection conn,
String sql, @Nullable Collection<Object[]> params, int batchSize,
List<Exception> exceptions, boolean quietly)
boolean quietly)
throws SQLException {
assertConnectionNotNull(conn);
assertSqlNotNull(sql);
checkArgument(batchSize > 0, "The batch size must be greater than 0.");
checkArgument(!quietly || (exceptions != null && exceptions.isEmpty()),
"The list used to store exceptions should be non-null and empty.");
if (params == null || params.isEmpty()) {
return Collections.emptyList();
return new BatchUpdateResult(0, 0, batchSize);
}
int batchCount = (params.size() + batchSize - 1) / batchSize;
List<int[]> result = Lists.newArrayListWithCapacity(batchCount);
final BatchUpdateResult result = new BatchUpdateResult(params.size(), batchCount, batchSize);
try (PreparedStatement stmt = conn.prepareStatement(sql)) {
int i = 0;
int batchIndex = 0;
for (Object[] ps : params) {
i++;
fillStatement(stmt, ps);
@@ -240,28 +239,27 @@ class JdbcOperationSupport {
final int indexInBatch = i % batchSize;
if (indexInBatch == 0 || i >= params.size()) {
try {
int[] n = stmt.executeBatch();
result.add(n);
stmt.clearBatch();
int[] updateCounts = stmt.executeBatch();
result.recordSuccessBatch(batchIndex, updateCounts);
}
catch (Exception e) {
final int[] updateCounts;
if (e instanceof BatchUpdateException) {
updateCounts = ((BatchUpdateException)e).getUpdateCounts();
updateCounts = ((BatchUpdateException) e).getUpdateCounts();
}
else {
int n = (i >= params.size() && indexInBatch != 0) ? indexInBatch : batchSize;
updateCounts = new int[n];
Arrays.fill(updateCounts, UNKNOWN_COUNT);
}
result.add(updateCounts);
stmt.clearBatch();
result.recordErrorBatch(batchIndex, updateCounts, e);
if (!quietly) {
throw e;
result.interrupt();
return result;
}
// 收集异常信息
exceptions.add(e);
}
stmt.clearBatch();
batchIndex++;
}
}
return result;

View File

@@ -243,28 +243,31 @@ public interface JdbcOperations {
throws SQLException;
/**
* 批量更新,返回每条记录更新的行数
* 批量更新
*
* <p>
* 跑批过程中发生异常即中断操作,并返回结果。
*
* @param sql SQL 语句
* @param params 参数列表
* @param batchSize 每次批量更新的数据量
*/
List<int[]> batchUpdate(String sql, @Nullable Collection<Object[]> params, int batchSize)
BatchUpdateResult batchUpdate(String sql, @Nullable Collection<Object[]> params, int batchSize)
throws SQLException;
/**
* 批量更新,返回更新成功的记录行数
* 批量更新
*
* @param sql sql语句
* @param params 参数列表
* @param batchSize 每次批量更新的数据量
* @param exceptions 空列表,用于记录异常信息
* @param quietly 静默跑批
* {@code quietly} 为 {@code true},发生异常不中断操作,将异常存入 {@code exceptions} 中
* {@code quietly} 为 {@code false},发生异常即中断操作,并将异常抛出
* @param quietly 静默分批更新
* 如果 {@code quietly} 为 {@code true}分批更新过程中发生异常不中断操作;
* 如果 {@code quietly} 为 {@code false}分批更新过程中发生异常即中断操作,并返回结果
*/
List<int[]> batchUpdate(String sql, @Nullable Collection<Object[]> params,
int batchSize, List<Exception> exceptions, boolean quietly)
BatchUpdateResult batchUpdate(String sql, @Nullable Collection<Object[]> params,
int batchSize, boolean quietly)
throws SQLException;
// #endregion

View File

@@ -261,21 +261,21 @@ public class SimpleJdbcTemplate implements JdbcOperations {
/** {@inheritDoc} */
@Override
public List<int[]> batchUpdate(String sql, @Nullable Collection<Object[]> params, int batchSize)
public BatchUpdateResult batchUpdate(String sql, @Nullable Collection<Object[]> params, int batchSize)
throws SQLException {
try (Connection conn = this.dataSource.getConnection()) {
return JdbcOperationSupport.batchUpdate(conn, sql, params, batchSize, null, false);
return JdbcOperationSupport.batchUpdate(conn, sql, params, batchSize, false);
}
}
/** {@inheritDoc} */
@Override
public List<int[]> batchUpdate(String sql, @Nullable Collection<Object[]> params,
int batchSize, List<Exception> exceptions, boolean quietly)
public BatchUpdateResult batchUpdate(String sql, @Nullable Collection<Object[]> params,
int batchSize, boolean quietly)
throws SQLException {
try (Connection conn = this.dataSource.getConnection()) {
return JdbcOperationSupport
.batchUpdate(conn, sql, params, batchSize, exceptions, quietly);
.batchUpdate(conn, sql, params, batchSize, quietly);
}
}
@@ -533,20 +533,19 @@ public class SimpleJdbcTemplate implements JdbcOperations {
/** {@inheritDoc} */
@Override
public List<int[]> batchUpdate(String sql, @Nullable Collection<Object[]> params, int batchSize)
public BatchUpdateResult batchUpdate(String sql, @Nullable Collection<Object[]> params, int batchSize)
throws SQLException {
return JdbcOperationSupport.batchUpdate(this.conn, sql, params, batchSize, null, false);
return JdbcOperationSupport.batchUpdate(this.conn, sql, params, batchSize, false);
}
/** {@inheritDoc} */
@Override
public List<int[]> batchUpdate(String sql,
public BatchUpdateResult batchUpdate(String sql,
@Nullable Collection<Object[]> params,
int batchSize,
List<Exception> exceptions,
boolean quietly) throws SQLException {
return JdbcOperationSupport
.batchUpdate(this.conn, sql, params, batchSize, exceptions, quietly);
.batchUpdate(this.conn, sql, params, batchSize, quietly);
}
// #endregion

View File

@@ -32,8 +32,19 @@ public class AccountPO {
public AccountPO() {
}
public AccountPO(Long id, String username, String accountStatus, LocalDateTime createTime, Long createdBy,
LocalDateTime updateTime, Long updatedBy, Long version) {
public AccountPO(Long id, String username, String accountStatus,
Long createdBy, Long updatedBy) {
this.id = id;
this.username = username;
this.accountStatus = accountStatus;
this.createdBy = createdBy;
this.updatedBy = updatedBy;
}
public AccountPO(Long id, String username, String accountStatus,
LocalDateTime createTime, Long createdBy,
LocalDateTime updateTime, Long updatedBy,
Long version) {
this.id = id;
this.username = username;
this.accountStatus = accountStatus;

View File

@@ -0,0 +1,150 @@
/*
* Copyright 2026 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package xyz.zhouxy.jdbc.test;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static xyz.zhouxy.jdbc.ParamBuilder.buildBatchParams;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.sql.SQLException;
import java.util.List;
import java.util.Optional;
import org.h2.jdbcx.JdbcDataSource;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import com.google.common.collect.Lists;
import com.google.common.io.Resources;
import xyz.zhouxy.jdbc.BatchUpdateResult;
import xyz.zhouxy.jdbc.BatchUpdateStatus;
import xyz.zhouxy.jdbc.SimpleJdbcTemplate;
public class BatchUpdateTests {
private static SimpleJdbcTemplate jdbcTemplate;
@BeforeAll
static void initH2() throws IOException, SQLException {
JdbcDataSource dataSource = new JdbcDataSource();
dataSource.setURL("jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1;DATABASE_TO_UPPER=FALSE;MODE=MySQL");
dataSource.setUser("sa");
dataSource.setPassword("");
jdbcTemplate = new SimpleJdbcTemplate(dataSource);
// 建表
executeSqlFile("schema.sql");
}
@BeforeEach
void initData() throws SQLException {
// 初始化数据
jdbcTemplate.update("truncate table sys_account");
}
static void executeSqlFile(String filePath) throws IOException, SQLException {
String[] sqls = Resources
.toString(Resources.getResource(filePath), StandardCharsets.UTF_8)
.split(";");
for (String sql : sqls) {
jdbcTemplate.update(sql);
}
}
final List<AccountPO> accountPOs = Lists.newArrayList(
// batch 0
new AccountPO(10001L, "test_0001", "1", 1L, 1L),
new AccountPO(10002L, "test_0002", "1", 1L, 1L),
new AccountPO(10003L, "test_0003", "1", 1L, 1L),
// batch 1
new AccountPO(10004L, "test_0004", "1", 1L, 1L),
new AccountPO(10005L, "test_0005", "1", 1L, 1L),
new AccountPO(10006L, "test_0006", "1", 1L, 1L),
// batch 2
new AccountPO(10007L, "test_0007", "1", 1L, 1L),
new AccountPO(10007L, "test_*0007", "1", 1L, 1L),
// new AccountPO(10008L, "test_0008", "1", 1L, 1L),
new AccountPO(10009L, "test_0009", "1", 1L, 1L),
// batch 3
new AccountPO(10009L, "test_*0009", "1", 1L, 1L),
// new AccountPO(10010L, "test_0010", "1", 1L, 1L),
new AccountPO(10011L, "test_0011", "1", 1L, 1L),
new AccountPO(10012L, "test_0012", "1", 1L, 1L),
// batch 4
new AccountPO(10013L, "test_0013", "1", 1L, 1L)
);
@Test
void testBatchUpdate() throws SQLException {
Optional<Integer> count0 = jdbcTemplate.queryFirst("SELECT COUNT(*) FROM sys_account", (rs, i) -> rs.getInt(1));
assertEquals(0, count0.get().intValue());
BatchUpdateResult result = jdbcTemplate.batchUpdate(
"INSERT INTO sys_account (id, username, account_status, created_by, updated_by) VALUES (?, ?, ?, ?, ?)",
buildBatchParams(accountPOs,
a -> new Object[] { a.getId(), a.getUsername(), a.getAccountStatus(), a.getCreatedBy(), a.getUpdatedBy() }),
3);
assertEquals(BatchUpdateStatus.INTERRUPTED, result.getStatus());
assertEquals(13, result.getTotal());
assertEquals(5, result.getBatchCount());
assertEquals(3, result.getCompleteBatchCount());
assertEquals(2, result.getSuccessBatchCount());
assertEquals(1, result.getErrorBatchCount());
assertEquals(2, result.getRemainingBatchCount());
assertArrayEquals(new int[] { 1, 1, 1 }, result.getUpdateCounts(0));
assertArrayEquals(new int[] { 1, 1, 1 }, result.getUpdateCounts(1));
assertArrayEquals(new int[] { 1, -3, 1 }, result.getUpdateCounts(2));
assertNull(result.getUpdateCounts(3));
assertNull(result.getUpdateCounts(4));
Optional<Integer> count8 = jdbcTemplate.queryFirst("SELECT COUNT(*) FROM sys_account", (rs, i) -> rs.getInt(1));
assertEquals(8, count8.get().intValue());
}
@Test
void testBatchUpdateQuietly() throws SQLException {
Optional<Integer> count0 = jdbcTemplate.queryFirst("SELECT COUNT(*) FROM sys_account", (rs, i) -> rs.getInt(1));
assertEquals(0, count0.get().intValue());
BatchUpdateResult result = jdbcTemplate.batchUpdate(
"INSERT INTO sys_account (id, username, account_status, created_by, updated_by) VALUES (?, ?, ?, ?, ?)",
buildBatchParams(accountPOs,
a -> new Object[] { a.getId(), a.getUsername(), a.getAccountStatus(), a.getCreatedBy(), a.getUpdatedBy() }),
3,
true);
assertEquals(BatchUpdateStatus.COMPLETED_WITH_ERRORS, result.getStatus());
assertEquals(13, result.getTotal());
assertEquals(5, result.getBatchCount());
assertEquals(5, result.getCompleteBatchCount());
assertEquals(3, result.getSuccessBatchCount());
assertEquals(2, result.getErrorBatchCount());
assertEquals(0, result.getRemainingBatchCount());
assertArrayEquals(new int[] { 1, 1, 1 }, result.getUpdateCounts(0));
assertArrayEquals(new int[] { 1, 1, 1 }, result.getUpdateCounts(1));
assertArrayEquals(new int[] { 1, -3, 1 }, result.getUpdateCounts(2));
assertArrayEquals(new int[] { -3, 1, 1 }, result.getUpdateCounts(3));
assertArrayEquals(new int[] { 1 }, result.getUpdateCounts(4));
Optional<Integer> count11 = jdbcTemplate.queryFirst("SELECT COUNT(*) FROM sys_account", (rs, i) -> rs.getInt(1));
assertEquals(11, count11.get().intValue());
}
}