Mybatis Plus樂觀鎖拋異常

Mybatis-Plus(https://github.com/baomidou/mybatis-plus)的樂觀鎖插件並不能實現更新失敗時拋出指定異常,本博文針對此對3.0版本的樂觀鎖進行了改造,只貼關鍵代碼。

簡單介紹一下改造:當一次update發生時,攔截器首先判斷是否有傳版本號字段(本代碼中是version_val,自行按照實際命名,判斷邏輯較複雜,有興趣的盆友可以看看),如果沒有傳版本號字段則恢復執行,如果有那麼更新成功條數爲0時,會將原條件去掉版本號字段後再查詢一遍,如果查詢結果爲1,則說明是版本號不匹配導致的更新失敗,則拋出樂觀鎖異常(可自行定製),如果查詢結果爲0則結束。

import com.alibaba.druid.util.StringUtils;
import com.baomidou.mybatisplus.annotation.Version;
import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
import com.baomidou.mybatisplus.core.conditions.ISqlSegment;
import com.baomidou.mybatisplus.core.conditions.Wrapper;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.conditions.segments.MergeSegments;
import com.baomidou.mybatisplus.core.conditions.segments.NormalSegmentList;
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import com.baomidou.mybatisplus.core.toolkit.Constants;
import com.baomidou.mybatisplus.core.toolkit.ReflectionKit;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.core.toolkit.TableInfoHelper;

import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.SimpleExecutor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;

import java.lang.reflect.Field;
import java.math.BigDecimal;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.Statement;
import java.sql.Timestamp;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;


@Intercepts({@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})})
@Slf4j
public class OptimisticLockerExceptionInterceptor implements Interceptor {

    /**
     * 樂觀鎖常量
     *
     *
     */
    @Deprecated
    public static final String MP_OPTLOCK_VERSION_ORIGINAL = MybatisConstants.MP_OPTLOCK_VERSION_ORIGINAL;
    /**
     * 樂觀鎖常量
     *
     *
     */
    @Deprecated
    public static final String MP_OPTLOCK_VERSION_COLUMN = MybatisConstants.MP_OPTLOCK_VERSION_COLUMN;
    /**
     * 樂觀鎖常量
     *
     *
     */
    @Deprecated
    public static final String MP_OPTLOCK_ET_ORIGINAL = MybatisConstants.MP_OPTLOCK_ET_ORIGINAL;

    private static final String NAME_ENTITY = Constants.ENTITY;
    private static final String NAME_ENTITY_WRAPPER = Constants.WRAPPER;
    private static final String PARAM_UPDATE_METHOD_NAME = "update";
    private final Map<Class<?>, EntityField> versionFieldCache = new ConcurrentHashMap<>();
    private final Map<Class<?>, List<EntityField>> entityFieldsCache = new ConcurrentHashMap<>();

    /**
     * 正則匹配鍵前綴
     */
    private static final String EW_PARAMNAME_VALUE_PAIRS = "#{ew.paramNameValuePairs.";

    /**
     * 版本號字段(按照實際填寫)
     */
    private static final String VERSION_FIELD_NAME = "version_val";
    /**
     * 正則匹配version_val = #{...}
     */
    private static final Pattern versionPattern = Pattern.compile(VERSION_FIELD_NAME + " = #\\{[^\\}]+\\}");


