001/*
002 * Copyright 2015-2018 Transmogrify LLC.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.pyranid;
018
019import javax.sql.DataSource;
020import java.sql.Connection;
021import java.sql.PreparedStatement;
022import java.sql.ResultSet;
023import java.sql.SQLException;
024import java.util.ArrayDeque;
025import java.util.ArrayList;
026import java.util.Arrays;
027import java.util.Deque;
028import java.util.List;
029import java.util.Optional;
030import java.util.logging.Logger;
031
032import static java.lang.String.format;
033import static java.lang.System.nanoTime;
034import static java.util.Collections.emptyList;
035import static java.util.Objects.requireNonNull;
036import static java.util.logging.Level.WARNING;
037
038/**
039 * @author <a href="http://revetkn.com">Mark Allen</a>
040 * @since 1.0.0
041 */
042public class Database {
043  private static final ThreadLocal<Deque<Transaction>> TRANSACTION_STACK_HOLDER = ThreadLocal
044      .withInitial(() -> new ArrayDeque<>());
045
046  private final DataSource dataSource;
047  private final InstanceProvider instanceProvider;
048  private final PreparedStatementBinder preparedStatementBinder;
049  private final ResultSetMapper resultSetMapper;
050  private final StatementLogger statementLogger;
051  private final Logger logger = Logger.getLogger(Database.class.getName());
052
053  public static Builder forDataSource(DataSource dataSource) {
054    return new Builder(requireNonNull(dataSource));
055  }
056
057  protected Database(Builder builder) {
058    requireNonNull(builder);
059    this.dataSource = requireNonNull(builder.dataSource);
060    this.instanceProvider = requireNonNull(builder.instanceProvider);
061    this.preparedStatementBinder = requireNonNull(builder.preparedStatementBinder);
062    this.resultSetMapper = requireNonNull(builder.resultSetMapper);
063    this.statementLogger = requireNonNull(builder.statementLogger);
064  }
065
066  public static class Builder {
067    private final DataSource dataSource;
068    private final DatabaseType databaseType;
069    private InstanceProvider instanceProvider;
070    private PreparedStatementBinder preparedStatementBinder;
071    private ResultSetMapper resultSetMapper;
072    private StatementLogger statementLogger;
073
074    // See build() method for explanation of why we keep track of whether these fields have changed
075    private final InstanceProvider initialInstanceProvider;
076    private final ResultSetMapper initialResultSetMapper;
077
078    private Builder(DataSource dataSource) {
079      this.dataSource = requireNonNull(dataSource);
080      this.databaseType = DatabaseType.fromDataSource(dataSource);
081
082      this.preparedStatementBinder = new DefaultPreparedStatementBinder(this.databaseType);
083      this.statementLogger = new DefaultStatementLogger();
084
085      this.instanceProvider = new DefaultInstanceProvider();
086      this.initialInstanceProvider = this.instanceProvider;
087
088      this.resultSetMapper = new DefaultResultSetMapper(this.databaseType, this.instanceProvider);
089      this.initialResultSetMapper = resultSetMapper;
090    }
091
092    public Builder instanceProvider(InstanceProvider instanceProvider) {
093      this.instanceProvider = requireNonNull(instanceProvider);
094      return this;
095    }
096
097    public Builder preparedStatementBinder(PreparedStatementBinder preparedStatementBinder) {
098      this.preparedStatementBinder = requireNonNull(preparedStatementBinder);
099      return this;
100    }
101
102    public Builder resultSetMapper(ResultSetMapper resultSetMapper) {
103      this.resultSetMapper = requireNonNull(resultSetMapper);
104      return this;
105    }
106
107    public Builder statementLogger(StatementLogger statementLogger) {
108      this.statementLogger = requireNonNull(statementLogger);
109      return this;
110    }
111
112    public Database build() {
113      // A little sleight-of-hand to make the 99% case easier for users...
114      // If at build time the InstanceProvider has been changed but the ResultSetMapper is unchanged,
115      // wire the custom InstanceProvider into the DefaultResultSetMapper
116      if (this.instanceProvider != this.initialInstanceProvider && this.resultSetMapper == this.initialResultSetMapper)
117        this.resultSetMapper = new DefaultResultSetMapper(this.databaseType, this.instanceProvider);
118
119      return new Database(this);
120    }
121  }
122
123  public Optional<Transaction> currentTransaction() {
124    Deque<Transaction> transactionStack = TRANSACTION_STACK_HOLDER.get();
125    return Optional.ofNullable(transactionStack.size() == 0 ? null : transactionStack.peek());
126  }
127
128  public void transaction(TransactionalOperation transactionalOperation) {
129    requireNonNull(transactionalOperation);
130
131    transaction(() -> {
132      transactionalOperation.perform();
133      return null;
134    });
135  }
136
137  public <T> T transaction(ReturningTransactionalOperation<T> transactionalOperation) {
138    requireNonNull(transactionalOperation);
139    return transaction(TransactionIsolation.DEFAULT, transactionalOperation);
140  }
141
142  public <T> T transaction(TransactionIsolation transactionIsolation,
143                           ReturningTransactionalOperation<T> transactionalOperation) {
144    requireNonNull(transactionIsolation);
145    requireNonNull(transactionalOperation);
146
147    Transaction transaction = new Transaction(dataSource, transactionIsolation);
148    TRANSACTION_STACK_HOLDER.get().push(transaction);
149    boolean committed = false;
150    boolean rolledBack = false;
151
152    try {
153      T returnValue = transactionalOperation.perform();
154
155      if (transaction.isRollbackOnly()) {
156        transaction.rollback();
157        rolledBack = true;
158      } else {
159        transaction.commit();
160        committed = true;
161      }
162
163      return returnValue;
164    } catch (RuntimeException e) {
165      try {
166        transaction.rollback();
167        rolledBack = true;
168      } catch (Exception rollbackException) {
169        logger.log(WARNING, "Unable to roll back transaction", rollbackException);
170      }
171
172      throw e;
173    } catch (Throwable t) {
174      try {
175        transaction.rollback();
176        rolledBack = true;
177      } catch (Exception rollbackException) {
178        logger.log(WARNING, "Unable to roll back transaction", rollbackException);
179      }
180
181      throw new RuntimeException(t);
182    } finally {
183      TRANSACTION_STACK_HOLDER.get().pop();
184
185      try {
186        try {
187          if (transaction.initialAutoCommit().isPresent() && transaction.initialAutoCommit().get())
188            // Autocommit was true initially, so restoring to true now that transaction has completed
189            transaction.setAutoCommit(true);
190        } finally {
191          if (transaction.hasConnection())
192            closeConnection(transaction.connection());
193        }
194      } finally {
195        // Execute any user-supplied post-execution hooks
196        if (committed) {
197          for (Runnable postCommitOperation : transaction.postCommitOperations())
198            postCommitOperation.run();
199        } else if (rolledBack) {
200          for (Runnable postRollbackOperation : transaction.postRollbackOperations())
201            postRollbackOperation.run();
202        }
203      }
204    }
205  }
206
207  protected void closeConnection(Connection connection) {
208    requireNonNull(connection);
209
210    try {
211      connection.close();
212    } catch (SQLException e) {
213      throw new DatabaseException("Unable to close database connection", e);
214    }
215  }
216
217  public void participate(Transaction transaction, TransactionalOperation transactionalOperation) {
218    requireNonNull(transaction);
219    requireNonNull(transactionalOperation);
220
221    participate(transaction, () -> {
222      transactionalOperation.perform();
223      return null;
224    });
225  }
226
227  public <T> T participate(Transaction transaction, ReturningTransactionalOperation<T> transactionalOperation) {
228    requireNonNull(transaction);
229    requireNonNull(transactionalOperation);
230
231    TRANSACTION_STACK_HOLDER.get().push(transaction);
232
233    try {
234      return transactionalOperation.perform();
235    } catch (RuntimeException e) {
236      transaction.setRollbackOnly(true);
237      throw e;
238    } catch (Throwable t) {
239      transaction.setRollbackOnly(true);
240      throw new RuntimeException(t);
241    } finally {
242      TRANSACTION_STACK_HOLDER.get().pop();
243    }
244  }
245
246  public <T> Optional<T> queryForObject(String sql, Class<T> objectType, Object... parameters) {
247    requireNonNull(sql);
248    requireNonNull(objectType);
249
250    return queryForObject(sql, null, objectType, parameters);
251  }
252
253  public <T> Optional<T> queryForObject(String sql, StatementMetadata statementMetadata, Class<T> objectType, Object... parameters) {
254    requireNonNull(sql);
255    requireNonNull(objectType);
256
257    List<T> list = queryForList(sql, statementMetadata, objectType, parameters);
258
259    if (list.size() > 1)
260      throw new DatabaseException(format("Expected 1 row in resultset but got %s instead", list.size()));
261
262    return Optional.ofNullable(list.size() == 0 ? null : list.get(0));
263  }
264
265  protected static class DatabaseOperationResult {
266    private final Optional<Long> executionTime;
267    private final Optional<Long> resultSetMappingTime;
268
269    public DatabaseOperationResult(Optional<Long> executionTime, Optional<Long> resultSetMappingTime) {
270      this.executionTime = requireNonNull(executionTime);
271      this.resultSetMappingTime = requireNonNull(resultSetMappingTime);
272    }
273
274    public Optional<Long> executionTime() {
275      return executionTime;
276    }
277
278    public Optional<Long> resultSetMappingTime() {
279      return resultSetMappingTime;
280    }
281  }
282
283  @FunctionalInterface
284  protected static interface DatabaseOperation {
285    DatabaseOperationResult perform(PreparedStatement preparedStatement) throws Exception;
286  }
287
288  @FunctionalInterface
289  protected static interface PreparedStatementBindingOperation {
290    void perform(PreparedStatement preparedStatement) throws Exception;
291  }
292
293  protected void performDatabaseOperation(String sql, StatementMetadata statementMetadata, Object[] parameters, DatabaseOperation databaseOperation) {
294    requireNonNull(sql);
295    requireNonNull(databaseOperation);
296
297    performDatabaseOperation(sql, statementMetadata, parameters, (preparedStatement) -> {
298      if (parameters != null && parameters.length > 0)
299        preparedStatementBinder().bind(preparedStatement, Arrays.asList(parameters));
300    }, databaseOperation);
301  }
302
303  protected void performDatabaseOperation(String sql, StatementMetadata statementMetadata, Object[] parameters,
304                                          PreparedStatementBindingOperation preparedStatementBindingOperation, DatabaseOperation databaseOperation) {
305    requireNonNull(sql);
306    requireNonNull(preparedStatementBindingOperation);
307    requireNonNull(databaseOperation);
308
309    long startTime = nanoTime();
310    Optional<Long> connectionAcquisitionTime = Optional.empty();
311    Optional<Long> preparationTime = Optional.empty();
312    Optional<Long> executionTime = Optional.empty();
313    Optional<Long> resultSetMappingTime = Optional.empty();
314    Optional<Exception> exception = Optional.empty();
315
316    Connection connection = null;
317
318    try {
319      boolean alreadyHasConnection = currentTransaction().isPresent() && currentTransaction().get().hasConnection();
320      connection = acquireConnection();
321      connectionAcquisitionTime = alreadyHasConnection ? Optional.empty() : Optional.of(nanoTime() - startTime);
322      startTime = nanoTime();
323
324      try (PreparedStatement preparedStatement = connection.prepareStatement(sql)) {
325        preparedStatementBindingOperation.perform(preparedStatement);
326        preparationTime = Optional.of(nanoTime() - startTime);
327
328        DatabaseOperationResult databaseOperationResult = databaseOperation.perform(preparedStatement);
329        executionTime = databaseOperationResult.executionTime();
330        resultSetMappingTime = databaseOperationResult.resultSetMappingTime();
331      }
332    } catch (DatabaseException e) {
333      exception = Optional.of(e);
334      throw e;
335    } catch (Exception e) {
336      exception = Optional.of(e);
337      throw new DatabaseException(e);
338    } finally {
339      try {
340        // If this was a single-shot operation (not in a transaction), close the connection
341        if (connection != null && !currentTransaction().isPresent())
342          closeConnection(connection);
343      } finally {
344        StatementLog statementLog =
345            StatementLog.forSql(sql)
346                .parameters(parameters == null ? emptyList() : Arrays.asList(parameters))
347                .connectionAcquisitionTime(connectionAcquisitionTime)
348                .preparationTime(preparationTime)
349                .executionTime(executionTime)
350                .resultSetMappingTime(resultSetMappingTime)
351                .exception(exception)
352                .statementMetadata(Optional.ofNullable(statementMetadata))
353                .build();
354
355        statementLogger().log(statementLog);
356      }
357    }
358  }
359
360  public <T> List<T> queryForList(String sql, Class<T> elementType, Object... parameters) {
361    requireNonNull(sql);
362    requireNonNull(elementType);
363
364    return queryForList(sql, null, elementType, parameters);
365  }
366
367  public <T> List<T> queryForList(String sql, StatementMetadata statementMetadata, Class<T> elementType, Object... parameters) {
368    requireNonNull(sql);
369    requireNonNull(elementType);
370
371    List<T> list = new ArrayList<>();
372
373    performDatabaseOperation(sql, statementMetadata, parameters, (PreparedStatement preparedStatement) -> {
374      long startTime = nanoTime();
375
376      try (ResultSet resultSet = preparedStatement.executeQuery()) {
377        Long executionTime = nanoTime() - startTime;
378        startTime = nanoTime();
379
380        while (resultSet.next()) {
381          T listElement = resultSetMapper().map(resultSet, elementType);
382          list.add(listElement);
383        }
384
385        Long resultSetMappingTime = nanoTime() - startTime;
386        return new DatabaseOperationResult(Optional.of(executionTime), Optional.of(resultSetMappingTime));
387      }
388    });
389
390    return list;
391  }
392
393  private static class ResultHolder<T> {
394    T value;
395  }
396
397  public long execute(String sql, Object... parameters) {
398    requireNonNull(sql);
399    return execute(sql, null, parameters);
400  }
401
402  public long execute(String sql, StatementMetadata statementMetadata, Object... parameters) {
403    requireNonNull(sql);
404
405    ResultHolder<Long> resultHolder = new ResultHolder<>();
406
407    performDatabaseOperation(sql, statementMetadata, parameters, (PreparedStatement preparedStatement) -> {
408      long startTime = nanoTime();
409      // TODO: allow users to specify that they want support for executeLargeUpdate()
410      // Not everyone implements it currently
411      resultHolder.value = (long) preparedStatement.executeUpdate();
412      return new DatabaseOperationResult(Optional.of(nanoTime() - startTime), Optional.empty());
413    });
414
415    return resultHolder.value;
416  }
417
418  public <T> Optional<T> executeReturning(String sql, Class<T> returnType, Object... parameters) {
419    requireNonNull(sql);
420    requireNonNull(returnType);
421
422    return executeReturning(sql, null, returnType, parameters);
423  }
424
425  public <T> Optional<T> executeReturning(String sql, StatementMetadata statementMetadata, Class<T> returnType, Object... parameters) {
426    requireNonNull(sql);
427    requireNonNull(returnType);
428
429    ResultHolder<T> resultHolder = new ResultHolder<>();
430
431    performDatabaseOperation(sql, statementMetadata, parameters, (PreparedStatement preparedStatement) -> {
432      long startTime = nanoTime();
433
434      try (ResultSet resultSet = preparedStatement.executeQuery()) {
435        Long executionTime = nanoTime() - startTime;
436        startTime = nanoTime();
437
438        if (resultSet.next())
439          resultHolder.value = resultSetMapper().map(resultSet, returnType);
440
441        Long resultSetMappingTime = nanoTime() - startTime;
442        return new DatabaseOperationResult(Optional.of(executionTime), Optional.of(resultSetMappingTime));
443      }
444    });
445
446    return Optional.ofNullable(resultHolder.value);
447  }
448
449  public long[] executeBatch(String sql, List<List<Object>> parameterGroups) {
450    requireNonNull(sql);
451    requireNonNull(parameterGroups);
452
453    return executeBatch(sql, null, parameterGroups);
454  }
455
456  public long[] executeBatch(String sql, StatementMetadata statementMetadata, List<List<Object>> parameterGroups) {
457    requireNonNull(sql);
458    requireNonNull(parameterGroups);
459
460    ResultHolder<long[]> resultHolder = new ResultHolder<>();
461
462    performDatabaseOperation(sql, statementMetadata, parameterGroups.toArray(), (preparedStatement) -> {
463      for (List<Object> parameterGroup : parameterGroups) {
464        if (parameterGroup != null && parameterGroup.size() > 0)
465          preparedStatementBinder().bind(preparedStatement, parameterGroup);
466
467        preparedStatement.addBatch();
468      }
469    }, (PreparedStatement preparedStatement) -> {
470      long startTime = nanoTime();
471      // TODO: allow users to specify that they want support for executeLargeBatch()
472      // Not everyone implements it currently
473      int[] result = preparedStatement.executeBatch();
474      long[] longResult = new long[result.length];
475
476      for (int i = 0; i < result.length; ++i)
477        longResult[i] = result[i];
478
479      resultHolder.value = longResult;
480      return new DatabaseOperationResult(Optional.of(nanoTime() - startTime), Optional.empty());
481    });
482
483    return resultHolder.value;
484  }
485
486  protected Connection acquireConnection() {
487    Optional<Transaction> transaction = currentTransaction();
488
489    if (transaction.isPresent())
490      return transaction.get().connection();
491
492    try {
493      return dataSource.getConnection();
494    } catch (SQLException e) {
495      throw new DatabaseException("Unable to acquire database connection", e);
496    }
497  }
498
499  protected DataSource dataSource() {
500    return dataSource;
501  }
502
503  protected InstanceProvider instanceProvider() {
504    return instanceProvider;
505  }
506
507  protected PreparedStatementBinder preparedStatementBinder() {
508    return preparedStatementBinder;
509  }
510
511  protected ResultSetMapper resultSetMapper() {
512    return resultSetMapper;
513  }
514
515  protected StatementLogger statementLogger() {
516    return statementLogger;
517  }
518}