public static void processSkewJoin()

in ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/GenMRSkewJoinProcessor.java [110:365]


  public static void processSkewJoin(JoinOperator joinOp,
      Task<?> currTask, ParseContext parseCtx)
      throws SemanticException {

    // We are trying to adding map joins to handle skew keys, and map join right
    // now does not work with outer joins
    if (!GenMRSkewJoinProcessor.skewJoinEnabled(parseCtx.getConf(), joinOp)) {
      return;
    }

    List<Task<?>> children = currTask.getChildTasks();

    Path baseTmpDir = parseCtx.getContext().getMRTmpPath();

    JoinDesc joinDescriptor = joinOp.getConf();
    Map<Byte, List<ExprNodeDesc>> joinValues = joinDescriptor.getExprs();
    int numAliases = joinValues.size();

    Map<Byte, Path> bigKeysDirMap = new HashMap<Byte, Path>();
    Map<Byte, Map<Byte, Path>> smallKeysDirMap = new HashMap<Byte, Map<Byte, Path>>();
    Map<Byte, Path> skewJoinJobResultsDir = new HashMap<Byte, Path>();
    Byte[] tags = joinDescriptor.getTagOrder();
    for (int i = 0; i < numAliases; i++) {
      Byte alias = tags[i];
      bigKeysDirMap.put(alias, getBigKeysDir(baseTmpDir, alias));
      Map<Byte, Path> smallKeysMap = new HashMap<Byte, Path>();
      smallKeysDirMap.put(alias, smallKeysMap);
      for (Byte src2 : tags) {
        if (!src2.equals(alias)) {
          smallKeysMap.put(src2, getSmallKeysDir(baseTmpDir, alias, src2));
        }
      }
      skewJoinJobResultsDir.put(alias, getBigKeysSkewJoinResultDir(baseTmpDir,
          alias));
    }

    joinDescriptor.setHandleSkewJoin(true);
    joinDescriptor.setBigKeysDirMap(bigKeysDirMap);
    joinDescriptor.setSmallKeysDirMap(smallKeysDirMap);
    joinDescriptor.setSkewKeyDefinition(HiveConf.getIntVar(parseCtx.getConf(),
        HiveConf.ConfVars.HIVE_SKEWJOIN_KEY));

    HashMap<Path, Task<?>> bigKeysDirToTaskMap =
      new HashMap<Path, Task<?>>();
    List<Serializable> listWorks = new ArrayList<Serializable>();
    List<Task<?>> listTasks = new ArrayList<Task<?>>();
    MapredWork currPlan = (MapredWork) currTask.getWork();

    TableDesc keyTblDesc = (TableDesc) currPlan.getReduceWork().getKeyDesc().clone();
    List<String> joinKeys = Utilities
        .getColumnNames(keyTblDesc.getProperties());
    List<String> joinKeyTypes = Utilities.getColumnTypes(keyTblDesc
        .getProperties());

    Map<Byte, TableDesc> tableDescList = new HashMap<Byte, TableDesc>();
    Map<Byte, RowSchema> rowSchemaList = new HashMap<Byte, RowSchema>();
    Map<Byte, List<ExprNodeDesc>> newJoinValues = new HashMap<Byte, List<ExprNodeDesc>>();
    Map<Byte, List<ExprNodeDesc>> newJoinKeys = new HashMap<Byte, List<ExprNodeDesc>>();
    // used for create mapJoinDesc, should be in order
    List<TableDesc> newJoinValueTblDesc = new ArrayList<TableDesc>();

    for (Byte tag : tags) {
      newJoinValueTblDesc.add(null);
    }

    for (int i = 0; i < numAliases; i++) {
      Byte alias = tags[i];
      List<ExprNodeDesc> valueCols = joinValues.get(alias);
      String colNames = "";
      String colTypes = "";
      int columnSize = valueCols.size();
      List<ExprNodeDesc> newValueExpr = new ArrayList<ExprNodeDesc>();
      List<ExprNodeDesc> newKeyExpr = new ArrayList<ExprNodeDesc>();
      ArrayList<ColumnInfo> columnInfos = new ArrayList<ColumnInfo>();

      boolean first = true;
      for (int k = 0; k < columnSize; k++) {
        TypeInfo type = valueCols.get(k).getTypeInfo();
        String newColName = i + "_VALUE_" + k; // any name, it does not matter.
        ColumnInfo columnInfo = new ColumnInfo(newColName, type, alias.toString(), false);
        columnInfos.add(columnInfo);
        newValueExpr.add(new ExprNodeColumnDesc(columnInfo));
        if (!first) {
          colNames = colNames + ",";
          colTypes = colTypes + ",";
        }
        first = false;
        colNames = colNames + newColName;
        colTypes = colTypes + valueCols.get(k).getTypeString();
      }

      // we are putting join keys at last part of the spilled table
      for (int k = 0; k < joinKeys.size(); k++) {
        if (!first) {
          colNames = colNames + ",";
          colTypes = colTypes + ",";
        }
        first = false;
        colNames = colNames + joinKeys.get(k);
        colTypes = colTypes + joinKeyTypes.get(k);
        ColumnInfo columnInfo = new ColumnInfo(joinKeys.get(k), TypeInfoFactory
            .getPrimitiveTypeInfo(joinKeyTypes.get(k)), alias.toString(), false);
        columnInfos.add(columnInfo);
        newKeyExpr.add(new ExprNodeColumnDesc(columnInfo));
      }

      newJoinValues.put(alias, newValueExpr);
      newJoinKeys.put(alias, newKeyExpr);
      tableDescList.put(alias, Utilities.getTableDesc(colNames, colTypes));
      rowSchemaList.put(alias, new RowSchema(columnInfos));

      // construct value table Desc
      String valueColNames = "";
      String valueColTypes = "";
      first = true;
      for (int k = 0; k < columnSize; k++) {
        String newColName = i + "_VALUE_" + k; // any name, it does not matter.
        if (!first) {
          valueColNames = valueColNames + ",";
          valueColTypes = valueColTypes + ",";
        }
        valueColNames = valueColNames + newColName;
        valueColTypes = valueColTypes + valueCols.get(k).getTypeString();
        first = false;
      }
      newJoinValueTblDesc.set(Byte.valueOf((byte) i), Utilities.getTableDesc(
          valueColNames, valueColTypes));
    }

    joinDescriptor.setSkewKeysValuesTables(tableDescList);
    joinDescriptor.setKeyTableDesc(keyTblDesc);

    for (int i = 0; i < numAliases - 1; i++) {
      Byte src = tags[i];
      MapWork newPlan = PlanUtils.getMapRedWork().getMapWork();

      // This code has been only added for testing
      boolean mapperCannotSpanPartns =
        parseCtx.getConf().getBoolVar(
          HiveConf.ConfVars.HIVE_MAPPER_CANNOT_SPAN_MULTIPLE_PARTITIONS);
      newPlan.setMapperCannotSpanPartns(mapperCannotSpanPartns);

      MapredWork clonePlan = SerializationUtilities.clonePlan(currPlan);

      Operator<? extends OperatorDesc>[] parentOps = new TableScanOperator[tags.length];
      for (int k = 0; k < tags.length; k++) {
        Operator<? extends OperatorDesc> ts =
            GenMapRedUtils.createTemporaryTableScanOperator(
                joinOp.getCompilationOpContext(), rowSchemaList.get((byte)k));
        ((TableScanOperator)ts).setTableDescSkewJoin(tableDescList.get((byte)k));
        parentOps[k] = ts;
      }
      Operator<? extends OperatorDesc> tblScan_op = parentOps[i];

      ArrayList<String> aliases = new ArrayList<String>();
      String alias = src.toString().intern();
      aliases.add(alias);
      Path bigKeyDirPath = bigKeysDirMap.get(src);
      newPlan.addPathToAlias(bigKeyDirPath, aliases);

      newPlan.getAliasToWork().put(alias, tblScan_op);
      PartitionDesc part = new PartitionDesc(tableDescList.get(src), null);

      newPlan.addPathToPartitionInfo(bigKeyDirPath, part);
      newPlan.getAliasToPartnInfo().put(alias, part);

      Operator<? extends OperatorDesc> reducer = clonePlan.getReduceWork().getReducer();
      assert reducer instanceof JoinOperator;
      JoinOperator cloneJoinOp = (JoinOperator) reducer;

      String dumpFilePrefix = "mapfile"+PlanUtils.getCountForMapJoinDumpFilePrefix();
      MapJoinDesc mapJoinDescriptor = new MapJoinDesc(newJoinKeys, keyTblDesc,
          newJoinValues, newJoinValueTblDesc, newJoinValueTblDesc,joinDescriptor
          .getOutputColumnNames(), i, joinDescriptor.getConds(),
          joinDescriptor.getFilters(), joinDescriptor.getNoOuterJoin(), dumpFilePrefix,
          joinDescriptor.getMemoryMonitorInfo(), joinDescriptor.getInMemoryDataSize());
      mapJoinDescriptor.setTagOrder(tags);
      mapJoinDescriptor.setHandleSkewJoin(false);
      mapJoinDescriptor.setNullSafes(joinDescriptor.getNullSafes());
      mapJoinDescriptor.setColumnExprMap(joinDescriptor.getColumnExprMap());

      MapredLocalWork localPlan = new MapredLocalWork(
          new LinkedHashMap<String, Operator<? extends OperatorDesc>>(),
          new LinkedHashMap<String, FetchWork>());
      Map<Byte, Path> smallTblDirs = smallKeysDirMap.get(src);

      for (int j = 0; j < numAliases; j++) {
        if (j == i) {
          continue;
        }
        Byte small_alias = tags[j];
        Operator<? extends OperatorDesc> tblScan_op2 = parentOps[j];
        localPlan.getAliasToWork().put(small_alias.toString(), tblScan_op2);
        Path tblDir = smallTblDirs.get(small_alias);
        localPlan.getAliasToFetchWork().put(small_alias.toString(),
            new FetchWork(tblDir, tableDescList.get(small_alias)));
      }

      newPlan.setMapRedLocalWork(localPlan);

      // construct a map join and set it as the child operator of tblScan_op
      MapJoinOperator mapJoinOp = (MapJoinOperator) OperatorFactory.getAndMakeChild(
          joinOp.getCompilationOpContext(), mapJoinDescriptor, (RowSchema) null, parentOps);
      // change the children of the original join operator to point to the map
      // join operator
      List<Operator<? extends OperatorDesc>> childOps = cloneJoinOp
          .getChildOperators();
      for (Operator<? extends OperatorDesc> childOp : childOps) {
        childOp.replaceParent(cloneJoinOp, mapJoinOp);
      }
      mapJoinOp.setChildOperators(childOps);

      HiveConf jc = new HiveConf(parseCtx.getConf(),
          GenMRSkewJoinProcessor.class);

      newPlan.setNumMapTasks(HiveConf
          .getIntVar(jc, HiveConf.ConfVars.HIVE_SKEWJOIN_MAPJOIN_NUM_MAP_TASK));
      newPlan
          .setMinSplitSize(HiveConf.getLongVar(jc, HiveConf.ConfVars.HIVE_SKEWJOIN_MAPJOIN_MIN_SPLIT));
      newPlan.setInputformat(HiveInputFormat.class.getName());

      MapredWork w = new MapredWork();
      w.setMapWork(newPlan);

      Task<?> skewJoinMapJoinTask = TaskFactory.get(w);
      skewJoinMapJoinTask.setFetchSource(currTask.isFetchSource());

      bigKeysDirToTaskMap.put(bigKeyDirPath, skewJoinMapJoinTask);
      listWorks.add(skewJoinMapJoinTask.getWork());
      listTasks.add(skewJoinMapJoinTask);
    }
    if (children != null) {
      for (Task<?> tsk : listTasks) {
        for (Task<?> oldChild : children) {
          tsk.addDependentTask(oldChild);
        }
      }
      currTask.setChildTasks(new ArrayList<Task<?>>());
      for (Task<?> oldChild : children) {
        oldChild.getParentTasks().remove(currTask);
      }
      listTasks.addAll(children);
    }
    ConditionalResolverSkewJoinCtx context =
        new ConditionalResolverSkewJoinCtx(bigKeysDirToTaskMap, children);

    ConditionalWork cndWork = new ConditionalWork(listWorks);
    ConditionalTask cndTsk = (ConditionalTask) TaskFactory.get(cndWork);
    cndTsk.setListTasks(listTasks);
    cndTsk.setResolver(new ConditionalResolverSkewJoin());
    cndTsk.setResolverCtx(context);
    currTask.setChildTasks(new ArrayList<Task<?>>());
    currTask.addDependentTask(cndTsk);

    return;
  }