bool rewrite_query()

in src/rewrite_query.c [40:154]


bool rewrite_query(PgSocket *client, PktHdr *pkt) {
	SBuf *sbuf = &client->sbuf;
	char *pkt_start;
	char *stmt_str="", *query_str, *loggable_query_str, *tmp_new_query_str, *new_query_str;
	char *new_io_buf;
	char *remaining_buffer_ptr;
	int new_pkt_len, remaining_buffer_len;
	int i;

	if (!is_rewrite_enabled(client)) return true;
	switch (handle_incomplete_packet(client, pkt)) {
	case kIncompletePacketDecisionDisable:
	    return true;
	case kIncompletePacketDecisionDefer:
	    return false;
	case kIncompletePacketDecisionContinue:
	    ;  // no-op
	}

	/* extract query string from packet */
	/* first byte is the packet type (which we already know)
	 * next 4 bytes is the packet length
	 * For packet type 'Q', the query string is next
	 * 	'Q' | int32 len | str query
	 * For packet type 'P', the query string is after the stmt string
	 * 	'P' | int32 len | str stmt | str query | int16 numparams | int32 paramoid
	 * (Ref: https://www.pgcon.org/2014/schedule/attachments/330_postgres-for-the-wire.pdf)
	 */
	pkt_start = (char *) &sbuf->io->buf[sbuf->io->parse_pos];
	if (pkt->type == 'Q') {
		query_str = (char *) pkt_start + 5;
	} else if (pkt->type == 'P') {
		stmt_str = pkt_start + 5;
		query_str = stmt_str + strlen(stmt_str) + 1;
	} else {
		fatal("Invalid packet type - expected Q or P, got %c", pkt->type);
	}

	/* don't process same query again */
	if (is_rewritten(query_str)) return true;

	loggable_query_str = strip_newlines(query_str) ;
	slog_debug(client, "rewrite_query: Username => %s", client->login_user->name);
	slog_debug(client, "rewrite_query: Orig Query=> %s", loggable_query_str);
	free(loggable_query_str);

	/* call python function to rewrite the query */
	tmp_new_query_str = pycall(client, client->login_user->name, query_str, cf_rewrite_query_py_module_file,
			"rewrite_query");
	if (tmp_new_query_str == NULL) {
		slog_debug(client, "query unchanged");
		return true;
	}
	new_query_str = tag_rewritten(tmp_new_query_str);
	free(tmp_new_query_str);
	loggable_query_str = strip_newlines(new_query_str) ;
	slog_debug(client, "rewrite_query: New => %s", loggable_query_str);
	free(loggable_query_str);

	/* new query must fit in the buffer */
	if ((int)(sbuf->io->recv_pos + strlen(new_query_str) - strlen(query_str)) > (int)cf_sbuf_len) {
		slog_error(client,
				"Rewritten query will not fit into the allocated buffer!");
		free(new_query_str);
		switch (handle_failure(client)) {
		case kIncompletePacketDecisionDisable:
		case kIncompletePacketDecisionContinue:
		    return true;
		case kIncompletePacketDecisionDefer:
		    return false;
		}
	}

	/* manipulate the buffer to replace query */
	/* clone buffer */
	new_io_buf = malloc(cf_sbuf_len);
	if (new_io_buf == NULL) {
		fatal_perror("malloc");
	}
	memcpy(new_io_buf, sbuf->io->buf, cf_sbuf_len);
	i = sbuf->io->parse_pos;
	/* packet type */
	new_io_buf[i++] = pkt->type;
	/* packet length */
	new_pkt_len = pkt->len + strlen(new_query_str) - strlen(query_str) - 1;
	new_io_buf[i++] = (new_pkt_len >> 24) & 255;
	new_io_buf[i++] = (new_pkt_len >> 16) & 255;
	new_io_buf[i++] = (new_pkt_len >> 8) & 255;
	new_io_buf[i++] = new_pkt_len & 255;
	/* statement str - for type P */
	if (pkt->type == 'P') {
		strcpy(&new_io_buf[i], stmt_str);
		i += strlen(stmt_str) + 1;
	}
	/* query string */
	strcpy(&new_io_buf[i], new_query_str);
	i += strlen(new_query_str) + 1;
	/* copy everything else in buffer */
	remaining_buffer_ptr = query_str + strlen(query_str) + 1;
	remaining_buffer_len = (char *) &sbuf->io->buf[sbuf->io->recv_pos]
			- remaining_buffer_ptr;
	memcpy(&new_io_buf[i], remaining_buffer_ptr, remaining_buffer_len);
	i += remaining_buffer_len;
	/* replace original buffer with new buffer */
	memcpy(sbuf->io->buf, new_io_buf, i);
	/* adjust buffer recv_pos index to new position */
	sbuf->io->recv_pos = i;
	/* update PktHdr structure */
	pkt->len = new_pkt_len + 1;
	iobuf_parse_all(sbuf->io, &pkt->data);
	/* done */
	free(new_query_str);
	free(new_io_buf);
	return true;
}