001/*
002 * Copyright (C) 2014 Jörg Prante
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *      http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.xbib.elasticsearch.plugin.jdbc.client;
017
018import org.elasticsearch.action.bulk.BulkProcessor;
019import org.elasticsearch.action.bulk.BulkRequest;
020import org.elasticsearch.common.logging.ESLogger;
021import org.elasticsearch.common.logging.ESLoggerFactory;
022import org.elasticsearch.common.unit.TimeValue;
023
024import java.lang.reflect.Field;
025import java.lang.reflect.Method;
026import java.util.concurrent.Semaphore;
027import java.util.concurrent.TimeUnit;
028
029public class BulkProcessorHelper {
030
031    private final static ESLogger logger = ESLoggerFactory.getLogger(BulkProcessorHelper.class.getSimpleName());
032
033    public static void flush(BulkProcessor bulkProcessor) {
034        try {
035            Field field = bulkProcessor.getClass().getDeclaredField("bulkRequest");
036            if (field != null) {
037                field.setAccessible(true);
038                BulkRequest bulkRequest = (BulkRequest) field.get(bulkProcessor);
039                if (bulkRequest.numberOfActions() > 0) {
040                    Method method = bulkProcessor.getClass().getDeclaredMethod("execute");
041                    if (method != null) {
042                        method.setAccessible(true);
043                        method.invoke(bulkProcessor);
044                    }
045                }
046            }
047        } catch (Throwable e) {
048            logger.error(e.getMessage(), e);
049        }
050    }
051
052    public static boolean waitFor(BulkProcessor bulkProcessor, TimeValue maxWait) {
053        Semaphore semaphore = null;
054        boolean acquired = false;
055        try {
056            Field field = bulkProcessor.getClass().getDeclaredField("semaphore");
057            if (field != null) {
058                field.setAccessible(true);
059                Field concurrentField = bulkProcessor.getClass().getDeclaredField("concurrentRequests");
060                concurrentField.setAccessible(true);
061                int concurrency = concurrentField.getInt(bulkProcessor);
062                // concurreny == 1 means there is no concurrency (default start value)
063                if (concurrency > 1) {
064                    semaphore = (Semaphore) field.get(bulkProcessor);
065                    acquired = semaphore.tryAcquire(concurrency, maxWait.getMillis(), TimeUnit.MILLISECONDS);
066                    return semaphore.availablePermits() == concurrency;
067                }
068            }
069        } catch (InterruptedException e) {
070            Thread.currentThread().interrupt();
071            logger.warn("interrupted");
072        } catch (Throwable e) {
073            logger.error(e.getMessage(), e);
074        } finally {
075            if (semaphore != null && acquired) {
076                semaphore.release();
077            }
078        }
079        return false;
080    }
081}