    @Override
    @SuppressWarnings({"unchecked", "rawtypes"})
    public Object intercept(Invocation invocation) throws Throwable {
        boolean hasVersionField;
        Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[0];
        if (SqlCommandType.UPDATE != ms.getSqlCommandType()) {
            return invocation.proceed();
        }
        Object param = args[1];
        if (param instanceof Map) {
            Map map = (Map) param;
            //updateById(et), update(et, wrapper);
            Object et = map.getOrDefault(NAME_ENTITY,null);
            if (et != null) {
                // entity
                String methodId = ms.getId();
                String methodName = methodId.substring(methodId.lastIndexOf(StringPool.DOT) + 1);
                Class<?> entityClass = et.getClass();
                TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);
                EntityField versionField = this.getVersionField(entityClass, tableInfo);
                if (versionField == null) {
                    return invocation.proceed();
                }
                // 更新結果
                Object resultObj;
                Field field = versionField.getField();
                Object originalVersionVal = versionField.getField().get(et);
                if (originalVersionVal == null) {
                    Wrapper ew = (Wrapper) map.getOrDefault(NAME_ENTITY_WRAPPER,null);
                    if (ew == null) {
                        return invocation.proceed();
                    } else if (ew.getEntity() != null && ((BaseEntity)ew.getEntity()).getVersionVal() != null) {
                        originalVersionVal = ((BaseEntity)ew.getEntity()).getVersionVal();
                    } else if (ew.getSqlSegment() != null && ew.getSqlSegment().contains(versionField.getColumnName())) {
                        String sqlSegmentValue = ew.getSqlSegment();
                        Map<String, Object> pairsMap = ((AbstractWrapper) ew).getParamNameValuePairs();
                        Matcher matcher = versionPattern.matcher(sqlSegmentValue);
                        if (matcher.find()) {
                            String versionPair = matcher.group();
                            String versionKey = versionPair.substring(versionPair.indexOf(EW_PARAMNAME_VALUE_PAIRS)+EW_PARAMNAME_VALUE_PAIRS.length(),versionPair.length()-1);
                            if (pairsMap.get(versionKey) != null) {
                                originalVersionVal = pairsMap.get(versionKey);
                            }
                        }
                    } else {
                        return invocation.proceed();
                    }
                }
                if (originalVersionVal != null) {
                    hasVersionField = true;
                } else {
                    // 再判斷一次,確保update不會遺漏
                    return invocation.proceed();
                }
                Object updatedVersionVal = getUpdatedVersionVal(originalVersionVal);
                if (PARAM_UPDATE_METHOD_NAME.equals(methodName)) {
                    // update(entity, wrapper)
                    // mapper.update(updEntity, QueryWrapper<>(whereEntity);
                    AbstractWrapper<?, ?, ?> ew = (AbstractWrapper<?, ?, ?>) map.getOrDefault(NAME_ENTITY_WRAPPER, null);
                    if (ew == null) {
                        UpdateWrapper<?> uw = new UpdateWrapper<>();
                        uw.eq(versionField.getColumnName(), originalVersionVal);
                        map.put(NAME_ENTITY_WRAPPER, uw);
                    } else {
                        Field expressionField = getDeclaredField(ew.getClass(),"expression");
                        expressionField.setAccessible(true);
                        MergeSegments expression = (MergeSegments)expressionField.get(ew);
                        Field normalExpression = expression.getClass().getDeclaredField("normal");
                        normalExpression.setAccessible(true);
                        NormalSegmentList normalSegmentList = (NormalSegmentList)normalExpression.get(expression);
                        for (int i=0;i<normalSegmentList.size();i++) {
                            String s = normalSegmentList.get(i).getSqlSegment();
                            if (versionPattern.matcher(s).find()) {
                                Object sqlSegment = normalSegmentList.get(i);
                                Field arg$3 = getDeclaredField(sqlSegment.getClass(), "arg$3");
                                arg$3.setAccessible(true);
                                Object tt = arg$3.get(sqlSegment);
                                ((Object[])tt)[0] = originalVersionVal;
                            }
                        }
                        ew.apply(versionField.getColumnName() + " = {0}", originalVersionVal);
                    }
                    field.set(et, updatedVersionVal);
                    resultObj = invocation.proceed();
                } else {
                    List<EntityField> fields = entityFieldsCache.computeIfAbsent(entityClass, this::getFieldsFromClazz);
                    Map<String, Object> entityMap = new HashMap<>(fields.size());
                    for (EntityField ef : fields) {
                        Field fd = ef.getField();
                        entityMap.put(fd.getName(), fd.get(et));
                    }
                    String versionColumnName = versionField.getColumnName();
                    //update to cache
                    versionField.setColumnName(versionColumnName);
                    entityMap.put(field.getName(), updatedVersionVal);
                    entityMap.put(MybatisConstants.MP_OPTLOCK_VERSION_ORIGINAL, originalVersionVal);
                    entityMap.put(MybatisConstants.MP_OPTLOCK_VERSION_COLUMN, versionColumnName);
                    entityMap.put(MybatisConstants.MP_OPTLOCK_ET_ORIGINAL, et);
                    map.put(NAME_ENTITY, entityMap);
                    resultObj = invocation.proceed();
                }
                if (resultObj != null && resultObj instanceof Integer) {
                    Integer effRow = (Integer) resultObj;
                    if (updatedVersionVal != null && effRow != 0) {
                        //updated version value set to entity.
                        field.set(et, updatedVersionVal);
                    }
                    else if (hasVersionField && effRow == 0) {
                        log.debug("有樂觀鎖");
                        Wrapper ew = (Wrapper) map.getOrDefault(NAME_ENTITY_WRAPPER, null);
                        // entity = null的情形
                        if (ew != null && ew.getEntity() == null && ew instanceof AbstractWrapper) {
                            // 查詢是否由於樂觀鎖字段引起的update失敗
                            AbstractWrapper updateWrapper = (AbstractWrapper) ew;
                            Map<String, Object> paramPairsMap = ((AbstractWrapper) ew).getParamNameValuePairs();
                            Field expressionField = getDeclaredField(updateWrapper.getClass(),"expression");
                            expressionField.setAccessible(true);
                            MergeSegments expression = (MergeSegments)expressionField.get(ew);

                            Field normalExpression = expression.getClass().getDeclaredField("normal");
                            normalExpression.setAccessible(true);
                            NormalSegmentList normalSegmentList = (NormalSegmentList)normalExpression.get(expression);
                            StringBuilder sqlBuilder = new StringBuilder();
                            sqlBuilder.append("SELECT COUNT(*) FROM ");
                            sqlBuilder.append(tableInfo.getTableName());
                            int versionIndex = -10;
                            for (int i=0;i<normalSegmentList.size();i++) {
                                if (i==0) {
                                    sqlBuilder.append(" WHERE ");
                                }
                                String s = normalSegmentList.get(i).getSqlSegment();
                                if (versionPattern.matcher(s).find()) {
                                    sqlBuilder.append(" 1 = 1 ");
                                    continue;
                                }
                                if (s.equals(VERSION_FIELD_NAME)) {
                                    versionIndex = i;
                                }
                                if (i == versionIndex || i == versionIndex+2) {
                                    sqlBuilder.append(1);
                                } else {
                                    if (s.startsWith(EW_PARAMNAME_VALUE_PAIRS)) {
                                        s = s.substring(EW_PARAMNAME_VALUE_PAIRS.length(),s.length()-1);
                                        if (paramPairsMap.get(s) != null) {
                                            Object mapValue = paramPairsMap.get(s);
                                            if (mapValue instanceof String || mapValue instanceof LocalDate || mapValue instanceof LocalDateTime) {
                                                sqlBuilder.append("'");
                                                sqlBuilder.append(mapValue.toString());
                                                sqlBuilder.append("'");
                                            } else {
                                                sqlBuilder.append(mapValue.toString());
                                            }
                                        }
                                    } else {
                                        sqlBuilder.append(s);
                                    }
                                }
                                sqlBuilder.append(" ");
                            }
                            String sql = sqlBuilder.toString();
                            if (sql.indexOf("WHERE") != -1) {
                                SimpleExecutor executor = (SimpleExecutor)invocation.getTarget();
                                Connection connection = executor.getTransaction().getConnection();
                                Statement st = connection.createStatement();
                                ResultSet selectResult = st.executeQuery(sql);
                                if (selectResult.next()) {
                                    Object selectObject = selectResult.getObject(1);
                                    BigDecimal bigDecimal = new BigDecimal(0);
                                    if (selectObject instanceof Long) {
                                        bigDecimal = new BigDecimal( (Long) selectObject);
                                    } else if (selectObject instanceof Integer) {
                                        bigDecimal = new BigDecimal( (Integer) selectObject);
                                    }
                                    if (bigDecimal.compareTo(BigDecimal.ZERO) > 0) {
                                        if (selectResult.isClosed()) {
                                            selectResult.close();
                                        }
                                        if (st.isClosed()) {
                                            st.close();
                                        }
                                        if (!connection.isClosed()) {
                                            connection.close();
                                        }
                                        throw new BizException(ErrorCodeEnum.OPTIMISTICLOCKER_EXCEPTION_UPDATE_FAIL);
                                    }
                                }
                                if (selectResult.isClosed()) {
                                    selectResult.close();
                                }
                                if (st.isClosed()) {
                                    st.close();
                                }
                                if (!connection.isClosed()) {
                                    connection.close();
                                }
                            }
                        }
                        // entity != null的情形
                        else if (ew != null && ew.getEntity() != null){
                            BaseEntity entity = (BaseEntity) ew.getEntity();
                            entity.setVersionVal(null);
                            Object selectResult = ((BaseEntity)((MapperMethod.ParamMap) args[1]).get("et")).selectOne(new QueryWrapper().setEntity(entity));
                            if (selectResult != null) {
                                throw new BizException(ErrorCodeEnum.OPTIMISTICLOCKER_EXCEPTION_UPDATE_FAIL);
                            }
                        }
                        // wrapper = null 的情形
                        else if (ew == null && et != null ) {
                            BaseEntity baseEntity = (BaseEntity) map.get("param1");
                            if (baseEntity != null) {
                                if (baseEntity.getVersionVal() != null) {
                                    Object selectResult = baseEntity.selectById(baseEntity);
                                    if (selectResult != null) {
                                        throw new BizException(ErrorCodeEnum.OPTIMISTICLOCKER_EXCEPTION_UPDATE_FAIL);
                                    }
                                }
                            }

                        }
                    }
                }
                return resultObj;
            }
        }
        return invocation.proceed();
    }

    /**
     * This method provides the control for version value.<BR>
     * Returned value type must be the same as original one.
     *
     * @param originalVersionVal ignore
     * @return updated version val
     */
    protected Object getUpdatedVersionVal(Object originalVersionVal) {
        Class<?> versionValClass = originalVersionVal.getClass();
        if (long.class.equals(versionValClass) || Long.class.equals(versionValClass)) {
            return ((long) originalVersionVal) + 1;
        } else if (int.class.equals(versionValClass) || Integer.class.equals(versionValClass)) {
            return ((int) originalVersionVal) + 1;
        } else if (Date.class.equals(versionValClass)) {
            return new Date();
        } else if (Timestamp.class.equals(versionValClass)) {
            return new Timestamp(System.currentTimeMillis());
        } else if (LocalDateTime.class.equals(versionValClass)) {
            return LocalDateTime.now();
        }
        //not supported type, return original val.
        return originalVersionVal;
    }

    @Override
    public Object plugin(Object target) {
        if (target instanceof Executor) {
            return Plugin.wrap(target, this);
        }
        return target;
    }

    @Override
    public void setProperties(Properties properties) {
        // to do nothing
    }

    private EntityField getVersionField(Class<?> parameterClass, TableInfo tableInfo) {
        return versionFieldCache.computeIfAbsent(parameterClass, mapping -> getVersionFieldRegular(parameterClass, tableInfo));
    }

    /**
     * 反射檢查參數類是否啓動樂觀鎖
     *
     * @param parameterClass 實體類
     * @param tableInfo      實體數據庫反射信息
     * @return ignore
     */
    private EntityField getVersionFieldRegular(Class<?> parameterClass, TableInfo tableInfo) {
        return Object.class.equals(parameterClass) ? null : ReflectionKit.getFieldList(parameterClass).stream().filter(e -> e.isAnnotationPresent(Version.class)).map(field -> {
            field.setAccessible(true);
            return new EntityField(field, true, tableInfo.getFieldList().stream().filter(e -> field.getName().equals(e.getProperty())).map(TableFieldInfo::getColumn).findFirst().orElse(null));
        }).findFirst().orElseGet(() -> this.getVersionFieldRegular(parameterClass.getSuperclass(), tableInfo));
    }

    private List<EntityField> getFieldsFromClazz(Class<?> parameterClass) {
        return ReflectionKit.getFieldList(parameterClass).stream().map(field -> {
            field.setAccessible(true);
            return new EntityField(field, field.isAnnotationPresent(Version.class));
        }).collect(Collectors.toList());
    }

    @Data
    private class EntityField {

        private Field field;
        private boolean version;
        private String columnName;

        EntityField(Field field, boolean version) {
            this.field = field;
            this.version = version;
        }

        public EntityField(Field field, boolean version, String columnName) {
            this.field = field;
            this.version = version;
            this.columnName = columnName;
        }
    }

    private Field getDeclaredField(Class<?> clazz, String fieldName) {
        try {
            if (clazz.getDeclaredField(fieldName) != null) {
                return clazz.getDeclaredField(fieldName);
            }
        } catch (NoSuchFieldException e) {
            clazz = clazz.getSuperclass();
        }
        return getDeclaredField(clazz, fieldName);
    }

}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章