in samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala [124:691]
def apply(
containerId: String,
jobModel: JobModel,
customReporters: Map[String, MetricsReporter] = Map[String, MetricsReporter](),
// TODO SAMZA-2671: there is further room for improvement for metrics wiring in general
registry: MetricsRegistryMap,
taskFactory: TaskFactory[_],
jobContext: JobContext,
applicationContainerContextFactoryOption: Option[ApplicationContainerContextFactory[ApplicationContainerContext]],
applicationTaskContextFactoryOption: Option[ApplicationTaskContextFactory[ApplicationTaskContext]],
externalContextOption: Option[ExternalContext],
localityManager: LocalityManager = null,
startpointManager: StartpointManager = null,
diagnosticsManager: Option[DiagnosticsManager] = Option.empty,
drainMonitor: DrainMonitor = null) = {
val config = if (StandbyTaskUtil.isStandbyContainer(containerId)) {
// standby containers will need to continually poll checkpoint messages
val newConfig = new util.HashMap[String, String](jobContext.getConfig)
newConfig.put(TaskConfig.INTERNAL_CHECKPOINT_MANAGER_CONSUMER_STOP_AFTER_FIRST_READ, java.lang.Boolean.FALSE.toString)
new MapConfig(newConfig)
} else {
jobContext.getConfig
}
val jobConfig = new JobConfig(config)
val systemConfig = new SystemConfig(config)
val containerModel = jobModel.getContainers.get(containerId)
val containerName = "samza-container-%s" format containerId
val containerPID = ManagementFactory.getRuntimeMXBean().getName()
info("Setting up Samza container: %s" format containerName)
info("Samza container PID: %s" format containerPID)
println("Container PID: %s" format containerPID)
info("Using configuration: %s" format config)
info("Using container model: %s" format containerModel)
val samzaContainerMetrics = new SamzaContainerMetrics(containerName, registry)
val systemProducersMetrics = new SystemProducersMetrics(registry)
val systemConsumersMetrics = new SystemConsumersMetrics(registry)
val offsetManagerMetrics = new OffsetManagerMetrics(registry)
val metricsConfig = new MetricsConfig(config)
val clock = if (metricsConfig.getMetricsTimerEnabled) {
new HighResolutionClock {
override def nanoTime(): Long = System.nanoTime()
}
} else {
new HighResolutionClock {
override def nanoTime(): Long = 0L
}
}
val inputSystemStreamPartitions = containerModel
.getTasks
.values
.asScala
.flatMap(_.getSystemStreamPartitions.asScala)
.toSet
val storageConfig = new StorageConfig(config)
val sideInputStoresToSystemStreams = storageConfig.getStoreNames.asScala
.map { storeName => (storeName, storageConfig.getSideInputs(storeName).asScala) }
.filter { case (storeName, sideInputs) => sideInputs.nonEmpty }
.map { case (storeName, sideInputs) => (storeName, sideInputs.map(StreamUtil.getSystemStreamFromNameOrId(config, _))) }
.toMap
val sideInputSystemStreams = sideInputStoresToSystemStreams.values.flatMap(sideInputs => sideInputs.toStream).toSet
info("Got side input store system streams: %s" format sideInputSystemStreams)
val inputSystemStreams = inputSystemStreamPartitions
.map(_.getSystemStream)
.toSet.diff(sideInputSystemStreams)
val inputSystems = inputSystemStreams
.map(_.getSystem)
.toSet
val systemNames = systemConfig.getSystemNames.asScala
info("Got system names: %s" format systemNames)
val streamConfig = new StreamConfig(config)
val serdeStreams = systemNames.foldLeft(Set[SystemStream]())(_ ++ streamConfig.getSerdeStreams(_).asScala)
info("Got serde streams: %s" format serdeStreams)
val systemFactories = systemNames.map(systemName => {
val systemFactoryClassName = JavaOptionals.toRichOptional(systemConfig.getSystemFactory(systemName)).toOption
.getOrElse(throw new SamzaException("A stream uses system %s, which is missing from the configuration." format systemName))
(systemName, ReflectionUtil.getObj(systemFactoryClassName, classOf[SystemFactory]))
}).toMap
info("Got system factories: %s" format systemFactories.keys)
val systemAdmins = new SystemAdmins(config, this.getClass.getSimpleName)
info("Got system admins: %s" format systemAdmins.getSystemNames)
val streamMetadataCache = new StreamMetadataCache(systemAdmins)
val inputStreamMetadata = streamMetadataCache.getStreamMetadata(inputSystemStreams)
info("Got input stream metadata: %s" format inputStreamMetadata)
val consumers = inputSystems
.map(systemName => {
val systemFactory = systemFactories(systemName)
try {
(systemName, systemFactory.getConsumer(systemName, config, samzaContainerMetrics.registry, this.getClass.getSimpleName))
} catch {
case e: Exception =>
error("Failed to create a consumer for %s, so skipping." format systemName, e)
(systemName, null)
}
})
.filter(_._2 != null)
.toMap
info("Got system consumers: %s" format consumers.keys)
val producers = systemFactories
.map {
case (systemName, systemFactory) =>
try {
(systemName, systemFactory.getProducer(systemName, config, samzaContainerMetrics.registry, this.getClass.getSimpleName))
} catch {
case e: Exception =>
error("Failed to create a producer for %s, so skipping." format systemName, e)
(systemName, null)
}
}
.filter(_._2 != null)
info("Got system producers: %s" format producers.keys)
val serializerConfig = new SerializerConfig(config)
val serdesFromFactories = serializerConfig.getSerdeNames.asScala.map(serdeName => {
val serdeClassName = JavaOptionals.toRichOptional(serializerConfig.getSerdeFactoryClass(serdeName)).toOption
.getOrElse(SerializerConfig.getPredefinedSerdeFactoryName(serdeName))
val serde = ReflectionUtil.getObj(serdeClassName, classOf[SerdeFactory[Object]])
.getSerde(serdeName, config)
(serdeName, serde)
}).toMap
info("Got serdes from factories: %s" format serdesFromFactories.keys)
val serializableSerde = new SerializableSerde[Serde[Object]]()
val serdesFromSerializedInstances = config.subset(SerializerConfig.SERIALIZER_PREFIX format "").asScala
.filter { case (key, value) => key.endsWith(SerializerConfig.SERIALIZED_INSTANCE_SUFFIX) }
.flatMap { case (key, value) =>
val serdeName = key.replace(SerializerConfig.SERIALIZED_INSTANCE_SUFFIX, "")
debug(s"Trying to deserialize serde instance for $serdeName")
try {
val bytes = Base64.getDecoder.decode(value)
val serdeInstance = serializableSerde.fromBytes(bytes)
debug(s"Returning serialized instance for $serdeName")
Some((serdeName, serdeInstance))
} catch {
case e: Exception =>
warn(s"Ignoring invalid serialized instance for $serdeName: $value", e)
None
}
}
info("Got serdes from serialized instances: %s" format serdesFromSerializedInstances.keys)
val serdes = serdesFromFactories ++ serdesFromSerializedInstances
/*
* A Helper function to build a Map[String, Serde] (systemName -> Serde) for systems defined
* in the config. This is useful to build both key and message serde maps.
*/
val buildSystemSerdeMap = (getSerdeName: (String) => Option[String]) => {
systemNames
.filter(systemName => getSerdeName(systemName).isDefined)
.flatMap(systemName => {
val serdeName = getSerdeName(systemName).get
val serde = serdes.getOrElse(serdeName,
throw new SamzaException("buildSystemSerdeMap: No class defined for serde: %s." format serdeName))
// this shouldn't happen since system level serdes can't be set programmatically using the high level
// API, but adding this for safety.
Option(serde)
.filter(!_.isInstanceOf[NoOpSerde[Any]])
.map(serde => (systemName, serde))
}).toMap
}
/*
* A Helper function to build a Map[SystemStream, Serde] for streams defined in the config.
* This is useful to build both key and message serde maps.
*/
val buildSystemStreamSerdeMap = (getSerdeName: (SystemStream) => Optional[String]) => {
(serdeStreams ++ inputSystemStreamPartitions)
.filter(systemStream => getSerdeName(systemStream).isPresent)
.flatMap(systemStream => {
val serdeName = getSerdeName(systemStream).get
val serde = serdes.getOrElse(serdeName,
throw new SamzaException("buildSystemStreamSerdeMap: No serde found for name: %s." format serdeName))
// respect explicitly set no-op serdes in high level API
Option(serde)
.filter(!_.isInstanceOf[NoOpSerde[Any]])
.map(serde => (systemStream, serde))
}).toMap
}
val systemKeySerdes = buildSystemSerdeMap(systemName =>
JavaOptionals.toRichOptional(systemConfig.getSystemKeySerde(systemName)).toOption)
debug("Got system key serdes: %s" format systemKeySerdes)
val systemMessageSerdes = buildSystemSerdeMap(systemName =>
JavaOptionals.toRichOptional(systemConfig.getSystemMsgSerde(systemName)).toOption)
debug("Got system message serdes: %s" format systemMessageSerdes)
val systemStreamKeySerdes = buildSystemStreamSerdeMap(systemStream => streamConfig.getStreamKeySerde(systemStream))
debug("Got system stream key serdes: %s" format systemStreamKeySerdes)
val systemStreamMessageSerdes = buildSystemStreamSerdeMap(systemStream => streamConfig.getStreamMsgSerde(systemStream))
debug("Got system stream message serdes: %s" format systemStreamMessageSerdes)
val storeChangelogs = storageConfig.getStoreChangelogs
info("Got change log system streams: %s" format storeChangelogs)
val intermediateStreams = streamConfig
.getStreamIds()
.asScala
.filter((streamId:String) => streamConfig.getIsIntermediateStream(streamId))
.toList
info("Got intermediate streams: %s" format intermediateStreams)
val controlMessageKeySerdes = intermediateStreams
.flatMap(streamId => {
val systemStream = streamConfig.streamIdToSystemStream(streamId)
systemStreamKeySerdes.get(systemStream)
.orElse(systemKeySerdes.get(systemStream.getSystem))
.map(serde => (systemStream, new StringSerde("UTF-8")))
}).toMap
val intermediateStreamMessageSerdes = intermediateStreams
.flatMap(streamId => {
val systemStream = streamConfig.streamIdToSystemStream(streamId)
systemStreamMessageSerdes.get(systemStream)
.orElse(systemMessageSerdes.get(systemStream.getSystem))
.map(serde => (systemStream, new IntermediateMessageSerde(serde)))
}).toMap
val serdeManager = new SerdeManager(
serdes = serdes,
systemKeySerdes = systemKeySerdes,
systemMessageSerdes = systemMessageSerdes,
systemStreamKeySerdes = systemStreamKeySerdes,
systemStreamMessageSerdes = systemStreamMessageSerdes,
changeLogSystemStreams = storeChangelogs.asScala.values.toSet,
controlMessageKeySerdes = controlMessageKeySerdes,
intermediateMessageSerdes = intermediateStreamMessageSerdes)
info("Setting up JVM metrics.")
val jvm = new JvmMetrics(samzaContainerMetrics.registry)
info("Setting up message chooser.")
val taskConfig = new TaskConfig(config)
val chooserFactoryClassName = taskConfig.getMessageChooserClass
val chooserFactory = ReflectionUtil.getObj(chooserFactoryClassName, classOf[MessageChooserFactory])
val chooser = DefaultChooser(inputStreamMetadata, chooserFactory, config, samzaContainerMetrics.registry, systemAdmins)
info("Setting up metrics reporters.")
val reporters =
MetricsReporterLoader.getMetricsReporters(metricsConfig, containerName).asScala.toMap ++ customReporters
info("Got metrics reporters: %s" format reporters.keys)
val securityManager = JavaOptionals.toRichOptional(jobConfig.getSecurityManagerFactory).toOption match {
case Some(securityManagerFactoryClassName) =>
ReflectionUtil.getObj(securityManagerFactoryClassName, classOf[SecurityManagerFactory])
.getSecurityManager(config)
case _ => null
}
info("Got security manager: %s" format securityManager)
val checkpointManager = taskConfig.getCheckpointManager(samzaContainerMetrics.registry).orElse(null)
info("Got checkpoint manager: %s" format checkpointManager)
// create a map of consumers with callbacks to pass to the OffsetManager
val checkpointListeners = consumers.filter(_._2.isInstanceOf[CheckpointListener])
.map { case (system, consumer) => (system, consumer.asInstanceOf[CheckpointListener])}
info("Got checkpointListeners : %s" format checkpointListeners)
val offsetManager = OffsetManager(inputStreamMetadata, config, checkpointManager, startpointManager, systemAdmins, checkpointListeners, offsetManagerMetrics)
info("Got offset manager: %s" format offsetManager)
val dropDeserializationError = taskConfig.getDropDeserializationErrors
val dropSerializationError = taskConfig.getDropSerializationErrors
val pollIntervalMs = taskConfig.getPollIntervalMs
val appConfig = new ApplicationConfig(config)
val consumerMultiplexer = new SystemConsumers(
chooser = chooser,
consumers = consumers,
systemAdmins = systemAdmins,
serdeManager = serdeManager,
metrics = systemConsumersMetrics,
dropDeserializationError = dropDeserializationError,
pollIntervalMs = pollIntervalMs,
clock = () => clock.nanoTime(),
elasticityFactor = jobConfig.getElasticityFactor,
runId = appConfig.getRunId)
val producerMultiplexer = new SystemProducers(
producers = producers,
serdeManager = serdeManager,
metrics = systemProducersMetrics,
dropSerializationError = dropSerializationError)
val storageEngineFactories = storageConfig
.getStoreNames.asScala
.map(storeName => {
val storageFactoryClassName =
JavaOptionals.toRichOptional(storageConfig.getStorageFactoryClassName(storeName)).toOption
.getOrElse(throw new SamzaException("Missing storage factory for %s." format storeName))
(storeName,
ReflectionUtil.getObj(storageFactoryClassName, classOf[StorageEngineFactory[Object, Object]]))
}).toMap
info("Got storage engines: %s" format storageEngineFactories.keys)
val threadPoolSize = jobConfig.getThreadPoolSize
info("Got thread pool size: " + threadPoolSize)
samzaContainerMetrics.containerThreadPoolSize.set(threadPoolSize)
val taskThreadPool = if (threadPoolSize > 0) {
val taskExecutorFactoryClassName = jobConfig.getTaskExecutorFactory
val taskExecutorFactory = ReflectionUtil.getObj(taskExecutorFactoryClassName, classOf[TaskExecutorFactory])
taskExecutorFactory.getTaskExecutor(config)
} else {
null
}
val finalTaskFactory = TaskFactoryUtil.finalizeTaskFactory(
taskFactory,
taskThreadPool)
// executor for performing async commit operations for a task.
val commitThreadPoolSize =
Math.min(
Math.max(containerModel.getTasks.size() * 2, jobConfig.getCommitThreadPoolSize),
jobConfig.getCommitThreadPoolMaxSize
)
val commitThreadPool = Executors.newFixedThreadPool(commitThreadPoolSize,
new ThreadFactoryBuilder().setNameFormat("Samza Task Commit Thread-%d").setDaemon(true).build())
val taskModels = containerModel.getTasks.values.asScala
val containerContext = new ContainerContextImpl(containerModel, samzaContainerMetrics.registry, taskThreadPool)
val applicationContainerContextOption = applicationContainerContextFactoryOption
.map(_.create(externalContextOption.orNull, jobContext, containerContext))
val storeWatchPaths = new util.HashSet[Path]()
val timerExecutor = Executors.newSingleThreadScheduledExecutor
val taskInstanceMetrics: Map[TaskName, TaskInstanceMetrics] = taskModels.map(taskModel => {
(taskModel.getTaskName, new TaskInstanceMetrics("TaskName-%s" format taskModel.getTaskName))
}).toMap
val taskCollectors : Map[TaskName, TaskInstanceCollector] = taskModels.map(taskModel => {
(taskModel.getTaskName, new TaskInstanceCollector(producerMultiplexer, taskInstanceMetrics.get(taskModel.getTaskName).get))
}).toMap
val defaultStoreBaseDir = new File(System.getProperty("user.dir"), "state")
info("Got default storage engine base directory: %s" format defaultStoreBaseDir)
val nonLoggedStorageBaseDir = getNonLoggedStorageBaseDir(jobConfig, defaultStoreBaseDir)
info("Got base directory for non logged data stores: %s" format nonLoggedStorageBaseDir)
val loggedStorageBaseDir = getLoggedStorageBaseDir(jobConfig, defaultStoreBaseDir)
info("Got base directory for logged data stores: %s" format loggedStorageBaseDir)
val backupFactoryNames = storageConfig.getBackupFactories
val restoreFactoryNames = storageConfig.getRestoreFactories
// Restore factories should be a subset of backup factories
if (!backupFactoryNames.containsAll(restoreFactoryNames)) {
backupFactoryNames.removeAll(restoreFactoryNames)
throw new SamzaException("Restore state backend factories is not a subset of backup state backend factories, " +
"missing factories: " + backupFactoryNames.toString)
}
val stateStorageBackendBackupFactories = backupFactoryNames.asScala.map(
ReflectionUtil.getObj(_, classOf[StateBackendFactory])
)
val stateStorageBackendRestoreFactories = restoreFactoryNames.asScala.map(
factoryName => (factoryName , ReflectionUtil.getObj(factoryName, classOf[StateBackendFactory])))
.toMap.asJava
val containerStorageManager = new ContainerStorageManager(
checkpointManager,
containerModel,
streamMetadataCache,
systemAdmins,
storeChangelogs,
sideInputStoresToSystemStreams.mapValues(systemStreamSet => systemStreamSet.toSet.asJava).asJava,
storageEngineFactories.asJava,
systemFactories.asJava,
serdes.asJava,
config,
taskInstanceMetrics.asJava,
samzaContainerMetrics,
jobContext,
containerContext,
stateStorageBackendRestoreFactories,
taskCollectors.asJava,
loggedStorageBaseDir,
nonLoggedStorageBaseDir,
serdeManager,
SystemClock.instance())
storeWatchPaths.addAll(containerStorageManager.getStoreDirectoryPaths)
// Create taskInstances
val taskInstances: Map[TaskName, TaskInstance] = taskModels
.filter(taskModel => taskModel.getTaskMode.eq(TaskMode.Active)).map(taskModel => {
debug("Setting up task instance: %s" format taskModel)
val taskName = taskModel.getTaskName
val task = finalTaskFactory match {
case tf: AsyncStreamTaskFactory => tf.asInstanceOf[AsyncStreamTaskFactory].createInstance()
case tf: StreamTaskFactory => tf.asInstanceOf[StreamTaskFactory].createInstance()
}
val taskSSPs = taskModel.getSystemStreamPartitions.asScala.toSet
info("Got task SSPs: %s" format taskSSPs)
val sideInputStoresToSSPs = sideInputStoresToSystemStreams.mapValues(sideInputSystemStreams =>
taskSSPs.filter(ssp => sideInputSystemStreams.contains(ssp.getSystemStream)).asJava)
val taskSideInputSSPs = sideInputStoresToSSPs.values.flatMap(_.asScala).toSet
info ("Got task side input SSPs: %s" format taskSideInputSSPs)
val taskBackupManagerMap = new util.HashMap[String, TaskBackupManager]()
val systemAdminsMap = systemAdmins.getSystemAdmins
stateStorageBackendBackupFactories.asJava.forEach(new Consumer[StateBackendFactory] {
override def accept(factory: StateBackendFactory): Unit = {
val taskMetricsRegistry =
if (taskInstanceMetrics.contains(taskName) &&
taskInstanceMetrics.get(taskName).isDefined) taskInstanceMetrics.get(taskName).get.registry
else new MetricsRegistryMap
val taskBackupManager = factory.getBackupManager(jobContext, containerModel,
taskModel, systemAdminsMap, commitThreadPool, taskMetricsRegistry, config, SystemClock.instance,
loggedStorageBaseDir, nonLoggedStorageBaseDir)
taskBackupManagerMap.put(factory.getClass.getName, taskBackupManager)
}
})
val commitManager = new TaskStorageCommitManager(taskName, taskBackupManagerMap,
containerStorageManager, storeChangelogs, taskModel.getChangelogPartition, checkpointManager, config,
commitThreadPool, new StorageManagerUtil, loggedStorageBaseDir, taskInstanceMetrics.get(taskName).get)
val tableManager = new TableManager(config)
info("Got table manager")
def createTaskInstance(task: Any): TaskInstance = new TaskInstance(
task = task,
taskModel = taskModel,
metrics = taskInstanceMetrics.get(taskName).get,
systemAdmins = systemAdmins,
consumerMultiplexer = consumerMultiplexer,
collector = taskCollectors.get(taskName).get,
offsetManager = offsetManager,
commitManager = commitManager,
containerStorageManager = containerStorageManager,
tableManager = tableManager,
systemStreamPartitions = (taskSSPs -- taskSideInputSSPs).asJava,
exceptionHandler = TaskInstanceExceptionHandler(taskInstanceMetrics.get(taskName).get, taskConfig),
jobModel = jobModel,
streamMetadataCache = streamMetadataCache,
inputStreamMetadata = inputStreamMetadata,
timerExecutor = timerExecutor,
commitThreadPool = commitThreadPool,
jobContext = jobContext,
containerContext = containerContext,
applicationContainerContextOption = applicationContainerContextOption,
applicationTaskContextFactoryOption = applicationTaskContextFactoryOption,
externalContextOption = externalContextOption,
elasticityFactor = jobConfig.getElasticityFactor)
val taskInstance = createTaskInstance(task)
(taskName, taskInstance)
}).toMap
val runLoop: Runnable = RunLoopFactory.createRunLoop(
taskInstances,
consumerMultiplexer,
taskThreadPool,
samzaContainerMetrics,
clock,
config)
val systemStatisticsMonitor : SystemStatisticsMonitor = new StatisticsMonitorImpl()
systemStatisticsMonitor.registerListener(
new SamzaContainerMonitorListener(config, samzaContainerMetrics, taskThreadPool))
val diskQuotaBytes = config.getLong("container.disk.quota.bytes", Long.MaxValue)
samzaContainerMetrics.diskQuotaBytes.set(diskQuotaBytes)
val diskQuotaPolicyFactoryString = config.get("container.disk.quota.policy.factory",
classOf[NoThrottlingDiskQuotaPolicyFactory].getName)
val diskQuotaPolicyFactory = ReflectionUtil.getObj(diskQuotaPolicyFactoryString, classOf[DiskQuotaPolicyFactory])
val diskQuotaPolicy = diskQuotaPolicyFactory.create(config)
var diskSpaceMonitor: DiskSpaceMonitor = null
val diskPollMillis = config.getInt(DISK_POLL_INTERVAL_KEY, 0)
if (diskPollMillis != 0) {
diskSpaceMonitor = new PollingScanDiskSpaceMonitor(storeWatchPaths, diskPollMillis)
diskSpaceMonitor.registerListener(new Listener {
override def onUpdate(diskUsageBytes: Long): Unit = {
val newWorkRate = diskQuotaPolicy.apply(1.0 - (diskUsageBytes.toDouble / diskQuotaBytes))
runLoop.asInstanceOf[Throttleable].setWorkFactor(newWorkRate)
samzaContainerMetrics.executorWorkFactor.set(runLoop.asInstanceOf[Throttleable].getWorkFactor)
samzaContainerMetrics.diskUsageBytes.set(diskUsageBytes)
}
})
info("Initialized disk space monitor watch paths to: %s" format storeWatchPaths)
} else {
info(s"Disk quotas disabled because polling interval is not set ($DISK_POLL_INTERVAL_KEY)")
}
info("Samza container setup complete.")
new SamzaContainer(
config = config,
taskInstances = taskInstances,
taskInstanceMetrics = taskInstanceMetrics,
runLoop = runLoop,
systemAdmins = systemAdmins,
consumerMultiplexer = consumerMultiplexer,
producerMultiplexer = producerMultiplexer,
localityManager = localityManager,
offsetManager = offsetManager,
securityManager = securityManager,
metrics = samzaContainerMetrics,
reporters = reporters,
jvm = jvm,
diskSpaceMonitor = diskSpaceMonitor,
hostStatisticsMonitor = systemStatisticsMonitor,
taskThreadPool = taskThreadPool,
commitThreadPool = commitThreadPool,
timerExecutor = timerExecutor,
containerContext = containerContext,
applicationContainerContextOption = applicationContainerContextOption,
externalContextOption = externalContextOption,
containerStorageManager = containerStorageManager,
drainMonitor = drainMonitor,
diagnosticsManager = diagnosticsManager)
}