in src/rewrite_query.c [40:154]
bool rewrite_query(PgSocket *client, PktHdr *pkt) {
SBuf *sbuf = &client->sbuf;
char *pkt_start;
char *stmt_str="", *query_str, *loggable_query_str, *tmp_new_query_str, *new_query_str;
char *new_io_buf;
char *remaining_buffer_ptr;
int new_pkt_len, remaining_buffer_len;
int i;
if (!is_rewrite_enabled(client)) return true;
switch (handle_incomplete_packet(client, pkt)) {
case kIncompletePacketDecisionDisable:
return true;
case kIncompletePacketDecisionDefer:
return false;
case kIncompletePacketDecisionContinue:
; // no-op
}
/* extract query string from packet */
/* first byte is the packet type (which we already know)
* next 4 bytes is the packet length
* For packet type 'Q', the query string is next
* 'Q' | int32 len | str query
* For packet type 'P', the query string is after the stmt string
* 'P' | int32 len | str stmt | str query | int16 numparams | int32 paramoid
* (Ref: https://www.pgcon.org/2014/schedule/attachments/330_postgres-for-the-wire.pdf)
*/
pkt_start = (char *) &sbuf->io->buf[sbuf->io->parse_pos];
if (pkt->type == 'Q') {
query_str = (char *) pkt_start + 5;
} else if (pkt->type == 'P') {
stmt_str = pkt_start + 5;
query_str = stmt_str + strlen(stmt_str) + 1;
} else {
fatal("Invalid packet type - expected Q or P, got %c", pkt->type);
}
/* don't process same query again */
if (is_rewritten(query_str)) return true;
loggable_query_str = strip_newlines(query_str) ;
slog_debug(client, "rewrite_query: Username => %s", client->login_user->name);
slog_debug(client, "rewrite_query: Orig Query=> %s", loggable_query_str);
free(loggable_query_str);
/* call python function to rewrite the query */
tmp_new_query_str = pycall(client, client->login_user->name, query_str, cf_rewrite_query_py_module_file,
"rewrite_query");
if (tmp_new_query_str == NULL) {
slog_debug(client, "query unchanged");
return true;
}
new_query_str = tag_rewritten(tmp_new_query_str);
free(tmp_new_query_str);
loggable_query_str = strip_newlines(new_query_str) ;
slog_debug(client, "rewrite_query: New => %s", loggable_query_str);
free(loggable_query_str);
/* new query must fit in the buffer */
if ((int)(sbuf->io->recv_pos + strlen(new_query_str) - strlen(query_str)) > (int)cf_sbuf_len) {
slog_error(client,
"Rewritten query will not fit into the allocated buffer!");
free(new_query_str);
switch (handle_failure(client)) {
case kIncompletePacketDecisionDisable:
case kIncompletePacketDecisionContinue:
return true;
case kIncompletePacketDecisionDefer:
return false;
}
}
/* manipulate the buffer to replace query */
/* clone buffer */
new_io_buf = malloc(cf_sbuf_len);
if (new_io_buf == NULL) {
fatal_perror("malloc");
}
memcpy(new_io_buf, sbuf->io->buf, cf_sbuf_len);
i = sbuf->io->parse_pos;
/* packet type */
new_io_buf[i++] = pkt->type;
/* packet length */
new_pkt_len = pkt->len + strlen(new_query_str) - strlen(query_str) - 1;
new_io_buf[i++] = (new_pkt_len >> 24) & 255;
new_io_buf[i++] = (new_pkt_len >> 16) & 255;
new_io_buf[i++] = (new_pkt_len >> 8) & 255;
new_io_buf[i++] = new_pkt_len & 255;
/* statement str - for type P */
if (pkt->type == 'P') {
strcpy(&new_io_buf[i], stmt_str);
i += strlen(stmt_str) + 1;
}
/* query string */
strcpy(&new_io_buf[i], new_query_str);
i += strlen(new_query_str) + 1;
/* copy everything else in buffer */
remaining_buffer_ptr = query_str + strlen(query_str) + 1;
remaining_buffer_len = (char *) &sbuf->io->buf[sbuf->io->recv_pos]
- remaining_buffer_ptr;
memcpy(&new_io_buf[i], remaining_buffer_ptr, remaining_buffer_len);
i += remaining_buffer_len;
/* replace original buffer with new buffer */
memcpy(sbuf->io->buf, new_io_buf, i);
/* adjust buffer recv_pos index to new position */
sbuf->io->recv_pos = i;
/* update PktHdr structure */
pkt->len = new_pkt_len + 1;
iobuf_parse_all(sbuf->io, &pkt->data);
/* done */
free(new_query_str);
free(new_io_buf);
return true;
}