diff --git a/hutool-db/src/main/java/cn/hutool/db/dialect/impl/MysqlDialect.java b/hutool-db/src/main/java/cn/hutool/db/dialect/impl/MysqlDialect.java index d10e811fb..3ce7a199a 100644 --- a/hutool-db/src/main/java/cn/hutool/db/dialect/impl/MysqlDialect.java +++ b/hutool-db/src/main/java/cn/hutool/db/dialect/impl/MysqlDialect.java @@ -1,10 +1,16 @@ package cn.hutool.db.dialect.impl; +import cn.hutool.db.Entity; import cn.hutool.db.Page; +import cn.hutool.db.StatementUtil; import cn.hutool.db.dialect.DialectName; import cn.hutool.db.sql.SqlBuilder; import cn.hutool.db.sql.Wrapper; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.SQLException; + /** * MySQL方言 * @author loolly @@ -21,9 +27,24 @@ public class MysqlDialect extends AnsiSqlDialect{ protected SqlBuilder wrapPageSql(SqlBuilder find, Page page) { return find.append(" LIMIT ").append(page.getStartPosition()).append(", ").append(page.getPageSize()); } - + @Override public String dialectName() { return DialectName.MYSQL.toString(); } + + /** + * 构建用于upsert的PreparedStatement + * + * @param conn 数据库连接对象 + * @param entity 数据实体类(包含表名) + * @param keys 查找字段 + * @return PreparedStatement + * @throws SQLException SQL执行异常 + */ + @Override + public PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException { + final SqlBuilder upsert = SqlBuilder.create(wrapper).upsert(entity, this.dialectName(),keys); + return StatementUtil.prepareStatement(conn, upsert); + } } diff --git a/hutool-db/src/main/java/cn/hutool/db/dialect/impl/PostgresqlDialect.java b/hutool-db/src/main/java/cn/hutool/db/dialect/impl/PostgresqlDialect.java index d7109e3c2..82f5fe373 100644 --- a/hutool-db/src/main/java/cn/hutool/db/dialect/impl/PostgresqlDialect.java +++ b/hutool-db/src/main/java/cn/hutool/db/dialect/impl/PostgresqlDialect.java @@ -1,8 +1,15 @@ package cn.hutool.db.dialect.impl; +import cn.hutool.db.Entity; +import cn.hutool.db.StatementUtil; import cn.hutool.db.dialect.DialectName; +import cn.hutool.db.sql.SqlBuilder; import cn.hutool.db.sql.Wrapper; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.SQLException; + /** * Postgree方言 @@ -20,4 +27,22 @@ public class PostgresqlDialect extends AnsiSqlDialect{ public String dialectName() { return DialectName.POSTGREESQL.name(); } + + /** + * 构建用于upsert的PreparedStatement + * + * @param conn 数据库连接对象 + * @param entity 数据实体类(包含表名) + * @param keys 查找字段 必须是有唯一索引的列且不能为空 + * @return PreparedStatement + * @throws SQLException SQL执行异常 + */ + @Override + public PreparedStatement psForUpsert(Connection conn, Entity entity, String... keys) throws SQLException { + if (null==keys || keys.length==0){ + throw new SQLException("keys不能为空"); + } + final SqlBuilder upsert = SqlBuilder.create(wrapper).upsert(entity, this.dialectName(),keys); + return StatementUtil.prepareStatement(conn, upsert); + } } diff --git a/hutool-db/src/main/java/cn/hutool/db/sql/SqlBuilder.java b/hutool-db/src/main/java/cn/hutool/db/sql/SqlBuilder.java index cee913a2e..ac666054b 100644 --- a/hutool-db/src/main/java/cn/hutool/db/sql/SqlBuilder.java +++ b/hutool-db/src/main/java/cn/hutool/db/sql/SqlBuilder.java @@ -202,7 +202,7 @@ public class SqlBuilder implements Builder { * * @param entity 实体 * @param dialectName 方言名,用于对特殊数据库特殊处理 - * @param keys 根据何字段来确认唯一性,不传则用主键 + * @param keys 根据何字段来确认唯一性,不传则用主键 * @return 自己 * @since 5.7.21 */ @@ -249,6 +249,12 @@ public class SqlBuilder implements Builder { // issue#1656@Github Phoenix兼容 if (DialectName.PHOENIX.match(dialectName)) { sql.append("UPSERT INTO ").append(entity.getTableName()); + } else if (DialectName.MYSQL.match(dialectName)) { + sql.append("INSERT INTO "); + sql.append(entity.getTableName()) + .append(" (").append(fieldsPart).append(") VALUES (") + .append(placeHolder).append(") on duplicate key update ") + .append(ArrayUtil.join(ArrayUtil.map(entity.keySet().toArray(), String.class, (k) -> k + "=values(" + k + ")"), ",")); } else if (DialectName.H2.match(dialectName)) { sql.append("MERGE INTO ").append(entity.getTableName()); if (null != keys && keys.length > 0) { @@ -257,6 +263,14 @@ public class SqlBuilder implements Builder { .append(placeHolder) .append(")"); } + } else if (DialectName.POSTGREESQL.match(dialectName)) { + sql.append("INSERT INTO "); + sql.append(entity.getTableName()) + .append(" (").append(fieldsPart).append(") VALUES (") + .append(placeHolder).append(") on conflict (") + .append(ArrayUtil.join(keys,",")) + .append(") do update set ") + .append(ArrayUtil.join(ArrayUtil.map(entity.keySet().toArray(), String.class, (k) -> k + "=excluded." + k ), ",")); } else { throw new RuntimeException(dialectName + " not support yet"); } diff --git a/hutool-db/src/test/java/cn/hutool/db/MySQLTest.java b/hutool-db/src/test/java/cn/hutool/db/MySQLTest.java index e3e72fba1..8ecebb74e 100644 --- a/hutool-db/src/test/java/cn/hutool/db/MySQLTest.java +++ b/hutool-db/src/test/java/cn/hutool/db/MySQLTest.java @@ -1,6 +1,9 @@ package cn.hutool.db; import cn.hutool.core.lang.Console; +import cn.hutool.core.util.ArrayUtil; +import org.junit.Assert; +import org.junit.BeforeClass; import org.junit.Ignore; import org.junit.Test; @@ -9,11 +12,16 @@ import java.util.List; /** * MySQL操作单元测试 - * - * @author looly * + * @author looly */ public class MySQLTest { + @BeforeClass + @Ignore + public static void createTable() throws SQLException { + Db db = Db.use("mysql"); + db.executeBatch("drop table if exists testuser", "CREATE TABLE if not exists `testuser` ( `id` int(11) NOT NULL, `account` varchar(255) DEFAULT NULL, `pass` varchar(255) DEFAULT NULL, PRIMARY KEY (`id`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8"); + } @Test @Ignore @@ -34,13 +42,13 @@ public class MySQLTest { * * @throws SQLException SQL异常 */ - @Test(expected=SQLException.class) + @Test(expected = SQLException.class) @Ignore public void txTest() throws SQLException { Db.use("mysql").tx(db -> { int update = db.update(Entity.create("user").set("text", "描述100"), Entity.create().set("id", 100)); db.update(Entity.create("user").set("text", "描述101"), Entity.create().set("id", 101)); - if(1 == update) { + if (1 == update) { // 手动指定异常,然后测试回滚触发 throw new RuntimeException("Error"); } @@ -64,4 +72,14 @@ public class MySQLTest { Console.log(all); } + @Test + @Ignore + public void upsertTest() throws SQLException { + Db db = Db.use("mysql"); + db.insert(Entity.create("testuser").set("id", 1).set("account", "ice").set("pass", "123456")); + db.upsert(Entity.create("testuser").set("id", 1).set("account", "icefairy").set("pass", "a123456")); + Entity user = db.get(Entity.create("testuser").set("id", 1)); + System.out.println("user======="+user.getStr("account")+"___"+user.getStr("pass")); + Assert.assertEquals(user.getStr("account"), new String("icefairy")); + } } diff --git a/hutool-db/src/test/java/cn/hutool/db/PostgreTest.java b/hutool-db/src/test/java/cn/hutool/db/PostgreTest.java index a19559a7e..250930efe 100644 --- a/hutool-db/src/test/java/cn/hutool/db/PostgreTest.java +++ b/hutool-db/src/test/java/cn/hutool/db/PostgreTest.java @@ -2,6 +2,7 @@ package cn.hutool.db; import java.sql.SQLException; +import org.junit.Assert; import org.junit.Ignore; import org.junit.Test; @@ -9,9 +10,8 @@ import cn.hutool.core.lang.Console; /** * PostgreSQL 单元测试 - * - * @author looly * + * @author looly */ public class PostgreTest { @@ -34,4 +34,16 @@ public class PostgreTest { Console.log(entity.get("id")); } } + + @Test + @Ignore + public void upsertTest() throws SQLException { + Db db = Db.use("postgre"); + db.executeBatch("drop table if exists ctest", + "create table if not exists \"ctest\" ( \"id\" serial4, \"t1\" varchar(255) COLLATE \"pg_catalog\".\"default\", \"t2\" varchar(255) COLLATE \"pg_catalog\".\"default\", \"t3\" varchar(255) COLLATE \"pg_catalog\".\"default\", CONSTRAINT \"ctest_pkey\" PRIMARY KEY (\"id\") ) "); + db.insert(Entity.create("ctest").set("id", 1).set("t1", "111").set("t2", "222").set("t3", "333")); + db.upsert(Entity.create("ctest").set("id", 1).set("t1", "new111").set("t2", "new222").set("t3", "bew333"),"id"); + Entity et=db.get(Entity.create("ctest").set("id", 1)); + Assert.assertEquals("new111",et.getStr("t1")); + } }