未验证 提交 926d7de2 编写于 作者: D DuanZhengqiang 提交者: GitHub

fix bug for ClassCastException when execute range query with different numeric type (#5770)

* fix bug for ClassCastException when execute range query in numeric type

* modify the javadoc format with one space

* blank lines start with 4 blank spaces
上级 82b9f33f
......@@ -19,22 +19,23 @@ package org.apache.shardingsphere.sharding.route.engine.condition.engine;
import com.google.common.collect.Range;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
import org.apache.shardingsphere.sharding.strategy.route.value.ListRouteValue;
import org.apache.shardingsphere.sharding.strategy.route.value.RangeRouteValue;
import org.apache.shardingsphere.sharding.strategy.route.value.RouteValue;
import org.apache.shardingsphere.infra.exception.ShardingSphereException;
import org.apache.shardingsphere.sharding.route.engine.condition.AlwaysFalseRouteValue;
import org.apache.shardingsphere.sharding.route.engine.condition.AlwaysFalseShardingCondition;
import org.apache.shardingsphere.sharding.route.engine.condition.Column;
import org.apache.shardingsphere.sharding.route.engine.condition.ShardingCondition;
import org.apache.shardingsphere.sharding.route.engine.condition.generator.ConditionValueGeneratorFactory;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
import org.apache.shardingsphere.sharding.strategy.route.value.ListRouteValue;
import org.apache.shardingsphere.sharding.strategy.route.value.RangeRouteValue;
import org.apache.shardingsphere.sharding.strategy.route.value.RouteValue;
import org.apache.shardingsphere.sql.parser.binder.metadata.schema.SchemaMetaData;
import org.apache.shardingsphere.sql.parser.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.sql.parser.binder.type.WhereAvailable;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.AndPredicate;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.PredicateSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.infra.exception.ShardingSphereException;
import org.apache.shardingsphere.sql.parser.sql.util.SafeRangeOperationUtils;
import java.util.ArrayList;
import java.util.Collection;
......@@ -168,13 +169,13 @@ public final class WhereClauseShardingConditionEngine {
}
private Range<Comparable<?>> mergeRangeRouteValues(final Range<Comparable<?>> value1, final Range<Comparable<?>> value2) {
return null == value2 ? value1 : value1.intersection(value2);
return null == value2 ? value1 : SafeRangeOperationUtils.safeIntersection(value1, value2);
}
private Collection<Comparable<?>> mergeListAndRangeRouteValues(final Collection<Comparable<?>> listValue, final Range<Comparable<?>> rangeValue) {
Collection<Comparable<?>> result = new LinkedList<>();
for (Comparable<?> each : listValue) {
if (rangeValue.contains(each)) {
if (SafeRangeOperationUtils.safeContains(rangeValue, each)) {
result.add(each);
}
}
......
......@@ -26,6 +26,7 @@ import org.apache.shardingsphere.sharding.route.engine.condition.generator.Condi
import org.apache.shardingsphere.sharding.route.engine.condition.generator.ConditionValueGenerator;
import org.apache.shardingsphere.sharding.route.spi.SPITimeService;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.value.PredicateBetweenRightValue;
import org.apache.shardingsphere.sql.parser.sql.util.SafeRangeOperationUtils;
import java.util.Date;
import java.util.List;
......@@ -41,7 +42,7 @@ public final class ConditionValueBetweenOperatorGenerator implements ConditionVa
Optional<Comparable> betweenRouteValue = new ConditionValue(predicateRightValue.getBetweenExpression(), parameters).getValue();
Optional<Comparable> andRouteValue = new ConditionValue(predicateRightValue.getAndExpression(), parameters).getValue();
if (betweenRouteValue.isPresent() && andRouteValue.isPresent()) {
return Optional.of(new RangeRouteValue<>(column.getName(), column.getTableName(), Range.closed(betweenRouteValue.get(), andRouteValue.get())));
return Optional.of(new RangeRouteValue<>(column.getName(), column.getTableName(), SafeRangeOperationUtils.safeClosed(betweenRouteValue.get(), andRouteValue.get())));
}
Date date = new SPITimeService().getTime();
if (!betweenRouteValue.isPresent() && ExpressionConditionUtils.isNowExpression(predicateRightValue.getBetweenExpression())) {
......
......@@ -24,6 +24,7 @@ import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.ExpressionSegme
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.complex.CommonExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.segment.dml.predicate.value.PredicateBetweenRightValue;
import org.apache.shardingsphere.sql.parser.sql.util.SafeRangeOperationUtils;
import org.junit.Test;
import java.util.Calendar;
......@@ -58,6 +59,23 @@ public final class ConditionValueBetweenOperatorGeneratorTest {
assertTrue(rangeRouteValue.getValueRange().contains(and));
}
@SuppressWarnings("unchecked")
@Test
public void assertGenerateConditionValueWithDifferentNumericType() {
int between = 3;
long and = 3147483647L;
ExpressionSegment betweenSegment = new LiteralExpressionSegment(0, 0, between);
ExpressionSegment andSegment = new LiteralExpressionSegment(0, 0, and);
PredicateBetweenRightValue value = new PredicateBetweenRightValue(betweenSegment, andSegment);
Optional<RouteValue> routeValue = generator.generate(value, column, new LinkedList<>());
assertTrue(routeValue.isPresent());
RangeRouteValue<Comparable<?>> rangeRouteValue = (RangeRouteValue<Comparable<?>>) routeValue.get();
assertThat(rangeRouteValue.getColumnName(), is(column.getName()));
assertThat(rangeRouteValue.getTableName(), is(column.getTableName()));
assertTrue(SafeRangeOperationUtils.safeContains(rangeRouteValue.getValueRange(), between));
assertTrue(SafeRangeOperationUtils.safeContains(rangeRouteValue.getValueRange(), and));
}
@Test(expected = ClassCastException.class)
public void assertGenerateErrorConditionValue() {
int between = 1;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.shardingsphere.sql.parser.sql.util;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Range;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.SneakyThrows;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Safe range operation utility class.
*/
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public final class SafeRangeOperationUtils {
/**
* Execute intersection method by safe mode.
*
* @param range range
* @param connectedRange connected range
* @return the intersection result of two ranges
*/
public static Range<Comparable<?>> safeIntersection(final Range<Comparable<?>> range, final Range<Comparable<?>> connectedRange) {
try {
return range.intersection(connectedRange);
} catch (final ClassCastException ex) {
Comparable<?> rangeLowerEndpoint = range.hasLowerBound() ? range.lowerEndpoint() : null;
Comparable<?> rangeUpperEndpoint = range.hasUpperBound() ? range.upperEndpoint() : null;
Comparable<?> connectedRangeLowerEndpoint = connectedRange.hasLowerBound() ? connectedRange.lowerEndpoint() : null;
Comparable<?> connectedRangeUpperEndpoint = connectedRange.hasUpperBound() ? connectedRange.upperEndpoint() : null;
Class<?> clazz = getTargetNumericType(Lists.newArrayList(rangeLowerEndpoint, rangeUpperEndpoint, connectedRangeLowerEndpoint, connectedRangeUpperEndpoint));
if (clazz == null) {
throw ex;
}
Range<Comparable<?>> newRange = createTargetNumericTypeRange(range, clazz);
Range<Comparable<?>> newConnectedRange = createTargetNumericTypeRange(connectedRange, clazz);
return newRange.intersection(newConnectedRange);
}
}
/**
* Execute closed method by safe mode.
*
* @param lowerEndpoint lower endpoint
* @param upperEndpoint upper endpoint
* @return new range
*/
public static Range<Comparable<?>> safeClosed(final Comparable<?> lowerEndpoint, final Comparable<?> upperEndpoint) {
try {
return Range.closed(lowerEndpoint, upperEndpoint);
} catch (final ClassCastException ex) {
Class<?> clazz = getTargetNumericType(Lists.newArrayList(lowerEndpoint, upperEndpoint));
if (clazz == null) {
throw ex;
}
return Range.closed(parseNumberByClazz(lowerEndpoint.toString(), clazz), parseNumberByClazz(upperEndpoint.toString(), clazz));
}
}
/**
* Execute contains method by safe mode.
*
* @param range range
* @param endpoint endpoint
* @return whether the endpoint is included in the range
*/
public static boolean safeContains(final Range<Comparable<?>> range, final Comparable<?> endpoint) {
try {
return range.contains(endpoint);
} catch (final ClassCastException ex) {
Comparable<?> rangeUpperEndpoint = range.hasUpperBound() ? range.upperEndpoint() : null;
Comparable<?> rangeLowerEndpoint = range.hasLowerBound() ? range.lowerEndpoint() : null;
Class<?> clazz = getTargetNumericType(Lists.newArrayList(rangeLowerEndpoint, rangeUpperEndpoint, endpoint));
if (clazz == null) {
throw ex;
}
Range<Comparable<?>> newRange = createTargetNumericTypeRange(range, clazz);
return newRange.contains(parseNumberByClazz(endpoint.toString(), clazz));
}
}
private static Range<Comparable<?>> createTargetNumericTypeRange(final Range<Comparable<?>> range, final Class<?> clazz) {
if (range.hasLowerBound() && range.hasUpperBound()) {
Comparable<?> lowerEndpoint = parseNumberByClazz(range.lowerEndpoint().toString(), clazz);
Comparable<?> upperEndpoint = parseNumberByClazz(range.upperEndpoint().toString(), clazz);
return Range.range(lowerEndpoint, range.lowerBoundType(), upperEndpoint, range.upperBoundType());
}
if (!range.hasLowerBound() && !range.hasUpperBound()) {
return Range.all();
}
if (range.hasLowerBound()) {
Comparable<?> lowerEndpoint = parseNumberByClazz(range.lowerEndpoint().toString(), clazz);
return Range.downTo(lowerEndpoint, range.lowerBoundType());
}
Comparable<?> upperEndpoint = parseNumberByClazz(range.upperEndpoint().toString(), clazz);
return Range.upTo(upperEndpoint, range.upperBoundType());
}
private static Class<?> getTargetNumericType(final List<Comparable<?>> endpoints) {
Preconditions.checkNotNull(endpoints, "getTargetNumericType param endpoints can not be null.");
Set<Class<?>> clazzSet = endpoints.stream().filter(Objects::nonNull).map(Comparable::getClass).collect(Collectors.toSet());
if (clazzSet.contains(BigDecimal.class)) {
return BigDecimal.class;
}
if (clazzSet.contains(Double.class)) {
return Double.class;
}
if (clazzSet.contains(Float.class)) {
return Float.class;
}
if (clazzSet.contains(BigInteger.class)) {
return BigInteger.class;
}
if (clazzSet.contains(Long.class)) {
return Long.class;
}
if (clazzSet.contains(Integer.class)) {
return Integer.class;
}
return null;
}
@SneakyThrows
private static Comparable<?> parseNumberByClazz(final String number, final Class<?> clazz) {
return (Comparable<?>) clazz.getConstructor(String.class).newInstance(number);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.shardingsphere.sql.parser.sql.util;
import com.google.common.collect.BoundType;
import com.google.common.collect.Range;
import org.junit.Test;
import java.math.BigDecimal;
import java.math.BigInteger;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
public class SafeRangeOperationUtilsTest {
@Test
public void assertSafeIntersectionForInteger() {
Range<Comparable<?>> range = Range.closed(10, 2000);
Range<Comparable<?>> connectedRange = Range.closed(1500, 4000);
Range<Comparable<?>> newRange = SafeRangeOperationUtils.safeIntersection(range, connectedRange);
assertThat(newRange.lowerEndpoint(), is(1500));
assertThat(newRange.lowerBoundType(), is(BoundType.CLOSED));
assertThat(newRange.upperEndpoint(), is(2000));
assertThat(newRange.upperBoundType(), is(BoundType.CLOSED));
}
@Test
public void assertSafeIntersectionForLong() {
Range<Comparable<?>> range = Range.upTo(3147483647L, BoundType.OPEN);
Range<Comparable<?>> connectedRange = Range.downTo(3, BoundType.OPEN);
Range<Comparable<?>> newRange = SafeRangeOperationUtils.safeIntersection(range, connectedRange);
assertThat(newRange.lowerEndpoint(), is(3L));
assertThat(newRange.lowerBoundType(), is(BoundType.OPEN));
assertThat(newRange.upperEndpoint(), is(3147483647L));
assertThat(newRange.upperBoundType(), is(BoundType.OPEN));
}
@Test
public void assertSafeIntersectionForBigInteger() {
Range<Comparable<?>> range = Range.upTo(new BigInteger("131323233123211"), BoundType.CLOSED);
Range<Comparable<?>> connectedRange = Range.downTo(35, BoundType.OPEN);
Range<Comparable<?>> newRange = SafeRangeOperationUtils.safeIntersection(range, connectedRange);
assertThat(newRange.lowerEndpoint(), is(new BigInteger("35")));
assertThat(newRange.lowerBoundType(), is(BoundType.OPEN));
assertThat(newRange.upperEndpoint(), is(new BigInteger("131323233123211")));
assertThat(newRange.upperBoundType(), is(BoundType.CLOSED));
}
@Test
public void assertSafeIntersectionForFloat() {
Range<Comparable<?>> range = Range.closed(5.5F, 13.8F);
Range<Comparable<?>> connectedRange = Range.closed(7.14F, 11.3F);
Range<Comparable<?>> newRange = SafeRangeOperationUtils.safeIntersection(range, connectedRange);
assertThat(newRange.lowerEndpoint(), is(7.14F));
assertThat(newRange.lowerBoundType(), is(BoundType.CLOSED));
assertThat(newRange.upperEndpoint(), is(11.3F));
assertThat(newRange.upperBoundType(), is(BoundType.CLOSED));
}
@Test
public void assertSafeIntersectionForDouble() {
Range<Comparable<?>> range = Range.closed(1242.114, 31474836.12);
Range<Comparable<?>> connectedRange = Range.downTo(567.34F, BoundType.OPEN);
Range<Comparable<?>> newRange = SafeRangeOperationUtils.safeIntersection(range, connectedRange);
assertThat(newRange.lowerEndpoint(), is(1242.114));
assertThat(newRange.lowerBoundType(), is(BoundType.CLOSED));
assertThat(newRange.upperEndpoint(), is(31474836.12));
assertThat(newRange.upperBoundType(), is(BoundType.CLOSED));
}
@Test
public void assertSafeIntersectionForBigDecimal() {
Range<Comparable<?>> range = Range.upTo(new BigDecimal("2331.23211"), BoundType.CLOSED);
Range<Comparable<?>> connectedRange = Range.open(135.13F, 45343.23F);
Range<Comparable<?>> newRange = SafeRangeOperationUtils.safeIntersection(range, connectedRange);
assertThat(newRange.lowerEndpoint(), is(new BigDecimal("135.13")));
assertThat(newRange.lowerBoundType(), is(BoundType.OPEN));
assertThat(newRange.upperEndpoint(), is(new BigDecimal("2331.23211")));
assertThat(newRange.upperBoundType(), is(BoundType.CLOSED));
}
@Test
public void assertSafeClosedForInteger() {
Range<Comparable<?>> range = SafeRangeOperationUtils.safeClosed(12, 500);
assertThat(range.lowerEndpoint(), is(12));
assertThat(range.upperEndpoint(), is(500));
}
@Test
public void assertSafeClosedForLong() {
Range<Comparable<?>> range = SafeRangeOperationUtils.safeClosed(12, 5001L);
assertThat(range.lowerEndpoint(), is(12L));
assertThat(range.upperEndpoint(), is(5001L));
}
@Test
public void assertSafeClosedForBigInteger() {
Range<Comparable<?>> range = SafeRangeOperationUtils.safeClosed(12L, new BigInteger("12344"));
assertThat(range.lowerEndpoint(), is(new BigInteger("12")));
assertThat(range.upperEndpoint(), is(new BigInteger("12344")));
}
@Test
public void assertSafeClosedForFloat() {
Range<Comparable<?>> range = SafeRangeOperationUtils.safeClosed(4.5F, 11.13F);
assertThat(range.lowerEndpoint(), is(4.5F));
assertThat(range.upperEndpoint(), is(11.13F));
}
@Test
public void assertSafeClosedForDouble() {
Range<Comparable<?>> range = SafeRangeOperationUtils.safeClosed(5.12F, 13.75);
assertThat(range.lowerEndpoint(), is(5.12));
assertThat(range.upperEndpoint(), is(13.75));
}
@Test
public void assertSafeClosedForBigDecimal() {
Range<Comparable<?>> range = SafeRangeOperationUtils.safeClosed(5.1F, new BigDecimal("17.666"));
assertThat(range.lowerEndpoint(), is(new BigDecimal("5.1")));
assertThat(range.upperEndpoint(), is(new BigDecimal("17.666")));
}
@Test
public void assertSafeContainsForInteger() {
Range<Comparable<?>> range = Range.closed(12, 100);
assertThat(SafeRangeOperationUtils.safeContains(range, 500), is(false));
}
@Test
public void assertSafeContainsForLong() {
Range<Comparable<?>> range = Range.closed(12L, 1000L);
assertThat(SafeRangeOperationUtils.safeContains(range, 500), is(true));
}
@Test
public void assertSafeContainsForBigInteger() {
Range<Comparable<?>> range = Range.closed(new BigInteger("123"), new BigInteger("1000"));
assertThat(SafeRangeOperationUtils.safeContains(range, 510), is(true));
}
@Test
public void assertSafeContainsForFloat() {
Range<Comparable<?>> range = Range.closed(123.11F, 9999.123F);
assertThat(SafeRangeOperationUtils.safeContains(range, 510.12), is(true));
}
@Test
public void assertSafeContainsForDouble() {
Range<Comparable<?>> range = Range.closed(11.11, 9999.99);
assertThat(SafeRangeOperationUtils.safeContains(range, new BigDecimal("510.12")), is(true));
}
@Test
public void assertSafeContainsForBigDecimal() {
Range<Comparable<?>> range = Range.closed(new BigDecimal("123.11"), new BigDecimal("9999.123"));
assertThat(SafeRangeOperationUtils.safeContains(range, 510.12), is(true));
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册