未验证 提交 d5737751 编写于 作者: Z Zhengqiang Duan 提交者: GitHub

support subquery sharding route without sharding column (#8105)

上级 0b36da6c
......@@ -24,9 +24,9 @@ import org.apache.shardingsphere.infra.binder.statement.dml.InsertStatementConte
import org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.config.properties.ConfigurationProperties;
import org.apache.shardingsphere.infra.hint.HintManager;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.route.SQLRouter;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.sharding.api.config.strategy.sharding.HintShardingStrategyConfiguration;
import org.apache.shardingsphere.sharding.constant.ShardingOrder;
import org.apache.shardingsphere.sharding.route.engine.condition.ShardingCondition;
......@@ -34,6 +34,7 @@ import org.apache.shardingsphere.sharding.route.engine.condition.ShardingConditi
import org.apache.shardingsphere.sharding.route.engine.condition.engine.ShardingConditionEngine;
import org.apache.shardingsphere.sharding.route.engine.condition.engine.ShardingConditionEngineFactory;
import org.apache.shardingsphere.sharding.route.engine.condition.value.ListShardingConditionValue;
import org.apache.shardingsphere.sharding.route.engine.condition.value.RangeShardingConditionValue;
import org.apache.shardingsphere.sharding.route.engine.condition.value.ShardingConditionValue;
import org.apache.shardingsphere.sharding.route.engine.type.ShardingRouteEngineFactory;
import org.apache.shardingsphere.sharding.route.engine.validator.ShardingStatementValidator;
......@@ -99,7 +100,6 @@ public final class ShardingSQLRouter implements SQLRouter<ShardingRule> {
return;
}
}
Preconditions.checkState(!shardingConditions.getConditions().isEmpty(), "Must have sharding column with subquery.");
if (shardingConditions.getConditions().size() > 1) {
Preconditions.checkState(isSameShardingCondition(shardingRule, shardingConditions), "Sharding value must same with subquery.");
}
......@@ -127,27 +127,39 @@ public final class ShardingSQLRouter implements SQLRouter<ShardingRule> {
for (int i = 0; i < shardingCondition1.getValues().size(); i++) {
ShardingConditionValue shardingConditionValue1 = shardingCondition1.getValues().get(i);
ShardingConditionValue shardingConditionValue2 = shardingCondition2.getValues().get(i);
if (!isSameShardingConditionValue(shardingRule, (ListShardingConditionValue) shardingConditionValue1, (ListShardingConditionValue) shardingConditionValue2)) {
if (!isSameShardingConditionValue(shardingRule, shardingConditionValue1, shardingConditionValue2)) {
return false;
}
}
return true;
}
private boolean isSameShardingConditionValue(final ShardingRule shardingRule, final ListShardingConditionValue shardingConditionValue1, final ListShardingConditionValue shardingConditionValue2) {
private boolean isSameShardingConditionValue(final ShardingRule shardingRule, final ShardingConditionValue shardingConditionValue1, final ShardingConditionValue shardingConditionValue2) {
return isSameLogicTable(shardingRule, shardingConditionValue1, shardingConditionValue2) && shardingConditionValue1.getColumnName().equals(shardingConditionValue2.getColumnName())
&& SafeNumberOperationUtils.safeEquals(shardingConditionValue1.getValues(), shardingConditionValue2.getValues());
&& isSameValue(shardingConditionValue1, shardingConditionValue2);
}
private boolean isSameLogicTable(final ShardingRule shardingRule, final ListShardingConditionValue shardingValue1, final ListShardingConditionValue shardingValue2) {
private boolean isSameLogicTable(final ShardingRule shardingRule, final ShardingConditionValue shardingValue1, final ShardingConditionValue shardingValue2) {
return shardingValue1.getTableName().equals(shardingValue2.getTableName()) || isBindingTable(shardingRule, shardingValue1, shardingValue2);
}
private boolean isBindingTable(final ShardingRule shardingRule, final ListShardingConditionValue shardingValue1, final ListShardingConditionValue shardingValue2) {
private boolean isBindingTable(final ShardingRule shardingRule, final ShardingConditionValue shardingValue1, final ShardingConditionValue shardingValue2) {
Optional<BindingTableRule> bindingRule = shardingRule.findBindingTableRule(shardingValue1.getTableName());
return bindingRule.isPresent() && bindingRule.get().hasLogicTable(shardingValue2.getTableName());
}
@SuppressWarnings({"rawtypes", "unchecked"})
private boolean isSameValue(final ShardingConditionValue shardingConditionValue1, final ShardingConditionValue shardingConditionValue2) {
if (shardingConditionValue1 instanceof ListShardingConditionValue && shardingConditionValue2 instanceof ListShardingConditionValue) {
return SafeNumberOperationUtils.safeCollectionEquals(
((ListShardingConditionValue) shardingConditionValue1).getValues(), ((ListShardingConditionValue) shardingConditionValue2).getValues());
} else if (shardingConditionValue1 instanceof RangeShardingConditionValue && shardingConditionValue2 instanceof RangeShardingConditionValue) {
return SafeNumberOperationUtils.safeRangeEquals(
((RangeShardingConditionValue) shardingConditionValue1).getValueRange(), ((RangeShardingConditionValue) shardingConditionValue2).getValueRange());
}
return false;
}
private void mergeShardingConditions(final ShardingConditions shardingConditions) {
if (shardingConditions.getConditions().size() > 1) {
ShardingCondition shardingCondition = shardingConditions.getConditions().remove(shardingConditions.getConditions().size() - 1);
......
......@@ -149,16 +149,6 @@ public final class SubqueryRouteTest extends AbstractSQLRouteTest {
assertRoute(sql, parameters);
}
@Test(expected = IllegalStateException.class)
public void assertSubqueryWithoutHint() {
List<Object> parameters = new LinkedList<>();
parameters.add(1);
parameters.add(2);
parameters.add(5);
String sql = "select count(*) from t_hint_test where user_id = (select t_hint_test from t_hint_test where user_id in (?,?,?)) ";
assertRoute(sql, parameters);
}
@Test
public void assertSubqueryWithHint() {
HintManager hintManager = HintManager.getInstance();
......
......@@ -49,12 +49,8 @@ public final class SafeNumberOperationUtils {
try {
return range.intersection(connectedRange);
} catch (final ClassCastException ex) {
Comparable<?> rangeLowerEndpoint = range.hasLowerBound() ? range.lowerEndpoint() : null;
Comparable<?> rangeUpperEndpoint = range.hasUpperBound() ? range.upperEndpoint() : null;
Comparable<?> connectedRangeLowerEndpoint = connectedRange.hasLowerBound() ? connectedRange.lowerEndpoint() : null;
Comparable<?> connectedRangeUpperEndpoint = connectedRange.hasUpperBound() ? connectedRange.upperEndpoint() : null;
Class<?> clazz = getTargetNumericType(Lists.newArrayList(rangeLowerEndpoint, rangeUpperEndpoint, connectedRangeLowerEndpoint, connectedRangeUpperEndpoint));
if (clazz == null) {
Class<?> clazz = getRangeTargetNumericType(range, connectedRange);
if (null == clazz) {
throw ex;
}
Range<Comparable<?>> newRange = createTargetNumericTypeRange(range, clazz);
......@@ -75,7 +71,7 @@ public final class SafeNumberOperationUtils {
return Range.closed(lowerEndpoint, upperEndpoint);
} catch (final ClassCastException ex) {
Class<?> clazz = getTargetNumericType(Lists.newArrayList(lowerEndpoint, upperEndpoint));
if (clazz == null) {
if (null == clazz) {
throw ex;
}
return Range.closed(parseNumberByClazz(lowerEndpoint.toString(), clazz), parseNumberByClazz(upperEndpoint.toString(), clazz));
......@@ -96,14 +92,31 @@ public final class SafeNumberOperationUtils {
Comparable<?> rangeUpperEndpoint = range.hasUpperBound() ? range.upperEndpoint() : null;
Comparable<?> rangeLowerEndpoint = range.hasLowerBound() ? range.lowerEndpoint() : null;
Class<?> clazz = getTargetNumericType(Lists.newArrayList(rangeLowerEndpoint, rangeUpperEndpoint, endpoint));
if (clazz == null) {
if (null == clazz) {
throw ex;
}
Range<Comparable<?>> newRange = createTargetNumericTypeRange(range, clazz);
return newRange.contains(parseNumberByClazz(endpoint.toString(), clazz));
}
}
/**
* Execute range equals method by safe mode.
*
* @param sourceRange source range
* @param targetRange target range
* @return whether the source range and target range are same
*/
public static boolean safeRangeEquals(final Range<Comparable<?>> sourceRange, final Range<Comparable<?>> targetRange) {
Class<?> clazz = getRangeTargetNumericType(sourceRange, targetRange);
if (null == clazz) {
return sourceRange.equals(targetRange);
}
Range<Comparable<?>> newSourceRange = createTargetNumericTypeRange(sourceRange, clazz);
Range<Comparable<?>> newTargetRange = createTargetNumericTypeRange(targetRange, clazz);
return newSourceRange.equals(newTargetRange);
}
/**
* Execute collection equals method by safe mode.
*
......@@ -111,7 +124,7 @@ public final class SafeNumberOperationUtils {
* @param targetCollection target collection
* @return whether the element in source collection and target collection are all same
*/
public static boolean safeEquals(final Collection<Comparable<?>> sourceCollection, final Collection<Comparable<?>> targetCollection) {
public static boolean safeCollectionEquals(final Collection<Comparable<?>> sourceCollection, final Collection<Comparable<?>> targetCollection) {
List<Comparable<?>> collection = Lists.newArrayList(sourceCollection);
collection.addAll(targetCollection);
Class<?> clazz = getTargetNumericType(collection);
......@@ -123,6 +136,14 @@ public final class SafeNumberOperationUtils {
return sourceClazzCollection.equals(targetClazzCollection);
}
private static Class<?> getRangeTargetNumericType(final Range<Comparable<?>> sourceRange, final Range<Comparable<?>> targetRange) {
Comparable<?> sourceRangeLowerEndpoint = sourceRange.hasLowerBound() ? sourceRange.lowerEndpoint() : null;
Comparable<?> sourceRangeUpperEndpoint = sourceRange.hasUpperBound() ? sourceRange.upperEndpoint() : null;
Comparable<?> targetRangeLowerEndpoint = targetRange.hasLowerBound() ? targetRange.lowerEndpoint() : null;
Comparable<?> targetRangeUpperEndpoint = targetRange.hasUpperBound() ? targetRange.upperEndpoint() : null;
return getTargetNumericType(Lists.newArrayList(sourceRangeLowerEndpoint, sourceRangeUpperEndpoint, targetRangeLowerEndpoint, targetRangeUpperEndpoint));
}
private static Range<Comparable<?>> createTargetNumericTypeRange(final Range<Comparable<?>> range, final Class<?> clazz) {
if (range.hasLowerBound() && range.hasUpperBound()) {
Comparable<?> lowerEndpoint = parseNumberByClazz(range.lowerEndpoint().toString(), clazz);
......
......@@ -178,44 +178,74 @@ public final class SafeNumberOperationUtilsTest {
}
@Test
public void assertSafeEqualsForInteger() {
public void assertSafeCollectionEqualsForInteger() {
List<Comparable<?>> sourceCollection = Lists.newArrayList(10, 12);
List<Comparable<?>> targetCollection = Lists.newArrayList(10, 12);
assertTrue(SafeNumberOperationUtils.safeEquals(sourceCollection, targetCollection));
assertTrue(SafeNumberOperationUtils.safeCollectionEquals(sourceCollection, targetCollection));
}
@Test
public void assertSafeEqualsForLong() {
public void assertSafeCollectionEqualsForLong() {
List<Comparable<?>> sourceCollection = Lists.newArrayList(10, 12);
List<Comparable<?>> targetCollection = Lists.newArrayList(10L, 12L);
assertTrue(SafeNumberOperationUtils.safeEquals(sourceCollection, targetCollection));
assertTrue(SafeNumberOperationUtils.safeCollectionEquals(sourceCollection, targetCollection));
}
@Test
public void assertSafeEqualsForBigInteger() {
public void assertSafeCollectionEqualsForBigInteger() {
List<Comparable<?>> sourceCollection = Lists.newArrayList(10, 12);
List<Comparable<?>> targetCollection = Lists.newArrayList(BigInteger.valueOf(10), BigInteger.valueOf(12L));
assertTrue(SafeNumberOperationUtils.safeEquals(sourceCollection, targetCollection));
assertTrue(SafeNumberOperationUtils.safeCollectionEquals(sourceCollection, targetCollection));
}
@Test
public void assertSafeEqualsForFloat() {
public void assertSafeCollectionEqualsForFloat() {
List<Comparable<?>> sourceCollection = Lists.newArrayList(10.01F, 12.01F);
List<Comparable<?>> targetCollection = Lists.newArrayList(10.01F, 12.01F);
assertTrue(SafeNumberOperationUtils.safeEquals(sourceCollection, targetCollection));
assertTrue(SafeNumberOperationUtils.safeCollectionEquals(sourceCollection, targetCollection));
}
@Test
public void assertSafeEqualsForDouble() {
public void assertSafeCollectionEqualsForDouble() {
List<Comparable<?>> sourceCollection = Lists.newArrayList(10.01, 12.01);
List<Comparable<?>> targetCollection = Lists.newArrayList(10.01F, 12.01);
assertTrue(SafeNumberOperationUtils.safeEquals(sourceCollection, targetCollection));
assertTrue(SafeNumberOperationUtils.safeCollectionEquals(sourceCollection, targetCollection));
}
@Test
public void assertSafeEqualsForBigDecimal() {
public void assertSafeCollectionEqualsForBigDecimal() {
List<Comparable<?>> sourceCollection = Lists.newArrayList(10.01, 12.01);
List<Comparable<?>> targetCollection = Lists.newArrayList(BigDecimal.valueOf(10.01), BigDecimal.valueOf(12.01));
assertTrue(SafeNumberOperationUtils.safeEquals(sourceCollection, targetCollection));
assertTrue(SafeNumberOperationUtils.safeCollectionEquals(sourceCollection, targetCollection));
}
@Test
public void assertSafeRangeEqualsForInteger() {
assertTrue(SafeNumberOperationUtils.safeRangeEquals(Range.greaterThan(1), Range.greaterThan(1L)));
}
@Test
public void assertSafeRangeEqualsForLong() {
assertTrue(SafeNumberOperationUtils.safeRangeEquals(Range.greaterThan(1L), Range.greaterThan(BigInteger.ONE)));
}
@Test
public void assertSafeRangeEqualsForBigInteger() {
assertTrue(SafeNumberOperationUtils.safeRangeEquals(Range.greaterThan(BigInteger.ONE), Range.greaterThan(1)));
}
@Test
public void assertSafeRangeEqualsForFloat() {
assertTrue(SafeNumberOperationUtils.safeRangeEquals(Range.greaterThan(1.1F), Range.greaterThan(1.1)));
}
@Test
public void assertSafeRangeEqualsForDouble() {
assertTrue(SafeNumberOperationUtils.safeRangeEquals(Range.greaterThan(1.1), Range.greaterThan(BigDecimal.valueOf(1.1))));
}
@Test
public void assertSafeRangeEqualsForBigDecimal() {
assertTrue(SafeNumberOperationUtils.safeRangeEquals(Range.greaterThan(BigDecimal.valueOf(1.1)), Range.greaterThan(1.1F)));
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册